forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 344
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'fy/hf2megatron' of ssh://git.sankuai.com/~fengyu05/mega…
…tron-deepspeed into fy/hf2megatron
- Loading branch information
Showing
4 changed files
with
373 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import os | ||
import re | ||
from pathlib import Path | ||
from typing import Optional | ||
from collections import OrderedDict | ||
|
||
import torch | ||
from tqdm.auto import tqdm | ||
from transformers import LlamaForCausalLM, AutoTokenizer | ||
|
||
|
||
scale2emb = { | ||
'7B': 4096, | ||
'13B': 5120, | ||
'30B': 6656, | ||
'65B': 8192, | ||
'70B': 8192, | ||
} | ||
|
||
|
||
key_to_dim = { | ||
"w1": 0, | ||
"w2": -1, | ||
"w3": 0, | ||
"wo": -1, | ||
"wq": 0, | ||
"wk": 0, | ||
"wv": 0, | ||
"output": 0, | ||
"tok_embeddings": -1, | ||
"ffn_norm": None, | ||
"attention_norm": None, | ||
"norm": None, | ||
"rope": None, | ||
} | ||
|
||
|
||
def init_merged_ckpt(pth_00, num_pth=8, emb_dim=8192): | ||
merged_ckpt = OrderedDict() | ||
for parameter_name, parameter in pth_00.items(): | ||
short_name = parameter_name.split(".")[-2] | ||
if key_to_dim[short_name] is None: | ||
merged_ckpt[parameter_name] = parameter | ||
del parameter | ||
elif key_to_dim[short_name] == 0: | ||
size = parameter.shape[0] | ||
merged_param_shape = [ parameter.shape[0] * num_pth, parameter.shape[1] ] | ||
merged_ckpt[parameter_name] = torch.zeros(merged_param_shape) | ||
merged_ckpt[parameter_name][0 : size, :] = parameter | ||
del parameter | ||
elif key_to_dim[short_name] == -1: | ||
size = parameter.shape[-1] | ||
merged_param_shape = [ parameter.shape[0], parameter.shape[1] * num_pth] | ||
merged_ckpt[parameter_name] = torch.zeros(merged_param_shape) | ||
merged_ckpt[parameter_name][:, 0 : size] = parameter | ||
del parameter | ||
return merged_ckpt | ||
|
||
|
||
def merge_meta_llama(size: int, root_dir: Path): | ||
paths = sorted(path for path in root_dir.iterdir() | ||
if re.match(r"^consolidated\.[0-9]+\.pth$", path.name)) | ||
if len(paths) == 1: # no sharded checkpoints, return everything | ||
return torch.load(paths[0], map_location=torch.device("cpu")) | ||
|
||
num_pth = len(paths) | ||
for i, ckpt_path in enumerate(tqdm(paths, desc="Merging llama")): | ||
llama_config = torch.load(ckpt_path, map_location=torch.device('cpu')) | ||
if i == 0: | ||
merged_ckpt = init_merged_ckpt(llama_config, num_pth=num_pth, | ||
emb_dim=scale2emb[f"{size}B"]) | ||
else: | ||
for parameter_name, parameter in llama_config.items(): | ||
short_name = parameter_name.split(".")[-2] | ||
if key_to_dim[short_name] == 0: | ||
size = parameter.shape[0] | ||
merged_param_shape = [ parameter.shape[0] * num_pth, parameter.shape[1] ] | ||
merged_ckpt[parameter_name][size * i : size * (i + 1), :] = parameter | ||
del parameter | ||
if key_to_dim[short_name] == -1: | ||
size = parameter.shape[-1] | ||
merged_param_shape = [ parameter.shape[0], parameter.shape[1] * num_pth] | ||
merged_ckpt[parameter_name][:, size * i : size * (i + 1)] = parameter | ||
del parameter | ||
del llama_config | ||
return merged_ckpt | ||
|
||
|
||
def merge_hf_llama(cache_dir: Optional[Path] = None): | ||
# assert version == 2, "Only llama v2 available using huggingface" | ||
model = LlamaForCausalLM.from_pretrained(cache_dir, cache_dir=cache_dir, local_files_only=True, use_safetensors=False) | ||
weights = model.state_dict() | ||
weights["tok_embeddings.weight"] = weights.pop("model.embed_tokens.weight") | ||
weights["norm.weight"] = weights.pop("model.norm.weight") | ||
weights["output.weight"] = weights.pop("lm_head.weight") | ||
for key in list(weights.keys()): | ||
if rmatch := re.match(r"^model\.(layers\.[0-9]+\.)(.+)(\.weight)$", key): | ||
new_key = { | ||
"self_attn.q_proj": "attention.wq", | ||
"self_attn.k_proj": "attention.wk", | ||
"self_attn.v_proj": "attention.wv", | ||
"self_attn.o_proj": "attention.wo", | ||
"mlp.gate_proj": "feed_forward.w1", | ||
"mlp.down_proj": "feed_forward.w2", | ||
"mlp.up_proj": "feed_forward.w3", | ||
"input_layernorm": "attention_norm", | ||
"post_attention_layernorm": "ffn_norm" | ||
}[rmatch.group(2)] | ||
weights[rmatch.group(1) + new_key + rmatch.group(3)] = weights.pop(key) | ||
return weights, model.config | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import re | ||
import sys | ||
import os | ||
import shutil | ||
from pathlib import Path | ||
from argparse import ArgumentParser | ||
|
||
import torch | ||
from tqdm.auto import tqdm | ||
|
||
|
||
def permute_qkv(qkv_w: torch.Tensor, dim: int, n_heads: int, | ||
n_heads_kv: int, revert: bool = False) -> torch.Tensor: | ||
|
||
def permute(x): | ||
if revert: | ||
return x.view(head_dim//2, 2, dim).transpose(0, 1).reshape(head_dim, dim) | ||
return x.view(2, head_dim//2, dim).transpose(0, 1).reshape(head_dim, dim) | ||
|
||
head_dim = dim//n_heads | ||
n_qs_per_kv = n_heads//n_heads_kv | ||
n_groups = qkv_w.size(0)//head_dim//(n_qs_per_kv + 2) | ||
groups = torch.chunk(qkv_w, n_groups, dim=0) | ||
new = [] | ||
for group in groups: | ||
*qs, k, v = torch.split(group, head_dim, dim=0) | ||
assert len(qs) == n_qs_per_kv, f"{len(qs)}, {n_qs_per_kv}" | ||
new += list(map(permute, qs)) + [permute(k), v] | ||
return torch.cat(new, dim=0) | ||
|
||
|
||
def update_checkpoint(input_dir: Path, output_dir: Path, overwrite_ok: bool = False): | ||
# make sure megatron is importable | ||
sys.path.append(os.path.abspath( | ||
os.path.join(os.path.dirname(__file__), | ||
os.path.pardir))) | ||
|
||
|
||
# prepare output dir | ||
if output_dir.exists(): | ||
if not overwrite_ok: | ||
raise FileExistsError(f"Output directory {output_dir} already exists") | ||
print(f"Removing {output_dir}") | ||
shutil.rmtree(output_dir) | ||
output_dir.mkdir(exist_ok=True) | ||
|
||
# determine realease | ||
with open(input_dir/"latest_checkpointed_iteration.txt") as f: | ||
it = f.read() | ||
print("Updating weights of iteration", it) | ||
with open(output_dir/"latest_checkpointed_iteration.txt", "w+") as f: | ||
f.write(it) | ||
(output_dir/it).mkdir() | ||
|
||
# convert weights | ||
for fname in tqdm(list((input_dir/it).iterdir())): | ||
checkpoint = torch.load(fname/"model_optim_rng.pt") | ||
args = checkpoint["args"] | ||
args = (args.hidden_size, args.num_attention_heads, | ||
args.num_attention_heads_kv) | ||
if "transformer" in checkpoint["model"]["language_model"]: | ||
key = "transformer" | ||
attn_key = "attention" | ||
else: | ||
key = "encoder" | ||
attn_key = "self_attention" | ||
states = checkpoint["model"]["language_model"][key] | ||
for name, weight in states.items(): | ||
if re.match(rf"^layers\.[0-9]+\.{attn_key}\.query_key_value\.weight$", name): | ||
states[name] = permute_qkv(weight, *args) | ||
(output_dir/it/fname.stem).mkdir() | ||
torch.save(checkpoint, output_dir/it/fname.stem/"model_optim_rng.pt") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument("--input-dir", type=Path) | ||
parser.add_argument("--output-dir", type=Path) | ||
parser.add_argument("--overwrite-ok", action="store_true") | ||
args = parser.parse_args() | ||
update_checkpoint(args.input_dir, args.output_dir, args.overwrite_ok) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import os | ||
import sys | ||
import shutil | ||
from pathlib import Path | ||
from typing import Optional | ||
from argparse import ArgumentParser, Namespace | ||
|
||
import torch | ||
from tqdm.auto import trange | ||
from transformers import AutoModelForCausalLM, LlamaTokenizer | ||
from transformers import LlamaConfig | ||
|
||
from permute_qkv import permute_qkv | ||
from merge_llama import merge_hf_llama | ||
|
||
def llama_to_megatron(weights: dict, llama_config: LlamaConfig = None) -> dict: | ||
def permute(qkv_w): | ||
return permute_qkv(qkv_w, hidden, n_heads, n_kv_heads) | ||
|
||
def rearrange_qkv(wq, wk, wv): | ||
wq = torch.split(wq, n_hidden_per_head, dim=0) | ||
wk = torch.split(wk, n_hidden_per_head, dim=0) | ||
wv = torch.split(wv, n_hidden_per_head, dim=0) | ||
assert len(wq) == n_heads | ||
assert len(wk) == n_kv_heads | ||
assert len(wv) == n_kv_heads | ||
n_qs_per_kv = n_heads//n_kv_heads | ||
w_qkv = [] | ||
for i in range(n_kv_heads): | ||
w_qkv += [wq[i*n_qs_per_kv + j] for j in range(n_qs_per_kv)] | ||
w_qkv += [wk[i], wv[i]] | ||
return permute(torch.concat(w_qkv)) | ||
|
||
# config | ||
n_layer = llama_config.num_hidden_layers | ||
hidden = llama_config.hidden_size | ||
n_heads = llama_config.num_attention_heads | ||
n_hidden_per_head = hidden//n_heads | ||
n_kv_heads = llama_config.num_key_value_heads | ||
# weights independent of layers | ||
embedding = {"word_embeddings": {"weight": weights["tok_embeddings.weight"]}} | ||
transformer = {"final_layernorm.weight": weights["norm.weight"]} | ||
lm_head = weights["output.weight"] | ||
# get all the other weights | ||
for layer in trange(n_layer, desc="Converting weights"): | ||
prefix = f"layers.{layer}" | ||
# identical weights | ||
transformer[f"{prefix}.attention.dense.weight"] = \ | ||
weights[f"{prefix}.attention.wo.weight"] | ||
transformer[f"{prefix}.post_attention_layernorm.weight"] = \ | ||
weights[f"{prefix}.ffn_norm.weight"] | ||
transformer[f"{prefix}.input_layernorm.weight"] = \ | ||
weights[f"{prefix}.attention_norm.weight"] | ||
transformer[f"{prefix}.mlp.dense_4h_to_h.weight"] = \ | ||
weights[f"{prefix}.feed_forward.w2.weight"] | ||
# concatenate up, gate mlp weights | ||
transformer[f"{prefix}.mlp.dense_h_to_4h.weight"] = torch.concat([ | ||
weights[f"{prefix}.feed_forward.w3.weight"], | ||
weights[f"{prefix}.feed_forward.w1.weight"] | ||
]) | ||
# finally, qkv requires serious manipulation to get right | ||
transformer[f"{prefix}.attention.query_key_value.weight"] = rearrange_qkv( | ||
weights[f"{prefix}.attention.wq.weight"], | ||
weights[f"{prefix}.attention.wk.weight"], | ||
weights[f"{prefix}.attention.wv.weight"] | ||
) | ||
|
||
# release references to original weights (free mem) | ||
del weights[f"{prefix}.feed_forward.w3.weight"] | ||
del weights[f"{prefix}.feed_forward.w1.weight"] | ||
del weights[f"{prefix}.attention.wq.weight"] | ||
del weights[f"{prefix}.attention.wk.weight"] | ||
del weights[f"{prefix}.attention.wv.weight"] | ||
|
||
return {"embedding": embedding, "encoder": transformer, | ||
"lm_head": lm_head} | ||
|
||
def main(out: Optional[Path] = None, | ||
cache_dir: Optional[Path] = None, megatron_path: Optional[Path] = None): | ||
|
||
if megatron_path: | ||
print("Add megatron to os path") | ||
os.path.append(megatron_path) | ||
# get weights from or specified directory | ||
print("Getting llama...") | ||
hf_weights, llama_config = merge_hf_llama(cache_dir) | ||
|
||
# convert state dict to be megatron-compatible | ||
megatron_weights = llama_to_megatron(hf_weights, llama_config=llama_config) | ||
|
||
# set args | ||
# llama1, llama2 | ||
args = {"num_layers": llama_config.num_hidden_layers, | ||
"hidden_size": llama_config.hidden_size, | ||
"num_attention_heads": llama_config.num_attention_heads, | ||
"ffn_hidden_size": llama_config.intermediate_size, | ||
"num_key_value_heads": llama_config.num_key_value_heads, | ||
"parallel_attn": False, | ||
"make_vocab_size_divisible_by": 1, | ||
"glu_activation": "swiglu", | ||
"max_position_embeddings": llama_config.max_length, # should use max_length rather than max_position_embeddings, detail in https://github.com/lm-sys/FastChat/issues/2046#issuecomment-1645265800 | ||
"seq_length": llama_config.max_length, | ||
"layernorm_epsilon": llama_config.rms_norm_eps, | ||
# llama args | ||
"padded_vocab_size": llama_config.vocab_size, | ||
"tokenizer_type": "GPTSentencePieceTokenizer", | ||
"no-query-key-layer-scaling": True, | ||
"attention-dropout": 0, | ||
"hidden-dropout": 0, | ||
"use-rotary-position-embeddings": True, | ||
"untie-embeddings-and-output-weights": True, | ||
"swiglu": True, | ||
"normalization": "rmsnorm", | ||
"disable-bias-linear": True, | ||
"add_position_embedding": False, | ||
"add_bias_linear": False, | ||
} | ||
if llama_config.num_key_value_heads: | ||
args.update({"num_attention_heads_kv": llama_config.num_key_value_heads}) | ||
|
||
args.update({ | ||
"tensor_model_parallel_size": 1, | ||
"pipeline_model_parallel_size": 1, | ||
"iteration": 0, | ||
"bias_gelu_fusion": False, | ||
"bias_droput_fusion": False, | ||
}) | ||
|
||
# save converted weights in specified out | ||
(out/"release"/"mp_rank_00").mkdir(parents=True) | ||
with open(out/"latest_checkpointed_iteration.txt", "w+") as f: | ||
f.write("release") | ||
final_dict = {"iteration": 'release', "model": {"language_model": megatron_weights}, | ||
"checkpoint_version": 3.0, "args": Namespace(**args)} | ||
torch.save(final_dict, out/"release"/"mp_rank_00"/"model_optim_rng.pt") | ||
print("Saved weights in", out) | ||
|
||
tokenizer = LlamaTokenizer.from_pretrained( | ||
cache_dir, cache_dir=cache_dir, local_files_only=True, | ||
) | ||
token_path = out/"tokenizer.model" | ||
vocab_file = tokenizer.vocab_file | ||
shutil.copy(vocab_file, token_path) | ||
print("Saved tokenizer.model in", token_path) | ||
print("Done") | ||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser(description="Convert Huggingface llama weights to " | ||
"megatron-compatible weights") | ||
parser.add_argument("--out", type=Path, | ||
help="Directory to store the megatron weights (as checkpoint)") | ||
parser.add_argument("--cache-dir", type=Path, | ||
help=("Directory to store the huggingface weights, or " | ||
"in case of the llama model, where to look for " | ||
"the consolidated.xx.pth")) | ||
parser.add_argument("--megatron-path", type=Path, default=None, | ||
help="Path where to find megatron code") | ||
args = parser.parse_args() | ||
|
||
main(args.out, args.cache_dir, args.megatron_path) |