Skip to content

Commit

Permalink
simplify & fix module/library handling
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Oct 26, 2024
1 parent b319731 commit b64f337
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"kernel": cuda.cuModuleGetFunction,
},
}
_kernel_ctypes = [cuda.CUfunction]

# binding availability depends on cuda-python version
py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0])
Expand All @@ -25,8 +24,10 @@
"data": cuda.cuLibraryLoadData,
"kernel": cuda.cuLibraryGetKernel,
}
_kernel_ctypes.append(cuda.CUkernel)
_kernel_ctypes = tuple(_kernel_ctypes)
_kernel_ctypes = (cuda.CUfunction, cuda.CUkernel)
else:
_kernel_ctypes = (cuda.CUfunction,)
driver_ver = handle_return(cuda.cuDriverGetVersion())


class Kernel:
Expand All @@ -45,6 +46,8 @@ def _from_obj(obj, mod):
ker._module = mod
return ker

# TODO: implement from_handle()


class ObjectCode:

Expand All @@ -57,11 +60,8 @@ def __init__(self, module, code_type, jit_options=None, *,
raise ValueError
self._handle = None

driver_ver = handle_return(cuda.cuDriverGetVersion())
if py_major_ver >= 12 and driver_ver >= 12000:
self._loader = _backend["new"]
else:
self._loader = _backend["old"]
backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000) else "old"
self._loader = _backend[backend]

if isinstance(module, str):
if driver_ver < 12000 and jit_options is not None:
Expand All @@ -72,11 +72,11 @@ def __init__(self, module, code_type, jit_options=None, *,
assert isinstance(module, bytes)
if jit_options is None:
jit_options = {}
if driver_ver >= 12000:
if backend == "new":
args = (module, list(jit_options.keys()), list(jit_options.values()), len(jit_options),
# TODO: support library options
[], [], 0)
else:
else: # "old" backend
args = (module, len(jit_options), list(jit_options.keys()), list(jit_options.values()))
self._handle = handle_return(self._loader["data"](*args))

Expand All @@ -95,3 +95,5 @@ def get_kernel(self, name):
name = name.encode()
data = handle_return(self._loader["kernel"](self._handle, name))
return Kernel._from_obj(data, self)

# TODO: implement from_handle()

0 comments on commit b64f337

Please sign in to comment.