Skip to content

Commit

Permalink
Add separate option for double scratch buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
yzygitzh committed Oct 22, 2024
1 parent b72decb commit 96a6d56
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 43 deletions.
12 changes: 10 additions & 2 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,14 @@ size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, siz
else
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);

size_t scratchBufferSize = sizePerRank * this->scratchChunks.at(rank);
if (this->isUsingPacket) {
return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/;
scratchBufferSize *= 2; // data + flag
}
return sizePerRank * this->scratchChunks.at(rank);
if (this->isUsingDoubleScratchBuffer) {
scratchBufferSize *= 2; // double buffer
}
return scratchBufferSize;
}
std::vector<Operation> ExecutionPlan::Impl::getOperations(int rank, int threadblock) const {
return this->operations.at(rank)[threadblock];
Expand All @@ -170,6 +174,8 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper

int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; }

bool ExecutionPlan::Impl::getIsUsingDoubleScratchBuffer() const { return this->getIsUsingDoubleScratchBuffer; }

void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
size_t constDstOffset) {
std::ifstream file(this->planPath);
Expand All @@ -182,6 +188,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,
this->isUsingPacket = true;
}
this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024);
this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"];
const auto& gpus = obj["gpus"];

for (const auto& gpu : gpus) {
Expand Down Expand Up @@ -209,6 +216,7 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output
if (protocol == "LL") {
this->isUsingPacket = true;
}
this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"];
const auto& gpus = obj["gpus"];

for (const auto& gpu : gpus) {
Expand Down
26 changes: 18 additions & 8 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct ExecutionContext {
size_t scratchBufferSize;
std::shared_ptr<char> deviceExecutionPlansBuffer;
int nthreadsPerBlock;
bool isUsingDoubleScratchBuffer;
};

struct Executor::Impl {
Expand Down Expand Up @@ -106,6 +107,7 @@ struct Executor::Impl {
context.scratchBufferSize = scratchBufferSize;
context.proxyService = std::make_shared<ProxyService>();
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
context.isUsingDoubleScratchBuffer = plan.impl_->getIsUsingDoubleScratchBuffer();
this->setupConnections(context, rank, plan);
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
Expand Down Expand Up @@ -306,6 +308,14 @@ struct Executor::Impl {
cudaStream_t stream, PacketType packetType) {
static uint32_t flag = 0;
int nthreadblocks = context.deviceExecutionPlans.size();
char* kernelScratchBufferPtr = context.scratchBuffer.get();
size_t kernelScratchBufferSize = context.scratchBufferSize;
if (context.isUsingDoubleScratchBuffer) {
kernelScratchBufferSize /= 2;
if (flag % 2) {
kernelScratchBufferPtr += kernelScratchBufferSize;
}
}
#if defined(ENABLE_NPKIT)
#if defined(__HIP_PLATFORM_AMD__)
if (nthreadblocks > NPKIT_MAX_NUM_GPU_THREADBLOCKS) {
Expand All @@ -321,16 +331,16 @@ struct Executor::Impl {
#endif
switch (packetType) {
case PacketType::LL16:
ExecutionKernel::launchKernel<LL16Packet>(
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
sharedMemSize, stream, ++flag);
ExecutionKernel::launchKernel<LL16Packet>(rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff,
(void*)kernelScratchBufferPtr, kernelScratchBufferSize, dataType,
(DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
sharedMemSize, stream, ++flag);
break;
case PacketType::LL8:
ExecutionKernel::launchKernel<LL8Packet>(
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
sharedMemSize, stream, ++flag);
ExecutionKernel::launchKernel<LL8Packet>(rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff,
(void*)kernelScratchBufferPtr, kernelScratchBufferSize, dataType,
(DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
sharedMemSize, stream, ++flag);
break;
default:
throw Error("Invalid packet type", ErrorCode::ExecutorError);
Expand Down
62 changes: 29 additions & 33 deletions src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,14 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf
}

template <typename PacketType>
MSCCLPP_DEVICE_INLINE void handlePutPacket(size_t scratchSize, DeviceHandle<SmChannel>* smChannels,
MSCCLPP_DEVICE_INLINE void handlePutPacket(DeviceHandle<SmChannel>* smChannels,
DeviceHandle<SimpleProxyChannel>* proxyChannels, uint8_t* dstChannelIndexes,
uint32_t* dstOffsets, uint32_t* srcOffsets, int nDstChannels, uint32_t size,
ChannelType chType, uint32_t flag) {
const size_t scratchBaseOffset = flag & 0x1 ? 0 : scratchSize >> 1;
if (chType == ChannelType::SM) {
for (int index = 0; index < nDstChannels; ++index) {
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(
scratchBaseOffset + dstOffsets[index] * 2, srcOffsets[index], size, threadIdx.x, blockDim.x, flag);
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(dstOffsets[index] * 2, srcOffsets[index], size,
threadIdx.x, blockDim.x, flag);
}
}
if (chType == ChannelType::PROXY) {
Expand All @@ -294,8 +293,8 @@ MSCCLPP_DEVICE_INLINE void handlePutPacket(size_t scratchSize, DeviceHandle<SmCh
return;
}
// For proxy channel, we assume src and dst are in packet format
uint32_t dstOffset = (dstOffsets[tid] << 1) + scratchBaseOffset;
uint32_t srcOffset = (srcOffsets[tid] << 1) + scratchBaseOffset;
uint32_t dstOffset = dstOffsets[tid] << 1;
uint32_t srcOffset = srcOffsets[tid] << 1;
proxyChannels[dstChannelIndexes[tid]].put(dstOffset, srcOffset, size << 1);
}
}
Expand All @@ -307,15 +306,14 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
uint32_t* outputOffsets, int nDstChannels, size_t size,
uint32_t flag) {
size_t nPackets = size * 2 / sizeof(PacketType);
const size_t intputBaseOffset = flag & 0x1 ? 0 : inputBuffSize >> 1;
const uint32_t srcOffset = srcOffsetByBytes / sizeof(PacketPayload<PacketType>);
const uint32_t dstOffset = dstOffsetByBytes / sizeof(PacketPayload<PacketType>);
PacketPayload<PacketType>* srcPacketPayload = (PacketPayload<PacketType>*)src + srcOffset;
PacketPayload<PacketType>* dstPacketPayload = (PacketPayload<PacketType>*)dst + dstOffset;
for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) {
PacketPayload<PacketType> data = {};
for (int index = 0; index < nSrcs; ++index) {
PacketType* pkt = (PacketType*)((char*)inputBuff + intputBaseOffset + 2 * inputOffsets[index]);
PacketType* pkt = (PacketType*)((char*)inputBuff + 2 * inputOffsets[index]);
PacketPayload<PacketType> val = pkt[idx].read(flag);
data = add_vectors<T>(data, val);
}
Expand All @@ -325,7 +323,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
if (SendToRemote) {
PacketType pkt(data, flag);
for (int index = 0; index < nDstChannels; ++index) {
size_t offset = (intputBaseOffset + outputOffsets[index] * 2) / sizeof(PacketType);
size_t offset = outputOffsets[index] * 2 / sizeof(PacketType);
smChannels[outputChannelIndexes[index]].write(offset + idx, pkt);
}
}
Expand All @@ -335,8 +333,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy
template <typename PacketType>
MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize, uint32_t dstOffset,
uint32_t srcOffset, size_t size, uint32_t flag) {
const size_t inputScratchBaseOffset = flag & 0x1 ? 0 : srcSize >> 1;
PacketType* srcPackets = (PacketType*)((char*)src + inputScratchBaseOffset + 2 * srcOffset);
PacketType* srcPackets = (PacketType*)((char*)src + 2 * srcOffset);
PacketPayload<PacketType>* result = (PacketPayload<PacketType>*)((char*)dst + dstOffset);
size_t nPackets = size * 2 / sizeof(PacketType);
for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) {
Expand All @@ -348,8 +345,7 @@ MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize
template <typename PacketType>
MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t dstSize, uint32_t dstOffset,
uint32_t srcOffset, size_t size, uint32_t flag) {
const size_t outputScratchBaseOffset = flag & 0x1 ? 0 : dstSize >> 1;
dstOffset = dstOffset * 2 + outputScratchBaseOffset;
dstOffset = dstOffset * 2;
mscclpp::putPackets<PacketType>(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag);
}

Expand Down Expand Up @@ -403,7 +399,7 @@ MSCCLPP_DEVICE_INLINE void handleCopy(void* dst, void* src, uint32_t dstOffset,

template <typename T, typename PacketType = LL16Packet>
__global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* input, T* output, T* scratch,
size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag
DeviceExecutionPlan* plan, uint32_t flag
#if defined(ENABLE_NPKIT)
,
NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) {
Expand Down Expand Up @@ -501,28 +497,28 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs,
op.size, false);
} else if (op.type == OperationType::PUT_PACKET) {
handlePutPacket<PacketType>(scratchSize, smChannels, proxyChannels, op.outputChannelIndexes, op.outputOffsets,
op.inputOffsets, op.nOutputs, op.size, op.channelType, flag);
handlePutPacket<PacketType>(smChannels, proxyChannels, op.outputChannelIndexes, op.outputOffsets, op.inputOffsets,
op.nOutputs, op.size, op.channelType, flag);
} else if (op.type == OperationType::REDUCE_SEND_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleReduceSendPacket<T, PacketType>(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize, op.inputOffsets,
op.nInputs, smChannels, op.outputChannelIndexes, op.outputOffsets,
op.nOutputs, op.size, flag);
handleReduceSendPacket<T, PacketType>(dst, op.dstOffset, src, op.srcOffset, scratch, op.inputOffsets, op.nInputs,
smChannels, op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size,
flag);
} else if (op.type == OperationType::REDUCE_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleReduceSendPacket<T, PacketType, false>(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize,
op.inputOffsets, op.nInputs, smChannels, op.outputChannelIndexes,
op.outputOffsets, op.nOutputs, op.size, flag);
handleReduceSendPacket<T, PacketType, false>(dst, op.dstOffset, src, op.srcOffset, scratch, op.inputOffsets,
op.nInputs, smChannels, op.outputChannelIndexes, op.outputOffsets,
op.nOutputs, op.size, flag);
} else if (op.type == OperationType::COPY_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleCopyPacket<PacketType>(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag);
handleCopyPacket<PacketType>(dst, src, op.dstOffset, op.srcOffset, op.size, flag);
} else if (op.type == OperationType::TRANSFORM_TO_PACKET) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
handleTransformToPacket<PacketType>(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag);
handleTransformToPacket<PacketType>(dst, src, op.dstOffset, op.srcOffset, op.size, flag);
} else if (op.type == OperationType::REDUCE_SEND) {
T* dst = getBuffer(input, output, scratch, op.dstBufferType);
T* src = getBuffer(input, output, scratch, op.srcBufferType);
Expand All @@ -548,12 +544,12 @@ class ExecutionKernel {
#if defined(MSCCLPP_DEVICE_HIP)
template <typename PacketType>
static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag = 0) {
DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, cudaStream_t stream,
uint32_t flag = 0) {
switch (dataType) {
case DataType::INT32:
executionKernel<int32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -563,7 +559,7 @@ class ExecutionKernel {
break;
case DataType::UINT32:
executionKernel<uint32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -573,7 +569,7 @@ class ExecutionKernel {
break;
case DataType::FLOAT16:
executionKernel<half, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag
rank, (half*)src, (half*)dst, (half*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -583,7 +579,7 @@ class ExecutionKernel {
break;
case DataType::FLOAT32:
executionKernel<float, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag
rank, (float*)src, (float*)dst, (float*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -593,7 +589,7 @@ class ExecutionKernel {
break;
case DataType::BFLOAT16:
executionKernel<__bfloat16, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -606,8 +602,8 @@ class ExecutionKernel {
#else // !defined(MSCCLPP_DEVICE_HIP)
template <typename PacketType>
static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag = 0);
DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, cudaStream_t stream,
uint32_t flag = 0);
#endif // !defined(MSCCLPP_DEVICE_HIP)
};
} // namespace mscclpp
Expand Down
2 changes: 2 additions & 0 deletions src/include/execution_plan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct ExecutionPlan::Impl {
std::vector<Operation> getOperations(int rank, int threadblock) const;
int getThreadblockCount(int rank) const;
int getNThreadsPerBlock() const;
bool getIsUsingDoubleScratchBuffer() const;

void loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset);
void lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset);
Expand Down Expand Up @@ -96,6 +97,7 @@ struct ExecutionPlan::Impl {
size_t inputSize;
size_t outputSize;
int nThreadsPerBlock;
bool isUsingDoubleScratchBuffer;

private:
std::pair<size_t, u_int32_t> calcSizePerRank(int rank, size_t inputSize, size_t outputSize) const;
Expand Down

0 comments on commit 96a6d56

Please sign in to comment.