Skip to content

Commit

Permalink
enable grain and tfrecord for mlperf dataset in SD base_2_base traini…
Browse files Browse the repository at this point in the history
…ng (#130)
  • Loading branch information
aireenmei authored Oct 30, 2024
1 parent 51e1db1 commit 9238bc9
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 19 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ jobs:
- name: Install dependencies
run: |
pip install -e .
pip install -U -r requirements.txt
export PATH=$PATH:$HOME/.local/bin
pip uninstall jax jaxlib libtpu-nightly libtpu -y
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
bash setup.sh MODE=stable
export PATH=$PATH:$HOME/.local/bin
pip install ruff
pip install isort
pip install pytest
Expand Down
12 changes: 11 additions & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@
set -e
export DEBIAN_FRONTEND=noninteractive

(sudo bash || bash) <<'EOF'
apt update && \
apt install -y numactl lsb-release gnupg curl net-tools iproute2 procps lsof git ethtool && \
export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`
echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | tee /etc/apt/sources.list.d/gcsfuse.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
apt update -y && apt -y install gcsfuse
rm -rf /var/lib/apt/lists/*
EOF

# Set environment variables from command line arguments
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
Expand Down Expand Up @@ -97,4 +107,4 @@ else
fi

# Install dependencies from requirements.txt
pip3 install -U -r requirements.txt
pip3 install -U -r requirements.txt
52 changes: 52 additions & 0 deletions setup_gcsfuse.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/bin/bash

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Description:
# bash setup_gcsfuse.sh DATASET_GCS_BUCKET=maxdiffusion-github-runner-test-assets MOUNT_PATH=/tmp/gcsfuse

set -e -x

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
echo "$KEY"="$VALUE"
done

if [[ -z ${DATASET_GCS_BUCKET} || -z ${MOUNT_PATH} ]]; then
echo "Please set arguments: DATASET_GCS_BUCKET and MOUNT_PATH"
exit 1
fi

if [[ "$DATASET_GCS_BUCKET" =~ gs:\/\/ ]] ; then
DATASET_GCS_BUCKET="${DATASET_GCS_BUCKET/gs:\/\//}"
echo "Removed gs:// from GCS bucket name, GCS bucket is $DATASET_GCS_BUCKET"
fi

if [[ -d ${MOUNT_PATH} ]]; then
echo "$MOUNT_PATH exists, removing..."
fusermount -u $MOUNT_PATH || rm -rf $MOUNT_PATH
fi

mkdir -p $MOUNT_PATH

# see https://cloud.google.com/storage/docs/gcsfuse-cli for all configurable options of gcsfuse CLI
# Grain uses _PROCESS_MANAGEMENT_MAX_THREADS = 64 (https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)
# Please make sure max-conns-per-host > grain_worker_count * _PROCESS_MANAGEMENT_MAX_THREADS

gcsfuse -o ro --implicit-dirs --http-client-timeout=5s --max-conns-per-host=2000 \
--debug_fuse_errors --debug_fuse --debug_gcs --debug_invariants --debug_mutex \
--log-file=$HOME/gcsfuse.json "$DATASET_GCS_BUCKET" "$MOUNT_PATH"
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ jax_cache_dir: ''
hf_data_dir: ''
hf_train_files: ''
hf_access_token: ''
grain_train_files: ''
grain_worker_count: 4
image_column: 'image'
caption_column: 'text'
resolution: 512
Expand Down
79 changes: 79 additions & 0 deletions src/maxdiffusion/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import dataclasses
import glob
import tensorflow as tf
import numpy as np
import grain.python as grain

from maxdiffusion import multihost_dataloading


def make_grain_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
):
"""Use Grain data input pipeline with ArrayRecord data format"""
data_files = glob.glob(config.grain_train_files)
data_source = grain.ArrayRecordDataSource(data_files)

operations = []
operations.append(ParseFeatures())
operations.append(grain.Batch(batch_size=global_batch_size // dataloading_host_count, drop_remainder=True))

index_sampler = grain.IndexSampler(
num_records=len(data_source),
num_epochs=None,
shard_options=grain.ShardOptions(
shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=True
),
shuffle=True,
seed=config.seed,
)

dataloader = grain.DataLoader(
data_source=data_source,
operations=operations,
sampler=index_sampler,
worker_count=config.grain_worker_count,
)

data_iter = multihost_dataloading.MultiHostDataLoadIterator(dataloader, mesh)
return data_iter


@dataclasses.dataclass
class ParseFeatures(grain.MapTransform):
"""Parse serialized example"""

def __init__(self):
self.feature_description = {
"moments": tf.io.FixedLenFeature([], tf.string),
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
}

def map(self, example):
def _parse(example):
features = tf.io.parse_single_example(example, self.feature_description)
moments = tf.io.parse_tensor(np.asarray(features["moments"]), out_type=tf.float32)
clip_embeddings = tf.io.parse_tensor(np.asarray(features["clip_embeddings"]), out_type=tf.float32)
return {"pixel_values": moments, "input_ids": clip_embeddings}

return _parse(example)
15 changes: 7 additions & 8 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,30 +87,29 @@ def make_tfrecord_iterator(
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""
feature_description = {
"latents": tf.io.FixedLenFeature([], tf.string),
"hidden_states": tf.io.FixedLenFeature([], tf.string),
"moments": tf.io.FixedLenFeature([], tf.string),
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
}

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)

def prepare_sample(features):
latents = tf.io.parse_tensor(tnp.asarray(features["latents"]), out_type=tf.float32)
hidden_states = tf.io.parse_tensor(tnp.asarray(features["hidden_states"]), out_type=tf.float32)
return {"pixel_values": latents, "input_ids": hidden_states}
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
return {"pixel_values": moments, "input_ids": clip_embeddings}

filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.prefetch(AUTOTUNE)
.repeat(-1)
.prefetch(AUTOTUNE)
)

train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)

train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter
11 changes: 10 additions & 1 deletion src/maxdiffusion/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jax

from maxdiffusion.input_pipeline import _hf_data_processing
from maxdiffusion.input_pipeline import _grain_data_processing
from maxdiffusion.input_pipeline import _tfds_data_processing
from maxdiffusion import multihost_dataloading
from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply
Expand Down Expand Up @@ -61,6 +62,14 @@ def make_data_iterator(
tokenize_fn=tokenize_fn,
image_transforms_fn=image_transforms_fn,
)
elif config.dataset_type == "grain":
return _grain_data_processing.make_grain_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
)
elif config.dataset_type == "tf":
return _tfds_data_processing.make_tf_iterator(
config,
Expand All @@ -80,7 +89,7 @@ def make_data_iterator(
global_batch_size,
)
else:
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf)"
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"


def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params):
Expand Down
64 changes: 60 additions & 4 deletions src/maxdiffusion/tests/input_pipeline_interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from functools import partial
import pathlib
import shutil
import subprocess
import unittest
from absl.testing import absltest

Expand Down Expand Up @@ -425,13 +426,68 @@ def test_make_pokemon_iterator_sdxl_cache(self):
config.resolution // vae_scale_factor,
)

def test_make_laion_grain_iterator(self):
try:
subprocess.check_output(
[
"bash",
"setup_gcsfuse.sh",
"DATASET_GCS_BUCKET=maxdiffusion-github-runner-test-assets",
"MOUNT_PATH=/tmp/gcsfuse",
],
stderr=subprocess.STDOUT,
)
except subprocess.CalledProcessError as e:
raise ValueError(f"setup_gcsfuse failed with error: {e.output}") from e
pyconfig.initialize(
[
None,
os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"),
"grain_train_files=/tmp/gcsfuse/datasets/array-record/laion400m/tf_records_512_encoder_state_fp32/*.arrayrecord",
"dataset_type=grain",
],
unittest=True,
)
config = pyconfig.config
global_batch_size = config.per_device_batch_size * jax.device_count()
devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained(
config.pretrained_model_name_or_path,
revision=config.revision,
dtype=config.activations_dtype,
safety_checker=None,
feature_extractor=None,
from_pt=config.from_pt,
)

train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size)
data = next(train_iterator)
device_count = jax.device_count()

vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1)
encoder_hidden_states = data["input_ids"]

# TODO - laion dataset was prepared with an extra dim.
# need to preprocess the dataset with dim removed.
if len(encoder_hidden_states.shape) == 4:
encoder_hidden_states = jnp.squeeze(encoder_hidden_states)

assert encoder_hidden_states.shape == (device_count, 77, 1024)
assert data["pixel_values"].shape == (
config.total_train_batch_size,
config.resolution // vae_scale_factor,
config.resolution // vae_scale_factor,
8,
)

def test_make_laion_tfrecord_iterator(self):
pyconfig.initialize(
[
None,
os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"),
"cache_latents_text_encoder_outputs=True",
"train_data_dir=gs://jfacevedo-maxdiffusion/laion400m/processed/laion400m_tfrec",
"train_data_dir=gs://jfacevedo-maxdiffusion/laion400m/raw_data/tf_records_512_encoder_state_fp32",
"dataset_type=tfrecord",
],
unittest=True,
Expand Down Expand Up @@ -464,10 +520,10 @@ def test_make_laion_tfrecord_iterator(self):

assert encoder_hidden_states.shape == (device_count, 77, 1024)
assert data["pixel_values"].shape == (
device_count,
pipeline.unet.config.in_channels,
config.total_train_batch_size,
config.resolution // vae_scale_factor,
config.resolution // vae_scale_factor,
8,
)

def test_tfrecord(self):
Expand Down
21 changes: 19 additions & 2 deletions src/maxdiffusion/trainers/stable_diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from maxdiffusion import (FlaxDDPMScheduler, maxdiffusion_utils, train_utils, max_utils, max_logging)

from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator)

from maxdiffusion.models.vae_flax import FlaxDiagonalGaussianDistribution

from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import (STABLE_DIFFUSION_CHECKPOINT)

Expand Down Expand Up @@ -67,6 +67,19 @@ def get_shaped_batch(self, config, pipeline):
pipeline.text_encoder.config.hidden_size,
)
input_ids_dtype = jnp.float32
elif config.dataset_type in ("tfrecord", "grain"):
batch_image_shape = (
total_train_batch_size,
config.resolution // vae_scale_factor,
config.resolution // vae_scale_factor,
8,
)
batch_ids_shape = (
total_train_batch_size,
pipeline.text_encoder.config.max_position_embeddings,
pipeline.text_encoder.config.hidden_size,
)
input_ids_dtype = jnp.float32
else:
batch_image_shape = (total_train_batch_size, 3, config.resolution, config.resolution)
batch_ids_shape = (total_train_batch_size, pipeline.text_encoder.config.max_position_embeddings)
Expand Down Expand Up @@ -240,10 +253,14 @@ def _train_step(unet_state, vae_state, text_encoder_state, batch, train_rng, pip
state_params = {"unet": unet_state.params}

def compute_loss(state_params):

if config.dataset_type == "tf" and config.cache_latents_text_encoder_outputs:
latents = batch["pixel_values"]
encoder_hidden_states = batch["input_ids"]
elif config.dataset_type in ("tfrecord", "grain"):
latents = FlaxDiagonalGaussianDistribution(batch["pixel_values"]).sample(sample_rng)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * pipeline.vae.config.scaling_factor
encoder_hidden_states = batch["input_ids"]
else:
# Convert images to latent space
vae_outputs = pipeline.vae.apply(
Expand Down

0 comments on commit 9238bc9

Please sign in to comment.