diff --git a/src/delphi/train/checkpoint_step.py b/src/delphi/train/checkpoint_step.py index afd84702..f26b2617 100644 --- a/src/delphi/train/checkpoint_step.py +++ b/src/delphi/train/checkpoint_step.py @@ -38,7 +38,6 @@ def log_and_save_checkpoint( batch_size=config.batch_size, split_to_ds={"train": train_ds, "val": validation_ds}, device=run_context.device, - epoch=mts.epoch, feature_name=config.dataset.feature, ) logging.info( diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index 7db8dfef..e1d65adc 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -8,13 +8,14 @@ from tqdm import tqdm from transformers import AutoTokenizer +from delphi.train.shuffle import shuffle_epoch + from .checkpoint_step import log_and_save_checkpoint, should_save_checkpoint from .config import TrainingConfig from .run_context import RunContext from .train_step import train_step from .utils import ( ModelTrainingState, - get_indices_for_epoch, initialize_model_training_state, set_lr, setup_determinism, @@ -70,12 +71,10 @@ def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext # training loop logging.info("Starting training...") for epoch in range(config.max_epochs): - logging.info(f"Epoch: {epoch} / {config.max_epochs - 1}") - train_data_indices = get_indices_for_epoch( - dataset_size=len(train_ds), - batch_size=config.batch_size, - epoch=epoch, - ordering_seed=config.batch_ordering_seed, + logging.info(f"Epoch: {epoch+1} / {config.max_epochs}") + train_data_indices = list(range(len(train_ds))) + shuffle_epoch( + train_data_indices, seed=config.batch_ordering_seed, epoch_nr=epoch ) model_training_state.epoch = epoch for step in tqdm(range(steps_per_epoch)): diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index e0d5777d..4eadb2c6 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -131,15 +131,6 @@ def initialize_model_training_state( ) -def get_indices_for_epoch( - dataset_size: int, batch_size: int, epoch: int, ordering_seed: int -) -> list[int]: - """ """ - indices = list(range(dataset_size)) - shuffle_list(indices, seed=ordering_seed + epoch) - return indices - - def gen_minibatches( dataset: Dataset, batch_size: int, @@ -168,19 +159,13 @@ def estimate_loss( batch_size: int, split_to_ds: dict[str, Dataset], device: torch.device, - epoch: int, feature_name: str, ) -> dict[str, float]: """helps estimate an arbitrarily accurate loss over either split using many batches""" out = {} model.eval() for split, ds in split_to_ds.items(): - indices = get_indices_for_epoch( - dataset_size=len(ds), - batch_size=batch_size, - epoch=epoch, - ordering_seed=1234, - ) + indices = list(range(len(ds))) eval_iters = min(eval_iters, len(ds) // batch_size) losses = torch.zeros(eval_iters) # keep on CPU minibatches = gen_minibatches(