From 9238bc91d733adb481423f9ed95843de22943042 Mon Sep 17 00:00:00 2001 From: aireenmei <12836798+aireenmei@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:54:58 -0700 Subject: [PATCH] enable grain and tfrecord for mlperf dataset in SD base_2_base training (#130) --- .github/workflows/UnitTests.yml | 5 +- setup.sh | 12 ++- setup_gcsfuse.sh | 52 ++++++++++++ src/maxdiffusion/configs/base_2_base.yml | 2 + .../input_pipeline/_grain_data_processing.py | 79 +++++++++++++++++++ .../input_pipeline/_tfds_data_processing.py | 15 ++-- .../input_pipeline_interface.py | 11 ++- .../tests/input_pipeline_interface_test.py | 64 ++++++++++++++- .../trainers/stable_diffusion_trainer.py | 21 ++++- 9 files changed, 242 insertions(+), 19 deletions(-) create mode 100755 setup_gcsfuse.sh create mode 100644 src/maxdiffusion/input_pipeline/_grain_data_processing.py diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 81395580..73dad075 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -39,10 +39,9 @@ jobs: - name: Install dependencies run: | pip install -e . - pip install -U -r requirements.txt - export PATH=$PATH:$HOME/.local/bin pip uninstall jax jaxlib libtpu-nightly libtpu -y - pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + bash setup.sh MODE=stable + export PATH=$PATH:$HOME/.local/bin pip install ruff pip install isort pip install pytest diff --git a/setup.sh b/setup.sh index f1c0e49f..a8f32fdc 100644 --- a/setup.sh +++ b/setup.sh @@ -23,6 +23,16 @@ set -e export DEBIAN_FRONTEND=noninteractive +(sudo bash || bash) <<'EOF' +apt update && \ +apt install -y numactl lsb-release gnupg curl net-tools iproute2 procps lsof git ethtool && \ +export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` +echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | tee /etc/apt/sources.list.d/gcsfuse.list +curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - +apt update -y && apt -y install gcsfuse +rm -rf /var/lib/apt/lists/* +EOF + # Set environment variables from command line arguments for ARGUMENT in "$@"; do IFS='=' read -r KEY VALUE <<< "$ARGUMENT" @@ -97,4 +107,4 @@ else fi # Install dependencies from requirements.txt -pip3 install -U -r requirements.txt \ No newline at end of file +pip3 install -U -r requirements.txt diff --git a/setup_gcsfuse.sh b/setup_gcsfuse.sh new file mode 100755 index 00000000..4b965c34 --- /dev/null +++ b/setup_gcsfuse.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# bash setup_gcsfuse.sh DATASET_GCS_BUCKET=maxdiffusion-github-runner-test-assets MOUNT_PATH=/tmp/gcsfuse + +set -e -x + +# Set environment variables +for ARGUMENT in "$@"; do + IFS='=' read -r KEY VALUE <<< "$ARGUMENT" + export "$KEY"="$VALUE" + echo "$KEY"="$VALUE" +done + +if [[ -z ${DATASET_GCS_BUCKET} || -z ${MOUNT_PATH} ]]; then + echo "Please set arguments: DATASET_GCS_BUCKET and MOUNT_PATH" + exit 1 +fi + +if [[ "$DATASET_GCS_BUCKET" =~ gs:\/\/ ]] ; then + DATASET_GCS_BUCKET="${DATASET_GCS_BUCKET/gs:\/\//}" + echo "Removed gs:// from GCS bucket name, GCS bucket is $DATASET_GCS_BUCKET" +fi + +if [[ -d ${MOUNT_PATH} ]]; then + echo "$MOUNT_PATH exists, removing..." + fusermount -u $MOUNT_PATH || rm -rf $MOUNT_PATH +fi + +mkdir -p $MOUNT_PATH + +# see https://cloud.google.com/storage/docs/gcsfuse-cli for all configurable options of gcsfuse CLI +# Grain uses _PROCESS_MANAGEMENT_MAX_THREADS = 64 (https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py) +# Please make sure max-conns-per-host > grain_worker_count * _PROCESS_MANAGEMENT_MAX_THREADS + +gcsfuse -o ro --implicit-dirs --http-client-timeout=5s --max-conns-per-host=2000 \ + --debug_fuse_errors --debug_fuse --debug_gcs --debug_invariants --debug_mutex \ + --log-file=$HOME/gcsfuse.json "$DATASET_GCS_BUCKET" "$MOUNT_PATH" diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index a310555d..f9c36ddb 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -158,6 +158,8 @@ jax_cache_dir: '' hf_data_dir: '' hf_train_files: '' hf_access_token: '' +grain_train_files: '' +grain_worker_count: 4 image_column: 'image' caption_column: 'text' resolution: 512 diff --git a/src/maxdiffusion/input_pipeline/_grain_data_processing.py b/src/maxdiffusion/input_pipeline/_grain_data_processing.py new file mode 100644 index 00000000..5ba3b637 --- /dev/null +++ b/src/maxdiffusion/input_pipeline/_grain_data_processing.py @@ -0,0 +1,79 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import dataclasses +import glob +import tensorflow as tf +import numpy as np +import grain.python as grain + +from maxdiffusion import multihost_dataloading + + +def make_grain_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, +): + """Use Grain data input pipeline with ArrayRecord data format""" + data_files = glob.glob(config.grain_train_files) + data_source = grain.ArrayRecordDataSource(data_files) + + operations = [] + operations.append(ParseFeatures()) + operations.append(grain.Batch(batch_size=global_batch_size // dataloading_host_count, drop_remainder=True)) + + index_sampler = grain.IndexSampler( + num_records=len(data_source), + num_epochs=None, + shard_options=grain.ShardOptions( + shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=True + ), + shuffle=True, + seed=config.seed, + ) + + dataloader = grain.DataLoader( + data_source=data_source, + operations=operations, + sampler=index_sampler, + worker_count=config.grain_worker_count, + ) + + data_iter = multihost_dataloading.MultiHostDataLoadIterator(dataloader, mesh) + return data_iter + + +@dataclasses.dataclass +class ParseFeatures(grain.MapTransform): + """Parse serialized example""" + + def __init__(self): + self.feature_description = { + "moments": tf.io.FixedLenFeature([], tf.string), + "clip_embeddings": tf.io.FixedLenFeature([], tf.string), + } + + def map(self, example): + def _parse(example): + features = tf.io.parse_single_example(example, self.feature_description) + moments = tf.io.parse_tensor(np.asarray(features["moments"]), out_type=tf.float32) + clip_embeddings = tf.io.parse_tensor(np.asarray(features["clip_embeddings"]), out_type=tf.float32) + return {"pixel_values": moments, "input_ids": clip_embeddings} + + return _parse(example) diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 654c1d38..89226f37 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -87,30 +87,29 @@ def make_tfrecord_iterator( maxdiffusion/pedagogical_examples/to_tfrecords.py """ feature_description = { - "latents": tf.io.FixedLenFeature([], tf.string), - "hidden_states": tf.io.FixedLenFeature([], tf.string), + "moments": tf.io.FixedLenFeature([], tf.string), + "clip_embeddings": tf.io.FixedLenFeature([], tf.string), } def _parse_tfrecord_fn(example): return tf.io.parse_single_example(example, feature_description) def prepare_sample(features): - latents = tf.io.parse_tensor(tnp.asarray(features["latents"]), out_type=tf.float32) - hidden_states = tf.io.parse_tensor(tnp.asarray(features["hidden_states"]), out_type=tf.float32) - return {"pixel_values": latents, "input_ids": hidden_states} + moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32) + clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32) + return {"pixel_values": moments, "input_ids": clip_embeddings} filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) train_ds = ( tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) + .shard(num_shards=dataloading_host_count, index=dataloading_host_index) .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) .map(prepare_sample, num_parallel_calls=AUTOTUNE) .shuffle(global_batch_size * 10) .batch(global_batch_size // dataloading_host_count, drop_remainder=True) - .prefetch(AUTOTUNE) .repeat(-1) + .prefetch(AUTOTUNE) ) - train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) - train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) return train_iter diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index 00bfcd9b..3a78ff09 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -21,6 +21,7 @@ import jax from maxdiffusion.input_pipeline import _hf_data_processing +from maxdiffusion.input_pipeline import _grain_data_processing from maxdiffusion.input_pipeline import _tfds_data_processing from maxdiffusion import multihost_dataloading from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply @@ -61,6 +62,14 @@ def make_data_iterator( tokenize_fn=tokenize_fn, image_transforms_fn=image_transforms_fn, ) + elif config.dataset_type == "grain": + return _grain_data_processing.make_grain_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + ) elif config.dataset_type == "tf": return _tfds_data_processing.make_tf_iterator( config, @@ -80,7 +89,7 @@ def make_data_iterator( global_batch_size, ) else: - assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf)" + assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)" def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params): diff --git a/src/maxdiffusion/tests/input_pipeline_interface_test.py b/src/maxdiffusion/tests/input_pipeline_interface_test.py index d7959583..9afd2267 100644 --- a/src/maxdiffusion/tests/input_pipeline_interface_test.py +++ b/src/maxdiffusion/tests/input_pipeline_interface_test.py @@ -18,6 +18,7 @@ from functools import partial import pathlib import shutil +import subprocess import unittest from absl.testing import absltest @@ -425,13 +426,68 @@ def test_make_pokemon_iterator_sdxl_cache(self): config.resolution // vae_scale_factor, ) + def test_make_laion_grain_iterator(self): + try: + subprocess.check_output( + [ + "bash", + "setup_gcsfuse.sh", + "DATASET_GCS_BUCKET=maxdiffusion-github-runner-test-assets", + "MOUNT_PATH=/tmp/gcsfuse", + ], + stderr=subprocess.STDOUT, + ) + except subprocess.CalledProcessError as e: + raise ValueError(f"setup_gcsfuse failed with error: {e.output}") from e + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"), + "grain_train_files=/tmp/gcsfuse/datasets/array-record/laion400m/tf_records_512_encoder_state_fp32/*.arrayrecord", + "dataset_type=grain", + ], + unittest=True, + ) + config = pyconfig.config + global_batch_size = config.per_device_batch_size * jax.device_count() + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + + train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size) + data = next(train_iterator) + device_count = jax.device_count() + + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + encoder_hidden_states = data["input_ids"] + + # TODO - laion dataset was prepared with an extra dim. + # need to preprocess the dataset with dim removed. + if len(encoder_hidden_states.shape) == 4: + encoder_hidden_states = jnp.squeeze(encoder_hidden_states) + + assert encoder_hidden_states.shape == (device_count, 77, 1024) + assert data["pixel_values"].shape == ( + config.total_train_batch_size, + config.resolution // vae_scale_factor, + config.resolution // vae_scale_factor, + 8, + ) + def test_make_laion_tfrecord_iterator(self): pyconfig.initialize( [ None, os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"), - "cache_latents_text_encoder_outputs=True", - "train_data_dir=gs://jfacevedo-maxdiffusion/laion400m/processed/laion400m_tfrec", + "train_data_dir=gs://jfacevedo-maxdiffusion/laion400m/raw_data/tf_records_512_encoder_state_fp32", "dataset_type=tfrecord", ], unittest=True, @@ -464,10 +520,10 @@ def test_make_laion_tfrecord_iterator(self): assert encoder_hidden_states.shape == (device_count, 77, 1024) assert data["pixel_values"].shape == ( - device_count, - pipeline.unet.config.in_channels, + config.total_train_batch_size, config.resolution // vae_scale_factor, config.resolution // vae_scale_factor, + 8, ) def test_tfrecord(self): diff --git a/src/maxdiffusion/trainers/stable_diffusion_trainer.py b/src/maxdiffusion/trainers/stable_diffusion_trainer.py index c46ea38e..f8d56e1b 100644 --- a/src/maxdiffusion/trainers/stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/stable_diffusion_trainer.py @@ -29,7 +29,7 @@ from maxdiffusion import (FlaxDDPMScheduler, maxdiffusion_utils, train_utils, max_utils, max_logging) from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) - +from maxdiffusion.models.vae_flax import FlaxDiagonalGaussianDistribution from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import (STABLE_DIFFUSION_CHECKPOINT) @@ -67,6 +67,19 @@ def get_shaped_batch(self, config, pipeline): pipeline.text_encoder.config.hidden_size, ) input_ids_dtype = jnp.float32 + elif config.dataset_type in ("tfrecord", "grain"): + batch_image_shape = ( + total_train_batch_size, + config.resolution // vae_scale_factor, + config.resolution // vae_scale_factor, + 8, + ) + batch_ids_shape = ( + total_train_batch_size, + pipeline.text_encoder.config.max_position_embeddings, + pipeline.text_encoder.config.hidden_size, + ) + input_ids_dtype = jnp.float32 else: batch_image_shape = (total_train_batch_size, 3, config.resolution, config.resolution) batch_ids_shape = (total_train_batch_size, pipeline.text_encoder.config.max_position_embeddings) @@ -240,10 +253,14 @@ def _train_step(unet_state, vae_state, text_encoder_state, batch, train_rng, pip state_params = {"unet": unet_state.params} def compute_loss(state_params): - if config.dataset_type == "tf" and config.cache_latents_text_encoder_outputs: latents = batch["pixel_values"] encoder_hidden_states = batch["input_ids"] + elif config.dataset_type in ("tfrecord", "grain"): + latents = FlaxDiagonalGaussianDistribution(batch["pixel_values"]).sample(sample_rng) + latents = jnp.transpose(latents, (0, 3, 1, 2)) + latents = latents * pipeline.vae.config.scaling_factor + encoder_hidden_states = batch["input_ids"] else: # Convert images to latent space vae_outputs = pipeline.vae.apply(