Skip to content

Commit

Permalink
Merge branch 'main' into v0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Sep 6, 2023
2 parents afd4eb4 + 828be48 commit 1e2210a
Show file tree
Hide file tree
Showing 25 changed files with 626 additions and 327 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ install(TARGETS mscclpp_static

# Tests
if (BUILD_TESTS)
enable_testing() # Called here to allow ctest from the build directory
add_subdirectory(test)
endif()

Expand Down
218 changes: 162 additions & 56 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ class TcpBootstrap : public Bootstrap {
void barrier() override;

private:
/// Implementation class for @ref TcpBootstrap.
class Impl;
/// Pointer to the implementation class for @ref TcpBootstrap.
// The interal implementation.
struct Impl;

// Pointer to the internal implementation.
std::unique_ptr<Impl> pimpl_;
};

Expand Down Expand Up @@ -306,23 +307,15 @@ std::string getIBDeviceName(Transport ibTransport);
/// @return The InfiniBand transport associated with the specified device name.
Transport getIBTransportByDeviceName(const std::string& ibDeviceName);

class Communicator;
class Context;
class Connection;

/// Represents a block of memory that has been registered to a @ref Communicator.
/// Represents a block of memory that has been registered to a @ref Context.
class RegisteredMemory {
protected:
struct Impl;

public:
/// Default constructor.
RegisteredMemory() = default;

/// Constructor that takes a shared pointer to an implementation object.
///
/// @param pimpl A shared pointer to an implementation object.
RegisteredMemory(std::shared_ptr<Impl> pimpl);

/// Destructor.
~RegisteredMemory();

Expand All @@ -336,11 +329,6 @@ class RegisteredMemory {
/// @return The size of the memory block.
size_t size();

/// Get the rank of the process that owns the memory block.
///
/// @return The rank of the process that owns the memory block.
int rank();

/// Get the transport flags associated with the memory block.
///
/// @return The transport flags associated with the memory block.
Expand All @@ -357,14 +345,54 @@ class RegisteredMemory {
/// @return A deserialized RegisteredMemory object.
static RegisteredMemory deserialize(const std::vector<char>& data);

private:
// The interal implementation.
struct Impl;

// Internal constructor.
RegisteredMemory(std::shared_ptr<Impl> pimpl);

// Pointer to the internal implementation. A shared_ptr is used since RegisteredMemory is immutable.
std::shared_ptr<Impl> pimpl_;

friend class Context;
friend class Connection;
friend class IBConnection;
friend class Communicator;
};

/// Represents one end of a connection.
class Endpoint {
public:
/// Default constructor.
Endpoint() = default;

/// Get the transport used.
///
/// @return The transport used.
Transport transport();

/// Serialize the Endpoint object to a vector of characters.
///
/// @return A vector of characters representing the serialized Endpoint object.
std::vector<char> serialize();

/// Deserialize a Endpoint object from a vector of characters.
///
/// @param data A vector of characters representing a serialized Endpoint object.
/// @return A deserialized Endpoint object.
static Endpoint deserialize(const std::vector<char>& data);

private:
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated
// lazily.
std::shared_ptr<Impl> pimpl;
// The interal implementation.
struct Impl;

// Internal constructor.
Endpoint(std::shared_ptr<Impl> pimpl);

// Pointer to the internal implementation. A shared_ptr is used since Endpoint is immutable.
std::shared_ptr<Impl> pimpl_;

friend class Context;
friend class Connection;
};

/// Represents a connection between two processes.
Expand All @@ -391,16 +419,6 @@ class Connection {
/// Flush any pending writes to the remote process.
virtual void flush(int64_t timeoutUsec = 3e7) = 0;

/// Get the rank of the remote process.
///
/// @return The rank of the remote process.
virtual int remoteRank() = 0;

/// Get the tag associated with the connection.
///
/// @return The tag associated with the connection.
virtual int tag() = 0;

/// Get the transport used by the local process.
///
/// @return The transport used by the local process.
Expand All @@ -412,11 +430,89 @@ class Connection {
virtual Transport remoteTransport() = 0;

protected:
/// Get the implementation object associated with a @ref RegisteredMemory object.
// Internal methods for getting implementation pointers.
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
};

/// Used to configure an endpoint.
struct EndpointConfig {
static const int DefaultMaxCqSize = 1024;
static const int DefaultMaxCqPollNum = 1;
static const int DefaultMaxSendWr = 8192;
static const int DefaultMaxWrPerSend = 64;

Transport transport;
int ibMaxCqSize = DefaultMaxCqSize;
int ibMaxCqPollNum = DefaultMaxCqPollNum;
int ibMaxSendWr = DefaultMaxSendWr;
int ibMaxWrPerSend = DefaultMaxWrPerSend;

/// Default constructor. Sets transport to Transport::Unknown.
EndpointConfig() : transport(Transport::Unknown) {}

/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param memory The @ref RegisteredMemory object.
/// @return A shared pointer to the implementation object.
static std::shared_ptr<RegisteredMemory::Impl> getRegisteredMemoryImpl(RegisteredMemory& memory);
/// @param transport The transport to use.
EndpointConfig(Transport transport) : transport(transport) {}
};

/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases
/// where the process group abstraction offered by @ref Communicator is not suitable, e.g., ephemeral client-server
/// connections. Correct use of this class requires external synchronization when finalizing connections with the
/// @ref connect() method.
///
/// As an example, a client-server scenario where the server will write to the client might proceed as follows:
/// 1. The client creates an endpoint with @ref createEndpoint() and sends it to the server.
/// 2. The server receives the client endpoint, creates its own endpoint with @ref createEndpoint(), sends it to the
/// client, and creates a connection with @ref connect().
/// 4. The client receives the server endpoint, creates a connection with @ref connect() and sends a
/// @ref RegisteredMemory to the server.
/// 5. The server receives the @ref RegisteredMemory and writes to it using the previously created connection.
/// The client waiting to create a connection before sending the @ref RegisteredMemory ensures that the server can not
/// write to the @ref RegisteredMemory before the connection is established.
///
/// While some transports may have more relaxed implementation behavior, this should not be relied upon.
class Context {
public:
/// Create a context.
Context();

/// Destroy the context.
~Context();

/// Register a region of GPU memory for use in this context.
///
/// @param ptr Base pointer to the memory.
/// @param size Size of the memory region in bytes.
/// @param transports Transport flags.
/// @return RegisteredMemory A handle to the buffer.
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);

/// Create an endpoint for establishing connections.
///
/// @param config The configuration for the endpoint.
/// @return The newly created endpoint.
Endpoint createEndpoint(EndpointConfig config);

/// Establish a connection between two endpoints. While this method immediately returns a connection object, the
/// connection is only safe to use after the corresponding connection on the remote endpoint has been established.
/// This method must be called on both endpoints to establish a connection.
///
/// @param localEndpoint The local endpoint.
/// @param remoteEndpoint The remote endpoint.
/// @return std::shared_ptr<Connection> A shared pointer to the connection.
std::shared_ptr<Connection> connect(Endpoint localEndpoint, Endpoint remoteEndpoint);

private:
// The interal implementation.
struct Impl;

// Pointer to the internal implementation.
std::unique_ptr<Impl> pimpl_;

friend class RegisteredMemory;
friend class Endpoint;
};

/// A base class for objects that can be set up during @ref Communicator::setup().
Expand Down Expand Up @@ -482,14 +578,12 @@ class NonblockingFuture {
/// 6. All done; use connections and registered memories to build channels.
///
class Communicator {
protected:
struct Impl;

public:
/// Initializes the communicator with a given bootstrap implementation.
///
/// @param bootstrap An implementation of the Bootstrap that the communicator will use.
Communicator(std::shared_ptr<Bootstrap> bootstrap);
/// @param context An optional context to use for the communicator. If not provided, a new context will be created.
Communicator(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context = nullptr);

/// Destroy the communicator.
~Communicator();
Expand All @@ -499,7 +593,12 @@ class Communicator {
/// @return std::shared_ptr<Bootstrap> The bootstrap held by this communicator.
std::shared_ptr<Bootstrap> bootstrap();

/// Register a region of GPU memory for use in this communicator.
/// Returns the context held by this communicator.
///
/// @return std::shared_ptr<Context> The context held by this communicator.
std::shared_ptr<Context> context();

/// Register a region of GPU memory for use in this communicator's context.
///
/// @param ptr Base pointer to the memory.
/// @param size Size of the memory region in bytes.
Expand Down Expand Up @@ -537,15 +636,22 @@ class Communicator {
///
/// @param remoteRank The rank of the remote process.
/// @param tag The tag of the connection for identifying it.
/// @param transport The type of transport to be used.
/// @param ibMaxCqSize The maximum number of completion queue entries for IB. Unused if transport is not IB.
/// @param ibMaxCqPollNum The maximum number of completion queue entries to poll for IB. Unused if transport is not
/// IB.
/// @param ibMaxSendWr The maximum number of outstanding send work requests for IB. Unused if transport is not IB.
/// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB.
/// @return std::shared_ptr<Connection> A shared pointer to the connection.
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024,
int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64);
/// @param config The configuration for the local endpoint.
/// @return NonblockingFuture<NonblockingFuture<std::shared_ptr<Connection>>> A non-blocking future of shared pointer
/// to the connection.
NonblockingFuture<std::shared_ptr<Connection>> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig);

/// Get the remote rank a connection is connected to.
///
/// @param connection The connection to get the remote rank for.
/// @return The remote rank the connection is connected to.
int remoteRankOf(const Connection& connection);

/// Get the tag a connection was made with.
///
/// @param connection The connection to get the tag for.
/// @return The tag the connection was made with.
int tagOf(const Connection& connection);

/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
///
Expand All @@ -559,12 +665,12 @@ class Communicator {
/// that have been registered after the (n-1)-th call.
void setup();

friend class RegisteredMemory::Impl;
friend class IBConnection;

private:
/// Unique pointer to the implementation of the Communicator class.
std::unique_ptr<Impl> pimpl;
// The interal implementation.
struct Impl;

// Pointer to the internal implementation.
std::unique_ptr<Impl> pimpl_;
};

/// A constant TransportFlags object representing no transports.
Expand Down
8 changes: 5 additions & 3 deletions python/examples/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
remote_memories.append(remote_mem)
comm.setup()

connections = [conn.get() for conn in connections]

# Create simple proxy channels
for i, conn in enumerate(connections):
proxy_channel = mscclpp.SimpleProxyChannel(
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(conn)),
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),
proxy_service.add_memory(remote_memories[i].get()),
proxy_service.add_memory(reg_mem),
)
simple_proxy_channels.append(mscclpp.device_handle(proxy_channel))
simple_proxy_channels.append(proxy_channel.device_handle())
comm.setup()

# Create sm channels
Expand All @@ -66,7 +68,7 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
for i, conn in enumerate(sm_semaphores):
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
sm_channels.append(sm_chan)
return simple_proxy_channels, [mscclpp.device_handle(sm_chan) for sm_chan in sm_channels]
return simple_proxy_channels, [sm_chan.device_handle() for sm_chan in sm_channels]


def run(rank, args):
Expand Down
Loading

0 comments on commit 1e2210a

Please sign in to comment.