From cc08b232c211afcbdc16667aa0bcc0b05f75b819 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 31 Aug 2023 17:45:58 +0000 Subject: [PATCH] Get rid of comm.setup() --- include/mscclpp/core.hpp | 105 ++++++------------------- include/mscclpp/semaphore.hpp | 2 +- python/mscclpp/core_py.cpp | 33 ++++---- src/bootstrap/bootstrap.cc | 100 +++++++++++++++++++---- src/communicator.cc | 100 +++++------------------ src/endpoint.cc | 4 +- src/registered_memory.cc | 6 +- src/semaphore.cc | 14 ++-- test/allgather_test_cpp.cu | 14 ++-- test/allgather_test_host_offloading.cu | 14 ++-- test/mp_unit/bootstrap_tests.cc | 44 ++++++----- test/mp_unit/communicator_tests.cu | 20 ++--- test/mp_unit/ib_tests.cu | 4 +- test/mp_unit/proxy_channel_tests.cu | 22 +++--- test/mp_unit/sm_channel_tests.cu | 22 +++--- test/mscclpp-test/allgather_test.cu | 3 +- test/mscclpp-test/common.cc | 33 +++----- test/mscclpp-test/common.hpp | 14 ++-- test/mscclpp-test/sendrecv_test.cu | 14 ++-- test/unit/core_tests.cc | 8 -- 20 files changed, 243 insertions(+), 333 deletions(-) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index ea7b1460..664b84ca 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -29,15 +29,15 @@ class Bootstrap { public: Bootstrap(){}; virtual ~Bootstrap() = default; - virtual int getRank() = 0; - virtual int getNranks() = 0; + virtual int rank() = 0; + virtual int size() = 0; virtual void send(void* data, int size, int peer, int tag) = 0; - virtual void recv(void* data, int size, int peer, int tag) = 0; + [[nodiscard]] virtual std::future recv(void* data, int size, int peer, int tag) = 0; virtual void allGather(void* allData, int size) = 0; virtual void barrier() = 0; void send(const std::vector& data, int peer, int tag); - void recv(std::vector& data, int peer, int tag); + std::future> recv(int peer, int tag); }; /// A native implementation of the bootstrap using TCP sockets. @@ -70,10 +70,10 @@ class TcpBootstrap : public Bootstrap { void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec = 30); /// Return the rank of the process. - int getRank() override; + int rank() override; /// Return the total number of ranks. - int getNranks() override; + int size() override; /// Send data to another process. /// @@ -95,7 +95,8 @@ class TcpBootstrap : public Bootstrap { /// @param size The size of the data to receive. /// @param peer The rank of the process to receive the data from. /// @param tag The tag to receive the data with. - void recv(void* data, int size, int peer, int tag) override; + /// @return A future that will be ready when the data has been received. + [[nodiscard]] std::future recv(void* data, int size, int peer, int tag) override; /// Gather data from all processes. /// @@ -324,17 +325,17 @@ class RegisteredMemory { /// Get the size of the memory block. /// /// @return The size of the memory block. - size_t size(); + size_t size() const; /// Get the transport flags associated with the memory block. /// /// @return The transport flags associated with the memory block. - TransportFlags transports(); + TransportFlags transports() const; /// Serialize the RegisteredMemory object to a vector of characters. /// /// @return A vector of characters representing the serialized RegisteredMemory object. - std::vector serialize(); + std::vector serialize() const; /// Deserialize a RegisteredMemory object from a vector of characters. /// @@ -365,12 +366,12 @@ class Endpoint { /// Get the transport used. /// /// @return The transport used. - Transport transport(); + Transport transport() const; /// Serialize the Endpoint object to a vector of characters. /// /// @return A vector of characters representing the serialized Endpoint object. - std::vector serialize(); + std::vector serialize() const; /// Deserialize a Endpoint object from a vector of characters. /// @@ -527,50 +528,14 @@ struct Setuppable { virtual void endSetup(std::shared_ptr bootstrap); }; -/// A non-blocking future that can be used to check if a value is ready and retrieve it. -template -class NonblockingFuture { - std::shared_future future; - - public: - /// Default constructor. - NonblockingFuture() = default; - - /// Constructor that takes a shared future and moves it into the NonblockingFuture. - /// - /// @param future The shared future to move. - NonblockingFuture(std::shared_future&& future) : future(std::move(future)) {} - - /// Copy constructor. - /// - /// @param other The @ref NonblockingFuture to copy. - NonblockingFuture(const NonblockingFuture& other) = default; - - /// Check if the value is ready to be retrieved. - /// - /// @return True if the value is ready, false otherwise. - bool ready() const { return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready; } - - /// Get the value. - /// - /// @return The value. - /// - /// @throws Error if the value is not ready. - T get() const { - if (!ready()) throw Error("NonblockingFuture::get() called before ready", ErrorCode::InvalidUsage); - return future.get(); - } -}; - /// A class that sets up all registered memories and connections between processes. /// /// A typical way to use this class: -/// 1. Call @ref connectOnSetup() to declare connections between the calling process with other processes. +/// 1. Call @ref connect() to declare connections between the calling process with other processes. /// 2. Call @ref registerMemory() to register memory regions that will be used for communication. -/// 3. Call @ref sendMemoryOnSetup() or @ref recvMemoryOnSetup() to send/receive registered memory regions to/from +/// 3. Call @ref sendMemory() or @ref recvMemory() to send/receive registered memory regions to/from /// other processes. -/// 4. Call @ref setup() to set up all registered memories and connections declared in the previous steps. -/// 5. Call @ref NonblockingFuture::get() to get the registered memory regions received from other +/// 5. Call @ref std::future::get() to get the registered memory regions received from other /// processes. /// 6. All done; use connections and registered memories to build channels. /// @@ -603,30 +568,23 @@ class Communicator { /// @return RegisteredMemory A handle to the buffer. RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports); - /// Send information of a registered memory to the remote side on setup. - /// - /// This function registers a send to a remote process that will happen by a following call of @ref setup(). The send - /// will carry information about a registered memory on the local process. + /// Send information of a registered memory to the remote side. /// /// @param memory The registered memory buffer to send information about. /// @param remoteRank The rank of the remote process. /// @param tag The tag to use for identifying the send. - void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag); + void sendMemory(RegisteredMemory memory, int remoteRank, int tag); - /// Receive memory on setup. - /// - /// This function registers a receive from a remote process that will happen by a following call of @ref setup(). The - /// receive will carry information about a registered memory on the remote process. + /// Receive memory. /// /// @param remoteRank The rank of the remote process. /// @param tag The tag to use for identifying the receive. - /// @return NonblockingFuture A non-blocking future of registered memory. - NonblockingFuture recvMemoryOnSetup(int remoteRank, int tag); + /// @return std::future A future of registered memory. + std::future recvMemory(int remoteRank, int tag); - /// Connect to a remote rank on setup. + /// Connect to a remote rank. /// - /// This function only prepares metadata for connection. The actual connection is made by a following call of - /// @ref setup(). Note that this function is two-way and a connection from rank `i` to remote rank `j` needs + /// Note that this function is two-way and a connection from rank `i` to remote rank `j` needs /// to have a counterpart from rank `j` to rank `i`. Note that with IB, buffers are registered at a page level and if /// a buffer is spread through multiple pages and do not fully utilize all of them, IB's QP has to register for all /// involved pages. This potentially has security risks if the connection's accesses are given to a malicious process. @@ -634,9 +592,8 @@ class Communicator { /// @param remoteRank The rank of the remote process. /// @param tag The tag of the connection for identifying it. /// @param config The configuration for the local endpoint. - /// @return NonblockingFuture>> A non-blocking future of shared pointer - /// to the connection. - NonblockingFuture> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig); + /// @return std::future> A future of shared pointer to the connection. + std::future> connect(int remoteRank, int tag, EndpointConfig localConfig); /// Get the remote rank a connection is connected to. /// @@ -650,18 +607,6 @@ class Communicator { /// @return The tag the connection was made with. int tagOf(const Connection& connection); - /// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called. - /// - /// @param setuppable A shared pointer to the Setuppable object. - void onSetup(std::shared_ptr setuppable); - - /// Setup all objects that have registered for setup. - /// - /// This includes previous calls of @ref sendMemoryOnSetup(), @ref recvMemoryOnSetup(), @ref connectOnSetup(), and - /// @ref onSetup(). It is allowed to call this function multiple times, where the n-th call will only setup objects - /// that have been registered after the (n-1)-th call. - void setup(); - private: // The interal implementation. struct Impl; diff --git a/include/mscclpp/semaphore.hpp b/include/mscclpp/semaphore.hpp index 9f73082e..028daeae 100644 --- a/include/mscclpp/semaphore.hpp +++ b/include/mscclpp/semaphore.hpp @@ -30,7 +30,7 @@ template