From ea44ea63be161bea2dd22c6dd23b1386474f09a7 Mon Sep 17 00:00:00 2001 From: satyaog Date: Thu, 5 Sep 2024 12:03:19 -0400 Subject: [PATCH] Generate llama instead of downloading it (#250) * Generate llama instead of downloading it * Generate safe checkpoints for llama3 70B * Fix llm requirements * rename huggingface token to MILABENCH_* to automatically forward the env var to a remote in such cases --- .github/workflows/tests_unit.yml | 2 +- benchmarks/llm/benchfile.py | 24 ++ benchmarks/llm/configs/llama3_70B_full.yaml | 93 +++-- benchmarks/llm/configs/llama3_70B_lora.yaml | 93 +++-- benchmarks/llm/llama/__init__.py | 6 + benchmarks/llm/llama/generation.py | 365 ++++++++++++++++++++ benchmarks/llm/llama/model.py | 302 ++++++++++++++++ benchmarks/llm/llama/test_tokenizer.py | 88 +++++ benchmarks/llm/llama/tokenizer.py | 229 ++++++++++++ benchmarks/llm/prepare.py | 115 +++++- benchmarks/llm/requirements.cuda.txt | 46 +++ benchmarks/llm/requirements.in | 5 + benchmarks/llm/requirements.txt | 4 + config/base.yaml | 3 +- 14 files changed, 1302 insertions(+), 73 deletions(-) create mode 100644 benchmarks/llm/llama/__init__.py create mode 100644 benchmarks/llm/llama/generation.py create mode 100644 benchmarks/llm/llama/model.py create mode 100644 benchmarks/llm/llama/test_tokenizer.py create mode 100644 benchmarks/llm/llama/tokenizer.py create mode 100644 benchmarks/llm/requirements.txt diff --git a/.github/workflows/tests_unit.yml b/.github/workflows/tests_unit.yml index 90d6f4831..28262cf16 100644 --- a/.github/workflows/tests_unit.yml +++ b/.github/workflows/tests_unit.yml @@ -74,7 +74,7 @@ jobs: - name: tests env: - HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN}} + MILABENCH_HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN}} run: | source $(poetry env info -p)/bin/activate coverage run --source=milabench -m pytest --ignore=tests/integration tests/ diff --git a/benchmarks/llm/benchfile.py b/benchmarks/llm/benchfile.py index 6f8cadeee..ca269c368 100644 --- a/benchmarks/llm/benchfile.py +++ b/benchmarks/llm/benchfile.py @@ -1,3 +1,5 @@ +import tempfile +from milabench.fs import XPath from milabench.pack import Package @@ -38,6 +40,28 @@ class Llm(Package): prepare_script = "prepare.py" async def install(self): + llama3_dir = XPath(__file__).resolve().parent + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_dir = XPath(tmp_dir) + tmp_dir.clone_subtree( + "https://github.com/meta-llama/llama3.git", + "11817d47e1ba7a4959b025eb1ca308572e0e3963", + ) + tmp_dir.merge_into( + llama3_dir, + manifest="\n".join( + [ + "/llama/", + "/requirements.txt", + ] + ) + ) + # Fix conflict with tiktoken. As we only need llama/model.py, we don't + # need to care about a compatible tiktoken for the llama3 module + requirements = (llama3_dir / "requirements.txt").read_text().splitlines() + requirements = [l for l in requirements if not l.startswith("tiktoken==")] + (llama3_dir / "requirements.txt").write_text("\n".join(requirements)) + await super().install() # super() call installs the requirements def build_run_plan(self): diff --git a/benchmarks/llm/configs/llama3_70B_full.yaml b/benchmarks/llm/configs/llama3_70B_full.yaml index bac2d9709..3eb270e62 100644 --- a/benchmarks/llm/configs/llama3_70B_full.yaml +++ b/benchmarks/llm/configs/llama3_70B_full.yaml @@ -31,40 +31,73 @@ shuffle: True model: _component_: torchtune.models.llama3_1.llama3_1_70b +safetensors: true checkpointer: _component_: torchtune.utils.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: [ - model-00001-of-00030.safetensors, - model-00002-of-00030.safetensors, - model-00003-of-00030.safetensors, - model-00004-of-00030.safetensors, - model-00005-of-00030.safetensors, - model-00006-of-00030.safetensors, - model-00007-of-00030.safetensors, - model-00008-of-00030.safetensors, - model-00009-of-00030.safetensors, - model-00010-of-00030.safetensors, - model-00011-of-00030.safetensors, - model-00012-of-00030.safetensors, - model-00013-of-00030.safetensors, - model-00014-of-00030.safetensors, - model-00015-of-00030.safetensors, - model-00016-of-00030.safetensors, - model-00017-of-00030.safetensors, - model-00018-of-00030.safetensors, - model-00019-of-00030.safetensors, - model-00020-of-00030.safetensors, - model-00021-of-00030.safetensors, - model-00022-of-00030.safetensors, - model-00023-of-00030.safetensors, - model-00024-of-00030.safetensors, - model-00025-of-00030.safetensors, - model-00026-of-00030.safetensors, - model-00027-of-00030.safetensors, - model-00028-of-00030.safetensors, - model-00029-of-00030.safetensors, - model-00030-of-00030.safetensors, + model-00001-of-00062.safetensors, + model-00002-of-00062.safetensors, + model-00003-of-00062.safetensors, + model-00004-of-00062.safetensors, + model-00005-of-00062.safetensors, + model-00006-of-00062.safetensors, + model-00007-of-00062.safetensors, + model-00008-of-00062.safetensors, + model-00009-of-00062.safetensors, + model-00010-of-00062.safetensors, + model-00011-of-00062.safetensors, + model-00012-of-00062.safetensors, + model-00013-of-00062.safetensors, + model-00014-of-00062.safetensors, + model-00015-of-00062.safetensors, + model-00016-of-00062.safetensors, + model-00017-of-00062.safetensors, + model-00018-of-00062.safetensors, + model-00019-of-00062.safetensors, + model-00020-of-00062.safetensors, + model-00021-of-00062.safetensors, + model-00022-of-00062.safetensors, + model-00023-of-00062.safetensors, + model-00024-of-00062.safetensors, + model-00025-of-00062.safetensors, + model-00026-of-00062.safetensors, + model-00027-of-00062.safetensors, + model-00028-of-00062.safetensors, + model-00029-of-00062.safetensors, + model-00030-of-00062.safetensors, + model-00031-of-00062.safetensors, + model-00032-of-00062.safetensors, + model-00033-of-00062.safetensors, + model-00034-of-00062.safetensors, + model-00035-of-00062.safetensors, + model-00036-of-00062.safetensors, + model-00037-of-00062.safetensors, + model-00038-of-00062.safetensors, + model-00039-of-00062.safetensors, + model-00040-of-00062.safetensors, + model-00041-of-00062.safetensors, + model-00042-of-00062.safetensors, + model-00043-of-00062.safetensors, + model-00044-of-00062.safetensors, + model-00045-of-00062.safetensors, + model-00046-of-00062.safetensors, + model-00047-of-00062.safetensors, + model-00048-of-00062.safetensors, + model-00049-of-00062.safetensors, + model-00050-of-00062.safetensors, + model-00051-of-00062.safetensors, + model-00052-of-00062.safetensors, + model-00053-of-00062.safetensors, + model-00054-of-00062.safetensors, + model-00055-of-00062.safetensors, + model-00056-of-00062.safetensors, + model-00057-of-00062.safetensors, + model-00058-of-00062.safetensors, + model-00059-of-00062.safetensors, + model-00060-of-00062.safetensors, + model-00061-of-00062.safetensors, + model-00062-of-00062.safetensors, ] recipe_checkpoint: null output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ diff --git a/benchmarks/llm/configs/llama3_70B_lora.yaml b/benchmarks/llm/configs/llama3_70B_lora.yaml index 040eb571e..5934c65b8 100644 --- a/benchmarks/llm/configs/llama3_70B_lora.yaml +++ b/benchmarks/llm/configs/llama3_70B_lora.yaml @@ -21,42 +21,73 @@ tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model - safetensors: true checkpointer: _component_: torchtune.utils.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: [ - model-00001-of-00030.safetensors, - model-00002-of-00030.safetensors, - model-00003-of-00030.safetensors, - model-00004-of-00030.safetensors, - model-00005-of-00030.safetensors, - model-00006-of-00030.safetensors, - model-00007-of-00030.safetensors, - model-00008-of-00030.safetensors, - model-00009-of-00030.safetensors, - model-00010-of-00030.safetensors, - model-00011-of-00030.safetensors, - model-00012-of-00030.safetensors, - model-00013-of-00030.safetensors, - model-00014-of-00030.safetensors, - model-00015-of-00030.safetensors, - model-00016-of-00030.safetensors, - model-00017-of-00030.safetensors, - model-00018-of-00030.safetensors, - model-00019-of-00030.safetensors, - model-00020-of-00030.safetensors, - model-00021-of-00030.safetensors, - model-00022-of-00030.safetensors, - model-00023-of-00030.safetensors, - model-00024-of-00030.safetensors, - model-00025-of-00030.safetensors, - model-00026-of-00030.safetensors, - model-00027-of-00030.safetensors, - model-00028-of-00030.safetensors, - model-00029-of-00030.safetensors, - model-00030-of-00030.safetensors, + model-00001-of-00062.safetensors, + model-00002-of-00062.safetensors, + model-00003-of-00062.safetensors, + model-00004-of-00062.safetensors, + model-00005-of-00062.safetensors, + model-00006-of-00062.safetensors, + model-00007-of-00062.safetensors, + model-00008-of-00062.safetensors, + model-00009-of-00062.safetensors, + model-00010-of-00062.safetensors, + model-00011-of-00062.safetensors, + model-00012-of-00062.safetensors, + model-00013-of-00062.safetensors, + model-00014-of-00062.safetensors, + model-00015-of-00062.safetensors, + model-00016-of-00062.safetensors, + model-00017-of-00062.safetensors, + model-00018-of-00062.safetensors, + model-00019-of-00062.safetensors, + model-00020-of-00062.safetensors, + model-00021-of-00062.safetensors, + model-00022-of-00062.safetensors, + model-00023-of-00062.safetensors, + model-00024-of-00062.safetensors, + model-00025-of-00062.safetensors, + model-00026-of-00062.safetensors, + model-00027-of-00062.safetensors, + model-00028-of-00062.safetensors, + model-00029-of-00062.safetensors, + model-00030-of-00062.safetensors, + model-00031-of-00062.safetensors, + model-00032-of-00062.safetensors, + model-00033-of-00062.safetensors, + model-00034-of-00062.safetensors, + model-00035-of-00062.safetensors, + model-00036-of-00062.safetensors, + model-00037-of-00062.safetensors, + model-00038-of-00062.safetensors, + model-00039-of-00062.safetensors, + model-00040-of-00062.safetensors, + model-00041-of-00062.safetensors, + model-00042-of-00062.safetensors, + model-00043-of-00062.safetensors, + model-00044-of-00062.safetensors, + model-00045-of-00062.safetensors, + model-00046-of-00062.safetensors, + model-00047-of-00062.safetensors, + model-00048-of-00062.safetensors, + model-00049-of-00062.safetensors, + model-00050-of-00062.safetensors, + model-00051-of-00062.safetensors, + model-00052-of-00062.safetensors, + model-00053-of-00062.safetensors, + model-00054-of-00062.safetensors, + model-00055-of-00062.safetensors, + model-00056-of-00062.safetensors, + model-00057-of-00062.safetensors, + model-00058-of-00062.safetensors, + model-00059-of-00062.safetensors, + model-00060-of-00062.safetensors, + model-00061-of-00062.safetensors, + model-00062-of-00062.safetensors, ] recipe_checkpoint: null output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ diff --git a/benchmarks/llm/llama/__init__.py b/benchmarks/llm/llama/__init__.py new file mode 100644 index 000000000..2a460b68d --- /dev/null +++ b/benchmarks/llm/llama/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +from .generation import Llama +from .model import ModelArgs, Transformer +from .tokenizer import Dialog, Tokenizer diff --git a/benchmarks/llm/llama/generation.py b/benchmarks/llm/llama/generation.py new file mode 100644 index 000000000..96be4b291 --- /dev/null +++ b/benchmarks/llm/llama/generation.py @@ -0,0 +1,365 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import json +import os +import sys +import time +from pathlib import Path +from typing import List, Optional, Tuple, TypedDict + +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_rank, + initialize_model_parallel, + model_parallel_is_initialized, +) + +from llama.model import ModelArgs, Transformer +from llama.tokenizer import ChatFormat, Dialog, Message, Tokenizer + + +class CompletionPrediction(TypedDict, total=False): + generation: str + tokens: List[str] # not required + logprobs: List[float] # not required + + +class ChatPrediction(TypedDict, total=False): + generation: Message + tokens: List[str] # not required + logprobs: List[float] # not required + + +class Llama: + @staticmethod + def build( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + model_parallel_size: Optional[int] = None, + seed: int = 1, + ) -> "Llama": + """ + Build a Llama instance by initializing and loading a model checkpoint. + + Args: + ckpt_dir (str): Path to the directory containing checkpoint files. + tokenizer_path (str): Path to the tokenizer file. + max_seq_len (int): Maximum sequence length for input text. + max_batch_size (int): Maximum batch size for inference. + model_parallel_size (Optional[int], optional): Number of model parallel processes. + If not provided, it's determined from the environment. Defaults to None. + + Returns: + Llama: An instance of the Llama class with the loaded model and tokenizer. + + Raises: + AssertionError: If there are no checkpoint files in the specified directory, + or if the model parallel size does not match the number of checkpoint files. + + Note: + This method initializes the distributed process group, sets the device to CUDA, + and loads the pre-trained model and tokenizer. + """ + assert 1 <= max_seq_len <= 8192, f"max_seq_len must be between 1 and 8192, got {max_seq_len}." + assert os.path.isdir(ckpt_dir), f"Checkpoint directory '{ckpt_dir}' does not exist." + assert os.path.isfile(tokenizer_path), f"Tokenizer file '{tokenizer_path}' does not exist." + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group("nccl") + if not model_parallel_is_initialized(): + if model_parallel_size is None: + model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) + initialize_model_parallel(model_parallel_size) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(seed) + + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert model_parallel_size == len( + checkpoints + ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" + ckpt_path = checkpoints[get_model_parallel_rank()] + checkpoint = torch.load(ckpt_path, map_location="cpu") + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + assert model_args.vocab_size == tokenizer.n_words + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) + model = Transformer(model_args) + model.load_state_dict(checkpoint, strict=False) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + + return Llama(model, tokenizer) + + def __init__(self, model: Transformer, tokenizer: Tokenizer): + self.model = model + self.tokenizer = tokenizer + self.formatter = ChatFormat(tokenizer) + + @torch.inference_mode() + def generate( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + """ + Generate text sequences based on provided prompts using the language generation model. + + Args: + prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. + + Note: + This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + params = self.model.params + bsz = len(prompt_tokens) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.max_seq_len + total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + if min_prompt_len == total_len: + logits = self.model.forward(tokens, prev_pos) + token_logprobs = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens, + reduction="none", + ignore_index=pad_id, + ) + + stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)) + + for cur_pos in range(min_prompt_len, total_len): + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + torch.isin(next_token, stop_tokens) + ) + prev_pos = cur_pos + if all(eos_reached): + break + + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + probs = None + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to after eos tok if any + for stop_token in self.tokenizer.stop_tokens: + try: + eos_idx = toks.index(stop_token) + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + except ValueError: + pass + out_tokens.append(toks) + out_logprobs.append(probs) + return (out_tokens, out_logprobs if logprobs else None) + + def text_completion( + self, + prompts: List[str], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + ) -> List[CompletionPrediction]: + """ + Perform text completion for a list of prompts using the language generation model. + + Args: + prompts (List[str]): List of text prompts for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + List[CompletionPrediction]: List of completion predictions, each containing the generated text completion. + + Note: + This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ) + if logprobs: + return [ + { + "generation": self.tokenizer.decode(t), + "tokens": [self.tokenizer.decode([x]) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] + + def chat_completion( + self, + dialogs: List[Dialog], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + ) -> List[ChatPrediction]: + """ + Generate assistant responses for a list of conversational dialogs using the language generation model. + + Args: + dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + + Returns: + List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. + + Note: + This method generates assistant responses for the provided conversational dialogs. + It employs nucleus sampling to introduce controlled randomness in text generation. + If logprobs is True, token log probabilities are computed for each generated token. + """ + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + + prompt_tokens = [ + self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs + ] + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + ) + if logprobs: + return [ + { + "generation": { + "role": "assistant", + "content": self.tokenizer.decode(t), + }, + "tokens": [self.tokenizer.decode([x]) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [ + { + "generation": { + "role": "assistant", + "content": self.tokenizer.decode(t), + }, + } + for t in generation_tokens + ] + + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/benchmarks/llm/llama/model.py b/benchmarks/llm/llama/model.py new file mode 100644 index 000000000..e388c0387 --- /dev/null +++ b/benchmarks/llm/llama/model.py @@ -0,0 +1,302 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import fairscale.nn.model_parallel.initialize as fs_init +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) +from torch import nn + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 500000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + model_parallel_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wk = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wv = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, + args.dim, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + + self.cache_k = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + self.cache_v = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv( + keys, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv( + values, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose( + 1, 2 + ) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.w2 = RowParallelLinear( + hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x + ) + self.w3 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = VocabParallelEmbedding( + params.vocab_size, params.dim, init_method=lambda x: x + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = ColumnParallelLinear( + params.dim, params.vocab_size, bias=False, init_method=lambda x: x + ) + + self.freqs_cis = precompute_freqs_cis( + params.dim // params.n_heads, + params.max_seq_len * 2, + params.rope_theta, + ) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) + + mask = torch.triu(mask, diagonal=1) + + # When performing key-value caching, we compute the attention scores + # only for the new sequence. Thus, the matrix of scores is of size + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. + mask = torch.hstack( + [torch.zeros((seqlen, start_pos), device=tokens.device), mask] + ).type_as(h) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h).float() + return output diff --git a/benchmarks/llm/llama/test_tokenizer.py b/benchmarks/llm/llama/test_tokenizer.py new file mode 100644 index 000000000..5c2a0749b --- /dev/null +++ b/benchmarks/llm/llama/test_tokenizer.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from unittest import TestCase +from llama.tokenizer import ChatFormat, Tokenizer + +# TOKENIZER_PATH= python -m unittest llama/test_tokenizer.py + +class TokenizerTests(TestCase): + def setUp(self): + self.tokenizer = Tokenizer(os.environ["TOKENIZER_PATH"]) + self.format = ChatFormat(self.tokenizer) + + def test_special_tokens(self): + self.assertEqual( + self.tokenizer.special_tokens["<|begin_of_text|>"], + 128000, + ) + + def test_encode(self): + self.assertEqual( + self.tokenizer.encode( + "This is a test sentence.", + bos=True, + eos=True + ), + [128000, 2028, 374, 264, 1296, 11914, 13, 128001], + ) + + def test_decode(self): + self.assertEqual( + self.tokenizer.decode( + [128000, 2028, 374, 264, 1296, 11914, 13, 128001], + ), + "<|begin_of_text|>This is a test sentence.<|end_of_text|>", + ) + + def test_encode_message(self): + message = { + "role": "user", + "content": "This is a test sentence.", + } + self.assertEqual( + self.format.encode_message(message), + [ + 128006, # <|start_header_id|> + 882, # "user" + 128007, # <|end_header_id|> + 271, # "\n\n" + 2028, 374, 264, 1296, 11914, 13, # This is a test sentence. + 128009, # <|eot_id|> + ] + ) + + def test_encode_dialog(self): + dialog = [ + { + "role": "system", + "content": "This is a test sentence.", + }, + { + "role": "user", + "content": "This is a response.", + } + ] + self.assertEqual( + self.format.encode_dialog_prompt(dialog), + [ + 128000, # <|begin_of_text|> + 128006, # <|start_header_id|> + 9125, # "system" + 128007, # <|end_header_id|> + 271, # "\n\n" + 2028, 374, 264, 1296, 11914, 13, # "This is a test sentence." + 128009, # <|eot_id|> + 128006, # <|start_header_id|> + 882, # "user" + 128007, # <|end_header_id|> + 271, # "\n\n" + 2028, 374, 264, 2077, 13, # "This is a response.", + 128009, # <|eot_id|> + 128006, # <|start_header_id|> + 78191, # "assistant" + 128007, # <|end_header_id|> + 271, # "\n\n" + ] + ) diff --git a/benchmarks/llm/llama/tokenizer.py b/benchmarks/llm/llama/tokenizer.py new file mode 100644 index 000000000..e691beb6a --- /dev/null +++ b/benchmarks/llm/llama/tokenizer.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from logging import getLogger +from pathlib import Path +from typing import ( + AbstractSet, + cast, + Collection, + Dict, + Iterator, + List, + Literal, + Sequence, + TypedDict, + Union, +) + +import tiktoken +from tiktoken.load import load_tiktoken_bpe + + +logger = getLogger(__name__) + + +Role = Literal["system", "user", "assistant"] + + +class Message(TypedDict): + role: Role + content: str + + +Dialog = Sequence[Message] + + +class Tokenizer: + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a Tiktoken model. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + logger.info(f"Reloaded tiktoken model from {model_path}") + + self.n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.pad_id: int = -1 + self.stop_tokens = { + self.special_tokens["<|end_of_text|>"], + self.special_tokens["<|eot_id|>"], + } + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = (), + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + assert type(s) is str + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + +class ChatFormat: + def __init__(self, tokenizer: Tokenizer): + self.tokenizer = tokenizer + + def encode_header(self, message: Message) -> List[int]: + tokens = [] + tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) + tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) + tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) + tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) + return tokens + + def encode_message(self, message: Message) -> List[int]: + tokens = self.encode_header(message) + tokens.extend( + self.tokenizer.encode(message["content"].strip(), bos=False, eos=False) + ) + tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) + return tokens + + def encode_dialog_prompt(self, dialog: Dialog) -> List[int]: + tokens = [] + tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) + for message in dialog: + tokens.extend(self.encode_message(message)) + # Add the start of an assistant message for the model to complete. + tokens.extend(self.encode_header({"role": "assistant", "content": ""})) + return tokens diff --git a/benchmarks/llm/prepare.py b/benchmarks/llm/prepare.py index ca5ba2d3b..221162ffa 100755 --- a/benchmarks/llm/prepare.py +++ b/benchmarks/llm/prepare.py @@ -1,11 +1,20 @@ #!/usr/bin/env python import argparse from dataclasses import dataclass +import json +import multiprocessing import os +from pathlib import Path +import time +import llama.model +import fairscale.nn.model_parallel from omegaconf import OmegaConf from argklass import ArgumentParser +import torch +import torch.distributed from torchtune._cli.tune import TuneCLIParser +from transformers import LlamaConfig, LlamaForCausalLM from benchmate.ux import long_action @@ -16,10 +25,57 @@ class Arguments: config: str = None +@dataclass +class ModelArgs(llama.model.ModelArgs): + use_scaled_rope: bool = True + + class MyParser(TuneCLIParser): def parse_args(self, args=None) -> argparse.Namespace: """Parse CLI arguments""" - return self._parser.parse_args(args) + parsed_args = self._parser.parse_args(args) + # Workaround to send a list to of ignore_patterns as self._parser does + # not support a list in input + parser = argparse.ArgumentParser() + parser.add_argument( + "--ignore-patterns", + type=str, + action='append', + ) + ignore_patterns_args, _ = parser.parse_known_args(args) + if ignore_patterns_args.ignore_patterns: + parsed_args.ignore_patterns = ignore_patterns_args.ignore_patterns + return parsed_args + + +def generate_model( + conn:multiprocessing.connection.Connection, + params_path:Path, + rank=0, + model_parallel_size=1 + ): + try: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + torch.distributed.init_process_group(rank=rank, world_size=model_parallel_size) + fairscale.nn.model_parallel.initialize.initialize_model_parallel(model_parallel_size) + + conn.send(os.getpid()) + while not conn.poll(): + time.sleep(0.1) + conn.recv() + + params = json.loads(params_path.read_text()) + model = llama.model.Transformer(ModelArgs(**params)) + torch.save(model.state_dict(), params_path.with_name(f"consolidated.{rank:02}.pth")) + + except Exception as e: + conn.send(e) + raise + + finally: + conn.close() def load_model(recipe, cfg): @@ -52,22 +108,25 @@ def main(): config = OmegaConf.merge(base, cli) repo_id = config["repo_id"] - hf_token = os.getenv("HUGGING_FACE_TOKEN", None) + hf_token = os.getenv("MILABENCH_HF_TOKEN", None) output_dir = config["checkpointer"]["output_dir"] - ignore_pattern = "*.safetensors" - if config.get("safetensors", False): - ignore_pattern = "*consolidated.*.pth" + ignore_patterns = ["*.safetensors", "*consolidated.*.pth"] download_args = [ "download", repo_id, "--output-dir", output_dir, - "--ignore-patterns", - ignore_pattern + *sum( + [ + ["--ignore-patterns", ignore_pattern] + for ignore_pattern in ignore_patterns + ], + [] + ) ] - + if hf_token is not None: download_args.extend([ "--hf-token", @@ -75,11 +134,49 @@ def main(): ]) else: print("No HF token found...") - + parser = MyParser() args = parser.parse_args(download_args) parser.run(args) + if config.get("safetensors", False): + params_path = args.output_dir / "config.json" + model = LlamaForCausalLM(LlamaConfig(**json.loads(params_path.read_text()))) + # Avoid saving this as part of the config. + del model.config._name_or_path + model.config.torch_dtype = torch.float16 + model.save_pretrained(str(args.output_dir), safe_serialization=True) + + else: + # Note that at the time of writing torchtune doesn't support multi-*.pth + # files loading + params_path = next(args.output_dir.glob("**/params.json")) + model_parallel_size = len(config["checkpointer"]["checkpoint_files"]) + pipes = [multiprocessing.Pipe() for _ in range(model_parallel_size)] + processes = [ + multiprocessing.Process( + target=generate_model, + args=[conn, params_path, rank, model_parallel_size] + ) + for rank, (_, conn) in enumerate(pipes) + ] + # Init torch.distributed process_group and fairscale model parallel in + # each fake workers + [p.start() for p in processes] + pids = set() + for (conn, _) in pipes: + while not conn.poll(): + time.sleep(0.1) + pid = conn.recv() + if isinstance(pid, Exception): + raise pid + pids.add(pid) + assert len(pids) == model_parallel_size + # Generate each chunk of the model one by one + for p, (conn, _) in zip(processes, pipes): + conn.send(True) + p.join() + if "qlora" in config.get("model", {}).get("_component_", ""): load_model(args.recipe, config) diff --git a/benchmarks/llm/requirements.cuda.txt b/benchmarks/llm/requirements.cuda.txt index 94afa483c..bac59e5f0 100644 --- a/benchmarks/llm/requirements.cuda.txt +++ b/benchmarks/llm/requirements.cuda.txt @@ -9,6 +9,10 @@ --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --trusted-host pypi.ngc.nvidia.com +accelerate==0.33.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.in aiohappyeyeballs==2.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -45,6 +49,7 @@ attrs==24.2.0 blobfile==3.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.txt # torchtune certifi==2024.8.30 # via @@ -71,6 +76,10 @@ executing==1.2.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname +fairscale==0.4.13 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.txt filelock==3.15.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -78,7 +87,12 @@ filelock==3.15.4 # datasets # huggingface-hub # torch + # transformers # triton +fire==0.6.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.txt frozenlist==1.4.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -102,8 +116,11 @@ hjson==3.1.0 huggingface-hub==0.24.6 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # accelerate # datasets + # tokenizers # torchtune + # transformers idna==3.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -174,7 +191,9 @@ networkx==3.3 numpy==1.26.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # accelerate # datasets + # fairscale # jax # jaxlib # ml-dtypes @@ -183,6 +202,7 @@ numpy==1.26.4 # pyarrow # scipy # torchtune + # transformers nvidia-cublas-cu12==12.1.3.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -268,8 +288,10 @@ ovld==0.3.9 packaging==24.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # accelerate # datasets # huggingface-hub + # transformers pandas==2.2.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -277,6 +299,7 @@ pandas==2.2.2 psutil==5.9.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # accelerate # voir ptera==1.4.1 # via @@ -306,9 +329,11 @@ pyyaml==6.0.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llm/requirements.in + # accelerate # datasets # huggingface-hub # omegaconf + # transformers reactivex==4.0.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -317,12 +342,14 @@ regex==2024.7.24 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tiktoken + # transformers requests==2.32.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets # huggingface-hub # tiktoken + # transformers rich==13.8.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -330,7 +357,9 @@ rich==13.8.0 safetensors==0.4.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # accelerate # torchtune + # transformers scipy==1.14.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -344,19 +373,31 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens + # fire # python-dateutil sympy==1.13.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch +termcolor==2.4.0 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # fire tiktoken==0.7.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torchtune +tokenizers==0.19.1 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # transformers torch==2.4.0+cu121 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llm/requirements.in + # -r benchmarks/llm/requirements.txt + # accelerate + # fairscale torchao==0.3.1+cu121 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -371,6 +412,11 @@ tqdm==4.66.5 # datasets # huggingface-hub # torchtune + # transformers +transformers==4.44.2 + # via + # -c .pin/../.pin/constraints-cuda-torch.txt + # -r benchmarks/llm/requirements.in triton==3.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt diff --git a/benchmarks/llm/requirements.in b/benchmarks/llm/requirements.in index 62ec2a995..bbe85dec2 100644 --- a/benchmarks/llm/requirements.in +++ b/benchmarks/llm/requirements.in @@ -3,3 +3,8 @@ torchtune torch PyYAML argklass + +# Prepare +accelerate +transformers +-r requirements.txt diff --git a/benchmarks/llm/requirements.txt b/benchmarks/llm/requirements.txt new file mode 100644 index 000000000..df593f573 --- /dev/null +++ b/benchmarks/llm/requirements.txt @@ -0,0 +1,4 @@ +torch +fairscale +fire +blobfile \ No newline at end of file diff --git a/config/base.yaml b/config/base.yaml index f3f0a8ae8..09cc11c85 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -602,12 +602,12 @@ llm-lora-mp-gpus: tokenizer.path={milabench_data}/llama3_70B/original/tokenizer.model: true checkpointer.checkpoint_dir={milabench_data}/llama3_70B: true checkpointer.output_dir={milabench_data}/llama3_70B/: true + safetensors=true: true metric_logger.log_dir={milabench_extra}/metrics: true repo_id="meta-llama/Meta-Llama-3.1-70B": true batch_size=8: true gradient_accumulation_steps=1: true - llm-full-mp-gpus: inherits: _llm plan: @@ -628,7 +628,6 @@ llm-full-mp-gpus: batch_size=2: true gradient_accumulation_steps=1: true - llm-full-mp-nodes: tags: - multinode