From bbddba865f0674238d99c14c505c3778cddc647c Mon Sep 17 00:00:00 2001 From: JaiDhyani Date: Thu, 7 Mar 2024 08:18:42 -0800 Subject: [PATCH] abstract away llama2 specifics to enable mamba implementation --- src/delphi/train/architectures.py | 77 +++++++++++++++++++++++++++++++ src/delphi/train/gigaconfig.py | 5 ++ src/delphi/train/utils.py | 51 ++++++-------------- 3 files changed, 97 insertions(+), 36 deletions(-) create mode 100644 src/delphi/train/architectures.py diff --git a/src/delphi/train/architectures.py b/src/delphi/train/architectures.py new file mode 100644 index 00000000..83453151 --- /dev/null +++ b/src/delphi/train/architectures.py @@ -0,0 +1,77 @@ +import os +from dataclasses import fields + +import torch +from llama2c import model_export +from llama2c.model import ModelArgs as Llama2ModelArgs +from llama2c.model import Transformer as Llama2Model + + +class ModelTypes: + LLAMA2C = "llama2c" + MAMBA = "mamba" + + +args_to_load_from_checkpoint = { + ModelTypes.LLAMA2C: [ + "dim", + "n_layers", + "n_heads", + "n_kv_heads", + "vocab_size", + "multiple_of", + "max_seq_len", + ], + ModelTypes.MAMBA: [ + "n_layers", + "model_dim", + ], +} + + +def initialize_model(**model_args) -> torch.nn.Module: + if model_args["architecture"] == ModelTypes.LLAMA2C: + llama_model_args = model_args.copy() + # filter model_args for fields in Llama2ModelArgs + llama2_arg_names = {f.name for f in fields(Llama2ModelArgs)} + llama2_args = { + k: v for k, v in llama_model_args.items() if k in llama2_arg_names + } + return Llama2Model(Llama2ModelArgs(**llama2_args)) + else: + raise NotImplementedError( + f"Architecture {model_args['architecture']} not yet implemented" + ) + + +def load_model(model_args, checkpoint) -> torch.nn.Module: + arch = model_args["architecture"] + checkpoint_model_args = checkpoint["model_args"] + for k in args_to_load_from_checkpoint[arch]: + model_args[k] = checkpoint_model_args[k] + if arch == ModelTypes.LLAMA2C: + # create the model + gptconf = Llama2ModelArgs(**model_args) + model = Llama2Model(gptconf) + state_dict = checkpoint["model"] + # fix the keys of the state dictionary :( + # honestly no idea how checkpoints sometimes get this prefix, have to debug more + unwanted_prefix = "_orig_mod." + for k, v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) + model.load_state_dict(state_dict) + return model + else: + raise NotImplementedError(f"Architecture {arch} not yet implemented") + + +def export_model(model, model_architecture, output_path): + if model_architecture == ModelTypes.LLAMA2C: + model_export( + model, + output_path, + version=0, + ) + else: + raise NotImplementedError("only llama2c model export is supported for now") diff --git a/src/delphi/train/gigaconfig.py b/src/delphi/train/gigaconfig.py index 024d82b1..095157b2 100644 --- a/src/delphi/train/gigaconfig.py +++ b/src/delphi/train/gigaconfig.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from datetime import datetime +from delphi.train.architectures import ModelTypes + @dataclass class GigaConfig: @@ -9,6 +11,9 @@ class GigaConfig: into several smaller configs. """ + # model architecture + architecture = ModelTypes.LLAMA2C + # I/O out_dir: str = "out" eval_interval: int = 2000 diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index 523bc7e8..a7381a99 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -6,15 +6,18 @@ from typing import Any, cast import torch -from llama2c import model_export -from llama2c.model import ModelArgs as Llama2ModelArgs -from llama2c.model import Transformer as Llama2Model from torch import Tensor from torch.optim import AdamW from torch.utils.data import DataLoader, Dataset from delphi import constants from delphi.eval.utils import load_delphi_dataset +from delphi.train.architectures import ( + ModelTypes, + export_model, + initialize_model, + load_model, +) from delphi.train.gigaconfig import GigaConfig from delphi.train.tokenized_chunks_dataset import TokenizedChunksDataset @@ -37,43 +40,16 @@ def get_device() -> str: @dataclass class ModelMidTrain: # hack for packing the values touched by resume_model in a single object - model: Llama2Model + model: torch.nn.Module iter_num: int best_val_loss: float checkpoint: Any -def initialize_model(**model_args) -> Llama2Model: - return Llama2Model(Llama2ModelArgs(**model_args)) - - def resume_model(resume_from_path: Path, device: str, **model_args) -> ModelMidTrain: ckpt_path = resume_from_path / "ckpt.pt" checkpoint = torch.load(ckpt_path, map_location=device) - checkpoint_model_args = checkpoint["model_args"] - # force these config attributes to be equal otherwise we can't even resume training - # the rest of the attributes (e.g. dropout) can stay as desired from command line - for k in [ - "dim", - "n_layers", - "n_heads", - "n_kv_heads", - "vocab_size", - "multiple_of", - "max_seq_len", - ]: - model_args[k] = checkpoint_model_args[k] - # create the model - gptconf = Llama2ModelArgs(**model_args) - model = Llama2Model(gptconf) - state_dict = checkpoint["model"] - # fix the keys of the state dictionary :( - # honestly no idea how checkpoints sometimes get this prefix, have to debug more - unwanted_prefix = "_orig_mod." - for k, v in list(state_dict.items()): - if k.startswith(unwanted_prefix): - state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) - model.load_state_dict(state_dict) + model = load_model(model_args, checkpoint) iter_num = checkpoint["iter_num"] best_val_loss = checkpoint["best_val_loss"] return ModelMidTrain( @@ -85,7 +61,7 @@ def resume_model(resume_from_path: Path, device: str, **model_args) -> ModelMidT def get_optimizer( - model: Llama2Model, + model: torch.nn.Module, config: GigaConfig, device: str, checkpoint=None, @@ -104,7 +80,7 @@ def get_optimizer( @torch.no_grad() def estimate_loss( - model: Llama2Model, + model: torch.nn.Module, eval_iters: int, batch_size: int, split_to_ds: dict[str, Dataset], @@ -193,8 +169,10 @@ def save_checkpoint_if_needed(eval_data: EvalData): } print(f"saving checkpoint to {eval_data.config.out_dir}") torch.save(checkpoint, os.path.join(eval_data.config.out_dir, "ckpt.pt")) - model_export( - eval_data.model, os.path.join(eval_data.config.out_dir, "model.bin"), version=0 + export_model( + eval_data.model, + eval_data.model_args["architecture"], + os.path.join(eval_data.config.out_dir, "model.bin"), ) @@ -211,6 +189,7 @@ def load_model_training_state(config: GigaConfig, device: str) -> ModelTrainingS iter_num = 0 best_val_loss = 1e9 model_args = dict( + architecture=config.architecture, dim=config.dim, n_layers=config.n_layers, n_heads=config.n_heads,