diff --git a/CMakeLists.txt b/CMakeLists.txt index ccddb366b..865ab81f5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,13 +99,17 @@ find_package(IBVerbs REQUIRED) find_package(NUMA REQUIRED) find_package(Threads REQUIRED) +include(FetchContent) +FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) +FetchContent_MakeAvailable(json) + add_library(mscclpp_obj OBJECT) target_include_directories(mscclpp_obj SYSTEM PRIVATE ${GPU_INCLUDE_DIRS} ${IBVERBS_INCLUDE_DIRS} ${NUMA_INCLUDE_DIRS}) -target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads) +target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads) set_target_properties(mscclpp_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION}) if(USE_CUDA) target_compile_definitions(mscclpp_obj PRIVATE USE_CUDA) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 1e0ec46ea..46c9af768 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -38,6 +38,7 @@ class Bootstrap { virtual ~Bootstrap() = default; virtual int getRank() = 0; virtual int getNranks() = 0; + virtual int getNranksPerNode() = 0; virtual void send(void* data, int size, int peer, int tag) = 0; virtual void recv(void* data, int size, int peer, int tag) = 0; virtual void allGather(void* allData, int size) = 0; @@ -83,6 +84,9 @@ class TcpBootstrap : public Bootstrap { /// Return the total number of ranks. int getNranks() override; + /// Return the total number of ranks per node. + int getNranksPerNode() override; + /// Send data to another process. /// /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, @@ -760,6 +764,10 @@ DeviceHandle> deviceHandle(T&& t) { return t.deviceHandle(); } +/// Packet value type. +template +using PacketPayload = typename T::Payload; + } // namespace mscclpp namespace std { diff --git a/include/mscclpp/errors.hpp b/include/mscclpp/errors.hpp index 4e90c8d84..8d3fde4d1 100644 --- a/include/mscclpp/errors.hpp +++ b/include/mscclpp/errors.hpp @@ -16,6 +16,7 @@ enum class ErrorCode { InvalidUsage, // The function was used incorrectly. Timeout, // The operation timed out. Aborted, // The operation was aborted. + ExecutorError, // An error occurred in the MSCCL++ executor. }; /// Convert an error code to a string. diff --git a/include/mscclpp/executor.hpp b/include/mscclpp/executor.hpp new file mode 100644 index 000000000..23dc7cece --- /dev/null +++ b/include/mscclpp/executor.hpp @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_EXECUTOR_HPP_ +#define MSCCLPP_EXECUTOR_HPP_ + +#include +#include +#include + +namespace mscclpp { + +enum class DataType { + INT32, + UINT32, + FLOAT16, + FLOAT32, +}; + +enum class PacketType { + LL8, + LL16, +}; + +class ExecutionPlan { + public: + ExecutionPlan(const std::string& name, const std::string& planPath); + ~ExecutionPlan() = default; + + private: + struct Impl; + std::shared_ptr impl_; + + friend class Executor; +}; + +class Executor { + public: + Executor(std::shared_ptr comm); + Executor(const Executor&) = delete; + Executor& operator=(const Executor&) = delete; + ~Executor(); + + void execute(int rank, void* sendbuff, void* recvBuff, size_t sendBuffSize, size_t recvBuffSize, DataType dataType, + int nthreads, const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType = PacketType::LL16); + + private: + struct Impl; + std::unique_ptr impl_; +}; +} // namespace mscclpp + +#endif // MSCCLPP_EXECUTOR_HPP_ diff --git a/include/mscclpp/packet_device.hpp b/include/mscclpp/packet_device.hpp index 11f63b53f..8cff4c790 100644 --- a/include/mscclpp/packet_device.hpp +++ b/include/mscclpp/packet_device.hpp @@ -24,12 +24,20 @@ union alignas(16) LL16Packet { uint32_t data2; uint32_t flag2; }; + using Payload = uint2; #if defined(MSCCLPP_DEVICE_COMPILE) ulonglong2 raw_; MSCCLPP_DEVICE_INLINE LL16Packet() {} + MSCCLPP_DEVICE_INLINE LL16Packet(uint2 val, uint32_t flag) { + data1 = val.x; + flag1 = flag; + data2 = val.y; + flag2 = flag; + } + /// Write 8 bytes of data to the packet. /// @param val1 The first 4-byte data to write. /// @param val2 The second 4-byte data to write. @@ -95,10 +103,17 @@ union alignas(8) LL8Packet { uint32_t flag; }; uint64_t raw_; + + using Payload = uint32_t; #if defined(MSCCLPP_DEVICE_COMPILE) MSCCLPP_DEVICE_INLINE LL8Packet() {} + MSCCLPP_DEVICE_INLINE LL8Packet(uint32_t val, uint32_t flag) { + this->data = val; + this->flag = flag; + } + MSCCLPP_DEVICE_INLINE void write(uint32_t val, uint32_t flag) { #if defined(MSCCLPP_DEVICE_CUDA) asm volatile("st.volatile.global.v2.u32 [%0], {%1,%2};" ::"l"(&raw_), "r"(val), "r"(flag)); diff --git a/python/mscclpp/__init__.py b/python/mscclpp/__init__.py index 8f013e080..0acc55fc5 100644 --- a/python/mscclpp/__init__.py +++ b/python/mscclpp/__init__.py @@ -19,6 +19,10 @@ TcpBootstrap, Transport, TransportFlags, + DataType, + Executor, + ExecutionPlan, + PacketType, version, is_nvls_supported, ) diff --git a/python/mscclpp/comm.py b/python/mscclpp/comm.py index b3cc51f12..1cf9ebb41 100644 --- a/python/mscclpp/comm.py +++ b/python/mscclpp/comm.py @@ -51,6 +51,7 @@ def __init__( self.communicator = Communicator(self.bootstrap) self.my_rank = self.bootstrap.get_rank() self.nranks = self.bootstrap.get_n_ranks() + self.nranks_per_node = self.bootstrap.get_n_ranks_per_node() def barrier(self): self.bootstrap.barrier() diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index df90cc3cf..3b7f48023 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -20,6 +20,7 @@ extern void register_fifo(nb::module_& m); extern void register_semaphore(nb::module_& m); extern void register_utils(nb::module_& m); extern void register_numa(nb::module_& m); +extern void register_executor(nb::module_& m); template void def_nonblocking_future(nb::handle& m, const std::string& typestr) { @@ -35,6 +36,7 @@ void register_core(nb::module_& m) { nb::class_(m, "Bootstrap") .def("get_rank", &Bootstrap::getRank) .def("get_n_ranks", &Bootstrap::getNranks) + .def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode) .def( "send", [](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) { @@ -204,4 +206,5 @@ NB_MODULE(_mscclpp, m) { register_utils(m); register_core(m); register_numa(m); + register_executor(m); } diff --git a/python/mscclpp/error_py.cpp b/python/mscclpp/error_py.cpp index 18d4b834a..af78ac880 100644 --- a/python/mscclpp/error_py.cpp +++ b/python/mscclpp/error_py.cpp @@ -16,7 +16,8 @@ void register_error(nb::module_& m) { .value("RemoteError", ErrorCode::RemoteError) .value("InvalidUsage", ErrorCode::InvalidUsage) .value("Timeout", ErrorCode::Timeout) - .value("Aborted", ErrorCode::Aborted); + .value("Aborted", ErrorCode::Aborted) + .value("ExecutorError", ErrorCode::ExecutorError); nb::class_(m, "BaseError") .def(nb::init(), nb::arg("message"), nb::arg("errorCode")) diff --git a/python/mscclpp/executor.cpp b/python/mscclpp/executor.cpp new file mode 100644 index 000000000..9f58eac8f --- /dev/null +++ b/python/mscclpp/executor.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include +#include + +namespace nb = nanobind; +using namespace mscclpp; + +void register_executor(nb::module_& m) { + nb::enum_(m, "DataType") + .value("int32", DataType::INT32) + .value("uint32", DataType::UINT32) + .value("float16", DataType::FLOAT16) + .value("float32", DataType::FLOAT32); + + nb::enum_(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16); + + nb::class_(m, "ExecutionPlan") + .def(nb::init(), nb::arg("name"), nb::arg("planPath")); + + nb::class_(m, "Executor") + .def(nb::init>(), nb::arg("comm")) + .def( + "execute", + [](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize, + DataType dataType, int nthreads, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) { + self->execute(rank, reinterpret_cast(sendbuff), reinterpret_cast(recvBuff), sendBuffSize, + recvBuffSize, dataType, nthreads, plan, (cudaStream_t)stream, packetType); + }, + nb::arg("rank"), nb::arg("sendbuff"), nb::arg("recvBuff"), nb::arg("sendBuffSize"), nb::arg("recvBuffSize"), + nb::arg("dataType"), nb::arg("nthreads"), nb::arg("plan"), nb::arg("stream"), + nb::arg("packetType") = PacketType::LL16); +} diff --git a/python/test/executor_test.py b/python/test/executor_test.py new file mode 100644 index 000000000..b0e4342dd --- /dev/null +++ b/python/test/executor_test.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from os import path +from mscclpp import ( + DataType, + Executor, + ExecutionPlan, +) +import mscclpp.comm as mscclpp_comm + +import cupy as cp +from mpi4py import MPI + +MSCCLPP_ROOT_PATH = "/root/mscclpp" + + +def bench_time(niters: int, ngraphIters: int, func): + # capture cuda graph for niters of the kernel launch + stream = cp.cuda.Stream(non_blocking=True) + with stream: + stream.begin_capture() + for i in range(niters): + func(stream) + graph = stream.end_capture() + + # now run a warm up round + graph.launch(stream) + + # now run the benchmark and measure time + start = cp.cuda.Event() + end = cp.cuda.Event() + + start.record(stream) + for _ in range(ngraphIters): + graph.launch(stream) + end.record(stream) + end.synchronize() + + return cp.cuda.get_elapsed_time(start, end) / niters * 1000.0 / ngraphIters + + +if __name__ == "__main__": + mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD) + cp.cuda.Device(MPI.COMM_WORLD.rank % mscclpp_group.nranks_per_node).use() + executor = Executor(mscclpp_group.communicator) + execution_plan = ExecutionPlan( + "allreduce_pairs", path.join(MSCCLPP_ROOT_PATH, "test", "execution-files", "allreduce.json") + ) + + nelems = 1024 * 1024 + cp.random.seed(42) + buffer = cp.random.random(nelems).astype(cp.float16) + sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size) + sendbuf = sub_arrays[MPI.COMM_WORLD.rank] + mscclpp_group.barrier() + + execution_time = bench_time( + 100, + 10, + lambda stream: executor.execute( + MPI.COMM_WORLD.rank, + sendbuf.data.ptr, + sendbuf.data.ptr, + sendbuf.nbytes, + sendbuf.nbytes, + DataType.float16, + 512, + execution_plan, + stream.ptr, + ), + ) + print(f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, data size: {sendbuf.nbytes} bytes") + executor = None + mscclpp_group = None diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index 4b3cb6ebf..c6014b84e 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -12,7 +12,10 @@ import pytest from mscclpp import ( + DataType, EndpointConfig, + ExecutionPlan, + Executor, Fifo, Host2DeviceSemaphore, Host2HostSemaphore, @@ -590,3 +593,39 @@ def test_nvls(mpi_group: MpiGroup): kernel() cp.cuda.runtime.deviceSynchronize() group.barrier() + + +@parametrize_mpi_groups(2) +@pytest.mark.parametrize("filename", ["allreduce.json", "allreduce_packet.json"]) +def test_executor(mpi_group: MpiGroup, filename: str): + if all_ranks_on_the_same_node(mpi_group) is False: + pytest.skip("algo not support cross node") + project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + mscclpp_group = mscclpp_comm.CommGroup(mpi_group.comm) + executor = Executor(mscclpp_group.communicator) + execution_plan = ExecutionPlan("allreduce_pairs", os.path.join(project_dir, "test", "execution-files", filename)) + + nelems = 1024 * 1024 + cp.random.seed(42) + buffer = cp.random.random(nelems).astype(cp.float16) + sub_arrays = cp.split(buffer, mpi_group.comm.size) + sendbuf = sub_arrays[mpi_group.comm.rank] + expected = cp.zeros_like(sendbuf) + for i in range(mpi_group.comm.size): + expected += sub_arrays[i] + mscclpp_group.barrier() + + stream = cp.cuda.Stream(non_blocking=True) + executor.execute( + mpi_group.comm.rank, + sendbuf.data.ptr, + sendbuf.data.ptr, + sendbuf.nbytes, + sendbuf.nbytes, + DataType.float16, + 512, + execution_plan, + stream.ptr, + ) + stream.synchronize() + assert cp.allclose(sendbuf, expected, atol=1e-3 * mpi_group.comm.size) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cfbcc927a..45b4075d2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc) +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc *.cu) target_sources(mscclpp_obj PRIVATE ${SOURCES}) target_include_directories(mscclpp_obj PRIVATE include) diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index c9cea10f4..d6e9a0dfb 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -81,6 +81,7 @@ class TcpBootstrap::Impl { UniqueId getUniqueId() const; int getRank(); int getNranks(); + int getNranksPerNode(); void allGather(void* allData, int size); void send(void* data, int size, int peer, int tag); void recv(void* data, int size, int peer, int tag); @@ -91,6 +92,7 @@ class TcpBootstrap::Impl { UniqueIdInternal uniqueId_; int rank_; int nRanks_; + int nRanksPerNode_; bool netInitialized; std::unique_ptr listenSockRoot_; std::unique_ptr listenSock_; @@ -141,6 +143,7 @@ UniqueId TcpBootstrap::Impl::getUniqueId(const UniqueIdInternal& uniqueId) { TcpBootstrap::Impl::Impl(int rank, int nRanks) : rank_(rank), nRanks_(nRanks), + nRanksPerNode_(0), netInitialized(false), peerCommAddresses_(nRanks, SocketAddress()), barrierArr_(nRanks, 0), @@ -418,6 +421,26 @@ void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) { TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_); } +int TcpBootstrap::Impl::getNranksPerNode() { + if (nRanksPerNode_ > 0) return nRanksPerNode_; + int nRanksPerNode = 0; + bool useIpv4 = peerCommAddresses_[rank_].sa.sa_family == AF_INET; + for (int i = 0; i < nRanks_; i++) { + if (useIpv4) { + if (peerCommAddresses_[i].sin.sin_addr.s_addr == peerCommAddresses_[rank_].sin.sin_addr.s_addr) { + nRanksPerNode++; + } + } else { + if (std::memcmp(&(peerCommAddresses_[i].sin6.sin6_addr), &(peerCommAddresses_[rank_].sin6.sin6_addr), + sizeof(in6_addr)) == 0) { + nRanksPerNode++; + } + } + } + nRanksPerNode_ = nRanksPerNode; + return nRanksPerNode_; +} + void TcpBootstrap::Impl::allGather(void* allData, int size) { char* data = static_cast(allData); int rank = rank_; @@ -520,6 +543,8 @@ MSCCLPP_API_CPP int TcpBootstrap::getRank() { return pimpl_->getRank(); } MSCCLPP_API_CPP int TcpBootstrap::getNranks() { return pimpl_->getNranks(); } +MSCCLPP_API_CPP int TcpBootstrap::getNranksPerNode() { return pimpl_->getNranksPerNode(); } + MSCCLPP_API_CPP void TcpBootstrap::send(void* data, int size, int peer, int tag) { pimpl_->send(data, size, peer, tag); } diff --git a/src/errors.cc b/src/errors.cc index 537b3fc27..fbc7a5734 100644 --- a/src/errors.cc +++ b/src/errors.cc @@ -19,6 +19,10 @@ std::string errorToString(enum ErrorCode error) { return "InvalidUsage"; case ErrorCode::Timeout: return "Timeout"; + case ErrorCode::Aborted: + return "Aborted"; + case ErrorCode::ExecutorError: + return "ExecutorError"; default: return "UnknownError"; } diff --git a/src/executor/execution_kernel.cu b/src/executor/execution_kernel.cu new file mode 100644 index 000000000..4e96af9ab --- /dev/null +++ b/src/executor/execution_kernel.cu @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "execution_kernel.hpp" + +#if defined(MSCCLPP_DEVICE_CUDA) +namespace mscclpp { + +template +void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, + size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, + size_t sharedMemSize, cudaStream_t stream, uint32_t flag) { + switch (dataType) { + case DataType::INT32: + executionKernel<<>>( + rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag); + break; + case DataType::UINT32: + executionKernel<<>>( + rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag); + break; + case DataType::FLOAT16: + executionKernel<<>>( + rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag); + break; + case DataType::FLOAT32: + executionKernel<<>>( + rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag); + break; + } +} + +template void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, + void* scratch, size_t scratchSize, DataType dataType, + DeviceExecutionPlan* plan, size_t sharedMemSize, + cudaStream_t stream, uint32_t flag); +template void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, + void* scratch, size_t scratchSize, DataType dataType, + DeviceExecutionPlan* plan, size_t sharedMemSize, + cudaStream_t stream, uint32_t flag); +} // namespace mscclpp +#endif diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc new file mode 100644 index 000000000..da7e135a7 --- /dev/null +++ b/src/executor/execution_plan.cc @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "execution_plan.hpp" + +#include +#include + +namespace { +template +std::vector filter(const std::vector& vec, Predicate pred) { + std::vector filtered; + std::copy_if(vec.begin(), vec.end(), std::back_inserter(filtered), pred); + return filtered; +} + +auto getOpType = [](const std::string& str) { + if (str == "nop") { + return mscclpp::OperationType::BARRIER; + } else if (str == "put") { + return mscclpp::OperationType::PUT; + } else if (str == "get") { + return mscclpp::OperationType::GET; + } else if (str == "copy") { + return mscclpp::OperationType::COPY; + } else if (str == "signal") { + return mscclpp::OperationType::SIGNAL; + } else if (str == "wait") { + return mscclpp::OperationType::WAIT; + } else if (str == "flush") { + return mscclpp::OperationType::FLUSH; + } else if (str == "re") { + return mscclpp::OperationType::REDUCE; + } else if (str == "rs") { + return mscclpp::OperationType::REDUCE_SEND; + } else if (str == "rrc") { + return mscclpp::OperationType::READ_REDUCE_COPY; + } else if (str == "rrcs") { + return mscclpp::OperationType::READ_REDUCE_COPY_SEND; + } else if (str == "ppkt") { + return mscclpp::OperationType::PUT_PACKET; + } else if (str == "rspkt") { + return mscclpp::OperationType::REDUCE_SEND_PACKET; + } else if (str == "cpkt") { + return mscclpp::OperationType::COPY_PACKET; + } else { + throw mscclpp::Error("Invalid operation type", mscclpp::ErrorCode::ExecutorError); + } +}; + +auto convertToBufferType = [](const std::string& str) { + if (str == "i") { + return mscclpp::BufferType::INPUT; + } else if (str == "o") { + return mscclpp::BufferType::OUTPUT; + } else if (str == "s") { + return mscclpp::BufferType::SCRATCH; + } else { + throw mscclpp::Error("Invalid buffer type", mscclpp::ErrorCode::ExecutorError); + } +}; + +auto convertToChannelType = [](const std::string& str) { + if (str == "sm") { + return mscclpp::ChannelType::SM; + } else if (str == "proxy") { + return mscclpp::ChannelType::PROXY; + } else if (str == "none") { + return mscclpp::ChannelType::NONE; + } else { + throw mscclpp::Error("Invalid channel type", mscclpp::ErrorCode::ExecutorError); + } +}; + +} // namespace + +namespace mscclpp { +using json = nlohmann::json; + +ExecutionPlan::Impl::Impl(const std::string name, const std::string planPath) + : name(name), planPath(planPath), isUsingPacket(false) {} + +std::vector ExecutionPlan::Impl::getChannelInfos(int rank, ChannelType channelType) const { + auto pred = [channelType](const ChannelInfo& info) { return info.channelType == channelType; }; + return filter(this->channelInfos.at(rank), pred); +} + +std::vector ExecutionPlan::Impl::getChannelInfos(int rank, BufferType dstBufferType) const { + auto pred = [dstBufferType](const ChannelInfo& info) { return info.dstBufferType == dstBufferType; }; + return filter(this->channelInfos.at(rank), pred); +} + +std::vector ExecutionPlan::Impl::getConnectedPeers(int rank) const { + std::set peers; + for (const auto& info : this->channelInfos.at(rank)) { + for (int peer : info.connectedPeers) { + peers.insert(peer); + } + } + return std::vector(peers.begin(), peers.end()); +} + +std::vector ExecutionPlan::Impl::getConnectedBufferTypes(int rank) const { + std::set bufferTypes; + for (const auto& info : this->channelInfos.at(rank)) { + bufferTypes.insert(info.dstBufferType); + } + return std::vector(bufferTypes.begin(), bufferTypes.end()); +} +size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize) const { + if (this->isUsingPacket) { + return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank) * 2 /* data + flag*/ * + 2 /*double buffer*/; + } + return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank); +} +std::vector ExecutionPlan::Impl::getOperations(int rank, int threadblock) const { + return this->operations.at(rank)[threadblock]; +} + +int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->operations.at(rank).size(); } + +void ExecutionPlan::Impl::loadExecutionPlan(int rank, size_t inputSize) { + std::ifstream file(this->planPath); + json obj = json::parse(file); + if (this->name != obj["name"]) { + throw Error("Plan name does not match", ErrorCode::ExecutorError); + } + std::string protocol = obj["protocol"]; + if (protocol == "LL") { + this->isUsingPacket = true; + } + const auto& gpus = obj["gpus"]; + + for (const auto& gpu : gpus) { + int rank = gpu["id"]; + this->inputChunks[rank] = gpu["inputChunks"]; + this->outputChunks[rank] = gpu["outputChunks"]; + this->scratchChunks[rank] = gpu["scratchChunks"]; + } + this->setupChannels(gpus); + + this->chunkSize = inputSize / this->inputChunks[rank]; + this->setupOperations(gpus); +} + +// Construct the channel info. Step 1. Flatten SM and PROXY channels into separate vectors. +// Step 2. For each threadblock, construct a vector of channel indexes and keys. +void ExecutionPlan::Impl::setupChannels(const json& gpus) { + for (const auto& gpu : gpus) { + int rank = gpu["id"]; + std::vector channelInfos; + for (const auto& channel : gpu["channels"]) { + ChannelInfo info; + info.srcBufferType = convertToBufferType(channel["srcbuff"]); + info.dstBufferType = convertToBufferType(channel["dstbuff"]); + info.channelType = convertToChannelType(channel["type"]); + for (const auto& peer : channel["connectedTo"]) { + info.connectedPeers.push_back(peer); + } + channelInfos.push_back(info); + } + this->channelInfos[rank] = channelInfos; + } + + // setup threadblockChannelMap + for (const auto& gpu : gpus) { + int rank = gpu["id"]; + auto channelTypes = {ChannelType::SM, ChannelType::PROXY}; + std::unordered_map> channelMap; + for (auto channelType : channelTypes) { + const std::vector channelInfos = this->getChannelInfos(rank, channelType); + int index = 0; + for (const auto& info : channelInfos) { + ChannelKey key = {info.srcBufferType, info.dstBufferType, info.channelType}; + for (size_t i = 0; i < info.connectedPeers.size(); i++) { + channelMap[key].push_back(index++); + } + } + } + int nthreadblocks = gpu["threadblocks"].size(); + this->threadblockSMChannelMap[rank].resize(nthreadblocks); + this->threadblockProxyChannelMap[rank].resize(nthreadblocks); + for (const auto& threadblock : gpu["threadblocks"]) { + for (const auto& channel : threadblock["channels"]) { + ChannelType channelType = convertToChannelType(channel["ctype"]); + ChannelKey key = {convertToBufferType(channel["src"]), convertToBufferType(channel["dst"]), channelType}; + for (int id : channel["cids"]) { + if (channelType == ChannelType::SM) { + this->threadblockSMChannelMap[rank][threadblock["id"]].emplace_back(channelMap[key][id], key); + } else if (channelType == ChannelType::PROXY) { + this->threadblockProxyChannelMap[rank][threadblock["id"]].emplace_back(channelMap[key][id], key); + } + } + } + } + } +} + +void ExecutionPlan::Impl::setupOperations(const json& gpus) { + // setup threadblocks and operations + for (const auto& gpu : gpus) { + int rank = gpu["id"]; + for (const auto& threadblock : gpu["threadblocks"]) { + std::unordered_map> channelIndexes; + std::vector ops; + int threadblockId = threadblock["id"]; + const auto& smChannels = this->threadblockSMChannelMap[rank][threadblockId]; + const auto& proxyChannels = this->threadblockProxyChannelMap[rank][threadblockId]; + for (size_t i = 0; i < smChannels.size(); i++) { + const auto& [_, key] = smChannels[i]; + channelIndexes[key].push_back(i); + } + for (size_t i = 0; i < proxyChannels.size(); i++) { + const auto& [_, key] = proxyChannels[i]; + channelIndexes[key].push_back(i); + } + for (const auto& op : threadblock["ops"]) { + Operation operation = {}; + operation.type = static_cast(getOpType(op["name"])); + if (op.contains("ctype")) { + operation.channelType = convertToChannelType(op["ctype"]); + } + if (op.contains("i_cids")) { + operation.nInputs = op["i_cids"].size(); + for (int i = 0; i < operation.nInputs; i++) { + BufferType srcBufferType = convertToBufferType(op["i_buff"]["src"]); + BufferType dstBufferType = convertToBufferType(op["i_buff"]["dst"]); + // Get the relevant channel index in rank channelInfos + operation.inputChannelIndexes[i] = + channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]]; + operation.inputOffsets[i] = this->chunkSize * (int)op["i_cids"][i]["off"]; + } + } + // will have either srcs or i_cids + if (op.contains("srcs")) { + operation.nInputs = op["srcs"].size(); + operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]); + for (int i = 0; i < operation.nInputs; i++) { + operation.inputOffsets[i] = this->chunkSize * (int)op["srcs"][i]["off"]; + } + } + if (op.contains("o_cids")) { + operation.nOutputs = op["o_cids"].size(); + for (int i = 0; i < operation.nOutputs; i++) { + BufferType srcBufferType = convertToBufferType(op["o_buff"]["src"]); + BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]); + operation.outputChannelIndexes[i] = + channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]]; + operation.outputOffsets[i] = this->chunkSize * (int)op["o_cids"][i]["off"]; + } + } + if (op.contains("srcbuff")) { + operation.srcBufferType = convertToBufferType(op["srcbuff"]); + } + if (op.contains("srcoff")) { + operation.srcOffset = (int)op["srcoff"] * this->chunkSize; + } + if (op.contains("dstbuff")) { + operation.dstBufferType = convertToBufferType(op["dstbuff"]); + } + if (op.contains("dstoff")) { + operation.dstOffset = (int)op["dstoff"] * this->chunkSize; + } + if (op.contains("cnt")) { + operation.size = this->chunkSize * (int)op["cnt"]; + } + ops.push_back(operation); + } + this->operations[rank].push_back(ops); + } + } +} + +ExecutionPlan::ExecutionPlan(const std::string& name, const std::string& planPath) + : impl_(std::make_shared(name, planPath)) {} + +} // namespace mscclpp diff --git a/src/executor/executor.cc b/src/executor/executor.cc new file mode 100644 index 000000000..2f4fdd264 --- /dev/null +++ b/src/executor/executor.cc @@ -0,0 +1,296 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include +#include + +#include "execution_kernel.hpp" +#include "execution_plan.hpp" + +namespace mscclpp { +struct ExecutionContextKey { + void* sendBuff; + void* recvBuff; + size_t sendBuffSize; + size_t recvBuffSize; + std::string plan; + + bool operator==(const ExecutionContextKey& other) const { + return sendBuff == other.sendBuff && recvBuff == other.recvBuff && sendBuffSize == other.sendBuffSize && + recvBuffSize == other.recvBuffSize && plan == other.plan; + } +}; +} // namespace mscclpp + +namespace std { +template <> +struct hash> { + std::size_t operator()(const std::pair& key) const { + return std::hash()(key.second) ^ std::hash()(static_cast(key.first)); + } +}; + +template <> +struct hash { + std::size_t operator()(const mscclpp::ExecutionContextKey& key) const { + return std::hash()(key.sendBuff) ^ std::hash()(key.recvBuff) ^ std::hash()(key.sendBuffSize) ^ + std::hash()(key.recvBuffSize) ^ std::hash()(key.plan); + } +}; +} // namespace std + +namespace { +auto inSameNode = [](int rank1, int rank2, int nranksPerNode) { + return rank1 / nranksPerNode == rank2 / nranksPerNode; +}; + +static const mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, + mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5, + mscclpp::Transport::IB6, mscclpp::Transport::IB7}; +} // namespace + +namespace mscclpp { + +struct ExecutionContext { + std::unordered_map> connections; + std::unordered_map, mscclpp::RegisteredMemory> registeredMemories; + std::vector> smSemaphores; + std::vector proxySemaphores; + std::vector smChannels; + std::vector proxyChannels; + std::vector deviceExecutionPlans; + std::shared_ptr scratchBuffer; + size_t scratchBufferSize; + std::shared_ptr deviceExecutionPlansBuffer; +}; + +struct Executor::Impl { + int nranksPerNode; + std::shared_ptr comm; + std::shared_ptr proxyService; + std::unordered_map contexts; + + Impl(std::shared_ptr comm) : comm(comm) { + this->nranksPerNode = comm->bootstrap()->getNranksPerNode(); + this->proxyService = std::make_shared(); + } + ~Impl() = default; + + ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t sendBufferSize, + size_t recvBufferSize, const ExecutionPlan& plan) { + ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name}; + if (this->contexts.find(key) != this->contexts.end()) { + return this->contexts[key]; + } + plan.impl_->loadExecutionPlan(rank, sendBufferSize); + + ExecutionContext context; + size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize); + std::shared_ptr scratchBuffer = allocExtSharedCuda(scratchBufferSize); + context.scratchBuffer = scratchBuffer; + context.scratchBufferSize = scratchBufferSize; + this->setupConnections(context, rank, plan); + this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); + this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, rank, plan); + this->setupDeviceExecutionPlan(context, rank, plan); + context.deviceExecutionPlansBuffer = + allocExtSharedCuda(context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); + memcpyCuda(context.deviceExecutionPlansBuffer.get(), (char*)context.deviceExecutionPlans.data(), + context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); + this->contexts.insert({key, context}); + return context; + } + + void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan) { + std::vector connectedPeers = plan.impl_->getConnectedPeers(rank); + std::vector>> connectionFutures; + for (int peer : connectedPeers) { + Transport transport = + inSameNode(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode]; + connectionFutures.push_back(this->comm->connectOnSetup(peer, 0, transport)); + } + this->comm->setup(); + for (size_t i = 0; i < connectionFutures.size(); i++) { + context.connections[connectedPeers[i]] = connectionFutures[i].get(); + } + } + + void setupRegisteredMemories(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, + size_t recvBufferSize, int rank, const ExecutionPlan& plan) { + auto getTransportFlags = [&](std::vector& infos, int rank) { + TransportFlags flags; + for (ChannelInfo& info : infos) { + if (info.channelType == ChannelType::SM) { + flags |= Transport::CudaIpc; + } else if (info.channelType == ChannelType::PROXY) { + for (int peer : info.connectedPeers) { + if (!inSameNode(rank, peer, this->nranksPerNode)) { + flags |= IBs[rank % this->nranksPerNode]; + } + } + } + } + return flags; + }; + auto getBufferInfo = [&](BufferType type) { + switch (type) { + case BufferType::INPUT: + return std::make_pair(sendbuff, sendBufferSize); + case BufferType::OUTPUT: + return std::make_pair(recvbuff, recvBufferSize); + case BufferType::SCRATCH: + return std::make_pair((void*)context.scratchBuffer.get(), context.scratchBufferSize); + default: + throw Error("Invalid buffer type", ErrorCode::ExecutorError); + } + }; + auto getConnectedPeers = [&](std::vector& infos) { + std::set peers; + for (ChannelInfo& info : infos) { + for (int peer : info.connectedPeers) { + peers.insert(peer); + } + } + return std::vector(peers.begin(), peers.end()); + }; + + std::vector bufferTypes = plan.impl_->getConnectedBufferTypes(rank); + for (BufferType bufferType : bufferTypes) { + std::vector channelInfos = plan.impl_->getChannelInfos(rank, bufferType); + TransportFlags transportFlags = getTransportFlags(channelInfos, rank); + RegisteredMemory memory = + this->comm->registerMemory(getBufferInfo(bufferType).first, getBufferInfo(bufferType).second, transportFlags); + std::vector connectedPeers = getConnectedPeers(channelInfos); + std::vector> remoteRegMemoryFutures; + for (int peer : connectedPeers) { + comm->sendMemoryOnSetup(memory, peer, 0); + remoteRegMemoryFutures.push_back(comm->recvMemoryOnSetup(peer, 0)); + } + comm->setup(); + for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) { + context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get()); + } + } + } + + void setupChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, int rank, + const ExecutionPlan& plan) { + const auto channelTypes = {ChannelType::SM, ChannelType::PROXY}; + std::vector> smSemaphores; + std::vector proxySemaphores; + for (ChannelType channelType : channelTypes) { + std::vector channelInfos = plan.impl_->getChannelInfos(rank, channelType); + for (ChannelInfo& info : channelInfos) { + for (int peer : info.connectedPeers) { + if (channelType == ChannelType::SM) { + smSemaphores.push_back( + std::make_shared(*this->comm, context.connections.at(peer))); + } else if (channelType == ChannelType::PROXY) { + proxySemaphores.push_back( + this->proxyService->buildAndAddSemaphore(*this->comm, context.connections.at(peer))); + } + } + } + } + this->comm->setup(); + context.smSemaphores = std::move(smSemaphores); + context.proxySemaphores = std::move(proxySemaphores); + + auto getBuffer = [&](BufferType type) { + switch (type) { + case BufferType::INPUT: + return sendbuff; + case BufferType::OUTPUT: + return recvbuff; + case BufferType::SCRATCH: + return (void*)context.scratchBuffer.get(); + default: + throw Error("Invalid buffer type", ErrorCode::ExecutorError); + } + }; + for (ChannelType channelType : channelTypes) { + std::vector channelInfos = plan.impl_->getChannelInfos(rank, channelType); + int index = 0; + for (ChannelInfo& info : channelInfos) { + void* src = getBuffer(info.srcBufferType); + TransportFlags transport = context.registeredMemories.begin()->second.transports(); + RegisteredMemory localMemory = this->comm->registerMemory(src, sendBufferSize, transport); + for (int peer : info.connectedPeers) { + if (channelType == ChannelType::SM) { + context.smChannels.emplace_back(context.smSemaphores[index++], + context.registeredMemories[{info.dstBufferType, peer}], src, nullptr); + } else if (channelType == ChannelType::PROXY) { + context.proxyChannels.emplace_back( + this->proxyService->proxyChannel(context.proxySemaphores[index++]), + this->proxyService->addMemory(context.registeredMemories[{info.dstBufferType, peer}]), + this->proxyService->addMemory(localMemory)); + } + } + } + } + } + + void setupDeviceExecutionPlan(ExecutionContext& context, int rank, const ExecutionPlan& plan) { + std::vector deviceExecutionPlans; + for (int threadblock = 0; threadblock < plan.impl_->getThreadblockCount(rank); threadblock++) { + DeviceExecutionPlan deviceExecutionPlan = {}; + std::vector ops = plan.impl_->getOperations(rank, threadblock); + deviceExecutionPlan.nOperations = ops.size(); + deviceExecutionPlan.nSmChannels = plan.impl_->threadblockSMChannelMap.at(rank).at(threadblock).size(); + deviceExecutionPlan.nProxyChannels = plan.impl_->threadblockProxyChannelMap.at(rank).at(threadblock).size(); + int chanIndex = 0; + for (const auto& [index, _] : plan.impl_->threadblockSMChannelMap.at(rank).at(threadblock)) { + deviceExecutionPlan.channels.smChannels[chanIndex++] = mscclpp::deviceHandle(context.smChannels[index]); + } + chanIndex = 0; + for (const auto& [index, _] : plan.impl_->threadblockProxyChannelMap.at(rank).at(threadblock)) { + deviceExecutionPlan.channels.proxyChannels[chanIndex++] = mscclpp::deviceHandle(context.proxyChannels[index]); + } + for (size_t i = 0; i < ops.size(); i++) { + deviceExecutionPlan.operations[i] = ops[i]; + } + deviceExecutionPlans.push_back(deviceExecutionPlan); + } + context.deviceExecutionPlans = std::move(deviceExecutionPlans); + } + + void launchKernel(ExecutionContext& context, int rank, int nthreadsPerBlock, void* sendbuff, void* recvbuff, + DataType dataType, cudaStream_t stream, PacketType packetType) { + static uint32_t flag = 0; + int nthreadblocks = context.deviceExecutionPlans.size(); + size_t sharedMemSize = sizeof(DeviceExecutionPlan); + switch (packetType) { + case PacketType::LL16: + ExecutionKernel::launchKernel( + rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), + context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), + sharedMemSize, stream, ++flag); + break; + case PacketType::LL8: + ExecutionKernel::launchKernel( + rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), + context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), + sharedMemSize, stream, ++flag); + break; + default: + throw Error("Invalid packet type", ErrorCode::ExecutorError); + } + } +}; + +Executor::Executor(std::shared_ptr comm) : impl_(std::make_unique(comm)) {} + +void Executor::execute(int rank, void* sendbuff, void* recvBuff, size_t sendBuffSize, size_t recvBuffSize, + DataType dataType, int nthreads, const ExecutionPlan& plan, cudaStream_t stream, + PacketType packetType) { + ExecutionContext context = + this->impl_->setupExecutionContext(rank, sendbuff, recvBuff, sendBuffSize, recvBuffSize, plan); + // TODO(binyli): need to flush proxy channel here this->impl_->proxyService->startProxy(); + this->impl_->launchKernel(context, rank, nthreads, sendbuff, recvBuff, dataType, stream, packetType); +} + +Executor::~Executor() = default; + +} // namespace mscclpp diff --git a/src/include/execution_common.hpp b/src/include/execution_common.hpp new file mode 100644 index 000000000..38d624109 --- /dev/null +++ b/src/include/execution_common.hpp @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_EXECUTION_COMMON_HPP_ +#define MSCCLPP_EXECUTION_COMMON_HPP_ + +#include +#include + +namespace mscclpp { + +constexpr int MAX_CHANNEL = 16; +constexpr int MAX_CHANNEL_PER_OPERATION = 8; +constexpr int MAX_OPERATION = 64; + +enum class BufferType : uint8_t { + INPUT, + OUTPUT, + SCRATCH, +}; + +enum class ChannelType : uint8_t { + NONE, + SM, + PROXY, +}; + +enum class OperationType : uint8_t { + BARRIER, + PUT, + PUT_PACKET, + GET, + COPY, + COPY_PACKET, + SIGNAL, + WAIT, + FLUSH, + REDUCE, + REDUCE_PACKET, + REDUCE_SEND, + REDUCE_SEND_PACKET, + READ_REDUCE_COPY, + READ_REDUCE_COPY_SEND, +}; + +struct Channels { + mscclpp::DeviceHandle smChannels[MAX_CHANNEL]; + mscclpp::DeviceHandle proxyChannels[MAX_CHANNEL]; +}; + +struct Operation { + OperationType type; + ChannelType channelType; + BufferType srcBufferType; + BufferType dstBufferType; + uint8_t nInputs; + uint8_t nOutputs; + union { + uint8_t inputChannelIndexes[MAX_CHANNEL_PER_OPERATION]; + BufferType inputBufferType; + }; + uint8_t outputChannelIndexes[MAX_CHANNEL_PER_OPERATION]; + uint32_t inputOffsets[MAX_CHANNEL_PER_OPERATION]; + uint32_t outputOffsets[MAX_CHANNEL_PER_OPERATION]; + uint32_t srcOffset; + uint32_t dstOffset; + uint32_t size; +}; + +// total size = 1920 + 6400 + 4 + 4(padding) + 12(align) = 8336 bytes +struct __attribute__((aligned(16))) DeviceExecutionPlan { + uint8_t nSmChannels; // 1 bytes + uint8_t nProxyChannels; // 1 bytes + uint16_t nOperations; // 2 bytes + Channels channels; // 1920 bytes + Operation operations[MAX_OPERATION]; // 64 * 100 = 6400 bytes +}; + +} // namespace mscclpp + +#endif // MSCCLPP_EXECUTION_COMMON_HPP_ diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp new file mode 100644 index 000000000..08e8796a5 --- /dev/null +++ b/src/include/execution_kernel.hpp @@ -0,0 +1,359 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_EXECUTION_KERNEL_HPP_ +#define MSCCLPP_EXECUTION_KERNEL_HPP_ + +#include +#include +#include +#include + +#include "execution_common.hpp" + +#if defined(MSCCLPP_DEVICE_COMPILE) +#if defined(MSCCLPP_DEVICE_HIP) +#define __synclds() asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier"); +#endif // defined(MSCCLPP_DEVICE_HIP) + +namespace { +template +MSCCLPP_DEVICE_INLINE To bit_cast(const From& src) { + static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast"); + + union { + From f; + To t; + } u; + u.f = src; + return u.t; +} + +template +MSCCLPP_DEVICE_INLINE T add_elements(T a, T b) { + return a + b; +} + +template <> +MSCCLPP_DEVICE_INLINE __half2 add_elements(__half2 a, __half2 b) { + return __hadd2(a, b); +} + +template +MSCCLPP_DEVICE_INLINE int4 add_vectors_helper(int4 a, int4 b) { + int4 ret; + ret.w = bit_cast(add_elements(bit_cast(a.w), bit_cast(b.w))); + ret.x = bit_cast(add_elements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(add_elements(bit_cast(a.y), bit_cast(b.y))); + ret.z = bit_cast(add_elements(bit_cast(a.z), bit_cast(b.z))); + return ret; +} + +template +MSCCLPP_DEVICE_INLINE int4 add_vectors(int4 a, int4 b) { + return add_vectors_helper(a, b); +} + +template <> +MSCCLPP_DEVICE_INLINE int4 add_vectors<__half>(int4 a, int4 b) { + return add_vectors_helper<__half2>(a, b); +} + +template +MSCCLPP_DEVICE_INLINE uint2 add_vectors_helper(uint2 a, uint2 b) { + uint2 ret; + ret.x = bit_cast(add_elements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(add_elements(bit_cast(a.y), bit_cast(b.y))); + return ret; +} + +template +MSCCLPP_DEVICE_INLINE uint2 add_vectors(uint2 a, uint2 b) { + return add_vectors_helper(a, b); +} + +template <> +MSCCLPP_DEVICE_INLINE __attribute__((unused)) uint2 add_vectors<__half>(uint2 a, uint2 b) { + return add_vectors_helper<__half2>(a, b); +} + +template +MSCCLPP_DEVICE_INLINE int add_vectors_helper(int a, int b) { + return bit_cast(add_elements(bit_cast(a), bit_cast(b))); +} + +template +MSCCLPP_DEVICE_INLINE int add_vectors(int a, int b) { + return add_vectors_helper(a, b); +} + +template <> +MSCCLPP_DEVICE_INLINE __attribute__((unused)) int add_vectors<__half>(int a, int b) { + return add_vectors_helper<__half2>(a, b); +} + +template +MSCCLPP_DEVICE_INLINE uint32_t add_vectors_helper(uint32_t a, uint32_t b) { + return bit_cast(add_elements(bit_cast(a), bit_cast(b))); +} + +template +MSCCLPP_DEVICE_INLINE uint32_t add_vectors(uint32_t a, uint32_t b) { + return add_vectors_helper(a, b); +} + +template <> +MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__half>(uint32_t a, uint32_t b) { + return add_vectors_helper<__half2>(a, b); +} + +} // namespace +#endif // defined(MSCCLPP_DEVICE_COMPILE) + +namespace mscclpp { + +#if defined(MSCCLPP_DEVICE_COMPILE) + +template +MSCCLPP_DEVICE_INLINE T* getBuffer(T* input, T* output, T* scratch, BufferType bufferType) { + if (bufferType == BufferType::INPUT) { + return input; + } + if (bufferType == BufferType::OUTPUT) { + return output; + } + if (bufferType == BufferType::SCRATCH) { + return scratch; + } + return nullptr; +} + +MSCCLPP_DEVICE_INLINE void handleSignal(int tid, DeviceHandle* smChannels, + DeviceHandle* proxyChannels, uint8_t* channelIndex, + int nChannels, ChannelType chType) { + if (tid < nChannels) { + if (chType == ChannelType::SM) { + smChannels[channelIndex[tid]].signal(); + } + if (chType == ChannelType::PROXY) { + proxyChannels[channelIndex[tid]].signal(); + } + } +} + +MSCCLPP_DEVICE_INLINE void handleWait(int tid, DeviceHandle* smChannels, + DeviceHandle* proxyChannels, uint8_t* channelIndexes, + int nChannels, ChannelType chType) { + if (tid < nChannels) { + if (chType == ChannelType::SM) { + smChannels[channelIndexes[tid]].wait(); + } + if (chType == ChannelType::PROXY) { + proxyChannels[channelIndexes[tid]].wait(); + } + } +} + +MSCCLPP_DEVICE_INLINE void handleGet(DeviceHandle& smChannel, uint32_t srcOffset, uint32_t dstOffset, + uint32_t size) { + smChannel.get(dstOffset, srcOffset, size, threadIdx.x, blockDim.x); +} + +template +MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOffsetByBytes, T* input, + uint32_t inputOffsetByBytes, DeviceHandle* smChannels, + uint8_t* dstChannelIndexes, uint8_t* srcChannelIndexes, + uint32_t* dstOffsets, uint32_t* srcOffsets, int nDstChannels, + int nSrcChannels, uint32_t size, bool sendToRemote = true) { + const size_t nInt4 = size / sizeof(int4); + const size_t inputOffset4 = inputOffsetByBytes / sizeof(int4); + const size_t outputOffset4 = outputOffsetByBytes / sizeof(int4); + int4* input4 = (int4*)input; + int4* output4 = (int4*)output; + for (size_t idx = threadIdx.x; idx < nInt4; idx += blockDim.x) { + int4 tmp = input4[inputOffset4 + idx]; + for (int index = 0; index < nSrcChannels; ++index) { + int4 val; + size_t srcOffset = srcOffsets[index] / sizeof(int4); + val = smChannels[srcChannelIndexes[index]].read(srcOffset + idx); + tmp = add_vectors(tmp, val); + } + output4[outputOffset4 + idx] = tmp; + if (sendToRemote) { + for (int index = 0; index < nDstChannels; ++index) { + size_t dstOffset = dstOffsets[index] / sizeof(int4); + smChannels[dstChannelIndexes[index]].write(dstOffset + idx, tmp); + } + } + } + // handle rest of data + size_t processed = nInt4 * sizeof(int4); + const size_t startIdx = (inputOffsetByBytes + processed) / sizeof(T); + const size_t endIdx = (inputOffsetByBytes + size) / sizeof(T); + for (size_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x) { + T tmp = input[idx]; + for (int index = 0; index < nSrcChannels; ++index) { + size_t srcOffset = srcOffsets[index] / sizeof(T); + tmp += smChannels[srcChannelIndexes[index]].read(srcOffset + idx); + } + output[idx] = tmp; + if (sendToRemote) { + for (int index = 0; index < nDstChannels; ++index) { + size_t dstOffset = dstOffsets[index] / sizeof(T); + smChannels[dstChannelIndexes[index]].write(dstOffset + idx, tmp); + } + } + } +} + +template +MSCCLPP_DEVICE_INLINE void handlePutPacket(uint32_t inputOffsetByBytes, size_t scratchSize, + DeviceHandle* smChannels, uint8_t* dstChannelIndexes, + uint32_t* dstOffsets, int nDstChannels, uint32_t size, uint32_t flag) { + const size_t scratchBaseOffset = flag & 0x1 ? 0 : scratchSize >> 1; + for (int index = 0; index < nDstChannels; ++index) { + smChannels[dstChannelIndexes[index]].putPackets( + scratchBaseOffset + dstOffsets[index] * 2, inputOffsetByBytes, size, threadIdx.x, blockDim.x, flag); + } +} + +template +MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes, + T* inputBuff, size_t inputBuffSize, uint32_t* inputOffsets, int nSrcs, + DeviceHandle* smChannels, uint8_t* outputChannelIndexes, + uint32_t* outputOffsets, int nDstChannels, size_t size, + uint32_t flag) { + size_t nPackets = size * 2 / sizeof(PacketType); + const size_t intputBaseOffset = flag & 0x1 ? 0 : inputBuffSize >> 1; + const uint32_t srcOffset = srcOffsetByBytes / sizeof(PacketPayload); + const uint32_t dstOffset = dstOffsetByBytes / sizeof(PacketPayload); + PacketPayload* srcPacketPayload = (PacketPayload*)src + srcOffset; + PacketPayload* dstPacketPayload = (PacketPayload*)dst + dstOffset; + for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) { + PacketPayload data = {}; + for (int index = 0; index < nSrcs; ++index) { + PacketType* pkt = (PacketType*)((char*)inputBuff + intputBaseOffset + 2 * inputOffsets[index]); + PacketPayload val = pkt[idx].read(flag); + data = add_vectors(data, val); + } + data = add_vectors(data, srcPacketPayload[idx]); + dstPacketPayload[idx] = data; + + PacketType pkt(data, flag); + for (int index = 0; index < nDstChannels; ++index) { + size_t offset = (intputBaseOffset + outputOffsets[index] * 2) / sizeof(PacketType); + smChannels[outputChannelIndexes[index]].write(offset + idx, pkt); + } + } +} + +template +MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize, uint32_t dstOffset, + uint32_t srcOffset, size_t size, uint32_t flag) { + const size_t outputScratchBaseOffset = flag & 0x1 ? 0 : srcSize >> 1; + PacketType* srcPackets = (PacketType*)((char*)src + outputScratchBaseOffset + 2 * srcOffset); + PacketPayload* result = (PacketPayload*)((char*)dst + dstOffset); + size_t nPackets = size * 2 / sizeof(PacketType); + for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) { + PacketPayload data = srcPackets[idx].read(flag); + result[idx] = data; + } +} + +template +__global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* input, T* output, T* scratch, + size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag) { + extern __shared__ int4 sharedMem[]; + int bid = blockIdx.x; + int tid = threadIdx.x; + DeviceExecutionPlan* localPlan = plan + bid; + for (size_t i = tid; i < sizeof(DeviceExecutionPlan) / sizeof(int4); i += blockDim.x) { + sharedMem[i] = ((int4*)localPlan)[i]; + } +#if defined(MSCCLPP_DEVICE_HIP) + __synclds(); +#else // !defined(MSCCLPP_DEVICE_HIP) + __syncthreads(); +#endif // !defined(MSCCLPP_DEVICE_HIP) + localPlan = (DeviceExecutionPlan*)sharedMem; + int nOperations = localPlan->nOperations; + Operation* operations = localPlan->operations; + DeviceHandle* smChannels = localPlan->channels.smChannels; + DeviceHandle* proxyChannels = localPlan->channels.proxyChannels; + + for (int i = 0; i < nOperations; i++) { + Operation& op = operations[i]; + if (op.type == OperationType::BARRIER) { + __syncthreads(); + } else if (op.type == OperationType::SIGNAL) { + handleSignal(tid, smChannels, proxyChannels, op.outputChannelIndexes, op.nOutputs, op.channelType); + } else if (op.type == OperationType::WAIT) { + handleWait(tid, smChannels, proxyChannels, op.inputChannelIndexes, op.nInputs, op.channelType); + } else if (op.type == OperationType::GET) { + handleGet(smChannels[op.inputChannelIndexes[0]], op.inputOffsets[0], op.dstOffset, op.size); + } else if (op.type == OperationType::READ_REDUCE_COPY_SEND) { + T* dst = getBuffer(input, output, scratch, op.dstBufferType); + T* src = getBuffer(input, output, scratch, op.srcBufferType); + handleReadReduceCopySend(dst, op.dstOffset, src, op.srcOffset, smChannels, op.outputChannelIndexes, + op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs, + op.size); + } else if (op.type == OperationType::READ_REDUCE_COPY) { + T* dst = getBuffer(input, output, scratch, op.dstBufferType); + T* src = getBuffer(input, output, scratch, op.srcBufferType); + handleReadReduceCopySend(dst, op.dstOffset, src, op.srcOffset, smChannels, op.outputChannelIndexes, + op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs, + op.size, false); + } else if (op.type == OperationType::PUT_PACKET) { + handlePutPacket(op.srcOffset, scratchSize, smChannels, op.outputChannelIndexes, op.outputOffsets, + op.nOutputs, op.size, flag); + } else if (op.type == OperationType::REDUCE_SEND_PACKET) { + T* dst = getBuffer(input, output, scratch, op.dstBufferType); + T* src = getBuffer(input, output, scratch, op.srcBufferType); + handleReduceSendPacket(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize, op.inputOffsets, + op.nInputs, smChannels, op.outputChannelIndexes, op.outputOffsets, + op.nOutputs, op.size, flag); + } else if (op.type == OperationType::COPY_PACKET) { + T* dst = getBuffer(input, output, scratch, op.dstBufferType); + T* src = getBuffer(input, output, scratch, op.srcBufferType); + handleCopyPacket(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag); + } + } +} +#endif // defined(MSCCLPP_DEVICE_COMPILE) + +class ExecutionKernel { + public: +#if defined(MSCCLPP_DEVICE_HIP) + template + static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, + size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, + cudaStream_t stream, uint32_t flag = 0) { + switch (dataType) { + case DataType::INT32: + executionKernel<<>>( + rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag); + break; + case DataType::UINT32: + executionKernel<<>>( + rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag); + break; + case DataType::FLOAT16: + executionKernel<<>>( + rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag); + break; + case DataType::FLOAT32: + executionKernel<<>>( + rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag); + break; + } + } +#else // !defined(MSCCLPP_DEVICE_HIP) + template + static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, + size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, + cudaStream_t stream, uint32_t flag = 0); +#endif // !defined(MSCCLPP_DEVICE_HIP) +}; +} // namespace mscclpp + +#endif // MSCCLPP_EXECUTION_KERNEL_HPP_ diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp new file mode 100644 index 000000000..924d1358c --- /dev/null +++ b/src/include/execution_plan.hpp @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_EXECUTOR_PLAN_HPP_ +#define MSCCLPP_EXECUTOR_PLAN_HPP_ + +#include +#include +#include +#include +#include + +#include "execution_common.hpp" + +namespace mscclpp { + +struct ChannelKey { + BufferType srcBufferType; + BufferType dstBufferType; + ChannelType channelType; + bool operator==(const ChannelKey& other) const { + return srcBufferType == other.srcBufferType && dstBufferType == other.dstBufferType && + channelType == other.channelType; + } +}; +} // namespace mscclpp + +namespace std { +template <> +struct hash { + std::size_t operator()(const mscclpp::ChannelKey& key) const { + return std::hash()(static_cast(key.srcBufferType)) ^ + std::hash()(static_cast(key.dstBufferType)) ^ std::hash()(static_cast(key.channelType)); + } +}; +} // namespace std + +namespace mscclpp { + +struct ChannelInfo { + BufferType srcBufferType; + BufferType dstBufferType; + ChannelType channelType; + std::vector connectedPeers; +}; + +struct ExecutionPlan::Impl { + public: + Impl(const std::string name, const std::string planPath); + ~Impl() = default; + + std::vector getChannelInfos(int rank, ChannelType channelType) const; + std::vector getChannelInfos(int rank, BufferType bufferType) const; + std::vector getConnectedPeers(int rank) const; + std::vector getConnectedBufferTypes(int rank) const; + size_t getScratchBufferSize(int rank, size_t inputSize) const; + std::vector getOperations(int rank, int threadblock) const; + int getThreadblockCount(int rank) const; + + void loadExecutionPlan(int rank, size_t inputSize); + void setupChannels(const nlohmann::json& gpus); + void setupOperations(const nlohmann::json& gpus); + + const std::string name; + const std::string planPath; + bool isUsingPacket; + // operations for [rank][threadblock] = [operations] + std::unordered_map>> operations; + std::unordered_map> channelInfos; + // threadblockChannelMap[rank][threadblock] = [channelIndex, channelKey] + std::unordered_map>>> threadblockSMChannelMap; + std::unordered_map>>> threadblockProxyChannelMap; + std::unordered_map inputChunks; + std::unordered_map outputChunks; + std::unordered_map scratchChunks; + size_t chunkSize; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_EXECUTOR_PLAN_HPP_ diff --git a/test/execution-files/allreduce.json b/test/execution-files/allreduce.json new file mode 100644 index 000000000..739b8e6ab --- /dev/null +++ b/test/execution-files/allreduce.json @@ -0,0 +1,1042 @@ +{ + "name": "allreduce_pairs", + "colletive": "allreduce", + "protocol": "Simple", + "inplace": true, + "gpus": [ + { + "id": 0, + "inputChunks": 8, + "outputChunks": 0, + "scratchChunks": 0, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 2 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 0 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 1 + } + ] + }, + { + "name": "rrcs", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 0 + } + ], + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 0 + } + ], + "src": 0, + "srcbuff": "i", + "srcoff": 0, + "dst": 0, + "dstbuff": "i", + "dstoff": 0, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 3 + } + ] + }, + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 0 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 2 + } + ], + "ctype": "sm", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 0 + ] + } + ] + }, + { + "id": 1, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 6 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 4 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 1, + "step": 1 + } + ] + }, + { + "name": "rrcs", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 4 + } + ], + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 4 + } + ], + "src": 0, + "srcbuff": "i", + "srcoff": 4, + "dst": 0, + "dstbuff": "i", + "dstoff": 4, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 1, + "step": 3 + } + ] + }, + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 4 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 6 + } + ], + "ctype": "sm", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 1 + ] + } + ] + }, + { + "id": 2, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 3 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 1 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 2, + "step": 1 + } + ] + }, + { + "name": "rrcs", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 1 + } + ], + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 1 + } + ], + "src": 0, + "srcbuff": "i", + "srcoff": 1, + "dst": 0, + "dstbuff": "i", + "dstoff": 1, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 2, + "step": 3 + } + ] + }, + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 1 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 3 + } + ], + "ctype": "sm", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 2 + ] + } + ] + }, + { + "id": 3, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 7 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 5 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 3, + "step": 1 + } + ] + }, + { + "name": "rrcs", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 5 + } + ], + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 5 + } + ], + "src": 0, + "srcbuff": "i", + "srcoff": 5, + "dst": 0, + "dstbuff": "i", + "dstoff": 5, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 3, + "step": 3 + } + ] + }, + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 5 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 7 + } + ], + "ctype": "sm", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 3 + ] + } + ] + } + ], + "channels": [ + { + "srcbuff": "i", + "dstbuff": "i", + "type": "sm", + "connectedTo": [ + 1, + 1, + 1, + 1 + ] + } + ] + }, + { + "id": 1, + "inputChunks": 8, + "outputChunks": 0, + "scratchChunks": 0, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 0 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 2 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 1 + } + ] + }, + { + "name": "rrcs", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 2 + } + ], + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 2 + } + ], + "src": 1, + "srcbuff": "i", + "srcoff": 2, + "dst": 1, + "dstbuff": "i", + "dstoff": 2, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 3 + } + ] + }, + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 2 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 0 + } + ], + "ctype": "sm", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 0 + ] + } + ] + }, + { + "id": 1, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 4 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 6 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 1, + "step": 1 + } + ] + }, + { + "name": "rrcs", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 6 + } + ], + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 6 + } + ], + "src": 1, + "srcbuff": "i", + "srcoff": 6, + "dst": 1, + "dstbuff": "i", + "dstoff": 6, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 1, + "step": 3 + } + ] + }, + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 6 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 4 + } + ], + "ctype": "sm", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 1 + ] + } + ] + }, + { + "id": 2, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 1 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 3 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 2, + "step": 1 + } + ] + }, + { + "name": "rrcs", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 3 + } + ], + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 3 + } + ], + "src": 1, + "srcbuff": "i", + "srcoff": 3, + "dst": 1, + "dstbuff": "i", + "dstoff": 3, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 2, + "step": 3 + } + ] + }, + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 3 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 1 + } + ], + "ctype": "sm", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 2 + ] + } + ] + }, + { + "id": 3, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 5 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 7 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 3, + "step": 1 + } + ] + }, + { + "name": "rrcs", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 7 + } + ], + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 7 + } + ], + "src": 1, + "srcbuff": "i", + "srcoff": 7, + "dst": 1, + "dstbuff": "i", + "dstoff": 7, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 3, + "step": 3 + } + ] + }, + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 7 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 5 + } + ], + "ctype": "sm", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 3 + ] + } + ] + } + ], + "channels": [ + { + "srcbuff": "i", + "dstbuff": "i", + "type": "sm", + "connectedTo": [ + 0, + 0, + 0, + 0 + ] + } + ] + } + ] +} diff --git a/test/execution-files/allreduce_packet.json b/test/execution-files/allreduce_packet.json new file mode 100644 index 000000000..c01ae4fd9 --- /dev/null +++ b/test/execution-files/allreduce_packet.json @@ -0,0 +1,330 @@ +{ + "name": "allreduce_pairs", + "colletive": "allreduce", + "protocol": "LL", + "inplace": true, + "gpus": [ + { + "id": 0, + "inputChunks": 4, + "outputChunks": 0, + "scratchChunks": 8, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "ppkt", + "o_buff": { + "src": "i", + "dst": "s" + }, + "o_cids": [ + { + "id": 0, + "off": 0 + } + ], + "src": 0, + "srcbuff": "i", + "srcoff": 2, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "rspkt", + "o_buff": { + "src": "i", + "dst": "s" + }, + "o_cids": [ + { + "id": 0, + "off": 4 + } + ], + "src": 0, + "srcs": [ + { + "buff": "s", + "off": 2 + } + ], + "srcbuff": "i", + "srcoff": 0, + "dst": 0, + "dstbuff": "i", + "dstoff": 0, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "cpkt", + "src": 0, + "srcbuff": "s", + "srcoff": 6, + "dst": 0, + "dstbuff": "i", + "dstoff": 2, + "ctype": "none", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "s", + "ctype": "sm", + "cids": [ + 0 + ] + } + ] + }, + { + "id": 1, + "ops": [ + { + "name": "ppkt", + "o_buff": { + "src": "i", + "dst": "s" + }, + "o_cids": [ + { + "id": 0, + "off": 1 + } + ], + "src": 0, + "srcbuff": "i", + "srcoff": 3, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "rspkt", + "o_buff": { + "src": "i", + "dst": "s" + }, + "o_cids": [ + { + "id": 0, + "off": 5 + } + ], + "src": 0, + "srcs": [ + { + "buff": "s", + "off": 3 + } + ], + "srcbuff": "i", + "srcoff": 1, + "dst": 0, + "dstbuff": "i", + "dstoff": 1, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "cpkt", + "src": 0, + "srcbuff": "s", + "srcoff": 7, + "dst": 0, + "dstbuff": "i", + "dstoff": 3, + "ctype": "none", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "s", + "ctype": "sm", + "cids": [ + 1 + ] + } + ] + } + ], + "channels": [ + { + "srcbuff": "i", + "dstbuff": "s", + "type": "sm", + "connectedTo": [ + 1, + 1 + ] + } + ] + }, + { + "id": 1, + "inputChunks": 4, + "outputChunks": 0, + "scratchChunks": 8, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "ppkt", + "o_buff": { + "src": "i", + "dst": "s" + }, + "o_cids": [ + { + "id": 0, + "off": 2 + } + ], + "src": 1, + "srcbuff": "i", + "srcoff": 0, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "rspkt", + "o_buff": { + "src": "i", + "dst": "s" + }, + "o_cids": [ + { + "id": 0, + "off": 6 + } + ], + "src": 1, + "srcs": [ + { + "buff": "s", + "off": 0 + } + ], + "srcbuff": "i", + "srcoff": 2, + "dst": 1, + "dstbuff": "i", + "dstoff": 2, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "cpkt", + "src": 1, + "srcbuff": "s", + "srcoff": 4, + "dst": 1, + "dstbuff": "i", + "dstoff": 0, + "ctype": "none", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "s", + "ctype": "sm", + "cids": [ + 0 + ] + } + ] + }, + { + "id": 1, + "ops": [ + { + "name": "ppkt", + "o_buff": { + "src": "i", + "dst": "s" + }, + "o_cids": [ + { + "id": 0, + "off": 3 + } + ], + "src": 1, + "srcbuff": "i", + "srcoff": 1, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "rspkt", + "o_buff": { + "src": "i", + "dst": "s" + }, + "o_cids": [ + { + "id": 0, + "off": 7 + } + ], + "src": 1, + "srcs": [ + { + "buff": "s", + "off": 1 + } + ], + "srcbuff": "i", + "srcoff": 3, + "dst": 1, + "dstbuff": "i", + "dstoff": 3, + "ctype": "sm", + "cnt": 1 + }, + { + "name": "cpkt", + "src": 1, + "srcbuff": "s", + "srcoff": 5, + "dst": 1, + "dstbuff": "i", + "dstoff": 1, + "ctype": "none", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "s", + "ctype": "sm", + "cids": [ + 1 + ] + } + ] + } + ], + "channels": [ + { + "srcbuff": "i", + "dstbuff": "s", + "type": "sm", + "connectedTo": [ + 0, + 0 + ] + } + ] + } + ] +} diff --git a/test/mp_unit/CMakeLists.txt b/test/mp_unit/CMakeLists.txt index dc388844f..8e37d2405 100644 --- a/test/mp_unit/CMakeLists.txt +++ b/test/mp_unit/CMakeLists.txt @@ -8,4 +8,5 @@ target_sources(mp_unit_tests PRIVATE communicator_tests.cu proxy_channel_tests.cu sm_channel_tests.cu + executor_tests.cc ) diff --git a/test/mp_unit/bootstrap_tests.cc b/test/mp_unit/bootstrap_tests.cc index 69e566dbd..65ec17027 100644 --- a/test/mp_unit/bootstrap_tests.cc +++ b/test/mp_unit/bootstrap_tests.cc @@ -120,6 +120,13 @@ class MPIBootstrap : public mscclpp::Bootstrap { MPI_Comm_size(MPI_COMM_WORLD, &worldSize); return worldSize; } + int getNranksPerNode() override { + MPI_Comm shmcomm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm); + int shmrank; + MPI_Comm_size(shmcomm, &shmrank); + return shmrank; + } void allGather(void* sendbuf, int size) override { MPI_Allgather(MPI_IN_PLACE, 0, MPI_BYTE, sendbuf, size, MPI_BYTE, MPI_COMM_WORLD); } diff --git a/test/mp_unit/executor_tests.cc b/test/mp_unit/executor_tests.cc new file mode 100644 index 000000000..6bfe09516 --- /dev/null +++ b/test/mp_unit/executor_tests.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include + +#include "mp_unit_tests.hpp" + +namespace { +std::string getExecutablePath() { + char result[PATH_MAX]; + ssize_t count = readlink("/proc/self/exe", result, PATH_MAX); + if (count == -1) { + throw std::runtime_error("Failed to get executable path"); + } + return std::string(result, count); +} +} // namespace + +void ExecutorTest::SetUp() { + MultiProcessTest::SetUp(); + + MSCCLPP_CUDATHROW(cudaSetDevice(rankToLocalRank(gEnv->rank))); + std::shared_ptr bootstrap; + mscclpp::UniqueId id; + if (gEnv->rank < gEnv->worldSize) { + bootstrap = std::make_shared(gEnv->rank, gEnv->worldSize); + if (gEnv->rank == 0) id = bootstrap->createUniqueId(); + } + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->initialize(id); + std::shared_ptr communicator = std::make_shared(bootstrap); + executor = std::make_shared(communicator); +} + +void ExecutorTest::TearDown() { + executor.reset(); + MultiProcessTest::TearDown(); +} + +TEST_F(ExecutorTest, TwoNodesAllreduce) { + if (gEnv->worldSize != 2 || gEnv->nRanksPerNode != 2) { + GTEST_SKIP() << "This test requires world size to be 2 and ranks per node to be 2"; + return; + } + std::string executablePath = getExecutablePath(); + std::filesystem::path path = executablePath; + std::filesystem::path executionFilesPath = + path.parent_path().parent_path().parent_path() / "test/execution-files/allreduce.json"; + mscclpp::ExecutionPlan plan("allreduce_pairs", executionFilesPath.string()); + const int bufferSize = 1024 * 1024; + std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize); + mscclpp::CudaStreamWithFlags stream(cudaStreamNonBlocking); + executor->execute(gEnv->rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16, 512, + plan, stream); + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); +} diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index e934dee49..6cb159c67 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -155,4 +156,11 @@ class SmChannelOneToOneTest : public CommunicatorTestBase { std::unordered_map> smSemaphores; }; +class ExecutorTest : public MultiProcessTest { + protected: + void SetUp() override; + void TearDown() override; + + std::shared_ptr executor; +}; #endif // MSCCLPP_MP_UNIT_TESTS_HPP_