ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [Pytorch]ResNet(전차학습모델)
    컴퓨터비전/pytorch 2023. 1. 13. 23:09

     

    CNN은 레이어가깊어지면 기울기 소실이 쉽게 일어난다 그래서 전에있던값을 재탕함으로써 기울기를 살리는 기법

     

    Block부분

    
    class BasicBlock(nn.Module):
      def __init__(self,input_dim,output_dim, stride = 1):
        super(BasicBlock,self).__init__()
        self.conv1=nn.Conv2d(input_dim,output_dim,3,1,1)
        self.conv2=nn.Conv2d(output_dim,output_dim,3,1,1)
        self.bn1=nn.BatchNorm2d(output_dim)
        self.bn2=nn.BatchNorm2d(output_dim)
    
    
        self.relu1=nn.ReLU()
        self.relu2=nn.ReLU()
        self.shortcut=nn.Sequential()
        if input_dim!=output_dim or  stride != 1:
          self.shortcut=nn.Sequential(
              nn.Conv2d(input_dim,output_dim,3,1,1),
              nn.BatchNorm2d(output_dim)
          )
        
      def forward(self,x):
        out=self.conv1(x)
        out=self.bn1(out)
        out=self.relu1(out)
        out=self.bn2(out)
        out+=self.shortcut(x)
    
        out=self.relu2(out)
    
        
        return(out)

    Model부분

    class ResNet(nn.Module):
        def __init__(self, num_classes = 2):
            super(ResNet, self).__init__()
            self.in_planes = 16
            
            self.conv1 = nn.Conv2d(3, 16, kernel_size = 3, stride = 1, padding = 1, bias = False)
            self.bn1 = nn.BatchNorm2d(16)
            self.layer1 = BasicBlock(16,16,1)
            self.layer2 = BasicBlock(16,16,1)
            self.layer3 = BasicBlock(16,32,1)
            self.layer4 = BasicBlock(32,32,1)
            self.layer5 = BasicBlock(32,64,1)
            self.layer6 = BasicBlock(64,64,1)
            self.linear = nn.Linear(64*4*4, num_classes)
            
        
        
        def forward(self, x):
            out = F.relu(self.bn1(self.conv1(x)))
            out = self.layer1(out)
    
            out = self.layer2(out)
    
            out = self.layer3(out)
            out = self.layer4(out)
            out = self.layer5(out)
            out = self.layer6(out)
    
            out = F.avg_pool2d(out, 8)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
            return out

    AllexNet 논문을 읽고 구현해 봤는데 레이어를 깊게 쌓을수록 학습이 잘 안되었고 다른 논문을 뒤져본 결과 이 ResNet을 찾았다.

Designed by Tistory.