-
Notifications
You must be signed in to change notification settings - Fork 1
/
tests.py
47 lines (36 loc) · 1.46 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import torch
from torch.utils.data import DataLoader
from torchsummary import summary
from datasets import *
from models import *
from parameters import args
if __name__ == '__main__':
test_dic = make_datapath_dic('test')
transform = ImageTransform(64)
test_dataset = TripletDataset(test_dic, transform=transform, phase='test')
batch_size = 32
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TripletNet().to(device)
model.eval()
model_weights = '0.14895377705494564.pth'
model.load_state_dict(torch.load(model_weights))
summary(model, (3, 64, 64))
predicted_metrics = []
test_labels = []
with torch.no_grad():
for i, (anchor, _, _, label) in enumerate(test_dataloader):
metric = model(anchor).detach().cpu().numpy()
metric = metric.reshape(metric.shape[0], metric.shape[1])
predicted_metrics.append(metric)
test_labels.append(label.detach().numpy())
predicted_metrics = np.concatenate(predicted_metrics, 0)
test_labels = np.concatenate(test_labels, 0)
tSNE_metrics = TSNE(n_components=2, random_state=0).fit_transform(predicted_metrics)
plt.scatter(tSNE_metrics[:, 0], tSNE_metrics[:, 1], c=test_labels)
plt.colorbar()
plt.savefig(model_weights + '.png')
plt.show()