diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 5f02ec8..9da8646 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -137,7 +137,11 @@ def create_flax_params_from_pytorch_state( # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): network_alpha_value = get_network_alpha_value(pt_key, network_alphas) - renamed_pt_key = rename_key(pt_key) + + # only rename the unet keys, text encoders are already correct. + if "unet" in pt_key: + renamed_pt_key = rename_key(pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) # conv if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: