In the previous section, we have learnt about generative models - i.e. models that can generate new images similar to the ones in the training dataset. VAE was a good example of generative model.
However, if we try to generate something really meaningful, like a painting at reasonable resolution, with VAE, we will see that training does not converge well. There is another architecture specifically targeted at generative models - Generative Adversarial Networks, or GANs.
The main idea of GAN is to have two neural networks that will be trained against each other:
- Generator is a network that takes some random vector, and produces the image as a result
- Discriminator is a network that takes an image, and it should tell whether it is a real image (from training dataset), or it was generated by a generator. It is essentially an image classifier.
The architecture of discriminator does not differ from an ordinary image classification network. In simplest case it can be fully-connected classifier, but most probably it will be a convolutional network.
GAN based on convolutional networks is called DCGAN
CNN discriminator consists of the following layers: several convolutions+poolings (with decreasing spatial size and ), one-or-more fully-connected layers to get "feature vector", final binary classifier.
Generator is slightly more tricky. You can consider it to be a reversed discriminator - starting from latent vector (in place of a feature vector), it has fully-connected layer to convert it into required size/shape, followed by deconvolutions+upscaling.
Because convolution layer is implemented as a linear filter traversing the image, deconvolution is essentially similar to convolution, and can be implemented using the same layer logic.
GANs are called adversarial because there is a constant competition between generator and discriminator. During this cometition, both generator and discriminator improve, thus the network learns to produce better and better pictures.
The training happens in two stages:
- Training the discriminator. It is pretty straightforward: we generate a batch of images by the generator (for them label would be 0, which stands for fake image), and take a batch of images from the input dataset (with label 1, real image). We obtain some discriminator loss, and perform back prop.
- Training the generator. This is slightly more tricky, because we do not know the expected output for the generator directly. We take the whole GAN network consisting of generator followed by discriminator, feed it with some random vectors, and expect the result to be 1 (corresponding to real images). We then freeze the parameters of the discriminator (we do not want it to be trained at this step), and perform the back prop.
During this process, both generator and discriminator losses are not going down significantly. In the ideal situation, they should oscillate, corresponding to both networks improving their performance.
GANs are known to be especially difficult to train. Here are a few problems:
- Mode Collapse. By this term we mean that generator learns to produce one successful image that tricks the generator, and not a variety of different images.
- Sensitivity to hyperparameters. Often you can see that GAN does not converge at all, and then suddenly decrease in the learning rate can lead to convergence.
- Keeping balance between generator and discriminator. In many cases discriminator loss can drop to zero relatively quickly, which results in generator being unable to train further. To overcome this, we can try setting different learning rates for generator and discriminator, or skip discriminator training if the loss is already too low.
- Training for high resolution. It is the same problems as with autoencoders, because reconstructing too many layers of convolutional network leads to artifacts. This problem is typically solved with so-called progressive growing, when first a few layers are trained on low-res images, and then layers are "unblocked" or added. Another solutions would be adding extra connections between layers and training several resolutions at once - see Multi-Scale Gradient GANs paper for details.
- Marco Pasini, 10 Lessons I Learned Training GANs for one Year
- StyleGAN, a de facto GAN architecture to consider
- Creating Generative Art using GANs on Azure ML