-
Notifications
You must be signed in to change notification settings - Fork 1
/
mnist_dcgan.py
41 lines (36 loc) · 1.32 KB
/
mnist_dcgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import tensorflow as tf
from src.models.gans import DCGAN
import argparse
from src import argparser
def main(argv=None):
nepochs = 1000
checkpoint_dir = '/fred/oz012/Bruno/checkpoints/' + str(FLAGS.checkpoint)
tensorboard_dir = '/fred/oz012/Bruno/tensorboard/' + str(FLAGS.checkpoint)
with tf.Session() as sess:
dcgan = DCGAN(sess=sess,
in_height=28,
in_width=28,
nchannels=1,
batch_size=128,
noise_dim=100,
mode='original',
opt_pars=(0.0001, 0.5, 0.999),
d_iters=1,
data_name='mnist',
pics_save_names=('mnist_data_','mnist_gen_'),
checkpoint_dir=checkpoint_dir,
tensorboard_dir=tensorboard_dir)
if FLAGS.mode == 'train':
dcgan.train(nepochs)
elif FLAGS.mode == 'generate':
dcgan.generate()
elif FLAGS.mode == 'predict':
dcgan.predict()
elif FLAGS.mode == 'save_weights':
dcgan.save_final_layer()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
FLAGS, _ = argparser.add_args(parser)
tf.logging.set_verbosity(tf.logging.DEBUG)
tf.app.run()