이미지 출력을 위한 전처리
def normalize_image(image):
image_min = image.min()
image_max = image.max()
image.clamp_(min=image_min, max=image_max) ------ torch.clamp는 주어진 최소(min), 최대(max)의 범주에 이미지가 위치하도록 합니다.
image.add_(-image_min).div_(image_max-image_min+1e-5) ------ ①
return image
정확하게 예측한 이미지 출력 함수
def plot_most_correct(correct, classes, n_images, normalize=True):
rows = int(np.sqrt(n_images)) ------ np.sqrt는 제곱근을 계산(0.5를 거듭제곱)
cols = int(np.sqrt(n_images))
fig = plt.figure(figsize=(25,20))
for i in range(rows*cols):
ax = fig.add_subplot(rows, cols, i+1) ------ 출력하려는 그래프 개수만큼 subplot을 만듭니다.
image, true_label, probs = correct[i]
image = image.permute(1, 2, 0) ------ ①
true_prob = probs[true_label]
correct_prob, correct_label = torch.max(probs, dim=0)
true_class = classes[true_label]
correct_class = classes[correct_label]
if normalize: ------ 본래 이미지대로 출력하기 위해 normalize_image 함수 호출
image = normalize_image(image)
ax.imshow(image.cpu().numpy())
ax.set_title(f'true label: {true_class} ({true_prob:.3f})\\n' \\
f'pred label: {correct_class} ({correct_prob:.3f})')
ax.axis('off')
fig.subplots_adjust(hspace=0.4)
예측
classes = test_dataset.classes
N_IMAGES = 5
plot_most_correct(correct_examples, classes, N_IMAGES)