-
Notifications
You must be signed in to change notification settings - Fork 2
/
ablation_vis.py
81 lines (62 loc) · 2.78 KB
/
ablation_vis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
from PIL import Image
import os
from util import semantic_to_mask, get_confusion_matrix, get_miou, get_classification_report
import torch
import torch.nn.functional as F
import cv2
from data_loader import get_dataloader
import torch.nn as nn
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
@torch.no_grad()
def generate_test():
input_dir = "../data/NPC20_V1/val/image"
mask_dir = "../data/NPC20_V1/val/mask"
output_dir = "../data/NPC20_V1/out_ablation/"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_names = os.listdir("./exp/ablation")
# model = torch.load("./exp/21_RendDANet_0.6905.pth", map_location='cpu').module
for model_name in model_names:
model = torch.load("./exp/ablation/" + model_name, map_location='cpu').module
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model = model.to(device)
model.eval()
labels = [0, 1, 2]
files = os.listdir(input_dir)
print("start vis process for model", model_name)
for file in files:
image = np.load(os.path.join(input_dir, file))
for i in range(image.shape[2]):
image[:, :, i] = (image[:, :, i] - np.min(image[:, :, i])) / (np.max(image[:, :, i]) - np.min(image[:, :, i]))
image = torch.from_numpy(image.transpose((2, 0, 1))).unsqueeze(dim=0).to(device)
if model_name.split('+')[1] == "BEM":
pred, final = model(image)
output = final['fine'].cpu().detach().numpy()
elif model_name.split('+')[1] == "SEM":
aux, pred = model(image)
output = pred.cpu().detach().numpy()
else:
pred = model(image)
output = pred.cpu().detach().numpy()
pred = semantic_to_mask(output, labels).squeeze()
mask = np.load(os.path.join(mask_dir, file.split('.')[0]) + ".npy")
cm = get_confusion_matrix(mask, pred, labels)
# print(cm)
miou = get_miou(cm)
print(miou)
score = (miou[1] + miou[2]) / ((miou != 0).sum() - 1 + 1e-6)
size = pred.shape[0]
print(model_name.split('+')[1], file)
# 红色NPC,绿色NPL
color = np.zeros([size, size, 3], dtype=np.uint8)
npc = pred == 1
npl = pred == 2
color[:, :, 0][npc] = 255
color[:, :, 1][npl] = 255
png_slice = Image.fromarray(color)
png_slice.save(os.path.join(output_dir + model_name.split('.p')[0], file.split('.')[0]) + "_" + str((miou != 0).sum() - 1) + "_" + str(score)[:6] + ".png")
if __name__ == "__main__":
generate_test()
exit(0)