From 49e66f548281f57206124702830f215e308daa69 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 1 Nov 2024 09:21:46 +0000 Subject: [PATCH] WIP --- include/mscclpp/nvls.hpp | 6 +----- src/executor/executor.cc | 2 +- src/nvls.cc | 43 ++++++++++++---------------------------- 3 files changed, 15 insertions(+), 36 deletions(-) diff --git a/include/mscclpp/nvls.hpp b/include/mscclpp/nvls.hpp index 8124fa3ed..c6d96d2a4 100644 --- a/include/mscclpp/nvls.hpp +++ b/include/mscclpp/nvls.hpp @@ -44,11 +44,7 @@ class NvlsConnection { std::shared_ptr allocateAndBindCuda(size_t size); - /// The \p handle to the allocation (its lifetime is managed by the caller) - /// and the \p size of the allocation. - std::shared_ptr bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size); - - std::shared_ptr bindAllocatedCudaWithPtr(CUdeviceptr devicePtr, size_t size); + std::shared_ptr bindAllocatedCuda(CUdeviceptr devicePtr, size_t size); size_t getMultiCastMinGranularity(); diff --git a/src/executor/executor.cc b/src/executor/executor.cc index 84eb8eb73..09c96d2d3 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -301,7 +301,7 @@ struct Executor::Impl { for (const NvlsInfo& info : nvlsInfos) { std::shared_ptr nvlsConnection = context.nvlsConnections[info]; void* buffer = getBuffer(info.bufferType, sendbuff, recvbuff, context.scratchBuffer.get()); - std::shared_ptr nvlsPtr = nvlsConnection->bindAllocatedCudaWithPtr((CUdeviceptr)buffer, info.bufferSize); + std::shared_ptr nvlsPtr = nvlsConnection->bindAllocatedCuda((CUdeviceptr)buffer, info.bufferSize); NvlsConnection::DeviceMulticastPointer deviceMulticastPointer(buffer, nvlsPtr, info.bufferSize); context.nvlsChannels.push_back(deviceMulticastPointer); } diff --git a/src/nvls.cc b/src/nvls.cc index bbb21a7c3..166531d3a 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -31,9 +31,7 @@ class NvlsConnection::Impl : public std::enable_shared_from_this bindMemory(CUmemGenericAllocationHandle memHandle, size_t devBuffSize); - std::shared_ptr bindMemoryWithPtr(CUdeviceptr devicePtr, size_t devBuffSize); - std::shared_ptr bindMemoryToMulticastHandle(size_t offset, size_t bufferSize); + std::shared_ptr bindMemory(CUdeviceptr devicePtr, size_t devBuffSize); private: friend class NvlsConnection; @@ -194,36 +192,26 @@ void NvlsConnection::Impl::freeBuffer(size_t offset, size_t size) noexcept { } } -std::shared_ptr NvlsConnection::Impl::bindMemory(CUmemGenericAllocationHandle memHandle, size_t devBuffSize) { - devBuffSize = ((devBuffSize + minMcGran_ - 1) / minMcGran_) * minMcGran_; - size_t offset = allocateBuffer(devBuffSize); - MSCCLPP_CUTHROW(cuMulticastBindMem(mcHandle_, offset /*mcOffset*/, memHandle, 0 /*memOffset*/, devBuffSize, 0)); - return bindMemoryToMulticastHandle(offset, devBuffSize); -} - -std::shared_ptr NvlsConnection::Impl::bindMemoryWithPtr(CUdeviceptr devicePtr, size_t devBuffSize) { +std::shared_ptr NvlsConnection::Impl::bindMemory(CUdeviceptr devicePtr, size_t devBuffSize) { devBuffSize = ((devBuffSize + minMcGran_ - 1) / minMcGran_) * minMcGran_; size_t offset = allocateBuffer(devBuffSize); MSCCLPP_CUTHROW(cuMulticastBindAddr(mcHandle_, offset /*mcOffset*/, devicePtr, devBuffSize, 0)); - return bindMemoryToMulticastHandle(offset, devBuffSize); -} -std::shared_ptr NvlsConnection::Impl::bindMemoryToMulticastHandle(size_t offset, size_t bufferSize) { char* mcPtr; - MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)(&mcPtr), bufferSize, minMcGran_, 0U, 0)); - MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)(mcPtr), bufferSize, 0, mcHandle_, 0)); - setReadWriteMemoryAccess(mcPtr, bufferSize); - INFO(MSCCLPP_COLL, "NVLS connection bound memory at offset %ld, size %ld", offset, bufferSize); + MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)(&mcPtr), devBuffSize, minMcGran_, 0U, 0)); + MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)(mcPtr), devBuffSize, 0, mcHandle_, 0)); + setReadWriteMemoryAccess(mcPtr, devBuffSize); + INFO(MSCCLPP_COLL, "NVLS connection bound memory at offset %ld, size %ld", offset, devBuffSize); auto deleter = [=, self = shared_from_this()](char* ptr) { int deviceId; CUdevice device; MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); MSCCLPP_CUTHROW(cuDeviceGet(&device, deviceId)); - MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, bufferSize)); - MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, bufferSize)); - MSCCLPP_CUTHROW(cuMulticastUnbind(mcHandle_, device, offset, bufferSize)); - self->freeBuffer(offset, bufferSize); + MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, devBuffSize)); + MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, devBuffSize)); + MSCCLPP_CUTHROW(cuMulticastUnbind(mcHandle_, device, offset, devBuffSize)); + self->freeBuffer(offset, devBuffSize); }; return std::shared_ptr(mcPtr, deleter); @@ -242,8 +230,7 @@ class NvlsConnection::Impl { std::vector serialize() { throw notSupportedError; } size_t allocateBuffer(size_t) { throw notSupportedError; } void freeBuffer(size_t, size_t) { throw notSupportedError; } - std::shared_ptr bindMemory(CUmemGenericAllocationHandle, size_t) { throw notSupportedError; } - std::shared_ptr bindMemoryWithPtr(CUdeviceptr, size_t) { throw notSupportedError; } + std::shared_ptr bindMemory(CUdeviceptr, size_t) { throw notSupportedError; } void addDevice(int) { throw notSupportedError; } size_t getMinMcGran() { throw notSupportedError; } @@ -278,12 +265,8 @@ std::shared_ptr NvlsConnection::allocate return std::make_shared(ptr, mcPtr, size); } -std::shared_ptr NvlsConnection::bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size) { - return pimpl_->bindMemory(memHandle, size); -} - -std::shared_ptr NvlsConnection::bindAllocatedCudaWithPtr(CUdeviceptr devicePtr, size_t size) { - return pimpl_->bindMemoryWithPtr(devicePtr, size); +std::shared_ptr NvlsConnection::bindAllocatedCuda(CUdeviceptr devicePtr, size_t size) { + return pimpl_->bindMemory(devicePtr, size); } NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() {