Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add executor to execute schedule-plan file #283

Merged
merged 54 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
20a5b6d
init
Binyang2014 Mar 28, 2024
abd4b3c
init
Binyang2014 Mar 29, 2024
c3e0e02
WIP
Binyang2014 Mar 29, 2024
91d4df2
WIP
Binyang2014 Mar 31, 2024
d6ec143
WIP
Binyang2014 Apr 1, 2024
3fdd602
WIP
Binyang2014 Apr 1, 2024
bf681b3
compiled
Binyang2014 Apr 1, 2024
edf93df
WIP
Binyang2014 Apr 1, 2024
c071e5d
add test file
Binyang2014 Apr 1, 2024
580e4a4
WIP
Binyang2014 Apr 1, 2024
d7026fb
WIP
Binyang2014 Apr 2, 2024
b34c9e8
WIP
Binyang2014 Apr 2, 2024
faef1e4
WIP
Binyang2014 Apr 2, 2024
a80bcee
build pass
Binyang2014 Apr 3, 2024
a1a11da
update struct
Binyang2014 Apr 3, 2024
4b5668c
fix
Binyang2014 Apr 3, 2024
d47ac65
try to launch kernel
Binyang2014 Apr 3, 2024
0b4c19a
WIP
Binyang2014 Apr 3, 2024
c14aac2
WIP
Binyang2014 Apr 4, 2024
d1c28bb
pass build
Binyang2014 Apr 4, 2024
36d31db
fix channel bugs
Binyang2014 Apr 4, 2024
37c2d7d
minor
Binyang2014 Apr 4, 2024
8c79780
add python binding
Binyang2014 Apr 4, 2024
71b6224
Fix
Binyang2014 Apr 5, 2024
2eb6426
update test json
Binyang2014 Apr 5, 2024
7e74ed8
for rocm
Binyang2014 Apr 5, 2024
7745c87
fix build
Binyang2014 Apr 5, 2024
bbf197d
update
Binyang2014 Apr 5, 2024
d38c9ed
build fix
Binyang2014 Apr 5, 2024
867101e
minor update
Binyang2014 Apr 5, 2024
6049e9e
more ops
Binyang2014 Apr 5, 2024
d97f312
WIP
Binyang2014 Apr 7, 2024
64106f1
WIP
Binyang2014 Apr 7, 2024
feaf058
update
Binyang2014 Apr 7, 2024
d52ef41
WIP
Binyang2014 Apr 7, 2024
b03be9a
WIP
Binyang2014 Apr 7, 2024
695ff94
update
Binyang2014 Apr 8, 2024
5f37c0a
lint
Binyang2014 Apr 8, 2024
78c5665
fix
Binyang2014 Apr 8, 2024
309f8f2
add test
Binyang2014 Apr 9, 2024
80a513e
minor update
Binyang2014 Apr 9, 2024
ab8d6d7
minor improve
Binyang2014 Apr 17, 2024
48d877c
WIP
Binyang2014 Apr 17, 2024
d6b0366
packet fix
Binyang2014 Apr 17, 2024
0c2b2c1
Minor updates
chhwang Apr 17, 2024
8d1b644
minor updates
chhwang Apr 17, 2024
c29df8e
Fix hang
Binyang2014 Apr 18, 2024
cccfd16
Merge branch 'binyli/executor' of https://github.com/microsoft/mscclp…
Binyang2014 Apr 18, 2024
4a73990
address comments
Binyang2014 Apr 18, 2024
d4671bc
address comments
Binyang2014 Apr 18, 2024
911e4c1
Merge branch 'main' into binyli/executor
Binyang2014 Apr 18, 2024
149eb41
address comments
Binyang2014 Apr 18, 2024
71e8b95
Merge branch 'binyli/executor' of https://github.com/microsoft/mscclp…
Binyang2014 Apr 18, 2024
c407f29
Python binding for ExecutorError
chhwang Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,17 @@ find_package(IBVerbs REQUIRED)
find_package(NUMA REQUIRED)
find_package(Threads REQUIRED)

include(FetchContent)
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)

add_library(mscclpp_obj OBJECT)
target_include_directories(mscclpp_obj
SYSTEM PRIVATE
${GPU_INCLUDE_DIRS}
${IBVERBS_INCLUDE_DIRS}
${NUMA_INCLUDE_DIRS})
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads)
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads)
set_target_properties(mscclpp_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
if(USE_CUDA)
target_compile_definitions(mscclpp_obj PRIVATE USE_CUDA)
Expand Down
8 changes: 8 additions & 0 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Bootstrap {
virtual ~Bootstrap() = default;
virtual int getRank() = 0;
virtual int getNranks() = 0;
virtual int getNranksPerNode() = 0;
virtual void send(void* data, int size, int peer, int tag) = 0;
virtual void recv(void* data, int size, int peer, int tag) = 0;
virtual void allGather(void* allData, int size) = 0;
Expand Down Expand Up @@ -83,6 +84,9 @@ class TcpBootstrap : public Bootstrap {
/// Return the total number of ranks.
int getNranks() override;

/// Return the total number of ranks per node.
int getNranksPerNode() override;

/// Send data to another process.
///
/// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size,
Expand Down Expand Up @@ -760,6 +764,10 @@ DeviceHandle<std::remove_reference_t<T>> deviceHandle(T&& t) {
return t.deviceHandle();
}

/// Packet value type.
template <class T>
using PacketPayload = typename T::Payload;

} // namespace mscclpp

namespace std {
Expand Down
1 change: 1 addition & 0 deletions include/mscclpp/errors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ enum class ErrorCode {
InvalidUsage, // The function was used incorrectly.
Timeout, // The operation timed out.
Aborted, // The operation was aborted.
ExecutorError, // An error occurred in the MSCCL++ executor.
};

/// Convert an error code to a string.
Expand Down
53 changes: 53 additions & 0 deletions include/mscclpp/executor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#ifndef MSCCLPP_EXECUTOR_HPP_
#define MSCCLPP_EXECUTOR_HPP_

#include <memory>
#include <mscclpp/core.hpp>
#include <unordered_map>

namespace mscclpp {

enum class DataType {
INT32,
UINT32,
FLOAT16,
FLOAT32,
};

enum class PacketType {
LL8,
LL16,
};

class ExecutionPlan {
public:
ExecutionPlan(const std::string& name, const std::string& planPath);
~ExecutionPlan() = default;

private:
struct Impl;
std::shared_ptr<Impl> impl_;

friend class Executor;
};

class Executor {
public:
Executor(std::shared_ptr<Communicator> comm);
Executor(const Executor&) = delete;
Executor& operator=(const Executor&) = delete;
~Executor();

void execute(int rank, void* sendbuff, void* recvBuff, size_t sendBuffSize, size_t recvBuffSize, DataType dataType,
int nthreads, const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType = PacketType::LL16);

private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace mscclpp

#endif // MSCCLPP_EXECUTOR_HPP_
15 changes: 15 additions & 0 deletions include/mscclpp/packet_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,20 @@ union alignas(16) LL16Packet {
uint32_t data2;
uint32_t flag2;
};
using Payload = uint2;

#if defined(MSCCLPP_DEVICE_COMPILE)
ulonglong2 raw_;

MSCCLPP_DEVICE_INLINE LL16Packet() {}

MSCCLPP_DEVICE_INLINE LL16Packet(uint2 val, uint32_t flag) {
data1 = val.x;
flag1 = flag;
data2 = val.y;
flag2 = flag;
}

/// Write 8 bytes of data to the packet.
/// @param val1 The first 4-byte data to write.
/// @param val2 The second 4-byte data to write.
Expand Down Expand Up @@ -95,10 +103,17 @@ union alignas(8) LL8Packet {
uint32_t flag;
};
uint64_t raw_;

using Payload = uint32_t;
#if defined(MSCCLPP_DEVICE_COMPILE)

MSCCLPP_DEVICE_INLINE LL8Packet() {}

MSCCLPP_DEVICE_INLINE LL8Packet(uint32_t val, uint32_t flag) {
this->data = val;
this->flag = flag;
}

MSCCLPP_DEVICE_INLINE void write(uint32_t val, uint32_t flag) {
#if defined(MSCCLPP_DEVICE_CUDA)
asm volatile("st.volatile.global.v2.u32 [%0], {%1,%2};" ::"l"(&raw_), "r"(val), "r"(flag));
Expand Down
4 changes: 4 additions & 0 deletions python/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
TcpBootstrap,
Transport,
TransportFlags,
DataType,
Executor,
ExecutionPlan,
PacketType,
version,
is_nvls_supported,
)
Expand Down
1 change: 1 addition & 0 deletions python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self.communicator = Communicator(self.bootstrap)
self.my_rank = self.bootstrap.get_rank()
self.nranks = self.bootstrap.get_n_ranks()
self.nranks_per_node = self.bootstrap.get_n_ranks_per_node()

def barrier(self):
self.bootstrap.barrier()
Expand Down
3 changes: 3 additions & 0 deletions python/mscclpp/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ extern void register_fifo(nb::module_& m);
extern void register_semaphore(nb::module_& m);
extern void register_utils(nb::module_& m);
extern void register_numa(nb::module_& m);
extern void register_executor(nb::module_& m);

template <typename T>
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
Expand All @@ -35,6 +36,7 @@ void register_core(nb::module_& m) {
nb::class_<Bootstrap>(m, "Bootstrap")
.def("get_rank", &Bootstrap::getRank)
.def("get_n_ranks", &Bootstrap::getNranks)
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
.def(
"send",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
Expand Down Expand Up @@ -204,4 +206,5 @@ NB_MODULE(_mscclpp, m) {
register_utils(m);
register_core(m);
register_numa(m);
register_executor(m);
}
3 changes: 2 additions & 1 deletion python/mscclpp/error_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ void register_error(nb::module_& m) {
.value("RemoteError", ErrorCode::RemoteError)
.value("InvalidUsage", ErrorCode::InvalidUsage)
.value("Timeout", ErrorCode::Timeout)
.value("Aborted", ErrorCode::Aborted);
.value("Aborted", ErrorCode::Aborted)
.value("ExecutorError", ErrorCode::ExecutorError);

nb::class_<BaseError>(m, "BaseError")
.def(nb::init<std::string&, int>(), nb::arg("message"), nb::arg("errorCode"))
Expand Down
38 changes: 38 additions & 0 deletions python/mscclpp/executor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <nanobind/nanobind.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>

#include <mscclpp/executor.hpp>
#include <mscclpp/gpu.hpp>

namespace nb = nanobind;
using namespace mscclpp;

void register_executor(nb::module_& m) {
nb::enum_<DataType>(m, "DataType")
.value("int32", DataType::INT32)
.value("uint32", DataType::UINT32)
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32);

nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);

nb::class_<ExecutionPlan>(m, "ExecutionPlan")
.def(nb::init<const std::string, const std::string>(), nb::arg("name"), nb::arg("planPath"));

nb::class_<Executor>(m, "Executor")
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
.def(
"execute",
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
DataType dataType, int nthreads, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
recvBuffSize, dataType, nthreads, plan, (cudaStream_t)stream, packetType);
},
nb::arg("rank"), nb::arg("sendbuff"), nb::arg("recvBuff"), nb::arg("sendBuffSize"), nb::arg("recvBuffSize"),
nb::arg("dataType"), nb::arg("nthreads"), nb::arg("plan"), nb::arg("stream"),
nb::arg("packetType") = PacketType::LL16);
}
75 changes: 75 additions & 0 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from os import path
from mscclpp import (
DataType,
Executor,
ExecutionPlan,
)
import mscclpp.comm as mscclpp_comm

import cupy as cp
from mpi4py import MPI

MSCCLPP_ROOT_PATH = "/root/mscclpp"


def bench_time(niters: int, ngraphIters: int, func):
# capture cuda graph for niters of the kernel launch
stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()
for i in range(niters):
func(stream)
graph = stream.end_capture()

# now run a warm up round
graph.launch(stream)

# now run the benchmark and measure time
start = cp.cuda.Event()
end = cp.cuda.Event()

start.record(stream)
for _ in range(ngraphIters):
graph.launch(stream)
end.record(stream)
end.synchronize()

return cp.cuda.get_elapsed_time(start, end) / niters * 1000.0 / ngraphIters


if __name__ == "__main__":
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
cp.cuda.Device(MPI.COMM_WORLD.rank % mscclpp_group.nranks_per_node).use()
executor = Executor(mscclpp_group.communicator)
execution_plan = ExecutionPlan(
"allreduce_pairs", path.join(MSCCLPP_ROOT_PATH, "test", "execution-files", "allreduce.json")
)

nelems = 1024 * 1024
cp.random.seed(42)
buffer = cp.random.random(nelems).astype(cp.float16)
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
sendbuf = sub_arrays[MPI.COMM_WORLD.rank]
mscclpp_group.barrier()

execution_time = bench_time(
100,
10,
lambda stream: executor.execute(
MPI.COMM_WORLD.rank,
sendbuf.data.ptr,
sendbuf.data.ptr,
sendbuf.nbytes,
sendbuf.nbytes,
DataType.float16,
512,
execution_plan,
stream.ptr,
),
)
print(f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, data size: {sendbuf.nbytes} bytes")
executor = None
mscclpp_group = None
39 changes: 39 additions & 0 deletions python/test/test_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
import pytest

from mscclpp import (
DataType,
EndpointConfig,
ExecutionPlan,
Executor,
Fifo,
Host2DeviceSemaphore,
Host2HostSemaphore,
Expand Down Expand Up @@ -590,3 +593,39 @@ def test_nvls(mpi_group: MpiGroup):
kernel()
cp.cuda.runtime.deviceSynchronize()
group.barrier()


@parametrize_mpi_groups(2)
@pytest.mark.parametrize("filename", ["allreduce.json", "allreduce_packet.json"])
def test_executor(mpi_group: MpiGroup, filename: str):
if all_ranks_on_the_same_node(mpi_group) is False:
pytest.skip("algo not support cross node")
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
mscclpp_group = mscclpp_comm.CommGroup(mpi_group.comm)
executor = Executor(mscclpp_group.communicator)
execution_plan = ExecutionPlan("allreduce_pairs", os.path.join(project_dir, "test", "execution-files", filename))

nelems = 1024 * 1024
cp.random.seed(42)
buffer = cp.random.random(nelems).astype(cp.float16)
sub_arrays = cp.split(buffer, mpi_group.comm.size)
sendbuf = sub_arrays[mpi_group.comm.rank]
expected = cp.zeros_like(sendbuf)
for i in range(mpi_group.comm.size):
expected += sub_arrays[i]
mscclpp_group.barrier()

stream = cp.cuda.Stream(non_blocking=True)
executor.execute(
mpi_group.comm.rank,
sendbuf.data.ptr,
sendbuf.data.ptr,
sendbuf.nbytes,
sendbuf.nbytes,
DataType.float16,
512,
execution_plan,
stream.ptr,
)
stream.synchronize()
assert cp.allclose(sendbuf, expected, atol=1e-3 * mpi_group.comm.size)
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc)
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc *.cu)
target_sources(mscclpp_obj PRIVATE ${SOURCES})
target_include_directories(mscclpp_obj PRIVATE include)
Loading
Loading