Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select algo according to json config #396

Merged
merged 10 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 72 additions & 89 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license.

#include <algorithm>
#include <filesystem>
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/executor.hpp>
Expand Down Expand Up @@ -38,6 +39,17 @@ struct channelKey {
bool operator==(const channelKey& other) const { return buff == other.buff && bytes == other.bytes; }
};

struct planKey {
size_t minMessageSize;
size_t maxMessageSize;
bool isInPlace;
};

struct executionPlanInstance {
planKey key;
std::shared_ptr<mscclpp::ExecutionPlan> plan;
};

namespace std {
template <>
struct hash<channelKey> {
Expand All @@ -57,17 +69,14 @@ struct ncclComm {
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
std::shared_ptr<mscclpp::Executor> executor;
std::shared_ptr<mscclpp::ExecutionPlan> allReducePacketIPPlan, allReducePacketOPPlan, allReduceIPPlan,
allReduceOPPlan, allGatherIPPlan, allGatherOPPlan, allGatherPacketIPPlan, allGatherPacketOPPlan;
std::unordered_map<std::string, std::vector<executionPlanInstance>> executionPlans;

std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;

size_t allReduceSmallMessageSizeBoundary, allReduceLargeMessageSizeBoundary;
size_t allGatherSmallMessageSizeBoundary, allGatherLargeMessageSizeBoundary;
uint32_t numScratchBuff;
uint32_t buffFlag;
};
Expand Down Expand Up @@ -183,6 +192,13 @@ static std::vector<mscclpp::SmChannel> setupSmChannels(ncclComm_t comm,
return channels;
}

static std::pair<std::string, executionPlanInstance> loadExecutionPlan(const std::string& filename) {
std::shared_ptr<mscclpp::ExecutionPlan> plan = std::make_shared<mscclpp::ExecutionPlan>(filename);
std::string collective = plan->collective();
planKey key{plan->minMessageSize(), plan->maxMessageSize(), plan->isInPlace()};
return std::make_pair(collective, executionPlanInstance{key, plan});
}

static std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> setupSmChannelDeviceHandles(
const std::vector<mscclpp::SmChannel>& smChannels) {
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
Expand Down Expand Up @@ -383,52 +399,19 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);

if (getenv("ALLREDUCEPKT_IP_JSON_FILE"))
commPtr->allReducePacketIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_IP_JSON_FILE")));
if (getenv("ALLREDUCEPKT_OP_JSON_FILE"))
commPtr->allReducePacketOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allreduce_packet", getenv("ALLREDUCEPKT_OP_JSON_FILE")));
if (getenv("ALLREDUCE_IP_JSON_FILE"))
commPtr->allReduceIPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_IP_JSON_FILE")));
if (getenv("ALLREDUCE_OP_JSON_FILE"))
commPtr->allReduceOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allreduce", getenv("ALLREDUCE_OP_JSON_FILE")));
if (getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"))
commPtr->allReduceSmallMessageSizeBoundary = parseSize(getenv("ALLREDUCE_SMALL_MSG_BOUNDARY"));
else
commPtr->allReduceSmallMessageSizeBoundary = 16 * (1 << 10);
if (getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"))
commPtr->allReduceLargeMessageSizeBoundary = parseSize(getenv("ALLREDUCE_LARGE_MSG_BOUNDARY"));
else
commPtr->allReduceLargeMessageSizeBoundary = 1 << 20;

if (getenv("ALLGATHERPKT_IP_JSON_FILE"))
commPtr->allGatherPacketIPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_IP_JSON_FILE")));
if (getenv("ALLGATHERPKT_OP_JSON_FILE"))
commPtr->allGatherPacketOPPlan = std::make_shared<mscclpp::ExecutionPlan>(
mscclpp::ExecutionPlan("allgather_pkt", getenv("ALLGATHERPKT_OP_JSON_FILE")));
if (getenv("ALLGATHER_IP_JSON_FILE"))
commPtr->allGatherIPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_IP_JSON_FILE")));
if (getenv("ALLGATHER_OP_JSON_FILE"))
commPtr->allGatherOPPlan =
std::make_shared<mscclpp::ExecutionPlan>(mscclpp::ExecutionPlan("allgather", getenv("ALLGATHER_OP_JSON_FILE")));
if (getenv("ALLGATHER_SMALL_MSG_BOUNDARY"))
commPtr->allGatherSmallMessageSizeBoundary = parseSize(getenv("ALLGATHER_SMALL_MSG_BOUNDARY"));
else
commPtr->allGatherSmallMessageSizeBoundary = (1 << 10);
if (getenv("ALLGATHER_LARGE_MSG_BOUNDARY"))
commPtr->allGatherLargeMessageSizeBoundary = parseSize(getenv("ALLGATHER_LARGE_MSG_BOUNDARY"));
else
commPtr->allGatherLargeMessageSizeBoundary = 1 << 20;

if (commPtr->allReduceSmallMessageSizeBoundary > commPtr->allReduceLargeMessageSizeBoundary)
return ncclInvalidArgument;
if (commPtr->allGatherSmallMessageSizeBoundary > commPtr->allGatherLargeMessageSizeBoundary)
return ncclInvalidArgument;
if (getenv("MSCCLPP_EXECUTION_PLAN_DIR")) {
std::string collectiveDir = getenv("MSCCLPP_EXECUTION_PLAN_DIR");
if (!std::filesystem::is_directory(collectiveDir)) {
return ncclInvalidArgument;
}
for (const auto& entry : std::filesystem::directory_iterator(collectiveDir)) {
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
if (entry.is_regular_file()) {
std::string filename = entry.path().filename().string();
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
auto plan = loadExecutionPlan(entry.path());
commPtr->executionPlans[plan.first].push_back(plan.second);
}
}
}

*comm = commPtr;
return ncclSuccess;
Expand Down Expand Up @@ -548,40 +531,39 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
size_t bytes = count * ncclTypeSize(datatype);
int rank = comm->comm->bootstrap()->getRank();

if (bytes < comm->allReduceSmallMessageSizeBoundary) {
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
} else {
std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes <= comm->allReduceLargeMessageSizeBoundary)
plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan;
else {
plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan;
std::vector<executionPlanInstance>& plans = comm->executionPlans["allreduce"];
std::shared_ptr<mscclpp::ExecutionPlan> plan;
bool inPlace = sendbuff == recvbuff;
for (const auto& p : plans) {
if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
}

if (plan == nullptr)
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
if (plan == nullptr)
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan,
stream, mscclpp::PacketType::LL8);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32,
*plan, stream, mscclpp::PacketType::LL8);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
mscclpp::DataType::BFLOAT16, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan,
stream, mscclpp::PacketType::LL8);
break;
default:
return ncclInvalidArgument;
}
switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan,
stream, mscclpp::PacketType::LL8);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32, *plan,
stream, mscclpp::PacketType::LL8);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
mscclpp::DataType::BFLOAT16, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan,
stream, mscclpp::PacketType::LL8);
break;
default:
return ncclInvalidArgument;
}

return ncclSuccess;
Expand All @@ -601,16 +583,17 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();

if (bytes * nRank < comm->allGatherSmallMessageSizeBoundary)
return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream);

std::vector<executionPlanInstance>& plans = comm->executionPlans["allgather"];
std::shared_ptr<mscclpp::ExecutionPlan> plan;
if (bytes * nRank <= comm->allGatherLargeMessageSizeBoundary)
plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherPacketIPPlan : comm->allGatherPacketOPPlan;
else {
plan = (sendbuff == (char*)recvbuff + rank * bytes) ? comm->allGatherIPPlan : comm->allGatherOPPlan;
void* basePtr = (char*)sendbuff - rank * bytes;
bool inPlace = basePtr == recvbuff;
const size_t totalBytes = bytes * nRank;
for (const auto& p : plans) {
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
}

if (plan == nullptr) return ncclAllGatherFallback(sendbuff, recvbuff, sendcount, datatype, comm, stream);

switch (datatype) {
Expand Down
8 changes: 7 additions & 1 deletion include/mscclpp/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,15 @@ enum class PacketType {

class ExecutionPlan {
public:
ExecutionPlan(const std::string& name, const std::string& planPath);
ExecutionPlan(const std::string& planPath);
~ExecutionPlan() = default;

std::string name() const;
std::string collective() const;
size_t minMessageSize() const;
size_t maxMessageSize() const;
bool isInPlace() const;

private:
struct Impl;
std::shared_ptr<Impl> impl_;
Expand Down
6 changes: 5 additions & 1 deletion python/mscclpp/executor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ void register_executor(nb::module_& m) {
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);

nb::class_<ExecutionPlan>(m, "ExecutionPlan")
.def(nb::init<const std::string, const std::string>(), nb::arg("name"), nb::arg("planPath"));
.def(nb::init<const std::string>(), nb::arg("planPath"))
.def("name", &ExecutionPlan::name)
.def("collective", &ExecutionPlan::collective)
.def("min_message_size", &ExecutionPlan::minMessageSize)
.def("max_message_size", &ExecutionPlan::maxMessageSize);

nb::class_<Executor>(m, "Executor")
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
Expand Down
22 changes: 10 additions & 12 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def bench_time(n_iters: int, n_graph_iters: int, func):


def bench_correctness(
execution_plan_name: str,
collective: str,
input_buf: cp.ndarray,
result_buf: cp.ndarray,
test_buf: cp.ndarray,
Expand All @@ -72,9 +72,9 @@ def bench_correctness(
type_size = cp.dtype(parse_dtype(dtype_str)).itemsize

fill_data_kernel_name = "fill_data_%s" % dtype_str
if "allgather" in execution_plan_name:
if "allgather" in collective:
coll = "all_gather"
elif "reducescatter" in execution_plan_name:
elif "reducescatter" in collective:
coll = "reduce_scatter"
else:
coll = "all_reduce"
Expand Down Expand Up @@ -142,7 +142,7 @@ def allocate_buffer(nelems, dtype):


def build_bufs(
execution_plan_name: str,
collective: str,
size: int,
in_place: bool,
dtype: cp.dtype,
Expand All @@ -153,7 +153,7 @@ def build_bufs(
assert (size % type_size) == 0, "size %d not multiple of type size %d" % (size, type_size)
nelems = size // type_size

if "allgather" in execution_plan_name:
if "allgather" in collective:
assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks)
nelems_input = nelems if in_place else nelems // num_ranks
else:
Expand All @@ -162,7 +162,7 @@ def build_bufs(

result_buf = allocate_buffer(nelems_output, dtype=dtype)
if in_place:
if "allgather" in execution_plan_name:
if "allgather" in collective:
input_buf = cp.split(result_buf, num_ranks)[rank]
else:
input_buf = result_buf
Expand All @@ -174,7 +174,6 @@ def build_bufs(


def main(
execution_plan_name: str,
execution_plan_path: str,
size: int,
in_place: bool = True,
Expand All @@ -189,11 +188,12 @@ def main(
npkit_dump_dir = os.getenv("NPKIT_DUMP_DIR")
if npkit_dump_dir is not None:
npkit.init(mscclpp_group.my_rank)
execution_plan = ExecutionPlan(execution_plan_name, execution_plan_path)
execution_plan = ExecutionPlan(execution_plan_path)
collective = execution_plan.collective()

dtype = parse_dtype(dtype_str)
input_buf, result_buf, test_buf = build_bufs(
execution_plan_name,
collective,
size,
in_place,
dtype,
Expand All @@ -215,7 +215,7 @@ def main(

mscclpp_group.barrier()
bench_correctness(
execution_plan_name,
collective,
input_buf,
result_buf,
test_buf,
Expand All @@ -242,7 +242,6 @@ def main(

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--execution_plan_name", type=str, required=True)
parser.add_argument("-path", "--execution_plan_path", type=str, required=True)
parser.add_argument("--size", type=str, required=True)
parser.add_argument("--in_place", action="store_true", help="flag to define an in-place operation")
Expand All @@ -258,7 +257,6 @@ def main(

buffer_size = parse_size(args.size)
main(
args.execution_plan_name,
args.execution_plan_path,
buffer_size,
args.in_place,
Expand Down
Loading
Loading