1. 데이터 준비

1. 필요한 라이브러리 불러오기

import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms

2. 데이터 셋 로드 및 분할

train = datasets.MNIST( '../data', train = True, download = True, transform = transforms.Compose([transform.ToTensor(), ]),)
test = datasets.MNIST( '../data', train = False, donwnload = True, transform = transforms.Compose([transform.ToTensor(), ]),)

x = train.data.float() / 255
y = train.targets

x = x.view(x.size(0), -1)
print(x.shape, y.shape)

input_size = x.size(-1)
output_size = int(max(y)) + 1 #0~9까지 숫자 -> 총 10개

print('input_size : %d, output_size : %d' % (input_size, output_size))
ratios = [.8,.2]

train_cnt = int(x.size(0)*ratios[0])
valid_cnt = int(x.size(0)*ratios[1])
test_cnt = len(test.data)
cnts = [train_cnt, valid_cnt]

print("Train %d / Valid %d / Test %d samples." % (train_cnt, valid_cnt, test_cnt))

indices = torch.randperm(x.size(0))

x = torch.index_select(x, dim = 0, index = indices)
y = torch.index_select(y, dim = 0, index = indices)

x = list(x.split(cnts, dim = 0))
y = list(y.split(cnts, dim = 0))

x += [(test.data.float() / 255.).view(test_cnt, -1)]
y += [test.targets]

for x_i, y_i in zip(x,y) :
		print(x_i.size(), y_i.size())

2. 학습 코드 구현

1. 서브모듈 선언

import torch.nn as nn

class Block(nn.Module):
    def __init__(self, input_size, output_size, use_batch_norm=True, dropout_p=.4):
        super().__init__()
        
        self.input_size = input_size
        self.output_size = output_size
        self.use_batch_norm = use_batch_norm
        self.dropout_p = dropout_p
        
        def get_regularizer(use_batch_norm, size):
            return nn.BatchNorm1d(size) if use_batch_norm else nn.Dropout(dropout_p)
        
        self.block = nn.Sequential(
            nn.Linear(input_size, output_size),
            nn.LeakyReLU(),
            get_regularizer(use_batch_norm, output_size),
        )

    def forward(self, x):
        y = self.block(x)
        return y