diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index 4951358d..b1223365 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -76,8 +76,11 @@ def launch(kernel, config, *kernel_args): kernel_args = ParamHolder(kernel_args) args_ptr = kernel_args.ptr - driver_ver = handle_return(cuda.cuDriverGetVersion()) - if driver_ver >= 12000: + # Note: CUkernel can still be launched via the old cuLaunchKernel. We check ._backend + # here not because of the CUfunction/CUkernel difference (which depends on whether the + # "old" or "new" module loading APIs are in use), but only as a proxy to check if + # both binding & driver versions support the "Ex" API, which is more feature rich. + if kernel._backend == "new": drv_cfg = cuda.CUlaunchConfig() drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block @@ -86,7 +89,7 @@ def launch(kernel, config, *kernel_args): drv_cfg.numAttrs = 0 # TODO handle_return(cuda.cuLaunchKernelEx( drv_cfg, int(kernel._handle), args_ptr, 0)) - else: + else: # "old" backend # TODO: check if config has any unsupported attrs handle_return(cuda.cuLaunchKernel( int(kernel._handle), diff --git a/cuda_core/cuda/core/experimental/_module.py b/cuda_core/cuda/core/experimental/_module.py index a51ab24f..e5d0808f 100644 --- a/cuda_core/cuda/core/experimental/_module.py +++ b/cuda_core/cuda/core/experimental/_module.py @@ -32,18 +32,19 @@ class Kernel: - __slots__ = ("_handle", "_module",) + __slots__ = ("_handle", "_module", "_backend") def __init__(self): raise NotImplementedError("directly constructing a Kernel instance is not supported") @staticmethod - def _from_obj(obj, mod): + def _from_obj(obj, mod, backend): assert isinstance(obj, _kernel_ctypes) assert isinstance(mod, ObjectCode) ker = Kernel.__new__(Kernel) ker._handle = obj ker._module = mod + ker._backend = backend return ker # TODO: implement from_handle() @@ -51,7 +52,7 @@ def _from_obj(obj, mod): class ObjectCode: - __slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map") + __slots__ = ("_handle", "_code_type", "_module", "_loader", "_loader_backend", "_sym_map") _supported_code_type = ("cubin", "ptx", "fatbin") def __init__(self, module, code_type, jit_options=None, *, @@ -62,6 +63,7 @@ def __init__(self, module, code_type, jit_options=None, *, backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000) else "old" self._loader = _backend[backend] + self._loader_backend = backend if isinstance(module, str): if driver_ver < 12000 and jit_options is not None: @@ -94,6 +96,6 @@ def get_kernel(self, name): except KeyError: name = name.encode() data = handle_return(self._loader["kernel"](self._handle, name)) - return Kernel._from_obj(data, self) + return Kernel._from_obj(data, self, self._loader_backend) # TODO: implement from_handle()