From 1fb39c876d17b4234e9181a12aee88324edcb18c Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 8 Mar 2024 21:07:16 +0100 Subject: [PATCH] Adding mamba implementation --- scripts/sample_mamba.json | 33 +++++++++++++++ scripts/train.py | 9 ++-- src/delphi/llama2.py | 1 - src/delphi/mamba.py | 11 +++-- src/delphi/train/architectures.py | 20 ++++++++- src/delphi/train/mamba.py | 4 ++ src/delphi/train/training_old.py | 69 +++++++++++++++++++++---------- 7 files changed, 115 insertions(+), 32 deletions(-) create mode 100644 scripts/sample_mamba.json diff --git a/scripts/sample_mamba.json b/scripts/sample_mamba.json new file mode 100644 index 00000000..d64524c0 --- /dev/null +++ b/scripts/sample_mamba.json @@ -0,0 +1,33 @@ +{ + "out_dir": "out", + "eval_interval": 500, + "log_interval": 1, + "eval_iters": 10, + "eval_only": false, + "architecture": "ModelTypes.MAMBA", + "always_save_checkpoint": false, + "init_from": "scratch", + "wandb_log": true, + "wandb_entity": "g-spaulo", + "wandb_project": "delphi", + "wandb_run_name": "2024_03_07_17_43_09", + "batch_size": 64, + "max_seq_len": 512, + "vocab_size": 4096, + "dim": 48, + "n_layers": 2, + "multiple_of": 32, + "dropout": 0.0, + "gradient_accumulation_steps": 4, + "learning_rate": 0.0005, + "max_epochs": 2, + "weight_decay": 0.1, + "beta1": 0.9, + "beta2": 0.95, + "grad_clip": 1.0, + "decay_lr": true, + "warmup_iters": 1000, + "min_lr": 0.0, + "train_sample_limit": 256, + "val_sample_limit": -1 +} \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py index 03ffc77b..cab173bf 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,7 +1,8 @@ -from delphi.train.training import DDP,TrainingConfig, model_initialization, train_loop -from delphi.train.utils import load_config from argparse import ArgumentParser +from delphi.train.training import DDP, TrainingConfig, model_initialization, train_loop +from delphi.train.utils import load_config + def main(): parser = ArgumentParser() @@ -11,5 +12,5 @@ def main(): config = load_config(args.config) TrainConf = TrainingConfig(config) - model,model_args = model_initialization(config) - train_loop(model, TrainConf) \ No newline at end of file + model, model_args = model_initialization(config) + train_loop(model, TrainConf) diff --git a/src/delphi/llama2.py b/src/delphi/llama2.py index ca204298..5d8e8625 100644 --- a/src/delphi/llama2.py +++ b/src/delphi/llama2.py @@ -9,6 +9,5 @@ class LLaMA2Args(ModelArgs): class LLaMA2(Transformer): - def __init__(self, params) -> None: super().__init__(params) diff --git a/src/delphi/mamba.py b/src/delphi/mamba.py index 32257779..87710f6f 100644 --- a/src/delphi/mamba.py +++ b/src/delphi/mamba.py @@ -1,7 +1,9 @@ from dataclasses import dataclass + +import torch.nn.functional as F from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel -import torch.nn.functional as F + @dataclass class MambaArgs(MambaConfig): @@ -9,7 +11,6 @@ class MambaArgs(MambaConfig): class Mamba(MambaLMHeadModel): - def __init__(self, params) -> None: super().__init__(params) @@ -20,6 +21,8 @@ def forward(self, input_ids, target_ids=None): """ hidden_states = self.backbone(input_ids) logits = self.lm_head(hidden_states) - self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_ids.view(-1), ignore_index=-1) - + self.last_loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), target_ids.view(-1), ignore_index=-1 + ) + return logits diff --git a/src/delphi/train/architectures.py b/src/delphi/train/architectures.py index 134bfde3..93f68d3b 100644 --- a/src/delphi/train/architectures.py +++ b/src/delphi/train/architectures.py @@ -5,6 +5,8 @@ from llama2c.model import ModelArgs as Llama2ModelArgs from llama2c.model import Transformer as Llama2Model +from delphi.train.mamba import Mamba, MambaArgs + class ModelTypes: LLAMA2C = "llama2c" @@ -35,6 +37,12 @@ def initialize_model(**model_args) -> torch.nn.Module: llama2_arg_names = {f.name for f in fields(Llama2ModelArgs)} llama2_args = {k: v for k, v in model_args.items() if k in llama2_arg_names} return Llama2Model(Llama2ModelArgs(**llama2_args)) + elif model_args["architecture"] == ModelTypes.MAMBA: + config = MambaArgs() + config.d_model = model_args["model_dim"] + config.vocab_size = model_args["vocab_size"] + config.n_layer = model_args["n_layers"] + return Mamba(config) else: raise NotImplementedError( f"Architecture {model_args['architecture']} not yet implemented" @@ -59,8 +67,16 @@ def load_model(model_args, checkpoint) -> torch.nn.Module: state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) model.load_state_dict(state_dict) return model - else: - raise NotImplementedError(f"Architecture {arch} not yet implemented") + if arch == ModelTypes.MAMBA: + config = MambaArgs() + config.d_model = model_args["model_dim"] + config.vocab_size = model_args["vocab_size"] + config.n_layer = model_args["n_layers"] + state_dict = checkpoint["model"] + model = MambaArgs(config) + model.load_state_dict(state_dict) + return model + raise NotImplementedError(f"Architecture {arch} not yet implemented") def export_model(model, model_architecture, output_path): diff --git a/src/delphi/train/mamba.py b/src/delphi/train/mamba.py index 04e8154e..eb222e43 100644 --- a/src/delphi/train/mamba.py +++ b/src/delphi/train/mamba.py @@ -30,3 +30,7 @@ def forward( ) return logits + + def estimate_mfu(self, fwdbwd_per_iter, dt): + """I don't want to implement this""" + return 0 diff --git a/src/delphi/train/training_old.py b/src/delphi/train/training_old.py index d2d5e0e1..e1ec0661 100644 --- a/src/delphi/train/training_old.py +++ b/src/delphi/train/training_old.py @@ -24,13 +24,11 @@ from functools import partial import torch -from torch.distributed import destroy_process_group, init_process_group -from torch.nn.parallel import DistributedDataParallel as DDP - from llama2 import LLaMA2, LLaMA2Args -from llama2c import model_export, Task - +from llama2c import Task, model_export from shuffle import shuffle_epoch +from torch.distributed import destroy_process_group, init_process_group +from torch.nn.parallel import DistributedDataParallel as DDP # ----------------------------------------------------------------------------- # I/O @@ -48,8 +46,10 @@ # data batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size max_seq_len = 256 -vocab_source = "llama2" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained -vocab_size = 32000 # the Llama 2 tokenizer has 32K tokens +vocab_source = ( + "llama2" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained +) +vocab_size = 32000 # the Llama 2 tokenizer has 32K tokens # model dim = 288 n_layers = 6 @@ -69,7 +69,9 @@ decay_lr = True # whether to decay the learning rate warmup_iters = 1000 # how many steps to warm up for # system -device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks +device = ( + "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks +) dtype = "bfloat16" # float32|bfloat16|float16 compile = True # use PyTorch 2.0 to compile the model to be faster # ----------------------------------------------------------------------------- @@ -90,7 +92,9 @@ # validating checks assert vocab_source in ["llama2", "custom"] -assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens" +assert ( + vocab_source == "custom" or vocab_size == 32000 +), "The vocab from Meta has 32K tokens" # various inits, derived attributes, I/O setup seed = 1337 @@ -113,10 +117,14 @@ master_process = True seed_offset = 0 ddp_world_size = 1 -tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len +tokens_per_iter = ( + gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len +) if master_process: print(f"tokens per iteration will be: {tokens_per_iter:,}") - print(f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len") + print( + f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len" + ) if master_process: os.makedirs(out_dir, exist_ok=True) @@ -125,7 +133,11 @@ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast # note: float16 data type will automatically use a GradScaler -ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] +ptdtype = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +}[dtype] ctx = ( nullcontext() if device_type == "cpu" @@ -141,7 +153,7 @@ vocab_source=vocab_source, device=device, num_workers=0, - seed=seed + seed=seed, ) # init these up here, can override if init_from='resume' (i.e. from a checkpoint) @@ -172,7 +184,15 @@ checkpoint_model_args = checkpoint["model_args"] # force these config attributes to be equal otherwise we can't even resume training # the rest of the attributes (e.g. dropout) can stay as desired from command line - for k in ["dim", "n_layers", "n_heads", "n_kv_heads", "vocab_size", "multiple_of", "max_seq_len"]: + for k in [ + "dim", + "n_layers", + "n_heads", + "n_kv_heads", + "vocab_size", + "multiple_of", + "max_seq_len", + ]: model_args[k] = checkpoint_model_args[k] # create the model gptconf = LLaMA2Args(**model_args) @@ -193,7 +213,9 @@ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) # optimizer -optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) +optimizer = model.configure_optimizers( + weight_decay, learning_rate, (beta1, beta2), device_type +) if init_from == "resume" and "optimizer" in checkpoint: optimizer.load_state_dict(checkpoint["optimizer"]) checkpoint = None # free up memory @@ -212,6 +234,7 @@ model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"} model = DDP(model, device_ids=[ddp_local_rank]) + # helps estimate an arbitrarily accurate loss over either split using many batches @torch.no_grad() def estimate_loss(): @@ -230,6 +253,7 @@ def estimate_loss(): model.train() return out + # learning rate decay scheduler (cosine with warmup) def get_lr(it): # 1) linear warmup for warmup_iters steps @@ -245,13 +269,11 @@ def get_lr(it): return min_lr + coeff * (learning_rate - min_lr) - # logging if wandb_log and master_process: import wandb - wandb.init(project=wandb_project, name=wandb_run_name, config=config) - + wandb.init(project=wandb_project, name=wandb_run_name, config=config) # training loop @@ -270,7 +292,9 @@ def get_lr(it): # evaluate the loss on train/val sets and write checkpoints if iter_num % eval_interval == 0 and master_process: losses = estimate_loss() - print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") + print( + f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" + ) if wandb_log: try: wandb.log( @@ -281,7 +305,8 @@ def get_lr(it): "loss/val": losses["val"], "lr": lr, "mfu": running_mfu * 100, # convert to percentage - }, step = iter_num + }, + step=iter_num, ) except Exception as e: print(f"logging to wandb failed: {e}") @@ -310,7 +335,9 @@ def get_lr(it): # the official way to do this is with model.no_sync() context manager, but # I really dislike that this bloats the code and forces us to repeat code # looking at the source of that context manager, it just toggles this variable - model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 + model.require_backward_grad_sync = ( + micro_step == gradient_accumulation_steps - 1 + ) with ctx: logits = model(X, Y) loss = raw_model.last_loss