diff --git a/benchmarks/llm/benchfile.py b/benchmarks/llm/benchfile.py index 44337daa9..9e2934fb5 100644 --- a/benchmarks/llm/benchfile.py +++ b/benchmarks/llm/benchfile.py @@ -29,7 +29,7 @@ class Llm(Package): prepare_script = "prepare.py" async def install(self): - llama3_dir = (XPath(__file__).resolve().parent / "llama3") + llama3_dir = XPath(__file__).resolve().parent with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = XPath(tmp_dir) tmp_dir.clone_subtree( diff --git a/benchmarks/llm/configs/llama3_70B_full.yaml b/benchmarks/llm/configs/llama3_70B_full.yaml index bac2d9709..3eb270e62 100644 --- a/benchmarks/llm/configs/llama3_70B_full.yaml +++ b/benchmarks/llm/configs/llama3_70B_full.yaml @@ -31,40 +31,73 @@ shuffle: True model: _component_: torchtune.models.llama3_1.llama3_1_70b +safetensors: true checkpointer: _component_: torchtune.utils.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: [ - model-00001-of-00030.safetensors, - model-00002-of-00030.safetensors, - model-00003-of-00030.safetensors, - model-00004-of-00030.safetensors, - model-00005-of-00030.safetensors, - model-00006-of-00030.safetensors, - model-00007-of-00030.safetensors, - model-00008-of-00030.safetensors, - model-00009-of-00030.safetensors, - model-00010-of-00030.safetensors, - model-00011-of-00030.safetensors, - model-00012-of-00030.safetensors, - model-00013-of-00030.safetensors, - model-00014-of-00030.safetensors, - model-00015-of-00030.safetensors, - model-00016-of-00030.safetensors, - model-00017-of-00030.safetensors, - model-00018-of-00030.safetensors, - model-00019-of-00030.safetensors, - model-00020-of-00030.safetensors, - model-00021-of-00030.safetensors, - model-00022-of-00030.safetensors, - model-00023-of-00030.safetensors, - model-00024-of-00030.safetensors, - model-00025-of-00030.safetensors, - model-00026-of-00030.safetensors, - model-00027-of-00030.safetensors, - model-00028-of-00030.safetensors, - model-00029-of-00030.safetensors, - model-00030-of-00030.safetensors, + model-00001-of-00062.safetensors, + model-00002-of-00062.safetensors, + model-00003-of-00062.safetensors, + model-00004-of-00062.safetensors, + model-00005-of-00062.safetensors, + model-00006-of-00062.safetensors, + model-00007-of-00062.safetensors, + model-00008-of-00062.safetensors, + model-00009-of-00062.safetensors, + model-00010-of-00062.safetensors, + model-00011-of-00062.safetensors, + model-00012-of-00062.safetensors, + model-00013-of-00062.safetensors, + model-00014-of-00062.safetensors, + model-00015-of-00062.safetensors, + model-00016-of-00062.safetensors, + model-00017-of-00062.safetensors, + model-00018-of-00062.safetensors, + model-00019-of-00062.safetensors, + model-00020-of-00062.safetensors, + model-00021-of-00062.safetensors, + model-00022-of-00062.safetensors, + model-00023-of-00062.safetensors, + model-00024-of-00062.safetensors, + model-00025-of-00062.safetensors, + model-00026-of-00062.safetensors, + model-00027-of-00062.safetensors, + model-00028-of-00062.safetensors, + model-00029-of-00062.safetensors, + model-00030-of-00062.safetensors, + model-00031-of-00062.safetensors, + model-00032-of-00062.safetensors, + model-00033-of-00062.safetensors, + model-00034-of-00062.safetensors, + model-00035-of-00062.safetensors, + model-00036-of-00062.safetensors, + model-00037-of-00062.safetensors, + model-00038-of-00062.safetensors, + model-00039-of-00062.safetensors, + model-00040-of-00062.safetensors, + model-00041-of-00062.safetensors, + model-00042-of-00062.safetensors, + model-00043-of-00062.safetensors, + model-00044-of-00062.safetensors, + model-00045-of-00062.safetensors, + model-00046-of-00062.safetensors, + model-00047-of-00062.safetensors, + model-00048-of-00062.safetensors, + model-00049-of-00062.safetensors, + model-00050-of-00062.safetensors, + model-00051-of-00062.safetensors, + model-00052-of-00062.safetensors, + model-00053-of-00062.safetensors, + model-00054-of-00062.safetensors, + model-00055-of-00062.safetensors, + model-00056-of-00062.safetensors, + model-00057-of-00062.safetensors, + model-00058-of-00062.safetensors, + model-00059-of-00062.safetensors, + model-00060-of-00062.safetensors, + model-00061-of-00062.safetensors, + model-00062-of-00062.safetensors, ] recipe_checkpoint: null output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ diff --git a/benchmarks/llm/configs/llama3_70B_lora.yaml b/benchmarks/llm/configs/llama3_70B_lora.yaml index 040eb571e..5934c65b8 100644 --- a/benchmarks/llm/configs/llama3_70B_lora.yaml +++ b/benchmarks/llm/configs/llama3_70B_lora.yaml @@ -21,42 +21,73 @@ tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model - safetensors: true checkpointer: _component_: torchtune.utils.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: [ - model-00001-of-00030.safetensors, - model-00002-of-00030.safetensors, - model-00003-of-00030.safetensors, - model-00004-of-00030.safetensors, - model-00005-of-00030.safetensors, - model-00006-of-00030.safetensors, - model-00007-of-00030.safetensors, - model-00008-of-00030.safetensors, - model-00009-of-00030.safetensors, - model-00010-of-00030.safetensors, - model-00011-of-00030.safetensors, - model-00012-of-00030.safetensors, - model-00013-of-00030.safetensors, - model-00014-of-00030.safetensors, - model-00015-of-00030.safetensors, - model-00016-of-00030.safetensors, - model-00017-of-00030.safetensors, - model-00018-of-00030.safetensors, - model-00019-of-00030.safetensors, - model-00020-of-00030.safetensors, - model-00021-of-00030.safetensors, - model-00022-of-00030.safetensors, - model-00023-of-00030.safetensors, - model-00024-of-00030.safetensors, - model-00025-of-00030.safetensors, - model-00026-of-00030.safetensors, - model-00027-of-00030.safetensors, - model-00028-of-00030.safetensors, - model-00029-of-00030.safetensors, - model-00030-of-00030.safetensors, + model-00001-of-00062.safetensors, + model-00002-of-00062.safetensors, + model-00003-of-00062.safetensors, + model-00004-of-00062.safetensors, + model-00005-of-00062.safetensors, + model-00006-of-00062.safetensors, + model-00007-of-00062.safetensors, + model-00008-of-00062.safetensors, + model-00009-of-00062.safetensors, + model-00010-of-00062.safetensors, + model-00011-of-00062.safetensors, + model-00012-of-00062.safetensors, + model-00013-of-00062.safetensors, + model-00014-of-00062.safetensors, + model-00015-of-00062.safetensors, + model-00016-of-00062.safetensors, + model-00017-of-00062.safetensors, + model-00018-of-00062.safetensors, + model-00019-of-00062.safetensors, + model-00020-of-00062.safetensors, + model-00021-of-00062.safetensors, + model-00022-of-00062.safetensors, + model-00023-of-00062.safetensors, + model-00024-of-00062.safetensors, + model-00025-of-00062.safetensors, + model-00026-of-00062.safetensors, + model-00027-of-00062.safetensors, + model-00028-of-00062.safetensors, + model-00029-of-00062.safetensors, + model-00030-of-00062.safetensors, + model-00031-of-00062.safetensors, + model-00032-of-00062.safetensors, + model-00033-of-00062.safetensors, + model-00034-of-00062.safetensors, + model-00035-of-00062.safetensors, + model-00036-of-00062.safetensors, + model-00037-of-00062.safetensors, + model-00038-of-00062.safetensors, + model-00039-of-00062.safetensors, + model-00040-of-00062.safetensors, + model-00041-of-00062.safetensors, + model-00042-of-00062.safetensors, + model-00043-of-00062.safetensors, + model-00044-of-00062.safetensors, + model-00045-of-00062.safetensors, + model-00046-of-00062.safetensors, + model-00047-of-00062.safetensors, + model-00048-of-00062.safetensors, + model-00049-of-00062.safetensors, + model-00050-of-00062.safetensors, + model-00051-of-00062.safetensors, + model-00052-of-00062.safetensors, + model-00053-of-00062.safetensors, + model-00054-of-00062.safetensors, + model-00055-of-00062.safetensors, + model-00056-of-00062.safetensors, + model-00057-of-00062.safetensors, + model-00058-of-00062.safetensors, + model-00059-of-00062.safetensors, + model-00060-of-00062.safetensors, + model-00061-of-00062.safetensors, + model-00062-of-00062.safetensors, ] recipe_checkpoint: null output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ diff --git a/benchmarks/llm/llama3/llama/__init__.py b/benchmarks/llm/llama/__init__.py similarity index 100% rename from benchmarks/llm/llama3/llama/__init__.py rename to benchmarks/llm/llama/__init__.py diff --git a/benchmarks/llm/llama3/llama/generation.py b/benchmarks/llm/llama/generation.py similarity index 100% rename from benchmarks/llm/llama3/llama/generation.py rename to benchmarks/llm/llama/generation.py diff --git a/benchmarks/llm/llama3/llama/model.py b/benchmarks/llm/llama/model.py similarity index 100% rename from benchmarks/llm/llama3/llama/model.py rename to benchmarks/llm/llama/model.py diff --git a/benchmarks/llm/llama3/llama/test_tokenizer.py b/benchmarks/llm/llama/test_tokenizer.py similarity index 100% rename from benchmarks/llm/llama3/llama/test_tokenizer.py rename to benchmarks/llm/llama/test_tokenizer.py diff --git a/benchmarks/llm/llama3/llama/tokenizer.py b/benchmarks/llm/llama/tokenizer.py similarity index 100% rename from benchmarks/llm/llama3/llama/tokenizer.py rename to benchmarks/llm/llama/tokenizer.py diff --git a/benchmarks/llm/prepare.py b/benchmarks/llm/prepare.py index b7bb490b2..631f3dd36 100755 --- a/benchmarks/llm/prepare.py +++ b/benchmarks/llm/prepare.py @@ -7,12 +7,14 @@ from pathlib import Path import time -import llama3.llama.model +import llama.model import fairscale.nn.model_parallel from omegaconf import OmegaConf from argklass import ArgumentParser +import torch import torch.distributed from torchtune._cli.tune import TuneCLIParser +from transformers import LlamaConfig, LlamaForCausalLM from benchmate.ux import long_action @@ -24,7 +26,7 @@ class Arguments: @dataclass -class ModelArgs(llama3.llama.model.ModelArgs): +class ModelArgs(llama.model.ModelArgs): use_scaled_rope: bool = True @@ -60,7 +62,7 @@ def generate_model( time.sleep(0.1) conn.recv() params = json.loads(params_path.read_text()) - model = llama3.llama.model.Transformer(ModelArgs(**params)) + model = llama.model.Transformer(ModelArgs(**params)) torch.save(model.state_dict(), params_path.with_name(f"consolidated.{rank:02}.pth")) except Exception as e: conn.send(e) @@ -104,9 +106,6 @@ def main(): ignore_patterns = ["*.safetensors", "original/consolidated.*.pth"] - if config.get("safetensors", False): - ignore_patterns = ["original/consolidated.*.pth"] - download_args = [ "download", repo_id, @@ -120,7 +119,7 @@ def main(): [] ) ] - + if hf_token is not None: download_args.extend([ "--hf-token", @@ -133,9 +132,19 @@ def main(): args = parser.parse_args(download_args) parser.run(args) - if not config.get("safetensors", False): + if config.get("safetensors", False): + params_path = args.output_dir / "config.json" + model = LlamaForCausalLM(LlamaConfig(**json.loads(params_path.read_text()))) + # Avoid saving this as part of the config. + del model.config._name_or_path + model.config.torch_dtype = torch.float16 + model.save_pretrained(str(args.output_dir), safe_serialization=True) + + else: + # Note that at the time of writing torchtune doesn't support multi-*.pth + # files loading params_path = next(args.output_dir.glob("**/params.json")) - model_parallel_size = 8 if repo_id.split("-")[-1].lower() == "70b" else 1 + model_parallel_size = len(config["checkpointer"]["checkpoint_files"]) pipes = [multiprocessing.Pipe() for _ in range(model_parallel_size)] processes = [ multiprocessing.Process( diff --git a/benchmarks/llm/requirements.in b/benchmarks/llm/requirements.in index cac409865..a1830ac75 100644 --- a/benchmarks/llm/requirements.in +++ b/benchmarks/llm/requirements.in @@ -4,4 +4,7 @@ torch PyYAML argklass --r llama3/requirements.txt +# Prepare +accelerate +transformers +-r requirements.txt diff --git a/benchmarks/llm/llama3/requirements.txt b/benchmarks/llm/requirements.txt similarity index 100% rename from benchmarks/llm/llama3/requirements.txt rename to benchmarks/llm/requirements.txt diff --git a/config/base.yaml b/config/base.yaml index 336940ca7..9e56c2eef 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -26,7 +26,6 @@ _torchvision: --loader: pytorch --data: "{milabench_data}/FakeImageNet" - _torchvision_ddp: inherits: _defaults definition: ../benchmarks/torchvision_ddp @@ -112,7 +111,6 @@ _timm: --dataset: "FakeImageNet" --workers: "auto({n_worker}, 8)" - _accelerate_opt: inherits: _defaults tags: @@ -149,7 +147,6 @@ _accelerate_opt: use_deepspeed: true num_machines: 1 - fp16: inherits: _flops @@ -389,7 +386,6 @@ brax: --num-minibatches: 32 --num-envs: 8192 - _diffusion: inherits: _defaults definition: ../benchmarks/diffusion @@ -501,12 +497,10 @@ _llm: max_duration: 1200 num_machines: 1 - model_parallel_size: 1 inherits: _defaults definition: ../benchmarks/llm install_group: torch - llm-lora-single: inherits: _llm plan: @@ -525,7 +519,6 @@ llm-lora-single: batch_size=8: true gradient_accumulation_steps=8: true - llm-lora-ddp-gpus: inherits: _llm plan: @@ -545,7 +538,6 @@ llm-lora-ddp-gpus: batch_size=8: true gradient_accumulation_steps=8: true - llm-lora-ddp-nodes: inherits: _llm plan: @@ -569,7 +561,6 @@ llm-lora-ddp-nodes: requires_capabilities: - "len(nodes) >= ${num_machines}" - llm-lora-mp-gpus: inherits: _llm plan: @@ -584,14 +575,12 @@ llm-lora-mp-gpus: tokenizer.path={milabench_data}/llama3_70B/original/tokenizer.model: true checkpointer.checkpoint_dir={milabench_data}/llama3_70B: true checkpointer.output_dir={milabench_data}/llama3_70B/: true + safetensors=true: true metric_logger.log_dir={milabench_extra}/metrics: true repo_id="meta-llama/Meta-Llama-3.1-70B": true batch_size=8: true gradient_accumulation_steps=1: true - model_parallel_size: 8 - - llm-full-mp-gpus: inherits: _llm plan: @@ -606,15 +595,12 @@ llm-full-mp-gpus: tokenizer.path={milabench_data}/llama3_70B/original/tokenizer.model: true checkpointer.checkpoint_dir={milabench_data}/llama3_70B: true checkpointer.output_dir={milabench_data}/llama3_70B/: true + safetensors=true: true metric_logger.log_dir={milabench_extra}/metrics: true repo_id="meta-llama/Meta-Llama-3.1-70B": true - safetensors=true: true batch_size=2: true gradient_accumulation_steps=1: true - model_parallel_size: 8 - - llm-full-mp-nodes: inherits: _llm plan: @@ -629,15 +615,12 @@ llm-full-mp-nodes: tokenizer.path={milabench_data}/llama3_70B/original/tokenizer.model: true checkpointer.checkpoint_dir={milabench_data}/llama3_70B: true checkpointer.output_dir={milabench_data}/llama3_70B/: true + safetensors=true: true metric_logger.log_dir={milabench_extra}/metrics: true repo_id="meta-llama/Meta-Llama-3.1-70B": true - safetensors=true: true batch_size=2: true gradient_accumulation_steps=1: true num_machines: 2 - model_parallel_size: 8 requires_capabilities: - "len(nodes) >= ${num_machines}" - -