From 48587dafc4bcc630c4b8bec9e1f2d4460ac153e9 Mon Sep 17 00:00:00 2001 From: Jett Date: Wed, 15 May 2024 18:11:00 +0200 Subject: [PATCH] basic performance test --- tests/train/test_train_step.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/train/test_train_step.py b/tests/train/test_train_step.py index b27eb504..1a7db8cb 100644 --- a/tests/train/test_train_step.py +++ b/tests/train/test_train_step.py @@ -8,6 +8,7 @@ from transformers import PreTrainedModel from delphi.constants import TEST_CONFIGS_DIR +from delphi.eval.utils import get_all_and_next_logprobs from delphi.train.config import TrainingConfig from delphi.train.config.utils import build_config_from_files_and_overrides from delphi.train.train_step import accumulate_gradients, train_step @@ -91,6 +92,35 @@ def test_basic_reproducibility(dataset, model): ).all() +def test_performance(dataset, model): + """check that predictions improve with training""" + # setup + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + model_training_state = ModelTrainingState( + model=model, + optimizer=optimizer, + iter_num=0, + epoch=0, + step=0, + train_loss=0.0, + lr=1e-3, + last_training_step_time=0.0, + ) + device = torch.device("cpu") + indices = list(range(len(dataset))) + + next_logprobs_before = get_all_and_next_logprobs(model, dataset["tokens"])[1] + + train_step( + model_training_state, dataset, load_test_config("debug"), device, indices + ) + + next_logprobs_after = get_all_and_next_logprobs(model, dataset["tokens"])[1] + # should generally increse with training + frac_increased = (next_logprobs_after > next_logprobs_before).float().mean().item() + assert frac_increased > 0.95 + + def get_grads(model: PreTrainedModel) -> Float[torch.Tensor, "grads"]: grads = [ param.grad.flatten() for param in model.parameters() if param.grad is not None