Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Dec 1, 2024
1 parent e316b82 commit 2b10c92
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 7 deletions.
13 changes: 8 additions & 5 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ struct channelKey {
};

struct planKey {
uint64_t minMessageBytes;
uint64_t maxMessageBytes;
size_t minMessageSize;
size_t maxMessageSize;
bool isInPlace;
};

Expand Down Expand Up @@ -194,7 +194,10 @@ static std::vector<mscclpp::SmChannel> setupSmChannels(ncclComm_t comm,
}

static std::pair<std::string, executionPlanInstance> load_execution_plan(const char* filename) {
return std::make_pair(std::string(filename), executionPlanInstance{});
std::shared_ptr<mscclpp::ExecutionPlan> plan = std::make_shared<mscclpp::ExecutionPlan>(filename);
std::string collective = plan->collective();
planKey key{plan->minMessageSize(), plan->maxMessageSize(), plan->isInPlace()};
return std::make_pair(collective, executionPlanInstance{key, plan});
}

static std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> setupSmChannelDeviceHandles(
Expand Down Expand Up @@ -530,7 +533,7 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
std::shared_ptr<mscclpp::ExecutionPlan> plan;
bool inPlace = sendbuff == recvbuff;
for (const auto& p : plans) {
if (bytes >= p.key.minMessageBytes && bytes < p.key.maxMessageBytes && inPlace == p.key.isInPlace) {
if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
Expand Down Expand Up @@ -582,7 +585,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
std::shared_ptr<mscclpp::ExecutionPlan> plan;
bool inPlace = sendbuff == recvbuff;
for (const auto& p : plans) {
if (bytes >= p.key.minMessageBytes && bytes < p.key.maxMessageBytes && inPlace == p.key.isInPlace) {
if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
Expand Down
6 changes: 6 additions & 0 deletions include/mscclpp/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class ExecutionPlan {
ExecutionPlan(const std::string& planPath);
~ExecutionPlan() = default;

std::string name() const;
std::string collective() const;
size_t minMessageSize() const;
size_t maxMessageSize() const;
bool isInPlace() const;

private:
struct Impl;
std::shared_ptr<Impl> impl_;
Expand Down
12 changes: 12 additions & 0 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -560,4 +560,16 @@ void ExecutionPlan::Impl::operationsReset() { this->operations.clear(); }
ExecutionPlan::ExecutionPlan(const std::string& name, const std::string& planPath)
: impl_(std::make_shared<Impl>(name, planPath)) {}

ExecutionPlan::ExecutionPlan(const std::string& planPath) : impl_(std::make_shared<Impl>(planPath)) {}

std::string ExecutionPlan::name() const { return this->impl_->name; }

std::string ExecutionPlan::collective() const { return this->impl_->collective; }

size_t ExecutionPlan::minMessageSize() const { return this->impl_->minMessageSize; }

size_t ExecutionPlan::maxMessageSize() const { return this->impl_->maxMessageSize; }

bool ExecutionPlan::isInPlace() const { return this->impl_->isInPlace; }

} // namespace mscclpp
4 changes: 2 additions & 2 deletions src/include/execution_plan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ struct ExecutionPlan::Impl {
size_t inputSize;
size_t outputSize;
int nThreadsPerBlock;
uint64_t minMessageSize;
uint64_t maxMessageSize;
size_t minMessageSize;
size_t maxMessageSize;
bool isInPlace;

private:
Expand Down

0 comments on commit 2b10c92

Please sign in to comment.