Skip to content

Commit

Permalink
Generate llama instead of downloading it
Browse files Browse the repository at this point in the history
  • Loading branch information
satyaog committed Aug 13, 2024
1 parent 1b80d4f commit 258f5a2
Showing 1 changed file with 47 additions and 6 deletions.
53 changes: 47 additions & 6 deletions benchmarks/llm/prepare.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
#!/usr/bin/env python
import argparse
from dataclasses import dataclass
import json
import os
from pathlib import Path

import llama.model
import fairscale.nn.model_parallel
from omegaconf import OmegaConf
from argklass import ArgumentParser
import torch.distributed
from torchtune._cli.tune import TuneCLIParser

from benchmate.ux import long_action
Expand All @@ -16,10 +21,35 @@ class Arguments:
config: str = None


@dataclass
class ModelArgs(llama.model.ModelArgs):
use_scaled_rope: bool = True

class MyParser(TuneCLIParser):
def parse_args(self, args=None) -> argparse.Namespace:
"""Parse CLI arguments"""
return self._parser.parse_args(args)
parsed_args = self._parser.parse_args(args)
# Workaround to send a list to of ignore_patterns as self._parser does
# not support a list in input
parser = argparse.ArgumentParser()
parser.add_argument(
"--ignore-patterns",
type=str,
action='append',
)
ignore_patterns_args, _ = parser.parse_known_args(args)
if ignore_patterns_args.ignore_patterns:
parsed_args.ignore_patterns = ignore_patterns_args.ignore_patterns
return parsed_args


def generate_model(model_parallel_size, params_path:Path):
params = json.loads(params_path.read_text())
torch.distributed.init_process_group(rank=0, world_size=1)
fairscale.nn.model_parallel.initialize.initialize_model_parallel(model_parallel_size)
for i in range(model_parallel_size):
model = llama.model.Transformer(ModelArgs(**params))
torch.save(model.state_dict(), params_path.with_name(f"consolidated.{i:02}.pth"))


def load_model(recipe, cfg):
Expand Down Expand Up @@ -54,18 +84,24 @@ def main():
repo_id = config["repo_id"]
hf_token = os.getenv("HUGGING_FACE_TOKEN", None)
output_dir = config["checkpointer"]["output_dir"]
ignore_pattern = "*.safetensors"

ignore_patterns = ["*.safetensors", "original/consolidated.*.pth"]

if config.get("safetensors", False):
ignore_pattern = "consolidated.*.pth"
ignore_patterns = ["original/consolidated.*.pth"]

download_args = [
"download",
repo_id,
"--output-dir",
output_dir,
"--ignore-patterns",
ignore_pattern
*sum(
[
["--ignore-patterns", ignore_pattern]
for ignore_pattern in ignore_patterns
],
[]
)
]

if hf_token is not None:
Expand All @@ -75,11 +111,16 @@ def main():
])
else:
print("No HF token found...")

parser = MyParser()
args = parser.parse_args(download_args)
parser.run(args)

if not config.get("safetensors", False):
params_path = next(args.output_dir.glob("**/params.json"))
model_parallel_size = 8 if repo_id.split("-")[-1].lower() == "70b" else 1
generate_model(model_parallel_size, params_path)

if "qlora" in config.get("model", {}).get("_component_", ""):
load_model(args.recipe, config)

Expand Down

0 comments on commit 258f5a2

Please sign in to comment.