From b30bb260e31dd4b37cd05dca65ac89395a53f5c4 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 18 Sep 2024 17:21:47 -0700 Subject: [PATCH 01/12] Tune threads per block for mscclpp executor (#345) --- apps/nccl/src/nccl.cu | 15 +++++++------- include/mscclpp/executor.hpp | 2 +- python/mscclpp/executor_py.cpp | 7 +++---- python/test/executor_test.py | 17 +++++----------- python/test/test_mscclpp.py | 1 - src/executor/execution_plan.cc | 3 +++ src/executor/executor.cc | 15 +++++++------- src/include/execution_plan.hpp | 2 ++ test/execution-files/allreduce.json | 1 + test/execution-files/allreduce_packet.json | 1 + test/executor_test.cc | 23 ++++++++++------------ test/mp_unit/executor_tests.cc | 2 +- 12 files changed, 43 insertions(+), 46 deletions(-) diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index cb0e7d56e..ec130b0b1 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -488,29 +488,30 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t std::shared_ptr plan; if (bytes <= comm->largeMessageSizeBoundary) plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan; - else + else { plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan; + } if (plan == nullptr) return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream); switch (datatype) { case ncclFloat16: - comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, 1024, - *plan, stream, mscclpp::PacketType::LL8); + comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan, + stream, mscclpp::PacketType::LL8); break; case ncclFloat32: comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32, - 1024, *plan, stream, mscclpp::PacketType::LL8); + *plan, stream, mscclpp::PacketType::LL8); break; case ncclBfloat16: comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes, - mscclpp::DataType::BFLOAT16, 1024, *plan, stream, mscclpp::PacketType::LL8); + mscclpp::DataType::BFLOAT16, *plan, stream, mscclpp::PacketType::LL8); break; case ncclInt32: case ncclUint32: - comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, 1024, - *plan, stream, mscclpp::PacketType::LL8); + comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan, + stream, mscclpp::PacketType::LL8); break; default: return ncclInvalidArgument; diff --git a/include/mscclpp/executor.hpp b/include/mscclpp/executor.hpp index 52c3c6da9..e994548e4 100644 --- a/include/mscclpp/executor.hpp +++ b/include/mscclpp/executor.hpp @@ -43,7 +43,7 @@ class Executor { ~Executor(); void execute(int rank, void* sendbuff, void* recvBuff, size_t sendBuffSize, size_t recvBuffSize, DataType dataType, - int nthreads, const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType = PacketType::LL16); + const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType = PacketType::LL16); private: struct Impl; diff --git a/python/mscclpp/executor_py.cpp b/python/mscclpp/executor_py.cpp index dadbf40f6..c550ecb00 100644 --- a/python/mscclpp/executor_py.cpp +++ b/python/mscclpp/executor_py.cpp @@ -29,11 +29,10 @@ void register_executor(nb::module_& m) { .def( "execute", [](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize, - DataType dataType, int nthreads, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) { + DataType dataType, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) { self->execute(rank, reinterpret_cast(sendbuff), reinterpret_cast(recvBuff), sendBuffSize, - recvBuffSize, dataType, nthreads, plan, (cudaStream_t)stream, packetType); + recvBuffSize, dataType, plan, (cudaStream_t)stream, packetType); }, nb::arg("rank"), nb::arg("sendbuff"), nb::arg("recvBuff"), nb::arg("sendBuffSize"), nb::arg("recvBuffSize"), - nb::arg("dataType"), nb::arg("nthreads"), nb::arg("plan"), nb::arg("stream"), - nb::arg("packetType") = PacketType::LL16); + nb::arg("dataType"), nb::arg("plan"), nb::arg("stream"), nb::arg("packetType") = PacketType::LL16); } diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 23c3ff483..53c11eb19 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -81,10 +81,9 @@ def main( execution_paln_name: str, execution_plan_path: str, size: int, - nthreads_per_block: int, dtype: cp.dtype = cp.float16, packet_type: PacketType = PacketType.LL16, - seed: int = 42, + seed: int = 42 + MPI.COMM_WORLD.rank, ): mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD) cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use() @@ -96,12 +95,9 @@ def main( cp.random.seed(seed) nelems = size // cp.dtype(dtype).itemsize - buffer = cp.random.random(nelems * mscclpp_group.nranks).astype(dtype) - sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size) - sendbuf = sub_arrays[MPI.COMM_WORLD.rank] - expected = cp.zeros_like(sendbuf) - for i in range(mscclpp_group.nranks): - expected += sub_arrays[i] + sendbuf = cp.random.random(nelems).astype(dtype) + expected = cp.asnumpy(sendbuf) + expected = MPI.COMM_WORLD.allreduce(expected, op=MPI.SUM) mscclpp_group.barrier() executor_func = lambda stream: executor.execute( @@ -111,7 +107,6 @@ def main( sendbuf.nbytes, sendbuf.nbytes, dtype_to_mscclpp_dtype(dtype), - nthreads_per_block, execution_plan, stream.ptr, packet_type, @@ -130,7 +125,7 @@ def main( print( f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, " f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} " - f"packet type: {packet_type} nthreads_per_block: {nthreads_per_block}" + f"packet type: {packet_type}" ) executor = None mscclpp_group = None @@ -141,7 +136,6 @@ def main( parser.add_argument("-n", "--execution_plan_name", type=str, required=True) parser.add_argument("-path", "--execution_plan_path", type=str, required=True) parser.add_argument("--size", type=str, required=True) - parser.add_argument("--nthreads_per_block", type=int, required=True) parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32") parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16") parser.add_argument("--seed", type=int, default=42) @@ -157,7 +151,6 @@ def main( args.execution_plan_name, args.execution_plan_path, buffer_size, - args.nthreads_per_block, dtype, packet_type, args.seed, diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index 9535c869f..1be0b1821 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -630,7 +630,6 @@ def test_executor(mpi_group: MpiGroup, filename: str): sendbuf.nbytes, sendbuf.nbytes, DataType.float16, - 512, execution_plan, stream.ptr, ) diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index 8ebf510e2..bda49791a 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -161,6 +161,8 @@ std::vector ExecutionPlan::Impl::getOperations(int rank, int threadbl int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->operations.at(rank).size(); } +int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; } + void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) { std::ifstream file(this->planPath); json obj = json::parse(file); @@ -171,6 +173,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOff if (protocol == "LL") { this->isUsingPacket = true; } + this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024); const auto& gpus = obj["gpus"]; for (const auto& gpu : gpus) { diff --git a/src/executor/executor.cc b/src/executor/executor.cc index f932316a2..f0da0e979 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -65,6 +65,7 @@ struct ExecutionContext { std::shared_ptr scratchBuffer; size_t scratchBufferSize; std::shared_ptr deviceExecutionPlansBuffer; + int nthreadsPerBlock; }; struct Executor::Impl { @@ -104,6 +105,7 @@ struct Executor::Impl { context.scratchBuffer = scratchBuffer; context.scratchBufferSize = scratchBufferSize; context.proxyService = std::make_shared(); + context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock(); this->setupConnections(context, rank, plan); this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, rank, plan); @@ -295,8 +297,8 @@ struct Executor::Impl { context.deviceExecutionPlans = std::move(deviceExecutionPlans); } - void launchKernel(ExecutionContext& context, int rank, int nthreadsPerBlock, void* sendbuff, void* recvbuff, - DataType dataType, cudaStream_t stream, PacketType packetType) { + void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType, + cudaStream_t stream, PacketType packetType) { static uint32_t flag = 0; int nthreadblocks = context.deviceExecutionPlans.size(); #if defined(ENABLE_NPKIT) @@ -315,13 +317,13 @@ struct Executor::Impl { switch (packetType) { case PacketType::LL16: ExecutionKernel::launchKernel( - rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), + rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), sharedMemSize, stream, ++flag); break; case PacketType::LL8: ExecutionKernel::launchKernel( - rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), + rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), sharedMemSize, stream, ++flag); break; @@ -334,7 +336,7 @@ struct Executor::Impl { Executor::Executor(std::shared_ptr comm) : impl_(std::make_unique(comm)) {} void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize, - [[maybe_unused]] size_t recvBuffSize, DataType dataType, int nthreads, const ExecutionPlan& plan, + [[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType) { size_t sendBytes, recvBytes; CUdeviceptr sendBasePtr, recvBasePtr; @@ -345,8 +347,7 @@ void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuff ExecutionContext context = this->impl_->setupExecutionContext( rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, offsetIn, offsetOut, sendBytes, recvBytes, plan); - // TODO(binyli): need to flush proxy channel here - this->impl_->launchKernel(context, rank, nthreads, sendbuff, recvbuff, dataType, stream, packetType); + this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType); } Executor::~Executor() = default; diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp index 79ab4af0b..ce6ec8ca2 100644 --- a/src/include/execution_plan.hpp +++ b/src/include/execution_plan.hpp @@ -68,6 +68,7 @@ struct ExecutionPlan::Impl { size_t getScratchBufferSize(int rank, size_t inputSize) const; std::vector getOperations(int rank, int threadblock) const; int getThreadblockCount(int rank) const; + int getNThreadsPerBlock() const; void loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset); void lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset); @@ -93,6 +94,7 @@ struct ExecutionPlan::Impl { std::unordered_map scratchChunks; std::unordered_map chunkGroups; size_t inputSize; + int nThreadsPerBlock; private: size_t getOffset(int rank, size_t inputSize, uint32_t chunkIndex, uint32_t alignment = 16) const; diff --git a/test/execution-files/allreduce.json b/test/execution-files/allreduce.json index eb7e41940..afc921f4b 100644 --- a/test/execution-files/allreduce.json +++ b/test/execution-files/allreduce.json @@ -3,6 +3,7 @@ "colletive": "allreduce", "protocol": "Simple", "inplace": true, + "num_threads_per_block": 512, "gpus": [ { "id": 0, diff --git a/test/execution-files/allreduce_packet.json b/test/execution-files/allreduce_packet.json index b0df82c91..d35a4e96b 100644 --- a/test/execution-files/allreduce_packet.json +++ b/test/execution-files/allreduce_packet.json @@ -3,6 +3,7 @@ "colletive": "allreduce", "protocol": "LL", "inplace": true, + "num_threads_per_block": 768, "gpus": [ { "id": 0, diff --git a/test/executor_test.cc b/test/executor_test.cc index 33a795045..f85c38454 100644 --- a/test/executor_test.cc +++ b/test/executor_test.cc @@ -56,16 +56,16 @@ mscclpp::PacketType parsePacketType(const char* value) { } double benchTime(int rank, std::shared_ptr bootstrap, std::shared_ptr executor, - const mscclpp::ExecutionPlan& plan, std::shared_ptr sendbuff, size_t bufferSize, - int nthreadsPerBlock, int niters, int ngrapthIters, mscclpp::PacketType packetType) { + const mscclpp::ExecutionPlan& plan, std::shared_ptr sendbuff, size_t bufferSize, int niters, + int ngrapthIters, mscclpp::PacketType packetType) { mscclpp::CudaStreamWithFlags stream(cudaStreamNonBlocking); cudaGraph_t graph; cudaGraphExec_t graphExec; mscclpp::Timer timer; MSCCLPP_CUDATHROW(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal)); for (int i = 0; i < niters; i++) { - executor->execute(rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16, - nthreadsPerBlock, plan, stream, packetType); + executor->execute(rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16, plan, + stream, packetType); } MSCCLPP_CUDATHROW(cudaStreamEndCapture(stream, &graph)); MSCCLPP_CUDATHROW(cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0)); @@ -86,11 +86,10 @@ double benchTime(int rank, std::shared_ptr bootstrap, std::s } int main(int argc, char* argv[]) { - if (argc != 7 && argc != 8) { + if (argc != 6 && argc != 7) { std::cerr << "Usage: " << argv[0] << " " << " " << " " - << " " << " " << " " << " (optional) " << std::endl; @@ -107,13 +106,12 @@ int main(int argc, char* argv[]) { const size_t bufferSize = parseSize(argv[1]); const std::string executionPlanName = argv[2]; const std::string executionPlanPath = argv[3]; - const int nthreadsPerBlock = std::stoi(argv[4]); - const int niters = std::stoi(argv[5]); - const int ngraphIters = std::stoi(argv[6]); + const int niters = std::stoi(argv[4]); + const int ngraphIters = std::stoi(argv[5]); const char* npkitDumpDir = getenv("NPKIT_DUMP_DIR"); mscclpp::PacketType packetType = mscclpp::PacketType::LL16; - if (argc == 8) { - packetType = parsePacketType(argv[7]); + if (argc == 7) { + packetType = parsePacketType(argv[6]); } std::shared_ptr bootstrap; @@ -133,8 +131,7 @@ int main(int argc, char* argv[]) { std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize); std::vector dataHost(bufferSize / sizeof(int), rank); MSCCLPP_CUDATHROW(cudaMemcpy(sendbuff.get(), dataHost.data(), bufferSize, cudaMemcpyHostToDevice)); - double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, nthreadsPerBlock, niters, - ngraphIters, packetType); + double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, niters, ngraphIters, packetType); if (npkitDumpDir != nullptr) { NpKit::Dump(npkitDumpDir); diff --git a/test/mp_unit/executor_tests.cc b/test/mp_unit/executor_tests.cc index 5baa2b67a..49952b6b4 100644 --- a/test/mp_unit/executor_tests.cc +++ b/test/mp_unit/executor_tests.cc @@ -59,7 +59,7 @@ TEST_F(ExecutorTest, TwoNodesAllreduce) { const int bufferSize = 1024 * 1024; std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize); mscclpp::CudaStreamWithFlags stream(cudaStreamNonBlocking); - executor->execute(gEnv->rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16, 512, + executor->execute(gEnv->rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16, plan, stream); MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); } From 5c4e105814848f4cb66c56b3358a2720865158f2 Mon Sep 17 00:00:00 2001 From: Ziyue Yang Date: Thu, 19 Sep 2024 13:35:44 +0800 Subject: [PATCH 02/12] Fix NPKit exit event offset (#356) --- include/mscclpp/npkit/npkit_event.hpp | 2 +- src/include/execution_common.hpp | 4 ++-- tools/npkit/npkit_trace_generator.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/mscclpp/npkit/npkit_event.hpp b/include/mscclpp/npkit/npkit_event.hpp index da0206c0f..1a24b241f 100644 --- a/include/mscclpp/npkit/npkit_event.hpp +++ b/include/mscclpp/npkit/npkit_event.hpp @@ -13,6 +13,6 @@ #define NPKIT_EVENT_EXECUTOR_INIT_EXIT 0x4 #define NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY 0x5 -#define NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT 0x15 +#define NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT 0x17 #endif diff --git a/src/include/execution_common.hpp b/src/include/execution_common.hpp index 87d07f0b4..99bf36a4f 100644 --- a/src/include/execution_common.hpp +++ b/src/include/execution_common.hpp @@ -31,6 +31,8 @@ enum class OperationType : uint8_t { BARRIER, PUT, PUT_PACKET, + PUT_WITH_SIGNAL, + PUT_WITH_SIGNAL_AND_FLUSH, GET, COPY, COPY_PACKET, @@ -44,8 +46,6 @@ enum class OperationType : uint8_t { REDUCE_SEND_PACKET, READ_REDUCE_COPY, READ_REDUCE_COPY_SEND, - PUT_WITH_SIGNAL, - PUT_WITH_SIGNAL_AND_FLUSH, }; struct Channels { diff --git a/tools/npkit/npkit_trace_generator.py b/tools/npkit/npkit_trace_generator.py index 96224b35e..9a5b88b44 100644 --- a/tools/npkit/npkit_trace_generator.py +++ b/tools/npkit/npkit_trace_generator.py @@ -14,6 +14,8 @@ def parse_npkit_event_header(npkit_event_header_path): "BARRIER", "PUT", "PUT_PACKET", + "PUT_WITH_SIGNAL", + "PUT_WITH_SIGNAL_AND_FLUSH", "GET", "COPY", "COPY_PACKET", @@ -27,8 +29,6 @@ def parse_npkit_event_header(npkit_event_header_path): "REDUCE_SEND_PACKET", "READ_REDUCE_COPY", "READ_REDUCE_COPY_SEND", - "PUT_WITH_SIGNAL", - "PUT_WITH_SIGNAL_AND_FLUSH", ] executor_op_to_offset = {} for executor_op in executor_ops: From 74130c7c5eb9e1e95ac86cf26bb610ece2ec85a8 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 19 Sep 2024 00:13:11 -0700 Subject: [PATCH 03/12] Use IB transport flags only when an IB device exists (#355) --- test/mscclpp-test/common.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/mscclpp-test/common.cc b/test/mscclpp-test/common.cc index 9c52f9f4a..899823f7d 100644 --- a/test/mscclpp-test/common.cc +++ b/test/mscclpp-test/common.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -399,7 +400,8 @@ void BaseTestEngine::setupMeshConnectionsInternal( void BaseTestEngine::setupMeshConnections(std::vector>& proxyChannels, void* inputBuff, size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes, SetupChannelFunc setupChannel) { - const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum]; + mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc; + if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum]; mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports); mscclpp::RegisteredMemory outputBufRegMem; if (outputBuff) { @@ -429,7 +431,8 @@ void BaseTestEngine::setupMeshConnections(std::vector& smChannels, void* inputBuff, size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes, ChannelSemantic semantic, size_t nChannelPerConnection) { - const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum]; + mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc; + if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum]; mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports); mscclpp::RegisteredMemory getPacketBufRegMem; mscclpp::RegisteredMemory outputBufRegMem; @@ -469,7 +472,8 @@ void BaseTestEngine::setupMeshConnections(std::vector& smCha void* inputBuff, size_t inputBuffBytes, void* putPacketBuff, size_t putPacketBuffBytes, void* getPacketBuff, size_t getPacketBuffBytes, void* outputBuff, size_t outputBuffBytes) { - const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum]; + mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc; + if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum]; mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports); mscclpp::RegisteredMemory putPacketBufRegMem; mscclpp::RegisteredMemory getPacketBufRegMem; From 8a330f9135a527c170b6f8006151540fb44669e4 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 20 Sep 2024 10:57:02 -0700 Subject: [PATCH 04/12] Update ROCm CI (#357) Co-authored-by: Binyang Li --- .azure-pipelines/integration-test-rocm.yml | 1 - .github/workflows/codeql-analysis.yml | 58 +++++++++++++++++++--- README.md | 3 +- 3 files changed, 53 insertions(+), 9 deletions(-) diff --git a/.azure-pipelines/integration-test-rocm.yml b/.azure-pipelines/integration-test-rocm.yml index c098ab085..029259ffb 100644 --- a/.azure-pipelines/integration-test-rocm.yml +++ b/.azure-pipelines/integration-test-rocm.yml @@ -64,7 +64,6 @@ jobs: set -e git clone https://$(GIT_USER):$(GIT_PAT)@msazure.visualstudio.com/DefaultCollection/One/_git/azure-mscclpp cd azure-mscclpp - git checkout binyli/ci mkdir execution-files python3 algos/allreduce_mi300_packet.py 8 8 > execution-files/allreduce_mi300_packet.json python3 algos/allreduce_mi300_sm_mscclpp.py 8 8 > execution-files/allreduce_mi300_sm_mscclpp.json diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 7295171e9..73496445d 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -9,11 +9,11 @@ on: - cron: "30 1 * * 1" jobs: - analyze: - name: Analyze + analyze-cuda: + name: Analyze (CUDA) runs-on: 'ubuntu-latest' container: - image: ghcr.io/microsoft/mscclpp/mscclpp:base-dev-${{ matrix.cuda-version }} + image: ghcr.io/microsoft/mscclpp/mscclpp:base-dev-${{ matrix.version }} permissions: actions: read @@ -24,7 +24,7 @@ jobs: fail-fast: false matrix: language: [ 'cpp', 'python' ] - cuda-version: [ 'cuda11.8', 'cuda12.2' ] + version: [ 'cuda11.8', 'cuda12.2' ] steps: - name: Checkout repository @@ -45,10 +45,56 @@ jobs: - name: Build run: | - cmake -DBYPASS_GPU_CHECK=ON -DUSE_CUDA=ON . + rm -rf build && mkdir build && cd build + cmake -DBYPASS_GPU_CHECK=ON -DUSE_CUDA=ON .. make -j - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v2 with: - category: "/language:${{matrix.language}}/cuda-version:${{matrix.cuda-version}}" + category: "/language:${{matrix.language}}/version:${{matrix.version}}" + + analyze-rocm: + name: Analyze (ROCm) + runs-on: 'ubuntu-latest' + container: + image: ghcr.io/microsoft/mscclpp/mscclpp:base-dev-${{ matrix.version }} + + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'cpp', 'python' ] + version: [ 'rocm6.2' ] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Check disk space + run: | + df -h + + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + + - name: Dubious ownership exception + run: | + git config --global --add safe.directory /__w/mscclpp/mscclpp + + - name: Build + run: | + rm -rf build && mkdir build && cd build + CXX=/opt/rocm/bin/hipcc cmake -DBYPASS_GPU_CHECK=ON -DUSE_ROCM=ON .. + make -j + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{matrix.language}}/version:${{matrix.version}}" diff --git a/README.md b/README.md index 9796179d3..cd8e80790 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,7 @@ |--------------------------|-------------------| | Unit Tests (CUDA) | [![Build Status](https://dev.azure.com/binyli/HPC/_apis/build/status%2Fmscclpp-ut?branchName=main)](https://dev.azure.com/binyli/HPC/_build/latest?definitionId=4&branchName=main) | | Integration Tests (CUDA) | [![Build Status](https://dev.azure.com/binyli/HPC/_apis/build/status%2Fmscclpp-test?branchName=main)](https://dev.azure.com/binyli/HPC/_build/latest?definitionId=3&branchName=main) | - -*NOTE (Nov 2023): Azure pipelines for ROCm will be added soon.* +| Integration Tests (ROCm) | [![Build Status](https://dev.azure.com/binyli/HPC/_apis/build/status%2Fmscclpp-test-rocm?branchName=main)](https://dev.azure.com/binyli/HPC/_build/latest?definitionId=7&branchName=main) | A GPU-driven communication stack for scalable AI applications. From 08a0cec2eb0d3d2c2abde2b8cb599a9180ccc72c Mon Sep 17 00:00:00 2001 From: Caio Rocha <164253795+caiomcbr@users.noreply.github.com> Date: Tue, 24 Sep 2024 23:01:41 -0700 Subject: [PATCH 05/12] Fixing RegisterMemory Allocation for ProxyChannels (#353) Co-authored-by: Binyang Li Co-authored-by: Changho Hwang --- python/test/executor_test.py | 34 +++++++++++++----- src/connection.cc | 17 +++++---- src/executor/executor.cc | 67 +++++++++++++++++++----------------- 3 files changed, 71 insertions(+), 47 deletions(-) diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 53c11eb19..3e0c369d3 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -78,12 +78,13 @@ def dtype_to_mscclpp_dtype(dtype): def main( - execution_paln_name: str, + execution_plan_name: str, execution_plan_path: str, size: int, + in_place: bool = True, dtype: cp.dtype = cp.float16, packet_type: PacketType = PacketType.LL16, - seed: int = 42 + MPI.COMM_WORLD.rank, + seed: int = 42, ): mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD) cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use() @@ -91,21 +92,33 @@ def main( npkit_dump_dir = os.getenv("NPKIT_DUMP_DIR") if npkit_dump_dir is not None: npkit.init(mscclpp_group.my_rank) - execution_plan = ExecutionPlan(execution_paln_name, execution_plan_path) + execution_plan = ExecutionPlan(execution_plan_name, execution_plan_path) cp.random.seed(seed) nelems = size // cp.dtype(dtype).itemsize - sendbuf = cp.random.random(nelems).astype(dtype) - expected = cp.asnumpy(sendbuf) - expected = MPI.COMM_WORLD.allreduce(expected, op=MPI.SUM) + buffer = cp.random.random(nelems * mscclpp_group.nranks, dtype=cp.float32).astype(dtype) + sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size) + sendbuf = cp.zeros(nelems, dtype=dtype) + for i in range(nelems): + sendbuf[i] = sub_arrays[MPI.COMM_WORLD.rank][i] + + if "allgather" in execution_plan_name: + recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype) + expected = buffer + else: + cp.random.seed(seed) + recvbuf = cp.zeros(nelems, dtype=dtype) + expected = cp.zeros_like(sendbuf, dtype=dtype) + for i in range(mscclpp_group.nranks): + expected += sub_arrays[i] mscclpp_group.barrier() executor_func = lambda stream: executor.execute( MPI.COMM_WORLD.rank, sendbuf.data.ptr, - sendbuf.data.ptr, - sendbuf.nbytes, + sendbuf.data.ptr if in_place else recvbuf.data.ptr, sendbuf.nbytes, + sendbuf.nbytes if in_place else recvbuf.nbytes, dtype_to_mscclpp_dtype(dtype), execution_plan, stream.ptr, @@ -115,7 +128,8 @@ def main( stream = cp.cuda.Stream(non_blocking=True) executor_func(stream) stream.synchronize() - assert cp.allclose(sendbuf, expected, atol=1e-2 * mscclpp_group.nranks) + + assert cp.allclose(sendbuf if in_place else recvbuf, expected, atol=1e-2 * mscclpp_group.nranks) mscclpp_group.barrier() execution_time = bench_time(100, 10, executor_func) @@ -136,6 +150,7 @@ def main( parser.add_argument("-n", "--execution_plan_name", type=str, required=True) parser.add_argument("-path", "--execution_plan_path", type=str, required=True) parser.add_argument("--size", type=str, required=True) + parser.add_argument("--in_place", action="store_true", help="flag to define an in-place operation") parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32") parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16") parser.add_argument("--seed", type=int, default=42) @@ -151,6 +166,7 @@ def main( args.execution_plan_name, args.execution_plan_path, buffer_size, + args.in_place, dtype, packet_type, args.seed, diff --git a/src/connection.cc b/src/connection.cc index 57e77b40b..79c4c9630 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -16,10 +16,13 @@ namespace mscclpp { -void validateTransport(RegisteredMemory mem, Transport transport) { +void validateTransport(RegisteredMemory mem, Transport transport, uint64_t offset = 0, uint64_t size = 0) { if (!mem.transports().has(transport)) { throw Error("RegisteredMemory does not support this transport", ErrorCode::InvalidUsage); } + if (offset + size > mem.size()) { + throw Error("RegisteredMemory out of bounds", ErrorCode::InvalidUsage); + } } // Connection @@ -59,8 +62,8 @@ Transport CudaIpcConnection::remoteTransport() { return Transport::CudaIpc; } void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { - validateTransport(dst, remoteTransport()); - validateTransport(src, transport()); + validateTransport(dst, remoteTransport(), dstOffset, size); + validateTransport(src, transport(), srcOffset, size); char* dstPtr = (char*)dst.data(); char* srcPtr = (char*)src.data(); @@ -115,8 +118,8 @@ Transport IBConnection::remoteTransport() { return remoteTransport_; } void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { - validateTransport(dst, remoteTransport()); - validateTransport(src, transport()); + validateTransport(dst, remoteTransport(), dstOffset, size); + validateTransport(src, transport(), srcOffset, size); auto dstTransportInfo = getImpl(dst)->getTransportInfo(remoteTransport()); if (dstTransportInfo.ibLocal) { @@ -231,8 +234,8 @@ Transport EthernetConnection::remoteTransport() { return Transport::Ethernet; } void EthernetConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { // Validating Transport Protocol - validateTransport(dst, remoteTransport()); - validateTransport(src, transport()); + validateTransport(dst, remoteTransport(), dstOffset, size); + validateTransport(src, transport(), srcOffset, size); // Initializing Variables char* srcPtr = reinterpret_cast(src.data()) + srcOffset / sizeof(char); diff --git a/src/executor/executor.cc b/src/executor/executor.cc index f0da0e979..54986d5d6 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -108,7 +108,7 @@ struct Executor::Impl { context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock(); this->setupConnections(context, rank, plan); this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); - this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, rank, plan); + this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); this->setupDeviceExecutionPlan(context, rank, plan); context.deviceExecutionPlansBuffer = allocExtSharedCuda(context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); @@ -119,6 +119,23 @@ struct Executor::Impl { return context; } + TransportFlags getTransportFlags(std::vector& infos, int rank) { + TransportFlags flags; + for (ChannelInfo& info : infos) { + if (info.channelType == ChannelType::SM) { + flags |= Transport::CudaIpc; + } else if (info.channelType == ChannelType::PROXY) { + for (int peer : info.connectedPeers) { + if (!inSameNode(rank, peer, this->nranksPerNode)) { + flags |= IBs[rank % this->nranksPerNode]; + } else + flags |= Transport::CudaIpc; + } + } + } + return flags; + }; + void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan) { std::vector connectedPeers = plan.impl_->getConnectedPeers(rank); std::vector>> connectionFutures; @@ -135,22 +152,6 @@ struct Executor::Impl { void setupRegisteredMemories(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, size_t recvBufferSize, int rank, const ExecutionPlan& plan) { - auto getTransportFlags = [&](std::vector& infos, int rank) { - TransportFlags flags; - for (ChannelInfo& info : infos) { - if (info.channelType == ChannelType::SM) { - flags |= Transport::CudaIpc; - } else if (info.channelType == ChannelType::PROXY) { - for (int peer : info.connectedPeers) { - if (!inSameNode(rank, peer, this->nranksPerNode)) { - flags |= IBs[rank % this->nranksPerNode]; - } else - flags |= Transport::CudaIpc; - } - } - } - return flags; - }; auto getBufferInfo = [&](BufferType type) { switch (type) { case BufferType::INPUT: @@ -192,22 +193,12 @@ struct Executor::Impl { comm->setup(); for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) { context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get()); - CUdeviceptr myRegBaseAdr, peerRegBaseAdr; - size_t temp; - MSCCLPP_CUTHROW(cuMemGetAddressRange(&myRegBaseAdr, &temp, (CUdeviceptr)(char*)memory.data())); - MSCCLPP_CUTHROW(cuMemGetAddressRange( - &peerRegBaseAdr, &temp, - (CUdeviceptr)(char*)context.registeredMemories[{bufferType, connectedPeers[i]}].data())); - size_t myRegOffset = (char*)memory.data() - (char*)myRegBaseAdr; - size_t peerRegOffset = - (char*)context.registeredMemories[{bufferType, connectedPeers[i]}].data() - (char*)peerRegBaseAdr; - if (myRegOffset != peerRegOffset) throw Error("Divergent data offset between peers", ErrorCode::ExecutorError); } } } - void setupChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, int rank, - const ExecutionPlan& plan) { + void setupChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, + size_t recvBufferSize, int rank, const ExecutionPlan& plan) { const auto channelTypes = {ChannelType::SM, ChannelType::PROXY}; std::vector> smSemaphores; std::vector proxySemaphores; @@ -251,13 +242,27 @@ struct Executor::Impl { throw Error("Invalid buffer type", ErrorCode::ExecutorError); } }; + auto getBufferSize = [&](BufferType type) { + switch (type) { + case BufferType::INPUT: + return sendBufferSize; + case BufferType::OUTPUT: + return recvBufferSize; + case BufferType::SCRATCH: + return context.scratchBufferSize; + default: + throw Error("Invalid buffer type", ErrorCode::ExecutorError); + } + }; + for (ChannelType channelType : channelTypes) { std::vector channelInfos = plan.impl_->getChannelInfos(rank, channelType); int index = 0; for (ChannelInfo& info : channelInfos) { void* src = getBuffer(info.srcBufferType); - TransportFlags transport = context.registeredMemories.begin()->second.transports(); - RegisteredMemory localMemory = this->comm->registerMemory(src, sendBufferSize, transport); + size_t bufferSize = getBufferSize(info.srcBufferType); + TransportFlags transport = getTransportFlags(channelInfos, rank); + RegisteredMemory localMemory = this->comm->registerMemory(src, bufferSize, transport); for (int peer : info.connectedPeers) { if (channelType == ChannelType::SM) { context.smChannels.emplace_back(context.smSemaphores[index++], From e9294357c53424f6f4881d580bd698025e7850a8 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 16 Oct 2024 14:16:34 -0700 Subject: [PATCH 06/12] Fix NCCL API bugs (#363) --- apps/nccl/src/allgather.hpp | 7 +++++++ apps/nccl/src/allreduce.hpp | 9 +++++---- apps/nccl/src/common.hpp | 4 ++++ apps/nccl/src/nccl.cu | 9 +++++---- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/apps/nccl/src/allgather.hpp b/apps/nccl/src/allgather.hpp index 35c2b60c4..b4ee696e4 100644 --- a/apps/nccl/src/allgather.hpp +++ b/apps/nccl/src/allgather.hpp @@ -102,6 +102,13 @@ __global__ void __launch_bounds__(1024, 1) } } } + + deviceSyncer.sync(gridDim.x); + + if (threadIdx.x < nPeer) { + smChans[threadIdx.x].relaxedSignal(); + smChans[threadIdx.x].wait(); + } } template diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index 1b85136ae..08cd46aca 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -14,8 +14,6 @@ #include "common.hpp" #include "gpu_data_types.hpp" -__device__ mscclpp::DeviceSyncer deviceSyncer; - template __forceinline__ __device__ To bit_cast(const From& src) { static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast"); @@ -236,14 +234,13 @@ __global__ void __launch_bounds__(1024, 1) blockDim.x * nBlocksPerPeer, flag); // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { - uint32_t data = 0; + uint32_t data = src[idx]; for (int index = 0; index < nPeers; index++) { const int remoteRank = index < rank ? index : index + 1; mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nPktsPerRank; uint32_t val = dstPkt[idx].read(flag, -1); data = add_vectors(val, data); } - data = add_vectors(data, src[idx]); dst[idx] = data; mscclpp::LL8Packet packet; @@ -384,6 +381,10 @@ __global__ void __launch_bounds__(512, 1) } } } + if (threadIdx.x < static_cast(nPeer)) { + outChannels[threadIdx.x].signal(); + outChannels[threadIdx.x].wait(); + } } template diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp index 25c74e71b..bf512106c 100644 --- a/apps/nccl/src/common.hpp +++ b/apps/nccl/src/common.hpp @@ -4,6 +4,8 @@ #ifndef NCCL_COMMON_HPP_ #define NCCL_COMMON_HPP_ +#include + #if defined(__HIP_PLATFORM_AMD__) #define WARP_SIZE 64 #define __syncwarp() __builtin_amdgcn_wave_barrier() @@ -14,4 +16,6 @@ constexpr int NRANKS_PER_NODE = 8; constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB +__device__ mscclpp::DeviceSyncer deviceSyncer; + #endif // NCCL_COMMON_HPP_ diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index ec130b0b1..a414ffe85 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -293,9 +293,10 @@ NCCL_API ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId) { return ncclSuccess; } -NCCL_API ncclResult_t ncclCommInitRankConfig(ncclComm_t*, int, ncclUniqueId, int, ncclConfig_t*) { - // TODO: implement this function - return ncclInternalError; +NCCL_API ncclResult_t ncclCommInitRankConfig(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank, + ncclConfig_t*) { + // TODO: implement config + return ncclCommInitRank(comm, nranks, commId, rank); } NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) { @@ -419,7 +420,7 @@ NCCL_API const char* ncclGetErrorString(ncclResult_t result) { NCCL_API const char* ncclGetLastError(ncclComm_t) { // TODO: implement this function - return nullptr; + return ""; } NCCL_API ncclResult_t ncclCommGetAsyncError(ncclComm_t, ncclResult_t* asyncError) { From f8c0bcca2b428f75f3911bf53635241d89be4c46 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 16 Oct 2024 14:35:08 -0700 Subject: [PATCH 07/12] Perf optimization & support clipping (#364) Co-authored-by: Nusrat Islam --- apps/nccl/src/allgather.hpp | 12 ++- apps/nccl/src/allreduce.hpp | 121 +++++++++++++++++++++--------- apps/nccl/src/common.hpp | 2 + include/mscclpp/packet_device.hpp | 4 +- 4 files changed, 101 insertions(+), 38 deletions(-) diff --git a/apps/nccl/src/allgather.hpp b/apps/nccl/src/allgather.hpp index b4ee696e4..59aedbb49 100644 --- a/apps/nccl/src/allgather.hpp +++ b/apps/nccl/src/allgather.hpp @@ -115,8 +115,16 @@ template cudaError_t allgather(T* buff, [[maybe_unused]] T* scratch, [[maybe_unused]] T* resultBuff, mscclpp::DeviceHandle* smChannels, size_t channelOutOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream) { - allgather6<<<28, 1024, 0, stream>>>((void*)buff, smChannels, channelOutOffset, rank, worldSize, - nRanksPerNode, nelems * sizeof(T) / sizeof(int)); + int nBlocks = 28; + if (nelems <= 4096) { + nBlocks = 7; + } else if (nelems <= 32768) { + nBlocks = 14; + } else if (nelems >= 2097152) { + nBlocks = 35; + } + allgather6<<>>((void*)buff, smChannels, channelOutOffset, rank, worldSize, + nRanksPerNode, nelems * sizeof(T) / sizeof(int)); return cudaGetLastError(); } diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index 08cd46aca..c4c1b1a5e 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -26,19 +26,57 @@ __forceinline__ __device__ To bit_cast(const From& src) { return u.t; } +template +__forceinline__ __device__ T clip(T val) { + return val; +} + +template <> +__forceinline__ __device__ __half clip(__half val) { + val = __hmax(val, bit_cast<__half, unsigned short>(0xfbff)); + val = __hmin(val, bit_cast<__half, unsigned short>(0x7bff)); + + return val; +} + +template <> +__forceinline__ __device__ __half2 clip(__half2 val) { + val.x = __hmax(val.x, bit_cast<__half, unsigned short>(0xfbff)); + val.x = __hmin(val.x, bit_cast<__half, unsigned short>(0x7bff)); + val.y = __hmax(val.y, bit_cast<__half, unsigned short>(0xfbff)); + val.y = __hmin(val.y, bit_cast<__half, unsigned short>(0x7bff)); + return val; +} + +template <> +__forceinline__ __device__ __bfloat16 clip(__bfloat16 val) { + val = __hmax(val, bit_cast<__bfloat16, unsigned short>(0xff80)); + val = __hmin(val, bit_cast<__bfloat16, unsigned short>(0x7f80)); + return val; +} + +template <> +__forceinline__ __device__ __bfloat162 clip(__bfloat162 val) { + val.x = __hmax(val.x, bit_cast<__bfloat16, unsigned short>(0xff80)); + val.x = __hmin(val.x, bit_cast<__bfloat16, unsigned short>(0x7f80)); + val.y = __hmax(val.y, bit_cast<__bfloat16, unsigned short>(0xff80)); + val.y = __hmin(val.y, bit_cast<__bfloat16, unsigned short>(0x7f80)); + return val; +} + template __forceinline__ __device__ T add_elements(T a, T b) { - return a + b; + return clip(a + b); } template <> __forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) { - return __hadd2(a, b); + return clip(__hadd2(a, b)); } template <> __forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) { - return __hadd2(a, b); + return clip(__hadd2(a, b)); } template @@ -203,11 +241,19 @@ __global__ void __launch_bounds__(1024, 1) size_t nelems, uint32_t flag) { // This version of allreduce only works for single nodes if (worldSize != nRanksPerNode) return; - nelems = nelems / (sizeof(int) / sizeof(T)); + + if (sizeof(T) == 2) + nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); + else + nelems = nelems / (sizeof(int) / sizeof(T)); + const int nPeers = nRanksPerNode - 1; - const size_t nPkts = nelems; - const int nelemsPerRank = nelems / worldSize; - const int nPktsPerRank = nelemsPerRank; + const size_t nPkts = nelems / 2; + + int nelemsPerRank = nelems / worldSize; + if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T); + + const int nPktsPerRank = nelemsPerRank / 2; // thread block & channel info const int nBlocksPerPeer = gridDim.x / nPeers; const int localBlockIdx = blockIdx.x % nBlocksPerPeer; @@ -215,11 +261,12 @@ __global__ void __launch_bounds__(1024, 1) const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1; const int tid = threadIdx.x + localBlockIdx * blockDim.x; void* scratchBuff = (void*)((char*)scratch + channelScratchOffset); - size_t scratchOffset = channelScratchOffset + rank * nPktsPerRank * sizeof(mscclpp::LL8Packet); - size_t scratchResultOffset = channelScratchOffset + 2 * nPkts * sizeof(mscclpp::LL8Packet); + size_t scratchOffset = channelScratchOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket); + size_t scratchResultOffset = channelScratchOffset + 2 * nPkts * sizeof(mscclpp::LLPacket); size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int) + channelDataOffset; - uint32_t* src = (uint32_t*)((char*)buff + rank * nelemsPerRank * sizeof(int)); - uint32_t* dst = (uint32_t*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int)); + + uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int)); + uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int)); // Put channels into shared memory, read channel info from global memory is unexpectable slow. __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; @@ -230,34 +277,40 @@ __global__ void __launch_bounds__(1024, 1) __syncwarp(); // step 1: write to scratch buffer - channels[peerIdx].putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, - blockDim.x * nBlocksPerPeer, flag); + channels[peerIdx].putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, + blockDim.x * nBlocksPerPeer, flag); // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { - uint32_t data = src[idx]; - for (int index = 0; index < nPeers; index++) { + uint2 data = src[idx]; + for (int index = 0; index < NPEERS; index++) { const int remoteRank = index < rank ? index : index + 1; - mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nPktsPerRank; - uint32_t val = dstPkt[idx].read(flag, -1); - data = add_vectors(val, data); + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank; + uint2 val = dstPkt[idx].read(flag); + data.x = add_vectors(val.x, data.x); + data.y = add_vectors(val.y, data.y); } - dst[idx] = data; - mscclpp::LL8Packet packet; - packet.data = data; - packet.flag = flag; - size_t offset = scratchResultOffset / sizeof(mscclpp::LL8Packet) + (idx + rank * nPktsPerRank); - for (int index = 0; index < nPeers; index++) { + dst[idx].x = data.x; + dst[idx].y = data.y; + + mscclpp::LLPacket packet; + packet.data1 = data.x; + packet.flag1 = flag; + packet.data2 = data.y; + packet.flag2 = flag; + size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank); + for (int index = 0; index < NPEERS; index++) { channels[index].write(offset, packet); } } // step 3: get data result from scratch buffer - mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)((char*)scratch + scratchResultOffset); + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset); const int dstOffset = remoteRank * nPktsPerRank; - uint32_t* result = (uint32_t*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int)); + uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int)); for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) { - uint32_t data = dstPkt[idx + dstOffset].read(flag, -1); - result[idx] = data; + uint2 data = dstPkt[idx + dstOffset].read(flag, -1); + result[idx].x = data.x; + result[idx].y = data.y; } } @@ -316,7 +369,7 @@ __global__ void __launch_bounds__(512, 1) __syncthreads(); // Starts allgather for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { - for (int i = 0; i < nPeer; i++) { + for (int i = 0; i < NPEERS; i++) { const int peerIdx = (i + blockIdx.x) % nPeer; const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; @@ -333,13 +386,13 @@ __global__ void __launch_bounds__(512, 1) for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; - for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; data = add_vectors(val, data); } resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; - for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), data); } @@ -353,7 +406,7 @@ __global__ void __launch_bounds__(512, 1) } __syncthreads(); for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { - for (int i = 0; i < nPeer; i++) { + for (int i = 0; i < NPEERS; i++) { const int peerIdx = (i + blockIdx.x) % nPeer; const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; @@ -369,13 +422,13 @@ __global__ void __launch_bounds__(512, 1) for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; - for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; data = add_vectors(val, data); } resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; - for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), data); } diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp index bf512106c..015e0a2f6 100644 --- a/apps/nccl/src/common.hpp +++ b/apps/nccl/src/common.hpp @@ -14,6 +14,8 @@ #endif constexpr int NRANKS_PER_NODE = 8; +constexpr int NPEERS = 7; + constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB __device__ mscclpp::DeviceSyncer deviceSyncer; diff --git a/include/mscclpp/packet_device.hpp b/include/mscclpp/packet_device.hpp index dc3886528..d7f2ee8a0 100644 --- a/include/mscclpp/packet_device.hpp +++ b/include/mscclpp/packet_device.hpp @@ -49,8 +49,8 @@ union alignas(16) LL16Packet { #else // !defined(MSCCLPP_DEVICE_CUDA) uint4 reg = make_uint4(val1, flag, val2, flag); ulonglong2* p = reinterpret_cast(®); - atomicStore(&(raw_.x), p->x, memoryOrderRelaxed); - atomicStore(&(raw_.y), p->y, memoryOrderRelaxed); + __builtin_nontemporal_store(p->x, &(raw_.x)); + __builtin_nontemporal_store(p->y, &(raw_.y)); #endif } From 0c150e516653d485e03ac3d79e1f57e1508ccfea Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 17 Oct 2024 21:25:46 -0700 Subject: [PATCH 08/12] Fix copyright messages (#367) --- src/bootstrap/socket.cc | 8 +++----- src/debug.cc | 8 +++----- src/endpoint.cc | 3 +++ src/include/debug.h | 8 +++----- src/include/ibverbs_wrapper.hpp | 3 +++ src/include/utils_internal.hpp | 8 +++----- test/executor_test.cc | 3 +++ test/unit/utils_internal_tests.cc | 3 +++ 8 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/bootstrap/socket.cc b/src/bootstrap/socket.cc index 9e5913403..6377bc6de 100644 --- a/src/bootstrap/socket.cc +++ b/src/bootstrap/socket.cc @@ -1,8 +1,6 @@ -/************************************************************************* - * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ +// Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. +// Modifications Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. #include "socket.h" diff --git a/src/debug.cc b/src/debug.cc index aa97b09db..dea9ee713 100644 --- a/src/debug.cc +++ b/src/debug.cc @@ -1,8 +1,6 @@ -/************************************************************************* - * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ +// Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. +// Modifications Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. #include "debug.h" diff --git a/src/endpoint.cc b/src/endpoint.cc index 35817793f..015d51a60 100644 --- a/src/endpoint.cc +++ b/src/endpoint.cc @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + #include "endpoint.hpp" #include diff --git a/src/include/debug.h b/src/include/debug.h index c3cc9f36d..1abbad340 100644 --- a/src/include/debug.h +++ b/src/include/debug.h @@ -1,8 +1,6 @@ -/************************************************************************* - * Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ +// Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved. +// Modifications Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. #ifndef MSCCLPP_DEBUG_H_ #define MSCCLPP_DEBUG_H_ diff --git a/src/include/ibverbs_wrapper.hpp b/src/include/ibverbs_wrapper.hpp index e862cbea3..fe67268a8 100644 --- a/src/include/ibverbs_wrapper.hpp +++ b/src/include/ibverbs_wrapper.hpp @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + #ifndef MSCCLPP_IBVERBS_WRAPPER_HPP_ #define MSCCLPP_IBVERBS_WRAPPER_HPP_ diff --git a/src/include/utils_internal.hpp b/src/include/utils_internal.hpp index 52edd1889..3cd06bf52 100644 --- a/src/include/utils_internal.hpp +++ b/src/include/utils_internal.hpp @@ -1,8 +1,6 @@ -/************************************************************************* - * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ +// Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. +// Modifications Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. #ifndef MSCCLPP_UTILS_INTERNAL_HPP_ #define MSCCLPP_UTILS_INTERNAL_HPP_ diff --git a/test/executor_test.cc b/test/executor_test.cc index f85c38454..2f6d9cf5d 100644 --- a/test/executor_test.cc +++ b/test/executor_test.cc @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + #include #include diff --git a/test/unit/utils_internal_tests.cc b/test/unit/utils_internal_tests.cc index 6ae04561b..5479a681a 100644 --- a/test/unit/utils_internal_tests.cc +++ b/test/unit/utils_internal_tests.cc @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + #include #include From 4136153a7628f3862534c60e6fd09c29d9d09b9f Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 17 Oct 2024 23:08:31 -0700 Subject: [PATCH 09/12] [Doc] mscclpp docs (#348) Generate docs for mescclpp. Setup github action to auto-deploy github-page doc link here: https://microsoft.github.io/mscclpp --------- Co-authored-by: Changho Hwang Co-authored-by: Caio Rocha --- .github/workflows/doc-build.yaml | 34 ++++ .github/workflows/gh-pages.yml | 58 +++++++ .readthedocs.yaml | 36 ++++ docs/.gitignore | 1 - docs/README.md | 8 +- docs/api/index.rst | 5 + docs/conf.py | 3 +- docs/design/design.md | 157 ++++++++++++++++++ .../design/nccl-over-mscclpp.md | 31 +++- docs/figs/size_boundary_diagram.png | Bin 0 -> 21789 bytes docs/{ => getting-started}/quickstart.md | 16 +- .../tutorials/customized-proxy-service.md | 1 + docs/getting-started/tutorials/index.rst | 16 ++ .../tutorials/initialization.md | 71 ++++++++ docs/getting-started/tutorials/packet-api.md | 1 + .../tutorials/proxy-channel.md | 3 + docs/getting-started/tutorials/python-api.md | 92 ++++++++++ docs/getting-started/tutorials/sm-channel.md | 3 + docs/index.rst | 55 +++++- docs/{ => performance}/performance-ndmv4.md | 0 docs/requirements.txt | 3 + 21 files changed, 575 insertions(+), 19 deletions(-) create mode 100644 .github/workflows/doc-build.yaml create mode 100644 .github/workflows/gh-pages.yml create mode 100644 .readthedocs.yaml create mode 100644 docs/api/index.rst create mode 100644 docs/design/design.md rename apps/nccl/README.md => docs/design/nccl-over-mscclpp.md (55%) create mode 100644 docs/figs/size_boundary_diagram.png rename docs/{ => getting-started}/quickstart.md (90%) create mode 100644 docs/getting-started/tutorials/customized-proxy-service.md create mode 100644 docs/getting-started/tutorials/index.rst create mode 100644 docs/getting-started/tutorials/initialization.md create mode 100644 docs/getting-started/tutorials/packet-api.md create mode 100644 docs/getting-started/tutorials/proxy-channel.md create mode 100644 docs/getting-started/tutorials/python-api.md create mode 100644 docs/getting-started/tutorials/sm-channel.md rename docs/{ => performance}/performance-ndmv4.md (100%) create mode 100644 docs/requirements.txt diff --git a/.github/workflows/doc-build.yaml b/.github/workflows/doc-build.yaml new file mode 100644 index 000000000..78af009e8 --- /dev/null +++ b/.github/workflows/doc-build.yaml @@ -0,0 +1,34 @@ +name: Docs Build + +on: + pull_request: + branches: + - '**' + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y doxygen graphviz + pip install -r docs/requirements.txt + + - name: Build docs + run: | + cd docs + doxygen + make html + touch _build/html/.nojekyll diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml new file mode 100644 index 000000000..1c2645edf --- /dev/null +++ b/.github/workflows/gh-pages.yml @@ -0,0 +1,58 @@ +name: GitHub Pages + +on: + push: + branches: + - main + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. +# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. +concurrency: + group: "pages" + cancel-in-progress: false + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y doxygen graphviz + pip install -r docs/requirements.txt + - name: Build docs + run: | + cd docs + doxygen + make html + touch _build/html/.nojekyll + - name: Upload artifacts + uses: actions/upload-pages-artifact@v3 + with: + path: docs/_build/html + + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 000000000..11da73bb2 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,36 @@ +# Read the Docs configuration file for Sphinx projects +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + apt_packages: + - doxygen + tools: + python: "3.12" + jobs: + pre_build: + - cd docs && doxygen + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/conf.py + # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs + # builder: "dirhtml" + # Fail on all warnings to avoid broken references + # fail_on_warning: true + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/requirements.txt diff --git a/docs/.gitignore b/docs/.gitignore index 00d9344fb..a69fac7ab 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,3 +1,2 @@ doxygen/ _build/ -sphinx/ diff --git a/docs/README.md b/docs/README.md index 2bb9c1efb..80d80b16c 100644 --- a/docs/README.md +++ b/docs/README.md @@ -3,13 +3,13 @@ 1. Install `doxygen`. ```bash - $ sudo apt-get install doxygen + $ sudo apt-get install doxygen graphviz ``` 2. Install Python packages below. If you install them on the user's local, you need to include `~/.local/bin` to `$PATH` (to use `sphinx-build`). ```bash - $ sudo python3 -m pip install sphinx sphinx_rtd_theme breathe + $ sudo python3 -m pip install -r ./requirements.txt ``` 3. Create Doxygen documents. @@ -21,7 +21,7 @@ 4. Create Sphinx documents. ```bash - $ sphinx-build -b html -Dbreathe_projects.mscclpp=$PWD/doxygen/xml $PWD $PWD/sphinx + $ make html ``` -5. Done. The HTML files will be on `sphinx/` directory. +5. Done. The HTML files will be on `_build/` directory. diff --git a/docs/api/index.rst b/docs/api/index.rst new file mode 100644 index 000000000..461a9fbdb --- /dev/null +++ b/docs/api/index.rst @@ -0,0 +1,5 @@ +API Reference +============= + +.. doxygennamespace:: mscclpp + :members: diff --git a/docs/conf.py b/docs/conf.py index 4a94b3aa5..4d3a91022 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,12 +14,13 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = ["breathe"] +extensions = ["breathe", "myst_parser"] templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # Breathe configuration +breathe_projects = {"mscclpp": "./doxygen/xml"} breathe_default_project = "mscclpp" # -- Options for HTML output ------------------------------------------------- diff --git a/docs/design/design.md b/docs/design/design.md new file mode 100644 index 000000000..82b6e0965 --- /dev/null +++ b/docs/design/design.md @@ -0,0 +1,157 @@ +# MSCCL++ Design Document +## Introduction +MSCCL++ redefines inter-GPU communication interfaces, thereby delivering a highly efficient and customizable communication stack for distributed GPU applications. Its design is specifically tailored to accommodate diverse performance optimization scenarios often encountered in state-of-the-art AI applications. The figure below provides a high-level overview of MSCCL++ abstractions in CUDA, C, and Python. + + +```{figure} ../figs/abstractions.png +:name: MSCCL++ Abstractions +:alt: MSCCL++ Abstractions +:align: center + +MSCCL++ Abstractions Overview +``` + +The followings highlight the key features of MSCCL++. +* **Light-weight and multi-layer abstractions.** MSCCL++ provides communication abstractions at lowest level close to hardware and at the highest level close to application API. The lowest level of abstraction is ultra light weight which enables a user to implement logics of data movement for a collective operation such as AllReduce inside a GPU kernel extremely efficiently without worrying about memory ordering of different ops. The modularity of MSCCL++ enables a user to construct the building blocks of MSCCL++ in a high level abstraction in Python and feed them to a CUDA kernel in order to facilitate the user's productivity. + +* **1-sided 0-copy synchronous and asynchronous abstracts.** MSCCL++ provides fine-grained synchronous and asynchronous 0-copy 1-sided abstracts for communication primitives such as `put()`, `get()`, `signal()`, `flush()`, and `wait()`. The 1-sided abstractions allows a user to asynchronously `put()` their data on the remote GPU as soon as it is ready without requiring the remote side to issue any receive instruction. This enables users to easily implement flexible communication logics, such as overlapping communication with computation, or implementing customized collective communication algorithms without worrying about potential deadlocks. Additionally, the 0-copy capability enables MSCCL++ to directly transfer data between user's buffers without using intermediate internal buffers which saves GPU bandwidth and memory capacity. + +* **Unified abstractions for different interconnection hardware.** MSCCL++ provides consistent abstractions regardless of the location of the remote GPU (either on the local node or on a remote node) or the underlying link (either NVLink/xGMI or InfiniBand). This simplifies the code for inter-GPU communication, which is often complex due to memory ordering of GPU/CPU read/writes and therefore, is error-prone. + +## Concepts + +To implement the list of features above, some concepts are introduced. +### Channel +MSCCL++ provides peer-to-peer communication methods between GPUs. A peer-to-peer connection between two GPUs is called a *Channel*. Channels are constructed by MSCCL++ host-side interfaces and copied to GPUs during initialization. Channels provide *GPU-side interfaces*, which means that all communication methods are defined as a device function to be called from a GPU kernel code. Following code shows the basic usage for channel, the `put()` method in the following code copies 1KB data from the local GPU to a remote GPU. +```cpp +__global__ void gpuKernel() { + ... + // Only one thread is needed for this method. + channel.put(/*dstOffset=*/ 0, /*srcOffset=*/ 0, /*size=*/ 1024); + ... +} +``` +MSCCL++ also provides efficient synchronization methods, `signal()`, `flush()`, and `wait()`. We will discuss these methods in the following sections. + +#### SmChannel & ProxyChannel +MSCCL++ delivers two types of channels, **ProxyChannel** and **SmChannel**. `ProxyChannel` provides (R)DMA-based data copy and synchronization methods. When called, these methods send/receive a signal to/from a host-side proxy (hence the name `ProxyChannel`), which will trigger (R)DMA (such as `cudaMemcpy*` or `ibv_post_send`) or issue synchronization methods (such as `cudaStreamSynchronize` or `ibv_poll_cq`). Since the key functionalities are run by the proxy, ProxyChannel requires only a single GPU thread to call its methods. See all `ProxyChannel` methods from [here](https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/proxy_channel_device.hpp). + +On the other hand, `SmChannel` provides memory-mapping-based copy and synchronization methods. When called, these methods will directly use GPU threads to read/write from/to the remote GPU's memory space. Comparing against ProxyChannel, SmChannel is especially performant for low-latency scenarios, while it may need many GPU threads to call copying methods at the same time to achieve high copying bandwidth. See all SmChannel methods from [here](https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/sm_channel_device.hpp). + +### Fifo & Trigger +One of the key features of MSCCL++ is to offload the communication logic from the GPU to the CPU. +To offload the communication logic from the GPU to the CPU, MSCCL++ introduces the concept of `Fifo` and `Trigger`. A Fifo is a circular buffer that shared between the GPU and the CPU. It is used to store `Trigger`. A `Trigger` is a signal that is sent from the GPU to the CPU to notify the CPU that there are commands in the Fifo that need to be processed. The CPU will then process the commands in the Fifo and send a signal back to the GPU to notify the GPU that the commands have been processed. The implementation details of Fifo and Trigger can be found in following sections. + +### ProxyService +Proxy service is a persistent service that resides in the CPU side. It functions as a polling service that receives the message `Trigger` from the GPU side and then transfers data according to the command. When we use `ProxyChannel` for communication, a `Trigger` is sent from the GPU side to the `ProxyService`. Then `ProxyService` will invoke `cudaMemcpy*` or `IB verbs` to transfer data to the targe device. + +## Implementation + +The core of MSCCL++ is implemented in C++ and CUDA. We offer both C++ and Python APIs for initializing communication channels. For interactions within the GPU kernel, we offer a collection of low-level device functions. Subsequent sections will delve into these interfaces and the methodology for transferring communication logic from the GPU to the CPU. + +### Interfaces +This section delivers a comprehensive overview of the MSCCL++ interfaces, encompassing both the setup and initialization of communication channels and the MSCCL++ kernel programming model. + +#### Communication setup and initialization APIs +MSCCL++ provides APIs in both C++ and Python for establishing communication channels, with further information available in the [Initialization](../getting-started/tutorials/initialization.md) section. Presently, it supports two types of transports: `cudaIPC` for `NVLink/xGMI`, and `IB` for `InfiniBand`. Users are empowered to select the connection type that best suits their hardware infrastructure. + +#### MSCCL++ kernel programming model +MSCCL++ offers one-sided communication methods directly callable from a GPU kernel, encompassing two primary API categories: data copy and synchronization. The data copy API features functions such as `put()`, `get()`, `read()`, and `write()`, while the synchronization API comprises `signal()`, `flush()`, and `wait()`. Demonstrated below, the basic utilization of the data copy API involves the `put()` method, which facilitates the transfer of 1KB of data from a local GPU to a remote GPU. Then send a signal to remote peer to notify the data is ready to use. To receive the data, the remote peer can call `wait()` method. +This operation is executed within a kernel launched with a single block. +```cpp +// Running on rank 0 +__device__ void gpuKernel(mscclpp::SmChannelDeviceHandle* smChannel) { + smChannel[0].put(/*dstOffset=*/ 0, /*srcOffset=*/ 0, /*size=*/ 1024, /*threadId*/ threadIdx.x, /*numThreads*/ blockDim.x); + __syncthreads(); + if (threadIdx.x == 0) { + smChannel[0].signal(); + } +} + +// Running on rank 1 +__device__ void gpuKernel(mscclpp::SmChannelDeviceHandle* smChannel) { + if (threadIdx.x == 0) { + smChannel[0].wait(); + } + __syncthreads(); + // Data is ready to use +} +``` + +Similar to the LL protocol offered by NCCL, MSCCL++ introduces a `Packet` structure designed to facilitate the transfer of both data and flags within a single instruction, proving particularly beneficial for applications where latency is a critical concern. The following code shows the basic usage of the `Packet` structure. The flag should be same for sender and receiver side. +```cpp +// Running on rank 0 +__device__ void gpuKernel(mscclpp::SmChannelDeviceHandle* smChans, int flag) { + smChans[0].putPackets(/*dstOffset=*/ 0, /*srcOffset=*/ 0, /*size=*/ 1024, /*threadId*/ threadIdx.x, /*numThreads*/ blockDim.x, + /*flag=*/ flag); +} + +// Running on rank 1 +__device__ void gpuKernel(mscclpp::SmChannelDeviceHandle* smChans, int flag) { + smChans[0].getPackets(/*dstOffset=*/ 0, /*srcOffset=*/ 0, /*size=*/ 1024, /*threadId*/ threadIdx.x, /*numThreads*/ blockDim.x, + /*flag=*/ flag); + // Data is ready to use +} +``` + +### The mechanism for offloading communication logic from the GPU to the CPU + +As mentioned in the previous section, the offloading of communication logic from the GPU to the CPU is accomplished through the `Fifo` and `Trigger` mechanism. + +The accompanying figure details the structure of `Tigger`, employing three bits to denote the operation type: `data transfer`, `signal`, and `flush`. The remaining fields specify the precise data locations for both local and remote buffers. + +``` +|-------------------|-------------------|-------------------|-----------------|-----------------|---------|-------------------|---------------| +| 32bit size | 32bit src offset | 32bit dst offset | 9bit src mem id | 9bit dst mem id | 3bit op | 10bit channel id | 1bit reserved | +|-------------------|-------------------|-------------------|-----------------|-----------------|---------|-------------------|---------------| +``` +
The proxy trigger format
+ +Page-locked memory is utilized for the `Fifo`, guaranteeing access by both the CPU and GPU. On the CPU side, a polling thread periodically checks the Fifo for new commands. Upon processing a command, it updates an incremented counter to signal to the GPU that the command has been executed. Users wishing to ensure a command has been processed can invoke `flush()`, which waits for the device-side counter to reflect this update. + +## Use Cases + +In this section, we will discuss several use cases that demonstrate the capabilities of MSCCL++. + +### Overlapping communication with computation + +MSCCL++ enables the offloading of communication logic from the GPU to the CPU, facilitating the overlapping of communication and computation processes. The code snippet provided illustrates this overlapping technique. In the depicted scenario, the GPU emits a signal to the CPU indicating readiness for data transfer. Subsequently, while the GPU continues to execute computation tasks, the CPU initiates the data transfer to the designated target device. +```cpp +__device__ void gpuKernel(mscclpp::SimpleProxyChannelDeviceHandle* proxyChannel) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + // Send a trigger to the CPU + if (tid == 0) { + proxyChannel[0].putWithSignal(/*dstOffset*/ 0, /*srcOffset*/ 0, /*size*/ 1024); + } + // Continue computation + matrixMul() + // ... +} +``` + +### Fusion of communication and computation + +Traditional communication libraries enforce a separation between communication and computation, creating a bottleneck where communication must await the completion of computation, especially when data dependencies exist. In contrast, MSCCL++ leverages its low-level premitives to facilitate the seamless integration of communication with computation. By segmenting the computation into tiles, MSCCL++ enables the simultaneous pipelining of computation and communication tasks. This approach not only mitigates the communication delay by overlapping processes but also significantly improves throughput by leveraging the low-level API for fine-grained control over the hardware, ensuring optimal efficiency. + +### Implementing customized collective communication algorithms + +MCSCL++ offers a low-level communication API, allowing users to design customized collective communication algorithms. The following code demonstrates how to implement a customized All2All algorithm using MSCCL++. +```cpp +using DeviceHandle = mscclpp::DeviceHandle; +__device__ void localAlltoall(DeviceHandle* proxyChans, int rank, + int nRanksPerNode, size_t nElements) { + int remoteRank = ((int)blockIdx.x < rank) ? blockIdx.x : blockIdx.x + 1; + for (int i = 1; i < nRanksPerNode; i++) { + DeviceHandle proxyChan = proxyChans[blockIdx.x]; + if (threadIdx.x == 0 && remoteRank % nRanksPerNode == (rank + i) % nRanksPerNode) { + proxyChan.putWithSignalAndFlush(rank * nElements * sizeof(int), remoteRank * nElements * sizeof(int), + nElements * sizeof(int)); + } + // wait for the data from GPU (rank-i) % nranksPerNode to arrive + if (threadIdx.x == 0 && remoteRank % nRanksPerNode == (rank - i + nRanksPerNode) % nRanksPerNode) { + proxyChan.wait(); + } + deviceSyncer.sync(nRanksPerNode - 1); + } +} +``` diff --git a/apps/nccl/README.md b/docs/design/nccl-over-mscclpp.md similarity index 55% rename from apps/nccl/README.md rename to docs/design/nccl-over-mscclpp.md index 364feeed2..ca362e9b5 100644 --- a/apps/nccl/README.md +++ b/docs/design/nccl-over-mscclpp.md @@ -1,6 +1,7 @@ -## NCCL Over MSCCL++ +# NCCL Over MSCCL++ -### Limitations +(limitations)= +## Limitations Current NCCL over MSCCL++ has a few limitations. @@ -8,7 +9,8 @@ Current NCCL over MSCCL++ has a few limitations. * Multi-node communication is not supported yet. * Currently, collective communication functions may not work correctly if the buffer address is differed from that of previous function calls while sharing the same base address (returned by [cuMemGetAddressRange](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g64fee5711274a2a0573a789c94d8299b)) with the previous address. This is because the current implementation performs zero-copy communication over user buffers, and it is difficult to efficiently inform all ranks if the buffer address dynamically changes. -### API Support Table +(api-support-table)= +## API Support Table The table below lists all NCCL APIs (v2.21). We may cover more APIs in the future. @@ -44,3 +46,26 @@ The table below lists all NCCL APIs (v2.21). We may cover more APIs in the futur | ncclRecv | X | | ncclRedOpCreatePreMulSum | X | | ncclRedOpDestroy | X | + +## Executor Support + +The executor is a versatile tool designed to specify how mscclpp executes algorithms. Currently, only the allReduce operation allows for algorithm customization. The following environment variables can be managed: + +- ALLREDUCEPKT_IP_JSON_FILE: Specifies the path to the JSON file that defines the algorithm for small-sized, in-place operations. +- ALLREDUCEPKT_OP_JSON_FILE: Specifies the path to the JSON file that defines the algorithm for small-sized, out-of-place operations. +- ALLREDUCE_IP_JSON_FILE: Specifies the path to the JSON file that defines the algorithm for larger-sized, in-place operations. +- ALLREDUCE_OP_JSON_FILE: Specifies the path to the JSON file that defines the algorithm for larger-sized, out-of-place operations. +- ALLREDUCE_SMALL_MSG_BOUNDARY: Defines the size threshold at which the algorithm will switch between fallback code and the customized algorithm for small messages. +- ALLREDUCE_LARGE_MSG_BOUNDARY: Defines the size threshold at which the algorithm will switch between the customized algorithm for small messages and that for larger messages. + +```{figure} ../figs/size_boundary_diagram.png +:name: MMSCCL++ Abstractions +:alt: MSCCL++ Abstractions +:align: center + +Decision Flowchart for Message Size-Based Algorithm Execution +``` + +This is an example of executing the interface with the executor: +``` bash +mpirun -np 8 -x ALLREDUCEPKT_IP_JSON_FILE=/root/azure-mscclpp/nccl/test/execution-files/allreducepacket.json -x ALLREDUCE_IP_JSON_FILE=/root/azure-mscclpp/nccl/test/execution-files/allreducesm.json -x ALLREDUCE_SMALL_MSG_BOUNDARY=16K -x ALLREDUCE_LARGE_MSG_BOUNDARY=1M ./apps/nccl/test/nccl_api_test diff --git a/docs/figs/size_boundary_diagram.png b/docs/figs/size_boundary_diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..41e3a38cb68905ebb70fc8f0e0976bbf089fd4f9 GIT binary patch literal 21789 zcmeHv2|Uzm+c2U~q*X$wNDHz~_MK#DN}Ftnv75<^U1P1Z*h&$iDMY7Gw!vg{T9746 zV;$2}B8C}-F^uiIXGZ6A&U4=Lz3=nB@3Va0`Td?>{&W5B|8-yYwJ-O1u0U za&d79SsXMu%EiS);o{$`<=i<%H!l|tq>CR)5e^FgAn@MB0|sASV2ud#^7VC*RWVan z+z(#KpY(#eA%aoh(dHQVM|nR0s~UkP@Iph4^BAI`wi`eU&z@b@7F9z{-kXC2R3n1kw$LTn2)JsmQ9D%I-G= zZ&?3RHC|>4_yKnT1hJU~1$=v@Ty}Gg*2dP^q2Xa7tnQ{;0_kEVPWdUrx!oHt|O}NT_k3cUs z7-|_C2UP$9;Tzy}?zcc!1RM@?U6$-JMi(TKB~JDMmqG3b-{sRS1I~ei|7HaXVEY4C zIeW68C%xPPJinh}nLHN!Ak53d^ZU-q`PeT`gPeZT>Xf!2I~52*iT35WHlfY9$Mjzk2)-C!&MDgz9`o?Zbk+jB0i ztSy3p`T>|{fS>R3YtGleaspuAweh6nJV zTmu7GQgQKP@fCg+#rpHdQsPJ!bt%=_ID@BjsO7e6mwz)jX_1>nMcVG5{FQ~(Tk9@b*cvqB(nnXI3`(6_vm zEzASKa=OPYe*cEE{vgaZ2o~Vw3Y$;3$P$IVi!0{0R=;Px&JX{w!;UX7_$$f zvftDX;SGpbX3z5WfBfl>MEwWfv7-#o(*Ok*UoQ{f=UK=1g}JlX2cFIXJMmyEzrydB zslsxnEc$yo#AbYt?Kp_H=?5rDrAe_?#xF9`1jm=)QoL`=E`S0i1NSLon zfLG9OxBnxl{dp41$>D#^;Jb`rFl#{W$uKkz&K;GAyxE0*wp{rz8xtnZtFm7qBO zp5FgCnW`X@`@6IHWkLLb`}iv}2hNC- zNmvH{LjX}xTZaA1M34g<8h8lM-a`wjUdZK%0D!M{AZl{f6GBHpYtyV z{f}&&{9_#S|3XjY7f)|xc+}kVzvAitp61H`w*f%ycL6~4dui~m5CGWj=06w)EW4r~ z#x6hRmh89R2R)n@%O9}2b8JY3XHE+jmo%4!iQ%yjCn`y(ZS(PkrEf~7DHucRn0p`i zGZODV_=vh{clyMkp_i9GZPZ`V zur)W#E0H=SR4_w$6^Wl)x;q+7L*18y-!?T zuKc))n^(kuYx$o&vTbn~r-#!R6@KX`E*{Q5COo`>T0F}?Sl~??r3+3zIa09x2N-W4 z^ZWY0mi!a>0=13SM2+7Vtqn#Uo3jwTvumI}f$A?e#Hin%wlxLP%hUWm zRD|EjgnPt0JD2~o>TsXl@U*LK(&EA#V>k^TiJkk2wB+kLIqvV*^H!I7Dq!xS{ezME zx^@Z3!xo<~ftbo_KuWpaLP>C3SG&>rI8IsyqpDk9GqRpV-lTl|)iF|_HpAdR4<)k2gYa?^T3 zi~94L_GXpM=oN671bOi_8i6;XJ%_S})QK7#7BdXF7hyd1?Am zu-t(p44fxMz%WOd9UltXhDRF-ohwz#FKIx&8c>MnR_5jC+aoySPh4HA9lH+b#))(CpjXH#MoX zm4=R^u+(+w)BZ0X+8z2FL~cX*$L4J7-e=cQ;?Y`VTX!em{(SeC=6g(b*~Zt{n|oT? zCdF`m$J0~YE>A0>v*&MHbc^21!Cp-CTXjN(iPtQ83Co# zg|+kYr+vm(8ekW_x<=&+=MqQtyke>)8xq9=%Yv=L`#VI#hx)Tzg44oo`(gV_ii5Li z5j8lp(*t43L;bJ~9!kY+k&`6PN~+hYAm=9m*u#Qf{e1{%bYpPohlocxU#_bc_(;~C zfAz8O+eFQ1eiL7c9E1Ore=A~UO@Bm++Gf(CC`21k;p(R)dOk)A&2q^F=~3c4n@ji@ zh7HeR!yGx&)$ac5AC3R&^2*r0!cyVO4e z`*66z%emE)=f1#BucW(E`yWbHpt&1_&N27rplPQ1R(*=evu|G?C_FXmJsOxui>;#V z-G*C7pqJkXB3ajLRhj&AbtyI2~WTGH{N_#UKih?VZmrm?CcWm8Z}) zRGT?6%3F6+v7*(^ech%^xZJ1I^s>F=lZVNf;r$5A`N27OL<;=APr$?bKWPex5Ku9n z_k-YXfaN=nuu_Iqgl8TCHY02*w^8Zzvt0VBv@w-FK6)i=FJ5E|Lpp2R{-VuGKb(GD z&Uj^Z=;-5PP0Hts38=K|xApaIO~u7_EPmXLMUybJ%r|gG&{nJhRo`ZlXyQAasVyF3 zxNl$m#2F%m;$3L;K6n?I(PHfY#Uf0M>hF&qPiMAVMZ=p5g69sME*{uG80~Z-Z2l!< z!N>r&bV*Gu{|aENnxiyL$w^M3x8S}#4rWeU6LOtSz8!xpAZH)%ClC*&gnmqPXn>CMqNa@pc`|&U^|cIYM>h{2hVX5!vNzb z%zL5ZiB}BD`S>ZFa4GLNb`EH;(*T;kL%L0eJ0@(h?~W>a(}H_|f=dZzx7H0q!Dlkk zxE<^a6M6x_`C*L)nA1R*-sP*x4cRn4-wTkq{)OZ(p8TaJ{}kIX$c|uo&5n8K6OCVp z5}ZRm^7V#($%3U=YZwY~^J*BeTt!d89d;3DP{Zf`_Rj9LdNUZSyLyg#jH*`CU9PL> zX-;7l53JI++xtGf*qjL|e!t6-y3RoQL=?v63#aJBoaGs>VPyMaNkJ2>k&Cuwi4waZ zq=1yM(pJ5aGe7f$eu!S>wm3g~BtiYjc-cFfO^4w#Qg`5-fTZd-`pmvl;(iEa(?DxT;YaZ-_jPhsp0m= z*2Tf^j)wCd1&D2xpRVKh%GaxhA2T(GzB@*{3Vk*em5vnSnLn)&Q^}nXJhGjKf8!nl zsM5{M7o3v2K@Nk6O4kW3dA77TcfWLG*WL3a{ej(%y}~cL_@$wiz<(-aCLCfLfST!Z zr$_ja2bqmmH|=$sqhb+_BdBbCwbQ_v4iA0#%y#$`X^c;&QsoC8G3}to{XkrS1uB6i zT+tGMU7FHJeh#~;cVkK^+O*1?eW%MV_fEl$P`h;|^axIV}d1-RnE`G~A_-ZA8!11VW#N$rFi3TV6 zw8Uz?_H_9}_2~rL!Fv70r&gI73e45sQnL-_9XFdaK;~*19X?)jvgp}Fa_Ne@#(vcM2L*pJ%uonG>DG5$a6i-Oh^>X=1f}bG8yZ zS^WDWwA)>sc^fx9sUL5Fb!^WQ0Z>b&B;LF>QY!~${Gd*5Xq#!;j(OGD?DNQMj6hu_8 ztx;JO=s3+vrNoranmmZ<*#hVnJu{z5UP<7JhH}>-3MA%apm>Af8QFp~au5K%;5Aa( zc=330j0nr8eY5;xfWhSCF123yq?~g$38jRxKC?}#7 zAwapPzK)Mg%EkNh8MZ5UMF3Y;0j{uV>&CV`)Z|I$1GJ^DKD&xnq`?s!S$p>?6U_eT zwr`eDrN>v~`J*RT3kJLIYz6(gfi%4+12i9-GQ6Kn%esbrB4PI$97ME15kTXKC-+-u zO7mP4xRB@^6Xs%VV#t3}2`rGW9U_2aZA&dkoGE+Rl+pcA-iZe8(QT_uhhkVNfCMZY zvT5XACvPPkFu$I#+Go3<5#~H=!K`}zN`C3*GkI&Jq0Vf|iV>N%C|;G-7o6l(L|+N> zODE&F1|Y#lg)bYpa2R`P!x0glLDquCVoh}rL+X^!T+vKHHf3nyoeUsg=_#Y@RnkGA z=t2cBiY>mP|TJ0jDc;sn){0*^gCz8dHi(~d6CJdU8Lk5I2KSZtgOz(DP z&{6}p%Ukx!5Qo&e?%FRtyQ2(`dF+UNt6kBrM_UZ;Kj2?1Zuho5IN?o6_cw*x9+hEj0r&7_=cz2cE7dk<@U)-mkN1JFQ>+rX!l(i6o;L<fUmd zSu?A=uvCbBPf@nq5MJJ1t-Rg6w^AT;0Z$f;L|F4Grn>CqVI+GbC!E7_*s^eSaHYOwLv zk(z)vKOr8EdN_8d*WD^b?8Fj-paCdaVr!Wp34aeVcVW7-BsQ#X;$1wEHq}LwQlFYk z!(_V>U-7k7u0^7ro~jP;@(~K}!_XcD;hvw7Vmx^%?9)<_DdMJx9NIE;zw%3Vcq7PB zi39F06o@s$<@QN%cUvroM_m_*H~nIg1r83?mg`JXqq@2YWCyo$9aQ{J=vx`CjK)DI za3BtXOG~Si(HKW4RzIX%>JECkSH->g_BO~h{5N;9RduW7url?@w)KSZ$^qfoPolLw zqmSluBRnE!p3OI2-PWXJ6kENZUQ!eA3XoXkThltxqPKegT3&?^F9bz_KbL7>DhecMd%5MtM->9#1)p=|#NoCnt|vW%Q4Vmpx{2oSv}jot!^}mpS*SsM3m7iE{F} zai*tD^0aG{gZ?LY3{}=GY+q>&TDLti$==9AiBBc4(=Xq;#-T28>BOgww4m;?eG5;= z{ER{H)U8hiUHF02h=BD3`mz4!ADlj@4by-y+T$^cuUe3mbLF@Lv+x6-PNAKZszt4{ zdfUW=$tt3_MSNhAmpcUIJJBb(2e`H`S1Lf0^6l07gbl&d_8$1LFI%jnA}3oQ_K=_b zo2vQ|$maNJw3TX6zxWF^$U~Y{2c^>|tZ;R!N)D~!BBHD1ecc6u-%8>5uAfJlZI!OM z0u)+YspJJw1*eYFafmb*y&e($}5JMYClbV>eeb@TtE3ddd2C{uhcN6`@FUQ z1wZ*!VOL{^Qd6G>D-ff6pO#rVH?LtT;>OlHo-uMKkkl!($7SgPUu$&bETP@kMt14X zTG*9)A5Z!)R#I-yRDt&&VCBC#klGn{{Nx&nMX98wMzF%#O(0j8bSvcz47Z)h_6;Em~)>dJRqsAhaW(H$)eZjJu?Q{dv93#((L~3{kVZAE=L;Z z=$_0ccjXE)UHj^ftO%DQ3JpET_Fwy^SL)eN>};9xt~ZFgAtl`^_iR$}uXHhw=(tZM zZ))7PSfh>8#6N4or%oN;Qfv2iw}tx;XsE7wmhw_hz21{e~^VZWDNi zMV-$p3qtKu@jX5<1F87WJ@Ml7dm!lwy`SsIybtSh_98-DsCp*!mveVeZ{J7|&~&6n z@28mRE_Irpy1%M@byB*J4fA=&+MRQ;?ZUT)w&3mdzJpgMEDn#ANotWYFRK~q!i63p z9*}O56$^0DWID3eX8Z&c`!az@7NEGCi~7sG%HR5jl`!Gk*a(5jUVZ@Z9h|JC`-F_ zik1<;m_n~jp)?zIRK3=fe^j3oep#t#tS!!hcri~Rv#Kjjum<`$w8JYl3 z9S+~pxuDgLw!JZnnJz>n*)wl^?x3ycT>LuSQi9D~=sYQq0X^f6n+mDMJ480`-9}~3 zy1%d%qs3QiURrR08^z!rIMg$4kd;zESHhs^8MDvR%esH>j(0lGyeU&ufq}ubt@BU2 zUxt|S(ruhV(+Rt@X3;E9$IKU^P~RU-Eqm6?*zeu{Dkc3P!+Kwy9cZxAsPINKd25~! zWoYGTXGQ@Ul|+hTWGG`t-N|J*GJG}_e><$Qn;*>X;t;wOPQ_p%`m%j;n-}6Nus{Fq zvfSx{zVT%jW+hm9`ljxVe9?q~;i@nFTwG$&th)e7`g&oe_XT*oJP7E{inNuLs@G3C zh31{7`Z!-}R8{4stvE2fcu%yjEb^;-)2LWrkZHn{OipFDYDVTw8**TlYjDA5wvuIA zb&QIk5fS@UKTkuH@g7EFnJ;qVEC$@`2euvTjF|cQa&1$0?IPW2ZAzU&mCDbAESQQ^ z?$C;c$5MwUw*>-?6Y3i>w8GW&ixf#s(~^4Y`uF#Y_QZ`>JlYqz*jl*Ojoe%_`;}36 z1ly_$tS)V|N2XpupOyf@r!VDZN|CF{m0~Y+>P`AwkGp{S`Vl;BD$ERWVDvzSs)tQk zKnvsi75kvREBTdi*3|Uyxv(>fBpoGKnmD#uf%XwjCZ=~39l?5hB#HRBBouuUsEPvfIT7cJDRq!ocHB1P8q>OJBI>9;sfuL<`@0u_mD_JYO=7&SVl79$-ku z@GM@yI4#99z`b`i3yt#9zqYe{n)(p_%zK)GjYXvBtA!eL;S4RfcJE=m*<7$8v#wh4 zibZwZmaIw-$e13C*^!`%82L7RFE>vq40R?+lL3m;Rk-3l;%HA*1XHbQs9HQsXYwPS zkAh2c*qI{d+%-^mnEL+fXPj57cEMs4gQ3q>Dg)p~g+GG2h@ z+j>wduKTN2{5P%@|ApHFGBAi1f-s}>q z%|ao0c?Cf)XLF;q3;*Xmn1SGH3Khv7chM)Ndt8&O>fcHbg8Ti=tA;XsSiRQo-JB13 zX-+p-eIMGhNt-%Em9^mr=(3!dcb9Wr#k*$*K$GSebY0!2tD(ytk@e6iJ(el3>yC8n zs_^24)IHOQ3CNCB(4O%=R;@UD$0dC1!q8~Qv{L&n` z)!}Zl;X)1Leac}p#rB9xes-Am5Kk{?j9e8B8A|tX1Ksab@&YXwNjPGz9B~c7bI3&c z6zJ?()hF`bR6J8RenZyMt#)C$iG|3~vP~>7%dXq>n?1Y=He|%cv*`eHI492;T+dlr zV5ml1%ga~e)dEZjvTEQKk+)_w2XFBUQckxuT;Bl$blsYr{1SJwNe^_Glb>%o=^%q8 zN8oe-E7~Vz_MUj0yd&{|k+9aVbQR-$cxh_W zG{6uQ<+O8(JJnnYeJoLDqT$4qcrwq|-8|94pz|WhQ=-iyEBK0qK~R`-IvEcKC64mM zc;u2cQ)^+b^X#FniQaA8yw#v>{fRXvsFt(2tARDH-ktc3$QL4f#XuPJo+Lyxd3j_7 za>Z)xe(E{`h448Bw3{La5+IHD;P3cdFT}brTT=1Y-!KAZe>(oitsz;aW>4LW0=`gw9rsCTulc$w$`K=$tj3sD zM2dH|Jruh-47BM7L0=**n2d%zJRgz7#iPmZThZ&zYR*(dej5X6@=aYvzwKs)(>tX= z!`syoHb2xREx{ALhR3T8(f3k0;@zRAl038gpEV=U>Xbv!)-ME2LC~7ojUk6*fqpA; zb}+4p@X&TrN3Gt7JvilxRs=&`TvVgXI&k+%@2hUWFm$%bd67MYdG@VN+|Q%)(TPY( zEN51`0n9qAeFNWC`SUP`(n{ZISq98D4P_z}S=nnTC1I5B&VJTt!+->42!TGzlxqfN z(M$7R^h;g}YhTto=p7MoVgX1L$Le$swD(9X&(N=~mo5;Mwbtanx$_)o9K5OYANUNU z;Mr*)713@NxZrwle))=_fjVn?>?`Vh(Lj+$D^t(6=gF|-C@o1b8q8>d5K(XP)=RK% zaE*>u6WP&E*}4U zdAM#1rsu6eIbHla1?Q*o^>00OtvB;&Zv~!%)pO2Wp1jXCd+6-|4WDZ-=6XHuDu8Sa z_UT&>M;5@$HvvA3T6=-k?bqik);3%>-?sbi%&T3eVHp^`v#%}>lc%_CNwk}`HrK5o zq_QW~pufo(ukQrI_LPJkp6CR#=Ry?mmMi|-IWK%K@Ko{9fz{e4YEk4v=fBnOcmT9t zf$5zEE(jDPf`Pi0pP@Bp#&mFfgDkk9McNof3hXu=b)5y%GjE_EkJ^O^JZ!E;tMOFv z(t%?-#2UVVK29+(=hC(aq*DAC5T|-TdwSD7t(E# zfUaN%%{@G~SeIc->|vV@74V`QVPwMhV|u6#4N`cXYwpu^J(EU))j&sL7gsDKl}D7G3rd_YN8 zugpKd*;9+fq^SBLeu6b13cEQ6&&Fa6ut2Klnv6OYTe?3ldz1u#L>;d4Q=-8MK*|Z_ z3Ujv6U|9D?LZ@%W8+-)EfAMCw2xreT0K8vPdplGgJS_CZZ{^HIPXfaGE50;~gZny& zGdZY64&mVR-4!DR2Y7nHF1;;1D>z6>z%gKNY!!KAfgV^I!Q+cB_{t$p?+JQr~Mddpi2_LhWW;BfCQ<@E7i07#4pJ)D_xG!oz*thC+3lgGmC zk`LwJ{?%sz+;m@vcLE!?6~L`Z%qoQ!0NnS4TD>^9j{@9H?l~R&7XYtb_V=t_W;l!P z0$Vu{2m$Vx{8P(x12DJeSd4)qi*7$n3J13dz`emf3k81+a6h$fy~4rG(z2IkIgR`m z1OeTHJ!?6*KY?#wXWDM($z$Qp$v?xvtpY@=ngRwPbQUkM)B;Vm(oTXcTYOM&D}M&k zSqah0Ph=ejstvZJX2c8b5K%OMuEn|@<{-HWL@%mf6Uu)Rh(~GN-TR!2Ne$pvD}UV( zKf&(IiwgqC>0Xbf%jemJQQzlzQekK$hL|fA`shp*i7UHyY9$HPcixK zZbb?*!I9w#W@yevEO)h4255$fD(c2}Ok6yH8^DnYUl%C-ihTz-SAwUz&3g9!tbf3D zU0!;1`4&3c{8*SoJXyDu*0nTSaNGcku-~dWdy52M29W#dnid=$0CSA42CNUBwM@+3!Z<-CL`VU{u zOdtLj+rJ2KX&uHoRpifR`LjBB*Skj_oiMFc`Qlko%1GEd`pGiCp;hYi2U4wL_^>dw zPo>DRmpV|9g|=E)f-gNC^-dhmE-hX_4_4BzjY1wQj8PM=`FE46hNpyI5ytZS{8@c_a(&QBb3co&h@jxcOS=~A6CbK8{OqkSD0{a zbX$v;O|6ts9P8nJ8yuSq<{L63;U{ot@Cw_b_Lb4^NVP;mQccx4N$S^;)g!b;eI>oZN!ZJCAjQ%}z$rzO`IcmeL8q4F}^9!oJRakOlwR-Trg^&GsH|cJIpS zVRpTS=#V0Eh_?~v`t|OSH}^u+gO>bbeLEm3aL@>~o1%{gkm zIv)AK{3$tJeQKhf?T5-iDCvIiaH6n4QqblAO?)EDA|+hGxk0rcbK<-z#V)3LPo?wt zQY)pieSLN=W-!xl>fxn*8?NBE_t!H>TdG_}N6jOxBj#SE_^s%(?mnD$cYf^W;kriQ z7o;_mKE16H^!j~{9h04z$;!L31QN^^FBYbniJ)N5 z4nBIwOp+w1&yT9)2Gw82ty5p3MOeC0@)VV7OSWs_{XVaz;Qe9Z&VjFC$oj0wXO@JH zQ>8*|cU1+f47p&rmn6&w8s~ku_c}=!S~f;f=SG)&{jKr0Z+*l-S4dzAFXn^G#)F|p z*QQZCzPt@DT}aFh92!B+61)o2oi>CO-_^4&s-j(Ty4Sc+_u(Gj>hPB2Ak_jN`A#XJ z#4wvxW3>pi*g}1LrcQTj#U#4)lVYre;8%Ra2GKCIPYLN<=CoO>Fop89$$2Swua#8D z^TitG8F15e?&`=6_*|h=ja*p7%Y{pIjIp>1<9@8m1LsnWnleiTK9=v61nxuo)1xnM z@lcOd!D!6U$O&PJos`^TTPQ^LXQ3=P>0VU|_U$vnx0*}$J=3(^TW{=rU~_G8iLh5O zrBu#pY!kt875V60_$s=Rlz+%Z3Ho^(Nz~DehXO5H6JK7U3*HPM+v|i7Lx17qVl1 z3fh5Mv*AXC*n$1bE~ChR*#^-6y5jCytLyS?cnR zop^T7JH?GXrHD|zKy*}_kU4W}1NaX@!nd?xX}!Yg)7^{y#Rr-PT#jZrcA+|@g0IAh z6X-u5pY%QiX(pBy_#;x3bzf&aWz;CDCSi(-zm!p}xSXE3c_o25UZ5*8d9Pw-e^_U} zHcqm8W$4O4)MOXU<-*4;*H9Rw( zHmXQ|lJsSD=unFkyawGQ*Q~+5YEHI3ibE)TYV_`2S$()+YapZUHD2?ZP;F3geHj)z zm!woMS`jp3Gs!$cs>eF^<9Y+b7pBjGsb5&jFgE-^niyTsG2|gAXQf3zKCt$}qft0* zJ@;gN+}Igco>)7*533#Qz{uwOhKS}80mh6)w5DQ1l}i<=uRkaY8veO0w%s~*@ zulL=&_)F>%WpD_94#{g6|B|}=wdxoC{xkR$y|mg5ntxXFg}4*=|2yTfFtsr$GCC9U Fe*mOwHBkTn literal 0 HcmV?d00001 diff --git a/docs/quickstart.md b/docs/getting-started/quickstart.md similarity index 90% rename from docs/quickstart.md rename to docs/getting-started/quickstart.md index d09ce0baf..8c0982e3e 100644 --- a/docs/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -20,6 +20,17 @@ lsmod | grep nvidia_peermem ``` +## Build with Docker Images + +We provide docker images which package all prerequisites for MSCCL++. You can setup your dev environment with the following command. + +```bash +$ docker run -it --privileged --net=host --ipc=host --gpus all ghcr.io/microsoft/mscclpp/mscclpp:base-dev-cuda12.2 mscclpp-dev bash +``` + +See all available images [here](https://github.com/microsoft/mscclpp/pkgs/container/mscclpp%2Fmscclpp). + +(build-from-source)= ## Build from Source CMake 3.25 or later is required. @@ -54,6 +65,7 @@ $ make -j mscclpp mscclpp_static $ sudo make install/fast ``` +(install-from-source-python-module)= ## Install from Source (Python Module) Python 3.8 or later is required. @@ -101,7 +113,7 @@ $ mpirun -np 16 -npernode 8 -hostfile hostfile ./test/mp_unit_tests -ip_port 10. ### Python Benchmark -[Install the MSCCL++ Python package](https://github.com/microsoft/mscclpp/blob/chhwang/docs/docs/quickstart.md#install-from-source-python-module) and run our Python AllReduce benchmark as follows. It requires MPI on the system. +[Install the MSCCL++ Python package](#install-from-source-python-module) and run our Python AllReduce benchmark as follows. It requires MPI on the system. ```bash # Choose `requirements_*.txt` according to your CUDA/ROCm version. @@ -163,4 +175,4 @@ mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=$MSCCLPP_BUILD/app If MSCCL++ is built on AMD platforms, `libmscclpp_nccl.so` would replace the [RCCL](https://github.com/ROCm/rccl) library (i.e., `librccl.so`). -See limitations of the current NCCL over MSCCL++ from [here](../apps/nccl/README.md#limitations). +See limitations of the current NCCL over MSCCL++ from [here](../design/nccl-over-mscclpp.md#limitations). diff --git a/docs/getting-started/tutorials/customized-proxy-service.md b/docs/getting-started/tutorials/customized-proxy-service.md new file mode 100644 index 000000000..232f81066 --- /dev/null +++ b/docs/getting-started/tutorials/customized-proxy-service.md @@ -0,0 +1 @@ +# Customize the Proxy Service diff --git a/docs/getting-started/tutorials/index.rst b/docs/getting-started/tutorials/index.rst new file mode 100644 index 000000000..7ee91b194 --- /dev/null +++ b/docs/getting-started/tutorials/index.rst @@ -0,0 +1,16 @@ +Tutorials +---------- + +This tutorial section provides a step-by-step guide to help you get started with the C++/Python API. + +.. toctree:: + :maxdepth: 1 + :caption: Tutorials + :hidden: + + initialization + proxy-channel + sm-channel + packet-api + customized-proxy-service + python-api diff --git a/docs/getting-started/tutorials/initialization.md b/docs/getting-started/tutorials/initialization.md new file mode 100644 index 000000000..0bdd8ad45 --- /dev/null +++ b/docs/getting-started/tutorials/initialization.md @@ -0,0 +1,71 @@ +# Commnunication initialize with mscclpp API + +In this tutorial, you will write a simple program to initialize communication between eight GPUs using MSCCL++ C++ API. You will also learn how to use the Python API to initialize communication. + +## Prerequisites +A system with eight GPUs is required to run this tutorial. + +Also make sure that you have installed MSCCL++ on your system. If not, please follow the [quick start](../quickstart.md). + +## Initialize Communication with C++ API +We will setup a mesh topology with eight GPUs. Each GPU will be connected to its neighbors. The following code shows how to initialize communication with MSCCL++ C++ API. + +```cpp +#include +#include +#include + +#include +#include +#include + +template +using DeviceHandle = mscclpp::DeviceHandle; +__constant__ DeviceHandle constProxyChans[8]; + +void setupMeshTopology(int rank, int worldsize, void* data, size_t dataSize) { + std::string ip_port = "10.0.0.4:50000"; + auto bootstrap = std::make_shared(rank, worldsize); + bootstrap->initialize(ip_port); + mscclpp::Communicator comm(bootstrap); + mscclpp::ProxyService proxyService; + + std::vector semaphoreIds; + std::vector localMemories; + std::vector>> connections(world_size); + std::vector> remoteMemories; + + for (int r = 0; r < world_size; ++r) { + if (r == rank) continue; + mscclpp::Transport transport = mscclpp::Transport::CudaIpc; + // Connect with all other ranks + connections[r] = comm.connectOnSetup(r, 0, transport); + auto memory = comm.registerMemory(data, dataSize, mscclpp::Transport::CudaIpc | ibTransport); + localMemories.push_back(memory); + comm.sendMemoryOnSetup(memory, r, 0); + remoteMemories.push_back(comm.recvMemoryOnSetup(r, 0)); + } + + comm.setup(); + + for (int r = 0; r < world_size; ++r) { + if (r == rank) continue; + semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, connections[r].get())); + } + + comm.setup(); + + std::vector> proxyChannels; + for (size_t i = 0; i < semaphoreIds.size(); ++i) { + proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel( + proxyService.proxyChannel(semaphoreIds[i]), proxyService.addMemory(remoteMemories[i].get()), + proxyService.addMemory(localMemories[i])))); + } + + if (proxyChannels.size() > sizeof(constProxyChans) / sizeof(DeviceHandle)) { + std::runtime_error("unexpected error"); + } + CUDACHECK(cudaMemcpyToSymbol(constProxyChans, proxyChannels.data(), + sizeof(DeviceHandle) * proxyChannels.size())); +} +``` diff --git a/docs/getting-started/tutorials/packet-api.md b/docs/getting-started/tutorials/packet-api.md new file mode 100644 index 000000000..8f4ea7074 --- /dev/null +++ b/docs/getting-started/tutorials/packet-api.md @@ -0,0 +1 @@ +# Packet API for latency sensitive applications diff --git a/docs/getting-started/tutorials/proxy-channel.md b/docs/getting-started/tutorials/proxy-channel.md new file mode 100644 index 000000000..fec5c4cc0 --- /dev/null +++ b/docs/getting-started/tutorials/proxy-channel.md @@ -0,0 +1,3 @@ +# Offload commnunication to CPU with ProxyChannel + +TBU diff --git a/docs/getting-started/tutorials/python-api.md b/docs/getting-started/tutorials/python-api.md new file mode 100644 index 000000000..9e6c5627b --- /dev/null +++ b/docs/getting-started/tutorials/python-api.md @@ -0,0 +1,92 @@ +# Working with Python API + +We provide Python API which help to initialze and setup the channel easily. +In this tutorial, you will write a simple program to initialize communication between eight GPUs using MSCCL++ Python API. + +## Setup Channel with Python API + +We will setup a mesh topology with eight GPUs. Each GPU will be connected to its neighbors. The following code shows how to initialize communication with MSCCL++ Python API. +```python +from mpi4py import MPI +import cupy as cp + +from mscclpp import ( + ProxyService, + Transport, +) +import mscclpp.comm as mscclpp_comm + +def create_connection(group: mscclpp_comm.CommGroup, transport: str): + remote_nghrs = list(range(group.nranks)) + remote_nghrs.remove(group.my_rank) + if transport == "NVLink": + tran = Transport.CudaIpc + elif transport == "IB": + tran = group.my_ib_device(group.my_rank % 8) + else: + assert False + connections = group.make_connection(remote_nghrs, tran) + return connections + +if __name__ == "__main__": + mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD) + connections = create_connection(mscclpp_group, "NVLink") + nelems = 1024 + memory = cp.zeros(nelem, dtype=cp.int32) + proxy_service = ProxyService() + simple_channels = group.make_proxy_channels(proxy_service, memory, connections) + proxy_service.start_proxy() + mscclpp_group.barrier() + launch_kernel(mscclpp_group.my_rank, mscclpp_group.nranks, simple_channels, memory) + cp.cuda.runtime.deviceSynchronize() + mscclpp_group.barrier() +``` + +### Launch Kernel with Python API +We provide some Python utils to help you launch kernel via python. Here is a exampl. +```python +from mscclpp.utils import KernelBuilder, pack + +def launch_kernel(my_rank: int, nranks: int, simple_channels: List[SimpleProxyChannel], memory: cp.ndarray): + file_dir = os.path.dirname(os.path.abspath(__file__)) + kernel = KernelBuilder(file="test.cu", kernel_name="test", file_dir=file_dir).get_compiled_kernel() + params = b"" + first_arg = next(iter(simple_channels.values())) + size_of_channels = len(first_arg.device_handle().raw) + device_handles = [] + for rank in range(nranks): + if rank == my_rank: + device_handles.append( + bytes(size_of_channels) + ) # just zeros for semaphores that do not exist + else: + device_handles.append(simple_channels[rank].device_handle().raw) + # keep a reference to the device handles so that they don't get garbage collected + d_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8) + params = pack(d_channels, my_rank, nranks, memory.size) + + nblocks = 1 + nthreads = 512 + kernel.launch_kernel(params, nblocks, nthreads, 0, None) +``` + +The test kernel is defined in `test.cu` as follows: +```cuda +#include +#include + +// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing +extern "C" __global__ void __launch_bounds__(1024, 1) + simple_proxy_channel(mscclpp::SimpleProxyChannelDeviceHandle* channels, int my_rank, int nranks, + int num_elements) { + int tid = threadIdx.x; + int nthreads = blockDim.x; + uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks; + uint64_t my_offset = size_per_rank * my_rank; + __syncthreads(); + if (tid < nranks && tid != my_rank) { + channels[tid].putWithSignalAndFlush(my_offset, my_offset, size_per_rank); + channels[tid].wait(); + } +} +``` diff --git a/docs/getting-started/tutorials/sm-channel.md b/docs/getting-started/tutorials/sm-channel.md new file mode 100644 index 000000000..191e47b36 --- /dev/null +++ b/docs/getting-started/tutorials/sm-channel.md @@ -0,0 +1,3 @@ +# Using SmChannel for Intra-Node Communication + +TBU diff --git a/docs/index.rst b/docs/index.rst index ba060047c..dc5604364 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,11 +6,56 @@ Welcome to MSCCL++'s documentation! =================================== +MSCCL++ is a GPU-driven communication stack for scalable AI applications. It is designed to provide a high-performance, scalable, and customizable communication stack for distributed GPU applications. + +Getting Started +--------------- +- Follow the :doc:`quick start ` for your platform of choice. +- Take a look at the :doc:`tutorials ` to learn how to write your first mscclpp program. + +.. toctree:: + :maxdepth: 1 + :caption: Getting Started + :hidden: + + getting-started/quickstart + getting-started/tutorials/index + +Design +------- +- :doc:`Design ` doc for those who want to understand the internals of MSCCL++. +- :doc:`NCCL over MSCCL++ ` doc for those who want to understand how to use NCCL over MSCCL++. + +.. toctree:: + :maxdepth: 1 + :caption: Design + :hidden: + + design/design + design/nccl-over-mscclpp + +Performance +--------------- +- We evaluate the performance of MSCCL++ in A100 and H100. Here are some :doc:`performance results ` for all-reduce operations. + .. toctree:: - :maxdepth: 2 - :caption: Contents: + :maxdepth: 1 + :caption: Performance + :hidden: + + performance/performance-ndmv4 + +C++ API +--------------- +- :doc:`mscclpp ` +.. toctree:: + :maxdepth: 1 + :caption: C++ API + :hidden: + + api/index Indices and tables ================== @@ -18,9 +63,3 @@ Indices and tables * :ref:`genindex` * :ref:`modindex` * :ref:`search` - -Docs -==== - -.. doxygennamespace:: mscclpp - :members: diff --git a/docs/performance-ndmv4.md b/docs/performance/performance-ndmv4.md similarity index 100% rename from docs/performance-ndmv4.md rename to docs/performance/performance-ndmv4.md diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000..82bb70d03 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,3 @@ +breathe +sphinx_rtd_theme +myst_parser From c6e06cfad7a49ad17ef1fff560498bf20ac8c427 Mon Sep 17 00:00:00 2001 From: Caio Rocha <164253795+caiomcbr@users.noreply.github.com> Date: Mon, 21 Oct 2024 05:45:56 -0700 Subject: [PATCH 10/12] Executor AllGather In-Place Support (#365) --- python/test/executor_test.py | 25 ++++++++--- src/executor/execution_plan.cc | 81 ++++++++++++++++++++++++---------- src/executor/executor.cc | 17 +++---- src/include/execution_plan.hpp | 13 +++--- 4 files changed, 95 insertions(+), 41 deletions(-) diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 3e0c369d3..5fd59f2bb 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -77,6 +77,15 @@ def dtype_to_mscclpp_dtype(dtype): raise ValueError(f"Unknown data type: {dtype}") +def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name): + if "allgather" in execution_plan_name: + return recvbuf + elif in_place: + return sendbuf + else: + return recvbuf + + def main( execution_plan_name: str, execution_plan_path: str, @@ -104,9 +113,11 @@ def main( if "allgather" in execution_plan_name: recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype) + if in_place: + for i in range(nelems): + recvbuf[mscclpp_group.my_rank * nelems + i] = sendbuf[i] expected = buffer else: - cp.random.seed(seed) recvbuf = cp.zeros(nelems, dtype=dtype) expected = cp.zeros_like(sendbuf, dtype=dtype) for i in range(mscclpp_group.nranks): @@ -116,9 +127,9 @@ def main( executor_func = lambda stream: executor.execute( MPI.COMM_WORLD.rank, sendbuf.data.ptr, - sendbuf.data.ptr if in_place else recvbuf.data.ptr, + determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr, sendbuf.nbytes, - sendbuf.nbytes if in_place else recvbuf.nbytes, + determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes, dtype_to_mscclpp_dtype(dtype), execution_plan, stream.ptr, @@ -129,10 +140,14 @@ def main( executor_func(stream) stream.synchronize() - assert cp.allclose(sendbuf if in_place else recvbuf, expected, atol=1e-2 * mscclpp_group.nranks) + assert cp.allclose( + determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name), + expected, + atol=1e-2 * mscclpp_group.nranks, + ) mscclpp_group.barrier() - execution_time = bench_time(100, 10, executor_func) + execution_time = bench_time(10, 10, executor_func) if npkit_dump_dir is not None: npkit.dump(npkit_dump_dir) npkit.shutdown() diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index bda49791a..09ebc6d8f 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -148,12 +148,19 @@ std::vector ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c } return std::vector(bufferTypes.begin(), bufferTypes.end()); } -size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize) const { +size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const { + size_t sizePerRank; + if (this->inputChunks.at(rank) != 0) + sizePerRank = inputSize / this->inputChunks.at(rank); + else if (this->outputChunks.at(rank) != 0) + sizePerRank = outputSize / this->outputChunks.at(rank); + else + throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError); + if (this->isUsingPacket) { - return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank) * 2 /* data + flag*/ * - 2 /*double buffer*/; + return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/; } - return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank); + return sizePerRank * this->scratchChunks.at(rank); } std::vector ExecutionPlan::Impl::getOperations(int rank, int threadblock) const { return this->operations.at(rank)[threadblock]; @@ -163,7 +170,8 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; } -void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) { +void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, + size_t constDstOffset) { std::ifstream file(this->planPath); json obj = json::parse(file); if (this->name != obj["name"]) { @@ -186,10 +194,12 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOff this->setupChannels(gpus); this->inputSize = inputSize; + this->outputSize = outputSize; this->setupOperations(gpus, contsSrcOffset, constDstOffset); } -void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) { +void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, + size_t constDstOffset) { std::ifstream file(this->planPath); json obj = json::parse(file); if (this->name != obj["name"]) { @@ -210,6 +220,7 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t contsS } this->inputSize = inputSize; + this->outputSize = outputSize; this->setupOperations(gpus, contsSrcOffset, constDstOffset); } @@ -313,8 +324,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse // Get the relevant channel index in rank channelInfos operation.inputChannelIndexes[i] = channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]]; - operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["i_cids"][i]["off"]) + - (srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0); + operation.inputOffsets[i] = + this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["i_cids"][i]["off"]) + + (srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0); chunkIndexes.push_back((uint32_t)op["i_cids"][i]["off"]); } } @@ -323,8 +335,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse operation.nInputs = op["srcs"].size(); operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]); for (int i = 0; i < operation.nInputs; i++) { - operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["srcs"][i]["off"]) + - (operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0); + operation.inputOffsets[i] = + this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcs"][i]["off"]) + + (operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0); chunkIndexes.push_back((uint32_t)op["srcs"][i]["off"]); } } @@ -335,8 +348,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]); operation.outputChannelIndexes[i] = channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]]; - operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["o_cids"][i]["off"]) + - (dstBufferType != BufferType::SCRATCH ? constDstOffset : 0); + operation.outputOffsets[i] = + this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["o_cids"][i]["off"]) + + (dstBufferType != BufferType::SCRATCH ? constDstOffset : 0); chunkIndexes.push_back((uint32_t)op["o_cids"][i]["off"]); } } @@ -345,8 +359,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse operation.nOutputs = op["dsts"].size(); operation.outputBufferType = convertToBufferType(op["dsts"][0]["buff"]); for (int i = 0; i < operation.nOutputs; i++) { - operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["dsts"][i]["off"]) + - (operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0); + operation.outputOffsets[i] = + this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dsts"][i]["off"]) + + (operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0); chunkIndexes.push_back((uint32_t)op["dsts"][i]["off"]); } } @@ -354,18 +369,19 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse operation.srcBufferType = convertToBufferType(op["srcbuff"]); } if (op.contains("srcoff")) { - operation.srcOffset = this->getOffset(rank, this->inputSize, (uint32_t)op["srcoff"]); + operation.srcOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcoff"]); chunkIndexes.push_back((uint32_t)op["srcoff"]); } if (op.contains("dstbuff")) { operation.dstBufferType = convertToBufferType(op["dstbuff"]); } if (op.contains("dstoff")) { - operation.dstOffset = this->getOffset(rank, this->inputSize, (uint32_t)op["dstoff"]); + operation.dstOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dstoff"]); chunkIndexes.push_back((uint32_t)op["dstoff"]); } if (op.contains("cnt")) { - operation.size = this->getNChunkSize(rank, this->inputSize, (uint32_t)op["cnt"], chunkIndexes); + operation.size = + this->getNChunkSize(rank, this->inputSize, this->outputSize, (uint32_t)op["cnt"], chunkIndexes); } ops.push_back(operation); } @@ -374,14 +390,33 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse } } -size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, uint32_t chunkIndex, uint32_t alignment) const { +std::pair ExecutionPlan::Impl::calcSizePerRank(int rank, size_t inputSize, size_t outputSize) const { + std::pair sizePerRank; + if (this->inputChunks.at(rank) == 0 && this->outputChunks.at(rank) == 0) { + throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError); + } else if (this->inputChunks.at(rank) != 0 && this->outputChunks.at(rank) != 0) { + if (inputSize / this->inputChunks.at(rank) != outputSize / this->outputChunks.at(rank)) + throw mscclpp::Error("Size per chunks inconsistent", mscclpp::ErrorCode::ExecutorError); + else + sizePerRank = std::make_pair(inputSize, this->inputChunks.at(rank)); + } else if (this->inputChunks.at(rank) != 0) { + sizePerRank = std::make_pair(inputSize, this->inputChunks.at(rank)); + } else if (this->outputChunks.at(rank) != 0) { + sizePerRank = std::make_pair(outputSize, this->outputChunks.at(rank)); + } + return sizePerRank; +} + +size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, size_t outputSize, uint32_t chunkIndex, + uint32_t alignment) const { if (inputSize % alignment != 0) { throw Error("inputSize must be a multiple of alignment", ErrorCode::ExecutorError); } const int nGroups = this->chunkGroups.at(rank); - uint32_t nInputChunks = this->inputChunks.at(rank); - uint32_t nelems = inputSize / (alignment * sizeof(uint8_t)); + auto sizePerRank = calcSizePerRank(rank, inputSize, outputSize); + uint32_t nInputChunks = sizePerRank.second; + uint32_t nelems = sizePerRank.first / (alignment * sizeof(uint8_t)); if (nelems % nGroups != 0) { throw Error("Input size must be a multiple of nGroups", ErrorCode::ExecutorError); } @@ -397,12 +432,12 @@ size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, uint32_t chunk return static_cast(offset) * alignment; } -size_t ExecutionPlan::Impl::getNChunkSize(int rank, size_t inputSize, uint32_t nChunks, +size_t ExecutionPlan::Impl::getNChunkSize(int rank, size_t inputSize, size_t outputSize, uint32_t nChunks, const std::vector chunkIndexes) const { size_t nChunkSize = 0; for (uint32_t index : chunkIndexes) { - uint32_t beginOff = getOffset(rank, inputSize, index); - uint32_t endOff = getOffset(rank, inputSize, index + nChunks); + uint32_t beginOff = getOffset(rank, inputSize, outputSize, index); + uint32_t endOff = getOffset(rank, inputSize, outputSize, index + nChunks); if (nChunkSize == 0) { nChunkSize = endOff - beginOff; } else if (nChunkSize != endOff - beginOff) { diff --git a/src/executor/executor.cc b/src/executor/executor.cc index 54986d5d6..9de8a58be 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -80,13 +80,13 @@ struct Executor::Impl { } ~Impl() = default; - ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t messageSize, - size_t contsSrcOffset, size_t constDstOffset, size_t sendBufferSize, - size_t recvBufferSize, const ExecutionPlan& plan) { + ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t inputMessageSize, + size_t outputMessageSize, size_t contsSrcOffset, size_t constDstOffset, + size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) { ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name}; if (this->contexts.find(key) != this->contexts.end()) { plan.impl_->operationsReset(); - plan.impl_->lightLoadExecutionPlan(messageSize, contsSrcOffset, constDstOffset); + plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset); this->setupDeviceExecutionPlan(this->contexts[key], rank, plan); this->contexts[key].deviceExecutionPlansBuffer = allocExtSharedCuda(this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); @@ -97,10 +97,10 @@ struct Executor::Impl { } plan.impl_->reset(); - plan.impl_->loadExecutionPlan(messageSize, contsSrcOffset, constDstOffset); + plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset); ExecutionContext context; - size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize); + size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize, recvBufferSize); std::shared_ptr scratchBuffer = allocExtSharedCuda(scratchBufferSize); context.scratchBuffer = scratchBuffer; context.scratchBufferSize = scratchBufferSize; @@ -350,8 +350,9 @@ void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuff size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr; size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; - ExecutionContext context = this->impl_->setupExecutionContext( - rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, offsetIn, offsetOut, sendBytes, recvBytes, plan); + ExecutionContext context = + this->impl_->setupExecutionContext(rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, recvBuffSize, + offsetIn, offsetOut, sendBytes, recvBytes, plan); this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType); } diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp index ce6ec8ca2..a44962782 100644 --- a/src/include/execution_plan.hpp +++ b/src/include/execution_plan.hpp @@ -65,13 +65,13 @@ struct ExecutionPlan::Impl { std::vector getUnpairedChannelInfos(int rank, int worldSize, ChannelType channelType); std::vector getConnectedPeers(int rank) const; std::vector getConnectedBufferTypes(int rank) const; - size_t getScratchBufferSize(int rank, size_t inputSize) const; + size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const; std::vector getOperations(int rank, int threadblock) const; int getThreadblockCount(int rank) const; int getNThreadsPerBlock() const; - void loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset); - void lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset); + void loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset); + void lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset); void setupChannels(const nlohmann::json& gpus); void setupOperations(const nlohmann::json& gpus, size_t contsSrcOffset, size_t constDstOffset); @@ -94,11 +94,14 @@ struct ExecutionPlan::Impl { std::unordered_map scratchChunks; std::unordered_map chunkGroups; size_t inputSize; + size_t outputSize; int nThreadsPerBlock; private: - size_t getOffset(int rank, size_t inputSize, uint32_t chunkIndex, uint32_t alignment = 16) const; - size_t getNChunkSize(int rank, size_t inputSize, uint32_t nChunks, const std::vector offsets) const; + std::pair calcSizePerRank(int rank, size_t inputSize, size_t outputSize) const; + size_t getOffset(int rank, size_t inputSize, size_t outputSize, uint32_t chunkIndex, uint32_t alignment = 16) const; + size_t getNChunkSize(int rank, size_t inputSize, size_t outputSize, uint32_t nChunks, + const std::vector offsets) const; }; } // namespace mscclpp From 582d386b3b009f202fd11a0e392b76485e8b3445 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 21 Oct 2024 19:59:15 -0700 Subject: [PATCH 11/12] Fix algo repo name (#369) Change algo repo name from azure-mscclpp to msccl-users Co-authored-by: Changho Hwang --- .azure-pipelines/integration-test-rocm.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.azure-pipelines/integration-test-rocm.yml b/.azure-pipelines/integration-test-rocm.yml index 029259ffb..3315dfa9a 100644 --- a/.azure-pipelines/integration-test-rocm.yml +++ b/.azure-pipelines/integration-test-rocm.yml @@ -62,8 +62,8 @@ jobs: targetType: 'inline' script: | set -e - git clone https://$(GIT_USER):$(GIT_PAT)@msazure.visualstudio.com/DefaultCollection/One/_git/azure-mscclpp - cd azure-mscclpp + git clone https://$(GIT_USER):$(GIT_PAT)@msazure.visualstudio.com/DefaultCollection/One/_git/msccl-users + cd msccl-users mkdir execution-files python3 algos/allreduce_mi300_packet.py 8 8 > execution-files/allreduce_mi300_packet.json python3 algos/allreduce_mi300_sm_mscclpp.py 8 8 > execution-files/allreduce_mi300_sm_mscclpp.json @@ -89,8 +89,8 @@ jobs: set -e export PATH=/usr/local/mpi/bin:$PATH sudo /usr/local/mpi/bin/mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=$(pwd)/build/apps/nccl/libmscclpp_nccl.so -x NCCL_DEBUG=WARN \ - -x ALLREDUCEPKT_IP_JSON_FILE=./azure-mscclpp/execution-files/allreduce_mi300_packet.json \ - -x ALLREDUCE_IP_JSON_FILE=./azure-mscclpp/execution-files/allreduce_mi300_sm_mscclpp.json \ + -x ALLREDUCEPKT_IP_JSON_FILE=./msccl-users/execution-files/allreduce_mi300_packet.json \ + -x ALLREDUCE_IP_JSON_FILE=./msccl-users/execution-files/allreduce_mi300_sm_mscclpp.json \ -x ALLREDUCE_SMALL_MSG_BOUNDARY=32K -x ALLREDUCE_LARGE_MSG_BOUNDARY=1M ./rccl-tests/build/all_reduce_perf \ -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 100 workingDirectory: '$(System.DefaultWorkingDirectory)' From b72decbfeb3dd6ee0ab980b0ae4d5752e63ac809 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 21 Oct 2024 21:51:28 -0700 Subject: [PATCH 12/12] Update docker image for cuda12.4 (#370) Update docker image for cuda12.4 Image pushed to registry --------- Co-authored-by: Changho Hwang --- docker/base-dev-x.dockerfile | 2 +- docker/base-x-rocm.dockerfile | 2 +- docker/base-x.dockerfile | 13 ++++++++----- docker/build.sh | 17 +++++++++++++++-- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/docker/base-dev-x.dockerfile b/docker/base-dev-x.dockerfile index d6236cd02..5aeaa4142 100644 --- a/docker/base-dev-x.dockerfile +++ b/docker/base-dev-x.dockerfile @@ -2,7 +2,7 @@ ARG BASE_IMAGE FROM ${BASE_IMAGE} LABEL maintainer="MSCCL++" -LABEL org.opencontainers.image.source https://github.com/microsoft/mscclpp +LABEL org.opencontainers.image.source=https://github.com/microsoft/mscclpp RUN apt-get update && \ apt-get install -y --no-install-recommends \ diff --git a/docker/base-x-rocm.dockerfile b/docker/base-x-rocm.dockerfile index 5865cc397..cffa68cc5 100644 --- a/docker/base-x-rocm.dockerfile +++ b/docker/base-x-rocm.dockerfile @@ -2,7 +2,7 @@ ARG BASE_IMAGE FROM ${BASE_IMAGE} LABEL maintainer="MSCCL++" -LABEL org.opencontainers.image.source https://github.com/microsoft/mscclpp +LABEL org.opencontainers.image.source=https://github.com/microsoft/mscclpp ENV DEBIAN_FRONTEND=noninteractive diff --git a/docker/base-x.dockerfile b/docker/base-x.dockerfile index a1ba20693..45f39c709 100644 --- a/docker/base-x.dockerfile +++ b/docker/base-x.dockerfile @@ -2,7 +2,7 @@ ARG BASE_IMAGE FROM ${BASE_IMAGE} LABEL maintainer="MSCCL++" -LABEL org.opencontainers.image.source https://github.com/microsoft/mscclpp +LABEL org.opencontainers.image.source=https://github.com/microsoft/mscclpp ENV DEBIAN_FRONTEND=noninteractive USER root @@ -17,6 +17,7 @@ RUN apt-get update && \ git \ libcap2 \ libnuma-dev \ + lsb-release \ openssh-client \ openssh-server \ python3-dev \ @@ -31,11 +32,13 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* /tmp/* # Install OFED -ENV OFED_VERSION=5.2-2.2.3.0 +ARG OFED_VERSION=5.2-2.2.3.0 RUN cd /tmp && \ - wget -q https://content.mellanox.com/ofed/MLNX_OFED-${OFED_VERSION}/MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \ - tar xzf MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \ - MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64/mlnxofedinstall --user-space-only --without-fw-update --force --all && \ + OS_VERSION=$(lsb_release -rs) && \ + OS_VERSION=ubuntu${OS_VERSION} && \ + wget -q https://content.mellanox.com/ofed/MLNX_OFED-${OFED_VERSION}/MLNX_OFED_LINUX-${OFED_VERSION}-${OS_VERSION}-x86_64.tgz && \ + tar xzf MLNX_OFED_LINUX-${OFED_VERSION}-${OS_VERSION}-x86_64.tgz && \ + MLNX_OFED_LINUX-${OFED_VERSION}-${OS_VERSION}-x86_64/mlnxofedinstall --user-space-only --without-fw-update --without-ucx-cuda --force --all && \ rm -rf /tmp/MLNX_OFED_LINUX-${OFED_VERSION}* # Install OpenMPI diff --git a/docker/build.sh b/docker/build.sh index d8af5f8fd..3e2169f68 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -8,6 +8,7 @@ baseImageTable=( ["cuda12.1"]="nvidia/cuda:12.1.1-devel-ubuntu20.04" ["cuda12.2"]="nvidia/cuda:12.2.2-devel-ubuntu20.04" ["cuda12.3"]="nvidia/cuda:12.3.2-devel-ubuntu20.04" + ["cuda12.4"]="nvidia/cuda:12.4.1-devel-ubuntu22.04" ["rocm6.2"]="rocm/rocm-terminal:6.2" ) @@ -20,11 +21,16 @@ extraLdPathTable=( ["rocm6.2"]="/opt/rocm/lib" ) +declare -A ofedVersionTable +ofedVersionTable=( + ["cuda12.4"]="23.07-0.5.1.2" +) + GHCR="ghcr.io/microsoft/mscclpp/mscclpp" TARGET=${1} print_usage() { - echo "Usage: $0 [cuda11.8|cuda12.1|cuda12.2|cuda12.3|rocm6.2]" + echo "Usage: $0 [cuda11.8|cuda12.1|cuda12.2|cuda12.3|cuda12.4|rocm6.2]" } if [[ ! -v "baseImageTable[${TARGET}]" ]]; then @@ -38,11 +44,18 @@ SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" cd ${SCRIPT_DIR}/.. +DEFAULT_OFED_VERSION="5.2-2.2.3.0" +OFED_VERSION=${ofedVersionTable[${TARGET}]} +if [[ -z ${OFED_VERSION} ]]; then + OFED_VERSION=${DEFAULT_OFED_VERSION} +fi + docker build -t ${GHCR}-common:base-${TARGET} \ -f docker/base-x.dockerfile \ --build-arg BASE_IMAGE=${baseImageTable[${TARGET}]} \ --build-arg EXTRA_LD_PATH=${extraLdPathTable[${TARGET}]} \ - --build-arg TARGET=${TARGET} . + --build-arg TARGET=${TARGET} \ + --build-arg OFED_VERSION=${OFED_VERSION} . if [[ ${TARGET} == rocm* ]]; then echo "Building ROCm base image..."