Skip to content

Commit

Permalink
NVLS support for NCCL API (#410)
Browse files Browse the repository at this point in the history
Co-authored-by: Qinghua Zhou <[email protected]>
Co-authored-by: Changho Hwang <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2024
1 parent 863a599 commit fcb2e46
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 11 deletions.
9 changes: 9 additions & 0 deletions apps/nccl/include/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ typedef struct ncclConfig_v21700 {
NCCL_CONFIG_UNDEF_INT /* splitShare */ \
}

/* NCCL malloc and free function for all types of NCCL optimizations
* (e.g. user buffer registration). The actual allocated size might
* be larger than requested due to granularity requirement. */
ncclResult_t ncclMemAlloc(void** ptr, size_t size);
ncclResult_t pncclMemAlloc(void** ptr, size_t size);

ncclResult_t ncclMemFree(void* ptr);
ncclResult_t pncclMemFree(void* ptr);

/* Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer.
* This integer is coded with the MAJOR, MINOR and PATCH level of the
* NCCL library
Expand Down
62 changes: 61 additions & 1 deletion apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <mscclpp/executor.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include <mscclpp/utils.hpp>
#include <sstream>
#include <unordered_map>
#include <vector>
Expand All @@ -33,6 +34,9 @@
// mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
// mscclpp::Transport::IB6, mscclpp::Transport::IB7};

// Declare the global map to store associations between raw pointer and shared pointer
static std::unordered_map<void*, std::shared_ptr<char>> ptrMap;

struct channelKey {
const void* buff;
size_t bytes;
Expand Down Expand Up @@ -113,7 +117,7 @@ static size_t ncclTypeSize(ncclDataType_t type) {
return 0;
}

double parseSize(const char* value) {
static double parseSize(const char* value) {
std::string valueStr(value);
std::istringstream iss(valueStr);
long long int units;
Expand Down Expand Up @@ -644,3 +648,59 @@ NCCL_API ncclResult_t ncclGroupEnd() {
// Do nothing
return ncclSuccess;
}

NCCL_API ncclResult_t ncclCommRegister(const ncclComm_t, void*, size_t, void**) {
// TODO: Implementation
return ncclSuccess;
}

NCCL_API ncclResult_t ncclCommDeregister(const ncclComm_t, void*) {
// TODO: Implementation
return ncclSuccess;
}

ncclResult_t ncclMemAlloc(void** ptr, size_t size) {
// Allocate memory using mscclpp::allocSharedPhysicalCuda
if (ptr == nullptr || size == 0) {
return ncclInvalidArgument;
}
std::shared_ptr<char> sharedPtr;
try {
if (mscclpp::isNvlsSupported()) {
sharedPtr = mscclpp::allocSharedPhysicalCuda<char>(size);
} else {
sharedPtr = mscclpp::allocExtSharedCuda<char>(size);
}
if (sharedPtr == nullptr) {
return ncclSystemError;
}
} catch (const mscclpp::Error& e) {
if (e.getErrorCode() == mscclpp::ErrorCode::InvalidUsage) {
return ncclInvalidUsage;
} else {
return ncclInternalError;
}
} catch (const mscclpp::CudaError& e) {
return ncclUnhandledCudaError;
} catch (const mscclpp::CuError& e) {
return ncclUnhandledCudaError;
} catch (const mscclpp::BaseError& e) {
return ncclInternalError;
}
ptrMap[sharedPtr.get()] = sharedPtr;

// Return the pointer
*ptr = sharedPtr.get();
return ncclSuccess;
}

ncclResult_t ncclMemFree(void* ptr) {
auto ptrIt = ptrMap.find(ptr);
if (ptrIt != ptrMap.end()) {
ptrMap.erase(ptrIt);
return ncclSuccess;
}

// Pointer not found
return ncclInvalidUsage;
}
14 changes: 12 additions & 2 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,17 @@ std::vector<ChannelInfo> ExecutionPlan::Impl::getUnpairedChannelInfos(int rank,
return unpaired;
}

std::vector<NvlsInfo> ExecutionPlan::Impl::getNvlsInfos(int rank) const { return this->nvlsInfos.at(rank); }
std::vector<NvlsInfo> ExecutionPlan::Impl::getNvlsInfos(int rank, size_t sendBuffserSize, size_t recvBufferSize) const {
if (sendBuffserSize == 0 && recvBufferSize == 0) {
return this->nvlsInfos.at(rank);
}
size_t chunkSize = this->getUpperBoundChunkSize(rank, sendBuffserSize, recvBufferSize);
std::vector<NvlsInfo> infos = this->nvlsInfos.at(rank);
for (auto& info : infos) {
info.bufferSize = info.bufferSize * chunkSize;
}
return infos;
}

std::vector<int> ExecutionPlan::Impl::getConnectedPeers(int rank) const {
std::set<int> peers;
Expand Down Expand Up @@ -272,7 +282,7 @@ void ExecutionPlan::Impl::parseChannels(
NvlsInfo info;
info.bufferType = convertToBufferType(channel["buff"]);
for (const auto& group : channel["rankGroups"]) {
info.bufferSize = (int)group["size"] * this->getUpperBoundChunkSize(rank, this->inputSize, this->outputSize);
info.bufferSize = (int)group["size"];
info.ranks.clear();
for (int rank : group["ranks"]) {
info.ranks.push_back(rank);
Expand Down
15 changes: 8 additions & 7 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ struct Executor::Impl {
context.scratchBufferSize = scratchBufferSize;
context.proxyService = std::make_shared<ProxyService>();
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
this->setupConnections(context, rank, plan);
this->setupConnections(context, rank, plan, sendMemRange, recvMemRange);
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupNvlsChannels(context, sendbuff, recvbuff, rank, plan);
this->setupNvlsChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan);
context.deviceExecutionPlansBuffers[devicePlanKey] =
allocExtSharedCuda<char>(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan));
Expand Down Expand Up @@ -214,7 +214,8 @@ struct Executor::Impl {
return flags;
};

void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan) {
void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan, size_t sendBufferSize,
size_t recvBufferSize) {
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers(rank);
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
for (int peer : connectedPeers) {
Expand All @@ -227,7 +228,7 @@ struct Executor::Impl {
context.connections[connectedPeers[i]] = connectionFutures[i].get();
}

std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank);
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank, sendBufferSize, recvBufferSize);
for (const NvlsInfo& info : nvlsInfos) {
std::shared_ptr<NvlsConnection> nvlsConnection =
mscclpp::connectNvlsCollective(this->comm, info.ranks, info.bufferSize);
Expand Down Expand Up @@ -351,9 +352,9 @@ struct Executor::Impl {
}
}

void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, int rank,
const ExecutionPlan& plan) {
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank);
void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize,
size_t recvBufferSize, int rank, const ExecutionPlan& plan) {
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank, sendBufferSize, recvBufferSize);
for (size_t i = 0; i < nvlsInfos.size(); i++) {
std::shared_ptr<NvlsConnection> nvlsConnection = context.nvlsConnections[i];
NvlsInfo info = nvlsInfos[i];
Expand Down
2 changes: 1 addition & 1 deletion src/include/execution_plan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct ExecutionPlan::Impl {
std::vector<ChannelInfo> getChannelInfos(int rank, BufferType bufferType) const;
std::vector<ChannelInfo> getChannelInfosByDstRank(int rank, BufferType bufferType) const;
std::vector<ChannelInfo> getUnpairedChannelInfos(int rank, int worldSize, ChannelType channelType);
std::vector<NvlsInfo> getNvlsInfos(int rank) const;
std::vector<NvlsInfo> getNvlsInfos(int rank, size_t sendBuffserSize = 0, size_t recvBufferSize = 0) const;
std::vector<int> getConnectedPeers(int rank) const;
std::vector<BufferType> getConnectedBufferTypes(int rank) const;
size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const;
Expand Down

0 comments on commit fcb2e46

Please sign in to comment.