-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathother_utils.py
200 lines (164 loc) · 6.02 KB
/
other_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""General utilities for displaying and loading data, RGB to gray
function, and testing source/target generators
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
import os
import math
def rgb2gray(rgb):
"""Convert from color image (RGB) to grayscale
Reference: opencv.org
Formula: grayscale = 0.299*red + 0.587*green + 0.114*blue
"""
return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])
def display_images(imgs,
filename,
title='',
imgs_dir=None,
show=False):
"""Display images in an nxn grid
Arguments:
imgs (tensor): array of images
filename (string): filename to save the displayed image
title (string): title on the displayed image
imgs_dir (string): directory where to save the files
show (bool): whether to display the image or not
(False during training, True during testing)
"""
rows = imgs.shape[1]
cols = imgs.shape[2]
channels = imgs.shape[3]
side = int(math.sqrt(imgs.shape[0]))
assert int(side * side) == imgs.shape[0]
# create saved_images folder
if imgs_dir is None:
imgs_dir = 'saved_images'
save_dir = os.path.join(os.getcwd(), imgs_dir)
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
filename = os.path.join(imgs_dir, filename)
# rows, cols, channels = img_shape
if channels==1:
imgs = imgs.reshape((side, side, rows, cols))
else:
imgs = imgs.reshape((side, side, rows, cols, channels))
imgs = np.vstack([np.hstack(i) for i in imgs])
plt.figure()
plt.axis('off')
plt.title(title)
if channels==1:
plt.imshow(imgs, interpolation='none', cmap='gray')
else:
plt.imshow(imgs, interpolation='none')
plt.savefig(filename)
if show:
plt.show()
plt.close('all')
def test_generator(generators,
test_data,
step,
titles,
dirs,
todisplay=100,
show=False):
"""Test the generator models
Arguments:
generators (tuple): source and target generators
test_data (tuple): source and target test data
step (int): step number during training (0 during testing)
titles (tuple): titles on the displayed image
dirs (tuple): folders to save the outputs of testings
todisplay (int): number of images to display (must be
perfect square)
show (bool): whether to display the image or not
(False during training, True during testing)
"""
# predict the output from test data
g_source, g_target = generators
test_source_data, test_target_data = test_data
t1, t2, t3, t4 = titles
title_pred_source = t1
title_pred_target = t2
title_reco_source = t3
title_reco_target = t4
dir_pred_source, dir_pred_target = dirs
pred_target_data = g_target.predict(test_source_data)
pred_source_data = g_source.predict(test_target_data)
reco_source_data = g_source.predict(pred_target_data)
reco_target_data = g_target.predict(pred_source_data)
# display the 1st todisplay images
imgs = pred_target_data[:todisplay]
filename = '%06d.png' % step
step = " Step: {:,}".format(step)
title = title_pred_target + step
display_images(imgs,
filename=filename,
imgs_dir=dir_pred_target,
title=title,
show=show)
imgs = pred_source_data[:todisplay]
title = title_pred_source
display_images(imgs,
filename=filename,
imgs_dir=dir_pred_source,
title=title,
show=show)
imgs = reco_source_data[:todisplay]
title = title_reco_source
filename = "reconstructed_source.png"
display_images(imgs,
filename=filename,
imgs_dir=dir_pred_source,
title=title,
show=show)
imgs = reco_target_data[:todisplay]
title = title_reco_target
filename = "reconstructed_target.png"
display_images(imgs,
filename=filename,
imgs_dir=dir_pred_target,
title=title,
show=show)
def load_data(data, titles, filenames, todisplay=100):
"""Generic loaded data transformation
Arguments:
data (tuple): source, target, test source, test target data
titles (tuple): titles of the test and source images to display
filenames (tuple): filenames of the test and source images to
display
todisplay (int): number of images to display (must be
perfect square)
"""
source_data, target_data, test_source_data, test_target_data = data
test_source_filename, test_target_filename = filenames
test_source_title, test_target_title = titles
# display test target images
imgs = test_target_data[:todisplay]
display_images(imgs,
filename=test_target_filename,
title=test_target_title)
# display test source images
imgs = test_source_data[:todisplay]
display_images(imgs,
filename=test_source_filename,
title=test_source_title)
# normalize images
target_data = target_data.astype('float32') / 255
test_target_data = test_target_data.astype('float32') / 255
source_data = source_data.astype('float32') / 255
test_source_data = test_source_data.astype('float32') / 255
# source data, target data, test_source data
data = (source_data, target_data, test_source_data, test_target_data)
rows = source_data.shape[1]
cols = source_data.shape[2]
channels = source_data.shape[3]
source_shape = (rows, cols, channels)
rows = target_data.shape[1]
cols = target_data.shape[2]
channels = target_data.shape[3]
target_shape = (rows, cols, channels)
shapes = (source_shape, target_shape)
return data, shapes