Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AllGather Executor Support in NCCL Interface #393

Merged
merged 14 commits into from
Nov 28, 2024
Merged
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 103 additions & 34 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@ struct ncclComm {
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
std::shared_ptr<mscclpp::Executor> executor;
std::shared_ptr<mscclpp::ExecutionPlan> allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan,
allReduceOPPlan;
allReduceOPPlan, allGatherIPPlan, allGatherOPPlan, allGatherPacketIPPlan, allGatherPacketOPPlan;

std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;

size_t smallMessageSizeBoundary, largeMessageSizeBoundary;
size_t allReduceSmallMessageSizeBoundary, allReduceLargeMessageSizeBoundary;
size_t allGatherSmallMessageSizeBoundary, allGatherLargeMessageSizeBoundary;
uint32_t numScratchBuff;
uint32_t buffFlag;
};
Expand Down Expand Up @@ -279,6 +280,46 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
return ncclSuccess;
}

static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

// Declarating variables
size_t recvBytes;
CUdeviceptr recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
channelKey recvKey{(void*)recvBasePtr, recvBytes};
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;

auto it = comm->channelOutInfos.find(recvKey);
if (it == comm->channelOutInfos.end()) {
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = it->second.smChannelDeviceHandles.get();
if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) {
CUDACHECK(allgather<false>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
} else {
CUDACHECK(allgather<true>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
}

return ncclSuccess;
}

NCCL_API ncclResult_t ncclGetVersion(int* version) {
if (version == nullptr) return ncclInvalidArgument;
*version = MSCCLPP_VERSION;
Expand Down Expand Up @@ -355,15 +396,38 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
commPtr->allReduceOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE")));
if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"))
commPtr->smallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"));
commPtr->allReduceSmallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"));
else
commPtr->smallMessageSizeBoundary = 16 * (1 << 10);
commPtr->allReduceSmallMessageSizeBoundary = 16 * (1 << 10);
if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"))
commPtr->largeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"));
commPtr->allReduceLargeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"));
else
commPtr->allReduceLargeMessageSizeBoundary = 1 << 20;

if (getenv("ALLGATHERPKT_IP_JSON_FILE"))
commPtr->allGatherPacketIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_IP_JSON_FILE")));
if (getenv("ALLGATHERPKT_OP_JSON_FILE"))
commPtr->allGatherPacketOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_OP_JSON_FILE")));
if (getenv("ALLGATHER_IP_JSON_FILE"))
commPtr->allGatherIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_IP_JSON_FILE")));
if (getenv("ALLGATHER_OP_JSON_FILE"))
commPtr->allGatherOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_OP_JSON_FILE")));
if (getenv("ALLGATHER_SMALL_MSG_BOUNDARY"))
commPtr->allGatherSmallMessageSizeBoundary = parseSize(getenv("ALLGATHER_SMALL_MSG_BOUNDARY"));
else
commPtr->allGatherSmallMessageSizeBoundary = (1 << 10);
if (getenv("ALLGATHER_LARGE_MSG_BOUNDARY"))
commPtr->allGatherLargeMessageSizeBoundary = parseSize(getenv("ALLGATHER_LARGE_MSG_BOUNDARY"));
else
commPtr->largeMessageSizeBoundary = 1 << 20;
commPtr->allGatherLargeMessageSizeBoundary = 1 << 20;

if (commPtr->smallMessageSizeBoundary > commPtr->largeMessageSizeBoundary) return ncclInvalidArgument;

if (commPtr->allReduceSmallMessageSizeBoundary > commPtr->allReduceLargeMessageSizeBoundary) return ncclInvalidArgument;
if (commPtr->allGatherSmallMessageSizeBoundary > commPtr->allGatherLargeMessageSizeBoundary) return ncclInvalidArgument;

*comm = commPtr;
return ncclSuccess;
Expand Down Expand Up @@ -483,11 +547,11 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
size_t bytes = count * ncclTypeSize(datatype);
int rank = comm->comm->bootstrap()->getRank();

if (bytes < comm->smallMessageSizeBoundary) {
if (bytes < comm->allReduceSmallMessageSizeBoundary) {
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
} else {
std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes <= comm->largeMessageSizeBoundary)
if (bytes <= comm->allReduceLargeMessageSizeBoundary)
plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan;
else {
plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan;
Expand Down Expand Up @@ -533,36 +597,41 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

// Declarating variables
size_t recvBytes;
CUdeviceptr recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
channelKey recvKey{(void*)recvBasePtr, recvBytes};
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;

auto it = comm->channelOutInfos.find(recvKey);
if (it == comm->channelOutInfos.end()) {
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
if (bytes * nRank < comm->allGatherSmallMessageSizeBoundary)
return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream);

std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes * nRank <= comm->allGatherLargeMessageSizeBoundary)
plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherPacketIPPlan : comm->allGatherPacketOPPlan;
else {
plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherIPPlan : comm->allGatherOPPlan;
}

smChannels = it->second.smChannelDeviceHandles.get();
if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) {
CUDACHECK(allgather<false>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
} else {
CUDACHECK(allgather<true>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
if (plan == nullptr) return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream);

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes * nRank, mscclpp::DataType::FLOAT16,
*plan, stream);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes * nRank,
mscclpp::DataType::FLOAT32, *plan, stream);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes * nRank,
mscclpp::DataType::BFLOAT16, *plan, stream);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes * nRank, mscclpp::DataType::UINT32,
*plan, stream);
break;
default:
return ncclInvalidArgument;
}

return ncclSuccess;
Expand Down
Loading