Skip to content

Commit

Permalink
Add a new, simpler test suite (that works!)
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 16, 2024
1 parent 808e011 commit 280bbf3
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 81 deletions.
176 changes: 176 additions & 0 deletions project/algorithms/testsuites/algo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import copy
import inspect
from abc import ABC
from logging import getLogger as get_logger
from typing import Generic, TypeVar, get_args

import lightning
import pytest
import torch
from lightning import LightningDataModule, LightningModule

from project.algorithms.example import ExampleAlgorithm
from project.configs.config import Config
from project.experiment import instantiate_algorithm
from project.utils.testutils import (
ParametrizedFixture,
fork_rng,
get_all_configs_in_group_with_target,
run_for_all_subclasses_of,
seeded_rng,
)
from project.utils.types import PyTree, is_sequence_of
from project.utils.types.protocols import DataModule, Module

logger = get_logger(__name__)

AlgorithmType = TypeVar("AlgorithmType", bound=LightningModule)


@pytest.mark.incremental
class LearningAlgorithmTests(Generic[AlgorithmType], ABC):
algorithm_name: ParametrizedFixture
# algorithm_name = algorithm_name
# datamodule_name = datamodule_name
# network_name = network_name

# network = staticmethod(network)
# trainer = staticmethod(trainer)
# datamodule = staticmethod(datamodule)

def __init_subclass__(cls) -> None:
super().__init_subclass__()
algorithm_under_test = _get_algorithm_class_from_generic_arg(cls)
# find all algorithm configs that create algorithms of this type.
configs_for_this_algorithm = get_all_configs_in_group_with_target(
"algorithm", algorithm_under_test
)
cls.algorithm_name = ParametrizedFixture(
configs_for_this_algorithm,
name="algorithm_name",
scope="session",
ids=str,
)
# TODO: Could also add a parametrize_when_used mark to parametrize the datamodule, network,
# etc, based on the type annotations of the algorithm constructor? For example, if an algo
# shows that it accepts any LightningDataModule, then parametrize it with all the datamodules,
# but if the algo says it only works with ImageNet, then parametrize with all the configs
# that have the ImageNet datamodule as their target (or a subclass of ImageNetDataModule).

def get_input_from_batch(self, batch: PyTree[torch.Tensor]):
"""Extracts the model input from a batch of data coming from the dataloader.
Overwrite this if your batches are not tuples of tensors (i.e. if your algorithm isn't a
simple supervised learning algorithm like the example).
"""
# By default, assume that the batch is a tuple of tensors.
if isinstance(batch, torch.Tensor):
return batch
if not is_sequence_of(batch, torch.Tensor):
raise NotImplementedError(
"The basic test suite assumes that a batch is a tuple of tensors, as in the"
f"supervised learning example, but the batch from the datamodule "
f"is of type {type(batch)}. You need to override this method in your test class "
"for the rest of the built-in tests to work correctly."
)
assert len(batch) >= 1
input = batch[0]
return input

def test_initialization_is_deterministic(
self,
experiment_config: Config,
datamodule: DataModule,
network: torch.nn.Module,
seed: int,
):
with seeded_rng(seed):
algorithm_1 = instantiate_algorithm(experiment_config, datamodule, network)

with seeded_rng(seed):
algorithm_2 = instantiate_algorithm(experiment_config, datamodule, network)

torch.testing.assert_close(algorithm_1.state_dict(), algorithm_2.state_dict())

def test_forward_pass_is_deterministic(
self, training_batch: tuple[torch.Tensor, ...], network: Module, seed: int
):
x = self.get_input_from_batch(training_batch)
with fork_rng():
out1 = network(x)
with fork_rng():
out2 = network(x)
torch.testing.assert_close(out1, out2)

@pytest.mark.timeout(10)
def test_backward_pass_is_deterministic(
self,
datamodule: LightningDataModule,
algorithm: LightningModule,
seed: int,
accelerator: str,
):
class GetGradientsCallback(lightning.Callback):
def __init__(self):
super().__init__()
self.grads: dict[str, torch.Tensor | None] = {}

def on_after_backward(
self, trainer: lightning.Trainer, pl_module: LightningModule
) -> None:
super().on_after_backward(trainer, pl_module)
if self.grads:
return # already collected the gradients.

for name, param in pl_module.named_parameters():
self.grads[name] = copy.deepcopy(param.grad)

algorithm_1 = copy.deepcopy(algorithm)
algorithm_2 = copy.deepcopy(algorithm)

with seeded_rng(seed):
gradients_callback = GetGradientsCallback()
trainer = lightning.Trainer(
accelerator=accelerator,
callbacks=[gradients_callback],
fast_dev_run=True,
enable_checkpointing=False,
deterministic=True,
)
trainer.fit(algorithm_1, datamodule=datamodule)
gradients_1 = gradients_callback.grads

with seeded_rng(seed):
gradients_callback = GetGradientsCallback()
trainer = lightning.Trainer(
accelerator=accelerator,
callbacks=[gradients_callback],
fast_dev_run=True,
)
trainer.fit(algorithm_2, datamodule=datamodule)
gradients_2 = gradients_callback.grads

torch.testing.assert_close(gradients_1, gradients_2)


def _get_algorithm_class_from_generic_arg(
cls: type[LearningAlgorithmTests[AlgorithmType]],
) -> type[AlgorithmType]:
"""Retrieves the class under test from the class definition (without having to set a class
attribute."""
class_under_test = get_args(cls.__orig_bases__[0])[0] # type: ignore
if inspect.isclass(class_under_test) and issubclass(class_under_test, LightningModule):
return class_under_test # type: ignore

# todo: Check if the class under test is a TypeVar, if so, check its bound.
raise RuntimeError(
"Your test class needs to pass the class under test to the generic base class.\n"
"for example: `class TestMyAlgorithm(AlgorithmTests[MyAlgorithm]):`\n"
f"(Got {class_under_test})"
)


# @parametrize_when_used(network_name, ["fcnet", "resnet18"])
@run_for_all_subclasses_of("network", torch.nn.Module)
class TestExampleAlgo(LearningAlgorithmTests[ExampleAlgorithm]):
"""Tests for the `ExampleAlgorithm`."""
4 changes: 2 additions & 2 deletions project/algorithms/testsuites/algorithm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
fork_rng,
get_all_datamodule_names_params,
get_all_network_names,
get_type_for_config_name,
get_target_of_config,
)
from project.utils.types import is_sequence_of
from project.utils.types.protocols import DataModule
Expand Down Expand Up @@ -540,7 +540,7 @@ def _skip_if_unsupported(
if not unsupported_types and not supported_types:
return

config_type: type = get_type_for_config_name(group, config_name, _cs=cs)
config_type: type = get_target_of_config(group, config_name, _cs=cs)
if not inspect.isclass(config_type):
config_return_type = typing.get_type_hints(config_type).get("return")
if config_return_type and inspect.isclass(config_return_type):
Expand Down
4 changes: 2 additions & 2 deletions project/configs/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

from project.configs.algorithm.lr_scheduler import get_all_scheduler_configs
from project.configs.algorithm.optimizer import get_all_optimizer_configs
from project.utils.testutils import seeded
from project.utils.testutils import seeded_rng


@pytest.fixture(scope="session")
def net(device: torch.device):
with seeded(123):
with seeded_rng(123):
net = torch.nn.Linear(10, 1).to(device)
return net

Expand Down
62 changes: 11 additions & 51 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from project.utils.hydra_utils import resolve_dictconfig
from project.utils.testutils import (
PARAM_WHEN_USED_MARK_NAME,
default_marks_for_config_combinations,
default_marks_for_config_name,
fork_rng,
Expand Down Expand Up @@ -527,6 +528,7 @@ def network(

@pytest.fixture(scope="function")
def algorithm(experiment_config: Config, datamodule: DataModule, network: nn.Module):
"""Fixture that creates an "algorithm" (LightningModule)."""
return instantiate_algorithm(experiment_config, datamodule=datamodule, network=network)


Expand Down Expand Up @@ -567,6 +569,8 @@ def make_torch_deterministic():


def pytest_runtest_makereport(item, call):
"""Used to setup the `pytest.mark.incremental` mark, as described in [this page](https://docs.pytest.org/en/7.1.x/example/simple.html#incremental-testing-test-steps)."""

if "incremental" in item.keywords:
# incremental marker is used
if call.excinfo is not None:
Expand All @@ -586,6 +590,7 @@ def pytest_runtest_makereport(item, call):


def pytest_runtest_setup(item):
"""Used to setup the `pytest.mark.incremental` mark, as described in [this page](https://docs.pytest.org/en/7.1.x/example/simple.html#incremental-testing-test-steps)."""
if "incremental" in item.keywords:
# retrieve the class name of the test
cls_name = str(item.cls)
Expand All @@ -602,60 +607,11 @@ def pytest_runtest_setup(item):
pytest.xfail(f"previous test failed ({test_name})")


PARAM_WHEN_USED_MARK_NAME = "parametrize_when_used"


def parametrize_when_used(
arg_name_or_fixture: str | typing.Callable, values: list
) -> pytest.MarkDecorator:
"""Fixture that applies `pytest.mark.parametrize` only when the argument is used (directly or
indirectly).
When `pytest.mark.parametrize` is applied to a class, all test methods in that class need to
use the parametrized argument, otherwise an error is raised. This function exists to work around
this and allows writing test methods that don't use the parametrized argument.
For example, this works, but would not be possible with `pytest.mark.parametrize`:
```python
import pytest
@parametrize_when_used("value", [1, 2, 3])
class TestFoo:
def test_foo(self, value):
...
def test_bar(self, value):
...
def test_something_else(self): # This will cause an error!
pass
```
Parameters
----------
arg_name_or_fixture: The name of the argument to parametrize, or a fixture to parametrize \
indirectly.
values: The values to be used to parametrize the test.
Returns
-------
A `pytest.MarkDecorator` that parametrizes the test with the given values only when the argument
is used (directly or indirectly) by the test.
"""
indirect = not isinstance(arg_name_or_fixture, str)
arg_name = (
arg_name_or_fixture
if isinstance(arg_name_or_fixture, str)
else arg_name_or_fixture.__name__
)
mark_fn = getattr(pytest.mark, PARAM_WHEN_USED_MARK_NAME)
return mark_fn(arg_name, values, indirect=indirect)


def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
"""Allows one to define custom parametrization schemes or extensions.
This is used to implement the `parametrize_when_used` mark, which allows one to parametrize an argument when it is used.
See
https://docs.pytest.org/en/7.1.x/how-to/parametrize.html#how-to-parametrize-fixtures-and-test-functions
"""
Expand Down Expand Up @@ -705,6 +661,10 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:

for arg_name, arg_values in args_to_parametrized_values.items():
# Test uses that argument, parametrize it.

# remove duplicates and order the parameters deterministically.
arg_values = sorted(set(arg_values), key=str)

# TODO: unsure what mark to pass here, if there were multiple marks for the same argument..
marker = args_to_be_parametrized_markers[arg_name][-1]
indirect = marker.kwargs.get("indirect", False)
Expand Down
Loading

0 comments on commit 280bbf3

Please sign in to comment.