-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcifar10_dcgan.py
31 lines (27 loc) · 933 Bytes
/
cifar10_dcgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import tensorflow as tf
from src.models.gans import DCGAN
import argparse
from src import argparser
def main(argv=None):
nepochs = 1000
checkpoint_dir = '/fred/oz012/Bruno/checkpoints/' + str(FLAGS.checkpoint)
with tf.Session() as sess:
dcgan = DCGAN(sess=sess,
in_height=32,
in_width=32,
nchannels=3,
batch_size=128,
noise_dim=100,
mode='wgan-gp',
opt_pars=(0.0001, 0., 0.9),
d_iters=5,
data_name='cifar10',
checkpoint_dir=checkpoint_dir)
if FLAGS.mode == 'train':
dcgan.train(nepochs)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
FLAGS, _ = argparser.add_args(parser)
tf.logging.set_verbosity(tf.logging.DEBUG)
tf.app.run()