Skip to content

Commit

Permalink
Generate safe checkpoints for llama3 70B
Browse files Browse the repository at this point in the history
  • Loading branch information
satyaog committed Aug 27, 2024
1 parent 0aa000c commit 656634d
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 92 deletions.
2 changes: 1 addition & 1 deletion benchmarks/llm/benchfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Llm(Package):
prepare_script = "prepare.py"

async def install(self):
llama3_dir = (XPath(__file__).resolve().parent / "llama3")
llama3_dir = XPath(__file__).resolve().parent
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = XPath(tmp_dir)
tmp_dir.clone_subtree(
Expand Down
93 changes: 63 additions & 30 deletions benchmarks/llm/configs/llama3_70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,40 +31,73 @@ shuffle: True
model:
_component_: torchtune.models.llama3_1.llama3_1_70b

safetensors: true
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files: [
model-00001-of-00030.safetensors,
model-00002-of-00030.safetensors,
model-00003-of-00030.safetensors,
model-00004-of-00030.safetensors,
model-00005-of-00030.safetensors,
model-00006-of-00030.safetensors,
model-00007-of-00030.safetensors,
model-00008-of-00030.safetensors,
model-00009-of-00030.safetensors,
model-00010-of-00030.safetensors,
model-00011-of-00030.safetensors,
model-00012-of-00030.safetensors,
model-00013-of-00030.safetensors,
model-00014-of-00030.safetensors,
model-00015-of-00030.safetensors,
model-00016-of-00030.safetensors,
model-00017-of-00030.safetensors,
model-00018-of-00030.safetensors,
model-00019-of-00030.safetensors,
model-00020-of-00030.safetensors,
model-00021-of-00030.safetensors,
model-00022-of-00030.safetensors,
model-00023-of-00030.safetensors,
model-00024-of-00030.safetensors,
model-00025-of-00030.safetensors,
model-00026-of-00030.safetensors,
model-00027-of-00030.safetensors,
model-00028-of-00030.safetensors,
model-00029-of-00030.safetensors,
model-00030-of-00030.safetensors,
model-00001-of-00062.safetensors,
model-00002-of-00062.safetensors,
model-00003-of-00062.safetensors,
model-00004-of-00062.safetensors,
model-00005-of-00062.safetensors,
model-00006-of-00062.safetensors,
model-00007-of-00062.safetensors,
model-00008-of-00062.safetensors,
model-00009-of-00062.safetensors,
model-00010-of-00062.safetensors,
model-00011-of-00062.safetensors,
model-00012-of-00062.safetensors,
model-00013-of-00062.safetensors,
model-00014-of-00062.safetensors,
model-00015-of-00062.safetensors,
model-00016-of-00062.safetensors,
model-00017-of-00062.safetensors,
model-00018-of-00062.safetensors,
model-00019-of-00062.safetensors,
model-00020-of-00062.safetensors,
model-00021-of-00062.safetensors,
model-00022-of-00062.safetensors,
model-00023-of-00062.safetensors,
model-00024-of-00062.safetensors,
model-00025-of-00062.safetensors,
model-00026-of-00062.safetensors,
model-00027-of-00062.safetensors,
model-00028-of-00062.safetensors,
model-00029-of-00062.safetensors,
model-00030-of-00062.safetensors,
model-00031-of-00062.safetensors,
model-00032-of-00062.safetensors,
model-00033-of-00062.safetensors,
model-00034-of-00062.safetensors,
model-00035-of-00062.safetensors,
model-00036-of-00062.safetensors,
model-00037-of-00062.safetensors,
model-00038-of-00062.safetensors,
model-00039-of-00062.safetensors,
model-00040-of-00062.safetensors,
model-00041-of-00062.safetensors,
model-00042-of-00062.safetensors,
model-00043-of-00062.safetensors,
model-00044-of-00062.safetensors,
model-00045-of-00062.safetensors,
model-00046-of-00062.safetensors,
model-00047-of-00062.safetensors,
model-00048-of-00062.safetensors,
model-00049-of-00062.safetensors,
model-00050-of-00062.safetensors,
model-00051-of-00062.safetensors,
model-00052-of-00062.safetensors,
model-00053-of-00062.safetensors,
model-00054-of-00062.safetensors,
model-00055-of-00062.safetensors,
model-00056-of-00062.safetensors,
model-00057-of-00062.safetensors,
model-00058-of-00062.safetensors,
model-00059-of-00062.safetensors,
model-00060-of-00062.safetensors,
model-00061-of-00062.safetensors,
model-00062-of-00062.safetensors,
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
Expand Down
93 changes: 62 additions & 31 deletions benchmarks/llm/configs/llama3_70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,73 @@ tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model


safetensors: true
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files: [
model-00001-of-00030.safetensors,
model-00002-of-00030.safetensors,
model-00003-of-00030.safetensors,
model-00004-of-00030.safetensors,
model-00005-of-00030.safetensors,
model-00006-of-00030.safetensors,
model-00007-of-00030.safetensors,
model-00008-of-00030.safetensors,
model-00009-of-00030.safetensors,
model-00010-of-00030.safetensors,
model-00011-of-00030.safetensors,
model-00012-of-00030.safetensors,
model-00013-of-00030.safetensors,
model-00014-of-00030.safetensors,
model-00015-of-00030.safetensors,
model-00016-of-00030.safetensors,
model-00017-of-00030.safetensors,
model-00018-of-00030.safetensors,
model-00019-of-00030.safetensors,
model-00020-of-00030.safetensors,
model-00021-of-00030.safetensors,
model-00022-of-00030.safetensors,
model-00023-of-00030.safetensors,
model-00024-of-00030.safetensors,
model-00025-of-00030.safetensors,
model-00026-of-00030.safetensors,
model-00027-of-00030.safetensors,
model-00028-of-00030.safetensors,
model-00029-of-00030.safetensors,
model-00030-of-00030.safetensors,
model-00001-of-00062.safetensors,
model-00002-of-00062.safetensors,
model-00003-of-00062.safetensors,
model-00004-of-00062.safetensors,
model-00005-of-00062.safetensors,
model-00006-of-00062.safetensors,
model-00007-of-00062.safetensors,
model-00008-of-00062.safetensors,
model-00009-of-00062.safetensors,
model-00010-of-00062.safetensors,
model-00011-of-00062.safetensors,
model-00012-of-00062.safetensors,
model-00013-of-00062.safetensors,
model-00014-of-00062.safetensors,
model-00015-of-00062.safetensors,
model-00016-of-00062.safetensors,
model-00017-of-00062.safetensors,
model-00018-of-00062.safetensors,
model-00019-of-00062.safetensors,
model-00020-of-00062.safetensors,
model-00021-of-00062.safetensors,
model-00022-of-00062.safetensors,
model-00023-of-00062.safetensors,
model-00024-of-00062.safetensors,
model-00025-of-00062.safetensors,
model-00026-of-00062.safetensors,
model-00027-of-00062.safetensors,
model-00028-of-00062.safetensors,
model-00029-of-00062.safetensors,
model-00030-of-00062.safetensors,
model-00031-of-00062.safetensors,
model-00032-of-00062.safetensors,
model-00033-of-00062.safetensors,
model-00034-of-00062.safetensors,
model-00035-of-00062.safetensors,
model-00036-of-00062.safetensors,
model-00037-of-00062.safetensors,
model-00038-of-00062.safetensors,
model-00039-of-00062.safetensors,
model-00040-of-00062.safetensors,
model-00041-of-00062.safetensors,
model-00042-of-00062.safetensors,
model-00043-of-00062.safetensors,
model-00044-of-00062.safetensors,
model-00045-of-00062.safetensors,
model-00046-of-00062.safetensors,
model-00047-of-00062.safetensors,
model-00048-of-00062.safetensors,
model-00049-of-00062.safetensors,
model-00050-of-00062.safetensors,
model-00051-of-00062.safetensors,
model-00052-of-00062.safetensors,
model-00053-of-00062.safetensors,
model-00054-of-00062.safetensors,
model-00055-of-00062.safetensors,
model-00056-of-00062.safetensors,
model-00057-of-00062.safetensors,
model-00058-of-00062.safetensors,
model-00059-of-00062.safetensors,
model-00060-of-00062.safetensors,
model-00061-of-00062.safetensors,
model-00062-of-00062.safetensors,
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
27 changes: 18 additions & 9 deletions benchmarks/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from pathlib import Path
import time

import llama3.llama.model
import llama.model
import fairscale.nn.model_parallel
from omegaconf import OmegaConf
from argklass import ArgumentParser
import torch
import torch.distributed
from torchtune._cli.tune import TuneCLIParser
from transformers import LlamaConfig, LlamaForCausalLM

from benchmate.ux import long_action

Expand All @@ -24,7 +26,7 @@ class Arguments:


@dataclass
class ModelArgs(llama3.llama.model.ModelArgs):
class ModelArgs(llama.model.ModelArgs):
use_scaled_rope: bool = True


Expand Down Expand Up @@ -60,7 +62,7 @@ def generate_model(
time.sleep(0.1)
conn.recv()
params = json.loads(params_path.read_text())
model = llama3.llama.model.Transformer(ModelArgs(**params))
model = llama.model.Transformer(ModelArgs(**params))
torch.save(model.state_dict(), params_path.with_name(f"consolidated.{rank:02}.pth"))
except Exception as e:
conn.send(e)
Expand Down Expand Up @@ -104,9 +106,6 @@ def main():

ignore_patterns = ["*.safetensors", "original/consolidated.*.pth"]

if config.get("safetensors", False):
ignore_patterns = ["original/consolidated.*.pth"]

download_args = [
"download",
repo_id,
Expand All @@ -120,7 +119,7 @@ def main():
[]
)
]

if hf_token is not None:
download_args.extend([
"--hf-token",
Expand All @@ -133,9 +132,19 @@ def main():
args = parser.parse_args(download_args)
parser.run(args)

if not config.get("safetensors", False):
if config.get("safetensors", False):
params_path = args.output_dir / "config.json"
model = LlamaForCausalLM(LlamaConfig(**json.loads(params_path.read_text())))
# Avoid saving this as part of the config.
del model.config._name_or_path
model.config.torch_dtype = torch.float16
model.save_pretrained(str(args.output_dir), safe_serialization=True)

else:
# Note that at the time of writing torchtune doesn't support multi-*.pth
# files loading
params_path = next(args.output_dir.glob("**/params.json"))
model_parallel_size = 8 if repo_id.split("-")[-1].lower() == "70b" else 1
model_parallel_size = len(config["checkpointer"]["checkpoint_files"])
pipes = [multiprocessing.Pipe() for _ in range(model_parallel_size)]
processes = [
multiprocessing.Process(
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/llm/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ torch
PyYAML
argklass

-r llama3/requirements.txt
# Prepare
accelerate
transformers
-r requirements.txt
File renamed without changes.
Loading

0 comments on commit 656634d

Please sign in to comment.