Skip to content
This repository has been archived by the owner on Jan 10, 2023. It is now read-only.

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 219105444
  • Loading branch information
cyfra committed Oct 29, 2018
1 parent 51af943 commit 560697e
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 16 deletions.
22 changes: 12 additions & 10 deletions compare_gan/src/multi_gan/README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
# A Case for Object Compositionality in Generative Models of Images
# A Case for Object Compositionality in Deep Generative Models of Images

![clevr-generated](illustrations/clevr_generated.png)

This is the code repository complementing the
["A Case for Object Compositionality in Generative Models of Images"](todo).
["A Case for Object Compositionality in Deep Generative Models of Images"](https://arxiv.org/pdf/1810.10340.pdf).

## Datasets

The following provides an overview of the datasets that were used.
The following provides an overview of the datasets that were used. Corresponding
.tfrecords files for all custom datasets are available [here](https://goo.gl/Eub81x).

### Multi-MNIST

All Multi-MNIST datasets are available [here](todo). In these dataset each image
consists of 3 (potentially overlapping) MNIST digits. Digits are obtained from
the original dataset (using train/test/valid respectively), re-scaled and
randomly placed in the image. A fixed-offset from the border ensures that digits
appear entirely in the image.
In these dataset each image consists of 3 (potentially overlapping) MNIST digits.
Digits are obtained from the original dataset (using train/test/valid respectively),
re-scaled and randomly placed in the image. A fixed-offset from the border ensures
that digits appear entirely in the image.

#### Uniform
#### Independent

All digits 0-9 have an equal chance of appearing in the image. This roughly
results in a uniform distribution over all 3-tuples of digits in the images. In
Expand All @@ -41,7 +43,7 @@ Digits are sampled [uniformly](#uniform) and colored either
appears in a Multi-MNIST image. Digits are drawn one by one into the canvas
without blending colors, such that overlapping digits occlude one another.

#### RGB Occluded + CIFAR10
### CIFAR10 + RGB MM

Draws digits from [rgb occluded](#rgb-occluded) images on top of a randomly
sampled CIFAR10 image (resized to 64 x 64 using bilinear interpolation).
Expand Down
19 changes: 19 additions & 0 deletions compare_gan/src/multi_gan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# coding=utf-8
# Copyright 2018 Google LLC & Hwalsuk Lee.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Init.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
2 changes: 1 addition & 1 deletion compare_gan/src/multi_gan/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def unpack_clevr_image(image_data):
def load_clevr(dataset_name, split_name, num_threads, buffer_size):
del dataset_name
filenames = tf.data.Dataset.list_files(
os.path.join(FLAGS.multigan_dataset_root, "clevr_%s*" % split_name))
os.path.join(FLAGS.multigan_dataset_root, "clevr/%s*" % split_name))

return tf.data.TFRecordDataset(
filenames,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 34 additions & 5 deletions compare_gan/src/multi_gan/visualize_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
"Which checkpoint(s) to evaluate for a given study/task."
"Supports {'all', <int>}.")
flags.DEFINE_enum("visualization_type", "multi_image", [
"multi_image", "latent", "multi_latent"], "How to visualize this GAN.")
"image", "multi_image", "latent", "multi_latent"],
"How to visualize this GAN.")
flags.DEFINE_integer("batch_size", 64, "Size of the batch.")
flags.DEFINE_integer("images_per_fig", 4, "How many images to stack in a fig.")
FLAGS = flags.FLAGS
Expand Down Expand Up @@ -206,6 +207,25 @@ def SaveMultiGanGeneratorImages(aggregated_images, generated_images, save_dir):
plt.close(fig)


def SaveGeneratorImages(aggregated_images, save_dir):
"""Visualizes the aggregated output of all generators.
Args:
aggregated_images: The aggregated image (B, W, H, C).
save_dir: The path to the directory in which the figure should be saved.
"""

n_images, _, _, _ = aggregated_images.shape

for i in range(n_images):
fig, ax = plt.subplots(figsize=(5, 5))

PlotImage(ax, aggregated_images[i])
plt.savefig(os.path.join(save_dir, "generator_images_%d.png" % i),
bbox_inches="tight")
plt.close(fig)


def GetMultiGANGeneratorsOp(graph, gan_type, architecture, aggregate):
"""Returns the op to obtain the output of all generators."""

Expand Down Expand Up @@ -278,6 +298,16 @@ def EvalCheckpoint(checkpoint_path, task_workdir, options, out_cp_dir):
fake_images, generator_preds = sess.run(fetches, feed_dict=feed_dict)
SaveMultiGanGeneratorImages(fake_images, generator_preds, out_cp_dir)

# Compute outputs for GeneratorImages.
elif FLAGS.visualization_type == "image":
# Construct feed dict
z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
feed_dict = {gan.z: z_sample}

# Fetch data and save images.
fake_images = sess.run(gan.fake_images, feed_dict=feed_dict)
SaveGeneratorImages(fake_images, out_cp_dir)

# Compute outputs for MultiGanLatentTraversalImages
elif (FLAGS.visualization_type == "multi_latent" and
"MultiGAN" in options["gan_type"]):
Expand Down Expand Up @@ -362,9 +392,8 @@ def EvalTask(options, task_workdir, out_dir):
if FLAGS.checkpoint == "all":
all_checkpoint_paths = checkpoint_state.all_model_checkpoint_paths
else:
all_checkpoint_paths = [
cp_path for cp_path in checkpoint_state.all_model_checkpoint_paths if
cp_path.split("-")[-1] == FLAGS.checkpoint]
all_checkpoint_paths = ["%s/checkpoint/%s.model-%s" % (
FLAGS.eval_task_workdir, options["gan_type"], FLAGS.checkpoint)]

for checkpoint_path in all_checkpoint_paths:
out_cp_dir = os.path.join(
Expand Down Expand Up @@ -405,6 +434,6 @@ def main(unused_argv):
EvalTask(options, task_workdir, out_dir)

if __name__ == "__main__":
flags.mark_flag_as_required("eval_workdir")
flags.mark_flag_as_required("eval_task_workdir")
flags.mark_flag_as_required("out_dir")
tf.app.run()

0 comments on commit 560697e

Please sign in to comment.