Skip to content

Commit

Permalink
update for h100 allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu authored and Ubuntu committed Sep 21, 2023
1 parent b3d0fdb commit 7fd63f8
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 41 deletions.
1 change: 1 addition & 0 deletions include/mscclpp/packet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#ifndef MSCCLPP_PACKET_HPP_
#define MSCCLPP_PACKET_HPP_

#include <stdint.h>
#include "poll.hpp"

namespace mscclpp {
Expand Down
100 changes: 59 additions & 41 deletions test/mscclpp-test/allreduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ __device__ void reduceScatterSm(int* buff, int* scratch, int rank, int nRanksPer
// step with its cross-node neighbor (same local rank number on the other node) via IB. Then performs a reduce
// operation.
int pipelineSize = 3;
float nBlocksForReduceScatterRatio = 0.8;
float nBlocksForReduceScatterRatio = 0.85;
const size_t chunkSize = nelems / worldSize;
const int peerRank = (rank + nRanksPerNode) % worldSize;
int peerNodeId = peerRank / nRanksPerNode;
Expand Down Expand Up @@ -413,7 +413,7 @@ __device__ void reduceScatterSm(int* buff, int* scratch, int rank, int nRanksPer
proxyChan.putWithSignal(offset, (chunkSize / pipelineSize * sizeof(int)));
}
localReduceScatterSm(buff, scratch, rank, nRanksPerNode, startChunkIndex, chunkSize / pipelineSize, chunkSize,
2 * chunkSize / pipelineSize, nBlocksForReduceScatter);
(pipelineSize - 1) * chunkSize / pipelineSize, nBlocksForReduceScatter);
if (isComm) {
proxyChan.wait();
}
Expand Down Expand Up @@ -446,7 +446,7 @@ __device__ void reduceScatterSm(int* buff, int* scratch, int rank, int nRanksPer
size_t offset = (rank * chunkSize + chunkSize / pipelineSize) * sizeof(int);
int* dst = (int*)((char*)buff + offset);
int* src = (int*)((char*)scratch + offset);
vectorSum(dst, src, 2 * chunkSize / pipelineSize);
vectorSum(dst, src, (pipelineSize - 1) * chunkSize / pipelineSize);
if (isComm) {
proxyChan.flush();
}
Expand All @@ -457,45 +457,63 @@ __device__ void localAllGatherSm(int rank, int nRanksPerNode, int startRankChunk
uint64_t rankChunkSize, uint64_t size, size_t nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;
const size_t nPeer = nRanksPerNode - 1;
const size_t peerIdx = blockIdx.x % nPeer;
const size_t nBlockForThisPeer = nBlocks / nPeer + (nBlocks % nPeer > peerIdx ? 1 : 0);
const size_t peerLocalBlockIdx = blockIdx.x / nPeer;
const size_t rankLocalIndex = rank % nRanksPerNode;
const int remoteRankLocalIndex = (peerIdx < rankLocalIndex ? peerIdx : peerIdx + 1);

// Split the data into chunks for aligned data access. Ignore the remainder here and let the last block handle it.
constexpr size_t chunkBytes = 128; // heuristic value
const size_t nChunk = size / chunkBytes;
const size_t nMinChunkPerBlock = nChunk / nBlockForThisPeer;
const size_t nRemainderChunk = nChunk % nBlockForThisPeer;

// Distribute chunks to blocks
size_t nChunkForThisBlock;
size_t offsetForThisBlock;
if (peerLocalBlockIdx < nRemainderChunk) {
nChunkForThisBlock = nMinChunkPerBlock + 1;
offsetForThisBlock = (nMinChunkPerBlock + 1) * peerLocalBlockIdx;
} else {
nChunkForThisBlock = nMinChunkPerBlock;
offsetForThisBlock =
(nMinChunkPerBlock + 1) * nRemainderChunk + (peerLocalBlockIdx - nRemainderChunk) * nMinChunkPerBlock;
}
offsetForThisBlock *= chunkBytes;
// const size_t nPeer = nRanksPerNode - 1;
// const size_t peerIdx = blockIdx.x % nPeer;
// const size_t nBlockForThisPeer = nBlocks / nPeer + (nBlocks % nPeer > peerIdx ? 1 : 0);
// const size_t peerLocalBlockIdx = blockIdx.x / nPeer;
// const size_t rankLocalIndex = rank % nRanksPerNode;
// const int remoteRankLocalIndex = (peerIdx < rankLocalIndex ? peerIdx : peerIdx + 1);

// // Split the data into chunks for aligned data access. Ignore the remainder here and let the last block handle it.
// constexpr size_t chunkBytes = 128; // heuristic value
// const size_t nChunk = size / chunkBytes;
// const size_t nMinChunkPerBlock = nChunk / nBlockForThisPeer;
// const size_t nRemainderChunk = nChunk % nBlockForThisPeer;

// // Distribute chunks to blocks
// size_t nChunkForThisBlock;
// size_t offsetForThisBlock;
// if (peerLocalBlockIdx < nRemainderChunk) {
// nChunkForThisBlock = nMinChunkPerBlock + 1;
// offsetForThisBlock = (nMinChunkPerBlock + 1) * peerLocalBlockIdx;
// } else {
// nChunkForThisBlock = nMinChunkPerBlock;
// offsetForThisBlock =
// (nMinChunkPerBlock + 1) * nRemainderChunk + (peerLocalBlockIdx - nRemainderChunk) * nMinChunkPerBlock;
// }
// offsetForThisBlock *= chunkBytes;

// // Calculate the size of the data for this block
// size_t sizeForThisBlock = nChunkForThisBlock * chunkBytes;
// const size_t lastChunkSize = size - nChunk * chunkBytes;
// if (lastChunkSize > 0 && peerLocalBlockIdx == nBlockForThisPeer - 1) {
// sizeForThisBlock += lastChunkSize;
// }
// if (threadIdx.x == 0 && peerLocalBlockIdx == 0) {
// constSmInPlaceChans[peerIdx].signal();
// constSmInPlaceChans[peerIdx].wait();
// }
// allGatherDeviceSyncer.sync(nBlocks);
// size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;
// constSmInPlaceChans[peerIdx].get(offset + offsetForThisBlock, sizeForThisBlock, threadIdx.x, blockDim.x);
int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int nPeer = nRanksPerNode - 1;

// Calculate the size of the data for this block
size_t sizeForThisBlock = nChunkForThisBlock * chunkBytes;
const size_t lastChunkSize = size - nChunk * chunkBytes;
if (lastChunkSize > 0 && peerLocalBlockIdx == nBlockForThisPeer - 1) {
sizeForThisBlock += lastChunkSize;
if (tid < nPeer) {
constSmInPlaceChans[tid].signal();
}
if (threadIdx.x == 0 && peerLocalBlockIdx == 0) {
constSmInPlaceChans[peerIdx].signal();
constSmInPlaceChans[peerIdx].wait();
int waitStart = nBlocks * blockDim.x - nPeer;
if (tid >= waitStart && tid < nBlocks * blockDim.x) {
constSmInPlaceChans[tid - waitStart].wait();
}
allGatherDeviceSyncer.sync(nBlocks);
size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;
constSmInPlaceChans[peerIdx].get(offset + offsetForThisBlock, sizeForThisBlock, threadIdx.x, blockDim.x);
const size_t rankLocalIndex = rank % nRanksPerNode;
for (int i = 0; i < nPeer; ++i) {
int peerIdx = (i + rankLocalIndex) % nPeer;
const int remoteRankLocalIndex = (peerIdx < rankLocalIndex ? peerIdx : peerIdx + 1);
size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;
constSmInPlaceChans[peerIdx].get(offset, size, tid, blockDim.x * nBlocks);
}
}

__device__ void localRingAllGatherSm(int rank, int nRanksPerNode, uint64_t size, size_t nBlocks) {
Expand Down Expand Up @@ -538,7 +556,7 @@ __device__ void allGatherSm(int rank, int worldSize, int nRanksPerNode, size_t n
int peerNodeId = peerRank / nRanksPerNode;
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
DeviceHandle<mscclpp::SimpleProxyChannel>& proxyChan = constDevSndRoundChans[peer];
const size_t nBlocksForLocalAllGather = gridDim.x / (nRanksPerNode - 1) * (nRanksPerNode - 1);
const size_t nBlocksForLocalAllGather = gridDim.x;// / (nRanksPerNode - 1) * (nRanksPerNode - 1);
const size_t rankChunkSize = nelemsPerGPU * sizeof(int);
const int startRankIndexInLocalNode = (rank / nRanksPerNode) * nRanksPerNode;
const int startRankIndexInPeerNode = (peerRank / nRanksPerNode) * nRanksPerNode;
Expand Down Expand Up @@ -964,7 +982,7 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
tmpBuff = scratchBuff;
nThreadsPerBlock = 1024;
} else if (kernelNum == 4) {
nBlocks = 45;
nBlocks = 64;
tmpBuff = scratchBuff;
nThreadsPerBlock = 512;
} else if (kernelNum == 5) {
Expand Down Expand Up @@ -1050,7 +1068,7 @@ std::vector<KernelRestriction> AllReduceTestColl::getKernelRestrictions() {
4,
"allreduce4",
true,
3,
1,
.alignedBytes = 16 * worldSize_ /*use ulong2 to transfer data*/,
},
{5, "allreduce5", false, 1, .alignedBytes = 4 * worldSize_},
Expand Down

0 comments on commit 7fd63f8

Please sign in to comment.