Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Nov 1, 2024
1 parent 66d624f commit 49e66f5
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 36 deletions.
6 changes: 1 addition & 5 deletions include/mscclpp/nvls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@ class NvlsConnection {

std::shared_ptr<DeviceMulticastPointer> allocateAndBindCuda(size_t size);

/// The \p handle to the allocation (its lifetime is managed by the caller)
/// and the \p size of the allocation.
std::shared_ptr<char> bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size);

std::shared_ptr<char> bindAllocatedCudaWithPtr(CUdeviceptr devicePtr, size_t size);
std::shared_ptr<char> bindAllocatedCuda(CUdeviceptr devicePtr, size_t size);

size_t getMultiCastMinGranularity();

Expand Down
2 changes: 1 addition & 1 deletion src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ struct Executor::Impl {
for (const NvlsInfo& info : nvlsInfos) {
std::shared_ptr<NvlsConnection> nvlsConnection = context.nvlsConnections[info];
void* buffer = getBuffer(info.bufferType, sendbuff, recvbuff, context.scratchBuffer.get());
std::shared_ptr<char> nvlsPtr = nvlsConnection->bindAllocatedCudaWithPtr((CUdeviceptr)buffer, info.bufferSize);
std::shared_ptr<char> nvlsPtr = nvlsConnection->bindAllocatedCuda((CUdeviceptr)buffer, info.bufferSize);
NvlsConnection::DeviceMulticastPointer deviceMulticastPointer(buffer, nvlsPtr, info.bufferSize);
context.nvlsChannels.push_back(deviceMulticastPointer);
}
Expand Down
43 changes: 13 additions & 30 deletions src/nvls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ class NvlsConnection::Impl : public std::enable_shared_from_this<NvlsConnection:
void addDevice(int cudaDeviceId);
size_t allocateBuffer(size_t size);
void freeBuffer(size_t offset, size_t size) noexcept;
std::shared_ptr<char> bindMemory(CUmemGenericAllocationHandle memHandle, size_t devBuffSize);
std::shared_ptr<char> bindMemoryWithPtr(CUdeviceptr devicePtr, size_t devBuffSize);
std::shared_ptr<char> bindMemoryToMulticastHandle(size_t offset, size_t bufferSize);
std::shared_ptr<char> bindMemory(CUdeviceptr devicePtr, size_t devBuffSize);

private:
friend class NvlsConnection;
Expand Down Expand Up @@ -194,36 +192,26 @@ void NvlsConnection::Impl::freeBuffer(size_t offset, size_t size) noexcept {
}
}

std::shared_ptr<char> NvlsConnection::Impl::bindMemory(CUmemGenericAllocationHandle memHandle, size_t devBuffSize) {
devBuffSize = ((devBuffSize + minMcGran_ - 1) / minMcGran_) * minMcGran_;
size_t offset = allocateBuffer(devBuffSize);
MSCCLPP_CUTHROW(cuMulticastBindMem(mcHandle_, offset /*mcOffset*/, memHandle, 0 /*memOffset*/, devBuffSize, 0));
return bindMemoryToMulticastHandle(offset, devBuffSize);
}

std::shared_ptr<char> NvlsConnection::Impl::bindMemoryWithPtr(CUdeviceptr devicePtr, size_t devBuffSize) {
std::shared_ptr<char> NvlsConnection::Impl::bindMemory(CUdeviceptr devicePtr, size_t devBuffSize) {
devBuffSize = ((devBuffSize + minMcGran_ - 1) / minMcGran_) * minMcGran_;
size_t offset = allocateBuffer(devBuffSize);
MSCCLPP_CUTHROW(cuMulticastBindAddr(mcHandle_, offset /*mcOffset*/, devicePtr, devBuffSize, 0));
return bindMemoryToMulticastHandle(offset, devBuffSize);
}

std::shared_ptr<char> NvlsConnection::Impl::bindMemoryToMulticastHandle(size_t offset, size_t bufferSize) {
char* mcPtr;
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)(&mcPtr), bufferSize, minMcGran_, 0U, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)(mcPtr), bufferSize, 0, mcHandle_, 0));
setReadWriteMemoryAccess(mcPtr, bufferSize);
INFO(MSCCLPP_COLL, "NVLS connection bound memory at offset %ld, size %ld", offset, bufferSize);
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)(&mcPtr), devBuffSize, minMcGran_, 0U, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)(mcPtr), devBuffSize, 0, mcHandle_, 0));
setReadWriteMemoryAccess(mcPtr, devBuffSize);
INFO(MSCCLPP_COLL, "NVLS connection bound memory at offset %ld, size %ld", offset, devBuffSize);

auto deleter = [=, self = shared_from_this()](char* ptr) {
int deviceId;
CUdevice device;
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId));
MSCCLPP_CUTHROW(cuDeviceGet(&device, deviceId));
MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, bufferSize));
MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, bufferSize));
MSCCLPP_CUTHROW(cuMulticastUnbind(mcHandle_, device, offset, bufferSize));
self->freeBuffer(offset, bufferSize);
MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, devBuffSize));
MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, devBuffSize));
MSCCLPP_CUTHROW(cuMulticastUnbind(mcHandle_, device, offset, devBuffSize));
self->freeBuffer(offset, devBuffSize);
};

return std::shared_ptr<char>(mcPtr, deleter);
Expand All @@ -242,8 +230,7 @@ class NvlsConnection::Impl {
std::vector<char> serialize() { throw notSupportedError; }
size_t allocateBuffer(size_t) { throw notSupportedError; }
void freeBuffer(size_t, size_t) { throw notSupportedError; }
std::shared_ptr<char> bindMemory(CUmemGenericAllocationHandle, size_t) { throw notSupportedError; }
std::shared_ptr<char> bindMemoryWithPtr(CUdeviceptr, size_t) { throw notSupportedError; }
std::shared_ptr<char> bindMemory(CUdeviceptr, size_t) { throw notSupportedError; }
void addDevice(int) { throw notSupportedError; }
size_t getMinMcGran() { throw notSupportedError; }

Expand Down Expand Up @@ -278,12 +265,8 @@ std::shared_ptr<NvlsConnection::DeviceMulticastPointer> NvlsConnection::allocate
return std::make_shared<DeviceMulticastPointer>(ptr, mcPtr, size);
}

std::shared_ptr<char> NvlsConnection::bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size) {
return pimpl_->bindMemory(memHandle, size);
}

std::shared_ptr<char> NvlsConnection::bindAllocatedCudaWithPtr(CUdeviceptr devicePtr, size_t size) {
return pimpl_->bindMemoryWithPtr(devicePtr, size);
std::shared_ptr<char> NvlsConnection::bindAllocatedCuda(CUdeviceptr devicePtr, size_t size) {
return pimpl_->bindMemory(devicePtr, size);
}

NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() {
Expand Down

0 comments on commit 49e66f5

Please sign in to comment.