Skip to content

Commit

Permalink
Add allgather5
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Jan 15, 2024
1 parent c0fe31f commit d208486
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 12 deletions.
55 changes: 52 additions & 3 deletions test/mscclpp-test/allgather_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using DeviceHandle = mscclpp::DeviceHandle<T>;
__constant__ DeviceHandle<mscclpp::SimpleProxyChannel> constProxyChans[16];
__constant__ DeviceHandle<mscclpp::ProxyChannel> constRawProxyChan[16];

__constant__ DeviceHandle<mscclpp::SmChannel> constSmChans[8];
__constant__ DeviceHandle<mscclpp::SmChannel> constSmChans[256];

__global__ void allgather0(int rank, size_t nelemsPerGPU) {
int warpId = threadIdx.x / WARP_SIZE;
Expand Down Expand Up @@ -288,6 +288,49 @@ __global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne
nBlocksForLocalAllGather);
}

__global__ void allgather5(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int lid = tid % WARP_SIZE;
const int wid = tid / WARP_SIZE;
const int nWarp = blockDim.x * gridDim.x / WARP_SIZE;
const int nPeer = nRanksPerNode - 1;
const int chanOffset = nPeer * blockIdx.x;
auto smChans = constSmChans + chanOffset;

if (wid < nPeer) {
smChans[wid].signal();
smChans[wid].wait();
}
__syncthreads();
constexpr size_t unitBytesPerThread = 16;
constexpr size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE;
const size_t unitBytes = unitBytesPerWarp * nWarp;
const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
const size_t bytes = bytesPerGPU * nPeer;
const size_t nLoop = bytes / unitBytes;
for (size_t i = 0; i < nLoop; ++i) {
const size_t gWid = wid + i * nWarp;
const int peerIdx = gWid % nPeer;
const int remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
const size_t offset = bytesPerGPU * remoteRankLocalIndex + (gWid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].get(offset, unitBytesPerWarp, lid, WARP_SIZE);
}

if (bytes % unitBytes > 0) {
const size_t gWid = wid + nLoop * nWarp;
const int peerIdx = gWid % nPeer;
const int remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
const size_t offset = bytesPerGPU * remoteRankLocalIndex + offsetWithinRank;
const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU)
? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0)
: unitBytesPerWarp;
if (remainBytes > 0) {
smChans[peerIdx].get(offset, remainBytes, lid, WARP_SIZE);
}
}
}

class AllGatherProxyService : public mscclpp::BaseProxyService {
public:
AllGatherProxyService(int worldSize, int rank, int cudaDevice);
Expand Down Expand Up @@ -387,6 +430,9 @@ void AllGatherTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
if (kernelNum == 4) {
nBlocks = 21;
nThreads = 1024;
} else if (kernelNum == 5) {
nBlocks = 32;
nThreads = 1024;
} else {
nBlocks = 1;
nThreads = WARP_SIZE * (worldSize - 1);
Expand All @@ -401,6 +447,8 @@ void AllGatherTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
allgather3<<<nBlocks, nThreads, 0, stream>>>();
} else if (kernelNum == 4) {
allgather4<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, nRanksPerNode, paramCount_);
} else if (kernelNum == 5) {
allgather5<<<nBlocks, nThreads, 0, stream>>>(rank, worldSize, nRanksPerNode, paramCount_);
}
}

Expand Down Expand Up @@ -453,7 +501,8 @@ std::vector<KernelRestriction> AllGatherTestColl::getKernelRestrictions() {
{1, "allgather1", false, 1, 4 * worldSize_},
{2, "allgather2", true, 3, 4 * worldSize_},
{3, "allgather3", true, 1, 4 * worldSize_},
{4, "allgather4", true, 3, 16 * worldSize_ /*use ulong2 to transfer data*/}};
{4, "allgather4", true, 3, 16 * worldSize_ /*use ulong2 to transfer data*/},
{5, "allgather5", false, 1, 16 * worldSize_ /*use ulong2 to transfer data*/}};
}

class AllGatherTestEngine : public BaseTestEngine {
Expand Down Expand Up @@ -494,7 +543,7 @@ void AllGatherTestEngine::setupConnections() {
CUDATHROW(cudaMemcpyToSymbol(constProxyChans, devProxyChannels.data(),
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>) * devProxyChannels.size()));

setupMeshConnections(smChannels_, sendBuff_.get(), args_.maxBytes);
setupMeshConnections(smChannels_, sendBuff_.get(), args_.maxBytes, nullptr, 0, ChannelSemantic::PUT, 32);
std::vector<DeviceHandle<mscclpp::SmChannel>> smChannelHandles(smChannels_.size());
if (smChannels_.size() > sizeof(constSmChans) / sizeof(DeviceHandle<mscclpp::SmChannel>)) {
std::runtime_error("unexpected error");
Expand Down
20 changes: 12 additions & 8 deletions test/mscclpp-test/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Simp

void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff,
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes,
ChannelSemantic semantic) {
ChannelSemantic semantic, size_t nChannelPerConnection) {
const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum];
mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports);
mscclpp::RegisteredMemory getPacketBufRegMem;
Expand All @@ -443,19 +443,23 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
(outputBuff && semantic == ChannelSemantic::PUT) ? outputBufRegMem : inputBufRegMem;
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);

std::unordered_map<size_t, std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
std::unordered_map<size_t, std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>>> smSemaphores;
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
smSemaphores.emplace(cid, std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm_, connections[cid]));
for (size_t i = 0; i < nChannelPerConnection; ++i) {
smSemaphores[cid].emplace_back(std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm_, connections[cid]));
}
}
}
comm_->setup();

for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
smChannels.emplace_back(smSemaphores[cid], remoteRegMemories[cid].get(),
(outputBuff && semantic == ChannelSemantic::GET) ? outputBuff : inputBufRegMem.data(),
nullptr);
for (size_t i = 0; i < nChannelPerConnection; ++i) {
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
smChannels.emplace_back(smSemaphores[cid][i], remoteRegMemories[cid].get(),
(outputBuff && semantic == ChannelSemantic::GET) ? outputBuff : inputBufRegMem.data(),
nullptr);
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/mscclpp-test/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class BaseTestEngine {
SetupChannelFunc setupChannel = nullptr);
void setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff, size_t inputBuffBytes,
void* outputBuff = nullptr, size_t outputBuffBytes = 0,
ChannelSemantic semantic = ChannelSemantic::PUT);
ChannelSemantic semantic = ChannelSemantic::PUT, size_t nChannelPerConnection = 1);
void setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels,
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>>& proxyChannels, void* inputBuff,
size_t inputBuffBytes, void* putPacketBuff = nullptr, size_t putPacketBuffBytes = 0,
Expand Down

0 comments on commit d208486

Please sign in to comment.