Skip to content

Commit

Permalink
fixing black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
SrGonao committed Mar 2, 2024
1 parent 23db4ca commit 3c4800e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ __pycache__/
# C extensions
*.so

bin
include
lib64
pyvenv.cfg

# Distribution / packaging
.Python
build/
Expand Down
11 changes: 7 additions & 4 deletions src/delphi/train/mamba.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn.functional as F
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch
from typing import Optional


@dataclass
class MambaArgs(MambaConfig):
pass


class Mamba(MambaLMHeadModel):
def __init__(self, params:MambaArgs) -> None:
def __init__(self, params: MambaArgs) -> None:
super().__init__(params)

def forward(self, input_ids:torch.Tensor, target_ids:Optional[torch.Tensor] = None)-> torch.Tensor:
def forward(
self, input_ids: torch.Tensor, target_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
Expand Down
17 changes: 6 additions & 11 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import torch
from tqdm import tqdm

from llama2c.model import ModelArgs as Llama2ModelArgs, Transformer as Llama2Model
from llama2c import Task, model_export
from llama2c.model import ModelArgs as Llama2ModelArgs
from llama2c.model import Transformer as Llama2Model

# -----------------------------------------------------------------------------
# I/O
Expand Down Expand Up @@ -103,9 +104,7 @@

# various inits, derived attributes, I/O setup
seed = 1337
tokens_per_iter = (
gradient_accumulation_steps * batch_size * max_seq_len
)
tokens_per_iter = gradient_accumulation_steps * batch_size * max_seq_len
print(f"tokens per iteration will be: {tokens_per_iter:,}")
print(
f"breaks down as: {gradient_accumulation_steps} grad accum steps * {batch_size} batch size * {max_seq_len} max seq len"
Expand Down Expand Up @@ -200,6 +199,7 @@

# wrap model into DDP container


# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
Expand Down Expand Up @@ -289,16 +289,13 @@ def get_lr(it):
}
print(f"saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))
model_export(
model, os.path.join(out_dir, "model.bin"), version=0
)
model_export(model, os.path.join(out_dir, "model.bin"), version=0)
if iter_num == 0 and eval_only:
break

# forward backward update, with optional gradient accumulation to simulate larger batch size
# and using the GradScaler if data type is float16
for micro_step in range(gradient_accumulation_steps):

logits = model(X, Y)
loss = model.last_loss
loss = loss / gradient_accumulation_steps
Expand All @@ -322,9 +319,7 @@ def get_lr(it):
# get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
lossf = loss.item() * gradient_accumulation_steps
if local_iter_num >= 5: # let the training loop settle a bit
mfu = model.estimate_mfu(
batch_size * gradient_accumulation_steps, dt
)
mfu = model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
running_mfu = (
mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
)
Expand Down

0 comments on commit 3c4800e

Please sign in to comment.