From d208486e5503efa41cc420f2ad286a05ad89242b Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 15 Jan 2024 19:18:48 +0000 Subject: [PATCH] Add allgather5 --- test/mscclpp-test/allgather_test.cu | 55 +++++++++++++++++++++++++++-- test/mscclpp-test/common.cc | 20 ++++++----- test/mscclpp-test/common.hpp | 2 +- 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/test/mscclpp-test/allgather_test.cu b/test/mscclpp-test/allgather_test.cu index 4b2eff78f..495236a88 100644 --- a/test/mscclpp-test/allgather_test.cu +++ b/test/mscclpp-test/allgather_test.cu @@ -23,7 +23,7 @@ using DeviceHandle = mscclpp::DeviceHandle; __constant__ DeviceHandle constProxyChans[16]; __constant__ DeviceHandle constRawProxyChan[16]; -__constant__ DeviceHandle constSmChans[8]; +__constant__ DeviceHandle constSmChans[256]; __global__ void allgather0(int rank, size_t nelemsPerGPU) { int warpId = threadIdx.x / WARP_SIZE; @@ -288,6 +288,49 @@ __global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne nBlocksForLocalAllGather); } +__global__ void allgather5(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int lid = tid % WARP_SIZE; + const int wid = tid / WARP_SIZE; + const int nWarp = blockDim.x * gridDim.x / WARP_SIZE; + const int nPeer = nRanksPerNode - 1; + const int chanOffset = nPeer * blockIdx.x; + auto smChans = constSmChans + chanOffset; + + if (wid < nPeer) { + smChans[wid].signal(); + smChans[wid].wait(); + } + __syncthreads(); + constexpr size_t unitBytesPerThread = 16; + constexpr size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE; + const size_t unitBytes = unitBytesPerWarp * nWarp; + const size_t bytesPerGPU = nelemsPerGPU * sizeof(int); + const size_t bytes = bytesPerGPU * nPeer; + const size_t nLoop = bytes / unitBytes; + for (size_t i = 0; i < nLoop; ++i) { + const size_t gWid = wid + i * nWarp; + const int peerIdx = gWid % nPeer; + const int remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].get(offset, unitBytesPerWarp, lid, WARP_SIZE); + } + + if (bytes % unitBytes > 0) { + const size_t gWid = wid + nLoop * nWarp; + const int peerIdx = gWid % nPeer; + const int remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + const size_t offset = bytesPerGPU * remoteRankLocalIndex + offsetWithinRank; + const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU) + ? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0) + : unitBytesPerWarp; + if (remainBytes > 0) { + smChans[peerIdx].get(offset, remainBytes, lid, WARP_SIZE); + } + } +} + class AllGatherProxyService : public mscclpp::BaseProxyService { public: AllGatherProxyService(int worldSize, int rank, int cudaDevice); @@ -387,6 +430,9 @@ void AllGatherTestColl::runColl(const TestArgs& args, cudaStream_t stream) { if (kernelNum == 4) { nBlocks = 21; nThreads = 1024; + } else if (kernelNum == 5) { + nBlocks = 32; + nThreads = 1024; } else { nBlocks = 1; nThreads = WARP_SIZE * (worldSize - 1); @@ -401,6 +447,8 @@ void AllGatherTestColl::runColl(const TestArgs& args, cudaStream_t stream) { allgather3<<>>(); } else if (kernelNum == 4) { allgather4<<>>(rank, worldSize, nRanksPerNode, paramCount_); + } else if (kernelNum == 5) { + allgather5<<>>(rank, worldSize, nRanksPerNode, paramCount_); } } @@ -453,7 +501,8 @@ std::vector AllGatherTestColl::getKernelRestrictions() { {1, "allgather1", false, 1, 4 * worldSize_}, {2, "allgather2", true, 3, 4 * worldSize_}, {3, "allgather3", true, 1, 4 * worldSize_}, - {4, "allgather4", true, 3, 16 * worldSize_ /*use ulong2 to transfer data*/}}; + {4, "allgather4", true, 3, 16 * worldSize_ /*use ulong2 to transfer data*/}, + {5, "allgather5", false, 1, 16 * worldSize_ /*use ulong2 to transfer data*/}}; } class AllGatherTestEngine : public BaseTestEngine { @@ -494,7 +543,7 @@ void AllGatherTestEngine::setupConnections() { CUDATHROW(cudaMemcpyToSymbol(constProxyChans, devProxyChannels.data(), sizeof(DeviceHandle) * devProxyChannels.size())); - setupMeshConnections(smChannels_, sendBuff_.get(), args_.maxBytes); + setupMeshConnections(smChannels_, sendBuff_.get(), args_.maxBytes, nullptr, 0, ChannelSemantic::PUT, 32); std::vector> smChannelHandles(smChannels_.size()); if (smChannels_.size() > sizeof(constSmChans) / sizeof(DeviceHandle)) { std::runtime_error("unexpected error"); diff --git a/test/mscclpp-test/common.cc b/test/mscclpp-test/common.cc index c5653b3fc..fe07df511 100644 --- a/test/mscclpp-test/common.cc +++ b/test/mscclpp-test/common.cc @@ -428,7 +428,7 @@ void BaseTestEngine::setupMeshConnections(std::vector& smChannels, void* inputBuff, size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes, - ChannelSemantic semantic) { + ChannelSemantic semantic, size_t nChannelPerConnection) { const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum]; mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports); mscclpp::RegisteredMemory getPacketBufRegMem; @@ -443,19 +443,23 @@ void BaseTestEngine::setupMeshConnections(std::vector& smCha (outputBuff && semantic == ChannelSemantic::PUT) ? outputBufRegMem : inputBufRegMem; setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories); - std::unordered_map> smSemaphores; + std::unordered_map>> smSemaphores; for (size_t cid = 0; cid < connections.size(); ++cid) { if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) { - smSemaphores.emplace(cid, std::make_shared(*comm_, connections[cid])); + for (size_t i = 0; i < nChannelPerConnection; ++i) { + smSemaphores[cid].emplace_back(std::make_shared(*comm_, connections[cid])); + } } } comm_->setup(); - for (size_t cid = 0; cid < connections.size(); ++cid) { - if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) { - smChannels.emplace_back(smSemaphores[cid], remoteRegMemories[cid].get(), - (outputBuff && semantic == ChannelSemantic::GET) ? outputBuff : inputBufRegMem.data(), - nullptr); + for (size_t i = 0; i < nChannelPerConnection; ++i) { + for (size_t cid = 0; cid < connections.size(); ++cid) { + if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) { + smChannels.emplace_back(smSemaphores[cid][i], remoteRegMemories[cid].get(), + (outputBuff && semantic == ChannelSemantic::GET) ? outputBuff : inputBufRegMem.data(), + nullptr); + } } } } diff --git a/test/mscclpp-test/common.hpp b/test/mscclpp-test/common.hpp index 665ff9119..7e3e8c423 100644 --- a/test/mscclpp-test/common.hpp +++ b/test/mscclpp-test/common.hpp @@ -118,7 +118,7 @@ class BaseTestEngine { SetupChannelFunc setupChannel = nullptr); void setupMeshConnections(std::vector& smChannels, void* inputBuff, size_t inputBuffBytes, void* outputBuff = nullptr, size_t outputBuffBytes = 0, - ChannelSemantic semantic = ChannelSemantic::PUT); + ChannelSemantic semantic = ChannelSemantic::PUT, size_t nChannelPerConnection = 1); void setupMeshConnections(std::vector& smChannels, std::vector>& proxyChannels, void* inputBuff, size_t inputBuffBytes, void* putPacketBuff = nullptr, size_t putPacketBuffBytes = 0,