I'm Lim

[논문 구현] LeNet 본문

Classification/Implementation

[논문 구현] LeNet

imlim 2023. 3. 27. 08:41

1. LeNet 구현

 

위 그림을 참조하여 아래와 같이 코드 구현을 진행하였습니다.

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1, padding=0)
       
        self.fc1 = nn.Linear(in_features=120, out_features=84)
        self.fc2 = nn.Linear(in_features=84, out_features=10)

        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        x = self.avg_pool(F.tanh(self.conv1(x)))
        x = self.avg_pool(F.tanh(self.conv2(x)))
        x = F.tanh(self.conv3(x))

        x = x.view(x.size(0), -1)

        x = F.tanh(self.fc1(x))
        x = self.fc2(x)

        return x

※ Loss function으로 CrossEntropyLoss를 사용하였습니다.

 Optimizer로 SGD를 사용하였습니다.

2. LeNet 학습 결과

학습을 위해 MNIST 데이터 셋을 사용하였고, Batch size : 128, epoch : 100, Learning Rate : 1e-3, Weight decay : 1e-5으로 설정하였습니다.

 

그다지 어렵지 않은 MNIST 데이터 셋이라 그런지 Overfitting은 발생하지 않았습니다. 

Comments