This repository has been archived by the owner on Jan 10, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 317
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Project import generated by Copybara.
PiperOrigin-RevId: 235733444
- Loading branch information
1 parent
560697e
commit e0b739f
Showing
126 changed files
with
10,399 additions
and
10,677 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,113 @@ | ||
## Compare GAN code. | ||
|
||
This is the code that was used in "Are GANs Created Equal? A Large-Scale Study" | ||
paper (https://arxiv.org/abs/1711.10337) and in "The GAN Landscape: Losses, | ||
Architectures, Regularization, and Normalization" | ||
(https://arxiv.org/abs/1807.04720). | ||
|
||
If you want to see the version used only in the first paper - please see the | ||
*v1* branch of this repository. | ||
|
||
## Pre-trained models | ||
|
||
The pre-trained models are available on TensorFlow Hub. Please see | ||
[this colab](https://colab.research.google.com/github/google/compare_gan/blob/master/compare_gan/src/tfhub_models.ipynb) | ||
for an example how to use them. | ||
|
||
### Best hyperparameters | ||
|
||
This repository also contains the values for the best hyperparameters for | ||
different combinations of models, regularizations and penalties. You can see | ||
them in `generate_tasks_lib.py` file and train using | ||
`--experiment=best_models_sndcgan` | ||
|
||
### Installation: | ||
|
||
To install, run: | ||
|
||
```shell | ||
python -m pip install -e . --user | ||
``` | ||
|
||
After installing, make sure to run | ||
|
||
```shell | ||
compare_gan_prepare_datasets.sh | ||
``` | ||
|
||
It will download all the necessary datasets and frozen TF graphs. By default it | ||
will store them in `/tmp/datasets`. | ||
|
||
WARNING: by default this script only downloads and installs small datasets - it | ||
doesn't download celebaHQ or lsun bedrooms. | ||
|
||
* **Lsun bedrooms dataset**: If you want to install lsun-bedrooms you need to | ||
run t2t-datagen yourself (this dataset will take couple hours to download | ||
and unpack). | ||
|
||
* **CelebaHQ dataset**: currently it is not available in tensor2tensor. Please | ||
use the | ||
[ProgressiveGAN github](https://github.com/tkarras/progressive_growing_of_gans) | ||
for instructions on how to prepare it. | ||
|
||
### Running | ||
|
||
compare_gan has two binaries: | ||
|
||
* `generate_tasks` - that creates a list of files with parameters to execute | ||
* `run_one_task` - that executes a given task, both training and evaluation, | ||
and stores results in the CSV file. | ||
|
||
```shell | ||
# Create tasks for experiment "test" in directory /tmp/results. See "src/generate_tasks_lib.py" to see other possible experiments. | ||
compare_gan_generate_tasks --workdir=/tmp/results --experiment=test | ||
|
||
# Run task 0 (training and eval) | ||
compare_gan_run_one_task --workdir=/tmp/results --task_num=0 --dataset_root=/tmp/datasets | ||
|
||
# Run task 1 (training and eval) | ||
compare_gan_run_one_task --workdir=/tmp/results --task_num=1 --dataset_root=/tmp/datasets | ||
``` | ||
|
||
Results (all computed metrics) will be stored in | ||
`/tmp/results/TASK_NUM/scores.csv`. | ||
# Compare GAN | ||
|
||
This repository offers TensorFlow implementations for many components related to | ||
**Generative Adversarial Networks**: | ||
|
||
* losses (such non-saturating GAN, least-squares GAN, and WGAN), | ||
* penalties (such as the gradient penalty), | ||
* normalization techniques (such as spectral normalization, batch | ||
normalization, and layer normalization), | ||
* neural architectures (BigGAN, ResNet, DCGAN), and | ||
* evaluation metrics (FID score, Inception Score, precision-recall, and KID | ||
score). | ||
|
||
The code is **configurable via [Gin](https://github.com/google/gin-config)** and | ||
runs on **GPU/TPU/CPUs**. Several research papers make use of this repository, | ||
including: | ||
|
||
1. [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337) | ||
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/tree/v1) | ||
\ | ||
Mario Lucic*, Karol Kurach*, Marcin Michalski, Sylvain Gelly, Olivier | ||
Bousquet **[NeurIPS 2018]** | ||
|
||
2. [The GAN Landscape: Losses, Architectures, Regularization, and Normalization](https://arxiv.org/abs/1807.04720) | ||
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/tree/v2) | ||
\ | ||
Karol Kurach*, Mario Lucic*, Xiaohua Zhai, Marcin Michalski, Sylvain Gelly | ||
**[2018]** | ||
|
||
3. [Assessing Generative Models via Precision and Recall](https://arxiv.org/abs/1806.00035) | ||
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/blob/560697ee213f91048c6b4231ab79fcdd9bf20381/compare_gan/src/prd_score.py) | ||
\ | ||
Mehdi S. M. Sajjadi, Olivier Bachem, Mario Lucic, Olivier Bousquet, Sylvain | ||
Gelly **[NeurIPS 2018]** | ||
|
||
4. [GILBO: One Metric to Measure Them All](https://arxiv.org/abs/1802.04874) | ||
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/blob/560697ee213f91048c6b4231ab79fcdd9bf20381/compare_gan/src/gilbo.py) | ||
\ | ||
Alexander A. Alemi, Ian Fischer **[NeurIPS 2018]** | ||
|
||
5. [A Case for Object Compositionality in Deep Generative Models of Images](https://arxiv.org/abs/1810.10340) | ||
[<font color="green">[Code]</font>](https://github.com/google/compare_gan/tree/v2_multigan) | ||
\ | ||
Sjoerd van Steenkiste, Karol Kurach, Sylvain Gelly **[2018]** | ||
|
||
6. [On Self Modulation for Generative Adversarial Networks](https://arxiv.org/abs/1810.01365) | ||
[<font color="green">[Code]</font>](https://github.com/google/compare_gan) \ | ||
Ting Chen, Mario Lucic, Neil Houlsby, Sylvain Gelly **[ICLR 2019]** | ||
|
||
7. [Self-Supervised Generative Adversarial Networks](https://arxiv.org/abs/1811.11212) | ||
[<font color="green">[Code]</font>](https://github.com/google/compare_gan) \ | ||
Ting Chen, Xiaohua Zhai, Marvin Ritter, Mario Lucic, Neil Houlsby **[CVPR | ||
2019]** | ||
|
||
|
||
## Installation | ||
|
||
You can easily install the library and all necessary dependencies by running: | ||
`pip install -e .` from the `compare_gan/` folder. | ||
|
||
## Running experiments | ||
|
||
Simply run the `main.py` passing a `--model_dir` (this is where checkpoints are | ||
stored) and a `--gin_config` (defies which model on which data set and other | ||
options). We provide several example configurations in the `example_configs/` | ||
folder, namely: | ||
|
||
* **dcgan_celeba64**: DCGAN architecture with non-saturating loss on CelebA | ||
64x64px | ||
* **resnet_cifar10**: ResNet architecture with non-saturating loss and | ||
spectral normalization on CIFAR-10 | ||
* **resnet_lsun-bedroom128**: ResNet architecture with WGAN loss and gradient | ||
penalty on LSUN-bedrooms 128x128px | ||
* **sndcgan_celebahq128**: SN-DCGAN architecture with non-saturating loss and | ||
spectral normalization on CelebA-HQ 128x128px | ||
* **biggan_imagenet128**: BigGAN architecture with hinge loss and spectral | ||
normalization on ImageNet 128x128px | ||
|
||
### Training and evaluation | ||
|
||
To see all available options please run `python main.py --help`. Main options: | ||
|
||
* To **train** the model use `--schedule=train` (default). Training is resumed | ||
from the last saved checkpoint. | ||
* To **evaluate** all checkpoints use `--schedule=continuous_eval | ||
--eval_every_steps=0`. To evaluate only checkpoints where the step size is | ||
divisible by 5000, use `--schedule=continuous_eval --eval_every_steps=5000`. | ||
By default, 3 averaging runs are used to estimate the Inception Score and | ||
the FID score. Keep in mind that when running locally on a single GPU it may | ||
not be possible to run training and evaluation simultaneously due to memory | ||
constraints. | ||
* To **train and evaluate** the model use `--schedule=eval_after_train | ||
--eval_every_steps=0`. | ||
|
||
### Training on Cloud TPUs | ||
|
||
We recommend using the | ||
[ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu) to create | ||
a Cloud TPU and corresponding Compute Engine VM. We use v3-128 Cloud TPU v3 Pod | ||
for training models on ImageNet in 128x128 resolutions. You can use smaller | ||
slices if you reduce the batch size (`options.batch_size` in the Gin config) or | ||
model parameters. Keep in mind that the model quality might change. Before | ||
training make sure that the environment variable `TPU_NAME` is set. Running | ||
evaluation on TPUs is currently not supported. Use a VM with a single GPU | ||
instead. | ||
|
||
### Datasets | ||
|
||
Compare GAN uses [TensorFlow Datasets](https://www.tensorflow.org/datasets) and | ||
it will automatically download and prepare the data. For ImageNet you will need | ||
to download the archive yourself. For CelebAHq you need to download and prepare | ||
the images on your own. If you are using TPUs make sure to point the training | ||
script to your Google Storage Bucket (`--tfds_data_dir`). |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 Google LLC & Hwalsuk Lee. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Defines interfaces for generator and discriminator networks.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import abc | ||
from compare_gan import utils | ||
import gin | ||
import six | ||
import tensorflow as tf | ||
|
||
|
||
@gin.configurable("G", blacklist=["name", "image_shape"]) | ||
@six.add_metaclass(abc.ABCMeta) | ||
class AbstractGenerator(object): | ||
"""Interface for generator architectures.""" | ||
|
||
def __init__(self, | ||
name="generator", | ||
image_shape=None, | ||
batch_norm_fn=None, | ||
spectral_norm=False): | ||
"""Constructor for all generator architectures. | ||
Args: | ||
name: Scope name of the generator. | ||
image_shape: Image shape to be generated, [height, width, colors]. | ||
batch_norm_fn: Function for batch normalization or None. | ||
spectral_norm: If True use spectral normalization for all weights. | ||
""" | ||
self._name = name | ||
self._image_shape = image_shape | ||
self._batch_norm_fn = batch_norm_fn | ||
self._spectral_norm = spectral_norm | ||
|
||
def __call__(self, z, y, is_training, reuse=tf.AUTO_REUSE): | ||
with tf.variable_scope(self._name, values=[z, y], reuse=reuse): | ||
outputs = self.apply(z=z, y=y, is_training=is_training) | ||
return outputs | ||
|
||
def batch_norm(self, inputs, **kwargs): | ||
if self._batch_norm_fn is None: | ||
return inputs | ||
args = kwargs.copy() | ||
args["inputs"] = inputs | ||
if "use_sn" not in args: | ||
args["use_sn"] = self._spectral_norm | ||
return utils.call_with_accepted_args(self._batch_norm_fn, **args) | ||
|
||
@abc.abstractmethod | ||
def apply(self, z, y, is_training): | ||
"""Apply the generator on a input. | ||
Args: | ||
z: `Tensor` of shape [batch_size, z_dim] with latent code. | ||
y: `Tensor` of shape [batch_size, num_classes] with one hot encoded | ||
labels. | ||
is_training: Boolean, whether the architecture should be constructed for | ||
training or inference. | ||
Returns: | ||
Generated images of shape [batch_size] + self.image_shape. | ||
""" | ||
|
||
|
||
@gin.configurable("D", blacklist=["name"]) | ||
@six.add_metaclass(abc.ABCMeta) | ||
class AbstractDiscriminator(object): | ||
"""Interface for discriminator architectures.""" | ||
|
||
def __init__(self, | ||
name="discriminator", | ||
batch_norm_fn=None, | ||
layer_norm=False, | ||
spectral_norm=False): | ||
self._name = name | ||
self._batch_norm_fn = batch_norm_fn | ||
self._layer_norm = layer_norm | ||
self._spectral_norm = spectral_norm | ||
|
||
def __call__(self, x, y, is_training, reuse=tf.AUTO_REUSE): | ||
with tf.variable_scope(self._name, values=[x, y], reuse=reuse): | ||
outputs = self.apply(x=x, y=y, is_training=is_training) | ||
return outputs | ||
|
||
def batch_norm(self, inputs, **kwargs): | ||
if self._batch_norm_fn is None: | ||
return inputs | ||
args = kwargs.copy() | ||
args["inputs"] = inputs | ||
if "use_sn" not in args: | ||
args["use_sn"] = self._spectral_norm | ||
return utils.call_with_accepted_args(self._batch_norm_fn, **args) | ||
|
||
|
||
@abc.abstractmethod | ||
def apply(self, x, y, is_training): | ||
"""Apply the discriminator on a input. | ||
Args: | ||
x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. | ||
y: `Tensor` of shape [batch_size, num_classes] with one hot encoded | ||
labels. | ||
is_training: Boolean, whether the architecture should be constructed for | ||
training or inference. | ||
Returns: | ||
Tuple of 3 Tensors, the final prediction of the discriminator, the logits | ||
before the final output activation function and logits form the second | ||
last layer. | ||
""" |
Oops, something went wrong.