Skip to content

Commit

Permalink
modify checkpoint conversion tool to accommodate LLaMA2 70B model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570536326
Change-Id: I0c42a29725a9a37936e6aaac98a59181ca4a2110
  • Loading branch information
Sax Authors authored and copybara-github committed Oct 3, 2023
1 parent 88bdf18 commit b0bb24b
Showing 1 changed file with 78 additions and 32 deletions.
110 changes: 78 additions & 32 deletions saxml/tools/convert_llama_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
Usage:
# Get LLaMA pytorch_vars from Meta
# Get LLaMA pytorch_vars from Meta
# Example cmd:
python3 -m convert_llama_ckpt --base llama_7b --pax pax_7b
python3 -m convert_llama_ckpt --base llama_7b --pax pax_7b --model-size 7b
# For large size model (e.g. 70B model), this script requires large memory VM.
# The script load and save weights in a single pass.
# To fit less memory, modify convert() to load/save weights in multiple passes.
# Each pass, load and save partial weights (subset of all weight variables).
"""
# pylint: disable=g-line-too-long
import argparse
Expand All @@ -34,15 +39,39 @@

import torch

num_layers = 32
num_heads = 32
dims_per_head = 128
vocab = 32000
num_gpus = 1


def convert(base_model_path, pax_model_path):
MODEL_PARAMS_DICT = {
'70b': {
'num_layers': 80,
'num_heads': 64,
'num_kv_heads': 8,
'dims_per_head': 128,
'vocab': 32000,
'num_gpus': 1,
'combined_qkv': False,
},
'7b': {
'num_layers': 32,
'num_heads': 32,
'num_kv_heads': 32,
'dims_per_head': 128,
'vocab': 32000,
'num_gpus': 1,
'combined_qkv': True,
},
}


def convert(base_model_path, pax_model_path, model_size):
"""Convert from vicuna to pax."""
model_params = MODEL_PARAMS_DICT[model_size]
num_layers = model_params['num_layers']
num_heads = model_params['num_heads']
dims_per_head = model_params['dims_per_head']
num_kv_heads = model_params['num_kv_heads']
vocab = model_params['vocab']
combined_qkv = model_params['combined_qkv']
num_gpus = model_params['num_gpus']

print(f'Loading the base model from {base_model_path}')
ckpt_paths = sorted(pathlib.Path(base_model_path).glob('*.pth'))
pytorch_vars = {}
Expand All @@ -55,12 +84,12 @@ def convert(base_model_path, pax_model_path):
jax_weights = {
'lm': {
'embedding_lookup': {
'emb_var': np.concatenate([var['tok_embeddings.weight'].numpy() for var in pytorch_vars], axis=1)[:vocab,:]
'emb_var': np.concatenate([var['tok_embeddings.weight'].type(torch.float16).numpy() for var in pytorch_vars], axis=1)[:vocab,:]
},
'softmax': {
'logits_ffn': {
'linear': {
'w': np.concatenate([var['output.weight'].numpy() for var in pytorch_vars], axis=0).transpose()[:, :vocab]
'w': np.concatenate([var['output.weight'].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose()[:, :vocab]
}
}
},
Expand All @@ -70,54 +99,68 @@ def convert(base_model_path, pax_model_path):
'transformer': {}
}
}

for layer_idx in range(num_layers):
wq = np.concatenate([var['layers.%d.attention.wq.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=0).transpose()
wk = np.concatenate([var['layers.%d.attention.wk.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=0).transpose()
wv = np.concatenate([var['layers.%d.attention.wv.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=0).transpose()
wc = np.stack([wq, wk, wv], axis=0)
wc = np.reshape(wc, [3, num_heads * dims_per_head, num_heads, dims_per_head])
wq = np.concatenate([var['layers.%d.attention.wq.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose()
wk = np.concatenate([var['layers.%d.attention.wk.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose()
wv = np.concatenate([var['layers.%d.attention.wv.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose()
if combined_qkv:
wc = np.stack([wq, wk, wv], axis=0)
wc = np.reshape(wc, [3, num_heads * dims_per_head, num_heads, dims_per_head])
else:
wq = np.reshape(wq, [num_heads * dims_per_head, num_heads, dims_per_head])
wk = np.reshape(wk, [num_heads * dims_per_head, num_kv_heads, dims_per_head])
wv = np.reshape(wv, [num_heads * dims_per_head, num_kv_heads, dims_per_head])

w_post = np.concatenate(
[
var['layers.%d.attention.wo.weight' % (layer_idx)].numpy()
var['layers.%d.attention.wo.weight' % (layer_idx)].type(torch.float16).numpy()
for var in pytorch_vars
],
axis=1,
)
w_post = np.reshape(w_post, [num_heads * dims_per_head, num_heads, dims_per_head])

if combined_qkv:
attention_weights = {
'self_attention': {'combined_qkv': {'w': wc}, 'post': {'w': w_post}}
}
else:
attention_weights = {
'self_attention': {
'query': {'w': wq},
'key': {'w': wk},
'value': {'w': wv},
'post': {'w': w_post},
},
}

layer_weight = {
'self_attention': {
'combined_qkv': {
'w': wc
},
'post': {
'w': w_post
}
},
'ff_layer': {
'ffn_layer1_gate': {
'linear': {
'w': np.concatenate([var['layers.%d.feed_forward.w1.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=0).transpose()
'w': np.concatenate([var['layers.%d.feed_forward.w1.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose()
}
},
'ffn_layer1': {
'linear': {
'w': np.concatenate([var['layers.%d.feed_forward.w3.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=0).transpose()
'w': np.concatenate([var['layers.%d.feed_forward.w3.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose()
}
},
'ffn_layer2': {
'linear': {
'w': np.concatenate([var['layers.%d.feed_forward.w2.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=1).transpose()
'w': np.concatenate([var['layers.%d.feed_forward.w2.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=1).transpose()
}
},
'layer_norm': {
'scale': pytorch_vars[0]['layers.%d.ffn_norm.weight' % (layer_idx)].numpy()
'scale': pytorch_vars[0]['layers.%d.ffn_norm.weight' % (layer_idx)].type(torch.float16).numpy()
}
},
'layer_norm': {
'scale': pytorch_vars[0]['layers.%d.attention_norm.weight' % (layer_idx)].numpy()
'scale': pytorch_vars[0]['layers.%d.attention_norm.weight' % (layer_idx)].type(torch.float16).numpy()
}
}
layer_weight.update(attention_weights)
jax_weights['lm']['transformer']['x_layers_%d' % layer_idx] = layer_weight

print(f'Saving the pax model to {pax_model_path}')
Expand Down Expand Up @@ -150,6 +193,9 @@ def identity(x):
parser = argparse.ArgumentParser()
parser.add_argument('--base-model-path', type=str, required=True)
parser.add_argument('--pax-model-path', type=str, required=True)
parser.add_argument('--model-size', type=str, required=True)
args = parser.parse_args()

convert(args.base_model_path, args.pax_model_path)
if args.model_size not in MODEL_PARAMS_DICT:
raise NotImplementedError
convert(args.base_model_path, args.pax_model_path, args.model_size)

0 comments on commit b0bb24b

Please sign in to comment.