Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
entrpn committed Dec 11, 2024
1 parent 115467f commit 83e1612
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/maxdiffusion/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def rename_for_interceptor(params_keys, network_alphas):
def make_lora_interceptor(cls, params, rank, network_alphas):
# Only unet interceptor supported for now.
network_alphas_for_interceptor = {}

unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys()
lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas)
network_alphas_for_interceptor.update(unet_alphas)
Expand All @@ -150,7 +150,7 @@ def make_lora_interceptor(cls, params, rank, network_alphas):
text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas)
lora_keys.extend(text_encoder_2_keys)
network_alphas_for_interceptor.update(text_encoder_2_alphas)

def _intercept(next_fn, args, kwargs, context):
mod = context.module
while mod is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/maxdiffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,13 @@ def _noop_interceptor(next_fn, args, kwargs, context):
if len(lora_config["lora_model_name_or_path"]) > 0:
# For now only first lora supported. In the future, they will be merged
# before being loaded.
# TODO - merge LoRAs here.
params, rank, network_alphas = pipeline.load_lora_weights(
lora_config["lora_model_name_or_path"][0],
weight_name=lora_config["weight_name"][0],
params=params,
adapter_name=lora_config["adapter_name"][0],
unet_config=pipeline.unet.config
unet_config=pipeline.unet.config,
)
interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas)

Expand Down
24 changes: 12 additions & 12 deletions src/maxdiffusion/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def create_flax_params_from_pytorch_state(
network_alpha_value = get_network_alpha_value(pt_key, network_alphas)

# rename text encoders fc1 lora layers.
pt_key = pt_key.replace("lora_linear_layer","lora")
pt_key = pt_key.replace("lora_linear_layer", "lora")

# only rename the unet keys, text encoders are already correct.
if "unet" in pt_key:
Expand All @@ -157,20 +157,20 @@ def create_flax_params_from_pytorch_state(
flax_key_list = [*pt_tuple_key]
if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key:
rename_from_to = (
("to_k_lora", ("k_proj", "lora")),
("to_q_lora", ("q_proj", "lora")),
("to_v_lora", ("v_proj", "lora")),
("to_out_lora", ("out_proj", "lora")),
("weight", "kernel"),
("to_k_lora", ("k_proj", "lora")),
("to_q_lora", ("q_proj", "lora")),
("to_v_lora", ("v_proj", "lora")),
("to_out_lora", ("out_proj", "lora")),
("weight", "kernel"),
)
# the unet
else:
else:
rename_from_to = (
("to_k_lora", ("to_k", "lora")),
("to_q_lora", ("to_q", "lora")),
("to_v_lora", ("to_v", "lora")),
("to_out_lora", ("to_out_0", "lora")),
("weight", "kernel"),
("to_k_lora", ("to_k", "lora")),
("to_q_lora", ("to_q", "lora")),
("to_v_lora", ("to_v", "lora")),
("to_out_lora", ("to_out_0", "lora")),
("weight", "kernel"),
)
for rename_from, rename_to in rename_from_to:
tmp = []
Expand Down

0 comments on commit 83e1612

Please sign in to comment.