Skip to content

Commit

Permalink
Add proxy channel related operations (#351)
Browse files Browse the repository at this point in the history
Add Flush, PutWithSignal, PutWithFlushAndSignal operation
  • Loading branch information
Binyang2014 authored Sep 15, 2024
1 parent 26a8753 commit 7bedb25
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ auto getOpType = [](const std::string& str) {
return mscclpp::OperationType::BARRIER;
} else if (str == "put") {
return mscclpp::OperationType::PUT;
} else if (str == "pws") {
return mscclpp::OperationType::PUT_WITH_SIGNAL;
} else if (str == "pwsf") {
return mscclpp::OperationType::PUT_WITH_SIGNAL_AND_FLUSH;
} else if (str == "get") {
return mscclpp::OperationType::GET;
} else if (str == "copy") {
Expand Down
2 changes: 2 additions & 0 deletions src/include/execution_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ enum class OperationType : uint8_t {
REDUCE_SEND_PACKET,
READ_REDUCE_COPY,
READ_REDUCE_COPY_SEND,
PUT_WITH_SIGNAL,
PUT_WITH_SIGNAL_AND_FLUSH,
};

struct Channels {
Expand Down
35 changes: 34 additions & 1 deletion src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ MSCCLPP_DEVICE_INLINE void handleWait(DeviceHandle<SmChannel>* smChannels,
}
}

MSCCLPP_DEVICE_INLINE void handleFlush(DeviceHandle<SimpleProxyChannel>* proxyChannels, uint8_t* channelIndexes,
int nChannels) {
int tid = threadIdx.x;
if (tid < nChannels) {
proxyChannels[channelIndexes[tid]].flush();
}
}

MSCCLPP_DEVICE_INLINE void handleGet(DeviceHandle<SmChannel>* smChannel, uint8_t* srcChannelIndexes,
uint32_t* dstOffsets, uint32_t* srcOffsets, int count, uint32_t size) {
for (int i = 0; i < count; i++) {
Expand All @@ -194,6 +202,7 @@ MSCCLPP_DEVICE_INLINE void handleGet(DeviceHandle<SmChannel>* smChannel, uint8_t
}
}

template <bool PutWithSignal = false, bool PutWithSignalAndFlush = false>
MSCCLPP_DEVICE_INLINE void handlePut(DeviceHandle<SmChannel>* smChannel,
DeviceHandle<SimpleProxyChannel>* proxyChannels, uint8_t* dstChannelIndexes,
uint32_t* dstOffsets, uint32_t* srcOffsets, int count, uint32_t size,
Expand All @@ -209,7 +218,13 @@ MSCCLPP_DEVICE_INLINE void handlePut(DeviceHandle<SmChannel>* smChannel,
if (chType == ChannelType::PROXY) {
int tid = threadIdx.x;
if (tid < count) {
proxyChannels[dstChannelIndexes[tid]].put(dstOffsets[tid], srcOffsets[tid], size);
if constexpr (PutWithSignal) {
proxyChannels[dstChannelIndexes[tid]].putWithSignal(dstOffsets[tid], srcOffsets[tid], size);
} else if constexpr (PutWithSignalAndFlush) {
proxyChannels[dstChannelIndexes[tid]].putWithSignalAndFlush(dstOffsets[tid], srcOffsets[tid], size);
} else {
proxyChannels[dstChannelIndexes[tid]].put(dstOffsets[tid], srcOffsets[tid], size);
}
}
}
}
Expand Down Expand Up @@ -380,6 +395,12 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
}
}

MSCCLPP_DEVICE_INLINE void handleCopy(void* dst, void* src, uint32_t dstOffset, uint32_t srcOffset, size_t size) {
char* srcData = (char*)src + srcOffset;
char* dstData = (char*)dst + dstOffset;
Element::copy(dstData, srcData, size, threadIdx.x, blockDim.x);
}

template <typename T, typename PacketType = LL16Packet>
__global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* input, T* output, T* scratch,
size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag
Expand Down Expand Up @@ -449,11 +470,23 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
handleSignal(smChannels, proxyChannels, op.outputChannelIndexes, op.nOutputs, op.channelType);
} else if (op.type == OperationType::WAIT) {
handleWait(smChannels, proxyChannels, op.inputChannelIndexes, op.nInputs, op.channelType);
} else if (op.type == OperationType::FLUSH) {
handleFlush(proxyChannels, op.outputChannelIndexes, op.nOutputs);
} else if (op.type == OperationType::PUT) {
handlePut(smChannels, proxyChannels, op.outputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs,
op.size, op.channelType);
} else if (op.type == OperationType::PUT_WITH_SIGNAL) {
handlePut<true>(smChannels, proxyChannels, op.outputChannelIndexes, op.outputOffsets, op.inputOffsets,
op.nOutputs, op.size, op.channelType);
} else if (op.type == OperationType::PUT_WITH_SIGNAL_AND_FLUSH) {
handlePut<false, true>(smChannels, proxyChannels, op.outputChannelIndexes, op.outputOffsets, op.inputOffsets,
op.nOutputs, op.size, op.channelType);
} else if (op.type == OperationType::GET) {
handleGet(smChannels, op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nInputs, op.size);
} else if (op.type == OperationType::COPY) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleCopy(dst, src, op.dstOffset, op.srcOffset, op.size);
} else if (op.type == OperationType::READ_REDUCE_COPY_SEND) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
Expand Down
188 changes: 188 additions & 0 deletions test/execution-files/sendrecv.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
{
"name": "send_recv",
"colletive": "sendrecv",
"protocol": "Simple",
"inplace": false,
"gpus": [
{
"id": 0,
"inputChunks": 1,
"outputChunks": 1,
"scratchChunks": 2,
"chunkGroups": 1,
"threadblocks": [
{
"id": 0,
"ops": [
{
"name": "pwsf",
"o_buff": {
"src": "i",
"dst": "s"
},
"o_cids": [
{
"id": 0,
"off": 1
}
],
"srcs": [
{
"buff": "i",
"off": 0
}
],
"ctype": "proxy",
"cnt": 1
},
{
"name": "wait",
"i_buff": {
"src": "i",
"dst": "s"
},
"i_cids": [
{
"id": 0,
"off": 0
}
],
"ctype": "proxy",
"cnt": 1
},
{
"name": "nop",
"deps": [
{
"tb": 0,
"step": 1
}
]
},
{
"name": "copy",
"src": 0,
"srcbuff": "s",
"srcoff": 1,
"dst": 0,
"dstbuff": "o",
"dstoff": 0,
"ctype": "none",
"cnt": 1
}
],
"channels": [
{
"src": "i",
"dst": "s",
"ctype": "proxy",
"cids": [
0
]
}
]
}
],
"channels": [
{
"srcbuff": "i",
"dstbuff": "s",
"type": "proxy",
"connectedTo": [
1
]
}
]
},
{
"id": 1,
"inputChunks": 1,
"outputChunks": 1,
"scratchChunks": 2,
"chunkGroups": 1,
"threadblocks": [
{
"id": 0,
"ops": [
{
"name": "pwsf",
"o_buff": {
"src": "i",
"dst": "s"
},
"o_cids": [
{
"id": 0,
"off": 1
}
],
"srcs": [
{
"buff": "i",
"off": 0
}
],
"ctype": "proxy",
"cnt": 1
},
{
"name": "wait",
"i_buff": {
"src": "i",
"dst": "s"
},
"i_cids": [
{
"id": 0,
"off": 0
}
],
"ctype": "proxy",
"cnt": 1
},
{
"name": "nop",
"deps": [
{
"tb": 0,
"step": 1
}
]
},
{
"name": "copy",
"src": 1,
"srcbuff": "s",
"srcoff": 1,
"dst": 1,
"dstbuff": "o",
"dstoff": 0,
"ctype": "none",
"cnt": 1
}
],
"channels": [
{
"src": "i",
"dst": "s",
"ctype": "proxy",
"cids": [
0
]
}
]
}
],
"channels": [
{
"srcbuff": "i",
"dstbuff": "s",
"type": "proxy",
"connectedTo": [
0
]
}
]
}
]
}
2 changes: 2 additions & 0 deletions tools/npkit/npkit_trace_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def parse_npkit_event_header(npkit_event_header_path):
"REDUCE_SEND_PACKET",
"READ_REDUCE_COPY",
"READ_REDUCE_COPY_SEND",
"PUT_WITH_SIGNAL",
"PUT_WITH_SIGNAL_AND_FLUSH",
]
executor_op_to_offset = {}
for executor_op in executor_ops:
Expand Down

0 comments on commit 7bedb25

Please sign in to comment.