Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Apr 7, 2024
1 parent d52ef41 commit b03be9a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus) {
// will have either srcs or i_cids
if (op.contains("srcs")) {
operation.nInputs = op["srcs"].size();
operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]);
}
for (int i = 0; i < operation.nInputs; i++) {
operation.inputOffsets[i] = this->chunkSize * (int)op["srcs"][i]["off"];
Expand Down
5 changes: 4 additions & 1 deletion src/include/execution_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ struct Operation {
BufferType dstBufferType;
uint8_t nInputs;
uint8_t nOutputs;
uint8_t inputChannelIndexes[MAX_CHANNEL_PER_OPERATION];
union {
uint8_t inputChannelIndexes[MAX_CHANNEL_PER_OPERATION];
BufferType inputBufferType;
};
uint8_t outputChannelIndexes[MAX_CHANNEL_PER_OPERATION];
uint32_t inputOffsets[MAX_CHANNEL_PER_OPERATION];
uint32_t outputOffsets[MAX_CHANNEL_PER_OPERATION];
Expand Down
38 changes: 20 additions & 18 deletions src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,36 +211,36 @@ MSCCLPP_DEVICE_INLINE void handlePutPacket(uint32_t inputOffsetByBytes, DeviceHa
uint8_t* dstChannelIndexes, uint32_t* dstOffsets, int nDstChannels,
uint32_t size, uint32_t flag) {
for (int index = 0; index < nDstChannels; ++index) {
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(
dstOffsets[index] * sizeof(PacketType), inputOffsetByBytes, size, threadIdx.x, blockDim.x, flag);
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(dstOffsets[index] * 2, inputOffsetByBytes, size,
threadIdx.x, blockDim.x, flag);
}
}

template <typename T, typename PacketType>
MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* output, uint32_t outputOffsetByBytes, T* input,
uint32_t inputOffsetByBytes, DeviceHandle<SmChannel>* smChannels,
uint8_t* dstChannelIndexes, uint32_t* dstOffsets,
uint32_t* srcOffsets, int nDstChannels, int nSrcs, size_t size,
MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes,
T* inputBuff, uint32_t* inputOffsets, int nSrcs,
DeviceHandle<SmChannel>* smChannels, uint8_t* outputChannelIndexes,
uint32_t* outputOffsets, int nDstChannels, size_t size,
uint32_t flag) {
size_t nPackets = size * 2 / sizeof(PacketType);
const uint32_t srcOffset = inputOffsetByBytes / sizeof(PacketValType<PacketType>);
const uint32_t dstOffset = outputOffsetByBytes / sizeof(PacketValType<PacketType>);
PacketValType<PacketType>* src = (PacketValType<PacketType>*)input + srcOffset;
PacketValType<PacketType>* dst = (PacketValType<PacketType>*)output + dstOffset;
const uint32_t srcOffset = dstOffsetByBytes / sizeof(PacketValType<PacketType>);
const uint32_t dstOffset = dstOffsetByBytes / sizeof(PacketValType<PacketType>);
PacketValType<PacketType>* srcPacketValue = (PacketValType<PacketType>*)src + srcOffset;
PacketValType<PacketType>* dstPacketValue = (PacketValType<PacketType>*)dst + dstOffset;
for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) {
PacketValType<PacketType> data = {};
for (int index = 0; index < nSrcs; ++index) {
PacketType* pkt = (PacketType*)((char*)input + 2 * srcOffsets[index]);
PacketType* pkt = (PacketType*)((char*)inputBuff + 2 * inputOffsets[index]);
PacketValType<PacketType> val = pkt[idx].read(flag);
data = add_vectors<T>(data, val);
}
data = add_vectors<T>(data, src[idx]);
dst[idx] = data;
data = add_vectors<T>(data, srcPacketValue[idx]);
dstPacketValue[idx] = data;

PacketType pkt(data, flag);
for (int index = 0; index < nDstChannels; ++index) {
size_t offset = (dstOffsets[index] * 2) / sizeof(PacketType);
smChannels[dstChannelIndexes[index]].write(offset + idx, pkt);
size_t offset = (outputOffsets[index] * 2) / sizeof(PacketType);
smChannels[outputChannelIndexes[index]].write(offset + idx, pkt);
}
}
}
Expand Down Expand Up @@ -277,6 +277,7 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
DeviceHandle<SimpleProxyChannel>* proxyChannels = localPlan->channels.proxyChannels;
T* src = nullptr;
T* dst = nullptr;
T* tmp = nullptr;
for (int i = 0; i < localPlan->nOperations; i++) {
switch (operations[i].type) {
case OperationType::BARRIER:
Expand Down Expand Up @@ -317,10 +318,11 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
case OperationType::REDUCE_SEND_PACKET:
dst = getBuffer(input, output, scratch, operations[i].dstBufferType);
src = getBuffer(input, output, scratch, operations[i].srcBufferType);
handleReduceSendPacket<T, PacketType>(dst, operations[i].dstOffset, src, operations[i].srcOffset, smChannels,
tmp = getBuffer(input, output, scratch, operations[i].inputBufferType);
handleReduceSendPacket<T, PacketType>(dst, operations[i].dstOffset, src, operations[i].srcOffset, tmp,
operations[i].inputOffsets, operations[i].nInputs, smChannels,
operations[i].outputChannelIndexes, operations[i].outputOffsets,
operations[i].inputOffsets, operations[i].nOutputs, operations[i].nInputs,
operations[i].size, flag);
operations[i].nOutputs, operations[i].size, flag);
break;
case OperationType::COPY_PACKET:
dst = getBuffer(input, output, scratch, operations[i].dstBufferType);
Expand Down
2 changes: 1 addition & 1 deletion test/executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ int main() {
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
bootstrap->initialize(id);
// sleep 20s
std::this_thread::sleep_for(std::chrono::seconds(20));
// std::this_thread::sleep_for(std::chrono::seconds(20));
auto comm = std::make_shared<mscclpp::Communicator>(bootstrap);
CUDACHECK(cudaSetDevice(rank));

Expand Down

0 comments on commit b03be9a

Please sign in to comment.