Skip to content

Commit

Permalink
Adding mamba implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
SrGonao committed Mar 8, 2024
1 parent cb653cb commit 1fb39c8
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 32 deletions.
33 changes: 33 additions & 0 deletions scripts/sample_mamba.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"out_dir": "out",
"eval_interval": 500,
"log_interval": 1,
"eval_iters": 10,
"eval_only": false,
"architecture": "ModelTypes.MAMBA",
"always_save_checkpoint": false,
"init_from": "scratch",
"wandb_log": true,
"wandb_entity": "g-spaulo",
"wandb_project": "delphi",
"wandb_run_name": "2024_03_07_17_43_09",
"batch_size": 64,
"max_seq_len": 512,
"vocab_size": 4096,
"dim": 48,
"n_layers": 2,
"multiple_of": 32,
"dropout": 0.0,
"gradient_accumulation_steps": 4,
"learning_rate": 0.0005,
"max_epochs": 2,
"weight_decay": 0.1,
"beta1": 0.9,
"beta2": 0.95,
"grad_clip": 1.0,
"decay_lr": true,
"warmup_iters": 1000,
"min_lr": 0.0,
"train_sample_limit": 256,
"val_sample_limit": -1
}
9 changes: 5 additions & 4 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from delphi.train.training import DDP,TrainingConfig, model_initialization, train_loop
from delphi.train.utils import load_config
from argparse import ArgumentParser

from delphi.train.training import DDP, TrainingConfig, model_initialization, train_loop
from delphi.train.utils import load_config


def main():
parser = ArgumentParser()
Expand All @@ -11,5 +12,5 @@ def main():

config = load_config(args.config)
TrainConf = TrainingConfig(config)
model,model_args = model_initialization(config)
train_loop(model, TrainConf)
model, model_args = model_initialization(config)
train_loop(model, TrainConf)
1 change: 0 additions & 1 deletion src/delphi/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ class LLaMA2Args(ModelArgs):


class LLaMA2(Transformer):

def __init__(self, params) -> None:
super().__init__(params)
11 changes: 7 additions & 4 deletions src/delphi/mamba.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from dataclasses import dataclass

import torch.nn.functional as F
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch.nn.functional as F


@dataclass
class MambaArgs(MambaConfig):
pass


class Mamba(MambaLMHeadModel):

def __init__(self, params) -> None:
super().__init__(params)

Expand All @@ -20,6 +21,8 @@ def forward(self, input_ids, target_ids=None):
"""
hidden_states = self.backbone(input_ids)
logits = self.lm_head(hidden_states)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_ids.view(-1), ignore_index=-1)

self.last_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), target_ids.view(-1), ignore_index=-1
)

return logits
20 changes: 18 additions & 2 deletions src/delphi/train/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from llama2c.model import ModelArgs as Llama2ModelArgs
from llama2c.model import Transformer as Llama2Model

from delphi.train.mamba import Mamba, MambaArgs


class ModelTypes:
LLAMA2C = "llama2c"
Expand Down Expand Up @@ -35,6 +37,12 @@ def initialize_model(**model_args) -> torch.nn.Module:
llama2_arg_names = {f.name for f in fields(Llama2ModelArgs)}
llama2_args = {k: v for k, v in model_args.items() if k in llama2_arg_names}
return Llama2Model(Llama2ModelArgs(**llama2_args))
elif model_args["architecture"] == ModelTypes.MAMBA:
config = MambaArgs()
config.d_model = model_args["model_dim"]
config.vocab_size = model_args["vocab_size"]
config.n_layer = model_args["n_layers"]
return Mamba(config)
else:
raise NotImplementedError(
f"Architecture {model_args['architecture']} not yet implemented"
Expand All @@ -59,8 +67,16 @@ def load_model(model_args, checkpoint) -> torch.nn.Module:
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
model.load_state_dict(state_dict)
return model
else:
raise NotImplementedError(f"Architecture {arch} not yet implemented")
if arch == ModelTypes.MAMBA:
config = MambaArgs()
config.d_model = model_args["model_dim"]
config.vocab_size = model_args["vocab_size"]
config.n_layer = model_args["n_layers"]
state_dict = checkpoint["model"]
model = MambaArgs(config)
model.load_state_dict(state_dict)
return model
raise NotImplementedError(f"Architecture {arch} not yet implemented")


def export_model(model, model_architecture, output_path):
Expand Down
4 changes: 4 additions & 0 deletions src/delphi/train/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ def forward(
)

return logits

def estimate_mfu(self, fwdbwd_per_iter, dt):
"""I don't want to implement this"""
return 0
69 changes: 48 additions & 21 deletions src/delphi/train/training_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
from functools import partial

import torch
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP

from llama2 import LLaMA2, LLaMA2Args
from llama2c import model_export, Task

from llama2c import Task, model_export
from shuffle import shuffle_epoch
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP

# -----------------------------------------------------------------------------
# I/O
Expand All @@ -48,8 +46,10 @@
# data
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
max_seq_len = 256
vocab_source = "llama2" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
vocab_size = 32000 # the Llama 2 tokenizer has 32K tokens
vocab_source = (
"llama2" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
)
vocab_size = 32000 # the Llama 2 tokenizer has 32K tokens
# model
dim = 288
n_layers = 6
Expand All @@ -69,7 +69,9 @@
decay_lr = True # whether to decay the learning rate
warmup_iters = 1000 # how many steps to warm up for
# system
device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
device = (
"cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
)
dtype = "bfloat16" # float32|bfloat16|float16
compile = True # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
Expand All @@ -90,7 +92,9 @@

# validating checks
assert vocab_source in ["llama2", "custom"]
assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens"
assert (
vocab_source == "custom" or vocab_size == 32000
), "The vocab from Meta has 32K tokens"

# various inits, derived attributes, I/O setup
seed = 1337
Expand All @@ -113,10 +117,14 @@
master_process = True
seed_offset = 0
ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
tokens_per_iter = (
gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
)
if master_process:
print(f"tokens per iteration will be: {tokens_per_iter:,}")
print(f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len")
print(
f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len"
)

if master_process:
os.makedirs(out_dir, exist_ok=True)
Expand All @@ -125,7 +133,11 @@
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
ptdtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[dtype]
ctx = (
nullcontext()
if device_type == "cpu"
Expand All @@ -141,7 +153,7 @@
vocab_source=vocab_source,
device=device,
num_workers=0,
seed=seed
seed=seed,
)

# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
Expand Down Expand Up @@ -172,7 +184,15 @@
checkpoint_model_args = checkpoint["model_args"]
# force these config attributes to be equal otherwise we can't even resume training
# the rest of the attributes (e.g. dropout) can stay as desired from command line
for k in ["dim", "n_layers", "n_heads", "n_kv_heads", "vocab_size", "multiple_of", "max_seq_len"]:
for k in [
"dim",
"n_layers",
"n_heads",
"n_kv_heads",
"vocab_size",
"multiple_of",
"max_seq_len",
]:
model_args[k] = checkpoint_model_args[k]
# create the model
gptconf = LLaMA2Args(**model_args)
Expand All @@ -193,7 +213,9 @@
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
optimizer = model.configure_optimizers(
weight_decay, learning_rate, (beta1, beta2), device_type
)
if init_from == "resume" and "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
checkpoint = None # free up memory
Expand All @@ -212,6 +234,7 @@
model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
model = DDP(model, device_ids=[ddp_local_rank])


# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
Expand All @@ -230,6 +253,7 @@ def estimate_loss():
model.train()
return out


# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
# 1) linear warmup for warmup_iters steps
Expand All @@ -245,13 +269,11 @@ def get_lr(it):
return min_lr + coeff * (learning_rate - min_lr)



# logging
if wandb_log and master_process:
import wandb
wandb.init(project=wandb_project, name=wandb_run_name, config=config)


wandb.init(project=wandb_project, name=wandb_run_name, config=config)


# training loop
Expand All @@ -270,7 +292,9 @@ def get_lr(it):
# evaluate the loss on train/val sets and write checkpoints
if iter_num % eval_interval == 0 and master_process:
losses = estimate_loss()
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
print(
f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
)
if wandb_log:
try:
wandb.log(
Expand All @@ -281,7 +305,8 @@ def get_lr(it):
"loss/val": losses["val"],
"lr": lr,
"mfu": running_mfu * 100, # convert to percentage
}, step = iter_num
},
step=iter_num,
)
except Exception as e:
print(f"logging to wandb failed: {e}")
Expand Down Expand Up @@ -310,7 +335,9 @@ def get_lr(it):
# the official way to do this is with model.no_sync() context manager, but
# I really dislike that this bloats the code and forces us to repeat code
# looking at the source of that context manager, it just toggles this variable
model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
model.require_backward_grad_sync = (
micro_step == gradient_accumulation_steps - 1
)
with ctx:
logits = model(X, Y)
loss = raw_model.last_loss
Expand Down

0 comments on commit 1fb39c8

Please sign in to comment.