Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve single node allreduce performance #169

Merged
merged 12 commits into from
Sep 13, 2023
Prev Previous commit
Next Next commit
WIP
Binyang2014 committed Aug 27, 2023
commit 6378cf0b411f7a479dd330da3b7336a6b1f72e85
26 changes: 26 additions & 0 deletions include/mscclpp/concurrency.hpp
Original file line number Diff line number Diff line change
@@ -48,6 +48,32 @@ struct DeviceSyncer {
// the flag is flipped.
__syncthreads();
}

__forceinline__ __device__ void syncWithoutFence(int blockNum) {
int maxOldCnt = blockNum - 1;
if (blockNum == 1) {
__syncthreads();
return;
}
if (threadIdx.x == 0) {
int tmpIsAdd = isAdd_ ^ 1;
if (tmpIsAdd) {
if (atomicAdd(&count_, 1) == maxOldCnt) {
flag_ = 1;
}
POLL_MAYBE_JAILBREAK(!flag_, 1000000000);
} else {
if (atomicSub(&count_, 1) == 1) {
flag_ = 0;
}
POLL_MAYBE_JAILBREAK(flag_, 1000000000);
}
isAdd_ = tmpIsAdd;
}
// We need sync here because only a single thread is checking whether
// the flag is flipped.
__syncthreads();
}
#endif

private:
79 changes: 51 additions & 28 deletions test/mscclpp-test/allreduce_test.cu
Original file line number Diff line number Diff line change
@@ -291,39 +291,34 @@ __device__ void localReduceScatterSm(int* buff, int* scratch, int rank, int nRan

int4* buff4 = (int4*)buff;

for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) {
smChans[peerIdx].signalWithoutFence();
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < nPeer) {
smChans[tid].signalWithoutFence();
smChans[tid].wait();
}
for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) {
smChans[peerIdx].wait();
if (tid == nBlocks * blockDim.x - 1) {
__threadfence_system();
}
reduceScatterDeviceSyncer.sync(nBlocks);
reduceScatterDeviceSyncer.syncWithoutFence(nBlocks);

const size_t nInt4 = nelems / 4;
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nInt4; idx += blockDim.x * nBlocks) {
int4 sum = make_int4(0, 0, 0, 0);

for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
int4 val = smChans[(localRankIndexInNode + peerIdx) % nPeer].read<int4>(indexOffset4 + idx);
sum.w += val.w;
sum.x += val.x;
sum.y += val.y;
sum.z += val.z;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
int4 val;
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nInt4; idx += blockDim.x * nBlocks) {
val = smChans[(localRankIndexInNode + peerIdx) % nPeer].read<int4>(indexOffset4 + idx);
buff4[indexOffset4 + idx].w += val.w;
buff4[indexOffset4 + idx].x += val.x;
buff4[indexOffset4 + idx].y += val.y;
buff4[indexOffset4 + idx].z += val.z;
}
buff4[indexOffset4 + idx].w += sum.w;
buff4[indexOffset4 + idx].x += sum.x;
buff4[indexOffset4 + idx].y += sum.y;
buff4[indexOffset4 + idx].z += sum.z;
}

const size_t nLastInts = nelems % 4;
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nLastInts; idx += blockDim.x * nBlocks) {
int sum = 0;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nLastInts; idx += blockDim.x * nBlocks) {
int val = smChans[(localRankIndexInNode + peerIdx) % nPeer].read<int>(indexOffset + nInt4 * 4 + idx);
sum += val;
buff[indexOffset + nInt4 * 4 + idx] += val;
}
buff[indexOffset + nInt4 * 4 + idx] += sum;
}
}

@@ -413,6 +408,26 @@ __device__ void localAllGatherSm(int rank, int nRanksPerNode, int startRankChunk
uint64_t rankChunkSize, uint64_t size, size_t nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;

// int tid = threadIdx.x + blockIdx.x * blockDim.x;
// const int nPeer = nRanksPerNode - 1;

// if (tid < nPeer) {
// constSmInPlaceChans[tid].signalWithoutFence();
// constSmInPlaceChans[tid].wait();
// }
// if (tid == nBlocks * blockDim.x - 1) {
// __threadfence_system();
// }
// allGatherDeviceSyncer.syncWithoutFence(nBlocks);
// for (int i = 0; i < nPeer; ++i) {
// int peerIdx = (i + rank) % nPeer;
// const size_t rankLocalIndex = rank % nRanksPerNode;
// const int remoteRankLocalIndex = (peerIdx < rankLocalIndex ? peerIdx : peerIdx + 1);
// size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;
// constSmInPlaceChans[peerIdx].get(offset, size, tid, blockDim.x * nBlocks);
// }

const size_t nPeer = nRanksPerNode - 1;
const size_t peerIdx = blockIdx.x % nPeer;
const size_t nBlockForThisPeer = nBlocks / nPeer + (nBlocks % nPeer > peerIdx ? 1 : 0);
@@ -449,6 +464,10 @@ __device__ void localAllGatherSm(int rank, int nRanksPerNode, int startRankChunk
constSmInPlaceChans[peerIdx].signalWithoutFence();
constSmInPlaceChans[peerIdx].wait();
}
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid == nBlocks * blockDim.x - 1) {
__threadfence_system();
}
allGatherDeviceSyncer.sync(nBlocks);
size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;
constSmInPlaceChans[peerIdx].get(offset + offsetForThisBlock, sizeForThisBlock, threadIdx.x, blockDim.x);
@@ -477,7 +496,7 @@ __device__ void allGatherSm(int rank, int worldSize, int nRanksPerNode, size_t n
const int startRankIndexInPeerNode = (peerRank / nRanksPerNode) * nRanksPerNode;

if (peerNodeId == rank / nRanksPerNode) {
localAllGatherSm(rank, nRanksPerNode, 0, 0, rankChunkSize, rankChunkSize, nBlocksForLocalAllGather);
localAllGatherSm(rank, nRanksPerNode, 0, 0, rankChunkSize, rankChunkSize, gridDim.x);
return;
}

@@ -791,9 +810,13 @@ __global__ void __launch_bounds__(1024)

__global__ void allreduce4(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize,
size_t nelems) {
reduceScatterSm(buff, scratch, rank, nRanksPerNode, worldSize, nelems);
deviceSyncer.sync(gridDim.x);
allGatherSm(rank, worldSize, nRanksPerNode, nelems / worldSize);
// reduceScatterSm(buff, scratch, rank, nRanksPerNode, worldSize, nelems);
// deviceSyncer.sync(gridDim.x);
// allGatherSm(rank, worldSize, nRanksPerNode, nelems / worldSize);
localReduceScatterSm(buff, scratch, rank, nRanksPerNode, 0, 0, nelems / worldSize, nelems / worldSize, gridDim.x);
deviceSyncer.syncWithoutFence(gridDim.x);
localAllGatherSm(rank, nRanksPerNode, 0, 0, nelems / worldSize * sizeof(int), nelems / worldSize * sizeof(int),
gridDim.x);
}

class AllReduceTestColl : public BaseTestColl {
@@ -828,7 +851,7 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
tmpBuff = scratchBuff;
nThreadsPerBlock = 1024;
} else if (kernelNum == 4) {
nBlocks = 45;
nBlocks = 49;
tmpBuff = scratchBuff;
nThreadsPerBlock = 512;
} else {