From c9d7bacf70acbccd4321127ff76a19475797e9df Mon Sep 17 00:00:00 2001 From: "Benedikt J. Daurer" Date: Tue, 5 Mar 2024 09:20:28 +0000 Subject: [PATCH] update dependencies and better error handling (#540) --- cufft/dependencies.yml | 6 +++++- cufft/extensions.py | 5 +++++ cufft/setup.py | 13 ++++++++++--- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/cufft/dependencies.yml b/cufft/dependencies.yml index 48f17a1e7..2b081a06a 100644 --- a/cufft/dependencies.yml +++ b/cufft/dependencies.yml @@ -1,10 +1,14 @@ name: ptypy_cufft channels: - conda-forge + - nvidia dependencies: - python - cmake>=3.8.0 - pybind11 - compilers - - cudatoolkit-dev + - cuda-nvcc + - cuda-cudart-dev + - libcufft-dev + - libcufft-static - pip \ No newline at end of file diff --git a/cufft/extensions.py b/cufft/extensions.py index 545b43d04..4d8b7f598 100644 --- a/cufft/extensions.py +++ b/cufft/extensions.py @@ -40,6 +40,11 @@ def locate_cuda(): cudaconfig = {'home': home, 'nvcc': nvcc, 'include': os.path.join(home, 'include'), 'lib64': os.path.join(home, 'lib64')} + + # If lib64 does not exist, try lib instead (as common in conda env) + if not os.path.exists(cudaconfig['lib64']): + cudaconfig['lib64'] = os.path.join(home, 'lib') + for k, v in cudaconfig.items(): if not os.path.exists(v): raise EnvironmentError('The CUDA %s path could not be located in %s' % (k, v)) diff --git a/cufft/setup.py b/cufft/setup.py index 5108ebf32..a8ef7c61e 100644 --- a/cufft/setup.py +++ b/cufft/setup.py @@ -24,12 +24,19 @@ ) cmdclass = {"build_ext": CustomBuildExt} EXTBUILD_MESSAGE = "The filtered cufft extension has been successfully installed.\n" -except: +except EnvironmentError as e: EXTBUILD_MESSAGE = '*' * 75 + "\n" EXTBUILD_MESSAGE += "Could not install the filtered cufft extension.\n" - EXTBUILD_MESSAGE += "Make sure to have CUDA >= 10 and pybind11 installed.\n" + EXTBUILD_MESSAGE += "Make sure to have CUDA >= 10 installed.\n" EXTBUILD_MESSAGE += '*' * 75 + "\n" - + EXTBUILD_MESSAGE += 'Error message: ' + str(e) +except ImportError as e: + EXTBUILD_MESSAGE = '*' * 75 + "\n" + EXTBUILD_MESSAGE += "Could not install the filtered cufft extension.\n" + EXTBUILD_MESSAGE += "Make sure to have pybind11 installed.\n" + EXTBUILD_MESSAGE += '*' * 75 + "\n" + EXTBUILD_MESSAGE += 'Error message: ' + str(e) + exclude_packages = [] package_list = setuptools.find_packages(exclude=exclude_packages) setup(