Skip to content
This repository has been archived by the owner on Jan 10, 2023. It is now read-only.

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 235733444
  • Loading branch information
mariolucic authored and Marvin182 committed Feb 26, 2019
1 parent 560697e commit e0b739f
Show file tree
Hide file tree
Showing 126 changed files with 10,399 additions and 10,677 deletions.
186 changes: 113 additions & 73 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,73 +1,113 @@
## Compare GAN code.

This is the code that was used in "Are GANs Created Equal? A Large-Scale Study"
paper (https://arxiv.org/abs/1711.10337) and in "The GAN Landscape: Losses,
Architectures, Regularization, and Normalization"
(https://arxiv.org/abs/1807.04720).

If you want to see the version used only in the first paper - please see the
*v1* branch of this repository.

## Pre-trained models

The pre-trained models are available on TensorFlow Hub. Please see
[this colab](https://colab.research.google.com/github/google/compare_gan/blob/master/compare_gan/src/tfhub_models.ipynb)
for an example how to use them.

### Best hyperparameters

This repository also contains the values for the best hyperparameters for
different combinations of models, regularizations and penalties. You can see
them in `generate_tasks_lib.py` file and train using
`--experiment=best_models_sndcgan`

### Installation:

To install, run:

```shell
python -m pip install -e . --user
```

After installing, make sure to run

```shell
compare_gan_prepare_datasets.sh
```

It will download all the necessary datasets and frozen TF graphs. By default it
will store them in `/tmp/datasets`.

WARNING: by default this script only downloads and installs small datasets - it
doesn't download celebaHQ or lsun bedrooms.

* **Lsun bedrooms dataset**: If you want to install lsun-bedrooms you need to
run t2t-datagen yourself (this dataset will take couple hours to download
and unpack).

* **CelebaHQ dataset**: currently it is not available in tensor2tensor. Please
use the
[ProgressiveGAN github](https://github.com/tkarras/progressive_growing_of_gans)
for instructions on how to prepare it.

### Running

compare_gan has two binaries:

* `generate_tasks` - that creates a list of files with parameters to execute
* `run_one_task` - that executes a given task, both training and evaluation,
and stores results in the CSV file.

```shell
# Create tasks for experiment "test" in directory /tmp/results. See "src/generate_tasks_lib.py" to see other possible experiments.
compare_gan_generate_tasks --workdir=/tmp/results --experiment=test

# Run task 0 (training and eval)
compare_gan_run_one_task --workdir=/tmp/results --task_num=0 --dataset_root=/tmp/datasets

# Run task 1 (training and eval)
compare_gan_run_one_task --workdir=/tmp/results --task_num=1 --dataset_root=/tmp/datasets
```

Results (all computed metrics) will be stored in
`/tmp/results/TASK_NUM/scores.csv`.
# Compare GAN

This repository offers TensorFlow implementations for many components related to
**Generative Adversarial Networks**:

* losses (such non-saturating GAN, least-squares GAN, and WGAN),
* penalties (such as the gradient penalty),
* normalization techniques (such as spectral normalization, batch
normalization, and layer normalization),
* neural architectures (BigGAN, ResNet, DCGAN), and
* evaluation metrics (FID score, Inception Score, precision-recall, and KID
score).

The code is **configurable via [Gin](https://github.com/google/gin-config)** and
runs on **GPU/TPU/CPUs**. Several research papers make use of this repository,
including:

1. [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337)
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/tree/v1)
\
Mario Lucic*, Karol Kurach*, Marcin Michalski, Sylvain Gelly, Olivier
Bousquet **[NeurIPS 2018]**

2. [The GAN Landscape: Losses, Architectures, Regularization, and Normalization](https://arxiv.org/abs/1807.04720)
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/tree/v2)
\
Karol Kurach*, Mario Lucic*, Xiaohua Zhai, Marcin Michalski, Sylvain Gelly
**[2018]**

3. [Assessing Generative Models via Precision and Recall](https://arxiv.org/abs/1806.00035)
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/blob/560697ee213f91048c6b4231ab79fcdd9bf20381/compare_gan/src/prd_score.py)
\
Mehdi S. M. Sajjadi, Olivier Bachem, Mario Lucic, Olivier Bousquet, Sylvain
Gelly **[NeurIPS 2018]**

4. [GILBO: One Metric to Measure Them All](https://arxiv.org/abs/1802.04874)
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/blob/560697ee213f91048c6b4231ab79fcdd9bf20381/compare_gan/src/gilbo.py)
\
Alexander A. Alemi, Ian Fischer **[NeurIPS 2018]**

5. [A Case for Object Compositionality in Deep Generative Models of Images](https://arxiv.org/abs/1810.10340)
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/tree/v2_multigan)
\
Sjoerd van Steenkiste, Karol Kurach, Sylvain Gelly **[2018]**

6. [On Self Modulation for Generative Adversarial Networks](https://arxiv.org/abs/1810.01365)
[<font color="green">[Code]</font>](https://github.com/google/compare_gan) \
Ting Chen, Mario Lucic, Neil Houlsby, Sylvain Gelly **[ICLR 2019]**

7. [Self-Supervised Generative Adversarial Networks](https://arxiv.org/abs/1811.11212)
[<font color="green">[Code]</font>](https://github.com/google/compare_gan) \
Ting Chen, Xiaohua Zhai, Marvin Ritter, Mario Lucic, Neil Houlsby **[CVPR
2019]**


## Installation

You can easily install the library and all necessary dependencies by running:
`pip install -e .` from the `compare_gan/` folder.

## Running experiments

Simply run the `main.py` passing a `--model_dir` (this is where checkpoints are
stored) and a `--gin_config` (defies which model on which data set and other
options). We provide several example configurations in the `example_configs/`
folder, namely:

* **dcgan_celeba64**: DCGAN architecture with non-saturating loss on CelebA
64x64px
* **resnet_cifar10**: ResNet architecture with non-saturating loss and
spectral normalization on CIFAR-10
* **resnet_lsun-bedroom128**: ResNet architecture with WGAN loss and gradient
penalty on LSUN-bedrooms 128x128px
* **sndcgan_celebahq128**: SN-DCGAN architecture with non-saturating loss and
spectral normalization on CelebA-HQ 128x128px
* **biggan_imagenet128**: BigGAN architecture with hinge loss and spectral
normalization on ImageNet 128x128px

### Training and evaluation

To see all available options please run `python main.py --help`. Main options:

* To **train** the model use `--schedule=train` (default). Training is resumed
from the last saved checkpoint.
* To **evaluate** all checkpoints use `--schedule=continuous_eval
--eval_every_steps=0`. To evaluate only checkpoints where the step size is
divisible by 5000, use `--schedule=continuous_eval --eval_every_steps=5000`.
By default, 3 averaging runs are used to estimate the Inception Score and
the FID score. Keep in mind that when running locally on a single GPU it may
not be possible to run training and evaluation simultaneously due to memory
constraints.
* To **train and evaluate** the model use `--schedule=eval_after_train
--eval_every_steps=0`.

### Training on Cloud TPUs

We recommend using the
[ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu) to create
a Cloud TPU and corresponding Compute Engine VM. We use v3-128 Cloud TPU v3 Pod
for training models on ImageNet in 128x128 resolutions. You can use smaller
slices if you reduce the batch size (`options.batch_size` in the Gin config) or
model parameters. Keep in mind that the model quality might change. Before
training make sure that the environment variable `TPU_NAME` is set. Running
evaluation on TPUs is currently not supported. Use a VM with a single GPU
instead.

### Datasets

Compare GAN uses [TensorFlow Datasets](https://www.tensorflow.org/datasets) and
it will automatically download and prepare the data. For ImageNet you will need
to download the archive yourself. For CelebAHq you need to download and prepare
the images on your own. If you are using TPUs make sure to point the training
script to your Google Storage Bucket (`--tfds_data_dir`).
File renamed without changes.
127 changes: 127 additions & 0 deletions compare_gan/architectures/abstract_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# coding=utf-8
# Copyright 2018 Google LLC & Hwalsuk Lee.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines interfaces for generator and discriminator networks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
from compare_gan import utils
import gin
import six
import tensorflow as tf


@gin.configurable("G", blacklist=["name", "image_shape"])
@six.add_metaclass(abc.ABCMeta)
class AbstractGenerator(object):
"""Interface for generator architectures."""

def __init__(self,
name="generator",
image_shape=None,
batch_norm_fn=None,
spectral_norm=False):
"""Constructor for all generator architectures.
Args:
name: Scope name of the generator.
image_shape: Image shape to be generated, [height, width, colors].
batch_norm_fn: Function for batch normalization or None.
spectral_norm: If True use spectral normalization for all weights.
"""
self._name = name
self._image_shape = image_shape
self._batch_norm_fn = batch_norm_fn
self._spectral_norm = spectral_norm

def __call__(self, z, y, is_training, reuse=tf.AUTO_REUSE):
with tf.variable_scope(self._name, values=[z, y], reuse=reuse):
outputs = self.apply(z=z, y=y, is_training=is_training)
return outputs

def batch_norm(self, inputs, **kwargs):
if self._batch_norm_fn is None:
return inputs
args = kwargs.copy()
args["inputs"] = inputs
if "use_sn" not in args:
args["use_sn"] = self._spectral_norm
return utils.call_with_accepted_args(self._batch_norm_fn, **args)

@abc.abstractmethod
def apply(self, z, y, is_training):
"""Apply the generator on a input.
Args:
z: `Tensor` of shape [batch_size, z_dim] with latent code.
y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
labels.
is_training: Boolean, whether the architecture should be constructed for
training or inference.
Returns:
Generated images of shape [batch_size] + self.image_shape.
"""


@gin.configurable("D", blacklist=["name"])
@six.add_metaclass(abc.ABCMeta)
class AbstractDiscriminator(object):
"""Interface for discriminator architectures."""

def __init__(self,
name="discriminator",
batch_norm_fn=None,
layer_norm=False,
spectral_norm=False):
self._name = name
self._batch_norm_fn = batch_norm_fn
self._layer_norm = layer_norm
self._spectral_norm = spectral_norm

def __call__(self, x, y, is_training, reuse=tf.AUTO_REUSE):
with tf.variable_scope(self._name, values=[x, y], reuse=reuse):
outputs = self.apply(x=x, y=y, is_training=is_training)
return outputs

def batch_norm(self, inputs, **kwargs):
if self._batch_norm_fn is None:
return inputs
args = kwargs.copy()
args["inputs"] = inputs
if "use_sn" not in args:
args["use_sn"] = self._spectral_norm
return utils.call_with_accepted_args(self._batch_norm_fn, **args)


@abc.abstractmethod
def apply(self, x, y, is_training):
"""Apply the discriminator on a input.
Args:
x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images.
y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
labels.
is_training: Boolean, whether the architecture should be constructed for
training or inference.
Returns:
Tuple of 3 Tensors, the final prediction of the discriminator, the logits
before the final output activation function and logits form the second
last layer.
"""
Loading

0 comments on commit e0b739f

Please sign in to comment.