From f75de2f0492a311a2df6200d7883710449b11a82 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 10 Dec 2024 00:38:01 +0000 Subject: [PATCH 1/4] supports other format loras for unet --- src/maxdiffusion/maxdiffusion_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index fb1a827..5ddac4e 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -54,6 +54,7 @@ def _noop_interceptor(next_fn, args, kwargs, context): weight_name=lora_config["weight_name"][0], params=params, adapter_name=lora_config["adapter_name"][0], + unet_config=pipeline.unet.config ) interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas) From 718141fb5c2f4afb2f1dbe7a93306c7a82bb92d7 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 10 Dec 2024 16:50:02 +0000 Subject: [PATCH 2/4] don't rename key for text encoders so that pytree matches original. --- src/maxdiffusion/models/modeling_flax_pytorch_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 5f02ec8..9da8646 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -137,7 +137,11 @@ def create_flax_params_from_pytorch_state( # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): network_alpha_value = get_network_alpha_value(pt_key, network_alphas) - renamed_pt_key = rename_key(pt_key) + + # only rename the unet keys, text encoders are already correct. + if "unet" in pt_key: + renamed_pt_key = rename_key(pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) # conv if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: From 115467fd7e3b74dc20b73cca15f758e59b1aa81f Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 11 Dec 2024 19:42:30 +0000 Subject: [PATCH 3/4] add text encoder support. --- src/maxdiffusion/generate_sdxl.py | 15 +++++----- src/maxdiffusion/loaders/lora_pipeline.py | 28 ++++++++++++++----- .../models/modeling_flax_pytorch_utils.py | 23 ++++++++++++--- 3 files changed, 48 insertions(+), 18 deletions(-) diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index be750fb..f7d1194 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -225,6 +225,7 @@ def run(config): pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False ) + # load unet params from orbax checkpoint unet_params = load_params_from_path( config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "unet_state" ) @@ -253,14 +254,14 @@ def run(config): vae_state, vae_state_shardings = checkpoint_loader.create_vae_state( pipeline, params, checkpoint_item_name="vae_state", is_training=False ) - text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state( - pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False - ) - - text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state( - pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False - ) + with nn.intercept_methods(lora_interceptor): + text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state( + pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False + ) + text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state( + pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False + ) states = {} state_shardings = {} diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 4b50743..c1d085f 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -121,22 +121,36 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas): def rename_for_interceptor(params_keys, network_alphas): new_params_keys = [] + new_network_alphas = {} for layer_lora in params_keys: if "lora" in layer_lora: new_layer_lora = layer_lora[: layer_lora.index("lora")] if new_layer_lora not in new_params_keys: new_params_keys.append(new_layer_lora) network_alpha = network_alphas[layer_lora] - del network_alphas[layer_lora] - network_alphas[new_layer_lora] = network_alpha - return new_params_keys, network_alphas + new_network_alphas[new_layer_lora] = network_alpha + return new_params_keys, new_network_alphas @classmethod def make_lora_interceptor(cls, params, rank, network_alphas): # Only unet interceptor supported for now. + network_alphas_for_interceptor = {} + unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys() - unet_lora_keys, network_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas) - + lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas) + network_alphas_for_interceptor.update(unet_alphas) + + text_encoder_keys = flax.traverse_util.flatten_dict(params["text_encoder"]).keys() + text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas) + lora_keys.extend(text_encoder_keys) + network_alphas_for_interceptor.update(text_encoder_alphas) + + if "text_encoder_2" in params.keys(): + text_encoder_2_keys = flax.traverse_util.flatten_dict(params["text_encoder_2"]).keys() + text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas) + lora_keys.extend(text_encoder_2_keys) + network_alphas_for_interceptor.update(text_encoder_2_alphas) + def _intercept(next_fn, args, kwargs, context): mod = context.module while mod is not None: @@ -146,8 +160,8 @@ def _intercept(next_fn, args, kwargs, context): h = next_fn(*args, **kwargs) if context.method_name == "__call__": module_path = context.module.path - if module_path in unet_lora_keys: - lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas) + if module_path in lora_keys: + lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor) return lora_layer(h, *args, **kwargs) return h diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 9da8646..7e5e0ae 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -137,11 +137,15 @@ def create_flax_params_from_pytorch_state( # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): network_alpha_value = get_network_alpha_value(pt_key, network_alphas) - + + # rename text encoders fc1 lora layers. + pt_key = pt_key.replace("lora_linear_layer","lora") + # only rename the unet keys, text encoders are already correct. if "unet" in pt_key: renamed_pt_key = rename_key(pt_key) - + else: + renamed_pt_key = pt_key pt_tuple_key = tuple(renamed_pt_key.split(".")) # conv if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: @@ -151,13 +155,24 @@ def create_flax_params_from_pytorch_state( flax_tensor = pt_tensor else: flax_key_list = [*pt_tuple_key] - for rename_from, rename_to in ( + if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key: + rename_from_to = ( + ("to_k_lora", ("k_proj", "lora")), + ("to_q_lora", ("q_proj", "lora")), + ("to_v_lora", ("v_proj", "lora")), + ("to_out_lora", ("out_proj", "lora")), + ("weight", "kernel"), + ) + # the unet + else: + rename_from_to = ( ("to_k_lora", ("to_k", "lora")), ("to_q_lora", ("to_q", "lora")), ("to_v_lora", ("to_v", "lora")), ("to_out_lora", ("to_out_0", "lora")), ("weight", "kernel"), - ): + ) + for rename_from, rename_to in rename_from_to: tmp = [] for s in flax_key_list: if s == rename_from: From 83e16121aac7925a0465fefb095bc3e773e3cc33 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 11 Dec 2024 20:51:03 +0000 Subject: [PATCH 4/4] formatting --- src/maxdiffusion/loaders/lora_pipeline.py | 4 ++-- src/maxdiffusion/maxdiffusion_utils.py | 3 ++- .../models/modeling_flax_pytorch_utils.py | 24 +++++++++---------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index c1d085f..fe4d6e9 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -135,7 +135,7 @@ def rename_for_interceptor(params_keys, network_alphas): def make_lora_interceptor(cls, params, rank, network_alphas): # Only unet interceptor supported for now. network_alphas_for_interceptor = {} - + unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys() lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas) network_alphas_for_interceptor.update(unet_alphas) @@ -150,7 +150,7 @@ def make_lora_interceptor(cls, params, rank, network_alphas): text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas) lora_keys.extend(text_encoder_2_keys) network_alphas_for_interceptor.update(text_encoder_2_alphas) - + def _intercept(next_fn, args, kwargs, context): mod = context.module while mod is not None: diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 5ddac4e..029641e 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -49,12 +49,13 @@ def _noop_interceptor(next_fn, args, kwargs, context): if len(lora_config["lora_model_name_or_path"]) > 0: # For now only first lora supported. In the future, they will be merged # before being loaded. + # TODO - merge LoRAs here. params, rank, network_alphas = pipeline.load_lora_weights( lora_config["lora_model_name_or_path"][0], weight_name=lora_config["weight_name"][0], params=params, adapter_name=lora_config["adapter_name"][0], - unet_config=pipeline.unet.config + unet_config=pipeline.unet.config, ) interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas) diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 7e5e0ae..3fa5417 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -139,7 +139,7 @@ def create_flax_params_from_pytorch_state( network_alpha_value = get_network_alpha_value(pt_key, network_alphas) # rename text encoders fc1 lora layers. - pt_key = pt_key.replace("lora_linear_layer","lora") + pt_key = pt_key.replace("lora_linear_layer", "lora") # only rename the unet keys, text encoders are already correct. if "unet" in pt_key: @@ -157,20 +157,20 @@ def create_flax_params_from_pytorch_state( flax_key_list = [*pt_tuple_key] if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key: rename_from_to = ( - ("to_k_lora", ("k_proj", "lora")), - ("to_q_lora", ("q_proj", "lora")), - ("to_v_lora", ("v_proj", "lora")), - ("to_out_lora", ("out_proj", "lora")), - ("weight", "kernel"), + ("to_k_lora", ("k_proj", "lora")), + ("to_q_lora", ("q_proj", "lora")), + ("to_v_lora", ("v_proj", "lora")), + ("to_out_lora", ("out_proj", "lora")), + ("weight", "kernel"), ) # the unet - else: + else: rename_from_to = ( - ("to_k_lora", ("to_k", "lora")), - ("to_q_lora", ("to_q", "lora")), - ("to_v_lora", ("to_v", "lora")), - ("to_out_lora", ("to_out_0", "lora")), - ("weight", "kernel"), + ("to_k_lora", ("to_k", "lora")), + ("to_q_lora", ("to_q", "lora")), + ("to_v_lora", ("to_v", "lora")), + ("to_out_lora", ("to_out_0", "lora")), + ("weight", "kernel"), ) for rename_from, rename_to in rename_from_to: tmp = []