diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index d002ad735..f95d2aa66 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -400,8 +400,8 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc); commPtr->executor = std::make_shared(mscclppComm); - if (getenv("COMMUNICATION_COLLECTIVE_DIR")) { - std::string collectiveDir = getenv("COMMUNICATION_COLLECTIVE_DIR"); + if (getenv("MSCCLPP_EXECUTION_PLAN_DIR")) { + std::string collectiveDir = getenv("MSCCLPP_EXECUTION_PLAN_DIR"); for (const auto& entry : std::filesystem::directory_iterator(collectiveDir)) { if (entry.is_regular_file()) { std::string filename = entry.path().filename().string(); @@ -583,9 +583,10 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t std::vector& plans = comm->executionPlans["allgather"]; std::shared_ptr plan; - bool inPlace = sendbuff == recvbuff; + void* basePtr = (char*)sendbuff - rank * bytes; + bool inPlace = basePtr == recvbuff; for (const auto& p : plans) { - if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) { + if (bytes * nRank >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) { plan = p.plan; break; } diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index c8d9c5c17..b5bc4bc99 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -97,7 +97,15 @@ using json = nlohmann::json; ExecutionPlan::Impl::Impl(const std::string name, const std::string planPath) : name(name), planPath(planPath), isUsingPacket(false) {} -ExecutionPlan::Impl::Impl(const std::string planPath) : planPath(planPath), isUsingPacket(false) {} +ExecutionPlan::Impl::Impl(const std::string planPath) : planPath(planPath), isUsingPacket(false) { + std::ifstream file(this->planPath); + json obj = json::parse(file); + this->name = obj["name"]; + this->collective = obj["collective"]; + this->isInPlace = obj["inplace"]; + this->minMessageSize = obj.value("min_message_size", 0); + this->maxMessageSize = obj.value("max_message_size", std::numeric_limits::max()); +} std::vector ExecutionPlan::Impl::getChannelInfos(int rank, ChannelType channelType) const { auto pred = [channelType](const ChannelInfo& info) { return info.channelType == channelType; }; @@ -186,9 +194,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t constDstOffset) { std::ifstream file(this->planPath); json obj = json::parse(file); - if (this->name.empty()) { - this->name = obj["name"]; - } else if (this->name != obj["name"]) { + if (this->name != obj["name"]) { throw Error("Plan name does not match", ErrorCode::ExecutorError); } this->collective = obj["collective"]; @@ -201,7 +207,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024); this->minMessageSize = obj.value("min_message_size", 0); this->maxMessageSize = obj.value("max_message_size", std::numeric_limits::max()); - this->isInPlace = obj["in_place"]; + this->isInPlace = obj["inplace"]; const auto& gpus = obj["gpus"]; for (const auto& gpu : gpus) {