diff --git a/project/conftest.py b/project/conftest.py index 9052ddcd..94881655 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -522,17 +522,8 @@ def network( if any(torch.nn.parameter.is_lazy(p) for p in network.parameters()): # a bit ugly, but we need to initialize any lazy weights before we pass the network # to the tests. - try: - _ = network(input) - except RuntimeError as err: - # TODO: Investigate the false positives with example_from_config, resnets, cifar10 - logger.error(f"Error when running the network: {err}") - request.node.add_marker( - pytest.mark.xfail( - raises=RuntimeError, - reason="Network doesn't seem to be compatible this dataset.", - ) - ) + # TODO: Investigate the false positives with example_from_config, resnets, cifar10 + _ = network(input) return network diff --git a/project/main_test.py b/project/main_test.py index 58532813..6f974688 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -26,9 +26,9 @@ def test_jax_can_use_the_GPU(): device = jax.numpy.zeros(1).devices().pop() if shutil.which("nvidia-smi"): - assert device.type == "GPU" + assert str(device) == "cuda:0" else: - assert device.type == "CPU" + assert str(device) == "cpu" def test_torch_can_use_the_GPU():