Skip to content

Commit

Permalink
fixed out-of-place allgather
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Feb 24, 2024
1 parent 1becb79 commit 2e0013f
Showing 1 changed file with 44 additions and 16 deletions.
60 changes: 44 additions & 16 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <mscclpp/core.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include <unordered_map>
#include <vector>

#include "nccl.h"
Expand Down Expand Up @@ -425,8 +426,9 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPe
return cudaGetLastError();
}

template <bool IsOutOfPlace>
__global__ void __launch_bounds__(1024, 1)
allgather5(size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) {
allgather6(void *sendbuff, size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) {
const size_t nBlock = gridDim.x;
if (blockIdx.x >= nBlock) return;

Expand Down Expand Up @@ -460,40 +462,61 @@ __global__ void __launch_bounds__(1024, 1)
if (nLoop > 0) {
// First loop unrolling
const size_t peerIdx = wid % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (wid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE);
const size_t offset = bytesPerGPU * rank + (wid / nPeer) * unitBytesPerWarp;
if constexpr (IsOutOfPlace) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(smChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
const size_t offsetWithinRank = (wid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].copy<16, false>(src + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE);
smChans[peerIdx].copy<16, false>(dst + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE);
} else {
smChans[peerIdx].put<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE);
}
}

for (size_t i = 1; i < nLoop; ++i) {
const size_t gWid = wid + i * nWarp;
const size_t peerIdx = gWid % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE);
const size_t offset = bytesPerGPU * rank + (gWid / nPeer) * unitBytesPerWarp;
if constexpr (IsOutOfPlace) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(smChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].copy<16, false>(src + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE);
smChans[peerIdx].copy<16, false>(dst + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE);
} else {
smChans[peerIdx].put<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE);
}
}

if (bytes % unitBytes > 0) {
const size_t gWid = wid + nLoop * nWarp;
const size_t peerIdx = gWid % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
const size_t offset = bytesPerGPU * remoteRankLocalIndex + offsetWithinRank;
const size_t offset = bytesPerGPU * rank + offsetWithinRank;
const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU)
? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0)
: unitBytesPerWarp;
if (remainBytes > 0) {
smChans[peerIdx].get<16, true>(offset, remainBytes, lid, WARP_SIZE);
if constexpr (IsOutOfPlace) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(smChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
smChans[peerIdx].copy<16, true>(src + offset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE);
smChans[peerIdx].copy<16, true>(dst + offset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE);
} else {
smChans[peerIdx].put<16, true>(offset, remainBytes, lid, WARP_SIZE);
}
}
}
}

template <typename T>
template <bool IsOutOfPlace, typename T>
cudaError_t allgather(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPerNode, int worldSize, size_t nelems,
cudaStream_t stream) {
cudaError_t err = cudaMemcpyAsync(resultBuff + nelems * rank, buff, nelems * sizeof(T), cudaMemcpyDeviceToDevice, stream);
if (err != cudaSuccess) return err;
allgather5<<<24, 1024, 0, stream>>>(rank, worldSize, nRanksPerNode, nelems);
allgather6<IsOutOfPlace><<<24, 1024, 0, stream>>>((void*)buff, rank, worldSize, nRanksPerNode, nelems);
return cudaGetLastError();
}

Expand Down Expand Up @@ -857,8 +880,13 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
constSmChannels, it->second.smChannelDeviceHandles.data(),
sizeof(mscclpp::DeviceHandle<mscclpp::SmChannel>) * it->second.smChannelDeviceHandles.size(), 0,
cudaMemcpyHostToDevice, stream));
CUDACHECK(allgather((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff,
rank, nRanksPerNode, nRank, bytes / sizeof(int), stream));
if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) {
CUDACHECK(allgather<false>((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff,
rank, nRanksPerNode, nRank, bytes / sizeof(int), stream));
} else {
CUDACHECK(allgather<true>((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff,
rank, nRanksPerNode, nRank, bytes / sizeof(int), stream));
}
return ncclSuccess;
}

Expand Down

0 comments on commit 2e0013f

Please sign in to comment.