From 5ba6ce00c71110a9245ef53efade76b4444626a9 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 26 Mar 2024 19:24:24 -0700 Subject: [PATCH] Fix bootstrapping mechanism (#278) Co-authored-by: Binyang Li Co-authored-by: Pashupati Kumar <74680231+pash-msft@users.noreply.github.com> --- CMakeLists.txt | 2 +- include/mscclpp/core.hpp | 8 +-- python/mscclpp/CMakeLists.txt | 4 +- python/mscclpp/core_py.cpp | 2 +- python/test/CMakeLists.txt | 4 +- src/bootstrap/bootstrap.cc | 86 +++++++++++++++++++++------------ src/bootstrap/socket.cc | 6 ++- src/fifo.cc | 1 + src/include/socket.h | 12 +++-- test/CMakeLists.txt | 2 +- test/mp_unit/bootstrap_tests.cc | 2 +- test/unit/socket_tests.cc | 2 +- 12 files changed, 81 insertions(+), 50 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 302febab7..4715ac0cc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -101,7 +101,7 @@ find_package(Threads REQUIRED) add_library(mscclpp_obj OBJECT) target_include_directories(mscclpp_obj - PRIVATE + SYSTEM PRIVATE ${GPU_INCLUDE_DIRS} ${IBVERBS_INCLUDE_DIRS} ${NUMA_INCLUDE_DIRS}) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 02c277a3e..c2a4dff44 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -51,6 +51,10 @@ class Bootstrap { /// A native implementation of the bootstrap using TCP sockets. class TcpBootstrap : public Bootstrap { public: + /// Create a random unique ID. + /// @return The created unique ID. + static UniqueId createUniqueId(); + /// Constructor. /// @param rank The rank of the process. /// @param nRanks The total number of ranks. @@ -59,10 +63,6 @@ class TcpBootstrap : public Bootstrap { /// Destructor. ~TcpBootstrap(); - /// Create a random unique ID and store it in the @ref TcpBootstrap. - /// @return The created unique ID. - UniqueId createUniqueId(); - /// Return the unique ID stored in the @ref TcpBootstrap. /// @return The unique ID stored in the @ref TcpBootstrap. UniqueId getUniqueId() const; diff --git a/python/mscclpp/CMakeLists.txt b/python/mscclpp/CMakeLists.txt index 0fe510c80..bb9eadf32 100644 --- a/python/mscclpp/CMakeLists.txt +++ b/python/mscclpp/CMakeLists.txt @@ -9,6 +9,6 @@ FetchContent_MakeAvailable(nanobind) file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp) nanobind_add_module(mscclpp_py ${SOURCES}) set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp) -target_link_libraries(mscclpp_py PRIVATE ${GPU_LIBRARIES} mscclpp_static) -target_include_directories(mscclpp_py PRIVATE ${GPU_INCLUDE_DIRS}) +target_link_libraries(mscclpp_py PRIVATE mscclpp_static ${GPU_LIBRARIES}) +target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS}) install(TARGETS mscclpp_py LIBRARY DESTINATION .) diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index 5fd4bd317..1a1cd2780 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -63,7 +63,7 @@ void register_core(nb::module_& m) { .def_static( "create", [](int rank, int nRanks) { return std::make_shared(rank, nRanks); }, nb::arg("rank"), nb::arg("nRanks")) - .def("create_unique_id", &TcpBootstrap::createUniqueId) + .def_static("create_unique_id", &TcpBootstrap::createUniqueId) .def("get_unique_id", &TcpBootstrap::getUniqueId) .def("initialize", static_cast(&TcpBootstrap::initialize), nb::call_guard(), nb::arg("uniqueId"), nb::arg("timeoutSec") = 30) diff --git a/python/test/CMakeLists.txt b/python/test/CMakeLists.txt index cf705841c..be62aea99 100644 --- a/python/test/CMakeLists.txt +++ b/python/test/CMakeLists.txt @@ -9,5 +9,5 @@ FetchContent_MakeAvailable(nanobind) file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp) nanobind_add_module(mscclpp_py_test ${SOURCES}) set_target_properties(mscclpp_py_test PROPERTIES OUTPUT_NAME _ext) -target_link_libraries(mscclpp_py_test PRIVATE ${GPU_LIBRARIES} mscclpp_static) -target_include_directories(mscclpp_py_test PRIVATE ${GPU_INCLUDE_DIRS}) +target_link_libraries(mscclpp_py_test PRIVATE mscclpp_static ${GPU_LIBRARIES}) +target_include_directories(mscclpp_py_test SYSTEM PRIVATE ${GPU_INCLUDE_DIRS}) diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index 00a58b992..c9cea10f4 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -70,12 +70,14 @@ static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is class TcpBootstrap::Impl { public: + static UniqueId createUniqueId(); + static UniqueId getUniqueId(const UniqueIdInternal& uniqueId); + Impl(int rank, int nRanks); ~Impl(); void initialize(const UniqueId& uniqueId, int64_t timeoutSec); void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec); void establishConnections(int64_t timeoutSec); - UniqueId createUniqueId(); UniqueId getUniqueId() const; int getRank(); int getNranks(); @@ -99,7 +101,6 @@ class TcpBootstrap::Impl { std::unique_ptr abortFlagStorage_; volatile uint32_t* abortFlag_; std::thread rootThread_; - char netIfName_[MAX_IF_NAME_SIZE + 1]; SocketAddress netIfAddr_; std::unordered_map, std::shared_ptr, PairHash> peerSendSockets_; std::unordered_map, std::shared_ptr, PairHash> peerRecvSockets_; @@ -110,15 +111,33 @@ class TcpBootstrap::Impl { std::shared_ptr getPeerSendSocket(int peer, int tag); std::shared_ptr getPeerRecvSocket(int peer, int tag); + static void assignPortToUniqueId(UniqueIdInternal& uniqueId); + static void netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr); + void bootstrapCreateRoot(); void bootstrapRoot(); void getRemoteAddresses(Socket* listenSock, std::vector& rankAddresses, std::vector& rankAddressesRoot, int& rank); void sendHandleToPeer(int peer, const std::vector& rankAddresses, const std::vector& rankAddressesRoot); - void netInit(std::string ipPortPair, std::string interface); }; +UniqueId TcpBootstrap::Impl::createUniqueId() { + UniqueIdInternal uniqueId; + SocketAddress netIfAddr; + netInit("", "", netIfAddr); + getRandomData(&uniqueId.magic, sizeof(uniqueId_.magic)); + std::memcpy(&uniqueId.addr, &netIfAddr, sizeof(SocketAddress)); + assignPortToUniqueId(uniqueId); + return getUniqueId(uniqueId); +} + +UniqueId TcpBootstrap::Impl::getUniqueId(const UniqueIdInternal& uniqueId) { + UniqueId ret; + std::memcpy(&ret, &uniqueId, sizeof(uniqueId)); + return ret; +} + TcpBootstrap::Impl::Impl(int rank, int nRanks) : rank_(rank), nRanks_(nRanks), @@ -128,29 +147,26 @@ TcpBootstrap::Impl::Impl(int rank, int nRanks) abortFlagStorage_(new uint32_t(0)), abortFlag_(abortFlagStorage_.get()) {} -UniqueId TcpBootstrap::Impl::getUniqueId() const { - UniqueId ret; - std::memcpy(&ret, &uniqueId_, sizeof(uniqueId_)); - return ret; -} - -UniqueId TcpBootstrap::Impl::createUniqueId() { - netInit("", ""); - getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic)); - std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress)); - bootstrapCreateRoot(); - return getUniqueId(); -} +UniqueId TcpBootstrap::Impl::getUniqueId() const { return getUniqueId(uniqueId_); } int TcpBootstrap::Impl::getRank() { return rank_; } int TcpBootstrap::Impl::getNranks() { return nRanks_; } void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec) { - netInit("", ""); + if (!netInitialized) { + netInit("", "", netIfAddr_); + netInitialized = true; + } std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_)); + if (rank_ == 0) { + bootstrapCreateRoot(); + } + char line[MAX_IF_NAME_SIZE + 1]; + SocketToString(&uniqueId_.addr, line); + INFO(MSCCLPP_INIT, "rank %d nranks %d - connecting to %s", rank_, nRanks_, line); establishConnections(timeoutSec); } @@ -170,7 +186,10 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t tim ipPortPair = ifIpPortTrio.substr(ipPortPair.find_first_of(':') + 1); } - netInit(ipPortPair, interface); + if (!netInitialized) { + netInit(ipPortPair, interface, netIfAddr_); + netInitialized = true; + } uniqueId_.magic = 0xdeadbeef; std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress)); @@ -230,9 +249,15 @@ void TcpBootstrap::Impl::sendHandleToPeer(int peer, const std::vector socket = std::make_unique(&uniqueId.addr, uniqueId.magic, SocketTypeBootstrap); + socket->bind(); + uniqueId.addr = socket->getAddr(); +} + void TcpBootstrap::Impl::bootstrapCreateRoot() { listenSockRoot_ = std::make_unique(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0); - listenSockRoot_->listen(); + listenSockRoot_->bindAndListen(); uniqueId_.addr = listenSockRoot_->getAddr(); rootThread_ = std::thread([this]() { @@ -279,34 +304,33 @@ void TcpBootstrap::Impl::bootstrapRoot() { TRACE(MSCCLPP_INIT, "DONE"); } -void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface) { - if (netInitialized) return; +void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr) { + char netIfName[MAX_IF_NAME_SIZE + 1]; if (!ipPortPair.empty()) { if (interface != "") { // we know the - int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1, interface.c_str()); + int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1, interface.c_str()); if (ret <= 0) throw Error("NET/Socket : No interface named " + interface + " found.", ErrorCode::InternalError); } else { // we do not know the try to match it next SocketAddress remoteAddr; SocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()); - if (FindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { + if (FindInterfaceMatchSubnet(netIfName, &netIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { throw Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError); } } } else { - int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1); + int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1); if (ret <= 0) { throw Error("TcpBootstrap : no socket interface found", ErrorCode::InternalError); } } char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2]; - std::sprintf(line, " %s:", netIfName_); - SocketToString(&netIfAddr_, line + strlen(line)); + std::sprintf(line, " %s:", netIfName); + SocketToString(&netIfAddr, line + strlen(line)); INFO(MSCCLPP_INIT, "TcpBootstrap : Using%s", line); - netInitialized = true; } #define TIMEOUT(__exp) \ @@ -345,13 +369,13 @@ void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) { uint64_t magic = uniqueId_.magic; // Create socket for other ranks to contact me listenSock_ = std::make_unique(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_); - listenSock_->listen(); + listenSock_->bindAndListen(); info.extAddressListen = listenSock_->getAddr(); { // Create socket for root to contact me Socket lsock(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_); - lsock.listen(); + lsock.bindAndListen(); info.extAddressListenRoot = lsock.getAddr(); // stagger connection times to avoid an overload of the root @@ -486,9 +510,9 @@ void TcpBootstrap::Impl::close() { peerRecvSockets_.clear(); } -MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique(rank, nRanks); } +MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return Impl::createUniqueId(); } -MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return pimpl_->createUniqueId(); } +MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique(rank, nRanks); } MSCCLPP_API_CPP UniqueId TcpBootstrap::getUniqueId() const { return pimpl_->getUniqueId(); } diff --git a/src/bootstrap/socket.cc b/src/bootstrap/socket.cc index 2267af9b3..a79821f1b 100644 --- a/src/bootstrap/socket.cc +++ b/src/bootstrap/socket.cc @@ -390,7 +390,7 @@ Socket::Socket(const SocketAddress* addr, uint64_t magic, enum SocketType type, Socket::~Socket() { close(); } -void Socket::listen() { +void Socket::bind() { if (fd_ == -1) { throw Error("file descriptor is -1", ErrorCode::InvalidUsage); } @@ -433,7 +433,11 @@ void Socket::listen() { if (::getsockname(fd_, &addr_.sa, &size) != 0) { throw SysError("getsockname failed", errno); } + state_ = SocketStateBound; +} +void Socket::bindAndListen() { + bind(); #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN + 1]; TRACE(MSCCLPP_INIT | MSCCLPP_NET, "Listening on socket %s", SocketToString(&addr_, line)); diff --git a/src/fifo.cc b/src/fifo.cc index 4255bcdcd..592bf7d00 100644 --- a/src/fifo.cc +++ b/src/fifo.cc @@ -56,6 +56,7 @@ MSCCLPP_API_CPP void Fifo::pop() { MSCCLPP_API_CPP void Fifo::flushTail(bool sync) { // Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure that the fifo can // make progress even if there is no request mscclppSync. However, mscclppSync type is for flush request. + AvoidCudaGraphCaptureGuard cgcGuard; MSCCLPP_CUDATHROW(cudaMemcpyAsync(pimpl->tailReplica.get(), &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice, pimpl->stream)); if (sync) { diff --git a/src/include/socket.h b/src/include/socket.h index 9f043414e..ed125c990 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -35,10 +35,11 @@ enum SocketState { SocketStateConnecting = 4, SocketStateConnectPolling = 5, SocketStateConnected = 6, - SocketStateReady = 7, - SocketStateClosed = 8, - SocketStateError = 9, - SocketStateNum = 10 + SocketStateBound = 7, + SocketStateReady = 8, + SocketStateClosed = 9, + SocketStateError = 10, + SocketStateNum = 11 }; enum SocketType { @@ -62,7 +63,8 @@ class Socket { enum SocketType type = SocketTypeUnknown, volatile uint32_t* abortFlag = nullptr, int asyncFlag = 0); ~Socket(); - void listen(); + void bind(); + void bindAndListen(); void connect(int64_t timeout = -1); void accept(const Socket* listenSocket, int64_t timeout = -1); void send(void* ptr, int size); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0268af1c6..da47066ea 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -5,7 +5,7 @@ find_package(MPI) set(TEST_LIBS_COMMON mscclpp ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads) set(TEST_LIBS_GTEST GTest::gtest_main GTest::gmock_main) -set(TEST_INC_COMMON PRIVATE ${PROJECT_SOURCE_DIR}/include ${GPU_INCLUDE_DIRS}) +set(TEST_INC_COMMON PRIVATE ${PROJECT_SOURCE_DIR}/include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS}) set(TEST_INC_INTERNAL PRIVATE ${PROJECT_SOURCE_DIR}/src/include) if(USE_ROCM) diff --git a/test/mp_unit/bootstrap_tests.cc b/test/mp_unit/bootstrap_tests.cc index 82120a1f7..69e566dbd 100644 --- a/test/mp_unit/bootstrap_tests.cc +++ b/test/mp_unit/bootstrap_tests.cc @@ -67,7 +67,7 @@ TEST_F(BootstrapTest, ResumeWithId) { // This test may take a few minutes. bootstrapTestTimer.set(300); - for (int i = 0; i < 3000; ++i) { + for (int i = 0; i < 10; ++i) { auto bootstrap = std::make_shared(gEnv->rank, gEnv->worldSize); mscclpp::UniqueId id; if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId(); diff --git a/test/unit/socket_tests.cc b/test/unit/socket_tests.cc index fe0a063e5..4fa8d3915 100644 --- a/test/unit/socket_tests.cc +++ b/test/unit/socket_tests.cc @@ -17,7 +17,7 @@ TEST(Socket, ListenAndConnect) { ASSERT_NO_THROW(mscclpp::SocketGetAddrFromString(&listenAddr, ipPortPair.c_str())); mscclpp::Socket listenSock(&listenAddr); - listenSock.listen(); + listenSock.bindAndListen(); std::thread clientThread([&listenAddr]() { mscclpp::Socket sock(&listenAddr);