diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 1cb7c71d6..e95b4fced 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -6,6 +6,7 @@ #include #include #include +#include #include #include "nccl.h" @@ -425,8 +426,9 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPe return cudaGetLastError(); } +template __global__ void __launch_bounds__(1024, 1) - allgather5(size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) { + allgather6(void *sendbuff, size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) { const size_t nBlock = gridDim.x; if (blockIdx.x >= nBlock) return; @@ -460,40 +462,61 @@ __global__ void __launch_bounds__(1024, 1) if (nLoop > 0) { // First loop unrolling const size_t peerIdx = wid % nPeer; - const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); - const size_t offset = bytesPerGPU * remoteRankLocalIndex + (wid / nPeer) * unitBytesPerWarp; - smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + const size_t offset = bytesPerGPU * rank + (wid / nPeer) * unitBytesPerWarp; + if constexpr (IsOutOfPlace) { + char* dst = reinterpret_cast(smChans[peerIdx].dst_); + char* src = reinterpret_cast(smChans[peerIdx].src_); + char* buff = reinterpret_cast(sendbuff); + const size_t offsetWithinRank = (wid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].copy<16, false>(src + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE); + smChans[peerIdx].copy<16, false>(dst + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE); + } else { + smChans[peerIdx].put<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + } } for (size_t i = 1; i < nLoop; ++i) { const size_t gWid = wid + i * nWarp; const size_t peerIdx = gWid % nPeer; - const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); - const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp; - smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + const size_t offset = bytesPerGPU * rank + (gWid / nPeer) * unitBytesPerWarp; + if constexpr (IsOutOfPlace) { + char* dst = reinterpret_cast(smChans[peerIdx].dst_); + char* src = reinterpret_cast(smChans[peerIdx].src_); + char* buff = reinterpret_cast(sendbuff); + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].copy<16, false>(src + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE); + smChans[peerIdx].copy<16, false>(dst + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE); + } else { + smChans[peerIdx].put<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + } } if (bytes % unitBytes > 0) { const size_t gWid = wid + nLoop * nWarp; const size_t peerIdx = gWid % nPeer; - const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1); const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; - const size_t offset = bytesPerGPU * remoteRankLocalIndex + offsetWithinRank; + const size_t offset = bytesPerGPU * rank + offsetWithinRank; const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU) ? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0) : unitBytesPerWarp; if (remainBytes > 0) { - smChans[peerIdx].get<16, true>(offset, remainBytes, lid, WARP_SIZE); + if constexpr (IsOutOfPlace) { + char* dst = reinterpret_cast(smChans[peerIdx].dst_); + char* src = reinterpret_cast(smChans[peerIdx].src_); + char* buff = reinterpret_cast(sendbuff); + smChans[peerIdx].copy<16, true>(src + offset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE); + smChans[peerIdx].copy<16, true>(dst + offset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE); + } else { + smChans[peerIdx].put<16, true>(offset, remainBytes, lid, WARP_SIZE); + } } } } -template +template cudaError_t allgather(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream) { - cudaError_t err = cudaMemcpyAsync(resultBuff + nelems * rank, buff, nelems * sizeof(T), cudaMemcpyDeviceToDevice, stream); - if (err != cudaSuccess) return err; - allgather5<<<24, 1024, 0, stream>>>(rank, worldSize, nRanksPerNode, nelems); + allgather6<<<24, 1024, 0, stream>>>((void*)buff, rank, worldSize, nRanksPerNode, nelems); return cudaGetLastError(); } @@ -857,8 +880,13 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t constSmChannels, it->second.smChannelDeviceHandles.data(), sizeof(mscclpp::DeviceHandle) * it->second.smChannelDeviceHandles.size(), 0, cudaMemcpyHostToDevice, stream)); - CUDACHECK(allgather((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, - rank, nRanksPerNode, nRank, bytes / sizeof(int), stream)); + if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) { + CUDACHECK(allgather((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, + rank, nRanksPerNode, nRank, bytes / sizeof(int), stream)); + } else { + CUDACHECK(allgather((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, + rank, nRanksPerNode, nRank, bytes / sizeof(int), stream)); + } return ncclSuccess; }