-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathIPN V2+_test.py
70 lines (68 loc) · 2.97 KB
/
IPN V2+_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn as nn
import logging
import sys
import os
import model
import numpy as np
import scipy.misc as misc
from options.test_options import TestOptions
import natsort
from scipy import io
def test_net(net,device):
DATA_SIZE = opt.data_size
test_results = os.path.join(opt.saveroot, 'test_results_V2+')
net.eval()
test_images = np.zeros((1, opt.channels,DATA_SIZE[1], DATA_SIZE[2]))
testids = opt.test_ids
valids = opt.val_ids
featurelist0 = os.listdir(os.path.join(opt.dataroot, opt.modality_filename[0]))
featurelist0 = natsort.natsorted(featurelist0)
featurelist = featurelist0[valids[0]:valids[1]]+featurelist0[testids[0]:testids[1]]
for cube in featurelist:
test_images[0, :, :, :] = np.load(os.path.join(opt.feature_dir, cube + '.npy'))
images = torch.from_numpy(test_images)
images = images.to(device=device, dtype=torch.float32)
pred,featuremap= net(images)
pred = torch.nn.functional.softmax(pred, dim=1)
result=pred[0,1, :,:].cpu().detach().numpy()*255
misc.imsave(os.path.join(test_results, cube + ".bmp"), result.astype(np.uint8))
featuremap = np.squeeze(featuremap.cpu().detach().numpy(), 0)
#io.savemat(os.path.join('logs/Features_V2+', cube + ".mat"), {'feature': featuremap})
print(cube)
if __name__ == '__main__':
#setting logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
#loading options
opt = TestOptions().parse()
#setting GPU
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
#loading network
if opt.plane_perceptron == 'UNet_3Plus':
net = model.UNet_3Plus(in_channels=opt.plane_perceptron_channels, channels=opt.plane_perceptron_channels, n_classes=opt.n_classes)
if opt.plane_perceptron == 'UNet':
net = model.UNet(in_channels=opt.plane_perceptron_channels, channels=opt.plane_perceptron_channels, n_classes=opt.n_classes)
#load trained model
bestmodelpath = os.path.join(opt.saveroot, 'best_model_V2+',
natsort.natsorted(os.listdir(os.path.join(opt.saveroot, 'best_model_V2+')))[-1])
restore_path = os.path.join(opt.saveroot, 'best_model_V2+',
natsort.natsorted(os.listdir(os.path.join(opt.saveroot, 'best_model_V2+')))[-1]) + '/' + \
os.listdir(bestmodelpath)[0]
print(restore_path)
#restore_path = os.path.join(opt.saveroot, 'checkpoints_V2+', '4800.pth')
net.load_state_dict(
torch.load(restore_path, map_location=device)
)
#input the model into GPU
net.to(device=device)
try:
test_net(net=net,device=device)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)