Day 12 수정
안녕하세요! 오늘 한 내용 중에 마지막 best_checkpoint.pth 로드해서 테스트 하는 부분 코드 다음과 같이 수정 부탁드립니다 😄
이미지와 라벨을 gpu로 옮겨주세요. 모델은 GPU에 있기 때문에 사진이랑 라벨로 올려주어야 합니다. imgs, labels = imgs.to(device), labels.to(device)
imshow 그래프 함수를 다음과 같이 넣어주세요! matplotlib이나 numpy는 GPU 텐서 처리가 불가해서 CPU로 옮긴 다음에 .numpy()를 해줘야합니다.
def imshow(input):
input = input.cpu().numpy().transpose((1, 2, 0))
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
input = std * input + mean
input = np.clip(input, 0, 1)
plt.imshow(input)
plt.axis('off')
plt.show()
감사합니다! 좋은 주말 되세요~!
13회 조회