-
Notifications
You must be signed in to change notification settings - Fork 57
/
modified_m2.py
83 lines (74 loc) · 3.22 KB
/
modified_m2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorbayes.layers import Constant, Placeholder, Dense, GaussianSample
from tensorbayes.distributions import log_bernoulli_with_logits, log_normal
from tensorbayes.tbutils import cross_entropy_with_logits
from tensorbayes.nbutils import show_graph
from tensorbayes.utils import progbar
import numpy as np
import sys
from shared_subgraphs import qy_graph, qz_graph, labeled_loss
from utils import train
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
try:
method = sys.argv[1]
except IndexError:
raise Exception('Pass an argument specifying identity/relu/layer\n'
'e.g. python modified_m2.py identity')
def custom_layer(zy, reuse):
# Here are 3 choices for what to do with zy
# I leave this as hyperparameter
if method == 'identity':
return zy
elif method == 'relu':
return tf.nn.relu(zy)
elif method == 'layer':
return Dense(zy, 512, 'layer1', tf.nn.relu, reuse=reuse)
else:
raise Exception('Undefined method')
def px_graph(z, y):
reuse = len(tf.get_collection(tf.GraphKeys.VARIABLES, scope='px')) > 0
# -- transform z to be a sample from one of the Gaussian mixture components
with tf.variable_scope('z_transform'):
zm = Dense(y, 64, 'zm', reuse=reuse)
zv = Dense(y, 64, 'zv', tf.nn.softplus, reuse=reuse)
# -- p(x)
with tf.variable_scope('px'):
with tf.name_scope('layer1'):
zy = zm + tf.sqrt(zv) * z
h1 = custom_layer(zy, reuse)
h2 = Dense(h1, 512, 'layer2', tf.nn.relu, reuse=reuse)
px_logit = Dense(h2, 784, 'logit', reuse=reuse)
return px_logit
tf.reset_default_graph()
x = Placeholder((None, 784), 'x')
# binarize data and create a y "placeholder"
with tf.name_scope('x_binarized'):
xb = tf.cast(tf.greater(x, tf.random_uniform(tf.shape(x), 0, 1)), tf.float32)
with tf.name_scope('y_'):
y_ = tf.fill(tf.pack([tf.shape(x)[0], 10]), 0.0)
# propose distribution over y
qy_logit, qy = qy_graph(xb)
# for each proposed y, infer z and reconstruct x
z, zm, zv, px_logit = [[None] * 10 for i in xrange(4)]
for i in xrange(10):
with tf.name_scope('graphs/hot_at{:d}'.format(i)):
y = tf.add(y_, Constant(np.eye(10)[i], name='hot_at_{:d}'.format(i)))
z[i], zm[i], zv[i] = qz_graph(xb, y)
px_logit[i] = px_graph(z[i], y)
# Aggressive name scoping for pretty graph visualization :P
with tf.name_scope('loss'):
with tf.name_scope('neg_entropy'):
nent = -cross_entropy_with_logits(qy_logit, qy)
losses = [None] * 10
for i in xrange(10):
with tf.name_scope('loss_at{:d}'.format(i)):
losses[i] = labeled_loss(xb, px_logit[i], z[i], zm[i], zv[i], Constant(0), Constant(1))
with tf.name_scope('final_loss'):
loss = tf.add_n([nent] + [qy[:, i] * losses[i] for i in xrange(10)])
train_step = tf.train.AdamOptimizer().minimize(loss)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
# sess.run(tf.global_variables_initializer()) # Change initialization protocol depending on tensorflow version
sess_info = (sess, qy_logit, nent, loss, train_step)
train('logs/modified_m2_method={:s}.log'.format(method), mnist, sess_info, epochs=1000)