Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Dec 7, 2024
1 parent 8cb4269 commit 06f58e3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,9 @@ std::vector<BufferType> ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c

size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize) const {
size_t sizePerChunk = 0;
if (this->inputChunks.at(rank) != 0)
sizePerChunk = inputSize / this->inputChunks.at(rank);
size_t inputChunks = this->inputChunks.at(rank);
if (inputChunks != 0)
sizePerChunk = (inputSize + inputChunks - 1) / this->inputChunks.at(rank);
else
throw mscclpp::Error("Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);

Expand Down
3 changes: 1 addition & 2 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ struct Executor::Impl {

ExecutionContext context;
size_t maxScratchBufferSize = plan.impl_->getMaxScratchBufferSize(rank);
size_t scratchBufferSize =
std::min(plan.impl_->getScratchBufferSize(rank, sendMemRange), maxScratchBufferSize);
size_t scratchBufferSize = std::min(plan.impl_->getScratchBufferSize(rank, sendMemRange), maxScratchBufferSize);
std::shared_ptr<char> scratchBuffer;
if (isNvlsSupported()) {
scratchBuffer = allocSharedPhysicalCuda<char>(scratchBufferSize);
Expand Down
2 changes: 1 addition & 1 deletion src/include/execution_plan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct ExecutionPlan::Impl {
std::vector<NvlsInfo> getNvlsInfos(int rank) const;
std::vector<int> getConnectedPeers(int rank) const;
std::vector<BufferType> getConnectedBufferTypes(int rank) const;
size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const;
size_t getScratchBufferSize(int rank, size_t inputSize) const;
size_t getMaxScratchBufferSize(int rank) const;
std::vector<Operation> getOperations(int rank, int threadblock) const;
int getThreadblockCount(int rank) const;
Expand Down

0 comments on commit 06f58e3

Please sign in to comment.