From d92e248b9811da908315074e49f3b8719c07b14a Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 22 Feb 2024 01:20:37 +0000 Subject: [PATCH] allgather works --- apps/nccl/rccl_test.py | 33 +++++++---- apps/nccl/src/nccl.cu | 128 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 140 insertions(+), 21 deletions(-) diff --git a/apps/nccl/rccl_test.py b/apps/nccl/rccl_test.py index 9cdbed1ac..298d26eee 100644 --- a/apps/nccl/rccl_test.py +++ b/apps/nccl/rccl_test.py @@ -38,16 +38,29 @@ a = torch.ones(shape_a, device="cuda") b = torch.ones(shape_b, device="cuda") c = torch.mm(a, b) - + print(c) - -nccl_op = nccl.NCCL_SUM -group.allReduce( - sendbuf=c.data_ptr(), - recvbuf=c.data_ptr(), - count=c.nelement(), + +# nccl_op = nccl.NCCL_SUM +# group.allReduce( +# sendbuf=c.data_ptr(), +# recvbuf=c.data_ptr(), +# count=c.nelement(), +# datatype=nccl.NCCL_FLOAT, +# op=nccl_op, +# stream=torch.cuda.current_stream().cuda_stream) + +# print(c) + +d = torch.ones((1024*1024,), device="cuda") +e = torch.zeros((8*1024*1024,), device="cuda") +e[rank*1024*1024:(rank+1)*1024*1024] = d + +group.allGather( + sendbuf=d.data_ptr(), + recvbuf=e.data_ptr(), + count=d.nelement(), datatype=nccl.NCCL_FLOAT, - op=nccl_op, stream=torch.cuda.current_stream().cuda_stream) - -print(c) + +print(e) diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 715211c7d..73d9d6b2d 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -21,6 +21,14 @@ } \ } while (0) +#define NUM_CHANNELS_PER_CONNECTION 32 + +#if defined(__HIP_PLATFORM_AMD__) +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif + template __forceinline__ __device__ To bit_cast(const From& src) { static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast"); @@ -143,8 +151,8 @@ static const int scratchSize = 1024 * 1024 * 8; // mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5, // mscclpp::Transport::IB6, mscclpp::Transport::IB7}; -__constant__ mscclpp::DeviceHandle constSmChannels[8]; -__constant__ mscclpp::DeviceHandle constSmOutChannels[8]; +__constant__ mscclpp::DeviceHandle constSmChannels[256]; +__constant__ mscclpp::DeviceHandle constSmOutChannels[256]; __device__ mscclpp::DeviceSyncer deviceSyncer; struct channelKey { @@ -411,6 +419,76 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPe return cudaGetLastError(); } +__global__ void __launch_bounds__(1024, 1) + allgather5(size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) { + const size_t nBlock = gridDim.x; + if (blockIdx.x >= nBlock) return; + + const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; + const size_t lid = tid % WARP_SIZE; + const size_t wid = tid / WARP_SIZE; + + const size_t nThread = blockDim.x * nBlock; + const size_t nWarp = nThread / WARP_SIZE; + const size_t nPeer = nRanksPerNode - 1; + const size_t chanOffset = nPeer * blockIdx.x; + auto smChans = constSmChannels + chanOffset; + + if (wid < nPeer && lid == 0) { + smChans[wid].relaxedSignal(); + smChans[wid].wait(); + } + __syncthreads(); + const size_t bytesPerGPU = nelemsPerGPU * sizeof(int); + const size_t bytes = bytesPerGPU * nPeer; + size_t unitBytesPerThread; + if (bytes >= nThread * 64) { + unitBytesPerThread = 64; + } else { + unitBytesPerThread = 16; + } + const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE; + const size_t unitBytes = unitBytesPerWarp * nWarp; + const size_t nLoop = bytes / unitBytes; + + if (nLoop > 0) { + // First loop unrolling + const size_t peerIdx = wid % nPeer; + const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offset = bytesPerGPU * remoteRankLocalIndex + (wid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + } + + for (size_t i = 1; i < nLoop; ++i) { + const size_t gWid = wid + i * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + } + + if (bytes % unitBytes > 0) { + const size_t gWid = wid + nLoop * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + const size_t offset = bytesPerGPU * remoteRankLocalIndex + offsetWithinRank; + const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU) + ? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0) + : unitBytesPerWarp; + if (remainBytes > 0) { + smChans[peerIdx].get<16, true>(offset, remainBytes, lid, WARP_SIZE); + } + } +} + +template +cudaError_t allgather(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPerNode, int worldSize, size_t nelems, + cudaStream_t stream) { + allgather5<<<24, 1024, 0, stream>>>(rank, worldSize, nRanksPerNode, nelems); + return cudaGetLastError(); +} + static size_t ncclTypeSize(ncclDataType_t type) { switch (type) { case ncclInt8: @@ -474,9 +552,12 @@ static std::vector setupSmChannels(ncclComm_t comm, void* src) { std::vector channels; std::vector>& smSemaphores = comm->smSemaphores; - for (size_t cid = 0; cid < comm->connections.size(); ++cid) { - if (comm->connections[cid]->transport() == mscclpp::Transport::CudaIpc) { - channels.emplace_back(smSemaphores[cid], remoteMemories[cid], src, nullptr); + size_t nConnections = comm->connections.size(); + for (size_t idx = 0; idx < NUM_CHANNELS_PER_CONNECTION; ++idx) { + for (size_t cid = 0; cid < nConnections; ++cid) { + if (comm->connections[cid]->transport() == mscclpp::Transport::CudaIpc) { + channels.emplace_back(smSemaphores[idx * nConnections + cid], remoteMemories[cid], src, nullptr); + } } } return channels; @@ -523,10 +604,12 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI [](const auto& future) { return future.get(); }); std::vector> smSemaphores; - for (size_t cid = 0; cid < connections.size(); ++cid) { - if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) { - smSemaphores.emplace_back( - std::make_shared(*(mscclppComm), connections[cid])); + for (size_t idx = 0; idx < NUM_CHANNELS_PER_CONNECTION; ++idx) { + for (size_t cid = 0; cid < connections.size(); ++cid) { + if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) { + smSemaphores.emplace_back( + std::make_shared(*(mscclppComm), connections[cid])); + } } } mscclppComm->setup(); @@ -723,8 +806,31 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) { - // TODO: implement this function - return ncclInternalError; + size_t bytes = sendcount * ncclTypeSize(datatype); + if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument; + int rank = comm->comm->bootstrap()->getRank(); + int nRank = comm->comm->bootstrap()->getNranks(); + channelKey key{sendbuff, recvbuff, bytes}; + + auto it = comm->smChannels.find(key); + if (it == comm->smChannels.end()) { + std::vector remoteMemories = + setupRemoteMemories(comm->comm, rank, const_cast(recvbuff), bytes * nRank, + mscclpp::Transport::CudaIpc); + std::vector channels = + setupSmChannels(comm, remoteMemories, const_cast(recvbuff)); + it = comm->smChannels.emplace(key, channels).first; + } + std::vector> smChannelDeviceHandles; + std::transform(it->second.begin(), it->second.end(), std::back_inserter(smChannelDeviceHandles), + [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); + // TODO: if sendbuff and recvbuff don't change, we can avoid copying smChannelDeviceHandles to device + CUDACHECK(cudaMemcpyToSymbol(constSmChannels, smChannelDeviceHandles.data(), + sizeof(mscclpp::DeviceHandle) * smChannelDeviceHandles.size())); + + CUDACHECK(allgather((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, + rank, nRanksPerNode, nRank, bytes / sizeof(int), stream)); + return ncclSuccess; } NCCL_API ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,