From b0bb24bff4b4fbddc7a2e28d0315e58fe18c8d74 Mon Sep 17 00:00:00 2001 From: Sax Authors Date: Tue, 3 Oct 2023 16:56:45 -0700 Subject: [PATCH] modify checkpoint conversion tool to accommodate LLaMA2 70B model. PiperOrigin-RevId: 570536326 Change-Id: I0c42a29725a9a37936e6aaac98a59181ca4a2110 --- saxml/tools/convert_llama_ckpt.py | 110 +++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 32 deletions(-) diff --git a/saxml/tools/convert_llama_ckpt.py b/saxml/tools/convert_llama_ckpt.py index f25bf878..27760753 100644 --- a/saxml/tools/convert_llama_ckpt.py +++ b/saxml/tools/convert_llama_ckpt.py @@ -15,10 +15,15 @@ Usage: -# Get LLaMA pytorch_vars from Meta +# Get LLaMA pytorch_vars from Meta # Example cmd: -python3 -m convert_llama_ckpt --base llama_7b --pax pax_7b +python3 -m convert_llama_ckpt --base llama_7b --pax pax_7b --model-size 7b + +# For large size model (e.g. 70B model), this script requires large memory VM. +# The script load and save weights in a single pass. +# To fit less memory, modify convert() to load/save weights in multiple passes. +# Each pass, load and save partial weights (subset of all weight variables). """ # pylint: disable=g-line-too-long import argparse @@ -34,15 +39,39 @@ import torch -num_layers = 32 -num_heads = 32 -dims_per_head = 128 -vocab = 32000 -num_gpus = 1 - - -def convert(base_model_path, pax_model_path): +MODEL_PARAMS_DICT = { + '70b': { + 'num_layers': 80, + 'num_heads': 64, + 'num_kv_heads': 8, + 'dims_per_head': 128, + 'vocab': 32000, + 'num_gpus': 1, + 'combined_qkv': False, + }, + '7b': { + 'num_layers': 32, + 'num_heads': 32, + 'num_kv_heads': 32, + 'dims_per_head': 128, + 'vocab': 32000, + 'num_gpus': 1, + 'combined_qkv': True, + }, +} + + +def convert(base_model_path, pax_model_path, model_size): """Convert from vicuna to pax.""" + model_params = MODEL_PARAMS_DICT[model_size] + num_layers = model_params['num_layers'] + num_heads = model_params['num_heads'] + dims_per_head = model_params['dims_per_head'] + num_kv_heads = model_params['num_kv_heads'] + vocab = model_params['vocab'] + combined_qkv = model_params['combined_qkv'] + num_gpus = model_params['num_gpus'] + print(f'Loading the base model from {base_model_path}') ckpt_paths = sorted(pathlib.Path(base_model_path).glob('*.pth')) pytorch_vars = {} @@ -55,12 +84,12 @@ def convert(base_model_path, pax_model_path): jax_weights = { 'lm': { 'embedding_lookup': { - 'emb_var': np.concatenate([var['tok_embeddings.weight'].numpy() for var in pytorch_vars], axis=1)[:vocab,:] + 'emb_var': np.concatenate([var['tok_embeddings.weight'].type(torch.float16).numpy() for var in pytorch_vars], axis=1)[:vocab,:] }, 'softmax': { 'logits_ffn': { 'linear': { - 'w': np.concatenate([var['output.weight'].numpy() for var in pytorch_vars], axis=0).transpose()[:, :vocab] + 'w': np.concatenate([var['output.weight'].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose()[:, :vocab] } } }, @@ -70,54 +99,68 @@ def convert(base_model_path, pax_model_path): 'transformer': {} } } + for layer_idx in range(num_layers): - wq = np.concatenate([var['layers.%d.attention.wq.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=0).transpose() - wk = np.concatenate([var['layers.%d.attention.wk.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=0).transpose() - wv = np.concatenate([var['layers.%d.attention.wv.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=0).transpose() - wc = np.stack([wq, wk, wv], axis=0) - wc = np.reshape(wc, [3, num_heads * dims_per_head, num_heads, dims_per_head]) + wq = np.concatenate([var['layers.%d.attention.wq.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose() + wk = np.concatenate([var['layers.%d.attention.wk.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose() + wv = np.concatenate([var['layers.%d.attention.wv.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose() + if combined_qkv: + wc = np.stack([wq, wk, wv], axis=0) + wc = np.reshape(wc, [3, num_heads * dims_per_head, num_heads, dims_per_head]) + else: + wq = np.reshape(wq, [num_heads * dims_per_head, num_heads, dims_per_head]) + wk = np.reshape(wk, [num_heads * dims_per_head, num_kv_heads, dims_per_head]) + wv = np.reshape(wv, [num_heads * dims_per_head, num_kv_heads, dims_per_head]) w_post = np.concatenate( [ - var['layers.%d.attention.wo.weight' % (layer_idx)].numpy() + var['layers.%d.attention.wo.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars ], axis=1, ) w_post = np.reshape(w_post, [num_heads * dims_per_head, num_heads, dims_per_head]) + + if combined_qkv: + attention_weights = { + 'self_attention': {'combined_qkv': {'w': wc}, 'post': {'w': w_post}} + } + else: + attention_weights = { + 'self_attention': { + 'query': {'w': wq}, + 'key': {'w': wk}, + 'value': {'w': wv}, + 'post': {'w': w_post}, + }, + } + layer_weight = { - 'self_attention': { - 'combined_qkv': { - 'w': wc - }, - 'post': { - 'w': w_post - } - }, 'ff_layer': { 'ffn_layer1_gate': { 'linear': { - 'w': np.concatenate([var['layers.%d.feed_forward.w1.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=0).transpose() + 'w': np.concatenate([var['layers.%d.feed_forward.w1.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose() } }, 'ffn_layer1': { 'linear': { - 'w': np.concatenate([var['layers.%d.feed_forward.w3.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=0).transpose() + 'w': np.concatenate([var['layers.%d.feed_forward.w3.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=0).transpose() } }, 'ffn_layer2': { 'linear': { - 'w': np.concatenate([var['layers.%d.feed_forward.w2.weight' % (layer_idx)].numpy() for var in pytorch_vars], axis=1).transpose() + 'w': np.concatenate([var['layers.%d.feed_forward.w2.weight' % (layer_idx)].type(torch.float16).numpy() for var in pytorch_vars], axis=1).transpose() } }, 'layer_norm': { - 'scale': pytorch_vars[0]['layers.%d.ffn_norm.weight' % (layer_idx)].numpy() + 'scale': pytorch_vars[0]['layers.%d.ffn_norm.weight' % (layer_idx)].type(torch.float16).numpy() } }, 'layer_norm': { - 'scale': pytorch_vars[0]['layers.%d.attention_norm.weight' % (layer_idx)].numpy() + 'scale': pytorch_vars[0]['layers.%d.attention_norm.weight' % (layer_idx)].type(torch.float16).numpy() } } + layer_weight.update(attention_weights) jax_weights['lm']['transformer']['x_layers_%d' % layer_idx] = layer_weight print(f'Saving the pax model to {pax_model_path}') @@ -150,6 +193,9 @@ def identity(x): parser = argparse.ArgumentParser() parser.add_argument('--base-model-path', type=str, required=True) parser.add_argument('--pax-model-path', type=str, required=True) + parser.add_argument('--model-size', type=str, required=True) args = parser.parse_args() - convert(args.base_model_path, args.pax_model_path) + if args.model_size not in MODEL_PARAMS_DICT: + raise NotImplementedError + convert(args.base_model_path, args.pax_model_path, args.model_size)