From 0b1320a835bca1a878fc831f41b98283c193be8b Mon Sep 17 00:00:00 2001 From: Satya Ortiz-Gagne Date: Thu, 5 Sep 2024 10:11:56 -0400 Subject: [PATCH] Generate safe checkpoints for llama3 70B --- benchmarks/llm/benchfile.py | 2 +- benchmarks/llm/configs/llama3_70B_full.yaml | 93 +++++++++++++------ benchmarks/llm/configs/llama3_70B_lora.yaml | 93 ++++++++++++------- benchmarks/llm/{llama3 => }/llama/__init__.py | 0 .../llm/{llama3 => }/llama/generation.py | 0 benchmarks/llm/{llama3 => }/llama/model.py | 0 .../llm/{llama3 => }/llama/test_tokenizer.py | 0 .../llm/{llama3 => }/llama/tokenizer.py | 0 benchmarks/llm/prepare.py | 29 ++++-- benchmarks/llm/requirements.in | 5 +- benchmarks/llm/{llama3 => }/requirements.txt | 0 config/base.yaml | 7 +- 12 files changed, 150 insertions(+), 79 deletions(-) rename benchmarks/llm/{llama3 => }/llama/__init__.py (100%) rename benchmarks/llm/{llama3 => }/llama/generation.py (100%) rename benchmarks/llm/{llama3 => }/llama/model.py (100%) rename benchmarks/llm/{llama3 => }/llama/test_tokenizer.py (100%) rename benchmarks/llm/{llama3 => }/llama/tokenizer.py (100%) rename benchmarks/llm/{llama3 => }/requirements.txt (100%) diff --git a/benchmarks/llm/benchfile.py b/benchmarks/llm/benchfile.py index a1acbb4bb..ca269c368 100644 --- a/benchmarks/llm/benchfile.py +++ b/benchmarks/llm/benchfile.py @@ -40,7 +40,7 @@ class Llm(Package): prepare_script = "prepare.py" async def install(self): - llama3_dir = (XPath(__file__).resolve().parent / "llama3") + llama3_dir = XPath(__file__).resolve().parent with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = XPath(tmp_dir) tmp_dir.clone_subtree( diff --git a/benchmarks/llm/configs/llama3_70B_full.yaml b/benchmarks/llm/configs/llama3_70B_full.yaml index bac2d9709..3eb270e62 100644 --- a/benchmarks/llm/configs/llama3_70B_full.yaml +++ b/benchmarks/llm/configs/llama3_70B_full.yaml @@ -31,40 +31,73 @@ shuffle: True model: _component_: torchtune.models.llama3_1.llama3_1_70b +safetensors: true checkpointer: _component_: torchtune.utils.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: [ - model-00001-of-00030.safetensors, - model-00002-of-00030.safetensors, - model-00003-of-00030.safetensors, - model-00004-of-00030.safetensors, - model-00005-of-00030.safetensors, - model-00006-of-00030.safetensors, - model-00007-of-00030.safetensors, - model-00008-of-00030.safetensors, - model-00009-of-00030.safetensors, - model-00010-of-00030.safetensors, - model-00011-of-00030.safetensors, - model-00012-of-00030.safetensors, - model-00013-of-00030.safetensors, - model-00014-of-00030.safetensors, - model-00015-of-00030.safetensors, - model-00016-of-00030.safetensors, - model-00017-of-00030.safetensors, - model-00018-of-00030.safetensors, - model-00019-of-00030.safetensors, - model-00020-of-00030.safetensors, - model-00021-of-00030.safetensors, - model-00022-of-00030.safetensors, - model-00023-of-00030.safetensors, - model-00024-of-00030.safetensors, - model-00025-of-00030.safetensors, - model-00026-of-00030.safetensors, - model-00027-of-00030.safetensors, - model-00028-of-00030.safetensors, - model-00029-of-00030.safetensors, - model-00030-of-00030.safetensors, + model-00001-of-00062.safetensors, + model-00002-of-00062.safetensors, + model-00003-of-00062.safetensors, + model-00004-of-00062.safetensors, + model-00005-of-00062.safetensors, + model-00006-of-00062.safetensors, + model-00007-of-00062.safetensors, + model-00008-of-00062.safetensors, + model-00009-of-00062.safetensors, + model-00010-of-00062.safetensors, + model-00011-of-00062.safetensors, + model-00012-of-00062.safetensors, + model-00013-of-00062.safetensors, + model-00014-of-00062.safetensors, + model-00015-of-00062.safetensors, + model-00016-of-00062.safetensors, + model-00017-of-00062.safetensors, + model-00018-of-00062.safetensors, + model-00019-of-00062.safetensors, + model-00020-of-00062.safetensors, + model-00021-of-00062.safetensors, + model-00022-of-00062.safetensors, + model-00023-of-00062.safetensors, + model-00024-of-00062.safetensors, + model-00025-of-00062.safetensors, + model-00026-of-00062.safetensors, + model-00027-of-00062.safetensors, + model-00028-of-00062.safetensors, + model-00029-of-00062.safetensors, + model-00030-of-00062.safetensors, + model-00031-of-00062.safetensors, + model-00032-of-00062.safetensors, + model-00033-of-00062.safetensors, + model-00034-of-00062.safetensors, + model-00035-of-00062.safetensors, + model-00036-of-00062.safetensors, + model-00037-of-00062.safetensors, + model-00038-of-00062.safetensors, + model-00039-of-00062.safetensors, + model-00040-of-00062.safetensors, + model-00041-of-00062.safetensors, + model-00042-of-00062.safetensors, + model-00043-of-00062.safetensors, + model-00044-of-00062.safetensors, + model-00045-of-00062.safetensors, + model-00046-of-00062.safetensors, + model-00047-of-00062.safetensors, + model-00048-of-00062.safetensors, + model-00049-of-00062.safetensors, + model-00050-of-00062.safetensors, + model-00051-of-00062.safetensors, + model-00052-of-00062.safetensors, + model-00053-of-00062.safetensors, + model-00054-of-00062.safetensors, + model-00055-of-00062.safetensors, + model-00056-of-00062.safetensors, + model-00057-of-00062.safetensors, + model-00058-of-00062.safetensors, + model-00059-of-00062.safetensors, + model-00060-of-00062.safetensors, + model-00061-of-00062.safetensors, + model-00062-of-00062.safetensors, ] recipe_checkpoint: null output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ diff --git a/benchmarks/llm/configs/llama3_70B_lora.yaml b/benchmarks/llm/configs/llama3_70B_lora.yaml index 040eb571e..5934c65b8 100644 --- a/benchmarks/llm/configs/llama3_70B_lora.yaml +++ b/benchmarks/llm/configs/llama3_70B_lora.yaml @@ -21,42 +21,73 @@ tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model - safetensors: true checkpointer: _component_: torchtune.utils.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: [ - model-00001-of-00030.safetensors, - model-00002-of-00030.safetensors, - model-00003-of-00030.safetensors, - model-00004-of-00030.safetensors, - model-00005-of-00030.safetensors, - model-00006-of-00030.safetensors, - model-00007-of-00030.safetensors, - model-00008-of-00030.safetensors, - model-00009-of-00030.safetensors, - model-00010-of-00030.safetensors, - model-00011-of-00030.safetensors, - model-00012-of-00030.safetensors, - model-00013-of-00030.safetensors, - model-00014-of-00030.safetensors, - model-00015-of-00030.safetensors, - model-00016-of-00030.safetensors, - model-00017-of-00030.safetensors, - model-00018-of-00030.safetensors, - model-00019-of-00030.safetensors, - model-00020-of-00030.safetensors, - model-00021-of-00030.safetensors, - model-00022-of-00030.safetensors, - model-00023-of-00030.safetensors, - model-00024-of-00030.safetensors, - model-00025-of-00030.safetensors, - model-00026-of-00030.safetensors, - model-00027-of-00030.safetensors, - model-00028-of-00030.safetensors, - model-00029-of-00030.safetensors, - model-00030-of-00030.safetensors, + model-00001-of-00062.safetensors, + model-00002-of-00062.safetensors, + model-00003-of-00062.safetensors, + model-00004-of-00062.safetensors, + model-00005-of-00062.safetensors, + model-00006-of-00062.safetensors, + model-00007-of-00062.safetensors, + model-00008-of-00062.safetensors, + model-00009-of-00062.safetensors, + model-00010-of-00062.safetensors, + model-00011-of-00062.safetensors, + model-00012-of-00062.safetensors, + model-00013-of-00062.safetensors, + model-00014-of-00062.safetensors, + model-00015-of-00062.safetensors, + model-00016-of-00062.safetensors, + model-00017-of-00062.safetensors, + model-00018-of-00062.safetensors, + model-00019-of-00062.safetensors, + model-00020-of-00062.safetensors, + model-00021-of-00062.safetensors, + model-00022-of-00062.safetensors, + model-00023-of-00062.safetensors, + model-00024-of-00062.safetensors, + model-00025-of-00062.safetensors, + model-00026-of-00062.safetensors, + model-00027-of-00062.safetensors, + model-00028-of-00062.safetensors, + model-00029-of-00062.safetensors, + model-00030-of-00062.safetensors, + model-00031-of-00062.safetensors, + model-00032-of-00062.safetensors, + model-00033-of-00062.safetensors, + model-00034-of-00062.safetensors, + model-00035-of-00062.safetensors, + model-00036-of-00062.safetensors, + model-00037-of-00062.safetensors, + model-00038-of-00062.safetensors, + model-00039-of-00062.safetensors, + model-00040-of-00062.safetensors, + model-00041-of-00062.safetensors, + model-00042-of-00062.safetensors, + model-00043-of-00062.safetensors, + model-00044-of-00062.safetensors, + model-00045-of-00062.safetensors, + model-00046-of-00062.safetensors, + model-00047-of-00062.safetensors, + model-00048-of-00062.safetensors, + model-00049-of-00062.safetensors, + model-00050-of-00062.safetensors, + model-00051-of-00062.safetensors, + model-00052-of-00062.safetensors, + model-00053-of-00062.safetensors, + model-00054-of-00062.safetensors, + model-00055-of-00062.safetensors, + model-00056-of-00062.safetensors, + model-00057-of-00062.safetensors, + model-00058-of-00062.safetensors, + model-00059-of-00062.safetensors, + model-00060-of-00062.safetensors, + model-00061-of-00062.safetensors, + model-00062-of-00062.safetensors, ] recipe_checkpoint: null output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ diff --git a/benchmarks/llm/llama3/llama/__init__.py b/benchmarks/llm/llama/__init__.py similarity index 100% rename from benchmarks/llm/llama3/llama/__init__.py rename to benchmarks/llm/llama/__init__.py diff --git a/benchmarks/llm/llama3/llama/generation.py b/benchmarks/llm/llama/generation.py similarity index 100% rename from benchmarks/llm/llama3/llama/generation.py rename to benchmarks/llm/llama/generation.py diff --git a/benchmarks/llm/llama3/llama/model.py b/benchmarks/llm/llama/model.py similarity index 100% rename from benchmarks/llm/llama3/llama/model.py rename to benchmarks/llm/llama/model.py diff --git a/benchmarks/llm/llama3/llama/test_tokenizer.py b/benchmarks/llm/llama/test_tokenizer.py similarity index 100% rename from benchmarks/llm/llama3/llama/test_tokenizer.py rename to benchmarks/llm/llama/test_tokenizer.py diff --git a/benchmarks/llm/llama3/llama/tokenizer.py b/benchmarks/llm/llama/tokenizer.py similarity index 100% rename from benchmarks/llm/llama3/llama/tokenizer.py rename to benchmarks/llm/llama/tokenizer.py diff --git a/benchmarks/llm/prepare.py b/benchmarks/llm/prepare.py index da088f7bc..9c64ac8fe 100755 --- a/benchmarks/llm/prepare.py +++ b/benchmarks/llm/prepare.py @@ -7,12 +7,14 @@ from pathlib import Path import time -import llama3.llama.model +import llama.model import fairscale.nn.model_parallel from omegaconf import OmegaConf from argklass import ArgumentParser +import torch import torch.distributed from torchtune._cli.tune import TuneCLIParser +from transformers import LlamaConfig, LlamaForCausalLM from benchmate.ux import long_action @@ -24,7 +26,7 @@ class Arguments: @dataclass -class ModelArgs(llama3.llama.model.ModelArgs): +class ModelArgs(llama.model.ModelArgs): use_scaled_rope: bool = True @@ -60,7 +62,7 @@ def generate_model( time.sleep(0.1) conn.recv() params = json.loads(params_path.read_text()) - model = llama3.llama.model.Transformer(ModelArgs(**params)) + model = llama.model.Transformer(ModelArgs(**params)) torch.save(model.state_dict(), params_path.with_name(f"consolidated.{rank:02}.pth")) except Exception as e: conn.send(e) @@ -102,10 +104,7 @@ def main(): hf_token = os.getenv("HUGGING_FACE_TOKEN", None) output_dir = config["checkpointer"]["output_dir"] - ignore_patterns = ["*.safetensors", "original/consolidated.*.pth"] - - if config.get("safetensors", False): - ignore_pattern = "*consolidated.*.pth" + ignore_patterns = ["*.safetensors", "*consolidated.*.pth"] download_args = [ "download", @@ -120,7 +119,7 @@ def main(): [] ) ] - + if hf_token is not None: download_args.extend([ "--hf-token", @@ -133,9 +132,19 @@ def main(): args = parser.parse_args(download_args) parser.run(args) - if not config.get("safetensors", False): + if config.get("safetensors", False): + params_path = args.output_dir / "config.json" + model = LlamaForCausalLM(LlamaConfig(**json.loads(params_path.read_text()))) + # Avoid saving this as part of the config. + del model.config._name_or_path + model.config.torch_dtype = torch.float16 + model.save_pretrained(str(args.output_dir), safe_serialization=True) + + else: + # Note that at the time of writing torchtune doesn't support multi-*.pth + # files loading params_path = next(args.output_dir.glob("**/params.json")) - model_parallel_size = 8 if repo_id.split("-")[-1].lower() == "70b" else 1 + model_parallel_size = len(config["checkpointer"]["checkpoint_files"]) pipes = [multiprocessing.Pipe() for _ in range(model_parallel_size)] processes = [ multiprocessing.Process( diff --git a/benchmarks/llm/requirements.in b/benchmarks/llm/requirements.in index d9142ee87..bbe85dec2 100644 --- a/benchmarks/llm/requirements.in +++ b/benchmarks/llm/requirements.in @@ -4,4 +4,7 @@ torch PyYAML argklass --r llama3/requirements.txt +# Prepare +accelerate +transformers +-r requirements.txt diff --git a/benchmarks/llm/llama3/requirements.txt b/benchmarks/llm/requirements.txt similarity index 100% rename from benchmarks/llm/llama3/requirements.txt rename to benchmarks/llm/requirements.txt diff --git a/config/base.yaml b/config/base.yaml index 64a441b97..694309288 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -508,7 +508,6 @@ _llm: max_duration: 1200 num_machines: 1 - model_parallel_size: 1 inherits: _defaults definition: ../benchmarks/llm install_group: torch @@ -594,13 +593,12 @@ llm-lora-mp-gpus: tokenizer.path={milabench_data}/llama3_70B/original/tokenizer.model: true checkpointer.checkpoint_dir={milabench_data}/llama3_70B: true checkpointer.output_dir={milabench_data}/llama3_70B/: true + safetensors=true: true metric_logger.log_dir={milabench_extra}/metrics: true repo_id="meta-llama/Meta-Llama-3.1-70B": true batch_size=8: true gradient_accumulation_steps=1: true - model_parallel_size: 8 - llm-full-mp-gpus: inherits: _llm plan: @@ -621,8 +619,6 @@ llm-full-mp-gpus: batch_size=2: true gradient_accumulation_steps=1: true - model_parallel_size: 8 - llm-full-mp-nodes: tags: - multinode @@ -647,7 +643,6 @@ llm-full-mp-nodes: gradient_accumulation_steps=1: true num_machines: 2 - model_parallel_size: 8 requires_capabilities: - "len(nodes) >= ${num_machines}"