Skip to content

Commit

Permalink
Fix bug with default device and configure_model
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Nov 28, 2024
1 parent 5a2ee40 commit 284011c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
14 changes: 8 additions & 6 deletions project/algorithms/testsuites/lightning_module_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_initialization_is_deterministic(
datamodule: lightning.LightningDataModule | None,
seed: int,
trainer: lightning.Trainer,
device: torch.device,
):
"""Checks that the weights initialization is consistent given the a random seed."""

Expand All @@ -65,21 +66,21 @@ def test_initialization_is_deterministic(
algorithm_1 = instantiate_algorithm(experiment_config.algorithm, datamodule)
assert isinstance(algorithm_1, lightning.LightningModule)

with trainer.init_module():
with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm_1._device = torch.get_default_device()
algorithm_1._device = device
algorithm_1.configure_model()

with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
algorithm_2 = instantiate_algorithm(experiment_config.algorithm, datamodule)
assert isinstance(algorithm_2, lightning.LightningModule)

with trainer.init_module():
with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm_2._device = torch.get_default_device()
algorithm_2._device = device
algorithm_2.configure_model()

torch.testing.assert_close(algorithm_1.state_dict(), algorithm_2.state_dict())
Expand Down Expand Up @@ -157,16 +158,17 @@ def test_initialization_is_reproducible(
seed: int,
tensor_regression: TensorRegressionFixture,
trainer: lightning.Trainer,
device: torch.device,
):
"""Check that the network initialization is reproducible given the same random seed."""
with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))):
torch.random.manual_seed(seed)
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
assert isinstance(algorithm, lightning.LightningModule)
with trainer.init_module():
with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm._device = torch.get_default_device()
algorithm._device = device
algorithm.configure_model()

tensor_regression.check(
Expand Down
5 changes: 3 additions & 2 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,16 @@ def algorithm(
datamodule: lightning.LightningDataModule | None,
trainer: lightning.Trainer | JaxTrainer,
seed: int,
device: torch.device,
):
"""Fixture that creates the "algorithm" (a
[LightningModule][lightning.pytorch.core.module.LightningModule])."""
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
if isinstance(trainer, lightning.Trainer) and isinstance(algorithm, lightning.LightningModule):
with trainer.init_module():
with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
# with a Trainer.
algorithm._device = torch.get_default_device()
algorithm._device = device
algorithm.configure_model()
return algorithm

Expand Down

0 comments on commit 284011c

Please sign in to comment.