Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NVLS support for NCCL API #410

Merged
merged 16 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions apps/nccl/include/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ typedef struct ncclConfig_v21700 {
NCCL_CONFIG_UNDEF_INT /* splitShare */ \
}

/* NCCL malloc and free function for all types of NCCL optimizations
* (e.g. user buffer registration). The actual allocated size might
* be larger than requested due to granularity requirement. */
ncclResult_t ncclMemAlloc(void** ptr, size_t size);
ncclResult_t pncclMemAlloc(void** ptr, size_t size);

ncclResult_t ncclMemFree(void* ptr);
chhwang marked this conversation as resolved.
Show resolved Hide resolved
ncclResult_t pncclMemFree(void* ptr);

/* Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer.
* This integer is coded with the MAJOR, MINOR and PATCH level of the
* NCCL library
Expand Down
62 changes: 61 additions & 1 deletion apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <mscclpp/executor.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include <mscclpp/utils.hpp>
#include <sstream>
#include <unordered_map>
#include <vector>
Expand All @@ -33,6 +34,9 @@
// mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
// mscclpp::Transport::IB6, mscclpp::Transport::IB7};

// Declare the global map to store associations between raw pointer and shared pointer
static std::unordered_map<void*, std::shared_ptr<char>> ptrMap;

struct channelKey {
const void* buff;
size_t bytes;
Expand Down Expand Up @@ -113,7 +117,7 @@ static size_t ncclTypeSize(ncclDataType_t type) {
return 0;
}

double parseSize(const char* value) {
static double parseSize(const char* value) {
std::string valueStr(value);
std::istringstream iss(valueStr);
long long int units;
Expand Down Expand Up @@ -644,3 +648,59 @@ NCCL_API ncclResult_t ncclGroupEnd() {
// Do nothing
return ncclSuccess;
}

NCCL_API ncclResult_t ncclCommRegister(const ncclComm_t, void*, size_t, void**) {
// TODO: Implementation
return ncclSuccess;
}

NCCL_API ncclResult_t ncclCommDeregister(const ncclComm_t, void*) {
// TODO: Implementation
return ncclSuccess;
}

ncclResult_t ncclMemAlloc(void** ptr, size_t size) {
// Allocate memory using mscclpp::allocSharedPhysicalCuda
if (ptr == nullptr || size == 0) {
return ncclInvalidArgument;
}
std::shared_ptr<char> sharedPtr;
try {
if (mscclpp::isNvlsSupported()) {
sharedPtr = mscclpp::allocSharedPhysicalCuda<char>(size);
} else {
sharedPtr = mscclpp::allocExtSharedCuda<char>(size);
}
if (sharedPtr == nullptr) {
return ncclSystemError;
}
} catch (const mscclpp::Error& e) {
if (e.getErrorCode() == mscclpp::ErrorCode::InvalidUsage) {
return ncclInvalidUsage;
} else {
return ncclInternalError;
}
} catch (const mscclpp::CudaError& e) {
return ncclUnhandledCudaError;
} catch (const mscclpp::CuError& e) {
return ncclUnhandledCudaError;
} catch (const mscclpp::BaseError& e) {
return ncclInternalError;
}
ptrMap[sharedPtr.get()] = sharedPtr;

// Return the pointer
*ptr = sharedPtr.get();
return ncclSuccess;
}

ncclResult_t ncclMemFree(void* ptr) {
auto ptrIt = ptrMap.find(ptr);
if (ptrIt != ptrMap.end()) {
ptrMap.erase(ptrIt);
return ncclSuccess;
}

// Pointer not found
return ncclInvalidUsage;
}
14 changes: 12 additions & 2 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,17 @@ std::vector<ChannelInfo> ExecutionPlan::Impl::getUnpairedChannelInfos(int rank,
return unpaired;
}

std::vector<NvlsInfo> ExecutionPlan::Impl::getNvlsInfos(int rank) const { return this->nvlsInfos.at(rank); }
std::vector<NvlsInfo> ExecutionPlan::Impl::getNvlsInfos(int rank, size_t sendBuffserSize, size_t recvBufferSize) const {
if (sendBuffserSize == 0 && recvBufferSize == 0) {
return this->nvlsInfos.at(rank);
}
size_t chunkSize = this->getUpperBoundChunkSize(rank, sendBuffserSize, recvBufferSize);
std::vector<NvlsInfo> infos = this->nvlsInfos.at(rank);
for (auto& info : infos) {
info.bufferSize = info.bufferSize * chunkSize;
}
return infos;
}

std::vector<int> ExecutionPlan::Impl::getConnectedPeers(int rank) const {
std::set<int> peers;
Expand Down Expand Up @@ -272,7 +282,7 @@ void ExecutionPlan::Impl::parseChannels(
NvlsInfo info;
info.bufferType = convertToBufferType(channel["buff"]);
for (const auto& group : channel["rankGroups"]) {
info.bufferSize = (int)group["size"] * this->getUpperBoundChunkSize(rank, this->inputSize, this->outputSize);
info.bufferSize = (int)group["size"];
info.ranks.clear();
for (int rank : group["ranks"]) {
info.ranks.push_back(rank);
Expand Down
15 changes: 8 additions & 7 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ struct Executor::Impl {
context.scratchBufferSize = scratchBufferSize;
context.proxyService = std::make_shared<ProxyService>();
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
this->setupConnections(context, rank, plan);
this->setupConnections(context, rank, plan, sendMemRange, recvMemRange);
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupNvlsChannels(context, sendbuff, recvbuff, rank, plan);
this->setupNvlsChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan);
context.deviceExecutionPlansBuffers[devicePlanKey] =
allocExtSharedCuda<char>(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan));
Expand Down Expand Up @@ -214,7 +214,8 @@ struct Executor::Impl {
return flags;
};

void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan) {
void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan, size_t sendBufferSize,
size_t recvBufferSize) {
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers(rank);
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
for (int peer : connectedPeers) {
Expand All @@ -227,7 +228,7 @@ struct Executor::Impl {
context.connections[connectedPeers[i]] = connectionFutures[i].get();
}

std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank);
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank, sendBufferSize, recvBufferSize);
for (const NvlsInfo& info : nvlsInfos) {
std::shared_ptr<NvlsConnection> nvlsConnection =
mscclpp::connectNvlsCollective(this->comm, info.ranks, info.bufferSize);
Expand Down Expand Up @@ -351,9 +352,9 @@ struct Executor::Impl {
}
}

void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, int rank,
const ExecutionPlan& plan) {
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank);
void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize,
size_t recvBufferSize, int rank, const ExecutionPlan& plan) {
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank, sendBufferSize, recvBufferSize);
for (size_t i = 0; i < nvlsInfos.size(); i++) {
std::shared_ptr<NvlsConnection> nvlsConnection = context.nvlsConnections[i];
NvlsInfo info = nvlsInfos[i];
Expand Down
2 changes: 1 addition & 1 deletion src/include/execution_plan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct ExecutionPlan::Impl {
std::vector<ChannelInfo> getChannelInfos(int rank, BufferType bufferType) const;
std::vector<ChannelInfo> getChannelInfosByDstRank(int rank, BufferType bufferType) const;
std::vector<ChannelInfo> getUnpairedChannelInfos(int rank, int worldSize, ChannelType channelType);
std::vector<NvlsInfo> getNvlsInfos(int rank) const;
std::vector<NvlsInfo> getNvlsInfos(int rank, size_t sendBuffserSize = 0, size_t recvBufferSize = 0) const;
std::vector<int> getConnectedPeers(int rank) const;
std::vector<BufferType> getConnectedBufferTypes(int rank) const;
size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const;
Expand Down
Loading