Skip to content

Commit

Permalink
allgather wip
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Feb 22, 2024
1 parent 5bf684b commit 0b4221d
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 14 deletions.
18 changes: 15 additions & 3 deletions apps/nccl/rccl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
a = torch.ones(shape_a, device="cuda")
b = torch.ones(shape_b, device="cuda")
c = torch.mm(a, b)

print(c)

nccl_op = nccl.NCCL_SUM
group.allReduce(
sendbuf=c.data_ptr(),
Expand All @@ -49,5 +49,17 @@
datatype=nccl.NCCL_FLOAT,
op=nccl_op,
stream=torch.cuda.current_stream().cuda_stream)

print(c)

# d = torch.div(c, 2)
# e = torch.ones((8*M,N), device="cuda")

# group.allGather(
# sendbuf=d.data_ptr(),
# recvbuf=e.data_ptr(),
# count=d.nelement(),
# datatype=nccl.NCCL_FLOAT,
# stream=torch.cuda.current_stream().cuda_stream)

# print(e)
128 changes: 117 additions & 11 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
} \
} while (0)

#define NUM_CHANNELS_PER_CONNECTION 32

#if defined(__HIP_PLATFORM_AMD__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif

template <typename To, typename From>
__forceinline__ __device__ To bit_cast(const From& src) {
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
Expand Down Expand Up @@ -143,8 +151,8 @@ static const int scratchSize = 1024 * 1024 * 8;
// mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
// mscclpp::Transport::IB6, mscclpp::Transport::IB7};

__constant__ mscclpp::DeviceHandle<mscclpp::SmChannel> constSmChannels[8];
__constant__ mscclpp::DeviceHandle<mscclpp::SmChannel> constSmOutChannels[8];
__constant__ mscclpp::DeviceHandle<mscclpp::SmChannel> constSmChannels[256];
__constant__ mscclpp::DeviceHandle<mscclpp::SmChannel> constSmOutChannels[256];
__device__ mscclpp::DeviceSyncer deviceSyncer;

struct channelKey {
Expand Down Expand Up @@ -411,6 +419,76 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPe
return cudaGetLastError();
}

__global__ void __launch_bounds__(1024, 1)
allgather5(size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) {
const size_t nBlock = gridDim.x;
if (blockIdx.x >= nBlock) return;

const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
const size_t lid = tid % WARP_SIZE;
const size_t wid = tid / WARP_SIZE;

const size_t nThread = blockDim.x * nBlock;
const size_t nWarp = nThread / WARP_SIZE;
const size_t nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;
auto smChans = constSmChannels + chanOffset;

if (wid < nPeer && lid == 0) {
smChans[wid].relaxedSignal();
smChans[wid].wait();
}
__syncthreads();
const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
const size_t bytes = bytesPerGPU * nPeer;
size_t unitBytesPerThread;
if (bytes >= nThread * 64) {
unitBytesPerThread = 64;
} else {
unitBytesPerThread = 16;
}
const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE;
const size_t unitBytes = unitBytesPerWarp * nWarp;
const size_t nLoop = bytes / unitBytes;

if (nLoop > 0) {
// First loop unrolling
const size_t peerIdx = wid % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (wid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE);
}

for (size_t i = 1; i < nLoop; ++i) {
const size_t gWid = wid + i * nWarp;
const size_t peerIdx = gWid % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].get<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE);
}

if (bytes % unitBytes > 0) {
const size_t gWid = wid + nLoop * nWarp;
const size_t peerIdx = gWid % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
const size_t offset = bytesPerGPU * remoteRankLocalIndex + offsetWithinRank;
const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU)
? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0)
: unitBytesPerWarp;
if (remainBytes > 0) {
smChans[peerIdx].get<16, true>(offset, remainBytes, lid, WARP_SIZE);
}
}
}

template <typename T>
cudaError_t allgather(T* buff, T* scratch, T* resultBuff, int rank, int nRanksPerNode, int worldSize, size_t nelems,
cudaStream_t stream) {
allgather5<<<24, 1024, 0, stream>>>(rank, worldSize, nRanksPerNode, nelems);
return cudaGetLastError();
}

static size_t ncclTypeSize(ncclDataType_t type) {
switch (type) {
case ncclInt8:
Expand Down Expand Up @@ -474,9 +552,12 @@ static std::vector<mscclpp::SmChannel> setupSmChannels(ncclComm_t comm,
void* src) {
std::vector<mscclpp::SmChannel> channels;
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>>& smSemaphores = comm->smSemaphores;
for (size_t cid = 0; cid < comm->connections.size(); ++cid) {
if (comm->connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
channels.emplace_back(smSemaphores[cid], remoteMemories[cid], src, nullptr);
size_t nConnections = comm->connections.size();
for (size_t idx = 0; idx < NUM_CHANNELS_PER_CONNECTION; ++idx) {
for (size_t cid = 0; cid < nConnections; ++cid) {
if (comm->connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
channels.emplace_back(smSemaphores[idx * nConnections + cid], remoteMemories[cid], src, nullptr);
}
}
}
return channels;
Expand Down Expand Up @@ -523,10 +604,12 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
[](const auto& future) { return future.get(); });

std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
smSemaphores.emplace_back(
std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*(mscclppComm), connections[cid]));
for (size_t idx = 0; idx < NUM_CHANNELS_PER_CONNECTION; ++idx) {
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
smSemaphores.emplace_back(
std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*(mscclppComm), connections[cid]));
}
}
}
mscclppComm->setup();
Expand Down Expand Up @@ -723,8 +806,31 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff,

NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
// TODO: implement this function
return ncclInternalError;
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;
int rank = comm->comm->bootstrap()->getRank();
channelKey key{sendbuff, recvbuff, bytes};

auto it = comm->smChannels.find(key);
if (it == comm->smChannels.end()) {
std::vector<mscclpp::RegisteredMemory> remoteMemories =
setupRemoteMemories(comm->comm, rank, const_cast<void*>(sendbuff), bytes,
mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, remoteMemories, const_cast<void*>(sendbuff));
it = comm->smChannels.emplace(key, channels).first;
}
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(it->second.begin(), it->second.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
// TODO: if sendbuff and recvbuff don't change, we can avoid copying smChannelDeviceHandles to device
CUDACHECK(cudaMemcpyToSymbol(constSmChannels, smChannelDeviceHandles.data(),
sizeof(mscclpp::DeviceHandle<mscclpp::SmChannel>) * smChannelDeviceHandles.size()));

CUDACHECK(allgather((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff,
comm->comm->bootstrap()->getRank(), nRanksPerNode, comm->comm->bootstrap()->getNranks(),
sendcount, stream));
return ncclSuccess;
}

NCCL_API ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
Expand Down

0 comments on commit 0b4221d

Please sign in to comment.