From 284011c35c27943907df58a631d4b4ca57b32b2d Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 28 Nov 2024 10:29:58 -0500 Subject: [PATCH] Fix bug with default device and configure_model Signed-off-by: Fabrice Normandin --- .../testsuites/lightning_module_tests.py | 14 ++++++++------ project/conftest.py | 5 +++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/project/algorithms/testsuites/lightning_module_tests.py b/project/algorithms/testsuites/lightning_module_tests.py index 792468f1..17290827 100644 --- a/project/algorithms/testsuites/lightning_module_tests.py +++ b/project/algorithms/testsuites/lightning_module_tests.py @@ -57,6 +57,7 @@ def test_initialization_is_deterministic( datamodule: lightning.LightningDataModule | None, seed: int, trainer: lightning.Trainer, + device: torch.device, ): """Checks that the weights initialization is consistent given the a random seed.""" @@ -65,10 +66,10 @@ def test_initialization_is_deterministic( algorithm_1 = instantiate_algorithm(experiment_config.algorithm, datamodule) assert isinstance(algorithm_1, lightning.LightningModule) - with trainer.init_module(): + with trainer.init_module(), device: # A bit hacky, but we have to do this because the lightningmodule isn't associated # with a Trainer. - algorithm_1._device = torch.get_default_device() + algorithm_1._device = device algorithm_1.configure_model() with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): @@ -76,10 +77,10 @@ def test_initialization_is_deterministic( algorithm_2 = instantiate_algorithm(experiment_config.algorithm, datamodule) assert isinstance(algorithm_2, lightning.LightningModule) - with trainer.init_module(): + with trainer.init_module(), device: # A bit hacky, but we have to do this because the lightningmodule isn't associated # with a Trainer. - algorithm_2._device = torch.get_default_device() + algorithm_2._device = device algorithm_2.configure_model() torch.testing.assert_close(algorithm_1.state_dict(), algorithm_2.state_dict()) @@ -157,16 +158,17 @@ def test_initialization_is_reproducible( seed: int, tensor_regression: TensorRegressionFixture, trainer: lightning.Trainer, + device: torch.device, ): """Check that the network initialization is reproducible given the same random seed.""" with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): torch.random.manual_seed(seed) algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) assert isinstance(algorithm, lightning.LightningModule) - with trainer.init_module(): + with trainer.init_module(), device: # A bit hacky, but we have to do this because the lightningmodule isn't associated # with a Trainer. - algorithm._device = torch.get_default_device() + algorithm._device = device algorithm.configure_model() tensor_regression.check( diff --git a/project/conftest.py b/project/conftest.py index 62b69887..6e3d0393 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -328,15 +328,16 @@ def algorithm( datamodule: lightning.LightningDataModule | None, trainer: lightning.Trainer | JaxTrainer, seed: int, + device: torch.device, ): """Fixture that creates the "algorithm" (a [LightningModule][lightning.pytorch.core.module.LightningModule]).""" algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) if isinstance(trainer, lightning.Trainer) and isinstance(algorithm, lightning.LightningModule): - with trainer.init_module(): + with trainer.init_module(), device: # A bit hacky, but we have to do this because the lightningmodule isn't associated # with a Trainer. - algorithm._device = torch.get_default_device() + algorithm._device = device algorithm.configure_model() return algorithm