Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant tests #92

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 7 additions & 125 deletions project/algorithms/testsuites/lightning_module_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,110 +51,10 @@ def forward_pass(self, algorithm: LightningModule, input: PyTree[torch.Tensor]):
return algorithm(**input)
return algorithm(input)

def test_initialization_is_deterministic(
self,
experiment_config: Config,
datamodule: lightning.LightningDataModule | None,
seed: int,
trainer: lightning.Trainer,
device: torch.device,
):
"""Checks that the weights initialization is consistent given the a random seed."""

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
algorithm_1 = instantiate_algorithm(experiment_config.algorithm, datamodule)
assert isinstance(algorithm_1, lightning.LightningModule)

with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm_1._device = device
algorithm_1.configure_model()

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
algorithm_2 = instantiate_algorithm(experiment_config.algorithm, datamodule)
assert isinstance(algorithm_2, lightning.LightningModule)

with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm_2._device = device
algorithm_2.configure_model()

torch.testing.assert_close(algorithm_1.state_dict(), algorithm_2.state_dict())

def test_forward_pass_is_deterministic(
self, forward_pass_input: Any, algorithm: AlgorithmType, seed: int
):
"""Checks that the forward pass output is consistent given the a random seed and a given
input."""

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
out1 = self.forward_pass(algorithm, forward_pass_input)
with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
out2 = self.forward_pass(algorithm, forward_pass_input)

torch.testing.assert_close(out1, out2)

# @pytest.mark.timeout(10)
def test_backward_pass_is_deterministic(
self,
datamodule: LightningDataModule,
algorithm: AlgorithmType,
seed: int,
accelerator: str,
devices: int | list[int] | Literal["auto"],
tmp_path: Path,
):
"""Check that the backward pass is reproducible given the same input, weights, and random
seed."""

algorithm_1 = copy.deepcopy(algorithm)
algorithm_2 = copy.deepcopy(algorithm)

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
gradients_callback = GetStuffFromFirstTrainingStep()
self.do_one_step_of_training(
algorithm_1,
datamodule,
accelerator,
devices=devices,
callbacks=[gradients_callback],
tmp_path=tmp_path / "run1",
)

batch_1 = gradients_callback.batch
gradients_1 = gradients_callback.grads
training_step_outputs_1 = gradients_callback.outputs

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
gradients_callback = GetStuffFromFirstTrainingStep()
self.do_one_step_of_training(
algorithm_2,
datamodule,
accelerator=accelerator,
devices=devices,
callbacks=[gradients_callback],
tmp_path=tmp_path / "run2",
)
batch_2 = gradients_callback.batch
gradients_2 = gradients_callback.grads
training_step_outputs_2 = gradients_callback.outputs

torch.testing.assert_close(batch_1, batch_2)
torch.testing.assert_close(gradients_1, gradients_2)
torch.testing.assert_close(training_step_outputs_1, training_step_outputs_2)

def test_initialization_is_reproducible(
self,
experiment_config: Config,
datamodule: lightning.LightningDataModule,
datamodule: lightning.LightningDataModule | None,
seed: int,
tensor_regression: TensorRegressionFixture,
trainer: lightning.Trainer,
Expand All @@ -165,14 +65,15 @@ def test_initialization_is_reproducible(
torch.random.manual_seed(seed)
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
assert isinstance(algorithm, lightning.LightningModule)
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer here.
with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm._device = device
algorithm.configure_model()

tensor_regression.check(
algorithm.state_dict(),
# todo: is this necessary? Shouldn't the weights be the same on CPU and GPU?
# Save the regression files on a different subfolder for each device (cpu / cuda)
additional_label=next(algorithm.parameters()).device.type,
include_gpu_name_in_stats=False,
Expand Down Expand Up @@ -236,33 +137,14 @@ def test_backward_pass_is_reproducible(
"grads": gradients_callback.grads,
"outputs": outputs,
},
default_tolerance={"rtol": 1e-5, "atol": 1e-6}, # some tolerance for the jax example.
# todo: this tolerance was mainly added for the jax example.
default_tolerance={"rtol": 1e-5, "atol": 1e-6}, # some tolerance
# todo: check if this actually differs between cpu / gpu.
# Save the regression files on a different subfolder for each device (cpu / cuda)
additional_label=accelerator if accelerator not in ["auto", "gpu"] else None,
include_gpu_name_in_stats=False,
)

def __init_subclass__(cls) -> None:
super().__init_subclass__()
# algorithm_under_test = _get_algorithm_class_from_generic_arg(cls)
# # find all algorithm configs that create algorithms of this type.
# configs_for_this_algorithm = get_all_configs_in_group_with_target(
# "algorithm", algorithm_under_test
# )
# # assert not hasattr(cls, "algorithm_config"), cls
# cls.algorithm_config = ParametrizedFixture(
# name="algorithm_config",
# values=configs_for_this_algorithm,
# ids=configs_for_this_algorithm,
# ,
# )

# TODO: Could also add a parametrize_when_used mark to parametrize the datamodule, network,
# etc, based on the type annotations of the algorithm constructor? For example, if an algo
# shows that it accepts any LightningDataModule, then parametrize it with all the datamodules,
# but if the algo says it only works with ImageNet, then parametrize with all the configs
# that have the ImageNet datamodule as their target (or a subclass of ImageNetDataModule).

@pytest.fixture(scope="session")
def forward_pass_input(self, training_batch: PyTree[torch.Tensor], device: torch.device):
"""Extracts the model input from a batch of data coming from the dataloader.
Expand Down
7 changes: 2 additions & 5 deletions project/datamodules/datamodules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from project.utils.typing_utils import is_sequence_of


# @use_overrides(["datamodule.num_workers=0"])
# @pytest.mark.timeout(25, func_only=True)
@pytest.mark.slow
@pytest.mark.parametrize(
"stage",
Expand All @@ -47,9 +45,8 @@ def test_first_batch(
stage: RunningStage,
datadir: Path,
):
# todo: skip this test if the dataset isn't already downloaded (for example on the GitHub CI).

# TODO: This causes hanging issues when tests fail, since dataloader workers aren't cleaned up.
# Note: using dataloader workers in tests can cause issues, since if a test fails, dataloader
# workers aren't always cleaned up properly.
if isinstance(datamodule, VisionDataModule) or hasattr(datamodule, "num_workers"):
datamodule.num_workers = 0 # type: ignore

Expand Down
6 changes: 6 additions & 0 deletions project/datamodules/text/text_classification_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import huggingface_hub.errors
import lightning
import pytest

Expand Down Expand Up @@ -61,6 +62,11 @@ def prepared_datamodule(
datamodule.working_path = _slurm_tmpdir_before


@pytest.mark.xfail(
raises=huggingface_hub.errors.HfHubHTTPError,
strict=False,
reason="Can sometimes get 'Too many requests for url'",
)
@pytest.mark.parametrize(datamodule.__name__, datamodule_configs, indirect=True)
def test_dataset_location(
prepared_datamodule: TextClassificationDataModule,
Expand Down