diff --git a/convert_hf_model.py b/convert_hf_model.py index e45240e..12fc0d8 100644 --- a/convert_hf_model.py +++ b/convert_hf_model.py @@ -88,6 +88,25 @@ def unpermute(tensor): .reshape(embedding_size, embedding_size) ) + def unpermute_attention_key_matrices(tensor): + if n_attention_heads == n_attention_query_groups: + return unpermute(tensor) + else: + key_value_size = ( + embedding_size // n_attention_heads * n_attention_query_groups + ) + + return ( + tensor.view( + n_attention_query_groups, + 2, + key_value_size // n_attention_query_groups // 2, + embedding_size, + ) + .transpose(1, 2) + .reshape(key_value_size, embedding_size) + ) + # attention_query_matrices for layer in range(n_layers): serialize_f32( @@ -99,7 +118,9 @@ def unpermute(tensor): for layer in range(n_layers): serialize_f32( output_file, - unpermute(hf_state_dict[f"model.layers.{layer}.self_attn.k_proj.weight"]), + unpermute_attention_key_matrices( + hf_state_dict[f"model.layers.{layer}.self_attn.k_proj.weight"] + ), ) # attention_value_matrices