diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index e17e41878..22b8b85d5 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -168,8 +168,9 @@ std::vector ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize) const { size_t sizePerChunk = 0; - if (this->inputChunks.at(rank) != 0) - sizePerChunk = inputSize / this->inputChunks.at(rank); + size_t inputChunks = this->inputChunks.at(rank); + if (inputChunks != 0) + sizePerChunk = (inputSize + inputChunks - 1) / this->inputChunks.at(rank); else throw mscclpp::Error("Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError); diff --git a/src/executor/executor.cc b/src/executor/executor.cc index f960fa0f8..5cb958028 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -168,8 +168,7 @@ struct Executor::Impl { ExecutionContext context; size_t maxScratchBufferSize = plan.impl_->getMaxScratchBufferSize(rank); - size_t scratchBufferSize = - std::min(plan.impl_->getScratchBufferSize(rank, sendMemRange), maxScratchBufferSize); + size_t scratchBufferSize = std::min(plan.impl_->getScratchBufferSize(rank, sendMemRange), maxScratchBufferSize); std::shared_ptr scratchBuffer; if (isNvlsSupported()) { scratchBuffer = allocSharedPhysicalCuda(scratchBufferSize); diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp index 3af585508..13815a69e 100644 --- a/src/include/execution_plan.hpp +++ b/src/include/execution_plan.hpp @@ -72,7 +72,7 @@ struct ExecutionPlan::Impl { std::vector getNvlsInfos(int rank) const; std::vector getConnectedPeers(int rank) const; std::vector getConnectedBufferTypes(int rank) const; - size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const; + size_t getScratchBufferSize(int rank, size_t inputSize) const; size_t getMaxScratchBufferSize(int rank) const; std::vector getOperations(int rank, int threadblock) const; int getThreadblockCount(int rank) const;