Skip to content

Commit

Permalink
use shuffle_epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 22, 2024
1 parent 3cb823f commit 12e38cc
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 24 deletions.
1 change: 0 additions & 1 deletion src/delphi/train/checkpoint_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def log_and_save_checkpoint(
batch_size=config.batch_size,
split_to_ds={"train": train_ds, "val": validation_ds},
device=run_context.device,
epoch=mts.epoch,
feature_name=config.dataset.feature,
)
logging.info(
Expand Down
13 changes: 6 additions & 7 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from tqdm import tqdm
from transformers import AutoTokenizer

from delphi.train.shuffle import shuffle_epoch

from .checkpoint_step import log_and_save_checkpoint, should_save_checkpoint
from .config import TrainingConfig
from .run_context import RunContext
from .train_step import train_step
from .utils import (
ModelTrainingState,
get_indices_for_epoch,
initialize_model_training_state,
set_lr,
setup_determinism,
Expand Down Expand Up @@ -70,12 +71,10 @@ def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext
# training loop
logging.info("Starting training...")
for epoch in range(config.max_epochs):
logging.info(f"Epoch: {epoch} / {config.max_epochs - 1}")
train_data_indices = get_indices_for_epoch(
dataset_size=len(train_ds),
batch_size=config.batch_size,
epoch=epoch,
ordering_seed=config.batch_ordering_seed,
logging.info(f"Epoch: {epoch+1} / {config.max_epochs}")
train_data_indices = list(range(len(train_ds)))
shuffle_epoch(
train_data_indices, seed=config.batch_ordering_seed, epoch_nr=epoch
)
model_training_state.epoch = epoch
for step in tqdm(range(steps_per_epoch)):
Expand Down
17 changes: 1 addition & 16 deletions src/delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,6 @@ def initialize_model_training_state(
)


def get_indices_for_epoch(
dataset_size: int, batch_size: int, epoch: int, ordering_seed: int
) -> list[int]:
""" """
indices = list(range(dataset_size))
shuffle_list(indices, seed=ordering_seed + epoch)
return indices


def gen_minibatches(
dataset: Dataset,
batch_size: int,
Expand Down Expand Up @@ -168,19 +159,13 @@ def estimate_loss(
batch_size: int,
split_to_ds: dict[str, Dataset],
device: torch.device,
epoch: int,
feature_name: str,
) -> dict[str, float]:
"""helps estimate an arbitrarily accurate loss over either split using many batches"""
out = {}
model.eval()
for split, ds in split_to_ds.items():
indices = get_indices_for_epoch(
dataset_size=len(ds),
batch_size=batch_size,
epoch=epoch,
ordering_seed=1234,
)
indices = list(range(len(ds)))
eval_iters = min(eval_iters, len(ds) // batch_size)
losses = torch.zeros(eval_iters) # keep on CPU
minibatches = gen_minibatches(
Expand Down

0 comments on commit 12e38cc

Please sign in to comment.