diff --git a/src/delphi/train/train_step.py b/src/delphi/train/train_step.py index 0fb55ff1..5956116c 100644 --- a/src/delphi/train/train_step.py +++ b/src/delphi/train/train_step.py @@ -77,12 +77,7 @@ def train_step( f"gradient accumulation steps: {config.gradient_accumulation_steps}, " f"num_steps: {iteration_params.num_steps}, iter_num: {model_training_state.iter_num}" ) - for micro_step in range( - min( - config.gradient_accumulation_steps, - iteration_params.num_steps - model_training_state.iter_num + 1, - ) - ): + for micro_step in range(config.gradient_accumulation_steps): logits = model(X, Y) loss = model.last_loss / config.gradient_accumulation_steps # immediately async prefetch next batch while model is doing the forward pass on the GPU diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index 96a1f358..583ad351 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -1,13 +1,13 @@ import os -import time from dataclasses import fields +from typing import cast import torch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from delphi.train import wandb_utils -from delphi.train.gigaconfig import GigaConfig, debug_config +from delphi.train.gigaconfig import GigaConfig from delphi.train.iteration_params import set_iteration_params from delphi.train.shuffle import shuffle_list from delphi.train.train_step import train_step @@ -31,9 +31,12 @@ def run_training(config: GigaConfig): # load data print("Loading data...") - train_ds = load_delphi_training_dataset("train", limit=config.train_sample_limit) - validation_ds = load_delphi_training_dataset( - "validation", limit=config.val_sample_limit + train_ds = cast( + Dataset, load_delphi_training_dataset("train", limit=config.train_sample_limit) + ) + validation_ds = cast( + Dataset, + load_delphi_training_dataset("validation", limit=config.val_sample_limit), ) # derive iteration params (num_batches, num_steps, etc) @@ -59,9 +62,17 @@ def run_training(config: GigaConfig): print("Starting training...") for epoch in range(config.max_epochs): sampler = shuffle_list( - list(range(len(train_ds))), seed=config.batch_ordering_seed + epoch + list(range(len(train_ds))), seed=config.batch_ordering_seed + epoch # type: ignore + ) + train_batch_iter = iter( + DataLoader( + train_ds, + batch_size=config.batch_size, + sampler=sampler, + pin_memory=True, + drop_last=True, + ) ) - train_batch_iter = iter(DataLoader(train_ds, batch_size=config.batch_size, sampler=sampler)) # type: ignore for _ in tqdm(range(iteration_params.num_steps)): breaknow = train_step( model_training_state,