Skip to content

Commit

Permalink
bughunt
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Mar 8, 2024
1 parent a8f7143 commit e038f31
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
7 changes: 1 addition & 6 deletions src/delphi/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,7 @@ def train_step(
f"gradient accumulation steps: {config.gradient_accumulation_steps}, "
f"num_steps: {iteration_params.num_steps}, iter_num: {model_training_state.iter_num}"
)
for micro_step in range(
min(
config.gradient_accumulation_steps,
iteration_params.num_steps - model_training_state.iter_num + 1,
)
):
for micro_step in range(config.gradient_accumulation_steps):
logits = model(X, Y)
loss = model.last_loss / config.gradient_accumulation_steps
# immediately async prefetch next batch while model is doing the forward pass on the GPU
Expand Down
27 changes: 19 additions & 8 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import time
from dataclasses import fields
from typing import cast

import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from delphi.train import wandb_utils
from delphi.train.gigaconfig import GigaConfig, debug_config
from delphi.train.gigaconfig import GigaConfig
from delphi.train.iteration_params import set_iteration_params
from delphi.train.shuffle import shuffle_list
from delphi.train.train_step import train_step
Expand All @@ -31,9 +31,12 @@ def run_training(config: GigaConfig):

# load data
print("Loading data...")
train_ds = load_delphi_training_dataset("train", limit=config.train_sample_limit)
validation_ds = load_delphi_training_dataset(
"validation", limit=config.val_sample_limit
train_ds = cast(
Dataset, load_delphi_training_dataset("train", limit=config.train_sample_limit)
)
validation_ds = cast(
Dataset,
load_delphi_training_dataset("validation", limit=config.val_sample_limit),
)

# derive iteration params (num_batches, num_steps, etc)
Expand All @@ -59,9 +62,17 @@ def run_training(config: GigaConfig):
print("Starting training...")
for epoch in range(config.max_epochs):
sampler = shuffle_list(
list(range(len(train_ds))), seed=config.batch_ordering_seed + epoch
list(range(len(train_ds))), seed=config.batch_ordering_seed + epoch # type: ignore
)
train_batch_iter = iter(
DataLoader(
train_ds,
batch_size=config.batch_size,
sampler=sampler,
pin_memory=True,
drop_last=True,
)
)
train_batch_iter = iter(DataLoader(train_ds, batch_size=config.batch_size, sampler=sampler)) # type: ignore
for _ in tqdm(range(iteration_params.num_steps)):
breaknow = train_step(
model_training_state,
Expand Down

0 comments on commit e038f31

Please sign in to comment.