diff --git a/convert_model/convert_model.sh b/convert_model/convert_model.sh new file mode 100644 index 0000000..171b939 --- /dev/null +++ b/convert_model/convert_model.sh @@ -0,0 +1,4 @@ +python scripts/convert_lit_checkpoint.py \ + --out_dir ./out_20240217/tinyllama_1b \ + --checkpoint_name iter-2860000-ckpt.pth \ + --model_name csg-tiny-1B diff --git a/convert_model/convert_model_hf.sh b/convert_model/convert_model_hf.sh new file mode 100644 index 0000000..4551504 --- /dev/null +++ b/convert_model/convert_model_hf.sh @@ -0,0 +1,3 @@ +python scripts/convert_hf_checkpoint.py \ + --checkpoint_dir /data/models/csg-tiny-1B/csg-tiny-1B-480K \ + --model_name tiny_LLaMA_1b diff --git a/convert_model/convert_model_lit.sh b/convert_model/convert_model_lit.sh new file mode 100644 index 0000000..fc2b65b --- /dev/null +++ b/convert_model/convert_model_lit.sh @@ -0,0 +1,4 @@ +python scripts/convert_lit_checkpoint.py \ + --out_dir /data/train/csg_tiny_train/out_20240217/tinyllama_1b \ + --checkpoint_name iter-4780000-ckpt.pth \ + --model_name csg-tiny-1B diff --git a/convert_model/pytorch-to-safetensor-converter/README.md b/convert_model/pytorch-to-safetensor-converter/README.md new file mode 100644 index 0000000..2d1ed0d --- /dev/null +++ b/convert_model/pytorch-to-safetensor-converter/README.md @@ -0,0 +1,41 @@ +# Pytorch to Safetensor Converter + +--- + + + +A simple converter which converts pytorch .bin tensor files (Usually listed as "pytorch_model.bin" or "pytorch_model-xxxx-of-xxxx.bin") to safetensor files. Reason? + +~~because it's cool!~~ + +Because the safetensor format decreases the loading time of large LLM models, currently supported in [oobabooga's text-generation-webui](https://github.com/oobabooga/text-generation-webui). It also supports in-place loading, which effectively decreased the required memory to load a LLM. + +Note: Most of the code originated from [Convert to Safetensors - a Hugging Face Space by safetensors](https://huggingface.co/spaces/safetensors/convert), and this code cannot deal with files that are not named as "pytorch_model.bin" or "pytorch_model-xxxx-of-xxxx.bin". + +### Limitations: + +The program requires **A lot** of memory. To be specific, your idle memory should be **at least** twice the size of your largest ".bin" file. Or else, the program will run out of memory and use your swap... that would be **slow!** + +This program **will not** re-shard (aka break down) the model, you'll need to do it yourself using some other tools. + +### Usage: + +After installing python (Python 3.10.x is suggested), ``cd`` into the repository and install dependencies first: + +``` +git clone https://github.com/Silver267/pytorch-to-safetensor-converter.git +cd pytorch-to-safetensor-converter +pip install -r requirements.txt +``` + +Copy **all content** of your model's folder into this repository, then run: + +``` +python convert_to_safetensor.py +``` +Follow the instruction in the program. Remember to use the **full path** for the model directory (Something like ``E:\models\xxx-fp16`` that contains all the model files). Wait for a while, and you're good to go. The program will automatically copy all other files to your destination folder, enjoy! + +### Precision stuff +if your original model is fp32 then don't forget to edit ``"torch_dtype": "float32",`` to ``"torch_dtype": "float16",`` in ``config.json`` +#### Note that this operation might (in rare occasions) cause the LLM to output NaN while performing operations since it decreases the precision to fp16. +If you're worried about that, simply edit the line ``loaded = {k: v.contiguous().half() for k, v in loaded.items()}`` in ``convert_to_safetensor.py`` into ``loaded = {k: v.contiguous() for k, v in loaded.items()}`` and you'll have a full precision model. diff --git a/convert_model/pytorch-to-safetensor-converter/convert_to_safetensor.py b/convert_model/pytorch-to-safetensor-converter/convert_to_safetensor.py new file mode 100644 index 0000000..e60e789 --- /dev/null +++ b/convert_model/pytorch-to-safetensor-converter/convert_to_safetensor.py @@ -0,0 +1,103 @@ +import json +import os +import shutil +import torch +from collections import defaultdict +from safetensors.torch import load_file, save_file +from tqdm import tqdm + +def shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + return [names for names in ptrs.values() if len(names) > 1] + +def check_file_size(sf_filename, pt_filename): + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError(f"File size difference exceeds 1% between {sf_filename} and {pt_filename}") + +def convert_file(pt_filename, sf_filename, copy_add_data=True): + source_folder = os.path.dirname(pt_filename) + dest_folder = os.path.dirname(sf_filename) + loaded = torch.load(pt_filename, map_location="cpu") + loaded = loaded.get("state_dict", loaded) + shared = shared_pointers(loaded) + + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + loaded = {k: v.contiguous().half() for k, v in loaded.items()} + + os.makedirs(dest_folder, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + check_file_size(sf_filename, pt_filename) + if copy_add_data: + copy_additional_files(source_folder, dest_folder) + + reloaded = load_file(sf_filename) + for k, v in loaded.items(): + if not torch.equal(v, reloaded[k]): + raise RuntimeError(f"Mismatch in tensors for key {k}.") + +def rename(pt_filename): + return pt_filename.replace("pytorch_model", "model").replace(".bin", ".safetensors") + +def copy_additional_files(source_folder, dest_folder): + for file in os.listdir(source_folder): + file_path = os.path.join(source_folder, file) + if os.path.isfile(file_path) and not (file.endswith('.bin') or file.endswith('.py')): + shutil.copy(file_path, dest_folder) + +def find_index_file(source_folder): + for file in os.listdir(source_folder): + if file.endswith('.bin.index.json'): + return file + return None + +def convert_files(source_folder, dest_folder, delete_old): + index_file = find_index_file(source_folder) + if not index_file: + raise RuntimeError("Index file not found. Please ensure the correct folder is specified.") + + index_file = os.path.join(source_folder, index_file) + with open(index_file) as f: + index_data = json.load(f) + + for pt_filename in tqdm(set(index_data["weight_map"].values())): + full_pt_filename = os.path.join(source_folder, pt_filename) + sf_filename = os.path.join(dest_folder, rename(pt_filename)) + convert_file(full_pt_filename, sf_filename, copy_add_data=False) + if delete_old: + os.remove(full_pt_filename) + + copy_additional_files(source_folder, dest_folder) + + index_path = os.path.join(dest_folder, "model.safetensors.index.json") + with open(index_path, "w") as f: + new_map = {k: rename(v) for k, v in index_data["weight_map"].items()} + json.dump({**index_data, "weight_map": new_map}, f, indent=4) + +def main(): + script_dir = os.path.dirname(os.path.realpath(__file__)) + + source_folder = input("Source folder for PyTorch files (leave blank for script's directory): ").strip() or script_dir + dest_folder = input("Destination folder for SafeTensors files (leave blank for default): ").strip() + + if not dest_folder: + model_name = os.path.basename(os.path.normpath(source_folder)) + dest_folder = os.path.join(source_folder, model_name + "_safetensors") + + delete_old = input("Delete old PyTorch files? (Y/N): ").strip().upper() == 'Y' + + if "pytorch_model.bin" in os.listdir(source_folder): + convert_file(os.path.join(source_folder, "pytorch_model.bin"), os.path.join(dest_folder, "model.safetensors"), copy_add_data=True) + if delete_old: + os.remove(os.path.join(source_folder, "pytorch_model.bin")) + else: + convert_files(source_folder, dest_folder, delete_old) + +if __name__ == "__main__": + main() diff --git a/convert_model/pytorch-to-safetensor-converter/requirements.txt b/convert_model/pytorch-to-safetensor-converter/requirements.txt new file mode 100644 index 0000000..6cc0b32 --- /dev/null +++ b/convert_model/pytorch-to-safetensor-converter/requirements.txt @@ -0,0 +1,4 @@ +safetensors +torch +tqdm +numpy diff --git a/eval/eval_checkpoint.py b/eval/eval_checkpoint.py new file mode 100644 index 0000000..ff94ae8 --- /dev/null +++ b/eval/eval_checkpoint.py @@ -0,0 +1,307 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +import sys +from pathlib import Path +from typing import Dict, List, Literal, Optional, Union + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision +from lm_eval import base, evaluator, tasks +from lm_eval.base import BaseLM +from multiprocessing import Pool + + +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import GPT, Config, Tokenizer +from lit_gpt.generate.base import generate +from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint +import glob +import os +import numpy as np +import time +from multiprocessing import Process, cpu_count + + +class EvalHarnessBase(BaseLM): + # Credits: + # https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py + def __init__(self, fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, batch_size: int): + super().__init__() + self.fabric = fabric + self.model = model + self.tokenizer = tokenizer + self.batch_size_per_gpu = batch_size + # with fabric.init_tensor(): + # model.set_kv_cache(batch_size=batch_size) + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config=None): + kwargs = {el.split("=")[0]: el.split("=")[1] for el in arg_string.split(",")} + return cls(**kwargs, **additional_config) + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_id + + @property + def max_length(self): + # return self.model.max_seq_length + return self.model.config.block_size + + @property + def vocab_size(self): + return self.tokenizer.vocab_size + + @property + def max_gen_toks(self): + return 256 + + @property + def batch_size(self): + return self.batch_size_per_gpu * self.fabric.world_size + + @property + def device(self): + return self.fabric.device + + def tok_encode(self, string: str) -> List[int]: + return self.tokenizer.encode(string, bos=False, eos=False).tolist() + + def tok_decode(self, tokens: List[int]) -> str: + t = torch.tensor(tokens) + return self.tokenizer.decode(t) + + @torch.inference_mode() + def _model_call(self, inps): + return self.model(inps) + + @torch.inference_mode() + def _model_generate(self, context, max_length, eos_token_id) -> torch.Tensor: + # this only supports batch size 1 + assert context.shape[0] == 1 + out = generate(self.model, context[0], max_length, eos_id=eos_token_id) + for block in self.model.transformer.h: + block.attn.kv_cache.reset_parameters() + return out.unsqueeze(0) + + @torch.inference_mode() + def run_eval_single( + self, eval_tasks: List[str], num_fewshot: int, limit: Optional[int], bootstrap_iters: int, no_cache: bool + ) -> Dict: + + print("global_rank:", self.fabric.global_rank) + # Returns a list containing all values of the task registry that + # match at least one of the patterns + import fnmatch + + def pattern_match(patterns, source_list): + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + return list(task_names) + + eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS) + print(f"Found tasks: {eval_tasks}") + + # **HACK INCOMING**: + # first get task dict on local main rank + # the tasks are downloaded *as they are initialized*, and the downloads don't like multithreading. + # so we download them once on the local main rank, wait, and then initialize them on all other ranks, which *should* load from the cache. + # if self.fabric.local_rank == 0: + tasks.get_task_dict(eval_tasks) + # torch barrier + # self.fabric.barrier() + # tasks.get_task_dict(eval_tasks) + + lm = self + if not no_cache: + lm = base.CachingLM(lm, "lm_cache/litgpt.db") + results = evaluator.evaluate( + lm=lm, + task_dict=tasks.get_task_dict(eval_tasks), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + ) + results["config"] = dict( + model=self.model.config.name, + batch_size=self.batch_size, + device=str(self.device), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + no_cache=no_cache, + ) + return results + + @torch.inference_mode() + def run_eval( + self, eval_tasks: List[str], num_fewshot: int, limit: Optional[int], bootstrap_iters: int, no_cache: bool + ) -> Dict: + # Returns a list containing all values of the task registry that + # match at least one of the patterns + import fnmatch + + def pattern_match(patterns, source_list): + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + return list(task_names) + + eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS) + print(f"Found tasks: {eval_tasks}") + + # **HACK INCOMING**: + # first get task dict on local main rank + # the tasks are downloaded *as they are initialized*, and the downloads don't like multithreading. + # so we download them once on the local main rank, wait, and then initialize them on all other ranks, which *should* load from the cache. + if self.fabric.local_rank == 0: + tasks.get_task_dict(eval_tasks) + # torch barrier + self.fabric.barrier() + tasks.get_task_dict(eval_tasks) + + lm = self + if not no_cache: + lm = base.CachingLM(lm, "lm_cache/litgpt.db") + results = evaluator.evaluate( + lm=lm, + task_dict=tasks.get_task_dict(eval_tasks), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + ) + results["config"] = dict( + model=self.model.config.name, + batch_size=self.batch_size, + device=str(self.device), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + no_cache=no_cache, + ) + return results + + +@torch.inference_mode() +def run_eval_harness( + checkpoint_dir: Path, + tokenizer_dir: Path, + precision: Optional[str] = None, + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, + eval_tasks: List[str] = ["arc_challenge", "piqa", "hellaswag", "hendrycksTest-*"], + save_filepath: Optional[Path] = None, + num_fewshot: int = 0, + limit: Optional[int] = None, + bootstrap_iters: int = 100000, + no_cache: bool = True, + devices: Union[int, list] = 1, +): + if precision is None: + precision = get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + fabric = L.Fabric(devices=devices, precision=precision, plugins=plugins) + + # check_valid_checkpoint_dir(checkpoint_dir) + tokenizer = Tokenizer(tokenizer_dir) + + model_name = "csg-wukong-1B" + config = Config.from_name(model_name) + + # config = Config.from_name(checkpoint_dir / "model_config.yaml") + # config = Config.from_file(checkpoint_dir / "model_config.yaml") + + # checkpoint_path = checkpoint_dir / "lit_model.pth" + # for i, checkpoint_path in enumerate(checkpoint_dir): + checkpoint_path = checkpoint_dir + print("checkpoint_path", checkpoint_path) + print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) + with fabric.init_module(empty_init=False): + model = GPT(config) + + model.eval() + + # model = fabric.setup(model) + model = fabric.setup_module(model) + # model = fabric.load(resume, state) + + # load_checkpoint(fabric, model, checkpoint_path) + fabric.load(checkpoint_path, {"model": model}) + + eval_harness = EvalHarnessBase(fabric, model, tokenizer, 128) + + results = eval_harness.run_eval(eval_tasks, num_fewshot, limit, bootstrap_iters, no_cache) + print(results) + if save_filepath is None: + # print(results) + pass + else: + print(f"Saving results to {str(save_filepath)!r}") + save_filepath.parent.mkdir(parents=True, exist_ok=True) + data = json.dumps(results) + with open(save_filepath, "w") as fw: + fw.write(data) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + # CLI(run_eval_harness) + + checkpoint_dir = Path("/data/train/csg-tiny-1B/out_20240217/tinyllama_1b") + tokenizer_dir = Path("/data/datasets/tokenizers/Llama2Tokenizer") + precision = "bf16-true" + eval_tasks = ["hellaswag", "openbookqa", "winogrande", "arc_challenge", "arc_easy", "boolq", "piqa"] + # eval_tasks = ["hellaswag"] + + save_filepath = "/data/train/csg-tiny-1B/out_20240217/result" + filenames = sorted(glob.glob(str(checkpoint_dir / f"{'iter-'}*"))) + # filenames = ["/data/train/csg-tiny-1B/out_20240217/tinyllama_1b/iter-920000-ckpt.pth"] + # filenames = sorted(glob.glob(checkpoint_dir + "/*", recursive=True), reverse=True) + # only retrain subsets that follow the prefix in filenames_subset + print("filenames:", filenames) + + # num_processes = 8 + # chunked_filenames = np.array_split(filenames, num_processes) + batch_size = 128 + processes = [] + start_time = time.time() + + # async def + for i, filename in enumerate(filenames): + # for j in range(i, i+8): + # subset = filenames[i] + # print("subset:", subset) + print("filename:", filename) + if 1: + save_filepath_now = save_filepath + filename.split("/data/train/csg-tiny-1B/out_20240217/tinyllama_1b")[-1].split(".pth")[0] + ".json" + print("save_filepath_now", save_filepath_now) + # os.environ["CUDA_VISIBLE_DEVICES"] = "1" + run_eval_harness(checkpoint_dir=filename, tokenizer_dir=tokenizer_dir, precision=precision, + eval_tasks=eval_tasks, save_filepath=Path(save_filepath_now), devices=[2]) + else: + continue + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + # print("checkpoint_dir:", checkpoint_dir) + # # output_dir = Path("/data_share/train/tinyllama") + + # run_eval_harness(checkpoint_dir=checkpoint_dir, tokenizer_dir=tokenizer_dir, + # precision=precision, eval_tasks=eval_tasks, + # save_filepath=save_filepath, batch_size=128) diff --git a/eval/lm_eval_harness.py b/eval/lm_eval_harness.py new file mode 100644 index 0000000..90c9672 --- /dev/null +++ b/eval/lm_eval_harness.py @@ -0,0 +1,264 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +import sys +from pathlib import Path +from typing import Dict, List, Literal, Optional + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision +from lm_eval import base, evaluator, tasks +from lm_eval.base import BaseLM + +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import GPT, Config, Tokenizer +from lit_gpt.generate.base import generate +from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint + + +class EvalHarnessBase(BaseLM): + # Credits: + # https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py + def __init__(self, fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, batch_size: int): + super().__init__() + self.fabric = fabric + self.model = model + self.tokenizer = tokenizer + self.batch_size_per_gpu = batch_size + # with fabric.init_tensor(): + # model.set_kv_cache(batch_size=batch_size) + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config=None): + kwargs = {el.split("=")[0]: el.split("=")[1] for el in arg_string.split(",")} + return cls(**kwargs, **additional_config) + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_id + + @property + def max_length(self): + # return self.model.max_seq_length + return self.model.config.block_size + + @property + def vocab_size(self): + return self.tokenizer.vocab_size + + @property + def max_gen_toks(self): + return 256 + + @property + def batch_size(self): + return self.batch_size_per_gpu * self.fabric.world_size + + @property + def device(self): + return self.fabric.device + + def tok_encode(self, string: str) -> List[int]: + return self.tokenizer.encode(string, bos=False, eos=False).tolist() + + def tok_decode(self, tokens: List[int]) -> str: + t = torch.tensor(tokens) + return self.tokenizer.decode(t) + + @torch.inference_mode() + def _model_call(self, inps): + return self.model(inps) + + @torch.inference_mode() + def _model_generate(self, context, max_length, eos_token_id) -> torch.Tensor: + # this only supports batch size 1 + assert context.shape[0] == 1 + out = generate(self.model, context[0], max_length, eos_id=eos_token_id) + for block in self.model.transformer.h: + block.attn.kv_cache.reset_parameters() + return out.unsqueeze(0) + + @torch.inference_mode() + def run_eval_single( + self, eval_tasks: List[str], num_fewshot: int, limit: Optional[int], bootstrap_iters: int, no_cache: bool + ) -> Dict: + + print("global_rank:", self.fabric.global_rank) + # Returns a list containing all values of the task registry that + # match at least one of the patterns + import fnmatch + + def pattern_match(patterns, source_list): + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + return list(task_names) + + eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS) + print(f"Found tasks: {eval_tasks}") + + # **HACK INCOMING**: + # first get task dict on local main rank + # the tasks are downloaded *as they are initialized*, and the downloads don't like multithreading. + # so we download them once on the local main rank, wait, and then initialize them on all other ranks, which *should* load from the cache. + # if self.fabric.local_rank == 0: + tasks.get_task_dict(eval_tasks) + # torch barrier + # self.fabric.barrier() + # tasks.get_task_dict(eval_tasks) + + lm = self + if not no_cache: + lm = base.CachingLM(lm, "lm_cache/litgpt.db") + results = evaluator.evaluate( + lm=lm, + task_dict=tasks.get_task_dict(eval_tasks), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + ) + results["config"] = dict( + model=self.model.config.name, + batch_size=self.batch_size, + device=str(self.device), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + no_cache=no_cache, + ) + return results + + @torch.inference_mode() + def run_eval( + self, eval_tasks: List[str], num_fewshot: int, limit: Optional[int], bootstrap_iters: int, no_cache: bool + ) -> Dict: + # Returns a list containing all values of the task registry that + # match at least one of the patterns + import fnmatch + + def pattern_match(patterns, source_list): + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + return list(task_names) + + eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS) + print(f"Found tasks: {eval_tasks}") + + # **HACK INCOMING**: + # first get task dict on local main rank + # the tasks are downloaded *as they are initialized*, and the downloads don't like multithreading. + # so we download them once on the local main rank, wait, and then initialize them on all other ranks, which *should* load from the cache. + if self.fabric.local_rank == 0: + tasks.get_task_dict(eval_tasks) + # torch barrier + self.fabric.barrier() + tasks.get_task_dict(eval_tasks) + + lm = self + if not no_cache: + lm = base.CachingLM(lm, "lm_cache/litgpt.db") + results = evaluator.evaluate( + lm=lm, + task_dict=tasks.get_task_dict(eval_tasks), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + write_out=True, + output_base_path="/data/project/TinyLlama/eval/out/" + ) + results["config"] = dict( + model=self.model.config.name, + batch_size=self.batch_size, + device=str(self.device), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + no_cache=no_cache, + ) + return results + + +@torch.inference_mode() +def run_eval_harness( + checkpoint_dir: Path, + precision: Optional[str] = None, + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, + eval_tasks: List[str] = ["arc_challenge", "piqa", "hellaswag", "hendrycksTest-*"], + save_filepath: Optional[Path] = None, + num_fewshot: int = 0, + limit: Optional[int] = None, + bootstrap_iters: int = 100000, + no_cache: bool = True, + devices: int = 1, +): + if precision is None: + precision = get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=devices, precision=precision, plugins=plugins) + + check_valid_checkpoint_dir(checkpoint_dir) + tokenizer = Tokenizer(checkpoint_dir) + + model_name = "csg-wukong-1B" + config = Config.from_name(model_name) + + # config = Config.from_name(checkpoint_dir / "model_config.yaml") + # config = Config.from_file(checkpoint_dir / "model_config.yaml") + + checkpoint_path = checkpoint_dir / "lit_model.pth" + print("checkpoint_path", checkpoint_path) + print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) + with fabric.init_module(empty_init=False): + model = GPT(config) + + model.eval() + + # model = fabric.setup(model) + model = fabric.setup_module(model) + # model = fabric.load(resume, state) + + # load_checkpoint(fabric, model, checkpoint_path) + fabric.load(checkpoint_path, {"model": model}) + + eval_harness = EvalHarnessBase(fabric, model, tokenizer, 1) + + results = eval_harness.run_eval(eval_tasks, num_fewshot, limit, bootstrap_iters, no_cache) + print(results) + if save_filepath is None: + # print(results) + pass + else: + print(f"Saving results to {str(save_filepath)!r}") + save_filepath.parent.mkdir(parents=True, exist_ok=True) + data = json.dumps(results) + with open(save_filepath, "w") as fw: + fw.write(data) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + # CLI(run_eval_harness) + checkpoint_dir = Path("/data/models/csg-tiny-1B/csg-tiny-1B-955K") + precision = "bf16-true" + # eval_tasks = ["hellaswag", "openbookqa", "winogrande", "arc_challenge", "arc_easy", "boolq", "piqa"] + eval_tasks = ["piqa"] + + save_filepath = checkpoint_dir / "eval.json" + run_eval_harness(checkpoint_dir=checkpoint_dir, precision=precision, + eval_tasks=eval_tasks, save_filepath=save_filepath) diff --git a/lit_gpt/__init__.py b/lit_gpt/__init__.py index a15c7f4..fde55a0 100644 --- a/lit_gpt/__init__.py +++ b/lit_gpt/__init__.py @@ -1,5 +1,9 @@ +import logging +import re + from lit_gpt.model import GPT from lit_gpt.config import Config +from lit_gpt.prompts import PromptStyle from lit_gpt.tokenizer import Tokenizer from lit_gpt.fused_cross_entropy import FusedCrossEntropyLoss from lightning_utilities.core.imports import RequirementCache @@ -16,5 +20,10 @@ f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}" ) +# Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632 +pattern = re.compile(".*Profiler function .* will be ignored") +logging.getLogger("torch._dynamo.variables.torch").addFilter(lambda record: not pattern.search(record.getMessage())) + +# __all__ = ["GPT", "Config", "Tokenizer"] +__all__ = ["GPT", "Config", "Tokenizer", "PromptStyle"] -__all__ = ["GPT", "Config", "Tokenizer"] diff --git a/lit_gpt/config.py b/lit_gpt/config.py index 1345eb6..10a26dc 100644 --- a/lit_gpt/config.py +++ b/lit_gpt/config.py @@ -113,6 +113,11 @@ def norm_class(self) -> Type: # EleutherAI Pythia #################### pythia = [ + + #https://huggingface.co/EleutherAI/pythia-14m/blob/main/config.json + dict(org="EleutherAI", name="pythia-14m", block_size=512, n_layer=6, n_embd=128, n_head=4, padding_multiple=128), + #https://huggingface.co/EleutherAI/pythia-31m/blob/main/config.json + dict(org="EleutherAI", name="pythia-31m", block_size=1024, n_layer=6, n_embd=256, n_head=8, padding_multiple=128), # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json dict(org="EleutherAI", name="pythia-70m", block_size=2048, n_layer=6, n_embd=512, n_head=8, padding_multiple=128), # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json @@ -303,6 +308,260 @@ def norm_class(self) -> Type: ] configs.extend(tiny_LLaMA) +############################# +# OpenCSG Algo Team +############################# +csg = [ + # https://huggingface.co/opencsg + dict( + org="OpenCSG", + name="csg-wukong-1B", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=22, + n_head=32, + n_embd=2048, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6 + _mlp_class="LLaMAMLP", + intermediate_size=5632, + n_query_groups=4, + ), + + dict( + org="OpenCSG", + name="csg-wukong-1B-deepseek", + block_size=2048, + vocab_size=102400, + padding_multiple=64, + n_layer=22, + n_head=32, + n_embd=2048, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6 + _mlp_class="LLaMAMLP", + intermediate_size=5632, + n_query_groups=4, + ), + + dict( + org="OpenCSG", + name="csg-wukong-1B-qwew", + block_size=2048, + vocab_size=151936, + padding_multiple=64, + n_layer=22, + n_head=32, + n_embd=2048, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6 + _mlp_class="LLaMAMLP", + intermediate_size=5632, + n_query_groups=4, + ), + + dict( + org="OpenCSG", + name="csg-wukong-1B-yi", + block_size=2048, + vocab_size=64000, + padding_multiple=64, + n_layer=22, + n_head=32, + n_embd=2048, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6 + _mlp_class="LLaMAMLP", + intermediate_size=5632, + n_query_groups=4, + ), + + dict( + org="OpenCSG", + name="csg_tiny_120M", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=12, + n_head=12, + n_embd=768, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=2048, + n_query_groups=1, + ), + + dict( + org="OpenCSG", + name="csg_tiny_120M-v2", + block_size=2048, + vocab_size=102400, + padding_multiple=64, + n_layer=12, + n_head=12, + n_embd=768, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=2048, + n_query_groups=1, + ), + + dict( + org="OpenCSG", + name="csg_tiny_120M-v3", + block_size=2048, + vocab_size=151936, + padding_multiple=64, + n_layer=12, + n_head=12, + n_embd=768, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=2048, + n_query_groups=1, + ), + dict( + org="OpenCSG", + name="csg_tiny_120M-v4", + block_size=2048, + vocab_size=64000, + padding_multiple=64, + n_layer=12, + n_head=12, + n_embd=768, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=2048, + n_query_groups=1, + ), + + dict( + org="OpenCSG", + name="csg_tiny_10M_llama", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=6, + n_head=4, + n_embd=128, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=768, + n_query_groups=1, + ), + + + dict( + org="OpenCSG", + name="csg_tiny_10M_yi", + block_size=2048, + vocab_size=64000, + padding_multiple=64, + n_layer=6, + n_head=4, + n_embd=72, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=576, + n_query_groups=1, + ), + + dict( + org="OpenCSG", + name="csg_tiny_30M_yi", + block_size=2048, + vocab_size=64000, + padding_multiple=64, + n_layer=6, + n_head=4, + n_embd=200, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=1048, + n_query_groups=1, + ), + + dict( + org="OpenCSG", + name="csg_tiny_70M_yi", + block_size=2048, + vocab_size=64000, + padding_multiple=64, + n_layer=8, + n_head=6, + n_embd=512, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=1536, + n_query_groups=1, + ), + + dict( + org="OpenCSG", + name="csg_tiny_88M_yi", + block_size=2048, + vocab_size=64000, + padding_multiple=64, + n_layer=6, + n_head=4, + n_embd=512, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="FusedRMSNorm", + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=2048, + n_query_groups=1, + ) + +] +configs.extend(csg) ############################# # OpenLM Research Open LLaMA diff --git a/lit_gpt/generate/__init__.py b/lit_gpt/generate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lit_gpt/generate/adapter.py b/lit_gpt/generate/adapter.py new file mode 100644 index 0000000..e6cf02d --- /dev/null +++ b/lit_gpt/generate/adapter.py @@ -0,0 +1,117 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import sys +import time +from pathlib import Path +from typing import Literal, Optional + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision + +from lit_gpt import PromptStyle, Tokenizer +from lit_gpt.adapter import GPT, Config +from lit_gpt.generate.base import generate +from lit_gpt.prompts import has_prompt_style, load_prompt_style +from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, lazy_load + + +def main( + prompt: str = "What food do llamas eat?", + input: str = "", + adapter_path: Path = Path("out/finetune/adapter/final/lit_model.pth.adapter"), + checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, + max_new_tokens: int = 100, + top_k: Optional[int] = 200, + temperature: float = 0.8, + precision: Optional[str] = None, +) -> None: + """Generates a response based on a given instruction and an optional input. This script will only work with + checkpoints from the instruction-tuned adapter model. See ``litgpt.finetune.adapter``. + + Args: + prompt: The prompt/instruction (Alpaca style). + input: Optional input (Alpaca style). + adapter_path: Path to the checkpoint with trained adapter weights, which are the output of + ``litgpt.finetune.adapter``. + checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + - bnb.int8: 8-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + precision: Indicates the Fabric precision setting to use. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + fabric.launch() + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_file(checkpoint_dir / "model_config.yaml") + + checkpoint_path = checkpoint_dir / "lit_model.pth" + + tokenizer = Tokenizer(checkpoint_dir) + prompt_style = ( + load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) + ) + + prompt = prompt_style.apply(prompt, input=input) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + t0 = time.perf_counter() + checkpoint = lazy_load(checkpoint_path) + adapter_checkpoint = lazy_load(adapter_path) + checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint)) + model.load_state_dict(checkpoint) + fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + + model = fabric.setup(model) + + L.seed_everything(1234) + t0 = time.perf_counter() + y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) + t = time.perf_counter() - t0 + + output = tokenizer.decode(y) + output = output.split("### Response:")[1].strip() + fabric.print(output) + + tokens_generated = y.size(0) - prompt_length + fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/lit_gpt/generate/adapter_v2.py b/lit_gpt/generate/adapter_v2.py new file mode 100644 index 0000000..e5bacb2 --- /dev/null +++ b/lit_gpt/generate/adapter_v2.py @@ -0,0 +1,117 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import sys +import time +from pathlib import Path +from typing import Literal, Optional + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision + +from lit_gpt import PromptStyle, Tokenizer +from lit_gpt.adapter_v2 import GPT, Config +from lit_gpt.generate.base import generate +from lit_gpt.prompts import has_prompt_style, load_prompt_style +from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, lazy_load + + +def main( + prompt: str = "What food do llamas eat?", + input: str = "", + adapter_path: Path = Path("out/finetune/adapter-v2/final/lit_model.pth.adapter_v2"), + checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, + max_new_tokens: int = 100, + top_k: Optional[int] = 200, + temperature: float = 0.8, + precision: Optional[str] = None, +) -> None: + """Generates a response based on a given instruction and an optional input. This script will only work with + checkpoints from the instruction-tuned adapter v2 model. See ``litgpt.finetune.adapter_v2``. + + Args: + prompt: The prompt/instruction (Alpaca style). + input: Optional input (Alpaca style). + adapter_path: Path to the checkpoint with trained adapter weights, which are the output of + ``litgpt.finetune.adapter_v2``. + checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + - bnb.int8: 8-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + precision: Indicates the Fabric precision setting to use. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + fabric.launch() + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_file(checkpoint_dir / "model_config.yaml") + + checkpoint_path = checkpoint_dir / "lit_model.pth" + + tokenizer = Tokenizer(checkpoint_dir) + prompt_style = ( + load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) + ) + + prompt = prompt_style.apply(prompt, input=input) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + t0 = time.perf_counter() + checkpoint = lazy_load(checkpoint_path) + adapter_checkpoint = lazy_load(adapter_path) + checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint)) + model.load_state_dict(checkpoint) + fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + + model = fabric.setup(model) + + L.seed_everything(1234) + t0 = time.perf_counter() + y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) + t = time.perf_counter() - t0 + + output = tokenizer.decode(y) + output = output.split("### Response:")[1].strip() + fabric.print(output) + + tokens_generated = y.size(0) - prompt_length + fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/lit_gpt/generate/base.py b/lit_gpt/generate/base.py new file mode 100644 index 0000000..b67e444 --- /dev/null +++ b/lit_gpt/generate/base.py @@ -0,0 +1,196 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import sys +import time +from pathlib import Path +from typing import Any, Literal, Optional + +import lightning as L +import torch +import torch._dynamo.config +import torch._inductor.config +from lightning.fabric.plugins import BitsandbytesPrecision + +from lit_gpt import GPT, Config, PromptStyle, Tokenizer +from lit_gpt.prompts import has_prompt_style, load_prompt_style +from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint + + +def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor: + if torch._dynamo.is_compiling(): + # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly + distribution = torch.empty_like(probs).exponential_(1) + return torch.argmax(probs / distribution, dim=-1, keepdim=True) + return torch.multinomial(probs, num_samples=1) + + +def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor: + logits = logits[0, -1] + # optionally crop the logits to only the top k options + if top_k is not None: + v, i = torch.topk(logits, min(top_k, logits.size(-1))) + # do not use `torch.where` as in nanogpt because it will repeat top-k collisions + logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v) + # optionally scale the logits and sample from a probability distribution + if temperature > 0.0: + probs = torch.nn.functional.softmax(logits / temperature, dim=-1) + return multinomial_num_samples_1(probs) + return torch.argmax(logits, dim=-1, keepdim=True) + + +def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + logits = model(x, input_pos) + next = sample(logits, **kwargs) + return next.to(dtype=x.dtype) + + +@torch.inference_mode() +def generate( + model: GPT, + prompt: torch.Tensor, + max_returned_tokens: int, + *, + temperature: float = 1.0, + top_k: Optional[int] = None, + eos_id: Optional[int] = None, +) -> torch.Tensor: + """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + + The implementation of this function is modified from A. Karpathy's nanoGPT. + + Args: + model: The model to use. + prompt: Tensor of shape (T) with indices of the prompt sequence. + max_returned_tokens: The maximum number of tokens to return (given plus generated). + temperature: Scales the predicted logits by 1 / temperature. + top_k: If specified, only sample among the tokens with the k highest probabilities. + eos_id: If specified, stop generating any more token once the token is triggered. + """ + T = prompt.size(0) + assert max_returned_tokens > T + if model.max_seq_length < max_returned_tokens - 1: + # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a + # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do + # not support it to avoid negatively impacting the overall speed + raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}") + + device = prompt.device + print("generate device.......................:", device) + tokens = [prompt] + input_pos = torch.tensor([T], device=device) + token = next_token( + model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k + ).clone() + tokens.append(token) + for _ in range(2, max_returned_tokens - T + 1): + token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k).clone() + tokens.append(token) + if token == eos_id: + break + input_pos = input_pos.add_(1) + return torch.cat(tokens) + + +@torch.inference_mode() +def main( + prompt: str = "What food do llamas eat?", + *, + num_samples: int = 1, + max_new_tokens: int = 50, + top_k: Optional[int] = 200, + temperature: float = 0.8, + checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, + precision: Optional[str] = None, + compile: bool = False, +) -> None: + """Generates text samples based on a pre-trained model and tokenizer. + + Args: + prompt: The prompt string to use for generating the samples. + num_samples: The number of text samples to generate. + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + checkpoint_dir: The checkpoint directory to load. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + - bnb.int8: 8-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md + precision: Indicates the Fabric precision setting to use. + compile: Whether to compile the model. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_file(checkpoint_dir / "model_config.yaml") + + checkpoint_path = checkpoint_dir / "lit_model.pth" + + tokenizer = Tokenizer(checkpoint_dir) + prompt_style = ( + load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) + ) + + prompt = prompt_style.apply(prompt) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + if compile: + torch._dynamo.config.automatic_dynamic_shapes = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.coordinate_descent_tuning = True + global next_token + next_token = torch.compile(next_token, mode="reduce-overhead") + + model = fabric.setup_module(model) + + t0 = time.perf_counter() + load_checkpoint(fabric, model, checkpoint_path) + fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + + L.seed_everything(1234) + for i in range(num_samples): + t0 = time.perf_counter() + y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) + t = time.perf_counter() - t0 + for block in model.transformer.h: + block.attn.kv_cache.reset_parameters() + fabric.print(tokenizer.decode(y)) + tokens_generated = y.size(0) - prompt_length + fabric.print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr + ) + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/lit_gpt/generate/full.py b/lit_gpt/generate/full.py new file mode 100644 index 0000000..e602e6e --- /dev/null +++ b/lit_gpt/generate/full.py @@ -0,0 +1,113 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import sys +import time +from pathlib import Path +from typing import Literal, Optional + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision + +from litgpt import GPT, Config, PromptStyle, Tokenizer +from litgpt.generate.base import generate +from litgpt.prompts import has_prompt_style, load_prompt_style +from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint + + +def main( + prompt: str = "What food do llamas eat?", + input: str = "", + finetuned_path: Path = Path("out/full/alpaca/lit_model_finetuned.pth"), + checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, + max_new_tokens: int = 100, + top_k: Optional[int] = 200, + temperature: float = 0.8, + precision: Optional[str] = None, +) -> None: + """Generates a response based on a given instruction and an optional input. This script will only work with + checkpoints from the instruction-tuned GPT model. See ``litgpt.finetune.full``. + + Args: + prompt: The prompt/instruction (Alpaca style). + input: Optional input (Alpaca style). + finetuned_path: Path to the checkpoint with trained weights, which are the output of + ``litgpt.finetune.full``. + checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + - bnb.int8: 8-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + precision: Indicates the Fabric precision setting to use. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + fabric.launch() + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_file(checkpoint_dir / "model_config.yaml") + + checkpoint_path = finetuned_path + + tokenizer = Tokenizer(checkpoint_dir) + prompt_style = ( + load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) + ) + + prompt = prompt_style.apply(prompt, input=input) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + model = fabric.setup(model) + + t0 = time.perf_counter() + load_checkpoint(fabric, model, checkpoint_path) + fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + + L.seed_everything(1234) + t0 = time.perf_counter() + y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) + t = time.perf_counter() - t0 + + output = tokenizer.decode(y) + output = output.split("### Response:")[1].strip() + fabric.print(output) + + tokens_generated = y.size(0) - prompt_length + fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/lit_gpt/generate/sequentially.py b/lit_gpt/generate/sequentially.py new file mode 100644 index 0000000..cce1d8d --- /dev/null +++ b/lit_gpt/generate/sequentially.py @@ -0,0 +1,226 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import itertools +import logging +import re +import sys +import time +from collections import OrderedDict +from functools import partial +from pathlib import Path +from typing import Literal, Optional + +import lightning as L +import torch +from lightning.fabric.accelerators import CUDAAccelerator +from lightning.fabric.plugins import BitsandbytesPrecision +from lightning.fabric.utilities.init import _materialize_meta_tensors +from typing_extensions import Type + +import litgpt.generate.base as generate_base +from litgpt import GPT, Config, Tokenizer +from litgpt.model import Block, build_mask_cache +from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision + + +@torch.inference_mode() +def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int): + if model.config.n_layer % devices: + # TODO: support smarter partitioning schemes + raise NotImplementedError( + f"Only balanced partitioning is implemented: n_layer={model.config.n_layer}, devices {devices}" + ) + layers_per_rank = model.config.n_layer // devices + # dictates where each block should be instantiated + mapping = layer_to_device(model, chunk_on=Block, chunk_size=layers_per_rank) + + # materialize each block on the appropriate device + for path, target_index in mapping.items(): + submodule = model.get_submodule(path) + target_device = torch.device(root.type, target_index) + print(f"Moving {path!r} to {target_device}", file=sys.stderr) + # submodules loaded by the checkpoint will be on CPU (if no quantization). move them + replace_device(submodule, replace=torch.device("cpu"), by=target_device) + # in case the checkpoint was partial, materialize leftover metas + _materialize_meta_tensors(submodule, target_device) + # and build the kv cache + submodule.attn.kv_cache = submodule.attn.build_kv_cache(1, max_seq_length, model.cos.size(-1), target_device) + # rebuild odd ends + with root: + model.max_seq_length = max_seq_length + # the rope cache which is on meta device + model.cos, model.sin = model.rope_cache() + # the mask cache which cannot be created with `set_kv_cache` because that will set it for all layers + model.mask_cache = build_mask_cache(max_seq_length) + # and everything that is not a block in the root + _materialize_meta_tensors(model, root) + replace_device(model, replace=torch.device("cpu"), by=root) + + if devices > 1: + # install hooks to move layer inputs/output between devices + for layer_num, (path, target_index) in enumerate(mapping.items()): + submodule = model.get_submodule(path) + if layer_num >= layers_per_rank: + # we need to move the block input on the boundaries between devices + # and also on every non-root device because the RoPE and mask cache is shared + # TODO: the second case could be optimized and then we would only need this hook for + # `layer_num in [layers_per_rank * i - 1 for i in range(1, devices + 1)]` + target_device = torch.device(root.type, target_index) + submodule.register_forward_pre_hook(partial(move_block_input, target_device)) + if layer_num == model.config.n_layer - 1: + submodule.register_forward_hook(partial(move_block_output, root)) + + return model + + +def layer_to_device( + module: torch.nn.Module, chunk_on: Type[torch.nn.Module], chunk_size: int +) -> "OrderedDict[str, int]": + """Create a mapping from layer (block) to device.""" + # this assumes that the definition order is the same as the execution order + hits = [name for name, submodule in module.named_modules() if isinstance(submodule, chunk_on)] + return OrderedDict((name, i // chunk_size) for i, name in enumerate(hits)) + + +def move_block_input(device: torch.device, module: torch.nn.Module, ins): + """``forward_pre_hook`` to move a Block's input before forward.""" + # during inference, none of the inputs are None: x, cos, sin, mask, input_pos + return tuple(t.to(device) for t in ins) + + +def move_block_output(device: torch.device, module: torch.nn.Module, ins, outs) -> torch.Tensor: + """``forward_hook`` to move a Block's output after forward.""" + return outs.to(device) + + +def replace_device(module: torch.nn.Module, replace: torch.device, by: torch.device) -> torch.nn.Module: + for name, submodule in module.named_modules(): + tensors = dict( + itertools.chain(submodule.named_parameters(recurse=False), submodule.named_buffers(recurse=False)) + ) + if not tensors: + continue + devices = {t.device for t in tensors.values()} + if len(devices) != 1: + # since this is using `submodule.to`, different devices in the same submodule is a problem + path_to_device = {f"{name}.{p}": t.device for p, t in tensors.items()} + raise ValueError(f"Found multiple devices: {path_to_device}") + if devices.pop() == replace: + submodule.to(by) + return module + + +@torch.inference_mode() +def main( + prompt: str = "What food do llamas eat?", + *, + num_samples: int = 1, + max_new_tokens: int = 50, + top_k: Optional[int] = 200, + temperature: float = 0.8, + checkpoint_dir: Path = Path("checkpoints/mistralai/Mistral-7B-Instruct-v0.1"), + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None, + precision: Optional[str] = None, + compile: bool = False, +) -> None: + """Generates text samples based on a pre-trained model and tokenizer. + + Args: + prompt: The prompt string to use for generating the samples. + num_samples: The number of text samples to generate. + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + checkpoint_dir: The checkpoint directory to load. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md + precision: Indicates the Fabric precision setting to use. + compile: Whether to compile the model. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None: + if compile: + raise NotImplementedError # untested + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + logging.getLogger("lightning.fabric.plugins.precision.bitsandbytes").setLevel(logging.DEBUG) + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, accelerator="cuda", plugins=plugins) + + total_devices = CUDAAccelerator.auto_device_count() + print(f"Using {total_devices} devices", file=sys.stderr) + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_file(checkpoint_dir / "model_config.yaml") + + checkpoint_path = checkpoint_dir / "lit_model.pth" + + tokenizer = Tokenizer(checkpoint_dir) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) + t0 = time.perf_counter() + # cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced + # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert + # still, use init_tensor for the precision + with fabric.init_tensor(), torch.device("meta"): + model = GPT(config) + print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + + t0 = time.perf_counter() + state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu") + # TODO: this assumes that the model fits on CPU. Use lazy_load and make the materialization checkpoint aware + model.load_state_dict(state_dict, assign=True) + print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + + model = fabric.setup_module(model, move_to_device=False) + + t0 = time.perf_counter() + model = sequential(model, fabric.device, max_returned_tokens, total_devices) + print(f"Time to sequential-ize the model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + + if compile: + # TODO: raises an internal compile AssertionError caused by fabric.strategy.precision.forward_context + raise NotImplementedError + # silence developer warning on nightly builds + # https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/ir.py#L4166 + pattern = re.compile(".*DeviceCopy in input program.*") + logging.getLogger("torch._inductor.utils").addFilter(lambda record: not pattern.search(record.getMessage())) + torch._dynamo.config.automatic_dynamic_shapes = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.coordinate_descent_tuning = True + # cannot use cudagraphs because it doesn't support multiple device indices + # https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/compile_fx.py#L371-L375 + generate_base.next_token = torch.compile(generate_base.next_token) + + L.seed_everything(1234) + for i in range(num_samples): + t0 = time.perf_counter() + y = generate_base.generate( + model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id + ) + t = time.perf_counter() - t0 + for block in model.transformer.h: + block.attn.kv_cache.reset_parameters() + print(tokenizer.decode(y)) + tokens_generated = y.size(0) - prompt_length + print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr + ) + print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/lit_gpt/generate/tp.py b/lit_gpt/generate/tp.py new file mode 100644 index 0000000..3c6c8da --- /dev/null +++ b/lit_gpt/generate/tp.py @@ -0,0 +1,220 @@ +"""Tensor-parallel implementation adapted from https://github.com/pytorch-labs/gpt-fast/blob/14df27/tp.py""" + +import logging +import sys +import time +from functools import partial +from pathlib import Path +from typing import Literal, Optional, Union + +import lightning as L +import torch +import torch._dynamo.config +import torch._inductor.config +from lightning.fabric.plugins import BitsandbytesPrecision +from lightning.fabric.utilities import rank_zero_only +from torch.distributed._functional_collectives import all_reduce + +import litgpt.generate.base as generate_base +from litgpt import GPT, Config, Tokenizer +from litgpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE +from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision + + +def tensor_parallel_linear(fabric: L.Fabric, linear: torch.nn.Linear, style: str) -> None: + world_size = fabric.world_size + dim, attr = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")}[style] + size = getattr(linear, attr) + if size % world_size != 0: + raise ValueError( + f"This linear's {attr} value ({size}) is not evenly divisible by the world size ({world_size})" + ) + + shard = torch.tensor_split(linear.weight, world_size, dim=dim)[fabric.global_rank] + # overwrite `.data` instead of recreating the parameter for quantization (bitsandbytes) support. + # the bitsandbytes linear classes use custom `torch.nn.Parameter` subclasses + linear.weight.data = shard + setattr(linear, attr, shard.size(dim)) + + if linear.bias is not None and dim == 0: + shard = torch.tensor_split(linear.bias, world_size)[fabric.global_rank] + linear.bias = torch.nn.Parameter(shard, requires_grad=linear.bias.requires_grad) + + +def tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMAMoE]) -> None: + if isinstance(mlp, LLaMAMLP): + tensor_parallel_linear(fabric, mlp.fc_1, "colwise") + tensor_parallel_linear(fabric, mlp.fc_2, "colwise") + tensor_parallel_linear(fabric, mlp.proj, "rowwise") + mlp.register_forward_hook(partial(all_reduce_output, fabric.world_size)) + elif isinstance(mlp, GptNeoxMLP): + tensor_parallel_linear(fabric, mlp.fc, "colwise") + tensor_parallel_linear(fabric, mlp.proj, "rowwise") + mlp.register_forward_hook(partial(all_reduce_output, fabric.world_size)) + elif isinstance(mlp, LLaMAMoE): + # we use expert slicing across ranks, alternatively, we could create a expert parallelism group + # when the number of experts is a multiple of the world size + for expert in mlp.experts: + tensor_parallel_mlp(fabric, expert) + else: + raise NotImplementedError + + +def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None: + tensor_parallel_linear(fabric, attn.attn, "colwise") + tensor_parallel_linear(fabric, attn.proj, "rowwise") + attn.register_forward_hook(partial(all_reduce_output, fabric.world_size)) + + +def all_reduce_output(world_size: int, module: torch.nn.Module, ins, outs) -> torch.Tensor: + return all_reduce(outs, "sum", list(range(world_size))) + + +def tensor_parallel(fabric: L.Fabric, model: GPT) -> GPT: + for block in model.transformer.h: + tensor_parallel_mlp(fabric, block.mlp) + tensor_parallel_attn(fabric, block.attn) + + # update the config values to the shard sizes + # this is only relevant for `tensor_parallel_attn`, but it needs to run only once + world_size = fabric.world_size + attrs = ["n_head", "n_embd", "n_query_groups"] + for attr in attrs: + size = getattr(model.config, attr) + if size % world_size != 0: + raise ValueError(f"This {attr} value ({size}) is not evenly divisible by the world size ({world_size})") + setattr(model.config, attr, size // world_size) + + return model + + +@torch.inference_mode() +def main( + prompt: str = "What food do llamas eat?", + *, + num_samples: int = 1, + max_new_tokens: int = 50, + top_k: Optional[int] = 200, + temperature: float = 0.8, + checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None, + precision: Optional[str] = None, + compile: bool = False, +) -> None: + """Generates text samples based on a pre-trained model and tokenizer. + + Args: + prompt: The prompt string to use for generating the samples. + num_samples: The number of text samples to generate. + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + checkpoint_dir: The checkpoint directory to load. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md + precision: Indicates the Fabric precision setting to use. + compile: Whether to compile the model. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None: + if compile: + raise NotImplementedError # untested + if "mixed" in precision: + raise ValueError("Quantization and mixed precision is not supported.") + dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] + bnb_logger = logging.getLogger("lightning.fabric.plugins.precision.bitsandbytes") + bnb_logger.setLevel(logging.DEBUG) + bnb_logger.debug = rank_zero_only(bnb_logger.debug) + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + # set "ddp" as the strategy for the launching functionality, but there's no data-parallelism + fabric = L.Fabric(devices="auto", strategy="ddp", precision=precision, plugins=plugins) + fabric.launch() + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_file(checkpoint_dir / "model_config.yaml") + + model_file = "lit_model.pth" + checkpoint_path = checkpoint_dir / model_file + + tokenizer = Tokenizer(checkpoint_dir) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) + t0 = time.perf_counter() + # cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced + # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert + # still, use init_tensor for the precision + with fabric.init_tensor(), torch.device("meta"): + model = GPT(config) + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + + # sequentially do: load the checkpoint on CPU -> quantize -> apply tp -> move to device + # so that the CPU RAM doesn't OOM with larger models + for rank in range(fabric.world_size): + if fabric.global_rank == rank: + t0 = time.perf_counter() + state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu") + model.load_state_dict(state_dict, assign=True) + print(f"[{rank}] Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + + # cannot use `.setup_module` because it will wrap with DDP + model = fabric._precision.convert_module(model) + + t0 = time.perf_counter() + model = tensor_parallel(fabric, model) + print( + f"[{rank}] Time to tensor-parallelize the model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # the rope cache which is on meta device + model.cos, model.sin = model.rope_cache() + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + t0 = time.perf_counter() + model = fabric.to_device(model) + print(f"[{rank}] Time to move the model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) + fabric.barrier() + + if compile: + torch._dynamo.config.automatic_dynamic_shapes = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.coordinate_descent_tuning = True + generate_base.next_token = torch.compile(generate_base.next_token, mode="reduce-overhead") + + L.seed_everything(1234) + for i in range(num_samples): + t0 = time.perf_counter() + y = generate_base.generate( + model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id + ) + t = time.perf_counter() - t0 + for block in model.transformer.h: + block.attn.kv_cache.reset_parameters() + fabric.print(tokenizer.decode(y)) + tokens_generated = y.size(0) - prompt_length + fabric.print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr + ) + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/lit_gpt/packed_dataset.py b/lit_gpt/packed_dataset.py index 1b4b7dc..832bf92 100644 --- a/lit_gpt/packed_dataset.py +++ b/lit_gpt/packed_dataset.py @@ -232,4 +232,4 @@ def __init__(self, datasets, seed, weights): def __next__(self): (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) - return next(dataset) + return next(dataset) \ No newline at end of file diff --git a/lit_gpt/prompts.py b/lit_gpt/prompts.py new file mode 100644 index 0000000..00f7cd9 --- /dev/null +++ b/lit_gpt/prompts.py @@ -0,0 +1,358 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import importlib +import re +from abc import abstractmethod +from json import dumps +from pathlib import Path +from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union + +import yaml + +from lit_gpt.config import Config + +if TYPE_CHECKING: + from lit_gpt import Tokenizer + + +class PromptStyle: + """Base interface for prompt styles.""" + + @abstractmethod + def apply(self, prompt: str, **kwargs: str) -> str: + return prompt + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + return ([tokenizer.eos_id],) + + @classmethod + def from_name(cls, name: str) -> "PromptStyle": + return prompt_styles[name]() + + @classmethod + def from_config(cls, config: Config) -> "PromptStyle": + return model_name_to_prompt_style(config.name) + + +class Default(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return prompt + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + return ([tokenizer.eos_id],) + + +class Alpaca(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + if kwargs.get("input"): + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{prompt}\n\n### Input:\n{kwargs['input']}\n\n### Response:\n" + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{prompt}\n\n### Response:\n" + ) + + +class FLAN(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{prompt}\n\n### Response:\n" + ) + + +class Longform(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{prompt}\n\n### Response:\n" + ) + + +class StableLMAlpha(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return ( + "<|SYSTEM|># StableLM Tuned (Alpha version)\n- StableLM is a helpful and harmless open-source AI language" + " model developed by StabilityAI.\n- StableLM is excited to be able to help the user, but will refuse to do" + " anything that could be considered harmful to the user.\n- StableLM is more than just an information" + " source, StableLM is also able to write poetry, short stories, and make jokes.\n- StableLM will refuse to" + f" participate in anything that could harm a human.<|USER|>{prompt}<|ASSISTANT|>" + ) + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + return ( + [tokenizer.eos_id], + [tokenizer.token_to_id("<|SYSTEM|>")], + [tokenizer.token_to_id("<|ASSISTANT|>")], + [tokenizer.token_to_id("<|USER|>")], + ) + + +class StableLMZephyr(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"<|user|>\n{prompt}<|endoftext|>\n<|assistant|>\n" + + +class TogetherComputerChat(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f": {prompt}\n:" + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + lt, gt = tokenizer.token_to_id("<"), tokenizer.token_to_id(">:") + return ( + [tokenizer.eos_id], + # annoyingly, there's no single stop token for these + [lt, tokenizer.token_to_id("human"), gt], + [lt, tokenizer.token_to_id("bot"), gt], + ) + + +class TogetherComputerInstruct(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"Q: {prompt}\nA:" + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + colon = tokenizer.token_to_id(":") + return ( + [tokenizer.eos_id], + # annoyingly, there's no single stop token for these + [tokenizer.token_to_id("Q"), colon], + [tokenizer.token_to_id("Question")], + [tokenizer.token_to_id("A"), colon], + [tokenizer.token_to_id("Label"), colon], + [187, 187], # '\n', '\n' + [535], # '\n\n' + [2756], # '\n\n\n' + ) + + +class Falcon(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + # First line could be modified. AFAIK Falcon doesn't impose a specific system prompt + # The instruction to not prefix its replies doesn't work always, but better than nothing + # I've also tried just "{prompt}\n" but the model seems to ramble more often + return f"Do not prefix your replies with 'Bot: '\nUser: {prompt}\n" + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + return ( + [tokenizer.eos_id], + # the model rarely emits the eos token and instead outputs newlines, but we cannot use them + # to stop or else things like code generation wouldn't work + [tokenizer.token_to_id("User"), tokenizer.token_to_id(":")], + [193, tokenizer.token_to_id("User")], # 193: '\n' + ) + + +class Vicuna(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + # https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template + return ( + "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, " + f"detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:" + ) + + +class Llama2FunctionCalling(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + # Has to be before the llama config + b_func, e_func = "", "\n\n" + b_inst, e_inst = "[INST]", "[/INST]" + b_sys, e_sys = "<>\n", "\n<>\n\n" + # This is an example for how to format functions for the model + function_metadata = { + "function": "search_bing", + "description": ( + "Search the web for content on Bing. This allows users to search online/the internet/the web for" + " content." + ), + "arguments": [{"name": "query", "type": "string", "description": "The search query string"}], + } + + system_prompt = ( + "You are a helpful, respectful and honest assistant. Always answer as helpfully as" + "possible. Your only response should be JSON formatted functions" + ) + # replace the curly braces with double curly braces to escape them + function_list = dumps(function_metadata).replace("{", "{{").replace("}", "}}") + return ( + f"{b_func}{function_list.strip()}{e_func}{b_inst}{b_sys}" + f"{system_prompt.strip()}" + f"{e_sys}{prompt}{e_inst}\n\n" + ) + + +class Llama2(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + b_inst, e_inst = "[INST]", "[/INST]" + b_sys, e_sys = "<>\n", "\n<>\n\n" + return ( + f"{b_inst} {b_sys}You are a helpful, respectful and honest assistant. Always answer as helpfully as" + " possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist," + " toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and" + " positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why" + " instead of answering something not correct. If you don't know the answer to a question, please don't" + f" share false information.{e_sys} {prompt} {e_inst} " + ) + + +class FreeWilly2(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return ( + "### System:\nThis is a system prompt, please behave and help the user.\n\n" + "### User:\n" + f"{prompt}\n\n" + "### Assistant:\n" + ) + + +class Platypus(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"### Instruction:\n\n{prompt}\n\n### Response:\n" + + +class NousResearch(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"### Instruction:\n{prompt}\n\n### Response:\n" + + +class StableCode(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"###Instruction\n{prompt}###Response\n" + + +class CodeLlama(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + # for CodeLLama, we don't set a default system prompt, but it is supported: + # https://huggingface.co/blog/codellama#conversational-instructions + # Mistral does not: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format + b_inst, e_inst = "[INST]", "[/INST]" + return f"{b_inst} {prompt} {e_inst}" + + +class Phi1(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"{prompt}\n\nAnswer:" + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + return ( + [tokenizer.eos_id], + [tokenizer.token_to_id("Answer"), tokenizer.token_to_id(":")], + [198, tokenizer.token_to_id("Answer"), tokenizer.token_to_id(":")], + # the model rarely emits the eos token and instead outputs newlines, but we cannot use them + # to stop or else things like code generation wouldn't work + # [198, 198], # '\n', '\n' + ) + + +class Phi2(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"Instruct:{prompt}\nOutput:" + + +class TinyLlama(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return ( + "<|system|>\n" + "You are a friendly chatbot who always gives helpful, detailed, and polite answers.\n" + "<|user|>\n" + f"{prompt}\n" + "<|assistant|>\n" + ) + + +class Gemma(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"user\n{prompt}\nmodel\n" + + +# Maps prompt style names to PromptStyle classes +prompt_styles: Dict[str, Type[PromptStyle]] = { + # Dataset-specific prompt styles + "alpaca": Alpaca, + "flan": FLAN, + "longform": Longform, + # Model-specific prompt styles + "stablelm-alpha": StableLMAlpha, + "stablelm-zephyr": StableLMZephyr, + "togethercomputer-chat": TogetherComputerChat, + "togethercomputer-instruct": TogetherComputerInstruct, + "falcon": Falcon, + "vicuna": Vicuna, + "llama2-function-calling": Llama2FunctionCalling, + "llama2": Llama2, + "freewilly2": FreeWilly2, + "platypus": Platypus, + "nous-research": NousResearch, + "stablecode": StableCode, + "codellama": CodeLlama, + "phi-1": Phi1, + "phi-2": Phi2, + "tinyllama": TinyLlama, + "gemma": Gemma, +} + + +def model_name_to_prompt_style(model_name: str) -> PromptStyle: + if re.search(r"stablelm-tuned-alpha", model_name): + return StableLMAlpha() + if re.search(r"stablelm-zephyr-3b", model_name): + return StableLMZephyr() + if re.search("stablecode-instruct", model_name): + return StableCode() + if re.search(r"RedPajama-INCITE.*-Chat", model_name): + return TogetherComputerChat() + if re.search(r"RedPajama-INCITE.*-Instruct", model_name): + return TogetherComputerInstruct() + if re.search(r"falcon.*-instruct", model_name): + return Falcon() + if re.search(r"vicuna|longchat", model_name): + return Vicuna() + if re.search("Llama-2-7b-chat-hf-function-calling-v2", model_name): + return Llama2FunctionCalling() + if re.search("Llama-2.*-chat", model_name): + return Llama2() + if re.search("FreeWilly2", model_name): + return FreeWilly2() + if re.search("Platypus", model_name): + return Platypus() + if re.search("Nous-Hermes", model_name): + return NousResearch() + if re.search("CodeLlama|Mistral.*Instruct", model_name): + return CodeLlama() + if re.search("phi-1", model_name): + return Phi1() + if re.search("phi-2", model_name): + return Phi2() + if re.search(r"tiny-llama.*chat", model_name): + return TinyLlama() + if re.search(r"Gemma.*-it", model_name): + return Gemma() + return Default() + + +def save_prompt_style(style: Union[str, PromptStyle], checkpoint_dir: Path) -> None: + style = PromptStyle.from_name(style) if isinstance(style, str) else style + cls = type(style) + # Allow saving the full module path for user-defined prompt classes + config = {"class_path": f"{cls.__module__}.{cls.__name__}"} + with open(checkpoint_dir / "prompt_style.yaml", "w") as file: + yaml.dump(config, file) + + +def load_prompt_style(checkpoint_dir: Path) -> PromptStyle: + with open(checkpoint_dir / "prompt_style.yaml", "r") as file: + config = yaml.safe_load(file) + # Support loading the full module path for user-defined prompt classes + full_module_path, cls_name = config["class_path"].rsplit(".", 1) + module = importlib.import_module(full_module_path) + cls = getattr(module, cls_name) + return cls() + + +def has_prompt_style(checkpoint_dir: Path) -> bool: + return (checkpoint_dir / "prompt_style.yaml").is_file() diff --git a/lit_gpt/speed_monitor.py b/lit_gpt/speed_monitor.py index 178e0d7..f1e73aa 100644 --- a/lit_gpt/speed_monitor.py +++ b/lit_gpt/speed_monitor.py @@ -74,6 +74,7 @@ def get_flops_available(device: torch.device, precision: str) -> Optional[float]: if device.type == "cuda": device_name = torch.cuda.get_device_name(device).lower() + print("**********************device_name*******************:", device_name) if "h100" in device_name and "hbm3" in device_name: device_name = "h100-sxm" elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name): diff --git a/lit_gpt/tokenizer.py b/lit_gpt/tokenizer.py index 1c652be..497e1d2 100644 --- a/lit_gpt/tokenizer.py +++ b/lit_gpt/tokenizer.py @@ -23,15 +23,23 @@ def __init__(self, checkpoint_dir: Path) -> None: with open(checkpoint_dir / "tokenizer_config.json") as fp: config = json.load(fp) bos_token = config.get("bos_token") - self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None - self.eos_id = self.token_to_id(config["eos_token"]) + eos_token = config.get("eos_token") + if type(bos_token) is not dict: + self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None # general qwen + self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None # general qwen + + else: + self.bos_id = self.token_to_id(bos_token["content"]) if bos_token is not None else None # deepseek + self.eos_id = self.token_to_id(eos_token["content"]) if eos_token is not None else None # deepseek + else: raise NotImplementedError @property def vocab_size(self) -> int: if self.backend == "huggingface": - return self.processor.get_vocab_size(with_added_tokens=False) + # return self.processor.get_vocab_size(with_added_tokens=False) + return self.processor.get_vocab_size(with_added_tokens=True) if self.backend == "sentencepiece": return self.processor.vocab_size() raise RuntimeError @@ -48,12 +56,12 @@ def token_to_id(self, token: str) -> int: return id_ def encode( - self, - string: str, - device: Optional[torch.device] = None, - bos: bool = False, - eos: bool = True, - max_length: int = -1, + self, + string: str, + device: Optional[torch.device] = None, + bos: bool = False, + eos: bool = True, + max_length: int = -1, ) -> torch.Tensor: if self.backend == "huggingface": tokens = self.processor.encode(string).ids @@ -75,3 +83,13 @@ def encode( def decode(self, tensor: torch.Tensor) -> str: tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() return self.processor.decode(tokens) + + +if __name__ == "__main__": + tokenizer_path = "/Users/peiji/opencsg/llm_train/data/deepseekLlamaTokenizer" + t = Tokenizer(Path(tokenizer_path)) + print(t.bos_id) + print(t.vocab_size) + + + diff --git a/lit_gpt/utils.py b/lit_gpt/utils.py index d1d7bc6..33d19c3 100644 --- a/lit_gpt/utils.py +++ b/lit_gpt/utils.py @@ -8,14 +8,17 @@ from io import BytesIO from pathlib import Path from types import MethodType -from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union +from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union, Literal import torch import torch.nn as nn import torch.utils._device -from lightning.fabric.loggers import CSVLogger +from lightning.fabric.loggers import CSVLogger, TensorBoardLogger +from lightning.pytorch.loggers import WandbLogger +from lightning.fabric.strategies import FSDPStrategy +# from lightning.fabric.utilities.load import _lazy_load as lazy_load from torch.serialization import normalize_storage_type - +import lightning as L def find_multiple(n: int, k: int) -> int: assert k > 0 @@ -233,7 +236,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: files = { "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), - "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), + # "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( checkpoint_dir / "tokenizer.model" ).is_file(), @@ -404,6 +407,26 @@ def __exit__(self, type, value, traceback): self.zipfile.write_end_of_file() +def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None: + if isinstance(fabric.strategy, FSDPStrategy): + fabric.load_raw(checkpoint_path, model, strict=strict) + else: + state_dict = lazy_load(checkpoint_path) + state_dict = state_dict.get("model", state_dict) + model.load_state_dict(state_dict, strict=strict) + + +def CLI(*args: Any, **kwargs: Any) -> Any: + from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options + + set_docstring_parse_options(attribute_docstrings=True) + set_config_read_mode(urls_enabled=True) + + kwargs.setdefault("as_positional", False) + + return CLI(*args, **kwargs) + + T = TypeVar("T") @@ -503,3 +526,19 @@ def get_default_supported_precision(training: bool, tpu: bool = False) -> str: if not torch.cuda.is_available() or torch.cuda.is_bf16_supported(): return "bf16-mixed" if training else "bf16-true" return "16-mixed" if training else "16-true" + +def choose_logger( + logger_name: Literal["csv", "tensorboard", "wandb"], + out_dir: Path, + name: str, + log_interval: int = 1, + resume: Optional[bool] = None, + **kwargs: Any, +): + if logger_name == "csv": + return CSVLogger(root_dir=(out_dir / "logs"), name="csv", flush_logs_every_n_steps=log_interval, **kwargs) + if logger_name == "tensorboard": + return TensorBoardLogger(root_dir=(out_dir / "logs"), name="tensorboard", **kwargs) + if logger_name == "wandb": + return WandbLogger(project=name, resume=resume, **kwargs) + raise ValueError(f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'.") diff --git a/pretrain/csg_tiny_10m_llama.py b/pretrain/csg_tiny_10m_llama.py new file mode 100644 index 0000000..fb0f922 --- /dev/null +++ b/pretrain/csg_tiny_10m_llama.py @@ -0,0 +1,436 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union, Dict +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually + +from eval.lm_eval_harness import EvalHarnessBase +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, \ + lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss, Tokenizer +import random + +model_name = "csg_tiny_10M_llama" +name = "csg_tiny_10M_llama" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +# micro_batch_size = 8 +micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +# eval_iters = 100 +eval_iters = 10 +save_step_interval = 5000 +# eval_step_interval = 5000 +eval_step_interval = 10 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_slim", 0.693584), + ("train_star", 0.306416), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="csg_tiny_10M_llama") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, + tokenizer_dir: Union[bool, Path] = Path("/data/models/csg-tiny-1B/csg-tiny-1B-955K"), +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + # fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume, tokenizer_dir) + + +def main(fabric, train_data_dir, val_data_dir, resume, tokenizer_dir): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights, n_layer=config.n_layer)) + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume: + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume, tokenizer_dir) + fabric.print(f"Training time: {(time.perf_counter() - train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume, tokenizer_dir): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + # measured_flops = measure_flops(meta_model, x) + # fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for train_data in train_dataloader: + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0: model.config.block_size].contiguous() + targets = train_data[:, 1: model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss=loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + eval_metrics = evaluation(fabric, model, tokenizer_dir) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + + + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + + # val loss + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * ( + state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + + # val ppl + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * ( + state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + + # ["hellaswag", "openbookqa", "winogrande", "arc_easy", "arc_challenge", "boolq", "piqa"] + + # for m in ["openbookqa"]: + for m in ["hellaswag", "openbookqa", "winogrande", "arc_easy", "arc_challenge", "boolq", "piqa"]: + val_m = eval_metrics["results"][m].get("acc_norm", -1) + if val_m == -1: + val_m = eval_metrics["results"][m].get("acc", -1) + fabric.log_dict({("metric/" + "val_" + m): val_m, "total_tokens": model.config.block_size * ( + state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + + fabric.barrier() + + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0: model.config.block_size].contiguous() + targets = val_data[:, 1: model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +@torch.no_grad() +def evaluation(fabric: L.Fabric, model: torch.nn.Module, tokenizer_dir) -> Dict: + fabric.print("evaluating............") + model.eval() + tokenizer = Tokenizer(tokenizer_dir) + # if fabric.local_rank == 0: + # if fabric.global_rank == 1: + # print("---------------evaluation devices---------------:", fabric.device) + eval_harness = EvalHarnessBase(fabric, model, tokenizer, 1) + eval_tasks = ["hellaswag", "openbookqa", "winogrande", "arc_easy", "arc_challenge", "boolq", "piqa"] + # eval_tasks = ["openbookqa"] + num_fewshot: int = 0 + limit: Optional[int] = None + bootstrap_iters: int = 100000 + no_cache: bool = True + # if fabric.global_rank == 0: + eval_res = eval_harness.run_eval(eval_tasks, num_fewshot, limit, bootstrap_iters, no_cache) + # fabric.barrier() + # eval_res = fabric.broadcast(eval_res) + # print("---------------evaluation devices---------------:", fabric.device) + # fabric.barrier() + return eval_res + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed + fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_10m_yi.py b/pretrain/csg_tiny_10m_yi.py new file mode 100644 index 0000000..7d5caa1 --- /dev/null +++ b/pretrain/csg_tiny_10m_yi.py @@ -0,0 +1,393 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "csg_tiny_10M_yi" +name = "csg_tiny_10M_yi" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_slim", 0.693584), + ("train_star", 0.306416), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="test_yi_tokenizer_10m_skypile") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_10m_yi_skypile.py b/pretrain/csg_tiny_10m_yi_skypile.py new file mode 100644 index 0000000..a513f8a --- /dev/null +++ b/pretrain/csg_tiny_10m_yi_skypile.py @@ -0,0 +1,395 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "csg_tiny_10M_yi" +name = "csg_tiny_10M_yi" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + # ("train_slim", 0.693584), + # ("train_star", 0.306416), + ("data_skypile", 1.0) + +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="test_yi_tokenizer_10m_skypile") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_120m_deepseek.py b/pretrain/csg_tiny_120m_deepseek.py new file mode 100644 index 0000000..e1d0d11 --- /dev/null +++ b/pretrain/csg_tiny_120m_deepseek.py @@ -0,0 +1,393 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "csg_tiny_120M-v2" +name = "csg_tiny_120M-v2" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_slim", 0.693584), + ("train_star", 0.306416), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="test_deepseek_tokenizer_120m-2") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_120m_yi.py b/pretrain/csg_tiny_120m_yi.py new file mode 100644 index 0000000..5319168 --- /dev/null +++ b/pretrain/csg_tiny_120m_yi.py @@ -0,0 +1,393 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "csg_tiny_120M-v4" +name = "csg_tiny_120M-v4" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_slim", 0.693584), + ("train_star", 0.306416), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="test_yi_tokenizer_120m") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_1B_llama_gitlab.py b/pretrain/csg_tiny_1B_llama_gitlab.py new file mode 100644 index 0000000..6bb4080 --- /dev/null +++ b/pretrain/csg_tiny_1B_llama_gitlab.py @@ -0,0 +1,390 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "csg-tiny-1B" +name = "csg-tiny-1B-llama-gitlab" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_gitlab", 1.0), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="csg-tiny-1B-llama_gitlab") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_1B_llama_wanjuancc.py b/pretrain/csg_tiny_1B_llama_wanjuancc.py new file mode 100644 index 0000000..87b8859 --- /dev/null +++ b/pretrain/csg_tiny_1B_llama_wanjuancc.py @@ -0,0 +1,430 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union, Dict +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually + +from eval.lm_eval_harness import EvalHarnessBase +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, \ + lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss, Tokenizer +import random + +model_name = "csg-tiny-1B" +name = "csg_tiny_1B_wanjuancc" +out_dir = Path("/data/train/csg_tiny_traindata/out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +# micro_batch_size = 8 +micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +# eval_iters = 10 +save_step_interval = 5000 +eval_step_interval = 5000 +# eval_step_interval = 10 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + # ("train_slim", 0.693584), + # ("train_star", 0.306416), + ("train_wanjuancc", 1.0) +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="csg_tiny_1B_llama_wanjuancc_val_slim") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, + tokenizer_dir: Union[bool, Path] = Path("/data/models/csg-tiny-1B/csg-tiny-1B-955K"), +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + # fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume, tokenizer_dir) + + +def main(fabric, train_data_dir, val_data_dir, resume, tokenizer_dir): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights, n_layer=config.n_layer)) + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume: + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume, tokenizer_dir) + fabric.print(f"Training time: {(time.perf_counter() - train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume, tokenizer_dir): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + # measured_flops = measure_flops(meta_model, x) + # fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for train_data in train_dataloader: + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0: model.config.block_size].contiguous() + targets = train_data[:, 1: model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss=loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + eval_metrics = evaluation(fabric, model, tokenizer_dir) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + + + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + + # val loss + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * ( + state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + + # val ppl + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * ( + state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + + # ["hellaswag", "openbookqa", "winogrande", "arc_easy", "arc_challenge", "boolq", "piqa"] + + # for m in ["openbookqa"]: + for m in ["hellaswag", "openbookqa", "winogrande", "arc_easy", "arc_challenge", "boolq", "piqa"]: + val_m = eval_metrics["results"][m].get("acc_norm", -1) + if val_m == -1: + val_m = eval_metrics["results"][m].get("acc", -1) + fabric.log_dict({("metric/" + "val_" + m): val_m, "total_tokens": model.config.block_size * ( + state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + + fabric.barrier() + + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0: model.config.block_size].contiguous() + targets = val_data[:, 1: model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +@torch.no_grad() +def evaluation(fabric: L.Fabric, model: torch.nn.Module, tokenizer_dir) -> Dict: + fabric.print("evaluating............") + model.eval() + tokenizer = Tokenizer(tokenizer_dir) + + eval_harness = EvalHarnessBase(fabric, model, tokenizer, 1) + eval_tasks = ["hellaswag", "openbookqa", "winogrande", "arc_easy", "arc_challenge", "boolq", "piqa"] + num_fewshot: int = 0 + limit: Optional[int] = None + bootstrap_iters: int = 100000 + no_cache: bool = True + eval_res = eval_harness.run_eval(eval_tasks, num_fewshot, limit, bootstrap_iters, no_cache) + + return eval_res + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed + fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_1B_qwen.py b/pretrain/csg_tiny_1B_qwen.py new file mode 100644 index 0000000..53ec607 --- /dev/null +++ b/pretrain/csg_tiny_1B_qwen.py @@ -0,0 +1,392 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "csg-tiny-1B-V3" +name = "csg-tiny-1B-V3" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_slim", 0.693584), + ("train_star", 0.306416), + # ("validation", 1), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger() + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_1B_yi_gitlab.py b/pretrain/csg_tiny_1B_yi_gitlab.py new file mode 100644 index 0000000..a5c71b1 --- /dev/null +++ b/pretrain/csg_tiny_1B_yi_gitlab.py @@ -0,0 +1,390 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "csg-tiny-1B-yi" +name = "csg-tiny-1B-yi-gitlab" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_gitlab", 1.0), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="test_yi_tokenizer_yi_gitlab") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + print("------------------------", filenames) + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_1B_yi_skypile.py b/pretrain/csg_tiny_1B_yi_skypile.py new file mode 100644 index 0000000..aabafa8 --- /dev/null +++ b/pretrain/csg_tiny_1B_yi_skypile.py @@ -0,0 +1,393 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "csg-tiny-1B-yi" +name = "csg-tiny-1B-V4-skypile" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + # ("train_slim", 0.693584), + # ("train_star", 0.306416), + ("data_skypile", 1.0) + +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="test_yi_tokenizer_1b_skypile") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_1B_yi_wudao.py b/pretrain/csg_tiny_1B_yi_wudao.py new file mode 100644 index 0000000..aa8d6b9 --- /dev/null +++ b/pretrain/csg_tiny_1B_yi_wudao.py @@ -0,0 +1,391 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "csg-tiny-1B-yi" +name = "csg-tiny-1B-V4-wudao" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_wudao", 1.0) + +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="test_yi_tokenizer_1b_wudao") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/csg_tiny_30m_yi.py b/pretrain/csg_tiny_30m_yi.py new file mode 100644 index 0000000..84947fd --- /dev/null +++ b/pretrain/csg_tiny_30m_yi.py @@ -0,0 +1,393 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "csg_tiny_30M_yi" +name = "csg_tiny_30M_yi" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_slim", 0.693584), + ("train_star", 0.306416), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="test_yi_tokenizer_30m") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/pretrain_continue_cosmopedia.sh b/pretrain/pretrain_continue_cosmopedia.sh new file mode 100644 index 0000000..a3b60b9 --- /dev/null +++ b/pretrain/pretrain_continue_cosmopedia.sh @@ -0,0 +1,7 @@ +fabric run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=2 \ + pretrain/tinyllama_continue_train.py --devices 8 --train_data_dir /data/datasets/processed/llama2_tokenize_data/cosmopedia_train --val_data_dir /data/datasets/processed/llama2_tokenize_data/slimpajama_validation --resume /data/train/csg-tiny-1B/iter-220000-ckpt.pth \ No newline at end of file diff --git a/pretrain/pythia_14m_yi.py b/pretrain/pythia_14m_yi.py new file mode 100644 index 0000000..84c36fa --- /dev/null +++ b/pretrain/pythia_14m_yi.py @@ -0,0 +1,393 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "pythia-14m" +name = "pythia-14m" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +# micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_slim", 0.693584), + ("train_star", 0.306416), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger(project="pretrain-csg-tiny", name="test_yi_tokenizer_pythia_14m") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/tinyllama.py b/pretrain/tinyllama.py index 668a88b..fd1a303 100644 --- a/pretrain/tinyllama.py +++ b/pretrain/tinyllama.py @@ -18,11 +18,14 @@ from lit_gpt.packed_dataset import CombinedDataset, PackedDataset from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor from lit_gpt.speed_monitor import estimate_flops, measure_flops -from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger#, lazy_load from pytorch_lightning.loggers import WandbLogger from lit_gpt import FusedCrossEntropyLoss + +# from lm_evaluation_harness.lm_eval import run_eval_harness import random + model_name = "tiny_LLaMA_1b" name = "tinyllama_1b" out_dir = Path("out") / name @@ -31,7 +34,8 @@ num_of_devices = 8 global_batch_size = 512 learning_rate = 4e-4 -micro_batch_size = 8 +# micro_batch_size = 8 +micro_batch_size = 16 max_step = 715256 * 2 warmup_steps = 2000 log_step_interval = 10 @@ -151,7 +155,7 @@ def main(fabric, train_data_dir, val_data_dir, resume): if resume is True: resume = sorted(out_dir.glob("*.pth"))[-1] - if resume : + if resume: fabric.print(f"Resuming training from {resume}") fabric.load(resume, state) @@ -243,7 +247,6 @@ def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): # print days as well f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " ) - monitor.on_train_batch_end( state["iter_num"] * micro_batch_size, t1 - total_t0, @@ -254,9 +257,6 @@ def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): lengths=total_lengths, train_loss = loss.item() ) - - - if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: @@ -293,7 +293,8 @@ def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoade losses[k] = loss.item() out = losses.mean() - + # + # run_eval_harness model.train() return out @@ -303,6 +304,7 @@ def create_dataloader( ) -> DataLoader: datasets = [] data_config = train_data_config if split == "train" else val_data_config + # train_data_config = [("train_slim", 0.693584),("train_star", 0.306416),] for prefix, _ in data_config: filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) random.seed(seed) diff --git a/pretrain/tinyllama_continue_train.py b/pretrain/tinyllama_continue_train.py new file mode 100644 index 0000000..4b88fd6 --- /dev/null +++ b/pretrain/tinyllama_continue_train.py @@ -0,0 +1,435 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union, Dict +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from eval.lm_eval_harness import EvalHarnessBase +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss, Tokenizer +import random + + +model_name = "csg-wukong-1B" +name = "csg-wukong-1B" +# model_name = "tiny_LLaMA_1b" +# name = "tiny_LLaMA_1b" +out_dir = Path("/data/train") / name +checkpoint_path = "/data/models/csg-tiny-1B/csg-tiny-1B-1195k/iter-4780000-ckpt.pth" +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 2e-4 +min_lr = 2e-5 +micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train", 1), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger=WandbLogger(project="continue-pretrain-csg-tiny", name="csg-tiny-1B-1195K-cosmopedia") + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, + tokenizer_dir: Union[bool, Path] = "/data/models/csg-tiny-1B/csg-tiny-1B-1195k", +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + # fabric.launch(main, train_data_dir, val_data_dir, resume, tokenizer_dir) + main(fabric, train_data_dir, val_data_dir, resume, tokenizer_dir) + + +def main(fabric, train_data_dir, val_data_dir, resume, tokenizer_dir): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + + + model = fabric.setup(model) + # fabric.load_raw(checkpoint_path, model, strict=True) + if not resume: + fabric.load(checkpoint_path, {"model": model}) + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # import bitsandbytes as bnb + # optimizer = bnb.optim.AdamW8bit( + # model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2) + # ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume: + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume, tokenizer_dir) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume, tokenizer_dir): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for curr_iter, train_data in enumerate(train_dataloader, initial_iter): + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + eval_metrics = evaluation(fabric, model, tokenizer_dir) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + + + for m in ["hellaswag", "openbookqa", "winogrande", "arc_easy", "arc_challenge", "boolq", "piqa"]: + val_m = eval_metrics["results"][m].get("acc_norm", -1) + if val_m == -1: + val_m = eval_metrics["results"][m].get("acc", -1) + fabric.log_dict({("eval/" + "val_" + m): val_m, "total_tokens": model.config.block_size * ( + state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + + fabric.barrier() + + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +@torch.no_grad() +def evaluation(fabric: L.Fabric, model: torch.nn.Module, tokenizer_dir) -> Dict: + fabric.print("evaluating............") + model.eval() + tokenizer = Tokenizer(tokenizer_dir) + + eval_harness = EvalHarnessBase(fabric, model, tokenizer, 1) + eval_tasks = ["hellaswag", "openbookqa", "winogrande", "arc_easy", "arc_challenge", "boolq", "piqa"] + num_fewshot: int = 0 + limit: Optional[int] = None + bootstrap_iters: int = 100000 + no_cache: bool = True + eval_res = eval_harness.run_eval(eval_tasks, num_fewshot, limit, bootstrap_iters, no_cache) + + return eval_res + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, + shuffle: bool = True, seed: int = 12345, split="train") -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + n_chunks = 8 + if len(filenames) < 8: + n_chunks = 1 + else: + n_chunks = 8 + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=n_chunks, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain/tinyllama_monitor.py b/pretrain/tinyllama_monitor.py new file mode 100644 index 0000000..9305a8e --- /dev/null +++ b/pretrain/tinyllama_monitor.py @@ -0,0 +1,393 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "tiny_LLaMA_1b" +name = "tinyllama_1b" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +# micro_batch_size = 8 +micro_batch_size = 16 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-5 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_slim", 0.693584), + ("train_star", 0.306416), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger() + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=3407, + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=False): + model = GPT(config) + model.apply(partial(model._init_weights, n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for train_data in train_dataloader: + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + state["step_count"], + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed+fabric.global_rank, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/pretrain_scripts/pretrain.sh b/pretrain_scripts/pretrain.sh new file mode 100644 index 0000000..15afd62 --- /dev/null +++ b/pretrain_scripts/pretrain.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/tinyllama.py --devices 8 --train_data_dir /data/datasets/processed/deepseek_llama2_tokenize_data/slim_star_combined --val_data_dir /data/datasets/processed/deepseek_llama2_tokenize_data/slim_star_combined_validation diff --git a/pretrain_scripts/pretrain_continue_cosmopedia.sh b/pretrain_scripts/pretrain_continue_cosmopedia.sh new file mode 100644 index 0000000..4539858 --- /dev/null +++ b/pretrain_scripts/pretrain_continue_cosmopedia.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/tinyllama_continue_train.py --devices 8 --train_data_dir /data/datasets/processed/llama2_tokenize_data/cosmopedia_train --val_data_dir /data/datasets/processed/llama2_tokenize_data/slimpajama_validation \ No newline at end of file diff --git a/pretrain_scripts/pretrain_csg_tiny_10m_llama.sh b/pretrain_scripts/pretrain_csg_tiny_10m_llama.sh new file mode 100644 index 0000000..d7ee2f5 --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_10m_llama.sh @@ -0,0 +1,7 @@ +fabric run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/csg_tiny_10m_llama.py --devices 8 --train_data_dir /data/datasets/processed/llama2_tokenize_data/slim_star_combined --val_data_dir /data/datasets/processed/llama2_tokenize_data/slimpajama_validation diff --git a/pretrain_scripts/pretrain_csg_tiny_10m_yi.sh b/pretrain_scripts/pretrain_csg_tiny_10m_yi.sh new file mode 100644 index 0000000..4292b8e --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_10m_yi.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/csg_tiny_10m_yi.py --devices 8 --train_data_dir /data/datasets_bak/processed/yi_llama2_tokenize_data/slim_star_combined --val_data_dir /data/datasets_bak/processed/yi_llama2_tokenize_data/slimpajama_validation \ No newline at end of file diff --git a/pretrain_scripts/pretrain_csg_tiny_10m_yi_skypile.sh b/pretrain_scripts/pretrain_csg_tiny_10m_yi_skypile.sh new file mode 100644 index 0000000..ecbe6ec --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_10m_yi_skypile.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/csg_tiny_10m_yi_skypile.py --devices 8 --train_data_dir /data/datasets_bak/processed/yi_llama2_tokenize_data/skypile_train --resume /data/project/TinyLlama/out/csg_tiny_10M_yi/iter-320000-ckpt.pth diff --git a/pretrain_scripts/pretrain_csg_tiny_120m_deepseek.sh b/pretrain_scripts/pretrain_csg_tiny_120m_deepseek.sh new file mode 100644 index 0000000..33e801f --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_120m_deepseek.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/csg_tiny_120m_deepseek.py --devices 8 --train_data_dir /data/datasets_bak/processed/deepseek_llama2_tokenize_data/slim_star_combined --val_data_dir /data/datasets_bak/processed/deepseek_llama2_tokenize_data/slimpajama_validation diff --git a/pretrain_scripts/pretrain_csg_tiny_1B.sh b/pretrain_scripts/pretrain_csg_tiny_1B.sh new file mode 100644 index 0000000..fc00a51 --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_1B.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/csg_tiny_1B.py --devices 8 --train_data_dir /data/datasets/processed/deepseek_llama2_tokenize_data/slim_star_combined --val_data_dir /data/datasets/processed/deepseek_llama2_tokenize_data/slim_star_combined_validation diff --git a/pretrain_scripts/pretrain_csg_tiny_1B_llama_gitlab.sh b/pretrain_scripts/pretrain_csg_tiny_1B_llama_gitlab.sh new file mode 100644 index 0000000..51800fe --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_1B_llama_gitlab.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=2 \ + pretrain/csg_tiny_1B_llama_gitlab.py --devices 8 --train_data_dir /data/datasets/processed/llama2_tokenize_data/gitlab_train diff --git a/pretrain_scripts/pretrain_csg_tiny_1B_llama_wanjuancc.sh b/pretrain_scripts/pretrain_csg_tiny_1B_llama_wanjuancc.sh new file mode 100644 index 0000000..1b56fd5 --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_1B_llama_wanjuancc.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/csg_tiny_1B_llama_wanjuancc.py --devices 8 --train_data_dir /data/datasets/processed/llama2_tokenize_data/wanjuancc_train --val_data_dir /data/datasets/processed/llama2_tokenize_data/slimpajama_validation diff --git a/pretrain_scripts/pretrain_csg_tiny_1B_qwen.sh b/pretrain_scripts/pretrain_csg_tiny_1B_qwen.sh new file mode 100644 index 0000000..d6e1103 --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_1B_qwen.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/csg_tiny_1B_qwen.py --devices 8 --train_data_dir /data/datasets/processed/qwen_llama2_tokenize_data/slim_star_combined --val_data_dir /data/datasets/processed/qwen_llama2_tokenize_data/slimpajama_validation diff --git a/pretrain_scripts/pretrain_csg_tiny_1B_yi_gitlab.sh b/pretrain_scripts/pretrain_csg_tiny_1B_yi_gitlab.sh new file mode 100644 index 0000000..b2c96c6 --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_1B_yi_gitlab.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/csg_tiny_1B_yi_gitlab.py --devices 8 --train_data_dir /data/datasets_bak/processed/yi_llama2_tokenize_data/gitlab_train \ No newline at end of file diff --git a/pretrain_scripts/pretrain_csg_tiny_1B_yi_skypile.sh b/pretrain_scripts/pretrain_csg_tiny_1B_yi_skypile.sh new file mode 100644 index 0000000..7cfd94a --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_1B_yi_skypile.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=2 \ + pretrain/csg_tiny_1B_yi_skypile.py --devices 8 --train_data_dir /data/datasets_bak/processed/yi_llama2_tokenize_data/skypile_train diff --git a/pretrain_scripts/pretrain_csg_tiny_1B_yi_wudao.sh b/pretrain_scripts/pretrain_csg_tiny_1B_yi_wudao.sh new file mode 100644 index 0000000..76795e1 --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_1B_yi_wudao.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=2 \ + pretrain/csg_tiny_1B_yi_wudao.py --devices 8 --train_data_dir /data/datasets_bak/processed/yi_llama2_tokenize_data/wudao_train diff --git a/pretrain_scripts/pretrain_csg_tiny_30m_yi.sh b/pretrain_scripts/pretrain_csg_tiny_30m_yi.sh new file mode 100644 index 0000000..ac494e0 --- /dev/null +++ b/pretrain_scripts/pretrain_csg_tiny_30m_yi.sh @@ -0,0 +1,7 @@ +lightning run model \ + --node-rank=0 \ + --main-address=192.168.48.36 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=1 \ + pretrain/csg_tiny_30m_yi.py --devices 8 --train_data_dir /data/datasets_bak/processed/yi_llama2_tokenize_data/slim_star_combined --val_data_dir /data/datasets_bak/processed/yi_llama2_tokenize_data/slimpajama_validation \ No newline at end of file diff --git a/scripts/convert_lit_checkpoint.py b/scripts/convert_lit_checkpoint.py index 6b41b8c..1d70a0e 100644 --- a/scripts/convert_lit_checkpoint.py +++ b/scripts/convert_lit_checkpoint.py @@ -295,6 +295,8 @@ def convert_lit_checkpoint(*, # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"") pth_file = out_dir / checkpoint_name bin_file = pth_file.with_suffix(".bin") + # bin_file = pth_file.with_suffix(".pth") + with incremental_save(bin_file) as saver: with contextlib.ExitStack() as stack: diff --git a/scripts/prepare_cosmopedia.py b/scripts/prepare_cosmopedia.py new file mode 100644 index 0000000..6af2e00 --- /dev/null +++ b/scripts/prepare_cosmopedia.py @@ -0,0 +1,110 @@ +import json +import glob +import os +from pathlib import Path +import sys +from typing import List +import numpy as np +from tqdm import tqdm +from multiprocessing import Process, cpu_count +from datasets import load_dataset + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Tokenizer + +# Filename for cosmopedia +cosmopedia_sets = { + "train_auto_math_text": "data/auto_math_text/*", + "train": "data/khanacademy/*", + "train": "data/openstax/*", + "train": "data/stanford/*", + "train": "data/stories/*", + "train": "data/web_samples_v1/*", + "train": "data/wikihow/*" + +} + + +def prepare_full( + source_path: Path, + tokenizer_path: Path, + destination_path: Path, + chunk_size: int, + split: str="train", + filenames_subset: List[str] = None, + process_id: int = 0 +) -> None: + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(tokenizer_path) + + # Use the provided filenames_subset or default to all filenames + filenames = filenames_subset + + if not filenames: + raise RuntimeError( + f"No files matching {cosmopedia_sets[split]} found at {source_path}. \n" + "Make sure you download the data..." + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=f"{split}_{process_id}", # Use process_id to differentiate builders + chunk_size=chunk_size, + sep_token=tokenizer.bos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for filepath in filenames: + print(f"Processing {filepath}") + ds = load_dataset("parquet", data_files={"train": filepath}, split="train", streaming=True) + for row in tqdm(iter(ds)): + text = row["prompt"] + row["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details + # builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + destination_path: Path = Path("data/red_pajama_sample"), + chunk_size: int = 2049 * 1024, + split: str="train", + percentage: float = 1.0, +) -> None: + import time + + filenames = glob.glob(os.path.join(source_path, cosmopedia_sets[split]), recursive=True) + filenames = filenames[:int(len(filenames) * percentage)] + + num_processes = 2 + chunked_filenames = np.array_split(filenames, num_processes) + + processes = [] + start_time = time.time() + + for i, subset in enumerate(chunked_filenames): + p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) + processes.append(p) + p.start() + + for p in processes: + p.join() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + from jsonargparse import CLI + CLI(prepare) \ No newline at end of file diff --git a/scripts/prepare_general.py b/scripts/prepare_general.py new file mode 100644 index 0000000..e89b78b --- /dev/null +++ b/scripts/prepare_general.py @@ -0,0 +1,110 @@ +import json +import glob +import os +from pathlib import Path +import sys +from typing import List +import numpy as np +from tqdm import tqdm +from multiprocessing import Process, cpu_count +from datasets import load_dataset + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Tokenizer + +prefex = "general" +# Filename for general +general_sets = { + "train": "*.json", +} + + +def prepare_full( + source_path: Path, + tokenizer_path: Path, + destination_path: Path, + chunk_size: int, + split: str="train", + filenames_subset: List[str] = None, + process_id: int = 0 +) -> None: + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(tokenizer_path) + + # Use the provided filenames_subset or default to all filenames + filenames = filenames_subset + + if not filenames: + raise RuntimeError( + f"No files matching {general_sets[split]} found at {source_path}. \n" + "Make sure you download the data..." + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=f"{split}_{prefex}_{process_id}", # Use process_id to differentiate builders + chunk_size=chunk_size, + sep_token=tokenizer.bos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for filepath in filenames: + print(f"Processing {filepath}") + ds = load_dataset("json", data_files={"train": filepath}, split="train", streaming=True) + for row in tqdm(iter(ds)): + if "boolq_write_out_info.json" not in filepath: + text = row["prompt_" + str(row["truth"])] + else: + yesorno={" yes": "0", " no": "1"} + text = row["prompt_" + yesorno[row["truth"]]] + # print("text:", text) + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details + # builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + destination_path: Path = Path("data/red_pajama_sample"), + # chunk_size: int = 2049 * 1024, + chunk_size: int = 2049 * 64, + split: str="train", + percentage: float = 1.0, +) -> None: + import time + + filenames = glob.glob(os.path.join(source_path, general_sets[split]), recursive=True) + filenames = filenames[:int(len(filenames) * percentage)] + + num_processes = min(len(filenames), cpu_count()) + chunked_filenames = np.array_split(filenames, num_processes) + + processes = [] + start_time = time.time() + # print("filenames:", chunked_filenames) + for i, subset in enumerate(chunked_filenames): + p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) + processes.append(p) + p.start() + + for p in processes: + p.join() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + from jsonargparse import CLI + CLI(prepare) \ No newline at end of file diff --git a/scripts/prepare_gitlab.py b/scripts/prepare_gitlab.py new file mode 100644 index 0000000..75976b3 --- /dev/null +++ b/scripts/prepare_gitlab.py @@ -0,0 +1,104 @@ +import json +import glob +import os +from pathlib import Path +import sys +from typing import List +import numpy as np +from tqdm import tqdm +from multiprocessing import Process, cpu_count + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Tokenizer + +# Filename for gitlab +gitlab_sets = { + "train": "*", +} + + +def prepare_full( + source_path: Path, + tokenizer_path: Path, + destination_path: Path, + chunk_size: int, + split: str="train", + filenames_subset: List[str] = None, + process_id: int = 0 +) -> None: + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(tokenizer_path) + + # Use the provided filenames_subset or default to all filenames + filenames = filenames_subset + # print("------------", filenames) + + # if not filenames: + # raise RuntimeError( + # f"No files matching {gitlab_sets[split]} found at {source_path}. \n" + # "Make sure you download the data..." + # ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=f"{split}_gitlab_{process_id}", # Use process_id to differentiate builders + chunk_size=chunk_size, + sep_token=tokenizer.bos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for filepath in filenames: + print(f"Processing {filepath}") + with open(filepath, encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details + # builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + destination_path: Path = Path("data/red_pajama_sample"), + chunk_size: int = 2049 * 1024, + split: str="train", + percentage: float = 1.0, +) -> None: + import time + + filenames = glob.glob(os.path.join(source_path, gitlab_sets[split]), recursive=True) + filenames = filenames[:int(len(filenames) * percentage)] + print("----------------------", filenames) + + num_processes = cpu_count() + chunked_filenames = np.array_split(filenames, num_processes) + + processes = [] + start_time = time.time() + + for i, subset in enumerate(chunked_filenames): + p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) + processes.append(p) + p.start() + + for p in processes: + p.join() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + from jsonargparse import CLI + CLI(prepare) \ No newline at end of file diff --git a/scripts/prepare_skypile.py b/scripts/prepare_skypile.py new file mode 100644 index 0000000..cc4a108 --- /dev/null +++ b/scripts/prepare_skypile.py @@ -0,0 +1,106 @@ +import json +import glob +import os +from pathlib import Path +import sys +from typing import List +import numpy as np +from tqdm import tqdm +from multiprocessing import Process, cpu_count + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Tokenizer + +# Filename for SlimPajama +skypile_sets = { + # "train": "train/chunk*/*", + # "validation": "validation/chunk*/*", + # "test": "test/chunk*/*", + "data": "data/*" +} + + +def prepare_full( + source_path: Path, + tokenizer_path: Path, + destination_path: Path, + chunk_size: int, + split: str="data", + filenames_subset: List[str] = None, + process_id: int = 0 +) -> None: + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(tokenizer_path) + + # Use the provided filenames_subset or default to all filenames + filenames = filenames_subset + + if not filenames: + raise RuntimeError( + f"No files matching {skypile_sets[split]} found at {source_path}. \n" + "Make sure you download the data..." + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=f"{split}_skypile_{process_id}", # Use process_id to differentiate builders + chunk_size=chunk_size, + sep_token=tokenizer.bos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for filepath in filenames: + print(f"Processing {filepath}") + with open(filepath, encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details + # builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + destination_path: Path = Path("data/red_pajama_sample"), + chunk_size: int = 2049 * 1024, + split: str="train", + percentage: float = 1.0, +) -> None: + import time + + filenames = glob.glob(os.path.join(source_path, skypile_sets[split]), recursive=True) + filenames = filenames[:int(len(filenames) * percentage)] + print("the number of files:", len(filenames)) + + num_processes = cpu_count() + chunked_filenames = np.array_split(filenames, num_processes) + + processes = [] + start_time = time.time() + + for i, subset in enumerate(chunked_filenames): + p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) + processes.append(p) + p.start() + + for p in processes: + p.join() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + from jsonargparse import CLI + CLI(prepare) \ No newline at end of file diff --git a/scripts/prepare_starcoder.py b/scripts/prepare_starcoder.py index 7ef9c68..c009462 100644 --- a/scripts/prepare_starcoder.py +++ b/scripts/prepare_starcoder.py @@ -82,7 +82,8 @@ def prepare( if filenames_subset: filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])] filenames = filenames[:int(len(filenames) * percentage)] - num_processes = 64 + + num_processes = min(len(filenames), cpu_count()) chunked_filenames = np.array_split(filenames, num_processes) processes = [] @@ -102,4 +103,4 @@ def prepare( if __name__ == "__main__": from jsonargparse import CLI - CLI(prepare) + CLI(prepare) \ No newline at end of file diff --git a/scripts/prepare_wanjuancc.py b/scripts/prepare_wanjuancc.py new file mode 100644 index 0000000..dc14894 --- /dev/null +++ b/scripts/prepare_wanjuancc.py @@ -0,0 +1,114 @@ +import json +import glob +import os +from pathlib import Path +import sys +from typing import List +import numpy as np +from tqdm import tqdm +import tarfile +from multiprocessing import Process, cpu_count + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Tokenizer + +# Filename for wanjuancc +wanjuancc_sets = { + "train": "raw/*", +} + + +def prepare_full( + source_path: Path, + tokenizer_path: Path, + destination_path: Path, + chunk_size: int, + split: str="train", + filenames_subset: List[str] = None, + process_id: int = 0 +) -> None: + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(tokenizer_path) + + # Use the provided filenames_subset or default to all filenames + filenames = filenames_subset + + if not filenames: + raise RuntimeError( + f"No files matching {wanjuancc_sets[split]} found at {source_path}. \n" + "Make sure you download the data..." + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=f"{split}_wanjuancc_{process_id}", # Use process_id to differentiate builders + chunk_size=chunk_size, + sep_token=tokenizer.bos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for filepath in filenames: + print(f"Processing {filepath}") + with tarfile.open(filepath, 'r:gz') as tar: + # 遍历 tar 中的每个成员 + for member in tar.getmembers(): + #提取文件对象 + print("member:", member) + file = tar.extractfile(member) + print("file:", file) + # 逐行解析 JSON 数据 + for row in (file.readlines()): + text = json.loads(row.decode("utf-8"))["content"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details + # builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + destination_path: Path = Path("data/red_pajama_sample"), + chunk_size: int = 2049 * 1024, + split: str="train", + percentage: float = 1.0, +) -> None: + import time + + filenames = glob.glob(os.path.join(source_path, wanjuancc_sets[split]), recursive=True) + filenames = filenames[:int(len(filenames) * percentage)] + print("count of filenames:", len(filenames)) + + num_processes = min(len(filenames), cpu_count()) + # num_processes = 1 + # num_processes = cpu_count() + # num_processes = 61 + chunked_filenames = np.array_split(filenames, num_processes) + + processes = [] + start_time = time.time() + + for i, subset in tqdm(enumerate(chunked_filenames)): + p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) + processes.append(p) + p.start() + + for p in processes: + p.join() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + from jsonargparse import CLI + CLI(prepare) diff --git a/scripts/prepare_wudao.py b/scripts/prepare_wudao.py new file mode 100644 index 0000000..258a676 --- /dev/null +++ b/scripts/prepare_wudao.py @@ -0,0 +1,104 @@ +import json +import glob +import os +from pathlib import Path +import sys +from typing import List +import numpy as np +from tqdm import tqdm +from multiprocessing import Process, cpu_count + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Tokenizer + +# Filename for wudao +wudao_sets = { + "train": "./*" +} + + +def prepare_full( + source_path: Path, + tokenizer_path: Path, + destination_path: Path, + chunk_size: int, + split: str="train", + filenames_subset: List[str] = None, + process_id: int = 0 +) -> None: + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(tokenizer_path) + + # Use the provided filenames_subset or default to all filenames + filenames = filenames_subset + # print("filenames:", filenames) + + if not filenames: + raise RuntimeError( + f"No files matching {wudao_sets[split]} found at {source_path}. \n" + "Make sure you download the data..." + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=f"{split}_wudao_{process_id}", # Use process_id to differentiate builders + chunk_size=chunk_size, + sep_token=tokenizer.bos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for filepath in filenames: + print(f"Processing {filepath}") + with open(filepath, encoding="utf-8") as f: + texts = json.load(f) + for row in tqdm(texts): + text_ids = tokenizer.encode(row["content"]) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details + # builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + destination_path: Path = Path("data/red_pajama_sample"), + chunk_size: int = 2049 * 1024, + split: str="train", + percentage: float = 1.0, +) -> None: + import time + + filenames = glob.glob(os.path.join(source_path, wudao_sets[split]), recursive=True) + filenames = filenames[:int(len(filenames) * percentage)] + print("the number of files:", len(filenames)) + + num_processes = cpu_count() + chunked_filenames = np.array_split(filenames, num_processes) + + processes = [] + start_time = time.time() + + for i, subset in enumerate(chunked_filenames): + p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) + processes.append(p) + p.start() + + for p in processes: + p.join() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + from jsonargparse import CLI + CLI(prepare) \ No newline at end of file diff --git a/tests/test.ipynb b/tests/test.ipynb new file mode 100644 index 0000000..8169537 --- /dev/null +++ b/tests/test.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import jsonlines\n", + "import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "filepath = \"/data/datasets/Skywork/SkyPile-150B/data/2020-40_zh_head_0000.jsonl\"" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "file = open(filepath)\n", + "lines = file.readlines()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "468550\n" + ] + } + ], + "source": [ + "print(len(lines))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\"text\": \"会计人员信息登记和变更,是会计人员顺利开展继续教育、会计专业技术资格考试报名、高级会计专业技术资格评审、参加高端会计人才选拔等的重要基础信息。广大会计人员一定要高度重视信息登记和变更工作,保证信息的完整性和准确性。 信息采集内容是按照全国会计人员信息管理要求列示的最基本信息,旨在为会计专业技术资格考试、高级(正高级)会计师评审、继续教育学习、会计类培训报名、会计人才选拔等工作提供信息检索和核查,实现会计资格考试报名和继续教育等“零跑腿”服务。 会计人员范围是指根据《中华人民共和国会计法》的规定,在国家机关、社会团体、企业、事业单位和其他组织中从事会计核算、实行会计监督等会计工作的人员。山东省内的会计人员,具体包括:具有会计专业技术(含初级、中级、高级、正高级)资格的人员;不具有会计专业技术资格但从事会计工作的人员。 从事会计工作的人员指从事下列具体工作的人员: (一)出纳; (二)稽核; (三)资产、负债和所有者权益(净资产)的核算; (四)收入、费用(支出)的核算; (五)财务成果(政府预算执行结果)的核算; (六)财务会计报告(决算报告)编制; (七)会计监督; (八)会计机构内会计档案管理; (九)其他会计工作。 担任单位会计机构负责人(会计主管人员)、总会计师的人员,属于会计人员。 会计人员信息采集分为集中采集和长期采集阶段, 集中采集阶段:2019年4月20日至6月30日。 长期采集阶段:会计人员可以继续长期进行信息采集。 会计人员应根据个人信息变化情况,及时登录系统对个人信息进行更新,山东省财政厅将对连续3年个人信息(包括继续教育)未更新的人员进行清理。 信息采集采用会计人员网上填报本人信息、上传相关资料和财政部门审核确认的方式。 (一)会计人员登陆山东省财政厅网站([URL],点击右下角“山东会计管理”专题页面,“进入“会计人员信息采集”入口,根据说明和提示注册并填报相关信息。 (二)各级财政部门会计管理人员登陆会计人员信息采集系统管理端,根据申报人员填报的信息和提交的证明材料进行审核。 (三)信息采集遵循属地化原则,会计人员在工作单位所在地采集信息。 请广大会计从业人员务必确保信息的真实性、准确性、完整性,避免损害您的个人信用和从业资格。\\n\"}\n", + "\n" + ] + } + ], + "source": [ + "line = lines[7]\n", + "print(line)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "会计人员信息登记和变更,是会计人员顺利开展继续教育、会计专业技术资格考试报名、高级会计专业技术资格评审、参加高端会计人才选拔等的重要基础信息。广大会计人员一定要高度重视信息登记和变更工作,保证信息的完整性和准确性。 信息采集内容是按照全国会计人员信息管理要求列示的最基本信息,旨在为会计专业技术资格考试、高级(正高级)会计师评审、继续教育学习、会计类培训报名、会计人才选拔等工作提供信息检索和核查,实现会计资格考试报名和继续教育等“零跑腿”服务。 会计人员范围是指根据《中华人民共和国会计法》的规定,在国家机关、社会团体、企业、事业单位和其他组织中从事会计核算、实行会计监督等会计工作的人员。山东省内的会计人员,具体包括:具有会计专业技术(含初级、中级、高级、正高级)资格的人员;不具有会计专业技术资格但从事会计工作的人员。 从事会计工作的人员指从事下列具体工作的人员: (一)出纳; (二)稽核; (三)资产、负债和所有者权益(净资产)的核算; (四)收入、费用(支出)的核算; (五)财务成果(政府预算执行结果)的核算; (六)财务会计报告(决算报告)编制; (七)会计监督; (八)会计机构内会计档案管理; (九)其他会计工作。 担任单位会计机构负责人(会计主管人员)、总会计师的人员,属于会计人员。 会计人员信息采集分为集中采集和长期采集阶段, 集中采集阶段:2019年4月20日至6月30日。 长期采集阶段:会计人员可以继续长期进行信息采集。 会计人员应根据个人信息变化情况,及时登录系统对个人信息进行更新,山东省财政厅将对连续3年个人信息(包括继续教育)未更新的人员进行清理。 信息采集采用会计人员网上填报本人信息、上传相关资料和财政部门审核确认的方式。 (一)会计人员登陆山东省财政厅网站([URL],点击右下角“山东会计管理”专题页面,“进入“会计人员信息采集”入口,根据说明和提示注册并填报相关信息。 (二)各级财政部门会计管理人员登陆会计人员信息采集系统管理端,根据申报人员填报的信息和提交的证明材料进行审核。 (三)信息采集遵循属地化原则,会计人员在工作单位所在地采集信息。 请广大会计从业人员务必确保信息的真实性、准确性、完整性,避免损害您的个人信用和从业资格。\n", + "\n", + "1\n" + ] + }, + { + "data": { + "text/plain": [ + "str" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = json.loads(line)[\"text\"]\n", + "print(text)\n", + "print(len(text[0]))\n", + "type(text)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "module 'jsonlines' has no attribute 'load'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/data/project/TinyLlama/test.ipynb 单元格 3\u001b[0m line \u001b[0;36m3\n\u001b[1;32m 1\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39m(filepath, \u001b[39m\"\u001b[39m\u001b[39mrb\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mas\u001b[39;00m f:\n\u001b[1;32m 2\u001b[0m \u001b[39mfor\u001b[39;00m row \u001b[39min\u001b[39;00m f:\n\u001b[0;32m----> 3\u001b[0m text \u001b[39m=\u001b[39m jsonlines\u001b[39m.\u001b[39;49mload(line)\n\u001b[1;32m 4\u001b[0m \u001b[39mprint\u001b[39m(text)\n", + "\u001b[0;31mAttributeError\u001b[0m: module 'jsonlines' has no attribute 'load'" + ] + } + ], + "source": [ + "with open(filepath, \"rb\") as f:\n", + " for row in f:\n", + " text = jsonlines.loads(line)\n", + " print(text)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jsonlines" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cu118", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tokenize_scripts/deepseek_llama2_tokenize_train_slimpajamash.sh b/tokenize_scripts/deepseek_llama2_tokenize_train_slimpajamash.sh new file mode 100644 index 0000000..d75673c --- /dev/null +++ b/tokenize_scripts/deepseek_llama2_tokenize_train_slimpajamash.sh @@ -0,0 +1,5 @@ +python scripts/prepare_slimpajama.py \ + --source_path /data/datasets/cerebras/SlimPajama-627B \ + --tokenizer_path /data/datasets/tokenizers/deepseekLlamaTokenizer \ + --destination_path /data/datasets_bak/processed/deepseek_llama2_tokenize_data/slimpajama_train \ + --split train --percentage 1.0 \ No newline at end of file diff --git a/tokenize_scripts/deepseek_llama2_tokenize_train_starcoderdata.sh b/tokenize_scripts/deepseek_llama2_tokenize_train_starcoderdata.sh new file mode 100644 index 0000000..6e6c682 --- /dev/null +++ b/tokenize_scripts/deepseek_llama2_tokenize_train_starcoderdata.sh @@ -0,0 +1,5 @@ +python scripts/prepare_starcoder.py \ + --source_path /data/datasets/starcoderdata \ + --tokenizer_path /data/datasets/tokenizers/deepseekLlamaTokenizer \ + --destination_path /data/datasets_bak/processed/deepseek_llama2_tokenize_data/starcoder_train \ + --split train --percentage 1.0 diff --git a/tokenize_scripts/deepseek_llama2_tokenize_validation_slimpajama.sh b/tokenize_scripts/deepseek_llama2_tokenize_validation_slimpajama.sh new file mode 100644 index 0000000..9fb6be7 --- /dev/null +++ b/tokenize_scripts/deepseek_llama2_tokenize_validation_slimpajama.sh @@ -0,0 +1,5 @@ +python scripts/prepare_slimpajama.py \ + --source_path /data/datasets/cerebras/SlimPajama-627B \ + --tokenizer_path /data/datasets/tokenizers/deepseekLlamaTokenizer \ + --destination_path /data/datasets_bak/processed/deepseek_llama2_tokenize_data/slimpajama_validation \ + --split validation --percentage 1.0 \ No newline at end of file diff --git a/tokenize_scripts/llama2_tokenize_train_cosmopedia.sh b/tokenize_scripts/llama2_tokenize_train_cosmopedia.sh new file mode 100644 index 0000000..c9950fa --- /dev/null +++ b/tokenize_scripts/llama2_tokenize_train_cosmopedia.sh @@ -0,0 +1,5 @@ +python scripts/prepare_cosmopedia.py \ + --source_path /data/datasets/HuggingFaceTB/cosmopedia \ + --tokenizer_path /data/datasets/tokenizers/Llama2Tokenizer \ + --destination_path /data/datasets/processed/llama2_tokenize_data/cosmopedia/wikihow \ + --split train --percentage 1.0 \ No newline at end of file diff --git a/tokenize_scripts/llama2_tokenize_train_gitlab.sh b/tokenize_scripts/llama2_tokenize_train_gitlab.sh new file mode 100644 index 0000000..65c9456 --- /dev/null +++ b/tokenize_scripts/llama2_tokenize_train_gitlab.sh @@ -0,0 +1,5 @@ +python scripts/prepare_gitlab.py \ + --source_path /data/datasets/gitlab_model/train \ + --tokenizer_path /data/datasets/tokenizers/Llama2Tokenizer \ + --destination_path /data/datasets/processed/llama2_tokenize_data/gitlab_train \ + --split train --percentage 1.0 \ No newline at end of file diff --git a/tokenize_scripts/llama2_tokenize_train_wanjuancc.sh b/tokenize_scripts/llama2_tokenize_train_wanjuancc.sh new file mode 100644 index 0000000..cf237cc --- /dev/null +++ b/tokenize_scripts/llama2_tokenize_train_wanjuancc.sh @@ -0,0 +1,5 @@ +python scripts/prepare_wanjuancc.py \ + --source_path /data/datasets_bak/opendatalab/OpenDataLab___WanJuanCC \ + --tokenizer_path /data/datasets/tokenizers/Llama2Tokenizer \ + --destination_path /data/datasets/processed/llama2_tokenize_data/wanjuancc_train \ + --split train --percentage 1.0 diff --git a/tokenize_scripts/qwen_llama2_tokenize_train_slimpajama.sh b/tokenize_scripts/qwen_llama2_tokenize_train_slimpajama.sh new file mode 100644 index 0000000..6940598 --- /dev/null +++ b/tokenize_scripts/qwen_llama2_tokenize_train_slimpajama.sh @@ -0,0 +1,5 @@ +python scripts/prepare_slimpajama.py \ + --source_path /data/datasets/cerebras/SlimPajama-627B \ + --tokenizer_path /data/datasets/tokenizers/qwenLlamaTokenizer \ + --destination_path /data/datasets/processed/qwen_llama2_tokenize_data/slimpajama_train \ + --split train --percentage 1.0 \ No newline at end of file diff --git a/tokenize_scripts/qwen_llama2_tokenize_train_starcoderdata.sh b/tokenize_scripts/qwen_llama2_tokenize_train_starcoderdata.sh new file mode 100644 index 0000000..8656360 --- /dev/null +++ b/tokenize_scripts/qwen_llama2_tokenize_train_starcoderdata.sh @@ -0,0 +1,5 @@ +python scripts/prepare_starcoder.py \ + --source_path /data/datasets/starcoderdata \ + --tokenizer_path /data/datasets/tokenizers/qwenLlamaTokenizer \ + --destination_path /data/datasets/processed/qwen_llama2_tokenize_data/starcoder_train \ + --split train --percentage 1.0 diff --git a/tokenize_scripts/qwen_llama2_tokenize_validation_slimpajama.sh b/tokenize_scripts/qwen_llama2_tokenize_validation_slimpajama.sh new file mode 100644 index 0000000..d4acfa5 --- /dev/null +++ b/tokenize_scripts/qwen_llama2_tokenize_validation_slimpajama.sh @@ -0,0 +1,5 @@ +python scripts/prepare_slimpajama.py \ + --source_path /data/datasets/cerebras/SlimPajama-627B \ + --tokenizer_path /data/datasets/tokenizers/qwenLlamaTokenizer \ + --destination_path /data/datasets/processed/qwen_llama2_tokenize_data/slimpajama_validation \ + --split validation --percentage 1.0 \ No newline at end of file diff --git a/tokenize_scripts/yi_llama2_tokenize_train_gitlab.sh b/tokenize_scripts/yi_llama2_tokenize_train_gitlab.sh new file mode 100644 index 0000000..8a4dddf --- /dev/null +++ b/tokenize_scripts/yi_llama2_tokenize_train_gitlab.sh @@ -0,0 +1,5 @@ +python scripts/prepare_gitlab.py \ + --source_path /data/datasets/gitlab_model/train \ + --tokenizer_path /data/datasets/tokenizers/yiLlamaTokenizer \ + --destination_path /data/datasets_bak/processed/yi_llama2_tokenize_data/gitlab_train \ + --split train --percentage 1.0 \ No newline at end of file diff --git a/tokenize_scripts/yi_llama2_tokenize_train_skypile.sh b/tokenize_scripts/yi_llama2_tokenize_train_skypile.sh new file mode 100644 index 0000000..4765c68 --- /dev/null +++ b/tokenize_scripts/yi_llama2_tokenize_train_skypile.sh @@ -0,0 +1,5 @@ +python scripts/prepare_skypile.py \ + --source_path /data/datasets/Skywork/SkyPile-150B \ + --tokenizer_path /data/datasets/tokenizers/yiLlamaTokenizer \ + --destination_path /data/datasets_bak/processed/yi_llama2_tokenize_data/skypile_train \ + --split data --percentage 1.0 \ No newline at end of file diff --git a/tokenize_scripts/yi_llama2_tokenize_train_slimpajama.sh b/tokenize_scripts/yi_llama2_tokenize_train_slimpajama.sh new file mode 100644 index 0000000..cb64579 --- /dev/null +++ b/tokenize_scripts/yi_llama2_tokenize_train_slimpajama.sh @@ -0,0 +1,5 @@ +python scripts/prepare_slimpajama.py \ + --source_path /data/datasets/cerebras/SlimPajama-627B \ + --tokenizer_path /data/datasets/tokenizers/yiLlamaTokenizer \ + --destination_path /data/datasets_bak/processed/yi_llama2_tokenize_data/slimpajama_train \ + --split train --percentage 1.0 \ No newline at end of file diff --git a/tokenize_scripts/yi_llama2_tokenize_train_starcoderdata.sh b/tokenize_scripts/yi_llama2_tokenize_train_starcoderdata.sh new file mode 100644 index 0000000..77944b2 --- /dev/null +++ b/tokenize_scripts/yi_llama2_tokenize_train_starcoderdata.sh @@ -0,0 +1,5 @@ +python scripts/prepare_starcoder.py \ + --source_path /data/datasets/starcoderdata \ + --tokenizer_path /data/datasets/tokenizers/yiLlamaTokenizer \ + --destination_path /data/datasets_bak/processed/yi_llama2_tokenize_data/starcoder_train \ + --split train --percentage 1.0 diff --git a/tokenize_scripts/yi_llama2_tokenize_train_wudao.sh b/tokenize_scripts/yi_llama2_tokenize_train_wudao.sh new file mode 100644 index 0000000..d4231d5 --- /dev/null +++ b/tokenize_scripts/yi_llama2_tokenize_train_wudao.sh @@ -0,0 +1,5 @@ +python scripts/prepare_wudao.py \ + --source_path /data/datasets/WuDaoCorpus2.0_base_200G \ + --tokenizer_path /data/datasets/tokenizers/yiLlamaTokenizer \ + --destination_path /data/datasets_bak/processed/yi_llama2_tokenize_data/wudao_train \ + --split train --percentage 1.0 \ No newline at end of file diff --git a/tokenize_scripts/yi_llama2_tokenize_validation_slimpajama.sh b/tokenize_scripts/yi_llama2_tokenize_validation_slimpajama.sh new file mode 100644 index 0000000..61e4e62 --- /dev/null +++ b/tokenize_scripts/yi_llama2_tokenize_validation_slimpajama.sh @@ -0,0 +1,5 @@ +python scripts/prepare_slimpajama.py \ + --source_path /data/datasets/cerebras/SlimPajama-627B \ + --tokenizer_path /data/datasets/tokenizers/yiLlamaTokenizer \ + --destination_path /data/datasets_bak/processed/yi_llama2_tokenize_data/slimpajama_validation \ + --split validation --percentage 1.0 \ No newline at end of file