From 4b3e27c555e7e6d4be824e586f96cf06f0c5272b Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 21 Feb 2024 11:17:26 +0000 Subject: [PATCH] allreduce7 works --- apps/nccl/src/nccl.cu | 91 +++++++++++++++++++++++++++++ src/bootstrap/bootstrap.cc | 2 +- test/mscclpp-test/allreduce_test.cu | 10 +++- 3 files changed, 100 insertions(+), 3 deletions(-) diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index ec47362cf..39f08f411 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -96,6 +96,21 @@ __forceinline__ __device__ int add_vectors<__half>(int a, int b) { return add_vectors_helper<__half2>(a, b); } +template +__forceinline__ __device__ uint32_t add_vectors_helper(uint32_t a, uint32_t b) { + return bit_cast(add_elements(bit_cast(a), bit_cast(b))); +} + +template +__forceinline__ __device__ uint32_t add_vectors(uint32_t a, uint32_t b) { + return add_vectors_helper(a, b); +} + +template <> +__forceinline__ __device__ uint32_t add_vectors<__half>(uint32_t a, uint32_t b) { + return add_vectors_helper<__half2>(a, b); +} + template __forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem, int blockId, int nBlocks) { size_t nInt4 = nElem / 4; @@ -310,11 +325,86 @@ __global__ void allreduce1(T* src, T* dst, int rank, int nranks, size_t nelems) } } +template +__global__ void allreduce7(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPerNode, int worldSize, + size_t nelems) { + // This version of allreduce only works for single nodes + if (worldSize != nRanksPerNode) return; + nelems = nelems / (sizeof(int) / sizeof(T)); + const int nPeers = nRanksPerNode - 1; + const size_t nPkts = nelems; + const int nelemsPerRank = nelems / worldSize; + const int nPktsPerRank = nelemsPerRank; + // flag for packets. Initially 1 + const uint32_t flag = (uint32_t)globalFlag; + // thread block & channel info + const int nBlocksPerPeer = gridDim.x / nPeers; + const int localBlockIdx = blockIdx.x % nBlocksPerPeer; + const int peerIdx = blockIdx.x / nBlocksPerPeer; + const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1; + const int tid = threadIdx.x + localBlockIdx * blockDim.x; + // double buffering + size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LL8Packet); + void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset); + size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LL8Packet); + size_t scratchResultOffset = + (flag & 1) ? 2 * nPkts * sizeof(mscclpp::LL8Packet) : 3 * nPkts * sizeof(mscclpp::LL8Packet); + size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int); + uint32_t* src = (uint32_t*)((char*)buff + rank * nelemsPerRank * sizeof(int)); + uint32_t* dst = (uint32_t*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int)); + + // step 1: write to scratch buffer + constSmChannels[peerIdx].putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, + blockDim.x * nBlocksPerPeer, flag); + // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { + uint32_t data = 0; + for (int index = 0; index < nPeers; index++) { + const int remoteRank = index < rank ? index : index + 1; + mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nPktsPerRank; + uint32_t val = dstPkt[idx].read(flag); + data = add_vectors(val, data); + } + data = add_vectors(data, src[idx]); + dst[idx] = data; + + mscclpp::LL8Packet packet; + packet.data = data; + packet.flag = flag; + size_t offset = scratchResultOffset / sizeof(mscclpp::LL8Packet) + (idx + rank * nPktsPerRank); + for (int index = 0; index < nPeers; index++) { + constSmChannels[index].write(offset, packet); + } + } + // step 3: get data result from scratch buffer + mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)((char*)scratch + scratchResultOffset); + const int dstOffset = remoteRank * nPktsPerRank; + uint32_t* result = (uint32_t*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int)); + for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) { + uint32_t data = dstPkt[idx + dstOffset].read(flag); + result[idx] = data; + } + if (threadIdx.x == 0 && blockIdx.x == 0) { + globalFlag += 1; + } +} + template cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream) { if (sizeof(T) * nelems <= (1 << 20)) { +#if defined(__HIP_PLATFORM_AMD__) + int nBlocks = 28; + int nThreadsPerBlock = 1024; + if (nelems >= 8192) { + nBlocks = 56; + nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024; + } + allreduce7<<>>(buff, scratch, resultBuff, rank, nRanksPerNode, worldSize, + nelems); +#else allreduce6<<<21, 512, 0, stream>>>(buff, scratch, resultBuff, rank, nRanksPerNode, worldSize, nelems); +#endif } else { allreduce1<<<24, 1024, 0, stream>>>(buff, resultBuff, rank, worldSize, nelems); } @@ -511,6 +601,7 @@ NCCL_API ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) { } NCCL_API ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* device) { + if (comm == nullptr || device == nullptr) return ncclInvalidArgument; *device = comm->comm->bootstrap()->getRank(); return ncclSuccess; } diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index 6ae8b2c10..c9cea10f4 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -187,7 +187,7 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t tim } if (!netInitialized) { - netInit("", "", netIfAddr_); + netInit(ipPortPair, interface, netIfAddr_); netInitialized = true; } diff --git a/test/mscclpp-test/allreduce_test.cu b/test/mscclpp-test/allreduce_test.cu index cbedcefd0..56de36159 100644 --- a/test/mscclpp-test/allreduce_test.cu +++ b/test/mscclpp-test/allreduce_test.cu @@ -1138,9 +1138,15 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) { tmpBuff = scratchPacketBuff; nThreadsPerBlock = 512; } else if (kernelNum == 7) { - nBlocks = 28; tmpBuff = scratchPacketBuff; - nThreadsPerBlock = 1024; + // tune the #blocks and #threads for MI300X + if (paramCount_ < 8192) { + nBlocks = 28; + nThreadsPerBlock = 1024; + } else { + nBlocks = 56; + nThreadsPerBlock = (paramCount_ <= 76800) ? 512 : 1024; + } } else { nBlocks = std::max(args.nRanksPerNode - 1, 1) * BLOCKS_PER_PEER; tmpBuff = scratchPacketBuff;