-
Notifications
You must be signed in to change notification settings - Fork 70
/
Middleburry_Test.py
104 lines (71 loc) · 2.76 KB
/
Middleburry_Test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import sys
import time
import copy
import shutil
import random
import pdb
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from dataset.transforms import Resize
import config
import myutils
from torch.utils.data import DataLoader
args, unparsed = config.get_args()
cwd = os.getcwd()
device = torch.device('cuda' if args.cuda else 'cpu')
torch.manual_seed(args.random_seed)
if args.cuda:
torch.cuda.manual_seed(args.random_seed)
from dataset.Middleburry import get_loader
test_loader = get_loader(args.data_root, 1, shuffle=False, num_workers=args.num_workers)
from model.FLAVR_arch import UNet_3D_3D
print("Building model: %s"%args.model.lower())
model = UNet_3D_3D(args.model.lower() , n_inputs=args.nbr_frame, n_outputs=args.n_outputs, joinType=args.joinType)
# Just make every model to DataParallel
model = torch.nn.DataParallel(model).to(device)
print("#params" , sum([p.numel() for p in model.parameters()]))
def make_image(img):
# img = F.interpolate(img.unsqueeze(0) , (720,1280) , mode="bilinear").squeeze(0)
q_im = img.data.mul(255.).clamp(0,255).round()
im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
return im
folderList = ['Backyard', 'Basketball', 'Dumptruck', 'Evergreen', 'Mequon', 'Schefflera', 'Teddy', 'Urban']
def test(args):
time_taken = []
img_save_id = 0
losses, psnrs, ssims = myutils.init_meters(args.loss)
model.eval()
psnr_list = []
with torch.no_grad():
for i, (images, name ) in enumerate((test_loader)):
if name[0] not in folderList:
continue;
images = torch.stack(images , dim=1).squeeze(0)
# images = [img_.cuda() for img_ in images]
H,W = images[0].shape[-2:]
resizes = 8*(H//8) , 8*(W//8)
import torchvision
transform = Resize(resizes)
rev_transforms = Resize((H,W))
images = transform(images).unsqueeze(0).cuda()# [transform(img_.squeeze(0)).unsqueeze(0).cuda() for img_ in images]
images = torch.unbind(images, dim=1)
start_time = time.time()
out = model(images)
print("Time Taken" , time.time() - start_time)
out = torch.cat(out)
out = rev_transforms(out)
output_image = make_image(out.squeeze(0))
import imageio
os.makedirs("Middleburry/%s/"%name[0])
imageio.imwrite("Middleburry/%s/frame10i11.png"%name[0], output_image)
return
def main(args):
assert args.load_from is not None
model_dict = model.state_dict()
model.load_state_dict(torch.load(args.load_from)["state_dict"] , strict=True)
test(args)
if __name__ == "__main__":
main(args)