From d4ede480f48fc28cd1b930dd3439dc47390fd953 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 25 Apr 2024 11:06:43 -0700 Subject: [PATCH] Ethernet support (#284) Co-authored-by: Binyang Li Co-authored-by: Caio Rocha --- include/mscclpp/core.hpp | 36 ++++--- src/bootstrap/socket.cc | 32 ++++++ src/connection.cc | 145 ++++++++++++++++++++++++++++ src/context.cc | 6 ++ src/core.cc | 2 +- src/endpoint.cc | 19 ++++ src/include/connection.hpp | 33 +++++++ src/include/endpoint.hpp | 9 ++ src/include/socket.h | 1 + src/registered_memory.cc | 2 + test/mp_unit/communicator_tests.cu | 12 ++- test/mp_unit/mp_unit_tests.hpp | 18 +++- test/mp_unit/proxy_channel_tests.cu | 78 ++++++++++----- 13 files changed, 341 insertions(+), 52 deletions(-) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index ffbde7bf8..033e95952 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -130,25 +130,26 @@ class TcpBootstrap : public Bootstrap { /// Enumerates the available transport types. enum class Transport { - Unknown, // Unknown transport type. - CudaIpc, // CUDA IPC transport type. - Nvls, // NVLS transport type. - IB0, // InfiniBand device 0 transport type. - IB1, // InfiniBand device 1 transport type. - IB2, // InfiniBand device 2 transport type. - IB3, // InfiniBand device 3 transport type. - IB4, // InfiniBand device 4 transport type. - IB5, // InfiniBand device 5 transport type. - IB6, // InfiniBand device 6 transport type. - IB7, // InfiniBand device 7 transport type. - NumTransports // The number of transports. + Unknown, // Unknown transport type. + CudaIpc, // CUDA IPC transport type. + Nvls, // NVLS transport type. + IB0, // InfiniBand device 0 transport type. + IB1, // InfiniBand device 1 transport type. + IB2, // InfiniBand device 2 transport type. + IB3, // InfiniBand device 3 transport type. + IB4, // InfiniBand device 4 transport type. + IB5, // InfiniBand device 5 transport type. + IB6, // InfiniBand device 6 transport type. + IB7, // InfiniBand device 7 transport type. + Ethernet, // Ethernet transport type. + NumTransports, // The number of transports. }; -const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", - "IB3", "IB4", "IB5", "IB6", "IB7", "NUM"}; +const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", "IB3", + "IB4", "IB5", "IB6", "IB7", "ETH", "NUM"}; namespace detail { -const size_t TransportFlagsSize = 11; +const size_t TransportFlagsSize = 12; static_assert(TransportFlagsSize == static_cast(Transport::NumTransports), "TransportFlagsSize must match the number of transports"); /// Bitset for storing transport flags. @@ -336,6 +337,11 @@ class RegisteredMemory { /// @return A pointer to the memory block. void* data() const; + /// Get a pointer to the original memory block. + /// + /// @return A pointer to the original memory block. + void* originalDataPtr() const; + /// Get the size of the memory block. /// /// @return The size of the memory block. diff --git a/src/bootstrap/socket.cc b/src/bootstrap/socket.cc index a79821f1b..9e5913403 100644 --- a/src/bootstrap/socket.cc +++ b/src/bootstrap/socket.cc @@ -543,6 +543,38 @@ void Socket::recv(void* ptr, int size) { socketWait(MSCCLPP_SOCKET_RECV, ptr, size, &offset); } +void Socket::recvUntilEnd(void* ptr, int size, int* closed) { + int offset = 0; + *closed = 0; + if (state_ != SocketStateReady) { + std::stringstream ss; + ss << "socket state (" << state_ << ") is not ready in recvUntilEnd"; + throw Error(ss.str(), ErrorCode::InternalError); + } + + int bytes = 0; + char* data = (char*)ptr; + + do { + bytes = ::recv(fd_, data + (offset), size - (offset), 0); + if (bytes == 0) { + *closed = 1; + return; + } + if (bytes == -1) { + if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN && state_ != SocketStateClosed) { + throw SysError("recv until end failed", errno); + } else { + bytes = 0; + } + } + (offset) += bytes; + if (abortFlag_ && *abortFlag_ != 0) { + throw Error("aborted", ErrorCode::Aborted); + } + } while (bytes > 0 && (offset) < size); +} + void Socket::close() { if (fd_ >= 0) ::close(fd_); state_ = SocketStateClosed; diff --git a/src/connection.cc b/src/connection.cc index 65b76b33f..6e01367f6 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -5,6 +5,7 @@ #include #include +#include #include "debug.h" #include "endpoint.hpp" @@ -180,4 +181,148 @@ void IBConnection::flush(int64_t timeoutUsec) { // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); } +// EthernetConnection + +EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize, + uint64_t recvBufferSize) + : abortFlag_(0), sendBufferSize_(sendBufferSize), recvBufferSize_(recvBufferSize) { + // Validating Transport Protocol + if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) { + throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage); + } + + // Instanciating Buffers + sendBuffer_.resize(sendBufferSize_); + recvBuffer_.resize(recvBufferSize_); + + // Creating Thread to Accept the Connection + auto parameter = (getImpl(localEndpoint)->socket_).get(); + std::thread t([this, parameter]() { + recvSocket_ = std::make_unique(nullptr, MSCCLPP_SOCKET_MAGIC, SocketTypeUnknown, abortFlag_); + recvSocket_->accept(parameter); + }); + + // Starting Connection + sendSocket_ = std::make_unique(&(getImpl(remoteEndpoint)->socketAddress_), MSCCLPP_SOCKET_MAGIC, + SocketTypeBootstrap, abortFlag_); + sendSocket_->connect(); + + // Ensure the Connection was Established + t.join(); + + // Starting Thread to Receive Messages + threadRecvMessages_ = std::thread(&EthernetConnection::recvMessages, this); + + INFO(MSCCLPP_NET, "Ethernet connection created"); +} + +EthernetConnection::~EthernetConnection() { + sendSocket_->close(); + recvSocket_->close(); + threadRecvMessages_.join(); +} + +Transport EthernetConnection::transport() { return Transport::Ethernet; } + +Transport EthernetConnection::remoteTransport() { return Transport::Ethernet; } + +void EthernetConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) { + // Validating Transport Protocol + validateTransport(dst, remoteTransport()); + validateTransport(src, transport()); + + // Initializing Variables + char* srcPtr = reinterpret_cast(src.data()) + srcOffset / sizeof(char); + char* dstPtr = reinterpret_cast(dst.originalDataPtr()) + dstOffset / sizeof(char); + uint64_t sentDataSize = 0; + uint64_t headerSize = 0; + + // Copying Meta Data to Send Buffer + char* dstPtrBytes = reinterpret_cast(&dstPtr); + std::copy(dstPtrBytes, dstPtrBytes + sizeof(dstPtr), sendBuffer_.data() + headerSize / sizeof(char)); + headerSize += sizeof(dstPtr); + char* sizeBytes = reinterpret_cast(&size); + std::copy(sizeBytes, sizeBytes + sizeof(size), sendBuffer_.data() + headerSize / sizeof(char)); + headerSize += sizeof(size); + + // Getting Data From GPU and Sending Message + while (sentDataSize < size) { + uint64_t dataSize = + std::min(sendBufferSize_ - headerSize / sizeof(char), (size - sentDataSize) / sizeof(char)) * sizeof(char); + uint64_t messageSize = dataSize + headerSize; + mscclpp::memcpyCuda(sendBuffer_.data() + headerSize / sizeof(char), + (char*)srcPtr + (sentDataSize / sizeof(char)), dataSize, cudaMemcpyDeviceToHost); + sendSocket_->send(sendBuffer_.data(), messageSize); + sentDataSize += messageSize; + headerSize = 0; + } + + INFO(MSCCLPP_NET, "EthernetConnection write: from %p to %p, size %lu", srcPtr, dstPtr, size); +} + +void EthernetConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) { + // Validating Transport Protocol + validateTransport(dst, remoteTransport()); + + // Initializing Variables + uint64_t oldValue = *src; + uint64_t* dstPtr = reinterpret_cast(reinterpret_cast(dst.originalDataPtr()) + dstOffset); + uint64_t dataSize = sizeof(uint64_t); + uint64_t messageSize = 0; + *src = newValue; + + // Copying Data to Send Buffer + char* dstPtrBytes = reinterpret_cast(&dstPtr); + std::copy(dstPtrBytes, dstPtrBytes + sizeof(dstPtr), sendBuffer_.data() + messageSize / sizeof(char)); + messageSize += sizeof(dstPtr); + char* sizeBytes = reinterpret_cast(&dataSize); + std::copy(sizeBytes, sizeBytes + sizeof(dataSize), sendBuffer_.data() + messageSize / sizeof(char)); + messageSize += sizeof(dataSize); + char* dataBytes = reinterpret_cast(src); + std::copy(dataBytes, dataBytes + dataSize, sendBuffer_.data() + messageSize / sizeof(char)); + messageSize += dataSize; + + // Sending Message + sendSocket_->send(sendBuffer_.data(), messageSize); + + INFO(MSCCLPP_NET, "EthernetConnection atomic write: from %p to %p, %lu -> %lu", src, dstPtr + dstOffset, oldValue, + newValue); +} + +void EthernetConnection::flush(int64_t timeoutUsec) { INFO(MSCCLPP_NET, "EthernetConnection flushing connection"); } + +void EthernetConnection::recvMessages() { + // Declarating Variables + char* ptr; + uint64_t size; + uint64_t recvSize; + int closed = 0; + bool received = true; + + // Receiving Messages Until Connection is Closed + while (recvSocket_->getState() != SocketStateClosed) { + // Receiving Data Address + if (closed == 0) recvSocket_->recvUntilEnd(&ptr, sizeof(char*), &closed); + received &= !closed; + + // Receiving data size + if (closed == 0) recvSocket_->recvUntilEnd(&size, sizeof(uint64_t), &closed); + received &= !closed; + + // Receiving Data and Copying Data yo GPU + recvSize = 0; + while (recvSize < size && closed == 0) { + uint64_t messageSize = std::min(recvBufferSize_, (size - recvSize) / sizeof(char)) * sizeof(char); + recvSocket_->recvUntilEnd(recvBuffer_.data(), messageSize, &closed); + received &= !closed; + + if (received) + mscclpp::memcpyCuda((char*)ptr + (recvSize / sizeof(char)), recvBuffer_.data(), messageSize, + cudaMemcpyHostToDevice); + recvSize += messageSize; + } + } +} + } // namespace mscclpp diff --git a/src/context.cc b/src/context.cc index d04a8e32c..f8bb3ec83 100644 --- a/src/context.cc +++ b/src/context.cc @@ -49,9 +49,15 @@ MSCCLPP_API_CPP std::shared_ptr Context::connect(Endpoint localEndpo throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage); } conn = std::make_shared(localEndpoint, remoteEndpoint, *this); + } else if (localEndpoint.transport() == Transport::Ethernet) { + if (remoteEndpoint.transport() != Transport::Ethernet) { + throw mscclpp::Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage); + } + conn = std::make_shared(localEndpoint, remoteEndpoint); } else { throw mscclpp::Error("Unsupported transport", ErrorCode::InternalError); } + pimpl_->connections_.push_back(conn); return conn; } diff --git a/src/core.cc b/src/core.cc index 4d89250d0..32b67c3d2 100644 --- a/src/core.cc +++ b/src/core.cc @@ -87,7 +87,7 @@ const TransportFlags NoTransports = TransportFlags(); const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 | Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7; -const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc; +const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc | Transport::Ethernet; void Setuppable::beginSetup(std::shared_ptr) {} diff --git a/src/endpoint.cc b/src/endpoint.cc index dbc773898..68c2726de 100644 --- a/src/endpoint.cc +++ b/src/endpoint.cc @@ -4,6 +4,7 @@ #include "api.h" #include "context.hpp" +#include "socket.h" #include "utils_internal.hpp" namespace mscclpp { @@ -15,6 +16,16 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl) ibQp_ = contextImpl.getIbContext(transport_) ->createQp(config.ibMaxCqSize, config.ibMaxCqPollNum, config.ibMaxSendWr, 0, config.ibMaxWrPerSend); ibQpInfo_ = ibQp_->getInfo(); + } else if (transport_ == Transport::Ethernet) { + // Configuring Ethernet Interfaces + abortFlag_ = 0; + int ret = FindInterfaces(netIfName_, &socketAddress_, MAX_IF_NAME_SIZE, 1, ""); + if (ret <= 0) throw Error("NET/Socket", ErrorCode::InternalError); + + // Starting Server Socket + socket_ = std::make_unique(&socketAddress_, MSCCLPP_SOCKET_MAGIC, SocketTypeBootstrap, abortFlag_); + socket_->bindAndListen(); + socketAddress_ = socket_->getAddr(); } } @@ -27,6 +38,10 @@ MSCCLPP_API_CPP std::vector Endpoint::serialize() { if (AllIBTransports.has(pimpl_->transport_)) { std::copy_n(reinterpret_cast(&pimpl_->ibQpInfo_), sizeof(pimpl_->ibQpInfo_), std::back_inserter(data)); } + if ((pimpl_->transport_) == Transport::Ethernet) { + std::copy_n(reinterpret_cast(&pimpl_->socketAddress_), sizeof(pimpl_->socketAddress_), + std::back_inserter(data)); + } return data; } @@ -45,6 +60,10 @@ Endpoint::Impl::Impl(const std::vector& serialization) { std::copy_n(it, sizeof(ibQpInfo_), reinterpret_cast(&ibQpInfo_)); it += sizeof(ibQpInfo_); } + if (transport_ == Transport::Ethernet) { + std::copy_n(it, sizeof(socketAddress_), reinterpret_cast(&socketAddress_)); + it += sizeof(socketAddress_); + } } MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr pimpl) : pimpl_(pimpl) {} diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 47b154758..283bb8d07 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -11,6 +11,7 @@ #include "context.hpp" #include "ib.hpp" #include "registered_memory.hpp" +#include "socket.h" namespace mscclpp { @@ -53,6 +54,38 @@ class IBConnection : public Connection { void flush(int64_t timeoutUsec) override; }; +class EthernetConnection : public Connection { + std::unique_ptr sendSocket_; + std::unique_ptr recvSocket_; + std::thread threadRecvMessages_; + volatile uint32_t* abortFlag_; + const uint64_t sendBufferSize_; + const uint64_t recvBufferSize_; + std::vector sendBuffer_; + std::vector recvBuffer_; + + public: + EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize = 256 * 1024 * 1024, + uint64_t recvBufferSize = 256 * 1024 * 1024); + + ~EthernetConnection(); + + Transport transport() override; + + Transport remoteTransport() override; + + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) override; + void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override; + + void flush(int64_t timeoutUsec) override; + + private: + void recvMessages(); + + void sendMessage(); +}; + } // namespace mscclpp #endif // MSCCLPP_CONNECTION_HPP_ diff --git a/src/include/endpoint.hpp b/src/include/endpoint.hpp index 311fa9982..734a6c1bd 100644 --- a/src/include/endpoint.hpp +++ b/src/include/endpoint.hpp @@ -8,6 +8,9 @@ #include #include "ib.hpp" +#include "socket.h" + +#define MAX_IF_NAME_SIZE 16 namespace mscclpp { @@ -22,6 +25,12 @@ struct Endpoint::Impl { bool ibLocal_; IbQp* ibQp_; IbQpInfo ibQpInfo_; + + // The following are only used for Ethernet and are undefined for other transports. + std::unique_ptr socket_; + SocketAddress socketAddress_; + volatile uint32_t* abortFlag_; + char netIfName_[MAX_IF_NAME_SIZE + 1]; }; } // namespace mscclpp diff --git a/src/include/socket.h b/src/include/socket.h index ed125c990..77cdfa61a 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -69,6 +69,7 @@ class Socket { void accept(const Socket* listenSocket, int64_t timeout = -1); void send(void* ptr, int size); void recv(void* ptr, int size); + void recvUntilEnd(void* ptr, int size, int* closed); void close(); int getFd() const { return fd_; } diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 6d5fd79f5..0702c497b 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -62,6 +62,8 @@ MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default; MSCCLPP_API_CPP void* RegisteredMemory::data() const { return pimpl_->data; } +MSCCLPP_API_CPP void* RegisteredMemory::originalDataPtr() const { return pimpl_->originalDataPtr; } + MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl_->size; } MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl_->transports; } diff --git a/test/mp_unit/communicator_tests.cu b/test/mp_unit/communicator_tests.cu index 30727667d..adb6b5df6 100644 --- a/test/mp_unit/communicator_tests.cu +++ b/test/mp_unit/communicator_tests.cu @@ -42,14 +42,16 @@ void CommunicatorTestBase::TearDown() { void CommunicatorTestBase::setNumRanksToUse(int num) { numRanksToUse = num; } -void CommunicatorTestBase::connectMesh(bool useIbOnly) { +void CommunicatorTestBase::connectMesh(bool useIpc, bool useIb, bool useEthernet) { std::vector>> connectionFutures(numRanksToUse); for (int i = 0; i < numRanksToUse; i++) { if (i != gEnv->rank) { - if ((rankToNode(i) == rankToNode(gEnv->rank)) && !useIbOnly) { + if ((rankToNode(i) == rankToNode(gEnv->rank)) && useIpc) { connectionFutures[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::CudaIpc); - } else { + } else if (useIb) { connectionFutures[i] = communicator->connectOnSetup(i, 0, ibTransport); + } else if (useEthernet) { + connectionFutures[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::Ethernet); } } } @@ -97,7 +99,7 @@ void CommunicatorTest::SetUp() { ASSERT_EQ((deviceBufferSize / sizeof(int)) % gEnv->worldSize, 0); - connectMesh(); + connectMesh(true, true, false); devicePtr.resize(numBuffers); localMemory.resize(numBuffers); @@ -281,4 +283,4 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) { ASSERT_TRUE(testWriteCorrectness()); communicator->bootstrap()->barrier(); -} +} \ No newline at end of file diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index 6cb159c67..e13a05104 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -93,7 +93,7 @@ class CommunicatorTestBase : public MultiProcessTest { void TearDown() override; void setNumRanksToUse(int num); - void connectMesh(bool useIbOnly = false); + void connectMesh(bool useIpc = true, bool useIb = true, bool useEthernet = false); // Register a local memory and receive corresponding remote memories void registerMemoryPairs(void* buff, size_t buffSize, mscclpp::TransportFlags transport, int tag, @@ -130,13 +130,21 @@ using DeviceHandle = mscclpp::DeviceHandle; class ProxyChannelOneToOneTest : public CommunicatorTestBase { protected: + struct PingPongTestParams { + bool useIPC; + bool useIB; + bool useEthernet; + bool waitWithPoll; + }; + void SetUp() override; void TearDown() override; - void setupMeshConnections(std::vector& proxyChannels, bool useIbOnly, void* sendBuff, - size_t sendBuffBytes, void* recvBuff = nullptr, size_t recvBuffBytes = 0); - void testPingPong(bool useIbOnly, bool waitWithPoll); - void testPingPongPerf(bool useIbOnly, bool waitWithPoll); + void setupMeshConnections(std::vector& proxyChannels, bool useIPC, bool useIb, + bool useEthernet, void* sendBuff, size_t sendBuffBytes, void* recvBuff = nullptr, + size_t recvBuffBytes = 0); + void testPingPong(PingPongTestParams params); + void testPingPongPerf(PingPongTestParams params); void testPacketPingPong(bool useIbOnly); void testPacketPingPongPerf(bool useIbOnly); diff --git a/test/mp_unit/proxy_channel_tests.cu b/test/mp_unit/proxy_channel_tests.cu index 796a565d4..75858b631 100644 --- a/test/mp_unit/proxy_channel_tests.cu +++ b/test/mp_unit/proxy_channel_tests.cu @@ -16,12 +16,16 @@ void ProxyChannelOneToOneTest::SetUp() { void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); } void ProxyChannelOneToOneTest::setupMeshConnections(std::vector& proxyChannels, - bool useIbOnly, void* sendBuff, size_t sendBuffBytes, - void* recvBuff, size_t recvBuffBytes) { + bool useIPC, bool useIb, bool useEthernet, void* sendBuff, + size_t sendBuffBytes, void* recvBuff, size_t recvBuffBytes) { const int rank = communicator->bootstrap()->getRank(); const int worldSize = communicator->bootstrap()->getNranks(); const bool isInPlace = (recvBuff == nullptr); - mscclpp::TransportFlags transport = (useIbOnly) ? ibTransport : (mscclpp::Transport::CudaIpc | ibTransport); + mscclpp::TransportFlags transport; + + if (useIPC) transport |= mscclpp::Transport::CudaIpc; + if (useIb) transport |= ibTransport; + if (useEthernet) transport |= mscclpp::Transport::Ethernet; std::vector>> connectionFutures(worldSize); std::vector> remoteMemFutures(worldSize); @@ -36,10 +40,12 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vectorrank)) && !useIbOnly) { + if ((rankToNode(r) == rankToNode(gEnv->rank)) && useIPC) { connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc); - } else { + } else if (useIb) { connectionFutures[r] = communicator->connectOnSetup(r, 0, ibTransport); + } else if (useEthernet) { + connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::Ethernet); } if (isInPlace) { @@ -145,14 +151,14 @@ __global__ void kernelProxyPingPong(int* buff, int rank, int nElem, bool waitWit } } -void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) { +void ProxyChannelOneToOneTest::testPingPong(PingPongTestParams params) { if (gEnv->rank >= numRanksToUse) return; const int nElem = 4 * 1024 * 1024; std::vector proxyChannels; std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); - setupMeshConnections(proxyChannels, useIbOnly, buff.get(), nElem * sizeof(int)); + setupMeshConnections(proxyChannels, params.useIPC, params.useIB, params.useEthernet, buff.get(), nElem * sizeof(int)); std::vector> proxyChannelHandles; for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle()); @@ -167,22 +173,22 @@ void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) { const int nTries = 1000; - kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, nTries, ret.get()); + kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, params.waitWithPoll, nTries, ret.get()); MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); EXPECT_EQ(*ret, 0); - kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, waitWithPoll, nTries, ret.get()); + kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, params.waitWithPoll, nTries, ret.get()); MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); EXPECT_EQ(*ret, 0); - kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, waitWithPoll, nTries, ret.get()); + kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024 * 1024, params.waitWithPoll, nTries, ret.get()); MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); EXPECT_EQ(*ret, 0); - kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, waitWithPoll, nTries, ret.get()); + kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 4 * 1024 * 1024, params.waitWithPoll, nTries, ret.get()); MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); EXPECT_EQ(*ret, 0); @@ -190,14 +196,14 @@ void ProxyChannelOneToOneTest::testPingPong(bool useIbOnly, bool waitWithPoll) { proxyService->stopProxy(); } -void ProxyChannelOneToOneTest::testPingPongPerf(bool useIbOnly, bool waitWithPoll) { +void ProxyChannelOneToOneTest::testPingPongPerf(PingPongTestParams params) { if (gEnv->rank >= numRanksToUse) return; const int nElem = 4 * 1024 * 1024; std::vector proxyChannels; std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); - setupMeshConnections(proxyChannels, useIbOnly, buff.get(), nElem * sizeof(int)); + setupMeshConnections(proxyChannels, params.useIPC, params.useIB, params.useEthernet, buff.get(), nElem * sizeof(int)); std::vector> proxyChannelHandles; for (auto& ch : proxyChannels) proxyChannelHandles.push_back(ch.deviceHandle()); @@ -212,17 +218,17 @@ void ProxyChannelOneToOneTest::testPingPongPerf(bool useIbOnly, bool waitWithPol auto* testInfo = ::testing::UnitTest::GetInstance()->current_test_info(); const std::string testName = std::string(testInfo->test_suite_name()) + "." + std::string(testInfo->name()); - const int nTries = 1000000; + const int nTries = 1000; // Warm-up - kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, nTries, ret.get()); + kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, params.waitWithPoll, nTries, ret.get()); MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); communicator->bootstrap()->barrier(); // Measure latency mscclpp::Timer timer; - kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, waitWithPoll, nTries, ret.get()); + kernelProxyPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, params.waitWithPoll, nTries, ret.get()); MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); communicator->bootstrap()->barrier(); @@ -234,17 +240,37 @@ void ProxyChannelOneToOneTest::testPingPongPerf(bool useIbOnly, bool waitWithPol proxyService->stopProxy(); } -TEST_F(ProxyChannelOneToOneTest, PingPong) { testPingPong(false, false); } +TEST_F(ProxyChannelOneToOneTest, PingPong) { + testPingPong(PingPongTestParams{.useIPC = true, .useIB = true, .useEthernet = false, .waitWithPoll = false}); +} + +TEST_F(ProxyChannelOneToOneTest, PingPongIb) { + testPingPong(PingPongTestParams{.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false}); +} + +TEST_F(ProxyChannelOneToOneTest, PingPongEthernet) { + testPingPong(PingPongTestParams{.useIPC = false, .useIB = false, .useEthernet = true, .waitWithPoll = false}); +} -TEST_F(ProxyChannelOneToOneTest, PingPongIb) { testPingPong(true, false); } +TEST_F(ProxyChannelOneToOneTest, PingPongWithPoll) { + testPingPong(PingPongTestParams{.useIPC = true, .useIB = true, .useEthernet = false, .waitWithPoll = true}); +} -TEST_F(ProxyChannelOneToOneTest, PingPongWithPoll) { testPingPong(false, true); } +TEST_F(ProxyChannelOneToOneTest, PingPongIbWithPoll) { + testPingPong(PingPongTestParams{.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = true}); +} -TEST_F(ProxyChannelOneToOneTest, PingPongIbWithPoll) { testPingPong(true, true); } +TEST_F(ProxyChannelOneToOneTest, PingPongPerf) { + testPingPongPerf(PingPongTestParams{.useIPC = true, .useIB = true, .useEthernet = false, .waitWithPoll = false}); +} -TEST_F(ProxyChannelOneToOneTest, PingPongPerf) { testPingPongPerf(false, false); } +TEST_F(ProxyChannelOneToOneTest, PingPongPerfIb) { + testPingPongPerf(PingPongTestParams{.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false}); +} -TEST_F(ProxyChannelOneToOneTest, PingPongPerfIb) { testPingPongPerf(true, false); } +TEST_F(ProxyChannelOneToOneTest, PingPongPerfEthernet) { + testPingPongPerf(PingPongTestParams{.useIPC = false, .useIB = false, .useEthernet = true, .waitWithPoll = false}); +} __device__ mscclpp::DeviceSyncer gChannelOneToOneTestProxyChansSyncer; @@ -324,8 +350,8 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) { auto putPacketBuffer = mscclpp::allocExtSharedCuda(nPacket); auto getPacketBuffer = mscclpp::allocExtSharedCuda(nPacket); - setupMeshConnections(proxyChannels, useIbOnly, putPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket), - getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket)); + setupMeshConnections(proxyChannels, !useIbOnly, true, false, putPacketBuffer.get(), + nPacket * sizeof(mscclpp::LLPacket), getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket)); ASSERT_EQ(proxyChannels.size(), 1); @@ -391,8 +417,8 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) { auto putPacketBuffer = mscclpp::allocExtSharedCuda(nPacket); auto getPacketBuffer = mscclpp::allocExtSharedCuda(nPacket); - setupMeshConnections(proxyChannels, useIbOnly, putPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket), - getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket)); + setupMeshConnections(proxyChannels, !useIbOnly, true, false, putPacketBuffer.get(), + nPacket * sizeof(mscclpp::LLPacket), getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket)); ASSERT_EQ(proxyChannels.size(), 1);