Skip to content

Commit

Permalink
don't rename key for text encoders so that pytree matches original.
Browse files Browse the repository at this point in the history
  • Loading branch information
entrpn committed Dec 10, 2024
1 parent f75de2f commit 718141f
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/maxdiffusion/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def create_flax_params_from_pytorch_state(
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
network_alpha_value = get_network_alpha_value(pt_key, network_alphas)
renamed_pt_key = rename_key(pt_key)

# only rename the unet keys, text encoders are already correct.
if "unet" in pt_key:
renamed_pt_key = rename_key(pt_key)

pt_tuple_key = tuple(renamed_pt_key.split("."))
# conv
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
Expand Down

0 comments on commit 718141f

Please sign in to comment.