diff --git a/project/algorithms/llm_finetuning_test.py b/project/algorithms/llm_finetuning_test.py index f0d25bd3..ef80dedb 100644 --- a/project/algorithms/llm_finetuning_test.py +++ b/project/algorithms/llm_finetuning_test.py @@ -122,6 +122,7 @@ def test_initialization_is_reproducible( seed: int, tensor_regression: TensorRegressionFixture, trainer: lightning.Trainer, + device: torch.device, ): super().test_initialization_is_reproducible( experiment_config=experiment_config, @@ -129,6 +130,7 @@ def test_initialization_is_reproducible( seed=seed, tensor_regression=tensor_regression, trainer=trainer, + device=device, ) @pytest.mark.xfail(