diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index acb3e9139..d002ad735 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -41,8 +41,8 @@ struct channelKey { }; struct planKey { - uint64_t minMessageBytes; - uint64_t maxMessageBytes; + size_t minMessageSize; + size_t maxMessageSize; bool isInPlace; }; @@ -194,7 +194,10 @@ static std::vector setupSmChannels(ncclComm_t comm, } static std::pair load_execution_plan(const char* filename) { - return std::make_pair(std::string(filename), executionPlanInstance{}); + std::shared_ptr plan = std::make_shared(filename); + std::string collective = plan->collective(); + planKey key{plan->minMessageSize(), plan->maxMessageSize(), plan->isInPlace()}; + return std::make_pair(collective, executionPlanInstance{key, plan}); } static std::shared_ptr> setupSmChannelDeviceHandles( @@ -530,7 +533,7 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t std::shared_ptr plan; bool inPlace = sendbuff == recvbuff; for (const auto& p : plans) { - if (bytes >= p.key.minMessageBytes && bytes < p.key.maxMessageBytes && inPlace == p.key.isInPlace) { + if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) { plan = p.plan; break; } @@ -582,7 +585,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t std::shared_ptr plan; bool inPlace = sendbuff == recvbuff; for (const auto& p : plans) { - if (bytes >= p.key.minMessageBytes && bytes < p.key.maxMessageBytes && inPlace == p.key.isInPlace) { + if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) { plan = p.plan; break; } diff --git a/include/mscclpp/executor.hpp b/include/mscclpp/executor.hpp index c909d39ea..91261d16c 100644 --- a/include/mscclpp/executor.hpp +++ b/include/mscclpp/executor.hpp @@ -29,6 +29,12 @@ class ExecutionPlan { ExecutionPlan(const std::string& planPath); ~ExecutionPlan() = default; + std::string name() const; + std::string collective() const; + size_t minMessageSize() const; + size_t maxMessageSize() const; + bool isInPlace() const; + private: struct Impl; std::shared_ptr impl_; diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index e9ad7b397..c8d9c5c17 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -560,4 +560,16 @@ void ExecutionPlan::Impl::operationsReset() { this->operations.clear(); } ExecutionPlan::ExecutionPlan(const std::string& name, const std::string& planPath) : impl_(std::make_shared(name, planPath)) {} +ExecutionPlan::ExecutionPlan(const std::string& planPath) : impl_(std::make_shared(planPath)) {} + +std::string ExecutionPlan::name() const { return this->impl_->name; } + +std::string ExecutionPlan::collective() const { return this->impl_->collective; } + +size_t ExecutionPlan::minMessageSize() const { return this->impl_->minMessageSize; } + +size_t ExecutionPlan::maxMessageSize() const { return this->impl_->maxMessageSize; } + +bool ExecutionPlan::isInPlace() const { return this->impl_->isInPlace; } + } // namespace mscclpp diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp index 609bcc1d8..00a0c5e76 100644 --- a/src/include/execution_plan.hpp +++ b/src/include/execution_plan.hpp @@ -108,8 +108,8 @@ struct ExecutionPlan::Impl { size_t inputSize; size_t outputSize; int nThreadsPerBlock; - uint64_t minMessageSize; - uint64_t maxMessageSize; + size_t minMessageSize; + size_t maxMessageSize; bool isInPlace; private: