forked from RuiShu/vae-clustering
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gmvae.py
63 lines (56 loc) · 2.58 KB
/
gmvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorbayes.layers import Constant, Placeholder, Dense, GaussianSample
from tensorbayes.distributions import log_bernoulli_with_logits, log_normal
from tensorbayes.tbutils import cross_entropy_with_logits
from tensorbayes.nbutils import show_graph
from tensorbayes.utils import progbar
import numpy as np
import sys
from shared_subgraphs import qy_graph, qz_graph, labeled_loss
from utils import train
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
def px_graph(z, y):
reuse = len(tf.get_collection(tf.GraphKeys.VARIABLES, scope='px')) > 0
# -- p(z)
with tf.variable_scope('pz'):
zm = Dense(y, 64, 'zm', reuse=reuse)
zv = Dense(y, 64, 'zv', tf.nn.softplus, reuse=reuse)
# -- p(x)
with tf.variable_scope('px'):
h1 = Dense(z, 512, 'layer1', tf.nn.relu, reuse=reuse)
h2 = Dense(h1, 512, 'layer2', tf.nn.relu, reuse=reuse)
px_logit = Dense(h2, 784, 'logit', reuse=reuse)
return zm, zv, px_logit
tf.reset_default_graph()
x = Placeholder((None, 784), 'x')
# binarize data and create a y "placeholder"
with tf.name_scope('x_binarized'):
xb = tf.cast(tf.greater(x, tf.random_uniform(tf.shape(x), 0, 1)), tf.float32)
with tf.name_scope('y_'):
y_ = tf.fill(tf.pack([tf.shape(x)[0], 10]), 0.0)
# propose distribution over y
qy_logit, qy = qy_graph(xb)
# for each proposed y, infer z and reconstruct x
z, zm, zv, zm_prior, zv_prior, px_logit = [[None] * 10 for i in xrange(6)]
for i in xrange(10):
with tf.name_scope('graphs/hot_at{:d}'.format(i)):
y = tf.add(y_, Constant(np.eye(10)[i], name='hot_at_{:d}'.format(i)))
z[i], zm[i], zv[i] = qz_graph(xb, y)
zm_prior[i], zv_prior[i], px_logit[i] = px_graph(z[i], y)
# Aggressive name scoping for pretty graph visualization :P
with tf.name_scope('loss'):
with tf.name_scope('neg_entropy'):
nent = -cross_entropy_with_logits(qy_logit, qy)
losses = [None] * 10
for i in xrange(10):
with tf.name_scope('loss_at{:d}'.format(i)):
losses[i] = labeled_loss(xb, px_logit[i], z[i], zm[i], zv[i], zm_prior[i], zv_prior[i])
with tf.name_scope('final_loss'):
loss = tf.add_n([nent] + [qy[:, i] * losses[i] for i in xrange(10)])
train_step = tf.train.AdamOptimizer().minimize(loss)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
# sess.run(tf.global_variables_initializer()) # Change initialization protocol depending on tensorflow version
sess_info = (sess, qy_logit, nent, loss, train_step)
train('logs/gmvae.log', mnist, sess_info, epochs=1000)