Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/mscclpp into caio…
Browse files Browse the repository at this point in the history
…rocha/reduce_scatter_test_support
  • Loading branch information
caiomcbr committed Nov 28, 2024
2 parents 2e958d9 + d9c297b commit 3777481
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 49 deletions.
140 changes: 105 additions & 35 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@ struct ncclComm {
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
std::shared_ptr<mscclpp::Executor> executor;
std::shared_ptr<mscclpp::ExecutionPlan> allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan,
allReduceOPPlan;
allReduceOPPlan, allGatherIPPlan, allGatherOPPlan, allGatherPacketIPPlan, allGatherPacketOPPlan;

std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;

size_t smallMessageSizeBoundary, largeMessageSizeBoundary;
size_t allReduceSmallMessageSizeBoundary, allReduceLargeMessageSizeBoundary;
size_t allGatherSmallMessageSizeBoundary, allGatherLargeMessageSizeBoundary;
uint32_t numScratchBuff;
uint32_t buffFlag;
};
Expand Down Expand Up @@ -216,7 +217,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;

// Creating the channels
if (count * ncclTypeSize(datatype) <= comm->largeMessageSizeBoundary) {
if (count * ncclTypeSize(datatype) <= (1 << 20)) {
auto sendIt = comm->channelScratchInfos.find(sendKey);
if (sendIt == comm->channelScratchInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
Expand Down Expand Up @@ -279,6 +280,46 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
return ncclSuccess;
}

static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

// Declarating variables
size_t recvBytes;
CUdeviceptr recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
channelKey recvKey{(void*)recvBasePtr, recvBytes};
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;

auto it = comm->channelOutInfos.find(recvKey);
if (it == comm->channelOutInfos.end()) {
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = it->second.smChannelDeviceHandles.get();
if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) {
CUDACHECK(allgather<false>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
} else {
CUDACHECK(allgather<true>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
}

return ncclSuccess;
}

NCCL_API ncclResult_t ncclGetVersion(int* version) {
if (version == nullptr) return ncclInvalidArgument;
*version = MSCCLPP_VERSION;
Expand Down Expand Up @@ -355,15 +396,39 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
commPtr->allReduceOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE")));
if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"))
commPtr->smallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"));
commPtr->allReduceSmallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"));
else
commPtr->smallMessageSizeBoundary = 16 * (1 << 10);
commPtr->allReduceSmallMessageSizeBoundary = 16 * (1 << 10);
if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"))
commPtr->largeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"));
commPtr->allReduceLargeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"));
else
commPtr->allReduceLargeMessageSizeBoundary = 1 << 20;

if (getenv("ALLGATHERPKT_IP_JSON_FILE"))
commPtr->allGatherPacketIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_IP_JSON_FILE")));
if (getenv("ALLGATHERPKT_OP_JSON_FILE"))
commPtr->allGatherPacketOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_OP_JSON_FILE")));
if (getenv("ALLGATHER_IP_JSON_FILE"))
commPtr->allGatherIPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_IP_JSON_FILE")));
if (getenv("ALLGATHER_OP_JSON_FILE"))
commPtr->allGatherOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_OP_JSON_FILE")));
if (getenv("ALLGATHER_SMALL_MSG_BOUNDARY"))
commPtr->allGatherSmallMessageSizeBoundary = parseSize(getenv("ALLGATHER_SMALL_MSG_BOUNDARY"));
else
commPtr->largeMessageSizeBoundary = 1 << 20;
commPtr->allGatherSmallMessageSizeBoundary = (1 << 10);
if (getenv("ALLGATHER_LARGE_MSG_BOUNDARY"))
commPtr->allGatherLargeMessageSizeBoundary = parseSize(getenv("ALLGATHER_LARGE_MSG_BOUNDARY"));
else
commPtr->allGatherLargeMessageSizeBoundary = 1 << 20;

if (commPtr->smallMessageSizeBoundary > commPtr->largeMessageSizeBoundary) return ncclInvalidArgument;
if (commPtr->allReduceSmallMessageSizeBoundary > commPtr->allReduceLargeMessageSizeBoundary)
return ncclInvalidArgument;
if (commPtr->allGatherSmallMessageSizeBoundary > commPtr->allGatherLargeMessageSizeBoundary)
return ncclInvalidArgument;

*comm = commPtr;
return ncclSuccess;
Expand Down Expand Up @@ -483,11 +548,11 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
size_t bytes = count * ncclTypeSize(datatype);
int rank = comm->comm->bootstrap()->getRank();

if (bytes < comm->smallMessageSizeBoundary) {
if (bytes < comm->allReduceSmallMessageSizeBoundary) {
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
} else {
std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes <= comm->largeMessageSizeBoundary)
if (bytes <= comm->allReduceLargeMessageSizeBoundary)
plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan;
else {
plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan;
Expand Down Expand Up @@ -533,36 +598,41 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

// Declarating variables
size_t recvBytes;
CUdeviceptr recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
channelKey recvKey{(void*)recvBasePtr, recvBytes};
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;

auto it = comm->channelOutInfos.find(recvKey);
if (it == comm->channelOutInfos.end()) {
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
if (bytes * nRank < comm->allGatherSmallMessageSizeBoundary)
return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream);

std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes * nRank <= comm->allGatherLargeMessageSizeBoundary)
plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherPacketIPPlan : comm->allGatherPacketOPPlan;
else {
plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherIPPlan : comm->allGatherOPPlan;
}

smChannels = it->second.smChannelDeviceHandles.get();
if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) {
CUDACHECK(allgather<false>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
} else {
CUDACHECK(allgather<true>((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank,
NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream));
if (plan == nullptr) return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream);

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes * nRank, mscclpp::DataType::FLOAT16,
*plan, stream);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes * nRank,
mscclpp::DataType::FLOAT32, *plan, stream);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes * nRank,
mscclpp::DataType::BFLOAT16, *plan, stream);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes * nRank, mscclpp::DataType::UINT32,
*plan, stream);
break;
default:
return ncclInvalidArgument;
}

return ncclSuccess;
Expand Down
2 changes: 1 addition & 1 deletion include/mscclpp/npkit/npkit_event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
#define NPKIT_EVENT_EXECUTOR_INIT_EXIT 0x4

#define NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY 0x5
#define NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT 0x17
#define NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT 0x18

#endif
2 changes: 1 addition & 1 deletion python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def make_connection(
else:
endpoint = endpoints
if endpoint.transport == Transport.Nvls:
return connect_nvls_collective(self.communicator, all_ranks)
return connect_nvls_collective(self.communicator, all_ranks, 2**30)
else:
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
self.communicator.setup()
Expand Down
2 changes: 1 addition & 1 deletion python/mscclpp_benchmark/allreduce_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def get_netinterface_info():
mscclpp_algbw = []
nccl_algbw = []
speed_ups = []
end_range = 28 if is_nvls_supported() else 29
end_range = 29
for i in range(10, end_range):
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
nelems = 2**i
Expand Down
8 changes: 8 additions & 0 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ std::vector<T> filter(const std::vector<T>& vec, Predicate pred) {

auto getOpType = [](const std::string& str) {
if (str == "nop") {
return mscclpp::OperationType::NOP;
} else if (str == "barrier") {
return mscclpp::OperationType::BARRIER;
} else if (str == "put") {
return mscclpp::OperationType::PUT;
Expand Down Expand Up @@ -456,6 +458,12 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
operation.size =
this->getNChunkSize(rank, this->inputSize, this->outputSize, (uint32_t)op["cnt"], chunkIndexes);
}
if (op.contains("barrier_id")) {
operation.deviceSyncerIndex = op["barrier_id"];
}
if (op.contains("nthread_blocks")) {
operation.nThreadBlocks = op["nthread_blocks"];
}
ops.push_back(operation);
}
this->operations[rank].push_back(ops);
Expand Down
20 changes: 15 additions & 5 deletions src/include/execution_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ enum class ChannelType : uint8_t {

// NOTE(chhwang): any modification here requires corresponding updates in `tools/npkit/npkit_trace_generator.py`.
enum class OperationType : uint8_t {
NOP,
BARRIER,
PUT,
PUT_PACKET,
Expand Down Expand Up @@ -78,11 +79,20 @@ struct Operation {
BufferType outputBufferType;
uint8_t nvlsOutputIndex;
};
uint32_t inputOffsets[MAX_CHANNEL_PER_OPERATION];
uint32_t outputOffsets[MAX_CHANNEL_PER_OPERATION];
uint32_t srcOffset;
uint32_t dstOffset;
uint32_t size;
union {
// For Barrier operation
struct {
uint32_t deviceSyncerIndex;
uint32_t nThreadBlocks;
};
struct {
uint32_t inputOffsets[MAX_CHANNEL_PER_OPERATION];
uint32_t outputOffsets[MAX_CHANNEL_PER_OPERATION];
uint32_t srcOffset;
uint32_t dstOffset;
uint32_t size;
};
};
};

// total size = 2304 + 6400 + 4 + 12(padding) = 8720 bytes
Expand Down
10 changes: 9 additions & 1 deletion src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#if defined(ENABLE_NPKIT)
#include <mscclpp/npkit/npkit.hpp>
#endif
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/packet_device.hpp>
#include <mscclpp/proxy_channel.hpp>
#include <mscclpp/sm_channel.hpp>
Expand Down Expand Up @@ -172,6 +173,9 @@ struct VectorType<float> {

namespace mscclpp {

#define MAX_DEVICE_SYNCERS 16
__device__ DeviceSyncer deviceSyncers[MAX_DEVICE_SYNCERS];

#if defined(MSCCLPP_DEVICE_COMPILE)

template <typename T>
Expand Down Expand Up @@ -526,8 +530,12 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
event_buffer, &event_buffer_head);
#endif

if (op.type == OperationType::BARRIER) {
if (op.type == OperationType::NOP) {
__syncthreads();
} else if (op.type == OperationType::BARRIER) {
int nThreadBlocks = op.nThreadBlocks;
int syncStateIndex = op.deviceSyncerIndex;
deviceSyncers[syncStateIndex].sync(nThreadBlocks);
} else if (op.type == OperationType::SIGNAL) {
handleSignal(smChannels, proxyChannels, op.outputChannelIndexes, op.nOutputs, op.channelType);
} else if (op.type == OperationType::WAIT) {
Expand Down
11 changes: 6 additions & 5 deletions test/executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ int main(int argc, char* argv[]) {
}

mscclpp::ExecutionPlan plan(executionPlanName, executionPlanPath);
#if (CUDA_NVLS_SUPPORTED)
std::shared_ptr<char> sendbuff = mscclpp::allocSharedPhysicalCuda<char>(bufferSize);
#else
std::shared_ptr<char> sendbuff = mscclpp::allocExtSharedCuda<char>(bufferSize);
#endif
std::shared_ptr<char> sendbuff;
if (mscclpp::isNvlsSupported()) {
sendbuff = mscclpp::allocSharedPhysicalCuda<char>(bufferSize);
} else {
sendbuff = mscclpp::allocExtSharedCuda<char>(bufferSize);
}
std::vector<int> dataHost(bufferSize / sizeof(int), rank);
MSCCLPP_CUDATHROW(cudaMemcpy(sendbuff.get(), dataHost.data(), bufferSize, cudaMemcpyHostToDevice));
double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, niters, ngraphIters, packetType);
Expand Down
1 change: 1 addition & 0 deletions tools/npkit/npkit_trace_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
def parse_npkit_event_header(npkit_event_header_path):
npkit_event_def = {"id_to_type": {}, "type_to_id": {}}
executor_ops = [
"NOP",
"BARRIER",
"PUT",
"PUT_PACKET",
Expand Down

0 comments on commit 3777481

Please sign in to comment.