diff --git a/python/cudf/cudf/utils/_numba.py b/python/cudf/cudf/utils/_numba.py index 574170d28c6..11987741005 100644 --- a/python/cudf/cudf/utils/_numba.py +++ b/python/cudf/cudf/utils/_numba.py @@ -1,11 +1,14 @@ # Copyright (c) 2023-2025, NVIDIA CORPORATION. +from __future__ import annotations import glob import os import sys from functools import lru_cache +import numba from numba import config as numba_config +from packaging import version # Use an lru_cache with a single value to allow a delayed import of @@ -133,18 +136,25 @@ def _setup_numba(): numba_config.CUDA_ENABLE_PYNVJITLINK = True +# Avoids using contextlib.contextmanager due to additional overhead class _CUDFNumbaConfig: - def __enter__(self): + def __enter__(self) -> None: self.CUDA_LOW_OCCUPANCY_WARNINGS = ( numba_config.CUDA_LOW_OCCUPANCY_WARNINGS ) numba_config.CUDA_LOW_OCCUPANCY_WARNINGS = 0 - self.CAPTURED_ERRORS = numba_config.CAPTURED_ERRORS - numba_config.CAPTURED_ERRORS = "new_style" + self.is_numba_lt_061 = version.parse( + numba.__version__ + ) < version.parse("0.61") - def __exit__(self, exc_type, exc_value, traceback): + if self.is_numba_lt_061: + self.CAPTURED_ERRORS = numba_config.CAPTURED_ERRORS + numba_config.CAPTURED_ERRORS = "new_style" + + def __exit__(self, exc_type, exc_value, traceback) -> None: numba_config.CUDA_LOW_OCCUPANCY_WARNINGS = ( self.CUDA_LOW_OCCUPANCY_WARNINGS ) - numba_config.CAPTURED_ERRORS = self.CAPTURED_ERRORS + if self.is_numba_lt_061: + numba_config.CAPTURED_ERRORS = self.CAPTURED_ERRORS