-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathshared_subgraphs.py
34 lines (31 loc) · 1.36 KB
/
shared_subgraphs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import tensorflow as tf
from tensorbayes.layers import Constant, Placeholder, Dense, GaussianSample
from tensorbayes.distributions import log_bernoulli_with_logits, log_normal
from tensorbayes.tbutils import cross_entropy_with_logits
import numpy as np
import sys
# vae subgraphs
def qy_graph(x, k=10):
reuse = len(tf.get_collection(tf.GraphKeys.VARIABLES, scope='qy')) > 0
# -- q(y)
with tf.variable_scope('qy'):
h1 = Dense(x, 512, 'layer1', tf.nn.relu, reuse=reuse)
h2 = Dense(h1, 512, 'layer2', tf.nn.relu, reuse=reuse)
qy_logit = Dense(h2, k, 'logit', reuse=reuse)
qy = tf.nn.softmax(qy_logit, name='prob')
return qy_logit, qy
def qz_graph(x, y):
reuse = len(tf.get_collection(tf.GraphKeys.VARIABLES, scope='qz')) > 0
# -- q(z)
with tf.variable_scope('qz'):
xy = tf.concat(1, (x, y), name='xy/concat')
h1 = Dense(xy, 512, 'layer1', tf.nn.relu, reuse=reuse)
h2 = Dense(h1, 512, 'layer2', tf.nn.relu, reuse=reuse)
zm = Dense(h2, 64, 'zm', reuse=reuse)
zv = Dense(h2, 64, 'zv', tf.nn.softplus, reuse=reuse)
z = GaussianSample(zm, zv, 'z')
return z, zm, zv
def labeled_loss(x, px_logit, z, zm, zv, zm_prior, zv_prior):
xy_loss = -log_bernoulli_with_logits(x, px_logit)
xy_loss += log_normal(z, zm, zv) - log_normal(z, zm_prior, zv_prior)
return xy_loss - np.log(0.1)