From 234e2abb4e38f3d6be5b9cbcc86ac6bf27f7490f Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Tue, 6 Oct 2020 01:56:26 -0700 Subject: [PATCH 01/15] Fixed compilation erros in src/Perf.h --- src/Perf.cc | 2 -- src/Perf.h | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Perf.cc b/src/Perf.cc index 155802c..faf154a 100644 --- a/src/Perf.cc +++ b/src/Perf.cc @@ -15,8 +15,6 @@ #include "Perf.h" -#include - #include #include diff --git a/src/Perf.h b/src/Perf.h index bf9668d..2349b01 100644 --- a/src/Perf.h +++ b/src/Perf.h @@ -17,7 +17,7 @@ #define HOMA_PERF_H #include -#include +#include #include From 9d4e95e090cd06f6c6ce026c961ba7696bf00e02 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Fri, 17 Jul 2020 15:09:34 -0700 Subject: [PATCH 02/15] Changes to the Driver interface to get Homa ready for Shenango integration Details: - Disable DpdkDriver in CMakeLists.txt temporarily - Remove unused method Driver::Packet::getMaxPayloadSize() - Remove param `driver` in handleXXXPacket - Change Driver::Packet to become a POD struct - Change Driver::Packet::{address,priority} into params in Driver::sendPacket - Remove the opaque Driver::Address - Use IP packets as the common interface between transport and driver - Extend Homa packet headers to include L4 src/dst port numbers - Use SocketAddress (i.e., ip + port) as opposed to Driver::Address to identify the src/dst address of a message --- CMakeLists.txt | 71 +++--- include/Homa/Driver.h | 177 ++++++--------- include/Homa/Drivers/Fake/FakeDriver.h | 52 ++--- include/Homa/Homa.h | 19 +- include/Homa/Util.h | 6 + src/ControlPacket.h | 8 +- src/Drivers/DPDK/DpdkDriverImpl.h | 6 - src/Drivers/Fake/FakeAddressTest.cc | 75 ------- src/Drivers/Fake/FakeDriver.cc | 84 ++------ src/Drivers/Fake/FakeDriverTest.cc | 61 ++---- src/Mock/MockDriver.h | 42 ++-- src/Mock/MockPolicy.h | 4 +- src/Mock/MockReceiver.h | 17 +- src/Mock/MockSender.h | 22 +- src/Policy.cc | 8 +- src/Policy.h | 6 +- src/PolicyTest.cc | 2 +- src/Protocol.h | 16 +- src/Receiver.cc | 47 ++-- src/Receiver.h | 13 +- src/ReceiverTest.cc | 162 +++++++------- src/Sender.cc | 63 ++---- src/Sender.h | 30 +-- src/SenderTest.cc | 286 +++++++++++++------------ src/TransportImpl.cc | 99 +++++---- src/TransportImpl.h | 6 +- src/TransportImplTest.cc | 41 ++-- test/system_test.cc | 15 +- 28 files changed, 607 insertions(+), 831 deletions(-) delete mode 100644 src/Drivers/Fake/FakeAddressTest.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a6f9c1..2f82962 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/modules) find_package(Doxygen OPTIONAL_COMPONENTS dot mscgen dia) # Network Interface library (https://www.dpdk.org/) -find_package(Dpdk REQUIRED) +# find_package(Dpdk REQUIRED) # Source control tool; needed to download external libraries. find_package(Git REQUIRED) @@ -135,34 +135,34 @@ target_compile_options(FakeDriver ) ## lib DpdkDriver ############################################################## -add_library(DpdkDriver - src/Drivers/DPDK/DpdkDriver.cc - src/Drivers/DPDK/DpdkDriverImpl.cc - src/Drivers/DPDK/MacAddress.cc -) -add_library(Homa::DpdkDriver ALIAS DpdkDriver) -target_include_directories(DpdkDriver - PUBLIC - $ - $ - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/src -) -target_link_libraries(DpdkDriver - PRIVATE - Dpdk::Dpdk - PUBLIC - Homa -) -target_compile_features(DpdkDriver - PUBLIC - cxx_std_11 -) -target_compile_options(DpdkDriver - PRIVATE - -Wall - -Wextra -) +#add_library(DpdkDriver +# src/Drivers/DPDK/DpdkDriver.cc +# src/Drivers/DPDK/DpdkDriverImpl.cc +# src/Drivers/DPDK/MacAddress.cc +#) +#add_library(Homa::DpdkDriver ALIAS DpdkDriver) +#target_include_directories(DpdkDriver +# PUBLIC +# $ +# $ +# PRIVATE +# ${CMAKE_CURRENT_SOURCE_DIR}/src +#) +#target_link_libraries(DpdkDriver +# PRIVATE +# Dpdk::Dpdk +# PUBLIC +# Homa +#) +#target_compile_features(DpdkDriver +# PUBLIC +# cxx_std_11 +#) +#target_compile_options(DpdkDriver +# PRIVATE +# -Wall +# -Wextra +#) ################################################################################ ## Tests ####################################################################### @@ -195,7 +195,8 @@ endif() ## Install & Export ############################################################ ################################################################################ -install(TARGETS Homa DpdkDriver FakeDriver EXPORT HomaTargets +#install(TARGETS Homa DpdkDriver FakeDriver EXPORT HomaTargets +install(TARGETS Homa FakeDriver EXPORT HomaTargets LIBRARY DESTINATION lib ARCHIVE DESTINATION lib RUNTIME DESTINATION bin @@ -274,11 +275,11 @@ target_sources(unit_test target_link_libraries(unit_test FakeDriver) #DPDK Tests -target_sources(unit_test - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/src/Drivers/DPDK/MacAddressTest.cc -) -target_link_libraries(unit_test DpdkDriver) +#target_sources(unit_test +# PUBLIC +# ${CMAKE_CURRENT_SOURCE_DIR}/src/Drivers/DPDK/MacAddressTest.cc +#) +#target_link_libraries(unit_test DpdkDriver) target_link_libraries(unit_test gmock_main) # -fno-access-control allows access to private members for testing diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index ecfe666..d510046 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -22,6 +22,28 @@ namespace Homa { +/// IPv4 address in host byte order. +using IpAddress = uint32_t; + +/** + * Represents a packet of data that can be send or is received over the network. + * A Packet logically contains only the transport-layer (L4) Homa header in + * addition to application data. + * + * This struct specifies the minimal object layout of a packet that the core + * Homa protocol depends on (e.g., Homa::Core::{Sender, Receiver}); this is + * useful for applications that only want to use the transport layer of this + * library and have their own infrastructures for sending and receiving packets. + */ +struct PacketSpec { + /// Pointer to an array of bytes containing the payload of this Packet. + /// This array is valid until the Packet is released back to the Driver. + void* payload; + + /// Number of bytes in the payload. + int32_t length; +} __attribute__((packed)); + /** * Used by Homa::Transport to send and receive unreliable datagrams. Provides * the interface to which all Driver implementations must conform. @@ -31,133 +53,46 @@ namespace Homa { class Driver { public: /** - * Represents a Network address. + * Represents a packet that can be send or is received over the network. * - * Each Address representation is specific to the Driver instance that - * returned the it; they cannot be use interchangeably between different - * Driver instances. - */ - using Address = uint64_t; - - /** - * Used to hold a driver's serialized byte-format for a network address. + * The layout of this struct has two parts: the first part is essentially + * a copy of PacketSpec, while the second part contains members specific + * to our driver implementation. * - * Each driver may define its own byte-format so long as fits within the - * bytes array. + * @sa Homa::PacketSpec */ - struct WireFormatAddress { - uint8_t type; ///< Can be used to distinguish between different wire - ///< address formats. - uint8_t bytes[19]; ///< Holds an Address's serialized byte-format. - } __attribute__((packed)); + struct Packet final { + // === PacketSpec definitions === + // The order and types of the following members must match those in + // PacketSpec precisely. - /** - * Represents a packet of data that can be send or is received over the - * network. A Packet logically contains only the payload and not any Driver - * specific headers. - * - * A Packet may be Driver specific and should not used interchangeably - * between Driver instances or implementations. - * - * This class is NOT thread-safe but the Transport and Driver's use of - * Packet objects should be allow the Transport and the Driver to execute on - * different threads. - */ - class Packet { - public: - /// Packet's source or destination. When sending a Packet, the address - /// field will contain the destination Address. When receiving a Packet, - /// address field will contain the source Address. - Address address; - - /// Packet's network priority (send only); the lowest possible priority - /// is 0. The highest priority is positive number defined by the - /// Driver; the highest priority can be queried by calling the method - /// getHighestPacketPriority(). - int priority; - - /// Pointer to an array of bytes containing the payload of this Packet. - /// This array is valid until the Packet is released back to the Driver. - void* const payload; + /// See Homa::PacketSpec::payload. + void* payload; - /// Number of bytes in the payload. - int length; + /// See Homa::PacketSpec::length + int32_t length; - /// Return the maximum number of bytes the payload can hold. - virtual int getMaxPayloadSize() = 0; + // === Extended definitions === + // The following members are specific to the driver framework bundled + // in this library. Therefore, these members must *NOT* appear in the + // core components of Homa transport; they are only used in a few + // places to facilitate the glue code between transport and driver. - protected: - /** - * Construct a Packet. - */ - explicit Packet(void* payload, int length = 0) - : address() - , priority(0) - , payload(payload) - , length(length) - {} + /// Packet's source IpAddress. Only meaningful when this packet is an + /// incoming packet. + IpAddress sourceIp; + } __attribute__((packed)); - // DISALLOW_COPY_AND_ASSIGN - Packet(const Packet&) = delete; - Packet& operator=(const Packet&) = delete; - }; + // Static checks to enforce the object layout compatibility between + // Driver::Packet and PacketSpec. + static_assert(offsetof(Packet, payload) == offsetof(PacketSpec, payload)); + static_assert(offsetof(Packet, length) == offsetof(PacketSpec, length)); /** * Driver destructor. */ virtual ~Driver() = default; - /** - * Return a Driver specific network address for the given string - * representation of the address. - * - * @param addressString - * The string representation of the address to return. The address - * string format can be Driver specific. - * - * @return - * Address that can be the source or destination of a Packet. - * - * @throw BadAddress - * _addressString_ is malformed. - */ - virtual Address getAddress(std::string const* const addressString) = 0; - - /** - * Return a Driver specific network address for the given serialized - * byte-format of the address. - * - * @param wireAddress - * The serialized byte-format of the address to be returned. The - * format can be Driver specific. - * - * @return - * Address that can be the source or destination of a Packet. - * - * @throw BadAddress - * _rawAddress_ is malformed. - */ - virtual Address getAddress(WireFormatAddress const* const wireAddress) = 0; - - /** - * Return the string representation of a network address. - * - * @param address - * Address whose string representation should be returned. - */ - virtual std::string addressToString(const Address address) = 0; - - /** - * Serialize a network address into its Raw byte format. - * - * @param address - * Address to be serialized. - * @param[out] wireAddress - * WireFormatAddress object to which the Address is serialized. - */ - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) = 0; - /** * Allocate a new Packet object from the Driver's pool of resources. The * caller must eventually release the packet by passing it to a call to @@ -187,8 +122,16 @@ class Driver { * * @param packet * Packet to be sent over the network. + * @param destination + * IP address of the packet destination. + * @param priority + * Packet's network priority; the lowest possible priority is 0. + * The highest priority is positive number defined by the Driver; + * the highest priority can be queried by calling the method + * getHighestPacketPriority(). */ - virtual void sendPacket(Packet* packet) = 0; + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority) = 0; /** * Request that the Driver enter the "corked" mode where outbound packets @@ -273,10 +216,10 @@ class Driver { virtual uint32_t getBandwidth() = 0; /** - * Return this Driver's local network Address which it uses as the source - * Address for outgoing packets. + * Return this Driver's local IP address which it uses as the source + * address for outgoing packets. */ - virtual Address getLocalAddress() = 0; + virtual IpAddress getLocalAddress() = 0; /** * Return the number of bytes that have been passed to the Driver through diff --git a/include/Homa/Drivers/Fake/FakeDriver.h b/include/Homa/Drivers/Fake/FakeDriver.h index 8413778..04ce8c0 100644 --- a/include/Homa/Drivers/Fake/FakeDriver.h +++ b/include/Homa/Drivers/Fake/FakeDriver.h @@ -34,7 +34,7 @@ const int NUM_PRIORITIES = 8; /// Maximum number of bytes a packet can hold. const uint32_t MAX_PAYLOAD_SIZE = 1500; -/// A set of methods to contol the underlying FakeNetwork's behavior. +/// A set of methods to control the underlying FakeNetwork's behavior. namespace FakeNetworkConfig { /** * Configure the FakeNetwork to drop packets at the specified loss rate. @@ -51,43 +51,34 @@ void setPacketLossRate(double lossRate); * * @sa Driver::Packet */ -class FakePacket : public Driver::Packet { - public: +struct FakePacket { + /// C-style "inheritance"; used to maintain the base struct as a POD type. + Driver::Packet base; + + /// Raw storage for this packets payload. + char buf[MAX_PAYLOAD_SIZE]; + /** * FakePacket constructor. - * - * @param maxPayloadSize - * The maximum number of bytes this packet can hold. */ explicit FakePacket() - : Packet(buf, 0) + : base{.payload = buf, + .length = 0, + .sourceIp = 0} + , buf() {} /** * Copy constructor. */ FakePacket(const FakePacket& other) - : Packet(buf, other.length) - { - address = other.address; - priority = other.priority; - memcpy(buf, other.buf, MAX_PAYLOAD_SIZE); - } - - virtual ~FakePacket() {} - - /// see Driver::Packet::getMaxPayloadSize() - virtual int getMaxPayloadSize() + : base{.payload = buf, + .length = other.base.length, + .sourceIp = 0} + , buf() { - return MAX_PAYLOAD_SIZE; + memcpy(base.payload, other.base.payload, MAX_PAYLOAD_SIZE); } - - private: - /// Raw storage for this packets payload. - char buf[MAX_PAYLOAD_SIZE]; - - // Disable Assignment - FakePacket& operator=(const FakePacket&) = delete; }; /// Holds the incoming packets for a particular driver. @@ -117,20 +108,15 @@ class FakeDriver : public Driver { */ virtual ~FakeDriver(); - virtual Address getAddress(std::string const* const addressString); - virtual Address getAddress(WireFormatAddress const* const wireAddress); - virtual std::string addressToString(const Address address); - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); virtual Packet* allocPacket(); - virtual void sendPacket(Packet* packet); + virtual void sendPacket(Packet* packet, IpAddress destination, int priority); virtual uint32_t receivePackets(uint32_t maxPackets, Packet* receivedPackets[]); virtual void releasePackets(Packet* packets[], uint16_t numPackets); virtual int getHighestPacketPriority(); virtual uint32_t getMaxPayloadSize(); virtual uint32_t getBandwidth(); - virtual Address getLocalAddress(); + virtual IpAddress getLocalAddress(); virtual uint32_t getQueuedBytes(); private: diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index dec090c..aba9073 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -37,6 +37,17 @@ namespace Homa { template using unique_ptr = std::unique_ptr; +/** + * Represents a socket address to (from) which we can send (receive) messages. + */ +struct SocketAddress { + /// IPv4 address in host byte order. + IpAddress ip; + + /// Port number in host byte order. + uint16_t port; +}; + /** * Represents an array of bytes that has been received over the network. * @@ -220,11 +231,11 @@ class OutMessage { * Send this message to the destination. * * @param destination - * Address of the transport to which this message will be sent. + * Network address to which this message will be sent. * @param options * Flags to request non-default sending behavior. */ - virtual void send(Driver::Address destination, + virtual void send(SocketAddress destination, Options options = Options::NONE) = 0; protected: @@ -265,10 +276,12 @@ class Transport { /** * Allocate Message that can be sent with this Transport. * + * @param sourcePort + * Port number of the socket from which the message will be sent. * @return * A pointer to the allocated message. */ - virtual Homa::unique_ptr alloc() = 0; + virtual Homa::unique_ptr alloc(uint16_t sourcePort) = 0; /** * Check for and return a Message sent to this Transport if available. diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 121bb44..30a3548 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -21,6 +21,12 @@ #include #include +/// Cast a member of a structure out to the containing structure. +#define container_of(ptr, type, member) ({ \ + const typeof( ((type *)0)->member ) \ + *__mptr = (ptr); \ + (type *)( (char *)__mptr - offsetof(type,member) );}) + namespace Homa { namespace Util { diff --git a/src/ControlPacket.h b/src/ControlPacket.h index a8da070..bc53f10 100644 --- a/src/ControlPacket.h +++ b/src/ControlPacket.h @@ -31,21 +31,19 @@ namespace ControlPacket { * @param driver * Driver with which to send the packet. * @param address - * Destination address for the packet to be sent. + * Destination IP address for the packet to be sent. * @param args * Arguments to PacketHeaderType's constructor. */ template void -send(Driver* driver, Driver::Address address, Args&&... args) +send(Driver* driver, IpAddress address, Args&&... args) { Driver::Packet* packet = driver->allocPacket(); new (packet->payload) PacketHeaderType(static_cast(args)...); packet->length = sizeof(PacketHeaderType); - packet->address = address; - packet->priority = driver->getHighestPacketPriority(); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, address, driver->getHighestPacketPriority()); driver->releasePackets(&packet, 1); } diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 289e83f..9b77383 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -109,12 +109,6 @@ class DpdkDriver::Impl { explicit Packet(struct rte_mbuf* mbuf, void* data); explicit Packet(OverflowBuffer* overflowBuf); - /// see Driver::Packet::getMaxPayloadSize() - virtual int getMaxPayloadSize() - { - return MAX_PAYLOAD_SIZE; - } - /// Used to indicate whether the packet is backed by an DPDK mbuf or a /// driver-level OverflowBuffer. enum BufferType { MBUF, OVERFLOW_BUF } bufType; ///< Packet BufferType. diff --git a/src/Drivers/Fake/FakeAddressTest.cc b/src/Drivers/Fake/FakeAddressTest.cc deleted file mode 100644 index 67cef78..0000000 --- a/src/Drivers/Fake/FakeAddressTest.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2019, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#include - -#include "FakeAddress.h" - -#include "../RawAddressType.h" - -namespace Homa { -namespace Drivers { -namespace Fake { -namespace { - -TEST(FakeAddressTest, constructor_id) -{ - FakeAddress address(42); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_str) -{ - FakeAddress address("42"); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_str_bad) -{ - EXPECT_THROW(FakeAddress address("D42"), BadAddress); -} - -TEST(FakeAddressTest, constructor_raw) -{ - Driver::Address::Raw raw; - raw.type = RawAddressType::FAKE; - *reinterpret_cast(raw.bytes) = 42; - - FakeAddress address(&raw); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_raw_bad) -{ - Driver::Address::Raw raw; - raw.type = !RawAddressType::FAKE; - - EXPECT_THROW(FakeAddress address(&raw), BadAddress); -} - -TEST(FakeAddressTest, toString) -{ - // tested sufficiently in constructor tests -} - -TEST(FakeAddressTest, toAddressId) -{ - EXPECT_THROW(FakeAddress::toAddressId("D42"), BadAddress); -} - -} // namespace -} // namespace Fake -} // namespace Drivers -} // namespace Homa diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index 6200a49..b6355cc 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -72,8 +72,8 @@ static class FakeNetwork { } /// Deliver the provide packet to the specified destination. - void sendPacket(FakePacket* packet, Driver::Address src, - Driver::Address dst) + void sendPacket(FakePacket* packet, int priority, IpAddress src, + IpAddress dst) { FakeNIC* nic = nullptr; { @@ -92,10 +92,10 @@ static class FakeNetwork { assert(nic != nullptr); std::lock_guard lock_nic(nic->mutex, std::adopt_lock); FakePacket* dstPacket = new FakePacket(*packet); - dstPacket->address = src; - assert(dstPacket->priority < NUM_PRIORITIES); - assert(dstPacket->priority >= 0); - nic->priorityQueue.at(dstPacket->priority).push_back(dstPacket); + dstPacket->base.sourceIp = src; + assert(priority < NUM_PRIORITIES); + assert(priority >= 0); + nic->priorityQueue.at(priority).push_back(dstPacket); } void setPacketLossRate(double lossRate) @@ -115,10 +115,9 @@ static class FakeNetwork { std::mutex mutex; /// Holds all the packets being sent through the fake network. - std::unordered_map network; + std::unordered_map network; - /// The FakeAddress identifier for the next FakeDriver that "connects" to - /// the FakeNetwork. + /// Identifier for the next FakeDriver that "connects" to the FakeNetwork. std::atomic nextAddressId; /// Rate at which packets should be dropped when sent over this network. @@ -177,53 +176,6 @@ FakeDriver::~FakeDriver() fakeNetwork.deregisterNIC(localAddressId); } -/** - * See Driver::getAddress() - */ -Driver::Address -FakeDriver::getAddress(std::string const* const addressString) -{ - char* end; - uint64_t address = std::strtoul(addressString->c_str(), &end, 10); - if (address == 0) { - throw BadAddress(HERE_STR, StringUtil::format("Bad address string: %s", - addressString->c_str())); - } - return address; -} - -/** - * See Driver::getAddress() - */ -Driver::Address -FakeDriver::getAddress(Driver::WireFormatAddress const* const wireAddress) -{ - const Address* address = - reinterpret_cast(wireAddress->bytes); - return *address; -} - -/** - * See Driver::addressToString() - */ -std::string -FakeDriver::addressToString(const Address address) -{ - char buf[21]; - snprintf(buf, sizeof(buf), "%lu", address); - return buf; -} - -/** - * See Driver::addressToWireFormat() - */ -void -FakeDriver::addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) -{ - new (reinterpret_cast(wireAddress->bytes)) Address(address); -} - /** * See Driver::allocPacket() */ @@ -231,19 +183,19 @@ Driver::Packet* FakeDriver::allocPacket() { FakePacket* packet = new FakePacket(); - return packet; + return &packet->base; } /** * See Driver::sendPacket() */ void -FakeDriver::sendPacket(Packet* packet) +FakeDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - FakePacket* srcPacket = static_cast(packet); - Address srcAddress = getLocalAddress(); - Address dstAddress = srcPacket->address; - fakeNetwork.sendPacket(srcPacket, srcAddress, dstAddress); + FakePacket* srcPacket = container_of(packet, FakePacket, base); + IpAddress srcAddress = getLocalAddress(); + IpAddress dstAddress = destination; + fakeNetwork.sendPacket(srcPacket, priority, srcAddress, dstAddress); queueEstimator.signalBytesSent(packet->length); } @@ -257,8 +209,9 @@ FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) uint32_t numReceived = 0; for (int i = NUM_PRIORITIES - 1; i >= 0; --i) { while (numReceived < maxPackets && !nic.priorityQueue.at(i).empty()) { - receivedPackets[numReceived] = nic.priorityQueue.at(i).front(); + FakePacket* fakePacket = nic.priorityQueue.at(i).front(); nic.priorityQueue.at(i).pop_front(); + receivedPackets[numReceived] = &fakePacket->base; numReceived++; } } @@ -272,8 +225,7 @@ void FakeDriver::releasePackets(Packet* packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { - FakePacket* packet = static_cast(packets[i]); - delete packet; + delete container_of(packets[i], FakePacket, base); } } @@ -308,7 +260,7 @@ FakeDriver::getBandwidth() /** * See Driver::getLocalAddress() */ -Driver::Address +IpAddress FakeDriver::getLocalAddress() { return localAddressId; diff --git a/src/Drivers/Fake/FakeDriverTest.cc b/src/Drivers/Fake/FakeDriverTest.cc index e410119..2390abf 100644 --- a/src/Drivers/Fake/FakeDriverTest.cc +++ b/src/Drivers/Fake/FakeDriverTest.cc @@ -18,7 +18,6 @@ #include -#include "../RawAddressType.h" #include "StringUtil.h" namespace Homa { @@ -34,46 +33,12 @@ TEST(FakeDriverTest, constructor) EXPECT_EQ(nextAddressId, driver.localAddressId); } -TEST(FakeDriverTest, getAddress_string) -{ - FakeDriver driver; - std::string addressStr("42"); - Driver::Address address = driver.getAddress(&addressStr); - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, getAddress_wireformat) -{ - FakeDriver driver; - Driver::WireFormatAddress wireformatAddress; - wireformatAddress.type = RawAddressType::FAKE; - *reinterpret_cast(wireformatAddress.bytes) = 42; - Driver::Address address = driver.getAddress(&wireformatAddress); - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, addressToString) -{ - FakeDriver driver; - Driver::Address address = 42; - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, addressToWireFormat) -{ - FakeDriver driver; - Driver::WireFormatAddress wireformatAddress; - driver.addressToWireFormat(42, &wireformatAddress); - EXPECT_EQ("42", - driver.addressToString(driver.getAddress(&wireformatAddress))); -} - TEST(FakeDriverTest, allocPacket) { FakeDriver driver; Driver::Packet* packet = driver.allocPacket(); // allocPacket doesn't do much so we just need to make sure we can call it. - delete packet; + delete container_of(packet, FakePacket, base); } TEST(FakeDriverTest, sendPackets) @@ -82,13 +47,14 @@ TEST(FakeDriverTest, sendPackets) FakeDriver driver2; Driver::Packet* packets[4]; + IpAddress destinations[4]; + int prio[4]; for (int i = 0; i < 4; ++i) { packets[i] = driver1.allocPacket(); - packets[i]->address = driver2.getLocalAddress(); - packets[i]->priority = i; + destinations[i] = driver2.getLocalAddress(); + prio[i] = i; } - std::string addressStr("42"); - packets[2]->address = driver1.getAddress(&addressStr); + destinations[2] = IpAddress(42); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -99,7 +65,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - driver1.sendPacket(packets[0]); + driver1.sendPacket(packets[0], destinations[0], prio[0]); EXPECT_EQ(1U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -110,13 +76,12 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); { - Driver::Packet* packet = static_cast( - driver2.nic.priorityQueue.at(0).front()); - EXPECT_EQ(driver1.getLocalAddress(), packet->address); + Driver::Packet* packet = &driver2.nic.priorityQueue.at(0).front()->base; + EXPECT_EQ(driver1.getLocalAddress(), packet->sourceIp); } for (int i = 0; i < 4; ++i) { - driver1.sendPacket(packets[i]); + driver1.sendPacket(packets[i], destinations[i], prio[i]); } EXPECT_EQ(2U, driver2.nic.priorityQueue.at(0).size()); @@ -128,7 +93,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - delete packets[2]; + delete container_of(packets[2], FakePacket, base); } TEST(FakeDriverTest, receivePackets) @@ -235,10 +200,8 @@ TEST(FakeDriverTest, getBandwidth) TEST(FakeDriverTest, getLocalAddress) { uint64_t nextAddressId = FakeDriver().localAddressId + 1; - std::string addressStr = StringUtil::format("%lu", nextAddressId); - FakeDriver driver; - EXPECT_EQ(driver.getAddress(&addressStr), driver.getLocalAddress()); + EXPECT_EQ(nextAddressId, driver.getLocalAddress()); } } // namespace diff --git a/src/Mock/MockDriver.h b/src/Mock/MockDriver.h index 6cc5ea7..35fd731 100644 --- a/src/Mock/MockDriver.h +++ b/src/Mock/MockDriver.h @@ -35,32 +35,22 @@ class MockDriver : public Driver { * * @sa Driver::Packet. */ - class MockPacket : public Driver::Packet { - public: - MockPacket(void* payload, uint16_t length = 0) - : Packet(payload, length) - {} - - MOCK_METHOD0(getMaxPayloadSize, int()); - }; - - MOCK_METHOD1(getAddress, Address(std::string const* const addressString)); - MOCK_METHOD1(getAddress, - Address(WireFormatAddress const* const wireAddress)); - MOCK_METHOD1(addressToString, std::string(Address address)); - MOCK_METHOD2(addressToWireFormat, - void(Address address, WireFormatAddress* wireAddress)); - MOCK_METHOD0(allocPacket, Packet*()); - MOCK_METHOD1(sendPacket, void(Packet* packet)); - MOCK_METHOD0(flushPackets, void()); - MOCK_METHOD2(receivePackets, - uint32_t(uint32_t maxPackets, Packet* receivedPackets[])); - MOCK_METHOD2(releasePackets, void(Packet* packets[], uint16_t numPackets)); - MOCK_METHOD0(getHighestPacketPriority, int()); - MOCK_METHOD0(getMaxPayloadSize, uint32_t()); - MOCK_METHOD0(getBandwidth, uint32_t()); - MOCK_METHOD0(getLocalAddress, Address()); - MOCK_METHOD0(getQueuedBytes, uint32_t()); + using MockPacket = Driver::Packet; + + MOCK_METHOD(Packet*, allocPacket, (), (override)); + MOCK_METHOD(void, sendPacket, + (Packet* packet, IpAddress destination, int priority), + (override)); + MOCK_METHOD(void, flushPackets, ()); + MOCK_METHOD(uint32_t, receivePackets, + (uint32_t maxPackets, Packet* receivedPackets[]), (override)); + MOCK_METHOD(void, releasePackets, (Packet* packets[], uint16_t numPackets), + (override)); + MOCK_METHOD(int, getHighestPacketPriority, (), (override)); + MOCK_METHOD(uint32_t, getMaxPayloadSize, (), (override)); + MOCK_METHOD(uint32_t, getBandwidth, (), (override)); + MOCK_METHOD(IpAddress, getLocalAddress, (), (override)); + MOCK_METHOD(uint32_t, getQueuedBytes, (), (override)); }; } // namespace Mock diff --git a/src/Mock/MockPolicy.h b/src/Mock/MockPolicy.h index 0595f25..52cb2a5 100644 --- a/src/Mock/MockPolicy.h +++ b/src/Mock/MockPolicy.h @@ -36,10 +36,10 @@ class MockPolicyManager : public Core::Policy::Manager { MOCK_METHOD0(getResendPriority, int()); MOCK_METHOD0(getScheduledPolicy, Core::Policy::Scheduled()); MOCK_METHOD2(getUnscheduledPolicy, - Core::Policy::Unscheduled(const Driver::Address destination, + Core::Policy::Unscheduled(const IpAddress destination, const uint32_t messageLength)); MOCK_METHOD3(signalNewMessage, - void(const Driver::Address source, uint8_t policyVersion, + void(const IpAddress source, uint8_t policyVersion, uint32_t messageLength)); MOCK_METHOD0(poll, void()); }; diff --git a/src/Mock/MockReceiver.h b/src/Mock/MockReceiver.h index fc0fa13..75eea2c 100644 --- a/src/Mock/MockReceiver.h +++ b/src/Mock/MockReceiver.h @@ -36,15 +36,14 @@ class MockReceiver : public Core::Receiver { : Receiver(driver, nullptr, messageTimeoutCycles, resendIntervalCycles) {} - MOCK_METHOD2(handleDataPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleBusyPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handlePingPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD0(receiveMessage, Homa::InMessage*()); - MOCK_METHOD0(poll, void()); - MOCK_METHOD0(checkTimeouts, uint64_t()); + MOCK_METHOD(void, handleDataPacket, + (Driver::Packet* packet, IpAddress sourceIp), (override)); + MOCK_METHOD(void, handleBusyPacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, handlePingPacket, + (Driver::Packet* packet, IpAddress sourceIp), (override)); + MOCK_METHOD(Homa::InMessage*, receiveMessage, (), (override)); + MOCK_METHOD(void, poll, (), (override)); + MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); }; } // namespace Mock diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index b67152b..4a8bd27 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -37,19 +37,15 @@ class MockSender : public Core::Sender { pingIntervalCycles) {} - MOCK_METHOD0(allocMessage, Homa::OutMessage*()); - MOCK_METHOD2(handleDonePacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleGrantPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleResendPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleUnknownPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleErrorPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD0(poll, void()); - MOCK_METHOD0(checkTimeouts, uint64_t()); + MOCK_METHOD(Homa::OutMessage*, allocMessage, (uint16_t sport), (override)); + MOCK_METHOD(void, handleDonePacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, handleGrantPacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, handleResendPacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, handleUnknownPacket, (Driver::Packet* packet), + (override)); + MOCK_METHOD(void, handleErrorPacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, poll, (), (override)); + MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); }; } // namespace Mock diff --git a/src/Policy.cc b/src/Policy.cc index cf0e62e..12e7e16 100644 --- a/src/Policy.cc +++ b/src/Policy.cc @@ -97,14 +97,14 @@ Manager::getScheduledPolicy() * unilaterally "granted" (unscheduled) bytes for a new Message to be sent. * * @param destination - * The policy for the Transport at this Address will be returned. + * The policy for the Transport at this IpAddress will be returned. * @param messageLength * The policy for message containing this many bytes will be returned. * * @sa Policy::Unscheduled */ Unscheduled -Manager::getUnscheduledPolicy(const Driver::Address destination, +Manager::getUnscheduledPolicy(const IpAddress destination, const uint32_t messageLength) { SpinLock::Lock lock(mutex); @@ -140,14 +140,14 @@ Manager::getUnscheduledPolicy(const Driver::Address destination, * Called by the Receiver when a new Message has started to arrive. * * @param source - * Address of the Transport from which the new Message was received. + * IpAddress of the Transport from which the new Message was received. * @param policyVersion * Version of the policy the Sender used when sending the Message. * @param messageLength * Number of bytes the new incoming Message contains. */ void -Manager::signalNewMessage(const Driver::Address source, uint8_t policyVersion, +Manager::signalNewMessage(const IpAddress source, uint8_t policyVersion, uint32_t messageLength) { SpinLock::Lock lock(mutex); diff --git a/src/Policy.h b/src/Policy.h index c32bf66..6c80c90 100644 --- a/src/Policy.h +++ b/src/Policy.h @@ -75,9 +75,9 @@ class Manager { virtual ~Manager() = default; virtual int getResendPriority(); virtual Scheduled getScheduledPolicy(); - virtual Unscheduled getUnscheduledPolicy(const Driver::Address destination, + virtual Unscheduled getUnscheduledPolicy(const IpAddress destination, const uint32_t messageLength); - virtual void signalNewMessage(const Driver::Address source, + virtual void signalNewMessage(const IpAddress source, uint8_t policyVersion, uint32_t messageLength); virtual void poll(); @@ -107,7 +107,7 @@ class Manager { /// The scheduled policy for the Transport that owns this Policy::Manager. Scheduled localScheduledPolicy; /// Collection of the known Policies for each peered Homa::Transport; - std::unordered_map peerPolicies; + std::unordered_map peerPolicies; /// Number of bytes that can be transmitted in one round-trip-time. const uint32_t RTT_BYTES; /// The highest network packet priority that the driver supports. diff --git a/src/PolicyTest.cc b/src/PolicyTest.cc index ee0dde5..88cdd45 100644 --- a/src/PolicyTest.cc +++ b/src/PolicyTest.cc @@ -59,7 +59,7 @@ TEST(PolicyManagerTest, getUnscheduledPolicy) EXPECT_CALL(mockDriver, getBandwidth).WillOnce(Return(8000)); EXPECT_CALL(mockDriver, getHighestPacketPriority).WillOnce(Return(7)); Policy::Manager manager(&mockDriver); - Driver::Address dest(22); + IpAddress dest(22); { Policy::Unscheduled policy = manager.getUnscheduledPolicy(dest, 1); diff --git a/src/Protocol.h b/src/Protocol.h index f83725e..25471bb 100644 --- a/src/Protocol.h +++ b/src/Protocol.h @@ -122,16 +122,20 @@ struct HeaderPrefix { /** * Describes the wire format for header fields that are common to all packet - * types. + * types. Note: the first 4 bytes are identical for TCP, UDP, and Homa. */ struct CommonHeader { + uint16_t sport, dport;///< Transport layer (L4) source and destination ports + ///< in network byte order; only used by DataHeader. HeaderPrefix prefix; ///< Common to all versions of the protocol. uint8_t opcode; ///< One of the values of Opcode. MessageId messageId; ///< RemoteOp/Message associated with this packet. /// CommonHeader constructor. CommonHeader(Opcode opcode, MessageId messageId) - : prefix(1) + : sport(0) + , dport(0) + , prefix(1) , opcode(opcode) , messageId(messageId) {} @@ -157,14 +161,18 @@ struct DataHeader { // starting at the offset corresponding to the given packet index. /// DataHeader constructor. - DataHeader(MessageId messageId, uint32_t totalLength, uint8_t policyVersion, + DataHeader(uint16_t sport, uint16_t dport, MessageId messageId, + uint32_t totalLength, uint8_t policyVersion, uint16_t unscheduledIndexLimit, uint16_t index) : common(Opcode::DATA, messageId) , totalLength(totalLength) , policyVersion(policyVersion) , unscheduledIndexLimit(unscheduledIndexLimit) , index(index) - {} + { + common.sport = htobe16(sport); + common.dport = htobe16(dport); + } } __attribute__((packed)); /** diff --git a/src/Receiver.cc b/src/Receiver.cc index 25e0619..d499a61 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -82,11 +82,11 @@ Receiver::~Receiver() * * @param packet * The incoming packet to be processed. - * @param driver - * The driver from which the packet was received. + * @param sourceIp + * Source IP address of the packet. */ void -Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) +Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) { Protocol::Packet::DataHeader* header = static_cast(packet->payload); @@ -102,14 +102,18 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) int numUnscheduledPackets = header->unscheduledIndexLimit; { SpinLock::Lock lock_allocator(messageAllocator.mutex); + SocketAddress srcAddress = { + .ip = sourceIp, + .port = be16toh(header->common.sport) + }; message = messageAllocator.pool.construct( this, driver, dataHeaderLength, messageLength, id, - packet->address, numUnscheduledPackets); + srcAddress, numUnscheduledPackets); } bucket->messages.push_back(&message->bucketNode); - policyManager->signalNewMessage(message->source, header->policyVersion, - header->totalLength); + policyManager->signalNewMessage(message->source.ip, + header->policyVersion, header->totalLength); if (message->scheduled) { // Message needs to be scheduled. @@ -121,7 +125,8 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) // Things that must be true (sanity check) assert(id == message->id); assert(message->driver == driver); - assert(message->source == packet->address); + assert(message->source.ip == sourceIp); + assert(message->source.port == be16toh(header->common.sport)); assert(message->messageLength == Util::downCast(header->totalLength)); // Add the packet @@ -169,11 +174,9 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) * * @param packet * The incoming BUSY packet to be processed. - * @param driver - * The driver from which the BUSY packet was received. */ void -Receiver::handleBusyPacket(Driver::Packet* packet, Driver* driver) +Receiver::handleBusyPacket(Driver::Packet* packet) { Protocol::Packet::BusyHeader* header = static_cast(packet->payload); @@ -198,11 +201,11 @@ Receiver::handleBusyPacket(Driver::Packet* packet, Driver* driver) * * @param packet * The incoming PING packet to be processed. - * @param driver - * The driver from which the PING packet was received. + * @param sourceIp + * Source IP address of the packet. */ void -Receiver::handlePingPacket(Driver::Packet* packet, Driver* driver) +Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) { Protocol::Packet::PingHeader* header = static_cast(packet->payload); @@ -236,13 +239,13 @@ Receiver::handlePingPacket(Driver::Packet* packet, Driver* driver) Perf::counters.tx_grant_pkts.add(1); ControlPacket::send( - driver, message->source, message->id, bytesGranted, priority); + driver, message->source.ip, message->id, bytesGranted, priority); } else { // We are here because we have no knowledge of the message the Sender is // asking about. Reply UNKNOWN so the Sender can react accordingly. Perf::counters.tx_unknown_pkts.add(1); ControlPacket::send( - driver, packet->address, id); + driver, sourceIp, id); } driver->releasePackets(&packet, 1); } @@ -346,7 +349,7 @@ Receiver::Message::acknowledge() const MessageBucket* bucket = receiver->messageBuckets.getBucket(id); SpinLock::Lock lock(bucket->mutex); Perf::counters.tx_done_pkts.add(1); - ControlPacket::send(driver, source, id); + ControlPacket::send(driver, source.ip, id); } /** @@ -367,7 +370,7 @@ Receiver::Message::fail() const MessageBucket* bucket = receiver->messageBuckets.getBucket(id); SpinLock::Lock lock(bucket->mutex); Perf::counters.tx_error_pkts.add(1); - ControlPacket::send(driver, source, id); + ControlPacket::send(driver, source.ip, id); } /** @@ -678,7 +681,7 @@ Receiver::checkResendTimeouts() SpinLock::Lock lock_scheduler(schedulerMutex); Perf::counters.tx_resend_pkts.add(1); ControlPacket::send( - message->driver, message->source, message->id, + message->driver, message->source.ip, message->id, Util::downCast(index), Util::downCast(num), message->scheduledMessageInfo.priority); @@ -691,7 +694,7 @@ Receiver::checkResendTimeouts() SpinLock::Lock lock_scheduler(schedulerMutex); Perf::counters.tx_resend_pkts.add(1); ControlPacket::send( - message->driver, message->source, message->id, + message->driver, message->source.ip, message->id, Util::downCast(index), Util::downCast(num), message->scheduledMessageInfo.priority); @@ -748,7 +751,7 @@ Receiver::trySendGrants() ScheduledMessageInfo* info = &message->scheduledMessageInfo; // Access message const variables without message mutex. const Protocol::MessageId id = message->id; - const Driver::Address source = message->source; + const IpAddress sourceIp = message->source.ip; // Recalculate message priority info->priority = @@ -765,7 +768,7 @@ Receiver::trySendGrants() info->bytesGranted = newGrantLimit; Perf::counters.tx_grant_pkts.add(1); ControlPacket::send( - driver, source, id, + driver, sourceIp, id, Util::downCast(info->bytesGranted), info->priority); } @@ -806,7 +809,7 @@ Receiver::schedule(Receiver::Message* message, const SpinLock::Lock& lock) { (void)lock; ScheduledMessageInfo* info = &message->scheduledMessageInfo; - Peer* peer = &peerTable[message->source]; + Peer* peer = &peerTable[message->source.ip]; // Insert the Message peer->scheduledMessages.push_front(&info->scheduledMessageNode); Intrusive::deprioritize(&peer->scheduledMessages, diff --git a/src/Receiver.h b/src/Receiver.h index 444e1aa..c97c462 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -30,6 +30,7 @@ #include "Protocol.h" #include "SpinLock.h" #include "Timeout.h" +#include "Util.h" namespace Homa { namespace Core { @@ -46,9 +47,9 @@ class Receiver { uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles); virtual ~Receiver(); - virtual void handleDataPacket(Driver::Packet* packet, Driver* driver); - virtual void handleBusyPacket(Driver::Packet* packet, Driver* driver); - virtual void handlePingPacket(Driver::Packet* packet, Driver* driver); + virtual void handleDataPacket(Driver::Packet* packet, IpAddress sourceIp); + virtual void handleBusyPacket(Driver::Packet* packet); + virtual void handlePingPacket(Driver::Packet* packet, IpAddress sourceIp); virtual Homa::InMessage* receiveMessage(); virtual void poll(); virtual uint64_t checkTimeouts(); @@ -132,7 +133,7 @@ class Receiver { explicit Message(Receiver* receiver, Driver* driver, size_t packetHeaderLength, size_t messageLength, - Protocol::MessageId id, Driver::Address source, + Protocol::MessageId id, SocketAddress source, int numUnscheduledPackets) : receiver(receiver) , driver(driver) @@ -195,7 +196,7 @@ class Receiver { const Protocol::MessageId id; /// Contains source address this message. - const Driver::Address source; + const SocketAddress source; /// Number of bytes at the beginning of each Packet that should be /// reserved for the Homa transport header. @@ -473,7 +474,7 @@ class Receiver { /// Collection of all peers; used for fast access. Access is protected by /// the schedulerMutex. - std::unordered_map peerTable; + std::unordered_map peerTable; /// List of peers with inbound messages that require grants to complete. /// Access is protected by the schedulerMutex. diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index a49aee2..213e2bd 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -41,7 +41,7 @@ class ReceiverTest : public ::testing::Test { public: ReceiverTest() : mockDriver() - , mockPacket(&payload) + , mockPacket {&payload} , mockPolicyManager(&mockDriver) , payload() , receiver() @@ -68,7 +68,7 @@ class ReceiverTest : public ::testing::Test { static const uint64_t resendIntervalCycles = 100; NiceMock mockDriver; - NiceMock mockPacket; + Homa::Mock::MockDriver::MockPacket mockPacket; NiceMock mockPolicyManager; char payload[1028]; Receiver* receiver; @@ -105,21 +105,21 @@ TEST_F(ReceiverTest, handleDataPacket) header->totalLength = totalMessageLength; header->policyVersion = policyVersion; header->unscheduledIndexLimit = 1; - mockPacket.address = Driver::Address(22); + mockPacket.sourceIp = IpAddress(22); // ------------------------------------------------------------------------- // Receive packet[1]. New message. header->index = 1; mockPacket.length = HEADER_SIZE + 1000; EXPECT_CALL(mockPolicyManager, - signalNewMessage(Eq(mockPacket.address), Eq(policyVersion), + signalNewMessage(Eq(mockPacket.sourceIp), Eq(policyVersion), Eq(totalMessageLength))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- { @@ -148,7 +148,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- EXPECT_EQ(1U, message->numPackets); @@ -162,7 +162,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- EXPECT_EQ(2U, message->numPackets); @@ -177,7 +177,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- EXPECT_EQ(3U, message->numPackets); @@ -192,7 +192,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- EXPECT_EQ(4U, message->numPackets); @@ -207,7 +207,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); // --------- Mock::VerifyAndClearExpectations(&mockDriver); @@ -217,7 +217,7 @@ TEST_F(ReceiverTest, handleBusyPacket_basic) { Protocol::MessageId id(42, 32); Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(0), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{0, 60001}, 0); Receiver::MessageBucket* bucket = receiver->messageBuckets.getBucket(id); bucket->messages.push_back(&message->bucketNode); @@ -228,7 +228,7 @@ TEST_F(ReceiverTest, handleBusyPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - receiver->handleBusyPacket(&mockPacket, &mockDriver); + receiver->handleBusyPacket(&mockPacket); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->resendTimeout.expirationCycleTime); @@ -245,15 +245,15 @@ TEST_F(ReceiverTest, handleBusyPacket_unknown) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - receiver->handleBusyPacket(&mockPacket, &mockDriver); + receiver->handleBusyPacket(&mockPacket); } TEST_F(ReceiverTest, handlePingPacket_basic) { Protocol::MessageId id(42, 32); - Driver::Address mockAddress = 22; + IpAddress mockAddress = 22; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 20000, id, mockAddress, 0); + receiver, &mockDriver, 0, 20000, id, SocketAddress{mockAddress, 0}, 0); ASSERT_TRUE(message->scheduled); Receiver::ScheduledMessageInfo* info = &message->scheduledMessageInfo; info->bytesGranted = 500; @@ -263,25 +263,25 @@ TEST_F(ReceiverTest, handlePingPacket_basic) bucket->messages.push_back(&message->bucketNode); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket(pingPayload); - pingPacket.address = mockAddress; + Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; + pingPacket.sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(mockAddress), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, &mockDriver); + receiver->handlePingPacket(&pingPacket, pingPacket.sourceIp); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(0U, message->resendTimeout.expirationCycleTime); - EXPECT_EQ(mockAddress, mockPacket.address); Protocol::Packet::GrantHeader* header = (Protocol::Packet::GrantHeader*)payload; EXPECT_EQ(Protocol::Packet::GRANT, header->common.opcode); @@ -295,22 +295,23 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) Protocol::MessageId id(42, 32); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket(pingPayload); - pingPacket.address = (Driver::Address)22; + Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; + IpAddress mockAddress = (IpAddress)22; + pingPacket.sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(mockAddress), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, &mockDriver); + receiver->handlePingPacket(&pingPacket, pingPacket.sourceIp); - EXPECT_EQ(pingPacket.address, mockPacket.address); Protocol::Packet::UnknownHeader* header = (Protocol::Packet::UnknownHeader*)payload; EXPECT_EQ(Protocol::Packet::UNKNOWN, header->common.opcode); @@ -321,10 +322,10 @@ TEST_F(ReceiverTest, receiveMessage) { Receiver::Message* msg0 = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, Protocol::MessageId(42, 0), - Driver::Address(22), 0); + SocketAddress{22, 60001}, 0); Receiver::Message* msg1 = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, Protocol::MessageId(42, 0), - Driver::Address(22), 0); + SocketAddress{22, 60001}, 0); receiver->receivedMessages.queue.push_back(&msg0->receivedMessageNode); receiver->receivedMessages.queue.push_back(&msg1->receivedMessageNode); @@ -349,7 +350,7 @@ TEST_F(ReceiverTest, poll) TEST_F(ReceiverTest, checkTimeouts) { Receiver::Message message(receiver, &mockDriver, 0, 0, - Protocol::MessageId(0, 0), Driver::Address(0), 0); + Protocol::MessageId(0, 0), SocketAddress{0, 60001}, 0); Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); bucket->resendTimeouts.setTimeout(&message.resendTimeout); bucket->messageTimeouts.setTimeout(&message.messageTimeout); @@ -373,7 +374,7 @@ TEST_F(ReceiverTest, Message_destructor_basic) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); const uint16_t NUM_PKTS = 5; @@ -392,7 +393,7 @@ TEST_F(ReceiverTest, Message_destructor_holes) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); const uint16_t NUM_PKTS = 4; @@ -414,10 +415,11 @@ TEST_F(ReceiverTest, Message_acknowledge) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket( + Eq(&mockPacket), Eq(message->source.ip), _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -428,14 +430,13 @@ TEST_F(ReceiverTest, Message_acknowledge) EXPECT_EQ(Protocol::Packet::DONE, header->opcode); EXPECT_EQ(id, header->messageId); EXPECT_EQ(sizeof(Protocol::Packet::DoneHeader), mockPacket.length); - EXPECT_EQ(message->source, mockPacket.address); } TEST_F(ReceiverTest, Message_dropped) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->state = Receiver::Message::State::IN_PROGRESS; @@ -450,10 +451,11 @@ TEST_F(ReceiverTest, Message_fail) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket( + Eq(&mockPacket), Eq(message->source.ip), _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -464,7 +466,6 @@ TEST_F(ReceiverTest, Message_fail) EXPECT_EQ(Protocol::Packet::ERROR, header->opcode); EXPECT_EQ(id, header->messageId); EXPECT_EQ(sizeof(Protocol::Packet::ErrorHeader), mockPacket.length); - EXPECT_EQ(message->source, mockPacket.address); } TEST_F(ReceiverTest, Message_get_basic) @@ -472,10 +473,10 @@ TEST_F(ReceiverTest, Message_get_basic) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; char source[] = "Hello, world!"; message->setPacket(0, &packet0); @@ -499,10 +500,10 @@ TEST_F(ReceiverTest, Message_get_offsetTooLarge) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; message->setPacket(0, &packet0); message->setPacket(1, &packet1); @@ -525,10 +526,10 @@ TEST_F(ReceiverTest, Message_get_missingPacket) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; char source[] = "Hello,"; message->setPacket(0, &packet0); @@ -557,7 +558,7 @@ TEST_F(ReceiverTest, Message_length) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->messageLength = 200; message->start = 20; EXPECT_EQ(180U, message->length()); @@ -567,7 +568,7 @@ TEST_F(ReceiverTest, Message_strip) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->messageLength = 30; message->start = 0; @@ -589,7 +590,7 @@ TEST_F(ReceiverTest, Message_getPacket) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); Driver::Packet* packet = (Driver::Packet*)42; message->packets[0] = packet; @@ -605,7 +606,7 @@ TEST_F(ReceiverTest, Message_setPacket) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); Driver::Packet* packet = (Driver::Packet*)42; EXPECT_FALSE(message->occupied.test(0)); @@ -626,12 +627,12 @@ TEST_F(ReceiverTest, MessageBucket_findMessage) Protocol::MessageId id0 = {42, 0}; Receiver::Message* msg0 = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id0, 0, - 0); + receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id0, + SocketAddress{0, 60001}, 0); Protocol::MessageId id1 = {42, 1}; Receiver::Message* msg1 = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id1, - Driver::Address(0), 0); + SocketAddress{0, 60001}, 0); Protocol::MessageId id_none = {42, 42}; bucket->messages.push_back(&msg0->bucketNode); @@ -659,7 +660,7 @@ TEST_F(ReceiverTest, dropMessage) SpinLock::Lock dummy(dummyMutex); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 1000, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 1000, id, SocketAddress{22, 60001}, 0); ASSERT_TRUE(message->scheduled); Receiver::MessageBucket* bucket = receiver->messageBuckets.getBucket(id); @@ -670,7 +671,7 @@ TEST_F(ReceiverTest, dropMessage) EXPECT_EQ(1U, receiver->messageAllocator.pool.outstandingObjects); EXPECT_EQ(message, bucket->findMessage(id, dummy)); - EXPECT_EQ(&receiver->peerTable[message->source], + EXPECT_EQ(&receiver->peerTable[message->source.ip], message->scheduledMessageInfo.peer); EXPECT_FALSE(bucket->messageTimeouts.list.empty()); EXPECT_FALSE(bucket->resendTimeouts.list.empty()); @@ -693,7 +694,7 @@ TEST_F(ReceiverTest, checkMessageTimeouts_basic) Protocol::MessageId id = {42, 10 + i}; op[i] = reinterpret_cast(i); message[i] = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 1000, id, 0, 0); + receiver, &mockDriver, 0, 1000, id, SocketAddress{0, 60001}, 0); bucket->messages.push_back(&message[i]->bucketNode); bucket->messageTimeouts.setTimeout(&message[i]->messageTimeout); bucket->resendTimeouts.setTimeout(&message[i]->resendTimeout); @@ -767,7 +768,7 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) for (uint64_t i = 0; i < 3; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 10000, id, Driver::Address(22), 5); + receiver, &mockDriver, 0, 10000, id, SocketAddress{22, 60001}, 5); bucket->resendTimeouts.setTimeout(&message[i]->resendTimeout); } @@ -803,14 +804,16 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) char buf1[1024]; char buf2[1024]; - Homa::Mock::MockDriver::MockPacket mockResendPacket1(buf1); - Homa::Mock::MockDriver::MockPacket mockResendPacket2(buf2); + Homa::Mock::MockDriver::MockPacket mockResendPacket1 {buf1}; + Homa::Mock::MockDriver::MockPacket mockResendPacket2 {buf2}; EXPECT_CALL(mockDriver, allocPacket()) .WillOnce(Return(&mockResendPacket1)) .WillOnce(Return(&mockResendPacket2)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket1))).Times(1); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket2))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket1), + Eq(message[0]->source.ip), _)).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket2), + Eq(message[0]->source.ip), _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket1), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket2), Eq(1))) @@ -830,7 +833,6 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) EXPECT_EQ(2U, header1->index); EXPECT_EQ(4U, header1->num); EXPECT_EQ(sizeof(Protocol::Packet::ResendHeader), mockResendPacket1.length); - EXPECT_EQ(message[0]->source, mockResendPacket1.address); Protocol::Packet::ResendHeader* header2 = static_cast(mockResendPacket2.payload); EXPECT_EQ(Protocol::Packet::RESEND, header2->common.opcode); @@ -838,7 +840,6 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) EXPECT_EQ(8U, header2->index); EXPECT_EQ(2U, header2->num); EXPECT_EQ(sizeof(Protocol::Packet::ResendHeader), mockResendPacket2.length); - EXPECT_EQ(message[0]->source, mockResendPacket2.address); // Message[1]: Blocked on grants EXPECT_EQ(10100, message[1]->resendTimeout.expirationCycleTime); @@ -867,7 +868,8 @@ TEST_F(ReceiverTest, trySendGrants) Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - 10000 * (i + 1), id, Driver::Address(100 + i), 10 * (i + 1)); + 10000 * (i + 1), id, SocketAddress{IpAddress(100 + i), 60001}, + 10 * (i + 1)); { SpinLock::Lock lock_scheduler(receiver->schedulerMutex); receiver->schedule(message[i], lock_scheduler); @@ -894,7 +896,7 @@ TEST_F(ReceiverTest, trySendGrants) EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -920,7 +922,7 @@ TEST_F(ReceiverTest, trySendGrants) EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -941,7 +943,7 @@ TEST_F(ReceiverTest, trySendGrants) policy.maxScheduledBytes = 10000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(_)).Times(0); + EXPECT_CALL(mockDriver, sendPacket(_, _, _)).Times(0); receiver->trySendGrants(); @@ -960,7 +962,7 @@ TEST_F(ReceiverTest, trySendGrants) policy.maxScheduledBytes = 10000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(_)).Times(0); + EXPECT_CALL(mockDriver, sendPacket(_, _, _)).Times(0); receiver->trySendGrants(); @@ -975,13 +977,13 @@ TEST_F(ReceiverTest, schedule) { Receiver::Message* message[4]; Receiver::ScheduledMessageInfo* info[4]; - Driver::Address address[4] = {22, 33, 33, 22}; + IpAddress address[4] = {22, 33, 33, 22}; int messageLength[4] = {2000, 3000, 1000, 4000}; for (uint64_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - messageLength[i], id, address[i], 0); + messageLength[i], id, SocketAddress{address[i], 60001}, 0); info[i] = &message[i]->scheduledMessageInfo; } @@ -1043,19 +1045,19 @@ TEST_F(ReceiverTest, unschedule) int messageLength[5] = {10, 20, 30, 10, 20}; for (uint64_t i = 0; i < 5; ++i) { Protocol::MessageId id = {42, 10 + i}; - Driver::Address source = Driver::Address((i / 3) + 10); + IpAddress source = IpAddress((i / 3) + 10); message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - messageLength[i], id, source, 0); + messageLength[i], id, SocketAddress{source, 60001}, 0); info[i] = &message[i]->scheduledMessageInfo; receiver->schedule(message[i], lock); } - ASSERT_EQ(Driver::Address(10), message[0]->source); - ASSERT_EQ(Driver::Address(10), message[1]->source); - ASSERT_EQ(Driver::Address(10), message[2]->source); - ASSERT_EQ(Driver::Address(11), message[3]->source); - ASSERT_EQ(Driver::Address(11), message[4]->source); + ASSERT_EQ(IpAddress(10), message[0]->source.ip); + ASSERT_EQ(IpAddress(10), message[1]->source.ip); + ASSERT_EQ(IpAddress(10), message[2]->source.ip); + ASSERT_EQ(IpAddress(11), message[3]->source.ip); + ASSERT_EQ(IpAddress(11), message[4]->source.ip); ASSERT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); ASSERT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); @@ -1128,15 +1130,15 @@ TEST_F(ReceiverTest, updateSchedule) for (uint64_t i = 0; i < 3; ++i) { Protocol::MessageId id = {42, 10 + i}; int messageLength = 10 * (i + 1); - Driver::Address source = Driver::Address(((i + 1) / 2) + 10); + IpAddress source = IpAddress(((i + 1) / 2) + 10); other[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - 10 * (i + 1), id, source, 0); + 10 * (i + 1), id, SocketAddress{source, 60001}, 0); receiver->schedule(other[i], lock); } Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 100, - Protocol::MessageId(42, 1), Driver::Address(11), 0); + Protocol::MessageId(42, 1), SocketAddress{11, 60001}, 0); receiver->schedule(message, lock); ASSERT_EQ(&receiver->peerTable.at(10), other[0]->scheduledMessageInfo.peer); ASSERT_EQ(&receiver->peerTable.at(11), other[1]->scheduledMessageInfo.peer); diff --git a/src/Sender.cc b/src/Sender.cc index c2d0c3f..b993b6b 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -67,10 +67,10 @@ Sender::~Sender() {} * Allocate an OutMessage that can be sent with this Sender. */ Homa::OutMessage* -Sender::allocMessage() +Sender::allocMessage(uint16_t sourcePort) { SpinLock::Lock lock_allocator(messageAllocator.mutex); - return messageAllocator.pool.construct(this, driver); + return messageAllocator.pool.construct(this, sourcePort); } /** @@ -78,12 +78,9 @@ Sender::allocMessage() * * @param packet * Incoming DONE packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) +Sender::handleDonePacket(Driver::Packet* packet) { Protocol::Packet::DoneHeader* header = static_cast(packet->payload); @@ -152,12 +149,9 @@ Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming RESEND packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) +Sender::handleResendPacket(Driver::Packet* packet) { Protocol::Packet::ResendHeader* header = static_cast(packet->payload); @@ -222,7 +216,7 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) // when it's ready. Perf::counters.tx_busy_pkts.add(1); ControlPacket::send( - driver, info->destination, info->id); + driver, info->destination.ip, info->id); } else { // There are some packets to resend but only resend packets that have // already been sent. @@ -230,11 +224,10 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) int resendPriority = policyManager->getResendPriority(); for (uint16_t i = index; i < resendEnd; ++i) { Driver::Packet* packet = info->packets->getPacket(i); - packet->priority = resendPriority; // Packets will be sent at the priority their original priority. Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, message->destination.ip, resendPriority); } } @@ -246,12 +239,9 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming GRANT packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) +Sender::handleGrantPacket(Driver::Packet* packet) { Protocol::Packet::GrantHeader* header = static_cast(packet->payload); @@ -310,12 +300,9 @@ Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming UNKNOWN packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) +Sender::handleUnknownPacket(Driver::Packet* packet) { Protocol::Packet::UnknownHeader* header = static_cast(packet->payload); @@ -376,7 +363,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) // Get the current policy for unscheduled bytes. Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( - message->destination, message->messageLength); + message->destination.ip, message->messageLength); int unscheduledIndexLimit = ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / message->PACKET_DATA_LENGTH); @@ -401,10 +388,10 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) // If there is only one packet in the message, send it right away. Driver::Packet* dataPacket = message->getPacket(0); assert(dataPacket != nullptr); - dataPacket->priority = policy.priority; Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(dataPacket->length); - driver->sendPacket(dataPacket); + driver->sendPacket(dataPacket, message->destination.ip, + policy.priority); message->state.store(OutMessage::Status::SENT); } else { // Otherwise, queue the message to be sent in SRPT order. @@ -413,7 +400,8 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) // Some of these values should still be set from when the message // was first queued. assert(info->id == message->id); - assert(info->destination == message->destination); + assert(!memcmp(&info->destination, &message->destination, + sizeof(info->destination))); assert(info->packets == message); // Some values need to be updated info->unsentBytes = message->messageLength; @@ -439,12 +427,9 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming ERROR packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleErrorPacket(Driver::Packet* packet, Driver* driver) +Sender::handleErrorPacket(Driver::Packet* packet) { Protocol::Packet::ErrorHeader* header = static_cast(packet->payload); @@ -697,7 +682,7 @@ Sender::Message::reserve(size_t count) * @copydoc Homa::OutMessage::send() */ void -Sender::Message::send(Driver::Address destination, +Sender::Message::send(SocketAddress destination, Sender::Message::Options options) { sender->sendMessage(this, destination, options); @@ -758,7 +743,7 @@ Sender::Message::getOrAllocPacket(size_t index) * @sa dropMessage() */ void -Sender::sendMessage(Sender::Message* message, Driver::Address destination, +Sender::sendMessage(Sender::Message* message, SocketAddress destination, Sender::Message::Options options) { // Prepare the message @@ -767,7 +752,7 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, Protocol::MessageId id(transportId, nextMessageSequenceNumber++); Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( - destination, message->messageLength); + destination.ip, message->messageLength); int unscheduledPacketLimit = ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / message->PACKET_DATA_LENGTH); @@ -789,10 +774,10 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, i * message->PACKET_DATA_LENGTH); } - packet->address = message->destination; new (packet->payload) Protocol::Packet::DataHeader( - message->id, Util::downCast(message->messageLength), - policy.version, Util::downCast(unscheduledPacketLimit), + message->source.port, destination.port, message->id, + Util::downCast(message->messageLength), policy.version, + Util::downCast(unscheduledPacketLimit), Util::downCast(i)); actualMessageLen += (packet->length - message->TRANSPORT_HEADER_LENGTH); } @@ -816,10 +801,9 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, // If there is only one packet in the message, send it right away. Driver::Packet* packet = message->getPacket(0); assert(packet != nullptr); - packet->priority = policy.priority; Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, message->destination.ip, policy.priority); message->state.store(OutMessage::Status::SENT); } else { // Otherwise, queue the message to be sent in SRPT order. @@ -984,7 +968,7 @@ Sender::checkPingTimeouts() // the receiver to ensure it still knows about this Message. Perf::counters.tx_ping_pkts.add(1); ControlPacket::send( - message->driver, message->destination, message->id); + message->driver, message->destination.ip, message->id); } globalNextTimeout = std::min(globalNextTimeout, nextTimeout); } @@ -1038,10 +1022,9 @@ Sender::trySend() break; } // ... if not, send away! - packet->priority = info->priority; Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, message.destination.ip, info->priority); int packetDataBytes = packet->length - info->packets->TRANSPORT_HEADER_LENGTH; assert(info->unsentBytes >= packetDataBytes); diff --git a/src/Sender.h b/src/Sender.h index 471925a..faa5dee 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -46,12 +46,12 @@ class Sender { uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles); virtual ~Sender(); - virtual Homa::OutMessage* allocMessage(); - virtual void handleDonePacket(Driver::Packet* packet, Driver* driver); - virtual void handleResendPacket(Driver::Packet* packet, Driver* driver); - virtual void handleGrantPacket(Driver::Packet* packet, Driver* driver); - virtual void handleUnknownPacket(Driver::Packet* packet, Driver* driver); - virtual void handleErrorPacket(Driver::Packet* packet, Driver* driver); + virtual Homa::OutMessage* allocMessage(uint16_t sourcePort); + virtual void handleDonePacket(Driver::Packet* packet); + virtual void handleResendPacket(Driver::Packet* packet); + virtual void handleGrantPacket(Driver::Packet* packet); + virtual void handleUnknownPacket(Driver::Packet* packet); + virtual void handleErrorPacket(Driver::Packet* packet); virtual void poll(); virtual uint64_t checkTimeouts(); @@ -96,7 +96,7 @@ class Sender { Protocol::MessageId id; /// Contains destination address this message. - Driver::Address destination; + SocketAddress destination; /// Handle to the queue Message for access to the packets that will /// be sent. This member documents that the packets are logically owned @@ -131,13 +131,14 @@ class Sender { /** * Construct an Message. */ - explicit Message(Sender* sender, Driver* driver) + explicit Message(Sender* sender, uint16_t sourcePort) : sender(sender) - , driver(driver) + , driver(sender->driver) , TRANSPORT_HEADER_LENGTH(sizeof(Protocol::Packet::DataHeader)) , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - TRANSPORT_HEADER_LENGTH) , id(0, 0) + , source{driver->getLocalAddress(), sourcePort} , destination() , options(Options::NONE) , start(0) @@ -161,7 +162,7 @@ class Sender { virtual void prepend(const void* source, size_t count); virtual void release(); virtual void reserve(size_t count); - virtual void send(Driver::Address destination, + virtual void send(SocketAddress destination, Options options = Options::NONE); private: @@ -188,8 +189,11 @@ class Sender { /// Contains the unique identifier for this message. Protocol::MessageId id; - /// Contains destination address this message. - Driver::Address destination; + /// Contains source address of this message. + SocketAddress source; + + /// Contains destination address of this message. + SocketAddress destination; /// Contains flags for any requested optional send behavior. Options options; @@ -384,7 +388,7 @@ class Sender { Protocol::MessageId::Hasher hasher; }; - void sendMessage(Sender::Message* message, Driver::Address destination, + void sendMessage(Sender::Message* message, SocketAddress destination, Message::Options options = Message::Options::NONE); void cancelMessage(Sender::Message* message); void dropMessage(Sender::Message* message); diff --git a/src/SenderTest.cc b/src/SenderTest.cc index fdae6ab..244a7c9 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -36,13 +36,13 @@ class SenderTest : public ::testing::Test { public: SenderTest() : mockDriver() - , mockPacket(&payload) + , mockPacket {&payload} , mockPolicyManager(&mockDriver) , sender() , savedLogPolicy(Debug::getLogPolicy()) { ON_CALL(mockDriver, getBandwidth).WillByDefault(Return(8000)); - ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1027)); + ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1031)); ON_CALL(mockDriver, getQueuedBytes).WillByDefault(Return(0)); Debug::setLogPolicy( Debug::logPolicyFromString("src/ObjectPool@SILENT")); @@ -59,7 +59,7 @@ class SenderTest : public ::testing::Test { } NiceMock mockDriver; - NiceMock mockPacket; + Homa::Mock::MockDriver::MockPacket mockPacket; NiceMock mockPolicyManager; char payload[1028]; Sender* sender; @@ -124,7 +124,7 @@ TEST_F(SenderTest, allocMessage) { EXPECT_EQ(0U, sender->messageAllocator.pool.outstandingObjects); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); } @@ -132,7 +132,7 @@ TEST_F(SenderTest, handleDonePacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_NE(Homa::OutMessage::Status::COMPLETED, message->state); Protocol::Packet::DoneHeader* header = @@ -143,7 +143,7 @@ TEST_F(SenderTest, handleDonePacket_basic) .Times(2); // No message. - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_NE(Homa::OutMessage::Status::COMPLETED, message->state); @@ -151,7 +151,7 @@ TEST_F(SenderTest, handleDonePacket_basic) message->state = Homa::OutMessage::Status::SENT; // Normal expected behavior. - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(nullptr, message->messageTimeout.node.list); EXPECT_EQ(nullptr, message->pingTimeout.node.list); @@ -162,7 +162,7 @@ TEST_F(SenderTest, handleDonePacket_CANCELED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::CANCELED; @@ -173,14 +173,14 @@ TEST_F(SenderTest, handleDonePacket_CANCELED) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); } TEST_F(SenderTest, handleDonePacket_COMPLETED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::COMPLETED; @@ -194,7 +194,7 @@ TEST_F(SenderTest, handleDonePacket_COMPLETED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -211,7 +211,7 @@ TEST_F(SenderTest, handleDonePacket_FAILED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::FAILED; @@ -225,7 +225,7 @@ TEST_F(SenderTest, handleDonePacket_FAILED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -244,7 +244,7 @@ TEST_F(SenderTest, handleDonePacket_IN_PROGRESS) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -258,7 +258,7 @@ TEST_F(SenderTest, handleDonePacket_IN_PROGRESS) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -277,7 +277,7 @@ TEST_F(SenderTest, handleDonePacket_NO_STARTED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::NOT_STARTED; @@ -291,7 +291,7 @@ TEST_F(SenderTest, handleDonePacket_NO_STARTED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -310,10 +310,12 @@ TEST_F(SenderTest, handleResendPacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); std::vector packets; + std::vector priorities; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket(payload)); + packets.push_back(new Homa::Mock::MockDriver::MockPacket {payload}); + priorities.push_back(0); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -331,22 +333,24 @@ TEST_F(SenderTest, handleResendPacket_basic) resendHdr->priority = 4; EXPECT_CALL(mockPolicyManager, getResendPriority).WillOnce(Return(7)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]))).Times(1); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]), _, _)).WillOnce( + [&priorities] (auto _1, auto _2, int p) { priorities[3] = p; }); + EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]), _, _)).WillOnce( + [&priorities] (auto _1, auto _2, int p) { priorities[4] = p; }); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(5U, info->packetsSent); EXPECT_EQ(8U, info->packetsGranted); EXPECT_EQ(4, info->priority); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); - EXPECT_EQ(0, packets[2]->priority); - EXPECT_EQ(7, packets[3]->priority); - EXPECT_EQ(7, packets[4]->priority); - EXPECT_EQ(0, packets[5]->priority); + EXPECT_EQ(0, priorities[2]); + EXPECT_EQ(7, priorities[3]); + EXPECT_EQ(7, priorities[4]); + EXPECT_EQ(0, priorities[5]); EXPECT_TRUE(sender->sendReady.load()); for (int i = 0; i < 10; ++i) { @@ -366,7 +370,7 @@ TEST_F(SenderTest, handleResendPacket_staleResend) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); } TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) @@ -374,10 +378,10 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload); + new Homa::Mock::MockDriver::MockPacket {payload}; setMessagePacket(message, 0, packet); Protocol::Packet::ResendHeader* resendHdr = @@ -393,7 +397,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -414,10 +418,10 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); std::vector packets; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket(payload)); + packets.push_back(new Homa::Mock::MockDriver::MockPacket {payload}); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -440,7 +444,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -464,9 +468,9 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); char data[1028]; - Homa::Mock::MockDriver::MockPacket dataPacket(data); + Homa::Mock::MockDriver::MockPacket dataPacket {data}; for (int i = 0; i < 10; ++i) { setMessagePacket(message, i, &dataPacket); } @@ -484,18 +488,18 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) // Expect the BUSY control packet. char busy[1028]; - Homa::Mock::MockDriver::MockPacket busyPacket(busy); + Homa::Mock::MockDriver::MockPacket busyPacket {busy}; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&busyPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&busyPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&busyPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&busyPacket), Eq(1))) .Times(1); // Expect no data to be sent but the RESEND packet to be release. - EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket))).Times(0); + EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket), _, _)).Times(0); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(5U, info->packetsSent); EXPECT_EQ(8U, info->packetsGranted); @@ -511,7 +515,7 @@ TEST_F(SenderTest, handleGrantPacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -530,7 +534,7 @@ TEST_F(SenderTest, handleGrantPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(7, info->packetsGranted); EXPECT_EQ(6, info->priority); @@ -543,7 +547,7 @@ TEST_F(SenderTest, handleGrantPacket_excessiveGrant) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -565,7 +569,7 @@ TEST_F(SenderTest, handleGrantPacket_excessiveGrant) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -590,7 +594,7 @@ TEST_F(SenderTest, handleGrantPacket_staleGrant) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -608,7 +612,7 @@ TEST_F(SenderTest, handleGrantPacket_staleGrant) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(5, info->packetsGranted); EXPECT_EQ(2, info->priority); @@ -628,23 +632,23 @@ TEST_F(SenderTest, handleGrantPacket_dropGrant) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); } TEST_F(SenderTest, handleUnknownPacket_basic) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policyOld = {1, 2000, 1}; Core::Policy::Unscheduled policyNew = {2, 3000, 2}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); std::vector packets; char payload[5][1028]; for (int i = 0; i < 5; ++i) { Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload[i]); + new Homa::Mock::MockDriver::MockPacket {payload[i]}; Protocol::Packet::DataHeader* header = static_cast(packet->payload); header->policyVersion = policyOld.version; @@ -674,12 +678,12 @@ TEST_F(SenderTest, handleUnknownPacket_basic) EXPECT_CALL( mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(message->messageLength))) + getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) .WillOnce(Return(policyNew)); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); for (int i = 0; i < 3; ++i) { @@ -706,13 +710,13 @@ TEST_F(SenderTest, handleUnknownPacket_basic) TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policyOld = {1, 2000, 1}; Core::Policy::Unscheduled policyNew = {2, 3000, 2}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); - Homa::Mock::MockDriver::MockPacket dataPacket(payload); + dynamic_cast(sender->allocMessage(0)); + Homa::Mock::MockDriver::MockPacket dataPacket {payload}; Protocol::Packet::DataHeader* dataHeader = static_cast(dataPacket.payload); dataHeader->policyVersion = policyOld.version; @@ -733,13 +737,13 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) EXPECT_CALL( mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(message->messageLength))) + getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) .WillOnce(Return(policyNew)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_EQ(policyNew.version, dataHeader->policyVersion); @@ -754,7 +758,7 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = @@ -764,7 +768,7 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) char payload[5][1028]; for (int i = 0; i < 5; ++i) { Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload[i]); + new Homa::Mock::MockDriver::MockPacket {payload[i]}; packets.push_back(packet); setMessagePacket(message, i, packet); } @@ -806,14 +810,14 @@ TEST_F(SenderTest, handleUnknownPacket_no_message) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); } TEST_F(SenderTest, handleUnknownPacket_done) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::COMPLETED); EXPECT_EQ(0U, message->messageTimeout.expirationCycleTime); @@ -826,7 +830,7 @@ TEST_F(SenderTest, handleUnknownPacket_done) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::COMPLETED, message->state); EXPECT_EQ(0U, message->messageTimeout.expirationCycleTime); @@ -838,7 +842,7 @@ TEST_F(SenderTest, handleErrorPacket_basic) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); @@ -851,7 +855,7 @@ TEST_F(SenderTest, handleErrorPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(nullptr, message->messageTimeout.node.list); EXPECT_EQ(nullptr, message->pingTimeout.node.list); @@ -863,7 +867,7 @@ TEST_F(SenderTest, handleErrorPacket_CANCELED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::CANCELED); @@ -874,7 +878,7 @@ TEST_F(SenderTest, handleErrorPacket_CANCELED) EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::CANCELED, message->state); } @@ -884,7 +888,7 @@ TEST_F(SenderTest, handleErrorPacket_NOT_STARTED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::NOT_STARTED); @@ -898,7 +902,7 @@ TEST_F(SenderTest, handleErrorPacket_NOT_STARTED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -920,7 +924,7 @@ TEST_F(SenderTest, handleErrorPacket_IN_PROGRESS) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::IN_PROGRESS); @@ -934,7 +938,7 @@ TEST_F(SenderTest, handleErrorPacket_IN_PROGRESS) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -956,7 +960,7 @@ TEST_F(SenderTest, handleErrorPacket_COMPLETED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::COMPLETED); @@ -970,7 +974,7 @@ TEST_F(SenderTest, handleErrorPacket_COMPLETED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -992,7 +996,7 @@ TEST_F(SenderTest, handleErrorPacket_FAILED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::FAILED); @@ -1006,7 +1010,7 @@ TEST_F(SenderTest, handleErrorPacket_FAILED) VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -1029,7 +1033,7 @@ TEST_F(SenderTest, handleErrorPacket_noMessage) header->common.messageId = id; EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); } TEST_F(SenderTest, poll) @@ -1040,7 +1044,7 @@ TEST_F(SenderTest, poll) TEST_F(SenderTest, checkTimeouts) { - Sender::Message message(sender, &mockDriver); + Sender::Message message(sender, 0); Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); bucket->pingTimeouts.setTimeout(&message.pingTimeout); bucket->messageTimeouts.setTimeout(&message.messageTimeout); @@ -1065,7 +1069,7 @@ TEST_F(SenderTest, Message_destructor) const int MAX_RAW_PACKET_LENGTH = 2000; ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message* msg = new Sender::Message(sender, &mockDriver); + Sender::Message* msg = new Sender::Message(sender, 0); const uint16_t NUM_PKTS = 5; @@ -1086,10 +1090,10 @@ TEST_F(SenderTest, Message_append_basic) ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + MAX_RAW_PACKET_LENGTH); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1126,10 +1130,10 @@ TEST_F(SenderTest, Message_append_truncated) ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + MAX_RAW_PACKET_LENGTH); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1155,7 +1159,7 @@ TEST_F(SenderTest, Message_append_truncated) EXPECT_STREQ("append", m.function); EXPECT_EQ(int(Debug::LogLevel::WARNING), m.logLevel); EXPECT_EQ( - "Max message size limit (2020352B) reached; 7 of 14 bytes appended", + "Max message size limit (2016256B) reached; 7 of 14 bytes appended", m.message); Debug::setLogHandler(std::function()); @@ -1183,10 +1187,10 @@ TEST_F(SenderTest, Message_length) TEST_F(SenderTest, Message_prepend) { ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1218,10 +1222,10 @@ TEST_F(SenderTest, Message_release) TEST_F(SenderTest, Message_reserve) { - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1259,7 +1263,7 @@ TEST_F(SenderTest, Message_send) TEST_F(SenderTest, Message_getPacket) { - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); Driver::Packet* packet = (Driver::Packet*)42; msg.packets[0] = packet; @@ -1273,10 +1277,10 @@ TEST_F(SenderTest, Message_getPacket) TEST_F(SenderTest, Message_getOrAllocPacket) { // TODO(cstlee): cleanup - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; EXPECT_FALSE(msg.occupied.test(0)); EXPECT_EQ(0U, msg.numPackets); @@ -1298,9 +1302,9 @@ TEST_F(SenderTest, MessageBucket_findMessage) Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); Sender::Message* msg0 = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::Message* msg1 = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); msg0->id = {42, 0}; msg1->id = {42, 1}; Protocol::MessageId id_none = {42, 42}; @@ -1329,35 +1333,42 @@ TEST_F(SenderTest, sendMessage_basic) { Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; + uint16_t sport = 0; + uint16_t dport = 60001; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(sport)); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); setMessagePacket(message, 0, &mockPacket); message->messageLength = 420; mockPacket.length = message->messageLength + message->TRANSPORT_HEADER_LENGTH; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, dport}; Core::Policy::Unscheduled policy = {1, 3000, 2}; EXPECT_FALSE(bucket->messages.contains(&message->bucketNode)); EXPECT_CALL(mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(420))) + getUnscheduledPolicy(Eq(destination.ip), Eq(420))) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + int mockPriority = 0; + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(destination.ip), _)) + .WillOnce([&mockPriority] (auto _1, auto _2, int p){mockPriority = p;}); sender->sendMessage(message, destination, Sender::Message::Options::NO_RETRY); // Check Message metadata EXPECT_EQ(id, message->id); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(Sender::Message::Options::NO_RETRY, message->options); // Check packet metadata Protocol::Packet::DataHeader* header = static_cast(mockPacket.payload); + EXPECT_EQ(htobe16(sport), header->common.sport); + EXPECT_EQ(htobe16(dport), header->common.dport); EXPECT_EQ(id, header->common.messageId); EXPECT_EQ(420U, header->totalLength); EXPECT_EQ(policy.version, header->policyVersion); @@ -1370,8 +1381,7 @@ TEST_F(SenderTest, sendMessage_basic) EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); // Check sent packet metadata - EXPECT_EQ(22U, (uint64_t)mockPacket.address); - EXPECT_EQ(policy.priority, mockPacket.priority); + EXPECT_EQ(policy.priority, mockPriority); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_FALSE(sender->sendReady.load()); @@ -1381,48 +1391,48 @@ TEST_F(SenderTest, sendMessage_multipacket) { char payload0[1027]; char payload1[1027]; - NiceMock packet0(payload0); - NiceMock packet1(payload1); + Homa::Mock::MockDriver::MockPacket packet0 {payload0}; + Homa::Mock::MockDriver::MockPacket packet1 {payload1}; Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); setMessagePacket(message, 0, &packet0); setMessagePacket(message, 1, &packet1); message->messageLength = 1420; - packet0.length = 1000 + 27; - packet1.length = 420 + 27; - Driver::Address destination = (Driver::Address)22; + packet0.length = 1000 + 31; + packet1.length = 420 + 31; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policy = {1, 1000, 2}; - EXPECT_EQ(27U, sizeof(Protocol::Packet::DataHeader)); + EXPECT_EQ(31U, sizeof(Protocol::Packet::DataHeader)); EXPECT_EQ(1000U, message->PACKET_DATA_LENGTH); EXPECT_CALL(mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(1420))) + getUnscheduledPolicy(Eq(destination.ip), Eq(1420))) .WillOnce(Return(policy)); sender->sendMessage(message, destination); // Check Message metadata EXPECT_EQ(id, message->id); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); // Check packet metadata Protocol::Packet::DataHeader* header = nullptr; // Packet0 - EXPECT_EQ(22U, (uint64_t)packet0.address); header = static_cast(packet0.payload); EXPECT_EQ(message->id, header->common.messageId); EXPECT_EQ(message->messageLength, header->totalLength); // Packet1 - EXPECT_EQ(22U, (uint64_t)packet1.address); header = static_cast(packet1.payload); EXPECT_EQ(message->id, header->common.messageId); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(message->messageLength, header->totalLength); // Check Sender metadata @@ -1441,13 +1451,13 @@ TEST_F(SenderTest, sendMessage_missingPacket) Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); setMessagePacket(message, 1, &mockPacket); Core::Policy::Unscheduled policy = {1, 1000, 2}; ON_CALL(mockPolicyManager, getUnscheduledPolicy(_, _)) .WillByDefault(Return(policy)); - EXPECT_DEATH(sender->sendMessage(message, Driver::Address()), + EXPECT_DEATH(sender->sendMessage(message, SocketAddress{0, 0}), ".*Incomplete message with id \\(22:1\\); missing packet at " "offset 0; this shouldn't happen.*"); } @@ -1457,17 +1467,17 @@ TEST_F(SenderTest, sendMessage_unscheduledLimit) Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); for (int i = 0; i < 9; ++i) { setMessagePacket(message, i, &mockPacket); } message->messageLength = 9000; mockPacket.length = 1000 + sizeof(Protocol::Packet::DataHeader); - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policy = {1, 4500, 2}; EXPECT_EQ(9U, message->numPackets); EXPECT_EQ(1000U, message->PACKET_DATA_LENGTH); - EXPECT_CALL(mockPolicyManager, getUnscheduledPolicy(destination, 9000)) + EXPECT_CALL(mockPolicyManager, getUnscheduledPolicy(destination.ip, 9000)) .WillOnce(Return(policy)); sender->sendMessage(message, destination); @@ -1481,7 +1491,7 @@ TEST_F(SenderTest, cancelMessage) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->pingTimeouts.setTimeout(&message->pingTimeout); @@ -1505,7 +1515,7 @@ TEST_F(SenderTest, cancelMessage) TEST_F(SenderTest, dropMessage) { Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); sender->dropMessage(message); @@ -1518,7 +1528,7 @@ TEST_F(SenderTest, checkMessageTimeouts_basic) Sender::Message* message[4]; for (uint64_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); + message[i] = dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message[i]); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->messageTimeouts.setTimeout(&message[i]->messageTimeout); @@ -1581,7 +1591,7 @@ TEST_F(SenderTest, checkPingTimeouts_basic) Sender::Message* message[5]; for (uint64_t i = 0; i < 5; ++i) { Protocol::MessageId id = {42, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); + message[i] = dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message[i]); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->pingTimeouts.setTimeout(&message[i]->pingTimeout); @@ -1606,7 +1616,7 @@ TEST_F(SenderTest, checkPingTimeouts_basic) EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -1645,7 +1655,7 @@ TEST_F(SenderTest, trySend_basic) { Protocol::MessageId id = {42, 10}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; SenderTest::addMessage(sender, id, message, true, 3); Homa::Mock::MockDriver::MockPacket* packet[5]; @@ -1653,7 +1663,7 @@ TEST_F(SenderTest, trySend_basic) const uint32_t PACKET_DATA_SIZE = PACKET_SIZE - message->TRANSPORT_HEADER_LENGTH; for (int i = 0; i < 5; ++i) { - packet[i] = new Homa::Mock::MockDriver::MockPacket(payload); + packet[i] = new Homa::Mock::MockDriver::MockPacket {payload}; packet[i]->length = PACKET_SIZE; setMessagePacket(message, i, packet[i]); info->unsentBytes += PACKET_DATA_SIZE; @@ -1668,8 +1678,8 @@ TEST_F(SenderTest, trySend_basic) EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); // 3 granted packets; 2 will send; queue limit reached. - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]), _, _)); sender->trySend(); // < test call EXPECT_TRUE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); @@ -1681,7 +1691,7 @@ TEST_F(SenderTest, trySend_basic) Mock::VerifyAndClearExpectations(&mockDriver); // 1 packet to be sent; grant limit reached. - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]), _, _)); sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); @@ -1708,8 +1718,8 @@ TEST_F(SenderTest, trySend_basic) // 2 more granted packets; will finish. info->packetsGranted = 5; sender->sendReady = true; - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[3]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[4]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[3]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[4]), _, _)); sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); @@ -1732,10 +1742,10 @@ TEST_F(SenderTest, trySend_multipleMessages) Homa::Mock::MockDriver::MockPacket* packet[3]; for (uint64_t i = 0; i < 3; ++i) { Protocol::MessageId id = {22, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); + message[i] = dynamic_cast(sender->allocMessage(0)); info[i] = &message[i]->queuedMessageInfo; SenderTest::addMessage(sender, id, message[i], true, 1); - packet[i] = new Homa::Mock::MockDriver::MockPacket(payload); + packet[i] = new Homa::Mock::MockDriver::MockPacket {payload}; packet[i]->length = sender->driver->getMaxPayloadSize() / 4; setMessagePacket(message[i], 0, packet[i]); info[i]->unsentBytes += @@ -1758,9 +1768,9 @@ TEST_F(SenderTest, trySend_multipleMessages) EXPECT_EQ(1, info[2]->packetsGranted); info[2]->packetsSent = 0; - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]))); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]), _, _)); sender->trySend(); @@ -1779,7 +1789,7 @@ TEST_F(SenderTest, trySend_alreadyRunning) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; SenderTest::addMessage(sender, id, message, true, 1); setMessagePacket(message, 0, &mockPacket); diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index 310e099..d4ebc70 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -94,61 +94,66 @@ void TransportImpl::processPackets() { // Keep track of time spent doing active processing versus idle. - Perf::Timer activityTimer; - activityTimer.split(); - uint64_t activeTime = 0; - uint64_t idleTime = 0; + uint64_t cycles = PerfUtils::Cycles::rdtsc(); const int MAX_BURST = 32; Driver::Packet* packets[MAX_BURST]; int numPackets = driver->receivePackets(MAX_BURST, packets); for (int i = 0; i < numPackets; ++i) { Driver::Packet* packet = packets[i]; - assert(packet->length >= - Util::downCast(sizeof(Protocol::Packet::CommonHeader))); - Perf::counters.rx_bytes.add(packet->length); - Protocol::Packet::CommonHeader* header = - static_cast(packet->payload); - switch (header->opcode) { - case Protocol::Packet::DATA: - Perf::counters.rx_data_pkts.add(1); - receiver->handleDataPacket(packet, driver); - break; - case Protocol::Packet::GRANT: - Perf::counters.rx_grant_pkts.add(1); - sender->handleGrantPacket(packet, driver); - break; - case Protocol::Packet::DONE: - Perf::counters.rx_done_pkts.add(1); - sender->handleDonePacket(packet, driver); - break; - case Protocol::Packet::RESEND: - Perf::counters.rx_resend_pkts.add(1); - sender->handleResendPacket(packet, driver); - break; - case Protocol::Packet::BUSY: - Perf::counters.rx_busy_pkts.add(1); - receiver->handleBusyPacket(packet, driver); - break; - case Protocol::Packet::PING: - Perf::counters.rx_ping_pkts.add(1); - receiver->handlePingPacket(packet, driver); - break; - case Protocol::Packet::UNKNOWN: - Perf::counters.rx_unknown_pkts.add(1); - sender->handleUnknownPacket(packet, driver); - break; - case Protocol::Packet::ERROR: - Perf::counters.rx_error_pkts.add(1); - sender->handleErrorPacket(packet, driver); - break; - } - activeTime += activityTimer.split(); + processPacket(packet, packet->sourceIp); + } + + cycles = PerfUtils::Cycles::rdtsc() - cycles; + if (numPackets > 0) { + Perf::counters.active_cycles.add(cycles); + } else { + Perf::counters.idle_cycles.add(cycles); } - idleTime += activityTimer.split(); +} - Perf::counters.active_cycles.add(activeTime); - Perf::counters.idle_cycles.add(idleTime); +void +TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) +{ + assert(packet->length >= + Util::downCast(sizeof(Protocol::Packet::CommonHeader))); + Perf::counters.rx_bytes.add(packet->length); + Protocol::Packet::CommonHeader* header = + static_cast(packet->payload); + switch (header->opcode) { + case Protocol::Packet::DATA: + Perf::counters.rx_data_pkts.add(1); + receiver->handleDataPacket(packet, sourceIp); + break; + case Protocol::Packet::GRANT: + Perf::counters.rx_grant_pkts.add(1); + sender->handleGrantPacket(packet); + break; + case Protocol::Packet::DONE: + Perf::counters.rx_done_pkts.add(1); + sender->handleDonePacket(packet); + break; + case Protocol::Packet::RESEND: + Perf::counters.rx_resend_pkts.add(1); + sender->handleResendPacket(packet); + break; + case Protocol::Packet::BUSY: + Perf::counters.rx_busy_pkts.add(1); + receiver->handleBusyPacket(packet); + break; + case Protocol::Packet::PING: + Perf::counters.rx_ping_pkts.add(1); + receiver->handlePingPacket(packet, sourceIp); + break; + case Protocol::Packet::UNKNOWN: + Perf::counters.rx_unknown_pkts.add(1); + sender->handleUnknownPacket(packet); + break; + case Protocol::Packet::ERROR: + Perf::counters.rx_error_pkts.add(1); + sender->handleErrorPacket(packet); + break; + } } } // namespace Core diff --git a/src/TransportImpl.h b/src/TransportImpl.h index 2d559be..ad46f99 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -47,9 +47,10 @@ class TransportImpl : public Transport { ~TransportImpl(); /// See Homa::Transport::alloc() - virtual Homa::unique_ptr alloc() + virtual Homa::unique_ptr alloc(uint16_t sourcePort) { - return Homa::unique_ptr(sender->allocMessage()); + Homa::OutMessage* outMessage = sender->allocMessage(sourcePort); + return Homa::unique_ptr(outMessage); } /// See Homa::Transport::receive() @@ -74,6 +75,7 @@ class TransportImpl : public Transport { private: void processPackets(); + void processPacket(Driver::Packet* packet, IpAddress source); /// Unique identifier for this transport. const std::atomic transportId; diff --git a/src/TransportImplTest.cc b/src/TransportImplTest.cc index 0e0ab60..c69a36a 100644 --- a/src/TransportImplTest.cc +++ b/src/TransportImplTest.cc @@ -27,6 +27,7 @@ namespace Homa { namespace Core { namespace { +using ::testing::_; using ::testing::DoAll; using ::testing::Eq; using ::testing::NiceMock; @@ -101,68 +102,60 @@ TEST_F(TransportImplTest, processPackets) Homa::Driver::Packet* packets[8]; // Set DATA packet - Homa::Mock::MockDriver::MockPacket dataPacket(payload[0], 1024); + Homa::Mock::MockDriver::MockPacket dataPacket {payload[0], 1024}; static_cast(dataPacket.payload) ->common.opcode = Protocol::Packet::DATA; packets[0] = &dataPacket; - EXPECT_CALL(*mockReceiver, - handleDataPacket(Eq(&dataPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockReceiver, handleDataPacket(Eq(&dataPacket), _)); // Set GRANT packet - Homa::Mock::MockDriver::MockPacket grantPacket(payload[1], 1024); + Homa::Mock::MockDriver::MockPacket grantPacket {payload[1], 1024}; static_cast(grantPacket.payload) ->common.opcode = Protocol::Packet::GRANT; packets[1] = &grantPacket; - EXPECT_CALL(*mockSender, - handleGrantPacket(Eq(&grantPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleGrantPacket(Eq(&grantPacket))); // Set DONE packet - Homa::Mock::MockDriver::MockPacket donePacket(payload[2], 1024); + Homa::Mock::MockDriver::MockPacket donePacket {payload[2], 1024}; static_cast(donePacket.payload) ->common.opcode = Protocol::Packet::DONE; packets[2] = &donePacket; - EXPECT_CALL(*mockSender, - handleDonePacket(Eq(&donePacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleDonePacket(Eq(&donePacket))); // Set RESEND packet - Homa::Mock::MockDriver::MockPacket resendPacket(payload[3], 1024); + Homa::Mock::MockDriver::MockPacket resendPacket {payload[3], 1024}; static_cast(resendPacket.payload) ->common.opcode = Protocol::Packet::RESEND; packets[3] = &resendPacket; - EXPECT_CALL(*mockSender, - handleResendPacket(Eq(&resendPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleResendPacket(Eq(&resendPacket))); // Set BUSY packet - Homa::Mock::MockDriver::MockPacket busyPacket(payload[4], 1024); + Homa::Mock::MockDriver::MockPacket busyPacket {payload[4], 1024}; static_cast(busyPacket.payload) ->common.opcode = Protocol::Packet::BUSY; packets[4] = &busyPacket; - EXPECT_CALL(*mockReceiver, - handleBusyPacket(Eq(&busyPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockReceiver, handleBusyPacket(Eq(&busyPacket))); // Set PING packet - Homa::Mock::MockDriver::MockPacket pingPacket(payload[5], 1024); + Homa::Mock::MockDriver::MockPacket pingPacket {payload[5], 1024}; static_cast(pingPacket.payload) ->common.opcode = Protocol::Packet::PING; packets[5] = &pingPacket; - EXPECT_CALL(*mockReceiver, - handlePingPacket(Eq(&pingPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockReceiver, handlePingPacket(Eq(&pingPacket), _)); // Set UNKNOWN packet - Homa::Mock::MockDriver::MockPacket unknownPacket(payload[6], 1024); + Homa::Mock::MockDriver::MockPacket unknownPacket {payload[6], 1024}; static_cast(unknownPacket.payload) ->common.opcode = Protocol::Packet::UNKNOWN; packets[6] = &unknownPacket; - EXPECT_CALL(*mockSender, - handleUnknownPacket(Eq(&unknownPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleUnknownPacket(Eq(&unknownPacket))); // Set ERROR packet - Homa::Mock::MockDriver::MockPacket errorPacket(payload[7], 1024); + Homa::Mock::MockDriver::MockPacket errorPacket {payload[7], 1024}; static_cast(errorPacket.payload) ->common.opcode = Protocol::Packet::ERROR; packets[7] = &errorPacket; - EXPECT_CALL(*mockSender, - handleErrorPacket(Eq(&errorPacket), Eq(&mockDriver))); + EXPECT_CALL(*mockSender, handleErrorPacket(Eq(&errorPacket))); EXPECT_CALL(mockDriver, receivePackets) .WillOnce(DoAll(SetArrayArgument<1>(packets, packets + 8), Return(8))); diff --git a/test/system_test.cc b/test/system_test.cc index 8e43238..266d842 100644 --- a/test/system_test.cc +++ b/test/system_test.cc @@ -70,7 +70,7 @@ struct Node { }; void -serverMain(Node* server, std::vector addresses) +serverMain(Node* server, std::vector addresses) { while (true) { if (server->run.load() == false) { @@ -101,7 +101,7 @@ serverMain(Node* server, std::vector addresses) * Number of Op that failed. */ int -clientMain(int count, int size, std::vector addresses) +clientMain(int count, int size, std::vector addresses) { std::random_device rd; std::mt19937 gen(rd()); @@ -119,9 +119,9 @@ clientMain(int count, int size, std::vector addresses) payload[i] = randData(gen); } - std::string destAddress = addresses[randAddr(gen)]; + Homa::IpAddress destAddress = addresses[randAddr(gen)]; - Homa::unique_ptr message = client.transport->alloc(); + Homa::unique_ptr message = client.transport->alloc(0); { MessageHeader header; header.id = id; @@ -133,7 +133,7 @@ clientMain(int count, int size, std::vector addresses) << std::endl; } } - message->send(client.driver.getAddress(&destAddress)); + message->send(Homa::SocketAddress{destAddress, 60001}); while (1) { Homa::OutMessage::Status status = message->getStatus(); @@ -185,12 +185,11 @@ main(int argc, char* argv[]) Homa::Drivers::Fake::FakeNetworkConfig::setPacketLossRate(packetLossRate); uint64_t nextServerId = 101; - std::vector addresses; + std::vector addresses; std::vector servers; for (int i = 0; i < numServers; ++i) { Node* server = new Node(nextServerId++); - addresses.emplace_back(std::string( - server->driver.addressToString(server->driver.getLocalAddress()))); + addresses.emplace_back(server->driver.getLocalAddress()); servers.push_back(server); } From a305d46158b52eb61a79b6b55cc563b5cd084c7c Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Mon, 20 Jul 2020 22:00:51 -0700 Subject: [PATCH 03/15] Update DpdkDriver to use the new API introduced in the previous commit - Reenable DPDK in CMakeLists.txt - Initialize an ARP table at driver startup using the content of /proc/net/arp - Select eth port via the symbolic name of the network interface (e.g., eno1d1) (the current implementation uses ioctl to obtain the IP and MAC addresses of a network interface) - Add a system test for DPDK driver: test/dpdk_test.cc --- CMakeLists.txt | 71 +++++----- include/Homa/Drivers/DPDK/DpdkDriver.h | 31 ++--- include/Homa/Util.h | 2 + src/Drivers/DPDK/DpdkDriver.cc | 47 ++----- src/Drivers/DPDK/DpdkDriverImpl.cc | 180 ++++++++++++++++--------- src/Drivers/DPDK/DpdkDriverImpl.h | 53 +++++--- src/Drivers/DPDK/MacAddress.cc | 53 -------- src/Drivers/DPDK/MacAddress.h | 13 +- src/Drivers/DPDK/MacAddressTest.cc | 40 ------ src/Drivers/RawAddressType.h | 38 ------ src/Util.cc | 16 +++ test/CMakeLists.txt | 12 ++ test/Output.h | 105 +++++++++++++++ test/dpdk_test.cc | 88 ++++++++++++ 14 files changed, 435 insertions(+), 314 deletions(-) delete mode 100644 src/Drivers/RawAddressType.h create mode 100644 test/Output.h create mode 100644 test/dpdk_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 2f82962..4a6f9c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/modules) find_package(Doxygen OPTIONAL_COMPONENTS dot mscgen dia) # Network Interface library (https://www.dpdk.org/) -# find_package(Dpdk REQUIRED) +find_package(Dpdk REQUIRED) # Source control tool; needed to download external libraries. find_package(Git REQUIRED) @@ -135,34 +135,34 @@ target_compile_options(FakeDriver ) ## lib DpdkDriver ############################################################## -#add_library(DpdkDriver -# src/Drivers/DPDK/DpdkDriver.cc -# src/Drivers/DPDK/DpdkDriverImpl.cc -# src/Drivers/DPDK/MacAddress.cc -#) -#add_library(Homa::DpdkDriver ALIAS DpdkDriver) -#target_include_directories(DpdkDriver -# PUBLIC -# $ -# $ -# PRIVATE -# ${CMAKE_CURRENT_SOURCE_DIR}/src -#) -#target_link_libraries(DpdkDriver -# PRIVATE -# Dpdk::Dpdk -# PUBLIC -# Homa -#) -#target_compile_features(DpdkDriver -# PUBLIC -# cxx_std_11 -#) -#target_compile_options(DpdkDriver -# PRIVATE -# -Wall -# -Wextra -#) +add_library(DpdkDriver + src/Drivers/DPDK/DpdkDriver.cc + src/Drivers/DPDK/DpdkDriverImpl.cc + src/Drivers/DPDK/MacAddress.cc +) +add_library(Homa::DpdkDriver ALIAS DpdkDriver) +target_include_directories(DpdkDriver + PUBLIC + $ + $ + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src +) +target_link_libraries(DpdkDriver + PRIVATE + Dpdk::Dpdk + PUBLIC + Homa +) +target_compile_features(DpdkDriver + PUBLIC + cxx_std_11 +) +target_compile_options(DpdkDriver + PRIVATE + -Wall + -Wextra +) ################################################################################ ## Tests ####################################################################### @@ -195,8 +195,7 @@ endif() ## Install & Export ############################################################ ################################################################################ -#install(TARGETS Homa DpdkDriver FakeDriver EXPORT HomaTargets -install(TARGETS Homa FakeDriver EXPORT HomaTargets +install(TARGETS Homa DpdkDriver FakeDriver EXPORT HomaTargets LIBRARY DESTINATION lib ARCHIVE DESTINATION lib RUNTIME DESTINATION bin @@ -275,11 +274,11 @@ target_sources(unit_test target_link_libraries(unit_test FakeDriver) #DPDK Tests -#target_sources(unit_test -# PUBLIC -# ${CMAKE_CURRENT_SOURCE_DIR}/src/Drivers/DPDK/MacAddressTest.cc -#) -#target_link_libraries(unit_test DpdkDriver) +target_sources(unit_test + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/src/Drivers/DPDK/MacAddressTest.cc +) +target_link_libraries(unit_test DpdkDriver) target_link_libraries(unit_test gmock_main) # -fno-access-control allows access to private members for testing diff --git a/include/Homa/Drivers/DPDK/DpdkDriver.h b/include/Homa/Drivers/DPDK/DpdkDriver.h index dafb05f..010d59b 100644 --- a/include/Homa/Drivers/DPDK/DpdkDriver.h +++ b/include/Homa/Drivers/DPDK/DpdkDriver.h @@ -53,14 +53,14 @@ class DpdkDriver : public Driver { * has exclusive access to DPDK. Note: This call will initialize the DPDK * EAL with default values. * - * @param port - * Selects which physical port to use for communication. + * @param ifname + * Selects which network interface to use for communication. * @param config * Optional configuration parameters (see Config). * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, const Config* const config = nullptr); + DpdkDriver(const char* ifname, const Config* const config = nullptr); /** * Construct a DpdkDriver and initialize the DPDK EAL using the provided @@ -75,7 +75,7 @@ class DpdkDriver : public Driver { * overriding the default affinity set by rte_eal_init(). * * @param port - * Selects which physical port to use for communication. + * Selects which network interface to use for communication. * @param argc * Parameter passed to rte_eal_init(). * @param argv @@ -85,7 +85,7 @@ class DpdkDriver : public Driver { * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, int argc, char* argv[], + DpdkDriver(const char* ifname, int argc, char* argv[], const Config* const config = nullptr); /// Used to signal to the DpdkDriver constructor that the DPDK EAL should @@ -101,7 +101,7 @@ class DpdkDriver : public Driver { * called before calling this constructor. * * @param port - * Selects which physical port to use for communication. + * Selects which network interface to use for communication. * @param _ * Parameter is used only to define this constructors alternate * signature. @@ -110,29 +110,20 @@ class DpdkDriver : public Driver { * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, NoEalInit _, const Config* const config = nullptr); + DpdkDriver(const char* ifname, NoEalInit _, + const Config* const config = nullptr); /** * DpdkDriver Destructor. */ virtual ~DpdkDriver(); - /// See Driver::getAddress() - virtual Address getAddress(std::string const* const addressString); - virtual Address getAddress(WireFormatAddress const* const wireAddress); - - /// See Driver::addressToString() - virtual std::string addressToString(const Address address); - - /// See Driver::addressToWireFormat() - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); - /// See Driver::allocPacket() virtual Packet* allocPacket(); /// See Driver::sendPacket() - virtual void sendPacket(Packet* packet); + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority); /// See Driver::cork() virtual void cork(); @@ -157,7 +148,7 @@ class DpdkDriver : public Driver { virtual uint32_t getBandwidth(); /// See Driver::getLocalAddress() - virtual Driver::Address getLocalAddress(); + virtual IpAddress getLocalAddress(); /// See Driver::getQueuedBytes(); virtual uint32_t getQueuedBytes(); diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 30a3548..a57a386 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -57,6 +57,8 @@ downCast(const Large& large) std::string demangle(const char* name); std::string hexDump(const void* buf, uint64_t bytes); +std::string ipToString(uint32_t ip); +uint32_t stringToIp(const char* ip); /** * This class is used to temporarily release lock in a safe fashion. Creating diff --git a/src/Drivers/DPDK/DpdkDriver.cc b/src/Drivers/DPDK/DpdkDriver.cc index c536159..1500c26 100644 --- a/src/Drivers/DPDK/DpdkDriver.cc +++ b/src/Drivers/DPDK/DpdkDriver.cc @@ -21,50 +21,21 @@ namespace Homa { namespace Drivers { namespace DPDK { -DpdkDriver::DpdkDriver(int port, const Config* const config) - : pImpl(new Impl(port, config)) +DpdkDriver::DpdkDriver(const char* ifname, const Config* const config) + : pImpl(new Impl(ifname, config)) {} -DpdkDriver::DpdkDriver(int port, int argc, char* argv[], +DpdkDriver::DpdkDriver(const char* ifname, int argc, char* argv[], const Config* const config) - : pImpl(new Impl(port, argc, argv, config)) + : pImpl(new Impl(ifname, argc, argv, config)) {} -DpdkDriver::DpdkDriver(int port, NoEalInit _, const Config* const config) - : pImpl(new Impl(port, _, config)) +DpdkDriver::DpdkDriver(const char* ifname, NoEalInit _, const Config* const config) + : pImpl(new Impl(ifname, _, config)) {} DpdkDriver::~DpdkDriver() = default; -/// See Driver::getAddress() -Driver::Address -DpdkDriver::getAddress(std::string const* const addressString) -{ - return pImpl->getAddress(addressString); -} - -/// See Driver::getAddress() -Driver::Address -DpdkDriver::getAddress(WireFormatAddress const* const wireAddress) -{ - return pImpl->getAddress(wireAddress); -} - -/// See Driver::addressToString() -std::string -DpdkDriver::addressToString(const Address address) -{ - return pImpl->addressToString(address); -} - -/// See Driver::addressToWireFormat() -void -DpdkDriver::addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) -{ - pImpl->addressToWireFormat(address, wireAddress); -} - /// See Driver::allocPacket() Driver::Packet* DpdkDriver::allocPacket() @@ -74,9 +45,9 @@ DpdkDriver::allocPacket() /// See Driver::sendPacket() void -DpdkDriver::sendPacket(Packet* packet) +DpdkDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - return pImpl->sendPacket(packet); + return pImpl->sendPacket(packet, destination, priority); } /// See Driver::cork() @@ -128,7 +99,7 @@ DpdkDriver::getBandwidth() } /// See Driver::getLocalAddress() -Driver::Address +IpAddress DpdkDriver::getLocalAddress() { return pImpl->getLocalAddress(); diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index e658ccb..ec4c58d 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -17,13 +17,19 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include +#include +#include +#include +#include + #include "DpdkDriverImpl.h" #include -#include #include "CodeLocation.h" #include "StringUtil.h" +#include "Homa/Util.h" namespace Homa { @@ -45,7 +51,7 @@ const char* default_eal_argv[] = {"homa", NULL}; * Memory location in the mbuf where the packet data should be stored. */ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) - : Driver::Packet(data, 0) + : base {.payload = data, .length = 0, .sourceIp = 0} , bufType(MBUF) , bufRef() { @@ -59,7 +65,7 @@ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) * Overflow buffer that holds this packet. */ DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) - : Driver::Packet(overflowBuf->data, 0) + : base {.payload = overflowBuf->data, .length = 0, .sourceIp = 0} , bufType(OVERFLOW_BUF) , bufRef() { @@ -69,17 +75,21 @@ DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, const Config* const config) - : Impl(port, default_eal_argc, const_cast(default_eal_argv), config) +DpdkDriver::Impl::Impl(const char* ifname, const Config* const config) + : Impl(ifname, default_eal_argc, const_cast(default_eal_argv), + config) {} /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, int argc, char* argv[], +DpdkDriver::Impl::Impl(const char* ifname, int argc, char* argv[], const Config* const config) - : port(port) - , localMac(Driver::Address(0)) + : ifname(ifname) + , port() + , arpTable() + , localIp() + , localMac("00:00:00:00:00:00") , HIGHEST_PACKET_PRIORITY( (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 @@ -124,10 +134,14 @@ DpdkDriver::Impl::Impl(int port, int argc, char* argv[], /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, __attribute__((__unused__)) NoEalInit _, +DpdkDriver::Impl::Impl(const char* ifname, + __attribute__((__unused__)) NoEalInit _, const Config* const config) - : port(port) - , localMac(Driver::Address(0)) + : ifname(ifname) + , port() + , arpTable() + , localIp() + , localMac("00:00:00:00:00:00") , HIGHEST_PACKET_PRIORITY( (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 @@ -159,37 +173,8 @@ DpdkDriver::Impl::~Impl() rte_mempool_free(mbufPool); } -// See Driver::getAddress() -Driver::Address -DpdkDriver::Impl::getAddress(std::string const* const addressString) -{ - return MacAddress(addressString->c_str()).toAddress(); -} - -// See Driver::getAddress() -Driver::Address -DpdkDriver::Impl::getAddress(Driver::WireFormatAddress const* const wireAddress) -{ - return MacAddress(wireAddress).toAddress(); -} - -/// See Driver::addressToString() -std::string -DpdkDriver::Impl::addressToString(const Driver::Address address) -{ - return MacAddress(address).toString(); -} - -/// See Driver::addressToWireFormat() -void -DpdkDriver::Impl::addressToWireFormat(const Driver::Address address, - Driver::WireFormatAddress* wireAddress) -{ - MacAddress(address).toWireFormat(wireAddress); -} - // See Driver::allocPacket() -DpdkDriver::Impl::Packet* +Driver::Packet* DpdkDriver::Impl::allocPacket() { DpdkDriver::Impl::Packet* packet = _allocMbufPacket(); @@ -199,15 +184,17 @@ DpdkDriver::Impl::allocPacket() packet = packetPool.construct(buf); NOTICE("OverflowBuffer used."); } - return packet; + return &packet->base; } // See Driver::sendPacket() void -DpdkDriver::Impl::sendPacket(Driver::Packet* packet) +DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, + int priority) { + ; DpdkDriver::Impl::Packet* pkt = - static_cast(packet); + container_of(packet, DpdkDriver::Impl::Packet, base); struct rte_mbuf* mbuf = nullptr; // If the packet is held in an Overflow buffer, we need to copy it out // into a new mbuf. @@ -223,15 +210,15 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) numMbufsAvail, numMbufsInUse); return; } - char* buf = rte_pktmbuf_append( - mbuf, Homa::Util::downCast(PACKET_HDR_LEN + pkt->length)); + char* buf = rte_pktmbuf_append(mbuf, + Homa::Util::downCast(PACKET_HDR_LEN + pkt->base.length)); if (unlikely(NULL == buf)) { WARNING("rte_pktmbuf_append call failed; dropping packet"); rte_pktmbuf_free(mbuf); return; } char* data = buf + PACKET_HDR_LEN; - rte_memcpy(data, pkt->payload, pkt->length); + rte_memcpy(data, pkt->base.payload, pkt->base.length); } else { mbuf = pkt->bufRef.mbuf; @@ -246,9 +233,14 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) // Fill out the destination and source MAC addresses plus the Ethernet // frame type (i.e., IEEE 802.1Q VLAN tagging). - MacAddress macAddr(pkt->address); + auto it = arpTable.find(destination); + if (it == arpTable.end()) { + WARNING("Failed to find ARP record for packet; dropping packet"); + return; + } + MacAddress& destMac = it->second; struct ether_hdr* ethHdr = rte_pktmbuf_mtod(mbuf, struct ether_hdr*); - rte_memcpy(ðHdr->d_addr, macAddr.address, ETHER_ADDR_LEN); + rte_memcpy(ðHdr->d_addr, destMac.address, ETHER_ADDR_LEN); rte_memcpy(ðHdr->s_addr, localMac.address, ETHER_ADDR_LEN); ethHdr->ether_type = rte_cpu_to_be_16(ETHER_TYPE_VLAN); @@ -256,13 +248,16 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) // encapsulated frame (DEI and VLAN ID are not relevant and trivially // set to 0). struct vlan_hdr* vlanHdr = reinterpret_cast(ethHdr + 1); - vlanHdr->vlan_tci = rte_cpu_to_be_16(PRIORITY_TO_PCP[pkt->priority]); + vlanHdr->vlan_tci = rte_cpu_to_be_16(PRIORITY_TO_PCP[priority]); vlanHdr->eth_proto = rte_cpu_to_be_16(EthPayloadType::HOMA); + // Store our local IP address right before the payload. + *rte_pktmbuf_mtod_offset(mbuf, uint32_t*, PACKET_HDR_LEN - 4) = localIp; + // In the normal case, we pre-allocate a pakcet's mbuf with enough // storage to hold the MAX_PAYLOAD_SIZE. If the actual payload is // smaller, trim the mbuf to size to avoid sending unecessary bits. - uint32_t actualLength = PACKET_HDR_LEN + pkt->length; + uint32_t actualLength = PACKET_HDR_LEN + pkt->base.length; uint32_t mbufDataLength = rte_pktmbuf_pkt_len(mbuf); if (actualLength < mbufDataLength) { if (rte_pktmbuf_trim(mbuf, mbufDataLength - actualLength) < 0) { @@ -274,7 +269,7 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) } // loopback if src mac == dst mac - if (localMac.toAddress() == pkt->address) { + if (localMac == destMac) { struct rte_mbuf* mbuf_clone = rte_pktmbuf_clone(mbuf, mbufPool); if (unlikely(mbuf_clone == NULL)) { WARNING("Failed to clone packet for loopback; dropping packet"); @@ -390,6 +385,9 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, } } + uint32_t srcIp = *rte_pktmbuf_mtod_offset(m, uint32_t*, headerLength); + headerLength += sizeof(srcIp); + payload += sizeof(srcIp); assert(rte_pktmbuf_pkt_len(m) >= headerLength); uint32_t length = rte_pktmbuf_pkt_len(m) - headerLength; assert(length <= MAX_PAYLOAD_SIZE); @@ -399,10 +397,10 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, SpinLock::Lock lock(packetLock); packet = packetPool.construct(m, payload); } - packet->address = MacAddress(ethHdr->s_addr.addr_bytes).toAddress(); - packet->length = length; + packet->base.length = length; + packet->base.sourceIp = srcIp; - receivedPackets[numPacketsReceived++] = packet; + receivedPackets[numPacketsReceived++] = &packet->base; } return numPacketsReceived; @@ -415,7 +413,7 @@ DpdkDriver::Impl::releasePackets(Driver::Packet* packets[], uint16_t numPackets) for (uint16_t i = 0; i < numPackets; ++i) { SpinLock::Lock lock(packetLock); DpdkDriver::Impl::Packet* packet = - static_cast(packets[i]); + container_of(packets[i], DpdkDriver::Impl::Packet, base); if (likely(packet->bufType == DpdkDriver::Impl::Packet::MBUF)) { rte_pktmbuf_free(packet->bufRef.mbuf); } else { @@ -447,10 +445,10 @@ DpdkDriver::Impl::getBandwidth() } // See Driver::getLocalAddress() -Driver::Address +IpAddress DpdkDriver::Impl::getLocalAddress() { - return localMac.toAddress(); + return localIp; } // See Driver::getQueuedBytes(); @@ -490,11 +488,71 @@ DpdkDriver::Impl::_eal_init(int argc, char* argv[]) void DpdkDriver::Impl::_init() { - struct ether_addr mac; struct rte_eth_conf portConf; int ret; uint16_t mtu; + // Populate the ARP table with records in /proc/net/arp (inspired by + // net-tools/arp.c) + std::ifstream input("/proc/net/arp"); + for (std::string line; getline(input, line);) { + char ip[100]; + char hwa[100]; + char mask[100]; + char dev[100]; + int type, flags; + int cols = sscanf(line.c_str(), "%s 0x%x 0x%x %99s %99s %99s\n", + ip, &type, &flags, hwa, mask, dev); + if (cols != 6) continue; + arpTable.emplace(Homa::Util::stringToIp(ip), hwa); + } + + // Use ioctl to obtain the IP and MAC addresses of the network interface. + struct ifreq ifr; + ifname.copy(ifr.ifr_name, ifname.length()); + ifr.ifr_name[ifname.length() + 1] = 0; + if (ifname.length() >= sizeof(ifr.ifr_name)) { + throw DriverInitFailure(HERE_STR, + StringUtil::format("Interface name %s too long", ifname.c_str())); + } + + int fd = socket(AF_INET, SOCK_DGRAM, 0); + if (fd == -1) { + throw DriverInitFailure(HERE_STR, + StringUtil::format("Failed to create socket: %s", strerror(errno))); + } + + if (ioctl(fd, SIOCGIFADDR, &ifr) == -1) { + char* error = strerror(errno); + close(fd); + throw DriverInitFailure(HERE_STR, + StringUtil::format("Failed to obtain IP address: %s", error)); + } + localIp = be32toh(((struct sockaddr_in*) &ifr.ifr_addr)->sin_addr.s_addr); + + if (ioctl(fd, SIOCGIFHWADDR, &ifr) == -1) { + char* error = strerror(errno); + close(fd); + throw DriverInitFailure(HERE_STR, + StringUtil::format("Failed to obtain MAC address: %s", error)); + } + close(fd); + memcpy(localMac.address, ifr.ifr_hwaddr.sa_data, 6); + + // Iterate over ethernet devices to locate the port identifier. + int p; + RTE_ETH_FOREACH_DEV(p) { + struct ether_addr mac; + rte_eth_macaddr_get(p, &mac); + if (MacAddress(mac.addr_bytes) == localMac) { + port = p; + break; + } + } + NOTICE("Using interface %s, ip %s, mac %s, port %u", + ifname.c_str(), Homa::Util::ipToString(localIp).c_str(), + localMac.toString().c_str(), port); + std::string poolName = StringUtil::format("homa_mbuf_pool_%u", port); std::string ringName = StringUtil::format("homa_loopback_ring_%u", port); @@ -518,10 +576,6 @@ DpdkDriver::Impl::_init() StringUtil::format("Ethernet port %u doesn't exist", port)); } - // Read the MAC address from the NIC via DPDK. - rte_eth_macaddr_get(port, &mac); - new (const_cast(&localMac)) MacAddress(mac.addr_bytes); - // configure some default NIC port parameters memset(&portConf, 0, sizeof(portConf)); portConf.rxmode.max_rx_pkt_len = ETHER_MAX_VLAN_FRAME_LEN; diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 9b77383..4ed3406 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -28,6 +28,7 @@ #include #include +#include #include "MacAddress.h" #include "ObjectPool.h" @@ -65,8 +66,13 @@ const uint16_t MAX_PKT_BURST = 32; /// field defined in the VLAN tag to specify the packet priority. const uint32_t VLAN_TAG_LEN = 4; -// Size of Ethernet header including VLAN tag, in bytes. -const uint32_t PACKET_HDR_LEN = ETHER_HDR_LEN + VLAN_TAG_LEN; +/// Strictly speaking, this DPDK driver is supposed to send/receive IP packets; +/// however, it currently only records the source IP address right after the +/// Ethernet header for simplicity. +const uint32_t IP_HDR_LEN = sizeof(IpAddress); + +// Size of Ethernet header including VLAN tag plus IP header, in bytes. +const uint32_t PACKET_HDR_LEN = ETHER_HDR_LEN + VLAN_TAG_LEN + IP_HDR_LEN; // The MTU (Maximum Transmission Unit) size of an Ethernet frame, which is the // maximum size of the packet an Ethernet frame can carry in its payload. This @@ -104,11 +110,13 @@ class DpdkDriver::Impl { * Dpdk specific Packet object used to track a its lifetime and * contents. */ - class Packet : public Driver::Packet { - public: + struct Packet { explicit Packet(struct rte_mbuf* mbuf, void* data); explicit Packet(OverflowBuffer* overflowBuf); + /// C-style "inheritance" + Driver::Packet base; + /// Used to indicate whether the packet is backed by an DPDK mbuf or a /// driver-level OverflowBuffer. enum BufferType { MBUF, OVERFLOW_BUF } bufType; ///< Packet BufferType. @@ -122,26 +130,18 @@ class DpdkDriver::Impl { /// The memory location of this packet's header. The header should be /// PACKET_HDR_LEN in length. void* header; - - private: - Packet(const Packet&) = delete; - Packet& operator=(const Packet&) = delete; }; - Impl(int port, const Config* const config = nullptr); - Impl(int port, int argc, char* argv[], + Impl(const char* ifname, const Config* const config = nullptr); + Impl(const char* ifname, int argc, char* argv[], const Config* const config = nullptr); - Impl(int port, NoEalInit _, const Config* const config = nullptr); + Impl(const char* ifname, NoEalInit _, const Config* const config = nullptr); virtual ~Impl(); // Interface Methods - Driver::Address getAddress(std::string const* const addressString); - Driver::Address getAddress(WireFormatAddress const* const wireAddress); - std::string addressToString(const Address address); - void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); - Packet* allocPacket(); - void sendPacket(Driver::Packet* packet); + Driver::Packet* allocPacket(); + void sendPacket(Driver::Packet* packet, IpAddress destination, + int priority); void cork(); void uncork(); uint32_t receivePackets(uint32_t maxPackets, @@ -150,7 +150,7 @@ class DpdkDriver::Impl { int getHighestPacketPriority(); uint32_t getMaxPayloadSize(); uint32_t getBandwidth(); - Driver::Address getLocalAddress(); + IpAddress getLocalAddress(); uint32_t getQueuedBytes(); private: @@ -163,12 +163,21 @@ class DpdkDriver::Impl { static void txBurstErrorCallback(struct rte_mbuf* pkts[], uint16_t unsent, void* userdata); + /// Name of the Linux network interface to be used by DPDK. + std::string ifname; + /// Stores the NIC's physical port id addressed by the instantiated /// driver. - const uint16_t port; + uint16_t port; + + /// Address resolution table that translates IP addresses to MAC addresses. + std::unordered_map arpTable; + + /// Stores the IpAddress of the driver. + IpAddress localIp; - /// Stores the address of the NIC (either native or set by override). - const MacAddress localMac; + /// Stores the HW address of the NIC (either native or set by override). + MacAddress localMac; /// Stores the driver's maximum network packet priority (either default or /// set by override). diff --git a/src/Drivers/DPDK/MacAddress.cc b/src/Drivers/DPDK/MacAddress.cc index 0178851..63149fa 100644 --- a/src/Drivers/DPDK/MacAddress.cc +++ b/src/Drivers/DPDK/MacAddress.cc @@ -18,7 +18,6 @@ #include "StringUtil.h" #include "../../CodeLocation.h" -#include "../RawAddressType.h" namespace Homa { namespace Drivers { @@ -55,33 +54,6 @@ MacAddress::MacAddress(const char* macStr) address[i] = Util::downCast(bytes[i]); } -/** - * Create a new address from a given address in its raw byte format. - * @param raw - * The raw bytes format. - * - * @sa Driver::Address::Raw - */ -MacAddress::MacAddress(const Driver::WireFormatAddress* const wireAddress) -{ - if (wireAddress->type != RawAddressType::MAC) { - throw BadAddress(HERE_STR, "Bad address: Raw format is not type MAC"); - } - static_assert(sizeof(wireAddress->bytes) >= 6); - memcpy(address, wireAddress->bytes, 6); -} - -/** - * Create a new address given the Driver::Address representation. - * - * @param addr - * The Driver::Address representation of an address. - */ -MacAddress::MacAddress(const Driver::Address addr) -{ - memcpy(address, &addr, 6); -} - /** * Return the string representation of this address. */ @@ -94,31 +66,6 @@ MacAddress::toString() const return buf; } -/** - * Serialized this address into a wire format. - * - * @param[out] wireAddress - * WireFormatAddress object to which the this address is serialized. - */ -void -MacAddress::toWireFormat(Driver::WireFormatAddress* wireAddress) const -{ - static_assert(sizeof(wireAddress->bytes) >= 6); - memcpy(wireAddress->bytes, address, 6); - wireAddress->type = RawAddressType::MAC; -} - -/** - * Return a Driver::Address representation of this address. - */ -Driver::Address -MacAddress::toAddress() const -{ - Driver::Address addr = 0; - memcpy(&addr, address, 6); - return addr; -} - /** * @return * True if the MacAddress consists of all zero bytes, false if not. diff --git a/src/Drivers/DPDK/MacAddress.h b/src/Drivers/DPDK/MacAddress.h index 1106eec..148f2ce 100644 --- a/src/Drivers/DPDK/MacAddress.h +++ b/src/Drivers/DPDK/MacAddress.h @@ -28,14 +28,19 @@ namespace DPDK { struct MacAddress { explicit MacAddress(const uint8_t raw[6]); explicit MacAddress(const char* macStr); - explicit MacAddress(const Driver::WireFormatAddress* const wireAddress); - explicit MacAddress(const Driver::Address addr); MacAddress(const MacAddress&) = default; std::string toString() const; - void toWireFormat(Driver::WireFormatAddress* wireAddress) const; - Driver::Address toAddress() const; bool isNull() const; + /** + * Equality function for MacAddress, for use in std::unordered_maps etc. + */ + bool operator==(const MacAddress& other) const + { + return (*(uint32_t*)(address + 0) == *(uint32_t*)(other.address + 0)) && + (*(uint16_t*)(address + 4) == *(uint16_t*)(other.address + 4)); + } + /// The raw bytes of the MAC address. uint8_t address[6]; }; diff --git a/src/Drivers/DPDK/MacAddressTest.cc b/src/Drivers/DPDK/MacAddressTest.cc index 329c309..7587a16 100644 --- a/src/Drivers/DPDK/MacAddressTest.cc +++ b/src/Drivers/DPDK/MacAddressTest.cc @@ -15,8 +15,6 @@ #include "MacAddress.h" -#include "../RawAddressType.h" - #include namespace Homa { @@ -35,26 +33,6 @@ TEST(MacAddressTest, constructorString) EXPECT_EQ("de:ad:be:ef:98:76", MacAddress("de:ad:be:ef:98:76").toString()); } -TEST(MacAddressTest, constructorWireFormatAddress) -{ - uint8_t bytes[] = {0xde, 0xad, 0xbe, 0xef, 0x98, 0x76}; - Driver::WireFormatAddress wireformatAddress; - wireformatAddress.type = RawAddressType::MAC; - memcpy(wireformatAddress.bytes, bytes, 6); - EXPECT_EQ("de:ad:be:ef:98:76", MacAddress(&wireformatAddress).toString()); - - wireformatAddress.type = RawAddressType::FAKE; - EXPECT_THROW(MacAddress address(&wireformatAddress), BadAddress); -} - -TEST(MacAddressTest, constructorAddress) -{ - uint8_t raw[] = {0xde, 0xad, 0xbe, 0xef, 0x98, 0x76}; - MacAddress(raw).toString(); - Driver::Address addr = MacAddress("de:ad:be:ef:98:76").toAddress(); - EXPECT_EQ("de:ad:be:ef:98:76", MacAddress(addr).toString()); -} - TEST(MacAddressTest, construct_DefaultCopy) { MacAddress source("de:ad:be:ef:98:76"); @@ -67,24 +45,6 @@ TEST(MacAddressTest, toString) // tested sufficiently in constructor tests } -TEST(MacAddressTest, toWireFormat) -{ - Driver::WireFormatAddress wireformatAddress; - MacAddress("de:ad:be:ef:98:76").toWireFormat(&wireformatAddress); - EXPECT_EQ(RawAddressType::MAC, wireformatAddress.type); - EXPECT_EQ(0xde, wireformatAddress.bytes[0]); - EXPECT_EQ(0xad, wireformatAddress.bytes[1]); - EXPECT_EQ(0xbe, wireformatAddress.bytes[2]); - EXPECT_EQ(0xef, wireformatAddress.bytes[3]); - EXPECT_EQ(0x98, wireformatAddress.bytes[4]); - EXPECT_EQ(0x76, wireformatAddress.bytes[5]); -} - -TEST(MacAddressTest, toAddress) -{ - // Tested in constructorAddress -} - TEST(MacAddressTest, isNull) { uint8_t rawNull[] = {0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; diff --git a/src/Drivers/RawAddressType.h b/src/Drivers/RawAddressType.h deleted file mode 100644 index 1def76d..0000000 --- a/src/Drivers/RawAddressType.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2019, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#ifndef HOMA_DRIVERS_RAWADDRESSTYPE_H -#define HOMA_DRIVERS_RAWADDRESSTYPE_H - -namespace Homa { -namespace Drivers { - -/** - * Identifies a particular raw serialized byte-format for a Driver::Address - * supported by this project. The types are enumerated here in one place to - * ensure drivers do have overlapping type identifiers. New drivers that wish - * to claim a type id should add an entry to this enum. - * - * @sa Driver::Address::Raw - */ -enum RawAddressType { - FAKE = 0, - MAC = 1, -}; - -} // namespace Drivers -} // namespace Homa - -#endif // HOMA_DRIVERS_RAWADDRESSTYPE_H diff --git a/src/Util.cc b/src/Util.cc index 90ee9f4..fe73752 100644 --- a/src/Util.cc +++ b/src/Util.cc @@ -100,5 +100,21 @@ hexDump(const void* buf, uint64_t bytes) return output.str(); } +std::string +ipToString(uint32_t ip) +{ + return StringUtil::format("%d.%d.%d.%d", + (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); +} + +uint32_t +stringToIp(const char* ipStr) +{ + unsigned int bytes[4]; + sscanf(ipStr, "%u.%u.%u.%u", &bytes[0], &bytes[1], &bytes[2], &bytes[3]); + return (bytes[0] << 24) | (bytes[1] << 16) | (bytes[2] << 8) | bytes[3]; +} + + } // namespace Util } // namespace Homa diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 403a340..e01e945 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -34,6 +34,18 @@ target_link_libraries(system_test docopt ) +## dpdk_test ################################################################# + +add_executable(dpdk_test + dpdk_test.cc +) +target_link_libraries(dpdk_test + PRIVATE + Homa::DpdkDriver + docopt + PerfUtils +) + ## Perf ######################################################################## add_executable(Perf diff --git a/test/Output.h b/test/Output.h new file mode 100644 index 0000000..bea8f8b --- /dev/null +++ b/test/Output.h @@ -0,0 +1,105 @@ +#pragma once + +#include +#include +#include +#include + +namespace Output { + +using Latency = std::chrono::duration; + +struct TimeDist { + Latency min; // Fastest time seen (seconds). + Latency p50; // Median time per operation (seconds). + Latency p90; // 90th percentile time/op (seconds). + Latency p99; // 99th percentile time/op (seconds). + Latency p999; // 99.9th percentile time/op (seconds). +}; + +std::string +format(const std::string& format, ...) +{ + va_list args; + va_start(args, format); + size_t len = std::vsnprintf(NULL, 0, format.c_str(), args); + va_end(args); + std::vector vec(len + 1); + va_start(args, format); + std::vsnprintf(&vec[0], len + 1, format.c_str(), args); + va_end(args); + return &vec[0]; +} + +std::string +formatTime(Latency seconds) +{ + if (seconds < std::chrono::duration(1)) { + return format( + "%5.1f ns", + std::chrono::duration(seconds).count()); + } else if (seconds < std::chrono::duration(1)) { + return format( + "%5.1f us", + std::chrono::duration(seconds).count()); + } else if (seconds < std::chrono::duration(1)) { + return format( + "%5.2f ms", + std::chrono::duration(seconds).count()); + } else { + return format("%5.2f s ", seconds.count()); + } +} + +std::string +basicHeader() +{ + return "median min p90 p99 p999 description"; +} + +std::string +basic(std::vector& times, const std::string description) +{ + int count = times.size(); + std::sort(times.begin(), times.end()); + + TimeDist dist; + + dist.min = times[0]; + int index = count / 2; + if (index < count) { + dist.p50 = times.at(index); + } else { + dist.p50 = dist.min; + } + index = count - (count + 5) / 10; + if (index < count) { + dist.p90 = times.at(index); + } else { + dist.p90 = dist.p50; + } + index = count - (count + 50) / 100; + if (index < count) { + dist.p99 = times.at(index); + } else { + dist.p99 = dist.p90; + } + index = count - (count + 500) / 1000; + if (index < count) { + dist.p999 = times.at(index); + } else { + dist.p999 = dist.p99; + } + + std::string output = ""; + output += format("%9s", formatTime(dist.p50).c_str()); + output += format(" %9s", formatTime(dist.min).c_str()); + output += format(" %9s", formatTime(dist.p90).c_str()); + output += format(" %9s", formatTime(dist.p99).c_str()); + output += format(" %9s", formatTime(dist.p999).c_str()); + output += " "; + output += description; + return output; +} + +} // namespace Output diff --git a/test/dpdk_test.cc b/test/dpdk_test.cc new file mode 100644 index 0000000..ebf9ba2 --- /dev/null +++ b/test/dpdk_test.cc @@ -0,0 +1,88 @@ +#include +#include + +#include +#include +#include +#include + +#include "Output.h" + +static const char USAGE[] = R"(DPDK Driver Test. + + Usage: + dpdk_test [options] (--server | ) + + Options: + -h --help Show this screen. + --version Show version. + --timetrace Enable TimeTrace output [default: false]. +)"; + +int +main(int argc, char* argv[]) +{ + std::map args = + docopt::docopt(USAGE, {argv + 1, argv + argc}, + true, // show help if requested + "DPDK Driver Test"); // version string + + std::string iface = args[""].asString(); + bool isServer = args["--server"].asBool(); + std::string server_ip_string; + if (!isServer) { + server_ip_string = args[""].asString(); + } + + Homa::Drivers::DPDK::DpdkDriver driver(iface.c_str()); + + if (isServer) { + std::cout << Homa::Util::ipToString(driver.getLocalAddress()) + << std::endl; + while (true) { + Homa::Driver::Packet* incoming[10]; + uint32_t receivedPackets; + do { + receivedPackets = driver.receivePackets(10, incoming); + } while (receivedPackets == 0); + Homa::Driver::Packet* pong = driver.allocPacket(); + pong->length = 100; + driver.sendPacket(pong, incoming[0]->sourceIp, 0); + driver.releasePackets(incoming, receivedPackets); + driver.releasePackets(&pong, 1); + } + } else { + Homa::IpAddress server_ip = + Homa::Util::stringToIp(server_ip_string.c_str()); + std::vector times; + for (int i = 0; i < 100000; ++i) { + uint64_t start = PerfUtils::Cycles::rdtsc(); + PerfUtils::TimeTrace::record(start, "START"); + Homa::Driver::Packet* ping = driver.allocPacket(); + PerfUtils::TimeTrace::record("allocPacket"); + ping->length = 100; + PerfUtils::TimeTrace::record("set ping args"); + driver.sendPacket(ping, server_ip, 0); + PerfUtils::TimeTrace::record("sendPacket"); + driver.releasePackets(&ping, 1); + PerfUtils::TimeTrace::record("releasePacket"); + Homa::Driver::Packet* incoming[10]; + uint32_t receivedPackets; + do { + receivedPackets = driver.receivePackets(10, incoming); + PerfUtils::TimeTrace::record("receivePackets"); + } while (receivedPackets == 0); + driver.releasePackets(incoming, receivedPackets); + PerfUtils::TimeTrace::record("releasePacket"); + uint64_t stop = PerfUtils::Cycles::rdtsc(); + times.emplace_back(PerfUtils::Cycles::toSeconds(stop - start)); + } + if (args["--timetrace"].asBool()) { + PerfUtils::TimeTrace::print(); + } + std::cout << Output::basicHeader() << std::endl; + std::cout << Output::basic(times, "DpdkDriver Ping-Pong") << std::endl; + } + + return 0; +} \ No newline at end of file From 23e6c28cd70570a6d7bb3f37c67b8d02ebd22f36 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Thu, 6 Aug 2020 23:04:20 -0700 Subject: [PATCH 04/15] Improvements based on code review discussions with Collin - Change IpAddress from typedef to a POD type to provide some type safety - Add a third argument to Driver::receivePackets() to hold source addresses of ingress packets when the method returns - Eliminate Driver::Packet (use Homa::PacketSpec instead) - Move L4 header fields sport/dport into header prefix --- CMakeLists.txt | 3 +- include/Homa/Driver.h | 89 ++++++++++-------- include/Homa/Drivers/DPDK/DpdkDriver.h | 3 +- include/Homa/Drivers/Fake/FakeDriver.h | 16 ++-- include/Homa/Util.h | 11 +-- src/Driver.cc | 38 ++++++++ src/Drivers/DPDK/DpdkDriver.cc | 5 +- src/Drivers/DPDK/DpdkDriverImpl.cc | 23 +++-- src/Drivers/DPDK/DpdkDriverImpl.h | 5 +- src/Drivers/Fake/FakeDriver.cc | 28 +++--- src/Drivers/Fake/FakeDriverTest.cc | 23 ++--- src/Mock/MockDriver.h | 3 +- src/Policy.h | 3 +- src/PolicyTest.cc | 2 +- src/Protocol.h | 24 ++--- src/Receiver.cc | 4 +- src/Receiver.h | 2 +- src/ReceiverTest.cc | 123 +++++++++++++------------ src/SenderTest.cc | 4 +- src/TransportImpl.cc | 6 +- src/Util.cc | 16 ---- test/Output.h | 15 +++ test/dpdk_test.cc | 27 +++++- 23 files changed, 283 insertions(+), 190 deletions(-) create mode 100644 src/Driver.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a6f9c1..f5cb6ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.11) -project(Homa VERSION 0.1.1.0 LANGUAGES CXX) +project(Homa VERSION 0.1.2.0 LANGUAGES CXX) ################################################################################ ## Dependency Configuration #################################################### @@ -74,6 +74,7 @@ endif() add_library(Homa src/CodeLocation.cc src/Debug.cc + src/Driver.cc src/Homa.cc src/Perf.cc src/Policy.cc diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index d510046..a5cc855 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -22,8 +22,49 @@ namespace Homa { -/// IPv4 address in host byte order. -using IpAddress = uint32_t; +/** + * A simple wrapper struct around an IP address in binary format. + * + * This struct is meant to provide some type-safety when manipulating IP + * addresses. In order to avoid any runtime overhead, this struct contains + * nothing more than the IP address, so it is trivially copyable. + */ +struct IpAddress final { + /// IPv4 address in host byte order. + uint32_t addr; + + /** + * Unbox the IP address in binary format. + */ + explicit operator uint32_t() + { + return addr; + } + + /** + * Equality function for IpAddress, for use in std::unordered_maps etc. + */ + bool operator==(const IpAddress& other) const + { + return addr == other.addr; + } + + /** + * This class computes a hash of an IpAddress, so that IpAddress can be used + * as keys in unordered_maps. + */ + struct Hasher { + /// Return a "hash" of the given IpAddress. + std::size_t operator()(const IpAddress& address) const + { + return std::hash{}(address.addr); + } + }; + + static std::string toString(IpAddress address); + static IpAddress fromString(const char* addressStr); +}; +static_assert(std::is_trivially_copyable()); /** * Represents a packet of data that can be send or is received over the network. @@ -43,6 +84,7 @@ struct PacketSpec { /// Number of bytes in the payload. int32_t length; } __attribute__((packed)); +static_assert(std::is_trivial()); /** * Used by Homa::Transport to send and receive unreliable datagrams. Provides @@ -52,41 +94,8 @@ struct PacketSpec { */ class Driver { public: - /** - * Represents a packet that can be send or is received over the network. - * - * The layout of this struct has two parts: the first part is essentially - * a copy of PacketSpec, while the second part contains members specific - * to our driver implementation. - * - * @sa Homa::PacketSpec - */ - struct Packet final { - // === PacketSpec definitions === - // The order and types of the following members must match those in - // PacketSpec precisely. - - /// See Homa::PacketSpec::payload. - void* payload; - - /// See Homa::PacketSpec::length - int32_t length; - - // === Extended definitions === - // The following members are specific to the driver framework bundled - // in this library. Therefore, these members must *NOT* appear in the - // core components of Homa transport; they are only used in a few - // places to facilitate the glue code between transport and driver. - - /// Packet's source IpAddress. Only meaningful when this packet is an - /// incoming packet. - IpAddress sourceIp; - } __attribute__((packed)); - - // Static checks to enforce the object layout compatibility between - // Driver::Packet and PacketSpec. - static_assert(offsetof(Packet, payload) == offsetof(PacketSpec, payload)); - static_assert(offsetof(Packet, length) == offsetof(PacketSpec, length)); + /// Import PacketSpec into the Driver namespace. + using Packet = PacketSpec; /** * Driver destructor. @@ -164,6 +173,9 @@ class Driver { * this method. * @param[out] receivedPackets * Received packets are appended to this array in order of arrival. + * @param[out] sourceAddresses + * Source IP addresses of the received packets are appended to this + * array in order of arrival. * * @return * Number of Packet objects being returned. @@ -171,7 +183,8 @@ class Driver { * @sa Driver::releasePackets() */ virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]) = 0; + Packet* receivedPackets[], + IpAddress sourceAddresses[]) = 0; /** * Release a collection of Packet objects back to the Driver. Every diff --git a/include/Homa/Drivers/DPDK/DpdkDriver.h b/include/Homa/Drivers/DPDK/DpdkDriver.h index 010d59b..f15d575 100644 --- a/include/Homa/Drivers/DPDK/DpdkDriver.h +++ b/include/Homa/Drivers/DPDK/DpdkDriver.h @@ -133,7 +133,8 @@ class DpdkDriver : public Driver { /// See Driver::receivePackets() virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]); + Packet* receivedPackets[], + IpAddress sourceAddresses[]); /// See Driver::releasePackets() virtual void releasePackets(Packet* packets[], uint16_t numPackets); diff --git a/include/Homa/Drivers/Fake/FakeDriver.h b/include/Homa/Drivers/Fake/FakeDriver.h index 04ce8c0..dd01261 100644 --- a/include/Homa/Drivers/Fake/FakeDriver.h +++ b/include/Homa/Drivers/Fake/FakeDriver.h @@ -58,14 +58,17 @@ struct FakePacket { /// Raw storage for this packets payload. char buf[MAX_PAYLOAD_SIZE]; + /// Source IpAddress of the packet. + IpAddress sourceIp; + /** * FakePacket constructor. */ explicit FakePacket() : base{.payload = buf, - .length = 0, - .sourceIp = 0} + .length = 0} , buf() + , sourceIp() {} /** @@ -73,9 +76,9 @@ struct FakePacket { */ FakePacket(const FakePacket& other) : base{.payload = buf, - .length = other.base.length, - .sourceIp = 0} + .length = other.base.length} , buf() + , sourceIp() { memcpy(base.payload, other.base.payload, MAX_PAYLOAD_SIZE); } @@ -111,7 +114,8 @@ class FakeDriver : public Driver { virtual Packet* allocPacket(); virtual void sendPacket(Packet* packet, IpAddress destination, int priority); virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]); + Packet* receivedPackets[], + IpAddress sourceAddresses[]); virtual void releasePackets(Packet* packets[], uint16_t numPackets); virtual int getHighestPacketPriority(); virtual uint32_t getMaxPayloadSize(); @@ -121,7 +125,7 @@ class FakeDriver : public Driver { private: /// Identifier for this driver on the fake network. - uint64_t localAddressId; + uint32_t localAddressId; /// Holds the incoming packets for this driver. FakeNIC nic; diff --git a/include/Homa/Util.h b/include/Homa/Util.h index a57a386..ba757e6 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -22,10 +22,11 @@ #include /// Cast a member of a structure out to the containing structure. -#define container_of(ptr, type, member) ({ \ - const typeof( ((type *)0)->member ) \ - *__mptr = (ptr); \ - (type *)( (char *)__mptr - offsetof(type,member) );}) +template +P* container_of(M* ptr, const M P::*member) +{ + return (P*)((char*) ptr - (size_t) &(reinterpret_cast(0)->*member)); +} namespace Homa { namespace Util { @@ -57,8 +58,6 @@ downCast(const Large& large) std::string demangle(const char* name); std::string hexDump(const void* buf, uint64_t bytes); -std::string ipToString(uint32_t ip); -uint32_t stringToIp(const char* ip); /** * This class is used to temporarily release lock in a safe fashion. Creating diff --git a/src/Driver.cc b/src/Driver.cc new file mode 100644 index 0000000..c7d61cb --- /dev/null +++ b/src/Driver.cc @@ -0,0 +1,38 @@ +/* Copyright (c) 2018-2019, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include + +#include "StringUtil.h" + +namespace Homa { + +std::string +IpAddress::toString(IpAddress address) +{ + uint32_t ip = address.addr; + return StringUtil::format("%d.%d.%d.%d", + (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); +} + +IpAddress +IpAddress::fromString(const char* addressStr) +{ + unsigned int b0, b1, b2, b3; + sscanf(addressStr, "%u.%u.%u.%u", &b0, &b1, &b2, &b3); + return IpAddress{(b0 << 24u) | (b1 << 16u) | (b2 << 8u) | b3}; +} + +} // namespace Homa diff --git a/src/Drivers/DPDK/DpdkDriver.cc b/src/Drivers/DPDK/DpdkDriver.cc index 1500c26..3c8833a 100644 --- a/src/Drivers/DPDK/DpdkDriver.cc +++ b/src/Drivers/DPDK/DpdkDriver.cc @@ -66,9 +66,10 @@ DpdkDriver::uncork() /// See Driver::receivePackets() uint32_t -DpdkDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) +DpdkDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[], + IpAddress sourceAddresses[]) { - return pImpl->receivePackets(maxPackets, receivedPackets); + return pImpl->receivePackets(maxPackets, receivedPackets, sourceAddresses); } /// See Driver::releasePackets() void diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index ec4c58d..42a2340 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -51,7 +51,8 @@ const char* default_eal_argv[] = {"homa", NULL}; * Memory location in the mbuf where the packet data should be stored. */ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) - : base {.payload = data, .length = 0, .sourceIp = 0} + : base {.payload = data, + .length = 0} , bufType(MBUF) , bufRef() { @@ -65,7 +66,8 @@ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) * Overflow buffer that holds this packet. */ DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) - : base {.payload = overflowBuf->data, .length = 0, .sourceIp = 0} + : base {.payload = overflowBuf->data, + .length = 0} , bufType(OVERFLOW_BUF) , bufRef() { @@ -252,7 +254,8 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, vlanHdr->eth_proto = rte_cpu_to_be_16(EthPayloadType::HOMA); // Store our local IP address right before the payload. - *rte_pktmbuf_mtod_offset(mbuf, uint32_t*, PACKET_HDR_LEN - 4) = localIp; + *rte_pktmbuf_mtod_offset(mbuf, uint32_t*, PACKET_HDR_LEN - 4) = + (uint32_t)localIp; // In the normal case, we pre-allocate a pakcet's mbuf with enough // storage to hold the MAX_PAYLOAD_SIZE. If the actual payload is @@ -322,7 +325,8 @@ DpdkDriver::Impl::uncork() // See Driver::receivePackets() uint32_t DpdkDriver::Impl::receivePackets(uint32_t maxPackets, - Driver::Packet* receivedPackets[]) + Driver::Packet* receivedPackets[], + IpAddress sourceAddresses[]) { uint32_t numPacketsReceived = 0; @@ -398,9 +402,10 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, packet = packetPool.construct(m, payload); } packet->base.length = length; - packet->base.sourceIp = srcIp; - receivedPackets[numPacketsReceived++] = &packet->base; + receivedPackets[numPacketsReceived] = &packet->base; + sourceAddresses[numPacketsReceived] = {srcIp}; + ++numPacketsReceived; } return numPacketsReceived; @@ -504,7 +509,7 @@ DpdkDriver::Impl::_init() int cols = sscanf(line.c_str(), "%s 0x%x 0x%x %99s %99s %99s\n", ip, &type, &flags, hwa, mask, dev); if (cols != 6) continue; - arpTable.emplace(Homa::Util::stringToIp(ip), hwa); + arpTable.emplace(IpAddress::fromString(ip), hwa); } // Use ioctl to obtain the IP and MAC addresses of the network interface. @@ -528,7 +533,7 @@ DpdkDriver::Impl::_init() throw DriverInitFailure(HERE_STR, StringUtil::format("Failed to obtain IP address: %s", error)); } - localIp = be32toh(((struct sockaddr_in*) &ifr.ifr_addr)->sin_addr.s_addr); + localIp = {be32toh(((struct sockaddr_in*) &ifr.ifr_addr)->sin_addr.s_addr)}; if (ioctl(fd, SIOCGIFHWADDR, &ifr) == -1) { char* error = strerror(errno); @@ -550,7 +555,7 @@ DpdkDriver::Impl::_init() } } NOTICE("Using interface %s, ip %s, mac %s, port %u", - ifname.c_str(), Homa::Util::ipToString(localIp).c_str(), + ifname.c_str(), IpAddress::toString(localIp).c_str(), localMac.toString().c_str(), port); std::string poolName = StringUtil::format("homa_mbuf_pool_%u", port); diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 4ed3406..4d664fb 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -145,7 +145,8 @@ class DpdkDriver::Impl { void cork(); void uncork(); uint32_t receivePackets(uint32_t maxPackets, - Driver::Packet* receivedPackets[]); + Driver::Packet* receivedPackets[], + IpAddress sourceAddresses[]); void releasePackets(Driver::Packet* packets[], uint16_t numPackets); int getHighestPacketPriority(); uint32_t getMaxPayloadSize(); @@ -171,7 +172,7 @@ class DpdkDriver::Impl { uint16_t port; /// Address resolution table that translates IP addresses to MAC addresses. - std::unordered_map arpTable; + std::unordered_map arpTable; /// Stores the IpAddress of the driver. IpAddress localIp; diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index b6355cc..5cbafb8 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -56,19 +56,21 @@ static class FakeNetwork { /// Register the FakeNIC so it can receive packets. Returns the newly /// registered FakeNIC's addressId. - uint64_t registerNIC(FakeNIC* nic) + uint32_t registerNIC(FakeNIC* nic) { std::lock_guard lock(mutex); - uint64_t addressId = nextAddressId.fetch_add(1); - network.insert({addressId, nic}); + uint32_t addressId = nextAddressId.fetch_add(1); + IpAddress ipAddress{addressId}; + network.insert({ipAddress, nic}); return addressId; } /// Remove the FakeNIC from the network. - void deregisterNIC(uint64_t addressId) + void deregisterNIC(uint32_t addressId) { std::lock_guard lock(mutex); - network.erase(addressId); + IpAddress ipAddress{addressId}; + network.erase(ipAddress); } /// Deliver the provide packet to the specified destination. @@ -92,7 +94,7 @@ static class FakeNetwork { assert(nic != nullptr); std::lock_guard lock_nic(nic->mutex, std::adopt_lock); FakePacket* dstPacket = new FakePacket(*packet); - dstPacket->base.sourceIp = src; + dstPacket->sourceIp = src; assert(priority < NUM_PRIORITIES); assert(priority >= 0); nic->priorityQueue.at(priority).push_back(dstPacket); @@ -115,10 +117,10 @@ static class FakeNetwork { std::mutex mutex; /// Holds all the packets being sent through the fake network. - std::unordered_map network; + std::unordered_map network; /// Identifier for the next FakeDriver that "connects" to the FakeNetwork. - std::atomic nextAddressId; + std::atomic nextAddressId; /// Rate at which packets should be dropped when sent over this network. double packetLossRate; @@ -192,7 +194,7 @@ FakeDriver::allocPacket() void FakeDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - FakePacket* srcPacket = container_of(packet, FakePacket, base); + FakePacket* srcPacket = container_of(packet, &FakePacket::base); IpAddress srcAddress = getLocalAddress(); IpAddress dstAddress = destination; fakeNetwork.sendPacket(srcPacket, priority, srcAddress, dstAddress); @@ -203,7 +205,8 @@ FakeDriver::sendPacket(Packet* packet, IpAddress destination, int priority) * See Driver::receivePackets() */ uint32_t -FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) +FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[], + IpAddress sourceAddresses[]) { std::lock_guard lock_nic(nic.mutex); uint32_t numReceived = 0; @@ -212,6 +215,7 @@ FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) FakePacket* fakePacket = nic.priorityQueue.at(i).front(); nic.priorityQueue.at(i).pop_front(); receivedPackets[numReceived] = &fakePacket->base; + sourceAddresses[numReceived] = fakePacket->sourceIp; numReceived++; } } @@ -225,7 +229,7 @@ void FakeDriver::releasePackets(Packet* packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { - delete container_of(packets[i], FakePacket, base); + delete container_of(packets[i], &FakePacket::base); } } @@ -263,7 +267,7 @@ FakeDriver::getBandwidth() IpAddress FakeDriver::getLocalAddress() { - return localAddressId; + return IpAddress{localAddressId}; } /** diff --git a/src/Drivers/Fake/FakeDriverTest.cc b/src/Drivers/Fake/FakeDriverTest.cc index 2390abf..43802ae 100644 --- a/src/Drivers/Fake/FakeDriverTest.cc +++ b/src/Drivers/Fake/FakeDriverTest.cc @@ -27,7 +27,7 @@ namespace { TEST(FakeDriverTest, constructor) { - uint64_t nextAddressId = FakeDriver().localAddressId + 1; + uint32_t nextAddressId = FakeDriver().localAddressId + 1; FakeDriver driver; EXPECT_EQ(nextAddressId, driver.localAddressId); @@ -38,7 +38,7 @@ TEST(FakeDriverTest, allocPacket) FakeDriver driver; Driver::Packet* packet = driver.allocPacket(); // allocPacket doesn't do much so we just need to make sure we can call it. - delete container_of(packet, FakePacket, base); + delete container_of(packet, &FakePacket::base); } TEST(FakeDriverTest, sendPackets) @@ -54,7 +54,7 @@ TEST(FakeDriverTest, sendPackets) destinations[i] = driver2.getLocalAddress(); prio[i] = i; } - destinations[2] = IpAddress(42); + destinations[2] = IpAddress{42}; EXPECT_EQ(0U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -76,7 +76,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); { - Driver::Packet* packet = &driver2.nic.priorityQueue.at(0).front()->base; + FakePacket* packet = driver2.nic.priorityQueue.at(0).front(); EXPECT_EQ(driver1.getLocalAddress(), packet->sourceIp); } @@ -93,7 +93,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - delete container_of(packets[2], FakePacket, base); + delete container_of(packets[2], &FakePacket::base); } TEST(FakeDriverTest, receivePackets) @@ -102,6 +102,7 @@ TEST(FakeDriverTest, receivePackets) FakeDriver driver; Driver::Packet* packets[4]; + IpAddress srcAddrs[4]; // 3 packets at priority 7 for (int i = 0; i < 3; ++i) @@ -123,7 +124,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(3U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(4U, driver.receivePackets(4, packets)); + EXPECT_EQ(4U, driver.receivePackets(4, packets, srcAddrs)); driver.releasePackets(packets, 4); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -135,7 +136,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(1U, driver.receivePackets(1, packets)); + EXPECT_EQ(1U, driver.receivePackets(1, packets, srcAddrs)); driver.releasePackets(packets, 1); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -158,7 +159,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(1U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(1U, driver.receivePackets(1, packets)); + EXPECT_EQ(1U, driver.receivePackets(1, packets, srcAddrs)); driver.releasePackets(packets, 1); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -170,7 +171,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(3U, driver.receivePackets(4, packets)); + EXPECT_EQ(3U, driver.receivePackets(4, packets, srcAddrs)); driver.releasePackets(packets, 3); } @@ -199,9 +200,9 @@ TEST(FakeDriverTest, getBandwidth) TEST(FakeDriverTest, getLocalAddress) { - uint64_t nextAddressId = FakeDriver().localAddressId + 1; + uint32_t nextAddressId = FakeDriver().localAddressId + 1; FakeDriver driver; - EXPECT_EQ(nextAddressId, driver.getLocalAddress()); + EXPECT_EQ(nextAddressId, (uint32_t)driver.getLocalAddress()); } } // namespace diff --git a/src/Mock/MockDriver.h b/src/Mock/MockDriver.h index 35fd731..4080882 100644 --- a/src/Mock/MockDriver.h +++ b/src/Mock/MockDriver.h @@ -43,7 +43,8 @@ class MockDriver : public Driver { (override)); MOCK_METHOD(void, flushPackets, ()); MOCK_METHOD(uint32_t, receivePackets, - (uint32_t maxPackets, Packet* receivedPackets[]), (override)); + (uint32_t maxPackets, Packet* receivedPackets[], + IpAddress sourceAddresses[]), (override)); MOCK_METHOD(void, releasePackets, (Packet* packets[], uint16_t numPackets), (override)); MOCK_METHOD(int, getHighestPacketPriority, (), (override)); diff --git a/src/Policy.h b/src/Policy.h index 6c80c90..5339f32 100644 --- a/src/Policy.h +++ b/src/Policy.h @@ -107,7 +107,8 @@ class Manager { /// The scheduled policy for the Transport that owns this Policy::Manager. Scheduled localScheduledPolicy; /// Collection of the known Policies for each peered Homa::Transport; - std::unordered_map peerPolicies; + std::unordered_map + peerPolicies; /// Number of bytes that can be transmitted in one round-trip-time. const uint32_t RTT_BYTES; /// The highest network packet priority that the driver supports. diff --git a/src/PolicyTest.cc b/src/PolicyTest.cc index 88cdd45..4f23806 100644 --- a/src/PolicyTest.cc +++ b/src/PolicyTest.cc @@ -59,7 +59,7 @@ TEST(PolicyManagerTest, getUnscheduledPolicy) EXPECT_CALL(mockDriver, getBandwidth).WillOnce(Return(8000)); EXPECT_CALL(mockDriver, getHighestPacketPriority).WillOnce(Return(7)); Policy::Manager manager(&mockDriver); - IpAddress dest(22); + IpAddress dest{22}; { Policy::Unscheduled policy = manager.getUnscheduledPolicy(dest, 1); diff --git a/src/Protocol.h b/src/Protocol.h index 25471bb..ef2c723 100644 --- a/src/Protocol.h +++ b/src/Protocol.h @@ -104,38 +104,40 @@ enum Opcode { /** * This is the first part of the Homa packet header and is common to all - * versions of the protocol. The struct contains version information about the + * versions of the protocol. The first four bytes of the header store the source + * and destination ports, which is common for many transport layer protocols + * (e.g., TCP, UDP, etc.) The struct also contains version information about the * protocol used in the encompassing packet. The Transport should always send * this prefix and can always expect it when receiving a Homa packet. The prefix * is separated into its own struct because the Transport may need to know the * protocol version before interpreting the rest of the packet. */ struct HeaderPrefix { + uint16_t sport, dport;///< Transport layer (L4) source and destination ports + ///< in network byte order; only used by DataHeader. uint8_t version; ///< The version of the protocol being used by this ///< packet. /// HeaderPrefix constructor. - HeaderPrefix(uint8_t version) - : version(version) + HeaderPrefix(uint16_t sport, uint16_t dport, uint8_t version) + : sport(sport) + , dport(dport) + , version(version) {} } __attribute__((packed)); /** * Describes the wire format for header fields that are common to all packet - * types. Note: the first 4 bytes are identical for TCP, UDP, and Homa. + * types. */ struct CommonHeader { - uint16_t sport, dport;///< Transport layer (L4) source and destination ports - ///< in network byte order; only used by DataHeader. HeaderPrefix prefix; ///< Common to all versions of the protocol. uint8_t opcode; ///< One of the values of Opcode. MessageId messageId; ///< RemoteOp/Message associated with this packet. /// CommonHeader constructor. CommonHeader(Opcode opcode, MessageId messageId) - : sport(0) - , dport(0) - , prefix(1) + : prefix(0, 0, 1) , opcode(opcode) , messageId(messageId) {} @@ -170,8 +172,8 @@ struct DataHeader { , unscheduledIndexLimit(unscheduledIndexLimit) , index(index) { - common.sport = htobe16(sport); - common.dport = htobe16(dport); + common.prefix.sport = htobe16(sport); + common.prefix.dport = htobe16(dport); } } __attribute__((packed)); diff --git a/src/Receiver.cc b/src/Receiver.cc index d499a61..d007087 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -104,7 +104,7 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) SpinLock::Lock lock_allocator(messageAllocator.mutex); SocketAddress srcAddress = { .ip = sourceIp, - .port = be16toh(header->common.sport) + .port = be16toh(header->common.prefix.sport) }; message = messageAllocator.pool.construct( this, driver, dataHeaderLength, messageLength, id, @@ -126,7 +126,7 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) assert(id == message->id); assert(message->driver == driver); assert(message->source.ip == sourceIp); - assert(message->source.port == be16toh(header->common.sport)); + assert(message->source.port == be16toh(header->common.prefix.sport)); assert(message->messageLength == Util::downCast(header->totalLength)); // Add the packet diff --git a/src/Receiver.h b/src/Receiver.h index c97c462..65e65ff 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -474,7 +474,7 @@ class Receiver { /// Collection of all peers; used for fast access. Access is protected by /// the schedulerMutex. - std::unordered_map peerTable; + std::unordered_map peerTable; /// List of peers with inbound messages that require grants to complete. /// Access is protected by the schedulerMutex. diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index 213e2bd..da9e0bc 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -37,6 +37,9 @@ using ::testing::NiceMock; using ::testing::Pointee; using ::testing::Return; +/// Helper macro to construct an IpAddress from a numeric number. +#define IP(x) IpAddress{x} + class ReceiverTest : public ::testing::Test { public: ReceiverTest() @@ -105,21 +108,21 @@ TEST_F(ReceiverTest, handleDataPacket) header->totalLength = totalMessageLength; header->policyVersion = policyVersion; header->unscheduledIndexLimit = 1; - mockPacket.sourceIp = IpAddress(22); + IpAddress sourceIp{22}; // ------------------------------------------------------------------------- // Receive packet[1]. New message. header->index = 1; mockPacket.length = HEADER_SIZE + 1000; EXPECT_CALL(mockPolicyManager, - signalNewMessage(Eq(mockPacket.sourceIp), Eq(policyVersion), + signalNewMessage(Eq(sourceIp), Eq(policyVersion), Eq(totalMessageLength))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- { @@ -148,7 +151,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(1U, message->numPackets); @@ -162,7 +165,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(2U, message->numPackets); @@ -177,7 +180,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(3U, message->numPackets); @@ -192,7 +195,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(4U, message->numPackets); @@ -207,7 +210,7 @@ TEST_F(ReceiverTest, handleDataPacket) .Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, mockPacket.sourceIp); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- Mock::VerifyAndClearExpectations(&mockDriver); @@ -251,7 +254,7 @@ TEST_F(ReceiverTest, handleBusyPacket_unknown) TEST_F(ReceiverTest, handlePingPacket_basic) { Protocol::MessageId id(42, 32); - IpAddress mockAddress = 22; + IpAddress mockAddress{22}; Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 20000, id, SocketAddress{mockAddress, 0}, 0); ASSERT_TRUE(message->scheduled); @@ -264,7 +267,7 @@ TEST_F(ReceiverTest, handlePingPacket_basic) char pingPayload[1028]; Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; - pingPacket.sourceIp = mockAddress; + IpAddress sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; @@ -277,7 +280,7 @@ TEST_F(ReceiverTest, handlePingPacket_basic) EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, pingPacket.sourceIp); + receiver->handlePingPacket(&pingPacket, sourceIp); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(0U, message->resendTimeout.expirationCycleTime); @@ -296,8 +299,7 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) char pingPayload[1028]; Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; - IpAddress mockAddress = (IpAddress)22; - pingPacket.sourceIp = mockAddress; + IpAddress mockAddress{22}; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; @@ -310,7 +312,7 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, pingPacket.sourceIp); + receiver->handlePingPacket(&pingPacket, mockAddress); Protocol::Packet::UnknownHeader* header = (Protocol::Packet::UnknownHeader*)payload; @@ -864,11 +866,11 @@ TEST_F(ReceiverTest, trySendGrants) { Receiver::Message* message[4]; Receiver::ScheduledMessageInfo* info[4]; - for (uint64_t i = 0; i < 4; ++i) { + for (uint32_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - 10000 * (i + 1), id, SocketAddress{IpAddress(100 + i), 60001}, + 10000 * (i + 1), id, SocketAddress{IP(100 + i), 60001}, 10 * (i + 1)); { SpinLock::Lock lock_scheduler(receiver->schedulerMutex); @@ -996,7 +998,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[0], lock); - EXPECT_EQ(&receiver->peerTable.at(22), info[0]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(22)), info[0]->peer); EXPECT_EQ(message[0], &info[0]->peer->scheduledMessages.front()); EXPECT_EQ(info[0]->peer, &receiver->scheduledPeers.front()); @@ -1008,7 +1010,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[1], lock); - EXPECT_EQ(&receiver->peerTable.at(33), info[1]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(33)), info[1]->peer); EXPECT_EQ(message[1], &info[1]->peer->scheduledMessages.front()); EXPECT_EQ(info[1]->peer, &receiver->scheduledPeers.back()); @@ -1020,7 +1022,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[2], lock); - EXPECT_EQ(&receiver->peerTable.at(33), info[2]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(33)), info[2]->peer); EXPECT_EQ(message[2], &info[2]->peer->scheduledMessages.front()); EXPECT_EQ(info[2]->peer, &receiver->scheduledPeers.front()); @@ -1032,7 +1034,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[3], lock); - EXPECT_EQ(&receiver->peerTable.at(22), info[3]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(22)), info[3]->peer); EXPECT_EQ(message[3], &info[3]->peer->scheduledMessages.back()); EXPECT_EQ(info[3]->peer, &receiver->scheduledPeers.back()); } @@ -1043,23 +1045,24 @@ TEST_F(ReceiverTest, unschedule) Receiver::ScheduledMessageInfo* info[5]; SpinLock::Lock lock(receiver->schedulerMutex); int messageLength[5] = {10, 20, 30, 10, 20}; - for (uint64_t i = 0; i < 5; ++i) { + for (uint32_t i = 0; i < 5; ++i) { Protocol::MessageId id = {42, 10 + i}; - IpAddress source = IpAddress((i / 3) + 10); + IpAddress source = IP((i / 3) + 10); message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), messageLength[i], id, SocketAddress{source, 60001}, 0); info[i] = &message[i]->scheduledMessageInfo; receiver->schedule(message[i], lock); } + auto& scheduledPeers = receiver->scheduledPeers; - ASSERT_EQ(IpAddress(10), message[0]->source.ip); - ASSERT_EQ(IpAddress(10), message[1]->source.ip); - ASSERT_EQ(IpAddress(10), message[2]->source.ip); - ASSERT_EQ(IpAddress(11), message[3]->source.ip); - ASSERT_EQ(IpAddress(11), message[4]->source.ip); - ASSERT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - ASSERT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + ASSERT_EQ(IP(10), message[0]->source.ip); + ASSERT_EQ(IP(10), message[1]->source.ip); + ASSERT_EQ(IP(10), message[2]->source.ip); + ASSERT_EQ(IP(11), message[3]->source.ip); + ASSERT_EQ(IP(11), message[4]->source.ip); + ASSERT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + ASSERT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); // <10>: [0](10) -> [1](20) -> [2](30) // <11>: [3](10) -> [4](20) @@ -1077,10 +1080,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[4], lock); EXPECT_EQ(nullptr, info[4]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(3U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(3U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[1]; peer in correct position. @@ -1090,10 +1093,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[1], lock); EXPECT_EQ(nullptr, info[1]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(2U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(2U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[0]; peer needs to be reordered. @@ -1103,10 +1106,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[0], lock); EXPECT_EQ(nullptr, info[0]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(11)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(10)); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(10).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[3]; peer needs to be removed. @@ -1115,10 +1118,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[3], lock); EXPECT_EQ(nullptr, info[3]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(10)); - EXPECT_EQ(1U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(0U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(1U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(0U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); } TEST_F(ReceiverTest, updateSchedule) @@ -1127,10 +1130,10 @@ TEST_F(ReceiverTest, updateSchedule) // 11 : [20][30] SpinLock::Lock lock(receiver->schedulerMutex); Receiver::Message* other[3]; - for (uint64_t i = 0; i < 3; ++i) { + for (uint32_t i = 0; i < 3; ++i) { Protocol::MessageId id = {42, 10 + i}; int messageLength = 10 * (i + 1); - IpAddress source = IpAddress(((i + 1) / 2) + 10); + IpAddress source = IP(((i + 1) / 2) + 10); other[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 10 * (i + 1), id, SocketAddress{source, 60001}, 0); @@ -1140,12 +1143,13 @@ TEST_F(ReceiverTest, updateSchedule) receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 100, Protocol::MessageId(42, 1), SocketAddress{11, 60001}, 0); receiver->schedule(message, lock); - ASSERT_EQ(&receiver->peerTable.at(10), other[0]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), other[1]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), other[2]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), message->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - ASSERT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + auto& peerTable = receiver->peerTable; + ASSERT_EQ(&peerTable.at(IP(10)), other[0]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), other[1]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), other[2]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), message->scheduledMessageInfo.peer); + ASSERT_EQ(&receiver->scheduledPeers.front(), &peerTable.at(IP(10))); + ASSERT_EQ(&receiver->scheduledPeers.back(), &peerTable.at(IP(11))); //-------------------------------------------------------------------------- // Move message up within peer. @@ -1155,11 +1159,12 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + auto& scheduledPeers = receiver->scheduledPeers; + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); Receiver::Peer* peer = &receiver->scheduledPeers.back(); auto it = peer->scheduledMessages.begin(); EXPECT_TRUE( - std::next(receiver->peerTable.at(11).scheduledMessages.begin()) == + std::next(receiver->peerTable.at(IP(11)).scheduledMessages.begin()) == message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); @@ -1171,8 +1176,8 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(receiver->peerTable.at(11).scheduledMessages.begin(), + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(receiver->peerTable.at(IP(11)).scheduledMessages.begin(), message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); @@ -1184,8 +1189,8 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(11)); - EXPECT_EQ(receiver->peerTable.at(11).scheduledMessages.begin(), + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(receiver->peerTable.at(IP(11)).scheduledMessages.begin(), message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); } diff --git a/src/SenderTest.cc b/src/SenderTest.cc index 244a7c9..dfb216e 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -1367,8 +1367,8 @@ TEST_F(SenderTest, sendMessage_basic) // Check packet metadata Protocol::Packet::DataHeader* header = static_cast(mockPacket.payload); - EXPECT_EQ(htobe16(sport), header->common.sport); - EXPECT_EQ(htobe16(dport), header->common.dport); + EXPECT_EQ(htobe16(sport), header->common.prefix.sport); + EXPECT_EQ(htobe16(dport), header->common.prefix.dport); EXPECT_EQ(id, header->common.messageId); EXPECT_EQ(420U, header->totalLength); EXPECT_EQ(policy.version, header->policyVersion); diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index d4ebc70..a380944 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -98,10 +98,10 @@ TransportImpl::processPackets() const int MAX_BURST = 32; Driver::Packet* packets[MAX_BURST]; - int numPackets = driver->receivePackets(MAX_BURST, packets); + IpAddress srcAddrs[MAX_BURST]; + int numPackets = driver->receivePackets(MAX_BURST, packets, srcAddrs); for (int i = 0; i < numPackets; ++i) { - Driver::Packet* packet = packets[i]; - processPacket(packet, packet->sourceIp); + processPacket(packets[i], srcAddrs[i]); } cycles = PerfUtils::Cycles::rdtsc() - cycles; diff --git a/src/Util.cc b/src/Util.cc index fe73752..90ee9f4 100644 --- a/src/Util.cc +++ b/src/Util.cc @@ -100,21 +100,5 @@ hexDump(const void* buf, uint64_t bytes) return output.str(); } -std::string -ipToString(uint32_t ip) -{ - return StringUtil::format("%d.%d.%d.%d", - (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); -} - -uint32_t -stringToIp(const char* ipStr) -{ - unsigned int bytes[4]; - sscanf(ipStr, "%u.%u.%u.%u", &bytes[0], &bytes[1], &bytes[2], &bytes[3]); - return (bytes[0] << 24) | (bytes[1] << 16) | (bytes[2] << 8) | bytes[3]; -} - - } // namespace Util } // namespace Homa diff --git a/test/Output.h b/test/Output.h index bea8f8b..de8d740 100644 --- a/test/Output.h +++ b/test/Output.h @@ -1,3 +1,18 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + #pragma once #include diff --git a/test/dpdk_test.cc b/test/dpdk_test.cc index ebf9ba2..4ca1a82 100644 --- a/test/dpdk_test.cc +++ b/test/dpdk_test.cc @@ -1,3 +1,18 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + #include #include @@ -37,23 +52,24 @@ main(int argc, char* argv[]) Homa::Drivers::DPDK::DpdkDriver driver(iface.c_str()); if (isServer) { - std::cout << Homa::Util::ipToString(driver.getLocalAddress()) + std::cout << Homa::IpAddress::toString(driver.getLocalAddress()) << std::endl; while (true) { Homa::Driver::Packet* incoming[10]; + Homa::IpAddress srcAddrs[10]; uint32_t receivedPackets; do { - receivedPackets = driver.receivePackets(10, incoming); + receivedPackets = driver.receivePackets(10, incoming, srcAddrs); } while (receivedPackets == 0); Homa::Driver::Packet* pong = driver.allocPacket(); pong->length = 100; - driver.sendPacket(pong, incoming[0]->sourceIp, 0); + driver.sendPacket(pong, srcAddrs[0], 0); driver.releasePackets(incoming, receivedPackets); driver.releasePackets(&pong, 1); } } else { Homa::IpAddress server_ip = - Homa::Util::stringToIp(server_ip_string.c_str()); + Homa::IpAddress::fromString(server_ip_string.c_str()); std::vector times; for (int i = 0; i < 100000; ++i) { uint64_t start = PerfUtils::Cycles::rdtsc(); @@ -67,9 +83,10 @@ main(int argc, char* argv[]) driver.releasePackets(&ping, 1); PerfUtils::TimeTrace::record("releasePacket"); Homa::Driver::Packet* incoming[10]; + Homa::IpAddress srcAddrs[10]; uint32_t receivedPackets; do { - receivedPackets = driver.receivePackets(10, incoming); + receivedPackets = driver.receivePackets(10, incoming, srcAddrs); PerfUtils::TimeTrace::record("receivePackets"); } while (receivedPackets == 0); driver.releasePackets(incoming, receivedPackets); From c1308dc5666e2862c5f14b778602a12c5f83d8bf Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Mon, 5 Oct 2020 23:48:31 -0700 Subject: [PATCH 05/15] fixed copyrights and formatting issues --- include/Homa/Driver.h | 8 ++-- include/Homa/Drivers/Fake/FakeDriver.h | 9 ++-- include/Homa/Util.h | 9 ++-- src/ControlPacket.h | 2 +- src/Driver.cc | 6 +-- src/Drivers/DPDK/DpdkDriver.cc | 3 +- src/Drivers/DPDK/DpdkDriverImpl.cc | 45 +++++++++++--------- src/Drivers/DPDK/MacAddress.cc | 2 +- src/Drivers/DPDK/MacAddress.h | 2 +- src/Drivers/DPDK/MacAddressTest.cc | 2 +- src/Drivers/Fake/FakeDriver.cc | 2 +- src/Drivers/Fake/FakeDriverTest.cc | 2 +- src/Mock/MockDriver.h | 7 +-- src/Mock/MockPolicy.h | 2 +- src/Mock/MockReceiver.h | 8 ++-- src/Mock/MockSender.h | 13 +++--- src/Policy.h | 3 +- src/PolicyTest.cc | 2 +- src/Protocol.h | 5 ++- src/Receiver.cc | 16 +++---- src/ReceiverTest.cc | 47 +++++++++++--------- src/Sender.cc | 4 +- src/SenderTest.cc | 59 ++++++++++++++------------ src/TransportImplTest.cc | 16 +++---- test/system_test.cc | 2 +- 25 files changed, 147 insertions(+), 129 deletions(-) diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index a5cc855..67df785 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -83,7 +83,7 @@ struct PacketSpec { /// Number of bytes in the payload. int32_t length; -} __attribute__((packed)); +} __attribute__((packed)); static_assert(std::is_trivial()); /** @@ -140,7 +140,7 @@ class Driver { * getHighestPacketPriority(). */ virtual void sendPacket(Packet* packet, IpAddress destination, - int priority) = 0; + int priority) = 0; /** * Request that the Driver enter the "corked" mode where outbound packets @@ -170,12 +170,12 @@ class Driver { * * @param maxPackets * The maximum number of Packet objects that should be returned by - * this method. + * this method. * @param[out] receivedPackets * Received packets are appended to this array in order of arrival. * @param[out] sourceAddresses * Source IP addresses of the received packets are appended to this - * array in order of arrival. + * array in order of arrival. * * @return * Number of Packet objects being returned. diff --git a/include/Homa/Drivers/Fake/FakeDriver.h b/include/Homa/Drivers/Fake/FakeDriver.h index dd01261..5f54586 100644 --- a/include/Homa/Drivers/Fake/FakeDriver.h +++ b/include/Homa/Drivers/Fake/FakeDriver.h @@ -65,8 +65,7 @@ struct FakePacket { * FakePacket constructor. */ explicit FakePacket() - : base{.payload = buf, - .length = 0} + : base{.payload = buf, .length = 0} , buf() , sourceIp() {} @@ -75,8 +74,7 @@ struct FakePacket { * Copy constructor. */ FakePacket(const FakePacket& other) - : base{.payload = buf, - .length = other.base.length} + : base{.payload = buf, .length = other.base.length} , buf() , sourceIp() { @@ -112,7 +110,8 @@ class FakeDriver : public Driver { virtual ~FakeDriver(); virtual Packet* allocPacket(); - virtual void sendPacket(Packet* packet, IpAddress destination, int priority); + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority); virtual uint32_t receivePackets(uint32_t maxPackets, Packet* receivedPackets[], IpAddress sourceAddresses[]); diff --git a/include/Homa/Util.h b/include/Homa/Util.h index ba757e6..4f17acc 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2009-2018, Stanford University +/* Copyright (c) 2009-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -22,10 +22,11 @@ #include /// Cast a member of a structure out to the containing structure. -template -P* container_of(M* ptr, const M P::*member) +template +P* +container_of(M* ptr, const M P::*member) { - return (P*)((char*) ptr - (size_t) &(reinterpret_cast(0)->*member)); + return (P*)((char*)ptr - (size_t) & (reinterpret_cast(0)->*member)); } namespace Homa { diff --git a/src/ControlPacket.h b/src/ControlPacket.h index bc53f10..17310af 100644 --- a/src/ControlPacket.h +++ b/src/ControlPacket.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Driver.cc b/src/Driver.cc index c7d61cb..b29c828 100644 --- a/src/Driver.cc +++ b/src/Driver.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018-2019, Stanford University +/* Copyright (c) 2018-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -23,8 +23,8 @@ std::string IpAddress::toString(IpAddress address) { uint32_t ip = address.addr; - return StringUtil::format("%d.%d.%d.%d", - (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); + return StringUtil::format("%d.%d.%d.%d", (ip >> 24) & 0xff, + (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); } IpAddress diff --git a/src/Drivers/DPDK/DpdkDriver.cc b/src/Drivers/DPDK/DpdkDriver.cc index 3c8833a..c27d1df 100644 --- a/src/Drivers/DPDK/DpdkDriver.cc +++ b/src/Drivers/DPDK/DpdkDriver.cc @@ -30,7 +30,8 @@ DpdkDriver::DpdkDriver(const char* ifname, int argc, char* argv[], : pImpl(new Impl(ifname, argc, argv, config)) {} -DpdkDriver::DpdkDriver(const char* ifname, NoEalInit _, const Config* const config) +DpdkDriver::DpdkDriver(const char* ifname, NoEalInit _, + const Config* const config) : pImpl(new Impl(ifname, _, config)) {} diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index 42a2340..e9fef18 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -17,19 +17,19 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#include -#include #include #include +#include #include +#include #include "DpdkDriverImpl.h" #include #include "CodeLocation.h" -#include "StringUtil.h" #include "Homa/Util.h" +#include "StringUtil.h" namespace Homa { @@ -51,8 +51,7 @@ const char* default_eal_argv[] = {"homa", NULL}; * Memory location in the mbuf where the packet data should be stored. */ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) - : base {.payload = data, - .length = 0} + : base{.payload = data, .length = 0} , bufType(MBUF) , bufRef() { @@ -66,8 +65,7 @@ DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) * Overflow buffer that holds this packet. */ DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) - : base {.payload = overflowBuf->data, - .length = 0} + : base{.payload = overflowBuf->data, .length = 0} , bufType(OVERFLOW_BUF) , bufRef() { @@ -212,7 +210,8 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, numMbufsAvail, numMbufsInUse); return; } - char* buf = rte_pktmbuf_append(mbuf, + char* buf = rte_pktmbuf_append( + mbuf, Homa::Util::downCast(PACKET_HDR_LEN + pkt->base.length)); if (unlikely(NULL == buf)) { WARNING("rte_pktmbuf_append call failed; dropping packet"); @@ -506,9 +505,10 @@ DpdkDriver::Impl::_init() char mask[100]; char dev[100]; int type, flags; - int cols = sscanf(line.c_str(), "%s 0x%x 0x%x %99s %99s %99s\n", - ip, &type, &flags, hwa, mask, dev); - if (cols != 6) continue; + int cols = sscanf(line.c_str(), "%s 0x%x 0x%x %99s %99s %99s\n", ip, + &type, &flags, hwa, mask, dev); + if (cols != 6) + continue; arpTable.emplace(IpAddress::fromString(ip), hwa); } @@ -517,28 +517,32 @@ DpdkDriver::Impl::_init() ifname.copy(ifr.ifr_name, ifname.length()); ifr.ifr_name[ifname.length() + 1] = 0; if (ifname.length() >= sizeof(ifr.ifr_name)) { - throw DriverInitFailure(HERE_STR, + throw DriverInitFailure( + HERE_STR, StringUtil::format("Interface name %s too long", ifname.c_str())); } int fd = socket(AF_INET, SOCK_DGRAM, 0); if (fd == -1) { - throw DriverInitFailure(HERE_STR, + throw DriverInitFailure( + HERE_STR, StringUtil::format("Failed to create socket: %s", strerror(errno))); } if (ioctl(fd, SIOCGIFADDR, &ifr) == -1) { char* error = strerror(errno); close(fd); - throw DriverInitFailure(HERE_STR, + throw DriverInitFailure( + HERE_STR, StringUtil::format("Failed to obtain IP address: %s", error)); } - localIp = {be32toh(((struct sockaddr_in*) &ifr.ifr_addr)->sin_addr.s_addr)}; + localIp = {be32toh(((struct sockaddr_in*)&ifr.ifr_addr)->sin_addr.s_addr)}; if (ioctl(fd, SIOCGIFHWADDR, &ifr) == -1) { char* error = strerror(errno); close(fd); - throw DriverInitFailure(HERE_STR, + throw DriverInitFailure( + HERE_STR, StringUtil::format("Failed to obtain MAC address: %s", error)); } close(fd); @@ -546,7 +550,8 @@ DpdkDriver::Impl::_init() // Iterate over ethernet devices to locate the port identifier. int p; - RTE_ETH_FOREACH_DEV(p) { + RTE_ETH_FOREACH_DEV(p) + { struct ether_addr mac; rte_eth_macaddr_get(p, &mac); if (MacAddress(mac.addr_bytes) == localMac) { @@ -554,9 +559,9 @@ DpdkDriver::Impl::_init() break; } } - NOTICE("Using interface %s, ip %s, mac %s, port %u", - ifname.c_str(), IpAddress::toString(localIp).c_str(), - localMac.toString().c_str(), port); + NOTICE("Using interface %s, ip %s, mac %s, port %u", ifname.c_str(), + IpAddress::toString(localIp).c_str(), localMac.toString().c_str(), + port); std::string poolName = StringUtil::format("homa_mbuf_pool_%u", port); std::string ringName = StringUtil::format("homa_loopback_ring_%u", port); diff --git a/src/Drivers/DPDK/MacAddress.cc b/src/Drivers/DPDK/MacAddress.cc index 63149fa..e47f27a 100644 --- a/src/Drivers/DPDK/MacAddress.cc +++ b/src/Drivers/DPDK/MacAddress.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Drivers/DPDK/MacAddress.h b/src/Drivers/DPDK/MacAddress.h index 148f2ce..33f47a5 100644 --- a/src/Drivers/DPDK/MacAddress.h +++ b/src/Drivers/DPDK/MacAddress.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Drivers/DPDK/MacAddressTest.cc b/src/Drivers/DPDK/MacAddressTest.cc index 7587a16..9b8b8ae 100644 --- a/src/Drivers/DPDK/MacAddressTest.cc +++ b/src/Drivers/DPDK/MacAddressTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index 5cbafb8..26cb102 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Drivers/Fake/FakeDriverTest.cc b/src/Drivers/Fake/FakeDriverTest.cc index 43802ae..cd64917 100644 --- a/src/Drivers/Fake/FakeDriverTest.cc +++ b/src/Drivers/Fake/FakeDriverTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Mock/MockDriver.h b/src/Mock/MockDriver.h index 4080882..9ea6ffe 100644 --- a/src/Mock/MockDriver.h +++ b/src/Mock/MockDriver.h @@ -39,13 +39,14 @@ class MockDriver : public Driver { MOCK_METHOD(Packet*, allocPacket, (), (override)); MOCK_METHOD(void, sendPacket, - (Packet* packet, IpAddress destination, int priority), + (Packet * packet, IpAddress destination, int priority), (override)); MOCK_METHOD(void, flushPackets, ()); MOCK_METHOD(uint32_t, receivePackets, (uint32_t maxPackets, Packet* receivedPackets[], - IpAddress sourceAddresses[]), (override)); - MOCK_METHOD(void, releasePackets, (Packet* packets[], uint16_t numPackets), + IpAddress sourceAddresses[]), + (override)); + MOCK_METHOD(void, releasePackets, (Packet * packets[], uint16_t numPackets), (override)); MOCK_METHOD(int, getHighestPacketPriority, (), (override)); MOCK_METHOD(uint32_t, getMaxPayloadSize, (), (override)); diff --git a/src/Mock/MockPolicy.h b/src/Mock/MockPolicy.h index 52cb2a5..32e7be8 100644 --- a/src/Mock/MockPolicy.h +++ b/src/Mock/MockPolicy.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Mock/MockReceiver.h b/src/Mock/MockReceiver.h index 75eea2c..61c21ce 100644 --- a/src/Mock/MockReceiver.h +++ b/src/Mock/MockReceiver.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -37,10 +37,10 @@ class MockReceiver : public Core::Receiver { {} MOCK_METHOD(void, handleDataPacket, - (Driver::Packet* packet, IpAddress sourceIp), (override)); - MOCK_METHOD(void, handleBusyPacket, (Driver::Packet* packet), (override)); + (Driver::Packet * packet, IpAddress sourceIp), (override)); + MOCK_METHOD(void, handleBusyPacket, (Driver::Packet * packet), (override)); MOCK_METHOD(void, handlePingPacket, - (Driver::Packet* packet, IpAddress sourceIp), (override)); + (Driver::Packet * packet, IpAddress sourceIp), (override)); MOCK_METHOD(Homa::InMessage*, receiveMessage, (), (override)); MOCK_METHOD(void, poll, (), (override)); MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index 4a8bd27..cb29c90 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -38,12 +38,13 @@ class MockSender : public Core::Sender { {} MOCK_METHOD(Homa::OutMessage*, allocMessage, (uint16_t sport), (override)); - MOCK_METHOD(void, handleDonePacket, (Driver::Packet* packet), (override)); - MOCK_METHOD(void, handleGrantPacket, (Driver::Packet* packet), (override)); - MOCK_METHOD(void, handleResendPacket, (Driver::Packet* packet), (override)); - MOCK_METHOD(void, handleUnknownPacket, (Driver::Packet* packet), + MOCK_METHOD(void, handleDonePacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, handleGrantPacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, handleResendPacket, (Driver::Packet * packet), (override)); - MOCK_METHOD(void, handleErrorPacket, (Driver::Packet* packet), (override)); + MOCK_METHOD(void, handleUnknownPacket, (Driver::Packet * packet), + (override)); + MOCK_METHOD(void, handleErrorPacket, (Driver::Packet * packet), (override)); MOCK_METHOD(void, poll, (), (override)); MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); }; diff --git a/src/Policy.h b/src/Policy.h index 5339f32..0be1eb2 100644 --- a/src/Policy.h +++ b/src/Policy.h @@ -77,8 +77,7 @@ class Manager { virtual Scheduled getScheduledPolicy(); virtual Unscheduled getUnscheduledPolicy(const IpAddress destination, const uint32_t messageLength); - virtual void signalNewMessage(const IpAddress source, - uint8_t policyVersion, + virtual void signalNewMessage(const IpAddress source, uint8_t policyVersion, uint32_t messageLength); virtual void poll(); diff --git a/src/PolicyTest.cc b/src/PolicyTest.cc index 4f23806..44b8829 100644 --- a/src/PolicyTest.cc +++ b/src/PolicyTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Protocol.h b/src/Protocol.h index ef2c723..55a34ac 100644 --- a/src/Protocol.h +++ b/src/Protocol.h @@ -113,8 +113,9 @@ enum Opcode { * protocol version before interpreting the rest of the packet. */ struct HeaderPrefix { - uint16_t sport, dport;///< Transport layer (L4) source and destination ports - ///< in network byte order; only used by DataHeader. + uint16_t sport, + dport; ///< Transport layer (L4) source and destination ports + ///< in network byte order; only used by DataHeader. uint8_t version; ///< The version of the protocol being used by this ///< packet. diff --git a/src/Receiver.cc b/src/Receiver.cc index d007087..c850d07 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -103,17 +103,15 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) { SpinLock::Lock lock_allocator(messageAllocator.mutex); SocketAddress srcAddress = { - .ip = sourceIp, - .port = be16toh(header->common.prefix.sport) - }; + .ip = sourceIp, .port = be16toh(header->common.prefix.sport)}; message = messageAllocator.pool.construct( - this, driver, dataHeaderLength, messageLength, id, - srcAddress, numUnscheduledPackets); + this, driver, dataHeaderLength, messageLength, id, srcAddress, + numUnscheduledPackets); } bucket->messages.push_back(&message->bucketNode); - policyManager->signalNewMessage(message->source.ip, - header->policyVersion, header->totalLength); + policyManager->signalNewMessage( + message->source.ip, header->policyVersion, header->totalLength); if (message->scheduled) { // Message needs to be scheduled. @@ -244,8 +242,8 @@ Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) // We are here because we have no knowledge of the message the Sender is // asking about. Reply UNKNOWN so the Sender can react accordingly. Perf::counters.tx_unknown_pkts.add(1); - ControlPacket::send( - driver, sourceIp, id); + ControlPacket::send(driver, sourceIp, + id); } driver->releasePackets(&packet, 1); } diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index da9e0bc..bfccc39 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -38,13 +38,17 @@ using ::testing::Pointee; using ::testing::Return; /// Helper macro to construct an IpAddress from a numeric number. -#define IP(x) IpAddress{x} +#define IP(x) \ + IpAddress \ + { \ + x \ + } class ReceiverTest : public ::testing::Test { public: ReceiverTest() : mockDriver() - , mockPacket {&payload} + , mockPacket{&payload} , mockPolicyManager(&mockDriver) , payload() , receiver() @@ -266,7 +270,7 @@ TEST_F(ReceiverTest, handlePingPacket_basic) bucket->messages.push_back(&message->bucketNode); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; + Homa::Mock::MockDriver::MockPacket pingPacket{pingPayload}; IpAddress sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; @@ -298,7 +302,7 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) Protocol::MessageId id(42, 32); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket {pingPayload}; + Homa::Mock::MockDriver::MockPacket pingPacket{pingPayload}; IpAddress mockAddress{22}; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; @@ -352,7 +356,8 @@ TEST_F(ReceiverTest, poll) TEST_F(ReceiverTest, checkTimeouts) { Receiver::Message message(receiver, &mockDriver, 0, 0, - Protocol::MessageId(0, 0), SocketAddress{0, 60001}, 0); + Protocol::MessageId(0, 0), + SocketAddress{0, 60001}, 0); Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); bucket->resendTimeouts.setTimeout(&message.resendTimeout); bucket->messageTimeouts.setTimeout(&message.messageTimeout); @@ -420,8 +425,9 @@ TEST_F(ReceiverTest, Message_acknowledge) receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket( - Eq(&mockPacket), Eq(message->source.ip), _)).Times(1); + EXPECT_CALL(mockDriver, + sendPacket(Eq(&mockPacket), Eq(message->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -456,8 +462,9 @@ TEST_F(ReceiverTest, Message_fail) receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket( - Eq(&mockPacket), Eq(message->source.ip), _)).Times(1); + EXPECT_CALL(mockDriver, + sendPacket(Eq(&mockPacket), Eq(message->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -477,8 +484,8 @@ TEST_F(ReceiverTest, Message_get_basic) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; char source[] = "Hello, world!"; message->setPacket(0, &packet0); @@ -504,8 +511,8 @@ TEST_F(ReceiverTest, Message_get_offsetTooLarge) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; message->setPacket(0, &packet0); message->setPacket(1, &packet1); @@ -530,8 +537,8 @@ TEST_F(ReceiverTest, Message_get_missingPacket) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; char source[] = "Hello,"; message->setPacket(0, &packet0); @@ -806,16 +813,18 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) char buf1[1024]; char buf2[1024]; - Homa::Mock::MockDriver::MockPacket mockResendPacket1 {buf1}; - Homa::Mock::MockDriver::MockPacket mockResendPacket2 {buf2}; + Homa::Mock::MockDriver::MockPacket mockResendPacket1{buf1}; + Homa::Mock::MockDriver::MockPacket mockResendPacket2{buf2}; EXPECT_CALL(mockDriver, allocPacket()) .WillOnce(Return(&mockResendPacket1)) .WillOnce(Return(&mockResendPacket2)); EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket1), - Eq(message[0]->source.ip), _)).Times(1); + Eq(message[0]->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket2), - Eq(message[0]->source.ip), _)).Times(1); + Eq(message[0]->source.ip), _)) + .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket1), Eq(1))) .Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket2), Eq(1))) diff --git a/src/Sender.cc b/src/Sender.cc index b993b6b..ea75bf4 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -391,7 +391,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet) Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(dataPacket->length); driver->sendPacket(dataPacket, message->destination.ip, - policy.priority); + policy.priority); message->state.store(OutMessage::Status::SENT); } else { // Otherwise, queue the message to be sent in SRPT order. @@ -401,7 +401,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet) // was first queued. assert(info->id == message->id); assert(!memcmp(&info->destination, &message->destination, - sizeof(info->destination))); + sizeof(info->destination))); assert(info->packets == message); // Some values need to be updated info->unsentBytes = message->messageLength; diff --git a/src/SenderTest.cc b/src/SenderTest.cc index dfb216e..8085c82 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -36,7 +36,7 @@ class SenderTest : public ::testing::Test { public: SenderTest() : mockDriver() - , mockPacket {&payload} + , mockPacket{&payload} , mockPolicyManager(&mockDriver) , sender() , savedLogPolicy(Debug::getLogPolicy()) @@ -314,7 +314,7 @@ TEST_F(SenderTest, handleResendPacket_basic) std::vector packets; std::vector priorities; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket {payload}); + packets.push_back(new Homa::Mock::MockDriver::MockPacket{payload}); priorities.push_back(0); setMessagePacket(message, i, packets[i]); } @@ -333,10 +333,12 @@ TEST_F(SenderTest, handleResendPacket_basic) resendHdr->priority = 4; EXPECT_CALL(mockPolicyManager, getResendPriority).WillOnce(Return(7)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]), _, _)).WillOnce( - [&priorities] (auto _1, auto _2, int p) { priorities[3] = p; }); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]), _, _)).WillOnce( - [&priorities] (auto _1, auto _2, int p) { priorities[4] = p; }); + EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]), _, _)) + .WillOnce( + [&priorities](auto _1, auto _2, int p) { priorities[3] = p; }); + EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]), _, _)) + .WillOnce( + [&priorities](auto _1, auto _2, int p) { priorities[4] = p; }); EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) .Times(1); @@ -381,7 +383,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket {payload}; + new Homa::Mock::MockDriver::MockPacket{payload}; setMessagePacket(message, 0, packet); Protocol::Packet::ResendHeader* resendHdr = @@ -421,7 +423,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) dynamic_cast(sender->allocMessage(0)); std::vector packets; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket {payload}); + packets.push_back(new Homa::Mock::MockDriver::MockPacket{payload}); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -470,7 +472,7 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) Sender::Message* message = dynamic_cast(sender->allocMessage(0)); char data[1028]; - Homa::Mock::MockDriver::MockPacket dataPacket {data}; + Homa::Mock::MockDriver::MockPacket dataPacket{data}; for (int i = 0; i < 10; ++i) { setMessagePacket(message, i, &dataPacket); } @@ -488,7 +490,7 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) // Expect the BUSY control packet. char busy[1028]; - Homa::Mock::MockDriver::MockPacket busyPacket {busy}; + Homa::Mock::MockDriver::MockPacket busyPacket{busy}; EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&busyPacket)); EXPECT_CALL(mockDriver, sendPacket(Eq(&busyPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(Pointee(&busyPacket), Eq(1))) @@ -648,7 +650,7 @@ TEST_F(SenderTest, handleUnknownPacket_basic) char payload[5][1028]; for (int i = 0; i < 5; ++i) { Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket {payload[i]}; + new Homa::Mock::MockDriver::MockPacket{payload[i]}; Protocol::Packet::DataHeader* header = static_cast(packet->payload); header->policyVersion = policyOld.version; @@ -716,7 +718,7 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) Sender::Message* message = dynamic_cast(sender->allocMessage(0)); - Homa::Mock::MockDriver::MockPacket dataPacket {payload}; + Homa::Mock::MockDriver::MockPacket dataPacket{payload}; Protocol::Packet::DataHeader* dataHeader = static_cast(dataPacket.payload); dataHeader->policyVersion = policyOld.version; @@ -768,7 +770,7 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) char payload[5][1028]; for (int i = 0; i < 5; ++i) { Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket {payload[i]}; + new Homa::Mock::MockDriver::MockPacket{payload[i]}; packets.push_back(packet); setMessagePacket(message, i, packet); } @@ -1092,8 +1094,8 @@ TEST_F(SenderTest, Message_append_basic) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + MAX_RAW_PACKET_LENGTH}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1132,8 +1134,8 @@ TEST_F(SenderTest, Message_append_truncated) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + MAX_RAW_PACKET_LENGTH}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1189,8 +1191,8 @@ TEST_F(SenderTest, Message_prepend) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1224,8 +1226,8 @@ TEST_F(SenderTest, Message_reserve) { Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1279,8 +1281,8 @@ TEST_F(SenderTest, Message_getOrAllocPacket) // TODO(cstlee): cleanup Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0 {buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1 {buf + 2048}; + Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; + Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; EXPECT_FALSE(msg.occupied.test(0)); EXPECT_EQ(0U, msg.numPackets); @@ -1353,7 +1355,8 @@ TEST_F(SenderTest, sendMessage_basic) .WillOnce(Return(policy)); int mockPriority = 0; EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(destination.ip), _)) - .WillOnce([&mockPriority] (auto _1, auto _2, int p){mockPriority = p;}); + .WillOnce( + [&mockPriority](auto _1, auto _2, int p) { mockPriority = p; }); sender->sendMessage(message, destination, Sender::Message::Options::NO_RETRY); @@ -1391,8 +1394,8 @@ TEST_F(SenderTest, sendMessage_multipacket) { char payload0[1027]; char payload1[1027]; - Homa::Mock::MockDriver::MockPacket packet0 {payload0}; - Homa::Mock::MockDriver::MockPacket packet1 {payload1}; + Homa::Mock::MockDriver::MockPacket packet0{payload0}; + Homa::Mock::MockDriver::MockPacket packet1{payload1}; Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = @@ -1663,7 +1666,7 @@ TEST_F(SenderTest, trySend_basic) const uint32_t PACKET_DATA_SIZE = PACKET_SIZE - message->TRANSPORT_HEADER_LENGTH; for (int i = 0; i < 5; ++i) { - packet[i] = new Homa::Mock::MockDriver::MockPacket {payload}; + packet[i] = new Homa::Mock::MockDriver::MockPacket{payload}; packet[i]->length = PACKET_SIZE; setMessagePacket(message, i, packet[i]); info->unsentBytes += PACKET_DATA_SIZE; @@ -1745,7 +1748,7 @@ TEST_F(SenderTest, trySend_multipleMessages) message[i] = dynamic_cast(sender->allocMessage(0)); info[i] = &message[i]->queuedMessageInfo; SenderTest::addMessage(sender, id, message[i], true, 1); - packet[i] = new Homa::Mock::MockDriver::MockPacket {payload}; + packet[i] = new Homa::Mock::MockDriver::MockPacket{payload}; packet[i]->length = sender->driver->getMaxPayloadSize() / 4; setMessagePacket(message[i], 0, packet[i]); info[i]->unsentBytes += diff --git a/src/TransportImplTest.cc b/src/TransportImplTest.cc index c69a36a..a0f66c6 100644 --- a/src/TransportImplTest.cc +++ b/src/TransportImplTest.cc @@ -102,56 +102,56 @@ TEST_F(TransportImplTest, processPackets) Homa::Driver::Packet* packets[8]; // Set DATA packet - Homa::Mock::MockDriver::MockPacket dataPacket {payload[0], 1024}; + Homa::Mock::MockDriver::MockPacket dataPacket{payload[0], 1024}; static_cast(dataPacket.payload) ->common.opcode = Protocol::Packet::DATA; packets[0] = &dataPacket; EXPECT_CALL(*mockReceiver, handleDataPacket(Eq(&dataPacket), _)); // Set GRANT packet - Homa::Mock::MockDriver::MockPacket grantPacket {payload[1], 1024}; + Homa::Mock::MockDriver::MockPacket grantPacket{payload[1], 1024}; static_cast(grantPacket.payload) ->common.opcode = Protocol::Packet::GRANT; packets[1] = &grantPacket; EXPECT_CALL(*mockSender, handleGrantPacket(Eq(&grantPacket))); // Set DONE packet - Homa::Mock::MockDriver::MockPacket donePacket {payload[2], 1024}; + Homa::Mock::MockDriver::MockPacket donePacket{payload[2], 1024}; static_cast(donePacket.payload) ->common.opcode = Protocol::Packet::DONE; packets[2] = &donePacket; EXPECT_CALL(*mockSender, handleDonePacket(Eq(&donePacket))); // Set RESEND packet - Homa::Mock::MockDriver::MockPacket resendPacket {payload[3], 1024}; + Homa::Mock::MockDriver::MockPacket resendPacket{payload[3], 1024}; static_cast(resendPacket.payload) ->common.opcode = Protocol::Packet::RESEND; packets[3] = &resendPacket; EXPECT_CALL(*mockSender, handleResendPacket(Eq(&resendPacket))); // Set BUSY packet - Homa::Mock::MockDriver::MockPacket busyPacket {payload[4], 1024}; + Homa::Mock::MockDriver::MockPacket busyPacket{payload[4], 1024}; static_cast(busyPacket.payload) ->common.opcode = Protocol::Packet::BUSY; packets[4] = &busyPacket; EXPECT_CALL(*mockReceiver, handleBusyPacket(Eq(&busyPacket))); // Set PING packet - Homa::Mock::MockDriver::MockPacket pingPacket {payload[5], 1024}; + Homa::Mock::MockDriver::MockPacket pingPacket{payload[5], 1024}; static_cast(pingPacket.payload) ->common.opcode = Protocol::Packet::PING; packets[5] = &pingPacket; EXPECT_CALL(*mockReceiver, handlePingPacket(Eq(&pingPacket), _)); // Set UNKNOWN packet - Homa::Mock::MockDriver::MockPacket unknownPacket {payload[6], 1024}; + Homa::Mock::MockDriver::MockPacket unknownPacket{payload[6], 1024}; static_cast(unknownPacket.payload) ->common.opcode = Protocol::Packet::UNKNOWN; packets[6] = &unknownPacket; EXPECT_CALL(*mockSender, handleUnknownPacket(Eq(&unknownPacket))); // Set ERROR packet - Homa::Mock::MockDriver::MockPacket errorPacket {payload[7], 1024}; + Homa::Mock::MockDriver::MockPacket errorPacket{payload[7], 1024}; static_cast(errorPacket.payload) ->common.opcode = Protocol::Packet::ERROR; packets[7] = &errorPacket; diff --git a/test/system_test.cc b/test/system_test.cc index 266d842..88b3814 100644 --- a/test/system_test.cc +++ b/test/system_test.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above From a5fc9252d543d7f2a0b2f52c1fa916df6cac644a Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Sun, 9 Aug 2020 02:37:59 -0700 Subject: [PATCH 06/15] More changes to enable Shenango integration. The biggest challenge is to move away from the poll-based execution model previously embedded in the implementation. For example, Shenango can't afford to drive the execution of the transport by calling Transport::poll in a busy loop. Also, Shenango needs to allow users to block on a socket waiting for incoming messages. Another refactoring that results in small code changes all over the place is the Driver::Packet interface. The old interface cannot prevent the driver from modifying the payload (e.g., prepend L3 headers). This will lead to corrupted message when the packet needs to be retransmitted. List of major changes: - CHoma: provide C-bindings of the Homa APIs (Shenango is written in C) - Shenango: implement Shenango Driver, MailboxDir, and Mailbox - TransportPoller: extract poll-based execution model out of the Transport - SimpleMailboxDir: a simple reference implementation for Homa::MailboxDir - Driver::Packet: a new packet interface to eliminate the awkward PacketSpec; this is used to provide an immutable view of the packet to the transport (driver can prepend headers to the payload w/o affecting the transport) - Sender: add a couple of user-defined callbacks (Shenango currently relies on them to wake up blocking threads) - Finally, bring unit tests up-to-date. --- CMakeLists.txt | 4 + include/Homa/Bindings/CHoma.h | 238 +++++++++++++++ include/Homa/Driver.h | 46 ++- include/Homa/Drivers/DPDK/DpdkDriver.h | 6 +- include/Homa/Drivers/Fake/FakeDriver.h | 32 +- include/Homa/Homa.h | 386 ++++++++++++++++++++++--- include/Homa/OutMessageStatus.h | 32 ++ include/Homa/Shenango.h | 77 +++++ include/Homa/Util.h | 8 - include/Homa/Utils/SimpleMailboxDir.h | 61 ++++ include/Homa/Utils/TransportPoller.h | 55 ++++ src/CHoma.cc | 228 +++++++++++++++ src/ControlPacket.h | 10 +- src/Drivers/DPDK/DpdkDriver.cc | 6 +- src/Drivers/DPDK/DpdkDriverImpl.cc | 96 +++--- src/Drivers/DPDK/DpdkDriverImpl.h | 27 +- src/Drivers/Fake/FakeDriver.cc | 17 +- src/Drivers/Fake/FakeDriverTest.cc | 15 +- src/Homa.cc | 8 +- src/Mock/MockDriver.h | 27 +- src/Mock/MockReceiver.h | 6 +- src/Mock/MockSender.h | 2 +- src/Receiver.cc | 167 ++++++----- src/Receiver.h | 55 ++-- src/ReceiverTest.cc | 203 +++++++------ src/Sender.cc | 208 ++++++++----- src/Sender.h | 43 ++- src/SenderTest.cc | 383 ++++++++++++------------ src/Shenango.cc | 272 +++++++++++++++++ src/SimpleMailboxDir.cc | 171 +++++++++++ src/TransportImpl.cc | 186 ++++++++---- src/TransportImpl.h | 95 ++++-- src/TransportImplTest.cc | 132 +++++---- src/TransportPoller.cc | 85 ++++++ test/dpdk_test.cc | 24 +- test/system_test.cc | 36 ++- 36 files changed, 2605 insertions(+), 842 deletions(-) create mode 100644 include/Homa/Bindings/CHoma.h create mode 100644 include/Homa/OutMessageStatus.h create mode 100644 include/Homa/Shenango.h create mode 100644 include/Homa/Utils/SimpleMailboxDir.h create mode 100644 include/Homa/Utils/TransportPoller.h create mode 100644 src/CHoma.cc create mode 100644 src/Shenango.cc create mode 100644 src/SimpleMailboxDir.cc create mode 100644 src/TransportPoller.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index f5cb6ef..4174fc1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,6 +73,7 @@ endif() ## lib Homa #################################################################### add_library(Homa src/CodeLocation.cc + src/CHoma.cc src/Debug.cc src/Driver.cc src/Homa.cc @@ -80,9 +81,12 @@ add_library(Homa src/Policy.cc src/Receiver.cc src/Sender.cc + src/Shenango.cc + src/SimpleMailboxDir.cc src/StringUtil.cc src/ThreadId.cc src/TransportImpl.cc + src/TransportPoller.cc src/Util.cc ) add_library(Homa::Homa ALIAS Homa) diff --git a/include/Homa/Bindings/CHoma.h b/include/Homa/Bindings/CHoma.h new file mode 100644 index 0000000..253afea --- /dev/null +++ b/include/Homa/Bindings/CHoma.h @@ -0,0 +1,238 @@ +/* Copyright (c) 2020 Stanford University + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR(S) DISCLAIM ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL AUTHORS BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +/** + * @file CHoma.h + * + * Contains C-bindings for the Homa Transport API. + */ + +#pragma once + +#include "Homa/OutMessageStatus.h" + +#ifdef __cplusplus +#include +#include +extern "C" { +#else +#include +#include +#endif + +/** + * Define handle types for various Homa objects. + * + * A handle type is essentially a thin wrapper around an opaque pointer. + * Compared to generic pointers, using handle types in the C API enables + * some type safety. + */ +#define DEFINE_HOMA_OBJ_HANDLE(x) \ + typedef struct { \ + void* p; \ + } homa_##x; + +DEFINE_HOMA_OBJ_HANDLE(driver) /* Homa::Driver */ +DEFINE_HOMA_OBJ_HANDLE(inmsg) /* Homa::InMessage */ +DEFINE_HOMA_OBJ_HANDLE(outmsg) /* Homa::OutMessage */ +DEFINE_HOMA_OBJ_HANDLE(mailbox) /* Homa::Mailbox */ +DEFINE_HOMA_OBJ_HANDLE(mailbox_dir) /* Homa::MailboxDir */ +DEFINE_HOMA_OBJ_HANDLE(sk) /* Homa::Socket */ +DEFINE_HOMA_OBJ_HANDLE(trans) /* Homa::Transport */ + +/* ============================ */ +/* Homa::InMessage API */ +/* ============================ */ + +/** + * homa_inmsg_ack - C-binding for Homa::InMessage::acknowledge + */ +extern void homa_inmsg_ack(homa_inmsg in_msg); + +/** + * homa_inmsg_dropped - C-binding for Homa::InMessage::dropped + */ +extern bool homa_inmsg_dropped(homa_inmsg in_msg); + +/** + * homa_inmsg_fail - C-binding for Homa::InMessage::fail + */ +extern void homa_inmsg_fail(homa_inmsg in_msg); + +/** + * homa_inmsg_get - C-binding for Homa::InMessage::get + */ +extern size_t homa_inmsg_get(homa_inmsg in_msg, size_t ofs, void* dst, + size_t len); + +/** + * homa_inmsg_src_addr - C-binding for Homa::InMessage::getSourceAddress + */ +extern void homa_inmsg_src_addr(homa_inmsg in_msg, uint32_t* ip, + uint16_t* port); + +/** + * homa_inmsg_len - C-binding for Homa::InMessage::length + */ +extern size_t homa_inmsg_len(homa_inmsg in_msg); + +/** + * homa_inmsg_release - C-binding for Homa::InMessage::release + */ +extern void homa_inmsg_release(homa_inmsg in_msg); + +/** + * homa_inmsg_strip - C-binding for Homa::InMessage::strip + */ +extern void homa_inmsg_strip(homa_inmsg in_msg, size_t n); + +/* ============================ */ +/* Homa::OutMessage API */ +/* ============================ */ + +/** + * homa_outmsg_append - C-binding for Homa::OutMessage::append + */ +extern void homa_outmsg_append(homa_outmsg out_msg, const void* buf, + size_t len); + +/** + * homa_outmsg_cancel - C-binding for Homa::OutMessage::cancel + */ +extern void homa_outmsg_cancel(homa_outmsg out_msg); + +/** + * homa_outmsg_status - C-binding for Homa::OutMessage::getStatus + */ +extern int homa_outmsg_status(homa_outmsg out_msg); + +/** + * homa_outmsg_prepend - C-binding for Homa::OutMessage::prepend + */ +extern void homa_outmsg_prepend(homa_outmsg out_msg, const void* buf, + size_t len); + +/** + * homa_outmsg_reserve - C-binding for Homa::OutMessage::reserve + */ +extern void homa_outmsg_reserve(homa_outmsg out_msg, size_t n); + +/** + * homa_outmsg_send - C-binding for Homa::OutMessage::send + */ +extern void homa_outmsg_send(homa_outmsg out_msg, uint32_t ip, uint16_t port); + +/** + * homa_outmsg_register_cb - C-binding for + * Homa::OutMessage::registerCallbackEndState + */ +extern void homa_outmsg_register_cb_end_state(homa_outmsg out_msg, + void (*cb)(void*), void* data); + +/** + * homa_outmsg_release - C-binding for Homa::OutMessage::release + */ +extern void homa_outmsg_release(homa_outmsg out_msg); + +/* ============================ */ +/* Homa::Socket API */ +/* ============================ */ + +/** + * homa_sk_alloc - C-binding for Homa::Socket::alloc + */ +extern homa_outmsg homa_sk_alloc(homa_sk sk); + +/** + * homa_sk_receive - C-binding for Homa::Socket::receive + */ +extern homa_inmsg homa_sk_receive(homa_sk sk, bool blocking); + +/** + * homa_sk_shutdown - C-binding for Homa::Socket::shutdown + */ +extern void homa_sk_shutdown(homa_sk sk); + +/** + * homa_sk_is_shutdown - C-binding for Homa::Socket::isShutdown + */ +extern bool homa_sk_is_shutdown(homa_sk sk); + +/** + * homa_sk_local_addr - C-binding for Homa::Socket::getLocalAddress + */ +extern void homa_sk_local_addr(homa_sk sk, uint32_t* ip, uint16_t* port); + +/** + * homa_sk_close - C-binding for Homa::Socket::close + */ +extern void homa_sk_close(homa_sk sk); + +/* ============================ */ +/* Homa::Transport API */ +/* ============================ */ + +/** + * homa_trans_create - C-binding for Homa::Transport::create + */ +extern homa_trans homa_trans_create(homa_driver drv, homa_mailbox_dir dir, + uint64_t id); + +/** + * homa_trans_free - C-binding for Homa::Transport::free + */ +extern void homa_trans_free(homa_trans trans); + +/** + * homa_trans_open - C-binding for Homa::Transport::open + */ +extern homa_sk homa_trans_open(homa_trans trans, uint16_t port); + +/** + * homa_trans_check_timeouts - C-binding for Homa::Transport::checkTimeouts + */ +extern uint64_t homa_trans_check_timeouts(homa_trans trans); + +/** + * homa_trans_id - C-binding for Homa::Transport::getId + */ +extern uint64_t homa_trans_id(homa_trans trans); + +/** + * homa_trans_proc - C-binding for Homa::Transport::processPacket + */ +extern void homa_trans_proc(homa_trans trans, uintptr_t desc, void* payload, + int32_t len, uint32_t src_ip); + +/** + * homa_trans_try_send - C-binding for + * Homa::Transport::registerCallbackSendReady + */ +extern void homa_trans_register_cb_send_ready(homa_trans trans, + void (*cb)(void*), void* data); + +/** + * homa_trans_try_send - C-binding for Homa::Transport::trySend + */ +extern bool homa_trans_try_send(homa_trans trans, uint64_t* wait_until); + +/** + * homa_trans_try_grant - C-binding for Homa::Transport::trySendGrants + */ +extern bool homa_trans_try_grant(homa_trans trans); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index 67df785..602f78e 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -19,6 +19,7 @@ #include #include "Homa/Exception.h" +#include "Homa/OutMessageStatus.h" namespace Homa { @@ -66,26 +67,6 @@ struct IpAddress final { }; static_assert(std::is_trivially_copyable()); -/** - * Represents a packet of data that can be send or is received over the network. - * A Packet logically contains only the transport-layer (L4) Homa header in - * addition to application data. - * - * This struct specifies the minimal object layout of a packet that the core - * Homa protocol depends on (e.g., Homa::Core::{Sender, Receiver}); this is - * useful for applications that only want to use the transport layer of this - * library and have their own infrastructures for sending and receiving packets. - */ -struct PacketSpec { - /// Pointer to an array of bytes containing the payload of this Packet. - /// This array is valid until the Packet is released back to the Driver. - void* payload; - - /// Number of bytes in the payload. - int32_t length; -} __attribute__((packed)); -static_assert(std::is_trivial()); - /** * Used by Homa::Transport to send and receive unreliable datagrams. Provides * the interface to which all Driver implementations must conform. @@ -94,8 +75,23 @@ static_assert(std::is_trivial()); */ class Driver { public: - /// Import PacketSpec into the Driver namespace. - using Packet = PacketSpec; + /** + * Describes a packet of data that can be send or is received over the + * network. A Packet logically contains only the transport-layer (L4) Homa + * header in addition to application data. + */ + struct Packet { + /// Unique identifier of this Packet within the Driver. This descriptor + /// is entirely opaque to the transport. + uintptr_t descriptor; + + /// Pointer to an array of bytes containing the payload of this Packet. + /// This array is valid until the Packet is released back to the Driver. + void* payload; + + /// Number of bytes in the payload. + int32_t length; + }; /** * Driver destructor. @@ -107,7 +103,7 @@ class Driver { * caller must eventually release the packet by passing it to a call to * releasePacket(). */ - virtual Packet* allocPacket() = 0; + virtual Packet allocPacket() = 0; /** * Send a packet over the network. @@ -183,7 +179,7 @@ class Driver { * @sa Driver::releasePackets() */ virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[], + Packet receivedPackets[], IpAddress sourceAddresses[]) = 0; /** @@ -201,7 +197,7 @@ class Driver { * @param numPackets * Number of Packet objects in _packets_. */ - virtual void releasePackets(Packet* packets[], uint16_t numPackets) = 0; + virtual void releasePackets(Packet packets[], uint16_t numPackets) = 0; /** * Returns the highest packet priority level this Driver supports (0 is diff --git a/include/Homa/Drivers/DPDK/DpdkDriver.h b/include/Homa/Drivers/DPDK/DpdkDriver.h index f15d575..fbb3e2c 100644 --- a/include/Homa/Drivers/DPDK/DpdkDriver.h +++ b/include/Homa/Drivers/DPDK/DpdkDriver.h @@ -119,7 +119,7 @@ class DpdkDriver : public Driver { virtual ~DpdkDriver(); /// See Driver::allocPacket() - virtual Packet* allocPacket(); + virtual Packet allocPacket(); /// See Driver::sendPacket() virtual void sendPacket(Packet* packet, IpAddress destination, @@ -133,11 +133,11 @@ class DpdkDriver : public Driver { /// See Driver::receivePackets() virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[], + Packet receivedPackets[], IpAddress sourceAddresses[]); /// See Driver::releasePackets() - virtual void releasePackets(Packet* packets[], uint16_t numPackets); + virtual void releasePackets(Packet packets[], uint16_t numPackets); /// See Driver::getHighestPacketPriority() virtual int getHighestPacketPriority(); diff --git a/include/Homa/Drivers/Fake/FakeDriver.h b/include/Homa/Drivers/Fake/FakeDriver.h index 5f54586..80e06bb 100644 --- a/include/Homa/Drivers/Fake/FakeDriver.h +++ b/include/Homa/Drivers/Fake/FakeDriver.h @@ -52,12 +52,12 @@ void setPacketLossRate(double lossRate); * @sa Driver::Packet */ struct FakePacket { - /// C-style "inheritance"; used to maintain the base struct as a POD type. - Driver::Packet base; - /// Raw storage for this packets payload. char buf[MAX_PAYLOAD_SIZE]; + /// Number of bytes in the payload. + int length; + /// Source IpAddress of the packet. IpAddress sourceIp; @@ -65,8 +65,8 @@ struct FakePacket { * FakePacket constructor. */ explicit FakePacket() - : base{.payload = buf, .length = 0} - , buf() + : buf() + , length() , sourceIp() {} @@ -74,11 +74,21 @@ struct FakePacket { * Copy constructor. */ FakePacket(const FakePacket& other) - : base{.payload = buf, .length = other.base.length} - , buf() + : buf() + , length(other.length) , sourceIp() { - memcpy(base.payload, other.base.payload, MAX_PAYLOAD_SIZE); + memcpy(buf, other.buf, MAX_PAYLOAD_SIZE); + } + + /** + * Convert this FakePacket to a generic Driver packet representation. + */ + Driver::Packet toPacket() + { + Driver::Packet packet = { + .descriptor = (uintptr_t)this, .payload = buf, .length = length}; + return packet; } }; @@ -109,13 +119,13 @@ class FakeDriver : public Driver { */ virtual ~FakeDriver(); - virtual Packet* allocPacket(); + virtual Packet allocPacket(); virtual void sendPacket(Packet* packet, IpAddress destination, int priority); virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[], + Packet receivedPackets[], IpAddress sourceAddresses[]); - virtual void releasePackets(Packet* packets[], uint16_t numPackets); + virtual void releasePackets(Packet packets[], uint16_t numPackets); virtual int getHighestPacketPriority(); virtual uint32_t getMaxPayloadSize(); virtual uint32_t getBandwidth(); diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index aba9073..9311559 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -24,19 +24,28 @@ #define HOMA_INCLUDE_HOMA_HOMA_H #include - -#include -#include -#include +#include namespace Homa { /** * Shorthand for an std::unique_ptr with a customized deleter. + * + * This is used to implement the "borrow" semantics for interface classes like + * InMessage, OutMessage, and Socket; that is, a user can obtain pointers to + * these objects and use them to make function calls, but the user must always + * return the objects back to the transport library eventually because the user + * has no idea how to destruct the objects or reclaim memory. */ template using unique_ptr = std::unique_ptr; +/** + * Shorthand for user-defined callback functions which are used by the transport + * library to notify users of certain events. + */ +using Callback = std::function; + /** * Represents a socket address to (from) which we can send (receive) messages. */ @@ -103,6 +112,11 @@ class InMessage { virtual size_t get(size_t offset, void* destination, size_t count) const = 0; + /** + * Return the remote address from which this Message is sent. + */ + virtual SocketAddress getSourceAddress() const = 0; + /** * Return the number of bytes this Message contains. */ @@ -117,6 +131,12 @@ class InMessage { virtual void strip(size_t count) = 0; protected: + /** + * Use protected destructor to prevent users from calling delete on pointers + * to this interface. + */ + ~InMessage() = default; + /** * Signal that this message is no longer needed. The caller should not * access this message following this call. @@ -134,14 +154,7 @@ class OutMessage { /** * Defines the possible states of an OutMessage. */ - enum class Status { - NOT_STARTED, //< The sending of this message has not started. - IN_PROGRESS, //< The message is in the process of being sent. - CANCELED, //< The message was canceled while still IN_PROGRESS. - SENT, //< The message has been completely sent. - COMPLETED, //< The message has been received and processed. - FAILED, //< The message failed to be delivered and processed. - }; + using Status = OutMessageStatus; /** * Options with which an OutMessage can be sent. @@ -210,6 +223,15 @@ class OutMessage { */ virtual void prepend(const void* source, size_t count) = 0; + /** + * Register a callback function to be invoked when the status of this + * message reaches an end state. + * + * @param func + * The function object to invoke. + */ + virtual void registerCallbackEndState(Callback func) = 0; + /** * Reserve a number of bytes at the beginning of the Message. * @@ -239,6 +261,12 @@ class OutMessage { Options options = Options::NONE) = 0; protected: + /** + * Use protected destructor to prevent users from calling delete on pointers + * to this interface. + */ + ~OutMessage() = default; + /** * Signal that this message is no longer needed. The caller should not * access this message following this call. @@ -246,69 +274,355 @@ class OutMessage { virtual void release() = 0; }; +/** + * Represents a location which can hold incoming messages temporarily before + * they are consumed by high-level applications. + * + * Despite a one-to-one relationship between Mailbox and Socket, this class + * is decoupled from Socket for three reasons: + *
    + *
  • Abstract out the interaction with the user's thread scheduler: e.g., + * a user system may want to block on receive until a message is delivered; + *
  • Abstract out the mechanism for high-performance message dispatch: e.g., + * a user system may choose to implement the message receive queue with a + * concurrent MPMC queue as opposed to a linked-list protected by a mutex; + *
  • Abstract out the mechanism for safe memory reclamation of the receive + * queue: e.g., RCU is a well-known solution, reference counting is another. + *
+ * + * Note: methods in this class are NOT meant to be called by user applications + * directly; instead, they are defined by user applications and called by the + * Homa transport library. + * + * This class is thread-safe. + * + * @sa MailboxDir + */ +class Mailbox { + public: + /** + * Destructor. + */ + virtual ~Mailbox() = default; + + /** + * Signal that the caller will not access the mailbox after this call. + * A mailbox will only be destroyed if it's removed from the directory + * and closed by all openers. + * + * Not meant to be called by users. + * + * @sa MailboxDir::open() + */ + virtual void close() = 0; + + /** + * Used by a transport to deliver an ingress message to this mailbox. + * + * Not meant to be called by users. + * + * @param message + * An ingress message just completed by the transport. + */ + virtual void deliver(InMessage* message) = 0; + + /** + * Retrieve a message currently stored in the mailbox. + * + * Not meant to be called by users; use Socket::receive() instead. + * + * @param blocking + * When set to true, this method should not return until a message + * arrives or the corresponding socket is shut down. + * @return + * A message previously delivered to this mailbox, if the mailbox is + * not empty; nullptr, otherwise. + * + * @sa Socket::receive() + */ + virtual InMessage* retrieve(bool blocking) = 0; + + /** + * Invoked when the corresponding socket of the mailbox is shut down. + * All pending retrieve() requests must return immediately. + */ + virtual void socketShutdown() = 0; +}; + +/** + * Provides a means to keep track of the mailboxes that are currently in use + * by Homa sockets. + * + * This class is separated out from Transport to allow users to 1) use their + * own data structures to store the map from port numbers to mailboxes, and + * 2) apply their own mechanisms to perform synchronization (e.g., hash map + * with fine-grained locks, RCU to delay mailbox destruction, etc). + * + * Similar to Mailbox, methods in this class are NOT meant to be called by + * user applications. + * + * This class is thread-safe. + */ +class MailboxDir { + public: + /** + * Destructor. + */ + virtual ~MailboxDir() = default; + + /** + * Allocate a new mailbox in the directory. + * + * @param port + * Port number which identifies the mailbox. + * @return + * Pointer to the new Mailbox on success; nullptr, if the port number + * is already in use. + */ + virtual Mailbox* alloc(uint16_t port) = 0; + + /** + * Find and open the mailbox that matches the given port number. Once a + * mailbox is opened, it's guaranteed to remain usable even if someone else + * removes it from the directory. + * + * @param port + * Port number which identifies the mailbox. + * @return + * Pointer to the opened mailbox on success; nullptr, if the desired + * mailbox doesn't exist. + */ + virtual Mailbox* open(uint16_t port) = 0; + + /** + * Remove the mailbox that matches the given port number. + * + * @param port + * Port number of the mailbox that will be removed. + * @return + * True on success; false, if the desired mailbox doesn't exist. + */ + virtual bool remove(uint16_t port) = 0; +}; + +/** + * Connection-less socket that can be used to send and receive Homa messages. + * + * This class is thread-safe. + */ +class Socket { + public: + using Address = SocketAddress; + + /** + * Custom deleter for use with Homa::unique_ptr. + */ + struct Deleter { + void operator()(Socket* socket) + { + socket->close(); + } + }; + + /** + * Allocate Message that can be sent with this Socket. + * + * @param sourcePort + * Port number of the socket from which the message will be sent. + * @return + * A pointer to the allocated message or nullptr if the socket has + * been shut down. + */ + virtual Homa::unique_ptr alloc() = 0; + + /** + * Check for and return a Message sent to this Socket if available. + * + * @param blocking + * When set to true, this method should not return until a message + * arrives or the socket is shut down. + * @return + * Pointer to the received message, if any. Otherwise, nullptr is + * returned if no message has been delivered or the socket has been + * shut down. + */ + virtual Homa::unique_ptr receive(bool blocking) = 0; + + /** + * Disable the socket. Once a socket is shut down, all ongoing/subsequent + * requests on the socket will return a failure. + * + * When multiple threads are working on a socket, this method can be used + * to notify other threads to drop their references to this socket so that + * the caller can safely close() the socket. + */ + virtual void shutdown() = 0; + + /** + * Check if the Socket has been shut down. + */ + virtual bool isShutdown() const = 0; + + /** + * Return the local IP address and port number of this Socket. + */ + virtual Socket::Address getLocalAddress() const = 0; + + protected: + /** + * Use protected destructor to prevent users from calling delete on pointers + * to this interface. + */ + ~Socket() = default; + + /** + * Signal that this Socket is no longer needed. No one should access this + * socket after this call. + * + * Note: outgoing messages already allocated from this socket will not be + * affected. + */ + virtual void close() = 0; +}; + /** * Provides a means of communicating across the network using the Homa protocol. * - * The transport is used to send and receive messages across the network using - * the RemoteOp and ServerOp abstractions. The execution of the transport is - * driven through repeated calls to the Transport::poll() method; the transport - * will not make any progress otherwise. + * The execution of the transport is driven through repeated calls to methods + * like checkTimeouts(), processPacket(), trySend(), and trySendGrants(); the + * transport will not make any progress otherwise. * * This class is thread-safe. */ class Transport { public: + /** + * Custom deleter for use with std::unique_ptr. + */ + struct Deleter { + void operator()(Transport* transport) + { + transport->free(); + } + }; + /** * Return a new instance of a Homa-based transport. * - * The caller is responsible for calling free() on the returned pointer. - * * @param driver * Driver with which this transport should send and receive packets. + * @param mailboxDir + * Mailbox directory with which this transport should decide where + * to deliver a message. * @param transportId * This transport's unique identifier in the group of transports among * which this transport will communicate. * @return * Pointer to the new transport instance. */ - static Transport* create(Driver* driver, uint64_t transportId); + static Homa::unique_ptr create(Driver* driver, + MailboxDir* mailboxDir, + uint64_t transportId); /** - * Allocate Message that can be sent with this Transport. + * Create a socket that can be used to send and receive messages. * - * @param sourcePort - * Port number of the socket from which the message will be sent. + * @param port + * The port number allocated to the socket. * @return - * A pointer to the allocated message. + * Pointer to the new socket, if the port number is not in use; + * nullptr, otherwise. */ - virtual Homa::unique_ptr alloc(uint16_t sourcePort) = 0; + virtual Homa::unique_ptr open(uint16_t port) = 0; /** - * Check for and return a Message sent to this Transport if available. + * Return the driver that this transport uses to send and receive packets. + */ + virtual Driver* getDriver() = 0; + + /** + * Return this transport's unique identifier. + */ + virtual uint64_t getId() = 0; + + /** + * Process any timeouts that have expired. + * + * This method must be called periodically to ensure timely handling of + * expired timeouts. * * @return - * Pointer to the received message, if any. Otherwise, nullptr is - * returned if no message has been delivered. + * The rdtsc cycle time when this method should be called again. */ - virtual Homa::unique_ptr receive() = 0; + virtual uint64_t checkTimeouts() = 0; /** - * Make incremental progress performing all Transport functionality. + * Handle an ingress packet by running it through the transport protocol + * stack. * - * This method MUST be called for the Transport to make progress and should - * be called frequently to ensure timely progress. + * @param packet + * The ingress packet. + * @param source + * IpAddress of the socket from which the packet is sent. */ - virtual void poll() = 0; + virtual void processPacket(Driver::Packet* packet, IpAddress source) = 0; /** - * Return the driver that this transport uses to send and receive packets. + * Register a callback function to be invoked when some packets just became + * ready to be sent (and there was none before). + * + * This callback allows the transport library to notify the users that + * trySend() should be invoked again as soon as possible. For example, + * the callback can be used to implement wakeup signals for the thread + * that is responsible for calling trySend(), if this thread decides to + * sleep when there is no packets to send. + * + * @param func + * The function object to invoke. */ - virtual Driver* getDriver() = 0; + virtual void registerCallbackSendReady(Callback func) = 0; /** - * Return this transport's unique identifier. + * Attempt to send out packets for any messages with unscheduled/granted + * bytes in a way that limits queue buildup in the NIC. + * + * This method must be called eagerly to allow the Transport to make + * progress toward sending outgoing messages. + * + * @param[out] waitUntil + * The rdtsc cycle time when this method should be called again + * (this allows the NIC to drain its transmit queue). Only set + * when this method returns true. + * @return + * True if more packets are ready to be transmitted when the method + * returns; false, otherwise. */ - virtual uint64_t getId() = 0; + virtual bool trySend(uint64_t* waitUntil) = 0; + + /** + * Attempt to grant to incoming messages according to the Homa protocol. + * + * This method must be called eagerly to allow the Transport to make + * progress toward receiving incoming messages. + * + * @return + * True if the method has found some messages to grant; false, + * otherwise. + */ + virtual bool trySendGrants() = 0; + + protected: + /** + * Use protected destructor to prevent users from calling delete on pointers + * to this interface. + */ + ~Transport() = default; + + /** + * Free this transport instance. No one should not access this transport + * following this call. + */ + virtual void free() = 0; }; /** diff --git a/include/Homa/OutMessageStatus.h b/include/Homa/OutMessageStatus.h new file mode 100644 index 0000000..9957e7a --- /dev/null +++ b/include/Homa/OutMessageStatus.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2020 Stanford University + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR(S) DISCLAIM ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL AUTHORS BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#pragma once + +/** + * Defines the possible states of an OutMessage. + */ +#ifdef __cplusplus +enum class OutMessageStatus : int { +#else +enum homa_outmsg_status { +#endif + NOT_STARTED, //< The sending of this message has not started. + IN_PROGRESS, //< The message is in the process of being sent. + CANCELED, //< The message was canceled while still IN_PROGRESS. + SENT, //< The message has been completely sent. + COMPLETED, //< The message has been received and processed. + FAILED, //< The message failed to be delivered and processed. +}; diff --git a/include/Homa/Shenango.h b/include/Homa/Shenango.h new file mode 100644 index 0000000..0d5bfea --- /dev/null +++ b/include/Homa/Shenango.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2020 Stanford University + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR(S) DISCLAIM ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL AUTHORS BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +/** + * @file Shenango.h + * + * Contains the glue code for Homa-Shenango integration. This is the only + * header Shenango needs to include in order to use Homa transport. + * + * Shenango is an experimental operating system that aims to provide low tail + * latency and high CPU efficiency simultaneously for servers in datacenters. + * See for more information. + * + * This file follows the Shenango coding style. + */ + +#pragma once + +#include "Bindings/CHoma.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * homa_driver_create - creates a shim driver that translates Homa::Driver + * operations to Shenango functions + * @proto: protocol number reserved for Homa transport protocol + * @local_ip: local IP address of the driver + * @max_payload: maximum number of bytes carried by the packet payload + * @link_speed: effective network bandwidth, in Mbits/second + * + * Returns a handle to the driver created. + */ +extern homa_driver homa_driver_create(uint8_t proto, uint32_t local_ip, + uint32_t max_payload, + uint32_t link_speed); + +/** + * homa_driver_free - frees a shim driver created earlier with + * @homa_driver_create. + * @param drv: the driver to free + */ +extern void homa_driver_free(homa_driver drv); + +/** + * homa_mb_dir_create - creates a shim mailbox directory that translates + * Homa::Mailbox operations to Shenango functions + * @proto: protocol number reserved for Homa transport protocol + * @local_ip: local IP address of the driver + * + * Returns a handle to the mailbox created. + */ +extern homa_mailbox_dir homa_mb_dir_create(uint8_t proto, uint32_t local_ip); + +/** + * homa_mb_dir_free - frees a shim mailbox directory created earlier with + * @homa_mb_dir_create. + * @param mailbox_dir: the mailbox directory to free + */ +extern void homa_mb_dir_free(homa_mailbox_dir mailbox_dir); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 4f17acc..0f58cb7 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -21,14 +21,6 @@ #include #include -/// Cast a member of a structure out to the containing structure. -template -P* -container_of(M* ptr, const M P::*member) -{ - return (P*)((char*)ptr - (size_t) & (reinterpret_cast(0)->*member)); -} - namespace Homa { namespace Util { diff --git a/include/Homa/Utils/SimpleMailboxDir.h b/include/Homa/Utils/SimpleMailboxDir.h new file mode 100644 index 0000000..78f3314 --- /dev/null +++ b/include/Homa/Utils/SimpleMailboxDir.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +/** + * @file Homa/Utils/SimpleMailboxDir.h + * + * Contains a simple reference implementation for the pluggable mailbox + * directory in the Homa transport library. A mailbox directory is essential + * to get a working transport but it's not central to the Homa protocol. + * + * Users may choose to use this reference implementation for starter, or define + * their own implementation for best performance. + */ + +#pragma once + +#include +#include + +namespace Homa { + +/// Forward declaration +class SpinLock; +class MailboxImpl; + +/** + * A simple reference implementation of Homa::MailboxDir. + * + * This class relies on a monitor-style lock to protect the hash table that + * maps port numbers to mailboxes and uses reference-counting for safe + * reclamation of removed mailboxes. + */ +class SimpleMailboxDir final : public MailboxDir { + public: + explicit SimpleMailboxDir(); + ~SimpleMailboxDir() override; + Mailbox* alloc(uint16_t port) override; + Mailbox* open(uint16_t port) override; + bool remove(uint16_t port) override; + + private: + /// Monitor-style lock. + std::unique_ptr mutex; + + /// Hash table that maps port numbers to mailboxes. + std::unordered_map map; +}; + +} // namespace Homa diff --git a/include/Homa/Utils/TransportPoller.h b/include/Homa/Utils/TransportPoller.h new file mode 100644 index 0000000..097600c --- /dev/null +++ b/include/Homa/Utils/TransportPoller.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#pragma once + +#include + +namespace Homa { + +/// Forward declaration. +class Transport; + +/** + * Provides a means to drive the execution of a transport through repeated + * calls to the poll() method. + * + * This class demonstrates a simple way to invoke the Homa::Transport APIs + * in a poll-based programming style. In practice, users will often need to + * invoke the Transport APIs in ways that fit their systems better. The Homa- + * Shenango integration provides a concrete example. + * + * This class is thread-safe; although calling poll() from multiple threads + * provides no performance benefit. + * + * @sa Homa/Shenango.h + */ +class TransportPoller { + public: + explicit TransportPoller(Transport* transport); + ~TransportPoller() = default; + void poll(); + + private: + void processPackets(); + + /// Transport instance whose execution is driven by this poller. + Transport* const transport; + + /// Caches the next cycle time that timeouts will need to rechecked. + std::atomic nextTimeoutCycles; +}; + +} // namespace Homa \ No newline at end of file diff --git a/src/CHoma.cc b/src/CHoma.cc new file mode 100644 index 0000000..d2a0d54 --- /dev/null +++ b/src/CHoma.cc @@ -0,0 +1,228 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "Homa/Bindings/CHoma.h" +#include "Homa/Homa.h" + +using namespace Homa; + +/// Shorthand for converting C-style Homa object handle types back to C++ types. +#define deref(T, x) (*static_cast(x.p)) + +void +homa_inmsg_ack(homa_inmsg in_msg) +{ + deref(InMessage, in_msg).acknowledge(); +} + +bool +homa_inmsg_dropped(homa_inmsg in_msg) +{ + return deref(InMessage, in_msg).dropped(); +} + +void +homa_inmsg_fail(homa_inmsg in_msg) +{ + deref(InMessage, in_msg).fail(); +} + +size_t +homa_inmsg_get(homa_inmsg in_msg, size_t ofs, void* dst, size_t len) +{ + return deref(InMessage, in_msg).get(ofs, dst, len); +} + +void +homa_inmsg_src_addr(homa_inmsg in_msg, uint32_t* ip, uint16_t* port) +{ + SocketAddress src = deref(InMessage, in_msg).getSourceAddress(); + *ip = (uint32_t)src.ip; + *port = src.port; +} + +size_t +homa_inmsg_len(homa_inmsg in_msg) +{ + return deref(InMessage, in_msg).length(); +} + +void +homa_inmsg_release(homa_inmsg in_msg) +{ + InMessage::Deleter deleter; + deleter(&deref(InMessage, in_msg)); +} + +void +homa_inmsg_strip(homa_inmsg in_msg, size_t n) +{ + deref(InMessage, in_msg).strip(n); +} + +void +homa_outmsg_append(homa_outmsg out_msg, const void* buf, size_t len) +{ + deref(OutMessage, out_msg).append(buf, len); +} + +void +homa_outmsg_cancel(homa_outmsg out_msg) +{ + deref(OutMessage, out_msg).cancel(); +} + +int +homa_outmsg_status(homa_outmsg out_msg) +{ + return int(deref(OutMessage, out_msg).getStatus()); +} + +void +homa_outmsg_prepend(homa_outmsg out_msg, const void* buf, size_t len) +{ + deref(OutMessage, out_msg).prepend(buf, len); +} + +void +homa_outmsg_reserve(homa_outmsg out_msg, size_t n) +{ + deref(OutMessage, out_msg).reserve(n); +} + +void +homa_outmsg_send(homa_outmsg out_msg, uint32_t ip, uint16_t port) +{ + deref(OutMessage, out_msg).send({IpAddress{ip}, port}); +} + +void +homa_outmsg_register_cb_end_state(homa_outmsg out_msg, void (*cb)(void*), + void* data) +{ + std::function func = std::bind(cb, data); + deref(OutMessage, out_msg).registerCallbackEndState(func); +} + +void +homa_outmsg_release(homa_outmsg out_msg) +{ + OutMessage::Deleter deleter; + deleter(&deref(OutMessage, out_msg)); +} + +homa_outmsg +homa_sk_alloc(homa_sk sk) +{ + unique_ptr out_msg = deref(Socket, sk).alloc(); + return homa_outmsg{out_msg.release()}; +} + +homa_inmsg +homa_sk_receive(homa_sk sk, bool blocking) +{ + unique_ptr in_msg = deref(Socket, sk).receive(blocking); + return homa_inmsg{in_msg.release()}; +} + +void +homa_sk_shutdown(homa_sk sk) +{ + deref(Socket, sk).shutdown(); +} + +bool +homa_sk_is_shutdown(homa_sk sk) +{ + return deref(Socket, sk).isShutdown(); +} + +void +homa_sk_local_addr(homa_sk sk, uint32_t* ip, uint16_t* port) +{ + SocketAddress addr = deref(Socket, sk).getLocalAddress(); + *ip = (uint32_t)addr.ip; + *port = addr.port; +} + +void +homa_sk_close(homa_sk sk) +{ + Socket::Deleter deleter; + deleter(&deref(Socket, sk)); +} + +homa_trans +homa_trans_create(homa_driver drv, homa_mailbox_dir dir, uint64_t id) +{ + unique_ptr trans = + Transport::create(&deref(Driver, drv), &deref(MailboxDir, dir), id); + return homa_trans{trans.release()}; +} + +void +homa_trans_free(homa_trans trans) +{ + Transport::Deleter deleter; + deleter(&deref(Transport, trans)); +} + +homa_sk +homa_trans_open(homa_trans trans, uint16_t port) +{ + unique_ptr sk = deref(Transport, trans).open(port); + return homa_sk{sk.release()}; +} + +uint64_t +homa_trans_check_timeouts(homa_trans trans) +{ + return deref(Transport, trans).checkTimeouts(); +} + +uint64_t +homa_trans_id(homa_trans trans) +{ + return deref(Transport, trans).getId(); +} + +void +homa_trans_proc(homa_trans trans, uintptr_t desc, void* payload, int32_t len, + uint32_t src_ip) +{ + Driver::Packet packet = { + .descriptor = desc, .payload = payload, .length = len}; + deref(Transport, trans).processPacket(&packet, IpAddress{src_ip}); +} + +void +homa_trans_register_cb_send_ready(homa_trans trans, void (*cb)(void*), + void* data) +{ + std::function func = std::bind(cb, data); + deref(Transport, trans).registerCallbackSendReady(func); +} + +bool +homa_trans_try_send(homa_trans trans, uint64_t* wait_until) +{ + return deref(Transport, trans).trySend(wait_until); +} + +bool +homa_trans_try_grant(homa_trans trans) +{ + return deref(Transport, trans).trySendGrants(); +} diff --git a/src/ControlPacket.h b/src/ControlPacket.h index 17310af..f8d71c9 100644 --- a/src/ControlPacket.h +++ b/src/ControlPacket.h @@ -39,11 +39,11 @@ template void send(Driver* driver, IpAddress address, Args&&... args) { - Driver::Packet* packet = driver->allocPacket(); - new (packet->payload) PacketHeaderType(static_cast(args)...); - packet->length = sizeof(PacketHeaderType); - Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet, address, driver->getHighestPacketPriority()); + Driver::Packet packet = driver->allocPacket(); + new (packet.payload) PacketHeaderType(static_cast(args)...); + packet.length = sizeof(PacketHeaderType); + Perf::counters.tx_bytes.add(packet.length); + driver->sendPacket(&packet, address, driver->getHighestPacketPriority()); driver->releasePackets(&packet, 1); } diff --git a/src/Drivers/DPDK/DpdkDriver.cc b/src/Drivers/DPDK/DpdkDriver.cc index c27d1df..a6cc48d 100644 --- a/src/Drivers/DPDK/DpdkDriver.cc +++ b/src/Drivers/DPDK/DpdkDriver.cc @@ -38,7 +38,7 @@ DpdkDriver::DpdkDriver(const char* ifname, NoEalInit _, DpdkDriver::~DpdkDriver() = default; /// See Driver::allocPacket() -Driver::Packet* +Driver::Packet DpdkDriver::allocPacket() { return pImpl->allocPacket(); @@ -67,14 +67,14 @@ DpdkDriver::uncork() /// See Driver::receivePackets() uint32_t -DpdkDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[], +DpdkDriver::receivePackets(uint32_t maxPackets, Packet receivedPackets[], IpAddress sourceAddresses[]) { return pImpl->receivePackets(maxPackets, receivedPackets, sourceAddresses); } /// See Driver::releasePackets() void -DpdkDriver::releasePackets(Packet* packets[], uint16_t numPackets) +DpdkDriver::releasePackets(Packet packets[], uint16_t numPackets) { pImpl->releasePackets(packets, numPackets); } diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index e9fef18..18daf1c 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -43,33 +43,43 @@ const int default_eal_argc = 1; const char* default_eal_argv[] = {"homa", NULL}; /** - * Construct a DPDK Packet backed by a DPDK mbuf. + * Construct a DPDK PacketBuf backed by a DPDK mbuf. * * @param mbuf * Pointer to the DPDK mbuf that holds this packet. * @param data * Memory location in the mbuf where the packet data should be stored. */ -DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) - : base{.payload = data, .length = 0} +DpdkDriver::Impl::PacketBuf::PacketBuf(struct rte_mbuf* mbuf, void* data) + : payload(data) , bufType(MBUF) - , bufRef() -{ - bufRef.mbuf = mbuf; -} + , bufRef{.mbuf = mbuf} +{} /** - * Construct a DPDK Packet backed by an OverflowBuffer. + * Construct a DPDK PacketBuf backed by an OverflowBuffer. * * @param overflowBuf * Overflow buffer that holds this packet. */ -DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) - : base{.payload = overflowBuf->data, .length = 0} +DpdkDriver::Impl::PacketBuf::PacketBuf(OverflowBuffer* overflowBuf) + : payload(overflowBuf->data) , bufType(OVERFLOW_BUF) - , bufRef() + , bufRef{.overflowBuf = overflowBuf} +{} + +/** + * Convert this DPDK PacketBuf into the generic Driver::Packet representation. + * + * @param length + * Number of bytes used in the payload buffer. + */ +Driver::Packet +DpdkDriver::Impl::PacketBuf::toPacket(int length) { - bufRef.overflowBuf = overflowBuf; + Driver::Packet packet = { + .descriptor = (uintptr_t)this, .payload = payload, .length = length}; + return packet; } /** @@ -174,17 +184,17 @@ DpdkDriver::Impl::~Impl() } // See Driver::allocPacket() -Driver::Packet* +Driver::Packet DpdkDriver::Impl::allocPacket() { - DpdkDriver::Impl::Packet* packet = _allocMbufPacket(); - if (unlikely(packet == nullptr)) { + PacketBuf* packetBuf = _allocMbufPacket(); + if (unlikely(packetBuf == nullptr)) { SpinLock::Lock lock(packetLock); OverflowBuffer* buf = overflowBufferPool.construct(); - packet = packetPool.construct(buf); + packetBuf = packetPool.construct(buf); NOTICE("OverflowBuffer used."); } - return &packet->base; + return packetBuf->toPacket(0); } // See Driver::sendPacket() @@ -192,15 +202,13 @@ void DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, int priority) { - ; - DpdkDriver::Impl::Packet* pkt = - container_of(packet, DpdkDriver::Impl::Packet, base); - struct rte_mbuf* mbuf = nullptr; + auto* packetBuf = (PacketBuf*)packet->descriptor; // If the packet is held in an Overflow buffer, we need to copy it out // into a new mbuf. - if (unlikely(pkt->bufType == DpdkDriver::Impl::Packet::OVERFLOW_BUF)) { + struct rte_mbuf* mbuf = nullptr; + if (unlikely(packetBuf->bufType == PacketBuf::OVERFLOW_BUF)) { mbuf = rte_pktmbuf_alloc(mbufPool); - if (unlikely(NULL == mbuf)) { + if (unlikely(nullptr == mbuf)) { uint32_t numMbufsAvail = rte_mempool_avail_count(mbufPool); uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); WARNING( @@ -212,16 +220,16 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, } char* buf = rte_pktmbuf_append( mbuf, - Homa::Util::downCast(PACKET_HDR_LEN + pkt->base.length)); + Homa::Util::downCast(PACKET_HDR_LEN + packet->length)); if (unlikely(NULL == buf)) { WARNING("rte_pktmbuf_append call failed; dropping packet"); rte_pktmbuf_free(mbuf); return; } char* data = buf + PACKET_HDR_LEN; - rte_memcpy(data, pkt->base.payload, pkt->base.length); + rte_memcpy(data, packetBuf->payload, packet->length); } else { - mbuf = pkt->bufRef.mbuf; + mbuf = packetBuf->bufRef.mbuf; // If the mbuf is still transmitting from a previous call to send, // we don't want to modify the buffer when the send is occuring. @@ -259,7 +267,7 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, // In the normal case, we pre-allocate a pakcet's mbuf with enough // storage to hold the MAX_PAYLOAD_SIZE. If the actual payload is // smaller, trim the mbuf to size to avoid sending unecessary bits. - uint32_t actualLength = PACKET_HDR_LEN + pkt->base.length; + uint32_t actualLength = PACKET_HDR_LEN + packet->length; uint32_t mbufDataLength = rte_pktmbuf_pkt_len(mbuf); if (actualLength < mbufDataLength) { if (rte_pktmbuf_trim(mbuf, mbufDataLength - actualLength) < 0) { @@ -286,7 +294,7 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, // If the packet is held in an mbuf, retain access to it so that the // processing of sending the mbuf won't free it. - if (likely(pkt->bufType == DpdkDriver::Impl::Packet::MBUF)) { + if (likely(packetBuf->bufType == PacketBuf::MBUF)) { rte_pktmbuf_refcnt_update(mbuf, 1); } @@ -324,7 +332,7 @@ DpdkDriver::Impl::uncork() // See Driver::receivePackets() uint32_t DpdkDriver::Impl::receivePackets(uint32_t maxPackets, - Driver::Packet* receivedPackets[], + Driver::Packet receivedPackets[], IpAddress sourceAddresses[]) { uint32_t numPacketsReceived = 0; @@ -395,14 +403,13 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, uint32_t length = rte_pktmbuf_pkt_len(m) - headerLength; assert(length <= MAX_PAYLOAD_SIZE); - DpdkDriver::Impl::Packet* packet = nullptr; + PacketBuf* packetBuf = nullptr; { SpinLock::Lock lock(packetLock); - packet = packetPool.construct(m, payload); + packetBuf = packetPool.construct(m, payload); } - packet->base.length = length; - receivedPackets[numPacketsReceived] = &packet->base; + receivedPackets[numPacketsReceived] = packetBuf->toPacket(length); sourceAddresses[numPacketsReceived] = {srcIp}; ++numPacketsReceived; } @@ -412,18 +419,17 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, // See Driver::releasePackets() void -DpdkDriver::Impl::releasePackets(Driver::Packet* packets[], uint16_t numPackets) +DpdkDriver::Impl::releasePackets(Driver::Packet packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { SpinLock::Lock lock(packetLock); - DpdkDriver::Impl::Packet* packet = - container_of(packets[i], DpdkDriver::Impl::Packet, base); - if (likely(packet->bufType == DpdkDriver::Impl::Packet::MBUF)) { - rte_pktmbuf_free(packet->bufRef.mbuf); + auto* packetBuf = (PacketBuf*)packets[i].descriptor; + if (likely(packetBuf->bufType == PacketBuf::MBUF)) { + rte_pktmbuf_free(packetBuf->bufRef.mbuf); } else { - overflowBufferPool.destroy(packet->bufRef.overflowBuf); + overflowBufferPool.destroy(packetBuf->bufRef.overflowBuf); } - packetPool.destroy(packet); + packetPool.destroy(packetBuf); } } @@ -715,10 +721,9 @@ DpdkDriver::Impl::_init() * The newly allocated Dpdk Packet; nullptr if the mbuf allocation * failed. */ -DpdkDriver::Impl::Packet* +DpdkDriver::Impl::PacketBuf* DpdkDriver::Impl::_allocMbufPacket() { - DpdkDriver::Impl::Packet* packet = nullptr; uint32_t numMbufsAvail = rte_mempool_avail_count(mbufPool); if (unlikely(numMbufsAvail <= NB_MBUF_RESERVED)) { uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); @@ -752,11 +757,8 @@ DpdkDriver::Impl::_allocMbufPacket() } // Perform packet operations with the lock held. - { - SpinLock::Lock _(packetLock); - packet = packetPool.construct(mbuf, buf + PACKET_HDR_LEN); - } - return packet; + SpinLock::Lock _(packetLock); + return packetPool.construct(mbuf, buf + PACKET_HDR_LEN); } /** diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 4d664fb..819feb9 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -107,15 +107,16 @@ struct OverflowBuffer { class DpdkDriver::Impl { public: /** - * Dpdk specific Packet object used to track a its lifetime and + * DPDK specific Packet object used to track a its lifetime and * contents. */ - struct Packet { - explicit Packet(struct rte_mbuf* mbuf, void* data); - explicit Packet(OverflowBuffer* overflowBuf); + struct PacketBuf { + explicit PacketBuf(struct rte_mbuf* mbuf, void* data); + explicit PacketBuf(OverflowBuffer* overflowBuf); + Driver::Packet toPacket(int length); - /// C-style "inheritance" - Driver::Packet base; + /// Memory location where the packet data should be stored. + void* const payload; /// Used to indicate whether the packet is backed by an DPDK mbuf or a /// driver-level OverflowBuffer. @@ -126,10 +127,6 @@ class DpdkDriver::Impl { struct rte_mbuf* mbuf; OverflowBuffer* overflowBuf; } bufRef; - - /// The memory location of this packet's header. The header should be - /// PACKET_HDR_LEN in length. - void* header; }; Impl(const char* ifname, const Config* const config = nullptr); @@ -139,15 +136,15 @@ class DpdkDriver::Impl { virtual ~Impl(); // Interface Methods - Driver::Packet* allocPacket(); + Driver::Packet allocPacket(); void sendPacket(Driver::Packet* packet, IpAddress destination, int priority); void cork(); void uncork(); uint32_t receivePackets(uint32_t maxPackets, - Driver::Packet* receivedPackets[], + Driver::Packet receivedPackets[], IpAddress sourceAddresses[]); - void releasePackets(Driver::Packet* packets[], uint16_t numPackets); + void releasePackets(Driver::Packet packets[], uint16_t numPackets); int getHighestPacketPriority(); uint32_t getMaxPayloadSize(); uint32_t getBandwidth(); @@ -157,7 +154,7 @@ class DpdkDriver::Impl { private: void _eal_init(int argc, char* argv[]); void _init(); - Packet* _allocMbufPacket(); + PacketBuf* _allocMbufPacket(); static uint16_t txBurstCallback(uint16_t port_id, uint16_t queue, struct rte_mbuf* pkts[], uint16_t nb_pkts, void* user_param); @@ -189,7 +186,7 @@ class DpdkDriver::Impl { /// Provides memory allocation for the DPDK specific implementation of a /// Driver Packet. - ObjectPool packetPool; + ObjectPool packetPool; /// Provides memory allocation for packet storage when mbuf are running out. ObjectPool overflowBufferPool; diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index 26cb102..16fa12e 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -181,11 +181,11 @@ FakeDriver::~FakeDriver() /** * See Driver::allocPacket() */ -Driver::Packet* +Driver::Packet FakeDriver::allocPacket() { - FakePacket* packet = new FakePacket(); - return &packet->base; + FakePacket* fakePacket = new FakePacket(); + return fakePacket->toPacket(); } /** @@ -194,7 +194,8 @@ FakeDriver::allocPacket() void FakeDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - FakePacket* srcPacket = container_of(packet, &FakePacket::base); + FakePacket* srcPacket = (FakePacket*)packet->descriptor; + srcPacket->length = packet->length; IpAddress srcAddress = getLocalAddress(); IpAddress dstAddress = destination; fakeNetwork.sendPacket(srcPacket, priority, srcAddress, dstAddress); @@ -205,7 +206,7 @@ FakeDriver::sendPacket(Packet* packet, IpAddress destination, int priority) * See Driver::receivePackets() */ uint32_t -FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[], +FakeDriver::receivePackets(uint32_t maxPackets, Packet receivedPackets[], IpAddress sourceAddresses[]) { std::lock_guard lock_nic(nic.mutex); @@ -214,7 +215,7 @@ FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[], while (numReceived < maxPackets && !nic.priorityQueue.at(i).empty()) { FakePacket* fakePacket = nic.priorityQueue.at(i).front(); nic.priorityQueue.at(i).pop_front(); - receivedPackets[numReceived] = &fakePacket->base; + receivedPackets[numReceived] = fakePacket->toPacket(); sourceAddresses[numReceived] = fakePacket->sourceIp; numReceived++; } @@ -226,10 +227,10 @@ FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[], * See Driver::releasePackets() */ void -FakeDriver::releasePackets(Packet* packets[], uint16_t numPackets) +FakeDriver::releasePackets(Packet packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { - delete container_of(packets[i], &FakePacket::base); + delete (FakePacket*)packets[i].descriptor; } } diff --git a/src/Drivers/Fake/FakeDriverTest.cc b/src/Drivers/Fake/FakeDriverTest.cc index cd64917..b645d68 100644 --- a/src/Drivers/Fake/FakeDriverTest.cc +++ b/src/Drivers/Fake/FakeDriverTest.cc @@ -35,10 +35,9 @@ TEST(FakeDriverTest, constructor) TEST(FakeDriverTest, allocPacket) { - FakeDriver driver; - Driver::Packet* packet = driver.allocPacket(); // allocPacket doesn't do much so we just need to make sure we can call it. - delete container_of(packet, &FakePacket::base); + FakeDriver driver; + Driver::Packet packet = driver.allocPacket(); } TEST(FakeDriverTest, sendPackets) @@ -46,7 +45,7 @@ TEST(FakeDriverTest, sendPackets) FakeDriver driver1; FakeDriver driver2; - Driver::Packet* packets[4]; + Driver::Packet packets[4]; IpAddress destinations[4]; int prio[4]; for (int i = 0; i < 4; ++i) { @@ -65,7 +64,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - driver1.sendPacket(packets[0], destinations[0], prio[0]); + driver1.sendPacket(&packets[0], destinations[0], prio[0]); EXPECT_EQ(1U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -81,7 +80,7 @@ TEST(FakeDriverTest, sendPackets) } for (int i = 0; i < 4; ++i) { - driver1.sendPacket(packets[i], destinations[i], prio[i]); + driver1.sendPacket(&packets[i], destinations[i], prio[i]); } EXPECT_EQ(2U, driver2.nic.priorityQueue.at(0).size()); @@ -92,8 +91,6 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(5).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - - delete container_of(packets[2], &FakePacket::base); } TEST(FakeDriverTest, receivePackets) @@ -101,7 +98,7 @@ TEST(FakeDriverTest, receivePackets) std::string addressStr("42"); FakeDriver driver; - Driver::Packet* packets[4]; + Driver::Packet packets[4]; IpAddress srcAddrs[4]; // 3 packets at priority 7 diff --git a/src/Homa.cc b/src/Homa.cc index e03b55e..b72a4bf 100644 --- a/src/Homa.cc +++ b/src/Homa.cc @@ -19,10 +19,12 @@ namespace Homa { -Transport* -Transport::create(Driver* driver, uint64_t transportId) +Homa::unique_ptr +Transport::create(Driver* driver, MailboxDir* mailboxDir, uint64_t transportId) { - return new Core::TransportImpl(driver, transportId); + Transport* transport = + new Core::TransportImpl(driver, mailboxDir, transportId); + return Homa::unique_ptr(transport); } } // namespace Homa diff --git a/src/Mock/MockDriver.h b/src/Mock/MockDriver.h index 9ea6ffe..dfb6ec2 100644 --- a/src/Mock/MockDriver.h +++ b/src/Mock/MockDriver.h @@ -31,22 +31,35 @@ namespace Mock { class MockDriver : public Driver { public: /** - * Used in unit tests to mock calls to Driver::Packet. - * - * @sa Driver::Packet. + * Used in unit tests to mock driver-specific packet buffers. */ - using MockPacket = Driver::Packet; + struct PacketBuf { + /// External buffer which stores the packet data. + void* buffer; - MOCK_METHOD(Packet*, allocPacket, (), (override)); + /** + * Convert this packet buffer to the generic Driver::Packet + * representation. + */ + Driver::Packet toPacket(int length = 0) + { + Driver::Packet packet = {.descriptor = (uintptr_t)this, + .payload = buffer, + .length = length}; + return packet; + } + }; + + MOCK_METHOD(Packet, allocPacket, (), (override)); MOCK_METHOD(void, sendPacket, (Packet * packet, IpAddress destination, int priority), (override)); MOCK_METHOD(void, flushPackets, ()); MOCK_METHOD(uint32_t, receivePackets, - (uint32_t maxPackets, Packet* receivedPackets[], + (uint32_t maxPackets, Packet receivedPackets[], IpAddress sourceAddresses[]), (override)); - MOCK_METHOD(void, releasePackets, (Packet * packets[], uint16_t numPackets), + MOCK_METHOD(void, releasePackets, (Packet packets[], uint16_t numPackets), (override)); MOCK_METHOD(int, getHighestPacketPriority, (), (override)); MOCK_METHOD(uint32_t, getMaxPayloadSize, (), (override)); diff --git a/src/Mock/MockReceiver.h b/src/Mock/MockReceiver.h index 61c21ce..e8e4e1d 100644 --- a/src/Mock/MockReceiver.h +++ b/src/Mock/MockReceiver.h @@ -33,7 +33,8 @@ class MockReceiver : public Core::Receiver { public: MockReceiver(Driver* driver, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles) - : Receiver(driver, nullptr, messageTimeoutCycles, resendIntervalCycles) + : Receiver(driver, nullptr, nullptr, messageTimeoutCycles, + resendIntervalCycles) {} MOCK_METHOD(void, handleDataPacket, @@ -41,9 +42,8 @@ class MockReceiver : public Core::Receiver { MOCK_METHOD(void, handleBusyPacket, (Driver::Packet * packet), (override)); MOCK_METHOD(void, handlePingPacket, (Driver::Packet * packet, IpAddress sourceIp), (override)); - MOCK_METHOD(Homa::InMessage*, receiveMessage, (), (override)); - MOCK_METHOD(void, poll, (), (override)); MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); + MOCK_METHOD(bool, trySendGrants, (), (override)); }; } // namespace Mock diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index cb29c90..0da8388 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -45,8 +45,8 @@ class MockSender : public Core::Sender { MOCK_METHOD(void, handleUnknownPacket, (Driver::Packet * packet), (override)); MOCK_METHOD(void, handleErrorPacket, (Driver::Packet * packet), (override)); - MOCK_METHOD(void, poll, (), (override)); MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); + MOCK_METHOD(bool, trySend, (uint64_t*), (override)); }; } // namespace Mock diff --git a/src/Receiver.cc b/src/Receiver.cc index c850d07..bbc685f 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -18,6 +18,7 @@ #include #include "Perf.h" +#include "Tub.h" #include "Util.h" namespace Homa { @@ -28,6 +29,8 @@ namespace Core { * * @param driver * The driver used to send and receive packets. + * @param mailboxDir + * The mailbox directory used to lookup message destination. * @param policyManager * Provides information about the grant and network priority policies. * @param messageTimeoutCycles @@ -37,15 +40,16 @@ namespace Core { * Number of cycles of inactivity to wait between requesting retransmission * of un-received parts of a message. */ -Receiver::Receiver(Driver* driver, Policy::Manager* policyManager, +Receiver::Receiver(Driver* driver, MailboxDir* mailboxDir, + Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles) : driver(driver) , policyManager(policyManager) + , mailboxDir(mailboxDir) , messageBuckets(messageTimeoutCycles, resendIntervalCycles) , schedulerMutex() , scheduledPeers() - , receivedMessages() - , granting() + , dontNeedGrants() , messageAllocator() {} @@ -57,8 +61,6 @@ Receiver::~Receiver() schedulerMutex.lock(); scheduledPeers.clear(); peerTable.clear(); - receivedMessages.mutex.lock(); - receivedMessages.queue.clear(); for (auto it = messageBuckets.buckets.begin(); it != messageBuckets.buckets.end(); ++it) { MessageBucket* bucket = *it; @@ -94,8 +96,9 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) Protocol::MessageId id = header->common.messageId; MessageBucket* bucket = messageBuckets.getBucket(id); - SpinLock::Lock lock_bucket(bucket->mutex); - Message* message = bucket->findMessage(id, lock_bucket); + Tub lock_bucket; + lock_bucket.construct(bucket->mutex); + Message* message = bucket->findMessage(id, *lock_bucket); if (message == nullptr) { // New message int messageLength = header->totalLength; @@ -143,6 +146,10 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) info->bytesRemaining -= packetDataBytes; updateSchedule(message, lock_scheduler); } + + // Non-duplicate DATA packets from scheduled messages can change + // the state of scheduledPeers; time to run trySendGrants() again + signalNeedGrants(lock_scheduler); } // Receiving a new packet means the message is still active so it @@ -155,16 +162,23 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) bucket->resendTimeouts.setTimeout(&message->resendTimeout); } else { // All message packets have been received. - message->state.store(Message::State::COMPLETED); + message->setState(Message::State::COMPLETED); bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - SpinLock::Lock lock_received_messages(receivedMessages.mutex); - receivedMessages.queue.push_back(&message->receivedMessageNode); + uint16_t dport = be16toh(header->common.prefix.dport); + Mailbox* mailbox = mailboxDir->open(dport); + if (mailbox) { + mailbox->deliver(message); + mailbox->close(); + } else { + lock_bucket.destroy(); + ERROR("Unable to deliver the message; message dropped"); + dropMessage(message); + } } } else { // must be a duplicate packet; drop packet. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); } - return; } /** @@ -187,11 +201,11 @@ Receiver::handleBusyPacket(Driver::Packet* packet) // Sender has replied BUSY to our RESEND request; consider this message // still active. bucket->messageTimeouts.setTimeout(&message->messageTimeout); - if (message->state == Message::State::IN_PROGRESS) { + if (message->getState() == Message::State::IN_PROGRESS) { bucket->resendTimeouts.setTimeout(&message->resendTimeout); } } - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); } /** @@ -245,41 +259,7 @@ Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) ControlPacket::send(driver, sourceIp, id); } - driver->releasePackets(&packet, 1); -} - -/** - * Return a handle to a new received Message. - * - * The Transport should regularly call this method to insure incoming messages - * are processed. - * - * @return - * A new Message which has been received, if available; otherwise, nullptr. - * - * @sa dropMessage() - */ -Homa::InMessage* -Receiver::receiveMessage() -{ - SpinLock::Lock lock_received_messages(receivedMessages.mutex); - Message* message = nullptr; - if (!receivedMessages.queue.empty()) { - message = &receivedMessages.queue.front(); - receivedMessages.queue.pop_front(); - } - return message; -} - -/** - * Allow the Receiver to make progress toward receiving incoming messages. - * - * This method must be called eagerly to ensure messages are received. - */ -void -Receiver::poll() -{ - trySendGrants(); + driver->releasePackets(packet, 1); } /** @@ -294,16 +274,13 @@ Receiver::poll() uint64_t Receiver::checkTimeouts() { - uint64_t nextTimeout; - // Ping Timeout - nextTimeout = checkResendTimeouts(); + uint64_t resendTimeout = checkResendTimeouts(); // Message Timeout uint64_t messageTimeout = checkMessageTimeouts(); - nextTimeout = nextTimeout < messageTimeout ? nextTimeout : messageTimeout; - return nextTimeout; + return std::min(resendTimeout, messageTimeout); } /** @@ -356,7 +333,7 @@ Receiver::Message::acknowledge() const bool Receiver::Message::dropped() const { - return state.load() == State::DROPPED; + return getState() == State::DROPPED; } /** @@ -398,7 +375,7 @@ Receiver::Message::get(size_t offset, void* destination, size_t count) const while (bytesCopied < _count) { uint32_t bytesToCopy = std::min(_count - bytesCopied, PACKET_DATA_LENGTH - packetOffset); - Driver::Packet* packet = getPacket(packetIndex); + const Driver::Packet* packet = getPacket(packetIndex); if (packet != nullptr) { char* source = static_cast(packet->payload); source += packetOffset + TRANSPORT_HEADER_LENGTH; @@ -416,6 +393,15 @@ Receiver::Message::get(size_t offset, void* destination, size_t count) const return bytesCopied; } +/** + * @copydoc Homa::InMessage::getSourceAddress() + */ +SocketAddress +Receiver::Message::getSourceAddress() const +{ + return source; +} + /** * @copydoc Homa::InMessage::length() */ @@ -452,11 +438,11 @@ Receiver::Message::release() * @return * Pointer to a Packet at the given index if it exists; nullptr otherwise. */ -Driver::Packet* +const Driver::Packet* Receiver::Message::getPacket(size_t index) const { if (occupied.test(index)) { - return packets[index]; + return &packets[index]; } return nullptr; } @@ -472,7 +458,7 @@ Receiver::Message::getPacket(size_t index) const * The Packet's index in the array of packets that form the message. * "packet index = "packet message offset" / PACKET_DATA_LENGTH * @param packet - * The packet pointer that should be stored. + * The packet that should be stored. * @return * True if the packet was stored; false if a packet already exists (the new * packet is not stored). @@ -483,12 +469,31 @@ Receiver::Message::setPacket(size_t index, Driver::Packet* packet) if (occupied.test(index)) { return false; } - packets[index] = packet; + packets[index] = *packet; occupied.set(index); numPackets++; return true; } +/** + * Clear the atomic _dontNeedGrants_ flag to indicate that trySendGrants() + * needs to run again. This method is called when the state of active messages + * in Receiver::scheduledPeers might have changed. + * + * Note: we require the caller to hold the schedulerMutex during this call + * because it becomes much easier to reason about the interaction between + * the atomic flag and the mutex this way (and it's essentially free). + * + * @param lockHeld + * Reminder to hold the Receiver::schedulerMutex during this call. + */ +void +Receiver::signalNeedGrants(const SpinLock::Lock& lockHeld) +{ + (void)lockHeld; + dontNeedGrants.clear(std::memory_order_release); +} + /** * Inform the Receiver that an Message returned by receiveMessage() is not * needed and can be dropped. @@ -564,7 +569,7 @@ Receiver::checkMessageTimeouts() bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - if (message->state == Message::State::IN_PROGRESS) { + if (message->getState() == Message::State::IN_PROGRESS) { // Message timed out before being fully received; drop the // message. @@ -587,7 +592,7 @@ Receiver::checkMessageTimeouts() } else { // Message timed out but we already made it available to the // Transport; let the Transport know. - message->state.store(Message::State::DROPPED); + message->setState(Message::State::DROPPED); } } globalNextTimeout = std::min(globalNextTimeout, nextTimeout); @@ -629,7 +634,7 @@ Receiver::checkResendTimeouts() } // Found expired timeout. - assert(message->state == Message::State::IN_PROGRESS); + assert(message->getState() == Message::State::IN_PROGRESS); bucket->resendTimeouts.setTimeout(&message->resendTimeout); // This Receiver expected to have heard from the Sender within the @@ -705,23 +710,30 @@ Receiver::checkResendTimeouts() /** * Send GRANTs to incoming Message according to the Receiver's policy. + * + * This method must be called eagerly to allow the Receiver to make progress + * toward receiving incoming messages. + * + * @return + * True if the method has found some messages to grant; false, otherwise. */ -void +bool Receiver::trySendGrants() { uint64_t start_tsc = PerfUtils::Cycles::rdtsc(); - bool idle = true; - // Skip scheduling if another poller is already working on it. - if (granting.test_and_set()) { - return; + // Fast path: skip if no message is waiting for grants + bool needGrants = !dontNeedGrants.test_and_set(); + if (!needGrants) { + return false; } + /* It's possible to have a benign race-condition here when another thread + * acquires the schedulerMutex before us and sets _dontNeedGrants_ back to + * false via signalNeedGrants. As a result, _dontNeedGrants_ will stay false + * when the method returns although all messages have been granted. + */ SpinLock::Lock lock(schedulerMutex); - if (scheduledPeers.empty()) { - granting.clear(); - return; - } /* The overall goal is to grant up to policy.degreeOvercommitment number of * scheduled messages simultaneously. Each of these messages should always @@ -743,6 +755,7 @@ Receiver::trySendGrants() auto it = scheduledPeers.begin(); int slot = 0; + bool foundWork = false; while (it != scheduledPeers.end() && slot < policy.degreeOvercommitment) { assert(!it->scheduledMessages.empty()); Message* message = &it->scheduledMessages.front(); @@ -758,7 +771,6 @@ Receiver::trySendGrants() // Send a GRANT if there are too few bytes granted and unreceived. int receivedBytes = info->messageLength - info->bytesRemaining; if (info->bytesGranted - receivedBytes < policy.minScheduledBytes) { - idle = false; // Calculate new grant limit int newGrantLimit = std::min( receivedBytes + policy.maxScheduledBytes, info->messageLength); @@ -768,6 +780,7 @@ Receiver::trySendGrants() ControlPacket::send( driver, sourceIp, id, Util::downCast(info->bytesGranted), info->priority); + foundWork = true; } // Update the iterator first since calling unschedule() may cause the @@ -782,14 +795,13 @@ Receiver::trySendGrants() ++slot; } - granting.clear(); - uint64_t elapsed_cycles = PerfUtils::Cycles::rdtsc() - start_tsc; - if (!idle) { + if (foundWork) { Perf::counters.active_cycles.add(elapsed_cycles); } else { Perf::counters.idle_cycles.add(elapsed_cycles); } + return foundWork; } /** @@ -872,6 +884,9 @@ Receiver::unschedule(Receiver::Message* message, const SpinLock::Lock& lock) Intrusive::deprioritize(&scheduledPeers, &peer->scheduledPeerNode, comp); } + + // scheduledPeers has been updated; time to run trySendGrants() again + signalNeedGrants(lock); } /** diff --git a/src/Receiver.h b/src/Receiver.h index 65e65ff..78f6f0f 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -43,16 +44,16 @@ namespace Core { */ class Receiver { public: - explicit Receiver(Driver* driver, Policy::Manager* policyManager, + explicit Receiver(Driver* driver, MailboxDir* mailboxDir, + Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles); virtual ~Receiver(); virtual void handleDataPacket(Driver::Packet* packet, IpAddress sourceIp); virtual void handleBusyPacket(Driver::Packet* packet); virtual void handlePingPacket(Driver::Packet* packet, IpAddress sourceIp); - virtual Homa::InMessage* receiveMessage(); - virtual void poll(); virtual uint64_t checkTimeouts(); + virtual bool trySendGrants(); private: // Forward declaration @@ -118,7 +119,7 @@ class Receiver { * Represents an incoming message that is being assembled or being processed * by the application. */ - class Message : public Homa::InMessage { + class Message final : public Homa::InMessage { public: /** * Defines the possible states of this Message. @@ -154,7 +155,6 @@ class Receiver { // construction. See Message::occupied. , state(Message::State::IN_PROGRESS) , bucketNode(this) - , receivedMessageNode(this) , messageTimeout(this) , resendTimeout(this) , scheduledMessageInfo(this, messageLength) @@ -166,23 +166,35 @@ class Receiver { virtual void fail() const; virtual size_t get(size_t offset, void* destination, size_t count) const; + virtual SocketAddress getSourceAddress() const; virtual size_t length() const; virtual void strip(size_t count); virtual void release(); + private: /** * Return the current state of this message. */ State getState() const { - return state.load(); + return state.load(std::memory_order_acquire); + } + + /** + * Change the current state of this message. + * + * @param newState + * The new state of the message + */ + void setState(State newState) + { + state.store(newState, std::memory_order_release); } - private: /// Define the maximum number of packets that a message can hold. static const int MAX_MESSAGE_PACKETS = 1024; - Driver::Packet* getPacket(size_t index) const; + const Driver::Packet* getPacket(size_t index) const; bool setPacket(size_t index, Driver::Packet* packet); /// The Receiver responsible for this message. @@ -230,7 +242,7 @@ class Receiver { /// Collection of Packet objects that make up this context's Message. /// These Packets will be released when this context is destroyed. - Driver::Packet* packets[MAX_MESSAGE_PACKETS]; + Driver::Packet packets[MAX_MESSAGE_PACKETS]; /// This message's current state. std::atomic state; @@ -240,10 +252,6 @@ class Receiver { /// is protected by the associated MessageBucket::mutex; Intrusive::List::Node bucketNode; - /// Intrusive structure used by the Receiver to keep track of this - /// message when it has been completely received. - Intrusive::List::Node receivedMessageNode; - /// Intrusive structure used by the Receiver to keep track when the /// receiving of this message should be considered failed. Timeout messageTimeout; @@ -450,10 +458,10 @@ class Receiver { Intrusive::List::Node scheduledPeerNode; }; + void signalNeedGrants(const SpinLock::Lock& lockHeld); void dropMessage(Receiver::Message* message); uint64_t checkMessageTimeouts(); uint64_t checkResendTimeouts(); - void trySendGrants(); void schedule(Message* message, const SpinLock::Lock& lock); void unschedule(Message* message, const SpinLock::Lock& lock); void updateSchedule(Message* message, const SpinLock::Lock& lock); @@ -465,6 +473,9 @@ class Receiver { /// Provider of network packet priority and grant policy decisions. Policy::Manager* const policyManager; + /// Records where to deliver the messages when they are completed. + MailboxDir* const mailboxDir; + /// Tracks the set of inbound messages being received by this Receiver. MessageBucketMap messageBuckets; @@ -480,18 +491,10 @@ class Receiver { /// Access is protected by the schedulerMutex. Intrusive::List scheduledPeers; - /// Message objects to be processed by the transport. - struct { - /// Protects the receivedMessage.queue - SpinLock mutex; - /// List of completely received messages. - Intrusive::List queue; - } receivedMessages; - - /// True if the Receiver is executing trySendGrants(); false, otherwise. - /// Used to prevent concurrent calls to trySendGrants() from blocking on - /// each other. - std::atomic_flag granting = ATOMIC_FLAG_INIT; + /// Hint whether there MIGHT be messages that need to be granted. Encoded + /// into an atomic bool so that checking if there is work to do can be done + /// efficiently without acquiring the schedulerMutex first. + std::atomic_flag dontNeedGrants; /// Used to allocate Message objects. struct { diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index bfccc39..aaf2ef8 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include @@ -34,7 +35,6 @@ using ::testing::InSequence; using ::testing::Matcher; using ::testing::Mock; using ::testing::NiceMock; -using ::testing::Pointee; using ::testing::Return; /// Helper macro to construct an IpAddress from a numeric number. @@ -44,21 +44,42 @@ using ::testing::Return; x \ } +/** + * Defines a matcher EqPacket(p) to match two Driver::Packet* by their + * underlying packet buffer descriptors. + */ +MATCHER_P(EqPacket, p, "") +{ + return arg->descriptor == p->descriptor; +} + +/** + * Defines a matcher EqPacketLen(p) to match a Driver::Packet* by its length. + */ +MATCHER_P(EqPacketLen, length, "") +{ + return arg->length == length; +} + class ReceiverTest : public ::testing::Test { public: ReceiverTest() : mockDriver() - , mockPacket{&payload} + , mockPacket() , mockPolicyManager(&mockDriver) + , mailboxDir() + , mailbox(mailboxDir.alloc(60001)) , payload() + , packetBuf{&payload} , receiver() , savedLogPolicy(Debug::getLogPolicy()) { ON_CALL(mockDriver, getBandwidth).WillByDefault(Return(8000)); ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1027)); + mockPacket = packetBuf.toPacket(); Debug::setLogPolicy( Debug::logPolicyFromString("src/ObjectPool@SILENT")); - receiver = new Receiver(&mockDriver, &mockPolicyManager, + receiver = new Receiver(&mockDriver, &mailboxDir, &mockPolicyManager, messageTimeoutCycles, resendIntervalCycles); PerfUtils::Cycles::mockTscValue = 10000; } @@ -75,9 +96,12 @@ class ReceiverTest : public ::testing::Test { static const uint64_t resendIntervalCycles = 100; NiceMock mockDriver; - Homa::Mock::MockDriver::MockPacket mockPacket; + Driver::Packet mockPacket; NiceMock mockPolicyManager; + SimpleMailboxDir mailboxDir; + Mailbox* mailbox; char payload[1028]; + Homa::Mock::MockDriver::PacketBuf packetBuf; Receiver* receiver; std::vector> savedLogPolicy; }; @@ -105,13 +129,10 @@ TEST_F(ReceiverTest, handleDataPacket) Receiver::ScheduledMessageInfo* info = nullptr; Receiver::MessageBucket* bucket = receiver->messageBuckets.getBucket(id); + new (mockPacket.payload) Protocol::Packet::DataHeader( + 0, 60001, id, totalMessageLength, policyVersion, 1, 0); Protocol::Packet::DataHeader* header = static_cast(mockPacket.payload); - header->common.opcode = Protocol::Packet::DATA; - header->common.messageId = id; - header->totalLength = totalMessageLength; - header->policyVersion = policyVersion; - header->unscheduledIndexLimit = 1; IpAddress sourceIp{22}; // ------------------------------------------------------------------------- @@ -122,8 +143,7 @@ TEST_F(ReceiverTest, handleDataPacket) signalNewMessage(Eq(sourceIp), Eq(policyVersion), Eq(totalMessageLength))) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(0); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(0); // TEST CALL receiver->handleDataPacket(&mockPacket, sourceIp); @@ -151,8 +171,7 @@ TEST_F(ReceiverTest, handleDataPacket) // ------------------------------------------------------------------------- // Receive packet[1]. Duplicate. - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); // TEST CALL receiver->handleDataPacket(&mockPacket, sourceIp); @@ -165,8 +184,7 @@ TEST_F(ReceiverTest, handleDataPacket) // Receive packet[2]. header->index = 2; mockPacket.length = HEADER_SIZE + 1000; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(0); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(0); // TEST CALL receiver->handleDataPacket(&mockPacket, sourceIp); @@ -180,8 +198,7 @@ TEST_F(ReceiverTest, handleDataPacket) // Receive packet[3]. header->index = 3; mockPacket.length = HEADER_SIZE + 500; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(0); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(0); // TEST CALL receiver->handleDataPacket(&mockPacket, sourceIp); @@ -195,8 +212,7 @@ TEST_F(ReceiverTest, handleDataPacket) // Receive packet[0]. Finished. header->index = 0; mockPacket.length = HEADER_SIZE + 1000; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(0); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(0); // TEST CALL receiver->handleDataPacket(&mockPacket, sourceIp); @@ -205,13 +221,12 @@ TEST_F(ReceiverTest, handleDataPacket) EXPECT_EQ(4U, message->numPackets); EXPECT_EQ(0U, info->bytesRemaining); EXPECT_EQ(Receiver::Message::State::COMPLETED, message->state); - EXPECT_EQ(message, &receiver->receivedMessages.queue.back()); + EXPECT_EQ(message, mailbox->retrieve(false)); Mock::VerifyAndClearExpectations(&mockDriver); // ------------------------------------------------------------------------- // Receive packet[0]. Already finished. - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); // TEST CALL receiver->handleDataPacket(&mockPacket, sourceIp); @@ -232,8 +247,7 @@ TEST_F(ReceiverTest, handleBusyPacket_basic) (Protocol::Packet::BusyHeader*)mockPacket.payload; busyHeader->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); receiver->handleBusyPacket(&mockPacket); @@ -249,8 +263,7 @@ TEST_F(ReceiverTest, handleBusyPacket_unknown) (Protocol::Packet::BusyHeader*)mockPacket.payload; busyHeader->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); receiver->handleBusyPacket(&mockPacket); } @@ -270,18 +283,20 @@ TEST_F(ReceiverTest, handlePingPacket_basic) bucket->messages.push_back(&message->bucketNode); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket{pingPayload}; + Homa::Mock::MockDriver::PacketBuf pingPacketBuf{pingPayload}; + Driver::Packet pingPacket = pingPacketBuf.toPacket(); IpAddress sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = - (Protocol::Packet::PingHeader*)pingPacket.payload; + (Protocol::Packet::PingHeader*)pingPacketBuf.buffer; pingHeader->common.messageId = id; - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(mockAddress), _)) + EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, + sendPacket(EqPacket(&mockPacket), Eq(mockAddress), _)) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&pingPacket), Eq(1))) .Times(1); receiver->handlePingPacket(&pingPacket, sourceIp); @@ -302,18 +317,20 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) Protocol::MessageId id(42, 32); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket{pingPayload}; + Homa::Mock::MockDriver::PacketBuf pingPacketBuf{pingPayload}; + Driver::Packet pingPacket = pingPacketBuf.toPacket(); IpAddress mockAddress{22}; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(mockAddress), _)) + EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, + sendPacket(EqPacket(&mockPacket), Eq(mockAddress), _)) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&pingPacket), Eq(1))) .Times(1); receiver->handlePingPacket(&pingPacket, mockAddress); @@ -324,35 +341,6 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) EXPECT_EQ(id, header->common.messageId); } -TEST_F(ReceiverTest, receiveMessage) -{ - Receiver::Message* msg0 = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, Protocol::MessageId(42, 0), - SocketAddress{22, 60001}, 0); - Receiver::Message* msg1 = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, Protocol::MessageId(42, 0), - SocketAddress{22, 60001}, 0); - - receiver->receivedMessages.queue.push_back(&msg0->receivedMessageNode); - receiver->receivedMessages.queue.push_back(&msg1->receivedMessageNode); - EXPECT_FALSE(receiver->receivedMessages.queue.empty()); - - EXPECT_EQ(msg0, receiver->receiveMessage()); - EXPECT_FALSE(receiver->receivedMessages.queue.empty()); - - EXPECT_EQ(msg1, receiver->receiveMessage()); - EXPECT_TRUE(receiver->receivedMessages.queue.empty()); - - EXPECT_EQ(nullptr, receiver->receiveMessage()); - EXPECT_TRUE(receiver->receivedMessages.queue.empty()); -} - -TEST_F(ReceiverTest, poll) -{ - // Nothing to test - receiver->poll(); -} - TEST_F(ReceiverTest, checkTimeouts) { Receiver::Message message(receiver, &mockDriver, 0, 0, @@ -424,11 +412,12 @@ TEST_F(ReceiverTest, Message_acknowledge) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); + EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(mockPacket)); EXPECT_CALL(mockDriver, - sendPacket(Eq(&mockPacket), Eq(message->source.ip), _)) + sendPacket(EqPacketLen(sizeof(Protocol::Packet::DoneHeader)), + Eq(message->source.ip), _)) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); message->acknowledge(); @@ -437,7 +426,6 @@ TEST_F(ReceiverTest, Message_acknowledge) static_cast(mockPacket.payload); EXPECT_EQ(Protocol::Packet::DONE, header->opcode); EXPECT_EQ(id, header->messageId); - EXPECT_EQ(sizeof(Protocol::Packet::DoneHeader), mockPacket.length); } TEST_F(ReceiverTest, Message_dropped) @@ -461,11 +449,12 @@ TEST_F(ReceiverTest, Message_fail) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); + EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(mockPacket)); EXPECT_CALL(mockDriver, - sendPacket(Eq(&mockPacket), Eq(message->source.ip), _)) + sendPacket(EqPacketLen(sizeof(Protocol::Packet::ErrorHeader)), + Eq(message->source.ip), _)) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); message->fail(); @@ -474,7 +463,6 @@ TEST_F(ReceiverTest, Message_fail) static_cast(mockPacket.payload); EXPECT_EQ(Protocol::Packet::ERROR, header->opcode); EXPECT_EQ(id, header->messageId); - EXPECT_EQ(sizeof(Protocol::Packet::ErrorHeader), mockPacket.length); } TEST_F(ReceiverTest, Message_get_basic) @@ -484,8 +472,10 @@ TEST_F(ReceiverTest, Message_get_basic) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); char source[] = "Hello, world!"; message->setPacket(0, &packet0); @@ -511,8 +501,10 @@ TEST_F(ReceiverTest, Message_get_offsetTooLarge) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); message->setPacket(0, &packet0); message->setPacket(1, &packet1); @@ -537,8 +529,10 @@ TEST_F(ReceiverTest, Message_get_missingPacket) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); char source[] = "Hello,"; message->setPacket(0, &packet0); @@ -601,8 +595,8 @@ TEST_F(ReceiverTest, Message_getPacket) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); - Driver::Packet* packet = (Driver::Packet*)42; - message->packets[0] = packet; + message->packets[0] = {}; + Driver::Packet* packet = &message->packets[0]; EXPECT_EQ(nullptr, message->getPacket(0)); @@ -616,14 +610,15 @@ TEST_F(ReceiverTest, Message_setPacket) Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); - Driver::Packet* packet = (Driver::Packet*)42; + Driver::Packet* packet = &mockPacket; EXPECT_FALSE(message->occupied.test(0)); EXPECT_EQ(0U, message->numPackets); EXPECT_TRUE(message->setPacket(0, packet)); - EXPECT_EQ(packet, message->packets[0]); + EXPECT_EQ(packet->descriptor, message->packets[0].descriptor); + EXPECT_EQ(packet->payload, message->packets[0].payload); EXPECT_TRUE(message->occupied.test(0)); EXPECT_EQ(1U, message->numPackets); @@ -813,21 +808,21 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) char buf1[1024]; char buf2[1024]; - Homa::Mock::MockDriver::MockPacket mockResendPacket1{buf1}; - Homa::Mock::MockDriver::MockPacket mockResendPacket2{buf2}; + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf1}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf2}; + Driver::Packet mockResendPacket1 = packetBuf0.toPacket(); + Driver::Packet mockResendPacket2 = packetBuf1.toPacket(); + const size_t RESEND_HEADER_LEN = sizeof(Protocol::Packet::ResendHeader); EXPECT_CALL(mockDriver, allocPacket()) - .WillOnce(Return(&mockResendPacket1)) - .WillOnce(Return(&mockResendPacket2)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket1), + .WillOnce(Return(mockResendPacket1)) + .WillOnce(Return(mockResendPacket2)); + EXPECT_CALL(mockDriver, sendPacket(EqPacketLen(RESEND_HEADER_LEN), Eq(message[0]->source.ip), _)) + .Times(2); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockResendPacket1), Eq(1))) .Times(1); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket2), - Eq(message[0]->source.ip), _)) - .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket1), Eq(1))) - .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket2), Eq(1))) + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockResendPacket2), Eq(1))) .Times(1); // TEST CALL @@ -843,14 +838,12 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) EXPECT_EQ(message[0]->id, header1->common.messageId); EXPECT_EQ(2U, header1->index); EXPECT_EQ(4U, header1->num); - EXPECT_EQ(sizeof(Protocol::Packet::ResendHeader), mockResendPacket1.length); Protocol::Packet::ResendHeader* header2 = static_cast(mockResendPacket2.payload); EXPECT_EQ(Protocol::Packet::RESEND, header2->common.opcode); EXPECT_EQ(message[0]->id, header2->common.messageId); EXPECT_EQ(8U, header2->index); EXPECT_EQ(2U, header2->num); - EXPECT_EQ(sizeof(Protocol::Packet::ResendHeader), mockResendPacket2.length); // Message[1]: Blocked on grants EXPECT_EQ(10100, message[1]->resendTimeout.expirationCycleTime); @@ -906,11 +899,12 @@ TEST_F(ReceiverTest, trySendGrants) info[0]->bytesRemaining -= 1000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), _, _)).Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); + receiver->dontNeedGrants.clear(); receiver->trySendGrants(); EXPECT_EQ(1, info[0]->priority); @@ -932,11 +926,12 @@ TEST_F(ReceiverTest, trySendGrants) info[1]->bytesRemaining -= 1000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), _, _)).Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); + receiver->dontNeedGrants.clear(); receiver->trySendGrants(); EXPECT_EQ(0, info[1]->priority); @@ -956,6 +951,7 @@ TEST_F(ReceiverTest, trySendGrants) .WillOnce(Return(policy)); EXPECT_CALL(mockDriver, sendPacket(_, _, _)).Times(0); + receiver->dontNeedGrants.clear(); receiver->trySendGrants(); EXPECT_EQ(1, info[1]->priority); @@ -975,6 +971,7 @@ TEST_F(ReceiverTest, trySendGrants) .WillOnce(Return(policy)); EXPECT_CALL(mockDriver, sendPacket(_, _, _)).Times(0); + receiver->dontNeedGrants.clear(); receiver->trySendGrants(); EXPECT_EQ(2, info[1]->priority); diff --git a/src/Sender.cc b/src/Sender.cc index ea75bf4..6b25f8d 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -50,11 +50,13 @@ Sender::Sender(uint64_t transportId, Driver* driver, , policyManager(policyManager) , nextMessageSequenceNumber(1) , DRIVER_QUEUED_BYTE_LIMIT(2 * driver->getMaxPayloadSize()) + , DRIVER_CYCLES_TO_DRAIN_1MB(PerfUtils::Cycles::fromSeconds(1) * 8 / + driver->getBandwidth()) , messageBuckets(messageTimeoutCycles, pingIntervalCycles) , queueMutex() - , sendQueue() - , sending() , sendReady(false) + , notifySendReady() + , sendQueue() , messageAllocator() {} @@ -92,7 +94,7 @@ Sender::handleDonePacket(Driver::Packet* packet) if (message == nullptr) { // No message for this DONE packet; must be old. Just drop it. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } @@ -103,7 +105,7 @@ Sender::handleDonePacket(Driver::Packet* packet) // Expected behavior bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::COMPLETED); + message->setStatus(OutMessage::Status::COMPLETED); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the DONE. @@ -141,7 +143,7 @@ Sender::handleDonePacket(Driver::Packet* packet) break; } - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); } /** @@ -167,7 +169,7 @@ Sender::handleResendPacket(Driver::Packet* packet) if (message == nullptr) { // No message for this RESEND; RESEND must be old. Just ignore it; this // case should be pretty rare and the Receiver will timeout eventually. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } else if (message->numPackets < 2) { // We should never get a RESEND for a single packet message. Just @@ -176,7 +178,7 @@ Sender::handleResendPacket(Driver::Packet* packet) "Message (%lu, %lu) with only 1 packet received unexpected RESEND " "request; peer Transport may be confused.", msgId.transportId, msgId.sequence); - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } @@ -195,7 +197,7 @@ Sender::handleResendPacket(Driver::Packet* packet) "may be confused.", msgId.transportId, msgId.sequence, index, resendEnd, info->packets->numPackets); - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } @@ -206,7 +208,7 @@ Sender::handleResendPacket(Driver::Packet* packet) // will never be overridden since the resend index will not exceed the // preset packetsGranted. info->priority = header->priority; - sendReady.store(true); + signalPacerThread(lock_queue); } if (index >= info->packetsSent) { @@ -231,7 +233,7 @@ Sender::handleResendPacket(Driver::Packet* packet) } } - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); } /** @@ -252,14 +254,14 @@ Sender::handleGrantPacket(Driver::Packet* packet) Message* message = bucket->findMessage(msgId, lock); if (message == nullptr) { // No message for this grant; grant must be old. Just drop it. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); - if (message->state.load() == OutMessage::Status::IN_PROGRESS) { + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -288,11 +290,11 @@ Sender::handleGrantPacket(Driver::Packet* packet) // limit will never be overridden since the incomingGrantIndex will // not exceed the preset packetsGranted. info->priority = header->priority; - sendReady.store(true); + signalPacerThread(lock_queue); } } - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); } /** @@ -314,7 +316,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet) if (message == nullptr) { // No message was found. Just drop the packet. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } @@ -334,7 +336,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet) if (message->numPackets > 1) { SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; - if (message->state == OutMessage::Status::IN_PROGRESS) { + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { assert(sendQueue.contains(&info->sendQueueNode)); sendQueue.remove(&info->sendQueueNode); } @@ -343,7 +345,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet) bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::FAILED); + message->setStatus(OutMessage::Status::FAILED); } else { // Message isn't done yet so we will restart sending the message. @@ -352,14 +354,14 @@ Sender::handleUnknownPacket(Driver::Packet* packet) if (message->numPackets > 1) { SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; - if (message->state == OutMessage::Status::IN_PROGRESS) { + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { assert(sendQueue.contains(&info->sendQueueNode)); sendQueue.remove(&info->sendQueueNode); } assert(!sendQueue.contains(&info->sendQueueNode)); } - message->state.store(OutMessage::Status::IN_PROGRESS); + message->setStatus(OutMessage::Status::IN_PROGRESS); // Get the current policy for unscheduled bytes. Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( @@ -392,7 +394,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet) Perf::counters.tx_bytes.add(dataPacket->length); driver->sendPacket(dataPacket, message->destination.ip, policy.priority); - message->state.store(OutMessage::Status::SENT); + message->setStatus(OutMessage::Status::SENT); } else { // Otherwise, queue the message to be sent in SRPT order. SpinLock::Lock lock_queue(queueMutex); @@ -415,11 +417,11 @@ Sender::handleUnknownPacket(Driver::Packet* packet) Intrusive::deprioritize( &sendQueue, &info->sendQueueNode, QueuedMessageInfo::ComparePriority()); - sendReady.store(true); + signalPacerThread(lock_queue); } } - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); } /** @@ -440,7 +442,7 @@ Sender::handleErrorPacket(Driver::Packet* packet) Message* message = bucket->findMessage(msgId, lock); if (message == nullptr) { // No message for this ERROR packet; must be old. Just drop it. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } @@ -450,7 +452,7 @@ Sender::handleErrorPacket(Driver::Packet* packet) // Message was sent and a failure notification was received. bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::FAILED); + message->setStatus(OutMessage::Status::FAILED); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the ERROR. @@ -488,18 +490,7 @@ Sender::handleErrorPacket(Driver::Packet* packet) break; } - driver->releasePackets(&packet, 1); -} - -/** - * Allow the Sender to make progress toward sending outgoing messages. - * - * This method must be called eagerly to ensure messages are sent. - */ -void -Sender::poll() -{ - trySend(); + driver->releasePackets(packet, 1); } /** @@ -588,7 +579,29 @@ Sender::Message::cancel() OutMessage::Status Sender::Message::getStatus() const { - return state.load(); + return state.load(std::memory_order_acquire); +} + +/** + * Change the current state of this message and invoke callback if necessary. + * + * @param newStatus + * The new state of the message + */ +void +Sender::Message::setStatus(OutMessage::Status newStatus) +{ + state.store(newStatus, std::memory_order_release); + if (notifyEndState) { + switch (newStatus) { + case OutMessage::Status::CANCELED: + case OutMessage::Status::COMPLETED: + case OutMessage::Status::FAILED: + notifyEndState(); + default: + break; + } + } } /** @@ -630,6 +643,15 @@ Sender::Message::prepend(const void* source, size_t count) } } +/** + * @copydoc Homa::OutMessage::registerCallbackEndState() + */ +void +Sender::Message::registerCallbackEndState(Callback func) +{ + notifyEndState = std::move(func); +} + /** * @copydoc Homa::OutMessage::release() */ @@ -698,10 +720,10 @@ Sender::Message::send(SocketAddress destination, * Pointer to a Packet at the given index if it exists; nullptr otherwise. */ Driver::Packet* -Sender::Message::getPacket(size_t index) const +Sender::Message::getPacket(size_t index) { if (occupied.test(index)) { - return packets[index]; + return &packets[index]; } return nullptr; } @@ -725,9 +747,9 @@ Sender::Message::getOrAllocPacket(size_t index) numPackets++; // TODO(cstlee): A Message probably shouldn't be in charge of setting // the packet length. - packets[index]->length = TRANSPORT_HEADER_LENGTH; + packets[index].length = TRANSPORT_HEADER_LENGTH; } - return packets[index]; + return &packets[index]; } /** @@ -760,7 +782,7 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, message->id = id; message->destination = destination; message->options = options; - message->state.store(OutMessage::Status::IN_PROGRESS); + message->setStatus(OutMessage::Status::IN_PROGRESS); int actualMessageLen = 0; // fill out metadata. @@ -804,7 +826,7 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); driver->sendPacket(packet, message->destination.ip, policy.priority); - message->state.store(OutMessage::Status::SENT); + message->setStatus(OutMessage::Status::SENT); } else { // Otherwise, queue the message to be sent in SRPT order. SpinLock::Lock lock_queue(queueMutex); @@ -821,7 +843,7 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, sendQueue.push_front(&info->sendQueueNode); Intrusive::deprioritize(&sendQueue, &info->sendQueueNode, QueuedMessageInfo::ComparePriority()); - sendReady.store(true); + signalPacerThread(lock_queue); } } @@ -840,19 +862,21 @@ Sender::cancelMessage(Sender::Message* message) if (bucket->messages.contains(&message->bucketNode)) { bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - if (message->numPackets > 1 && - message->state == OutMessage::Status::IN_PROGRESS) { - // Check to see if the message needs to be dequeued. + + // Check to see if the message needs to be dequeued. In order to reduce + // cache misses related to queueMutex, check the status first to avoid + // unnecessary locking. + OutMessage::Status status = message->getStatus(); + if ((status == OutMessage::Status::IN_PROGRESS) || + (status == OutMessage::Status::FAILED)) { SpinLock::Lock lock_queue(queueMutex); - // Recheck state with lock in case it change right before this. - if (message->state == OutMessage::Status::IN_PROGRESS) { - QueuedMessageInfo* info = &message->queuedMessageInfo; - assert(sendQueue.contains(&info->sendQueueNode)); + QueuedMessageInfo* info = &message->queuedMessageInfo; + if (sendQueue.contains(&info->sendQueueNode)) { sendQueue.remove(&info->sendQueueNode); } } bucket->messages.remove(&message->bucketNode); - message->state.store(OutMessage::Status::CANCELED); + message->setStatus(OutMessage::Status::CANCELED); } } @@ -902,13 +926,13 @@ Sender::checkMessageTimeouts() break; } // Found expired timeout. - if (message->state != OutMessage::Status::COMPLETED) { - message->state.store(OutMessage::Status::FAILED); + if (message->getStatus() != OutMessage::Status::COMPLETED) { + message->setStatus(OutMessage::Status::FAILED); // A sent NO_KEEP_ALIVE message should never reach this state // since the shorter ping timeout should have already canceled // the message timeout. assert( - !((message->state == OutMessage::Status::SENT) && + !((message->getStatus() == OutMessage::Status::SENT) && (message->options & OutMessage::Options::NO_KEEP_ALIVE))); } bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); @@ -951,12 +975,12 @@ Sender::checkPingTimeouts() break; } // Found expired timeout. - if (message->state == OutMessage::Status::COMPLETED || - message->state == OutMessage::Status::FAILED) { + if (message->getStatus() == OutMessage::Status::COMPLETED || + message->getStatus() == OutMessage::Status::FAILED) { bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); continue; } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE && - message->state == OutMessage::Status::SENT) { + message->getStatus() == OutMessage::Status::SENT) { bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); continue; @@ -975,22 +999,59 @@ Sender::checkPingTimeouts() return globalNextTimeout; } +/// See Homa::Transport::registerCallbackSendReady() +void +Sender::registerCallbackSendReady(Callback func) +{ + notifySendReady = std::move(func); +} + /** - * Send out packets for any messages with unscheduled/granted bytes. + * Attempt to wake up the pacer thread that is responsible for calling trySend() + * repeatedly, if it's currently blocked waiting for packets to become ready to + * be sent. + * + * This method is called when new GRANTs arrive, when new outgoing messages + * appear, and when retransmission is requested. + * + * @param lockHeld + * Reminder to hold the Sender::queueMutex during this call. */ void -Sender::trySend() +Sender::signalPacerThread(const SpinLock::Lock& lockHeld) +{ + (void)lockHeld; + sendReady = true; + if (notifySendReady) { + notifySendReady(); + } +} + +/** + * Attempt to send out packets for any messages with unscheduled/granted bytes + * in a way that limits queue buildup in the NIC. + * + * This method must be called eagerly to allow the Sender to make progress + * toward sending outgoing messages. + * + * @param[out] waitUntil + * Time to wait before next call, in microseconds, in order to allow + * the NIC transmit queue to drain. Only set when this method returns + * true. + * @return + * True if more packets are ready to be transmitted when the method + * returns; false, otherwise. + */ +bool +Sender::trySend(uint64_t* waitUntil) { uint64_t start_tsc = PerfUtils::Cycles::rdtsc(); bool idle = true; + // Skip when there are no messages to send. + SpinLock::UniqueLock lock_queue(queueMutex); if (!sendReady) { - return; - } - - // Skip sending if another thread is already working on it. - if (sending.test_and_set()) { - return; + return false; } /* The goal is to send out packets for messages that have bytes that have @@ -999,7 +1060,6 @@ Sender::trySend() * Each time this method is called we will try to send enough packet to keep * the NIC busy but not too many as to cause excessive queue in the NIC. */ - SpinLock::UniqueLock lock_queue(queueMutex); uint32_t queuedBytesEstimate = driver->getQueuedBytes(); // Optimistically assume we will finish sending every granted packet this // round; we will set again sendReady if it turns out we don't finish. @@ -1007,7 +1067,7 @@ Sender::trySend() auto it = sendQueue.begin(); while (it != sendQueue.end()) { Message& message = *it; - assert(message.state.load() == OutMessage::Status::IN_PROGRESS); + assert(message.getStatus() == OutMessage::Status::IN_PROGRESS); QueuedMessageInfo* info = &message.queuedMessageInfo; assert(info->packetsGranted <= info->packets->numPackets); while (info->packetsSent < info->packetsGranted) { @@ -1038,26 +1098,32 @@ Sender::trySend() } if (info->packetsSent >= info->packets->numPackets) { // We have finished sending the message. - message.state.store(OutMessage::Status::SENT); + message.setStatus(OutMessage::Status::SENT); it = sendQueue.remove(it); } else if (info->packetsSent >= info->packetsGranted) { // We have sent every granted packet. ++it; } else { - // We hit the DRIVER_QUEUED_BYTES_LIMIT; stop sending for now. + // We hit the DRIVER_QUEUED_BYTE_LIMIT; stop sending for now. // We didn't finish sending all granted packets. sendReady = true; + // Compute how much time the driver needs to drain its queue, + // then schedule to wake up a bit earlier to avoid blowing bubbles. + static const uint64_t us = PerfUtils::Cycles::fromMicroseconds(1); + *waitUntil = + PerfUtils::Cycles::rdtsc() - 1 * us + + queuedBytesEstimate * DRIVER_CYCLES_TO_DRAIN_1MB / 1000000; break; } } - sending.clear(); uint64_t elapsed_cycles = PerfUtils::Cycles::rdtsc() - start_tsc; if (!idle) { Perf::counters.active_cycles.add(elapsed_cycles); } else { Perf::counters.idle_cycles.add(elapsed_cycles); } + return sendReady; } } // namespace Core diff --git a/src/Sender.h b/src/Sender.h index faa5dee..b8e56df 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -21,7 +21,7 @@ #include #include -#include +#include #include "Intrusive.h" #include "ObjectPool.h" @@ -52,8 +52,9 @@ class Sender { virtual void handleGrantPacket(Driver::Packet* packet); virtual void handleUnknownPacket(Driver::Packet* packet); virtual void handleErrorPacket(Driver::Packet* packet); - virtual void poll(); virtual uint64_t checkTimeouts(); + virtual void registerCallbackSendReady(Callback func); + virtual bool trySend(uint64_t* waitUntil); private: /// Forward declarations @@ -126,7 +127,7 @@ class Sender { * Sender::Message objects are contained in the Transport::Op but should * only be accessed by the Sender. */ - class Message : public Homa::OutMessage { + class Message final : public Homa::OutMessage { public: /** * Construct an Message. @@ -148,6 +149,7 @@ class Sender { // packets is not initialized to reduce the work done during // construction. See Message::occupied. , state(Status::NOT_STARTED) + , notifyEndState() , bucketNode(this) , messageTimeout(this) , pingTimeout(this) @@ -160,16 +162,19 @@ class Sender { virtual Status getStatus() const; virtual size_t length() const; virtual void prepend(const void* source, size_t count); + virtual void registerCallbackEndState(Callback func); virtual void release(); virtual void reserve(size_t count); virtual void send(SocketAddress destination, Options options = Options::NONE); private: + void setStatus(Status newStatus); + /// Define the maximum number of packets that a message can hold. static const size_t MAX_MESSAGE_PACKETS = 1024; - Driver::Packet* getPacket(size_t index) const; + Driver::Packet* getPacket(size_t index); Driver::Packet* getOrAllocPacket(size_t index); /// The Sender responsible for sending this message. @@ -213,11 +218,14 @@ class Sender { /// Collection of Packet objects that make up this context's Message. /// These Packets will be released when this context is destroyed. - Driver::Packet* packets[MAX_MESSAGE_PACKETS]; + Driver::Packet packets[MAX_MESSAGE_PACKETS]; /// This message's current state. std::atomic state; + /// Callback function to invoke when _state_ reaches an end state. + Callback notifyEndState; + /// Intrusive structure used by the Sender to hold on to this Message /// in one of the Sender's MessageBuckets. Access to this structure /// is protected by the associated MessageBucket::mutex; @@ -390,11 +398,11 @@ class Sender { void sendMessage(Sender::Message* message, SocketAddress destination, Message::Options options = Message::Options::NONE); + void signalPacerThread(const SpinLock::Lock& lockHeld); void cancelMessage(Sender::Message* message); void dropMessage(Sender::Message* message); uint64_t checkMessageTimeouts(); uint64_t checkPingTimeouts(); - void trySend(); /// Transport identifier. const uint64_t transportId; @@ -412,24 +420,27 @@ class Sender { /// The maximum number of bytes that should be queued in the Driver. const uint32_t DRIVER_QUEUED_BYTE_LIMIT; + /// Rdtsc cycles for the Driver to drain one MB of data at line rate. + const uint32_t DRIVER_CYCLES_TO_DRAIN_1MB; + /// Tracks all outbound messages being sent by the Sender. MessageBucketMap messageBuckets; - /// Protects the readyQueue. + /// Protects the sendQueue and sendReady. SpinLock queueMutex; - /// A list of outbound messages that have unsent packets. Messages are kept - /// in order of priority. - Intrusive::List sendQueue; - - /// True if the Sender is currently executing trySend(); false, otherwise. - /// Use to prevent concurrent trySend() calls from blocking on each other. - std::atomic_flag sending = ATOMIC_FLAG_INIT; - /// Hint whether there are messages ready to be sent (i.e. there are granted /// messages in the sendQueue. Encoded into a single bool so that checking /// if there is work to do is more efficient. - std::atomic sendReady; + bool sendReady; + + /// Callback function to be invoked when _sendReady_ flips from false to + /// true. + Callback notifySendReady; + + /// A list of outbound messages that have unsent packets. Messages are kept + /// in order of priority. + Intrusive::List sendQueue; /// Used to allocate Message objects. struct { diff --git a/src/SenderTest.cc b/src/SenderTest.cc index 8085c82..96b49c7 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -29,18 +29,29 @@ using ::testing::_; using ::testing::Eq; using ::testing::Mock; using ::testing::NiceMock; -using ::testing::Pointee; using ::testing::Return; +/** + * Defines a matcher EqPacket(p) to match two Driver::Packet* by their + * underlying packet buffer descriptors. + */ +MATCHER_P(EqPacket, p, "") +{ + return arg->descriptor == p->descriptor; +} + class SenderTest : public ::testing::Test { public: SenderTest() : mockDriver() - , mockPacket{&payload} + , mockPacket() , mockPolicyManager(&mockDriver) + , payload() + , packetBuf{&payload} , sender() , savedLogPolicy(Debug::getLogPolicy()) { + mockPacket = packetBuf.toPacket(); ON_CALL(mockDriver, getBandwidth).WillByDefault(Return(8000)); ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1031)); ON_CALL(mockDriver, getQueuedBytes).WillByDefault(Return(0)); @@ -59,9 +70,10 @@ class SenderTest : public ::testing::Test { } NiceMock mockDriver; - Homa::Mock::MockDriver::MockPacket mockPacket; + Driver::Packet mockPacket; NiceMock mockPolicyManager; char payload[1028]; + Homa::Mock::MockDriver::PacketBuf packetBuf; Sender* sender; std::vector> savedLogPolicy; @@ -96,7 +108,7 @@ class SenderTest : public ::testing::Test { } static bool setMessagePacket(Sender::Message* message, int index, - Driver::Packet* packet) + Driver::Packet packet) { if (message->occupied.test(index)) { return false; @@ -139,8 +151,7 @@ TEST_F(SenderTest, handleDonePacket_basic) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(2); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(2); // No message. sender->handleDonePacket(&mockPacket); @@ -170,8 +181,7 @@ TEST_F(SenderTest, handleDonePacket_CANCELED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleDonePacket(&mockPacket); } @@ -188,8 +198,7 @@ TEST_F(SenderTest, handleDonePacket_COMPLETED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -219,8 +228,7 @@ TEST_F(SenderTest, handleDonePacket_FAILED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -252,8 +260,7 @@ TEST_F(SenderTest, handleDonePacket_IN_PROGRESS) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -285,8 +292,7 @@ TEST_F(SenderTest, handleDonePacket_NO_STARTED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -311,10 +317,11 @@ TEST_F(SenderTest, handleResendPacket_basic) Protocol::MessageId id = {42, 1}; Sender::Message* message = dynamic_cast(sender->allocMessage(0)); - std::vector packets; + std::vector packets; std::vector priorities; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket{payload}); + auto* packetBuf = new Homa::Mock::MockDriver::PacketBuf{payload}; + packets.push_back(packetBuf->toPacket()); priorities.push_back(0); setMessagePacket(message, i, packets[i]); } @@ -333,14 +340,13 @@ TEST_F(SenderTest, handleResendPacket_basic) resendHdr->priority = 4; EXPECT_CALL(mockPolicyManager, getResendPriority).WillOnce(Return(7)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]), _, _)) + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packets[3]), _, _)) .WillOnce( [&priorities](auto _1, auto _2, int p) { priorities[3] = p; }); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]), _, _)) + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packets[4]), _, _)) .WillOnce( [&priorities](auto _1, auto _2, int p) { priorities[4] = p; }); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleResendPacket(&mockPacket); @@ -353,10 +359,11 @@ TEST_F(SenderTest, handleResendPacket_basic) EXPECT_EQ(7, priorities[3]); EXPECT_EQ(7, priorities[4]); EXPECT_EQ(0, priorities[5]); - EXPECT_TRUE(sender->sendReady.load()); + EXPECT_TRUE(sender->sendReady); for (int i = 0; i < 10; ++i) { - delete packets[i]; + uintptr_t packetBuf = packets[i].descriptor; + delete (Homa::Mock::MockDriver::PacketBuf*)packetBuf; } } @@ -369,8 +376,7 @@ TEST_F(SenderTest, handleResendPacket_staleResend) resendHdr->index = 3; resendHdr->num = 5; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleResendPacket(&mockPacket); } @@ -382,8 +388,9 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) Sender::Message* message = dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); - Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket{payload}; + Homa::Mock::MockDriver::PacketBuf* packetBuf = + new Homa::Mock::MockDriver::PacketBuf{payload}; + Driver::Packet packet = packetBuf->toPacket(); setMessagePacket(message, 0, packet); Protocol::Packet::ResendHeader* resendHdr = @@ -393,8 +400,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) resendHdr->num = 5; resendHdr->priority = 4; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -413,7 +419,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) Debug::setLogHandler(std::function()); - delete packet; + delete packetBuf; } TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) @@ -421,9 +427,10 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) Protocol::MessageId id = {42, 1}; Sender::Message* message = dynamic_cast(sender->allocMessage(0)); - std::vector packets; + std::vector packets; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket{payload}); + auto* packetBuf = new Homa::Mock::MockDriver::PacketBuf{payload}; + packets.push_back(packetBuf->toPacket()); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -440,8 +447,7 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) resendHdr->num = 5; resendHdr->priority = 4; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -462,7 +468,8 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) Debug::setLogHandler(std::function()); for (int i = 0; i < 10; ++i) { - delete packets[i]; + uintptr_t packetBuf = packets[i].descriptor; + delete (Homa::Mock::MockDriver::PacketBuf*)packetBuf; } } @@ -472,9 +479,10 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) Sender::Message* message = dynamic_cast(sender->allocMessage(0)); char data[1028]; - Homa::Mock::MockDriver::MockPacket dataPacket{data}; + Homa::Mock::MockDriver::PacketBuf dataPacketBuf{data}; + Driver::Packet dataPacket = dataPacketBuf.toPacket(); for (int i = 0; i < 10; ++i) { - setMessagePacket(message, i, &dataPacket); + setMessagePacket(message, i, dataPacket); } SenderTest::addMessage(sender, id, message, true, 5); Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -490,16 +498,16 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) // Expect the BUSY control packet. char busy[1028]; - Homa::Mock::MockDriver::MockPacket busyPacket{busy}; - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&busyPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&busyPacket), _, _)).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&busyPacket), Eq(1))) + Homa::Mock::MockDriver::PacketBuf busyPacketBuf{busy}; + Driver::Packet busyPacket = busyPacketBuf.toPacket(); + EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(busyPacket)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&busyPacket), _, _)).Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&busyPacket), Eq(1))) .Times(1); // Expect no data to be sent but the RESEND packet to be release. - EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket), _, _)).Times(0); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&dataPacket), _, _)).Times(0); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleResendPacket(&mockPacket); @@ -533,8 +541,7 @@ TEST_F(SenderTest, handleGrantPacket_basic) header->byteLimit = 7000; header->priority = 6; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleGrantPacket(&mockPacket); @@ -542,7 +549,7 @@ TEST_F(SenderTest, handleGrantPacket_basic) EXPECT_EQ(6, info->priority); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); - EXPECT_TRUE(sender->sendReady.load()); + EXPECT_TRUE(sender->sendReady); } TEST_F(SenderTest, handleGrantPacket_excessiveGrant) @@ -565,8 +572,7 @@ TEST_F(SenderTest, handleGrantPacket_excessiveGrant) header->byteLimit = 11000; header->priority = 6; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -589,7 +595,7 @@ TEST_F(SenderTest, handleGrantPacket_excessiveGrant) EXPECT_EQ(6, info->priority); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); - EXPECT_TRUE(sender->sendReady.load()); + EXPECT_TRUE(sender->sendReady); } TEST_F(SenderTest, handleGrantPacket_staleGrant) @@ -611,8 +617,7 @@ TEST_F(SenderTest, handleGrantPacket_staleGrant) header->byteLimit = 4000; header->priority = 6; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleGrantPacket(&mockPacket); @@ -620,7 +625,7 @@ TEST_F(SenderTest, handleGrantPacket_staleGrant) EXPECT_EQ(2, info->priority); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); - EXPECT_FALSE(sender->sendReady.load()); + EXPECT_FALSE(sender->sendReady); } TEST_F(SenderTest, handleGrantPacket_dropGrant) @@ -631,8 +636,7 @@ TEST_F(SenderTest, handleGrantPacket_dropGrant) header->common.messageId = id; header->byteLimit = 4000; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleGrantPacket(&mockPacket); } @@ -646,13 +650,14 @@ TEST_F(SenderTest, handleUnknownPacket_basic) Sender::Message* message = dynamic_cast(sender->allocMessage(0)); - std::vector packets; + std::vector packets; char payload[5][1028]; for (int i = 0; i < 5; ++i) { - Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket{payload[i]}; + Homa::Mock::MockDriver::PacketBuf* packetBuf = + new Homa::Mock::MockDriver::PacketBuf{payload[i]}; + Driver::Packet packet = packetBuf->toPacket(); Protocol::Packet::DataHeader* header = - static_cast(packet->payload); + static_cast(packet.payload); header->policyVersion = policyOld.version; header->unscheduledIndexLimit = 2; packets.push_back(packet); @@ -682,16 +687,14 @@ TEST_F(SenderTest, handleUnknownPacket_basic) mockPolicyManager, getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) .WillOnce(Return(policyNew)); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); for (int i = 0; i < 3; ++i) { - Homa::Mock::MockDriver::MockPacket* packet = packets[i]; Protocol::Packet::DataHeader* header = - static_cast(packet->payload); + static_cast(packets[i].payload); EXPECT_EQ(policyNew.version, header->policyVersion); EXPECT_EQ(3U, header->unscheduledIndexLimit); } @@ -702,10 +705,11 @@ TEST_F(SenderTest, handleUnknownPacket_basic) EXPECT_EQ(policyNew.priority, info->priority); EXPECT_EQ(0U, info->packetsSent); EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); - EXPECT_TRUE(sender->sendReady.load()); + EXPECT_TRUE(sender->sendReady); for (int i = 0; i < 5; ++i) { - delete packets[i]; + uintptr_t packetBuf = packets[i].descriptor; + delete (Homa::Mock::MockDriver::PacketBuf*)packetBuf; } } @@ -718,12 +722,13 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) Sender::Message* message = dynamic_cast(sender->allocMessage(0)); - Homa::Mock::MockDriver::MockPacket dataPacket{payload}; + Homa::Mock::MockDriver::PacketBuf dataPacketBuf{payload}; + Driver::Packet dataPacket = dataPacketBuf.toPacket(); Protocol::Packet::DataHeader* dataHeader = static_cast(dataPacket.payload); dataHeader->policyVersion = policyOld.version; dataHeader->unscheduledIndexLimit = 2; - setMessagePacket(message, 0, &dataPacket); + setMessagePacket(message, 0, dataPacket); message->destination = destination; message->messageLength = 500; message->state.store(Homa::OutMessage::Status::SENT); @@ -741,9 +746,8 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) mockPolicyManager, getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) .WillOnce(Return(policyNew)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket), _, _)).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(message->packets), _, _)).Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleUnknownPacket(&mockPacket); @@ -754,7 +758,7 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); EXPECT_FALSE( sender->sendQueue.contains(&message->queuedMessageInfo.sendQueueNode)); - EXPECT_FALSE(sender->sendReady.load()); + EXPECT_FALSE(sender->sendReady); } TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) @@ -764,13 +768,14 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); message->options = OutMessage::Options::NO_RETRY; - std::vector packets; + std::vector packets; char payload[5][1028]; for (int i = 0; i < 5; ++i) { - Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket{payload[i]}; + Homa::Mock::MockDriver::PacketBuf* packetBuf = + new Homa::Mock::MockDriver::PacketBuf{payload[i]}; + Driver::Packet packet = packetBuf->toPacket(); packets.push_back(packet); setMessagePacket(message, i, packet); } @@ -787,10 +792,9 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_FALSE( sender->sendQueue.contains(&message->queuedMessageInfo.sendQueueNode)); @@ -798,7 +802,7 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) EXPECT_EQ(nullptr, message->pingTimeout.node.list); EXPECT_EQ(Homa::OutMessage::Status::FAILED, message->state); EXPECT_EQ(Homa::OutMessage::Status::FAILED, message->state); - EXPECT_FALSE(sender->sendReady.load()); + EXPECT_FALSE(sender->sendReady); } TEST_F(SenderTest, handleUnknownPacket_no_message) @@ -809,8 +813,7 @@ TEST_F(SenderTest, handleUnknownPacket_no_message) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleUnknownPacket(&mockPacket); } @@ -829,8 +832,7 @@ TEST_F(SenderTest, handleUnknownPacket_done) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleUnknownPacket(&mockPacket); @@ -854,8 +856,7 @@ TEST_F(SenderTest, handleErrorPacket_basic) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleErrorPacket(&mockPacket); @@ -877,8 +878,7 @@ TEST_F(SenderTest, handleErrorPacket_CANCELED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleErrorPacket(&mockPacket); @@ -898,8 +898,7 @@ TEST_F(SenderTest, handleErrorPacket_NOT_STARTED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -934,8 +933,7 @@ TEST_F(SenderTest, handleErrorPacket_IN_PROGRESS) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -970,8 +968,7 @@ TEST_F(SenderTest, handleErrorPacket_COMPLETED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -1006,8 +1003,7 @@ TEST_F(SenderTest, handleErrorPacket_FAILED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); @@ -1033,17 +1029,10 @@ TEST_F(SenderTest, handleErrorPacket_noMessage) Protocol::Packet::ErrorHeader* header = static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); sender->handleErrorPacket(&mockPacket); } -TEST_F(SenderTest, poll) -{ - // Nothing to test. - sender->poll(); -} - TEST_F(SenderTest, checkTimeouts) { Sender::Message message(sender, 0); @@ -1094,8 +1083,8 @@ TEST_F(SenderTest, Message_append_basic) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1{buf + MAX_RAW_PACKET_LENGTH}; + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1104,19 +1093,19 @@ TEST_F(SenderTest, Message_append_basic) TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH); char source[] = "Hello, world!"; - setMessagePacket(&msg, 0, &packet0); - packet0.length = MAX_RAW_PACKET_LENGTH - 7; + setMessagePacket(&msg, 0, packetBuf0.toPacket(MAX_RAW_PACKET_LENGTH - 7)); msg.messageLength = PACKET_DATA_LENGTH - 7; - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&packet1)); + EXPECT_CALL(mockDriver, allocPacket) + .WillOnce(Return(packetBuf1.toPacket())); msg.append(source, 14); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.messageLength); EXPECT_EQ(2U, msg.numPackets); - EXPECT_TRUE(msg.packets[1] == &packet1); - EXPECT_EQ(MAX_RAW_PACKET_LENGTH, packet0.length); - EXPECT_EQ(TRANSPORT_HEADER_LENGTH + 7, packet1.length); + EXPECT_EQ(msg.packets[1].payload, packetBuf1.buffer); + EXPECT_EQ(MAX_RAW_PACKET_LENGTH, msg.packets[0].length); + EXPECT_EQ(TRANSPORT_HEADER_LENGTH + 7, msg.packets[1].length); EXPECT_TRUE(std::memcmp(buf + MAX_RAW_PACKET_LENGTH - 7, source, 7) == 0); EXPECT_TRUE( std::memcmp(buf + MAX_RAW_PACKET_LENGTH + TRANSPORT_HEADER_LENGTH, @@ -1133,16 +1122,13 @@ TEST_F(SenderTest, Message_append_truncated) ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); Sender::Message msg(sender, 0); - char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1{buf + MAX_RAW_PACKET_LENGTH}; - - const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; - const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; + char buf[MAX_RAW_PACKET_LENGTH]; + Homa::Mock::MockDriver::PacketBuf packetBuf{buf}; char source[] = "Hello, world!"; - setMessagePacket(&msg, msg.MAX_MESSAGE_PACKETS - 1, &packet0); - packet0.length = msg.TRANSPORT_HEADER_LENGTH + msg.PACKET_DATA_LENGTH - 7; + setMessagePacket(&msg, msg.MAX_MESSAGE_PACKETS - 1, packetBuf.toPacket()); + Driver::Packet& packet = msg.packets[msg.MAX_MESSAGE_PACKETS - 1]; + packet.length = msg.TRANSPORT_HEADER_LENGTH + msg.PACKET_DATA_LENGTH - 7; msg.messageLength = msg.PACKET_DATA_LENGTH * msg.MAX_MESSAGE_PACKETS - 7; EXPECT_EQ(1U, msg.numPackets); @@ -1152,7 +1138,7 @@ TEST_F(SenderTest, Message_append_truncated) msg.messageLength); EXPECT_EQ(1U, msg.numPackets); EXPECT_EQ(msg.TRANSPORT_HEADER_LENGTH + msg.PACKET_DATA_LENGTH, - packet0.length); + packet.length); EXPECT_TRUE(std::memcmp(buf + MAX_RAW_PACKET_LENGTH - 7, source, 7) == 0); EXPECT_EQ(1U, handler.messages.size()); @@ -1180,7 +1166,7 @@ TEST_F(SenderTest, Message_getStatus) TEST_F(SenderTest, Message_length) { ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); msg.messageLength = 200; msg.start = 20; EXPECT_EQ(180U, msg.length()); @@ -1191,14 +1177,16 @@ TEST_F(SenderTest, Message_prepend) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; EXPECT_CALL(mockDriver, allocPacket) - .WillOnce(Return(&packet0)) - .WillOnce(Return(&packet1)); + .WillOnce(Return(packet0)) + .WillOnce(Return(packet1)); msg.reserve(PACKET_DATA_LENGTH + 7); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.start); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.messageLength); @@ -1226,8 +1214,8 @@ TEST_F(SenderTest, Message_reserve) { Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1236,26 +1224,28 @@ TEST_F(SenderTest, Message_reserve) EXPECT_EQ(0U, msg.messageLength); EXPECT_EQ(0U, msg.numPackets); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&packet0)); + EXPECT_CALL(mockDriver, allocPacket) + .WillOnce(Return(packetBuf0.toPacket())); msg.reserve(PACKET_DATA_LENGTH - 7); EXPECT_EQ(PACKET_DATA_LENGTH - 7, msg.start); EXPECT_EQ(PACKET_DATA_LENGTH - 7, msg.messageLength); EXPECT_EQ(1U, msg.numPackets); - EXPECT_EQ(&packet0, msg.getPacket(0)); - EXPECT_EQ(TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH - 7, packet0.length); + EXPECT_EQ(TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH - 7, + msg.packets[0].length); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&packet1)); + EXPECT_CALL(mockDriver, allocPacket) + .WillOnce(Return(packetBuf1.toPacket())); msg.reserve(14); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.start); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.messageLength); EXPECT_EQ(2U, msg.numPackets); - EXPECT_EQ(TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH, packet0.length); - EXPECT_EQ(&packet1, msg.getPacket(1)); - EXPECT_EQ(TRANSPORT_HEADER_LENGTH + 7, packet1.length); + EXPECT_EQ(TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH, + msg.packets[0].length); + EXPECT_EQ(TRANSPORT_HEADER_LENGTH + 7, msg.packets[1].length); } TEST_F(SenderTest, Message_send) @@ -1266,8 +1256,8 @@ TEST_F(SenderTest, Message_send) TEST_F(SenderTest, Message_getPacket) { Sender::Message msg(sender, 0); - Driver::Packet* packet = (Driver::Packet*)42; - msg.packets[0] = packet; + msg.packets[0] = {}; + Driver::Packet* packet = &msg.packets[0]; EXPECT_EQ(nullptr, msg.getPacket(0)); @@ -1281,19 +1271,21 @@ TEST_F(SenderTest, Message_getOrAllocPacket) // TODO(cstlee): cleanup Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0{buf + 0}; - Homa::Mock::MockDriver::MockPacket packet1{buf + 2048}; + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); EXPECT_FALSE(msg.occupied.test(0)); EXPECT_EQ(0U, msg.numPackets); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&packet0)); + EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(packet0)); - EXPECT_TRUE(&packet0 == msg.getOrAllocPacket(0)); + EXPECT_EQ(packet0.descriptor, msg.getOrAllocPacket(0)->descriptor); EXPECT_TRUE(msg.occupied.test(0)); EXPECT_EQ(1U, msg.numPackets); - EXPECT_TRUE(&packet0 == msg.getOrAllocPacket(0)); + EXPECT_EQ(packet0.descriptor, msg.getOrAllocPacket(0)->descriptor); EXPECT_TRUE(msg.occupied.test(0)); EXPECT_EQ(1U, msg.numPackets); @@ -1341,7 +1333,8 @@ TEST_F(SenderTest, sendMessage_basic) dynamic_cast(sender->allocMessage(sport)); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); - setMessagePacket(message, 0, &mockPacket); + setMessagePacket(message, 0, mockPacket); + Driver::Packet& mockPacket = message->packets[0]; message->messageLength = 420; mockPacket.length = message->messageLength + message->TRANSPORT_HEADER_LENGTH; @@ -1387,23 +1380,26 @@ TEST_F(SenderTest, sendMessage_basic) EXPECT_EQ(policy.priority, mockPriority); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); - EXPECT_FALSE(sender->sendReady.load()); + EXPECT_FALSE(sender->sendReady); } TEST_F(SenderTest, sendMessage_multipacket) { char payload0[1027]; char payload1[1027]; - Homa::Mock::MockDriver::MockPacket packet0{payload0}; - Homa::Mock::MockDriver::MockPacket packet1{payload1}; + Homa::Mock::MockDriver::PacketBuf packetBuf0{payload0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{payload1}; Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = dynamic_cast(sender->allocMessage(0)); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); - setMessagePacket(message, 0, &packet0); - setMessagePacket(message, 1, &packet1); + setMessagePacket(message, 0, packetBuf0.toPacket()); + setMessagePacket(message, 1, packetBuf1.toPacket()); + Driver::Packet& packet0 = message->packets[0]; + Driver::Packet& packet1 = message->packets[1]; + message->messageLength = 1420; packet0.length = 1000 + 31; packet1.length = 420 + 31; @@ -1446,7 +1442,7 @@ TEST_F(SenderTest, sendMessage_multipacket) // Check sendQueue metadata Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); - EXPECT_TRUE(sender->sendReady.load()); + EXPECT_TRUE(sender->sendReady); } TEST_F(SenderTest, sendMessage_missingPacket) @@ -1455,7 +1451,7 @@ TEST_F(SenderTest, sendMessage_missingPacket) sender->nextMessageSequenceNumber}; Sender::Message* message = dynamic_cast(sender->allocMessage(0)); - setMessagePacket(message, 1, &mockPacket); + setMessagePacket(message, 1, mockPacket); Core::Policy::Unscheduled policy = {1, 1000, 2}; ON_CALL(mockPolicyManager, getUnscheduledPolicy(_, _)) .WillByDefault(Return(policy)); @@ -1472,10 +1468,10 @@ TEST_F(SenderTest, sendMessage_unscheduledLimit) Sender::Message* message = dynamic_cast(sender->allocMessage(0)); for (int i = 0; i < 9; ++i) { - setMessagePacket(message, i, &mockPacket); + mockPacket.length = 1000 + sizeof(Protocol::Packet::DataHeader); + setMessagePacket(message, i, mockPacket); } message->messageLength = 9000; - mockPacket.length = 1000 + sizeof(Protocol::Packet::DataHeader); SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policy = {1, 4500, 2}; EXPECT_EQ(9U, message->numPackets); @@ -1618,9 +1614,9 @@ TEST_F(SenderTest, checkPingTimeouts_basic) EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), _, _)).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), _, _)).Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); uint64_t nextTimeout = sender->checkPingTimeouts(); @@ -1661,13 +1657,15 @@ TEST_F(SenderTest, trySend_basic) dynamic_cast(sender->allocMessage(0)); Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; SenderTest::addMessage(sender, id, message, true, 3); - Homa::Mock::MockDriver::MockPacket* packet[5]; + Driver::Packet packet[5]; + uint64_t waitUntil; const uint32_t PACKET_SIZE = sender->driver->getMaxPayloadSize(); const uint32_t PACKET_DATA_SIZE = PACKET_SIZE - message->TRANSPORT_HEADER_LENGTH; for (int i = 0; i < 5; ++i) { - packet[i] = new Homa::Mock::MockDriver::MockPacket{payload}; - packet[i]->length = PACKET_SIZE; + auto* packetBuf = new Homa::Mock::MockDriver::PacketBuf{payload}; + packet[i] = packetBuf->toPacket(); + packet[i].length = PACKET_SIZE; setMessagePacket(message, i, packet[i]); info->unsentBytes += PACKET_DATA_SIZE; } @@ -1681,9 +1679,9 @@ TEST_F(SenderTest, trySend_basic) EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); // 3 granted packets; 2 will send; queue limit reached. - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]), _, _)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]), _, _)); - sender->trySend(); // < test call + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[0]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[1]), _, _)); + sender->trySend(&waitUntil); // < test call EXPECT_TRUE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); EXPECT_EQ(3U, info->packetsGranted); @@ -1694,8 +1692,8 @@ TEST_F(SenderTest, trySend_basic) Mock::VerifyAndClearExpectations(&mockDriver); // 1 packet to be sent; grant limit reached. - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]), _, _)); - sender->trySend(); // < test call + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[2]), _, _)); + sender->trySend(&waitUntil); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); EXPECT_EQ(3U, info->packetsGranted); @@ -1708,7 +1706,7 @@ TEST_F(SenderTest, trySend_basic) // No additional grants; spurious ready hint. EXPECT_CALL(mockDriver, sendPacket).Times(0); sender->sendReady = true; - sender->trySend(); // < test call + sender->trySend(&waitUntil); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); EXPECT_EQ(3U, info->packetsGranted); @@ -1721,9 +1719,9 @@ TEST_F(SenderTest, trySend_basic) // 2 more granted packets; will finish. info->packetsGranted = 5; sender->sendReady = true; - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[3]), _, _)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[4]), _, _)); - sender->trySend(); // < test call + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[3]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[4]), _, _)); + sender->trySend(&waitUntil); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_EQ(5U, info->packetsGranted); @@ -1734,7 +1732,8 @@ TEST_F(SenderTest, trySend_basic) Mock::VerifyAndClearExpectations(&mockDriver); for (int i = 0; i < 5; ++i) { - delete packet[i]; + uintptr_t packetBuf = packet[i].descriptor; + delete (Homa::Mock::MockDriver::PacketBuf*)packetBuf; } } @@ -1742,17 +1741,18 @@ TEST_F(SenderTest, trySend_multipleMessages) { Sender::Message* message[3]; Sender::QueuedMessageInfo* info[3]; - Homa::Mock::MockDriver::MockPacket* packet[3]; + Driver::Packet packet[3]; for (uint64_t i = 0; i < 3; ++i) { Protocol::MessageId id = {22, 10 + i}; message[i] = dynamic_cast(sender->allocMessage(0)); info[i] = &message[i]->queuedMessageInfo; SenderTest::addMessage(sender, id, message[i], true, 1); - packet[i] = new Homa::Mock::MockDriver::MockPacket{payload}; - packet[i]->length = sender->driver->getMaxPayloadSize() / 4; + auto* packetBuf = new Homa::Mock::MockDriver::PacketBuf{payload}; + packet[i] = packetBuf->toPacket(); + packet[i].length = sender->driver->getMaxPayloadSize() / 4; setMessagePacket(message[i], 0, packet[i]); info[i]->unsentBytes += - (packet[i]->length - message[i]->TRANSPORT_HEADER_LENGTH); + (packet[i].length - message[i]->TRANSPORT_HEADER_LENGTH); message[i]->state = Homa::OutMessage::Status::IN_PROGRESS; } sender->sendReady = true; @@ -1764,19 +1764,21 @@ TEST_F(SenderTest, trySend_multipleMessages) // Message 1: Will reach grant limit EXPECT_EQ(1, info[1]->packetsGranted); info[1]->packetsSent = 0; - setMessagePacket(message[1], 1, nullptr); + setMessagePacket(message[1], 1, {}); EXPECT_EQ(2, message[1]->numPackets); // Message 2: Will finish EXPECT_EQ(1, info[2]->packetsGranted); info[2]->packetsSent = 0; - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]), _, _)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]), _, _)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[0]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[1]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[2]), _, _)); - sender->trySend(); + uint64_t waitUntil; + bool sendReady = sender->trySend(&waitUntil); + EXPECT_FALSE(sendReady); EXPECT_EQ(1U, info[0]->packetsSent); EXPECT_EQ(Homa::OutMessage::Status::SENT, message[0]->state); EXPECT_FALSE(sender->sendQueue.contains(&info[0]->sendQueueNode)); @@ -1788,33 +1790,14 @@ TEST_F(SenderTest, trySend_multipleMessages) EXPECT_FALSE(sender->sendQueue.contains(&info[2]->sendQueueNode)); } -TEST_F(SenderTest, trySend_alreadyRunning) -{ - Protocol::MessageId id = {42, 1}; - Sender::Message* message = - dynamic_cast(sender->allocMessage(0)); - Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; - SenderTest::addMessage(sender, id, message, true, 1); - setMessagePacket(message, 0, &mockPacket); - message->messageLength = 1000; - EXPECT_EQ(1U, message->numPackets); - EXPECT_EQ(1, info->packetsGranted); - EXPECT_EQ(0, info->packetsSent); - - sender->sending.test_and_set(); - - EXPECT_CALL(mockDriver, sendPacket).Times(0); - - sender->trySend(); - - EXPECT_EQ(0, info->packetsSent); -} - TEST_F(SenderTest, trySend_nothingToSend) { EXPECT_TRUE(sender->sendQueue.empty()); EXPECT_CALL(mockDriver, sendPacket).Times(0); - sender->trySend(); + uint64_t waitUntil = 0; + bool sendReady = sender->trySend(&waitUntil); + EXPECT_FALSE(sendReady); + EXPECT_EQ(waitUntil, 0); } } // namespace diff --git a/src/Shenango.cc b/src/Shenango.cc new file mode 100644 index 0000000..a57794e --- /dev/null +++ b/src/Shenango.cc @@ -0,0 +1,272 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "Homa/Shenango.h" +#include "Debug.h" +#include "Homa/Homa.h" + +using namespace Homa; + +/** + * Shorthand for declaring "extern" function pointers to Shenango functions. + * These functions pointers will be initialized on the Shenango side in homa.c. + */ +#define DECLARE_SHENANGO_FUNC(ReturnType, MethodName, ...) \ + extern ReturnType (*shenango_##MethodName)(__VA_ARGS__); + +/** + * Fast thread-local slab-based memory allocation. + */ +DECLARE_SHENANGO_FUNC(void*, smalloc, size_t) +DECLARE_SHENANGO_FUNC(void, sfree, void*) + +/** + * Protect RCU read-side critical sections. + */ +DECLARE_SHENANGO_FUNC(void, rcu_read_lock) +DECLARE_SHENANGO_FUNC(void, rcu_read_unlock) + +/** + * Allocate a Shenango mbuf struct to hold an egress Homa packet. + */ +DECLARE_SHENANGO_FUNC(void*, homa_tx_alloc_mbuf, void**) + +/** + * Free a packet buffer allocated earlier. + */ +DECLARE_SHENANGO_FUNC(void, mbuf_free, void*) + +/** + * Transmit an IP packet using Shenango's driver stack. + */ +DECLARE_SHENANGO_FUNC(int, homa_tx_ip, uintptr_t, void*, int32_t, uint8_t, + uint32_t, uint8_t) + +/** + * Deliver an ingress message to a homa socket in Shenango. + */ +DECLARE_SHENANGO_FUNC(void, homa_mb_deliver, void*, homa_inmsg) + +/** + * Return the number of bytes queued up in the transmit queue. + */ +DECLARE_SHENANGO_FUNC(uint32_t, homa_queued_bytes) + +/** + * Find a socket that matches the 5-tuple. + */ +DECLARE_SHENANGO_FUNC(void*, trans_table_lookup, uint8_t, SocketAddress, + SocketAddress) + +/** + * A simple shim driver that translates Driver operations to Shenango + * functions. + */ +class ShenangoDriver final : public Driver { + public: + explicit ShenangoDriver(uint8_t proto, uint32_t local_ip, + uint32_t max_payload, uint32_t link_speed) + : Driver() + , proto(proto) + , local_ip{local_ip} + , max_payload(max_payload) + , link_speed(link_speed) + {} + + Packet allocPacket() override + { + void* payload; + void* mbuf = shenango_homa_tx_alloc_mbuf(&payload); + return Packet{(uintptr_t)mbuf, payload, 0}; + } + + void sendPacket(Packet* packet, IpAddress destination, + int priority) override + { + shenango_homa_tx_ip(packet->descriptor, packet->payload, packet->length, + proto, (uint32_t)destination, (uint8_t)priority); + } + + uint32_t receivePackets(uint32_t maxPackets, Packet receivedPackets[], + IpAddress sourceAddresses[]) override + { + (void)maxPackets; + (void)receivedPackets; + (void)sourceAddresses; + PANIC("receivePackets must not be called when used with Shenango"); + return 0; + } + + void releasePackets(Packet packets[], uint16_t numPackets) override + { + for (uint16_t i = 0; i < numPackets; i++) { + shenango_mbuf_free((void*)packets[i].descriptor); + } + } + + uint32_t getMaxPayloadSize() override + { + return max_payload; + } + + uint32_t getBandwidth() override + { + return link_speed; + } + + IpAddress getLocalAddress() override + { + return local_ip; + } + + uint32_t getQueuedBytes() override + { + return shenango_homa_queued_bytes(); + } + + private: + /// Protocol number reserved for Homa; defined as IPPROTO_HOMA in Shenango. + const uint8_t proto; + + /// Local IP address of the driver. + const IpAddress local_ip; + + /// # bytes in a payload + const uint32_t max_payload; + + /// Effective network bandwidth, in Mbits/second. + const uint32_t link_speed; +}; + +homa_driver +homa_driver_create(uint8_t proto, uint32_t local_ip, uint32_t max_payload, + uint32_t link_speed) +{ + void* driver = new ShenangoDriver(proto, local_ip, max_payload, link_speed); + return homa_driver{driver}; +} + +void +homa_driver_free(homa_driver drv) +{ + delete static_cast(drv.p); +} + +/** + * An almost trivial implementation of Mailbox. This class is essentially + * a wrapper around a socket table entry in Shenango (i.e., struct trans_entry). + * + */ +class ShenangoMailbox final : public Mailbox { + public: + explicit ShenangoMailbox(void* trans_entry) + : trans_entry(trans_entry) + {} + + ~ShenangoMailbox() override = default; + + void close() override + { + this->~ShenangoMailbox(); + shenango_sfree(this); + shenango_rcu_read_unlock(); + } + + void deliver(InMessage* message) override + { + shenango_homa_mb_deliver(trans_entry, homa_inmsg{message}); + } + + InMessage* retrieve(bool blocking) override + { + (void)blocking; + PANIC("Shenango should never call Homa::Socket::receive"); + } + + void socketShutdown() override + { + PANIC("Shenango should never call Homa::Socket::shutdown"); + } + + private: + /// An opaque pointer to "struct trans_entry" in Shenango. + void* const trans_entry; +}; + +/** + * An almost trivial implementation of MailboxDir that uses Shenango's RCU + * mechanism to prevent a mailbox from being destroyed until all readers have + * closed it. + * + * Note: Shenango doesn't use Homa::Socket to receive messages, so the only + * method that has a meaningful implementation is open(). + */ +class ShenangoMailboxDir final : MailboxDir { + public: + explicit ShenangoMailboxDir(uint8_t proto, uint32_t local_ip) + : proto(proto) + , local_ip{local_ip} + {} + + ~ShenangoMailboxDir() override = default; + + Mailbox* alloc(uint16_t port) override + { + // Shenango doesn't rely on Homa::Socket to receive messages, + // so there is no need to assign a real mailbox to SocketImpl. + static ShenangoMailbox dummyMailbox(nullptr); + (void)port; + return &dummyMailbox; + } + + Mailbox* open(uint16_t port) override + { + SocketAddress laddr = {local_ip, port}; + shenango_rcu_read_lock(); + void* trans_entry = shenango_trans_table_lookup(proto, laddr, {}); + if (!trans_entry) { + return nullptr; + } + void* backing = shenango_smalloc(sizeof(ShenangoMailbox)); + return new (backing) ShenangoMailbox(trans_entry); + } + + bool remove(uint16_t port) override + { + // Nothing to do; Shenango is responsible for taking care of freeing + // the resources related to homa sockets. + (void)port; + return true; + } + + /// Protocol number reserved for Homa; defined as IPPROTO_HOMA in Shenango. + const uint8_t proto; + + /// Local IP address of the transport. + const IpAddress local_ip; +}; + +homa_mailbox_dir +homa_mb_dir_create(uint8_t proto, uint32_t local_ip) +{ + void* dir = new ShenangoMailboxDir(proto, local_ip); + return homa_mailbox_dir{dir}; +} + +void +homa_mb_dir_free(homa_mailbox_dir mailbox_dir) +{ + delete static_cast(mailbox_dir.p); +} diff --git a/src/SimpleMailboxDir.cc b/src/SimpleMailboxDir.cc new file mode 100644 index 0000000..e74ebd3 --- /dev/null +++ b/src/SimpleMailboxDir.cc @@ -0,0 +1,171 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include +#include "SpinLock.h" + +namespace Homa { + +/** + * A simple reference implementation of Homa::Mailbox that uses polling to + * detect incoming messages. + */ +class MailboxImpl : public Mailbox { + public: + explicit MailboxImpl(); + ~MailboxImpl() override; + void close() override; + void deliver(InMessage* message) override; + InMessage* retrieve(bool blocking) override; + void socketShutdown() override; + + /// Protects the queue + SpinLock mutex; + + /// Keeps track of the number of calls to open() without paired close(). + /// It's initialized to one because, intuitively, a Socket must keep its + /// mailbox "open" in order to retrieve incoming messages. + std::atomic openers; + + /// Has the corresponding socket been shut down? + std::atomic shutdown; + + /// List of completely received messages. + std::list queue; +}; + +MailboxImpl::MailboxImpl() + : mutex() + , openers(1) + , shutdown() + , queue() +{} + +MailboxImpl::~MailboxImpl() +{ + while (!queue.empty()) { + InMessage* message = queue.front(); + queue.pop_front(); + Homa::unique_ptr deleter(message); + } +} + +/// See Homa::Mailbox::close() +void +MailboxImpl::close() +{ + if (openers.fetch_sub(1, std::memory_order_release) == 1) { + std::atomic_thread_fence(std::memory_order_acquire); + + // MailboxImpl was instantiated via "new" in SimpleMailboxDir::alloc. + delete this; + } +} + +/// See Homa::Mailbox::deliver() +void +MailboxImpl::deliver(InMessage* message) +{ + SpinLock::Lock _(mutex); + queue.push_back(message); +} + +/// See Homa::Mailbox::retrieve() +InMessage* +MailboxImpl::retrieve(bool blocking) +{ + InMessage* message = nullptr; + do { + SpinLock::Lock _(mutex); + if (!queue.empty()) { + message = queue.front(); + queue.pop_front(); + } + } while (blocking && !shutdown.load(std::memory_order_relaxed)); + return message; +} + +/// See Homa::Mailbox::socketShutdown() +void +MailboxImpl::socketShutdown() +{ + shutdown.store(true); +} + +SimpleMailboxDir::SimpleMailboxDir() + : mutex(new SpinLock()) + , map() +{} + +SimpleMailboxDir::~SimpleMailboxDir() +{ + for (auto entry : map) { + MailboxImpl* mailbox = entry.second; + mailbox->close(); + } +} + +Mailbox* +SimpleMailboxDir::alloc(uint16_t port) +{ + MailboxImpl* mailbox = nullptr; + SpinLock::Lock _(*mutex); + auto it = map.find(port); + if (it == map.end()) { + mailbox = new MailboxImpl(); + map[port] = mailbox; + } + return mailbox; +} + +Mailbox* +SimpleMailboxDir::open(uint16_t port) +{ + MailboxImpl* mailbox = nullptr; + { + // Look up the mailbox + SpinLock::Lock _(*mutex); + auto it = map.find(port); + if (it != map.end()) { + mailbox = it->second; + } + } + + // Increment the reference count of the mailbox. + if (mailbox) { + mailbox->openers.fetch_add(1, std::memory_order_relaxed); + } + return mailbox; +} + +bool +SimpleMailboxDir::remove(uint16_t port) +{ + MailboxImpl* mailbox; + { + SpinLock::Lock _(*mutex); + auto it = map.find(port); + if (it == map.end()) { + return false; + } + mailbox = it->second; + map.erase(it); + } + mailbox->close(); + return true; +} + +} // namespace Homa diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index a380944..16fd42b 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -14,17 +14,11 @@ */ #include "TransportImpl.h" - -#include -#include -#include - #include "Cycles.h" #include "Perf.h" #include "Protocol.h" -namespace Homa { -namespace Core { +namespace Homa::Core { // Basic timeout unit. const uint64_t BASE_TIMEOUT_US = 2000; @@ -36,15 +30,18 @@ const uint64_t PING_INTERVAL_US = 3 * BASE_TIMEOUT_US; const uint64_t RESEND_INTERVAL_US = BASE_TIMEOUT_US; /** - * Construct an instances of a Homa-based transport. + * Construct an instance of a Homa-based transport. * * @param driver * Driver with which this transport should send and receive packets. + * @param mailboxDir + * Mailbox directory with which this transport should deliver messages. * @param transportId * This transport's unique identifier in the group of transports among * which this transport will communicate. */ -TransportImpl::TransportImpl(Driver* driver, uint64_t transportId) +TransportImpl::TransportImpl(Driver* driver, MailboxDir* mailboxDir, + uint64_t transportId) : transportId(transportId) , driver(driver) , policyManager(new Policy::Manager(driver)) @@ -52,10 +49,24 @@ TransportImpl::TransportImpl(Driver* driver, uint64_t transportId) PerfUtils::Cycles::fromMicroseconds(MESSAGE_TIMEOUT_US), PerfUtils::Cycles::fromMicroseconds(PING_INTERVAL_US))) , receiver( - new Receiver(driver, policyManager.get(), + new Receiver(driver, mailboxDir, policyManager.get(), PerfUtils::Cycles::fromMicroseconds(MESSAGE_TIMEOUT_US), PerfUtils::Cycles::fromMicroseconds(RESEND_INTERVAL_US))) - , nextTimeoutCycles(0) + , mailboxDir(mailboxDir) +{} + +/** + * Construct an instance of a Homa-based transport for unit testing. + */ +TransportImpl::TransportImpl(Driver* driver, MailboxDir* mailboxDir, + Sender* sender, Receiver* receiver, + uint64_t transportId) + : transportId(transportId) + , driver(driver) + , policyManager(new Policy::Manager(driver)) + , sender(sender) + , receiver(receiver) + , mailboxDir(mailboxDir) {} /** @@ -63,55 +74,41 @@ TransportImpl::TransportImpl(Driver* driver, uint64_t transportId) */ TransportImpl::~TransportImpl() = default; -/// See Homa::Transport::poll() +/// See Homa::Transport::free() void -TransportImpl::poll() +TransportImpl::free() { - // Receive and dispatch incoming packets. - processPackets(); - - // Allow sender and receiver to make incremental progress. - sender->poll(); - receiver->poll(); - - if (PerfUtils::Cycles::rdtsc() >= nextTimeoutCycles.load()) { - uint64_t requestedTimeoutCycles; - requestedTimeoutCycles = sender->checkTimeouts(); - nextTimeoutCycles.store(requestedTimeoutCycles); - requestedTimeoutCycles = receiver->checkTimeouts(); - if (nextTimeoutCycles.load() > requestedTimeoutCycles) { - nextTimeoutCycles.store(requestedTimeoutCycles); - } - } + // We simply call "delete this" here because the only way to instantiate + // a Core::TransportImpl instance is via "new" in Transport::create(). + // An alternative would be to provide a static free() method that takes + // a pointer to Transport, the downside of this approach is that we must + // cast the argument to TransportImpl* because polymorphic deletion is + // disabled on the Transport interface. + delete this; } -/** - * Helper method which receives a burst of incoming packets and process them - * through the transport protocol. Pulled out of TransportImpl::poll() to - * simplify unit testing. - */ -void -TransportImpl::processPackets() +/// See Homa::Transport::open() +Homa::unique_ptr +TransportImpl::open(uint16_t port) { - // Keep track of time spent doing active processing versus idle. - uint64_t cycles = PerfUtils::Cycles::rdtsc(); - - const int MAX_BURST = 32; - Driver::Packet* packets[MAX_BURST]; - IpAddress srcAddrs[MAX_BURST]; - int numPackets = driver->receivePackets(MAX_BURST, packets, srcAddrs); - for (int i = 0; i < numPackets; ++i) { - processPacket(packets[i], srcAddrs[i]); + Mailbox* mailbox = mailboxDir->alloc(port); + if (!mailbox) { + return nullptr; } + SocketImpl* socket = new SocketImpl(this, port, mailbox); + return Homa::unique_ptr(socket); +} - cycles = PerfUtils::Cycles::rdtsc() - cycles; - if (numPackets > 0) { - Perf::counters.active_cycles.add(cycles); - } else { - Perf::counters.idle_cycles.add(cycles); - } +/// See Homa::Transport::checkTimeouts() +uint64_t +TransportImpl::checkTimeouts() +{ + uint64_t requestedTimeoutCycles = + std::min(sender->checkTimeouts(), receiver->checkTimeouts()); + return requestedTimeoutCycles; } +/// See Homa::Transport::processPacket() void TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) { @@ -156,5 +153,88 @@ TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) } } -} // namespace Core -} // namespace Homa +/// See Homa::Transport::registerCallbackSendReady() +void +TransportImpl::registerCallbackSendReady(Callback func) +{ + sender->registerCallbackSendReady(func); +} + +/// See Homa::Transport::trySend() +bool +TransportImpl::trySend(uint64_t* waitUntil) +{ + return sender->trySend(waitUntil); +} + +/// See Homa::Transport::trySendGrants() +bool +TransportImpl::trySendGrants() +{ + return receiver->trySendGrants(); +} + +/** + * Construct an instance of a Homa socket. + * + * @param transport + * Transport that owns the socket. + * @param port + * Local port number of the socket. + * @param mailbox + * Mailbox assigned to this socket. + */ +TransportImpl::SocketImpl::SocketImpl(TransportImpl* transport, uint16_t port, + Mailbox* mailbox) + : Socket() + , disabled() + , localAddress{transport->getDriver()->getLocalAddress(), port} + , mailbox(mailbox) + , transport(transport) +{} + +/// See Homa::Socket::alloc() +unique_ptr +TransportImpl::SocketImpl::alloc() +{ + if (isShutdown()) { + return nullptr; + } + OutMessage* outMessage = transport->sender->allocMessage(localAddress.port); + return unique_ptr(outMessage); +} + +/// See Homa::Socket::close() +void +TransportImpl::SocketImpl::close() +{ + bool success = transport->mailboxDir->remove(localAddress.port); + if (!success) { + ERROR("Failed to remove mailbox (port = %u)", localAddress.port); + } + + // Destruct the socket (the mailbox may be still in use). + // Note: it's actually legal to say "delete this" from a member function: + // https://isocpp.org/wiki/faq/freestore-mgmt#delete-this + delete this; +} + +/// See Homa::Socket::receive() +unique_ptr +TransportImpl::SocketImpl::receive(bool blocking) +{ + if (isShutdown()) { + return nullptr; + } + return unique_ptr(mailbox->retrieve(blocking)); +} + +/// See Homa::Socket::shutdown() +void +TransportImpl::SocketImpl::shutdown() +{ + disabled.store(true); + mailbox->socketShutdown(); +} + +} // namespace Homa::Core diff --git a/src/TransportImpl.h b/src/TransportImpl.h index ad46f99..d0e06b0 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -20,10 +20,6 @@ #include #include -#include -#include -#include -#include #include "ObjectPool.h" #include "Policy.h" @@ -34,32 +30,26 @@ /** * Homa */ -namespace Homa { -namespace Core { +namespace Homa::Core { /** * Internal implementation of Homa::Transport. - * */ -class TransportImpl : public Transport { +class TransportImpl final : public Transport { public: - explicit TransportImpl(Driver* driver, uint64_t transportId); + explicit TransportImpl(Driver* driver, MailboxDir* mailboxDir, + uint64_t transportId); + explicit TransportImpl(Driver* driver, MailboxDir* mailboxDir, + Sender* sender, Receiver* receiver, + uint64_t transportId); ~TransportImpl(); - - /// See Homa::Transport::alloc() - virtual Homa::unique_ptr alloc(uint16_t sourcePort) - { - Homa::OutMessage* outMessage = sender->allocMessage(sourcePort); - return Homa::unique_ptr(outMessage); - } - - /// See Homa::Transport::receive() - virtual Homa::unique_ptr receive() - { - return Homa::unique_ptr(receiver->receiveMessage()); - } - - virtual void poll(); + void free() override; + Homa::unique_ptr open(uint16_t port) override; + uint64_t checkTimeouts() override; + void processPacket(Driver::Packet* packet, IpAddress source) override; + void registerCallbackSendReady(Callback func) override; + bool trySend(uint64_t* waitUntil) override; + bool trySendGrants() override; /// See Homa::Transport::getDriver() virtual Driver* getDriver() @@ -73,14 +63,55 @@ class TransportImpl : public Transport { return transportId; } - private: - void processPackets(); - void processPacket(Driver::Packet* packet, IpAddress source); + /** + * Internal implementation of Homa::Socket. + * + * @sa + * TransportImpl::socketMap + */ + class SocketImpl final : public Socket { + public: + explicit SocketImpl(TransportImpl* transport, uint16_t port, + Mailbox* mailbox); + virtual ~SocketImpl() = default; + + Homa::unique_ptr alloc() override; + void close() override; + Homa::unique_ptr receive(bool blocking) override; + void shutdown() override; + + /// See Homa::Socket::isShutdown() + bool isShutdown() const override + { + return disabled.load(std::memory_order_relaxed); + } + + /// See Homa::Socket::getLocalAddress() + Address getLocalAddress() const override + { + return localAddress; + } + + private: + /// Has the socket been shut down? + std::atomic disabled; + + /// Local address of the socket. + Address localAddress; + + /// Mailbox assigned to this socket. Not owned by this class. + Mailbox* mailbox; + + /// Transport that owns this socket. + TransportImpl* transport; + }; + private: /// Unique identifier for this transport. - const std::atomic transportId; + const uint64_t transportId; /// Driver from which this transport will send and receive packets. + /// Not owned by this class. Driver* const driver; /// Module which manages the network packet priority policy. @@ -92,11 +123,11 @@ class TransportImpl : public Transport { /// Module which receives packets and forms them into messages. std::unique_ptr receiver; - /// Caches the next cycle time that timeouts will need to rechecked. - std::atomic nextTimeoutCycles; + /// Module which keeps track of mailboxes currently in use. Not owned by + /// this class (we don't even know whether it's instantiated by "new"). + MailboxDir* const mailboxDir; }; -} // namespace Core -} // namespace Homa +} // namespace Homa::Core #endif // HOMA_CORE_TRANSPORT_H diff --git a/src/TransportImplTest.cc b/src/TransportImplTest.cc index a0f66c6..cc8b887 100644 --- a/src/TransportImplTest.cc +++ b/src/TransportImplTest.cc @@ -16,6 +16,7 @@ #include #include +#include "Homa/Utils/TransportPoller.h" #include "Mock/MockDriver.h" #include "Mock/MockReceiver.h" #include "Mock/MockSender.h" @@ -34,133 +35,156 @@ using ::testing::NiceMock; using ::testing::Return; using ::testing::SetArrayArgument; +/** + * Defines a matcher EqPacket(p) to match two Driver::Packet* by their + * underlying packet buffer descriptors. + */ +MATCHER_P(EqPacket, p, "") +{ + return arg->descriptor == p->descriptor; +} + class TransportImplTest : public ::testing::Test { public: TransportImplTest() - : mockDriver() - , transport(new TransportImpl(&mockDriver, 22)) - , mockSender( - new NiceMock(22, &mockDriver, 0, 0)) - , mockReceiver( - new NiceMock(&mockDriver, 0, 0)) + : mockDriver(allocMockDriver()) + , mockSender(new NiceMock(22, mockDriver, 0, 0)) + , mockReceiver(new NiceMock(mockDriver, 0, 0)) + , transport(new TransportImpl(mockDriver, nullptr, mockSender, + mockReceiver, 22)) + , poller(transport) { - transport->sender.reset(mockSender); - transport->receiver.reset(mockReceiver); - ON_CALL(mockDriver, getBandwidth).WillByDefault(Return(8000)); - ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1024)); PerfUtils::Cycles::mockTscValue = 10000; } ~TransportImplTest() { delete transport; + delete mockDriver; PerfUtils::Cycles::mockTscValue = 0; } - NiceMock mockDriver; - TransportImpl* transport; + NiceMock* allocMockDriver() + { + auto driver = new NiceMock(); + ON_CALL(*driver, getBandwidth).WillByDefault(Return(8000)); + ON_CALL(*driver, getMaxPayloadSize).WillByDefault(Return(1024)); + return driver; + } + + NiceMock* mockDriver; NiceMock* mockSender; NiceMock* mockReceiver; + TransportImpl* transport; + TransportPoller poller; }; TEST_F(TransportImplTest, poll) { - EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); - EXPECT_CALL(*mockSender, poll).Times(1); - EXPECT_CALL(*mockReceiver, poll).Times(1); + EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); + EXPECT_CALL(*mockSender, trySend).Times(1); + EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10000)); EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); - transport->poll(); + poller.poll(); - EXPECT_EQ(10000U, transport->nextTimeoutCycles); + EXPECT_EQ(10000U, poller.nextTimeoutCycles); - EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); - EXPECT_CALL(*mockSender, poll).Times(1); - EXPECT_CALL(*mockReceiver, poll).Times(1); + EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); + EXPECT_CALL(*mockSender, trySend).Times(1); + EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10200)); EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); - transport->poll(); + poller.poll(); - EXPECT_EQ(10100U, transport->nextTimeoutCycles); + EXPECT_EQ(10100U, poller.nextTimeoutCycles); - EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); - EXPECT_CALL(*mockSender, poll).Times(1); - EXPECT_CALL(*mockReceiver, poll).Times(1); + EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); + EXPECT_CALL(*mockSender, trySend).Times(1); + EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); EXPECT_CALL(*mockSender, checkTimeouts).Times(0); EXPECT_CALL(*mockReceiver, checkTimeouts).Times(0); - transport->poll(); + poller.poll(); - EXPECT_EQ(10100U, transport->nextTimeoutCycles); + EXPECT_EQ(10100U, poller.nextTimeoutCycles); } TEST_F(TransportImplTest, processPackets) { char payload[8][1024]; - Homa::Driver::Packet* packets[8]; + Homa::Driver::Packet packets[8]; // Set DATA packet - Homa::Mock::MockDriver::MockPacket dataPacket{payload[0], 1024}; + Homa::Mock::MockDriver::PacketBuf dataPacketBuf{payload[0]}; + Driver::Packet dataPacket = dataPacketBuf.toPacket(1024); static_cast(dataPacket.payload) ->common.opcode = Protocol::Packet::DATA; - packets[0] = &dataPacket; - EXPECT_CALL(*mockReceiver, handleDataPacket(Eq(&dataPacket), _)); + packets[0] = dataPacket; + EXPECT_CALL(*mockReceiver, handleDataPacket(EqPacket(&packets[0]), _)); // Set GRANT packet - Homa::Mock::MockDriver::MockPacket grantPacket{payload[1], 1024}; + Homa::Mock::MockDriver::PacketBuf grantPacketBuf{payload[1]}; + Driver::Packet grantPacket = grantPacketBuf.toPacket(1024); static_cast(grantPacket.payload) ->common.opcode = Protocol::Packet::GRANT; - packets[1] = &grantPacket; - EXPECT_CALL(*mockSender, handleGrantPacket(Eq(&grantPacket))); + packets[1] = grantPacket; + EXPECT_CALL(*mockSender, handleGrantPacket(EqPacket(&packets[1]))); // Set DONE packet - Homa::Mock::MockDriver::MockPacket donePacket{payload[2], 1024}; + Homa::Mock::MockDriver::PacketBuf donePacketBuf{payload[2]}; + Driver::Packet donePacket = donePacketBuf.toPacket(1024); static_cast(donePacket.payload) ->common.opcode = Protocol::Packet::DONE; - packets[2] = &donePacket; - EXPECT_CALL(*mockSender, handleDonePacket(Eq(&donePacket))); + packets[2] = donePacket; + EXPECT_CALL(*mockSender, handleDonePacket(EqPacket(&packets[2]))); // Set RESEND packet - Homa::Mock::MockDriver::MockPacket resendPacket{payload[3], 1024}; + Homa::Mock::MockDriver::PacketBuf resendPacketBuf{payload[3]}; + Driver::Packet resendPacket = resendPacketBuf.toPacket(1024); static_cast(resendPacket.payload) ->common.opcode = Protocol::Packet::RESEND; - packets[3] = &resendPacket; - EXPECT_CALL(*mockSender, handleResendPacket(Eq(&resendPacket))); + packets[3] = resendPacket; + EXPECT_CALL(*mockSender, handleResendPacket(EqPacket(&packets[3]))); // Set BUSY packet - Homa::Mock::MockDriver::MockPacket busyPacket{payload[4], 1024}; + Homa::Mock::MockDriver::PacketBuf busyPacketBuf{payload[4]}; + Driver::Packet busyPacket = busyPacketBuf.toPacket(1024); static_cast(busyPacket.payload) ->common.opcode = Protocol::Packet::BUSY; - packets[4] = &busyPacket; - EXPECT_CALL(*mockReceiver, handleBusyPacket(Eq(&busyPacket))); + packets[4] = busyPacket; + EXPECT_CALL(*mockReceiver, handleBusyPacket(EqPacket(&packets[4]))); // Set PING packet - Homa::Mock::MockDriver::MockPacket pingPacket{payload[5], 1024}; + Homa::Mock::MockDriver::PacketBuf pingPacketBuf{payload[5]}; + Driver::Packet pingPacket = pingPacketBuf.toPacket(1024); static_cast(pingPacket.payload) ->common.opcode = Protocol::Packet::PING; - packets[5] = &pingPacket; - EXPECT_CALL(*mockReceiver, handlePingPacket(Eq(&pingPacket), _)); + packets[5] = pingPacket; + EXPECT_CALL(*mockReceiver, handlePingPacket(EqPacket(&packets[5]), _)); // Set UNKNOWN packet - Homa::Mock::MockDriver::MockPacket unknownPacket{payload[6], 1024}; + Homa::Mock::MockDriver::PacketBuf unknownPacketBuf{payload[6]}; + Driver::Packet unknownPacket = unknownPacketBuf.toPacket(1024); static_cast(unknownPacket.payload) ->common.opcode = Protocol::Packet::UNKNOWN; - packets[6] = &unknownPacket; - EXPECT_CALL(*mockSender, handleUnknownPacket(Eq(&unknownPacket))); + packets[6] = unknownPacket; + EXPECT_CALL(*mockSender, handleUnknownPacket(EqPacket(&packets[6]))); // Set ERROR packet - Homa::Mock::MockDriver::MockPacket errorPacket{payload[7], 1024}; + Homa::Mock::MockDriver::PacketBuf errorPacketBuf{payload[7]}; + Driver::Packet errorPacket = errorPacketBuf.toPacket(1024); static_cast(errorPacket.payload) ->common.opcode = Protocol::Packet::ERROR; - packets[7] = &errorPacket; - EXPECT_CALL(*mockSender, handleErrorPacket(Eq(&errorPacket))); + packets[7] = errorPacket; + EXPECT_CALL(*mockSender, handleErrorPacket(EqPacket(&packets[7]))); - EXPECT_CALL(mockDriver, receivePackets) + EXPECT_CALL(*mockDriver, receivePackets) .WillOnce(DoAll(SetArrayArgument<1>(packets, packets + 8), Return(8))); - transport->processPackets(); + poller.processPackets(); } } // namespace diff --git a/src/TransportPoller.cc b/src/TransportPoller.cc new file mode 100644 index 0000000..0ff09b5 --- /dev/null +++ b/src/TransportPoller.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "Homa/Utils/TransportPoller.h" +#include +#include "Homa/Homa.h" +#include "Perf.h" + +namespace Homa { + +/** + * Transport poller constructor. + * + * @param transport + * Transport instance driven by this poller. + */ +TransportPoller::TransportPoller(Transport* transport) + : transport(transport) + , nextTimeoutCycles(0) +{} + +/** + * Make incremental progress performing all Transport functionality. + * + * This method MUST be called for the Transport to make progress and should + * be called frequently to ensure timely progress. + */ +void +TransportPoller::poll() +{ + // Receive and dispatch incoming packets. + processPackets(); + + // Allow sender and receiver to make incremental progress. + uint64_t waitUntil; + transport->trySend(&waitUntil); + transport->trySendGrants(); + + if (PerfUtils::Cycles::rdtsc() >= nextTimeoutCycles.load()) { + uint64_t requestedTimeoutCycles = transport->checkTimeouts(); + nextTimeoutCycles.store(requestedTimeoutCycles); + } +} + +/** + * Helper method which receives a burst of incoming packets and process them + * through the transport protocol. Pulled out of TransportPoller::poll() to + * simplify unit testing. + */ +void +TransportPoller::processPackets() +{ + // Keep track of time spent doing active processing versus idle. + uint64_t cycles = PerfUtils::Cycles::rdtsc(); + + const int MAX_BURST = 32; + Driver::Packet packets[MAX_BURST]; + IpAddress srcAddrs[MAX_BURST]; + Driver* driver = transport->getDriver(); + int numPackets = driver->receivePackets(MAX_BURST, packets, srcAddrs); + for (int i = 0; i < numPackets; ++i) { + transport->processPacket(&packets[i], srcAddrs[i]); + } + + cycles = PerfUtils::Cycles::rdtsc() - cycles; + if (numPackets > 0) { + Perf::counters.active_cycles.add(cycles); + } else { + Perf::counters.idle_cycles.add(cycles); + } +} + +} // namespace Homa diff --git a/test/dpdk_test.cc b/test/dpdk_test.cc index 4ca1a82..f1972a9 100644 --- a/test/dpdk_test.cc +++ b/test/dpdk_test.cc @@ -16,8 +16,8 @@ #include #include -#include #include +#include #include #include @@ -38,9 +38,9 @@ int main(int argc, char* argv[]) { std::map args = - docopt::docopt(USAGE, {argv + 1, argv + argc}, - true, // show help if requested - "DPDK Driver Test"); // version string + docopt::docopt(USAGE, {argv + 1, argv + argc}, + true, // show help if requested + "DPDK Driver Test"); // version string std::string iface = args[""].asString(); bool isServer = args["--server"].asBool(); @@ -55,15 +55,15 @@ main(int argc, char* argv[]) std::cout << Homa::IpAddress::toString(driver.getLocalAddress()) << std::endl; while (true) { - Homa::Driver::Packet* incoming[10]; + Homa::Driver::Packet incoming[10]; Homa::IpAddress srcAddrs[10]; uint32_t receivedPackets; do { receivedPackets = driver.receivePackets(10, incoming, srcAddrs); } while (receivedPackets == 0); - Homa::Driver::Packet* pong = driver.allocPacket(); - pong->length = 100; - driver.sendPacket(pong, srcAddrs[0], 0); + Homa::Driver::Packet pong = driver.allocPacket(); + pong.length = 100; + driver.sendPacket(&pong, srcAddrs[0], 0); driver.releasePackets(incoming, receivedPackets); driver.releasePackets(&pong, 1); } @@ -74,15 +74,15 @@ main(int argc, char* argv[]) for (int i = 0; i < 100000; ++i) { uint64_t start = PerfUtils::Cycles::rdtsc(); PerfUtils::TimeTrace::record(start, "START"); - Homa::Driver::Packet* ping = driver.allocPacket(); + Homa::Driver::Packet ping = driver.allocPacket(); PerfUtils::TimeTrace::record("allocPacket"); - ping->length = 100; + ping.length = 100; PerfUtils::TimeTrace::record("set ping args"); - driver.sendPacket(ping, server_ip, 0); + driver.sendPacket(&ping, server_ip, 0); PerfUtils::TimeTrace::record("sendPacket"); driver.releasePackets(&ping, 1); PerfUtils::TimeTrace::record("releasePacket"); - Homa::Driver::Packet* incoming[10]; + Homa::Driver::Packet incoming[10]; Homa::IpAddress srcAddrs[10]; uint32_t receivedPackets; do { diff --git a/test/system_test.cc b/test/system_test.cc index 88b3814..b84b35c 100644 --- a/test/system_test.cc +++ b/test/system_test.cc @@ -16,11 +16,11 @@ #include #include #include -#include +#include +#include #include #include -#include #include #include #include @@ -47,6 +47,7 @@ static const char USAGE[] = R"(Homa System Test. bool _PRINT_CLIENT_ = false; bool _PRINT_SERVER_ = false; +static const uint16_t SERVER_PORT = 60001; struct MessageHeader { uint64_t id; @@ -57,28 +58,33 @@ struct Node { explicit Node(uint64_t id) : id(id) , driver() - , transport(Homa::Transport::create(&driver, id)) + , mailboxDir() + , transport(Homa::Transport::create(&driver, &mailboxDir, id)) , thread() , run(false) + , serverSocket(transport->open(SERVER_PORT)) {} const uint64_t id; Homa::Drivers::Fake::FakeDriver driver; - Homa::Transport* transport; + Homa::SimpleMailboxDir mailboxDir; + Homa::unique_ptr transport; std::thread thread; std::atomic run; + Homa::unique_ptr serverSocket; }; void serverMain(Node* server, std::vector addresses) { + Homa::TransportPoller poller(server->transport.get()); while (true) { if (server->run.load() == false) { break; } Homa::unique_ptr message = - server->transport->receive(); + server->serverSocket->receive(false); if (message) { MessageHeader header; @@ -92,7 +98,7 @@ serverMain(Node* server, std::vector addresses) } message->acknowledge(); } - server->transport->poll(); + poller.poll(); } } @@ -112,16 +118,18 @@ clientMain(int count, int size, std::vector addresses) int numFailed = 0; Node client(1); + Homa::TransportPoller poller(client.transport.get()); + Homa::unique_ptr clientSocket = client.transport->open(0); for (int i = 0; i < count; ++i) { uint64_t id = nextId++; char payload[size]; - for (int i = 0; i < size; ++i) { - payload[i] = randData(gen); + for (char& byte : payload) { + byte = randData(gen); } Homa::IpAddress destAddress = addresses[randAddr(gen)]; - Homa::unique_ptr message = client.transport->alloc(0); + Homa::unique_ptr message = clientSocket->alloc(); { MessageHeader header; header.id = id; @@ -133,7 +141,7 @@ clientMain(int count, int size, std::vector addresses) << std::endl; } } - message->send(Homa::SocketAddress{destAddress, 60001}); + message->send(Homa::SocketAddress{destAddress, SERVER_PORT}); while (1) { Homa::OutMessage::Status status = message->getStatus(); @@ -143,7 +151,7 @@ clientMain(int count, int size, std::vector addresses) numFailed++; break; } - client.transport->poll(); + poller.poll(); } } return numFailed; @@ -193,16 +201,14 @@ main(int argc, char* argv[]) servers.push_back(server); } - for (auto it = servers.begin(); it != servers.end(); ++it) { - Node* server = *it; + for (auto server : servers) { server->run = true; server->thread = std::move(std::thread(&serverMain, server, addresses)); } int numFails = clientMain(numTests, numBytes, addresses); - for (auto it = servers.begin(); it != servers.end(); ++it) { - Node* server = *it; + for (auto server : servers) { server->run = false; server->thread.join(); delete server; From af2c685dccd80b053a0968bdf3cae8066928f15b Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Wed, 7 Oct 2020 22:54:59 -0700 Subject: [PATCH 07/15] Simplified the Mailbox/MailboxDir API ... by combining Mailbox::close(), Mailbox::deliver() and MailboxDir::open() into MailboxDir::deliver(). --- include/Homa/Homa.h | 33 ++++-------------- include/Homa/Utils/SimpleMailboxDir.h | 4 +-- src/Receiver.cc | 7 ++-- src/Shenango.cc | 46 ++++++------------------ src/SimpleMailboxDir.cc | 50 ++++++++++++++++++++------- 5 files changed, 59 insertions(+), 81 deletions(-) diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index 9311559..c82a1b3 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -305,27 +305,6 @@ class Mailbox { */ virtual ~Mailbox() = default; - /** - * Signal that the caller will not access the mailbox after this call. - * A mailbox will only be destroyed if it's removed from the directory - * and closed by all openers. - * - * Not meant to be called by users. - * - * @sa MailboxDir::open() - */ - virtual void close() = 0; - - /** - * Used by a transport to deliver an ingress message to this mailbox. - * - * Not meant to be called by users. - * - * @param message - * An ingress message just completed by the transport. - */ - virtual void deliver(InMessage* message) = 0; - /** * Retrieve a message currently stored in the mailbox. * @@ -382,17 +361,19 @@ class MailboxDir { virtual Mailbox* alloc(uint16_t port) = 0; /** - * Find and open the mailbox that matches the given port number. Once a - * mailbox is opened, it's guaranteed to remain usable even if someone else - * removes it from the directory. + * Used by a transport to deliver an ingress message to a mailbox. + * + * Not meant to be called by users. * * @param port * Port number which identifies the mailbox. + * @param message + * An ingress message just completed by the transport. * @return - * Pointer to the opened mailbox on success; nullptr, if the desired + * True if the message is delivered successfully; false, if the target * mailbox doesn't exist. */ - virtual Mailbox* open(uint16_t port) = 0; + virtual bool deliver(uint16_t port, InMessage* message) = 0; /** * Remove the mailbox that matches the given port number. diff --git a/include/Homa/Utils/SimpleMailboxDir.h b/include/Homa/Utils/SimpleMailboxDir.h index 78f3314..75aca98 100644 --- a/include/Homa/Utils/SimpleMailboxDir.h +++ b/include/Homa/Utils/SimpleMailboxDir.h @@ -47,14 +47,14 @@ class SimpleMailboxDir final : public MailboxDir { explicit SimpleMailboxDir(); ~SimpleMailboxDir() override; Mailbox* alloc(uint16_t port) override; - Mailbox* open(uint16_t port) override; + bool deliver(uint16_t port, Homa::InMessage* message) override; bool remove(uint16_t port) override; private: /// Monitor-style lock. std::unique_ptr mutex; - /// Hash table that maps port numbers to mailboxes. + /// Hash table that maps port numbers to mailboxes. Protected by mutex. std::unordered_map map; }; diff --git a/src/Receiver.cc b/src/Receiver.cc index bbc685f..12dbf2f 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -165,11 +165,8 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) message->setState(Message::State::COMPLETED); bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); uint16_t dport = be16toh(header->common.prefix.dport); - Mailbox* mailbox = mailboxDir->open(dport); - if (mailbox) { - mailbox->deliver(message); - mailbox->close(); - } else { + bool success = mailboxDir->deliver(dport, message); + if (!success) { lock_bucket.destroy(); ERROR("Unable to deliver the message; message dropped"); dropMessage(message); diff --git a/src/Shenango.cc b/src/Shenango.cc index a57794e..ccb1664 100644 --- a/src/Shenango.cc +++ b/src/Shenango.cc @@ -26,12 +26,6 @@ using namespace Homa; #define DECLARE_SHENANGO_FUNC(ReturnType, MethodName, ...) \ extern ReturnType (*shenango_##MethodName)(__VA_ARGS__); -/** - * Fast thread-local slab-based memory allocation. - */ -DECLARE_SHENANGO_FUNC(void*, smalloc, size_t) -DECLARE_SHENANGO_FUNC(void, sfree, void*) - /** * Protect RCU read-side critical sections. */ @@ -165,30 +159,13 @@ homa_driver_free(homa_driver drv) } /** - * An almost trivial implementation of Mailbox. This class is essentially - * a wrapper around a socket table entry in Shenango (i.e., struct trans_entry). - * + * A trivial implementation of Mailbox for catching errors. */ class ShenangoMailbox final : public Mailbox { public: - explicit ShenangoMailbox(void* trans_entry) - : trans_entry(trans_entry) - {} - + explicit ShenangoMailbox() = default; ~ShenangoMailbox() override = default; - void close() override - { - this->~ShenangoMailbox(); - shenango_sfree(this); - shenango_rcu_read_unlock(); - } - - void deliver(InMessage* message) override - { - shenango_homa_mb_deliver(trans_entry, homa_inmsg{message}); - } - InMessage* retrieve(bool blocking) override { (void)blocking; @@ -199,10 +176,6 @@ class ShenangoMailbox final : public Mailbox { { PANIC("Shenango should never call Homa::Socket::shutdown"); } - - private: - /// An opaque pointer to "struct trans_entry" in Shenango. - void* const trans_entry; }; /** @@ -226,21 +199,22 @@ class ShenangoMailboxDir final : MailboxDir { { // Shenango doesn't rely on Homa::Socket to receive messages, // so there is no need to assign a real mailbox to SocketImpl. - static ShenangoMailbox dummyMailbox(nullptr); + static ShenangoMailbox dummyMailbox; (void)port; return &dummyMailbox; } - Mailbox* open(uint16_t port) override + bool deliver(uint16_t port, InMessage* message) override { - SocketAddress laddr = {local_ip, port}; + // The socket table in Shenango is protected by an RCU. shenango_rcu_read_lock(); + SocketAddress laddr = {local_ip, port}; void* trans_entry = shenango_trans_table_lookup(proto, laddr, {}); - if (!trans_entry) { - return nullptr; + if (trans_entry) { + shenango_homa_mb_deliver(trans_entry, homa_inmsg{message}); } - void* backing = shenango_smalloc(sizeof(ShenangoMailbox)); - return new (backing) ShenangoMailbox(trans_entry); + shenango_rcu_read_unlock(); + return trans_entry != nullptr; } bool remove(uint16_t port) override diff --git a/src/SimpleMailboxDir.cc b/src/SimpleMailboxDir.cc index e74ebd3..dc14f39 100644 --- a/src/SimpleMailboxDir.cc +++ b/src/SimpleMailboxDir.cc @@ -27,8 +27,9 @@ class MailboxImpl : public Mailbox { public: explicit MailboxImpl(); ~MailboxImpl() override; - void close() override; - void deliver(InMessage* message) override; + void open(); + void close(); + void deliver(InMessage* message); InMessage* retrieve(bool blocking) override; void socketShutdown() override; @@ -63,7 +64,24 @@ MailboxImpl::~MailboxImpl() } } -/// See Homa::Mailbox::close() +/** + * Signal that the caller will be accessing the mailbox until close() is called. + * Once a mailbox is opened, it's guaranteed to remain usable even if someone + * else removes it from the directory. + */ +void +MailboxImpl::open() +{ + // Increment the reference count of the mailbox, so this mailbox won't be + // deleted even if it's removed from the hash table. + openers.fetch_add(1, std::memory_order_relaxed); +} + +/** + * Signal that the caller will not access the mailbox after this call. + * A mailbox will only be destroyed if it's removed from the directory + * and closed by all openers. + */ void MailboxImpl::close() { @@ -75,7 +93,12 @@ MailboxImpl::close() } } -/// See Homa::Mailbox::deliver() +/** + * Deliver an ingress message to this mailbox. + * + * @param message + * An ingress message just completed by the transport. + */ void MailboxImpl::deliver(InMessage* message) { @@ -131,24 +154,27 @@ SimpleMailboxDir::alloc(uint16_t port) return mailbox; } -Mailbox* -SimpleMailboxDir::open(uint16_t port) +bool +SimpleMailboxDir::deliver(uint16_t port, Homa::InMessage* message) { + // Find the mailbox. MailboxImpl* mailbox = nullptr; { - // Look up the mailbox SpinLock::Lock _(*mutex); auto it = map.find(port); if (it != map.end()) { mailbox = it->second; } + if (mailbox == nullptr) { + return false; + } } - // Increment the reference count of the mailbox. - if (mailbox) { - mailbox->openers.fetch_add(1, std::memory_order_relaxed); - } - return mailbox; + // Deliver the message. + mailbox->open(); + mailbox->deliver(message); + mailbox->close(); + return true; } bool From 91db428f9fb9e52082080ad2d02b2106fac0bae7 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Fri, 9 Oct 2020 01:07:14 -0700 Subject: [PATCH 08/15] wip: reduce more apis --- CMakeLists.txt | 1 - include/Homa/Bindings/CHoma.h | 66 +------ include/Homa/Homa.h | 247 ++++---------------------- include/Homa/Shenango.h | 20 ++- include/Homa/Utils/SimpleMailboxDir.h | 61 ------- src/CHoma.cc | 69 +------ src/Homa.cc | 4 +- src/Mock/MockSender.h | 2 +- src/Receiver.cc | 12 +- src/Receiver.h | 8 +- src/ReceiverTest.cc | 31 +++- src/Sender.cc | 36 +--- src/Sender.h | 15 +- src/SenderTest.cc | 16 +- src/Shenango.cc | 70 +++----- src/SimpleMailboxDir.cc | 197 -------------------- src/TransportImpl.cc | 100 ++--------- src/TransportImpl.h | 60 +------ test/system_test.cc | 37 ++-- 19 files changed, 189 insertions(+), 863 deletions(-) delete mode 100644 include/Homa/Utils/SimpleMailboxDir.h delete mode 100644 src/SimpleMailboxDir.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 4174fc1..775db86 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,7 +82,6 @@ add_library(Homa src/Receiver.cc src/Sender.cc src/Shenango.cc - src/SimpleMailboxDir.cc src/StringUtil.cc src/ThreadId.cc src/TransportImpl.cc diff --git a/include/Homa/Bindings/CHoma.h b/include/Homa/Bindings/CHoma.h index 253afea..d15f4ea 100644 --- a/include/Homa/Bindings/CHoma.h +++ b/include/Homa/Bindings/CHoma.h @@ -44,13 +44,11 @@ extern "C" { void* p; \ } homa_##x; -DEFINE_HOMA_OBJ_HANDLE(driver) /* Homa::Driver */ -DEFINE_HOMA_OBJ_HANDLE(inmsg) /* Homa::InMessage */ -DEFINE_HOMA_OBJ_HANDLE(outmsg) /* Homa::OutMessage */ -DEFINE_HOMA_OBJ_HANDLE(mailbox) /* Homa::Mailbox */ -DEFINE_HOMA_OBJ_HANDLE(mailbox_dir) /* Homa::MailboxDir */ -DEFINE_HOMA_OBJ_HANDLE(sk) /* Homa::Socket */ -DEFINE_HOMA_OBJ_HANDLE(trans) /* Homa::Transport */ +DEFINE_HOMA_OBJ_HANDLE(callbacks) /* Homa::Callbacks */ +DEFINE_HOMA_OBJ_HANDLE(driver) /* Homa::Driver */ +DEFINE_HOMA_OBJ_HANDLE(inmsg) /* Homa::InMessage */ +DEFINE_HOMA_OBJ_HANDLE(outmsg) /* Homa::OutMessage */ +DEFINE_HOMA_OBJ_HANDLE(trans) /* Homa::Transport */ /* ============================ */ /* Homa::InMessage API */ @@ -134,52 +132,11 @@ extern void homa_outmsg_reserve(homa_outmsg out_msg, size_t n); */ extern void homa_outmsg_send(homa_outmsg out_msg, uint32_t ip, uint16_t port); -/** - * homa_outmsg_register_cb - C-binding for - * Homa::OutMessage::registerCallbackEndState - */ -extern void homa_outmsg_register_cb_end_state(homa_outmsg out_msg, - void (*cb)(void*), void* data); - /** * homa_outmsg_release - C-binding for Homa::OutMessage::release */ extern void homa_outmsg_release(homa_outmsg out_msg); -/* ============================ */ -/* Homa::Socket API */ -/* ============================ */ - -/** - * homa_sk_alloc - C-binding for Homa::Socket::alloc - */ -extern homa_outmsg homa_sk_alloc(homa_sk sk); - -/** - * homa_sk_receive - C-binding for Homa::Socket::receive - */ -extern homa_inmsg homa_sk_receive(homa_sk sk, bool blocking); - -/** - * homa_sk_shutdown - C-binding for Homa::Socket::shutdown - */ -extern void homa_sk_shutdown(homa_sk sk); - -/** - * homa_sk_is_shutdown - C-binding for Homa::Socket::isShutdown - */ -extern bool homa_sk_is_shutdown(homa_sk sk); - -/** - * homa_sk_local_addr - C-binding for Homa::Socket::getLocalAddress - */ -extern void homa_sk_local_addr(homa_sk sk, uint32_t* ip, uint16_t* port); - -/** - * homa_sk_close - C-binding for Homa::Socket::close - */ -extern void homa_sk_close(homa_sk sk); - /* ============================ */ /* Homa::Transport API */ /* ============================ */ @@ -187,7 +144,7 @@ extern void homa_sk_close(homa_sk sk); /** * homa_trans_create - C-binding for Homa::Transport::create */ -extern homa_trans homa_trans_create(homa_driver drv, homa_mailbox_dir dir, +extern homa_trans homa_trans_create(homa_driver drv, homa_callbacks cbs, uint64_t id); /** @@ -196,9 +153,9 @@ extern homa_trans homa_trans_create(homa_driver drv, homa_mailbox_dir dir, extern void homa_trans_free(homa_trans trans); /** - * homa_trans_open - C-binding for Homa::Transport::open + * homa_trans_alloc - C-binding for Homa::Transport::alloc */ -extern homa_sk homa_trans_open(homa_trans trans, uint16_t port); +extern homa_outmsg homa_trans_alloc(homa_trans trans, uint16_t port); /** * homa_trans_check_timeouts - C-binding for Homa::Transport::checkTimeouts @@ -216,13 +173,6 @@ extern uint64_t homa_trans_id(homa_trans trans); extern void homa_trans_proc(homa_trans trans, uintptr_t desc, void* payload, int32_t len, uint32_t src_ip); -/** - * homa_trans_try_send - C-binding for - * Homa::Transport::registerCallbackSendReady - */ -extern void homa_trans_register_cb_send_ready(homa_trans trans, - void (*cb)(void*), void* data); - /** * homa_trans_try_send - C-binding for Homa::Transport::trySend */ diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index c82a1b3..8d7253a 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -40,12 +40,6 @@ namespace Homa { template using unique_ptr = std::unique_ptr; -/** - * Shorthand for user-defined callback functions which are used by the transport - * library to notify users of certain events. - */ -using Callback = std::function; - /** * Represents a socket address to (from) which we can send (receive) messages. */ @@ -223,15 +217,6 @@ class OutMessage { */ virtual void prepend(const void* source, size_t count) = 0; - /** - * Register a callback function to be invoked when the status of this - * message reaches an end state. - * - * @param func - * The function object to invoke. - */ - virtual void registerCallbackEndState(Callback func) = 0; - /** * Reserve a number of bytes at the beginning of the Message. * @@ -259,6 +244,7 @@ class OutMessage { */ virtual void send(SocketAddress destination, Options options = Options::NONE) = 0; + // FIXME: this is problematic; we can't really call send a second time... protected: /** @@ -275,195 +261,52 @@ class OutMessage { }; /** - * Represents a location which can hold incoming messages temporarily before - * they are consumed by high-level applications. - * - * Despite a one-to-one relationship between Mailbox and Socket, this class - * is decoupled from Socket for three reasons: - *
    - *
  • Abstract out the interaction with the user's thread scheduler: e.g., - * a user system may want to block on receive until a message is delivered; - *
  • Abstract out the mechanism for high-performance message dispatch: e.g., - * a user system may choose to implement the message receive queue with a - * concurrent MPMC queue as opposed to a linked-list protected by a mutex; - *
  • Abstract out the mechanism for safe memory reclamation of the receive - * queue: e.g., RCU is a well-known solution, reference counting is another. - *
- * - * Note: methods in this class are NOT meant to be called by user applications - * directly; instead, they are defined by user applications and called by the - * Homa transport library. - * - * This class is thread-safe. - * - * @sa MailboxDir - */ -class Mailbox { - public: - /** - * Destructor. - */ - virtual ~Mailbox() = default; - - /** - * Retrieve a message currently stored in the mailbox. - * - * Not meant to be called by users; use Socket::receive() instead. - * - * @param blocking - * When set to true, this method should not return until a message - * arrives or the corresponding socket is shut down. - * @return - * A message previously delivered to this mailbox, if the mailbox is - * not empty; nullptr, otherwise. - * - * @sa Socket::receive() - */ - virtual InMessage* retrieve(bool blocking) = 0; - - /** - * Invoked when the corresponding socket of the mailbox is shut down. - * All pending retrieve() requests must return immediately. - */ - virtual void socketShutdown() = 0; -}; - -/** - * Provides a means to keep track of the mailboxes that are currently in use - * by Homa sockets. - * - * This class is separated out from Transport to allow users to 1) use their - * own data structures to store the map from port numbers to mailboxes, and - * 2) apply their own mechanisms to perform synchronization (e.g., hash map - * with fine-grained locks, RCU to delay mailbox destruction, etc). - * - * Similar to Mailbox, methods in this class are NOT meant to be called by - * user applications. - * - * This class is thread-safe. + * Collection of user-defined transport callbacks. */ -class MailboxDir { +class Callbacks { public: /** * Destructor. */ - virtual ~MailboxDir() = default; - - /** - * Allocate a new mailbox in the directory. - * - * @param port - * Port number which identifies the mailbox. - * @return - * Pointer to the new Mailbox on success; nullptr, if the port number - * is already in use. - */ - virtual Mailbox* alloc(uint16_t port) = 0; + virtual ~Callbacks() = default; /** - * Used by a transport to deliver an ingress message to a mailbox. + * Invoked when an incoming message arrives and needs to dispatched to its + * destination in the user application for processing. * - * Not meant to be called by users. + * Here are a few example use cases of this callback: + *
    + *
  • Interaction with the user's thread scheduler: e.g., an application + * may want to block on receive until a message is delivered, so this method + * can be used to wake up blocking threads. + *
  • High-performance message dispatch: e.g., an application may choose + * to implement the message receive queue with a concurrent MPMC queue as + * opposed to a linked-list protected by a mutex; + *
  • Lightweight synchronization: e.g., the socket table that maps from + * port numbers to sockets is a read-mostly data structure, so lookup + * operations can benefit from synchronization schemes such as RCU. + *
* * @param port - * Port number which identifies the mailbox. + * Destination port number of the message. * @param message - * An ingress message just completed by the transport. + * Incoming message to dispatch. * @return - * True if the message is delivered successfully; false, if the target - * mailbox doesn't exist. + * True if the message is delivered successfully; false, otherwise. */ virtual bool deliver(uint16_t port, InMessage* message) = 0; /** - * Remove the mailbox that matches the given port number. + * Invoked when some packets just became ready to be sent (and there was + * none before). * - * @param port - * Port number of the mailbox that will be removed. - * @return - * True on success; false, if the desired mailbox doesn't exist. - */ - virtual bool remove(uint16_t port) = 0; -}; - -/** - * Connection-less socket that can be used to send and receive Homa messages. - * - * This class is thread-safe. - */ -class Socket { - public: - using Address = SocketAddress; - - /** - * Custom deleter for use with Homa::unique_ptr. - */ - struct Deleter { - void operator()(Socket* socket) - { - socket->close(); - } - }; - - /** - * Allocate Message that can be sent with this Socket. - * - * @param sourcePort - * Port number of the socket from which the message will be sent. - * @return - * A pointer to the allocated message or nullptr if the socket has - * been shut down. - */ - virtual Homa::unique_ptr alloc() = 0; - - /** - * Check for and return a Message sent to this Socket if available. - * - * @param blocking - * When set to true, this method should not return until a message - * arrives or the socket is shut down. - * @return - * Pointer to the received message, if any. Otherwise, nullptr is - * returned if no message has been delivered or the socket has been - * shut down. - */ - virtual Homa::unique_ptr receive(bool blocking) = 0; - - /** - * Disable the socket. Once a socket is shut down, all ongoing/subsequent - * requests on the socket will return a failure. - * - * When multiple threads are working on a socket, this method can be used - * to notify other threads to drop their references to this socket so that - * the caller can safely close() the socket. - */ - virtual void shutdown() = 0; - - /** - * Check if the Socket has been shut down. - */ - virtual bool isShutdown() const = 0; - - /** - * Return the local IP address and port number of this Socket. - */ - virtual Socket::Address getLocalAddress() const = 0; - - protected: - /** - * Use protected destructor to prevent users from calling delete on pointers - * to this interface. - */ - ~Socket() = default; - - /** - * Signal that this Socket is no longer needed. No one should access this - * socket after this call. - * - * Note: outgoing messages already allocated from this socket will not be - * affected. + * This callback allows the transport library to notify the users that + * trySend() should be invoked again as soon as possible. For example, + * the callback can be used to implement wakeup signals for the thread + * that is responsible for calling trySend(), if this thread decides to + * sleep when there is no packets to send. */ - virtual void close() = 0; + virtual void notifySendReady() {} }; /** @@ -492,9 +335,9 @@ class Transport { * * @param driver * Driver with which this transport should send and receive packets. - * @param mailboxDir - * Mailbox directory with which this transport should decide where - * to deliver a message. + * @param callbacks + * Collection of user-defined callbacks to customize the behavior of + * the transport. * @param transportId * This transport's unique identifier in the group of transports among * which this transport will communicate. @@ -502,19 +345,18 @@ class Transport { * Pointer to the new transport instance. */ static Homa::unique_ptr create(Driver* driver, - MailboxDir* mailboxDir, + Callbacks* callbacks, uint64_t transportId); /** - * Create a socket that can be used to send and receive messages. + * Allocate Message that can be sent with this Transport. * * @param port - * The port number allocated to the socket. + * Port number of the socket from which the message will be sent. * @return - * Pointer to the new socket, if the port number is not in use; - * nullptr, otherwise. + * A pointer to the allocated message. */ - virtual Homa::unique_ptr open(uint16_t port) = 0; + virtual Homa::unique_ptr alloc(uint16_t port) = 0; /** * Return the driver that this transport uses to send and receive packets. @@ -548,21 +390,6 @@ class Transport { */ virtual void processPacket(Driver::Packet* packet, IpAddress source) = 0; - /** - * Register a callback function to be invoked when some packets just became - * ready to be sent (and there was none before). - * - * This callback allows the transport library to notify the users that - * trySend() should be invoked again as soon as possible. For example, - * the callback can be used to implement wakeup signals for the thread - * that is responsible for calling trySend(), if this thread decides to - * sleep when there is no packets to send. - * - * @param func - * The function object to invoke. - */ - virtual void registerCallbackSendReady(Callback func) = 0; - /** * Attempt to send out packets for any messages with unscheduled/granted * bytes in a way that limits queue buildup in the NIC. diff --git a/include/Homa/Shenango.h b/include/Homa/Shenango.h index 0d5bfea..f424357 100644 --- a/include/Homa/Shenango.h +++ b/include/Homa/Shenango.h @@ -56,21 +56,25 @@ extern homa_driver homa_driver_create(uint8_t proto, uint32_t local_ip, extern void homa_driver_free(homa_driver drv); /** - * homa_mb_dir_create - creates a shim mailbox directory that translates - * Homa::Mailbox operations to Shenango functions + * homa_callbacks_create - creates a collection of the Shenango-defined + * callbacks for the transport. * @proto: protocol number reserved for Homa transport protocol * @local_ip: local IP address of the driver + * @cb_send_ready: callback function to invoke in Callbacks::notifySendReady + * @cb_data: input data for @cb_send_ready * - * Returns a handle to the mailbox created. + * Returns a handle to the callbacks created. */ -extern homa_mailbox_dir homa_mb_dir_create(uint8_t proto, uint32_t local_ip); +extern homa_callbacks homa_callbacks_create(uint8_t proto, uint32_t local_ip, + void (*cb_send_ready)(void*), + void* cb_data); /** - * homa_mb_dir_free - frees a shim mailbox directory created earlier with - * @homa_mb_dir_create. - * @param mailbox_dir: the mailbox directory to free + * homa_callbacks_free - frees the Callbacks object created earlier with + * @homa_callbacks_create. + * @param cbs: the callbacks to free */ -extern void homa_mb_dir_free(homa_mailbox_dir mailbox_dir); +extern void homa_callbacks_free(homa_callbacks cbs); #ifdef __cplusplus } diff --git a/include/Homa/Utils/SimpleMailboxDir.h b/include/Homa/Utils/SimpleMailboxDir.h deleted file mode 100644 index 75aca98..0000000 --- a/include/Homa/Utils/SimpleMailboxDir.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright (c) 2020, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -/** - * @file Homa/Utils/SimpleMailboxDir.h - * - * Contains a simple reference implementation for the pluggable mailbox - * directory in the Homa transport library. A mailbox directory is essential - * to get a working transport but it's not central to the Homa protocol. - * - * Users may choose to use this reference implementation for starter, or define - * their own implementation for best performance. - */ - -#pragma once - -#include -#include - -namespace Homa { - -/// Forward declaration -class SpinLock; -class MailboxImpl; - -/** - * A simple reference implementation of Homa::MailboxDir. - * - * This class relies on a monitor-style lock to protect the hash table that - * maps port numbers to mailboxes and uses reference-counting for safe - * reclamation of removed mailboxes. - */ -class SimpleMailboxDir final : public MailboxDir { - public: - explicit SimpleMailboxDir(); - ~SimpleMailboxDir() override; - Mailbox* alloc(uint16_t port) override; - bool deliver(uint16_t port, Homa::InMessage* message) override; - bool remove(uint16_t port) override; - - private: - /// Monitor-style lock. - std::unique_ptr mutex; - - /// Hash table that maps port numbers to mailboxes. Protected by mutex. - std::unordered_map map; -}; - -} // namespace Homa diff --git a/src/CHoma.cc b/src/CHoma.cc index d2a0d54..91ef3b3 100644 --- a/src/CHoma.cc +++ b/src/CHoma.cc @@ -108,14 +108,6 @@ homa_outmsg_send(homa_outmsg out_msg, uint32_t ip, uint16_t port) deref(OutMessage, out_msg).send({IpAddress{ip}, port}); } -void -homa_outmsg_register_cb_end_state(homa_outmsg out_msg, void (*cb)(void*), - void* data) -{ - std::function func = std::bind(cb, data); - deref(OutMessage, out_msg).registerCallbackEndState(func); -} - void homa_outmsg_release(homa_outmsg out_msg) { @@ -123,52 +115,11 @@ homa_outmsg_release(homa_outmsg out_msg) deleter(&deref(OutMessage, out_msg)); } -homa_outmsg -homa_sk_alloc(homa_sk sk) -{ - unique_ptr out_msg = deref(Socket, sk).alloc(); - return homa_outmsg{out_msg.release()}; -} - -homa_inmsg -homa_sk_receive(homa_sk sk, bool blocking) -{ - unique_ptr in_msg = deref(Socket, sk).receive(blocking); - return homa_inmsg{in_msg.release()}; -} - -void -homa_sk_shutdown(homa_sk sk) -{ - deref(Socket, sk).shutdown(); -} - -bool -homa_sk_is_shutdown(homa_sk sk) -{ - return deref(Socket, sk).isShutdown(); -} - -void -homa_sk_local_addr(homa_sk sk, uint32_t* ip, uint16_t* port) -{ - SocketAddress addr = deref(Socket, sk).getLocalAddress(); - *ip = (uint32_t)addr.ip; - *port = addr.port; -} - -void -homa_sk_close(homa_sk sk) -{ - Socket::Deleter deleter; - deleter(&deref(Socket, sk)); -} - homa_trans -homa_trans_create(homa_driver drv, homa_mailbox_dir dir, uint64_t id) +homa_trans_create(homa_driver drv, homa_callbacks cbs, uint64_t id) { unique_ptr trans = - Transport::create(&deref(Driver, drv), &deref(MailboxDir, dir), id); + Transport::create(&deref(Driver, drv), &deref(Callbacks, cbs), id); return homa_trans{trans.release()}; } @@ -179,11 +130,11 @@ homa_trans_free(homa_trans trans) deleter(&deref(Transport, trans)); } -homa_sk -homa_trans_open(homa_trans trans, uint16_t port) +homa_outmsg +homa_trans_alloc(homa_trans trans, uint16_t port) { - unique_ptr sk = deref(Transport, trans).open(port); - return homa_sk{sk.release()}; + unique_ptr out_msg = deref(Transport, trans).alloc(port); + return homa_outmsg{out_msg.release()}; } uint64_t @@ -207,14 +158,6 @@ homa_trans_proc(homa_trans trans, uintptr_t desc, void* payload, int32_t len, deref(Transport, trans).processPacket(&packet, IpAddress{src_ip}); } -void -homa_trans_register_cb_send_ready(homa_trans trans, void (*cb)(void*), - void* data) -{ - std::function func = std::bind(cb, data); - deref(Transport, trans).registerCallbackSendReady(func); -} - bool homa_trans_try_send(homa_trans trans, uint64_t* wait_until) { diff --git a/src/Homa.cc b/src/Homa.cc index b72a4bf..0f9a716 100644 --- a/src/Homa.cc +++ b/src/Homa.cc @@ -20,10 +20,10 @@ namespace Homa { Homa::unique_ptr -Transport::create(Driver* driver, MailboxDir* mailboxDir, uint64_t transportId) +Transport::create(Driver* driver, Callbacks* callbacks, uint64_t transportId) { Transport* transport = - new Core::TransportImpl(driver, mailboxDir, transportId); + new Core::TransportImpl(driver, callbacks, transportId); return Homa::unique_ptr(transport); } diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index 0da8388..91fd17f 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -33,7 +33,7 @@ class MockSender : public Core::Sender { public: MockSender(uint64_t transportId, Driver* driver, uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles) - : Sender(transportId, driver, nullptr, messageTimeoutCycles, + : Sender(transportId, driver, nullptr, nullptr, messageTimeoutCycles, pingIntervalCycles) {} diff --git a/src/Receiver.cc b/src/Receiver.cc index 12dbf2f..2db8ac1 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -29,8 +29,8 @@ namespace Core { * * @param driver * The driver used to send and receive packets. - * @param mailboxDir - * The mailbox directory used to lookup message destination. + * @param callbacks + * User-defined transport callbacks. * @param policyManager * Provides information about the grant and network priority policies. * @param messageTimeoutCycles @@ -40,12 +40,12 @@ namespace Core { * Number of cycles of inactivity to wait between requesting retransmission * of un-received parts of a message. */ -Receiver::Receiver(Driver* driver, MailboxDir* mailboxDir, +Receiver::Receiver(Driver* driver, Callbacks* callbacks, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles) - : driver(driver) + : callbacks(callbacks) + , driver(driver) , policyManager(policyManager) - , mailboxDir(mailboxDir) , messageBuckets(messageTimeoutCycles, resendIntervalCycles) , schedulerMutex() , scheduledPeers() @@ -165,7 +165,7 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) message->setState(Message::State::COMPLETED); bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); uint16_t dport = be16toh(header->common.prefix.dport); - bool success = mailboxDir->deliver(dport, message); + bool success = callbacks->deliver(dport, message); if (!success) { lock_bucket.destroy(); ERROR("Unable to deliver the message; message dropped"); diff --git a/src/Receiver.h b/src/Receiver.h index 78f6f0f..3f96bc7 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -44,7 +44,7 @@ namespace Core { */ class Receiver { public: - explicit Receiver(Driver* driver, MailboxDir* mailboxDir, + explicit Receiver(Driver* driver, Callbacks* callbacks, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles); @@ -466,6 +466,9 @@ class Receiver { void unschedule(Message* message, const SpinLock::Lock& lock); void updateSchedule(Message* message, const SpinLock::Lock& lock); + /// User-defined transport callbacks. Not owned by this class. + Callbacks* const callbacks; + /// Driver with which all packets will be sent and received. This driver /// is chosen by the Transport that owns this Sender. Driver* const driver; @@ -473,9 +476,6 @@ class Receiver { /// Provider of network packet priority and grant policy decisions. Policy::Manager* const policyManager; - /// Records where to deliver the messages when they are completed. - MailboxDir* const mailboxDir; - /// Tracks the set of inbound messages being received by this Receiver. MessageBucketMap messageBuckets; diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index aaf2ef8..2fbf8a9 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -15,7 +15,6 @@ #include #include -#include #include #include @@ -61,14 +60,31 @@ MATCHER_P(EqPacketLen, length, "") return arg->length == length; } +class MockCallbacks : public Callbacks { + public: + explicit MockCallbacks() + : receivedMessage() + {} + + bool deliver(uint16_t port, Homa::InMessage* message) override + { + if (port != 60001) { + return false; + } + receivedMessage = message; + return true; + } + + Homa::InMessage* receivedMessage; +}; + class ReceiverTest : public ::testing::Test { public: ReceiverTest() - : mockDriver() + : mockCallbacks() + , mockDriver() , mockPacket() , mockPolicyManager(&mockDriver) - , mailboxDir() - , mailbox(mailboxDir.alloc(60001)) , payload() , packetBuf{&payload} , receiver() @@ -79,7 +95,7 @@ class ReceiverTest : public ::testing::Test { mockPacket = packetBuf.toPacket(); Debug::setLogPolicy( Debug::logPolicyFromString("src/ObjectPool@SILENT")); - receiver = new Receiver(&mockDriver, &mailboxDir, &mockPolicyManager, + receiver = new Receiver(&mockDriver, &mockCallbacks, &mockPolicyManager, messageTimeoutCycles, resendIntervalCycles); PerfUtils::Cycles::mockTscValue = 10000; } @@ -95,11 +111,10 @@ class ReceiverTest : public ::testing::Test { static const uint64_t messageTimeoutCycles = 1000; static const uint64_t resendIntervalCycles = 100; + MockCallbacks mockCallbacks; NiceMock mockDriver; Driver::Packet mockPacket; NiceMock mockPolicyManager; - SimpleMailboxDir mailboxDir; - Mailbox* mailbox; char payload[1028]; Homa::Mock::MockDriver::PacketBuf packetBuf; Receiver* receiver; @@ -221,7 +236,7 @@ TEST_F(ReceiverTest, handleDataPacket) EXPECT_EQ(4U, message->numPackets); EXPECT_EQ(0U, info->bytesRemaining); EXPECT_EQ(Receiver::Message::State::COMPLETED, message->state); - EXPECT_EQ(message, mailbox->retrieve(false)); + EXPECT_EQ(message, mockCallbacks.receivedMessage); Mock::VerifyAndClearExpectations(&mockDriver); // ------------------------------------------------------------------------- diff --git a/src/Sender.cc b/src/Sender.cc index 6b25f8d..b3f9fef 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -33,6 +33,8 @@ namespace Core { * Unique identifier for the Transport that owns this Sender. * @param driver * The driver used to send and receive packets. + * @param callbacks + * Collections of user-defined transport callbacks. * @param policyManager * Provides information about the network packet priority policies. * @param messageTimeoutCycles @@ -42,10 +44,11 @@ namespace Core { * Number of cycles of inactivity to wait between checking on the liveness * of an Sender::Message. */ -Sender::Sender(uint64_t transportId, Driver* driver, +Sender::Sender(uint64_t transportId, Driver* driver, Callbacks* callbacks, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles) : transportId(transportId) + , callbacks(callbacks) , driver(driver) , policyManager(policyManager) , nextMessageSequenceNumber(1) @@ -55,7 +58,6 @@ Sender::Sender(uint64_t transportId, Driver* driver, , messageBuckets(messageTimeoutCycles, pingIntervalCycles) , queueMutex() , sendReady(false) - , notifySendReady() , sendQueue() , messageAllocator() {} @@ -592,16 +594,6 @@ void Sender::Message::setStatus(OutMessage::Status newStatus) { state.store(newStatus, std::memory_order_release); - if (notifyEndState) { - switch (newStatus) { - case OutMessage::Status::CANCELED: - case OutMessage::Status::COMPLETED: - case OutMessage::Status::FAILED: - notifyEndState(); - default: - break; - } - } } /** @@ -643,15 +635,6 @@ Sender::Message::prepend(const void* source, size_t count) } } -/** - * @copydoc Homa::OutMessage::registerCallbackEndState() - */ -void -Sender::Message::registerCallbackEndState(Callback func) -{ - notifyEndState = std::move(func); -} - /** * @copydoc Homa::OutMessage::release() */ @@ -999,13 +982,6 @@ Sender::checkPingTimeouts() return globalNextTimeout; } -/// See Homa::Transport::registerCallbackSendReady() -void -Sender::registerCallbackSendReady(Callback func) -{ - notifySendReady = std::move(func); -} - /** * Attempt to wake up the pacer thread that is responsible for calling trySend() * repeatedly, if it's currently blocked waiting for packets to become ready to @@ -1022,9 +998,7 @@ Sender::signalPacerThread(const SpinLock::Lock& lockHeld) { (void)lockHeld; sendReady = true; - if (notifySendReady) { - notifySendReady(); - } + callbacks->notifySendReady(); } /** diff --git a/src/Sender.h b/src/Sender.h index b8e56df..7e7fe64 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -41,7 +41,7 @@ namespace Core { */ class Sender { public: - explicit Sender(uint64_t transportId, Driver* driver, + explicit Sender(uint64_t transportId, Driver* driver, Callbacks* callbacks, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles); virtual ~Sender(); @@ -53,7 +53,6 @@ class Sender { virtual void handleUnknownPacket(Driver::Packet* packet); virtual void handleErrorPacket(Driver::Packet* packet); virtual uint64_t checkTimeouts(); - virtual void registerCallbackSendReady(Callback func); virtual bool trySend(uint64_t* waitUntil); private: @@ -149,7 +148,6 @@ class Sender { // packets is not initialized to reduce the work done during // construction. See Message::occupied. , state(Status::NOT_STARTED) - , notifyEndState() , bucketNode(this) , messageTimeout(this) , pingTimeout(this) @@ -162,7 +160,6 @@ class Sender { virtual Status getStatus() const; virtual size_t length() const; virtual void prepend(const void* source, size_t count); - virtual void registerCallbackEndState(Callback func); virtual void release(); virtual void reserve(size_t count); virtual void send(SocketAddress destination, @@ -223,9 +220,6 @@ class Sender { /// This message's current state. std::atomic state; - /// Callback function to invoke when _state_ reaches an end state. - Callback notifyEndState; - /// Intrusive structure used by the Sender to hold on to this Message /// in one of the Sender's MessageBuckets. Access to this structure /// is protected by the associated MessageBucket::mutex; @@ -407,6 +401,9 @@ class Sender { /// Transport identifier. const uint64_t transportId; + /// User-defined transport callbacks; not owned by this class. + Callbacks* const callbacks; + /// Driver with which all packets will be sent and received. This driver /// is chosen by the Transport that owns this Sender. Driver* const driver; @@ -434,10 +431,6 @@ class Sender { /// if there is work to do is more efficient. bool sendReady; - /// Callback function to be invoked when _sendReady_ flips from false to - /// true. - Callback notifySendReady; - /// A list of outbound messages that have unsent packets. Messages are kept /// in order of priority. Intrusive::List sendQueue; diff --git a/src/SenderTest.cc b/src/SenderTest.cc index 96b49c7..3c31969 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -40,10 +40,21 @@ MATCHER_P(EqPacket, p, "") return arg->descriptor == p->descriptor; } +class MockCallbacks : public Callbacks { + public: + explicit MockCallbacks() = default; + + bool deliver(uint16_t port, Homa::InMessage* message) override + { + return true; + } +}; + class SenderTest : public ::testing::Test { public: SenderTest() - : mockDriver() + : mockCallbacks() + , mockDriver() , mockPacket() , mockPolicyManager(&mockDriver) , payload() @@ -57,7 +68,7 @@ class SenderTest : public ::testing::Test { ON_CALL(mockDriver, getQueuedBytes).WillByDefault(Return(0)); Debug::setLogPolicy( Debug::logPolicyFromString("src/ObjectPool@SILENT")); - sender = new Sender(22, &mockDriver, &mockPolicyManager, + sender = new Sender(22, &mockDriver, &mockCallbacks, &mockPolicyManager, messageTimeoutCycles, pingIntervalCycles); PerfUtils::Cycles::mockTscValue = 10000; } @@ -69,6 +80,7 @@ class SenderTest : public ::testing::Test { PerfUtils::Cycles::mockTscValue = 0; } + MockCallbacks mockCallbacks; NiceMock mockDriver; Driver::Packet mockPacket; NiceMock mockPolicyManager; diff --git a/src/Shenango.cc b/src/Shenango.cc index ccb1664..40205c8 100644 --- a/src/Shenango.cc +++ b/src/Shenango.cc @@ -14,6 +14,8 @@ */ #include "Homa/Shenango.h" + +#include #include "Debug.h" #include "Homa/Homa.h" @@ -159,50 +161,18 @@ homa_driver_free(homa_driver drv) } /** - * A trivial implementation of Mailbox for catching errors. + * Shenango-defined callback functions for the transport. */ -class ShenangoMailbox final : public Mailbox { +class ShenangoCallbacks final : Callbacks { public: - explicit ShenangoMailbox() = default; - ~ShenangoMailbox() override = default; - - InMessage* retrieve(bool blocking) override - { - (void)blocking; - PANIC("Shenango should never call Homa::Socket::receive"); - } - - void socketShutdown() override - { - PANIC("Shenango should never call Homa::Socket::shutdown"); - } -}; - -/** - * An almost trivial implementation of MailboxDir that uses Shenango's RCU - * mechanism to prevent a mailbox from being destroyed until all readers have - * closed it. - * - * Note: Shenango doesn't use Homa::Socket to receive messages, so the only - * method that has a meaningful implementation is open(). - */ -class ShenangoMailboxDir final : MailboxDir { - public: - explicit ShenangoMailboxDir(uint8_t proto, uint32_t local_ip) + explicit ShenangoCallbacks(uint8_t proto, uint32_t local_ip, + std::function notify_send_ready) : proto(proto) , local_ip{local_ip} + , notify_send_ready(std::move(notify_send_ready)) {} - ~ShenangoMailboxDir() override = default; - - Mailbox* alloc(uint16_t port) override - { - // Shenango doesn't rely on Homa::Socket to receive messages, - // so there is no need to assign a real mailbox to SocketImpl. - static ShenangoMailbox dummyMailbox; - (void)port; - return &dummyMailbox; - } + ~ShenangoCallbacks() override = default; bool deliver(uint16_t port, InMessage* message) override { @@ -217,12 +187,9 @@ class ShenangoMailboxDir final : MailboxDir { return trans_entry != nullptr; } - bool remove(uint16_t port) override + void notifySendReady() override { - // Nothing to do; Shenango is responsible for taking care of freeing - // the resources related to homa sockets. - (void)port; - return true; + notify_send_ready(); } /// Protocol number reserved for Homa; defined as IPPROTO_HOMA in Shenango. @@ -230,17 +197,22 @@ class ShenangoMailboxDir final : MailboxDir { /// Local IP address of the transport. const IpAddress local_ip; + + /// Callback function for notifySendReady(). + std::function notify_send_ready; }; -homa_mailbox_dir -homa_mb_dir_create(uint8_t proto, uint32_t local_ip) +homa_callbacks +homa_callbacks_create(uint8_t proto, uint32_t local_ip, + void (*cb_send_ready)(void*), void* cb_data) { - void* dir = new ShenangoMailboxDir(proto, local_ip); - return homa_mailbox_dir{dir}; + void* cbs = new ShenangoCallbacks(proto, local_ip, + std::bind(cb_send_ready, cb_data)); + return homa_callbacks{cbs}; } void -homa_mb_dir_free(homa_mailbox_dir mailbox_dir) +homa_callbacks_free(homa_callbacks cbs) { - delete static_cast(mailbox_dir.p); + delete static_cast(cbs.p); } diff --git a/src/SimpleMailboxDir.cc b/src/SimpleMailboxDir.cc deleted file mode 100644 index dc14f39..0000000 --- a/src/SimpleMailboxDir.cc +++ /dev/null @@ -1,197 +0,0 @@ -/* Copyright (c) 2020, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#include -#include -#include "SpinLock.h" - -namespace Homa { - -/** - * A simple reference implementation of Homa::Mailbox that uses polling to - * detect incoming messages. - */ -class MailboxImpl : public Mailbox { - public: - explicit MailboxImpl(); - ~MailboxImpl() override; - void open(); - void close(); - void deliver(InMessage* message); - InMessage* retrieve(bool blocking) override; - void socketShutdown() override; - - /// Protects the queue - SpinLock mutex; - - /// Keeps track of the number of calls to open() without paired close(). - /// It's initialized to one because, intuitively, a Socket must keep its - /// mailbox "open" in order to retrieve incoming messages. - std::atomic openers; - - /// Has the corresponding socket been shut down? - std::atomic shutdown; - - /// List of completely received messages. - std::list queue; -}; - -MailboxImpl::MailboxImpl() - : mutex() - , openers(1) - , shutdown() - , queue() -{} - -MailboxImpl::~MailboxImpl() -{ - while (!queue.empty()) { - InMessage* message = queue.front(); - queue.pop_front(); - Homa::unique_ptr deleter(message); - } -} - -/** - * Signal that the caller will be accessing the mailbox until close() is called. - * Once a mailbox is opened, it's guaranteed to remain usable even if someone - * else removes it from the directory. - */ -void -MailboxImpl::open() -{ - // Increment the reference count of the mailbox, so this mailbox won't be - // deleted even if it's removed from the hash table. - openers.fetch_add(1, std::memory_order_relaxed); -} - -/** - * Signal that the caller will not access the mailbox after this call. - * A mailbox will only be destroyed if it's removed from the directory - * and closed by all openers. - */ -void -MailboxImpl::close() -{ - if (openers.fetch_sub(1, std::memory_order_release) == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - - // MailboxImpl was instantiated via "new" in SimpleMailboxDir::alloc. - delete this; - } -} - -/** - * Deliver an ingress message to this mailbox. - * - * @param message - * An ingress message just completed by the transport. - */ -void -MailboxImpl::deliver(InMessage* message) -{ - SpinLock::Lock _(mutex); - queue.push_back(message); -} - -/// See Homa::Mailbox::retrieve() -InMessage* -MailboxImpl::retrieve(bool blocking) -{ - InMessage* message = nullptr; - do { - SpinLock::Lock _(mutex); - if (!queue.empty()) { - message = queue.front(); - queue.pop_front(); - } - } while (blocking && !shutdown.load(std::memory_order_relaxed)); - return message; -} - -/// See Homa::Mailbox::socketShutdown() -void -MailboxImpl::socketShutdown() -{ - shutdown.store(true); -} - -SimpleMailboxDir::SimpleMailboxDir() - : mutex(new SpinLock()) - , map() -{} - -SimpleMailboxDir::~SimpleMailboxDir() -{ - for (auto entry : map) { - MailboxImpl* mailbox = entry.second; - mailbox->close(); - } -} - -Mailbox* -SimpleMailboxDir::alloc(uint16_t port) -{ - MailboxImpl* mailbox = nullptr; - SpinLock::Lock _(*mutex); - auto it = map.find(port); - if (it == map.end()) { - mailbox = new MailboxImpl(); - map[port] = mailbox; - } - return mailbox; -} - -bool -SimpleMailboxDir::deliver(uint16_t port, Homa::InMessage* message) -{ - // Find the mailbox. - MailboxImpl* mailbox = nullptr; - { - SpinLock::Lock _(*mutex); - auto it = map.find(port); - if (it != map.end()) { - mailbox = it->second; - } - if (mailbox == nullptr) { - return false; - } - } - - // Deliver the message. - mailbox->open(); - mailbox->deliver(message); - mailbox->close(); - return true; -} - -bool -SimpleMailboxDir::remove(uint16_t port) -{ - MailboxImpl* mailbox; - { - SpinLock::Lock _(*mutex); - auto it = map.find(port); - if (it == map.end()) { - return false; - } - mailbox = it->second; - map.erase(it); - } - mailbox->close(); - return true; -} - -} // namespace Homa diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index 16fd42b..b2df218 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -34,39 +34,39 @@ const uint64_t RESEND_INTERVAL_US = BASE_TIMEOUT_US; * * @param driver * Driver with which this transport should send and receive packets. - * @param mailboxDir - * Mailbox directory with which this transport should deliver messages. + * @param callbacks + * User-defined transport callbacks. * @param transportId * This transport's unique identifier in the group of transports among * which this transport will communicate. */ -TransportImpl::TransportImpl(Driver* driver, MailboxDir* mailboxDir, +TransportImpl::TransportImpl(Driver* driver, Callbacks* callbacks, uint64_t transportId) : transportId(transportId) + , callbacks(callbacks) , driver(driver) , policyManager(new Policy::Manager(driver)) - , sender(new Sender(transportId, driver, policyManager.get(), + , sender(new Sender(transportId, driver, callbacks, policyManager.get(), PerfUtils::Cycles::fromMicroseconds(MESSAGE_TIMEOUT_US), PerfUtils::Cycles::fromMicroseconds(PING_INTERVAL_US))) , receiver( - new Receiver(driver, mailboxDir, policyManager.get(), + new Receiver(driver, callbacks, policyManager.get(), PerfUtils::Cycles::fromMicroseconds(MESSAGE_TIMEOUT_US), PerfUtils::Cycles::fromMicroseconds(RESEND_INTERVAL_US))) - , mailboxDir(mailboxDir) {} /** * Construct an instance of a Homa-based transport for unit testing. */ -TransportImpl::TransportImpl(Driver* driver, MailboxDir* mailboxDir, +TransportImpl::TransportImpl(Driver* driver, Callbacks* callbacks, Sender* sender, Receiver* receiver, uint64_t transportId) : transportId(transportId) + , callbacks(callbacks) , driver(driver) , policyManager(new Policy::Manager(driver)) , sender(sender) , receiver(receiver) - , mailboxDir(mailboxDir) {} /** @@ -87,16 +87,12 @@ TransportImpl::free() delete this; } -/// See Homa::Transport::open() -Homa::unique_ptr -TransportImpl::open(uint16_t port) +/// See Homa::Transport::alloc() +Homa::unique_ptr +TransportImpl::alloc(uint16_t port) { - Mailbox* mailbox = mailboxDir->alloc(port); - if (!mailbox) { - return nullptr; - } - SocketImpl* socket = new SocketImpl(this, port, mailbox); - return Homa::unique_ptr(socket); + OutMessage* outMessage = sender->allocMessage(port); + return unique_ptr(outMessage); } /// See Homa::Transport::checkTimeouts() @@ -153,13 +149,6 @@ TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) } } -/// See Homa::Transport::registerCallbackSendReady() -void -TransportImpl::registerCallbackSendReady(Callback func) -{ - sender->registerCallbackSendReady(func); -} - /// See Homa::Transport::trySend() bool TransportImpl::trySend(uint64_t* waitUntil) @@ -174,67 +163,4 @@ TransportImpl::trySendGrants() return receiver->trySendGrants(); } -/** - * Construct an instance of a Homa socket. - * - * @param transport - * Transport that owns the socket. - * @param port - * Local port number of the socket. - * @param mailbox - * Mailbox assigned to this socket. - */ -TransportImpl::SocketImpl::SocketImpl(TransportImpl* transport, uint16_t port, - Mailbox* mailbox) - : Socket() - , disabled() - , localAddress{transport->getDriver()->getLocalAddress(), port} - , mailbox(mailbox) - , transport(transport) -{} - -/// See Homa::Socket::alloc() -unique_ptr -TransportImpl::SocketImpl::alloc() -{ - if (isShutdown()) { - return nullptr; - } - OutMessage* outMessage = transport->sender->allocMessage(localAddress.port); - return unique_ptr(outMessage); -} - -/// See Homa::Socket::close() -void -TransportImpl::SocketImpl::close() -{ - bool success = transport->mailboxDir->remove(localAddress.port); - if (!success) { - ERROR("Failed to remove mailbox (port = %u)", localAddress.port); - } - - // Destruct the socket (the mailbox may be still in use). - // Note: it's actually legal to say "delete this" from a member function: - // https://isocpp.org/wiki/faq/freestore-mgmt#delete-this - delete this; -} - -/// See Homa::Socket::receive() -unique_ptr -TransportImpl::SocketImpl::receive(bool blocking) -{ - if (isShutdown()) { - return nullptr; - } - return unique_ptr(mailbox->retrieve(blocking)); -} - -/// See Homa::Socket::shutdown() -void -TransportImpl::SocketImpl::shutdown() -{ - disabled.store(true); - mailbox->socketShutdown(); -} - } // namespace Homa::Core diff --git a/src/TransportImpl.h b/src/TransportImpl.h index d0e06b0..ccbd637 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -37,17 +37,15 @@ namespace Homa::Core { */ class TransportImpl final : public Transport { public: - explicit TransportImpl(Driver* driver, MailboxDir* mailboxDir, - uint64_t transportId); - explicit TransportImpl(Driver* driver, MailboxDir* mailboxDir, - Sender* sender, Receiver* receiver, + explicit TransportImpl(Driver* driver, Callbacks* callbacks, uint64_t transportId); + explicit TransportImpl(Driver* driver, Callbacks* callbacks, Sender* sender, + Receiver* receiver, uint64_t transportId); ~TransportImpl(); void free() override; - Homa::unique_ptr open(uint16_t port) override; + Homa::unique_ptr alloc(uint16_t port) override; uint64_t checkTimeouts() override; void processPacket(Driver::Packet* packet, IpAddress source) override; - void registerCallbackSendReady(Callback func) override; bool trySend(uint64_t* waitUntil) override; bool trySendGrants() override; @@ -63,53 +61,13 @@ class TransportImpl final : public Transport { return transportId; } - /** - * Internal implementation of Homa::Socket. - * - * @sa - * TransportImpl::socketMap - */ - class SocketImpl final : public Socket { - public: - explicit SocketImpl(TransportImpl* transport, uint16_t port, - Mailbox* mailbox); - virtual ~SocketImpl() = default; - - Homa::unique_ptr alloc() override; - void close() override; - Homa::unique_ptr receive(bool blocking) override; - void shutdown() override; - - /// See Homa::Socket::isShutdown() - bool isShutdown() const override - { - return disabled.load(std::memory_order_relaxed); - } - - /// See Homa::Socket::getLocalAddress() - Address getLocalAddress() const override - { - return localAddress; - } - - private: - /// Has the socket been shut down? - std::atomic disabled; - - /// Local address of the socket. - Address localAddress; - - /// Mailbox assigned to this socket. Not owned by this class. - Mailbox* mailbox; - - /// Transport that owns this socket. - TransportImpl* transport; - }; - private: /// Unique identifier for this transport. const uint64_t transportId; + /// User-defined transport callbacks. Not owned by this class. + Callbacks* const callbacks; + /// Driver from which this transport will send and receive packets. /// Not owned by this class. Driver* const driver; @@ -122,10 +80,6 @@ class TransportImpl final : public Transport { /// Module which receives packets and forms them into messages. std::unique_ptr receiver; - - /// Module which keeps track of mailboxes currently in use. Not owned by - /// this class (we don't even know whether it's instantiated by "new"). - MailboxDir* const mailboxDir; }; } // namespace Homa::Core diff --git a/test/system_test.cc b/test/system_test.cc index b84b35c..316b2ca 100644 --- a/test/system_test.cc +++ b/test/system_test.cc @@ -16,7 +16,6 @@ #include #include #include -#include #include #include @@ -55,23 +54,41 @@ struct MessageHeader { } __attribute__((packed)); struct Node { + class Callbacks : public Homa::Callbacks { + public: + explicit Callbacks(std::vector& receiveQueue) + : receiveQueue(receiveQueue) + {} + + bool deliver(uint16_t port, Homa::InMessage* message) override + { + if (port != SERVER_PORT) { + return false; + } + receiveQueue.push_back(message); + return true; + } + + std::vector& receiveQueue; + }; + explicit Node(uint64_t id) : id(id) + , callbacks(receiveQueue) , driver() - , mailboxDir() - , transport(Homa::Transport::create(&driver, &mailboxDir, id)) + , transport(Homa::Transport::create(&driver, &callbacks, id)) , thread() + , receiveQueue() , run(false) - , serverSocket(transport->open(SERVER_PORT)) {} const uint64_t id; + Callbacks callbacks; Homa::Drivers::Fake::FakeDriver driver; - Homa::SimpleMailboxDir mailboxDir; Homa::unique_ptr transport; std::thread thread; + std::vector receiveQueue; std::atomic run; - Homa::unique_ptr serverSocket; }; void @@ -83,9 +100,8 @@ serverMain(Node* server, std::vector addresses) break; } - Homa::unique_ptr message = - server->serverSocket->receive(false); - + Homa::unique_ptr message(server->receiveQueue.back()); + server->receiveQueue.pop_back(); if (message) { MessageHeader header; message->get(0, &header, sizeof(MessageHeader)); @@ -119,7 +135,6 @@ clientMain(int count, int size, std::vector addresses) Node client(1); Homa::TransportPoller poller(client.transport.get()); - Homa::unique_ptr clientSocket = client.transport->open(0); for (int i = 0; i < count; ++i) { uint64_t id = nextId++; char payload[size]; @@ -129,7 +144,7 @@ clientMain(int count, int size, std::vector addresses) Homa::IpAddress destAddress = addresses[randAddr(gen)]; - Homa::unique_ptr message = clientSocket->alloc(); + Homa::unique_ptr message = client.transport->alloc(0); { MessageHeader header; header.id = id; From a0838b31cfbfb2daa43760741a243efaa69f39f4 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Tue, 13 Oct 2020 00:08:02 -0700 Subject: [PATCH 09/15] major refactoring: splitted low-level and high-level transport API --- CMakeLists.txt | 4 +- include/Homa/Bindings/CHoma.h | 23 ++- include/Homa/Core/Transport.h | 162 +++++++++++++++++ include/Homa/Homa.h | 132 +------------- include/Homa/Shenango.h | 81 --------- include/Homa/Transports/PollModeTransport.h | 61 +++++++ include/Homa/Transports/Shenango.h | 63 +++++++ include/Homa/Utils/TransportPoller.h | 55 ------ src/CHoma.cc | 13 +- src/Homa.cc | 30 ---- src/PollModeTransportImpl.cc | 140 +++++++++++++++ src/PollModeTransportImpl.h | 83 +++++++++ src/PollModeTransportImplTest.cc | 190 ++++++++++++++++++++ src/Receiver.cc | 2 +- src/Receiver.h | 6 +- src/ReceiverTest.cc | 2 +- src/Sender.cc | 6 +- src/Sender.h | 7 +- src/SenderTest.cc | 2 +- src/Shenango.cc | 129 +++++++------ src/TransportImpl.cc | 13 +- src/TransportImpl.h | 8 +- src/TransportImplTest.cc | 125 +------------ src/TransportPoller.cc | 85 --------- test/system_test.cc | 38 +--- 25 files changed, 828 insertions(+), 632 deletions(-) create mode 100644 include/Homa/Core/Transport.h delete mode 100644 include/Homa/Shenango.h create mode 100644 include/Homa/Transports/PollModeTransport.h create mode 100644 include/Homa/Transports/Shenango.h delete mode 100644 include/Homa/Utils/TransportPoller.h delete mode 100644 src/Homa.cc create mode 100644 src/PollModeTransportImpl.cc create mode 100644 src/PollModeTransportImpl.h create mode 100644 src/PollModeTransportImplTest.cc delete mode 100644 src/TransportPoller.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 775db86..bbbda14 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,16 +76,15 @@ add_library(Homa src/CHoma.cc src/Debug.cc src/Driver.cc - src/Homa.cc src/Perf.cc src/Policy.cc + src/PollModeTransportImpl.cc src/Receiver.cc src/Sender.cc src/Shenango.cc src/StringUtil.cc src/ThreadId.cc src/TransportImpl.cc - src/TransportPoller.cc src/Util.cc ) add_library(Homa::Homa ALIAS Homa) @@ -257,6 +256,7 @@ add_executable(unit_test src/IntrusiveTest.cc src/ObjectPoolTest.cc src/PolicyTest.cc + src/PollModeTransportImplTest.cc src/ReceiverTest.cc src/SenderTest.cc src/SpinLockTest.cc diff --git a/include/Homa/Bindings/CHoma.h b/include/Homa/Bindings/CHoma.h index d15f4ea..79e6534 100644 --- a/include/Homa/Bindings/CHoma.h +++ b/include/Homa/Bindings/CHoma.h @@ -142,44 +142,49 @@ extern void homa_outmsg_release(homa_outmsg out_msg); /* ============================ */ /** - * homa_trans_create - C-binding for Homa::Transport::create + * homa_trans_create - C-binding for Homa::TransportBase::create */ extern homa_trans homa_trans_create(homa_driver drv, homa_callbacks cbs, uint64_t id); /** - * homa_trans_free - C-binding for Homa::Transport::free + * homa_trans_free - C-binding for Homa::TransportBase::free */ extern void homa_trans_free(homa_trans trans); /** - * homa_trans_alloc - C-binding for Homa::Transport::alloc + * homa_trans_alloc - C-binding for Homa::TransportBase::alloc */ extern homa_outmsg homa_trans_alloc(homa_trans trans, uint16_t port); /** - * homa_trans_check_timeouts - C-binding for Homa::Transport::checkTimeouts + * homa_trans_get_drv - C-binding for Homa::TransportBase::getDriver */ -extern uint64_t homa_trans_check_timeouts(homa_trans trans); +extern homa_driver homa_trans_get_drv(homa_trans trans); /** - * homa_trans_id - C-binding for Homa::Transport::getId + * homa_trans_id - C-binding for Homa::TransportBase::getId */ extern uint64_t homa_trans_id(homa_trans trans); /** - * homa_trans_proc - C-binding for Homa::Transport::processPacket + * homa_trans_check_timeouts - C-binding for Core::Transport::checkTimeouts + */ +extern uint64_t homa_trans_check_timeouts(homa_trans trans); + +/** + * homa_trans_proc - C-binding for Core::Transport::processPacket */ extern void homa_trans_proc(homa_trans trans, uintptr_t desc, void* payload, int32_t len, uint32_t src_ip); /** - * homa_trans_try_send - C-binding for Homa::Transport::trySend + * homa_trans_try_send - C-binding for Core::Transport::trySend */ extern bool homa_trans_try_send(homa_trans trans, uint64_t* wait_until); /** - * homa_trans_try_grant - C-binding for Homa::Transport::trySendGrants + * homa_trans_try_grant - C-binding for Core::Transport::trySendGrants */ extern bool homa_trans_try_grant(homa_trans trans); diff --git a/include/Homa/Core/Transport.h b/include/Homa/Core/Transport.h new file mode 100644 index 0000000..b042cf0 --- /dev/null +++ b/include/Homa/Core/Transport.h @@ -0,0 +1,162 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +/** + * @file Homa/Core/Transport.h + * + * Contains the low-level Homa Transport API. Advanced users of the Homa + * Transport library should include this header. + */ + +#pragma once + +#include + +namespace Homa::Core { + +/** + * Minimal set of low-level API that can be used to create Homa-based transports + * for different runtime environments (e.g. polling, kernel threading, + * green threads, etc). + * + * The execution of a transport is driven through repeated calls to methods + * like checkTimeouts(), processPacket(), trySend(), and trySendGrants(); the + * transport will not make any progress otherwise. Advanced users can compose + * these methods in a way that suits them best. + * + * This class is thread-safe. + */ +class Transport : public TransportBase { + public: + /** + * Collection of user-defined transport callbacks. + */ + class Callbacks { + public: + /** + * Destructor. + */ + virtual ~Callbacks() = default; + + /** + * Invoked when an incoming message arrives and needs to dispatched to + * its destination in the user application for processing. + * + * Here are a few example use cases of this callback: + *
    + *
  • Interaction with the user's thread scheduler: e.g., an + * application may want to block on receive until a message is + * delivered, so this method can be used to wake up blocking threads. + *
  • High-performance message dispatch: e.g., an application may + * choose to implement the message receive queue with a concurrent MPMC + * queue as opposed to a linked-list protected by a mutex;
  • + * Lightweight synchronization: e.g., the socket table that maps from + * port numbers to sockets is a read-mostly data structure, so lookup + * operations can benefit from synchronization schemes such as RCU. + *
+ * + * @param port + * Destination port number of the message. + * @param message + * Incoming message to dispatch. + * @return + * True if the message is delivered successfully; false, otherwise. + */ + virtual bool deliver(uint16_t port, InMessage* message) = 0; + + /** + * Invoked when some packets just became ready to be sent (and there was + * none before). + * + * This callback allows the transport library to notify the users that + * trySend() should be invoked again as soon as possible. For example, + * the callback can be used to implement wakeup signals for the thread + * that is responsible for calling trySend(), if this thread decides to + * sleep when there is no packets to send. + */ + virtual void notifySendReady() {} + }; + + /** + * Return a new instance of a Homa-based transport. + * + * @param driver + * Driver with which this transport should send and receive packets. + * @param callbacks + * Collection of user-defined callbacks to customize the behavior of + * the transport. + * @param transportId + * This transport's unique identifier in the group of transports among + * which this transport will communicate. + * @return + * Pointer to the new transport instance. + */ + static Homa::unique_ptr create(Driver* driver, + Callbacks* callbacks, + uint64_t transportId); + + /** + * Process any timeouts that have expired. + * + * This method must be called periodically to ensure timely handling of + * expired timeouts. + * + * @return + * The rdtsc cycle time when this method should be called again. + */ + virtual uint64_t checkTimeouts() = 0; + + /** + * Handle an ingress packet by running it through the transport protocol + * stack. + * + * @param packet + * The ingress packet. + * @param source + * IpAddress of the socket from which the packet is sent. + */ + virtual void processPacket(Driver::Packet* packet, IpAddress source) = 0; + + /** + * Attempt to send out packets for any messages with unscheduled/granted + * bytes in a way that limits queue buildup in the NIC. + * + * This method must be called eagerly to allow the Transport to make + * progress toward sending outgoing messages. + * + * @param[out] waitUntil + * The rdtsc cycle time when this method should be called again + * (this allows the NIC to drain its transmit queue). Only set + * when this method returns true. + * @return + * True if more packets are ready to be transmitted when the method + * returns; false, otherwise. + */ + virtual bool trySend(uint64_t* waitUntil) = 0; + + /** + * Attempt to grant to incoming messages according to the Homa protocol. + * + * This method must be called eagerly to allow the Transport to make + * progress toward receiving incoming messages. + * + * @return + * True if the method has found some messages to grant; false, + * otherwise. + */ + virtual bool trySendGrants() = 0; +}; + +} // namespace Homa::Core diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index 8d7253a..50ebfa7 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -24,7 +24,6 @@ #define HOMA_INCLUDE_HOMA_HOMA_H #include -#include namespace Homa { @@ -261,93 +260,23 @@ class OutMessage { }; /** - * Collection of user-defined transport callbacks. - */ -class Callbacks { - public: - /** - * Destructor. - */ - virtual ~Callbacks() = default; - - /** - * Invoked when an incoming message arrives and needs to dispatched to its - * destination in the user application for processing. - * - * Here are a few example use cases of this callback: - *
    - *
  • Interaction with the user's thread scheduler: e.g., an application - * may want to block on receive until a message is delivered, so this method - * can be used to wake up blocking threads. - *
  • High-performance message dispatch: e.g., an application may choose - * to implement the message receive queue with a concurrent MPMC queue as - * opposed to a linked-list protected by a mutex; - *
  • Lightweight synchronization: e.g., the socket table that maps from - * port numbers to sockets is a read-mostly data structure, so lookup - * operations can benefit from synchronization schemes such as RCU. - *
- * - * @param port - * Destination port number of the message. - * @param message - * Incoming message to dispatch. - * @return - * True if the message is delivered successfully; false, otherwise. - */ - virtual bool deliver(uint16_t port, InMessage* message) = 0; - - /** - * Invoked when some packets just became ready to be sent (and there was - * none before). - * - * This callback allows the transport library to notify the users that - * trySend() should be invoked again as soon as possible. For example, - * the callback can be used to implement wakeup signals for the thread - * that is responsible for calling trySend(), if this thread decides to - * sleep when there is no packets to send. - */ - virtual void notifySendReady() {} -}; - -/** - * Provides a means of communicating across the network using the Homa protocol. - * - * The execution of the transport is driven through repeated calls to methods - * like checkTimeouts(), processPacket(), trySend(), and trySendGrants(); the - * transport will not make any progress otherwise. + * Basic transport API that are shared between the low-level and high-level + * transport interfaces. * * This class is thread-safe. */ -class Transport { +class TransportBase { public: /** * Custom deleter for use with std::unique_ptr. */ struct Deleter { - void operator()(Transport* transport) + void operator()(TransportBase* transport) { transport->free(); } }; - /** - * Return a new instance of a Homa-based transport. - * - * @param driver - * Driver with which this transport should send and receive packets. - * @param callbacks - * Collection of user-defined callbacks to customize the behavior of - * the transport. - * @param transportId - * This transport's unique identifier in the group of transports among - * which this transport will communicate. - * @return - * Pointer to the new transport instance. - */ - static Homa::unique_ptr create(Driver* driver, - Callbacks* callbacks, - uint64_t transportId); - /** * Allocate Message that can be sent with this Transport. * @@ -368,63 +297,12 @@ class Transport { */ virtual uint64_t getId() = 0; - /** - * Process any timeouts that have expired. - * - * This method must be called periodically to ensure timely handling of - * expired timeouts. - * - * @return - * The rdtsc cycle time when this method should be called again. - */ - virtual uint64_t checkTimeouts() = 0; - - /** - * Handle an ingress packet by running it through the transport protocol - * stack. - * - * @param packet - * The ingress packet. - * @param source - * IpAddress of the socket from which the packet is sent. - */ - virtual void processPacket(Driver::Packet* packet, IpAddress source) = 0; - - /** - * Attempt to send out packets for any messages with unscheduled/granted - * bytes in a way that limits queue buildup in the NIC. - * - * This method must be called eagerly to allow the Transport to make - * progress toward sending outgoing messages. - * - * @param[out] waitUntil - * The rdtsc cycle time when this method should be called again - * (this allows the NIC to drain its transmit queue). Only set - * when this method returns true. - * @return - * True if more packets are ready to be transmitted when the method - * returns; false, otherwise. - */ - virtual bool trySend(uint64_t* waitUntil) = 0; - - /** - * Attempt to grant to incoming messages according to the Homa protocol. - * - * This method must be called eagerly to allow the Transport to make - * progress toward receiving incoming messages. - * - * @return - * True if the method has found some messages to grant; false, - * otherwise. - */ - virtual bool trySendGrants() = 0; - protected: /** * Use protected destructor to prevent users from calling delete on pointers * to this interface. */ - ~Transport() = default; + ~TransportBase() = default; /** * Free this transport instance. No one should not access this transport diff --git a/include/Homa/Shenango.h b/include/Homa/Shenango.h deleted file mode 100644 index f424357..0000000 --- a/include/Homa/Shenango.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright (c) 2020 Stanford University - * - * Permission to use, copy, modify, and distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR(S) DISCLAIM ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL AUTHORS BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -/** - * @file Shenango.h - * - * Contains the glue code for Homa-Shenango integration. This is the only - * header Shenango needs to include in order to use Homa transport. - * - * Shenango is an experimental operating system that aims to provide low tail - * latency and high CPU efficiency simultaneously for servers in datacenters. - * See for more information. - * - * This file follows the Shenango coding style. - */ - -#pragma once - -#include "Bindings/CHoma.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/** - * homa_driver_create - creates a shim driver that translates Homa::Driver - * operations to Shenango functions - * @proto: protocol number reserved for Homa transport protocol - * @local_ip: local IP address of the driver - * @max_payload: maximum number of bytes carried by the packet payload - * @link_speed: effective network bandwidth, in Mbits/second - * - * Returns a handle to the driver created. - */ -extern homa_driver homa_driver_create(uint8_t proto, uint32_t local_ip, - uint32_t max_payload, - uint32_t link_speed); - -/** - * homa_driver_free - frees a shim driver created earlier with - * @homa_driver_create. - * @param drv: the driver to free - */ -extern void homa_driver_free(homa_driver drv); - -/** - * homa_callbacks_create - creates a collection of the Shenango-defined - * callbacks for the transport. - * @proto: protocol number reserved for Homa transport protocol - * @local_ip: local IP address of the driver - * @cb_send_ready: callback function to invoke in Callbacks::notifySendReady - * @cb_data: input data for @cb_send_ready - * - * Returns a handle to the callbacks created. - */ -extern homa_callbacks homa_callbacks_create(uint8_t proto, uint32_t local_ip, - void (*cb_send_ready)(void*), - void* cb_data); - -/** - * homa_callbacks_free - frees the Callbacks object created earlier with - * @homa_callbacks_create. - * @param cbs: the callbacks to free - */ -extern void homa_callbacks_free(homa_callbacks cbs); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/include/Homa/Transports/PollModeTransport.h b/include/Homa/Transports/PollModeTransport.h new file mode 100644 index 0000000..7bc224e --- /dev/null +++ b/include/Homa/Transports/PollModeTransport.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#pragma once + +#include + +namespace Homa { + +/** + * A polling-based Homa transport implementation. + */ +class PollModeTransport : public TransportBase { + public: + /** + * Return a new instance of a polling-based Homa transport. + * + * @param driver + * Driver with which this transport should send and receive packets. + * @param transportId + * This transport's unique identifier in the group of transports among + * which this transport will communicate. + * @return + * Pointer to the new transport instance. + */ + static Homa::unique_ptr create(Driver* driver, + uint64_t transportId); + + /** + * Make incremental progress performing all Transport functionality. + * + * This method MUST be called for the Transport to make progress and should + * be called frequently to ensure timely progress. + */ + virtual void poll() = 0; + + /** + * Check for and return a Message sent to this Socket if available. + * + * @param blocking + * When set to true, this method should not return until a message + * arrives or the socket is shut down. + * @return + * Pointer to the received message, if any; otherwise, nullptr. + */ + virtual Homa::unique_ptr receive() = 0; +}; + +} // namespace Homa \ No newline at end of file diff --git a/include/Homa/Transports/Shenango.h b/include/Homa/Transports/Shenango.h new file mode 100644 index 0000000..369ea2c --- /dev/null +++ b/include/Homa/Transports/Shenango.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2020 Stanford University + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR(S) DISCLAIM ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL AUTHORS BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +/** + * @file Homa/Transports/Shenango.h + * + * Contains the glue code for Homa-Shenango integration. This is the only + * header Shenango needs to include in order to use Homa transport. + * + * Shenango is an experimental operating system that aims to provide low tail + * latency and high CPU efficiency simultaneously for servers in datacenters. + * See for more information. + * + * This file follows the Shenango coding style. + */ + +#pragma once + +#include "Homa/Bindings/CHoma.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * homa_create_shenango_trans - creates a transport instance that can be used by + * Shenango to send and receive messages. + * @id: Unique identifier for this transport instance + * @proto: Protocol number reserved for Homa transport protocol + * @local_ip: Local IP address of the driver + * @max_payload: Maximum number of bytes carried by the packet payload + * @link_speed: Effective network bandwidth, in Mbits/second + * @cb_send_ready: Callback function to invoke in Callbacks::notifySendReady + * @cb_data: Input data for @cb_send_ready + * + * Returns a handle to the callbacks created. + */ +extern homa_trans homa_create_shenango_trans(uint64_t id, + uint8_t proto, uint32_t local_ip, uint32_t max_payload, uint32_t link_speed, + void (*cb_send_ready)(void*), void* cb_data); + +/** + * homa_free_shenango_trans - frees a transport created earlier with + * @homa_create_shenango_trans. + * @param trans: the transport to free + */ +extern void homa_free_shenango_trans(homa_trans trans); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/include/Homa/Utils/TransportPoller.h b/include/Homa/Utils/TransportPoller.h deleted file mode 100644 index 097600c..0000000 --- a/include/Homa/Utils/TransportPoller.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright (c) 2020, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#pragma once - -#include - -namespace Homa { - -/// Forward declaration. -class Transport; - -/** - * Provides a means to drive the execution of a transport through repeated - * calls to the poll() method. - * - * This class demonstrates a simple way to invoke the Homa::Transport APIs - * in a poll-based programming style. In practice, users will often need to - * invoke the Transport APIs in ways that fit their systems better. The Homa- - * Shenango integration provides a concrete example. - * - * This class is thread-safe; although calling poll() from multiple threads - * provides no performance benefit. - * - * @sa Homa/Shenango.h - */ -class TransportPoller { - public: - explicit TransportPoller(Transport* transport); - ~TransportPoller() = default; - void poll(); - - private: - void processPackets(); - - /// Transport instance whose execution is driven by this poller. - Transport* const transport; - - /// Caches the next cycle time that timeouts will need to rechecked. - std::atomic nextTimeoutCycles; -}; - -} // namespace Homa \ No newline at end of file diff --git a/src/CHoma.cc b/src/CHoma.cc index 91ef3b3..7fa88a3 100644 --- a/src/CHoma.cc +++ b/src/CHoma.cc @@ -14,9 +14,10 @@ */ #include "Homa/Bindings/CHoma.h" -#include "Homa/Homa.h" +#include "Homa/Core/Transport.h" using namespace Homa; +using Transport = Core::Transport; /// Shorthand for converting C-style Homa object handle types back to C++ types. #define deref(T, x) (*static_cast(x.p)) @@ -118,8 +119,8 @@ homa_outmsg_release(homa_outmsg out_msg) homa_trans homa_trans_create(homa_driver drv, homa_callbacks cbs, uint64_t id) { - unique_ptr trans = - Transport::create(&deref(Driver, drv), &deref(Callbacks, cbs), id); + unique_ptr trans = Transport::create( + &deref(Driver, drv), &deref(Transport::Callbacks, cbs), id); return homa_trans{trans.release()}; } @@ -149,6 +150,12 @@ homa_trans_id(homa_trans trans) return deref(Transport, trans).getId(); } +homa_driver homa_trans_get_drv(homa_trans trans) +{ + Driver* drv = deref(Transport, trans).getDriver(); + return homa_driver{drv}; +} + void homa_trans_proc(homa_trans trans, uintptr_t desc, void* payload, int32_t len, uint32_t src_ip) diff --git a/src/Homa.cc b/src/Homa.cc deleted file mode 100644 index 0f9a716..0000000 --- a/src/Homa.cc +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright (c) 2018-2019, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#include - -#include "TransportImpl.h" - -namespace Homa { - -Homa::unique_ptr -Transport::create(Driver* driver, Callbacks* callbacks, uint64_t transportId) -{ - Transport* transport = - new Core::TransportImpl(driver, callbacks, transportId); - return Homa::unique_ptr(transport); -} - -} // namespace Homa diff --git a/src/PollModeTransportImpl.cc b/src/PollModeTransportImpl.cc new file mode 100644 index 0000000..58c03ec --- /dev/null +++ b/src/PollModeTransportImpl.cc @@ -0,0 +1,140 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "PollModeTransportImpl.h" + +namespace Homa { + +Homa::unique_ptr +PollModeTransport::create(Driver* driver, uint64_t transportId) +{ + return Homa::unique_ptr( + new PollModeTransportImpl(driver, transportId)); +} + +/** + * Constructor. + * + * @param driver + * Driver with which this transport should send and receive packets. + * @param transportId + * This transport's unique identifier in the group of transports among + * which this transport will communicate. + */ +PollModeTransportImpl::PollModeTransportImpl(Driver* driver, + uint64_t transportId) + : callbacks(this) + , core(driver, &callbacks, transportId) + , nextTimeoutCycles(0) +{} + +/** + * Construct for unit testing. + */ +PollModeTransportImpl::PollModeTransportImpl(Driver* driver, + Core::Sender* sender, + Core::Receiver* receiver, + uint64_t transportId) + : callbacks(this) + , core(driver, &callbacks, sender, receiver, transportId) + , nextTimeoutCycles(0) +{} + +/// See Homa::PollModeTransport::alloc() +Homa::unique_ptr +PollModeTransportImpl::alloc(uint16_t port) +{ + return core.alloc(port); +} + +/// See Homa::PollModeTransport::free() +void +PollModeTransportImpl::free() +{ + // This instance must be allocated via new from PollModeTransport::create(). + delete this; +} + +/// See Homa::PollModeTransport::getDriver() +Driver* +PollModeTransportImpl::getDriver() +{ + return core.getDriver(); +} + +/// See Homa::PollModeTransport::getId() +uint64_t +PollModeTransportImpl::getId() +{ + return core.getId(); +} + +void +PollModeTransportImpl::poll() +{ + // Receive and dispatch incoming packets. + processPackets(); + + // Allow sender and receiver to make incremental progress. + uint64_t waitUntil; + core.trySend(&waitUntil); + core.trySendGrants(); + + if (PerfUtils::Cycles::rdtsc() >= nextTimeoutCycles.load()) { + uint64_t requestedTimeoutCycles = core.checkTimeouts(); + nextTimeoutCycles.store(requestedTimeoutCycles); + } +} + +/// See Homa::PollModeTransport::receive +Homa::unique_ptr +PollModeTransportImpl::receive() +{ + if (receiveQueue.empty()) { + return nullptr; + } + InMessage* message = receiveQueue.back(); + receiveQueue.pop_back(); + return Homa::unique_ptr(message); +} + +/** + * Helper method which receives a burst of incoming packets and process them + * through the transport protocol. Pulled out of PollModeTransportImpl::poll() + * to simplify unit testing. + */ +void +PollModeTransportImpl::processPackets() +{ + // Keep track of time spent doing active processing versus idle. + uint64_t cycles = PerfUtils::Cycles::rdtsc(); + + const int MAX_BURST = 32; + Driver::Packet packets[MAX_BURST]; + IpAddress srcAddrs[MAX_BURST]; + int numPackets = getDriver()->receivePackets(MAX_BURST, packets, srcAddrs); + for (int i = 0; i < numPackets; ++i) { + core.processPacket(&packets[i], srcAddrs[i]); + } + + cycles = PerfUtils::Cycles::rdtsc() - cycles; + if (numPackets > 0) { + Perf::counters.active_cycles.add(cycles); + } else { + Perf::counters.idle_cycles.add(cycles); + } +} + +} // namespace Homa diff --git a/src/PollModeTransportImpl.h b/src/PollModeTransportImpl.h new file mode 100644 index 0000000..0740b2d --- /dev/null +++ b/src/PollModeTransportImpl.h @@ -0,0 +1,83 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#pragma once + +#include +#include +#include "TransportImpl.h" + +namespace Homa { + +/** + * Internal implementation of Homa::PollModeTransport. + */ +class PollModeTransportImpl final : public PollModeTransport { + public: + explicit PollModeTransportImpl(Driver* driver, uint64_t transportId); + explicit PollModeTransportImpl(Driver* driver, Core::Sender* sender, + Core::Receiver* receiver, + uint64_t transportId); + virtual ~PollModeTransportImpl() = default; + Homa::unique_ptr alloc(uint16_t port) override; + void free() override; + Driver* getDriver() override; + uint64_t getId() override; + void poll() override; + Homa::unique_ptr receive() override; + + private: + /** + * Callbacks defined for the polling-based transport implementation. + */ + class PollModeCallbacks : public Core::Transport::Callbacks { + public: + explicit PollModeCallbacks(PollModeTransportImpl* owner) + : owner(owner) + {} + + ~PollModeCallbacks() override = default; + + bool deliver(uint16_t port, InMessage* message) override + { + (void)port; + SpinLock::Lock _(owner->mutex); + owner->receiveQueue.push_back(message); + return true; + } + + private: + PollModeTransportImpl* owner; + }; + + void processPackets(); + + /// Transport callbacks. + PollModeCallbacks callbacks; + + /// Core transport instance. + Core::TransportImpl core; + + /// Caches the next cycle time that timeouts will need to rechecked. + std::atomic nextTimeoutCycles; + + /// Monitor-style lock which protects the receive queue. + SpinLock mutex; + + /// Queue of completed incoming messages. + std::vector receiveQueue; +}; + +} // namespace Homa \ No newline at end of file diff --git a/src/PollModeTransportImplTest.cc b/src/PollModeTransportImplTest.cc new file mode 100644 index 0000000..286174c --- /dev/null +++ b/src/PollModeTransportImplTest.cc @@ -0,0 +1,190 @@ +/* Copyright (c) 2018-2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include + +#include "Mock/MockDriver.h" +#include "Mock/MockReceiver.h" +#include "Mock/MockSender.h" +#include "PollModeTransportImpl.h" +#include "Protocol.h" +#include "TransportImpl.h" +#include "Tub.h" + +namespace Homa { +namespace Core { +namespace { + +using ::testing::_; +using ::testing::DoAll; +using ::testing::Eq; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SetArrayArgument; + +/** + * Defines a matcher EqPacket(p) to match two Driver::Packet* by their + * underlying packet buffer descriptors. + */ +MATCHER_P(EqPacket, p, "") +{ + return arg->descriptor == p->descriptor; +} + +class PollModeTransportImplTest : public ::testing::Test { + public: + PollModeTransportImplTest() + : mockDriver(allocMockDriver()) + , mockSender(new NiceMock(22, mockDriver, 0, 0)) + , mockReceiver(new NiceMock(mockDriver, 0, 0)) + , transport(new PollModeTransportImpl(mockDriver, mockSender, + mockReceiver, 22)) + { + PerfUtils::Cycles::mockTscValue = 10000; + } + + ~PollModeTransportImplTest() + { + delete transport; + delete mockDriver; + PerfUtils::Cycles::mockTscValue = 0; + } + + NiceMock* allocMockDriver() + { + auto driver = new NiceMock(); + ON_CALL(*driver, getBandwidth).WillByDefault(Return(8000)); + ON_CALL(*driver, getMaxPayloadSize).WillByDefault(Return(1024)); + return driver; + } + + NiceMock* mockDriver; + NiceMock* mockSender; + NiceMock* mockReceiver; + PollModeTransportImpl* transport; +}; + +TEST_F(PollModeTransportImplTest, poll) +{ + EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); + EXPECT_CALL(*mockSender, trySend).Times(1); + EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); + EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10000)); + EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); + + transport->poll(); + + EXPECT_EQ(10000U, transport->nextTimeoutCycles); + + EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); + EXPECT_CALL(*mockSender, trySend).Times(1); + EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); + EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10200)); + EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); + + transport->poll(); + + EXPECT_EQ(10100U, transport->nextTimeoutCycles); + + EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); + EXPECT_CALL(*mockSender, trySend).Times(1); + EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); + EXPECT_CALL(*mockSender, checkTimeouts).Times(0); + EXPECT_CALL(*mockReceiver, checkTimeouts).Times(0); + + transport->poll(); + + EXPECT_EQ(10100U, transport->nextTimeoutCycles); +} + +TEST_F(PollModeTransportImplTest, processPackets) +{ + char payload[8][1024]; + Homa::Driver::Packet packets[8]; + + // Set DATA packet + Homa::Mock::MockDriver::PacketBuf dataPacketBuf{payload[0]}; + Driver::Packet dataPacket = dataPacketBuf.toPacket(1024); + static_cast(dataPacket.payload) + ->common.opcode = Protocol::Packet::DATA; + packets[0] = dataPacket; + EXPECT_CALL(*mockReceiver, handleDataPacket(EqPacket(&packets[0]), _)); + + // Set GRANT packet + Homa::Mock::MockDriver::PacketBuf grantPacketBuf{payload[1]}; + Driver::Packet grantPacket = grantPacketBuf.toPacket(1024); + static_cast(grantPacket.payload) + ->common.opcode = Protocol::Packet::GRANT; + packets[1] = grantPacket; + EXPECT_CALL(*mockSender, handleGrantPacket(EqPacket(&packets[1]))); + + // Set DONE packet + Homa::Mock::MockDriver::PacketBuf donePacketBuf{payload[2]}; + Driver::Packet donePacket = donePacketBuf.toPacket(1024); + static_cast(donePacket.payload) + ->common.opcode = Protocol::Packet::DONE; + packets[2] = donePacket; + EXPECT_CALL(*mockSender, handleDonePacket(EqPacket(&packets[2]))); + + // Set RESEND packet + Homa::Mock::MockDriver::PacketBuf resendPacketBuf{payload[3]}; + Driver::Packet resendPacket = resendPacketBuf.toPacket(1024); + static_cast(resendPacket.payload) + ->common.opcode = Protocol::Packet::RESEND; + packets[3] = resendPacket; + EXPECT_CALL(*mockSender, handleResendPacket(EqPacket(&packets[3]))); + + // Set BUSY packet + Homa::Mock::MockDriver::PacketBuf busyPacketBuf{payload[4]}; + Driver::Packet busyPacket = busyPacketBuf.toPacket(1024); + static_cast(busyPacket.payload) + ->common.opcode = Protocol::Packet::BUSY; + packets[4] = busyPacket; + EXPECT_CALL(*mockReceiver, handleBusyPacket(EqPacket(&packets[4]))); + + // Set PING packet + Homa::Mock::MockDriver::PacketBuf pingPacketBuf{payload[5]}; + Driver::Packet pingPacket = pingPacketBuf.toPacket(1024); + static_cast(pingPacket.payload) + ->common.opcode = Protocol::Packet::PING; + packets[5] = pingPacket; + EXPECT_CALL(*mockReceiver, handlePingPacket(EqPacket(&packets[5]), _)); + + // Set UNKNOWN packet + Homa::Mock::MockDriver::PacketBuf unknownPacketBuf{payload[6]}; + Driver::Packet unknownPacket = unknownPacketBuf.toPacket(1024); + static_cast(unknownPacket.payload) + ->common.opcode = Protocol::Packet::UNKNOWN; + packets[6] = unknownPacket; + EXPECT_CALL(*mockSender, handleUnknownPacket(EqPacket(&packets[6]))); + + // Set ERROR packet + Homa::Mock::MockDriver::PacketBuf errorPacketBuf{payload[7]}; + Driver::Packet errorPacket = errorPacketBuf.toPacket(1024); + static_cast(errorPacket.payload) + ->common.opcode = Protocol::Packet::ERROR; + packets[7] = errorPacket; + EXPECT_CALL(*mockSender, handleErrorPacket(EqPacket(&packets[7]))); + + EXPECT_CALL(*mockDriver, receivePackets) + .WillOnce(DoAll(SetArrayArgument<1>(packets, packets + 8), Return(8))); + + transport->processPackets(); +} + +} // namespace +} // namespace Core +} // namespace Homa diff --git a/src/Receiver.cc b/src/Receiver.cc index 2db8ac1..689d7df 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -40,7 +40,7 @@ namespace Core { * Number of cycles of inactivity to wait between requesting retransmission * of un-received parts of a message. */ -Receiver::Receiver(Driver* driver, Callbacks* callbacks, +Receiver::Receiver(Driver* driver, Transport::Callbacks* callbacks, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles) : callbacks(callbacks) diff --git a/src/Receiver.h b/src/Receiver.h index 3f96bc7..1a425f1 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -16,8 +16,8 @@ #ifndef HOMA_CORE_RECEIVER_H #define HOMA_CORE_RECEIVER_H +#include #include -#include #include #include @@ -44,7 +44,7 @@ namespace Core { */ class Receiver { public: - explicit Receiver(Driver* driver, Callbacks* callbacks, + explicit Receiver(Driver* driver, Transport::Callbacks* callbacks, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles); @@ -467,7 +467,7 @@ class Receiver { void updateSchedule(Message* message, const SpinLock::Lock& lock); /// User-defined transport callbacks. Not owned by this class. - Callbacks* const callbacks; + Transport::Callbacks* const callbacks; /// Driver with which all packets will be sent and received. This driver /// is chosen by the Transport that owns this Sender. diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index 2fbf8a9..745076d 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -60,7 +60,7 @@ MATCHER_P(EqPacketLen, length, "") return arg->length == length; } -class MockCallbacks : public Callbacks { +class MockCallbacks : public Transport::Callbacks { public: explicit MockCallbacks() : receivedMessage() diff --git a/src/Sender.cc b/src/Sender.cc index b3f9fef..9eac00c 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -44,9 +44,9 @@ namespace Core { * Number of cycles of inactivity to wait between checking on the liveness * of an Sender::Message. */ -Sender::Sender(uint64_t transportId, Driver* driver, Callbacks* callbacks, - Policy::Manager* policyManager, uint64_t messageTimeoutCycles, - uint64_t pingIntervalCycles) +Sender::Sender(uint64_t transportId, Driver* driver, + Transport::Callbacks* callbacks, Policy::Manager* policyManager, + uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles) : transportId(transportId) , callbacks(callbacks) , driver(driver) diff --git a/src/Sender.h b/src/Sender.h index 7e7fe64..d637260 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -16,8 +16,8 @@ #ifndef HOMA_CORE_SENDER_H #define HOMA_CORE_SENDER_H +#include #include -#include #include #include @@ -41,7 +41,8 @@ namespace Core { */ class Sender { public: - explicit Sender(uint64_t transportId, Driver* driver, Callbacks* callbacks, + explicit Sender(uint64_t transportId, Driver* driver, + Transport::Callbacks* callbacks, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles); virtual ~Sender(); @@ -402,7 +403,7 @@ class Sender { const uint64_t transportId; /// User-defined transport callbacks; not owned by this class. - Callbacks* const callbacks; + Transport::Callbacks* const callbacks; /// Driver with which all packets will be sent and received. This driver /// is chosen by the Transport that owns this Sender. diff --git a/src/SenderTest.cc b/src/SenderTest.cc index 3c31969..fee2342 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -40,7 +40,7 @@ MATCHER_P(EqPacket, p, "") return arg->descriptor == p->descriptor; } -class MockCallbacks : public Callbacks { +class MockCallbacks : public Transport::Callbacks { public: explicit MockCallbacks() = default; diff --git a/src/Shenango.cc b/src/Shenango.cc index 40205c8..0c99b77 100644 --- a/src/Shenango.cc +++ b/src/Shenango.cc @@ -13,11 +13,11 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#include "Homa/Shenango.h" +#include "Homa/Transports/Shenango.h" #include #include "Debug.h" -#include "Homa/Homa.h" +#include "Homa/Core/Transport.h" using namespace Homa; @@ -66,6 +66,48 @@ DECLARE_SHENANGO_FUNC(uint32_t, homa_queued_bytes) DECLARE_SHENANGO_FUNC(void*, trans_table_lookup, uint8_t, SocketAddress, SocketAddress) +/** + * Callback functions specialized for the Shenango runtime. + */ +class ShenangoCallbacks final : Core::Transport::Callbacks { + public: + explicit ShenangoCallbacks(uint8_t proto, uint32_t local_ip, + std::function notify_send_ready) + : proto(proto) + , local_ip{local_ip} + , notify_send_ready(std::move(notify_send_ready)) + {} + + ~ShenangoCallbacks() override = default; + + bool deliver(uint16_t port, InMessage* message) override + { + // The socket table in Shenango is protected by an RCU. + shenango_rcu_read_lock(); + SocketAddress laddr = {local_ip, port}; + void* trans_entry = shenango_trans_table_lookup(proto, laddr, {}); + if (trans_entry) { + shenango_homa_mb_deliver(trans_entry, homa_inmsg{message}); + } + shenango_rcu_read_unlock(); + return trans_entry != nullptr; + } + + void notifySendReady() override + { + notify_send_ready(); + } + + /// Protocol number reserved for Homa; defined as IPPROTO_HOMA in Shenango. + const uint8_t proto; + + /// Local IP address of the transport. + const IpAddress local_ip; + + /// Callback function for notifySendReady(). + std::function notify_send_ready; +}; + /** * A simple shim driver that translates Driver operations to Shenango * functions. @@ -79,8 +121,11 @@ class ShenangoDriver final : public Driver { , local_ip{local_ip} , max_payload(max_payload) , link_speed(link_speed) + , callbacks() {} + ~ShenangoDriver() override = default; + Packet allocPacket() override { void* payload; @@ -132,7 +177,6 @@ class ShenangoDriver final : public Driver { return shenango_homa_queued_bytes(); } - private: /// Protocol number reserved for Homa; defined as IPPROTO_HOMA in Shenango. const uint8_t proto; @@ -144,75 +188,28 @@ class ShenangoDriver final : public Driver { /// Effective network bandwidth, in Mbits/second. const uint32_t link_speed; -}; - -homa_driver -homa_driver_create(uint8_t proto, uint32_t local_ip, uint32_t max_payload, - uint32_t link_speed) -{ - void* driver = new ShenangoDriver(proto, local_ip, max_payload, link_speed); - return homa_driver{driver}; -} -void -homa_driver_free(homa_driver drv) -{ - delete static_cast(drv.p); -} - -/** - * Shenango-defined callback functions for the transport. - */ -class ShenangoCallbacks final : Callbacks { - public: - explicit ShenangoCallbacks(uint8_t proto, uint32_t local_ip, - std::function notify_send_ready) - : proto(proto) - , local_ip{local_ip} - , notify_send_ready(std::move(notify_send_ready)) - {} - - ~ShenangoCallbacks() override = default; - - bool deliver(uint16_t port, InMessage* message) override - { - // The socket table in Shenango is protected by an RCU. - shenango_rcu_read_lock(); - SocketAddress laddr = {local_ip, port}; - void* trans_entry = shenango_trans_table_lookup(proto, laddr, {}); - if (trans_entry) { - shenango_homa_mb_deliver(trans_entry, homa_inmsg{message}); - } - shenango_rcu_read_unlock(); - return trans_entry != nullptr; - } - - void notifySendReady() override - { - notify_send_ready(); - } - - /// Protocol number reserved for Homa; defined as IPPROTO_HOMA in Shenango. - const uint8_t proto; - - /// Local IP address of the transport. - const IpAddress local_ip; - - /// Callback function for notifySendReady(). - std::function notify_send_ready; + /// Callback object. Piggybacked here to allow automatic destruction. + std::unique_ptr callbacks; }; -homa_callbacks -homa_callbacks_create(uint8_t proto, uint32_t local_ip, - void (*cb_send_ready)(void*), void* cb_data) +homa_trans +homa_create_shenango_trans(uint64_t id, uint8_t proto, uint32_t local_ip, + uint32_t max_payload, uint32_t link_speed, + void (*cb_send_ready)(void*), void* cb_data) { - void* cbs = new ShenangoCallbacks(proto, local_ip, - std::bind(cb_send_ready, cb_data)); - return homa_callbacks{cbs}; + ShenangoCallbacks* callbacks = new ShenangoCallbacks( + proto, local_ip, std::bind(cb_send_ready, cb_data)); + ShenangoDriver* drv = + new ShenangoDriver(proto, local_ip, max_payload, link_speed); + drv->callbacks.reset(callbacks); + return homa_trans_create(homa_driver{drv}, homa_callbacks{callbacks}, id); } void -homa_callbacks_free(homa_callbacks cbs) +homa_free_shenango_trans(homa_trans trans) { - delete static_cast(cbs.p); + homa_driver drv = homa_trans_get_drv(trans); + homa_trans_free(trans); + delete static_cast(drv.p); } diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index b2df218..dd1ceff 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -29,6 +29,14 @@ const uint64_t PING_INTERVAL_US = 3 * BASE_TIMEOUT_US; /// Microseconds to wait before performing retires on inbound messages. const uint64_t RESEND_INTERVAL_US = BASE_TIMEOUT_US; +Homa::unique_ptr +Transport::create(Driver* driver, Callbacks* callbacks, uint64_t transportId) +{ + Transport* transport = + new Core::TransportImpl(driver, callbacks, transportId); + return Homa::unique_ptr(transport); +} + /** * Construct an instance of a Homa-based transport. * @@ -69,11 +77,6 @@ TransportImpl::TransportImpl(Driver* driver, Callbacks* callbacks, , receiver(receiver) {} -/** - * TransportImpl Destructor. - */ -TransportImpl::~TransportImpl() = default; - /// See Homa::Transport::free() void TransportImpl::free() diff --git a/src/TransportImpl.h b/src/TransportImpl.h index ccbd637..f35d274 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -16,7 +16,7 @@ #ifndef HOMA_CORE_TRANSPORT_H #define HOMA_CORE_TRANSPORT_H -#include +#include #include #include @@ -41,7 +41,7 @@ class TransportImpl final : public Transport { uint64_t transportId); explicit TransportImpl(Driver* driver, Callbacks* callbacks, Sender* sender, Receiver* receiver, uint64_t transportId); - ~TransportImpl(); + virtual ~TransportImpl() = default; void free() override; Homa::unique_ptr alloc(uint16_t port) override; uint64_t checkTimeouts() override; @@ -50,13 +50,13 @@ class TransportImpl final : public Transport { bool trySendGrants() override; /// See Homa::Transport::getDriver() - virtual Driver* getDriver() + Driver* getDriver() override { return driver; } /// See Homa::Transport::getId() - virtual uint64_t getId() + uint64_t getId() override { return transportId; } diff --git a/src/TransportImplTest.cc b/src/TransportImplTest.cc index cc8b887..f3706f5 100644 --- a/src/TransportImplTest.cc +++ b/src/TransportImplTest.cc @@ -16,7 +16,6 @@ #include #include -#include "Homa/Utils/TransportPoller.h" #include "Mock/MockDriver.h" #include "Mock/MockReceiver.h" #include "Mock/MockSender.h" @@ -35,31 +34,19 @@ using ::testing::NiceMock; using ::testing::Return; using ::testing::SetArrayArgument; -/** - * Defines a matcher EqPacket(p) to match two Driver::Packet* by their - * underlying packet buffer descriptors. - */ -MATCHER_P(EqPacket, p, "") -{ - return arg->descriptor == p->descriptor; -} - class TransportImplTest : public ::testing::Test { public: TransportImplTest() : mockDriver(allocMockDriver()) , mockSender(new NiceMock(22, mockDriver, 0, 0)) , mockReceiver(new NiceMock(mockDriver, 0, 0)) - , transport(new TransportImpl(mockDriver, nullptr, mockSender, - mockReceiver, 22)) - , poller(transport) + , transport(mockDriver, nullptr, mockSender, mockReceiver, 22) { PerfUtils::Cycles::mockTscValue = 10000; } ~TransportImplTest() { - delete transport; delete mockDriver; PerfUtils::Cycles::mockTscValue = 0; } @@ -75,116 +62,12 @@ class TransportImplTest : public ::testing::Test { NiceMock* mockDriver; NiceMock* mockSender; NiceMock* mockReceiver; - TransportImpl* transport; - TransportPoller poller; + TransportImpl transport; }; -TEST_F(TransportImplTest, poll) +TEST_F(TransportImplTest, processPacket) { - EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); - EXPECT_CALL(*mockSender, trySend).Times(1); - EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10000)); - EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); - - poller.poll(); - - EXPECT_EQ(10000U, poller.nextTimeoutCycles); - - EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); - EXPECT_CALL(*mockSender, trySend).Times(1); - EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10200)); - EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); - - poller.poll(); - - EXPECT_EQ(10100U, poller.nextTimeoutCycles); - - EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); - EXPECT_CALL(*mockSender, trySend).Times(1); - EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).Times(0); - EXPECT_CALL(*mockReceiver, checkTimeouts).Times(0); - - poller.poll(); - - EXPECT_EQ(10100U, poller.nextTimeoutCycles); -} - -TEST_F(TransportImplTest, processPackets) -{ - char payload[8][1024]; - Homa::Driver::Packet packets[8]; - - // Set DATA packet - Homa::Mock::MockDriver::PacketBuf dataPacketBuf{payload[0]}; - Driver::Packet dataPacket = dataPacketBuf.toPacket(1024); - static_cast(dataPacket.payload) - ->common.opcode = Protocol::Packet::DATA; - packets[0] = dataPacket; - EXPECT_CALL(*mockReceiver, handleDataPacket(EqPacket(&packets[0]), _)); - - // Set GRANT packet - Homa::Mock::MockDriver::PacketBuf grantPacketBuf{payload[1]}; - Driver::Packet grantPacket = grantPacketBuf.toPacket(1024); - static_cast(grantPacket.payload) - ->common.opcode = Protocol::Packet::GRANT; - packets[1] = grantPacket; - EXPECT_CALL(*mockSender, handleGrantPacket(EqPacket(&packets[1]))); - - // Set DONE packet - Homa::Mock::MockDriver::PacketBuf donePacketBuf{payload[2]}; - Driver::Packet donePacket = donePacketBuf.toPacket(1024); - static_cast(donePacket.payload) - ->common.opcode = Protocol::Packet::DONE; - packets[2] = donePacket; - EXPECT_CALL(*mockSender, handleDonePacket(EqPacket(&packets[2]))); - - // Set RESEND packet - Homa::Mock::MockDriver::PacketBuf resendPacketBuf{payload[3]}; - Driver::Packet resendPacket = resendPacketBuf.toPacket(1024); - static_cast(resendPacket.payload) - ->common.opcode = Protocol::Packet::RESEND; - packets[3] = resendPacket; - EXPECT_CALL(*mockSender, handleResendPacket(EqPacket(&packets[3]))); - - // Set BUSY packet - Homa::Mock::MockDriver::PacketBuf busyPacketBuf{payload[4]}; - Driver::Packet busyPacket = busyPacketBuf.toPacket(1024); - static_cast(busyPacket.payload) - ->common.opcode = Protocol::Packet::BUSY; - packets[4] = busyPacket; - EXPECT_CALL(*mockReceiver, handleBusyPacket(EqPacket(&packets[4]))); - - // Set PING packet - Homa::Mock::MockDriver::PacketBuf pingPacketBuf{payload[5]}; - Driver::Packet pingPacket = pingPacketBuf.toPacket(1024); - static_cast(pingPacket.payload) - ->common.opcode = Protocol::Packet::PING; - packets[5] = pingPacket; - EXPECT_CALL(*mockReceiver, handlePingPacket(EqPacket(&packets[5]), _)); - - // Set UNKNOWN packet - Homa::Mock::MockDriver::PacketBuf unknownPacketBuf{payload[6]}; - Driver::Packet unknownPacket = unknownPacketBuf.toPacket(1024); - static_cast(unknownPacket.payload) - ->common.opcode = Protocol::Packet::UNKNOWN; - packets[6] = unknownPacket; - EXPECT_CALL(*mockSender, handleUnknownPacket(EqPacket(&packets[6]))); - - // Set ERROR packet - Homa::Mock::MockDriver::PacketBuf errorPacketBuf{payload[7]}; - Driver::Packet errorPacket = errorPacketBuf.toPacket(1024); - static_cast(errorPacket.payload) - ->common.opcode = Protocol::Packet::ERROR; - packets[7] = errorPacket; - EXPECT_CALL(*mockSender, handleErrorPacket(EqPacket(&packets[7]))); - - EXPECT_CALL(*mockDriver, receivePackets) - .WillOnce(DoAll(SetArrayArgument<1>(packets, packets + 8), Return(8))); - - poller.processPackets(); + // tested sufficiently in PollModeTransportImpl tests } } // namespace diff --git a/src/TransportPoller.cc b/src/TransportPoller.cc deleted file mode 100644 index 0ff09b5..0000000 --- a/src/TransportPoller.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright (c) 2020, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#include "Homa/Utils/TransportPoller.h" -#include -#include "Homa/Homa.h" -#include "Perf.h" - -namespace Homa { - -/** - * Transport poller constructor. - * - * @param transport - * Transport instance driven by this poller. - */ -TransportPoller::TransportPoller(Transport* transport) - : transport(transport) - , nextTimeoutCycles(0) -{} - -/** - * Make incremental progress performing all Transport functionality. - * - * This method MUST be called for the Transport to make progress and should - * be called frequently to ensure timely progress. - */ -void -TransportPoller::poll() -{ - // Receive and dispatch incoming packets. - processPackets(); - - // Allow sender and receiver to make incremental progress. - uint64_t waitUntil; - transport->trySend(&waitUntil); - transport->trySendGrants(); - - if (PerfUtils::Cycles::rdtsc() >= nextTimeoutCycles.load()) { - uint64_t requestedTimeoutCycles = transport->checkTimeouts(); - nextTimeoutCycles.store(requestedTimeoutCycles); - } -} - -/** - * Helper method which receives a burst of incoming packets and process them - * through the transport protocol. Pulled out of TransportPoller::poll() to - * simplify unit testing. - */ -void -TransportPoller::processPackets() -{ - // Keep track of time spent doing active processing versus idle. - uint64_t cycles = PerfUtils::Cycles::rdtsc(); - - const int MAX_BURST = 32; - Driver::Packet packets[MAX_BURST]; - IpAddress srcAddrs[MAX_BURST]; - Driver* driver = transport->getDriver(); - int numPackets = driver->receivePackets(MAX_BURST, packets, srcAddrs); - for (int i = 0; i < numPackets; ++i) { - transport->processPacket(&packets[i], srcAddrs[i]); - } - - cycles = PerfUtils::Cycles::rdtsc() - cycles; - if (numPackets > 0) { - Perf::counters.active_cycles.add(cycles); - } else { - Perf::counters.idle_cycles.add(cycles); - } -} - -} // namespace Homa diff --git a/test/system_test.cc b/test/system_test.cc index 316b2ca..1e5b8c7 100644 --- a/test/system_test.cc +++ b/test/system_test.cc @@ -15,8 +15,7 @@ #include #include -#include -#include +#include #include #include @@ -54,54 +53,30 @@ struct MessageHeader { } __attribute__((packed)); struct Node { - class Callbacks : public Homa::Callbacks { - public: - explicit Callbacks(std::vector& receiveQueue) - : receiveQueue(receiveQueue) - {} - - bool deliver(uint16_t port, Homa::InMessage* message) override - { - if (port != SERVER_PORT) { - return false; - } - receiveQueue.push_back(message); - return true; - } - - std::vector& receiveQueue; - }; - explicit Node(uint64_t id) : id(id) - , callbacks(receiveQueue) , driver() - , transport(Homa::Transport::create(&driver, &callbacks, id)) + , transport(Homa::PollModeTransport::create(&driver, id)) , thread() - , receiveQueue() , run(false) {} const uint64_t id; - Callbacks callbacks; Homa::Drivers::Fake::FakeDriver driver; - Homa::unique_ptr transport; + Homa::unique_ptr transport; std::thread thread; - std::vector receiveQueue; std::atomic run; }; void serverMain(Node* server, std::vector addresses) { - Homa::TransportPoller poller(server->transport.get()); while (true) { if (server->run.load() == false) { break; } - Homa::unique_ptr message(server->receiveQueue.back()); - server->receiveQueue.pop_back(); + Homa::unique_ptr message(server->transport->receive()); if (message) { MessageHeader header; message->get(0, &header, sizeof(MessageHeader)); @@ -114,7 +89,7 @@ serverMain(Node* server, std::vector addresses) } message->acknowledge(); } - poller.poll(); + server->transport->poll(); } } @@ -134,7 +109,6 @@ clientMain(int count, int size, std::vector addresses) int numFailed = 0; Node client(1); - Homa::TransportPoller poller(client.transport.get()); for (int i = 0; i < count; ++i) { uint64_t id = nextId++; char payload[size]; @@ -166,7 +140,7 @@ clientMain(int count, int size, std::vector addresses) numFailed++; break; } - poller.poll(); + client.transport->poll(); } } return numFailed; From d2eaf6560b2b59aac0369fd81401dc582005b240 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Wed, 21 Oct 2020 00:36:49 -0700 Subject: [PATCH 10/15] some easy fixes based on code review --- include/Homa/Bindings/CHoma.h | 17 ++++++++++---- include/Homa/Core/Transport.h | 26 +++++++++++++-------- include/Homa/Driver.h | 5 ++-- include/Homa/Homa.h | 3 ++- include/Homa/OutMessageStatus.h | 5 +++- include/Homa/Transports/PollModeTransport.h | 12 +++++----- include/Homa/Transports/Shenango.h | 9 ++++--- src/PollModeTransportImpl.h | 7 ++++-- src/Protocol.h | 8 +++---- src/TransportImpl.cc | 13 ++++++----- src/TransportImpl.h | 6 ++--- test/Output.h | 5 +++- 12 files changed, 70 insertions(+), 46 deletions(-) diff --git a/include/Homa/Bindings/CHoma.h b/include/Homa/Bindings/CHoma.h index 79e6534..7a483b4 100644 --- a/include/Homa/Bindings/CHoma.h +++ b/include/Homa/Bindings/CHoma.h @@ -19,7 +19,8 @@ * Contains C-bindings for the Homa Transport API. */ -#pragma once +#ifndef HOMA_INCLUDE_HOMA_BINDINGS_CHOMA_H +#define HOMA_INCLUDE_HOMA_BINDINGS_CHOMA_H #include "Homa/OutMessageStatus.h" @@ -137,9 +138,9 @@ extern void homa_outmsg_send(homa_outmsg out_msg, uint32_t ip, uint16_t port); */ extern void homa_outmsg_release(homa_outmsg out_msg); -/* ============================ */ -/* Homa::Transport API */ -/* ============================ */ +/* ================================ */ +/* Homa::TransportBase API */ +/* ================================ */ /** * homa_trans_create - C-binding for Homa::TransportBase::create @@ -167,6 +168,10 @@ extern homa_driver homa_trans_get_drv(homa_trans trans); */ extern uint64_t homa_trans_id(homa_trans trans); +/* ================================ */ +/* Homa::Core::Transport API */ +/* ================================ */ + /** * homa_trans_check_timeouts - C-binding for Core::Transport::checkTimeouts */ @@ -190,4 +195,6 @@ extern bool homa_trans_try_grant(homa_trans trans); #ifdef __cplusplus } -#endif \ No newline at end of file +#endif + +#endif // HOMA_INCLUDE_HOMA_BINDINGS_CHOMA_H \ No newline at end of file diff --git a/include/Homa/Core/Transport.h b/include/Homa/Core/Transport.h index b042cf0..7b7292a 100644 --- a/include/Homa/Core/Transport.h +++ b/include/Homa/Core/Transport.h @@ -20,7 +20,8 @@ * Transport library should include this header. */ -#pragma once +#ifndef HOMA_INCLUDE_HOMA_CORE_TRANSPORT_H +#define HOMA_INCLUDE_HOMA_CORE_TRANSPORT_H #include @@ -56,15 +57,18 @@ class Transport : public TransportBase { * * Here are a few example use cases of this callback: *
    - *
  • Interaction with the user's thread scheduler: e.g., an - * application may want to block on receive until a message is - * delivered, so this method can be used to wake up blocking threads. - *
  • High-performance message dispatch: e.g., an application may - * choose to implement the message receive queue with a concurrent MPMC - * queue as opposed to a linked-list protected by a mutex;
  • - * Lightweight synchronization: e.g., the socket table that maps from - * port numbers to sockets is a read-mostly data structure, so lookup - * operations can benefit from synchronization schemes such as RCU. + *
  • + * Interaction with the user's thread scheduler: e.g., an application + * may want to block on receive until a message is delivered, so this + * method can be used to wake up blocking threads. + *
  • + * High-performance message dispatch: e.g., an application may choose + * to implement the message receive queue with a concurrent MPMC queue + * as opposed to a linked-list protected by a mutex; + *
  • + * Lightweight synchronization: e.g., the socket table that maps port + * numbers to sockets is a read-mostly data structure, so lookup + * operations can benefit from synchronization schemes such as RCU. *
* * @param port @@ -160,3 +164,5 @@ class Transport : public TransportBase { }; } // namespace Homa::Core + +#endif // HOMA_INCLUDE_HOMA_CORE_TRANSPORT_H \ No newline at end of file diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index 602f78e..173fe5a 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -19,7 +19,6 @@ #include #include "Homa/Exception.h" -#include "Homa/OutMessageStatus.h" namespace Homa { @@ -68,8 +67,8 @@ struct IpAddress final { static_assert(std::is_trivially_copyable()); /** - * Used by Homa::Transport to send and receive unreliable datagrams. Provides - * the interface to which all Driver implementations must conform. + * Used by Homa::Core::Transport to send and receive unreliable datagrams. + * Provides the interface to which all Driver implementations must conform. * * Implementations of this class should be thread-safe. */ diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index 50ebfa7..74e4228 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -24,6 +24,7 @@ #define HOMA_INCLUDE_HOMA_HOMA_H #include +#include namespace Homa { @@ -31,7 +32,7 @@ namespace Homa { * Shorthand for an std::unique_ptr with a customized deleter. * * This is used to implement the "borrow" semantics for interface classes like - * InMessage, OutMessage, and Socket; that is, a user can obtain pointers to + * InMessage, OutMessage, and Transport; that is, a user can obtain pointers to * these objects and use them to make function calls, but the user must always * return the objects back to the transport library eventually because the user * has no idea how to destruct the objects or reclaim memory. diff --git a/include/Homa/OutMessageStatus.h b/include/Homa/OutMessageStatus.h index 9957e7a..941ce02 100644 --- a/include/Homa/OutMessageStatus.h +++ b/include/Homa/OutMessageStatus.h @@ -13,7 +13,8 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#pragma once +#ifndef HOMA_INCLUDE_HOMA_OUTMESSAGESTATUS_H +#define HOMA_INCLUDE_HOMA_OUTMESSAGESTATUS_H /** * Defines the possible states of an OutMessage. @@ -30,3 +31,5 @@ enum homa_outmsg_status { COMPLETED, //< The message has been received and processed. FAILED, //< The message failed to be delivered and processed. }; + +#endif // HOMA_INCLUDE_HOMA_OUTMESSAGESTATUS_H \ No newline at end of file diff --git a/include/Homa/Transports/PollModeTransport.h b/include/Homa/Transports/PollModeTransport.h index 7bc224e..262c3d7 100644 --- a/include/Homa/Transports/PollModeTransport.h +++ b/include/Homa/Transports/PollModeTransport.h @@ -13,7 +13,8 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#pragma once +#ifndef HOMA_INCLUDE_HOMA_TRANSPORTS_POLLMODETRANSPORT_H +#define HOMA_INCLUDE_HOMA_TRANSPORTS_POLLMODETRANSPORT_H #include @@ -47,15 +48,14 @@ class PollModeTransport : public TransportBase { virtual void poll() = 0; /** - * Check for and return a Message sent to this Socket if available. + * Check for and return a Message sent to this transport if available. * - * @param blocking - * When set to true, this method should not return until a message - * arrives or the socket is shut down. * @return * Pointer to the received message, if any; otherwise, nullptr. */ virtual Homa::unique_ptr receive() = 0; }; -} // namespace Homa \ No newline at end of file +} // namespace Homa + +#endif // HOMA_INCLUDE_HOMA_TRANSPORTS_POLLMODETRANSPORT_H \ No newline at end of file diff --git a/include/Homa/Transports/Shenango.h b/include/Homa/Transports/Shenango.h index 369ea2c..0e7c930 100644 --- a/include/Homa/Transports/Shenango.h +++ b/include/Homa/Transports/Shenango.h @@ -26,9 +26,10 @@ * This file follows the Shenango coding style. */ -#pragma once +#ifndef HOMA_INCLUDE_HOMA_TRANSPORTS_SHENANGO_H +#define HOMA_INCLUDE_HOMA_TRANSPORTS_SHENANGO_H -#include "Homa/Bindings/CHoma.h" +#include #ifdef __cplusplus extern "C" { @@ -60,4 +61,6 @@ extern void homa_free_shenango_trans(homa_trans trans); #ifdef __cplusplus } -#endif \ No newline at end of file +#endif + +#endif // HOMA_INCLUDE_HOMA_TRANSPORTS_SHENANGO_H \ No newline at end of file diff --git a/src/PollModeTransportImpl.h b/src/PollModeTransportImpl.h index 0740b2d..ba0473e 100644 --- a/src/PollModeTransportImpl.h +++ b/src/PollModeTransportImpl.h @@ -13,7 +13,8 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#pragma once +#ifndef HOMA_POLLMODETRANSPORT_H +#define HOMA_POLLMODETRANSPORT_H #include #include @@ -80,4 +81,6 @@ class PollModeTransportImpl final : public PollModeTransport { std::vector receiveQueue; }; -} // namespace Homa \ No newline at end of file +} // namespace Homa + +#endif // HOMA_POLLMODETRANSPORT_H \ No newline at end of file diff --git a/src/Protocol.h b/src/Protocol.h index 55a34ac..fe2cec0 100644 --- a/src/Protocol.h +++ b/src/Protocol.h @@ -113,11 +113,9 @@ enum Opcode { * protocol version before interpreting the rest of the packet. */ struct HeaderPrefix { - uint16_t sport, - dport; ///< Transport layer (L4) source and destination ports - ///< in network byte order; only used by DataHeader. - uint8_t version; ///< The version of the protocol being used by this - ///< packet. + uint16_t sport; ///< Transport layer (L4) source and destination ports + uint16_t dport; ///< in network byte order; only used by DataHeader. + uint8_t version; ///< The version of the protocol being used by this packet /// HeaderPrefix constructor. HeaderPrefix(uint16_t sport, uint16_t dport, uint8_t version) diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index dd1ceff..2138de2 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -29,6 +29,7 @@ const uint64_t PING_INTERVAL_US = 3 * BASE_TIMEOUT_US; /// Microseconds to wait before performing retires on inbound messages. const uint64_t RESEND_INTERVAL_US = BASE_TIMEOUT_US; +/// See Homa::Core::Transport::create() Homa::unique_ptr Transport::create(Driver* driver, Callbacks* callbacks, uint64_t transportId) { @@ -77,7 +78,7 @@ TransportImpl::TransportImpl(Driver* driver, Callbacks* callbacks, , receiver(receiver) {} -/// See Homa::Transport::free() +/// See Homa::TransportBase::free() void TransportImpl::free() { @@ -90,7 +91,7 @@ TransportImpl::free() delete this; } -/// See Homa::Transport::alloc() +/// See Homa::TransportBase::alloc() Homa::unique_ptr TransportImpl::alloc(uint16_t port) { @@ -98,7 +99,7 @@ TransportImpl::alloc(uint16_t port) return unique_ptr(outMessage); } -/// See Homa::Transport::checkTimeouts() +/// See Homa::Core::Transport::checkTimeouts() uint64_t TransportImpl::checkTimeouts() { @@ -107,7 +108,7 @@ TransportImpl::checkTimeouts() return requestedTimeoutCycles; } -/// See Homa::Transport::processPacket() +/// See Homa::Core::Transport::processPacket() void TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) { @@ -152,14 +153,14 @@ TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) } } -/// See Homa::Transport::trySend() +/// See Homa::Core::Transport::trySend() bool TransportImpl::trySend(uint64_t* waitUntil) { return sender->trySend(waitUntil); } -/// See Homa::Transport::trySendGrants() +/// See Homa::Core::Transport::trySendGrants() bool TransportImpl::trySendGrants() { diff --git a/src/TransportImpl.h b/src/TransportImpl.h index f35d274..9e0ea03 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -33,7 +33,7 @@ namespace Homa::Core { /** - * Internal implementation of Homa::Transport. + * Internal implementation of Homa::Core::Transport. */ class TransportImpl final : public Transport { public: @@ -49,13 +49,13 @@ class TransportImpl final : public Transport { bool trySend(uint64_t* waitUntil) override; bool trySendGrants() override; - /// See Homa::Transport::getDriver() + /// See Homa::Core::Transport::getDriver() Driver* getDriver() override { return driver; } - /// See Homa::Transport::getId() + /// See Homa::TransportBase::getId() uint64_t getId() override { return transportId; diff --git a/test/Output.h b/test/Output.h index de8d740..467280e 100644 --- a/test/Output.h +++ b/test/Output.h @@ -13,7 +13,8 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#pragma once +#ifndef HOMA_TEST_OUTPUT_H +#define HOMA_TEST_OUTPUT_H #include #include @@ -118,3 +119,5 @@ basic(std::vector& times, const std::string description) } } // namespace Output + +#endif // HOMA_TEST_OUTPUT_H \ No newline at end of file From 0c0aab2ea27b9532f559521d01728c543a3095db Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Wed, 21 Oct 2020 15:57:55 -0700 Subject: [PATCH 11/15] minor --- include/Homa/Core/Transport.h | 9 ++++++++- include/Homa/Homa.h | 5 ----- src/PollModeTransportImpl.cc | 10 ++-------- src/PollModeTransportImpl.h | 1 - 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/include/Homa/Core/Transport.h b/include/Homa/Core/Transport.h index 7b7292a..b591b93 100644 --- a/include/Homa/Core/Transport.h +++ b/include/Homa/Core/Transport.h @@ -122,6 +122,11 @@ class Transport : public TransportBase { */ virtual uint64_t checkTimeouts() = 0; + /** + * Return the driver that this transport uses to send and receive packets. + */ + virtual Driver* getDriver() = 0; + /** * Handle an ingress packet by running it through the transport protocol * stack. @@ -154,7 +159,9 @@ class Transport : public TransportBase { * Attempt to grant to incoming messages according to the Homa protocol. * * This method must be called eagerly to allow the Transport to make - * progress toward receiving incoming messages. + * progress toward receiving incoming messages. For example, a user may + * invoke this method every time the transport finishes processing a batch + * of incoming packets. * * @return * True if the method has found some messages to grant; false, diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index 74e4228..ff9e8b9 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -288,11 +288,6 @@ class TransportBase { */ virtual Homa::unique_ptr alloc(uint16_t port) = 0; - /** - * Return the driver that this transport uses to send and receive packets. - */ - virtual Driver* getDriver() = 0; - /** * Return this transport's unique identifier. */ diff --git a/src/PollModeTransportImpl.cc b/src/PollModeTransportImpl.cc index 58c03ec..d2e4e6d 100644 --- a/src/PollModeTransportImpl.cc +++ b/src/PollModeTransportImpl.cc @@ -67,13 +67,6 @@ PollModeTransportImpl::free() delete this; } -/// See Homa::PollModeTransport::getDriver() -Driver* -PollModeTransportImpl::getDriver() -{ - return core.getDriver(); -} - /// See Homa::PollModeTransport::getId() uint64_t PollModeTransportImpl::getId() @@ -124,7 +117,8 @@ PollModeTransportImpl::processPackets() const int MAX_BURST = 32; Driver::Packet packets[MAX_BURST]; IpAddress srcAddrs[MAX_BURST]; - int numPackets = getDriver()->receivePackets(MAX_BURST, packets, srcAddrs); + Driver* driver = core.getDriver(); + int numPackets = driver->receivePackets(MAX_BURST, packets, srcAddrs); for (int i = 0; i < numPackets; ++i) { core.processPacket(&packets[i], srcAddrs[i]); } diff --git a/src/PollModeTransportImpl.h b/src/PollModeTransportImpl.h index ba0473e..f1a201d 100644 --- a/src/PollModeTransportImpl.h +++ b/src/PollModeTransportImpl.h @@ -34,7 +34,6 @@ class PollModeTransportImpl final : public PollModeTransport { virtual ~PollModeTransportImpl() = default; Homa::unique_ptr alloc(uint16_t port) override; void free() override; - Driver* getDriver() override; uint64_t getId() override; void poll() override; Homa::unique_ptr receive() override; From 7486abf0d41010bd76a4a6fe0ab3e57d5bca3ee8 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Wed, 21 Oct 2020 22:35:03 -0700 Subject: [PATCH 12/15] more --- CMakeLists.txt | 8 ++-- include/Homa/Core/Transport.h | 11 ++--- include/Homa/Transports/Shenango.h | 2 - src/{ => Bindings}/CHoma.cc | 0 src/Receiver.cc | 2 +- src/Sender.cc | 42 +++++++------------ src/Sender.h | 7 ++-- src/TransportImpl.cc | 6 +-- src/TransportImpl.h | 2 +- src/{ => Transports}/PollModeTransportImpl.cc | 3 +- src/{ => Transports}/PollModeTransportImpl.h | 2 +- .../PollModeTransportImplTest.cc | 0 src/{ => Transports}/Shenango.cc | 2 +- 13 files changed, 34 insertions(+), 53 deletions(-) rename src/{ => Bindings}/CHoma.cc (100%) rename src/{ => Transports}/PollModeTransportImpl.cc (98%) rename src/{ => Transports}/PollModeTransportImpl.h (98%) rename src/{ => Transports}/PollModeTransportImplTest.cc (100%) rename src/{ => Transports}/Shenango.cc (99%) diff --git a/CMakeLists.txt b/CMakeLists.txt index bbbda14..7df3cd3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,19 +72,19 @@ endif() ## lib Homa #################################################################### add_library(Homa + src/Bindings/CHoma.cc src/CodeLocation.cc - src/CHoma.cc src/Debug.cc src/Driver.cc src/Perf.cc src/Policy.cc - src/PollModeTransportImpl.cc src/Receiver.cc src/Sender.cc - src/Shenango.cc src/StringUtil.cc src/ThreadId.cc src/TransportImpl.cc + src/Transports/PollModeTransportImpl.cc + src/Transports/Shenango.cc src/Util.cc ) add_library(Homa::Homa ALIAS Homa) @@ -256,7 +256,6 @@ add_executable(unit_test src/IntrusiveTest.cc src/ObjectPoolTest.cc src/PolicyTest.cc - src/PollModeTransportImplTest.cc src/ReceiverTest.cc src/SenderTest.cc src/SpinLockTest.cc @@ -265,6 +264,7 @@ add_executable(unit_test src/ThreadIdTest.cc src/TimeoutTest.cc src/TransportImplTest.cc + src/Transports/PollModeTransportImplTest.cc src/TubTest.cc src/UtilTest.cc ) diff --git a/include/Homa/Core/Transport.h b/include/Homa/Core/Transport.h index b591b93..a3b6b4d 100644 --- a/include/Homa/Core/Transport.h +++ b/include/Homa/Core/Transport.h @@ -145,15 +145,12 @@ class Transport : public TransportBase { * This method must be called eagerly to allow the Transport to make * progress toward sending outgoing messages. * - * @param[out] waitUntil - * The rdtsc cycle time when this method should be called again - * (this allows the NIC to drain its transmit queue). Only set - * when this method returns true. * @return - * True if more packets are ready to be transmitted when the method - * returns; false, otherwise. + * The rdtsc cycle time when this method should be called again to + * transmit the rest of the packets (this allows the NIC to drain its + * transmit queue first), or zero if there is no more packets to send. */ - virtual bool trySend(uint64_t* waitUntil) = 0; + virtual uint64_t trySend() = 0; /** * Attempt to grant to incoming messages according to the Homa protocol. diff --git a/include/Homa/Transports/Shenango.h b/include/Homa/Transports/Shenango.h index 0e7c930..40d7b5d 100644 --- a/include/Homa/Transports/Shenango.h +++ b/include/Homa/Transports/Shenango.h @@ -22,8 +22,6 @@ * Shenango is an experimental operating system that aims to provide low tail * latency and high CPU efficiency simultaneously for servers in datacenters. * See for more information. - * - * This file follows the Shenango coding style. */ #ifndef HOMA_INCLUDE_HOMA_TRANSPORTS_SHENANGO_H diff --git a/src/CHoma.cc b/src/Bindings/CHoma.cc similarity index 100% rename from src/CHoma.cc rename to src/Bindings/CHoma.cc diff --git a/src/Receiver.cc b/src/Receiver.cc index 689d7df..c099eb3 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -720,7 +720,7 @@ Receiver::trySendGrants() uint64_t start_tsc = PerfUtils::Cycles::rdtsc(); // Fast path: skip if no message is waiting for grants - bool needGrants = !dontNeedGrants.test_and_set(); + bool needGrants = !dontNeedGrants.test_and_set(std::memory_order_acquire); if (!needGrants) { return false; } diff --git a/src/Sender.cc b/src/Sender.cc index 9eac00c..768b2eb 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -210,7 +210,7 @@ Sender::handleResendPacket(Driver::Packet* packet) // will never be overridden since the resend index will not exceed the // preset packetsGranted. info->priority = header->priority; - signalPacerThread(lock_queue); + signalSendReady(lock_queue); } if (index >= info->packetsSent) { @@ -292,7 +292,7 @@ Sender::handleGrantPacket(Driver::Packet* packet) // limit will never be overridden since the incomingGrantIndex will // not exceed the preset packetsGranted. info->priority = header->priority; - signalPacerThread(lock_queue); + signalSendReady(lock_queue); } } @@ -419,7 +419,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet) Intrusive::deprioritize( &sendQueue, &info->sendQueueNode, QueuedMessageInfo::ComparePriority()); - signalPacerThread(lock_queue); + signalSendReady(lock_queue); } } @@ -826,7 +826,7 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, sendQueue.push_front(&info->sendQueueNode); Intrusive::deprioritize(&sendQueue, &info->sendQueueNode, QueuedMessageInfo::ComparePriority()); - signalPacerThread(lock_queue); + signalSendReady(lock_queue); } } @@ -983,9 +983,8 @@ Sender::checkPingTimeouts() } /** - * Attempt to wake up the pacer thread that is responsible for calling trySend() - * repeatedly, if it's currently blocked waiting for packets to become ready to - * be sent. + * Signal the thread which is responsible for calling trySend() that some + * packets just become ready to be sent. * * This method is called when new GRANTs arrive, when new outgoing messages * appear, and when retransmission is requested. @@ -994,30 +993,16 @@ Sender::checkPingTimeouts() * Reminder to hold the Sender::queueMutex during this call. */ void -Sender::signalPacerThread(const SpinLock::Lock& lockHeld) +Sender::signalSendReady(const SpinLock::Lock& lockHeld) { (void)lockHeld; sendReady = true; callbacks->notifySendReady(); } -/** - * Attempt to send out packets for any messages with unscheduled/granted bytes - * in a way that limits queue buildup in the NIC. - * - * This method must be called eagerly to allow the Sender to make progress - * toward sending outgoing messages. - * - * @param[out] waitUntil - * Time to wait before next call, in microseconds, in order to allow - * the NIC transmit queue to drain. Only set when this method returns - * true. - * @return - * True if more packets are ready to be transmitted when the method - * returns; false, otherwise. - */ -bool -Sender::trySend(uint64_t* waitUntil) +/// See Homa::Core::Transport::trySend() +uint64_t +Sender::trySend() { uint64_t start_tsc = PerfUtils::Cycles::rdtsc(); bool idle = true; @@ -1025,7 +1010,7 @@ Sender::trySend(uint64_t* waitUntil) // Skip when there are no messages to send. SpinLock::UniqueLock lock_queue(queueMutex); if (!sendReady) { - return false; + return 0; } /* The goal is to send out packets for messages that have bytes that have @@ -1038,6 +1023,7 @@ Sender::trySend(uint64_t* waitUntil) // Optimistically assume we will finish sending every granted packet this // round; we will set again sendReady if it turns out we don't finish. sendReady = false; + uint64_t waitUntil = 0; auto it = sendQueue.begin(); while (it != sendQueue.end()) { Message& message = *it; @@ -1084,7 +1070,7 @@ Sender::trySend(uint64_t* waitUntil) // Compute how much time the driver needs to drain its queue, // then schedule to wake up a bit earlier to avoid blowing bubbles. static const uint64_t us = PerfUtils::Cycles::fromMicroseconds(1); - *waitUntil = + waitUntil = PerfUtils::Cycles::rdtsc() - 1 * us + queuedBytesEstimate * DRIVER_CYCLES_TO_DRAIN_1MB / 1000000; break; @@ -1097,7 +1083,7 @@ Sender::trySend(uint64_t* waitUntil) } else { Perf::counters.idle_cycles.add(elapsed_cycles); } - return sendReady; + return waitUntil; } } // namespace Core diff --git a/src/Sender.h b/src/Sender.h index d637260..6c6f5b4 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -54,7 +54,7 @@ class Sender { virtual void handleUnknownPacket(Driver::Packet* packet); virtual void handleErrorPacket(Driver::Packet* packet); virtual uint64_t checkTimeouts(); - virtual bool trySend(uint64_t* waitUntil); + virtual uint64_t trySend(); private: /// Forward declarations @@ -393,7 +393,7 @@ class Sender { void sendMessage(Sender::Message* message, SocketAddress destination, Message::Options options = Message::Options::NONE); - void signalPacerThread(const SpinLock::Lock& lockHeld); + void signalSendReady(const SpinLock::Lock& lockHeld); void cancelMessage(Sender::Message* message); void dropMessage(Sender::Message* message); uint64_t checkMessageTimeouts(); @@ -429,7 +429,8 @@ class Sender { /// Hint whether there are messages ready to be sent (i.e. there are granted /// messages in the sendQueue. Encoded into a single bool so that checking - /// if there is work to do is more efficient. + /// if there is work to do is more efficient. Access to this field is + /// protected by queueMutex. bool sendReady; /// A list of outbound messages that have unsent packets. Messages are kept diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index 2138de2..b249b22 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -154,10 +154,10 @@ TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) } /// See Homa::Core::Transport::trySend() -bool -TransportImpl::trySend(uint64_t* waitUntil) +uint64_t +TransportImpl::trySend() { - return sender->trySend(waitUntil); + return sender->trySend(); } /// See Homa::Core::Transport::trySendGrants() diff --git a/src/TransportImpl.h b/src/TransportImpl.h index 9e0ea03..f083375 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -46,7 +46,7 @@ class TransportImpl final : public Transport { Homa::unique_ptr alloc(uint16_t port) override; uint64_t checkTimeouts() override; void processPacket(Driver::Packet* packet, IpAddress source) override; - bool trySend(uint64_t* waitUntil) override; + uint64_t trySend() override; bool trySendGrants() override; /// See Homa::Core::Transport::getDriver() diff --git a/src/PollModeTransportImpl.cc b/src/Transports/PollModeTransportImpl.cc similarity index 98% rename from src/PollModeTransportImpl.cc rename to src/Transports/PollModeTransportImpl.cc index d2e4e6d..f2c013c 100644 --- a/src/PollModeTransportImpl.cc +++ b/src/Transports/PollModeTransportImpl.cc @@ -81,8 +81,7 @@ PollModeTransportImpl::poll() processPackets(); // Allow sender and receiver to make incremental progress. - uint64_t waitUntil; - core.trySend(&waitUntil); + core.trySend(); core.trySendGrants(); if (PerfUtils::Cycles::rdtsc() >= nextTimeoutCycles.load()) { diff --git a/src/PollModeTransportImpl.h b/src/Transports/PollModeTransportImpl.h similarity index 98% rename from src/PollModeTransportImpl.h rename to src/Transports/PollModeTransportImpl.h index f1a201d..798eecd 100644 --- a/src/PollModeTransportImpl.h +++ b/src/Transports/PollModeTransportImpl.h @@ -18,7 +18,7 @@ #include #include -#include "TransportImpl.h" +#include "../TransportImpl.h" namespace Homa { diff --git a/src/PollModeTransportImplTest.cc b/src/Transports/PollModeTransportImplTest.cc similarity index 100% rename from src/PollModeTransportImplTest.cc rename to src/Transports/PollModeTransportImplTest.cc diff --git a/src/Shenango.cc b/src/Transports/Shenango.cc similarity index 99% rename from src/Shenango.cc rename to src/Transports/Shenango.cc index 0c99b77..423a782 100644 --- a/src/Shenango.cc +++ b/src/Transports/Shenango.cc @@ -16,7 +16,7 @@ #include "Homa/Transports/Shenango.h" #include -#include "Debug.h" +#include "../Debug.h" #include "Homa/Core/Transport.h" using namespace Homa; From 2ddefe8cae0ca67948712b637f5e0c9e1021c8af Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Wed, 21 Oct 2020 22:35:47 -0700 Subject: [PATCH 13/15] signalPacerThread & Callbacks::deliver --- include/Homa/Core/Transport.h | 3 +- src/Receiver.cc | 6 ++- src/ReceiverTest.cc | 8 ++-- src/Sender.cc | 57 +++++++++++++++---------- src/Sender.h | 12 +++--- src/Transports/PollModeTransportImpl.cc | 4 +- src/Transports/PollModeTransportImpl.h | 7 +-- src/Transports/Shenango.cc | 5 ++- 8 files changed, 60 insertions(+), 42 deletions(-) diff --git a/include/Homa/Core/Transport.h b/include/Homa/Core/Transport.h index a3b6b4d..f7b06d8 100644 --- a/include/Homa/Core/Transport.h +++ b/include/Homa/Core/Transport.h @@ -78,7 +78,8 @@ class Transport : public TransportBase { * @return * True if the message is delivered successfully; false, otherwise. */ - virtual bool deliver(uint16_t port, InMessage* message) = 0; + virtual bool deliver(uint16_t port, + Homa::unique_ptr message) = 0; /** * Invoked when some packets just became ready to be sent (and there was diff --git a/src/Receiver.cc b/src/Receiver.cc index c099eb3..2c266ba 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -164,10 +164,12 @@ Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) // All message packets have been received. message->setState(Message::State::COMPLETED); bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); + lock_bucket.destroy(); + uint16_t dport = be16toh(header->common.prefix.dport); - bool success = callbacks->deliver(dport, message); + bool success = + callbacks->deliver(dport, Homa::unique_ptr(message)); if (!success) { - lock_bucket.destroy(); ERROR("Unable to deliver the message; message dropped"); dropMessage(message); } diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index 745076d..0450437 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -66,16 +66,16 @@ class MockCallbacks : public Transport::Callbacks { : receivedMessage() {} - bool deliver(uint16_t port, Homa::InMessage* message) override + bool deliver(uint16_t port, Homa::unique_ptr message) override { if (port != 60001) { return false; } - receivedMessage = message; + receivedMessage = std::move(message); return true; } - Homa::InMessage* receivedMessage; + Homa::unique_ptr receivedMessage; }; class ReceiverTest : public ::testing::Test { @@ -236,7 +236,7 @@ TEST_F(ReceiverTest, handleDataPacket) EXPECT_EQ(4U, message->numPackets); EXPECT_EQ(0U, info->bytesRemaining); EXPECT_EQ(Receiver::Message::State::COMPLETED, message->state); - EXPECT_EQ(message, mockCallbacks.receivedMessage); + EXPECT_EQ(message, mockCallbacks.receivedMessage.get()); Mock::VerifyAndClearExpectations(&mockDriver); // ------------------------------------------------------------------------- diff --git a/src/Sender.cc b/src/Sender.cc index 768b2eb..38deb27 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -187,7 +187,8 @@ Sender::handleResendPacket(Driver::Packet* packet) bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); - SpinLock::Lock lock_queue(queueMutex); + bool notifySendReady = false; + SpinLock::UniqueLock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; // Check if RESEND request is out of range. @@ -210,7 +211,8 @@ Sender::handleResendPacket(Driver::Packet* packet) // will never be overridden since the resend index will not exceed the // preset packetsGranted. info->priority = header->priority; - signalSendReady(lock_queue); + sendReady = true; + notifySendReady = true; } if (index >= info->packetsSent) { @@ -235,6 +237,12 @@ Sender::handleResendPacket(Driver::Packet* packet) } } + // Only invoke the callback after unlocking queueMutex. + lock_queue.unlock(); + if (notifySendReady) { + callbacks->notifySendReady(); + } + driver->releasePackets(packet, 1); } @@ -263,6 +271,7 @@ Sender::handleGrantPacket(Driver::Packet* packet) bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); + bool notifySendReady = false; if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -292,10 +301,16 @@ Sender::handleGrantPacket(Driver::Packet* packet) // limit will never be overridden since the incomingGrantIndex will // not exceed the preset packetsGranted. info->priority = header->priority; - signalSendReady(lock_queue); + sendReady = true; + notifySendReady = true; } } + // Only invoke the callback after unlocking queueMutex. + if (notifySendReady) { + callbacks->notifySendReady(); + } + driver->releasePackets(packet, 1); } @@ -388,6 +403,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet) bucket->pingTimeouts.setTimeout(&message->pingTimeout); assert(message->numPackets > 0); + bool notifySendReady = false; if (message->numPackets == 1) { // If there is only one packet in the message, send it right away. Driver::Packet* dataPacket = message->getPacket(0); @@ -419,7 +435,13 @@ Sender::handleUnknownPacket(Driver::Packet* packet) Intrusive::deprioritize( &sendQueue, &info->sendQueueNode, QueuedMessageInfo::ComparePriority()); - signalSendReady(lock_queue); + sendReady = true; + notifySendReady = true; + } + + // Only invoke the callback after unlocking queueMutex. + if (notifySendReady) { + callbacks->notifySendReady(); } } @@ -801,6 +823,7 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); + bool notifySendReady = false; assert(message->numPackets > 0); if (message->numPackets == 1) { // If there is only one packet in the message, send it right away. @@ -826,7 +849,13 @@ Sender::sendMessage(Sender::Message* message, SocketAddress destination, sendQueue.push_front(&info->sendQueueNode); Intrusive::deprioritize(&sendQueue, &info->sendQueueNode, QueuedMessageInfo::ComparePriority()); - signalSendReady(lock_queue); + sendReady = true; + notifySendReady = true; + } + + // Only invoke the callback after unlocking queueMutex. + if (notifySendReady) { + callbacks->notifySendReady(); } } @@ -982,24 +1011,6 @@ Sender::checkPingTimeouts() return globalNextTimeout; } -/** - * Signal the thread which is responsible for calling trySend() that some - * packets just become ready to be sent. - * - * This method is called when new GRANTs arrive, when new outgoing messages - * appear, and when retransmission is requested. - * - * @param lockHeld - * Reminder to hold the Sender::queueMutex during this call. - */ -void -Sender::signalSendReady(const SpinLock::Lock& lockHeld) -{ - (void)lockHeld; - sendReady = true; - callbacks->notifySendReady(); -} - /// See Homa::Core::Transport::trySend() uint64_t Sender::trySend() diff --git a/src/Sender.h b/src/Sender.h index 6c6f5b4..22d5864 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -393,7 +393,6 @@ class Sender { void sendMessage(Sender::Message* message, SocketAddress destination, Message::Options options = Message::Options::NONE); - void signalSendReady(const SpinLock::Lock& lockHeld); void cancelMessage(Sender::Message* message); void dropMessage(Sender::Message* message); uint64_t checkMessageTimeouts(); @@ -402,7 +401,8 @@ class Sender { /// Transport identifier. const uint64_t transportId; - /// User-defined transport callbacks; not owned by this class. + /// User-defined transport callbacks; not owned by this class. As a general + /// rule, one should not hold any locks when invoking a callback. Transport::Callbacks* const callbacks; /// Driver with which all packets will be sent and received. This driver @@ -427,9 +427,11 @@ class Sender { /// Protects the sendQueue and sendReady. SpinLock queueMutex; - /// Hint whether there are messages ready to be sent (i.e. there are granted - /// messages in the sendQueue. Encoded into a single bool so that checking - /// if there is work to do is more efficient. Access to this field is + /// Hint whether there are messages ready to be sent (i.e. granted messages + /// in the sendQueue). Encoded into a single bool so that checking if there + /// is work to do is more efficient. This bool can be cleared by trySend() + /// and set to true when new GRANTs arrive, when new outgoing messages + /// appear, and when retransmission is requested. Access to this field is /// protected by queueMutex. bool sendReady; diff --git a/src/Transports/PollModeTransportImpl.cc b/src/Transports/PollModeTransportImpl.cc index f2c013c..890e9ff 100644 --- a/src/Transports/PollModeTransportImpl.cc +++ b/src/Transports/PollModeTransportImpl.cc @@ -97,9 +97,9 @@ PollModeTransportImpl::receive() if (receiveQueue.empty()) { return nullptr; } - InMessage* message = receiveQueue.back(); + Homa::unique_ptr message = std::move(receiveQueue.back()); receiveQueue.pop_back(); - return Homa::unique_ptr(message); + return message; } /** diff --git a/src/Transports/PollModeTransportImpl.h b/src/Transports/PollModeTransportImpl.h index 798eecd..d614c0c 100644 --- a/src/Transports/PollModeTransportImpl.h +++ b/src/Transports/PollModeTransportImpl.h @@ -50,11 +50,12 @@ class PollModeTransportImpl final : public PollModeTransport { ~PollModeCallbacks() override = default; - bool deliver(uint16_t port, InMessage* message) override + bool deliver(uint16_t port, + Homa::unique_ptr message) override { (void)port; SpinLock::Lock _(owner->mutex); - owner->receiveQueue.push_back(message); + owner->receiveQueue.push_back(std::move(message)); return true; } @@ -77,7 +78,7 @@ class PollModeTransportImpl final : public PollModeTransport { SpinLock mutex; /// Queue of completed incoming messages. - std::vector receiveQueue; + std::vector> receiveQueue; }; } // namespace Homa diff --git a/src/Transports/Shenango.cc b/src/Transports/Shenango.cc index 423a782..96d3320 100644 --- a/src/Transports/Shenango.cc +++ b/src/Transports/Shenango.cc @@ -80,14 +80,15 @@ class ShenangoCallbacks final : Core::Transport::Callbacks { ~ShenangoCallbacks() override = default; - bool deliver(uint16_t port, InMessage* message) override + bool deliver(uint16_t port, Homa::unique_ptr message) override { // The socket table in Shenango is protected by an RCU. shenango_rcu_read_lock(); SocketAddress laddr = {local_ip, port}; void* trans_entry = shenango_trans_table_lookup(proto, laddr, {}); if (trans_entry) { - shenango_homa_mb_deliver(trans_entry, homa_inmsg{message}); + shenango_homa_mb_deliver(trans_entry, + homa_inmsg{message.release()}); } shenango_rcu_read_unlock(); return trans_entry != nullptr; From a447cc1a5d23428cf3a7915174b19ce36fa8c8c3 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Wed, 21 Oct 2020 22:43:51 -0700 Subject: [PATCH 14/15] change signature of Driver::allocPacket --- include/Homa/Driver.h | 6 +++++- include/Homa/Drivers/DPDK/DpdkDriver.h | 2 +- include/Homa/Drivers/Fake/FakeDriver.h | 2 +- src/ControlPacket.h | 3 ++- src/Drivers/DPDK/DpdkDriver.cc | 6 +++--- src/Drivers/DPDK/DpdkDriverImpl.cc | 6 +++--- src/Drivers/DPDK/DpdkDriverImpl.h | 2 +- src/Drivers/Fake/FakeDriver.cc | 4 ++-- src/Drivers/Fake/FakeDriverTest.cc | 5 +++-- src/Sender.cc | 2 +- src/Transports/Shenango.cc | 2 +- test/dpdk_test.cc | 6 ++++-- 12 files changed, 27 insertions(+), 19 deletions(-) diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index 173fe5a..ef32e99 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -101,8 +101,12 @@ class Driver { * Allocate a new Packet object from the Driver's pool of resources. The * caller must eventually release the packet by passing it to a call to * releasePacket(). + * + * @param[out] packet + * Set to the description of the allocated packet when the method + * returns. */ - virtual Packet allocPacket() = 0; + virtual void allocPacket(Packet* packet) = 0; /** * Send a packet over the network. diff --git a/include/Homa/Drivers/DPDK/DpdkDriver.h b/include/Homa/Drivers/DPDK/DpdkDriver.h index fbb3e2c..fd2dd85 100644 --- a/include/Homa/Drivers/DPDK/DpdkDriver.h +++ b/include/Homa/Drivers/DPDK/DpdkDriver.h @@ -119,7 +119,7 @@ class DpdkDriver : public Driver { virtual ~DpdkDriver(); /// See Driver::allocPacket() - virtual Packet allocPacket(); + virtual void allocPacket(Packet* packet); /// See Driver::sendPacket() virtual void sendPacket(Packet* packet, IpAddress destination, diff --git a/include/Homa/Drivers/Fake/FakeDriver.h b/include/Homa/Drivers/Fake/FakeDriver.h index 80e06bb..bba1fc2 100644 --- a/include/Homa/Drivers/Fake/FakeDriver.h +++ b/include/Homa/Drivers/Fake/FakeDriver.h @@ -119,7 +119,7 @@ class FakeDriver : public Driver { */ virtual ~FakeDriver(); - virtual Packet allocPacket(); + virtual void allocPacket(Packet* packet); virtual void sendPacket(Packet* packet, IpAddress destination, int priority); virtual uint32_t receivePackets(uint32_t maxPackets, diff --git a/src/ControlPacket.h b/src/ControlPacket.h index f8d71c9..9c557ef 100644 --- a/src/ControlPacket.h +++ b/src/ControlPacket.h @@ -39,7 +39,8 @@ template void send(Driver* driver, IpAddress address, Args&&... args) { - Driver::Packet packet = driver->allocPacket(); + Driver::Packet packet; + driver->allocPacket(&packet); new (packet.payload) PacketHeaderType(static_cast(args)...); packet.length = sizeof(PacketHeaderType); Perf::counters.tx_bytes.add(packet.length); diff --git a/src/Drivers/DPDK/DpdkDriver.cc b/src/Drivers/DPDK/DpdkDriver.cc index a6cc48d..16a7016 100644 --- a/src/Drivers/DPDK/DpdkDriver.cc +++ b/src/Drivers/DPDK/DpdkDriver.cc @@ -38,10 +38,10 @@ DpdkDriver::DpdkDriver(const char* ifname, NoEalInit _, DpdkDriver::~DpdkDriver() = default; /// See Driver::allocPacket() -Driver::Packet -DpdkDriver::allocPacket() +void +DpdkDriver::allocPacket(Packet* packet) { - return pImpl->allocPacket(); + return pImpl->allocPacket(packet); } /// See Driver::sendPacket() diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index 18daf1c..a797f18 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -184,8 +184,8 @@ DpdkDriver::Impl::~Impl() } // See Driver::allocPacket() -Driver::Packet -DpdkDriver::Impl::allocPacket() +void +DpdkDriver::Impl::allocPacket(Driver::Packet* packet) { PacketBuf* packetBuf = _allocMbufPacket(); if (unlikely(packetBuf == nullptr)) { @@ -194,7 +194,7 @@ DpdkDriver::Impl::allocPacket() packetBuf = packetPool.construct(buf); NOTICE("OverflowBuffer used."); } - return packetBuf->toPacket(0); + *packet = packetBuf->toPacket(0); } // See Driver::sendPacket() diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 819feb9..ac9188c 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -136,7 +136,7 @@ class DpdkDriver::Impl { virtual ~Impl(); // Interface Methods - Driver::Packet allocPacket(); + void allocPacket(Driver::Packet* packet); void sendPacket(Driver::Packet* packet, IpAddress destination, int priority); void cork(); diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index 16fa12e..80a8339 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -181,8 +181,8 @@ FakeDriver::~FakeDriver() /** * See Driver::allocPacket() */ -Driver::Packet -FakeDriver::allocPacket() +void +FakeDriver::allocPacket(Packet* packet) { FakePacket* fakePacket = new FakePacket(); return fakePacket->toPacket(); diff --git a/src/Drivers/Fake/FakeDriverTest.cc b/src/Drivers/Fake/FakeDriverTest.cc index b645d68..2f02f07 100644 --- a/src/Drivers/Fake/FakeDriverTest.cc +++ b/src/Drivers/Fake/FakeDriverTest.cc @@ -37,7 +37,8 @@ TEST(FakeDriverTest, allocPacket) { // allocPacket doesn't do much so we just need to make sure we can call it. FakeDriver driver; - Driver::Packet packet = driver.allocPacket(); + Driver::Packet packet; + driver.allocPacket(&packet); } TEST(FakeDriverTest, sendPackets) @@ -49,7 +50,7 @@ TEST(FakeDriverTest, sendPackets) IpAddress destinations[4]; int prio[4]; for (int i = 0; i < 4; ++i) { - packets[i] = driver1.allocPacket(); + driver1.allocPacket(&packets[i]); destinations[i] = driver2.getLocalAddress(); prio[i] = i; } diff --git a/src/Sender.cc b/src/Sender.cc index 38deb27..43c7600 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -747,7 +747,7 @@ Driver::Packet* Sender::Message::getOrAllocPacket(size_t index) { if (!occupied.test(index)) { - packets[index] = driver->allocPacket(); + driver->allocPacket(&packets[index]); occupied.set(index); numPackets++; // TODO(cstlee): A Message probably shouldn't be in charge of setting diff --git a/src/Transports/Shenango.cc b/src/Transports/Shenango.cc index 96d3320..a5a8a2d 100644 --- a/src/Transports/Shenango.cc +++ b/src/Transports/Shenango.cc @@ -127,7 +127,7 @@ class ShenangoDriver final : public Driver { ~ShenangoDriver() override = default; - Packet allocPacket() override + void allocPacket(Packet* packet) override { void* payload; void* mbuf = shenango_homa_tx_alloc_mbuf(&payload); diff --git a/test/dpdk_test.cc b/test/dpdk_test.cc index f1972a9..38c9bcc 100644 --- a/test/dpdk_test.cc +++ b/test/dpdk_test.cc @@ -61,7 +61,8 @@ main(int argc, char* argv[]) do { receivedPackets = driver.receivePackets(10, incoming, srcAddrs); } while (receivedPackets == 0); - Homa::Driver::Packet pong = driver.allocPacket(); + Homa::Driver::Packet pong; + driver.allocPacket(&pong); pong.length = 100; driver.sendPacket(&pong, srcAddrs[0], 0); driver.releasePackets(incoming, receivedPackets); @@ -74,7 +75,8 @@ main(int argc, char* argv[]) for (int i = 0; i < 100000; ++i) { uint64_t start = PerfUtils::Cycles::rdtsc(); PerfUtils::TimeTrace::record(start, "START"); - Homa::Driver::Packet ping = driver.allocPacket(); + Homa::Driver::Packet ping; + driver.allocPacket(&ping); PerfUtils::TimeTrace::record("allocPacket"); ping.length = 100; PerfUtils::TimeTrace::record("set ping args"); From 204d8cc20946dc4f1ec372b79dbca694eff59bd9 Mon Sep 17 00:00:00 2001 From: Yilong Li Date: Thu, 22 Oct 2020 00:23:47 -0700 Subject: [PATCH 15/15] fixed compilation errors & passed all unit tests --- include/Homa/Bindings/CHoma.h | 2 +- src/Bindings/CHoma.cc | 6 ++--- src/Drivers/Fake/FakeDriver.cc | 2 +- src/Mock/MockDriver.h | 2 +- src/Mock/MockSender.h | 2 +- src/ReceiverTest.cc | 30 ++++++++++++--------- src/SenderTest.cc | 49 ++++++++++++++++++---------------- src/Transports/Shenango.cc | 6 ++--- 8 files changed, 54 insertions(+), 45 deletions(-) diff --git a/include/Homa/Bindings/CHoma.h b/include/Homa/Bindings/CHoma.h index 7a483b4..bd307cb 100644 --- a/include/Homa/Bindings/CHoma.h +++ b/include/Homa/Bindings/CHoma.h @@ -186,7 +186,7 @@ extern void homa_trans_proc(homa_trans trans, uintptr_t desc, void* payload, /** * homa_trans_try_send - C-binding for Core::Transport::trySend */ -extern bool homa_trans_try_send(homa_trans trans, uint64_t* wait_until); +extern uint64_t homa_trans_try_send(homa_trans trans); /** * homa_trans_try_grant - C-binding for Core::Transport::trySendGrants diff --git a/src/Bindings/CHoma.cc b/src/Bindings/CHoma.cc index 7fa88a3..b75971e 100644 --- a/src/Bindings/CHoma.cc +++ b/src/Bindings/CHoma.cc @@ -165,10 +165,10 @@ homa_trans_proc(homa_trans trans, uintptr_t desc, void* payload, int32_t len, deref(Transport, trans).processPacket(&packet, IpAddress{src_ip}); } -bool -homa_trans_try_send(homa_trans trans, uint64_t* wait_until) +uint64_t +homa_trans_try_send(homa_trans trans) { - return deref(Transport, trans).trySend(wait_until); + return deref(Transport, trans).trySend(); } bool diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index 80a8339..10c5e0c 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -185,7 +185,7 @@ void FakeDriver::allocPacket(Packet* packet) { FakePacket* fakePacket = new FakePacket(); - return fakePacket->toPacket(); + *packet = fakePacket->toPacket(); } /** diff --git a/src/Mock/MockDriver.h b/src/Mock/MockDriver.h index dfb6ec2..5a29bc9 100644 --- a/src/Mock/MockDriver.h +++ b/src/Mock/MockDriver.h @@ -50,7 +50,7 @@ class MockDriver : public Driver { } }; - MOCK_METHOD(Packet, allocPacket, (), (override)); + MOCK_METHOD(void, allocPacket, (Packet * packet), (override)); MOCK_METHOD(void, sendPacket, (Packet * packet, IpAddress destination, int priority), (override)); diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index 91fd17f..faa1291 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -46,7 +46,7 @@ class MockSender : public Core::Sender { (override)); MOCK_METHOD(void, handleErrorPacket, (Driver::Packet * packet), (override)); MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); - MOCK_METHOD(bool, trySend, (uint64_t*), (override)); + MOCK_METHOD(uint64_t, trySend, (), (override)); }; } // namespace Mock diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index 0450437..0db4d1d 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -71,11 +71,11 @@ class MockCallbacks : public Transport::Callbacks { if (port != 60001) { return false; } - receivedMessage = std::move(message); + receivedMessage = message.release(); return true; } - Homa::unique_ptr receivedMessage; + InMessage* receivedMessage; }; class ReceiverTest : public ::testing::Test { @@ -236,7 +236,7 @@ TEST_F(ReceiverTest, handleDataPacket) EXPECT_EQ(4U, message->numPackets); EXPECT_EQ(0U, info->bytesRemaining); EXPECT_EQ(Receiver::Message::State::COMPLETED, message->state); - EXPECT_EQ(message, mockCallbacks.receivedMessage.get()); + EXPECT_EQ(message, mockCallbacks.receivedMessage); Mock::VerifyAndClearExpectations(&mockDriver); // ------------------------------------------------------------------------- @@ -305,7 +305,8 @@ TEST_F(ReceiverTest, handlePingPacket_basic) (Protocol::Packet::PingHeader*)pingPacketBuf.buffer; pingHeader->common.messageId = id; - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), Eq(mockAddress), _)) .Times(1); @@ -339,7 +340,8 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), Eq(mockAddress), _)) .Times(1); @@ -427,7 +429,8 @@ TEST_F(ReceiverTest, Message_acknowledge) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); EXPECT_CALL(mockDriver, sendPacket(EqPacketLen(sizeof(Protocol::Packet::DoneHeader)), Eq(message->source.ip), _)) @@ -464,7 +467,8 @@ TEST_F(ReceiverTest, Message_fail) Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); EXPECT_CALL(mockDriver, sendPacket(EqPacketLen(sizeof(Protocol::Packet::ErrorHeader)), Eq(message->source.ip), _)) @@ -829,9 +833,9 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) Driver::Packet mockResendPacket2 = packetBuf1.toPacket(); const size_t RESEND_HEADER_LEN = sizeof(Protocol::Packet::ResendHeader); - EXPECT_CALL(mockDriver, allocPacket()) - .WillOnce(Return(mockResendPacket1)) - .WillOnce(Return(mockResendPacket2)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&mockResendPacket1](auto p) { *p = mockResendPacket1; }) + .WillOnce([&mockResendPacket2](auto p) { *p = mockResendPacket2; }); EXPECT_CALL(mockDriver, sendPacket(EqPacketLen(RESEND_HEADER_LEN), Eq(message[0]->source.ip), _)) .Times(2); @@ -914,7 +918,8 @@ TEST_F(ReceiverTest, trySendGrants) info[0]->bytesRemaining -= 1000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); @@ -941,7 +946,8 @@ TEST_F(ReceiverTest, trySendGrants) info[1]->bytesRemaining -= 1000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); diff --git a/src/SenderTest.cc b/src/SenderTest.cc index fee2342..2c56a16 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -44,7 +44,7 @@ class MockCallbacks : public Transport::Callbacks { public: explicit MockCallbacks() = default; - bool deliver(uint16_t port, Homa::InMessage* message) override + bool deliver(uint16_t port, Homa::unique_ptr message) override { return true; } @@ -512,7 +512,8 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) char busy[1028]; Homa::Mock::MockDriver::PacketBuf busyPacketBuf{busy}; Driver::Packet busyPacket = busyPacketBuf.toPacket(); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(busyPacket)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&busyPacket](Driver::Packet* p) { *p = busyPacket; }); EXPECT_CALL(mockDriver, sendPacket(EqPacket(&busyPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(EqPacket(&busyPacket), Eq(1))) .Times(1); @@ -1108,8 +1109,9 @@ TEST_F(SenderTest, Message_append_basic) setMessagePacket(&msg, 0, packetBuf0.toPacket(MAX_RAW_PACKET_LENGTH - 7)); msg.messageLength = PACKET_DATA_LENGTH - 7; - EXPECT_CALL(mockDriver, allocPacket) - .WillOnce(Return(packetBuf1.toPacket())); + Driver::Packet packet1 = packetBuf1.toPacket(); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&packet1](Driver::Packet* packet) { *packet = packet1; }); msg.append(source, 14); @@ -1196,9 +1198,9 @@ TEST_F(SenderTest, Message_prepend) const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; - EXPECT_CALL(mockDriver, allocPacket) - .WillOnce(Return(packet0)) - .WillOnce(Return(packet1)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&packet0](Driver::Packet* packet) { *packet = packet0; }) + .WillOnce([&packet1](Driver::Packet* packet) { *packet = packet1; }); msg.reserve(PACKET_DATA_LENGTH + 7); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.start); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.messageLength); @@ -1228,6 +1230,8 @@ TEST_F(SenderTest, Message_reserve) char buf[4096]; Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1236,8 +1240,8 @@ TEST_F(SenderTest, Message_reserve) EXPECT_EQ(0U, msg.messageLength); EXPECT_EQ(0U, msg.numPackets); - EXPECT_CALL(mockDriver, allocPacket) - .WillOnce(Return(packetBuf0.toPacket())); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&packet0](Driver::Packet* packet) { *packet = packet0; }); msg.reserve(PACKET_DATA_LENGTH - 7); @@ -1247,8 +1251,8 @@ TEST_F(SenderTest, Message_reserve) EXPECT_EQ(TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH - 7, msg.packets[0].length); - EXPECT_CALL(mockDriver, allocPacket) - .WillOnce(Return(packetBuf1.toPacket())); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&packet1](Driver::Packet* packet) { *packet = packet1; }); msg.reserve(14); @@ -1290,7 +1294,8 @@ TEST_F(SenderTest, Message_getOrAllocPacket) EXPECT_FALSE(msg.occupied.test(0)); EXPECT_EQ(0U, msg.numPackets); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(packet0)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&packet0](Driver::Packet* packet) { *packet = packet0; }); EXPECT_EQ(packet0.descriptor, msg.getOrAllocPacket(0)->descriptor); @@ -1626,7 +1631,8 @@ TEST_F(SenderTest, checkPingTimeouts_basic) EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(mockPacket)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), _, _)).Times(1); EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); @@ -1693,7 +1699,7 @@ TEST_F(SenderTest, trySend_basic) // 3 granted packets; 2 will send; queue limit reached. EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[0]), _, _)); EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[1]), _, _)); - sender->trySend(&waitUntil); // < test call + waitUntil = sender->trySend(); // < test call EXPECT_TRUE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); EXPECT_EQ(3U, info->packetsGranted); @@ -1705,7 +1711,7 @@ TEST_F(SenderTest, trySend_basic) // 1 packet to be sent; grant limit reached. EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[2]), _, _)); - sender->trySend(&waitUntil); // < test call + waitUntil = sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); EXPECT_EQ(3U, info->packetsGranted); @@ -1718,7 +1724,7 @@ TEST_F(SenderTest, trySend_basic) // No additional grants; spurious ready hint. EXPECT_CALL(mockDriver, sendPacket).Times(0); sender->sendReady = true; - sender->trySend(&waitUntil); // < test call + waitUntil = sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); EXPECT_EQ(3U, info->packetsGranted); @@ -1733,7 +1739,7 @@ TEST_F(SenderTest, trySend_basic) sender->sendReady = true; EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[3]), _, _)); EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[4]), _, _)); - sender->trySend(&waitUntil); // < test call + waitUntil = sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_EQ(5U, info->packetsGranted); @@ -1787,10 +1793,9 @@ TEST_F(SenderTest, trySend_multipleMessages) EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[1]), _, _)); EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[2]), _, _)); - uint64_t waitUntil; - bool sendReady = sender->trySend(&waitUntil); + uint64_t waitUntil = sender->trySend(); - EXPECT_FALSE(sendReady); + EXPECT_EQ(waitUntil, 0); EXPECT_EQ(1U, info[0]->packetsSent); EXPECT_EQ(Homa::OutMessage::Status::SENT, message[0]->state); EXPECT_FALSE(sender->sendQueue.contains(&info[0]->sendQueueNode)); @@ -1806,9 +1811,7 @@ TEST_F(SenderTest, trySend_nothingToSend) { EXPECT_TRUE(sender->sendQueue.empty()); EXPECT_CALL(mockDriver, sendPacket).Times(0); - uint64_t waitUntil = 0; - bool sendReady = sender->trySend(&waitUntil); - EXPECT_FALSE(sendReady); + uint64_t waitUntil = sender->trySend(); EXPECT_EQ(waitUntil, 0); } diff --git a/src/Transports/Shenango.cc b/src/Transports/Shenango.cc index a5a8a2d..b7d05bd 100644 --- a/src/Transports/Shenango.cc +++ b/src/Transports/Shenango.cc @@ -129,9 +129,9 @@ class ShenangoDriver final : public Driver { void allocPacket(Packet* packet) override { - void* payload; - void* mbuf = shenango_homa_tx_alloc_mbuf(&payload); - return Packet{(uintptr_t)mbuf, payload, 0}; + void* mbuf = shenango_homa_tx_alloc_mbuf(&packet->payload); + packet->descriptor = reinterpret_cast(mbuf); + packet->length = 0; } void sendPacket(Packet* packet, IpAddress destination,