Skip to content

Commit

Permalink
basic performance test
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 20, 2024
1 parent 2fa7396 commit 48587da
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/train/test_train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers import PreTrainedModel

from delphi.constants import TEST_CONFIGS_DIR
from delphi.eval.utils import get_all_and_next_logprobs
from delphi.train.config import TrainingConfig
from delphi.train.config.utils import build_config_from_files_and_overrides
from delphi.train.train_step import accumulate_gradients, train_step
Expand Down Expand Up @@ -91,6 +92,35 @@ def test_basic_reproducibility(dataset, model):
).all()


def test_performance(dataset, model):
"""check that predictions improve with training"""
# setup
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model_training_state = ModelTrainingState(
model=model,
optimizer=optimizer,
iter_num=0,
epoch=0,
step=0,
train_loss=0.0,
lr=1e-3,
last_training_step_time=0.0,
)
device = torch.device("cpu")
indices = list(range(len(dataset)))

next_logprobs_before = get_all_and_next_logprobs(model, dataset["tokens"])[1]

train_step(
model_training_state, dataset, load_test_config("debug"), device, indices
)

next_logprobs_after = get_all_and_next_logprobs(model, dataset["tokens"])[1]
# should generally increse with training
frac_increased = (next_logprobs_after > next_logprobs_before).float().mean().item()
assert frac_increased > 0.95


def get_grads(model: PreTrainedModel) -> Float[torch.Tensor, "grads"]:
grads = [
param.grad.flatten() for param in model.parameters() if param.grad is not None
Expand Down

0 comments on commit 48587da

Please sign in to comment.