diff --git a/README.md b/README.md new file mode 100644 index 0000000..e306c83 --- /dev/null +++ b/README.md @@ -0,0 +1,134 @@ +# 对抗生成网络 + +**对抗式生成网络(Generative Adversarial Network)** + +arXiv:[https://arxiv.org/abs/1406.2661](https://arxiv.org/abs/1406.2661) + +**条件生成对抗网络(Conditional Generative Adversarial Nets)** + +arXiv:https://arxiv.org/abs/1411.1784 + ++ [机器之心-GAN的完整理论推导与实现](https://github.com/jiqizhixin/ML-Tutorial-Experiment) + ++ [GAN完整理论推导与实现](https://www.jiqizhixin.com/articles/2017-10-1-1) + + 一篇很好地理论推导证明,语言通俗易懂,公式清晰明了,非常值的一读。 + ++ [台大李宏毅GAN讲解视频](https://www.youtube.com/watch?v=0CKeqXl5IY0)(需要翻墙) + +## github + ++ [nightrome/really-awesome-gan](https://github.com/nightrome/really-awesome-gan) + ++ [wiseodd/generative-models](https://github.com/wiseodd/generative-models) + ++ [bstriner/keras-adversarial](https://github.com/bstriner/keras-adversarial/) + ++ [kvfrans/generative-adversial](https://github.com/kvfrans/generative-adversial) + ++ [osh/KerasGAN](https://github.com/osh/KerasGAN) + ++ [Eyyub/tensorflow-cdcgan](https://github.com/Eyyub/tensorflow-cdcgan) + ++ [soumith/ganhacks(训练gan的有效方法)](https://github.com/soumith/ganhacks) + +## 通俗语言解释如何训练 + +我们共同训练生成器和辨别器,让他们变得强壮,通过反复训练防止其中一个网络比另一个网络强大太多。 + +为什么轮回训练网络使双方共同变强而不是单独训练让他们的性能更强大? + +如果其中一个网络太强大,另外一个会因能力太差而导致两个网络性能都弱化。一个网络不知道自己在跟低级的网络竞争而导致其认为自己很高级。自作聪明的网络就会对低级的网络过拟合。 + +**训练辨别器** + +给它一张训练集中的图片和一张生成器生成的图片,如果得到的是生成图片辨别器应该输出 0,如果是真实的图片应该输出 1。 + +从技术性的角度:交叉熵的损失可以由最优控制器弥补,小菜一碟! + +**训练生成器** + +生成器必须努力让辨别器在得到它生成的图片后输出 1。 + +现在,这有一个有趣的部分。 + + 假设生成器生成了一张图片,辨别器认为这张图片有 0.4 的概率是真实图片。生成器如何调整它生成的图片来增加这个概率,比如说增加到 0.41? + +答案就是: + +为训练生成器,辨别器不得不告诉生成器如何调整从而使它生成的图片变得更加真实。 + +生成器必须向辨别器寻求建议! + +直观来说,辨别器告诉生成器每个像素应调整多少来使整幅图像更真实一点点。 + +技术上来说,通过反向传播辨别器输出的梯度来调整生成图片。以这种方式训练生成器,你将会得到与图片形状一样的梯度向量。 + +如果你把这些梯度加到生成的图片上,在辨别器看来,图片就会变得更真实一点。 + +但是我们不仅仅把梯度加到图片上。 + +相反,我们进一步反向传播这些图片梯度成为组成生成器的权重,这样一来,生成器就学习到如何生成这幅新图片。 + +我重复一遍,为生成好的图片,你必须向老师展示你的工作,得到反馈! + +如果辨别器不帮助生成器的话,那就太残酷了,因为生成器实际做的工作比辨别器更艰难,它生成图片! + +这就是生成器如何被训练的。 + +就像这样,来回训练生成器和辨别器,直到达到一个平衡状态。 + +如果你很困惑,这是在初期盲目状态下,两个网络努力学习对话的直观感受: + + G:我有一张人脸图片,它跟你以前见过的相比,足够真实吗? + + D:比较真实但也比较像是你生成的图片。(对真实图片,辨别器产生 0.4 的概率) 我不太确定但我猜你给我的应该是一张生成的图片。 + + G:你猜对了!是我生成的一张图片。我应该怎样调整来让它更真实呢? + + D:让我想一下 (实际上在大脑里在做反向传播运算) 我认为你应该往图片里添加一对眼睛,人脸图片通常会包含眼睛。 + + (技术上来说:我认为你应该增加第 0 号像素的灰度值增加 1,第 1 号像素的灰度值减少 5 个,..., 第 4095 个像素的灰度值增加 8 个) + + G:收到 (反向传播那些梯度给所有的权重) + +**Dumbness** + +以上是一段比较初级的对话。双方都很白目。辨别器甚至不确定面部是否应该包含眼睛。它甚至说生成的没有眼睛的图片真实!(一个高级的辨别器对这张图片一定会说不,因为一张人脸图片肯定会包含眼睛!) + +经过一段时间的训练,它们会变得越来越聪明,直到他们达到非常高级的最优状态。 + +这里是两网络在最优高级状态学习时对话的直观感受: + + G:我有一张人脸图片,它跟你以前见过的人脸图片相比足够真实吗? + + D:这张图片真的很真实 (对真实图片,辨别器会产生 0.5 的概率) 但是这张图片是不是真的,我完全没有头绪。因为显而易见的是,你在生成真实图片上做的太好了。 + + G:这是我生成的一张图片。我知道这已经是真实的了但是我想要更多,我应该如何调整来使它变的更真实? + + D:让我想一下 (实际上大脑里在做反向传播) 我认为你的图片已经有了我认为需要有的部分。我看起来非常真实。显然你的图片包含眼睛,嘴巴,耳朵,头发,图片里是一张年轻男孩的脸。我不认为我有建议的东西。但是如果你想的话,可以把年轻男孩的胡须去掉。 + + (技术上来说,我认为你第 0 个像素灰度值增加 6,第 1 个像素灰度值减少 7,...,第 4095 个像素灰度值增加 2。) + + G:收到 (反向传播那些梯度给所有的权重) + +**Cleverness** + +它们变的高级之后,生成器会生成真实的图片,辨别器不再能辨别生成的图片。 + +它们在无人监督的情况下也都能理解胡须,眼睛,嘴巴,头发,年轻的脸庞。 + +你已经达到了一种平衡。 + +如果你持续不断的教导生成器如何使照片更加真实,就会很可能过拟合,就像辨别器会认为一个小男孩根本就不应该有胡子一样。辨别器会产生这样的想法,但是这可能不对。就像你不应太过依赖老师的意见一样。继续训练也不会得到任何东西。 + +**结论** + +两个网络并不是一直都在斗争,它们不得不协同合作以达到共同的目标。在整个训练过程中,辨别器不得不教导生成器如何在生成的数据上微做调整。同时它也一直都在学习如何做一个更好的老师。 + +它们共同变强,在理想状态下,会达到一种平衡。 + + +## 训练收敛性问题 + +GAN的主要问题之一就是它的收敛性问题。即使优化了GAN的架构,也不能保证训练的稳定性。随着训练轮数的增多,并不能保证模型的效果越来越好,你不知道何时停止训练。也就是说损失函数和图像质量不相关。 diff --git a/cgan_keras.py b/cgan_keras.py new file mode 100644 index 0000000..828508e --- /dev/null +++ b/cgan_keras.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- + +""" +@author:sunwill + +A implemention of conditional genertive adversarial network using keras + +reference paper: +arXiv:https://arxiv.org/abs/1411.1784 + +""" +import math +import numpy as np +from PIL import Image +from keras.datasets.mnist import load_data +from keras.optimizers import Adam +from keras.layers import concatenate +from keras.models import Sequential,Model +from keras.layers import Dense, Reshape, Conv2D, UpSampling2D, Input, Flatten, LeakyReLU, Dropout +from keras.losses import binary_crossentropy +from keras.utils import plot_model, to_categorical + +image_size = 28 +image_channel = 1 + +latent_size = 100 +y_dim = 10 +batch_size = 64 +epochs = 30 +learning_rate = 2e-4 + + +def generator(): + cnn = Sequential() + + cnn.add(Dense(1024, input_dim=latent_size+y_dim, activation='tanh')) + cnn.add(Dense(128 * 7 * 7, activation='tanh')) + cnn.add(Reshape((7, 7, 128))) + + # upsample to (14, 14, ...) + cnn.add(UpSampling2D(size=(2, 2))) + cnn.add(Conv2D(256, 5, padding='same', + activation='tanh', + kernel_initializer='glorot_normal')) + + # upsample to (28, 28, ...) + cnn.add(UpSampling2D(size=(2, 2))) + cnn.add(Conv2D(128, 5, padding='same', + activation='tanh', + kernel_initializer='glorot_normal')) + + # take a channel axis reduction + cnn.add(Conv2D(1, 2, padding='same', + activation='tanh', + kernel_initializer='glorot_normal')) + + input1 = Input(shape=(latent_size, )) + input2 = Input(shape=(y_dim,)) + inputs = concatenate([input1, input2], axis=1) + outs = cnn(inputs) + + return Model(inputs=[input1, input2], outputs=outs) + + +def discriminator(): + + cnn = Sequential() + + cnn.add(Conv2D(32, 3, padding='same', strides=2, + input_shape=(28, 28, image_channel+y_dim))) + cnn.add(LeakyReLU()) + cnn.add(Dropout(0.3)) + + cnn.add(Conv2D(64, 3, padding='same', strides=1)) + cnn.add(LeakyReLU()) + cnn.add(Dropout(0.3)) + + cnn.add(Conv2D(128, 3, padding='same', strides=2)) + cnn.add(LeakyReLU()) + cnn.add(Dropout(0.3)) + + cnn.add(Conv2D(256, 3, padding='same', strides=1)) + cnn.add(LeakyReLU()) + cnn.add(Dropout(0.3)) + + cnn.add(Flatten()) + cnn.add(Dense(1, activation='sigmoid')) + + inputs = Input(shape=(image_size, image_size, image_channel+y_dim)) + outs = cnn(inputs) + + return Model(inputs=inputs, outputs=outs) + + +def disc_on_gen(g, d): + + input1 = Input(shape=(image_size, image_size, y_dim)) + input2 = Input(shape=(latent_size,)) + input3 = Input(shape=(y_dim,)) + + g_out = g([input2, input3]) + d_input = concatenate([g_out, input1], axis=3) + d.trainable = False + outs = d(d_input) + model = Model(inputs=[input1, input2, input3], outputs=outs) + return model + + +def combine_images(images): + num = images.shape[0] + images = np.reshape(images, (-1, 28, 28)) + width = int(math.sqrt(num)) + height = int(math.ceil(float(num)/width)) + shape = images.shape[1:3] + image = np.zeros((height*shape[0], width*shape[1]), + dtype=images.dtype) + for index, img in enumerate(images): + i = int(index/width) + j = index % width + image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = img[:, :] + return image + + +def random_sample(x, y, batch_size): + x_bs = [] + y_bs = [] + i = 0 + while (i < batch_size): + rand = np.random.randint(0, x.shape[0]) + x_bs.append(x[rand]) + y_bs.append(y[rand]) + i += 1 + return np.array(x_bs), np.array(y_bs) + + +def train(): + (X_train, y_train), (X_test, y_test) = load_data() + y_train = to_categorical(y_train) + print X_train.shape ## (60000,28,28) + print y_train.shape ## (60000,10) + + num_samples = X_train.shape[0] + + X_train = np.expand_dims(X_train, axis=3) + X_train = (X_train.astype(np.float32) - 127.5) / 127.5 + + g = generator() + d = discriminator() + d_on_g = disc_on_gen(g, d) + + g_optimiper = Adam(lr=learning_rate) + d_optimizer = Adam(lr=learning_rate) + + g.compile(loss=binary_crossentropy, optimizer=g_optimiper) + d_on_g.compile(loss='binary_crossentropy', optimizer=g_optimiper) + + d.trainable = True + d.compile(loss='binary_crossentropy', optimizer=d_optimizer, metrics=['accuracy']) + + plot_model(g, to_file='./model/cgan_generator.png', show_shapes=True) + plot_model(d, to_file='./model/cgan_discriminator.png', show_shapes=True) + plot_model(d_on_g, to_file='./model/cgan.png', show_shapes=True) + p = 0 + for epoch in range(epochs): ## 多轮训练 + print 'epoch {}/{}'.format(epoch + 1, epochs) + + for i in range(num_samples / batch_size): ## 在每一轮迭代里面训练 + ## 随机生成高斯噪声 + noise = np.random.uniform(-1, 1, size=(batch_size, latent_size)) + ## 随机采样真实图片 + x_bs, y_bs = random_sample(X_train, y_train, batch_size) + + generate_images = g.predict([noise, y_bs], verbose=0) + # print generate_images.shape + ## 每经过500次训练输出生成图像 + if i % 500 == 0: + images = combine_images(generate_images) + images = images * 127.5 + 127.5 + Image.fromarray(images.astype(np.uint8)).save('./images/generated_{}_{}.png'.format(str(epoch + 1), i)) + + ## 训练判别器 + xs = np.concatenate([generate_images, x_bs]) + ys = np.concatenate([y_bs, y_bs]) + ys = np.reshape(ys, newshape=[-1, 1, 1, y_dim]) + ys = np.tile(ys, [1, 28, 28, 1]) + X = np.concatenate([xs, ys], axis=3) + + y = [0] * batch_size + [1] * batch_size + + d_loss, acc = d.train_on_batch(X, y) + if i % 100 == 0: + print 'epoch {}, iter {},d_loss = {}, acc = {}'.format(epoch + 1, i, d_loss, acc) + ## 训练生成器,此时需要固定判别器 + + d.trainable = False + + g_loss = d_on_g.train_on_batch([ys[:batch_size], noise, y_bs], [1] * batch_size) + if i % 100 == 0: + print 'epoch {}, iter {},g_loss = {} '.format(epoch + 1, i, g_loss) + + d.trainable = True + if i%500 == 0: + noise = np.random.uniform(-1, 1, size=(100, latent_size)) + ys = np.zeros(shape=(100, y_dim)) + for c in range(10): + ys[c * 10:(c + 1) * 10, c] = 1 + generate_images = g.predict([noise, ys]) + images = combine_images(generate_images) + images = images * 127.5 + 127.5 + Image.fromarray(images.astype(np.uint8)).save('./logs/cgan_{}.png'.format(p)) + p += 1 + + g.save_weights('./images/cgan_generator.h5'.format(epoch)) + d.save_weights('./images/cgan_discriminator.h5'.format(epoch)) + + +def generate(batch_size, flag=True): + g = generator(latent_size) + g.compile(optimizer=Adam(lr=learning_rate), loss='binary_crossentropy') + g.load_weights('./logs/generator.h5') + if flag: ##生成多张图片,选出最好的几张图片 + d = discriminator(image_size, image_channel) + d.compile(optimizer=Adam(lr=learning_rate), loss='binary_crossentropy') + d.load_weights('./logs/discriminator.h5') + noise = np.random.uniform(-1, 1, size=(batch_size * 10, latent_size)) + generate_images = g.predict(noise) + d_pred = d.predict(generate_images) + index = np.reshape(np.arange(0, batch_size * 10), (-1, 1)) + index_with_prob = list(np.append(index, d_pred, axis=1)) + index_with_prob.sort(key=lambda x: x[0], reverse=True) + nices = np.zeros(shape=((batch_size,) + generate_images.shape[1:])) + for i in range(batch_size): + idx = int(index_with_prob[i][0]) + nices[i] = generate_images[idx] + images = combine_images(nices) + else: + noise = np.random.uniform(-1, 1, size=(batch_size, latent_size)) + generate_images = g.predict(noise) + images = combine_images(generate_images) + + Image.fromarray(images).save('./generated_images.png') + + +train() + +# generate(64) + diff --git a/cgan_tf.py b/cgan_tf.py new file mode 100644 index 0000000..562fa13 --- /dev/null +++ b/cgan_tf.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- + +""" +@author:sunwill + +A implemention of conditional genertive adversarial network using tensorflow + +reference paper: +arXiv:https://arxiv.org/abs/1411.1784 + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +from PIL import Image +import tensorflow as tf +from tensorflow.examples.tutorials.mnist import input_data +import tensorflow.contrib.layers as tcl + +save_dir = './logs/' + +learning_rate = 0.0002 +batch_size = 64 +# X_dim = 784 +image_size = 28 +y_dim = 10 +z_dim = 100 + +mnist = input_data.read_data_sets('./datasets/mnist/', one_hot=True) +images = mnist.train.images +labels = mnist.train.labels + +print(images.shape) +print(labels.shape) + + +def get_shape(tensor): # static shape + return tensor.get_shape().as_list() + + +def lkrelu(x, slope=0.01): + return tf.maximum(slope * x, x) + + +## building the generator +def generator(zs, ys, reuse=False, is_training=True): + with tf.variable_scope('Generator', initializer=tf.truncated_normal_initializer(stddev=0.02), reuse=reuse): + # TensorFlow Layers automatically create variables and calculate their + # shape, based on the input. + x = tf.concat([zs, ys], axis=1) + + g = tcl.fully_connected(x, 7 * 7 * 512, activation_fn=lkrelu, normalizer_fn=tcl.batch_norm) + g = tf.reshape(g, (-1, 7, 7, 512)) + + print(get_shape(g)) + g = tcl.conv2d(g, 128, 3, stride=1, # (batch,7, 7 ,128) + activation_fn=lkrelu, normalizer_fn=tcl.batch_norm, padding='SAME', + weights_initializer=tf.random_normal_initializer(0, 0.02)) + print(get_shape(g)) + g = tcl.conv2d_transpose(g, 64, 4, stride=2, # (batch,14, 14 ,64) + activation_fn=lkrelu, normalizer_fn=tcl.batch_norm, padding='SAME', + weights_initializer=tf.random_normal_initializer(0, 0.02)) + print(get_shape(g)) + g = tcl.conv2d_transpose(g, 1, 4, stride=2, # (batch,28, 28 ,1) + activation_fn=tf.nn.tanh, padding='SAME', + weights_initializer=tf.random_normal_initializer(0, 0.02)) + print(get_shape(g)) + return g + + +## building the discriminatior +def discriminator(xs, ys, reuse=False, is_training=True): + with tf.variable_scope('Discriminator', initializer=tf.truncated_normal_initializer(stddev=0.02), reuse=reuse): + # Typical convolutional neural network to classify images. + x = tf.layers.conv2d(xs, 32, kernel_size=[5, 5], padding='SAME') + x = tf.concat([x, tf.tile(tf.reshape(ys, [-1, 1, 1, get_shape(ys)[-1]]), + [1, tf.shape(x)[1], tf.shape(x)[2], 1])], axis=3) + x = lkrelu(x) + x = tf.layers.conv2d(x, 16, kernel_size=[5, 5], padding='SAME') + x = tf.layers.batch_normalization(x, axis=3, training=is_training) + x = lkrelu(x) + x = tf.layers.conv2d(x, 8, kernel_size=[5, 5]) + x = tf.layers.batch_normalization(x, axis=3, training=is_training) + x = lkrelu(x) + x = tf.contrib.layers.flatten(x) + x = tf.layers.dense(x, 1024) + x = tf.layers.dropout(x, rate=0.5) + x = tf.layers.dense(x, 1) + return tf.nn.sigmoid(x), x + + +def sample_data(z_dim, batch_size): + return np.random.uniform(-1, 1, size=(batch_size, z_dim)) + + +def combine_images(images): + images = np.reshape(images, (-1, 28, 28)) + num = images.shape[0] + width = int(math.sqrt(num)) + height = int(math.ceil(float(num) / width)) + shape = images.shape[1:] + image = np.zeros((height * shape[0], width * shape[1]), + dtype=images.dtype) + for index, img in enumerate(images): + i = int(index / width) + j = index % width + image[i * shape[0]:(i + 1) * shape[0], j * shape[1]:(j + 1) * shape[1]] = img + return image + + +noise = tf.placeholder(tf.float32, shape=(None, z_dim)) +X = tf.placeholder(tf.float32, shape=(None, image_size, image_size, 1)) +y = tf.placeholder(tf.float32, shape=(None, y_dim)) + +generated_images = generator(noise, y) # pixes values between -1 and 1 + +d_real, d_real_logit = discriminator(X, y) +d_fake, d_fake_logit = discriminator(generated_images, y, reuse=True) + +d_out = tf.concat(values=[d_real, d_fake], axis=0) +d_true = tf.concat(values=[tf.ones_like(d_real, dtype=tf.int64), tf.zeros_like(d_fake, dtype=tf.int64)], axis=0) +corrected = tf.equal(tf.cast(tf.greater(d_out, 0.5), tf.int64), d_true) +d_accuracy = tf.reduce_mean(tf.cast(corrected, tf.float32)) + +d_loss_real = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real_logit), logits=d_real_logit)) +d_loss_fake = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake_logit), logits=d_fake_logit)) +d_loss = d_loss_real + d_loss_fake + +gen_loss = tf.reduce_mean( + tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_fake_logit), logits=d_fake_logit)) + +gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator') +disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator') + +gen_train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5).minimize(gen_loss, var_list=gen_vars) +disc_train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5).minimize(d_loss, var_list=disc_vars) + +init = tf.global_variables_initializer() +sess = tf.Session() +saver = tf.train.Saver() + +with tf.Session() as sess: + sess.run(init) + for epoch in range(100000): + if epoch % 1000 == 0: + n_sample = 100 + noise_sample = sample_data(z_dim, n_sample) + y_sample = np.zeros(shape=(n_sample, y_dim)) + for c in range(10): + y_sample[c*10:(c+1)*10, c] = 1 + gen_images = sess.run(generated_images, feed_dict={noise: noise_sample, y: y_sample}) + images = combine_images(gen_images) + print(' ---> ', images.max()) + images = images * 127.5 + 127.5 + print('----> ', images.max()) + Image.fromarray(images.astype(np.uint8)).save(save_dir + 'generated_{}.png'.format(epoch/1000)) + + x_batch, y_batch = mnist.train.next_batch(batch_size) + x_batch = np.reshape(x_batch, [-1, image_size, image_size, 1]) + ## if not did the normalization,the network would work,otherwise it did't work + # x_batch = (x_batch.astype(np.uint8) - 127.5) / 127.5 + noise_batch = sample_data(z_dim, batch_size) + d_loss_, _, acc = sess.run([d_loss, disc_train_op, d_accuracy], + feed_dict={X: x_batch, y: y_batch, noise: noise_batch}) + g_loss_, _ = sess.run([gen_loss, gen_train_op], feed_dict={noise: noise_batch, y: y_batch}) + + if epoch % 100 == 0: + print('epoch {}, d_loss={}, acc = {}, g_loss={}'.format(epoch, d_loss_, acc, g_loss_)) + saver.save(sess, './model/cgan_tf.ckpt') diff --git a/gan.py b/gan.py new file mode 100644 index 0000000..a27aec9 --- /dev/null +++ b/gan.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- + +""" +@author:sunwill + +An implemention of generative adversarial network using keras + +reference paper: +arXiv:https://arxiv.org/abs/1406.2661 +""" +import math +import numpy as np +from PIL import Image +from keras.datasets.mnist import load_data +from keras.optimizers import SGD +from keras.losses import binary_crossentropy +from keras.utils import plot_model +from keras.models import Model, Sequential +from keras.layers import Dense, Input, Conv2D, BatchNormalization, MaxPooling2D, Flatten, Reshape, \ + UpSampling2D + + +image_size = 28 +image_channel = 1 + +latent_size = 100 +batch_size = 32 +epochs = 30 +learning_rate = 1e-3 + + +## 定义生成器g +def generator(latent_size): + input = Input(shape=(latent_size,)) + + x = Dense(128, activation='tanh')(input) + x = Dense(128 * 7 * 7, activation='tanh')(x) + x = BatchNormalization(axis=1)(x) + x = Reshape((7, 7, 128))(x) + x = UpSampling2D(size=(2, 2))(x) + x = Conv2D(64, kernel_size=(3, 3), activation='tanh', padding='same')(x) + x = UpSampling2D(size=(2, 2))(x) + output = Conv2D(1, kernel_size=(5, 5), activation='tanh', padding='same')(x) + model = Model(inputs=input, outputs=output) + + return model + + +## 定义判别器d +def discriminator(image_size, image_channel): + input = Input(shape=(image_size, image_size, image_channel)) + + x = Conv2D(64, kernel_size=(5, 5), activation='tanh', padding='same')(input) + x = MaxPooling2D(pool_size=(2, 2))(x) + + x = Conv2D(128, kernel_size=(5, 5), activation='tanh', padding='same')(x) + x = MaxPooling2D(pool_size=(2, 2))(x) + + x = Flatten()(x) + + x = Dense(1024, activation='tanh')(x) + output = Dense(1, activation='sigmoid')(x) + + model = Model(inputs=input, outputs=output) + return model + + +## 将生成器和判别器拼接成一个模型,用于训练生成器 +def generator_on_disciminator(g, d): + ## 将前面定义的生成器架构和判别器架构组拼接成一个大的神经网络,用于判别生成的图片 + model = Sequential() + ## 先添加生成器架构,再令d不可训练,即固定d + ## 因此在给定d的情况下训练生成器,即通过将生成的结果投入到判别器进行辨别而优化生成器 + model.add(g) + d.trainable = False + model.add(d) + return model + + +def random_sample(data, batch_size): + ret_data = [] + i = 0 + while (i < batch_size): + rand = np.random.randint(0, data.shape[0]) + ret_data.append(data[rand]) + i += 1 + return np.array(ret_data) + + +def combine_images(images): + #生成图片拼接 + num = images.shape[0] + width = int(math.sqrt(num)) + height = int(math.ceil(float(num)/width)) + shape = images.shape[1:3] + image = np.zeros((height*shape[0], width*shape[1]), + dtype=images.dtype) + for index, img in enumerate(images): + i = int(index/width) + j = index % width + image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \ + img[:, :, 0] + return image + +(X_train, y_train), (X_test, y_test) = load_data() +print X_train.shape ## (60000,28,28) +print X_test.shape ## (10000,28,28) + +num_samples = X_train.shape[0] + +X_train = np.expand_dims(X_train, axis=3) +X_train = (X_train.astype(np.float32)-127.5)/127.5 + +g = generator(latent_size) +d = discriminator(image_size, image_channel) +g_on_d = generator_on_disciminator(g, d) + +g_optimiper = SGD(lr=0.001, momentum=0.9, nesterov=True) +d_optimizer = SGD(lr=0.001, momentum=0.9, nesterov=True) + +g.compile(loss=binary_crossentropy, optimizer='SGD') +g_on_d.compile(loss='binary_crossentropy', optimizer=g_optimiper) + +d.trainable = True +d.compile(loss='binary_crossentropy', optimizer=d_optimizer, metrics=['accuracy']) + +plot_model(g, to_file='./model/gan_generator.png', show_shapes=True) +plot_model(d, to_file='./model/gan_discriminator.png', show_shapes=True) + +for epoch in range(epochs):## 多轮训练 + print 'epoch {}/{}'.format(epoch+1, epochs) + + for i in range(num_samples/batch_size):## 在每一轮迭代里面训练 + ## 随机生成高斯噪声 + # noise = np.random.uniform(-1, 1, size=(batch_size, image_size, image_size, image_channel)) + noise = np.random.uniform(-1, 1, size=(batch_size, 100)) + ## 随机采样真实图片 + real_samples = random_sample(X_train, batch_size) + + generate_images = g.predict(noise, verbose=0) + # print generate_images.shape + ## 每经过500次训练输出生成图像 + if i%500 == 0: + images = combine_images(generate_images) + images = images*127.5+127.5 + Image.fromarray(images.astype(np.uint8)).save('./images/generated_{}_{}.png'.format(str(epoch+1), i)) + + ## 训练判别器 + X = np.concatenate([generate_images, real_samples]) + y = [0]*batch_size+[1]*batch_size + + d_loss, acc = d.train_on_batch(X, y) + if i%100 == 0: + print 'epoch {}, iter {},d_loss = {}, acc = {}'.format(epoch+1, i, d_loss, acc) + ## 训练生成器,此时需要固定判别器 + + d.trainable = False + + g_loss = g_on_d.train_on_batch(noise, [1]*batch_size) + if i % 100 == 0: + print 'epoch {}, iter {},g_loss = {} '.format(epoch+1, i, g_loss) + + d.trainable = True + + g.save_weights('./logs/generator_epoch{}.h5'.format(epoch)) + d.save_weights('./logs/discriminator_eopch{}.h5'.format(epoch)) + + +def generate(batch_size,flag=True): + g = generator(latent_size) + g.compile(optimizer='SGD', loss='binary_crossentropy') + g.load_weights('./logs/generator.h5') + if flag: ##生成多张图片,选出最好的几张图片 + d = discriminator(image_size, image_channel) + d.compile(optimizer='SGD', loss='binary_crossentropy') + d.load_weights('./logs/discriminator.h5') + noise = np.random.uniform(-1, 1, size=(batch_size*10, latent_size)) + generate_images = g.predict(noise) + d_pred = d.predict(generate_images) + index = np.reshape(np.arange(0, batch_size*10), (-1, 1)) + index_with_prob = list(np.append(index, d_pred, axis=1)) + index_with_prob.sort(key=lambda x:x[0], reverse=True) + nices = np.zeros(shape=((batch_size,)+generate_images.shape[1:])) + for i in range(batch_size): + idx = int(index_with_prob[i][0]) + nices[i] = generate_images[idx] + images = combine_images(nices) + else: + noise = np.random.uniform(-1, 1, size=(batch_size, latent_size)) + generate_images = g.predict(noise) + images = combine_images(generate_images) + + Image.fromarray(images).save('./generated_images.png') + +generate(64) \ No newline at end of file diff --git a/images/cgan_gen.gif b/images/cgan_gen.gif new file mode 100644 index 0000000..a2d7538 Binary files /dev/null and b/images/cgan_gen.gif differ diff --git a/images/gen.gif b/images/gen.gif new file mode 100644 index 0000000..dcdaeee Binary files /dev/null and b/images/gen.gif differ diff --git a/logs/cgan_0.png b/logs/cgan_0.png new file mode 100644 index 0000000..8ad850c Binary files /dev/null and b/logs/cgan_0.png differ diff --git a/logs/cgan_1.png b/logs/cgan_1.png new file mode 100644 index 0000000..ccfe772 Binary files /dev/null and b/logs/cgan_1.png differ diff --git a/logs/cgan_10.png b/logs/cgan_10.png new file mode 100644 index 0000000..02fd3f3 Binary files /dev/null and b/logs/cgan_10.png differ diff --git a/logs/cgan_11.png b/logs/cgan_11.png new file mode 100644 index 0000000..378c32f Binary files /dev/null and b/logs/cgan_11.png differ diff --git a/logs/cgan_12.png b/logs/cgan_12.png new file mode 100644 index 0000000..83bc301 Binary files /dev/null and b/logs/cgan_12.png differ diff --git a/logs/cgan_13.png b/logs/cgan_13.png new file mode 100644 index 0000000..71a49a1 Binary files /dev/null and b/logs/cgan_13.png differ diff --git a/logs/cgan_14.png b/logs/cgan_14.png new file mode 100644 index 0000000..c7e5d36 Binary files /dev/null and b/logs/cgan_14.png differ diff --git a/logs/cgan_15.png b/logs/cgan_15.png new file mode 100644 index 0000000..d3327c0 Binary files /dev/null and b/logs/cgan_15.png differ diff --git a/logs/cgan_16.png b/logs/cgan_16.png new file mode 100644 index 0000000..263dc54 Binary files /dev/null and b/logs/cgan_16.png differ diff --git a/logs/cgan_17.png b/logs/cgan_17.png new file mode 100644 index 0000000..b9368a0 Binary files /dev/null and b/logs/cgan_17.png differ diff --git a/logs/cgan_18.png b/logs/cgan_18.png new file mode 100644 index 0000000..741cd12 Binary files /dev/null and b/logs/cgan_18.png differ diff --git a/logs/cgan_19.png b/logs/cgan_19.png new file mode 100644 index 0000000..5084922 Binary files /dev/null and b/logs/cgan_19.png differ diff --git a/logs/cgan_2.png b/logs/cgan_2.png new file mode 100644 index 0000000..15cd250 Binary files /dev/null and b/logs/cgan_2.png differ diff --git a/logs/cgan_20.png b/logs/cgan_20.png new file mode 100644 index 0000000..7600f6c Binary files /dev/null and b/logs/cgan_20.png differ diff --git a/logs/cgan_21.png b/logs/cgan_21.png new file mode 100644 index 0000000..86dfb31 Binary files /dev/null and b/logs/cgan_21.png differ diff --git a/logs/cgan_22.png b/logs/cgan_22.png new file mode 100644 index 0000000..73dfdc0 Binary files /dev/null and b/logs/cgan_22.png differ diff --git a/logs/cgan_23.png b/logs/cgan_23.png new file mode 100644 index 0000000..7a0c33b Binary files /dev/null and b/logs/cgan_23.png differ diff --git a/logs/cgan_24.png b/logs/cgan_24.png new file mode 100644 index 0000000..29edefd Binary files /dev/null and b/logs/cgan_24.png differ diff --git a/logs/cgan_25.png b/logs/cgan_25.png new file mode 100644 index 0000000..41636d4 Binary files /dev/null and b/logs/cgan_25.png differ diff --git a/logs/cgan_26.png b/logs/cgan_26.png new file mode 100644 index 0000000..5e7623d Binary files /dev/null and b/logs/cgan_26.png differ diff --git a/logs/cgan_27.png b/logs/cgan_27.png new file mode 100644 index 0000000..727f9f2 Binary files /dev/null and b/logs/cgan_27.png differ diff --git a/logs/cgan_28.png b/logs/cgan_28.png new file mode 100644 index 0000000..d78b3d4 Binary files /dev/null and b/logs/cgan_28.png differ diff --git a/logs/cgan_29.png b/logs/cgan_29.png new file mode 100644 index 0000000..c41dd19 Binary files /dev/null and b/logs/cgan_29.png differ diff --git a/logs/cgan_3.png b/logs/cgan_3.png new file mode 100644 index 0000000..358ff31 Binary files /dev/null and b/logs/cgan_3.png differ diff --git a/logs/cgan_4.png b/logs/cgan_4.png new file mode 100644 index 0000000..09db451 Binary files /dev/null and b/logs/cgan_4.png differ diff --git a/logs/cgan_5.png b/logs/cgan_5.png new file mode 100644 index 0000000..834c2e3 Binary files /dev/null and b/logs/cgan_5.png differ diff --git a/logs/cgan_6.png b/logs/cgan_6.png new file mode 100644 index 0000000..df57dd7 Binary files /dev/null and b/logs/cgan_6.png differ diff --git a/logs/cgan_7.png b/logs/cgan_7.png new file mode 100644 index 0000000..1fc4af7 Binary files /dev/null and b/logs/cgan_7.png differ diff --git a/logs/cgan_8.png b/logs/cgan_8.png new file mode 100644 index 0000000..7c8419c Binary files /dev/null and b/logs/cgan_8.png differ diff --git a/logs/cgan_9.png b/logs/cgan_9.png new file mode 100644 index 0000000..3614657 Binary files /dev/null and b/logs/cgan_9.png differ diff --git a/logs/generated_0.png b/logs/generated_0.png new file mode 100644 index 0000000..d3f63ec Binary files /dev/null and b/logs/generated_0.png differ diff --git a/logs/generated_1.0.png b/logs/generated_1.0.png new file mode 100644 index 0000000..c08e4f1 Binary files /dev/null and b/logs/generated_1.0.png differ diff --git a/logs/generated_10.0.png b/logs/generated_10.0.png new file mode 100644 index 0000000..6c78093 Binary files /dev/null and b/logs/generated_10.0.png differ diff --git a/logs/generated_11.0.png b/logs/generated_11.0.png new file mode 100644 index 0000000..1b9a7af Binary files /dev/null and b/logs/generated_11.0.png differ diff --git a/logs/generated_12.0.png b/logs/generated_12.0.png new file mode 100644 index 0000000..663cbba Binary files /dev/null and b/logs/generated_12.0.png differ diff --git a/logs/generated_13.0.png b/logs/generated_13.0.png new file mode 100644 index 0000000..91bcd34 Binary files /dev/null and b/logs/generated_13.0.png differ diff --git a/logs/generated_14.0.png b/logs/generated_14.0.png new file mode 100644 index 0000000..cb53b85 Binary files /dev/null and b/logs/generated_14.0.png differ diff --git a/logs/generated_15.0.png b/logs/generated_15.0.png new file mode 100644 index 0000000..3b34d06 Binary files /dev/null and b/logs/generated_15.0.png differ diff --git a/logs/generated_16.0.png b/logs/generated_16.0.png new file mode 100644 index 0000000..c3bd60b Binary files /dev/null and b/logs/generated_16.0.png differ diff --git a/logs/generated_17.0.png b/logs/generated_17.0.png new file mode 100644 index 0000000..0b2d210 Binary files /dev/null and b/logs/generated_17.0.png differ diff --git a/logs/generated_18.0.png b/logs/generated_18.0.png new file mode 100644 index 0000000..d86628e Binary files /dev/null and b/logs/generated_18.0.png differ diff --git a/logs/generated_19.0.png b/logs/generated_19.0.png new file mode 100644 index 0000000..971cbcf Binary files /dev/null and b/logs/generated_19.0.png differ diff --git a/logs/generated_2.0.png b/logs/generated_2.0.png new file mode 100644 index 0000000..2ffbe4c Binary files /dev/null and b/logs/generated_2.0.png differ diff --git a/logs/generated_20.0.png b/logs/generated_20.0.png new file mode 100644 index 0000000..d715918 Binary files /dev/null and b/logs/generated_20.0.png differ diff --git a/logs/generated_21.0.png b/logs/generated_21.0.png new file mode 100644 index 0000000..b6520fd Binary files /dev/null and b/logs/generated_21.0.png differ diff --git a/logs/generated_22.0.png b/logs/generated_22.0.png new file mode 100644 index 0000000..771e4f4 Binary files /dev/null and b/logs/generated_22.0.png differ diff --git a/logs/generated_23.0.png b/logs/generated_23.0.png new file mode 100644 index 0000000..6391e55 Binary files /dev/null and b/logs/generated_23.0.png differ diff --git a/logs/generated_24.0.png b/logs/generated_24.0.png new file mode 100644 index 0000000..9b02f22 Binary files /dev/null and b/logs/generated_24.0.png differ diff --git a/logs/generated_25.0.png b/logs/generated_25.0.png new file mode 100644 index 0000000..b924bd2 Binary files /dev/null and b/logs/generated_25.0.png differ diff --git a/logs/generated_26.0.png b/logs/generated_26.0.png new file mode 100644 index 0000000..8e62896 Binary files /dev/null and b/logs/generated_26.0.png differ diff --git a/logs/generated_27.0.png b/logs/generated_27.0.png new file mode 100644 index 0000000..4ba9cd0 Binary files /dev/null and b/logs/generated_27.0.png differ diff --git a/logs/generated_28.0.png b/logs/generated_28.0.png new file mode 100644 index 0000000..5a8aa08 Binary files /dev/null and b/logs/generated_28.0.png differ diff --git a/logs/generated_29.0.png b/logs/generated_29.0.png new file mode 100644 index 0000000..16fee09 Binary files /dev/null and b/logs/generated_29.0.png differ diff --git a/logs/generated_3.0.png b/logs/generated_3.0.png new file mode 100644 index 0000000..c76b58b Binary files /dev/null and b/logs/generated_3.0.png differ diff --git a/logs/generated_4.0.png b/logs/generated_4.0.png new file mode 100644 index 0000000..a5b20b9 Binary files /dev/null and b/logs/generated_4.0.png differ diff --git a/logs/generated_5.0.png b/logs/generated_5.0.png new file mode 100644 index 0000000..57a3655 Binary files /dev/null and b/logs/generated_5.0.png differ diff --git a/logs/generated_6.0.png b/logs/generated_6.0.png new file mode 100644 index 0000000..81010a7 Binary files /dev/null and b/logs/generated_6.0.png differ diff --git a/logs/generated_7.0.png b/logs/generated_7.0.png new file mode 100644 index 0000000..c2ae745 Binary files /dev/null and b/logs/generated_7.0.png differ diff --git a/logs/generated_8.0.png b/logs/generated_8.0.png new file mode 100644 index 0000000..fcfb7a5 Binary files /dev/null and b/logs/generated_8.0.png differ diff --git a/logs/generated_9.0.png b/logs/generated_9.0.png new file mode 100644 index 0000000..aca531c Binary files /dev/null and b/logs/generated_9.0.png differ diff --git a/model/cgan.png b/model/cgan.png new file mode 100644 index 0000000..2620e4d Binary files /dev/null and b/model/cgan.png differ diff --git a/model/cgan_discriminator.png b/model/cgan_discriminator.png new file mode 100644 index 0000000..4078472 Binary files /dev/null and b/model/cgan_discriminator.png differ diff --git a/model/cgan_generator.png b/model/cgan_generator.png new file mode 100644 index 0000000..136662a Binary files /dev/null and b/model/cgan_generator.png differ diff --git a/model/gan_discriminator.png b/model/gan_discriminator.png new file mode 100644 index 0000000..21c9480 Binary files /dev/null and b/model/gan_discriminator.png differ diff --git a/model/gan_generator.png b/model/gan_generator.png new file mode 100644 index 0000000..19472db Binary files /dev/null and b/model/gan_generator.png differ