From 953e3e31f88a0c228976826abe3b1f90bb756866 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 13 Nov 2023 14:06:46 -0800 Subject: [PATCH] [op-builder] use unique exceptions for cuda issues (#4653) Another follow-up to https://github.com/microsoft/DeepSpeed/commit/4f7dd7214b1d81dbbdff826015a67accc10390d2 based on offline discussion w. @tjruwase /cc @loadams Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- op_builder/builder.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index 79692ce05878..3613791c938d 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -35,10 +35,19 @@ TORCH_MINOR = int(torch.__version__.split('.')[1]) +class MissingCUDAException(Exception): + pass + + +class CUDAMismatchException(Exception): + pass + + def installed_cuda_version(name=""): import torch.utils.cpp_extension cuda_home = torch.utils.cpp_extension.CUDA_HOME - assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" + if cuda_home is None: + raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)") # Ensure there is not a cuda version mismatch between torch and nvcc compiler output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True) output_split = output.split() @@ -89,9 +98,10 @@ def assert_no_cuda_mismatch(name=""): "Detected `DS_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior." ) return True - raise Exception(f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " - f"version torch was compiled with {torch.version.cuda}, unable to compile " - "cuda/cpp extensions without a matching cuda version.") + raise CUDAMismatchException( + f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}, unable to compile " + "cuda/cpp extensions without a matching cuda version.") return True @@ -339,7 +349,7 @@ def is_cuda_enable(self): try: assert_no_cuda_mismatch(self.name) return '-D__ENABLE_CUDA__' - except BaseException: + except MissingCUDAException: print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " "only cpu ops can be compiled!") return '-D__DISABLE_CUDA__' @@ -601,7 +611,7 @@ def builder(self): if not self.is_rocm_pytorch(): assert_no_cuda_mismatch(self.name) self.build_for_cpu = False - except BaseException: + except MissingCUDAException: self.build_for_cpu = True if self.build_for_cpu: