Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Dec 1, 2024
1 parent d9c297b commit e316b82
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 91 deletions.
154 changes: 65 additions & 89 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT license.

#include <algorithm>
#include <cstdlib>
#include <filesystem>
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/executor.hpp>
Expand Down Expand Up @@ -38,6 +40,17 @@ struct channelKey {
bool operator==(const channelKey& other) const { return buff == other.buff && bytes == other.bytes; }
};

struct planKey {
uint64_t minMessageBytes;
uint64_t maxMessageBytes;
bool isInPlace;
};

struct executionPlanInstance {
planKey key;
std::shared_ptr<mscclpp::ExecutionPlan> plan;
};

namespace std {
template <>
struct hash<channelKey> {
Expand All @@ -57,17 +70,14 @@ struct ncclComm {
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
std::shared_ptr<mscclpp::Executor> executor;
std::shared_ptr<mscclpp::ExecutionPlan> allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan,
allReduceOPPlan, allGatherIPPlan, allGatherOPPlan, allGatherPacketIPPlan, allGatherPacketOPPlan;
std::unordered_map<std::string, std::vector<executionPlanInstance>> executionPlans;

std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;

size_t allReduceSmallMessageSizeBoundary, allReduceLargeMessageSizeBoundary;
size_t allGatherSmallMessageSizeBoundary, allGatherLargeMessageSizeBoundary;
uint32_t numScratchBuff;
uint32_t buffFlag;
};
Expand Down Expand Up @@ -183,6 +193,10 @@ static std::vector<mscclpp::SmChannel> setupSmChannels(ncclComm_t comm,
return channels;
}

static std::pair<std::string, executionPlanInstance> load_execution_plan(const char* filename) {
return std::make_pair(std::string(filename), executionPlanInstance{});
}

static std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> setupSmChannelDeviceHandles(
const std::vector<mscclpp::SmChannel>& smChannels) {
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
Expand Down Expand Up @@ -383,52 +397,16 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);

if (getenv("ALLREDUCEPKT_IP_JSON_FILE"))
commPtr->allReducePacketIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_IP_JSON_FILE")));
if (getenv("ALLREDUCEPKT_OP_JSON_FILE"))
commPtr->allReducePacketOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_OP_JSON_FILE")));
if (getenv("ALLREDUCE_IP_JSON_FILE"))
commPtr->allReduceIPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_IP_JSON_FILE")));
if (getenv("ALLREDUCE_OP_JSON_FILE"))
commPtr->allReduceOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE")));
if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"))
commPtr->allReduceSmallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"));
else
commPtr->allReduceSmallMessageSizeBoundary = 16 * (1 << 10);
if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"))
commPtr->allReduceLargeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"));
else
commPtr->allReduceLargeMessageSizeBoundary = 1 << 20;

if (getenv("ALLGATHERPKT_IP_JSON_FILE"))
commPtr->allGatherPacketIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_IP_JSON_FILE")));
if (getenv("ALLGATHERPKT_OP_JSON_FILE"))
commPtr->allGatherPacketOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_OP_JSON_FILE")));
if (getenv("ALLGATHER_IP_JSON_FILE"))
commPtr->allGatherIPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_IP_JSON_FILE")));
if (getenv("ALLGATHER_OP_JSON_FILE"))
commPtr->allGatherOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_OP_JSON_FILE")));
if (getenv("ALLGATHER_SMALL_MSG_BOUNDARY"))
commPtr->allGatherSmallMessageSizeBoundary = parseSize(getenv("ALLGATHER_SMALL_MSG_BOUNDARY"));
else
commPtr->allGatherSmallMessageSizeBoundary = (1 << 10);
if (getenv("ALLGATHER_LARGE_MSG_BOUNDARY"))
commPtr->allGatherLargeMessageSizeBoundary = parseSize(getenv("ALLGATHER_LARGE_MSG_BOUNDARY"));
else
commPtr->allGatherLargeMessageSizeBoundary = 1 << 20;

if (commPtr->allReduceSmallMessageSizeBoundary > commPtr->allReduceLargeMessageSizeBoundary)
return ncclInvalidArgument;
if (commPtr->allGatherSmallMessageSizeBoundary > commPtr->allGatherLargeMessageSizeBoundary)
return ncclInvalidArgument;
if (getenv("COMMUNICATION_COLLECTIVE_DIR")) {
std::string collectiveDir = getenv("COMMUNICATION_COLLECTIVE_DIR");
for (const auto& entry : std::filesystem::directory_iterator(collectiveDir)) {
if (entry.is_regular_file()) {
std::string filename = entry.path().filename().string();
auto plan = load_execution_plan(entry.path().c_str());
commPtr->executionPlans[plan.first].push_back(plan.second);
}
}
}

*comm = commPtr;
return ncclSuccess;
Expand Down Expand Up @@ -548,40 +526,39 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
size_t bytes = count * ncclTypeSize(datatype);
int rank = comm->comm->bootstrap()->getRank();

if (bytes < comm->allReduceSmallMessageSizeBoundary) {
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
} else {
std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes <= comm->allReduceLargeMessageSizeBoundary)
plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan;
else {
plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan;
std::vector<executionPlanInstance>& plans = comm->executionPlans["allreduce"];
std::shared_ptr<mscclpp::ExecutionPlan> plan;
bool inPlace = sendbuff == recvbuff;
for (const auto& p : plans) {
if (bytes >= p.key.minMessageBytes && bytes < p.key.maxMessageBytes && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
}

if (plan == nullptr)
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
if (plan == nullptr)
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan,
stream, mscclpp::PacketType::LL8);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32,
*plan, stream, mscclpp::PacketType::LL8);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
mscclpp::DataType::BFLOAT16, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan,
stream, mscclpp::PacketType::LL8);
break;
default:
return ncclInvalidArgument;
}
switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan,
stream, mscclpp::PacketType::LL8);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32, *plan,
stream, mscclpp::PacketType::LL8);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
mscclpp::DataType::BFLOAT16, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan,
stream, mscclpp::PacketType::LL8);
break;
default:
return ncclInvalidArgument;
}

return ncclSuccess;
Expand All @@ -601,16 +578,15 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();

if (bytes * nRank < comm->allGatherSmallMessageSizeBoundary)
return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream);

std::vector<executionPlanInstance>& plans = comm->executionPlans["allgather"];
std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes * nRank <= comm->allGatherLargeMessageSizeBoundary)
plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherPacketIPPlan : comm->allGatherPacketOPPlan;
else {
plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherIPPlan : comm->allGatherOPPlan;
bool inPlace = sendbuff == recvbuff;
for (const auto& p : plans) {
if (bytes >= p.key.minMessageBytes && bytes < p.key.maxMessageBytes && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
}

if (plan == nullptr) return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream);

switch (datatype) {
Expand Down
1 change: 1 addition & 0 deletions include/mscclpp/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ enum class PacketType {
class ExecutionPlan {
public:
ExecutionPlan(const std::string& name, const std::string& planPath);
ExecutionPlan(const std::string& planPath);
~ExecutionPlan() = default;

private:
Expand Down
10 changes: 9 additions & 1 deletion src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ using json = nlohmann::json;
ExecutionPlan::Impl::Impl(const std::string name, const std::string planPath)
: name(name), planPath(planPath), isUsingPacket(false) {}

ExecutionPlan::Impl::Impl(const std::string planPath) : planPath(planPath), isUsingPacket(false) {}

std::vector<ChannelInfo> ExecutionPlan::Impl::getChannelInfos(int rank, ChannelType channelType) const {
auto pred = [channelType](const ChannelInfo& info) { return info.channelType == channelType; };
return filter(this->channelInfos.at(rank), pred);
Expand Down Expand Up @@ -184,16 +186,22 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,
size_t constDstOffset) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
if (this->name.empty()) {
this->name = obj["name"];
} else if (this->name != obj["name"]) {
throw Error("Plan name does not match", ErrorCode::ExecutorError);
}
this->collective = obj["collective"];
std::string protocol = obj["protocol"];
if (protocol == "LL") {
this->isUsingPacket = true;
}
this->inputSize = inputSize;
this->outputSize = outputSize;
this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024);
this->minMessageSize = obj.value("min_message_size", 0);
this->maxMessageSize = obj.value("max_message_size", std::numeric_limits<uint64_t>::max());
this->isInPlace = obj["in_place"];
const auto& gpus = obj["gpus"];

for (const auto& gpu : gpus) {
Expand Down
7 changes: 6 additions & 1 deletion src/include/execution_plan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ struct ChannelInfo {
struct ExecutionPlan::Impl {
public:
Impl(const std::string name, const std::string planPath);
Impl(const std::string planPath);
~Impl() = default;

std::vector<ChannelInfo> getChannelInfos(int rank, ChannelType channelType) const;
Expand All @@ -85,7 +86,8 @@ struct ExecutionPlan::Impl {
void reset();
void operationsReset();

const std::string name;
std::string name;
std::string collective;
const std::string planPath;
bool isUsingPacket;
// operations for [rank][threadblock] = [operations]
Expand All @@ -106,6 +108,9 @@ struct ExecutionPlan::Impl {
size_t inputSize;
size_t outputSize;
int nThreadsPerBlock;
uint64_t minMessageSize;
uint64_t maxMessageSize;
bool isInPlace;

private:
std::pair<size_t, u_int32_t> calcSizePerRank(int rank, size_t inputSize, size_t outputSize) const;
Expand Down

0 comments on commit e316b82

Please sign in to comment.