Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support other lora formats #136

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/maxdiffusion/generate_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def run(config):
pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False
)

# load unet params from orbax checkpoint
unet_params = load_params_from_path(
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "unet_state"
)
Expand Down Expand Up @@ -253,14 +254,14 @@ def run(config):
vae_state, vae_state_shardings = checkpoint_loader.create_vae_state(
pipeline, params, checkpoint_item_name="vae_state", is_training=False
)
text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state(
pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False
)

text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state(
pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False
)
with nn.intercept_methods(lora_interceptor):
text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state(
pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False
)

text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state(
pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False
)
states = {}
state_shardings = {}

Expand Down
26 changes: 20 additions & 6 deletions src/maxdiffusion/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,35 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas):

def rename_for_interceptor(params_keys, network_alphas):
new_params_keys = []
new_network_alphas = {}
for layer_lora in params_keys:
if "lora" in layer_lora:
new_layer_lora = layer_lora[: layer_lora.index("lora")]
if new_layer_lora not in new_params_keys:
new_params_keys.append(new_layer_lora)
network_alpha = network_alphas[layer_lora]
del network_alphas[layer_lora]
network_alphas[new_layer_lora] = network_alpha
return new_params_keys, network_alphas
new_network_alphas[new_layer_lora] = network_alpha
return new_params_keys, new_network_alphas

@classmethod
def make_lora_interceptor(cls, params, rank, network_alphas):
# Only unet interceptor supported for now.
network_alphas_for_interceptor = {}

unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys()
unet_lora_keys, network_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas)
lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas)
network_alphas_for_interceptor.update(unet_alphas)

text_encoder_keys = flax.traverse_util.flatten_dict(params["text_encoder"]).keys()
text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas)
lora_keys.extend(text_encoder_keys)
network_alphas_for_interceptor.update(text_encoder_alphas)

if "text_encoder_2" in params.keys():
text_encoder_2_keys = flax.traverse_util.flatten_dict(params["text_encoder_2"]).keys()
text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas)
lora_keys.extend(text_encoder_2_keys)
network_alphas_for_interceptor.update(text_encoder_2_alphas)

def _intercept(next_fn, args, kwargs, context):
mod = context.module
Expand All @@ -146,8 +160,8 @@ def _intercept(next_fn, args, kwargs, context):
h = next_fn(*args, **kwargs)
if context.method_name == "__call__":
module_path = context.module.path
if module_path in unet_lora_keys:
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas)
if module_path in lora_keys:
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor)
return lora_layer(h, *args, **kwargs)
return h

Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/maxdiffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ def _noop_interceptor(next_fn, args, kwargs, context):
if len(lora_config["lora_model_name_or_path"]) > 0:
# For now only first lora supported. In the future, they will be merged
# before being loaded.
# TODO - merge LoRAs here.
params, rank, network_alphas = pipeline.load_lora_weights(
lora_config["lora_model_name_or_path"][0],
weight_name=lora_config["weight_name"][0],
params=params,
adapter_name=lora_config["adapter_name"][0],
unet_config=pipeline.unet.config,
)
interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas)

Expand Down
35 changes: 27 additions & 8 deletions src/maxdiffusion/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,15 @@ def create_flax_params_from_pytorch_state(
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
network_alpha_value = get_network_alpha_value(pt_key, network_alphas)
renamed_pt_key = rename_key(pt_key)

# rename text encoders fc1 lora layers.
pt_key = pt_key.replace("lora_linear_layer", "lora")

# only rename the unet keys, text encoders are already correct.
if "unet" in pt_key:
renamed_pt_key = rename_key(pt_key)
else:
renamed_pt_key = pt_key
pt_tuple_key = tuple(renamed_pt_key.split("."))
# conv
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
Expand All @@ -147,13 +155,24 @@ def create_flax_params_from_pytorch_state(
flax_tensor = pt_tensor
else:
flax_key_list = [*pt_tuple_key]
for rename_from, rename_to in (
("to_k_lora", ("to_k", "lora")),
("to_q_lora", ("to_q", "lora")),
("to_v_lora", ("to_v", "lora")),
("to_out_lora", ("to_out_0", "lora")),
("weight", "kernel"),
):
if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key:
rename_from_to = (
("to_k_lora", ("k_proj", "lora")),
("to_q_lora", ("q_proj", "lora")),
("to_v_lora", ("v_proj", "lora")),
("to_out_lora", ("out_proj", "lora")),
("weight", "kernel"),
)
# the unet
else:
rename_from_to = (
("to_k_lora", ("to_k", "lora")),
("to_q_lora", ("to_q", "lora")),
("to_v_lora", ("to_v", "lora")),
("to_out_lora", ("to_out_0", "lora")),
("weight", "kernel"),
)
for rename_from, rename_to in rename_from_to:
tmp = []
for s in flax_key_list:
if s == rename_from:
Expand Down