Skip to content

Commit

Permalink
allreduce7 works
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Feb 21, 2024
1 parent b8fa212 commit 4b3e27c
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 3 deletions.
91 changes: 91 additions & 0 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ __forceinline__ __device__ int add_vectors<__half>(int a, int b) {
return add_vectors_helper<__half2>(a, b);
}

template <typename T>
__forceinline__ __device__ uint32_t add_vectors_helper(uint32_t a, uint32_t b) {
return bit_cast<uint32_t, T>(add_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
}

template <typename T>
__forceinline__ __device__ uint32_t add_vectors(uint32_t a, uint32_t b) {
return add_vectors_helper<T>(a, b);
}

template <>
__forceinline__ __device__ uint32_t add_vectors<__half>(uint32_t a, uint32_t b) {
return add_vectors_helper<__half2>(a, b);
}

template <typename T>
__forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem, int blockId, int nBlocks) {
size_t nInt4 = nElem / 4;
Expand Down Expand Up @@ -310,11 +325,86 @@ __global__ void allreduce1(T* src, T* dst, int rank, int nranks, size_t nelems)
}
}

template <typename T>
__global__ void allreduce7(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPerNode, int worldSize,
size_t nelems) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
nelems = nelems / (sizeof(int) / sizeof(T));
const int nPeers = nRanksPerNode - 1;
const size_t nPkts = nelems;
const int nelemsPerRank = nelems / worldSize;
const int nPktsPerRank = nelemsPerRank;
// flag for packets. Initially 1
const uint32_t flag = (uint32_t)globalFlag;
// thread block & channel info
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LL8Packet);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LL8Packet);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LL8Packet) : 3 * nPkts * sizeof(mscclpp::LL8Packet);
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
uint32_t* src = (uint32_t*)((char*)buff + rank * nelemsPerRank * sizeof(int));
uint32_t* dst = (uint32_t*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));

// step 1: write to scratch buffer
constSmChannels[peerIdx].putPackets<mscclpp::LL8Packet>(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid,
blockDim.x * nBlocksPerPeer, flag);
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
uint32_t data = 0;
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nPktsPerRank;
uint32_t val = dstPkt[idx].read(flag);
data = add_vectors<T>(val, data);
}
data = add_vectors<T>(data, src[idx]);
dst[idx] = data;

mscclpp::LL8Packet packet;
packet.data = data;
packet.flag = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LL8Packet) + (idx + rank * nPktsPerRank);
for (int index = 0; index < nPeers; index++) {
constSmChannels[index].write(offset, packet);
}
}
// step 3: get data result from scratch buffer
mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRank * nPktsPerRank;
uint32_t* result = (uint32_t*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
uint32_t data = dstPkt[idx + dstOffset].read(flag);
result[idx] = data;
}
if (threadIdx.x == 0 && blockIdx.x == 0) {
globalFlag += 1;
}
}

template <typename T>
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPerNode, int worldSize, size_t nelems,
cudaStream_t stream) {
if (sizeof(T) * nelems <= (1 << 20)) {
#if defined(__HIP_PLATFORM_AMD__)
int nBlocks = 28;
int nThreadsPerBlock = 1024;
if (nelems >= 8192) {
nBlocks = 56;
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
}
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, rank, nRanksPerNode, worldSize,
nelems);
#else
allreduce6<<<21, 512, 0, stream>>>(buff, scratch, resultBuff, rank, nRanksPerNode, worldSize, nelems);
#endif
} else {
allreduce1<<<24, 1024, 0, stream>>>(buff, resultBuff, rank, worldSize, nelems);
}
Expand Down Expand Up @@ -511,6 +601,7 @@ NCCL_API ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) {
}

NCCL_API ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* device) {
if (comm == nullptr || device == nullptr) return ncclInvalidArgument;
*device = comm->comm->bootstrap()->getRank();
return ncclSuccess;
}
Expand Down
2 changes: 1 addition & 1 deletion src/bootstrap/bootstrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t tim
}

if (!netInitialized) {
netInit("", "", netIfAddr_);
netInit(ipPortPair, interface, netIfAddr_);
netInitialized = true;
}

Expand Down
10 changes: 8 additions & 2 deletions test/mscclpp-test/allreduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1138,9 +1138,15 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
tmpBuff = scratchPacketBuff;
nThreadsPerBlock = 512;
} else if (kernelNum == 7) {
nBlocks = 28;
tmpBuff = scratchPacketBuff;
nThreadsPerBlock = 1024;
// tune the #blocks and #threads for MI300X
if (paramCount_ < 8192) {
nBlocks = 28;
nThreadsPerBlock = 1024;
} else {
nBlocks = 56;
nThreadsPerBlock = (paramCount_ <= 76800) ? 512 : 1024;
}
} else {
nBlocks = std::max(args.nRanksPerNode - 1, 1) * BLOCKS_PER_PEER;
tmpBuff = scratchPacketBuff;
Expand Down

0 comments on commit 4b3e27c

Please sign in to comment.