From 678c131bce9cb310d49542b21772a5ff166bb6c8 Mon Sep 17 00:00:00 2001 From: Jett Date: Wed, 15 May 2024 12:45:16 +0200 Subject: [PATCH] removed get_xy_batches, simplified tests --- src/delphi/train/utils.py | 34 ++------ tests/train/test_train_step.py | 143 +++++++++++---------------------- 2 files changed, 52 insertions(+), 125 deletions(-) diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index 3bcdb958..2201e759 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -3,7 +3,7 @@ import math import os import time -from collections.abc import Generator +from collections.abc import Iterator from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Any, Type, cast @@ -139,22 +139,6 @@ def get_indices_for_epoch( return indices -def get_xy_batch( - dataset: Dataset, - indices: list[int], - batch_size: int, - batch_num: int, - feature_name: str, - device: torch.device, -) -> torch.Tensor: - """Get a batch of data from a dataset given a batch number and indices""" - start = batch_num * batch_size - end = (batch_num + 1) * batch_size - batch_indices = indices[start:end] - data = dataset[batch_indices][feature_name].to(device) - return data - - def gen_minibatches( dataset: Dataset, batch_size: int, @@ -163,21 +147,17 @@ def gen_minibatches( indices: list[int], device: torch.device, feature_name: str, -) -> Generator[torch.Tensor, None, None]: +) -> Iterator[torch.Tensor]: """ Generate minibatches from a dataset given a step and indices """ minibatch_size = batch_size // num_minibatches first_minibatch_num = num_minibatches * step - for i in range(num_minibatches): - yield get_xy_batch( - dataset=dataset, - indices=indices, - batch_num=first_minibatch_num + i, - batch_size=minibatch_size, - feature_name=feature_name, - device=device, - ) + for batch_num in range(first_minibatch_num, first_minibatch_num + num_minibatches): + start = batch_num * minibatch_size + end = (batch_num + 1) * minibatch_size + batch_indices = indices[start:end] + yield dataset[batch_indices][feature_name].to(device) @torch.no_grad() diff --git a/tests/train/test_train_step.py b/tests/train/test_train_step.py index 599c1b1f..b27eb504 100644 --- a/tests/train/test_train_step.py +++ b/tests/train/test_train_step.py @@ -5,6 +5,7 @@ import torch from datasets import Dataset from jaxtyping import Float +from transformers import PreTrainedModel from delphi.constants import TEST_CONFIGS_DIR from delphi.train.config import TrainingConfig @@ -12,7 +13,7 @@ from delphi.train.train_step import accumulate_gradients, train_step from delphi.train.utils import ( ModelTrainingState, - get_xy_batch, + gen_minibatches, init_model, setup_determinism, ) @@ -90,69 +91,44 @@ def test_basic_reproducibility(dataset, model): ).all() +def get_grads(model: PreTrainedModel) -> Float[torch.Tensor, "grads"]: + grads = [ + param.grad.flatten() for param in model.parameters() if param.grad is not None + ] + return torch.cat(grads) + + def test_accumulate_gradients_accumulates(dataset, model): """ check that gradient accumulation works as expected and doesn't reset on each microstep """ - # setup - indices_set_a = [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9], - ] - # different batch but idential last batch; - indices_set_a = [1, 2, 3, 4, 5, 6, 7, 8, 9] - # different batch but idential last batch (with batches of 3); - # this should result in a different accumulated gradient - indices_set_b = [7, 8, 9, 7, 8, 9, 7, 8, 9] batch_size = 3 - num_batches = len(indices_set_a) // batch_size - - batches_a = [ - get_xy_batch( - dataset=dataset, - indices=indices_set_a, - batch_size=3, - batch_num=microstep, - feature_name="tokens", - device=torch.device("cpu"), - ) - for microstep in range(num_batches) - ] - batches_b = [ - get_xy_batch( - dataset=dataset, - indices=indices_set_b, - batch_size=3, - batch_num=microstep, - feature_name="tokens", - device=torch.device("cpu"), - ) - for microstep in range(num_batches) - ] + num_batches = 3 + # first 2 mini-batches different, last mini-batch the same + indices_set_a = [1, 2, 3] + [4, 5, 6] + [7, 8, 9] + indices_set_b = [7, 8, 9] * 3 + + kwargs = dict( + dataset=dataset, + batch_size=batch_size, + num_minibatches=num_batches, + step=0, + device=torch.device("cpu"), + feature_name="tokens", + ) + batches_a = gen_minibatches(indices=indices_set_a, **kwargs) # type: ignore + batches_b = gen_minibatches(indices=indices_set_b, **kwargs) # type: ignore # accumulate - _total_loss = accumulate_gradients(model, batches_a, len(batches_a)) - - grads_a = torch.cat( - [ - param.grad.clone().detach().flatten() - for param in model.parameters() - if param.grad is not None - ] - ) + _total_loss = accumulate_gradients(model, batches_a, num_batches) + + grads_a = get_grads(model) # reset grad on model model.zero_grad() - _total_loss = accumulate_gradients(model, batches_b, len(batches_b)) - grads_b = torch.cat( - [ - param.grad.clone().detach().flatten() - for param in model.parameters() - if param.grad is not None - ] - ) + _total_loss = accumulate_gradients(model, batches_b, num_batches) + grads_b = get_grads(model) # test assert not torch.isclose(grads_a, grads_b).all() @@ -163,59 +139,30 @@ def test_accumulate_gradients_consistent(dataset, model): Validate that the gradients are consistent when the same batch is passed to accumulate_gradients """ # setup - indices_set = [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9], - ] - indices_set = list(range(1, 10)) num_batches = 3 batch_size = 3 - batches_a = [ - get_xy_batch( - dataset=dataset, - indices=indices_set, - batch_size=batch_size, - batch_num=microstep, - feature_name="tokens", - device=torch.device("cpu"), - ) - for microstep in range(num_batches) - ] - batches_aa = [ - get_xy_batch( - dataset=dataset, - indices=indices_set, - batch_size=batch_size, - batch_num=microstep, - feature_name="tokens", - device=torch.device("cpu"), - ) - for microstep in range(num_batches) - ] + kwargs = dict( + indices=list(range(1, 10)), + dataset=dataset, + batch_size=batch_size, + num_minibatches=num_batches, + step=0, + device=torch.device("cpu"), + feature_name="tokens", + ) + batches_a = gen_minibatches(**kwargs) # type: ignore + batches_aa = gen_minibatches(**kwargs) # type: ignore # accumulate - total_loss = accumulate_gradients(model, batches_a, num_batches) - - grads_a = torch.cat( - [ - param.grad.clone().detach().flatten() - for param in model.parameters() - if param.grad is not None - ] - ) + _total_loss = accumulate_gradients(model, batches_a, num_batches) + + grads_a = get_grads(model) # reset grad on model model.zero_grad() - total_loss = accumulate_gradients(model, batches_aa, num_batches) - grads_aa = torch.cat( - [ - param.grad.clone().detach().flatten() - for param in model.parameters() - if param.grad is not None - ] - ) + _total_loss = accumulate_gradients(model, batches_aa, num_batches) + grads_aa = get_grads(model) # test assert torch.isclose(grads_a, grads_aa).all()