Skip to content

Commit

Permalink
[sax/llama_conversion] Add llama3_405b_mp8 to convert_llama_ckpt.py
Browse files Browse the repository at this point in the history
llama 3.1 405b comes both with 8 and 16 kv heads. This is to add an option to convert 8 kv heads checkpoint.

PiperOrigin-RevId: 683730513
Change-Id: I42ee4f650069566a8b79f592b903ac13a6cc315a
  • Loading branch information
rdzhabarov authored and copybara-github committed Oct 8, 2024
1 parent 26631a9 commit 3f085ae
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion saxml/tools/convert_llama_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import torch

MODEL_PARAMS_DICT = {
'llama3_405b': {
'llama3_405b_mp16': {
'num_layers': 126,
'num_heads': 128,
'num_kv_heads': 16,
Expand All @@ -49,6 +49,15 @@
'num_gpus': 1,
'combined_qkv': False,
},
'llama3_405b_mp8': {
'num_layers': 126,
'num_heads': 128,
'num_kv_heads': 8,
'dims_per_head': 128,
'vocab': 128256,
'num_gpus': 1,
'combined_qkv': False,
},
'llama3_70b': {
'num_layers': 80,
'num_heads': 64,
Expand Down

0 comments on commit 3f085ae

Please sign in to comment.