-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
77 lines (70 loc) · 3.03 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import time
import os
from .options.test_options import TestOptions
from .data.data_loader import CreateDataLoader
from .models.models import create_model
from .util.visualizer import Visualizer
from pdb import set_trace as st
from .util import html
import uuid
import cv2
class Enlighten_GAN:
def __init__(self):
self.opt = TestOptions().parse()
self.opt.nThreads = 1 # test code only supports nThreads = 1
self.opt.batchSize = 1 # test code only supports batchSize = 1
self.opt.serial_batches = True # no shuffle
self.opt.no_flip = True # no flip
self.opt.dataroot = '/home/azureuser/imagewizard/EnlightenGAN/test_dataset'
self.opt.weight_folder = "weights/EnlightenGAN/"
self.opt.name = 'enlightening'
self.opt.model = 'single'
self.opt.gpu_ids=-1
self.opt.which_direction = 'AtoB'
self.opt.no_dropout = True
self.opt.dataset_mode = 'unaligned'
self.opt.which_model_netG = 'sid_unet_resize'
self.opt.skip = 1
self.opt.use_norm = 1
self.opt.use_wgan = 0
self.opt.self_attention = True
self.opt.times_residual = True
self.opt.instance_norm = 0
self.opt.resize_or_crop = 'no'
self.opt.which_epoch = 200
self.model = create_model(self.opt)
def get_enlightened_image(self, image):
img_name = str(uuid.uuid4())+'.png'
img_name = os.path.join(self.opt.dataroot, 'testA', img_name)
try:
# if(image.shape[0]>1000 or image.shape[1]>1000):
# print("Resize: ", image.shape)
# image=cv2.resize(image, (int(image.shape[1]/3), int(image.shape[0]/3)))
# print(image.shape)
cv2.imwrite(img_name, image)
print('Creating DataLoader ...')
data_loader = CreateDataLoader(self.opt)
print('Creating DataSet ...')
dataset = data_loader.load_data()
print('Creating Visualiser ...')
visualizer = Visualizer(self.opt)
# create website
# web_dir = os.path.join("./ablation/", opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
# webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
print("Length of Dataset "+str(len(dataset)))
print(dataset)
for i, data in enumerate(dataset):
print("Setting Inputs")
self.model.set_input(data)
print("Predicting from Model")
visuals = self.model.predict()
print("Saving Visualiser")
image_new = visualizer.save_images(visuals)
image=cv2.resize(image_new, (int(image.shape[1]), int(image.shape[0])))
os.remove(img_name)
return image
except Exception as e:
os.remove(img_name)
print(e)
return None