Skip to content

Commit

Permalink
allreduce for all message sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Feb 28, 2024
1 parent f58b6f8 commit d72f09f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 27 deletions.
95 changes: 70 additions & 25 deletions apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,17 @@ __global__ void __launch_bounds__(1024, 1)
size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x;
size_t nInt4OfThisBlock = maxNInt4PerBlock;
size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock;
constexpr size_t nInt4PerChunk = 1024 * 256 / sizeof(int4); // 256KB
if (blockIdx.x >= nNeededBlocks) {
nInt4OfThisBlock = 0;
} else if (blockIdx.x == nNeededBlocks - 1) {
nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1);
}
const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk;
const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk;
const size_t chunkSizePerRank = nNeededBlocks * nInt4PerChunk;
const size_t blockOffset = nInt4PerChunk * blockIdx.x;
const size_t scratchChunkRankOffset = chunkSizePerRank * rank;

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> channels[NRANKS_PER_NODE - 1];
__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> outChannels[NRANKS_PER_NODE - 1];
Expand All @@ -370,33 +376,74 @@ __global__ void __launch_bounds__(1024, 1)
}
__syncwarp();

/// Starts allgather
for (size_t idx = offsetOfThisBlock + threadIdx.x; idx < offsetOfThisBlock + nInt4OfThisBlock; idx += blockDim.x) {
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = buff4[nInt4PerRank * remoteRank + idx];
channels[peerIdx].write(nInt4PerRank * rank + idx, val);
// we can use double buffering to hide synchronization overhead
for (size_t itr = 0; itr < nItrs; itr++) {
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].relaxedSignal();
outChannels[threadIdx.x].wait();
}
__syncthreads();
// Starts allgather
for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock];
channels[peerIdx].write(scratchChunkRankOffset + blockOffset + idx, val);
}
}
}

/// Starts reduce-scatter

if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].relaxedSignal();
outChannels[threadIdx.x].wait();
/// Starts reduce-scatter
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].relaxedSignal();
outChannels[threadIdx.x].wait();
}
__syncthreads();

for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx];
data = add_vectors<T>(val, data);
}
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock, data);
}
}
offsetOfThisBlock += nInt4PerChunk;
}
__syncthreads();
if (restNInt4 > 0) {
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].relaxedSignal();
outChannels[threadIdx.x].wait();
}
__syncthreads();
for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock];
channels[peerIdx].write(scratchChunkRankOffset + blockOffset + idx, val);
}
}

for (size_t idx = offsetOfThisBlock + threadIdx.x; idx < offsetOfThisBlock + nInt4OfThisBlock; idx += blockDim.x) {
int4 data = buff4[nInt4PerRank * rank + idx];
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = scratch4[nInt4PerRank * remoteRank + idx];
data = add_vectors<T>(val, data);
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].relaxedSignal();
outChannels[threadIdx.x].wait();
}
resultBuff4[nInt4PerRank * rank + idx] = data;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx, data);
__syncthreads();

for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx];
data = add_vectors<T>(val, data);
}
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock, data);
}
}
}
}
Expand All @@ -416,13 +463,11 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
}
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, rank, nRanksPerNode,
worldSize, nelems, flag++);
} else if (sizeof(T) * nelems <= (40 << 20)) {
} else {
int nBlocks = 32;
int nThreadsPerBlock = 512;
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels, rank, nRanksPerNode,
worldSize, nelems);
} else {
// TODO:
}
#else
if (sizeof(T) * nelems <= (1 << 20)) {
Expand Down
2 changes: 1 addition & 1 deletion apps/nccl/src/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
#endif

constexpr int NRANKS_PER_NODE = 8;
constexpr int SCRATCH_SIZE = 1024 * 1024 * 40;
constexpr int SCRATCH_SIZE = 1024 * 1024 * 64; // 32 thread-blocks * 8 ranks * 256KB = 64MB

#endif // NCCL_COMMON_HPP_
2 changes: 1 addition & 1 deletion include/mscclpp/gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

#if defined(__HIP_PLATFORM_AMD__)

#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>

using cudaError_t = hipError_t;
using cudaGraph_t = hipGraph_t;
Expand Down

0 comments on commit d72f09f

Please sign in to comment.