diff --git a/project/algorithms/testsuites/algorithm_tests.py b/project/algorithms/testsuites/algorithm_tests.py index 38013f77..70bd3471 100644 --- a/project/algorithms/testsuites/algorithm_tests.py +++ b/project/algorithms/testsuites/algorithm_tests.py @@ -65,6 +65,9 @@ def test_initialization_is_deterministic( assert isinstance(algorithm_1, lightning.LightningModule) with trainer.init_module(): + # A bit hacky, but we have to do this because the lightningmodule isn't associated + # with a Trainer. + algorithm_1._device = torch.get_default_device() algorithm_1.configure_model() with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): @@ -73,6 +76,9 @@ def test_initialization_is_deterministic( assert isinstance(algorithm_2, lightning.LightningModule) with trainer.init_module(): + # A bit hacky, but we have to do this because the lightningmodule isn't associated + # with a Trainer. + algorithm_2._device = torch.get_default_device() algorithm_2.configure_model() torch.testing.assert_close(algorithm_1.state_dict(), algorithm_2.state_dict())