diff --git a/include/mscclpp/executor.hpp b/include/mscclpp/executor.hpp index b964c5625..e994548e4 100644 --- a/include/mscclpp/executor.hpp +++ b/include/mscclpp/executor.hpp @@ -6,11 +6,18 @@ #include #include -#include #include namespace mscclpp { +enum class DataType { + INT32, + UINT32, + FLOAT16, + FLOAT32, + BFLOAT16, +}; + enum class PacketType { LL8, LL16, diff --git a/include/mscclpp/gpu.hpp b/include/mscclpp/gpu.hpp index 059f579c0..e860c7b93 100644 --- a/include/mscclpp/gpu.hpp +++ b/include/mscclpp/gpu.hpp @@ -103,13 +103,6 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri #define CUDA_NVLS_SUPPORTED 0 #endif // !defined(__HIP_PLATFORM_AMD__) -// Fabric -#if !defined(__HIP_PLATFORM_AMD__) -#define CUDA_FABRIC_SUPPORTED ((CUDART_VERSION >= 12040)) -#else // !defined(__HIP_PLATFORM_AMD__) -#define CUDA_FABRIC_SUPPORTED 0 -#endif // !defined(__HIP_PLATFORM_AMD__) - // GPU sync threads #if defined(__HIP_PLATFORM_AMD__) #define __syncshm() asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier"); diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp index 85e6bd812..10dea3e59 100644 --- a/include/mscclpp/gpu_data_types.hpp +++ b/include/mscclpp/gpu_data_types.hpp @@ -29,16 +29,4 @@ using __bfloat162 = __nv_bfloat162; #endif -namespace mscclpp { - -enum class DataType { - INT32, - UINT32, - FLOAT16, - FLOAT32, - BFLOAT16, -}; - -} // namespace mscclpp - #endif // MSCCLPP_GPU_DATA_TYPES_HPP_ diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index 4baf8af2d..cbfa882de 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -96,7 +96,7 @@ T* cudaCalloc(size_t nelem) { return ptr; } -#if (CUDA_FABRIC_SUPPORTED) +#if (CUDA_NVLS_SUPPORTED) template T* cudaPhysicalCalloc(size_t nbytes, size_t gran) { AvoidCudaGraphCaptureGuard cgcGuard; @@ -246,7 +246,7 @@ std::shared_ptr allocSharedCuda(size_t count = 1) { return detail::safeAlloc, CudaDeleter, std::shared_ptr>(count); } -#if (CUDA_FABRIC_SUPPORTED) +#if (CUDA_NVLS_SUPPORTED) static inline size_t getMulticastGranularity(size_t size, CUmulticastGranularity_flags granFlag) { size_t gran = 0; int numDevices = 0; @@ -270,9 +270,9 @@ static inline size_t getMulticastGranularity(size_t size, CUmulticastGranularity /// @return A std::shared_ptr to the allocated memory. template std::shared_ptr allocSharedPhysicalCuda([[maybe_unused]] size_t count, [[maybe_unused]] size_t gran = 0) { -#if (CUDA_FABRIC_SUPPORTED) - if (!isFabricSupported()) { - throw Error("Only suupport GPU with Fabric support", ErrorCode::InvalidUsage); +#if (CUDA_NVLS_SUPPORTED) + if (!isNvlsSupported()) { + throw Error("Only support GPU with NVLS support", ErrorCode::InvalidUsage); } if (count == 0) { return nullptr; @@ -387,9 +387,9 @@ UniqueCudaHostPtr makeUniqueCudaHost(size_t count) { /// @return A std::unique_ptr to the allocated memory. template std::unique_ptr allocUniquePhysicalCuda([[maybe_unused]] size_t count, [[maybe_unused]] size_t gran = 0) { -#if (CUDA_FABRIC_SUPPORTED) - if (!isFabricSupported()) { - throw Error("Only suupport GPU with Fabric support", ErrorCode::InvalidUsage); +#if (CUDA_NVLS_SUPPORTED) + if (!isNvlsSupported()) { + throw Error("Only suupport GPU with NVLS support", ErrorCode::InvalidUsage); } if (count == 0) { return nullptr; diff --git a/include/mscclpp/nvls.hpp b/include/mscclpp/nvls.hpp index c6d96d2a4..126afa6e5 100644 --- a/include/mscclpp/nvls.hpp +++ b/include/mscclpp/nvls.hpp @@ -26,25 +26,20 @@ class NvlsConnection { struct DeviceMulticastPointer { private: void* devicePtr_; - std::shared_ptr ptr_; std::shared_ptr mcPtr_; size_t bufferSize_; public: using DeviceHandle = DeviceMulticastPointerDeviceHandle; - DeviceMulticastPointer(std::shared_ptr ptr, std::shared_ptr mcPtr, size_t bufferSize) - : devicePtr_(ptr.get()), ptr_(ptr), mcPtr_(mcPtr), bufferSize_(bufferSize) {} DeviceMulticastPointer(void* devicePtr, std::shared_ptr mcPtr, size_t bufferSize) - : devicePtr_(devicePtr), ptr_(nullptr), mcPtr_(mcPtr), bufferSize_(bufferSize) {} + : devicePtr_(devicePtr), mcPtr_(mcPtr), bufferSize_(bufferSize) {} DeviceHandle deviceHandle(); void* getDevicePtr(); friend class NvlsConnection; }; - std::shared_ptr allocateAndBindCuda(size_t size); - - std::shared_ptr bindAllocatedCuda(CUdeviceptr devicePtr, size_t size); + DeviceMulticastPointer bindAllocatedMemory(CUdeviceptr devicePtr, size_t size); size_t getMultiCastMinGranularity(); diff --git a/include/mscclpp/utils.hpp b/include/mscclpp/utils.hpp index 6c31482b3..80b3bf39d 100644 --- a/include/mscclpp/utils.hpp +++ b/include/mscclpp/utils.hpp @@ -38,7 +38,6 @@ struct ScopedTimer : public Timer { std::string getHostName(int maxlen, const char delim); bool isNvlsSupported(); -bool isFabricSupported(); } // namespace mscclpp diff --git a/python/mscclpp/gpu_utils_py.cpp b/python/mscclpp/gpu_utils_py.cpp index c019ca377..32c578fb7 100644 --- a/python/mscclpp/gpu_utils_py.cpp +++ b/python/mscclpp/gpu_utils_py.cpp @@ -10,45 +10,21 @@ using namespace mscclpp; class PyCudaMemory { public: - PyCudaMemory(size_t size, DataType dtype) : size_(size), dtype_(dtype) { ptr_ = allocSharedPhysicalCuda(size); } + PyCudaMemory(size_t size) : size_(size) { ptr_ = allocSharedPhysicalCuda(size); } uintptr_t getPtr() const { return (uintptr_t)(ptr_.get()); } size_t size() const { return size_; } - DataType dtype() const { return dtype_; } private: std::shared_ptr ptr_; size_t size_; - DataType dtype_; }; -std::shared_ptr allocSharedPhysicalCudaDispatcher(size_t count, DataType dtype) { - size_t size = 0; - switch (dtype) { - case DataType::FLOAT32: - size = count * sizeof(float); - break; - case DataType::FLOAT16: - size = count * sizeof(__half); - break; - case DataType::BFLOAT16: - size = count * sizeof(__bfloat16); - break; - case DataType::INT32: - size = count * sizeof(int); - break; - default: - throw std::runtime_error("Unsupported data type."); - } - - return std::make_shared(size, dtype); -} - void register_gpu_utils(nb::module_& m) { nb::class_(m, "PyCudaMemory") - .def(nb::init(), nb::arg("size"), nb::arg("dtype")) + .def(nb::init(), nb::arg("size")) .def("get_ptr", &PyCudaMemory::getPtr, "Get the raw pointer") - .def("size", &PyCudaMemory::size, "Get the size of the allocated memory") - .def("dtype", &PyCudaMemory::dtype, "Get the data type of the memory"); - m.def("alloc_shared_physical_cuda", &allocSharedPhysicalCudaDispatcher, nb::arg("count"), nb::arg("dtype")); + .def("size", &PyCudaMemory::size, "Get the size of the allocated memory"); + m.def( + "alloc_shared_physical_cuda", [](size_t size) { return std::make_shared(size); }, nb::arg("size")); } diff --git a/python/mscclpp/nvls_py.cpp b/python/mscclpp/nvls_py.cpp index 819a7c6b0..91b966bd8 100644 --- a/python/mscclpp/nvls_py.cpp +++ b/python/mscclpp/nvls_py.cpp @@ -30,7 +30,7 @@ void register_nvls(nb::module_& m) { }); nb::class_(m, "NvlsConnection") - .def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda) + .def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("devicePtr"), nb::arg("size")) .def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity); m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("allRanks"), diff --git a/python/mscclpp_benchmark/mscclpp_op.py b/python/mscclpp_benchmark/mscclpp_op.py index 706107bef..d161f439e 100644 --- a/python/mscclpp_benchmark/mscclpp_op.py +++ b/python/mscclpp_benchmark/mscclpp_op.py @@ -1,7 +1,7 @@ import os import cupy as cp import ctypes -from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore +from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore, alloc_shared_physical_cuda import mscclpp.comm as mscclpp_comm from mscclpp.utils import KernelBuilder, pack @@ -443,10 +443,13 @@ def __init__( self.nvls_connection = group.make_connection(all_ranks, Transport.Nvls) min_gran = self.nvls_connection.get_multicast_min_granularity() aligned_buffer_size = int(((buffer_size + min_gran - 1) // min_gran) * min_gran) - self.nvls_mem_handle = self.nvls_connection.allocate_bind_memory( + buffer_raw = alloc_shared_physical_cuda(aligned_buffer_size) + self.nvls_mem_handle = self.nvls_connection.bind_allocated_memory( + buffer_raw.get_ptr(), aligned_buffer_size ) # just using recommended size for now self.memory_ptr = self.nvls_mem_handle.get_device_ptr() + self.buffer_raw = buffer_raw self.cp_memory_ptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(self.memory_ptr, aligned_buffer_size, None), 0) self.memory = cp.ndarray(nelem, memory_dtype, self.cp_memory_ptr) diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 9d9b3a708..29ad804fc 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -12,18 +12,33 @@ is_nvls_supported, ) import mscclpp.comm as mscclpp_comm +from mscclpp.utils import KernelBuilder, pack import os +import struct import cupy as cp from mpi4py import MPI -def bench_time(niters: int, ngraphIters: int, func): - # capture cuda graph for niters of the kernel launch +def parse_dtype(dtype_str): + """Convert a human-readable data type string to a numpy data type.""" + dtype_str = dtype_str.strip().lower() + if dtype_str == "float16": + return cp.float16 + elif dtype_str == "float32": + return cp.float32 + elif dtype_str == "int32": + return cp.int32 + else: + raise ValueError(f"Unknown data type: {dtype_str}") + + +def bench_time(n_iters: int, n_graph_iters: int, func): + # capture cuda graph for n_iters of the kernel launch stream = cp.cuda.Stream(non_blocking=True) with stream: stream.begin_capture() - for i in range(niters): + for i in range(n_iters): func(stream) graph = stream.end_capture() @@ -35,12 +50,58 @@ def bench_time(niters: int, ngraphIters: int, func): end = cp.cuda.Event() start.record(stream) - for _ in range(ngraphIters): + for _ in range(n_graph_iters): graph.launch(stream) end.record(stream) end.synchronize() - return cp.cuda.get_elapsed_time(start, end) / niters * 1000.0 / ngraphIters + return cp.cuda.get_elapsed_time(start, end) / n_iters * 1000.0 / n_graph_iters + + +def bench_correctness( + execution_plan_name: str, + input_buf: cp.ndarray, + result_buf: cp.ndarray, + test_buf: cp.ndarray, + dtype_str: str, + rank: int, + num_ranks: int, + n_iters: int, + func, +): + type_size = cp.dtype(parse_dtype(dtype_str)).itemsize + + fill_data_kernel_name = "fill_data_%s" % dtype_str + if "allgather" in execution_plan_name: + coll = "all_gather" + else: + coll = "all_reduce" + test_data_kernel_name = "test_data_%s_%s" % (coll, dtype_str) + + file_dir = os.path.dirname(os.path.abspath(__file__)) + fill_data_kernel = KernelBuilder( + file="executor_test_verifier.cu", kernel_name=fill_data_kernel_name, file_dir=file_dir + ).get_compiled_kernel() + test_data_kernel = KernelBuilder( + file="executor_test_verifier.cu", kernel_name=test_data_kernel_name, file_dir=file_dir + ).get_compiled_kernel() + nblocks = 64 + nthreads = 1024 + + stream = cp.cuda.Stream(non_blocking=True) + with stream: + stream.begin_capture() + for i in range(n_iters): + fill_data_params = pack(input_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(rank, i) + fill_data_kernel.launch_kernel(fill_data_params, nblocks, nthreads, 0, stream) + func(stream) + test_data_params = ( + pack(result_buf, test_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(num_ranks, i) + ) + test_data_kernel.launch_kernel(test_data_params, nblocks, nthreads, 0, stream) + graph = stream.end_capture() + graph.launch(stream) + stream.synchronize() def parse_size(size_str): @@ -55,19 +116,6 @@ def parse_size(size_str): return int(size_str) -def parse_dtype(dtype_str): - """Convert a human-readable data type string to a numpy data type.""" - dtype_str = dtype_str.strip().lower() - if dtype_str == "float16": - return cp.float16 - elif dtype_str == "float32": - return cp.float32 - elif dtype_str == "int32": - return cp.int32 - else: - raise ValueError(f"Unknown data type: {dtype_str}") - - def dtype_to_mscclpp_dtype(dtype): if dtype == cp.float16: return DataType.float16 @@ -79,25 +127,9 @@ def dtype_to_mscclpp_dtype(dtype): raise ValueError(f"Unknown data type: {dtype}") -def determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name): - if "allgather" in execution_plan_name and in_place: - return recvbuf - else: - return sendbuf - - -def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name): - if "allgather" in execution_plan_name: - return recvbuf - elif in_place: - return sendbuf - else: - return recvbuf - - def allocate_buffer(nelems, dtype): if is_nvls_supported: - buffer_raw = alloc_shared_physical_cuda(nelems, dtype=dtype_to_mscclpp_dtype(dtype)) + buffer_raw = alloc_shared_physical_cuda(nelems * cp.dtype(dtype).itemsize) buffer_ptr = cp.cuda.MemoryPointer( cp.cuda.UnownedMemory(buffer_raw.get_ptr(), buffer_raw.size(), buffer_raw), 0 ) @@ -107,14 +139,47 @@ def allocate_buffer(nelems, dtype): return cp.zeros(nelems, dtype=dtype) +def build_bufs( + execution_plan_name: str, + size: int, + in_place: bool, + dtype: cp.dtype, + rank: int, + num_ranks: int, +): + type_size = cp.dtype(dtype).itemsize + assert (size % type_size) == 0, "size %d not multiple of type size %d" % (size, type_size) + nelems = size // type_size + + if "allgather" in execution_plan_name: + assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks) + nelems_input = nelems if in_place else nelems // num_ranks + else: + nelems_input = nelems + nelems_output = nelems + + result_buf = allocate_buffer(nelems_output, dtype=dtype) + if in_place: + if "allgather" in execution_plan_name: + input_buf = cp.split(result_buf, num_ranks)[rank] + else: + input_buf = result_buf + else: + input_buf = allocate_buffer(nelems_input, dtype=dtype) + test_buf = cp.zeros(nelems_output, dtype=dtype) + + return input_buf, result_buf, test_buf + + def main( execution_plan_name: str, execution_plan_path: str, size: int, in_place: bool = True, - dtype: cp.dtype = cp.float16, + dtype_str: str = "float16", packet_type: PacketType = PacketType.LL16, - seed: int = 42, + n_iters: int = 10, + n_graph_iters: int = 10, ): mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD) cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use() @@ -124,56 +189,49 @@ def main( npkit.init(mscclpp_group.my_rank) execution_plan = ExecutionPlan(execution_plan_name, execution_plan_path) - cp.random.seed(seed) - nelems = size // cp.dtype(dtype).itemsize - buffer = cp.random.random(nelems * mscclpp_group.nranks, dtype=cp.float32).astype(dtype) - sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size) - sendbuf = allocate_buffer(nelems, dtype) - sendbuf[:] = sub_arrays[MPI.COMM_WORLD.rank] - - if "allgather" in execution_plan_name: - recvbuf = allocate_buffer(nelems * mscclpp_group.nranks, dtype) - if in_place: - for i in range(nelems): - recvbuf[mscclpp_group.my_rank * nelems + i] = sendbuf[i] - expected = buffer - else: - recvbuf = allocate_buffer(nelems, dtype) - expected = cp.zeros_like(sendbuf, dtype=dtype) - for i in range(mscclpp_group.nranks): - expected += sub_arrays[i] - mscclpp_group.barrier() + dtype = parse_dtype(dtype_str) + input_buf, result_buf, test_buf = build_bufs( + execution_plan_name, + size, + in_place, + dtype, + mscclpp_group.my_rank, + mscclpp_group.nranks, + ) executor_func = lambda stream: executor.execute( - MPI.COMM_WORLD.rank, - determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr, - determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr, - determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes, - determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes, + mscclpp_group.my_rank, + input_buf.data.ptr, + result_buf.data.ptr, + input_buf.nbytes, + result_buf.nbytes, dtype_to_mscclpp_dtype(dtype), execution_plan, stream.ptr, packet_type, ) - # check correctness - stream = cp.cuda.Stream(non_blocking=True) - executor_func(stream) - stream.synchronize() - assert cp.allclose( - determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name), - expected, - atol=1e-2 * mscclpp_group.nranks, + mscclpp_group.barrier() + bench_correctness( + execution_plan_name, + input_buf, + result_buf, + test_buf, + dtype_str, + mscclpp_group.my_rank, + mscclpp_group.nranks, + n_iters, + executor_func, ) mscclpp_group.barrier() - execution_time = bench_time(10, 10, executor_func) + execution_time = bench_time(n_iters, n_graph_iters, executor_func) if npkit_dump_dir is not None: npkit.dump(npkit_dump_dir) npkit.shutdown() print( - f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, " - f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} " + f"Rank: {mscclpp_group.my_rank} Execution time: {execution_time} us, " + f"data size: {result_buf.nbytes} bytes data type: {dtype().dtype.name} " f"packet type: {packet_type}" ) executor = None @@ -188,7 +246,8 @@ def main( parser.add_argument("--in_place", action="store_true", help="flag to define an in-place operation", default=True) parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32") parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16") - parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_iters", type=int, default=10) + parser.add_argument("--n_graph_iters", type=int, default=10) args = parser.parse_args() packet_type = PacketType.LL16 @@ -196,13 +255,13 @@ def main( packet_type = PacketType.LL8 buffer_size = parse_size(args.size) - dtype = parse_dtype(args.dtype) main( args.execution_plan_name, args.execution_plan_path, buffer_size, args.in_place, - dtype, + args.dtype, packet_type, - args.seed, + args.n_iters, + args.n_graph_iters, ) diff --git a/python/test/executor_test_verifier.cu b/python/test/executor_test_verifier.cu new file mode 100644 index 000000000..d99b93611 --- /dev/null +++ b/python/test/executor_test_verifier.cu @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#if defined(__HIP_PLATFORM_AMD__) +#include +#else +#include +#endif + +// Numerical Recipes ranqd1, Chapter 7.1, §An Even Quicker Generator, Eq. 7.1.6 +// parameters from Knuth and H. W. Lewis +static __device__ unsigned int ranqd1(unsigned int seed) { + const unsigned int a = 1664525; + const unsigned int c = 1013904223; + return a * seed + c; +} + +// fill/test kernel pairs must have the same thread block size to +// match their random number series. + +#define FILL_DATA(FuncNameType, DataType) \ + extern "C" __global__ void __launch_bounds__(1024, 1) \ + fill_data_##FuncNameType(DataType* input_buf, size_t num_elems, int rank, int seq) { \ + unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + seed = ranqd1(seed); \ + input_buf[i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \ + } \ + } + +FILL_DATA(float16, __half) +FILL_DATA(float32, float) +FILL_DATA(int32, int) + +#define TEST_DATA_ALL_GATHER(FuncNameType, DataType) \ + extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_gather_##FuncNameType( \ + DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \ + for (int rank = 0; rank < num_ranks; rank++) { \ + size_t rank_offset = rank * num_elems; \ + unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + seed = ranqd1(seed); \ + test_buf[rank_offset + i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \ + assert(result_buf[i] == test_buf[i]); \ + } \ + } \ + } + +TEST_DATA_ALL_GATHER(float16, __half) +TEST_DATA_ALL_GATHER(float32, float) +TEST_DATA_ALL_GATHER(int32, int) + +#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType) \ + extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_reduce_##FuncNameType( \ + DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \ + for (int rank = 0; rank < num_ranks; rank++) { \ + unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + if (rank == 0) { \ + test_buf[i] = 0; \ + } \ + seed = ranqd1(seed); \ + test_buf[i] += DataType(seed % blockDim.x) / DataType(blockDim.x); \ + } \ + } \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \ + } \ + } + +TEST_DATA_ALL_REDUCE(float16, __half) +TEST_DATA_ALL_REDUCE(float32, float) +TEST_DATA_ALL_REDUCE(int32, int) diff --git a/src/context.cc b/src/context.cc index f8bb3ec83..12b52b78a 100644 --- a/src/context.cc +++ b/src/context.cc @@ -11,7 +11,7 @@ namespace mscclpp { -Context::Impl::Impl() : ipcStream_(cudaStreamNonBlocking) {} +Context::Impl::Impl() {} IbCtx* Context::Impl::getIbContext(Transport ibTransport) { // Find IB context or create it @@ -43,7 +43,10 @@ MSCCLPP_API_CPP std::shared_ptr Context::connect(Endpoint localEndpo if (remoteEndpoint.transport() != Transport::CudaIpc) { throw mscclpp::Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage); } - conn = std::make_shared(localEndpoint, remoteEndpoint, pimpl_->ipcStream_); + if (!(pimpl_->ipcStream_)) { + pimpl_->ipcStream_ = std::make_shared(cudaStreamNonBlocking); + } + conn = std::make_shared(localEndpoint, remoteEndpoint, cudaStream_t(*(pimpl_->ipcStream_))); } else if (AllIBTransports.has(localEndpoint.transport())) { if (!AllIBTransports.has(remoteEndpoint.transport())) { throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage); diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index 7482c922c..49ceddf0a 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -332,7 +332,20 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) { } } -void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffset, size_t constDstOffset) { +void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffset, size_t constDstOffset) { + auto getConstOffset = [&](BufferType type) -> size_t { + switch (type) { + case BufferType::INPUT: + return constSrcOffset; + case BufferType::OUTPUT: + return constDstOffset; + case BufferType::SCRATCH: + return 0; + default: + throw Error("Invalid buffer type", ErrorCode::ExecutorError); + } + }; + // setup threadblocks and operations for (const auto& gpu : gpus) { int rank = gpu["id"]; @@ -378,7 +391,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]]; operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["i_cids"][i]["off"]) + - (srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0); + getConstOffset(srcBufferType); chunkIndexes.push_back((uint32_t)op["i_cids"][i]["off"]); } } @@ -390,7 +403,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse for (int i = 0; i < operation.nInputs; i++) { operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcs"][i]["off"]) + - (operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0); + getConstOffset(operation.inputBufferType); chunkIndexes.push_back((uint32_t)op["srcs"][i]["off"]); } } @@ -409,7 +422,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]]; operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["o_cids"][i]["off"]) + - (dstBufferType != BufferType::SCRATCH ? constDstOffset : 0); + getConstOffset(dstBufferType); chunkIndexes.push_back((uint32_t)op["o_cids"][i]["off"]); } } @@ -421,7 +434,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse for (int i = 0; i < operation.nOutputs; i++) { operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dsts"][i]["off"]) + - (operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0); + getConstOffset(operation.outputBufferType); chunkIndexes.push_back((uint32_t)op["dsts"][i]["off"]); } } diff --git a/src/executor/executor.cc b/src/executor/executor.cc index d8b295d72..ae34fa1bb 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -36,21 +36,62 @@ void* getBuffer(BufferType type, void* sendbuff, void* recvbuff, void* scratch) throw Error("Invalid buffer type", ErrorCode::ExecutorError); } }; + +struct DeviceExecutionPlanKey { + size_t inputMessageSize; + size_t outputMessageSize; + size_t constSrcOffset; + size_t constDstOffset; + + bool operator==(const DeviceExecutionPlanKey& other) const { + return inputMessageSize == other.inputMessageSize && outputMessageSize == other.outputMessageSize && + constSrcOffset == other.constSrcOffset && constDstOffset == other.constDstOffset; + } +}; + } // namespace mscclpp namespace std { + +// Refer https://www.boost.org/doc/libs/1_86_0/libs/container_hash/doc/html/hash.html#combine +template +inline void hash_combine(std::size_t& seed, const T& value) { + std::hash hasher; + seed ^= hasher(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + template <> struct hash> { std::size_t operator()(const std::pair& key) const { - return std::hash()(key.second) ^ std::hash()(static_cast(key.first)); + std::size_t seed = 0; + hash_combine(seed, static_cast(key.first)); + hash_combine(seed, key.second); + return seed; } }; template <> struct hash { std::size_t operator()(const mscclpp::ExecutionContextKey& key) const { - return std::hash()(key.sendBuff) ^ std::hash()(key.recvBuff) ^ std::hash()(key.sendBuffSize) ^ - std::hash()(key.recvBuffSize) ^ std::hash()(key.plan); + size_t seed = 0; + hash_combine(seed, key.sendBuff); + hash_combine(seed, key.recvBuff); + hash_combine(seed, key.sendBuffSize); + hash_combine(seed, key.recvBuffSize); + hash_combine(seed, key.plan); + return seed; + } +}; + +template <> +struct hash { + std::size_t operator()(const mscclpp::DeviceExecutionPlanKey& key) const { + std::size_t seed = 0; + hash_combine(seed, key.inputMessageSize); + hash_combine(seed, key.outputMessageSize); + hash_combine(seed, key.constSrcOffset); + hash_combine(seed, key.constDstOffset); + return seed; } }; } // namespace std @@ -77,11 +118,12 @@ struct ExecutionContext { std::vector smChannels; std::vector proxyChannels; std::vector nvlsChannels; - std::vector deviceExecutionPlans; + std::unordered_map> deviceExecutionPlans; + std::unordered_map> deviceExecutionPlansBuffers; std::shared_ptr scratchBuffer; size_t scratchBufferSize; - std::shared_ptr deviceExecutionPlansBuffer; int nthreadsPerBlock; + DeviceExecutionPlanKey currentDevicePlan; }; struct Executor::Impl { @@ -97,23 +139,32 @@ struct Executor::Impl { ~Impl() = default; ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t inputMessageSize, - size_t outputMessageSize, size_t contsSrcOffset, size_t constDstOffset, + size_t outputMessageSize, size_t constSrcOffset, size_t constDstOffset, size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) { ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name}; + DeviceExecutionPlanKey devicePlanKey = {inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset}; if (this->contexts.find(key) != this->contexts.end()) { + auto& devicePlans = this->contexts[key].deviceExecutionPlans; + if (this->contexts[key].currentDevicePlan == devicePlanKey) { + return this->contexts[key]; + } else if (devicePlans.find(devicePlanKey) != devicePlans.end()) { + this->contexts[key].currentDevicePlan = devicePlanKey; + return this->contexts[key]; + } plan.impl_->operationsReset(); - plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset); - this->setupDeviceExecutionPlan(this->contexts[key], rank, plan); - this->contexts[key].deviceExecutionPlansBuffer = - allocExtSharedCuda(this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); - memcpyCuda(this->contexts[key].deviceExecutionPlansBuffer.get(), - (char*)this->contexts[key].deviceExecutionPlans.data(), - this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); + plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset); + this->setupDeviceExecutionPlan(this->contexts[key], devicePlanKey, rank, plan); + this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey] = + allocExtSharedCuda(devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)); + memcpyCuda(this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey].get(), + (char*)devicePlans[devicePlanKey].data(), + devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); + this->contexts[key].currentDevicePlan = devicePlanKey; return this->contexts[key]; } plan.impl_->reset(); - plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset); + plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset); ExecutionContext context; size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize, recvBufferSize); @@ -131,11 +182,14 @@ struct Executor::Impl { this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); this->setupNvlsChannels(context, sendbuff, recvbuff, rank, plan); - this->setupDeviceExecutionPlan(context, rank, plan); - context.deviceExecutionPlansBuffer = - allocExtSharedCuda(context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); - memcpyCuda(context.deviceExecutionPlansBuffer.get(), (char*)context.deviceExecutionPlans.data(), - context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); + this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan); + context.deviceExecutionPlansBuffers[devicePlanKey] = + allocExtSharedCuda(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)); + memcpyCuda(context.deviceExecutionPlansBuffers[devicePlanKey].get(), + (char*)context.deviceExecutionPlans[devicePlanKey].data(), + context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), + cudaMemcpyHostToDevice); + context.currentDevicePlan = devicePlanKey; context.proxyService->startProxy(); this->contexts.insert({key, context}); return context; @@ -302,13 +356,14 @@ struct Executor::Impl { std::shared_ptr nvlsConnection = context.nvlsConnections[i]; NvlsInfo info = nvlsInfos[i]; void* buffer = getBuffer(info.bufferType, sendbuff, recvbuff, context.scratchBuffer.get()); - std::shared_ptr nvlsPtr = nvlsConnection->bindAllocatedCuda((CUdeviceptr)buffer, info.bufferSize); - NvlsConnection::DeviceMulticastPointer deviceMulticastPointer(buffer, nvlsPtr, info.bufferSize); + NvlsConnection::DeviceMulticastPointer deviceMulticastPointer = + nvlsConnection->bindAllocatedMemory((CUdeviceptr)buffer, info.bufferSize); context.nvlsChannels.push_back(deviceMulticastPointer); } } - void setupDeviceExecutionPlan(ExecutionContext& context, int rank, const ExecutionPlan& plan) { + void setupDeviceExecutionPlan(ExecutionContext& context, const DeviceExecutionPlanKey& key, int rank, + const ExecutionPlan& plan) { std::vector deviceExecutionPlans; for (int threadblock = 0; threadblock < plan.impl_->getThreadblockCount(rank); threadblock++) { DeviceExecutionPlan deviceExecutionPlan = {}; @@ -333,13 +388,14 @@ struct Executor::Impl { } deviceExecutionPlans.push_back(deviceExecutionPlan); } - context.deviceExecutionPlans = std::move(deviceExecutionPlans); + context.deviceExecutionPlans[key] = std::move(deviceExecutionPlans); } void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType, cudaStream_t stream, PacketType packetType) { static uint32_t flag = 0; - int nthreadblocks = context.deviceExecutionPlans.size(); + DeviceExecutionPlanKey key = context.currentDevicePlan; + int nthreadblocks = context.deviceExecutionPlans[key].size(); #if defined(ENABLE_NPKIT) #if defined(__HIP_PLATFORM_AMD__) if (nthreadblocks > NPKIT_MAX_NUM_GPU_THREADBLOCKS) { @@ -357,13 +413,13 @@ struct Executor::Impl { case PacketType::LL16: ExecutionKernel::launchKernel( rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), - context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), + context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(), sharedMemSize, stream, ++flag); break; case PacketType::LL8: ExecutionKernel::launchKernel( rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), - context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), + context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(), sharedMemSize, stream, ++flag); break; default: diff --git a/src/include/context.hpp b/src/include/context.hpp index abb95b27d..e88c7e5fa 100644 --- a/src/include/context.hpp +++ b/src/include/context.hpp @@ -16,7 +16,7 @@ namespace mscclpp { struct Context::Impl { std::vector> connections_; std::unordered_map> ibContexts_; - CudaStreamWithFlags ipcStream_; + std::shared_ptr ipcStream_; CUmemGenericAllocationHandle mcHandle_; Impl(); diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 2f7727636..f53add5a5 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -43,7 +43,6 @@ struct RegisteredMemory::Impl { size_t size; uint64_t hostHash; uint64_t pidHash; - bool isCuMemMapAlloc; TransportFlags transports; std::vector transportInfos; diff --git a/src/nvls.cc b/src/nvls.cc index 166531d3a..dbfae2e77 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -257,16 +257,9 @@ NvlsConnection::NvlsConnection(const std::vector& data) : pimpl_(std::make std::vector NvlsConnection::serialize() { return pimpl_->serialize(); } -std::shared_ptr NvlsConnection::allocateAndBindCuda(size_t size) { - CUmemGenericAllocationHandle handle; - auto ptr = allocSharedPhysicalCuda(size, pimpl_->getMinMcGran()); - MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, ptr.get())); - auto mcPtr = pimpl_->bindMemory(handle, size); - return std::make_shared(ptr, mcPtr, size); -} - -std::shared_ptr NvlsConnection::bindAllocatedCuda(CUdeviceptr devicePtr, size_t size) { - return pimpl_->bindMemory(devicePtr, size); +NvlsConnection::DeviceMulticastPointer NvlsConnection::bindAllocatedMemory(CUdeviceptr devicePtr, size_t size) { + auto mcPtr = pimpl_->bindMemory(devicePtr, size); + return DeviceMulticastPointer((void*)devicePtr, mcPtr, size); } NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() { diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 2396cd94d..8e6157797 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -12,18 +12,7 @@ #include "utils_internal.hpp" namespace { -// Check if ptr is allocaed by cuMemMap -bool isCuMemMapAllocated(void* ptr) { - CUmemGenericAllocationHandle handle; - CUresult result = cuMemRetainAllocationHandle(&handle, ptr); - if (result != CUDA_SUCCESS) { - return false; - } - MSCCLPP_CUTHROW(cuMemRelease(handle)); - return true; -} - -#if (CUDA_FABRIC_SUPPORTED) +#if (CUDA_NVLS_SUPPORTED) // Get the recommended granularity for cuMemAddressReserve size_t getRecommendedGranularity() { size_t gran = 0; @@ -52,7 +41,6 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports, size(size), hostHash(getHostHash()), pidHash(getPidHash()), - isCuMemMapAlloc(false), transports(transports) { if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; @@ -61,23 +49,12 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports, void* baseDataPtr; size_t baseDataSize; // dummy MSCCLPP_CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); - if (isCuMemMapAllocated(data)) { - this->isCuMemMapAlloc = true; - } - if (this->isCuMemMapAlloc) { -#if (CUDA_FABRIC_SUPPORTED) - if (isFabricSupported()) { - CUmemGenericAllocationHandle handle; - MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, baseDataPtr)); - MSCCLPP_CUTHROW( - cuMemExportToShareableHandle(transportInfo.shareableHandle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); - transportInfo.offsetFromBase = (char*)data - (char*)baseDataPtr; - } else { - throw Error("Fabric is not supported", ErrorCode::InvalidUsage); - } -#else - throw Error("Only support cuMemMap with CUDA 12.4 or later", ErrorCode::InvalidUsage); -#endif + if (isNvlsSupported()) { + CUmemGenericAllocationHandle handle; + MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, baseDataPtr)); + MSCCLPP_CUTHROW( + cuMemExportToShareableHandle(transportInfo.shareableHandle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); + transportInfo.offsetFromBase = (char*)data - (char*)baseDataPtr; } else { cudaIpcMemHandle_t handle; MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr)); @@ -128,8 +105,6 @@ MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() { std::copy_n(reinterpret_cast(&pimpl_->size), sizeof(pimpl_->size), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl_->hostHash), sizeof(pimpl_->hostHash), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl_->pidHash), sizeof(pimpl_->pidHash), std::back_inserter(result)); - std::copy_n(reinterpret_cast(&pimpl_->isCuMemMapAlloc), sizeof(pimpl_->isCuMemMapAlloc), - std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl_->transports), sizeof(pimpl_->transports), std::back_inserter(result)); if (pimpl_->transportInfos.size() > static_cast(std::numeric_limits::max())) { throw mscclpp::Error("Too many transport info entries", ErrorCode::InternalError); @@ -139,7 +114,7 @@ MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() { for (auto& entry : pimpl_->transportInfos) { std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); if (entry.transport == Transport::CudaIpc) { - if (pimpl_->isCuMemMapAlloc) { + if (isNvlsSupported()) { std::copy_n(reinterpret_cast(&entry.shareableHandle), sizeof(entry.shareableHandle), std::back_inserter(result)); std::copy_n(reinterpret_cast(&entry.offsetFromBase), sizeof(entry.offsetFromBase), @@ -173,8 +148,6 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { it += sizeof(this->hostHash); std::copy_n(it, sizeof(this->pidHash), reinterpret_cast(&this->pidHash)); it += sizeof(this->pidHash); - std::copy_n(it, sizeof(this->isCuMemMapAlloc), reinterpret_cast(&this->isCuMemMapAlloc)); - it += sizeof(this->isCuMemMapAlloc); std::copy_n(it, sizeof(this->transports), reinterpret_cast(&this->transports)); it += sizeof(this->transports); int8_t transportCount; @@ -185,7 +158,7 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast(&transportInfo.transport)); it += sizeof(transportInfo.transport); if (transportInfo.transport == Transport::CudaIpc) { - if (this->isCuMemMapAlloc) { + if (isNvlsSupported()) { std::copy_n(it, sizeof(transportInfo.shareableHandle), reinterpret_cast(&transportInfo.shareableHandle)); it += sizeof(transportInfo.shareableHandle); std::copy_n(it, sizeof(transportInfo.offsetFromBase), reinterpret_cast(&transportInfo.offsetFromBase)); @@ -219,22 +192,14 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { // The memory is local to the machine but not to the process, so we need to open the CUDA IPC handle auto entry = getTransportInfo(Transport::CudaIpc); void* base; - if (this->isCuMemMapAlloc) { -#if (CUDA_FABRIC_SUPPORTED) - if (isFabricSupported()) { - CUmemGenericAllocationHandle handle; - MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, CU_MEM_HANDLE_TYPE_FABRIC)); - size_t gran = getRecommendedGranularity(); - MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, this->size, gran, 0, 0)); - MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, this->size, 0, handle, 0)); - setReadWriteMemoryAccess(base, this->size); - this->data = static_cast(base) + entry.offsetFromBase; - } else { - throw Error("Fabric is not supported", ErrorCode::InvalidUsage); - } -#else - throw Error("Only support cuMemMap with CUDA 12.4 or later", ErrorCode::InvalidUsage); -#endif + if (isNvlsSupported()) { + CUmemGenericAllocationHandle handle; + MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, CU_MEM_HANDLE_TYPE_FABRIC)); + size_t gran = getRecommendedGranularity(); + MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, this->size, gran, 0, 0)); + MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, this->size, 0, handle, 0)); + setReadWriteMemoryAccess(base, this->size); + this->data = static_cast(base) + entry.offsetFromBase; } else { MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); this->data = static_cast(base) + entry.cudaIpcOffsetFromBase; @@ -250,19 +215,15 @@ RegisteredMemory::Impl::~Impl() { // Close the CUDA IPC handle if it was opened during deserialization if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) { void* base = static_cast(data) - getTransportInfo(Transport::CudaIpc).cudaIpcOffsetFromBase; - if (this->isCuMemMapAlloc) { - if (isFabricSupported()) { - CUmemGenericAllocationHandle handle; - size_t size = 0; - MSCCLPP_CULOG_WARN(cuMemRetainAllocationHandle(&handle, base)); - MSCCLPP_CULOG_WARN(cuMemRelease(handle)); - MSCCLPP_CULOG_WARN(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)base)); - MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size)); - MSCCLPP_CULOG_WARN(cuMemRelease(handle)); - MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size)); - } else { - WARN("Fabric is not supported, skipping cleanup"); - } + if (isNvlsSupported()) { + CUmemGenericAllocationHandle handle; + size_t size = 0; + MSCCLPP_CULOG_WARN(cuMemRetainAllocationHandle(&handle, base)); + MSCCLPP_CULOG_WARN(cuMemRelease(handle)); + MSCCLPP_CULOG_WARN(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)base)); + MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size)); + MSCCLPP_CULOG_WARN(cuMemRelease(handle)); + MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size)); } else { cudaError_t err = cudaIpcCloseMemHandle(base); if (err != cudaSuccess) { diff --git a/src/utils.cc b/src/utils.cc index a6a007c84..fb470a4ab 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -70,7 +70,7 @@ std::string getHostName(int maxlen, const char delim) { bool isNvlsSupported() { [[maybe_unused]] static bool result = false; [[maybe_unused]] static bool isChecked = false; -#if (CUDA_FABRIC_SUPPORTED) +#if (CUDA_NVLS_SUPPORTED) if (!isChecked) { int isMulticastSupported; int isFabricSupported; @@ -85,20 +85,4 @@ bool isNvlsSupported() { return false; } -bool isFabricSupported() { - [[maybe_unused]] static bool result = false; - [[maybe_unused]] static bool isChecked = false; -#if (CUDA_FABRIC_SUPPORTED) - if (!isChecked) { - int isFabricSupported; - CUdevice dev; - MSCCLPP_CUTHROW(cuCtxGetDevice(&dev)); - MSCCLPP_CUTHROW(cuDeviceGetAttribute(&isFabricSupported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev)); - result = (isFabricSupported == 1); - } - return result; -#endif - return false; -} - } // namespace mscclpp diff --git a/test/executor_test.cc b/test/executor_test.cc index 28bf34828..3fc0b1e21 100644 --- a/test/executor_test.cc +++ b/test/executor_test.cc @@ -131,7 +131,7 @@ int main(int argc, char* argv[]) { } mscclpp::ExecutionPlan plan(executionPlanName, executionPlanPath); -#if (CUDA_FABRIC_SUPPORTED) +#if (CUDA_NVLS_SUPPORTED) std::shared_ptr sendbuff = mscclpp::allocSharedPhysicalCuda(bufferSize); #else std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize);