diff --git a/pkgs/development/python-modules/torch/bin.nix b/pkgs/development/python-modules/torch/bin.nix index e2899c081e08b..ec6843deac98c 100644 --- a/pkgs/development/python-modules/torch/bin.nix +++ b/pkgs/development/python-modules/torch/bin.nix @@ -121,7 +121,10 @@ buildPythonPackage { pythonImportsCheck = [ "torch" ]; - passthru.gpuChecks.cudaAvailable = callPackage ./test-cuda.nix { torch = torch-bin; }; + passthru.tests = callPackage ./tests.nix { + torchWithCuda = torch-bin; + torchWithRocm = torch-bin; + }; meta = { description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration"; diff --git a/pkgs/development/python-modules/torch/gpu-checks.nix b/pkgs/development/python-modules/torch/gpu-checks.nix deleted file mode 100644 index 55a4b45f71522..0000000000000 --- a/pkgs/development/python-modules/torch/gpu-checks.nix +++ /dev/null @@ -1,40 +0,0 @@ -{ - lib, - torchWithCuda, - torchWithRocm, - callPackage, -}: - -let - accelAvailable = - { - feature, - versionAttr, - torch, - cudaPackages, - }: - cudaPackages.writeGpuTestPython - { - inherit feature; - libraries = [ torch ]; - name = "${feature}Available"; - } - '' - import torch - message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}" - assert torch.cuda.is_available() and torch.version.${versionAttr}, message - print(message) - ''; -in -{ - tester-cudaAvailable = callPackage accelAvailable { - feature = "cuda"; - versionAttr = "cuda"; - torch = torchWithCuda; - }; - tester-rocmAvailable = callPackage accelAvailable { - feature = "rocm"; - versionAttr = "hip"; - torch = torchWithRocm; - }; -} diff --git a/pkgs/development/python-modules/torch/mk-runtime-check.nix b/pkgs/development/python-modules/torch/mk-runtime-check.nix new file mode 100644 index 0000000000000..14560b06f87ce --- /dev/null +++ b/pkgs/development/python-modules/torch/mk-runtime-check.nix @@ -0,0 +1,19 @@ +{ + cudaPackages, + feature, + torch, + versionAttr, +}: + +cudaPackages.writeGpuTestPython + { + inherit feature; + libraries = [ torch ]; + name = "${feature}Available"; + } + '' + import torch + message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}" + assert torch.cuda.is_available() and torch.version.${versionAttr}, message + print(message) + '' diff --git a/pkgs/development/python-modules/torch/tests.nix b/pkgs/development/python-modules/torch/tests.nix index 5a46d0886868c..76b901cbcea91 100644 --- a/pkgs/development/python-modules/torch/tests.nix +++ b/pkgs/development/python-modules/torch/tests.nix @@ -1,3 +1,21 @@ -{ callPackage }: +{ + callPackage, + torchWithCuda, + torchWithRocm, +}: -callPackage ./gpu-checks.nix { } +{ + # To perform the runtime check use either + # `nix run .#python3Packages.torch.tests.tester-cudaAvailable` (outside the sandbox), or + # `nix build .#python3Packages.torch.tests.tester-cudaAvailable.gpuCheck` (in a relaxed sandbox) + tester-cudaAvailable = callPackage ./mk-runtime-check.nix { + feature = "cuda"; + versionAttr = "cuda"; + torch = torchWithCuda; + }; + tester-rocmAvailable = callPackage ./mk-runtime-check.nix { + feature = "rocm"; + versionAttr = "hip"; + torch = torchWithRocm; + }; +}