From 83e16121aac7925a0465fefb095bc3e773e3cc33 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 11 Dec 2024 20:51:03 +0000 Subject: [PATCH] formatting --- src/maxdiffusion/loaders/lora_pipeline.py | 4 ++-- src/maxdiffusion/maxdiffusion_utils.py | 3 ++- .../models/modeling_flax_pytorch_utils.py | 24 +++++++++---------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index c1d085f..fe4d6e9 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -135,7 +135,7 @@ def rename_for_interceptor(params_keys, network_alphas): def make_lora_interceptor(cls, params, rank, network_alphas): # Only unet interceptor supported for now. network_alphas_for_interceptor = {} - + unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys() lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas) network_alphas_for_interceptor.update(unet_alphas) @@ -150,7 +150,7 @@ def make_lora_interceptor(cls, params, rank, network_alphas): text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas) lora_keys.extend(text_encoder_2_keys) network_alphas_for_interceptor.update(text_encoder_2_alphas) - + def _intercept(next_fn, args, kwargs, context): mod = context.module while mod is not None: diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 5ddac4e..029641e 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -49,12 +49,13 @@ def _noop_interceptor(next_fn, args, kwargs, context): if len(lora_config["lora_model_name_or_path"]) > 0: # For now only first lora supported. In the future, they will be merged # before being loaded. + # TODO - merge LoRAs here. params, rank, network_alphas = pipeline.load_lora_weights( lora_config["lora_model_name_or_path"][0], weight_name=lora_config["weight_name"][0], params=params, adapter_name=lora_config["adapter_name"][0], - unet_config=pipeline.unet.config + unet_config=pipeline.unet.config, ) interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas) diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 7e5e0ae..3fa5417 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -139,7 +139,7 @@ def create_flax_params_from_pytorch_state( network_alpha_value = get_network_alpha_value(pt_key, network_alphas) # rename text encoders fc1 lora layers. - pt_key = pt_key.replace("lora_linear_layer","lora") + pt_key = pt_key.replace("lora_linear_layer", "lora") # only rename the unet keys, text encoders are already correct. if "unet" in pt_key: @@ -157,20 +157,20 @@ def create_flax_params_from_pytorch_state( flax_key_list = [*pt_tuple_key] if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key: rename_from_to = ( - ("to_k_lora", ("k_proj", "lora")), - ("to_q_lora", ("q_proj", "lora")), - ("to_v_lora", ("v_proj", "lora")), - ("to_out_lora", ("out_proj", "lora")), - ("weight", "kernel"), + ("to_k_lora", ("k_proj", "lora")), + ("to_q_lora", ("q_proj", "lora")), + ("to_v_lora", ("v_proj", "lora")), + ("to_out_lora", ("out_proj", "lora")), + ("weight", "kernel"), ) # the unet - else: + else: rename_from_to = ( - ("to_k_lora", ("to_k", "lora")), - ("to_q_lora", ("to_q", "lora")), - ("to_v_lora", ("to_v", "lora")), - ("to_out_lora", ("to_out_0", "lora")), - ("weight", "kernel"), + ("to_k_lora", ("to_k", "lora")), + ("to_q_lora", ("to_q", "lora")), + ("to_v_lora", ("to_v", "lora")), + ("to_out_lora", ("to_out_0", "lora")), + ("weight", "kernel"), ) for rename_from, rename_to in rename_from_to: tmp = []