import torch
import torch.nn as nn
import torch.nn.functional as F
BN_EPS = 1e-10
class ConvBnRelu1d(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False, kernel_size=3, padding=1, stride=1):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
        self.batchnorm = nn.BatchNorm1d(out_channels, eps=BN_EPS)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.2) if dropout is True else None

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        if self.dropout is not None:
          x = self.dropout(x)
        return x

class BnRelu1d(nn.Module):
    def __init__(self, out_channels, dropout=False):
        super().__init__()
        self.batchnorm = nn.BatchNorm1d(out_channels, eps=BN_EPS)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5) if dropout is True else None

    def forward(self, x):
        x = self.batchnorm(x)
        x = self.relu(x)
        if self.dropout is not None:
          x = self.dropout(x)
        return x

class Maxpool1d(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
      super().__init__()
      self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2,ceil_mode=True)
      self.conv1x1 = nn.Conv1d(in_channels, out_channels, kernel_size=1)
      self.pad = stride
      
    def forward(self, x):
      _, _, L = x.size()
      pad_ = (L+1)//2
      x = self.maxpool(x)
      if self.pad == 1:
        x = F.pad(x, (0,pad_))
      x = self.conv1x1(x)
      return x

class ResidualUnit(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
      super().__init__()
      self.maxpool1d = Maxpool1d(in_channels=in_channels, out_channels=out_channels, stride=stride)
      self.ConvBnRelu = ConvBnRelu1d(in_channels, out_channels, dropout=True, kernel_size=kernel_size, padding=padding, stride=stride)
      self.conv1d = nn.Conv1d(out_channels, out_channels, kernel_size = 11, padding = 5)
      self.BnRelu = BnRelu1d(out_channels, dropout=True)

    def forward(self, x, x_skip):
      x_skip = self.maxpool1d(x_skip)
      x = self.ConvBnRelu(x)
      x = self.conv1d(x)
      x_skip = x + x_skip
      x = self.BnRelu(x_skip)
      return x, x_skip

class ECGmodel(nn.Module):
    def __init__(self):
      super().__init__()
      self.conv = ConvBnRelu1d(12, 512, stride=2)
      self.res1 = ResidualUnit(512, 256, stride=2)
      self.res2 = ResidualUnit(256, 128, stride=2)
      self.res3 = ResidualUnit(128, 128, stride=2)
      self.res4 = ResidualUnit(128, 64, stride=2)
      self.dense = nn.Linear(64*32, 21)
      
    def forward(self, x):
      x = self.conv(x)
      x, x_skip = self.res1(x, x)
      x, x_skip = self.res2(x, x_skip)
      x, x_skip = self.res3(x, x_skip)
      x, _ = self.res4(x, x_skip)
      b, c, l = x.size()
      x = x.view(-1, c*l)
      x = self.dense(x)
      return x

summary 확인

from torchsummary import summary
model = ECGmodel()
summary(model, (12, 10000))
import torch
import torch.nn as nn
import torch.nn.functional as F
BN_EPS = 1e-10
class ConvBnRelu1d(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False, kernel_size=3, padding=1, stride=1):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
        self.batchnorm = nn.BatchNorm1d(out_channels, eps=BN_EPS)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5) if dropout is True else None

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        if self.dropout is not None:
          x = self.dropout(x)
        return x

class BnRelu1d(nn.Module):
    def __init__(self, out_channels, dropout=False):
        super().__init__()
        self.batchnorm = nn.BatchNorm1d(out_channels, eps=BN_EPS)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5) if dropout is True else None

    def forward(self, x):
        x = self.batchnorm(x)
        x = self.relu(x)
        if self.dropout is not None:
          x = self.dropout(x)
        return x

class Maxpool1d(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
      super().__init__()
      self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2,ceil_mode=True)
      self.conv1x1 = nn.Conv1d(in_channels, out_channels, kernel_size=1)
      self.pad = stride
      
    def forward(self, x):
      _, _, L = x.size()
      pad_ = (L+1)//2
      x = self.maxpool(x)
      if self.pad == 1:
        x = F.pad(x, (0,pad_))
      x = self.conv1x1(x)
      return x

class ResidualUnit(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
      super().__init__()
      self.maxpool1d = Maxpool1d(in_channels=in_channels, out_channels=out_channels, stride=stride)
      self.ConvBnRelu = ConvBnRelu1d(in_channels, out_channels, dropout=True, kernel_size=kernel_size, padding=padding, stride=stride)
      self.conv1d = nn.Conv1d(out_channels, out_channels, kernel_size = 11, padding = 5)
      self.BnRelu = BnRelu1d(out_channels, dropout=True)

    def forward(self, x, x_skip):
      x_skip = self.maxpool1d(x_skip)
      x = self.ConvBnRelu(x)
      x = self.conv1d(x)
      print(x.size())
      print(x_skip.size())
      x_skip = x + x_skip
      x = self.BnRelu(x_skip)
      return x, x_skip

class ECGmodel(nn.Module):
    def __init__(self):
      super().__init__()
      self.conv = ConvBnRelu1d(12, 512, stride=2)
      self.res1 = ResidualUnit(512, 256, stride=2)
      self.res2 = ResidualUnit(256, 128, stride=2)
      self.res3 = ResidualUnit(128, 128, stride=2)
      self.res4 = ResidualUnit(128, 64, stride=2)
      self.dense = nn.Linear(64*32, 21)
      
    def forward(self, x):
      x = self.conv(x)
      x, x_skip = self.res1(x, x)
      x, x_skip = self.res2(x, x_skip)
      x, x_skip = self.res3(x, x_skip)
      x, _ = self.res4(x, x_skip)
      b, c, l = x.size()
      x = x.view(-1, c*l)
      x = self.dense(x)
      return x