Skip to content

Commit

Permalink
Support executors to send packets over ProxyChannel (#344)
Browse files Browse the repository at this point in the history
Co-authored-by: Binyang Li <[email protected]>
  • Loading branch information
caiomcbr and Binyang2014 authored Aug 30, 2024
1 parent 1af62ea commit 4eca6f1
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 19 deletions.
2 changes: 2 additions & 0 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ auto getOpType = [](const std::string& str) {
return mscclpp::OperationType::REDUCE_SEND_PACKET;
} else if (str == "cpkt") {
return mscclpp::OperationType::COPY_PACKET;
} else if (str == "tpkt") {
return mscclpp::OperationType::TRANSFORM_TO_PACKET;
} else if (str == "rpkt") {
return mscclpp::OperationType::REDUCE_PACKET;
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/include/execution_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum class OperationType : uint8_t {
GET,
COPY,
COPY_PACKET,
COPY_TO_PACKET,
TRANSFORM_TO_PACKET,
SIGNAL,
WAIT,
FLUSH,
Expand Down
48 changes: 31 additions & 17 deletions src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,9 @@ MSCCLPP_DEVICE_INLINE void handlePut(DeviceHandle<SmChannel>* smChannel,
return;
}
if (chType == ChannelType::PROXY) {
for (int i = 0; i < count; i++) {
uint32_t dstOffset = dstOffsets[i];
uint32_t srcOffset = srcOffsets[i];
proxyChannels[dstChannelIndexes[i]].put(dstOffset, srcOffset, size);
int tid = threadIdx.x;
if (tid < count) {
proxyChannels[dstChannelIndexes[tid]].put(dstOffsets[tid], srcOffsets[tid], size);
}
}
}
Expand Down Expand Up @@ -254,12 +253,25 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf

template <typename PacketType>
MSCCLPP_DEVICE_INLINE void handlePutPacket(size_t scratchSize, DeviceHandle<SmChannel>* smChannels,
uint8_t* dstChannelIndexes, uint32_t* dstOffsets, uint32_t* srcOffsets,
int nDstChannels, uint32_t size, uint32_t flag) {
DeviceHandle<SimpleProxyChannel>* proxyChannels, uint8_t* dstChannelIndexes,
uint32_t* dstOffsets, uint32_t* srcOffsets, int nDstChannels, uint32_t size,
ChannelType chType, uint32_t flag) {
const size_t scratchBaseOffset = flag & 0x1 ? 0 : scratchSize >> 1;
for (int index = 0; index < nDstChannels; ++index) {
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(scratchBaseOffset + dstOffsets[index] * 2,
srcOffsets[index], size, threadIdx.x, blockDim.x, flag);
if (chType == ChannelType::SM) {
for (int index = 0; index < nDstChannels; ++index) {
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(
scratchBaseOffset + dstOffsets[index] * 2, srcOffsets[index], size, threadIdx.x, blockDim.x, flag);
}
}
if (chType == ChannelType::PROXY) {
int tid = threadIdx.x;
if (tid >= nDstChannels) {
return;
}
// For proxy channel, we assume src and dst are in packet format
uint32_t dstOffset = (dstOffsets[tid] << 1) + scratchBaseOffset;
uint32_t srcOffset = (srcOffsets[tid] << 1) + scratchBaseOffset;
proxyChannels[dstChannelIndexes[tid]].put(dstOffset, srcOffset, size << 1);
}
}

Expand Down Expand Up @@ -298,8 +310,8 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
template <typename PacketType>
MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize, uint32_t dstOffset,
uint32_t srcOffset, size_t size, uint32_t flag) {
const size_t outputScratchBaseOffset = flag & 0x1 ? 0 : srcSize >> 1;
PacketType* srcPackets = (PacketType*)((char*)src + outputScratchBaseOffset + 2 * srcOffset);
const size_t inputScratchBaseOffset = flag & 0x1 ? 0 : srcSize >> 1;
PacketType* srcPackets = (PacketType*)((char*)src + inputScratchBaseOffset + 2 * srcOffset);
PacketPayload<PacketType>* result = (PacketPayload<PacketType>*)((char*)dst + dstOffset);
size_t nPackets = size * 2 / sizeof(PacketType);
for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) {
Expand All @@ -309,8 +321,10 @@ MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize
}

template <typename PacketType>
MSCCLPP_DEVICE_INLINE void handleCopyToPacket(void* dst, void* src, size_t srcSize, uint32_t dstOffset,
uint32_t srcOffset, size_t size, uint32_t flag) {
MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t dstSize, uint32_t dstOffset,
uint32_t srcOffset, size_t size, uint32_t flag) {
const size_t outputScratchBaseOffset = flag & 0x1 ? 0 : dstSize >> 1;
dstOffset = dstOffset * 2 + outputScratchBaseOffset;
mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag);
}

Expand Down Expand Up @@ -444,8 +458,8 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs,
op.size, false);
} else if (op.type == OperationType::PUT_PACKET) {
handlePutPacket<PacketType>(scratchSize, smChannels, op.outputChannelIndexes, op.outputOffsets, op.inputOffsets,
op.nOutputs, op.size, flag);
handlePutPacket<PacketType>(scratchSize, smChannels, proxyChannels, op.outputChannelIndexes, op.outputOffsets,
op.inputOffsets, op.nOutputs, op.size, op.channelType, flag);
} else if (op.type == OperationType::REDUCE_SEND_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
Expand All @@ -462,10 +476,10 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleCopyPacket<PacketType>(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag);
} else if (op.type == OperationType::COPY_TO_PACKET) {
} else if (op.type == OperationType::TRANSFORM_TO_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleCopyPacket<PacketType>(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag);
handleTransformToPacket<PacketType>(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag);
} else if (op.type == OperationType::REDUCE_SEND) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
Expand Down
180 changes: 180 additions & 0 deletions test/execution-files/sendrecv_packet.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
{
"name": "send_recv",
"colletive": "sendrecv",
"protocol": "LL",
"inplace": false,
"gpus": [
{
"id": 0,
"inputChunks": 1,
"outputChunks": 1,
"scratchChunks": 2,
"chunkGroups": 1,
"threadblocks": [
{
"id": 0,
"ops": [
{
"name": "tpkt",
"src": 0,
"srcbuff": "i",
"srcoff": 0,
"dst": 0,
"dstbuff": "s",
"dstoff": 0,
"ctype": "none",
"cnt": 1
},
{
"name": "nop",
"deps": [
{
"tb": 0,
"step": 0
}
]
},
{
"name": "ppkt",
"o_buff": {
"src": "s",
"dst": "s"
},
"o_cids": [
{
"id": 0,
"off": 1
}
],
"srcs": [
{
"buff": "s",
"off": 0
}
],
"ctype": "proxy",
"cnt": 1
},
{
"name": "cpkt",
"src": 0,
"srcbuff": "s",
"srcoff": 1,
"dst": 0,
"dstbuff": "o",
"dstoff": 0,
"ctype": "none",
"cnt": 1
}
],
"channels": [
{
"src": "s",
"dst": "s",
"ctype": "proxy",
"cids": [
0
]
}
]
}
],
"channels": [
{
"srcbuff": "s",
"dstbuff": "s",
"type": "proxy",
"connectedTo": [
1
]
}
]
},
{
"id": 1,
"inputChunks": 1,
"outputChunks": 1,
"scratchChunks": 2,
"chunkGroups": 1,
"threadblocks": [
{
"id": 0,
"ops": [
{
"name": "tpkt",
"src": 1,
"srcbuff": "i",
"srcoff": 0,
"dst": 1,
"dstbuff": "s",
"dstoff": 0,
"ctype": "none",
"cnt": 1
},
{
"name": "nop",
"deps": [
{
"tb": 0,
"step": 0
}
]
},
{
"name": "ppkt",
"o_buff": {
"src": "s",
"dst": "s"
},
"o_cids": [
{
"id": 0,
"off": 1
}
],
"srcs": [
{
"buff": "s",
"off": 0
}
],
"ctype": "proxy",
"cnt": 1
},
{
"name": "cpkt",
"src": 1,
"srcbuff": "s",
"srcoff": 1,
"dst": 1,
"dstbuff": "o",
"dstoff": 0,
"ctype": "none",
"cnt": 1
}
],
"channels": [
{
"src": "s",
"dst": "s",
"ctype": "proxy",
"cids": [
0
]
}
]
}
],
"channels": [
{
"srcbuff": "s",
"dstbuff": "s",
"type": "proxy",
"connectedTo": [
0
]
}
]
}
]
}
2 changes: 1 addition & 1 deletion tools/npkit/npkit_trace_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def parse_npkit_event_header(npkit_event_header_path):
"GET",
"COPY",
"COPY_PACKET",
"COPY_TO_PACKET",
"TRANSFORM_TO_PACKET",
"SIGNAL",
"WAIT",
"FLUSH",
Expand Down

0 comments on commit 4eca6f1

Please sign in to comment.