-
Notifications
You must be signed in to change notification settings - Fork 6
/
test_old.py
142 lines (120 loc) · 5.55 KB
/
test_old.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import warnings
from utils import *
import tqdm
import torch
import argparse
from torch.utils.data import DataLoader
from model.fe import FeatureExtractor
from model.pmaa import PMAA
from dataset_old import MultipleDataset
from torch.utils.data import DataLoader, random_split
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
warnings.filterwarnings("ignore")
def test_and_visualization(opt, model_GEN, test_loader, criterion):
model_GEN.eval()
psnr_list = []
ssim_list = []
total_loss = 0
pbar = tqdm.tqdm(total=len(test_loader), ncols=0,
desc=f"{opt.test_mode}", unit=" step")
save_path = os.path.join(opt.predict_image_path, opt.test_mode)
with torch.no_grad():
for (real_A, real_B, image_names) in test_loader:
real_A[0], real_A[1], real_A[2], real_B = real_A[0].cuda(
), real_A[1].cuda(), real_A[2].cuda(), real_B.cuda()
real_A_input = torch.stack(
(real_A[0], real_A[1], real_A[2]), 1).cuda()
fake_B, cloud_mask, _ = model_GEN(real_A_input)
loss = criterion(fake_B, real_B)
for batch in range(opt.batch_size):
image_name = os.path.basename(image_names[batch])[:-4]
output, label = fake_B[batch], real_B[batch]
input_1, input_2, input_3 = real_A[0][batch], real_A[1][batch], real_A[2][batch]
input_1, input_2, input_3 = get_rgb_stgan(
input_1), get_rgb_stgan(input_2), get_rgb_stgan(input_3)
output_rgb, label_rgb = get_rgb_stgan(
output), get_rgb_stgan(label)
psnr, ssim = psnr_ssim_cal(label_rgb, output_rgb)
psnr_list.append(psnr)
ssim_list.append(ssim)
save_dir = os.path.join(
save_path, f"psnr_{psnr:.3f}_ssim_{ssim:.3f}")
os.makedirs(save_dir, exist_ok=True)
for idx, real_img in enumerate([input_1, input_2, input_3]):
save_image(real_img, save_dir, image_name +
f'_real_A{idx + 1}.png')
save_heatmap([cloud_mask[0][batch], cloud_mask[1][batch],
cloud_mask[2][batch]], save_dir, image_name)
save_image(output_rgb, save_dir, image_name + '_fake_B.png')
save_image(label_rgb, save_dir, image_name + '_real_B.png')
total_loss += loss.item()
pbar.update()
pbar.set_postfix(
loss_val=f"{total_loss:.4f}"
)
psnr_list = np.array(psnr_list)
ssim_list = np.array(ssim_list)
psnr = np.mean(psnr_list)
ssim = np.mean(ssim_list)
pbar.set_postfix(loss_val=f"{total_loss:.4f}",
psnr=f"{psnr:.3f}", ssim=f"{ssim:.3f}")
pbar.close()
return psnr, ssim
if __name__ == "__main__":
parser = argparse.ArgumentParser()
"""Path"""
parser.add_argument("--load_gen", type=str, default='checkpoints34/STGAN_multipleImage/G_best_PSNR_27.034_SSIM_0.848.pth',
help="which checkpoint you want to use for generator")
parser.add_argument("--predict_image_path", type=str,
default='./image_out34', help="name of the dataset_list")
parser.add_argument("--root", type=str, default='data',
help="Path to dataset")
"""Parameters"""
parser.add_argument("--image_size", type=int,
default=256, help="image size")
parser.add_argument("--in_channel", type=int, default=4,
help="the number of input channels")
parser.add_argument("--out_channel", type=int, default=4,
help="the number of output channels")
"""base_options"""
parser.add_argument("--test_mode", type=str, default='test',
help="which data_mode you want to use?(val/test)")
parser.add_argument("--n_cpu", type=int, default=0,
help="number of cpu threads to use during batch generation")
parser.add_argument("--batch_size", type=int,
default=1, help="size of the batches")
parser.add_argument("--gpu_id", type=str, default='0', help="gpu id")
opt = parser.parse_args()
random_seed_general = 2022
fixed_seed(random_seed_general)
os.makedirs(os.path.join(opt.predict_image_path,
opt.test_mode), exist_ok=True)
total_data = MultipleDataset(
root=os.path.join(opt.root, "multipleImage"),
band=opt.in_channel,
)
_, _, test_data = random_split(
dataset=total_data,
lengths=(2504, 313, 313),
generator=torch.Generator().manual_seed(2022),
)
test_loader = DataLoader(test_data, batch_size=opt.batch_size,
shuffle=False, num_workers=opt.n_cpu, drop_last=False)
"""define model & optimizer"""
model_GEN = PMAA(32, 3)
def replace_batchnorm(model):
for name, child in model.named_children():
if isinstance(child, torch.nn.BatchNorm2d):
child: torch.nn.BatchNorm2d = child
setattr(model, name, torch.nn.InstanceNorm2d(child.num_features))
else:
replace_batchnorm(child)
replace_batchnorm(model_GEN)
model_GEN.load_state_dict(torch.load(opt.load_gen))
print('load transformer model successfully!')
model_GEN = model_GEN.cuda()
criterion = torch.nn.L1Loss().cuda()
test_and_visualization(opt=opt, model_GEN=model_GEN,
test_loader=test_loader, criterion=criterion)