-
Notifications
You must be signed in to change notification settings - Fork 32
/
utils.py
140 lines (107 loc) · 3.45 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# -*- coding: utf-8 -*-
import cv2
import numpy as np
import matplotlib.pyplot as plt
from keras import backend as K
def read_img(img_path, preprocess_input, size):
"""util function to read and preprocess the test image.
Args:
img_path: path of image.
preprocess_input: preprocess_input function.
size: resize.
Returns:
img: original image.
pimg: processed image.
"""
img = cv2.imread(img_path)
pimg = cv2.resize(img, size)
pimg = np.expand_dims(pimg, axis=0)
pimg = preprocess_input(pimg)
return img, pimg
def deprocess_image(x):
"""util function to convert a tensor into a valid image.
Args:
x: tensor of filter.
Returns:
x: deprocessed tensor.
"""
# normalize tensor: center on 0., ensure std is 0.1
x -= x.mean()
x /= (x.std() + 1e-5)
x *= 0.1
# clip to [0, 1]
x += 0.5
x = np.clip(x, 0, 1)
# convert to RGB array
x *= 255
if K.image_data_format() == 'channels_first':
x = x.transpose((1, 2, 0))
x = np.clip(x, 0, 255).astype('uint8')
return x
def normalize(x):
"""utility function to normalize a tensor by its L2 norm
Args:
x: gradient.
Returns:
x: gradient.
"""
return x / (K.sqrt(K.mean(K.square(x))) + K.epsilon())
def vis_conv(images, n, name, t):
"""visualize conv output and conv filter.
Args:
img: original image.
n: number of col and row.
t: vis type.
name: save name.
"""
size = 64
margin = 5
if t == 'filter':
results = np.zeros((n * size + 7 * margin, n * size + 7 * margin, 3))
if t == 'conv':
results = np.zeros((n * size + 7 * margin, n * size + 7 * margin))
for i in range(n):
for j in range(n):
if t == 'filter':
filter_img = images[i + (j * n)]
if t == 'conv':
filter_img = images[..., i + (j * n)]
filter_img = cv2.resize(filter_img, (size, size))
# Put the result in the square `(i, j)` of the results grid
horizontal_start = i * size + i * margin
horizontal_end = horizontal_start + size
vertical_start = j * size + j * margin
vertical_end = vertical_start + size
if t == 'filter':
results[horizontal_start: horizontal_end, vertical_start: vertical_end, :] = filter_img
if t == 'conv':
results[horizontal_start: horizontal_end, vertical_start: vertical_end] = filter_img
# Display the results grid
plt.imshow(results)
plt.savefig('images\{}_{}.jpg'.format(t, name), dpi=600)
plt.show()
def vis_heatmap(img, heatmap):
"""visualize heatmap.
Args:
img: original image.
heatmap:heatmap.
"""
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
plt.figure()
plt.subplot(221)
plt.imshow(cv2.resize(img, (224, 224)))
plt.axis('off')
plt.subplot(222)
plt.imshow(heatmap)
plt.axis('off')
plt.subplot(212)
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)
# We apply the heatmap to the original image
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = heatmap * 0.4 + img
plt.imshow(superimposed_img)
plt.axis('off')
plt.tight_layout()
plt.savefig('images\heatmap.jpg', dpi=600)
plt.show()