forked from shaohua0116/VAE-Tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ae.py
65 lines (54 loc) · 2.14 KB
/
ae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class Autoencoder(object):
def __init__(self, learning_rate=1e-4, batch_size=64, n_z=16):
self.learning_rate = learning_rate
self.batch_size = batch_size
self.n_z = n_z
tf.reset_default_graph()
self.build()
self.sess = tf.InteractiveSession()
self.sess.run(tf.global_variables_initializer())
# Build the netowrk and the loss functions
def build(self):
self.x = tf.placeholder(
name='x', dtype=tf.float32, shape=[None, input_dim])
# Encode
# x -> z_mean, z_sigma -> z
f1 = fc(self.x, 256, scope='enc_fc1', activation_fn=tf.nn.elu)
f2 = fc(f1, 128, scope='enc_fc2', activation_fn=tf.nn.elu)
f3 = fc(f2, 64, scope='enc_fc3', activation_fn=tf.nn.elu)
self.z = fc(f3, self.n_z, scope='enc_fc4', activation_fn=tf.nn.elu)
# Decode
# z -> x_hat
g1 = fc(self.z, 64, scope='dec_fc1', activation_fn=tf.nn.elu)
g2 = fc(g1, 128, scope='dec_fc2', activation_fn=tf.nn.elu)
g3 = fc(g2, 256, scope='dec_fc3', activation_fn=tf.nn.elu)
self.x_hat = fc(g3, input_dim, scope='dec_fc4',
activation_fn=tf.sigmoid)
# Loss
# Reconstruction loss
# Minimize the cross-entropy loss
# H(x, x_hat) = -\Sigma x*log(x_hat) + (1-x)*log(1-x_hat)
epsilon = 1e-10
recon_loss = -tf.reduce_sum(
self.x * tf.log(epsilon+self.x_hat) +
(1-self.x) * tf.log(epsilon+1-self.x_hat),
axis=1
)
self.recon_loss = tf.reduce_mean(recon_loss)
self.train_op = tf.train.AdamOptimizer(
learning_rate=self.learning_rate).minimize(self.recon_loss)
self.losses = {
'recon_loss': self.recon_loss
}
return
# Execute the forward and the backward pass
def run_single_step(self, x):
_, losses = self.sess.run(
[self.train_op, self.losses],
feed_dict={self.x: x}
)
return losses
# x -> x_hat
def reconstructor(self, x):
x_hat = self.sess.run(self.x_hat, feed_dict={self.x: x})
return x_hat