Skip to content

Commit

Permalink
NCCL API Executor Integration (#331)
Browse files Browse the repository at this point in the history
Co-authored-by: Changho Hwang <[email protected]>
  • Loading branch information
caiomcbr and chhwang authored Jul 25, 2024
1 parent f131fae commit 67eb9b0
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 81 deletions.
242 changes: 177 additions & 65 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
#include <algorithm>
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/executor.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include <sstream>
#include <unordered_map>
#include <vector>

Expand Down Expand Up @@ -54,13 +56,17 @@ struct ncclComm {
std::shared_ptr<mscclpp::Communicator> comm;
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
std::shared_ptr<mscclpp::Executor> executor;
std::shared_ptr<mscclpp::ExecutionPlan> allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan,
allReduceOPPlan;

std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;

size_t smallMessageSizeBoundary, largeMessageSizeBoundary;
uint32_t numScratchBuff;
uint32_t buffFlag;
};
Expand Down Expand Up @@ -97,6 +103,43 @@ static size_t ncclTypeSize(ncclDataType_t type) {
return 0;
}

double parseSize(const char* value) {
std::string valueStr(value);
std::istringstream iss(valueStr);
long long int units;
double size;
char size_lit = 0;

if (iss >> size) {
iss >> std::ws; // eat whitespace
iss >> size_lit;
} else {
return -1.0;
}

if (size_lit != 0 && !std::isspace(size_lit)) {
switch (size_lit) {
case 'G':
case 'g':
units = 1024 * 1024 * 1024;
break;
case 'M':
case 'm':
units = 1024 * 1024;
break;
case 'K':
case 'k':
units = 1024;
break;
default:
return -1.0;
};
} else {
units = 1;
}
return size * units;
}

static mscclpp::Transport getTransport(int, int) {
// if (rank / nRanksPerNode == peerRank / nRanksPerNode) {
// return mscclpp::Transport::CudaIpc;
Expand Down Expand Up @@ -151,6 +194,86 @@ static std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> setupSmChannel
return ptr;
}

static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) {
// Checking if the parameters are valids
if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr)
return ncclInvalidArgument;

// Declarating variables
size_t sendBytes, recvBytes;
CUdeviceptr sendBasePtr, recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff));
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr;
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff;
size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx;
int rank = comm->comm->bootstrap()->getRank();
channelKey sendKey{(void*)sendBasePtr, sendBytes};
channelKey recvKey{(void*)recvBasePtr, recvBytes};
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;

// Creating the channels
if (count * ncclTypeSize(datatype) <= comm->largeMessageSizeBoundary) {
auto sendIt = comm->channelScratchInfos.find(sendKey);
if (sendIt == comm->channelScratchInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
} else {
std::vector<mscclpp::RegisteredMemory> remoteMemories;

auto sendIt = comm->channelInInfos.find(sendKey);
if (sendIt == comm->channelInInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first;
}

auto recvIt = comm->channelOutInfos.find(recvKey);
if (recvIt == comm->channelOutInfos.end()) {
remoteMemories =
setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> outChannels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)};
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
smOutChannels = recvIt->second.smChannelDeviceHandles.get();
}

switch (datatype) {
case ncclFloat16:
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclFloat32:
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
default:
return ncclInvalidArgument;
}
return ncclSuccess;
}

NCCL_API ncclResult_t ncclGetVersion(int* version) {
if (version == nullptr) return ncclInvalidArgument;
*version = MSCCLPP_VERSION;
Expand Down Expand Up @@ -211,6 +334,30 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
commPtr->scratchBuff = mscclpp::allocExtSharedCuda<char>(SCRATCH_SIZE);
commPtr->remoteScratchRegMemories =
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);

if (getenv("ALLREDUCEPKT_IP_JSON_FILE"))
commPtr->allReducePacketIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_IP_JSON_FILE")));
if (getenv("ALLREDUCEPKT_OP_JSON_FILE"))
commPtr->allReducePacketOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_OP_JSON_FILE")));
if (getenv("ALLREDUCE_IP_JSON_FILE"))
commPtr->allReduceIPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_IP_JSON_FILE")));
if (getenv("ALLREDUCE_OP_JSON_FILE"))
commPtr->allReduceOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE")));
if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"))
commPtr->smallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"));
else
commPtr->smallMessageSizeBoundary = 16 * (1 << 10);
if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"))
commPtr->largeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"));
else
commPtr->largeMessageSizeBoundary = 1 << 20;

if (commPtr->smallMessageSizeBoundary > commPtr->largeMessageSizeBoundary) return ncclInvalidArgument;

*comm = commPtr;
return ncclSuccess;
Expand Down Expand Up @@ -321,82 +468,46 @@ NCCL_API ncclResult_t ncclBroadcast(const void*, void*, size_t, ncclDataType_t,
}

NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) {
ncclRedOp_t reductionOperation, ncclComm_t comm, cudaStream_t stream) {
// Checking if the parameters are valids
if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr)
return ncclInvalidArgument;

// Declarating variables
size_t sendBytes, recvBytes;
CUdeviceptr sendBasePtr, recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff));
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr;
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff;
size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx;
size_t bytes = count * ncclTypeSize(datatype);
int rank = comm->comm->bootstrap()->getRank();
channelKey sendKey{(void*)sendBasePtr, sendBytes};
channelKey recvKey{(void*)recvBasePtr, recvBytes};
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;

// Creating the channels
if (count * ncclTypeSize(datatype) <= (1 << 20)) {
auto sendIt = comm->channelScratchInfos.find(sendKey);
if (sendIt == comm->channelScratchInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
if (bytes < comm->smallMessageSizeBoundary) {
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
} else {
std::vector<mscclpp::RegisteredMemory> remoteMemories;

auto sendIt = comm->channelInInfos.find(sendKey);
if (sendIt == comm->channelInInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first;
std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes <= comm->largeMessageSizeBoundary)
plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan;
else
plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan;

if (plan == nullptr)
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, 1024,
*plan, stream, mscclpp::PacketType::LL8);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32,
1024, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, 1024,
*plan, stream, mscclpp::PacketType::LL8);
break;
default:
return ncclInvalidArgument;
}

auto recvIt = comm->channelOutInfos.find(recvKey);
if (recvIt == comm->channelOutInfos.end()) {
remoteMemories =
setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> outChannels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)};
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
smOutChannels = recvIt->second.smChannelDeviceHandles.get();
}

switch (datatype) {
case ncclFloat16:
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclFloat32:
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
default:
return ncclInvalidArgument;
}
return ncclSuccess;
}

Expand Down Expand Up @@ -442,6 +553,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
CUDACHECK(allgather<true>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
}

return ncclSuccess;
}

Expand Down
5 changes: 4 additions & 1 deletion python/test/test_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,10 @@ def test_executor(mpi_group: MpiGroup, filename: str):
cp.random.seed(42)
buffer = cp.random.random(nelems).astype(cp.float16)
sub_arrays = cp.split(buffer, mpi_group.comm.size)
sendbuf = sub_arrays[mpi_group.comm.rank]
nelems_per_rank = int(nelems / mpi_group.comm.size)
sendbuf = cp.empty(nelems_per_rank).astype(cp.float16)
for i in range(nelems_per_rank):
sendbuf[i] = sub_arrays[mpi_group.comm.rank][i]
expected = cp.zeros_like(sendbuf)
for i in range(mpi_group.comm.size):
expected += sub_arrays[i]
Expand Down
Loading

0 comments on commit 67eb9b0

Please sign in to comment.