forked from facebookresearch/fairseq2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
integ.py
38 lines (31 loc) · 1.76 KB
/
integ.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict
from fairseq2.models.utils.checkpoint import convert_model_state_dict
def convert_to_reference_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a fairseq2 LLaMA checkpoint to the reference format."""
try:
model_key = checkpoint["model_key"]
except KeyError:
model_key = "model"
state_dict = checkpoint[model_key]
key_map = {
# fmt: off
r"^decoder\.layers\.([0-9]+)\.self_attn\.q_proj\.": r"layers.\1.attention.wq.",
r"^decoder\.layers\.([0-9]+)\.self_attn\.k_proj\.": r"layers.\1.attention.wk.",
r"^decoder\.layers\.([0-9]+)\.self_attn\.v_proj\.": r"layers.\1.attention.wv.",
r"^decoder\.layers\.([0-9]+)\.self_attn\.output_proj\.": r"layers.\1.attention.wo.",
r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"layers.\1.attention_norm.",
r"^decoder\.layers\.([0-9]+)\.ffn\.gate_proj\.": r"layers.\1.feed_forward.w1.",
r"^decoder\.layers\.([0-9]+)\.ffn\.output_proj\.": r"layers.\1.feed_forward.w2.",
r"^decoder\.layers\.([0-9]+)\.ffn\.inner_proj\.": r"layers.\1.feed_forward.w3.",
r"^decoder\.layers\.([0-9]+)\.ffn_layer_norm\.": r"layers.\1.ffn_norm.",
r"^decoder\.layer_norm\.": r"norm.",
r"^decoder_frontend\.embed\.": r"tok_embeddings.",
r"^final_proj\.": r"output.",
# fmt: on
}
return convert_model_state_dict(state_dict, key_map)