Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Apr 17, 2024
1 parent 0c2b2c1 commit 8d1b644
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) {
}
}

ExecutionPlan::ExecutionPlan(const std::string name, const std::string planPath)
ExecutionPlan::ExecutionPlan(const std::string& name, const std::string& planPath)
: impl_(std::make_shared<Impl>(name, planPath)) {}

} // namespace mscclpp
78 changes: 33 additions & 45 deletions src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,51 +283,39 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu

for (int i = 0; i < nOperations; i++) {
Operation& op = operations[i];
switch (op.type) {
case OperationType::BARRIER:
__syncthreads();
break;
case OperationType::SIGNAL:
handleSignal(tid, smChannels, proxyChannels, op.outputChannelIndexes, op.nOutputs, op.channelType);
break;
case OperationType::WAIT:
handleWait(tid, smChannels, proxyChannels, op.inputChannelIndexes, op.nInputs, op.channelType);
break;
case OperationType::GET:
handleGet(smChannels[op.inputChannelIndexes[0]], op.inputOffsets[0], op.dstOffset, op.size);
break;
case OperationType::READ_REDUCE_COPY_SEND:
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleReadReduceCopySend(dst, op.dstOffset, src, op.srcOffset, smChannels, op.outputChannelIndexes,
op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs,
op.nInputs, op.size);
break;
case OperationType::READ_REDUCE_COPY:
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleReadReduceCopySend(dst, op.dstOffset, src, op.srcOffset, smChannels, op.outputChannelIndexes,
op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs,
op.nInputs, op.size, false);
break;
case OperationType::PUT_PACKET:
handlePutPacket<PacketType>(op.srcOffset, scratchSize, smChannels, op.outputChannelIndexes, op.outputOffsets,
op.nOutputs, op.size, flag);
break;
case OperationType::REDUCE_SEND_PACKET:
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleReduceSendPacket<T, PacketType>(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize,
op.inputOffsets, op.nInputs, smChannels, op.outputChannelIndexes,
op.outputOffsets, op.nOutputs, op.size, flag);
break;
case OperationType::COPY_PACKET:
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleCopyPacket<PacketType>(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag);
break;
default:
break;
if (op.type == OperationType::BARRIER) {
__syncthreads();
} else if (op.type == OperationType::SIGNAL) {
handleSignal(tid, smChannels, proxyChannels, op.outputChannelIndexes, op.nOutputs, op.channelType);
} else if (op.type == OperationType::WAIT) {
handleWait(tid, smChannels, proxyChannels, op.inputChannelIndexes, op.nInputs, op.channelType);
} else if (op.type == OperationType::GET) {
handleGet(smChannels[op.inputChannelIndexes[0]], op.inputOffsets[0], op.dstOffset, op.size);
} else if (op.type == OperationType::READ_REDUCE_COPY_SEND) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleReadReduceCopySend(dst, op.dstOffset, src, op.srcOffset, smChannels, op.outputChannelIndexes,
op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs,
op.size);
} else if (op.type == OperationType::READ_REDUCE_COPY) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleReadReduceCopySend(dst, op.dstOffset, src, op.srcOffset, smChannels, op.outputChannelIndexes,
op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs,
op.size, false);
} else if (op.type == OperationType::PUT_PACKET) {
handlePutPacket<PacketType>(op.srcOffset, scratchSize, smChannels, op.outputChannelIndexes, op.outputOffsets,
op.nOutputs, op.size, flag);
} else if (op.type == OperationType::REDUCE_SEND_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleReduceSendPacket<T, PacketType>(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize, op.inputOffsets,
op.nInputs, smChannels, op.outputChannelIndexes, op.outputOffsets,
op.nOutputs, op.size, flag);
} else if (op.type == OperationType::COPY_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleCopyPacket<PacketType>(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag);
}
}
}
Expand Down

0 comments on commit 8d1b644

Please sign in to comment.