Skip to content

Commit

Permalink
Dreambooth training (#90)
Browse files Browse the repository at this point in the history
* dreambooth input pipeline and training script

---------

Co-authored-by: Juan Acevedo <[email protected]>
  • Loading branch information
entrpn and jfacevedo-google authored Jul 21, 2024
1 parent 57629bc commit 67ed9cc
Show file tree
Hide file tree
Showing 43 changed files with 1,242 additions and 251 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ MaxDiffusion supports
* [Getting Started](#getting-started)
* [Local Development for single host](#getting-started-local-development-for-single-host)
* [Training](#training)
* [Dreambooth](#dreambooth)
* [Inference](#inference)
* [SDXL Lightning](#sdxl-lightning)
* [ControlNet](#controlnet)
Expand Down Expand Up @@ -90,6 +91,14 @@ After installation completes, run the training script.
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_2_base.yml run_name="my_run" pretrained_model_name_or_path=<your_saved_checkpoint_path> from_pt=False attention=dot_product
```

## Dreambooth

**Stable Diffusion 1.x,2.x**

```bash
python src/maxdiffusion/dreambooth/train_dreambooth.py src/maxdiffusion/configs/base15.yml class_data_dir=<your-class-dir> instance_data_dir=<your-instance-dir> instance_prompt="a photo of ohwx dog" class_prompt="photo of a dog" max_train_steps=150 cache_dir=<your-cache-dir> class_prompt="a photo of a dog" activations_dtype=bfloat16 weights_dtype=float32 per_device_batch_size=1 enable_profiler=False precision=DEFAULT cache_dreambooth_dataset=False learning_rate=4e-6 output_dir=<your-output-dir> num_class_images=100 run_name=<your-run-name> base_output_directory=gs://<your-bucket-name>
```

## Inference

To generate images, run the following command:
Expand Down
35 changes: 29 additions & 6 deletions src/maxdiffusion/configs/base15.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,17 @@ log_period: 10000000000 # Flushes Tensorboard
pretrained_model_name_or_path: 'runwayml/stable-diffusion-v1-5'
unet_checkpoint: ''
revision: 'flax'
dtype: 'bfloat16'

# This will convert the weights to this dtype.
weights_dtype: 'float32'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
precision: "DEFAULT"
# orbax or diffusers
checkpoint_type: 'orbax'
# Set true to load weights from pytorch
from_pt: False
split_head_dim: True
Expand Down Expand Up @@ -88,11 +98,13 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_length', 'fsdp'],
['out_channels', 'fsdp'],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['length', 'fsdp']
]
data_sharding: [['data', 'fsdp', 'tensor']]

Expand Down Expand Up @@ -157,6 +169,7 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 1.e-2 # AdamW Weight decay
max_grad_norm: 1.0

enable_profiler: True
# Skip first n steps for profiling, to omit things like compilation and to give
Expand All @@ -178,4 +191,14 @@ enable_mllog: False
controlnet_model_name_or_path: 'lllyasviel/sd-controlnet-canny'
controlnet_from_pt: True
controlnet_conditioning_scale: 1.0
controlnet_image: 'https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg'
controlnet_image: 'https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg'

# dreambooth - this script always uses prior preservation.
instance_data_dir: ''
class_data_dir: ''
instance_prompt: ''
class_prompt: ''
prior_loss_weight: 1.0
num_class_images: 100
# If true, set dataset_save_location.
cache_dreambooth_dataset: False
17 changes: 15 additions & 2 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ log_period: 10000000000 # Flushes Tensorboard
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-1'
unet_checkpoint: ''
revision: 'bf16'
dtype: 'bfloat16'
# This will convert the weights to this dtype.
weights_dtype: 'float32'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
# Set true to load weights from pytorch
from_pt: False
split_head_dim: True
Expand Down Expand Up @@ -174,4 +177,14 @@ guidance_scale: 7.5
guidance_rescale: 0.0
num_inference_steps: 30

enable_mllog: False
enable_mllog: False

# dreambooth - this script always uses prior preservation.
instance_data_dir: ''
class_data_dir: ''
instance_prompt: ''
class_prompt: ''
prior_loss_weight: 1.0
num_class_images: 100
# If true, set dataset_save_location.
cache_dreambooth_dataset: False
18 changes: 16 additions & 2 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ log_period: 10000000000 # Flushes Tensorboard
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-base'
unet_checkpoint: ''
revision: 'main'
dtype: 'bfloat16'

# This will convert the weights to this dtype.
weights_dtype: 'float32'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
Expand Down Expand Up @@ -190,4 +194,14 @@ extracted_files_dir: ""
tfrecords_dir: ""
no_records_per_shard: 1000

enable_mllog: False
enable_mllog: False

# dreambooth - this script always uses prior preservation.
instance_data_dir: ''
class_data_dir: ''
instance_prompt: ''
class_prompt: ''
prior_loss_weight: 1.0
num_class_images: 100
# If true, set dataset_save_location.
cache_dreambooth_dataset: False
6 changes: 5 additions & 1 deletion src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ log_period: 100
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0'
unet_checkpoint: ''
revision: 'refs/pr/95'
dtype: 'bfloat16'
# This will convert the weights to this dtype.
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
weights_dtype: 'float32'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
# Set true to load weights from pytorch
from_pt: False
split_head_dim: True
Expand Down
5 changes: 4 additions & 1 deletion src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ log_period: 100
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0'
unet_checkpoint: ''
revision: 'refs/pr/95'
dtype: 'bfloat16'
# This will convert the weights to this dtype.
weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
# Set true to load weights from pytorch
from_pt: False
split_head_dim: True
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def to_json_saveable(value):
config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)
config_dict.pop("mesh", None)
config_dict.pop("precision", None)

return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

Expand Down
7 changes: 3 additions & 4 deletions src/maxdiffusion/controlnet/generate_controlnet_replicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
limitations under the License.
"""

import os
from typing import Sequence
from absl import app

Expand All @@ -23,13 +22,10 @@
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax.experimental.compilation_cache import compilation_cache as cc
from maxdiffusion import pyconfig
from maxdiffusion.utils import load_image
from maxdiffusion import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel

cc.set_cache_dir(os.path.expanduser("~/jax_cache"))

def run(config):

rng = jax.random.PRNGKey(config.seed)
Expand Down Expand Up @@ -84,6 +80,9 @@ def run(config):

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
config = pyconfig.config
if len(config.cache_dir) > 0:
jax.config.update("jax_compilation_cache_dir", config.cache_dir)
run(pyconfig.config)

if __name__ == "__main__":
Expand Down
17 changes: 6 additions & 11 deletions src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
limitations under the License.
"""

import os
from typing import Sequence
from absl import app

Expand All @@ -23,24 +22,17 @@
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax.experimental.compilation_cache import compilation_cache as cc
from maxdiffusion.utils import load_image
from PIL import Image
from maxdiffusion import pyconfig
from maxdiffusion import FlaxStableDiffusionXLControlNetPipeline, FlaxControlNetModel
from maxdiffusion.max_utils import (
get_dtype
)
import cv2

cc.set_cache_dir(os.path.expanduser("~/jax_cache"))

def create_key(seed=0):
return jax.random.PRNGKey(seed)

def run(config):
rng = jax.random.PRNGKey(config.seed)
weight_dtype = get_dtype(config)

prompts = config.prompt
negative_prompts = config.negative_prompt
Expand All @@ -57,14 +49,14 @@ def run(config):
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
config.controlnet_model_name_or_path,
from_pt=config.controlnet_from_pt,
dtype=weight_dtype
dtype=config.activations_dtype
)

pipe, params = FlaxStableDiffusionXLControlNetPipeline.from_pretrained(
config.pretrained_model_name_or_path,
controlnet=controlnet,
revision=config.revision,
dtype=weight_dtype
dtype=config.activations_dtype
)

scheduler_state = params.pop("scheduler")
Expand Down Expand Up @@ -97,10 +89,13 @@ def run(config):

output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
output_images[0].save("generated_image.png")
return output_images[0]
return output_images

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
config = pyconfig.config
if len(config.cache_dir) > 0:
jax.config.update("jax_compilation_cache_dir", config.cache_dir)
run(pyconfig.config)

if __name__ == "__main__":
Expand Down
10 changes: 10 additions & 0 deletions src/maxdiffusion/dreambooth/dreambooth_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
INSTANCE_IMAGES = "instance_images"
INSTANCE_IMAGE_LATENTS = "instance_image_latents"
INSTANCE_PROMPT_IDS = "instance_prompt_ids"
INSTANCE_PROMPT_INPUT_IDS = "instance_prompt_input_ids"
CLASS_IMAGES = "class_images"
CLASS_IMAGE_LATENTS = "class_image_latents"
CLASS_PROMPT_IDS = "class_prompt_ids"
CLASS_PROMPT_INPUT_IDS = "class_prompt_input_ids"
INSTANCE_DATASET_NAME = "instance_dataset"
CLASS_DATASET_NAME = "class_dataset"
Loading

0 comments on commit 67ed9cc

Please sign in to comment.