Skip to content

Commit

Permalink
Improve single node allreduce performance (#169)
Browse files Browse the repository at this point in the history
Improve all reduce performance for single node.
New number:
|   n_ctx | size    |  target latency (us) | allreduce5 | allreduce6 |
|---------|---------|----------------|------------|------------|
|       1 | 24.0kB  |            7.7 |            |        7.23|
|       2 | 48.0kB  |            7.7 |            |        7.69|
|       4 | 96.0kB  |            8   |            |        8.34|
|       8 | 192.0kB |           12.6 |            |        9.75|
|      12 | 288.0kB |           13   |            |       11.34|
|      16 | 384.0kB |           13.3 |            |       12.99|
|     768 | 18.0MB  |          158.7 |       160.3|            |
|     896 | 21.0MB  |          184.5 |       183.8|            |
|    1024 | 24.0MB  |          209.5 |       207.5|            |
|    1152 | 27.0MB  |          234.3 |       231.9|            |
|    1280 | 30.0MB  |          260   |       255.6|            |
|    1408 | 33.0MB  |          284.9 |       278.7|            |
|    1536 | 36.0MB  |          310.3 |       302.0|            |
|    1664 | 39.0MB  |          336.2 |       325.3|            |
|    1792 | 42.0MB  |          361.4 |       348.8|            |
|    1920 | 45.0MB  |          384.6 |       372.2|            |
|    2048 | 48.0MB  |          409.1 |       395.4|            |

---------

Co-authored-by: Changho Hwang <[email protected]>
  • Loading branch information
Binyang2014 and chhwang authored Sep 13, 2023
1 parent 015e29c commit 952f2da
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 25 deletions.
2 changes: 2 additions & 0 deletions .azure-pipelines/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ jobs:
mpirun -np 8 --bind-to numa -x MSCCLPP_DEBUG=WARN ./build/test/mscclpp-test/allreduce_test_perf -b 1K -e 1G -f 2 -k 2 -o output.jsonl
mpirun -np 8 --bind-to numa -x MSCCLPP_DEBUG=WARN ./build/test/mscclpp-test/allreduce_test_perf -b 1K -e 1G -f 2 -k 3 -o output.jsonl
mpirun -np 8 --bind-to numa -x MSCCLPP_DEBUG=WARN ./build/test/mscclpp-test/allreduce_test_perf -b 1K -e 1G -f 2 -k 4 -o output.jsonl
mpirun -np 8 --bind-to numa -x MSCCLPP_DEBUG=WARN ./build/test/mscclpp-test/allreduce_test_perf -b 12M -e 48M -i 3145728 2 -k 5 -o output.jsonl
mpirun -np 8 --bind-to numa -x MSCCLPP_DEBUG=WARN ./build/test/mscclpp-test/allreduce_test_perf -b 24K -e 768K -i 24576 -k 6 -w 100 -n 100 -o output.jsonl
workingDirectory: '$(System.DefaultWorkingDirectory)'

- task: Bash@3
Expand Down
4 changes: 4 additions & 0 deletions test/deploy/perf_ndmv4.jsonl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
{"name":"allreduce", "kernel":3, "ranks":8, "ranksPerNode":8, "algBw":139.08, "busBw":243.40, "size":1073741824, "time":7719.85, "target":"throughput"}
{"name":"allreduce", "kernel":4, "ranks":8, "ranksPerNode":8, "algBw":106.98, "busBw":187.22, "size":16777216, "time":156.81, "target":"throughput"}
{"name":"allreduce", "kernel":4, "ranks":8, "ranksPerNode":8, "algBw":116.24, "busBw":203.42, "size":33554432, "time":288.65, "target":"throughput"}
{"name":"allreduce", "kernel":5, "ranks":8, "ranksPerNode":8, "algBw":126.52,"busBw":221.418,"size":50331648, "time":397.79, "target":"throughput"}
{"name":"allreduce", "kernel":6, "ranks":8, "ranksPerNode":8, "algBw":3.3919,"busBw":5.9359, "size":24576, "time":7.24, "target":"latency"}
{"name":"allreduce", "kernel":6, "ranks":8, "ranksPerNode":8, "algBw":6.21, "busBw":10.87, "size":49152, "time":7.91, "target":"latency"}
{"name":"allreduce", "kernel":6, "ranks":8, "ranksPerNode":8, "algBw":8.90, "busBw":15.57, "size":73728, "time":8.28, "target":"latency"}
{"name":"allreduce", "kernel":4, "ranks":16,"ranksPerNode":8, "algBw":84.55, "busBw":158.53, "size":25165824, "time":297.64, "target":"throughput"}
{"name":"allreduce", "kernel":4, "ranks":16,"ranksPerNode":8, "algBw":99.43, "busBw":186.44, "size":50331648, "time":506.16, "target":"throughput"}
{"name":"allreduce", "kernel":4, "ranks":16,"ranksPerNode":8, "algBw":124.60, "busBw":233.64, "size":3221225472, "time":25850.67,"target":"throughput"}
Expand Down
223 changes: 198 additions & 25 deletions test/mscclpp-test/allreduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ __device__ void vectorSumSingleBlock(int* dst, int* src, size_t nElem) {
}

__device__ mscclpp::DeviceSyncer deviceSyncer;
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
__device__ mscclpp::DeviceSyncer reduceScatterDeviceSyncer;
__device__ mscclpp::DeviceSyncer ibDeviceSyncer;

Expand Down Expand Up @@ -326,6 +327,51 @@ __device__ void localReduceScatterSm(int* buff, int* scratch, int rank, int nRan
}
}

__device__ void localReduceScatterSm2(int* buff, int* scratch, int rank, int nRanksPerNode, size_t chunkSize,
size_t nelems, int nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;
const int nPeer = nRanksPerNode - 1;
DeviceHandle<mscclpp::SmChannel>* smChans = constSmOutOfPlaceGetChans;

const size_t localRankIndexInNode = rank % nRanksPerNode;
const size_t indexOffset = localRankIndexInNode * chunkSize;
const size_t indexOffset4 = indexOffset / 4;

int4* buff4 = (int4*)buff;

const int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < nPeer) {
smChans[tid].signal();
}
const int waitStart = nBlocks * blockDim.x - nPeer;
if (tid >= waitStart && tid < nBlocks * blockDim.x) {
smChans[tid - waitStart].wait();
}
reduceScatterDeviceSyncer.sync(nBlocks);

const size_t nInt4 = nelems / 4;
for (int index = 0; index < nPeer; ++index) {
int4 val;
int peerIdx = (index + localRankIndexInNode) % nPeer;
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nInt4; idx += blockDim.x * nBlocks) {
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
buff4[indexOffset4 + idx].w += val.w;
buff4[indexOffset4 + idx].x += val.x;
buff4[indexOffset4 + idx].y += val.y;
buff4[indexOffset4 + idx].z += val.z;
}
}

const size_t nLastInts = nelems % 4;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nLastInts; idx += blockDim.x * nBlocks) {
int val = smChans[(localRankIndexInNode + peerIdx) % nPeer].read<int>(indexOffset + nInt4 * 4 + idx);
buff[indexOffset + nInt4 * 4 + idx] += val;
}
}
}

__device__ void reduceScatterSm(int* buff, int* scratch, int rank, int nRanksPerNode, int worldSize,
size_t nelems // must be divisible by 3
) {
Expand All @@ -351,7 +397,7 @@ __device__ void reduceScatterSm(int* buff, int* scratch, int rank, int nRanksPer
int nBlocksRemain = gridDim.x - nBlocksForReduceScatter;
DeviceHandle<mscclpp::SimpleProxyChannel>& proxyChan = constDevFstRoundChans[peer];
if (peerNodeId == rank / nRanksPerNode) {
localReduceScatterSm(buff, scratch, rank, nRanksPerNode, 0, 0, chunkSize, chunkSize, nBlocksForReduceScatter);
localReduceScatterSm(buff, scratch, rank, nRanksPerNode, 0, 0, chunkSize, chunkSize, gridDim.x);
return;
}

Expand Down Expand Up @@ -448,11 +494,34 @@ __device__ void localAllGatherSm(int rank, int nRanksPerNode, int startRankChunk
constSmInPlaceChans[peerIdx].signal();
constSmInPlaceChans[peerIdx].wait();
}
deviceSyncer.sync(nBlocks);
allGatherDeviceSyncer.sync(nBlocks);
size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;
constSmInPlaceChans[peerIdx].get(offset + offsetForThisBlock, sizeForThisBlock, threadIdx.x, blockDim.x);
}

__device__ void localRingAllGatherSm(int rank, int nRanksPerNode, uint64_t size, size_t nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;

int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int nPeer = nRanksPerNode - 1;

if (tid < nPeer) {
constSmInPlaceChans[tid].signal();
}
int waitStart = nBlocks * blockDim.x - nPeer;
if (tid >= waitStart && tid < nBlocks * blockDim.x) {
constSmInPlaceChans[tid - waitStart].wait();
}
allGatherDeviceSyncer.sync(nBlocks);
for (int i = 0; i < nPeer; ++i) {
int peerIdx = (i + rank) % nPeer;
const int remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
size_t offset = size * remoteRankLocalIndex;
constSmInPlaceChans[peerIdx].get(offset, size, tid, blockDim.x * nBlocks);
}
}

// This is an allgather4 equivalent
__device__ void allGatherSm(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU) {
// this allgather is a pipelined and hierarchical one and only works for two nodes
Expand All @@ -470,13 +539,13 @@ __device__ void allGatherSm(int rank, int worldSize, int nRanksPerNode, size_t n
int peerNodeId = peerRank / nRanksPerNode;
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
DeviceHandle<mscclpp::SimpleProxyChannel>& proxyChan = constDevSndRoundChans[peer];
const size_t nBlocksForLocalAllGather = gridDim.x;
const size_t nBlocksForLocalAllGather = gridDim.x / (nRanksPerNode - 1) * (nRanksPerNode - 1);
const size_t rankChunkSize = nelemsPerGPU * sizeof(int);
const int startRankIndexInLocalNode = (rank / nRanksPerNode) * nRanksPerNode;
const int startRankIndexInPeerNode = (peerRank / nRanksPerNode) * nRanksPerNode;

if (peerNodeId == rank / nRanksPerNode) {
localAllGatherSm(rank, nRanksPerNode, 0, 0, rankChunkSize, rankChunkSize, nBlocksForLocalAllGather);
localAllGatherSm(rank, nRanksPerNode, 0, 0, rankChunkSize, rankChunkSize, gridDim.x);
return;
}

Expand All @@ -495,7 +564,7 @@ __device__ void allGatherSm(int rank, int worldSize, int nRanksPerNode, size_t n
proxyChan.wait();
proxyChan.flush();
}
deviceSyncer.sync(nBlocksForLocalAllGather);
deviceSyncer.sync(gridDim.x);
// Step 2
if (threadIdx.x == 0 && blockIdx.x == 0) {
proxyChan.putWithSignal(rank * nelemsPerGPU * sizeof(int) + step1Bytes, step2Bytes);
Expand All @@ -506,7 +575,7 @@ __device__ void allGatherSm(int rank, int worldSize, int nRanksPerNode, size_t n
proxyChan.wait();
proxyChan.flush();
}
deviceSyncer.sync(nBlocksForLocalAllGather);
deviceSyncer.sync(gridDim.x);
// Step 3
localAllGatherSm(rank, nRanksPerNode, startRankIndexInPeerNode, step1Bytes, rankChunkSize, step2Bytes,
nBlocksForLocalAllGather);
Expand Down Expand Up @@ -795,6 +864,75 @@ __global__ void allreduce4(int* buff, int* scratch, void* result, int rank, int
allGatherSm(rank, worldSize, nRanksPerNode, nelems / worldSize);
}

__global__ void allreduce5(int* buff, int* scratch, void* result, int rank, int nRanksPerNode, int worldSize,
size_t nelems) {
localReduceScatterSm2(buff, scratch, rank, nRanksPerNode, nelems / worldSize, nelems / worldSize, gridDim.x);
deviceSyncer.sync(gridDim.x);
localRingAllGatherSm(rank, nRanksPerNode, nelems / worldSize * sizeof(int), gridDim.x);
}

__global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank, int nRanksPerNode, int worldSize,
size_t nelems) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
const int nPeers = nRanksPerNode - 1;
const int nPkts = nelems / 2;
const int nelemsPerRank = nelems / worldSize;
const int nPktsPerRank = nelemsPerRank / 2;
// flag for packets. Initially 1
const uint32_t flag = (uint32_t)globalFlag;
// thread block & channel info
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
DeviceHandle<mscclpp::SmChannel> smChan = constSmOutOfPlaceChans[peerIdx];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
size_t srcOffset = rank * nelemsPerRank * sizeof(int);
uint2* src = (uint2*)((char*)buff + srcOffset);
uint2* dst = (uint2*)((char*)resultBuff + srcOffset);

// step 1: write to scratch buffer
smChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
uint2 data = make_uint2(0, 0);
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
uint2 val = dstPkt[idx].read(flag);
data.x += val.x;
data.y += val.y;
}
data.x += src[idx].x;
data.y += src[idx].y;
dst[idx].x = data.x;
dst[idx].y = data.y;
for (int index = 0; index < nPeers; index++) {
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)constSmOutOfPlaceChans[index].dst_ + scratchResultOffset);
dstPkt[idx + rank * nPktsPerRank].write(data.x, data.y, flag);
}
}
// step 3: get data result from scratch buffer
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRank * nPktsPerRank;
uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
uint2 data = dstPkt[idx + dstOffset].read(flag);
result[idx].x = data.x;
result[idx].y = data.y;
}
if (threadIdx.x == 0 && blockIdx.x == 0) {
globalFlag += 1;
}
}

class AllReduceTestColl : public BaseTestColl {
public:
AllReduceTestColl() = default;
Expand Down Expand Up @@ -830,6 +968,14 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
nBlocks = 45;
tmpBuff = scratchBuff;
nThreadsPerBlock = 512;
} else if (kernelNum == 5) {
nBlocks = 24;
tmpBuff = scratchBuff;
nThreadsPerBlock = 1024;
} else if (kernelNum == 6) {
nBlocks = 21;
tmpBuff = scratchPacketBuff;
nThreadsPerBlock = 512;
} else {
nBlocks = std::max(args.nRanksPerNode - 1, 1) * BLOCKS_PER_PEER;
tmpBuff = scratchPacketBuff;
Expand All @@ -850,6 +996,13 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
else if (kernelNum == 4)
allreduce4<<<nBlocks, nThreadsPerBlock, 0, stream>>>((int*)inputBuff, (int*)tmpBuff, resultBuff, rank,
args.nRanksPerNode, worldSize, paramCount_);
else if (kernelNum == 5)
allreduce5<<<nBlocks, nThreadsPerBlock, 0, stream>>>((int*)inputBuff, (int*)tmpBuff, resultBuff, rank,
args.nRanksPerNode, worldSize, paramCount_);
else if (kernelNum == 6) {
allreduce6<<<nBlocks, nThreadsPerBlock, 0, stream>>>((int*)inputBuff, (int*)tmpBuff, resultBuff, rank,
args.nRanksPerNode, worldSize, paramCount_);
}
}

void AllReduceTestColl::initData(const TestArgs& args, std::vector<void*> sendBuff, void* expectedBuff) {
Expand Down Expand Up @@ -882,6 +1035,7 @@ void AllReduceTestColl::setupCollTest(size_t size) {
mscclpp::DeviceSyncer syncer = {};
uint64_t initFlag = 1;
CUDATHROW(cudaMemcpyToSymbol(deviceSyncer, &syncer, sizeof(mscclpp::DeviceSyncer)));
CUDATHROW(cudaMemcpyToSymbol(allGatherDeviceSyncer, &syncer, sizeof(mscclpp::DeviceSyncer)));
CUDATHROW(cudaMemcpyToSymbol(reduceScatterDeviceSyncer, &syncer, sizeof(mscclpp::DeviceSyncer)));
CUDATHROW(cudaMemcpyToSymbol(ibDeviceSyncer, &syncer, sizeof(mscclpp::DeviceSyncer)));
CUDATHROW(cudaMemcpyToSymbol(globalFlag, &initFlag, sizeof(uint64_t)));
Expand All @@ -899,7 +1053,9 @@ std::vector<KernelRestriction> AllReduceTestColl::getKernelRestrictions() {
true,
3,
.alignedBytes = 16 * worldSize_ /*use ulong2 to transfer data*/,
}};
},
{5, "allreduce5", false, 1, .alignedBytes = 4 * worldSize_},
{6, "allreduce6", false, 1, .alignedBytes = 4 * worldSize_}};
}

class AllReduceTestEngine : public BaseTestEngine {
Expand Down Expand Up @@ -936,9 +1092,9 @@ AllReduceTestEngine::AllReduceTestEngine(const TestArgs& args) : BaseTestEngine(
inPlace_ = isInPlace();
}

bool AllReduceTestEngine::isUsePacket() const { return (args_.kernelNum == 2); }
bool AllReduceTestEngine::isUsePacket() const { return (args_.kernelNum == 2 || args_.kernelNum == 6); }

bool AllReduceTestEngine::isInPlace() const { return (args_.kernelNum != 2); }
bool AllReduceTestEngine::isInPlace() const { return (args_.kernelNum != 2 && args_.kernelNum != 6); }

void AllReduceTestEngine::allocateBuffer() {
inputBuff_ = mscclpp::allocSharedCuda<int>(args_.maxBytes / sizeof(int));
Expand All @@ -960,6 +1116,12 @@ void AllReduceTestEngine::allocateBuffer() {
getPacketBuff_ = mscclpp::allocSharedCuda<mscclpp::LLPacket>(packetBuffNelem);
putPacketBuff = putPacketBuff_.get();
getPacketBuff = getPacketBuff_.get();
} else if (args_.kernelNum == 6) {
const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t);
// 2x for double-buffering, scratchBuff used to store original data and reduced results
const size_t scratchBuffNelem = nPacket * 2 /*original data & reduced result */ * 2 /* double buffering*/;
scratchPacketBuff_ = mscclpp::allocSharedCuda<mscclpp::LLPacket>(scratchBuffNelem);
scratchPacketBuff = scratchPacketBuff_.get();
}

expectedBuff_ = std::shared_ptr<int[]>(new int[args_.maxBytes / sizeof(int)]);
Expand All @@ -975,22 +1137,33 @@ void AllReduceTestEngine::setupConnections() {
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;

const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t);
const size_t scratchPacketBuffBytes =
nPacket * std::max(args_.nRanksPerNode - 1, 1) * 2 * sizeof(mscclpp::LLPacket);
const size_t packetBuffBytes = nPacket * 2 * sizeof(mscclpp::LLPacket);
setupMeshConnections(smOutOfPlaceChannels_, proxyChannels, inputBuff_.get(), args_.maxBytes, putPacketBuff_.get(),
packetBuffBytes, getPacketBuff_.get(), packetBuffBytes, scratchPacketBuff_.get(),
scratchPacketBuffBytes);

assert(smOutOfPlaceChannels_.size() < sizeof(constSmOutOfPlaceChans) / sizeof(DeviceHandle<mscclpp::SmChannel>));
assert(proxyChannels.size() < sizeof(constDevFstRoundChans) / sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>));

std::vector<DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles(smOutOfPlaceChannels_.size());
getChannelDeviceHandle(smOutOfPlaceChannels_, smChannelDeviceHandles);
CUDATHROW(cudaMemcpyToSymbol(constSmOutOfPlaceChans, smChannelDeviceHandles.data(),
sizeof(DeviceHandle<mscclpp::SmChannel>) * smChannelDeviceHandles.size()));
CUDATHROW(cudaMemcpyToSymbol(constDevFstRoundChans, proxyChannels.data(),
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>) * proxyChannels.size()));
if (args_.kernelNum == 6) {
const size_t scratchPacketBuffBytes = nPacket * 2 * 2 * sizeof(mscclpp::LLPacket);
setupMeshConnections(smOutOfPlaceChannels_, inputBuff_.get(), args_.maxBytes, scratchPacketBuff_.get(),
scratchPacketBuffBytes);
std::vector<DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles(smOutOfPlaceChannels_.size());
getChannelDeviceHandle(smOutOfPlaceChannels_, smChannelDeviceHandles);
CUDATHROW(cudaMemcpyToSymbol(constSmOutOfPlaceChans, smChannelDeviceHandles.data(),
sizeof(DeviceHandle<mscclpp::SmChannel>) * smChannelDeviceHandles.size()));
}
if (args_.kernelNum == 2) {
const size_t scratchPacketBuffBytes =
nPacket * std::max(args_.nRanksPerNode - 1, 1) * 2 * sizeof(mscclpp::LLPacket);
const size_t packetBuffBytes = nPacket * 2 * sizeof(mscclpp::LLPacket);
setupMeshConnections(smOutOfPlaceChannels_, proxyChannels, inputBuff_.get(), args_.maxBytes, putPacketBuff_.get(),
packetBuffBytes, getPacketBuff_.get(), packetBuffBytes, scratchPacketBuff_.get(),
scratchPacketBuffBytes);

assert(smOutOfPlaceChannels_.size() < sizeof(constSmOutOfPlaceChans) / sizeof(DeviceHandle<mscclpp::SmChannel>));
assert(proxyChannels.size() < sizeof(constDevFstRoundChans) / sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>));

std::vector<DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles(smOutOfPlaceChannels_.size());
getChannelDeviceHandle(smOutOfPlaceChannels_, smChannelDeviceHandles);
CUDATHROW(cudaMemcpyToSymbol(constSmOutOfPlaceChans, smChannelDeviceHandles.data(),
sizeof(DeviceHandle<mscclpp::SmChannel>) * smChannelDeviceHandles.size()));
CUDATHROW(cudaMemcpyToSymbol(constDevFstRoundChans, proxyChannels.data(),
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>) * proxyChannels.size()));
}
} else {
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> fstRoundChannels;
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> sndRoundChannels;
Expand Down

0 comments on commit 952f2da

Please sign in to comment.