-
Notifications
You must be signed in to change notification settings - Fork 4
/
test_real.py
76 lines (53 loc) · 2.05 KB
/
test_real.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import tensorflow as tf
import numpy as np
import glob
import os
from argparse import ArgumentParser
import scipy.io as sio
import model
def build_parser():
parser = ArgumentParser()
parser.add_argument('--gpu', dest='gpu',default='0')
parser.add_argument('--datapath', dest='datapath', default='LR_mat(x2)')
parser.add_argument('--modelpath', dest='modelpath',default='Model_B')
return parser
parser = build_parser()
option = parser.parse_args()
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = option.gpu
conf = tf.ConfigProto()
conf.gpu_options.per_process_gpu_memory_fraction = 0.9
'''Dataset'''
Sets=['Test','Part1','Part2','Part3']
ind=0
def main():
scale = 2
model_type='REF'
datapath=option.datapath
modelpath=option.modelpath
img_list=np.sort(np.asarray(glob.glob(os.path.join(datapath,Sets[ind],'*.mat'))))
fileNum=len(img_list)
print(img_list)
with tf.Session(config=conf) as sess:
for n in range(fileNum):
image = sio.loadmat(img_list[n])
input_REF = image['REF']
input_REF = np.tanh(input_REF)
[HH, WW]= np.shape(input_REF)
input_img = tf.placeholder(tf.float32, [1, HH, WW, 1])
testCNN = model.REF_Network(input_img, scale, reuse=tf.AUTO_REUSE)
output = testCNN.output
saver = tf.train.Saver()
ckpt_model = os.path.join(modelpath, 'model')
print(ckpt_model, os.path.basename(img_list[n]))
saver.restore(sess, ckpt_model)
img=input_REF[None,:,:,None]
out=sess.run(output,feed_dict={input_img: img})
savefolder='REF_result'
if not os.path.exists('%s/%s' % (savefolder,Sets[ind])):
os.makedirs('%s/%s' % (savefolder,Sets[ind]))
sio.savemat(os.path.join('%s/%s' % (savefolder, Sets[ind]), os.path.basename(img_list[n][:-4]+'_%s.mat' % model_type)),{model_type: out[0,:,:,0]})
if __name__=='__main__':
main()
print('Done')