diff --git a/lib/basic_traffic/basic_traffic.cc b/lib/basic_traffic/basic_traffic.cc index fafb3892..e39f65df 100644 --- a/lib/basic_traffic/basic_traffic.cc +++ b/lib/basic_traffic/basic_traffic.cc @@ -299,14 +299,7 @@ absl::StatusOr> SendTraffic( absl::flat_hash_map sent_packets; std::vector> received_packets; { - ASSIGN_OR_RETURN( - auto finalizer, - testbed.ControlDevice().CollectPackets( - [&received_packets](absl::string_view control_interface, - absl::string_view packet_string) { - received_packets.push_back(std::make_tuple( - std::string(control_interface), std::string(packet_string))); - })); + ASSIGN_OR_RETURN(auto finalizer, testbed.ControlDevice().CollectPackets()); LOG(INFO) << "Starting to send traffic."; absl::Time start_time = absl::Now(); @@ -326,7 +319,14 @@ absl::StatusOr> SendTraffic( last_sent_time = absl::Now(); } } - absl::SleepFor(kPassthroughWaitTime); + + RETURN_IF_ERROR(finalizer->HandlePacketsFor( + kPassthroughWaitTime, + [&received_packets](absl::string_view control_interface, + absl::string_view packet_string) { + received_packets.push_back(std::make_tuple( + std::string(control_interface), std::string(packet_string))); + })); } LOG(INFO) << "Traffic sending complete."; diff --git a/lib/p4rt/BUILD.bazel b/lib/p4rt/BUILD.bazel index 63786d85..c699523e 100644 --- a/lib/p4rt/BUILD.bazel +++ b/lib/p4rt/BUILD.bazel @@ -39,7 +39,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", - "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/strings", ], ) diff --git a/lib/p4rt/packet_listener.cc b/lib/p4rt/packet_listener.cc index 781dca9a..c7bcb28e 100644 --- a/lib/p4rt/packet_listener.cc +++ b/lib/p4rt/packet_listener.cc @@ -19,7 +19,9 @@ #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "glog/logging.h" +#include "gutil/status.h" #include "gutil/testing.h" #include "lib/p4rt/p4rt_programming_context.h" #include "p4/v1/p4runtime.pb.h" @@ -29,7 +31,6 @@ #include "sai_p4/instantiations/google/instantiations.h" #include "sai_p4/instantiations/google/sai_p4info.h" #include "sai_p4/instantiations/google/sai_pd.pb.h" -#include "thinkit/control_device.h" namespace pins_test { @@ -37,38 +38,41 @@ PacketListener::PacketListener( pdpi::P4RuntimeSession* session, P4rtProgrammingContext context, sai::Instantiation instantiation, const absl::flat_hash_map* - interface_port_id_to_name, - thinkit::PacketCallback callback, std::function on_finish) + interface_port_id_to_name) : session_(session), context_(std::move(context)), - receive_packet_thread_([this, instantiation, interface_port_id_to_name, - callback = std::move(callback)]() { - p4::v1::StreamMessageResponse pi_response; - while (session_->StreamChannelRead(pi_response)) { - sai::StreamMessageResponse pd_response; - if (!pdpi::PiStreamMessageResponseToPd( - sai::GetIrP4Info(instantiation), pi_response, &pd_response) - .ok()) { - LOG(ERROR) << "Failed to convert PI stream message response to PD."; - return; - } - if (!pd_response.has_packet()) { - LOG(ERROR) << "PD response has no packet."; - return; - } - std::string port_id = pd_response.packet().metadata().ingress_port(); + instantiation_(instantiation), + interface_port_id_to_name_(*interface_port_id_to_name) {} - auto port_name = interface_port_id_to_name->find(port_id); - if (port_name == interface_port_id_to_name->end()) { - LOG(WARNING) << port_id << " not found."; - return; - } - LOG_EVERY_N(INFO, 1000) - << "Packet received (Count: " << google::COUNTER << ")."; +absl::Status PacketListener::HandlePacketsFor( + absl::Duration duration, thinkit::PacketCallback callback) { + ASSIGN_OR_RETURN(std::vector messages, + session_->GetAllStreamMessagesFor(duration)); + for (const auto& pi_response : messages) { + sai::StreamMessageResponse pd_response; + if (!pdpi::PiStreamMessageResponseToPd(sai::GetIrP4Info(instantiation_), + pi_response, &pd_response) + .ok()) { + LOG(ERROR) << "Failed to convert PI stream message response to PD."; + continue; + } + if (!pd_response.has_packet()) { + LOG(ERROR) << "PD response has no packet."; + continue; + } + std::string port_id = pd_response.packet().metadata().ingress_port(); - callback(port_name->second, pd_response.packet().payload()); - } - }), - on_finish_(std::move(on_finish)) {} + auto port_name = interface_port_id_to_name_.find(port_id); + if (port_name == interface_port_id_to_name_.end()) { + LOG(WARNING) << port_id << " not found."; + continue; + } + LOG_EVERY_N(INFO, 1000) + << "Packet received (Count: " << google::COUNTER << ")."; + callback(port_name->second, pd_response.packet().payload()); + } + + return absl::OkStatus(); +} } // namespace pins_test diff --git a/lib/p4rt/packet_listener.h b/lib/p4rt/packet_listener.h index 90d5cc33..be1f9a9b 100644 --- a/lib/p4rt/packet_listener.h +++ b/lib/p4rt/packet_listener.h @@ -16,7 +16,6 @@ #include #include -#include // NOLINT #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" @@ -45,28 +44,24 @@ class PacketListener : public thinkit::PacketGenerationFinalizer { P4rtProgrammingContext context, sai::Instantiation instantiation, const absl::flat_hash_map* - interface_port_id_to_name, - thinkit::PacketCallback callback, - std::function on_finish); + interface_port_id_to_name); + + absl::Status HandlePacketsFor(absl::Duration duration, + thinkit::PacketCallback callback_) override; ~PacketListener() { absl::Status status = context_.Revert(); if (!status.ok()) { LOG(WARNING) << "Failed to revert packet listening flows: " << status; } - status = session_->Finish(); - if (!status.ok()) { - LOG(WARNING) << "P4RuntimeSession finished abnormally: " << status; - } - receive_packet_thread_.join(); - on_finish_(); } private: pdpi::P4RuntimeSession* session_; P4rtProgrammingContext context_; - std::thread receive_packet_thread_; - std::function on_finish_; + sai::Instantiation instantiation_; + const absl::flat_hash_map& + interface_port_id_to_name_; }; } // namespace pins_test diff --git a/lib/pins_control_device.cc b/lib/pins_control_device.cc index 17f2fffd..c7b136e7 100644 --- a/lib/pins_control_device.cc +++ b/lib/pins_control_device.cc @@ -121,7 +121,7 @@ absl::StatusOr PinsControlDevice::Create( } absl::StatusOr> -PinsControlDevice::CollectPackets(thinkit::PacketCallback callback) { +PinsControlDevice::CollectPackets() { if (control_session_ == nullptr) { return absl::InternalError( "No P4RuntimeSession exists; Likely failed to establish another " @@ -148,19 +148,10 @@ PinsControlDevice::CollectPackets(thinkit::PacketCallback callback) { } })pb"))); RETURN_IF_ERROR(context.SendWriteRequest(punt_all_request)); + return absl::make_unique( control_session_.get(), std::move(context), - sai::Instantiation::kMiddleblock, &interface_port_id_to_name_, - std::move(callback), /*on_finish=*/[this]() { - // After the packet listener is finished and destroyed the old session, - // try to replace it with a new session. - auto session = pdpi::P4RuntimeSession::Create(*sut_); - if (!session.ok()) { - LOG(ERROR) << "Failed to establish another P4RuntimeSession:" - << session.status(); - } - control_session_ = std::move(session).value_or(nullptr); - }); + sai::Instantiation::kMiddleblock, &interface_port_id_to_name_); } absl::Status PinsControlDevice::SendPacket( diff --git a/lib/pins_control_device.h b/lib/pins_control_device.h index a77d6e82..d58c1c34 100644 --- a/lib/pins_control_device.h +++ b/lib/pins_control_device.h @@ -56,7 +56,7 @@ class PinsControlDevice : public thinkit::ControlDevice { absl::flat_hash_map interface_name_to_port_id); absl::StatusOr> - CollectPackets(thinkit::PacketCallback callback) override; + CollectPackets() override; absl::Status SendPacket(absl::string_view interface, absl::string_view packet, std::optional packet_delay) override; diff --git a/thinkit/BUILD.bazel b/thinkit/BUILD.bazel index f22a422f..f9fe7025 100644 --- a/thinkit/BUILD.bazel +++ b/thinkit/BUILD.bazel @@ -361,4 +361,8 @@ cc_test( cc_library( name = "packet_generation_finalizer", hdrs = ["packet_generation_finalizer.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + ], ) diff --git a/thinkit/control_device.h b/thinkit/control_device.h index c1ba9f19..7869e2ba 100644 --- a/thinkit/control_device.h +++ b/thinkit/control_device.h @@ -48,12 +48,6 @@ enum class RebootType { kCold, }; -// Callback when a packet is received, first parameter which is control -// interface port it was received on and second parameter is the raw byte string -// of the packet. -using PacketCallback = - std::function; - // A `ControlDevice` represents any device or devices that can at the very // least send and receive packets over their interfaces. It may be able to get // and set link state, as well as perform various other operations like link @@ -62,11 +56,10 @@ class ControlDevice { public: virtual ~ControlDevice() {} - // Starts collecting packets, calling `callback` whenever a packet is - // received. This continues until the `PacketGenerationFinalizer` goes out of - // scope. + // Starts collecting packets. This continues until the + // `PacketGenerationFinalizer` goes out of scope. virtual absl::StatusOr> - CollectPackets(PacketCallback callback) = 0; + CollectPackets() = 0; absl::Status SendPacket(absl::string_view interface, absl::string_view packet) { diff --git a/thinkit/mock_control_device.h b/thinkit/mock_control_device.h index 70d4ea6d..23b5fe59 100644 --- a/thinkit/mock_control_device.h +++ b/thinkit/mock_control_device.h @@ -37,7 +37,7 @@ class MockControlDevice : public ControlDevice { public: MOCK_METHOD( absl::StatusOr>, - CollectPackets, (PacketCallback callback), (override)); + CollectPackets, (), (override)); MOCK_METHOD(absl::Status, SendPacket, (absl::string_view interface, absl::string_view packet, std::optional packet_delay), diff --git a/thinkit/packet_generation_finalizer.h b/thinkit/packet_generation_finalizer.h index 35912f08..055eeb03 100644 --- a/thinkit/packet_generation_finalizer.h +++ b/thinkit/packet_generation_finalizer.h @@ -16,17 +16,25 @@ #define PINS_THINKIT_PACKET_GENERATION_FINALIZER_H_ +#include "absl/status/status.h" +#include "absl/time/time.h" + namespace thinkit { +// Callback when a packet is received, first parameter which is control +// interface port it was received on and second parameter is the raw byte string +// of the packet. +using PacketCallback = + std::function; // PacketGenerationFinalizer will stop listening for packets when it goes out of // scope. class PacketGenerationFinalizer { public: - virtual ~PacketGenerationFinalizer() = 0; + virtual absl::Status HandlePacketsFor(absl::Duration duration, + PacketCallback handler) = 0; + virtual ~PacketGenerationFinalizer() = default; }; -inline PacketGenerationFinalizer::~PacketGenerationFinalizer() {} - } // namespace thinkit #endif // PINS_THINKIT_PACKET_GENERATION_FINALIZER_H_