Skip to content

Commit

Permalink
Remove cuda-python from project (#245)
Browse files Browse the repository at this point in the history
Remove cuda-python and use CuPy APIs instead

---------

Co-authored-by: Changho Hwang <[email protected]>
  • Loading branch information
Binyang2014 and chhwang authored Feb 13, 2024
1 parent d97fef4 commit 5971508
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 76 deletions.
114 changes: 53 additions & 61 deletions python/mscclpp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import tempfile
from typing import Any, Type

from cuda import cuda, nvrtc, cudart
import cupy as cp
import numpy as np

Expand All @@ -22,62 +21,42 @@
torchTensor = Type[Any]


def _check_cuda_errors(result):
if result[0].value:
raise RuntimeError(f"CUDA error code={result[0].value}({_cuda_get_error(result[0])})")
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]


def _cuda_get_error(error):
if isinstance(error, cuda.CUresult):
err, name = cuda.cuGetErrorName(error)
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, cudart.cudaError_t):
return cudart.cudaGetErrorName(error)[1]
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError("Unknown error type: {}".format(error))


class Kernel:
def __init__(self, ptx: bytes, kernel_name: str, device_id: int):
self._context = _check_cuda_errors(cuda.cuCtxGetCurrent())
assert self._context is not None
self._module = _check_cuda_errors(cuda.cuModuleLoadData(ptx))
self._kernel = _check_cuda_errors(cuda.cuModuleGetFunction(self._module, kernel_name.encode()))
CU_LAUNCH_PARAM_BUFFER_POINTER = 0x01
CU_LAUNCH_PARAM_BUFFER_SIZE = 0x02
CU_LAUNCH_PARAM_END = 0x00 if not cp.cuda.runtime.is_hip else 0x03

def __init__(self, ptx: bytes, kernel_name: str):
self._module = cp.cuda.driver.moduleLoadData(ptx)
self._kernel = cp.cuda.driver.moduleGetFunction(self._module, kernel_name)

def launch_kernel(
self,
params: bytes,
nblocks: int,
nthreads: int,
shared: int,
stream: Type[cuda.CUstream] or Type[cudart.cudaStream_t],
stream: Type[cp.cuda.Stream] or Type[None],
):
buffer = (ctypes.c_byte * len(params)).from_buffer_copy(params)
buffer_size = ctypes.c_size_t(len(params))
config = np.array(
[
cuda.CU_LAUNCH_PARAM_BUFFER_POINTER,
Kernel.CU_LAUNCH_PARAM_BUFFER_POINTER,
ctypes.addressof(buffer),
cuda.CU_LAUNCH_PARAM_BUFFER_SIZE,
Kernel.CU_LAUNCH_PARAM_BUFFER_SIZE,
ctypes.addressof(buffer_size),
cuda.CU_LAUNCH_PARAM_END,
Kernel.CU_LAUNCH_PARAM_END,
],
dtype=np.uint64,
)
_check_cuda_errors(
cuda.cuLaunchKernel(self._kernel, nblocks, 1, 1, nthreads, 1, 1, shared, stream, 0, config.ctypes.data)
cuda_stream = stream.ptr if stream else 0
cp.cuda.driver.launchKernel(
self._kernel, nblocks, 1, 1, nthreads, 1, 1, shared, cuda_stream, 0, config.ctypes.data
)

def __del__(self):
cuda.cuModuleUnload(self._module)
cp.cuda.driver.moduleUnload(self._module)


class KernelBuilder:
Expand All @@ -96,35 +75,48 @@ def __init__(self, file: str, kernel_name: str, file_dir: str = None, macro_dict
self.macros = None
if file_dir:
self.macros = ["-D{}={}".format(macro, value) for macro, value in macro_dict.items()]
device_id = cp.cuda.Device().id
ptx = self._compile_cuda(os.path.join(self._current_file_dir, file), f"{kernel_name}.ptx", device_id)
self._kernel = Kernel(ptx, kernel_name, device_id)
ptx = self._compile_cuda(os.path.join(self._current_file_dir, file), f"{kernel_name}.ptx")
self._kernel = Kernel(ptx, kernel_name)
self.kernel_map[kernel_key] = self._kernel

def _compile_cuda(self, source_file, output_file, device_id, std_version="c++17"):
def _compile_cuda(self, source_file, output_file, std_version="c++17"):
mscclpp_home = os.environ.get("MSCCLPP_HOME", "/usr/local/mscclpp")
include_dir = os.path.join(mscclpp_home, "include")
major = _check_cuda_errors(
cudart.cudaDeviceGetAttribute(cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, device_id)
)
minor = _check_cuda_errors(
cudart.cudaDeviceGetAttribute(cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device_id)
)
cuda_home = os.environ.get("CUDA_HOME")
nvcc = os.path.join(cuda_home, "bin/nvcc") if cuda_home else "nvcc"
command = [
nvcc,
f"-std={std_version}",
"-ptx",
"-Xcompiler",
"-Wall,-Wextra",
f"-I{include_dir}",
f"{source_file}",
f"--gpu-architecture=compute_{major}{minor}",
f"--gpu-code=sm_{major}{minor},compute_{major}{minor}",
"-o",
f"{self._tempdir.name}/{output_file}",
]
if not cp.cuda.runtime.is_hip:
compute_capability = cp.cuda.Device().compute_capability
cuda_home = os.environ.get("CUDA_HOME")
nvcc = os.path.join(cuda_home, "bin/nvcc") if cuda_home else "nvcc"
command = [
nvcc,
f"-std={std_version}",
"-ptx",
"-Xcompiler",
"-Wall,-Wextra",
f"-I{include_dir}",
f"{source_file}",
f"--gpu-architecture=compute_{compute_capability}",
f"--gpu-code=sm_{compute_capability},compute_{compute_capability}",
"-o",
f"{self._tempdir.name}/{output_file}",
]
else:
# the gcn arch name is like "gfx942:sramecc+:xnack-"
gcn_arch = (
cp.cuda.runtime.getDeviceProperties(cp.cuda.Device().id)["gcnArchName"].decode("utf-8").split(":")[0]
)
rocm_home = os.environ.get("ROCM_HOME")
hipcc = os.path.join(rocm_home, "bin/hipcc") if rocm_home else "hipcc"
command = [
hipcc,
f"-std={std_version}",
"--genco",
"-D__HIP_PLATFORM_AMD__",
f"--offload-arch={gcn_arch}",
f"-I{include_dir}",
f"{source_file}",
"-o",
f"{self._tempdir.name}/{output_file}",
]
if self.macros:
command += self.macros
try:
Expand Down
4 changes: 4 additions & 0 deletions python/mscclpp_benchmark/allreduce.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_fp16.h>
#else
#include <cuda_fp16.h>
#endif

#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/nvls_device.hpp>
Expand Down
4 changes: 2 additions & 2 deletions python/mscclpp_benchmark/allreduce_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def check_correctness(memory, func, niter=100):
for p in range(niter):
memory[:] = cp.ones(memory.shape).astype(data_type) * (p * MPI.COMM_WORLD.size + MPI.COMM_WORLD.rank)
cp.cuda.runtime.deviceSynchronize()
output_memory = func(0)
output_memory = func(None)
cp.cuda.runtime.deviceSynchronize()
expected = cp.zeros_like(memory)
for i in range(MPI.COMM_WORLD.size):
Expand All @@ -110,7 +110,7 @@ def bench_time(niter: int, func):
with stream:
stream.begin_capture()
for i in range(niter):
func(stream.ptr)
func(stream)
graph = stream.end_capture()

# now run a warm up round
Expand Down
20 changes: 10 additions & 10 deletions python/mscclpp_benchmark/mscclpp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(

self.set_params(nblocks, block_size, read_only)

def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
def __call__(self, stream):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream)
return self.memory

def set_params(self, nblocks, block_size, read_only):
Expand Down Expand Up @@ -131,8 +131,8 @@ def __init__(

self.set_params(nblocks, block_size)

def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
def __call__(self, stream):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream)
return self.memory_out

def set_params(self, nblocks, block_size):
Expand Down Expand Up @@ -201,8 +201,8 @@ def __init__(

self.set_params(nblocks, block_size)

def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, 24, 1024, 0, stream_ptr)
def __call__(self, stream):
self.kernel.launch_kernel(self.params, 24, 1024, 0, stream)
return self.memory

def set_params(self, nblocks, block_size):
Expand Down Expand Up @@ -295,8 +295,8 @@ def __init__(

self.set_params(nblocks, block_size, pipeline_depth)

def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
def __call__(self, stream):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream)
return self.memory

def set_params(self, nblocks, block_size, pipeline_depth):
Expand Down Expand Up @@ -388,8 +388,8 @@ def __init__(

self.set_params(nblocks, block_size)

def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
def __call__(self, stream):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream)
return self.memory_out

def set_params(self, nblocks, block_size):
Expand Down
3 changes: 2 additions & 1 deletion python/mscclpp_benchmark/nccl_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(self, nccl_comm: nccl.NcclCommunicator, memory: cp.ndarray):
else:
raise RuntimeError("Make sure that the data type is mapped to the correct NCCL data type")

def __call__(self, stream_ptr):
def __call__(self, stream):
stream_ptr = stream.ptr if stream else 0
self.nccl_comm.allReduce(
self.memory.data.ptr, self.memory.data.ptr, self.memory.size, self.nccl_dtype, nccl.NCCL_SUM, stream_ptr
)
Expand Down
1 change: 0 additions & 1 deletion python/requirements_cu11.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
mpi4py
cupy-cuda11x
prettytable
cuda-python
netifaces
pytest
numpy
Expand Down
1 change: 0 additions & 1 deletion python/requirements_cu12.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
mpi4py
cupy-cuda12x
prettytable
cuda-python
netifaces
pytest
numpy
Expand Down

0 comments on commit 5971508

Please sign in to comment.