Skip to content

Commit

Permalink
Fix bootstrapping mechanism (#278)
Browse files Browse the repository at this point in the history
Co-authored-by: Binyang Li <[email protected]>
Co-authored-by: Pashupati Kumar <[email protected]>
  • Loading branch information
3 people authored Mar 27, 2024
1 parent bc465ae commit 5ba6ce0
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 50 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ find_package(Threads REQUIRED)

add_library(mscclpp_obj OBJECT)
target_include_directories(mscclpp_obj
PRIVATE
SYSTEM PRIVATE
${GPU_INCLUDE_DIRS}
${IBVERBS_INCLUDE_DIRS}
${NUMA_INCLUDE_DIRS})
Expand Down
8 changes: 4 additions & 4 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class Bootstrap {
/// A native implementation of the bootstrap using TCP sockets.
class TcpBootstrap : public Bootstrap {
public:
/// Create a random unique ID.
/// @return The created unique ID.
static UniqueId createUniqueId();

/// Constructor.
/// @param rank The rank of the process.
/// @param nRanks The total number of ranks.
Expand All @@ -59,10 +63,6 @@ class TcpBootstrap : public Bootstrap {
/// Destructor.
~TcpBootstrap();

/// Create a random unique ID and store it in the @ref TcpBootstrap.
/// @return The created unique ID.
UniqueId createUniqueId();

/// Return the unique ID stored in the @ref TcpBootstrap.
/// @return The unique ID stored in the @ref TcpBootstrap.
UniqueId getUniqueId() const;
Expand Down
4 changes: 2 additions & 2 deletions python/mscclpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ FetchContent_MakeAvailable(nanobind)
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp)
nanobind_add_module(mscclpp_py ${SOURCES})
set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
target_link_libraries(mscclpp_py PRIVATE ${GPU_LIBRARIES} mscclpp_static)
target_include_directories(mscclpp_py PRIVATE ${GPU_INCLUDE_DIRS})
target_link_libraries(mscclpp_py PRIVATE mscclpp_static ${GPU_LIBRARIES})
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
2 changes: 1 addition & 1 deletion python/mscclpp/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void register_core(nb::module_& m) {
.def_static(
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
nb::arg("nRanks"))
.def("create_unique_id", &TcpBootstrap::createUniqueId)
.def_static("create_unique_id", &TcpBootstrap::createUniqueId)
.def("get_unique_id", &TcpBootstrap::getUniqueId)
.def("initialize", static_cast<void (TcpBootstrap::*)(UniqueId, int64_t)>(&TcpBootstrap::initialize),
nb::call_guard<nb::gil_scoped_release>(), nb::arg("uniqueId"), nb::arg("timeoutSec") = 30)
Expand Down
4 changes: 2 additions & 2 deletions python/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ FetchContent_MakeAvailable(nanobind)
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp)
nanobind_add_module(mscclpp_py_test ${SOURCES})
set_target_properties(mscclpp_py_test PROPERTIES OUTPUT_NAME _ext)
target_link_libraries(mscclpp_py_test PRIVATE ${GPU_LIBRARIES} mscclpp_static)
target_include_directories(mscclpp_py_test PRIVATE ${GPU_INCLUDE_DIRS})
target_link_libraries(mscclpp_py_test PRIVATE mscclpp_static ${GPU_LIBRARIES})
target_include_directories(mscclpp_py_test SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
86 changes: 55 additions & 31 deletions src/bootstrap/bootstrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is

class TcpBootstrap::Impl {
public:
static UniqueId createUniqueId();
static UniqueId getUniqueId(const UniqueIdInternal& uniqueId);

Impl(int rank, int nRanks);
~Impl();
void initialize(const UniqueId& uniqueId, int64_t timeoutSec);
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec);
void establishConnections(int64_t timeoutSec);
UniqueId createUniqueId();
UniqueId getUniqueId() const;
int getRank();
int getNranks();
Expand All @@ -99,7 +101,6 @@ class TcpBootstrap::Impl {
std::unique_ptr<uint32_t> abortFlagStorage_;
volatile uint32_t* abortFlag_;
std::thread rootThread_;
char netIfName_[MAX_IF_NAME_SIZE + 1];
SocketAddress netIfAddr_;
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerSendSockets_;
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerRecvSockets_;
Expand All @@ -110,15 +111,33 @@ class TcpBootstrap::Impl {
std::shared_ptr<Socket> getPeerSendSocket(int peer, int tag);
std::shared_ptr<Socket> getPeerRecvSocket(int peer, int tag);

static void assignPortToUniqueId(UniqueIdInternal& uniqueId);
static void netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr);

void bootstrapCreateRoot();
void bootstrapRoot();
void getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
std::vector<SocketAddress>& rankAddressesRoot, int& rank);
void sendHandleToPeer(int peer, const std::vector<SocketAddress>& rankAddresses,
const std::vector<SocketAddress>& rankAddressesRoot);
void netInit(std::string ipPortPair, std::string interface);
};

UniqueId TcpBootstrap::Impl::createUniqueId() {
UniqueIdInternal uniqueId;
SocketAddress netIfAddr;
netInit("", "", netIfAddr);
getRandomData(&uniqueId.magic, sizeof(uniqueId_.magic));
std::memcpy(&uniqueId.addr, &netIfAddr, sizeof(SocketAddress));
assignPortToUniqueId(uniqueId);
return getUniqueId(uniqueId);
}

UniqueId TcpBootstrap::Impl::getUniqueId(const UniqueIdInternal& uniqueId) {
UniqueId ret;
std::memcpy(&ret, &uniqueId, sizeof(uniqueId));
return ret;
}

TcpBootstrap::Impl::Impl(int rank, int nRanks)
: rank_(rank),
nRanks_(nRanks),
Expand All @@ -128,29 +147,26 @@ TcpBootstrap::Impl::Impl(int rank, int nRanks)
abortFlagStorage_(new uint32_t(0)),
abortFlag_(abortFlagStorage_.get()) {}

UniqueId TcpBootstrap::Impl::getUniqueId() const {
UniqueId ret;
std::memcpy(&ret, &uniqueId_, sizeof(uniqueId_));
return ret;
}

UniqueId TcpBootstrap::Impl::createUniqueId() {
netInit("", "");
getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic));
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
bootstrapCreateRoot();
return getUniqueId();
}
UniqueId TcpBootstrap::Impl::getUniqueId() const { return getUniqueId(uniqueId_); }

int TcpBootstrap::Impl::getRank() { return rank_; }

int TcpBootstrap::Impl::getNranks() { return nRanks_; }

void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec) {
netInit("", "");
if (!netInitialized) {
netInit("", "", netIfAddr_);
netInitialized = true;
}

std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));
if (rank_ == 0) {
bootstrapCreateRoot();
}

char line[MAX_IF_NAME_SIZE + 1];
SocketToString(&uniqueId_.addr, line);
INFO(MSCCLPP_INIT, "rank %d nranks %d - connecting to %s", rank_, nRanks_, line);
establishConnections(timeoutSec);
}

Expand All @@ -170,7 +186,10 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t tim
ipPortPair = ifIpPortTrio.substr(ipPortPair.find_first_of(':') + 1);
}

netInit(ipPortPair, interface);
if (!netInitialized) {
netInit(ipPortPair, interface, netIfAddr_);
netInitialized = true;
}

uniqueId_.magic = 0xdeadbeef;
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress));
Expand Down Expand Up @@ -230,9 +249,15 @@ void TcpBootstrap::Impl::sendHandleToPeer(int peer, const std::vector<SocketAddr
netSend(&sock, &rankAddresses[next], sizeof(SocketAddress));
}

void TcpBootstrap::Impl::assignPortToUniqueId(UniqueIdInternal& uniqueId) {
std::unique_ptr<Socket> socket = std::make_unique<Socket>(&uniqueId.addr, uniqueId.magic, SocketTypeBootstrap);
socket->bind();
uniqueId.addr = socket->getAddr();
}

void TcpBootstrap::Impl::bootstrapCreateRoot() {
listenSockRoot_ = std::make_unique<Socket>(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0);
listenSockRoot_->listen();
listenSockRoot_->bindAndListen();
uniqueId_.addr = listenSockRoot_->getAddr();

rootThread_ = std::thread([this]() {
Expand Down Expand Up @@ -279,34 +304,33 @@ void TcpBootstrap::Impl::bootstrapRoot() {
TRACE(MSCCLPP_INIT, "DONE");
}

void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface) {
if (netInitialized) return;
void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr) {
char netIfName[MAX_IF_NAME_SIZE + 1];
if (!ipPortPair.empty()) {
if (interface != "") {
// we know the <interface>
int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1, interface.c_str());
int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1, interface.c_str());
if (ret <= 0) throw Error("NET/Socket : No interface named " + interface + " found.", ErrorCode::InternalError);
} else {
// we do not know the <interface> try to match it next
SocketAddress remoteAddr;
SocketGetAddrFromString(&remoteAddr, ipPortPair.c_str());
if (FindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
if (FindInterfaceMatchSubnet(netIfName, &netIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
throw Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError);
}
}

} else {
int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1);
int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1);
if (ret <= 0) {
throw Error("TcpBootstrap : no socket interface found", ErrorCode::InternalError);
}
}

char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
std::sprintf(line, " %s:", netIfName_);
SocketToString(&netIfAddr_, line + strlen(line));
std::sprintf(line, " %s:", netIfName);
SocketToString(&netIfAddr, line + strlen(line));
INFO(MSCCLPP_INIT, "TcpBootstrap : Using%s", line);
netInitialized = true;
}

#define TIMEOUT(__exp) \
Expand Down Expand Up @@ -345,13 +369,13 @@ void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) {
uint64_t magic = uniqueId_.magic;
// Create socket for other ranks to contact me
listenSock_ = std::make_unique<Socket>(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_);
listenSock_->listen();
listenSock_->bindAndListen();
info.extAddressListen = listenSock_->getAddr();

{
// Create socket for root to contact me
Socket lsock(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_);
lsock.listen();
lsock.bindAndListen();
info.extAddressListenRoot = lsock.getAddr();

// stagger connection times to avoid an overload of the root
Expand Down Expand Up @@ -486,9 +510,9 @@ void TcpBootstrap::Impl::close() {
peerRecvSockets_.clear();
}

MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique<Impl>(rank, nRanks); }
MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return Impl::createUniqueId(); }

MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return pimpl_->createUniqueId(); }
MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique<Impl>(rank, nRanks); }

MSCCLPP_API_CPP UniqueId TcpBootstrap::getUniqueId() const { return pimpl_->getUniqueId(); }

Expand Down
6 changes: 5 additions & 1 deletion src/bootstrap/socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ Socket::Socket(const SocketAddress* addr, uint64_t magic, enum SocketType type,

Socket::~Socket() { close(); }

void Socket::listen() {
void Socket::bind() {
if (fd_ == -1) {
throw Error("file descriptor is -1", ErrorCode::InvalidUsage);
}
Expand Down Expand Up @@ -433,7 +433,11 @@ void Socket::listen() {
if (::getsockname(fd_, &addr_.sa, &size) != 0) {
throw SysError("getsockname failed", errno);
}
state_ = SocketStateBound;
}

void Socket::bindAndListen() {
bind();
#ifdef ENABLE_TRACE
char line[SOCKET_NAME_MAXLEN + 1];
TRACE(MSCCLPP_INIT | MSCCLPP_NET, "Listening on socket %s", SocketToString(&addr_, line));
Expand Down
1 change: 1 addition & 0 deletions src/fifo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ MSCCLPP_API_CPP void Fifo::pop() {
MSCCLPP_API_CPP void Fifo::flushTail(bool sync) {
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure that the fifo can
// make progress even if there is no request mscclppSync. However, mscclppSync type is for flush request.
AvoidCudaGraphCaptureGuard cgcGuard;
MSCCLPP_CUDATHROW(cudaMemcpyAsync(pimpl->tailReplica.get(), &pimpl->hostTail, sizeof(uint64_t),
cudaMemcpyHostToDevice, pimpl->stream));
if (sync) {
Expand Down
12 changes: 7 additions & 5 deletions src/include/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ enum SocketState {
SocketStateConnecting = 4,
SocketStateConnectPolling = 5,
SocketStateConnected = 6,
SocketStateReady = 7,
SocketStateClosed = 8,
SocketStateError = 9,
SocketStateNum = 10
SocketStateBound = 7,
SocketStateReady = 8,
SocketStateClosed = 9,
SocketStateError = 10,
SocketStateNum = 11
};

enum SocketType {
Expand All @@ -62,7 +63,8 @@ class Socket {
enum SocketType type = SocketTypeUnknown, volatile uint32_t* abortFlag = nullptr, int asyncFlag = 0);
~Socket();

void listen();
void bind();
void bindAndListen();
void connect(int64_t timeout = -1);
void accept(const Socket* listenSocket, int64_t timeout = -1);
void send(void* ptr, int size);
Expand Down
2 changes: 1 addition & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ find_package(MPI)

set(TEST_LIBS_COMMON mscclpp ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads)
set(TEST_LIBS_GTEST GTest::gtest_main GTest::gmock_main)
set(TEST_INC_COMMON PRIVATE ${PROJECT_SOURCE_DIR}/include ${GPU_INCLUDE_DIRS})
set(TEST_INC_COMMON PRIVATE ${PROJECT_SOURCE_DIR}/include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
set(TEST_INC_INTERNAL PRIVATE ${PROJECT_SOURCE_DIR}/src/include)

if(USE_ROCM)
Expand Down
2 changes: 1 addition & 1 deletion test/mp_unit/bootstrap_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ TEST_F(BootstrapTest, ResumeWithId) {
// This test may take a few minutes.
bootstrapTestTimer.set(300);

for (int i = 0; i < 3000; ++i) {
for (int i = 0; i < 10; ++i) {
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
mscclpp::UniqueId id;
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
Expand Down
2 changes: 1 addition & 1 deletion test/unit/socket_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ TEST(Socket, ListenAndConnect) {
ASSERT_NO_THROW(mscclpp::SocketGetAddrFromString(&listenAddr, ipPortPair.c_str()));

mscclpp::Socket listenSock(&listenAddr);
listenSock.listen();
listenSock.bindAndListen();

std::thread clientThread([&listenAddr]() {
mscclpp::Socket sock(&listenAddr);
Expand Down

0 comments on commit 5ba6ce0

Please sign in to comment.