forked from Sohl-Dickstein/Diffusion-Probabilistic-Models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
viz.py
158 lines (139 loc) · 5.37 KB
/
viz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Tools for plotting / visualization
"""
import matplotlib
matplotlib.use('Agg') # no displayed figures -- need to call before loading pylab
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import warnings
def is_square(shp, n_colors=1):
"""
Test whether entries in shp are square numbers, or are square numbers after divigind out the
number of color channels.
"""
is_sqr = (shp == np.round(np.sqrt(shp))**2)
is_sqr_colors = (shp == n_colors*np.round(np.sqrt(np.array(shp)/float(n_colors)))**2)
return is_sqr | is_sqr_colors
def show_receptive_fields(theta, P=None, n_colors=None, max_display=100, grid_wa=None):
"""
Display receptive fields in a grid. Tries to intelligently guess whether to treat the rows,
the columns, or the last two axes together as containing the receptive fields. It does this
by checking which axes are square numbers -- so you can get some unexpected plots if the wrong
axis is a square number, or if multiple axes are. It also tries to handle the last axis
containing color channels correctly.
"""
shp = np.array(theta.shape)
if n_colors is None:
n_colors = 1
if shp[-1] == 3:
n_colors = 3
# multiply colors in as appropriate
if shp[-1] == n_colors:
shp[-2] *= n_colors
theta = theta.reshape(shp[:-1])
shp = np.array(theta.shape)
if len(shp) > 2:
# merge last two axes
shp[-2] *= shp[-1]
theta = theta.reshape(shp[:-1])
shp = np.array(theta.shape)
if len(shp) > 2:
# merge leading axes
theta = theta.reshape((-1,shp[-1]))
shp = np.array(theta.shape)
if len(shp) == 1:
theta = theta.reshape((-1,1))
shp = np.array(theta.shape)
# figure out the right orientation, by looking for the axis with a square
# number of entries, up to number of colors. transpose if required
is_sqr = is_square(shp, n_colors=n_colors)
if is_sqr[0] and is_sqr[1]:
warnings.warn("Unsure of correct matrix orientation. "
"Assuming receptive fields along first dimension.")
elif is_sqr[1]:
theta = theta.T
elif not is_sqr[0] and not is_sqr[1]:
# neither direction corresponds well to an image
# NOTE if you delete this next line, the code will work. The rfs just won't look very
# image like
return False
theta = theta[:,:max_display].copy()
if P is None:
img_w = int(np.ceil(np.sqrt(theta.shape[0]/float(n_colors))))
else:
img_w = int(np.ceil(np.sqrt(P.shape[0]/float(n_colors))))
nf = theta.shape[1]
if grid_wa is None:
grid_wa = int(np.ceil(np.sqrt(float(nf))))
grid_wb = int(np.ceil(nf / float(grid_wa)))
if P is not None:
theta = np.dot(P, theta)
vmin = np.min(theta)
vmax = np.max(theta)
for jj in range(nf):
plt.subplot(grid_wa, grid_wb, jj+1)
ptch = np.zeros((n_colors*img_w**2,))
ptch[:theta.shape[0]] = theta[:,jj]
if n_colors==3:
ptch = ptch.reshape((n_colors, img_w, img_w))
ptch = ptch.transpose((1,2,0)) # move color channels to end
else:
ptch = ptch.reshape((img_w, img_w))
ptch -= vmin
ptch /= vmax-vmin
plt.imshow(ptch, interpolation='nearest', cmap=cm.Greys_r )
plt.axis('off')
return True
def plot_parameter(theta_in, base_fname_part1, base_fname_part2="", title = '', n_colors=None):
"""
Save both a raw and receptive field style plot of the contents of theta_in.
base_fname_part1 provides the mandatory root of the filename.
"""
theta = np.array(theta_in.copy()) # in case it was a scalar
print "%s min %g median %g mean %g max %g shape"%(
title, np.min(theta), np.median(theta), np.mean(theta), np.max(theta)), theta.shape
theta = np.squeeze(theta)
if len(theta.shape) == 0:
# it's a scalar -- make it a 1d array
theta = np.array([theta])
shp = theta.shape
if len(shp) > 2:
theta = theta.reshape((theta.shape[0], -1))
shp = theta.shape
## display basic figure
plt.figure(figsize=[8,8])
if len(shp) == 1:
plt.plot(theta, '.', alpha=0.5)
elif len(shp) == 2:
plt.imshow(theta, interpolation='nearest', aspect='auto', cmap=cm.Greys_r)
plt.colorbar()
plt.title(title)
plt.savefig(base_fname_part1 + '_raw_' + base_fname_part2 + '.pdf')
plt.close()
## also display it in basis function view if it's a matrix, or
## if it's a bias with a square number of entries
if len(shp) >= 2 or is_square(shp[0]):
if len(shp) == 1:
theta = theta.reshape((-1,1))
plt.figure(figsize=[8,8])
if show_receptive_fields(theta, n_colors=n_colors):
plt.suptitle(title + "receptive fields")
plt.savefig(base_fname_part1 + '_rf_' + base_fname_part2 + '.pdf')
plt.close()
def plot_images(X, fname):
"""
Plot images in a grid.
X is expected to be a 4d tensor of dimensions [# images]x[# colors]x[height]x[width]
"""
## plot
# move color to end
Xcol = X.reshape((X.shape[0],-1,)).T
plt.figure(figsize=[8,8])
if show_receptive_fields(Xcol, n_colors=X.shape[1]):
plt.savefig(fname + '.pdf')
else:
warnings.warn('Images unexpected shape.')
plt.close()
## save as a .npz file
np.savez(fname + '.npz', X=X)