From 593478e1b74d808bd2c43446ad64729b363472c3 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 25 Nov 2024 21:13:30 -0800 Subject: [PATCH 1/4] Add cross threadblock barrier (#383) --- include/mscclpp/npkit/npkit_event.hpp | 2 +- src/executor/execution_plan.cc | 8 ++++++++ src/include/execution_common.hpp | 20 +++++++++++++++----- src/include/execution_kernel.hpp | 10 +++++++++- test/executor_test.cc | 11 ++++++----- tools/npkit/npkit_trace_generator.py | 1 + 6 files changed, 40 insertions(+), 12 deletions(-) diff --git a/include/mscclpp/npkit/npkit_event.hpp b/include/mscclpp/npkit/npkit_event.hpp index 1a24b241f..cb1925626 100644 --- a/include/mscclpp/npkit/npkit_event.hpp +++ b/include/mscclpp/npkit/npkit_event.hpp @@ -13,6 +13,6 @@ #define NPKIT_EVENT_EXECUTOR_INIT_EXIT 0x4 #define NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY 0x5 -#define NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT 0x17 +#define NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT 0x18 #endif diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index 49ceddf0a..20226b661 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -17,6 +17,8 @@ std::vector filter(const std::vector& vec, Predicate pred) { auto getOpType = [](const std::string& str) { if (str == "nop") { + return mscclpp::OperationType::NOP; + } else if (str == "barrier") { return mscclpp::OperationType::BARRIER; } else if (str == "put") { return mscclpp::OperationType::PUT; @@ -456,6 +458,12 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse operation.size = this->getNChunkSize(rank, this->inputSize, this->outputSize, (uint32_t)op["cnt"], chunkIndexes); } + if (op.contains("barrier_id")) { + operation.deviceSyncerIndex = op["barrier_id"]; + } + if (op.contains("nthread_blocks")) { + operation.nThreadBlocks = op["nthread_blocks"]; + } ops.push_back(operation); } this->operations[rank].push_back(ops); diff --git a/src/include/execution_common.hpp b/src/include/execution_common.hpp index f4f4fbd8c..d0d0dc30d 100644 --- a/src/include/execution_common.hpp +++ b/src/include/execution_common.hpp @@ -30,6 +30,7 @@ enum class ChannelType : uint8_t { // NOTE(chhwang): any modification here requires corresponding updates in `tools/npkit/npkit_trace_generator.py`. enum class OperationType : uint8_t { + NOP, BARRIER, PUT, PUT_PACKET, @@ -78,11 +79,20 @@ struct Operation { BufferType outputBufferType; uint8_t nvlsOutputIndex; }; - uint32_t inputOffsets[MAX_CHANNEL_PER_OPERATION]; - uint32_t outputOffsets[MAX_CHANNEL_PER_OPERATION]; - uint32_t srcOffset; - uint32_t dstOffset; - uint32_t size; + union { + // For Barrier operation + struct { + uint32_t deviceSyncerIndex; + uint32_t nThreadBlocks; + }; + struct { + uint32_t inputOffsets[MAX_CHANNEL_PER_OPERATION]; + uint32_t outputOffsets[MAX_CHANNEL_PER_OPERATION]; + uint32_t srcOffset; + uint32_t dstOffset; + uint32_t size; + }; + }; }; // total size = 2304 + 6400 + 4 + 12(padding) = 8720 bytes diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index 1e9d6ac57..1b0490f91 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -8,6 +8,7 @@ #if defined(ENABLE_NPKIT) #include #endif +#include #include #include #include @@ -172,6 +173,9 @@ struct VectorType { namespace mscclpp { +#define MAX_DEVICE_SYNCERS 16 +__device__ DeviceSyncer deviceSyncers[MAX_DEVICE_SYNCERS]; + #if defined(MSCCLPP_DEVICE_COMPILE) template @@ -526,8 +530,12 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu event_buffer, &event_buffer_head); #endif - if (op.type == OperationType::BARRIER) { + if (op.type == OperationType::NOP) { __syncthreads(); + } else if (op.type == OperationType::BARRIER) { + int nThreadBlocks = op.nThreadBlocks; + int syncStateIndex = op.deviceSyncerIndex; + deviceSyncers[syncStateIndex].sync(nThreadBlocks); } else if (op.type == OperationType::SIGNAL) { handleSignal(smChannels, proxyChannels, op.outputChannelIndexes, op.nOutputs, op.channelType); } else if (op.type == OperationType::WAIT) { diff --git a/test/executor_test.cc b/test/executor_test.cc index 3fc0b1e21..e4ebcc972 100644 --- a/test/executor_test.cc +++ b/test/executor_test.cc @@ -131,11 +131,12 @@ int main(int argc, char* argv[]) { } mscclpp::ExecutionPlan plan(executionPlanName, executionPlanPath); -#if (CUDA_NVLS_SUPPORTED) - std::shared_ptr sendbuff = mscclpp::allocSharedPhysicalCuda(bufferSize); -#else - std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize); -#endif + std::shared_ptr sendbuff; + if (mscclpp::isNvlsSupported()) { + sendbuff = mscclpp::allocSharedPhysicalCuda(bufferSize); + } else { + sendbuff = mscclpp::allocExtSharedCuda(bufferSize); + } std::vector dataHost(bufferSize / sizeof(int), rank); MSCCLPP_CUDATHROW(cudaMemcpy(sendbuff.get(), dataHost.data(), bufferSize, cudaMemcpyHostToDevice)); double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, niters, ngraphIters, packetType); diff --git a/tools/npkit/npkit_trace_generator.py b/tools/npkit/npkit_trace_generator.py index 9a5b88b44..31c2e1622 100644 --- a/tools/npkit/npkit_trace_generator.py +++ b/tools/npkit/npkit_trace_generator.py @@ -11,6 +11,7 @@ def parse_npkit_event_header(npkit_event_header_path): npkit_event_def = {"id_to_type": {}, "type_to_id": {}} executor_ops = [ + "NOP", "BARRIER", "PUT", "PUT_PACKET", From d9c297ba14dd219875dc4e691899c3668c86efbe Mon Sep 17 00:00:00 2001 From: Caio Rocha <164253795+caiomcbr@users.noreply.github.com> Date: Wed, 27 Nov 2024 17:05:51 -0800 Subject: [PATCH 2/4] AllGather Executor Support in NCCL Interface (#393) Co-authored-by: Ziyue Yang Co-authored-by: Changho Hwang Co-authored-by: Binyang Li --- apps/nccl/src/nccl.cu | 138 +++++++++++++++++++++++++++++++----------- 1 file changed, 104 insertions(+), 34 deletions(-) diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 2bd2a4422..2b7e97360 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -58,7 +58,7 @@ struct ncclComm { std::vector> smSemaphores; std::shared_ptr executor; std::shared_ptr allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan, - allReduceOPPlan; + allReduceOPPlan, allGatherIPPlan, allGatherOPPlan, allGatherPacketIPPlan, allGatherPacketOPPlan; std::unordered_map channelInInfos; std::unordered_map channelOutInfos; @@ -66,7 +66,8 @@ struct ncclComm { std::shared_ptr scratchBuff; std::vector remoteScratchRegMemories; - size_t smallMessageSizeBoundary, largeMessageSizeBoundary; + size_t allReduceSmallMessageSizeBoundary, allReduceLargeMessageSizeBoundary; + size_t allGatherSmallMessageSizeBoundary, allGatherLargeMessageSizeBoundary; uint32_t numScratchBuff; uint32_t buffFlag; }; @@ -279,6 +280,46 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, return ncclSuccess; } +static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, size_t sendcount, + ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) { + size_t bytes = sendcount * ncclTypeSize(datatype); + if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument; + + // Declarating variables + size_t recvBytes; + CUdeviceptr recvBasePtr; + MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff)); + size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; + channelKey recvKey{(void*)recvBasePtr, recvBytes}; + int rank = comm->comm->bootstrap()->getRank(); + int nRank = comm->comm->bootstrap()->getNranks(); + mscclpp::DeviceHandle* smChannels = nullptr; + + auto it = comm->channelOutInfos.find(recvKey); + if (it == comm->channelOutInfos.end()) { + std::vector remoteMemories = setupRemoteMemories( + comm->comm, rank, const_cast((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc); + std::vector channels = + setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); + std::vector> smChannelDeviceHandles; + std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), + [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); + ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + } + + smChannels = it->second.smChannelDeviceHandles.get(); + if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) { + CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, + NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); + } else { + CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, + NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); + } + + return ncclSuccess; +} + NCCL_API ncclResult_t ncclGetVersion(int* version) { if (version == nullptr) return ncclInvalidArgument; *version = MSCCLPP_VERSION; @@ -355,15 +396,39 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI commPtr->allReduceOPPlan = std::make_shared(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE"))); if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY")) - commPtr->smallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY")); + commPtr->allReduceSmallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY")); else - commPtr->smallMessageSizeBoundary = 16 * (1 << 10); + commPtr->allReduceSmallMessageSizeBoundary = 16 * (1 << 10); if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY")) - commPtr->largeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY")); + commPtr->allReduceLargeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY")); + else + commPtr->allReduceLargeMessageSizeBoundary = 1 << 20; + + if (getenv("ALLGATHERPKT_IP_JSON_FILE")) + commPtr->allGatherPacketIPPlan = std::make_shared( + mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_IP_JSON_FILE"))); + if (getenv("ALLGATHERPKT_OP_JSON_FILE")) + commPtr->allGatherPacketOPPlan = std::make_shared( + mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_OP_JSON_FILE"))); + if (getenv("ALLGATHER_IP_JSON_FILE")) + commPtr->allGatherIPPlan = + std::make_shared(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_IP_JSON_FILE"))); + if (getenv("ALLGATHER_OP_JSON_FILE")) + commPtr->allGatherOPPlan = + std::make_shared(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_OP_JSON_FILE"))); + if (getenv("ALLGATHER_SMALL_MSG_BOUNDARY")) + commPtr->allGatherSmallMessageSizeBoundary = parseSize(getenv("ALLGATHER_SMALL_MSG_BOUNDARY")); else - commPtr->largeMessageSizeBoundary = 1 << 20; + commPtr->allGatherSmallMessageSizeBoundary = (1 << 10); + if (getenv("ALLGATHER_LARGE_MSG_BOUNDARY")) + commPtr->allGatherLargeMessageSizeBoundary = parseSize(getenv("ALLGATHER_LARGE_MSG_BOUNDARY")); + else + commPtr->allGatherLargeMessageSizeBoundary = 1 << 20; - if (commPtr->smallMessageSizeBoundary > commPtr->largeMessageSizeBoundary) return ncclInvalidArgument; + if (commPtr->allReduceSmallMessageSizeBoundary > commPtr->allReduceLargeMessageSizeBoundary) + return ncclInvalidArgument; + if (commPtr->allGatherSmallMessageSizeBoundary > commPtr->allGatherLargeMessageSizeBoundary) + return ncclInvalidArgument; *comm = commPtr; return ncclSuccess; @@ -483,11 +548,11 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t size_t bytes = count * ncclTypeSize(datatype); int rank = comm->comm->bootstrap()->getRank(); - if (bytes < comm->smallMessageSizeBoundary) { + if (bytes < comm->allReduceSmallMessageSizeBoundary) { return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream); } else { std::shared_ptr plan; - if (bytes <= comm->largeMessageSizeBoundary) + if (bytes <= comm->allReduceLargeMessageSizeBoundary) plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan; else { plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan; @@ -533,36 +598,41 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t size_t bytes = sendcount * ncclTypeSize(datatype); if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument; - // Declarating variables - size_t recvBytes; - CUdeviceptr recvBasePtr; - MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff)); - size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; - channelKey recvKey{(void*)recvBasePtr, recvBytes}; int rank = comm->comm->bootstrap()->getRank(); int nRank = comm->comm->bootstrap()->getNranks(); - mscclpp::DeviceHandle* smChannels = nullptr; - auto it = comm->channelOutInfos.find(recvKey); - if (it == comm->channelOutInfos.end()) { - std::vector remoteMemories = setupRemoteMemories( - comm->comm, rank, const_cast((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc); - std::vector channels = - setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); - std::vector> smChannelDeviceHandles; - std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), - [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); - ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; - it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + if (bytes * nRank < comm->allGatherSmallMessageSizeBoundary) + return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream); + + std::shared_ptr plan; + if (bytes * nRank <= comm->allGatherLargeMessageSizeBoundary) + plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherPacketIPPlan : comm->allGatherPacketOPPlan; + else { + plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherIPPlan : comm->allGatherOPPlan; } - smChannels = it->second.smChannelDeviceHandles.get(); - if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) { - CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, - NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); - } else { - CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, - NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); + if (plan == nullptr) return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream); + + switch (datatype) { + case ncclFloat16: + comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes * nRank, mscclpp::DataType::FLOAT16, + *plan, stream); + break; + case ncclFloat32: + comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes * nRank, + mscclpp::DataType::FLOAT32, *plan, stream); + break; + case ncclBfloat16: + comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes * nRank, + mscclpp::DataType::BFLOAT16, *plan, stream); + break; + case ncclInt32: + case ncclUint32: + comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes * nRank, mscclpp::DataType::UINT32, + *plan, stream); + break; + default: + return ncclInvalidArgument; } return ncclSuccess; From ff18bb8d0b589fe32e71b8df38794967ddba0ace Mon Sep 17 00:00:00 2001 From: Caio Rocha <164253795+caiomcbr@users.noreply.github.com> Date: Thu, 28 Nov 2024 09:19:30 -0800 Subject: [PATCH 3/4] Providing reduce-scatter test support (#390) --- python/test/executor_test.py | 6 ++++-- python/test/executor_test_verifier.cu | 30 +++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 60cf36b95..5197b79b9 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -74,6 +74,8 @@ def bench_correctness( fill_data_kernel_name = "fill_data_%s" % dtype_str if "allgather" in execution_plan_name: coll = "all_gather" + elif "reducescatter" in execution_plan_name: + coll = "reduce_scatter" else: coll = "all_reduce" test_data_kernel_name = "test_data_%s_%s" % (coll, dtype_str) @@ -96,7 +98,7 @@ def bench_correctness( fill_data_kernel.launch_kernel(fill_data_params, nblocks, nthreads, 0, stream) func(stream) test_data_params = ( - pack(result_buf, test_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(num_ranks, i) + pack(result_buf, test_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(num_ranks, rank, i) ) test_data_kernel.launch_kernel(test_data_params, nblocks, nthreads, 0, stream) graph = stream.end_capture() @@ -128,7 +130,7 @@ def dtype_to_mscclpp_dtype(dtype): def allocate_buffer(nelems, dtype): - if is_nvls_supported: + if is_nvls_supported(): buffer_raw = alloc_shared_physical_cuda(nelems * cp.dtype(dtype).itemsize) buffer_ptr = cp.cuda.MemoryPointer( cp.cuda.UnownedMemory(buffer_raw.get_ptr(), buffer_raw.size(), buffer_raw), 0 diff --git a/python/test/executor_test_verifier.cu b/python/test/executor_test_verifier.cu index a7a6ca31f..581396198 100644 --- a/python/test/executor_test_verifier.cu +++ b/python/test/executor_test_verifier.cu @@ -36,7 +36,7 @@ FILL_DATA(int32, int) #define TEST_DATA_ALL_GATHER(FuncNameType, DataType) \ extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_gather_##FuncNameType( \ - DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \ + DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \ for (int rank = 0; rank < num_ranks; rank++) { \ size_t rank_offset = rank * num_elems; \ unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ @@ -54,7 +54,7 @@ TEST_DATA_ALL_GATHER(int32, int) #define TEST_DATA_ALL_REDUCE(FuncNameType, DataType) \ extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_reduce_##FuncNameType( \ - DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \ + DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \ for (int rank = 0; rank < num_ranks; rank++) { \ unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ @@ -73,3 +73,29 @@ TEST_DATA_ALL_GATHER(int32, int) TEST_DATA_ALL_REDUCE(float16, __half) TEST_DATA_ALL_REDUCE(float32, float) TEST_DATA_ALL_REDUCE(int32, int) + +#define TEST_DATA_REDUCE_SCATTER(FuncNameType, DataType) \ + extern "C" __global__ void __launch_bounds__(1024, 1) test_data_reduce_scatter_##FuncNameType( \ + DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \ + int nem_elems_per_rank = num_elems / num_ranks; \ + int offset = nem_elems_per_rank * my_rank; \ + for (int rank = 0; rank < num_ranks; rank++) { \ + unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + if (rank == 0) { \ + test_buf[i] = 0; \ + } \ + seed = ranqd1(seed); \ + test_buf[i] += DataType(seed % blockDim.x) / DataType(blockDim.x); \ + } \ + } \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + if (i >= offset && i < offset + nem_elems_per_rank) { \ + assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \ + } \ + } \ + } + +TEST_DATA_REDUCE_SCATTER(float16, __half) +TEST_DATA_REDUCE_SCATTER(float32, float) +TEST_DATA_REDUCE_SCATTER(int32, int) \ No newline at end of file From 88d28e07a7f1781623acce14a8018c8ac1082be2 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 3 Dec 2024 14:39:20 -0800 Subject: [PATCH 4/4] Select algo according to json config (#396) The way to run nccl-test over mscclpp: mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=$(pwd)/build/apps/nccl/libmscclpp_nccl.so -x NCCL_DEBUG=WARN -x MSCCLPP_EXECUTION_PLAN_DIR=/execution-files /root/nccl-tests/build/all_reduce_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20 --- apps/nccl/src/nccl.cu | 160 +++++++++------------ include/mscclpp/executor.hpp | 8 +- python/mscclpp/executor_py.cpp | 6 +- python/test/executor_test.py | 22 ++- python/test/test_mscclpp.py | 2 +- src/executor/execution_plan.cc | 28 +++- src/include/execution_plan.hpp | 8 +- test/execution-files/allreduce.json | 2 +- test/execution-files/allreduce_nvls.json | 2 +- test/execution-files/allreduce_packet.json | 2 +- test/execution-files/sendrecv.json | 2 +- test/execution-files/sendrecv_packet.json | 2 +- test/executor_test.cc | 16 +-- test/mp_unit/executor_tests.cc | 2 +- 14 files changed, 137 insertions(+), 125 deletions(-) diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 2b7e97360..cd75edfea 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -2,6 +2,7 @@ // Licensed under the MIT license. #include +#include #include #include #include @@ -38,6 +39,17 @@ struct channelKey { bool operator==(const channelKey& other) const { return buff == other.buff && bytes == other.bytes; } }; +struct planKey { + size_t minMessageSize; + size_t maxMessageSize; + bool isInPlace; +}; + +struct executionPlanInstance { + planKey key; + std::shared_ptr plan; +}; + namespace std { template <> struct hash { @@ -57,8 +69,7 @@ struct ncclComm { std::vector> connections; std::vector> smSemaphores; std::shared_ptr executor; - std::shared_ptr allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan, - allReduceOPPlan, allGatherIPPlan, allGatherOPPlan, allGatherPacketIPPlan, allGatherPacketOPPlan; + std::unordered_map> executionPlans; std::unordered_map channelInInfos; std::unordered_map channelOutInfos; @@ -66,8 +77,6 @@ struct ncclComm { std::shared_ptr scratchBuff; std::vector remoteScratchRegMemories; - size_t allReduceSmallMessageSizeBoundary, allReduceLargeMessageSizeBoundary; - size_t allGatherSmallMessageSizeBoundary, allGatherLargeMessageSizeBoundary; uint32_t numScratchBuff; uint32_t buffFlag; }; @@ -183,6 +192,13 @@ static std::vector setupSmChannels(ncclComm_t comm, return channels; } +static std::pair loadExecutionPlan(const std::string& filename) { + std::shared_ptr plan = std::make_shared(filename); + std::string collective = plan->collective(); + planKey key{plan->minMessageSize(), plan->maxMessageSize(), plan->isInPlace()}; + return std::make_pair(collective, executionPlanInstance{key, plan}); +} + static std::shared_ptr> setupSmChannelDeviceHandles( const std::vector& smChannels) { std::vector> smChannelDeviceHandles; @@ -383,52 +399,18 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc); commPtr->executor = std::make_shared(mscclppComm); - if (getenv("ALLREDUCEPKT_IP_JSON_FILE")) - commPtr->allReducePacketIPPlan = std::make_shared( - mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_IP_JSON_FILE"))); - if (getenv("ALLREDUCEPKT_OP_JSON_FILE")) - commPtr->allReducePacketOPPlan = std::make_shared( - mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_OP_JSON_FILE"))); - if (getenv("ALLREDUCE_IP_JSON_FILE")) - commPtr->allReduceIPPlan = - std::make_shared(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_IP_JSON_FILE"))); - if (getenv("ALLREDUCE_OP_JSON_FILE")) - commPtr->allReduceOPPlan = - std::make_shared(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE"))); - if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY")) - commPtr->allReduceSmallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY")); - else - commPtr->allReduceSmallMessageSizeBoundary = 16 * (1 << 10); - if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY")) - commPtr->allReduceLargeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY")); - else - commPtr->allReduceLargeMessageSizeBoundary = 1 << 20; - - if (getenv("ALLGATHERPKT_IP_JSON_FILE")) - commPtr->allGatherPacketIPPlan = std::make_shared( - mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_IP_JSON_FILE"))); - if (getenv("ALLGATHERPKT_OP_JSON_FILE")) - commPtr->allGatherPacketOPPlan = std::make_shared( - mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_OP_JSON_FILE"))); - if (getenv("ALLGATHER_IP_JSON_FILE")) - commPtr->allGatherIPPlan = - std::make_shared(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_IP_JSON_FILE"))); - if (getenv("ALLGATHER_OP_JSON_FILE")) - commPtr->allGatherOPPlan = - std::make_shared(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_OP_JSON_FILE"))); - if (getenv("ALLGATHER_SMALL_MSG_BOUNDARY")) - commPtr->allGatherSmallMessageSizeBoundary = parseSize(getenv("ALLGATHER_SMALL_MSG_BOUNDARY")); - else - commPtr->allGatherSmallMessageSizeBoundary = (1 << 10); - if (getenv("ALLGATHER_LARGE_MSG_BOUNDARY")) - commPtr->allGatherLargeMessageSizeBoundary = parseSize(getenv("ALLGATHER_LARGE_MSG_BOUNDARY")); - else - commPtr->allGatherLargeMessageSizeBoundary = 1 << 20; - - if (commPtr->allReduceSmallMessageSizeBoundary > commPtr->allReduceLargeMessageSizeBoundary) - return ncclInvalidArgument; - if (commPtr->allGatherSmallMessageSizeBoundary > commPtr->allGatherLargeMessageSizeBoundary) - return ncclInvalidArgument; + if (getenv("MSCCLPP_EXECUTION_PLAN_DIR")) { + std::string collectiveDir = getenv("MSCCLPP_EXECUTION_PLAN_DIR"); + if (!std::filesystem::is_directory(collectiveDir)) { + return ncclInvalidArgument; + } + for (const auto& entry : std::filesystem::directory_iterator(collectiveDir)) { + if (entry.is_regular_file()) { + auto plan = loadExecutionPlan(entry.path()); + commPtr->executionPlans[plan.first].push_back(plan.second); + } + } + } *comm = commPtr; return ncclSuccess; @@ -548,40 +530,39 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t size_t bytes = count * ncclTypeSize(datatype); int rank = comm->comm->bootstrap()->getRank(); - if (bytes < comm->allReduceSmallMessageSizeBoundary) { - return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream); - } else { - std::shared_ptr plan; - if (bytes <= comm->allReduceLargeMessageSizeBoundary) - plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan; - else { - plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan; + std::vector& plans = comm->executionPlans["allreduce"]; + std::shared_ptr plan; + bool inPlace = sendbuff == recvbuff; + for (const auto& p : plans) { + if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) { + plan = p.plan; + break; } + } - if (plan == nullptr) - return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream); + if (plan == nullptr) + return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream); - switch (datatype) { - case ncclFloat16: - comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan, - stream, mscclpp::PacketType::LL8); - break; - case ncclFloat32: - comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32, - *plan, stream, mscclpp::PacketType::LL8); - break; - case ncclBfloat16: - comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes, - mscclpp::DataType::BFLOAT16, *plan, stream, mscclpp::PacketType::LL8); - break; - case ncclInt32: - case ncclUint32: - comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan, - stream, mscclpp::PacketType::LL8); - break; - default: - return ncclInvalidArgument; - } + switch (datatype) { + case ncclFloat16: + comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan, + stream, mscclpp::PacketType::LL8); + break; + case ncclFloat32: + comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32, *plan, + stream, mscclpp::PacketType::LL8); + break; + case ncclBfloat16: + comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes, + mscclpp::DataType::BFLOAT16, *plan, stream, mscclpp::PacketType::LL8); + break; + case ncclInt32: + case ncclUint32: + comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan, + stream, mscclpp::PacketType::LL8); + break; + default: + return ncclInvalidArgument; } return ncclSuccess; @@ -601,16 +582,17 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t int rank = comm->comm->bootstrap()->getRank(); int nRank = comm->comm->bootstrap()->getNranks(); - if (bytes * nRank < comm->allGatherSmallMessageSizeBoundary) - return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream); - + std::vector& plans = comm->executionPlans["allgather"]; std::shared_ptr plan; - if (bytes * nRank <= comm->allGatherLargeMessageSizeBoundary) - plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherPacketIPPlan : comm->allGatherPacketOPPlan; - else { - plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherIPPlan : comm->allGatherOPPlan; + void* basePtr = (char*)sendbuff - rank * bytes; + bool inPlace = basePtr == recvbuff; + const size_t totalBytes = bytes * nRank; + for (const auto& p : plans) { + if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) { + plan = p.plan; + break; + } } - if (plan == nullptr) return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream); switch (datatype) { diff --git a/include/mscclpp/executor.hpp b/include/mscclpp/executor.hpp index e994548e4..5d76983e6 100644 --- a/include/mscclpp/executor.hpp +++ b/include/mscclpp/executor.hpp @@ -25,9 +25,15 @@ enum class PacketType { class ExecutionPlan { public: - ExecutionPlan(const std::string& name, const std::string& planPath); + ExecutionPlan(const std::string& planPath); ~ExecutionPlan() = default; + std::string name() const; + std::string collective() const; + size_t minMessageSize() const; + size_t maxMessageSize() const; + bool isInPlace() const; + private: struct Impl; std::shared_ptr impl_; diff --git a/python/mscclpp/executor_py.cpp b/python/mscclpp/executor_py.cpp index c550ecb00..d3add7194 100644 --- a/python/mscclpp/executor_py.cpp +++ b/python/mscclpp/executor_py.cpp @@ -22,7 +22,11 @@ void register_executor(nb::module_& m) { nb::enum_(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16); nb::class_(m, "ExecutionPlan") - .def(nb::init(), nb::arg("name"), nb::arg("planPath")); + .def(nb::init(), nb::arg("planPath")) + .def("name", &ExecutionPlan::name) + .def("collective", &ExecutionPlan::collective) + .def("min_message_size", &ExecutionPlan::minMessageSize) + .def("max_message_size", &ExecutionPlan::maxMessageSize); nb::class_(m, "Executor") .def(nb::init>(), nb::arg("comm")) diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 5197b79b9..67e9929f1 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -59,7 +59,7 @@ def bench_time(n_iters: int, n_graph_iters: int, func): def bench_correctness( - execution_plan_name: str, + collective: str, input_buf: cp.ndarray, result_buf: cp.ndarray, test_buf: cp.ndarray, @@ -72,9 +72,9 @@ def bench_correctness( type_size = cp.dtype(parse_dtype(dtype_str)).itemsize fill_data_kernel_name = "fill_data_%s" % dtype_str - if "allgather" in execution_plan_name: + if "allgather" in collective: coll = "all_gather" - elif "reducescatter" in execution_plan_name: + elif "reducescatter" in collective: coll = "reduce_scatter" else: coll = "all_reduce" @@ -142,7 +142,7 @@ def allocate_buffer(nelems, dtype): def build_bufs( - execution_plan_name: str, + collective: str, size: int, in_place: bool, dtype: cp.dtype, @@ -153,7 +153,7 @@ def build_bufs( assert (size % type_size) == 0, "size %d not multiple of type size %d" % (size, type_size) nelems = size // type_size - if "allgather" in execution_plan_name: + if "allgather" in collective: assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks) nelems_input = nelems if in_place else nelems // num_ranks else: @@ -162,7 +162,7 @@ def build_bufs( result_buf = allocate_buffer(nelems_output, dtype=dtype) if in_place: - if "allgather" in execution_plan_name: + if "allgather" in collective: input_buf = cp.split(result_buf, num_ranks)[rank] else: input_buf = result_buf @@ -174,7 +174,6 @@ def build_bufs( def main( - execution_plan_name: str, execution_plan_path: str, size: int, in_place: bool = True, @@ -189,11 +188,12 @@ def main( npkit_dump_dir = os.getenv("NPKIT_DUMP_DIR") if npkit_dump_dir is not None: npkit.init(mscclpp_group.my_rank) - execution_plan = ExecutionPlan(execution_plan_name, execution_plan_path) + execution_plan = ExecutionPlan(execution_plan_path) + collective = execution_plan.collective() dtype = parse_dtype(dtype_str) input_buf, result_buf, test_buf = build_bufs( - execution_plan_name, + collective, size, in_place, dtype, @@ -215,7 +215,7 @@ def main( mscclpp_group.barrier() bench_correctness( - execution_plan_name, + collective, input_buf, result_buf, test_buf, @@ -242,7 +242,6 @@ def main( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-n", "--execution_plan_name", type=str, required=True) parser.add_argument("-path", "--execution_plan_path", type=str, required=True) parser.add_argument("--size", type=str, required=True) parser.add_argument("--in_place", action="store_true", help="flag to define an in-place operation") @@ -258,7 +257,6 @@ def main( buffer_size = parse_size(args.size) main( - args.execution_plan_name, args.execution_plan_path, buffer_size, args.in_place, diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index 1be0b1821..3e7fa90da 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -607,7 +607,7 @@ def test_executor(mpi_group: MpiGroup, filename: str): npkit_dump_dir = os.getenv("NPKIT_DUMP_DIR") if npkit_dump_dir is not None: npkit.init(mscclpp_group.my_rank) - execution_plan = ExecutionPlan("allreduce_pairs", os.path.join(project_dir, "test", "execution-files", filename)) + execution_plan = ExecutionPlan(os.path.join(project_dir, "test", "execution-files", filename)) nelems = 1024 * 1024 cp.random.seed(42) diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index 20226b661..144fb4174 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -94,8 +94,15 @@ std::set groupChannelType{mscclpp::ChannelType::NVLS}; namespace mscclpp { using json = nlohmann::json; -ExecutionPlan::Impl::Impl(const std::string name, const std::string planPath) - : name(name), planPath(planPath), isUsingPacket(false) {} +ExecutionPlan::Impl::Impl(const std::string planPath) : planPath(planPath), isUsingPacket(false) { + std::ifstream file(this->planPath); + json obj = json::parse(file); + this->name = obj["name"]; + this->collective = obj["collective"]; + this->isInPlace = obj["inplace"]; + this->minMessageSize = obj.value("min_message_size", 0); + this->maxMessageSize = obj.value("max_message_size", std::numeric_limits::max()); +} std::vector ExecutionPlan::Impl::getChannelInfos(int rank, ChannelType channelType) const { auto pred = [channelType](const ChannelInfo& info) { return info.channelType == channelType; }; @@ -187,6 +194,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, if (this->name != obj["name"]) { throw Error("Plan name does not match", ErrorCode::ExecutorError); } + this->collective = obj["collective"]; std::string protocol = obj["protocol"]; if (protocol == "LL") { this->isUsingPacket = true; @@ -194,6 +202,9 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, this->inputSize = inputSize; this->outputSize = outputSize; this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024); + this->minMessageSize = obj.value("min_message_size", 0); + this->maxMessageSize = obj.value("max_message_size", std::numeric_limits::max()); + this->isInPlace = obj["inplace"]; const auto& gpus = obj["gpus"]; for (const auto& gpu : gpus) { @@ -549,7 +560,16 @@ void ExecutionPlan::Impl::reset() { void ExecutionPlan::Impl::operationsReset() { this->operations.clear(); } -ExecutionPlan::ExecutionPlan(const std::string& name, const std::string& planPath) - : impl_(std::make_shared(name, planPath)) {} +ExecutionPlan::ExecutionPlan(const std::string& planPath) : impl_(std::make_shared(planPath)) {} + +std::string ExecutionPlan::name() const { return this->impl_->name; } + +std::string ExecutionPlan::collective() const { return this->impl_->collective; } + +size_t ExecutionPlan::minMessageSize() const { return this->impl_->minMessageSize; } + +size_t ExecutionPlan::maxMessageSize() const { return this->impl_->maxMessageSize; } + +bool ExecutionPlan::isInPlace() const { return this->impl_->isInPlace; } } // namespace mscclpp diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp index 07292d748..8d291f45f 100644 --- a/src/include/execution_plan.hpp +++ b/src/include/execution_plan.hpp @@ -62,7 +62,7 @@ struct ChannelInfo { struct ExecutionPlan::Impl { public: - Impl(const std::string name, const std::string planPath); + Impl(const std::string planPath); ~Impl() = default; std::vector getChannelInfos(int rank, ChannelType channelType) const; @@ -85,7 +85,8 @@ struct ExecutionPlan::Impl { void reset(); void operationsReset(); - const std::string name; + std::string name; + std::string collective; const std::string planPath; bool isUsingPacket; // operations for [rank][threadblock] = [operations] @@ -106,6 +107,9 @@ struct ExecutionPlan::Impl { size_t inputSize; size_t outputSize; int nThreadsPerBlock; + size_t minMessageSize; + size_t maxMessageSize; + bool isInPlace; private: std::pair calcSizePerRank(int rank, size_t inputSize, size_t outputSize) const; diff --git a/test/execution-files/allreduce.json b/test/execution-files/allreduce.json index afc921f4b..0a0a76590 100644 --- a/test/execution-files/allreduce.json +++ b/test/execution-files/allreduce.json @@ -1,6 +1,6 @@ { "name": "allreduce_pairs", - "colletive": "allreduce", + "collective": "allreduce", "protocol": "Simple", "inplace": true, "num_threads_per_block": 512, diff --git a/test/execution-files/allreduce_nvls.json b/test/execution-files/allreduce_nvls.json index 069b5df9d..e882d0285 100644 --- a/test/execution-files/allreduce_nvls.json +++ b/test/execution-files/allreduce_nvls.json @@ -1,6 +1,6 @@ { "name": "allreduce_nvls", - "colletive": "allreduce", + "collective": "allreduce", "protocol": "Simple", "inplace": true, "gpus": [ diff --git a/test/execution-files/allreduce_packet.json b/test/execution-files/allreduce_packet.json index d35a4e96b..545523447 100644 --- a/test/execution-files/allreduce_packet.json +++ b/test/execution-files/allreduce_packet.json @@ -1,6 +1,6 @@ { "name": "allreduce_pairs", - "colletive": "allreduce", + "collective": "allreduce", "protocol": "LL", "inplace": true, "num_threads_per_block": 768, diff --git a/test/execution-files/sendrecv.json b/test/execution-files/sendrecv.json index e84a06f7d..439d5f4f4 100644 --- a/test/execution-files/sendrecv.json +++ b/test/execution-files/sendrecv.json @@ -1,6 +1,6 @@ { "name": "send_recv", - "colletive": "sendrecv", + "collective": "sendrecv", "protocol": "Simple", "inplace": false, "gpus": [ diff --git a/test/execution-files/sendrecv_packet.json b/test/execution-files/sendrecv_packet.json index 3156b6191..5d63ad7c0 100644 --- a/test/execution-files/sendrecv_packet.json +++ b/test/execution-files/sendrecv_packet.json @@ -1,6 +1,6 @@ { "name": "send_recv", - "colletive": "sendrecv", + "collective": "sendrecv", "protocol": "LL", "inplace": false, "gpus": [ diff --git a/test/executor_test.cc b/test/executor_test.cc index e4ebcc972..68e8bfa32 100644 --- a/test/executor_test.cc +++ b/test/executor_test.cc @@ -89,9 +89,8 @@ double benchTime(int rank, std::shared_ptr bootstrap, std::s } int main(int argc, char* argv[]) { - if (argc != 6 && argc != 7) { + if (argc != 5 && argc != 6) { std::cerr << "Usage: " << argv[0] << " " - << " " << " " << " " << " " @@ -107,14 +106,13 @@ int main(int argc, char* argv[]) { MSCCLPP_CUDATHROW(cudaSetDevice(rank)); const size_t bufferSize = parseSize(argv[1]); - const std::string executionPlanName = argv[2]; - const std::string executionPlanPath = argv[3]; - const int niters = std::stoi(argv[4]); - const int ngraphIters = std::stoi(argv[5]); + const std::string executionPlanPath = argv[2]; + const int niters = std::stoi(argv[3]); + const int ngraphIters = std::stoi(argv[4]); const char* npkitDumpDir = getenv("NPKIT_DUMP_DIR"); mscclpp::PacketType packetType = mscclpp::PacketType::LL16; - if (argc == 7) { - packetType = parsePacketType(argv[6]); + if (argc == 6) { + packetType = parsePacketType(argv[5]); } std::shared_ptr bootstrap; @@ -130,7 +128,7 @@ int main(int argc, char* argv[]) { NpKit::Init(rank); } - mscclpp::ExecutionPlan plan(executionPlanName, executionPlanPath); + mscclpp::ExecutionPlan plan(executionPlanPath); std::shared_ptr sendbuff; if (mscclpp::isNvlsSupported()) { sendbuff = mscclpp::allocSharedPhysicalCuda(bufferSize); diff --git a/test/mp_unit/executor_tests.cc b/test/mp_unit/executor_tests.cc index 49952b6b4..116470dd1 100644 --- a/test/mp_unit/executor_tests.cc +++ b/test/mp_unit/executor_tests.cc @@ -55,7 +55,7 @@ TEST_F(ExecutorTest, TwoNodesAllreduce) { std::filesystem::path path = executablePath; std::filesystem::path executionFilesPath = path.parent_path().parent_path().parent_path() / "test/execution-files/allreduce.json"; - mscclpp::ExecutionPlan plan("allreduce_pairs", executionFilesPath.string()); + mscclpp::ExecutionPlan plan(executionFilesPath.string()); const int bufferSize = 1024 * 1024; std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize); mscclpp::CudaStreamWithFlags stream(cudaStreamNonBlocking);