이미지 출력을 위한 전처리

def normalize_image(image):
    image_min = image.min()
    image_max = image.max()
    image.clamp_(min=image_min, max=image_max) ------ torch.clamp는 주어진 최소(min), 최대(max)의 범주에 이미지가 위치하도록 합니다.
    image.add_(-image_min).div_(image_max-image_min+1e-5) ------ ①
    return image

정확하게 예측한 이미지 출력 함수

def plot_most_correct(correct, classes, n_images, normalize=True):
    rows = int(np.sqrt(n_images)) ------ np.sqrt는 제곱근을 계산(0.5를 거듭제곱)
    cols = int(np.sqrt(n_images))
    fig = plt.figure(figsize=(25,20))
    for i in range(rows*cols):
        ax = fig.add_subplot(rows, cols, i+1) ------ 출력하려는 그래프 개수만큼 subplot을 만듭니다.
        image, true_label, probs = correct[i]
        image = image.permute(1, 2, 0) ------ ①
        true_prob = probs[true_label]
        correct_prob, correct_label = torch.max(probs, dim=0)
        true_class = classes[true_label]
        correct_class = classes[correct_label]

        if normalize: ------ 본래 이미지대로 출력하기 위해 normalize_image 함수 호출
            image = normalize_image(image)

        ax.imshow(image.cpu().numpy())
        ax.set_title(f'true label: {true_class} ({true_prob:.3f})\\n' \\
                     f'pred label: {correct_class} ({correct_prob:.3f})')
        ax.axis('off')

    fig.subplots_adjust(hspace=0.4)

예측

classes = test_dataset.classes
N_IMAGES = 5
plot_most_correct(correct_examples, classes, N_IMAGES)