From efa11e465997cc770d8c542de7a367baa9d2a3ae Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 13 Dec 2024 11:20:16 -0800 Subject: [PATCH] Support multi lora loading (#137) * adds multi lora support * update readme --- README.md | 11 +++++++ src/maxdiffusion/generate_sdxl.py | 15 ++++++--- src/maxdiffusion/loaders/lora_pipeline.py | 28 +++++++++-------- src/maxdiffusion/maxdiffusion_utils.py | 23 ++++++++------ .../models/modeling_flax_pytorch_utils.py | 31 ++++++++++++------- 5 files changed, 68 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 8f21ff5..c80b467 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? +- **`2024/12/12`**: Load multiple LoRAs for inference. - **`2024/10/22`**: LoRA support for Hyper SDXL. - **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format. - **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported. @@ -33,6 +34,7 @@ MaxDiffusion supports * Stable Diffusion XL (training and inference). * Stable Diffusion Lightning (inference). * Hyper-SD XL LoRA loading (inference). +* Load Multiple LoRA (SDXL inference). * ControlNet inference (Stable Diffusion 1.4 & SDXL). * Dreambooth training support for Stable Diffusion 1.x,2.x. @@ -45,6 +47,7 @@ MaxDiffusion supports * [Dreambooth](#dreambooth) * [Inference](#inference) * [Hyper-SD XL LoRA](#hyper-sdxl-lora) + * [Load Multiple LoRA](#load-multiple-lora) * [SDXL Lightning](#sdxl-lightning) * [ControlNet](#controlnet) * [Comparison To Alternatives](#comparison-to-alternatives) @@ -139,6 +142,14 @@ To generate images, run the following command: python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=2 do_classifier_free_guidance=False prompt="a photograph of a cat wearing a hat riding a skateboard in a park." per_device_batch_size=1 pretrained_model_name_or_path="Lykon/AAM_XL_AnimeMix" from_pt=True revision=main diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}' ``` + ## Load Multiple LoRA + + Supports loading multiple LoRAs for inference. Both from local or from HuggingFace hub. + + ```bash + python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=30 do_classifier_free_guidance=True prompt="ultra detailed diagram blueprint of a papercut Sitting MaineCoon cat, wide canvas, ampereart, electrical diagram, bl3uprint, papercut" per_device_batch_size=1 diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/blueprintify-sd-xl-10.safetensors","TheLastBen/Papercut_SDXL"], "weight_name" : ["/home/jfacevedo/blueprintify-sd-xl-10.safetensors","papercut.safetensors"], "adapter_name" : ["blueprint","papercut"], "scale": [0.8, 0.7], "from_pt": ["true", "true"]}' + ``` + ## SDXL Lightning Single and Multi host inference is supported with sharding annotations: diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index f7d1194..fcdd84a 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -16,6 +16,7 @@ import functools from absl import app +from contextlib import ExitStack from typing import Sequence import time @@ -233,14 +234,15 @@ def run(config): params["unet"] = unet_params # maybe load lora and create interceptor - params, lora_interceptor = maybe_load_lora(config, pipeline, params) + params, lora_interceptors = maybe_load_lora(config, pipeline, params) if config.lightning_repo: pipeline, params = load_sdxllightning_unet(config, pipeline, params) # Don't restore the full train state, instead, just restore params # and create an inference state. - with nn.intercept_methods(lora_interceptor): + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] unet_state, unet_state_shardings = max_utils.setup_initial_state( model=pipeline.unet, tx=None, @@ -254,7 +256,8 @@ def run(config): vae_state, vae_state_shardings = checkpoint_loader.create_vae_state( pipeline, params, checkpoint_item_name="vae_state", is_training=False ) - with nn.intercept_methods(lora_interceptor): + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state( pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False ) @@ -293,11 +296,13 @@ def run(config): ) s = time.time() - with nn.intercept_methods(lora_interceptor): + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] p_run_inference(states).block_until_ready() print("compile time: ", (time.time() - s)) s = time.time() - with nn.intercept_methods(lora_interceptor): + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] images = p_run_inference(states).block_until_ready() print("inference time: ", (time.time() - s)) images = jax.experimental.multihost_utils.process_allgather(images) diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index fe4d6e9..812a3ff 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -88,7 +88,7 @@ def load_lora_weights( return params, rank, network_alphas @classmethod - def _get_lora_layer(cls, module_path, module, rank, network_alphas): + def _get_lora_layer(cls, module_path, module, rank, network_alphas, adapter_name): is_conv = any("conv" in str_ for str_ in module_path) network_alpha = network_alphas.get(module_path, None) if is_conv: @@ -105,7 +105,7 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas): dtype=module.dtype, weights_dtype=module.param_dtype, precision=module.precision, - name="lora", + name=f"lora-{adapter_name}", ) else: lora_module = LoRALinearLayer( @@ -115,16 +115,17 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas): dtype=module.dtype, weights_dtype=module.param_dtype, precision=module.precision, - name="lora", + name=f"lora-{adapter_name}", ) return lora_module - def rename_for_interceptor(params_keys, network_alphas): + def rename_for_interceptor(params_keys, network_alphas, adapter_name): new_params_keys = [] new_network_alphas = {} + lora_name = f"lora-{adapter_name}" for layer_lora in params_keys: - if "lora" in layer_lora: - new_layer_lora = layer_lora[: layer_lora.index("lora")] + if lora_name in layer_lora: + new_layer_lora = layer_lora[: layer_lora.index(lora_name)] if new_layer_lora not in new_params_keys: new_params_keys.append(new_layer_lora) network_alpha = network_alphas[layer_lora] @@ -132,22 +133,23 @@ def rename_for_interceptor(params_keys, network_alphas): return new_params_keys, new_network_alphas @classmethod - def make_lora_interceptor(cls, params, rank, network_alphas): + def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name): # Only unet interceptor supported for now. network_alphas_for_interceptor = {} unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys() - lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas) + lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas, adapter_name) network_alphas_for_interceptor.update(unet_alphas) text_encoder_keys = flax.traverse_util.flatten_dict(params["text_encoder"]).keys() - text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas) + text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas, adapter_name) lora_keys.extend(text_encoder_keys) network_alphas_for_interceptor.update(text_encoder_alphas) - if "text_encoder_2" in params.keys(): text_encoder_2_keys = flax.traverse_util.flatten_dict(params["text_encoder_2"]).keys() - text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas) + text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor( + text_encoder_2_keys, network_alphas, adapter_name + ) lora_keys.extend(text_encoder_2_keys) network_alphas_for_interceptor.update(text_encoder_2_alphas) @@ -161,7 +163,7 @@ def _intercept(next_fn, args, kwargs, context): if context.method_name == "__call__": module_path = context.module.path if module_path in lora_keys: - lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor) + lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor, adapter_name) return lora_layer(h, *args, **kwargs) return h @@ -290,5 +292,5 @@ def load_lora(cls, state_dict, network_alphas, params, adapter_name=None, _pipel `default_{i}` where i is the total number of adapters being loaded. """ # Load the layers corresponding to Unet. - params, rank, network_alphas = convert_lora_pytorch_state_dict_to_flax(state_dict, params, network_alphas) + params, rank, network_alphas = convert_lora_pytorch_state_dict_to_flax(state_dict, params, network_alphas, adapter_name) return params, rank, network_alphas diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 029641e..60b1730 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -45,21 +45,24 @@ def _noop_interceptor(next_fn, args, kwargs, context): return next_fn(*args, **kwargs) lora_config = config.lora_config - interceptor = _noop_interceptor + interceptors = [_noop_interceptor] if len(lora_config["lora_model_name_or_path"]) > 0: # For now only first lora supported. In the future, they will be merged # before being loaded. # TODO - merge LoRAs here. - params, rank, network_alphas = pipeline.load_lora_weights( - lora_config["lora_model_name_or_path"][0], - weight_name=lora_config["weight_name"][0], - params=params, - adapter_name=lora_config["adapter_name"][0], - unet_config=pipeline.unet.config, - ) - interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas) + interceptors = [] + for i in range(len(lora_config["lora_model_name_or_path"])): + params, rank, network_alphas = pipeline.load_lora_weights( + lora_config["lora_model_name_or_path"][i], + weight_name=lora_config["weight_name"][i], + params=params, + adapter_name=lora_config["adapter_name"][i], + unet_config=pipeline.unet.config, + ) + interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i]) + interceptors.append(interceptor) - return params, interceptor + return params, interceptors def vae_apply(images, sample_rng, vae, vae_params): diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 3fa5417..5cb7b9e 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -130,7 +130,13 @@ def get_network_alpha_value(pt_key, network_alphas): def create_flax_params_from_pytorch_state( - pt_state_dict, unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, network_alphas, is_lora=False + pt_state_dict, + unet_state_dict, + text_encoder_state_dict, + text_encoder_2_state_dict, + network_alphas, + adapter_name, + is_lora=False, ): rank = None renamed_network_alphas = {} @@ -157,19 +163,21 @@ def create_flax_params_from_pytorch_state( flax_key_list = [*pt_tuple_key] if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key: rename_from_to = ( - ("to_k_lora", ("k_proj", "lora")), - ("to_q_lora", ("q_proj", "lora")), - ("to_v_lora", ("v_proj", "lora")), - ("to_out_lora", ("out_proj", "lora")), + ("to_k_lora", ("k_proj", f"lora-{adapter_name}")), + ("to_q_lora", ("q_proj", f"lora-{adapter_name}")), + ("to_v_lora", ("v_proj", f"lora-{adapter_name}")), + ("to_out_lora", ("out_proj", f"lora-{adapter_name}")), + ("lora", f"lora-{adapter_name}"), ("weight", "kernel"), ) # the unet else: rename_from_to = ( - ("to_k_lora", ("to_k", "lora")), - ("to_q_lora", ("to_q", "lora")), - ("to_v_lora", ("to_v", "lora")), - ("to_out_lora", ("to_out_0", "lora")), + ("to_k_lora", ("to_k", f"lora-{adapter_name}")), + ("to_q_lora", ("to_q", f"lora-{adapter_name}")), + ("to_v_lora", ("to_v", f"lora-{adapter_name}")), + ("to_out_lora", ("to_out_0", f"lora-{adapter_name}")), + ("lora", f"lora-{adapter_name}"), ("weight", "kernel"), ) for rename_from, rename_to in rename_from_to: @@ -206,11 +214,10 @@ def create_flax_params_from_pytorch_state( if network_alpha_value >= 0: renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value - return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas -def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas): +def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas, adapter_name): # Step 1: Convert pytorch tensor to numpy # sometimes we load weights in bf16 and numpy doesn't support it pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} @@ -223,7 +230,7 @@ def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alpha text_encoder_2_params = None (unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, network_alphas) = ( create_flax_params_from_pytorch_state( - pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, is_lora=True + pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, adapter_name, is_lora=True ) ) params["unet"] = unflatten_dict(unet_state_dict)