Skip to content

Commit

Permalink
Add executor to execute schedule-plan file (#283)
Browse files Browse the repository at this point in the history
Add executor to execute the JSON schedule file generated by msccl-tools

---------

Co-authored-by: Changho Hwang <[email protected]>
  • Loading branch information
Binyang2014 and chhwang authored Apr 18, 2024
1 parent 9406123 commit 64d837f
Show file tree
Hide file tree
Showing 27 changed files with 2,857 additions and 3 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,17 @@ find_package(IBVerbs REQUIRED)
find_package(NUMA REQUIRED)
find_package(Threads REQUIRED)

include(FetchContent)
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)

add_library(mscclpp_obj OBJECT)
target_include_directories(mscclpp_obj
SYSTEM PRIVATE
${GPU_INCLUDE_DIRS}
${IBVERBS_INCLUDE_DIRS}
${NUMA_INCLUDE_DIRS})
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads)
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads)
set_target_properties(mscclpp_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
if(USE_CUDA)
target_compile_definitions(mscclpp_obj PRIVATE USE_CUDA)
Expand Down
8 changes: 8 additions & 0 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Bootstrap {
virtual ~Bootstrap() = default;
virtual int getRank() = 0;
virtual int getNranks() = 0;
virtual int getNranksPerNode() = 0;
virtual void send(void* data, int size, int peer, int tag) = 0;
virtual void recv(void* data, int size, int peer, int tag) = 0;
virtual void allGather(void* allData, int size) = 0;
Expand Down Expand Up @@ -83,6 +84,9 @@ class TcpBootstrap : public Bootstrap {
/// Return the total number of ranks.
int getNranks() override;

/// Return the total number of ranks per node.
int getNranksPerNode() override;

/// Send data to another process.
///
/// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size,
Expand Down Expand Up @@ -760,6 +764,10 @@ DeviceHandle<std::remove_reference_t<T>> deviceHandle(T&& t) {
return t.deviceHandle();
}

/// Packet value type.
template <class T>
using PacketPayload = typename T::Payload;

} // namespace mscclpp

namespace std {
Expand Down
1 change: 1 addition & 0 deletions include/mscclpp/errors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ enum class ErrorCode {
InvalidUsage, // The function was used incorrectly.
Timeout, // The operation timed out.
Aborted, // The operation was aborted.
ExecutorError, // An error occurred in the MSCCL++ executor.
};

/// Convert an error code to a string.
Expand Down
53 changes: 53 additions & 0 deletions include/mscclpp/executor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#ifndef MSCCLPP_EXECUTOR_HPP_
#define MSCCLPP_EXECUTOR_HPP_

#include <memory>
#include <mscclpp/core.hpp>
#include <unordered_map>

namespace mscclpp {

enum class DataType {
INT32,
UINT32,
FLOAT16,
FLOAT32,
};

enum class PacketType {
LL8,
LL16,
};

class ExecutionPlan {
public:
ExecutionPlan(const std::string& name, const std::string& planPath);
~ExecutionPlan() = default;

private:
struct Impl;
std::shared_ptr<Impl> impl_;

friend class Executor;
};

class Executor {
public:
Executor(std::shared_ptr<Communicator> comm);
Executor(const Executor&) = delete;
Executor& operator=(const Executor&) = delete;
~Executor();

void execute(int rank, void* sendbuff, void* recvBuff, size_t sendBuffSize, size_t recvBuffSize, DataType dataType,
int nthreads, const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType = PacketType::LL16);

private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace mscclpp

#endif // MSCCLPP_EXECUTOR_HPP_
15 changes: 15 additions & 0 deletions include/mscclpp/packet_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,20 @@ union alignas(16) LL16Packet {
uint32_t data2;
uint32_t flag2;
};
using Payload = uint2;

#if defined(MSCCLPP_DEVICE_COMPILE)
ulonglong2 raw_;

MSCCLPP_DEVICE_INLINE LL16Packet() {}

MSCCLPP_DEVICE_INLINE LL16Packet(uint2 val, uint32_t flag) {
data1 = val.x;
flag1 = flag;
data2 = val.y;
flag2 = flag;
}

/// Write 8 bytes of data to the packet.
/// @param val1 The first 4-byte data to write.
/// @param val2 The second 4-byte data to write.
Expand Down Expand Up @@ -95,10 +103,17 @@ union alignas(8) LL8Packet {
uint32_t flag;
};
uint64_t raw_;

using Payload = uint32_t;
#if defined(MSCCLPP_DEVICE_COMPILE)

MSCCLPP_DEVICE_INLINE LL8Packet() {}

MSCCLPP_DEVICE_INLINE LL8Packet(uint32_t val, uint32_t flag) {
this->data = val;
this->flag = flag;
}

MSCCLPP_DEVICE_INLINE void write(uint32_t val, uint32_t flag) {
#if defined(MSCCLPP_DEVICE_CUDA)
asm volatile("st.volatile.global.v2.u32 [%0], {%1,%2};" ::"l"(&raw_), "r"(val), "r"(flag));
Expand Down
4 changes: 4 additions & 0 deletions python/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
TcpBootstrap,
Transport,
TransportFlags,
DataType,
Executor,
ExecutionPlan,
PacketType,
version,
is_nvls_supported,
)
Expand Down
1 change: 1 addition & 0 deletions python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self.communicator = Communicator(self.bootstrap)
self.my_rank = self.bootstrap.get_rank()
self.nranks = self.bootstrap.get_n_ranks()
self.nranks_per_node = self.bootstrap.get_n_ranks_per_node()

def barrier(self):
self.bootstrap.barrier()
Expand Down
3 changes: 3 additions & 0 deletions python/mscclpp/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ extern void register_fifo(nb::module_& m);
extern void register_semaphore(nb::module_& m);
extern void register_utils(nb::module_& m);
extern void register_numa(nb::module_& m);
extern void register_executor(nb::module_& m);

template <typename T>
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
Expand All @@ -35,6 +36,7 @@ void register_core(nb::module_& m) {
nb::class_<Bootstrap>(m, "Bootstrap")
.def("get_rank", &Bootstrap::getRank)
.def("get_n_ranks", &Bootstrap::getNranks)
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
.def(
"send",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
Expand Down Expand Up @@ -204,4 +206,5 @@ NB_MODULE(_mscclpp, m) {
register_utils(m);
register_core(m);
register_numa(m);
register_executor(m);
}
3 changes: 2 additions & 1 deletion python/mscclpp/error_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ void register_error(nb::module_& m) {
.value("RemoteError", ErrorCode::RemoteError)
.value("InvalidUsage", ErrorCode::InvalidUsage)
.value("Timeout", ErrorCode::Timeout)
.value("Aborted", ErrorCode::Aborted);
.value("Aborted", ErrorCode::Aborted)
.value("ExecutorError", ErrorCode::ExecutorError);

nb::class_<BaseError>(m, "BaseError")
.def(nb::init<std::string&, int>(), nb::arg("message"), nb::arg("errorCode"))
Expand Down
38 changes: 38 additions & 0 deletions python/mscclpp/executor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <nanobind/nanobind.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>

#include <mscclpp/executor.hpp>
#include <mscclpp/gpu.hpp>

namespace nb = nanobind;
using namespace mscclpp;

void register_executor(nb::module_& m) {
nb::enum_<DataType>(m, "DataType")
.value("int32", DataType::INT32)
.value("uint32", DataType::UINT32)
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32);

nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);

nb::class_<ExecutionPlan>(m, "ExecutionPlan")
.def(nb::init<const std::string, const std::string>(), nb::arg("name"), nb::arg("planPath"));

nb::class_<Executor>(m, "Executor")
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
.def(
"execute",
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
DataType dataType, int nthreads, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
recvBuffSize, dataType, nthreads, plan, (cudaStream_t)stream, packetType);
},
nb::arg("rank"), nb::arg("sendbuff"), nb::arg("recvBuff"), nb::arg("sendBuffSize"), nb::arg("recvBuffSize"),
nb::arg("dataType"), nb::arg("nthreads"), nb::arg("plan"), nb::arg("stream"),
nb::arg("packetType") = PacketType::LL16);
}
75 changes: 75 additions & 0 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from os import path
from mscclpp import (
DataType,
Executor,
ExecutionPlan,
)
import mscclpp.comm as mscclpp_comm

import cupy as cp
from mpi4py import MPI

MSCCLPP_ROOT_PATH = "/root/mscclpp"


def bench_time(niters: int, ngraphIters: int, func):
# capture cuda graph for niters of the kernel launch
stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()
for i in range(niters):
func(stream)
graph = stream.end_capture()

# now run a warm up round
graph.launch(stream)

# now run the benchmark and measure time
start = cp.cuda.Event()
end = cp.cuda.Event()

start.record(stream)
for _ in range(ngraphIters):
graph.launch(stream)
end.record(stream)
end.synchronize()

return cp.cuda.get_elapsed_time(start, end) / niters * 1000.0 / ngraphIters


if __name__ == "__main__":
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
cp.cuda.Device(MPI.COMM_WORLD.rank % mscclpp_group.nranks_per_node).use()
executor = Executor(mscclpp_group.communicator)
execution_plan = ExecutionPlan(
"allreduce_pairs", path.join(MSCCLPP_ROOT_PATH, "test", "execution-files", "allreduce.json")
)

nelems = 1024 * 1024
cp.random.seed(42)
buffer = cp.random.random(nelems).astype(cp.float16)
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
sendbuf = sub_arrays[MPI.COMM_WORLD.rank]
mscclpp_group.barrier()

execution_time = bench_time(
100,
10,
lambda stream: executor.execute(
MPI.COMM_WORLD.rank,
sendbuf.data.ptr,
sendbuf.data.ptr,
sendbuf.nbytes,
sendbuf.nbytes,
DataType.float16,
512,
execution_plan,
stream.ptr,
),
)
print(f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, data size: {sendbuf.nbytes} bytes")
executor = None
mscclpp_group = None
39 changes: 39 additions & 0 deletions python/test/test_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
import pytest

from mscclpp import (
DataType,
EndpointConfig,
ExecutionPlan,
Executor,
Fifo,
Host2DeviceSemaphore,
Host2HostSemaphore,
Expand Down Expand Up @@ -590,3 +593,39 @@ def test_nvls(mpi_group: MpiGroup):
kernel()
cp.cuda.runtime.deviceSynchronize()
group.barrier()


@parametrize_mpi_groups(2)
@pytest.mark.parametrize("filename", ["allreduce.json", "allreduce_packet.json"])
def test_executor(mpi_group: MpiGroup, filename: str):
if all_ranks_on_the_same_node(mpi_group) is False:
pytest.skip("algo not support cross node")
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
mscclpp_group = mscclpp_comm.CommGroup(mpi_group.comm)
executor = Executor(mscclpp_group.communicator)
execution_plan = ExecutionPlan("allreduce_pairs", os.path.join(project_dir, "test", "execution-files", filename))

nelems = 1024 * 1024
cp.random.seed(42)
buffer = cp.random.random(nelems).astype(cp.float16)
sub_arrays = cp.split(buffer, mpi_group.comm.size)
sendbuf = sub_arrays[mpi_group.comm.rank]
expected = cp.zeros_like(sendbuf)
for i in range(mpi_group.comm.size):
expected += sub_arrays[i]
mscclpp_group.barrier()

stream = cp.cuda.Stream(non_blocking=True)
executor.execute(
mpi_group.comm.rank,
sendbuf.data.ptr,
sendbuf.data.ptr,
sendbuf.nbytes,
sendbuf.nbytes,
DataType.float16,
512,
execution_plan,
stream.ptr,
)
stream.synchronize()
assert cp.allclose(sendbuf, expected, atol=1e-3 * mpi_group.comm.size)
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc)
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc *.cu)
target_sources(mscclpp_obj PRIVATE ${SOURCES})
target_include_directories(mscclpp_obj PRIVATE include)
Loading

0 comments on commit 64d837f

Please sign in to comment.