-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathtest_pt.py
41 lines (30 loc) · 1.35 KB
/
test_pt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# test script
# adapted from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from network_pt import Net
if __name__ == '__main__':
## cifar-10 dataset
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
dataiter = iter(testloader)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
## load the trained model
model = Net()
model.load_state_dict(torch.load('saved_model.pt'))
## inference
images, labels = next(dataiter)
print('Ground-truth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
outputs = model(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
# save to images
im = Image.fromarray((torch.cat(images.split(1,0),3).squeeze()/2*255+.5*255).permute(1,2,0).numpy().astype('uint8'))
im.save("test_pt_images.jpg")
print('test_pt_images.jpg saved.')