diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d088fcd..3b33a6e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,6 +93,7 @@ install(TARGETS mscclpp_static # Tests if (BUILD_TESTS) + enable_testing() # Called here to allow ctest from the build directory add_subdirectory(test) endif() diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 9391106c..306398fb 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -113,9 +113,10 @@ class TcpBootstrap : public Bootstrap { void barrier() override; private: - /// Implementation class for @ref TcpBootstrap. - class Impl; - /// Pointer to the implementation class for @ref TcpBootstrap. + // The interal implementation. + struct Impl; + + // Pointer to the internal implementation. std::unique_ptr pimpl_; }; @@ -306,23 +307,15 @@ std::string getIBDeviceName(Transport ibTransport); /// @return The InfiniBand transport associated with the specified device name. Transport getIBTransportByDeviceName(const std::string& ibDeviceName); -class Communicator; +class Context; class Connection; -/// Represents a block of memory that has been registered to a @ref Communicator. +/// Represents a block of memory that has been registered to a @ref Context. class RegisteredMemory { - protected: - struct Impl; - public: /// Default constructor. RegisteredMemory() = default; - /// Constructor that takes a shared pointer to an implementation object. - /// - /// @param pimpl A shared pointer to an implementation object. - RegisteredMemory(std::shared_ptr pimpl); - /// Destructor. ~RegisteredMemory(); @@ -336,11 +329,6 @@ class RegisteredMemory { /// @return The size of the memory block. size_t size(); - /// Get the rank of the process that owns the memory block. - /// - /// @return The rank of the process that owns the memory block. - int rank(); - /// Get the transport flags associated with the memory block. /// /// @return The transport flags associated with the memory block. @@ -357,14 +345,54 @@ class RegisteredMemory { /// @return A deserialized RegisteredMemory object. static RegisteredMemory deserialize(const std::vector& data); + private: + // The interal implementation. + struct Impl; + + // Internal constructor. + RegisteredMemory(std::shared_ptr pimpl); + + // Pointer to the internal implementation. A shared_ptr is used since RegisteredMemory is immutable. + std::shared_ptr pimpl_; + + friend class Context; friend class Connection; - friend class IBConnection; - friend class Communicator; +}; + +/// Represents one end of a connection. +class Endpoint { + public: + /// Default constructor. + Endpoint() = default; + + /// Get the transport used. + /// + /// @return The transport used. + Transport transport(); + + /// Serialize the Endpoint object to a vector of characters. + /// + /// @return A vector of characters representing the serialized Endpoint object. + std::vector serialize(); + + /// Deserialize a Endpoint object from a vector of characters. + /// + /// @param data A vector of characters representing a serialized Endpoint object. + /// @return A deserialized Endpoint object. + static Endpoint deserialize(const std::vector& data); private: - // A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated - // lazily. - std::shared_ptr pimpl; + // The interal implementation. + struct Impl; + + // Internal constructor. + Endpoint(std::shared_ptr pimpl); + + // Pointer to the internal implementation. A shared_ptr is used since Endpoint is immutable. + std::shared_ptr pimpl_; + + friend class Context; + friend class Connection; }; /// Represents a connection between two processes. @@ -391,16 +419,6 @@ class Connection { /// Flush any pending writes to the remote process. virtual void flush(int64_t timeoutUsec = 3e7) = 0; - /// Get the rank of the remote process. - /// - /// @return The rank of the remote process. - virtual int remoteRank() = 0; - - /// Get the tag associated with the connection. - /// - /// @return The tag associated with the connection. - virtual int tag() = 0; - /// Get the transport used by the local process. /// /// @return The transport used by the local process. @@ -412,11 +430,89 @@ class Connection { virtual Transport remoteTransport() = 0; protected: - /// Get the implementation object associated with a @ref RegisteredMemory object. + // Internal methods for getting implementation pointers. + static std::shared_ptr getImpl(RegisteredMemory& memory); + static std::shared_ptr getImpl(Endpoint& memory); +}; + +/// Used to configure an endpoint. +struct EndpointConfig { + static const int DefaultMaxCqSize = 1024; + static const int DefaultMaxCqPollNum = 1; + static const int DefaultMaxSendWr = 8192; + static const int DefaultMaxWrPerSend = 64; + + Transport transport; + int ibMaxCqSize = DefaultMaxCqSize; + int ibMaxCqPollNum = DefaultMaxCqPollNum; + int ibMaxSendWr = DefaultMaxSendWr; + int ibMaxWrPerSend = DefaultMaxWrPerSend; + + /// Default constructor. Sets transport to Transport::Unknown. + EndpointConfig() : transport(Transport::Unknown) {} + + /// Constructor that takes a transport and sets the other fields to their default values. /// - /// @param memory The @ref RegisteredMemory object. - /// @return A shared pointer to the implementation object. - static std::shared_ptr getRegisteredMemoryImpl(RegisteredMemory& memory); + /// @param transport The transport to use. + EndpointConfig(Transport transport) : transport(transport) {} +}; + +/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases +/// where the process group abstraction offered by @ref Communicator is not suitable, e.g., ephemeral client-server +/// connections. Correct use of this class requires external synchronization when finalizing connections with the +/// @ref connect() method. +/// +/// As an example, a client-server scenario where the server will write to the client might proceed as follows: +/// 1. The client creates an endpoint with @ref createEndpoint() and sends it to the server. +/// 2. The server receives the client endpoint, creates its own endpoint with @ref createEndpoint(), sends it to the +/// client, and creates a connection with @ref connect(). +/// 4. The client receives the server endpoint, creates a connection with @ref connect() and sends a +/// @ref RegisteredMemory to the server. +/// 5. The server receives the @ref RegisteredMemory and writes to it using the previously created connection. +/// The client waiting to create a connection before sending the @ref RegisteredMemory ensures that the server can not +/// write to the @ref RegisteredMemory before the connection is established. +/// +/// While some transports may have more relaxed implementation behavior, this should not be relied upon. +class Context { + public: + /// Create a context. + Context(); + + /// Destroy the context. + ~Context(); + + /// Register a region of GPU memory for use in this context. + /// + /// @param ptr Base pointer to the memory. + /// @param size Size of the memory region in bytes. + /// @param transports Transport flags. + /// @return RegisteredMemory A handle to the buffer. + RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports); + + /// Create an endpoint for establishing connections. + /// + /// @param config The configuration for the endpoint. + /// @return The newly created endpoint. + Endpoint createEndpoint(EndpointConfig config); + + /// Establish a connection between two endpoints. While this method immediately returns a connection object, the + /// connection is only safe to use after the corresponding connection on the remote endpoint has been established. + /// This method must be called on both endpoints to establish a connection. + /// + /// @param localEndpoint The local endpoint. + /// @param remoteEndpoint The remote endpoint. + /// @return std::shared_ptr A shared pointer to the connection. + std::shared_ptr connect(Endpoint localEndpoint, Endpoint remoteEndpoint); + + private: + // The interal implementation. + struct Impl; + + // Pointer to the internal implementation. + std::unique_ptr pimpl_; + + friend class RegisteredMemory; + friend class Endpoint; }; /// A base class for objects that can be set up during @ref Communicator::setup(). @@ -482,14 +578,12 @@ class NonblockingFuture { /// 6. All done; use connections and registered memories to build channels. /// class Communicator { - protected: - struct Impl; - public: /// Initializes the communicator with a given bootstrap implementation. /// /// @param bootstrap An implementation of the Bootstrap that the communicator will use. - Communicator(std::shared_ptr bootstrap); + /// @param context An optional context to use for the communicator. If not provided, a new context will be created. + Communicator(std::shared_ptr bootstrap, std::shared_ptr context = nullptr); /// Destroy the communicator. ~Communicator(); @@ -499,7 +593,12 @@ class Communicator { /// @return std::shared_ptr The bootstrap held by this communicator. std::shared_ptr bootstrap(); - /// Register a region of GPU memory for use in this communicator. + /// Returns the context held by this communicator. + /// + /// @return std::shared_ptr The context held by this communicator. + std::shared_ptr context(); + + /// Register a region of GPU memory for use in this communicator's context. /// /// @param ptr Base pointer to the memory. /// @param size Size of the memory region in bytes. @@ -537,15 +636,22 @@ class Communicator { /// /// @param remoteRank The rank of the remote process. /// @param tag The tag of the connection for identifying it. - /// @param transport The type of transport to be used. - /// @param ibMaxCqSize The maximum number of completion queue entries for IB. Unused if transport is not IB. - /// @param ibMaxCqPollNum The maximum number of completion queue entries to poll for IB. Unused if transport is not - /// IB. - /// @param ibMaxSendWr The maximum number of outstanding send work requests for IB. Unused if transport is not IB. - /// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB. - /// @return std::shared_ptr A shared pointer to the connection. - std::shared_ptr connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024, - int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64); + /// @param config The configuration for the local endpoint. + /// @return NonblockingFuture>> A non-blocking future of shared pointer + /// to the connection. + NonblockingFuture> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig); + + /// Get the remote rank a connection is connected to. + /// + /// @param connection The connection to get the remote rank for. + /// @return The remote rank the connection is connected to. + int remoteRankOf(const Connection& connection); + + /// Get the tag a connection was made with. + /// + /// @param connection The connection to get the tag for. + /// @return The tag the connection was made with. + int tagOf(const Connection& connection); /// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called. /// @@ -559,12 +665,12 @@ class Communicator { /// that have been registered after the (n-1)-th call. void setup(); - friend class RegisteredMemory::Impl; - friend class IBConnection; - private: - /// Unique pointer to the implementation of the Communicator class. - std::unique_ptr pimpl; + // The interal implementation. + struct Impl; + + // Pointer to the internal implementation. + std::unique_ptr pimpl_; }; /// A constant TransportFlags object representing no transports. diff --git a/python/examples/bootstrap.py b/python/examples/bootstrap.py index ca0a521c..71539e0b 100644 --- a/python/examples/bootstrap.py +++ b/python/examples/bootstrap.py @@ -47,14 +47,16 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service): remote_memories.append(remote_mem) comm.setup() + connections = [conn.get() for conn in connections] + # Create simple proxy channels for i, conn in enumerate(connections): proxy_channel = mscclpp.SimpleProxyChannel( - proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(conn)), + proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)), proxy_service.add_memory(remote_memories[i].get()), proxy_service.add_memory(reg_mem), ) - simple_proxy_channels.append(mscclpp.device_handle(proxy_channel)) + simple_proxy_channels.append(proxy_channel.device_handle()) comm.setup() # Create sm channels @@ -66,7 +68,7 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service): for i, conn in enumerate(sm_semaphores): sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr) sm_channels.append(sm_chan) - return simple_proxy_channels, [mscclpp.device_handle(sm_chan) for sm_chan in sm_channels] + return simple_proxy_channels, [sm_chan.device_handle() for sm_chan in sm_channels] def run(rank, args): diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index 21659786..60ceb96c 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -107,7 +107,6 @@ void register_core(nb::module_& m) { .def(nb::init<>()) .def("data", &RegisteredMemory::data) .def("size", &RegisteredMemory::size) - .def("rank", &RegisteredMemory::rank) .def("transports", &RegisteredMemory::transports) .def("serialize", &RegisteredMemory::serialize) .def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data")); @@ -122,16 +121,42 @@ void register_core(nb::module_& m) { }, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue")) .def("flush", &Connection::flush, nb::arg("timeoutUsec") = (int64_t)3e7) - .def("remote_rank", &Connection::remoteRank) - .def("tag", &Connection::tag) .def("transport", &Connection::transport) .def("remote_transport", &Connection::remoteTransport); + nb::class_(m, "Endpoint") + .def("transport", &Endpoint::transport) + .def("serialize", &Endpoint::serialize) + .def_static("deserialize", &Endpoint::deserialize, nb::arg("data")); + + nb::class_(m, "EndpointConfig") + .def(nb::init<>()) + .def(nb::init_implicit(), nb::arg("transport")) + .def_rw("transport", &EndpointConfig::transport) + .def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize) + .def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum) + .def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr) + .def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend); + + nb::class_(m, "Context") + .def(nb::init<>()) + .def( + "register_memory", + [](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) { + return self->registerMemory((void*)ptr, size, transports); + }, + nb::arg("ptr"), nb::arg("size"), nb::arg("transports")) + .def("create_endpoint", &Context::createEndpoint, nb::arg("config")) + .def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint")); + def_nonblocking_future(m, "RegisteredMemory"); + def_nonblocking_future>(m, "shared_ptr_Connection"); nb::class_(m, "Communicator") - .def(nb::init>(), nb::arg("bootstrap")) + .def(nb::init, std::shared_ptr>(), nb::arg("bootstrap"), + nb::arg("context") = nullptr) .def("bootstrap", &Communicator::bootstrap) + .def("context", &Communicator::context) .def( "register_memory", [](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) { @@ -142,8 +167,9 @@ void register_core(nb::module_& m) { nb::arg("tag")) .def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag")) .def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"), - nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1, - nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64) + nb::arg("localConfig")) + .def("remote_rank_of", &Communicator::remoteRankOf) + .def("tag_of", &Communicator::tagOf) .def("setup", &Communicator::setup); } diff --git a/python/test/mscclpp_group.py b/python/test/mscclpp_group.py index b2444107..7a7c7b01 100644 --- a/python/test/mscclpp_group.py +++ b/python/test/mscclpp_group.py @@ -78,6 +78,7 @@ def make_connection(self, remote_ranks: list[int], transport: Transport) -> dict for rank in remote_ranks: connections[rank] = self.communicator.connect_on_setup(rank, 0, transport) self.communicator.setup() + connections = {rank: connections[rank].get() for rank in connections} return connections def register_tensor_with_connections( diff --git a/src/communicator.cc b/src/communicator.cc index cc032355..4415b395 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -3,58 +3,30 @@ #include "communicator.hpp" -#include -#include -#include - #include "api.h" -#include "connection.hpp" -#include "debug.h" -#include "registered_memory.hpp" -#include "utils_internal.hpp" namespace mscclpp { -Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_(bootstrap) { - rankToHash_.resize(bootstrap->getNranks()); - auto hostHash = getHostHash(); - INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash); - rankToHash_[bootstrap->getRank()] = hostHash; - bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t)); - - MSCCLPP_CUDATHROW(cudaStreamCreateWithFlags(&ipcStream_, cudaStreamNonBlocking)); -} - -Communicator::Impl::~Impl() { - ibContexts_.clear(); - - cudaStreamDestroy(ipcStream_); -} - -IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) { - // Find IB context or create it - auto it = ibContexts_.find(ibTransport); - if (it == ibContexts_.end()) { - auto ibDev = getIBDeviceName(ibTransport); - ibContexts_[ibTransport] = std::make_unique(ibDev); - return ibContexts_[ibTransport].get(); +Communicator::Impl::Impl(std::shared_ptr bootstrap, std::shared_ptr context) + : bootstrap_(bootstrap) { + if (!context) { + context_ = std::make_shared(); } else { - return it->second.get(); + context_ = context; } } -cudaStream_t Communicator::Impl::getIpcStream() { return ipcStream_; } - MSCCLPP_API_CPP Communicator::~Communicator() = default; -MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap) - : pimpl(std::make_unique(bootstrap)) {} +MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap, std::shared_ptr context) + : pimpl_(std::make_unique(bootstrap, context)) {} + +MSCCLPP_API_CPP std::shared_ptr Communicator::bootstrap() { return pimpl_->bootstrap_; } -MSCCLPP_API_CPP std::shared_ptr Communicator::bootstrap() { return pimpl->bootstrap_; } +MSCCLPP_API_CPP std::shared_ptr Communicator::context() { return pimpl_->context_; } MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) { - return RegisteredMemory( - std::make_shared(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl)); + return context()->registerMemory(ptr, size, transports); } struct MemorySender : public Setuppable { @@ -94,53 +66,62 @@ MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSe return NonblockingFuture(memoryReceiver->memoryPromise_.get_future()); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int remoteRank, int tag, Transport transport, - int ibMaxCqSize /*=1024*/, - int ibMaxCqPollNum /*=1*/, - int ibMaxSendWr /*=8192*/, - int ibMaxWrPerSend /*=64*/) { - std::shared_ptr conn; - if (transport == Transport::CudaIpc) { - // sanity check: make sure the IPC connection is being made within a node - if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) { - std::stringstream ss; - ss << "Cuda IPC connection can only be made within a node: " << remoteRank << "(" << std::hex - << pimpl->rankToHash_[remoteRank] << ") != " << pimpl->bootstrap_->getRank() << "(" << std::hex - << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"; - throw mscclpp::Error(ss.str(), ErrorCode::InvalidUsage); - } - auto cudaIpcConn = std::make_shared(remoteRank, tag, pimpl->getIpcStream()); - conn = cudaIpcConn; - INFO(MSCCLPP_P2P, "Cuda IPC connection between rank %d(%lx) and remoteRank %d(%lx) created", - pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank, - pimpl->rankToHash_[remoteRank]); - } else if (AllIBTransports.has(transport)) { - auto ibConn = std::make_shared(remoteRank, tag, transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr, - ibMaxWrPerSend, *pimpl); - conn = ibConn; - INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created", - pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], - getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]); - } else { - throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError); +struct Communicator::Impl::Connector : public Setuppable { + Connector(Communicator& comm, Communicator::Impl& commImpl_, int remoteRank, int tag, EndpointConfig localConfig) + : comm_(comm), + commImpl_(commImpl_), + remoteRank_(remoteRank), + tag_(tag), + localEndpoint_(comm.context()->createEndpoint(localConfig)) {} + + void beginSetup(std::shared_ptr bootstrap) override { + bootstrap->send(localEndpoint_.serialize(), remoteRank_, tag_); + } + + void endSetup(std::shared_ptr bootstrap) override { + std::vector data; + bootstrap->recv(data, remoteRank_, tag_); + auto remoteEndpoint = Endpoint::deserialize(data); + auto connection = comm_.context()->connect(localEndpoint_, remoteEndpoint); + commImpl_.connectionInfos_[connection.get()] = {remoteRank_, tag_}; + connectionPromise_.set_value(connection); } - pimpl->connections_.push_back(conn); - onSetup(conn); - return conn; + + std::promise> connectionPromise_; + Communicator& comm_; + Communicator::Impl& commImpl_; + int remoteRank_; + int tag_; + Endpoint localEndpoint_; +}; + +MSCCLPP_API_CPP NonblockingFuture> Communicator::connectOnSetup( + int remoteRank, int tag, EndpointConfig localConfig) { + auto connector = std::make_shared(*this, *pimpl_, remoteRank, tag, localConfig); + onSetup(connector); + return NonblockingFuture>(connector->connectionPromise_.get_future()); +} + +MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) { + return pimpl_->connectionInfos_.at(&connection).remoteRank; +} + +MSCCLPP_API_CPP int Communicator::tagOf(const Connection& connection) { + return pimpl_->connectionInfos_.at(&connection).tag; } MSCCLPP_API_CPP void Communicator::onSetup(std::shared_ptr setuppable) { - pimpl->toSetup_.push_back(setuppable); + pimpl_->toSetup_.push_back(setuppable); } MSCCLPP_API_CPP void Communicator::setup() { - for (auto& setuppable : pimpl->toSetup_) { - setuppable->beginSetup(pimpl->bootstrap_); + for (auto& setuppable : pimpl_->toSetup_) { + setuppable->beginSetup(pimpl_->bootstrap_); } - for (auto& setuppable : pimpl->toSetup_) { - setuppable->endSetup(pimpl->bootstrap_); + for (auto& setuppable : pimpl_->toSetup_) { + setuppable->endSetup(pimpl_->bootstrap_); } - pimpl->toSetup_.clear(); + pimpl_->toSetup_.clear(); } } // namespace mscclpp diff --git a/src/connection.cc b/src/connection.cc index 112e1178..d47cf3b7 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -4,8 +4,10 @@ #include "connection.hpp" #include +#include #include "debug.h" +#include "endpoint.hpp" #include "infiniband/verbs.h" #include "npkit/npkit.h" #include "registered_memory.hpp" @@ -20,24 +22,29 @@ void validateTransport(RegisteredMemory mem, Transport transport) { // Connection -std::shared_ptr Connection::getRegisteredMemoryImpl(RegisteredMemory& memory) { - return memory.pimpl; -} - -// ConnectionBase - -ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {} +std::shared_ptr Connection::getImpl(RegisteredMemory& memory) { return memory.pimpl_; } -int ConnectionBase::remoteRank() { return remoteRank_; } - -int ConnectionBase::tag() { return tag_; } +std::shared_ptr Connection::getImpl(Endpoint& memory) { return memory.pimpl_; } // CudaIpcConnection -CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag, cudaStream_t stream) - : ConnectionBase(remoteRank, tag), stream_(stream) {} - -CudaIpcConnection::~CudaIpcConnection() {} +CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream) + : stream_(stream) { + if (localEndpoint.transport() != Transport::CudaIpc) { + throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage); + } + if (remoteEndpoint.transport() != Transport::CudaIpc) { + throw mscclpp::Error("Cuda IPC connection can only be made to a Cuda IPC endpoint", ErrorCode::InvalidUsage); + } + // sanity check: make sure the IPC connection is being made within a node + if (getImpl(remoteEndpoint)->hostHash_ != getImpl(localEndpoint)->hostHash_) { + std::stringstream ss; + ss << "Cuda IPC connection can only be made within a node: " << std::hex << getImpl(remoteEndpoint)->hostHash_ + << " != " << std::hex << getImpl(localEndpoint)->hostHash_; + throw mscclpp::Error(ss.str(), ErrorCode::InvalidUsage); + } + INFO(MSCCLPP_P2P, "Cuda IPC connection created"); +} Transport CudaIpcConnection::transport() { return Transport::CudaIpc; } @@ -77,27 +84,23 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) { AvoidCudaGraphCaptureGuard guard; MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream_)); // npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT); - INFO(MSCCLPP_P2P, "CudaIpcConnection flushing connection to remote rank %d", remoteRank()); + INFO(MSCCLPP_P2P, "CudaIpcConnection flushing connection"); } // IBConnection -IBConnection::IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr, - int maxWrPerSend, Communicator::Impl& commImpl) - : ConnectionBase(remoteRank, tag), - transport_(transport), - remoteTransport_(Transport::Unknown), +IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context) + : transport_(localEndpoint.transport()), + remoteTransport_(remoteEndpoint.transport()), numSignaledSends(0), dummyAtomicSource_(std::make_unique(0)) { - qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend); - dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared( - dummyAtomicSource_.get(), sizeof(uint64_t), commImpl.bootstrap_->getRank(), transport, commImpl)); - validateTransport(dummyAtomicSourceMem_, transport); - dstTransportInfo_ = getRegisteredMemoryImpl(dummyAtomicSourceMem_)->getTransportInfo(transport); - - if (!dstTransportInfo_.ibLocal) { - throw Error("dummyAtomicSource_ is remote, which is not supported", ErrorCode::InternalError); - } + qp = getImpl(localEndpoint)->ibQp_; + qp->rtr(getImpl(remoteEndpoint)->ibQpInfo_); + qp->rts(); + dummyAtomicSourceMem_ = context.registerMemory(dummyAtomicSource_.get(), sizeof(uint64_t), transport_); + validateTransport(dummyAtomicSourceMem_, transport_); + dstTransportInfo_ = getImpl(dummyAtomicSourceMem_)->getTransportInfo(transport_); + INFO(MSCCLPP_NET, "IB connection via %s created", getIBDeviceName(transport_).c_str()); } Transport IBConnection::transport() { return transport_; } @@ -109,11 +112,11 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem validateTransport(dst, remoteTransport()); validateTransport(src, transport()); - auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport()); + auto dstTransportInfo = getImpl(dst)->getTransportInfo(remoteTransport()); if (dstTransportInfo.ibLocal) { throw Error("dst is local, which is not supported", ErrorCode::InvalidUsage); } - auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(transport()); + auto srcTransportInfo = getImpl(src)->getTransportInfo(transport()); if (!srcTransportInfo.ibLocal) { throw Error("src is remote, which is not supported", ErrorCode::InvalidUsage); } @@ -133,7 +136,7 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) { validateTransport(dst, remoteTransport()); - auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport()); + auto dstTransportInfo = getImpl(dst)->getTransportInfo(remoteTransport()); if (dstTransportInfo.ibLocal) { throw Error("dst is local, which is not supported", ErrorCode::InvalidUsage); } @@ -173,31 +176,8 @@ void IBConnection::flush(int64_t timeoutUsec) { } } } - INFO(MSCCLPP_NET, "IBConnection flushing connection to remote rank %d", remoteRank()); + INFO(MSCCLPP_NET, "IBConnection flushing connection"); // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); } -void IBConnection::beginSetup(std::shared_ptr bootstrap) { - std::vector ibQpTransport; - std::copy_n(reinterpret_cast(&qp->getInfo()), sizeof(qp->getInfo()), std::back_inserter(ibQpTransport)); - std::copy_n(reinterpret_cast(&transport_), sizeof(transport_), std::back_inserter(ibQpTransport)); - - bootstrap->send(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag()); -} - -void IBConnection::endSetup(std::shared_ptr bootstrap) { - std::vector ibQpTransport(sizeof(IbQpInfo) + sizeof(Transport)); - bootstrap->recv(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag()); - - IbQpInfo qpInfo; - auto it = ibQpTransport.begin(); - std::copy_n(it, sizeof(qpInfo), reinterpret_cast(&qpInfo)); - it += sizeof(qpInfo); - std::copy_n(it, sizeof(remoteTransport_), reinterpret_cast(&remoteTransport_)); - it += sizeof(qpInfo); - - qp->rtr(qpInfo); - qp->rts(); -} - } // namespace mscclpp diff --git a/src/context.cc b/src/context.cc new file mode 100644 index 00000000..d04a8e32 --- /dev/null +++ b/src/context.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "context.hpp" + +#include "api.h" +#include "connection.hpp" +#include "debug.h" +#include "endpoint.hpp" +#include "registered_memory.hpp" + +namespace mscclpp { + +Context::Impl::Impl() : ipcStream_(cudaStreamNonBlocking) {} + +IbCtx* Context::Impl::getIbContext(Transport ibTransport) { + // Find IB context or create it + auto it = ibContexts_.find(ibTransport); + if (it == ibContexts_.end()) { + auto ibDev = getIBDeviceName(ibTransport); + ibContexts_[ibTransport] = std::make_unique(ibDev); + return ibContexts_[ibTransport].get(); + } else { + return it->second.get(); + } +} + +MSCCLPP_API_CPP Context::Context() : pimpl_(std::make_unique()) {} + +MSCCLPP_API_CPP Context::~Context() = default; + +MSCCLPP_API_CPP RegisteredMemory Context::registerMemory(void* ptr, size_t size, TransportFlags transports) { + return RegisteredMemory(std::make_shared(ptr, size, transports, *pimpl_)); +} + +MSCCLPP_API_CPP Endpoint Context::createEndpoint(EndpointConfig config) { + return Endpoint(std::make_shared(config, *pimpl_)); +} + +MSCCLPP_API_CPP std::shared_ptr Context::connect(Endpoint localEndpoint, Endpoint remoteEndpoint) { + std::shared_ptr conn; + if (localEndpoint.transport() == Transport::CudaIpc) { + if (remoteEndpoint.transport() != Transport::CudaIpc) { + throw mscclpp::Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage); + } + conn = std::make_shared(localEndpoint, remoteEndpoint, pimpl_->ipcStream_); + } else if (AllIBTransports.has(localEndpoint.transport())) { + if (!AllIBTransports.has(remoteEndpoint.transport())) { + throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage); + } + conn = std::make_shared(localEndpoint, remoteEndpoint, *this); + } else { + throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError); + } + pimpl_->connections_.push_back(conn); + return conn; +} + +} // namespace mscclpp diff --git a/src/endpoint.cc b/src/endpoint.cc new file mode 100644 index 00000000..dbc77389 --- /dev/null +++ b/src/endpoint.cc @@ -0,0 +1,52 @@ +#include "endpoint.hpp" + +#include + +#include "api.h" +#include "context.hpp" +#include "utils_internal.hpp" + +namespace mscclpp { + +Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl) + : transport_(config.transport), hostHash_(getHostHash()) { + if (AllIBTransports.has(transport_)) { + ibLocal_ = true; + ibQp_ = contextImpl.getIbContext(transport_) + ->createQp(config.ibMaxCqSize, config.ibMaxCqPollNum, config.ibMaxSendWr, 0, config.ibMaxWrPerSend); + ibQpInfo_ = ibQp_->getInfo(); + } +} + +MSCCLPP_API_CPP Transport Endpoint::transport() { return pimpl_->transport_; } + +MSCCLPP_API_CPP std::vector Endpoint::serialize() { + std::vector data; + std::copy_n(reinterpret_cast(&pimpl_->transport_), sizeof(pimpl_->transport_), std::back_inserter(data)); + std::copy_n(reinterpret_cast(&pimpl_->hostHash_), sizeof(pimpl_->hostHash_), std::back_inserter(data)); + if (AllIBTransports.has(pimpl_->transport_)) { + std::copy_n(reinterpret_cast(&pimpl_->ibQpInfo_), sizeof(pimpl_->ibQpInfo_), std::back_inserter(data)); + } + return data; +} + +MSCCLPP_API_CPP Endpoint Endpoint::deserialize(const std::vector& data) { + return Endpoint(std::make_shared(data)); +} + +Endpoint::Impl::Impl(const std::vector& serialization) { + auto it = serialization.begin(); + std::copy_n(it, sizeof(transport_), reinterpret_cast(&transport_)); + it += sizeof(transport_); + std::copy_n(it, sizeof(hostHash_), reinterpret_cast(&hostHash_)); + it += sizeof(hostHash_); + if (AllIBTransports.has(transport_)) { + ibLocal_ = false; + std::copy_n(it, sizeof(ibQpInfo_), reinterpret_cast(&ibQpInfo_)); + it += sizeof(ibQpInfo_); + } +} + +MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr pimpl) : pimpl_(pimpl) {} + +} // namespace mscclpp \ No newline at end of file diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 858a77ce..0f868b14 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -4,33 +4,29 @@ #ifndef MSCCL_COMMUNICATOR_HPP_ #define MSCCL_COMMUNICATOR_HPP_ -#include - #include #include -#include #include - -#include "ib.hpp" +#include namespace mscclpp { class ConnectionBase; +struct ConnectionInfo { + int remoteRank; + int tag; +}; + struct Communicator::Impl { - std::vector> connections_; - std::vector> toSetup_; - std::unordered_map> ibContexts_; - cudaStream_t ipcStream_; std::shared_ptr bootstrap_; - std::vector rankToHash_; - - Impl(std::shared_ptr bootstrap); + std::shared_ptr context_; + std::unordered_map connectionInfos_; + std::vector> toSetup_; - ~Impl(); + Impl(std::shared_ptr bootstrap, std::shared_ptr context); - IbCtx* getIbContext(Transport ibTransport); - cudaStream_t getIpcStream(); + struct Connector; }; } // namespace mscclpp diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 0475691c..d073d96b 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -9,31 +9,17 @@ #include #include "communicator.hpp" +#include "context.hpp" #include "ib.hpp" #include "registered_memory.hpp" namespace mscclpp { -// TODO: Add functionality to these classes for Communicator to do connectionSetup - -class ConnectionBase : public Connection, public Setuppable { - int remoteRank_; - int tag_; - - public: - ConnectionBase(int remoteRank, int tag); - - int remoteRank() override; - int tag() override; -}; - -class CudaIpcConnection : public ConnectionBase { +class CudaIpcConnection : public Connection { cudaStream_t stream_; public: - CudaIpcConnection(int remoteRank, int tag, cudaStream_t stream); - - ~CudaIpcConnection(); + CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream); Transport transport() override; @@ -46,7 +32,7 @@ class CudaIpcConnection : public ConnectionBase { void flush(int64_t timeoutUsec) override; }; -class IBConnection : public ConnectionBase { +class IBConnection : public Connection { Transport transport_; Transport remoteTransport_; IbQp* qp; @@ -56,8 +42,7 @@ class IBConnection : public ConnectionBase { mscclpp::TransportInfo dstTransportInfo_; public: - IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr, - int maxWrPerSend, Communicator::Impl& commImpl); + IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context); Transport transport() override; @@ -68,10 +53,6 @@ class IBConnection : public ConnectionBase { void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override; void flush(int64_t timeoutUsec) override; - - void beginSetup(std::shared_ptr bootstrap) override; - - void endSetup(std::shared_ptr bootstrap) override; }; } // namespace mscclpp diff --git a/src/include/context.hpp b/src/include/context.hpp new file mode 100644 index 00000000..11cc98d7 --- /dev/null +++ b/src/include/context.hpp @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCL_CONTEXT_HPP_ +#define MSCCL_CONTEXT_HPP_ + +#include +#include +#include +#include + +#include "ib.hpp" + +namespace mscclpp { + +struct Context::Impl { + std::vector> connections_; + std::unordered_map> ibContexts_; + CudaStreamWithFlags ipcStream_; + + Impl(); + + IbCtx* getIbContext(Transport ibTransport); +}; + +} // namespace mscclpp + +#endif // MSCCL_CONTEXT_HPP_ diff --git a/src/include/endpoint.hpp b/src/include/endpoint.hpp new file mode 100644 index 00000000..f246012c --- /dev/null +++ b/src/include/endpoint.hpp @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCL_ENDPOINT_HPP_ +#define MSCCL_ENDPOINT_HPP_ + +#include +#include + +#include "ib.hpp" + +namespace mscclpp { + +struct Endpoint::Impl { + Impl(EndpointConfig config, Context::Impl& contextImpl); + Impl(const std::vector& serialization); + + Transport transport_; + uint64_t hostHash_; + + // The following are only used for IB and are undefined for other transports. + bool ibLocal_; + IbQp* ibQp_; + IbQpInfo ibQpInfo_; +}; + +} // namespace mscclpp + +#endif // MSCCL_ENDPOINT_HPP_ diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 627960a8..3804bfd6 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -32,15 +32,18 @@ struct TransportInfo { }; struct RegisteredMemory::Impl { + // This is the data pointer returned by RegisteredMemory::data(), which may be different from the original data + // pointer for deserialized remote memory. void* data; + // This is the original data pointer the RegisteredMemory was created with. + void* originalDataPtr; size_t size; - int rank; - bool isRemote; uint64_t hostHash; + uint64_t pidHash; TransportFlags transports; std::vector transportInfos; - Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl); + Impl(void* data, size_t size, TransportFlags transports, Context::Impl& contextImpl); /// Constructs a RegisteredMemory::Impl from a vector of data. The constructor should only be used for the remote /// memory. Impl(const std::vector& data); diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 39a5ebb6..9c35e144 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -9,17 +9,18 @@ #include #include "api.h" +#include "context.hpp" #include "debug.h" #include "utils_internal.hpp" namespace mscclpp { -RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) +RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports, Context::Impl& contextImpl) : data(data), + originalDataPtr(data), size(size), - rank(rank), - isRemote(false), - hostHash(commImpl.rankToHash_.at(rank)), + hostHash(getHostHash()), + pidHash(getPidHash()), transports(transports) { if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; @@ -39,7 +40,7 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t auto addIb = [&](Transport ibTransport) { TransportInfo transportInfo; transportInfo.transport = ibTransport; - const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size); + const IbMr* mr = contextImpl.getIbContext(ibTransport)->registerMr(data, size); transportInfo.ibMr = mr; transportInfo.ibLocal = true; transportInfo.ibMrInfo = mr->getInfo(); @@ -57,30 +58,30 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t } } -MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : pimpl(pimpl) {} +MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : pimpl_(pimpl) {} MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default; -MSCCLPP_API_CPP void* RegisteredMemory::data() const { return pimpl->data; } +MSCCLPP_API_CPP void* RegisteredMemory::data() const { return pimpl_->data; } -MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl->size; } +MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl_->size; } -MSCCLPP_API_CPP int RegisteredMemory::rank() { return pimpl->rank; } - -MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl->transports; } +MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl_->transports; } MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() { std::vector result; - std::copy_n(reinterpret_cast(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result)); - std::copy_n(reinterpret_cast(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result)); - std::copy_n(reinterpret_cast(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result)); - std::copy_n(reinterpret_cast(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result)); - if (pimpl->transportInfos.size() > static_cast(std::numeric_limits::max())) { + std::copy_n(reinterpret_cast(&pimpl_->originalDataPtr), sizeof(pimpl_->originalDataPtr), + std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl_->size), sizeof(pimpl_->size), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl_->hostHash), sizeof(pimpl_->hostHash), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl_->pidHash), sizeof(pimpl_->pidHash), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl_->transports), sizeof(pimpl_->transports), std::back_inserter(result)); + if (pimpl_->transportInfos.size() > static_cast(std::numeric_limits::max())) { throw mscclpp::Error("Too many transport info entries", ErrorCode::InternalError); } - int8_t transportCount = pimpl->transportInfos.size(); + int8_t transportCount = pimpl_->transportInfos.size(); std::copy_n(reinterpret_cast(&transportCount), sizeof(transportCount), std::back_inserter(result)); - for (auto& entry : pimpl->transportInfos) { + for (auto& entry : pimpl_->transportInfos) { std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); if (entry.transport == Transport::CudaIpc) { std::copy_n(reinterpret_cast(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle), @@ -102,12 +103,14 @@ MSCCLPP_API_CPP RegisteredMemory RegisteredMemory::deserialize(const std::vector RegisteredMemory::Impl::Impl(const std::vector& serialization) { auto it = serialization.begin(); + std::copy_n(it, sizeof(this->originalDataPtr), reinterpret_cast(&this->originalDataPtr)); + it += sizeof(this->originalDataPtr); std::copy_n(it, sizeof(this->size), reinterpret_cast(&this->size)); it += sizeof(this->size); - std::copy_n(it, sizeof(this->rank), reinterpret_cast(&this->rank)); - it += sizeof(this->rank); std::copy_n(it, sizeof(this->hostHash), reinterpret_cast(&this->hostHash)); it += sizeof(this->hostHash); + std::copy_n(it, sizeof(this->pidHash), reinterpret_cast(&this->pidHash)); + it += sizeof(this->pidHash); std::copy_n(it, sizeof(this->transports), reinterpret_cast(&this->transports)); it += sizeof(this->transports); int8_t transportCount; @@ -137,28 +140,33 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { throw mscclpp::Error("Serialization failed", ErrorCode::InternalError); } - if (transports.has(Transport::CudaIpc)) { - uint64_t localHostHash = getHostHash(); - if (localHostHash == this->hostHash) { - auto entry = getTransportInfo(Transport::CudaIpc); - void* base; - MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); - data = static_cast(base) + entry.cudaIpcOffsetFromBase; - INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", data); - } + // Next decide how to set this->data + if (getHostHash() == this->hostHash && getPidHash() == this->pidHash) { + // The memory is local to the process, so originalDataPtr is valid as is + this->data = this->originalDataPtr; + } else if (transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash) { + // The memory is local to the machine but not to the process, so we need to open the CUDA IPC handle + auto entry = getTransportInfo(Transport::CudaIpc); + void* base; + MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); + this->data = static_cast(base) + entry.cudaIpcOffsetFromBase; + INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", this->data); + } else { + // No valid data pointer can be set + this->data = nullptr; } - this->isRemote = true; } RegisteredMemory::Impl::~Impl() { - uint64_t localHostHash = getHostHash(); - if (this->isRemote && localHostHash == this->hostHash && transports.has(Transport::CudaIpc)) { + // Close the CUDA IPC handle if it was opened during deserialization + if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) { void* base = static_cast(data) - getTransportInfo(Transport::CudaIpc).cudaIpcOffsetFromBase; cudaError_t err = cudaIpcCloseMemHandle(base); if (err != cudaSuccess) { - WARN("Failed to close cuda IPC handle: %s", cudaGetErrorString(err)); + WARN("Failed to close CUDA IPC handle at pointer %p: %s", base, cudaGetErrorString(err)); + } else { + INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", base); } - INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", base); data = nullptr; } } diff --git a/src/semaphore.cc b/src/semaphore.cc index 0b5269ca..921d32aa 100644 --- a/src/semaphore.cc +++ b/src/semaphore.cc @@ -12,8 +12,10 @@ static NonblockingFuture setupInboundSemaphoreId(Communicator& void* localInboundSemaphoreId) { auto localInboundSemaphoreIdsRegMem = communicator.registerMemory(localInboundSemaphoreId, sizeof(uint64_t), connection->transport()); - communicator.sendMemoryOnSetup(localInboundSemaphoreIdsRegMem, connection->remoteRank(), connection->tag()); - return communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag()); + int remoteRank = communicator.remoteRankOf(*connection); + int tag = communicator.tagOf(*connection); + communicator.sendMemoryOnSetup(localInboundSemaphoreIdsRegMem, remoteRank, tag); + return communicator.recvMemoryOnSetup(remoteRank, tag); } MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communicator, @@ -69,12 +71,12 @@ MSCCLPP_API_CPP SmDevice2DeviceSemaphore::SmDevice2DeviceSemaphore(Communicator& remoteInboundSemaphoreIdsRegMem_ = setupInboundSemaphoreId(communicator, connection.get(), localInboundSemaphore_.get()); INFO(MSCCLPP_INIT, "Creating a direct semaphore for CudaIPC transport from %d to %d", - communicator.bootstrap()->getRank(), connection->remoteRank()); + communicator.bootstrap()->getRank(), communicator.remoteRankOf(*connection)); isRemoteInboundSemaphoreIdSet_ = true; } else if (AllIBTransports.has(connection->transport())) { // We don't need to really with any of the IB transports, since the values will be local INFO(MSCCLPP_INIT, "Creating a direct semaphore for IB transport from %d to %d", - communicator.bootstrap()->getRank(), connection->remoteRank()); + communicator.bootstrap()->getRank(), communicator.remoteRankOf(*connection)); isRemoteInboundSemaphoreIdSet_ = false; } } diff --git a/src/utils_internal.cc b/src/utils_internal.cc index 1e3d49d1..37b96bbe 100644 --- a/src/utils_internal.cc +++ b/src/utils_internal.cc @@ -104,7 +104,7 @@ uint64_t getHostHash(void) { * * $$ $(readlink /proc/self/ns/pid) */ -uint64_t getPidHash(void) { +uint64_t computePidHash(void) { char pname[1024]; // Start off with our pid ($$) sprintf(pname, "%ld", (long)getpid()); @@ -118,6 +118,11 @@ uint64_t getPidHash(void) { return getHash(pname, strlen(pname)); } +uint64_t getPidHash(void) { + thread_local std::unique_ptr pidHash = std::make_unique(computePidHash()); + return *pidHash; +} + int parseStringList(const char* string, netIf* ifList, int maxList) { if (!string) return 0; diff --git a/test/allgather_test_cpp.cu b/test/allgather_test_cpp.cu index 9bf81b2b..951614db 100644 --- a/test/allgather_test_cpp.cu +++ b/test/allgather_test_cpp.cu @@ -215,6 +215,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr); std::vector semaphoreIds; std::vector localMemories; + std::vector>> connections(world_size); std::vector> remoteMemories; for (int r = 0; r < world_size; ++r) { @@ -226,7 +227,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co transport = ibTransport; } // Connect with all other ranks - semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, comm.connectOnSetup(r, 0, transport))); + connections[r] = comm.connectOnSetup(r, 0, transport); auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport); localMemories.push_back(memory); comm.sendMemoryOnSetup(memory, r, 0); @@ -235,6 +236,13 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co comm.setup(); + for (int r = 0; r < world_size; ++r) { + if (r == rank) continue; + semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, connections[r].get())); + } + + comm.setup(); + std::vector> proxyChannels; for (size_t i = 0; i < semaphoreIds.size(); ++i) { proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel( diff --git a/test/allgather_test_host_offloading.cu b/test/allgather_test_host_offloading.cu index d3e725f4..a71b4854 100644 --- a/test/allgather_test_host_offloading.cu +++ b/test/allgather_test_host_offloading.cu @@ -116,6 +116,7 @@ class MyProxyService { int cudaNum = rankToLocalRank(rank); std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr); + std::vector>> connectionsFuture(world_size); std::vector> remoteMemoriesFuture(world_size); localMemory_ = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport); @@ -133,14 +134,7 @@ class MyProxyService { transport = ibTransport; } // Connect with all other ranks - connections_[r] = comm.connectOnSetup(r, 0, transport); - if (rankToNode(r) == thisNode) { - hostSemaphores_.emplace_back(nullptr); - } else { - hostSemaphores_.emplace_back(std::make_shared(comm, connections_[r])); - } - deviceSemaphores1_.emplace_back(std::make_shared(comm, connections_[r])); - deviceSemaphores2_.emplace_back(std::make_shared(comm, connections_[r])); + connectionsFuture[r] = comm.connectOnSetup(r, 0, transport); comm.sendMemoryOnSetup(localMemory_, r, 0); remoteMemoriesFuture[r] = comm.recvMemoryOnSetup(r, 0); @@ -152,8 +146,18 @@ class MyProxyService { if (r == rank) { continue; } + connections_[r] = connectionsFuture[r].get(); + if (rankToNode(r) == thisNode) { + hostSemaphores_.emplace_back(nullptr); + } else { + hostSemaphores_.emplace_back(std::make_shared(comm, connections_[r])); + } + deviceSemaphores1_.emplace_back(std::make_shared(comm, connections_[r])); + deviceSemaphores2_.emplace_back(std::make_shared(comm, connections_[r])); remoteMemories_[r] = remoteMemoriesFuture[r].get(); } + + comm.setup(); } void bindThread() { diff --git a/test/mp_unit/communicator_tests.cu b/test/mp_unit/communicator_tests.cu index 829403b9..e3709357 100644 --- a/test/mp_unit/communicator_tests.cu +++ b/test/mp_unit/communicator_tests.cu @@ -43,16 +43,22 @@ void CommunicatorTestBase::TearDown() { void CommunicatorTestBase::setNumRanksToUse(int num) { numRanksToUse = num; } void CommunicatorTestBase::connectMesh(bool useIbOnly) { + std::vector>> connectionFutures(numRanksToUse); for (int i = 0; i < numRanksToUse; i++) { if (i != gEnv->rank) { if ((rankToNode(i) == rankToNode(gEnv->rank)) && !useIbOnly) { - connections[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::CudaIpc); + connectionFutures[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::CudaIpc); } else { - connections[i] = communicator->connectOnSetup(i, 0, ibTransport); + connectionFutures[i] = communicator->connectOnSetup(i, 0, ibTransport); } } } communicator->setup(); + for (int i = 0; i < numRanksToUse; i++) { + if (i != gEnv->rank) { + connections[i] = connectionFutures[i].get(); + } + } } // Register a local memory and receive corresponding remote memories diff --git a/test/mp_unit/proxy_channel_tests.cu b/test/mp_unit/proxy_channel_tests.cu index 5537fe01..23e450ce 100644 --- a/test/mp_unit/proxy_channel_tests.cu +++ b/test/mp_unit/proxy_channel_tests.cu @@ -22,6 +22,9 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vector>> connectionFutures(worldSize); + std::vector> remoteMemFutures(worldSize); + mscclpp::RegisteredMemory sendBufRegMem = communicator->registerMemory(sendBuff, sendBuffBytes, transport); mscclpp::RegisteredMemory recvBufRegMem; if (!isInPlace) { @@ -32,29 +35,33 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vector conn; if ((rankToNode(r) == rankToNode(gEnv->rank)) && !useIbOnly) { - conn = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc); + connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc); } else { - conn = communicator->connectOnSetup(r, 0, ibTransport); + connectionFutures[r] = communicator->connectOnSetup(r, 0, ibTransport); } - connections[r] = conn; if (isInPlace) { communicator->sendMemoryOnSetup(sendBufRegMem, r, 0); } else { communicator->sendMemoryOnSetup(recvBufRegMem, r, 0); } - auto remoteMemory = communicator->recvMemoryOnSetup(r, 0); + remoteMemFutures[r] = communicator->recvMemoryOnSetup(r, 0); + } - communicator->setup(); + communicator->setup(); - mscclpp::SemaphoreId cid = proxyService->buildAndAddSemaphore(*communicator, conn); - communicator->setup(); + for (int r = 0; r < worldSize; r++) { + if (r == rank) { + continue; + } + mscclpp::SemaphoreId cid = proxyService->buildAndAddSemaphore(*communicator, connectionFutures[r].get()); - proxyChannels.emplace_back(proxyService->proxyChannel(cid), proxyService->addMemory(remoteMemory.get()), + proxyChannels.emplace_back(proxyService->proxyChannel(cid), proxyService->addMemory(remoteMemFutures[r].get()), proxyService->addMemory(sendBufRegMem)); } + + communicator->setup(); } __constant__ DeviceHandle gChannelOneToOneTestConstProxyChans; diff --git a/test/mp_unit/sm_channel_tests.cu b/test/mp_unit/sm_channel_tests.cu index 21d9571a..37b3ce63 100644 --- a/test/mp_unit/sm_channel_tests.cu +++ b/test/mp_unit/sm_channel_tests.cu @@ -24,6 +24,9 @@ void SmChannelOneToOneTest::setupMeshConnections(std::vector const bool isInPlace = (outputBuff == nullptr); mscclpp::TransportFlags transport = mscclpp::Transport::CudaIpc | ibTransport; + std::vector>> connectionFutures(worldSize); + std::vector> remoteMemFutures(worldSize); + mscclpp::RegisteredMemory inputBufRegMem = communicator->registerMemory(inputBuff, inputBuffBytes, transport); mscclpp::RegisteredMemory outputBufRegMem; if (!isInPlace) { @@ -34,30 +37,35 @@ void SmChannelOneToOneTest::setupMeshConnections(std::vector if (r == rank) { continue; } - std::shared_ptr conn; if (rankToNode(r) == rankToNode(gEnv->rank)) { - conn = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc); + connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc); } else { - conn = communicator->connectOnSetup(r, 0, ibTransport); + connectionFutures[r] = communicator->connectOnSetup(r, 0, ibTransport); } - connections[r] = conn; if (isInPlace) { communicator->sendMemoryOnSetup(inputBufRegMem, r, 0); } else { communicator->sendMemoryOnSetup(outputBufRegMem, r, 0); } - auto remoteMemory = communicator->recvMemoryOnSetup(r, 0); + remoteMemFutures[r] = communicator->recvMemoryOnSetup(r, 0); + } - communicator->setup(); + communicator->setup(); - smSemaphores[r] = std::make_shared(*communicator, conn); + for (int r = 0; r < worldSize; r++) { + if (r == rank) { + continue; + } + connections[r] = connectionFutures[r].get(); - communicator->setup(); + smSemaphores[r] = std::make_shared(*communicator, connections[r]); - smChannels.emplace_back(smSemaphores[r], remoteMemory.get(), inputBufRegMem.data(), + smChannels.emplace_back(smSemaphores[r], remoteMemFutures[r].get(), inputBufRegMem.data(), (isInPlace ? nullptr : outputBufRegMem.data())); } + + communicator->setup(); } __constant__ DeviceHandle gChannelOneToOneTestConstSmChans; diff --git a/test/mscclpp-test/common.cc b/test/mscclpp-test/common.cc index e8053104..e47ccfbf 100644 --- a/test/mscclpp-test/common.cc +++ b/test/mscclpp-test/common.cc @@ -369,6 +369,7 @@ void BaseTestEngine::setupMeshConnectionsInternal( const int nRanksPerNode = args_.nRanksPerNode; const int thisNode = rank / nRanksPerNode; const mscclpp::Transport ibTransport = IBs[args_.gpuNum]; + std::vector>> connectionFutures; auto rankToNode = [&](int rank) { return rank / nRanksPerNode; }; for (int r = 0; r < worldSize; r++) { @@ -383,13 +384,16 @@ void BaseTestEngine::setupMeshConnectionsInternal( transport = ibTransport; } // Connect with all other ranks - connections.push_back(comm_->connectOnSetup(r, 0, transport)); + connectionFutures.push_back(comm_->connectOnSetup(r, 0, transport)); } comm_->sendMemoryOnSetup(localRegMemory, r, 0); auto remoteMemory = comm_->recvMemoryOnSetup(r, 0); remoteRegMemories.push_back(remoteMemory); } comm_->setup(); + std::transform( + connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections), + [](const mscclpp::NonblockingFuture>& future) { return future.get(); }); } // Create mesh connections between all ranks. If recvBuff is nullptr, assume in-place. diff --git a/test/mscclpp-test/sendrecv_test.cu b/test/mscclpp-test/sendrecv_test.cu index 53aecbe6..d0922014 100644 --- a/test/mscclpp-test/sendrecv_test.cu +++ b/test/mscclpp-test/sendrecv_test.cu @@ -155,16 +155,18 @@ void SendRecvTestEngine::setupConnections() { std::vector> smSemaphores; - auto sendConn = + auto sendConnFuture = comm_->connectOnSetup(sendToRank, 0, getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice)); - smSemaphores.push_back(std::make_shared(*comm_, sendConn)); if (recvFromRank != sendToRank) { - auto recvConn = + auto recvConnFuture = comm_->connectOnSetup(recvFromRank, 0, getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice)); - smSemaphores.push_back(std::make_shared(*comm_, recvConn)); + comm_->setup(); + smSemaphores.push_back(std::make_shared(*comm_, sendConnFuture.get())); + smSemaphores.push_back(std::make_shared(*comm_, recvConnFuture.get())); } else { - // reuse the send channel if worldSize is 2 - smSemaphores.push_back(smSemaphores[0]); + comm_->setup(); + smSemaphores.push_back(std::make_shared(*comm_, sendConnFuture.get())); + smSemaphores.push_back(smSemaphores[0]); // reuse the send channel if worldSize is 2 } comm_->setup(); diff --git a/test/unit/core_tests.cc b/test/unit/core_tests.cc index 55a8ed10..90da5dd7 100644 --- a/test/unit/core_tests.cc +++ b/test/unit/core_tests.cc @@ -10,6 +10,7 @@ class LocalCommunicatorTest : public ::testing::Test { protected: void SetUp() override { bootstrap = std::make_shared(0, 1); + bootstrap->initialize(bootstrap->createUniqueId()); comm = std::make_shared(bootstrap); } @@ -36,18 +37,17 @@ TEST_F(LocalCommunicatorTest, RegisterMemory) { auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports); EXPECT_EQ(memory.data(), &dummy); EXPECT_EQ(memory.size(), sizeof(dummy)); - EXPECT_EQ(memory.rank(), 0); EXPECT_EQ(memory.transports(), mscclpp::NoTransports); } -// TEST_F(LocalCommunicatorTest, SendMemoryToSelf) { -// int dummy[42]; -// auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports); -// comm->sendMemoryOnSetup(memory, 0, 0); -// auto memoryFuture = comm->recvMemoryOnSetup(0, 0); -// comm->setup(); -// auto sameMemory = memoryFuture.get(); -// EXPECT_EQ(sameMemory.size(), memory.size()); -// EXPECT_EQ(sameMemory.rank(), memory.rank()); -// EXPECT_EQ(sameMemory.transports(), memory.transports()); -// } +TEST_F(LocalCommunicatorTest, SendMemoryToSelf) { + int dummy[42]; + auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports); + comm->sendMemoryOnSetup(memory, 0, 0); + auto memoryFuture = comm->recvMemoryOnSetup(0, 0); + comm->setup(); + auto sameMemory = memoryFuture.get(); + EXPECT_EQ(sameMemory.data(), memory.data()); + EXPECT_EQ(sameMemory.size(), memory.size()); + EXPECT_EQ(sameMemory.transports(), memory.transports()); +}