From d6b03669070e0244280789ba9033831db0ce997d Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 17 Apr 2024 14:53:35 +0000 Subject: [PATCH] packet fix --- src/executor/execution_kernel.cu | 26 ++++++++------- src/executor/execution_plan.cc | 3 +- src/executor/executor.cc | 10 +++--- src/include/execution_kernel.hpp | 55 ++++++++++++++++---------------- 4 files changed, 50 insertions(+), 44 deletions(-) diff --git a/src/executor/execution_kernel.cu b/src/executor/execution_kernel.cu index 7aca5b1ed..4e96af9ab 100644 --- a/src/executor/execution_kernel.cu +++ b/src/executor/execution_kernel.cu @@ -8,33 +8,35 @@ namespace mscclpp { template void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, - DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, - cudaStream_t stream, uint32_t flag) { + size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, + size_t sharedMemSize, cudaStream_t stream, uint32_t flag) { switch (dataType) { case DataType::INT32: executionKernel<<>>( - rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag); + rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag); break; case DataType::UINT32: executionKernel<<>>( - rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag); + rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag); break; case DataType::FLOAT16: - executionKernel<<>>(rank, (half*)src, (half*)dst, - (half*)scratch, plan, flag); + executionKernel<<>>( + rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag); break; case DataType::FLOAT32: - executionKernel<<>>(rank, (float*)src, (float*)dst, - (float*)scratch, plan, flag); + executionKernel<<>>( + rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag); break; } } template void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, - void* scratch, DataType dataType, DeviceExecutionPlan* plan, - size_t sharedMemSize, cudaStream_t stream, uint32_t flag); + void* scratch, size_t scratchSize, DataType dataType, + DeviceExecutionPlan* plan, size_t sharedMemSize, + cudaStream_t stream, uint32_t flag); template void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, - void* scratch, DataType dataType, DeviceExecutionPlan* plan, - size_t sharedMemSize, cudaStream_t stream, uint32_t flag); + void* scratch, size_t scratchSize, DataType dataType, + DeviceExecutionPlan* plan, size_t sharedMemSize, + cudaStream_t stream, uint32_t flag); } // namespace mscclpp #endif diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index 5948d029d..341c9bd9f 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -109,7 +109,8 @@ std::vector ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c } size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize) const { if (this->isUsingPacket) { - return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank) * 2; + return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank) * 2 /* data + flag*/ * + 2 /*double buffer*/; } return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank); } diff --git a/src/executor/executor.cc b/src/executor/executor.cc index d4112f99b..5b1a827c6 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -263,13 +263,15 @@ struct Executor::Impl { switch (packetType) { case PacketType::LL16: ExecutionKernel::launchKernel( - rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), dataType, - (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), sharedMemSize, stream, ++flag); + rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), + context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), + sharedMemSize, stream, ++flag); break; case PacketType::LL8: ExecutionKernel::launchKernel( - rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), dataType, - (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), sharedMemSize, stream, ++flag); + rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), + context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), + sharedMemSize, stream, ++flag); break; default: throw std::runtime_error("Invalid packet type"); diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index 9fadb2bc1..4cfd1698a 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -207,22 +207,24 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf } template -MSCCLPP_DEVICE_INLINE void handlePutPacket(uint32_t inputOffsetByBytes, DeviceHandle* smChannels, - uint8_t* dstChannelIndexes, uint32_t* dstOffsets, int nDstChannels, - uint32_t size, uint32_t flag) { +MSCCLPP_DEVICE_INLINE void handlePutPacket(uint32_t inputOffsetByBytes, size_t scratchSize, + DeviceHandle* smChannels, uint8_t* dstChannelIndexes, + uint32_t* dstOffsets, int nDstChannels, uint32_t size, uint32_t flag) { + const size_t scratchBaseOffset = flag & 0x1 ? 0 : scratchSize >> 1; for (int index = 0; index < nDstChannels; ++index) { - smChannels[dstChannelIndexes[index]].putPackets(dstOffsets[index] * 2, inputOffsetByBytes, size, - threadIdx.x, blockDim.x, flag); + smChannels[dstChannelIndexes[index]].putPackets( + scratchBaseOffset + dstOffsets[index] * 2, inputOffsetByBytes, size, threadIdx.x, blockDim.x, flag); } } template MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes, - T* inputBuff, uint32_t* inputOffsets, int nSrcs, + T* inputBuff, size_t inputBuffSize, uint32_t* inputOffsets, int nSrcs, DeviceHandle* smChannels, uint8_t* outputChannelIndexes, uint32_t* outputOffsets, int nDstChannels, size_t size, uint32_t flag) { size_t nPackets = size * 2 / sizeof(PacketType); + const size_t intputBaseOffset = flag & 0x1 ? 0 : inputBuffSize >> 1; const uint32_t srcOffset = srcOffsetByBytes / sizeof(PacketValType); const uint32_t dstOffset = dstOffsetByBytes / sizeof(PacketValType); PacketValType* srcPacketValue = (PacketValType*)src + srcOffset; @@ -230,7 +232,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) { PacketValType data = {}; for (int index = 0; index < nSrcs; ++index) { - PacketType* pkt = (PacketType*)((char*)inputBuff + 2 * inputOffsets[index]); + PacketType* pkt = (PacketType*)((char*)inputBuff + intputBaseOffset + 2 * inputOffsets[index]); PacketValType val = pkt[idx].read(flag); data = add_vectors(data, val); } @@ -239,16 +241,17 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy PacketType pkt(data, flag); for (int index = 0; index < nDstChannels; ++index) { - size_t offset = (outputOffsets[index] * 2) / sizeof(PacketType); + size_t offset = (intputBaseOffset + outputOffsets[index] * 2) / sizeof(PacketType); smChannels[outputChannelIndexes[index]].write(offset + idx, pkt); } } } template -MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, uint32_t dstOffset, uint32_t srcOffset, size_t size, - uint32_t flag) { - PacketType* srcPackets = (PacketType*)((char*)src + 2 * srcOffset); +MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize, uint32_t dstOffset, + uint32_t srcOffset, size_t size, uint32_t flag) { + const size_t outputScratchBaseOffset = flag & 0x1 ? 0 : srcSize >> 1; + PacketType* srcPackets = (PacketType*)((char*)src + outputScratchBaseOffset + 2 * srcOffset); PacketValType* result = (PacketValType*)((char*)dst + dstOffset); size_t nPackets = size * 2 / sizeof(PacketType); for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) { @@ -259,7 +262,7 @@ MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, uint32_t dstOf template __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* input, T* output, T* scratch, - DeviceExecutionPlan* plan, uint32_t flag) { + size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag) { extern __shared__ int4 sharedMem[]; int bid = blockIdx.x; int tid = threadIdx.x; @@ -279,7 +282,6 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu DeviceHandle* proxyChannels = localPlan->channels.proxyChannels; T* src = nullptr; T* dst = nullptr; - T* tmp = nullptr; for (int i = 0; i < nOperations; i++) { Operation* op = &operations[i]; switch (op->type) { @@ -310,21 +312,20 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu op->nInputs, op->size, false); break; case OperationType::PUT_PACKET: - handlePutPacket(op->srcOffset, smChannels, op->outputChannelIndexes, op->outputOffsets, + handlePutPacket(op->srcOffset, scratchSize, smChannels, op->outputChannelIndexes, op->outputOffsets, op->nOutputs, op->size, flag); break; case OperationType::REDUCE_SEND_PACKET: dst = getBuffer(input, output, scratch, op->dstBufferType); src = getBuffer(input, output, scratch, op->srcBufferType); - tmp = getBuffer(input, output, scratch, op->inputBufferType); - handleReduceSendPacket(dst, op->dstOffset, src, op->srcOffset, tmp, op->inputOffsets, - op->nInputs, smChannels, op->outputChannelIndexes, op->outputOffsets, - op->nOutputs, op->size, flag); + handleReduceSendPacket(dst, op->dstOffset, src, op->srcOffset, scratch, scratchSize, + op->inputOffsets, op->nInputs, smChannels, op->outputChannelIndexes, + op->outputOffsets, op->nOutputs, op->size, flag); break; case OperationType::COPY_PACKET: dst = getBuffer(input, output, scratch, op->dstBufferType); src = getBuffer(input, output, scratch, op->srcBufferType); - handleCopyPacket(dst, src, op->dstOffset, op->srcOffset, op->size, flag); + handleCopyPacket(dst, src, scratchSize, op->dstOffset, op->srcOffset, op->size, flag); break; default: break; @@ -338,32 +339,32 @@ class ExecutionKernel { #if defined(MSCCLPP_DEVICE_HIP) template static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, - DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, cudaStream_t stream, - uint32_t flag = 0) { + size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, + cudaStream_t stream, uint32_t flag = 0) { switch (dataType) { case DataType::INT32: executionKernel<<>>( - rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag); + rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag); break; case DataType::UINT32: executionKernel<<>>( - rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag); + rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag); break; case DataType::FLOAT16: executionKernel<<>>( - rank, (half*)src, (half*)dst, (half*)scratch, plan, flag); + rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag); break; case DataType::FLOAT32: executionKernel<<>>( - rank, (float*)src, (float*)dst, (float*)scratch, plan, flag); + rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag); break; } } #else // !defined(MSCCLPP_DEVICE_HIP) template static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, - DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, cudaStream_t stream, - uint32_t flag = 0); + size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, + cudaStream_t stream, uint32_t flag = 0); #endif // !defined(MSCCLPP_DEVICE_HIP) }; } // namespace mscclpp