diff --git a/test/test_core.py b/test/test_core.py index 0c28cc4..d18eaca 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -101,14 +101,18 @@ def test_functions(self) -> None: with pytest.raises(RuntimeError): cuda.requires_cuda_support() + def _check_cudnn(self, val: bool) -> None: + torch.backends.cudnn.benchmark = val + assert torch.backends.cudnn.benchmark == val + with cuda.DisableCuDNNBenchmarkContext(): + assert not torch.backends.cudnn.benchmark + + assert torch.backends.cudnn.benchmark == val + def test_disable_cudnn_context(self) -> None: if torch.cuda.is_available(): - torch.backends.cudnn.benchmark = True - - assert torch.backends.cudnn.benchmark - with cuda.DisableCuDNNBenchmarkContext(): - assert not torch.backends.cudnn.benchmark - assert torch.backends.cudnn.benchmark + self._check_cudnn(True) + self._check_cudnn(False) @dataclasses.dataclass