From 1840fd247685c6ed2a857e1099b998383125bbff Mon Sep 17 00:00:00 2001 From: Dominik Charousset Date: Sun, 24 Sep 2023 11:44:57 +0200 Subject: [PATCH] Avoid implicit conversions in Zeek utility classes --- doc/_examples/ping.cc | 11 +- doc/_examples/pong.cc | 11 +- include/broker/data.hh | 167 ++++++++++++++ include/broker/endpoint.hh | 17 ++ include/broker/message.hh | 6 + include/broker/zeek.hh | 342 +++++++++++++++------------- src/data.cc | 20 ++ src/endpoint.cc | 10 + tests/benchmark/broker-benchmark.cc | 12 +- 9 files changed, 430 insertions(+), 166 deletions(-) diff --git a/doc/_examples/ping.cc b/doc/_examples/ping.cc index 56fbeb58..3565e8c3 100644 --- a/doc/_examples/ping.cc +++ b/doc/_examples/ping.cc @@ -25,12 +25,15 @@ int main() { // Do five ping / pong. for (int n = 0; n < 5; n++) { // Send event "ping(n)". - zeek::Event ping("ping", {n}); - ep.publish("/topic/test", ping); + ep.publish("/topic/test", zeek::Event{"ping", {n}}); // Wait for "pong" reply event. auto msg = sub.get(); - zeek::Event pong(move_data(msg)); - std::cout << "received " << pong.name() << pong.args() << std::endl; + auto pong = zeek::Event{std::move(msg)}; + if (pong.valid()) + std::cout << "received " << pong.name() << pong.args() << std::endl; + else + std::cout << "received invalid pong message: " << to_string(pong) + << std::endl; } } diff --git a/doc/_examples/pong.cc b/doc/_examples/pong.cc index 6e6c528d..f003f21e 100644 --- a/doc/_examples/pong.cc +++ b/doc/_examples/pong.cc @@ -26,11 +26,14 @@ int main() { for (int n = 0; n < 5; n++) { // Wait for a "ping" event. auto msg = sub.get(); - zeek::Event ping(move_data(msg)); - std::cout << "received " << ping.name() << ping.args() << std::endl; + auto ping = zeek::Event{std::move(msg)}; + if (ping.valid()) + std::cout << "received " << ping.name() << ping.args() << std::endl; + else + std::cout << "received invalid ping message: " << to_string(ping) + << std::endl; // Send event "pong" response. - zeek::Event pong("pong", {n}); - ep.publish("/topic/test", pong); + ep.publish("/topic/test", zeek::Event{"pong", {n}}); } } diff --git a/include/broker/data.hh b/include/broker/data.hh index 2653eaba..ed520aa5 100644 --- a/include/broker/data.hh +++ b/include/broker/data.hh @@ -138,6 +138,8 @@ public: return *this; } + // -- properties ------------------------------------------------------------- + /// Returns a string representation of the stored type. const char* get_type_name() const; @@ -156,6 +158,171 @@ public: return data_; } + /// Checks whether this view contains the `nil` value. + bool is_none() const noexcept { + return get_type() == type::none; + } + + /// Checks whether this view contains a boolean. + bool is_boolean() const noexcept { + return get_type() == type::boolean; + } + + /// Checks whether this view contains a count. + bool is_count() const noexcept { + return get_type() == type::count; + } + + /// Checks whether this view contains a integer. + bool is_integer() const noexcept { + return get_type() == type::integer; + } + + /// Checks whether this view contains a real. + bool is_real() const noexcept { + return get_type() == type::real; + } + + /// Checks whether this view contains a count. + bool is_string() const noexcept { + return get_type() == type::string; + } + + /// Checks whether this view contains a count. + bool is_address() const noexcept { + return get_type() == type::address; + } + + /// Checks whether this view contains a count. + bool is_subnet() const noexcept { + return get_type() == type::subnet; + } + + /// Checks whether this view contains a count. + bool is_port() const noexcept { + return get_type() == type::port; + } + + /// Checks whether this view contains a count. + bool is_timestamp() const noexcept { + return get_type() == type::timestamp; + } + + /// Checks whether this view contains a count. + bool is_timespan() const noexcept { + return get_type() == type::timespan; + } + + /// Checks whether this view contains a count. + bool is_enum_value() const noexcept { + return get_type() == type::enum_value; + } + + /// Checks whether this view contains a set. + bool is_set() const noexcept { + return get_type() == type::set; + } + + /// Checks whether this view contains a table. + bool is_table() const noexcept { + return get_type() == type::table; + } + + /// Checks whether this view contains a list. + bool is_list() const noexcept { + return get_type() == type::vector; + } + + // -- conversions ------------------------------------------------------------ + + /// Retrieves the @c boolean value or returns @p fallback if this object does + /// not contain a @c boolean. + bool to_boolean(bool fallback = false) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c count value or returns @p fallback if this object does + /// not contain a @c count. + count to_count(count fallback = 0) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c integer value or returns @p fallback if this object does + /// not contain a @c integer. + integer to_integer(integer fallback = 0) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c real value or returns @p fallback if this object does + /// not contain a @c real. + real to_real(real fallback = 0) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the string value or returns an empty string if this object does + /// not contain a string. + std::string_view to_string() const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return std::string_view{}; + } + + /// Retrieves the @c address value or returns @p fallback if this object does + /// not contain a @c address. + address to_address(const address& fallback = {}) const noexcept { + if (auto* val = std::get_if
(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c subnet value or returns @p fallback if this object does + /// not contain a @c subnet. + subnet to_subnet(const subnet& fallback = {}) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c port value or returns @p fallback if this object does + /// not contain a @c port. + port to_port(port fallback = {}) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c timestamp value or returns @p fallback if this object + /// does not contain a @c timestamp. + timestamp to_timestamp(timestamp fallback = {}) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c timespan value or returns @p fallback if this object does + /// not contain a @c timespan. + timespan to_timespan(timespan fallback = {}) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the enum_value value or returns @p fallback if this object does + /// not contain a enum_value. + const enum_value& to_enum_value() const noexcept; + + /// Converts the stored data as a list (`vector`). If the stored data is + /// not a list, the result is an empty list. + [[nodiscard]] const vector& to_list() const; + private: data_variant data_; }; diff --git a/include/broker/endpoint.hh b/include/broker/endpoint.hh index f6e3ad9f..6b09360c 100644 --- a/include/broker/endpoint.hh +++ b/include/broker/endpoint.hh @@ -39,6 +39,12 @@ struct endpoint_context; } // namespace broker::internal +namespace broker::zeek { + +class Message; + +} // namespace broker::zeek + namespace broker { /// The main publish/subscribe abstraction. Endpoints can *peer* with each @@ -258,6 +264,17 @@ public: /// @param d The message data. void publish(const endpoint_info& dst, topic t, data d); + /// Publishes a message. + /// @param t The topic of the message. + /// @param d The message data. + void publish(std::string_view t, zeek::Message&& d); + + /// Publishes a message to a specific peer endpoint only. + /// @param dst The destination endpoint. + /// @param t The topic of the message. + /// @param d The message data. + void publish(const endpoint_info& dst, std::string_view t, zeek::Message&& d); + /// Publishes a message as vector. /// @param t The topic of the messages. /// @param xs The contents of the messages. diff --git a/include/broker/message.hh b/include/broker/message.hh index 092c96af..f5126573 100644 --- a/include/broker/message.hh +++ b/include/broker/message.hh @@ -212,6 +212,12 @@ inline const topic& get_topic(const data_message& x) { return get<0>(x); } +/// Retrieves the topic from a ::command_message as a string. +/// @relates data_message +inline std::string get_topic_str(const data_message& x) { + return get_topic(x).string(); +} + /// Retrieves the topic from a ::command_message. /// @relates data_message inline const topic& get_topic(const command_message& x) { diff --git a/include/broker/zeek.hh b/include/broker/zeek.hh index 516f7ba0..96df23ec 100644 --- a/include/broker/zeek.hh +++ b/include/broker/zeek.hh @@ -6,6 +6,7 @@ #include "broker/data.hh" #include "broker/detail/assert.hh" +#include "broker/message.hh" namespace broker::zeek { @@ -34,18 +35,7 @@ public: }; Type type() const { - if (as_vector().size() < 2) - return Type::Invalid; - - auto cp = get_if(&as_vector()[1]); - - if (!cp) - return Type::Invalid; - - if (*cp > Type::MAX) - return Type::Invalid; - - return Type(*cp); + return type(data_); } data&& move_data() { @@ -68,39 +58,43 @@ public: return get(data_); } - operator data() const { - return as_data(); - } - static Type type(const data& msg) { - auto vp = get_if(&msg); + constexpr auto max_tag = static_cast(Type::MAX); + if (auto&& elements = msg.to_list(); elements.size() >= 2) { + auto tag = elements[1].to_count(max_tag); + if (tag < max_tag) + return static_cast(tag); + } + return Type::Invalid; + } - if (!vp) - return Type::Invalid; + static Type type(const data_message& msg) { + return type(get_data(msg)); + } - auto& v = *vp; +protected: + explicit Message(Type type, vector content) + : data_(vector{ProtocolVersion, count(type), std::move(content)}) {} - if (v.size() < 2) - return Type::Invalid; + explicit Message(data msg) : data_(std::move(msg)) {} - auto cp = get_if(&v[1]); + Message() = default; - if (!cp) - return Type::Invalid; + Message(Message&&) = default; - if (*cp > Type::MAX) - return Type::Invalid; + data data_; +}; - return Type(*cp); - } +/// Represents an invalid message. +class Invalid : public Message { +public: + Invalid() = default; -protected: - Message(Type type, vector content) - : data_(vector{ProtocolVersion, count(type), std::move(content)}) {} + explicit Invalid(data msg) : Message(std::move(msg)) {} - Message(data msg) : data_(std::move(msg)) {} + explicit Invalid(data_message msg) : Invalid(broker::move_data(msg)) {} - data data_; + explicit Invalid(Message&& msg) : Message(std::move(msg)) {} }; /// Support iteration with structured binding. @@ -196,7 +190,9 @@ public: : Message(Message::Type::Event, {std::move(name), std::move(args), std::move(metadata)}) {} - Event(data msg) : Message(std::move(msg)) {} + explicit Event(data msg) : Message(std::move(msg)) {} + + explicit Event(data_message msg) : Event(broker::move_data(msg)) {} const std::string& name() const { return get(get(as_vector()[2])[0]); @@ -230,27 +226,12 @@ public: } bool valid() const { - if (as_vector().size() < 3) - return false; - - auto vp = get_if(&(as_vector()[2])); - - if (!vp) - return false; - - auto& v = *vp; - - if (v.size() < 2) - return false; - - auto name_ptr = get_if(&v[0]); - - if (!name_ptr) + auto&& outer = data_.to_list(); + if (outer.size() < 3) return false; - auto args_ptr = get_if(&v[1]); - - if (!args_ptr) + auto&& items = outer[2].to_list(); + if (items.size() < 2 || !items[0].is_string() || !items[1].is_list()) return false; // Optional event metadata verification. @@ -258,32 +239,23 @@ public: // Verify the third element if it exists is a vector> // and type and further check that the NetworkTimestamp metadata has the // right type because we know down here what to expect. - if (v.size() > 2) { - auto md_ptr = get_if(&v[2]); - if (!md_ptr) + if (items.size() > 2) { + auto&& meta_field = items[2]; + if (!meta_field.is_list()) return false; - for (const auto& mde : *md_ptr) { - auto mdev_ptr = get_if(mde); - if (!mdev_ptr) - return false; + for (const auto& field : meta_field.to_list()) { + auto&& kvp = field.to_list(); - if (mdev_ptr->size() != 2) - return false; - - const auto& mdev = *mdev_ptr; - - auto mde_key_ptr = get_if(mdev[0]); - if (!mde_key_ptr) + // Must be two elements: key and value. + if (kvp.size() != 2 || !kvp[0].is_count()) return false; + // If we have a NetworkTimestamp key, the value must be a timestamp. constexpr auto net_ts_key = static_cast(MetadataType::NetworkTimestamp); - if (*mde_key_ptr == net_ts_key) { - auto mde_val_ptr = get_if(mdev[1]); - if (!mde_val_ptr) - return false; - } + if (kvp[0].to_count() == net_ts_key && !kvp[1].is_timestamp()) + return false; } } @@ -291,34 +263,6 @@ public: } }; -/// A batch of other messages. -class Batch : public Message { -public: - Batch(vector msgs) : Message(Message::Type::Batch, std::move(msgs)) {} - - Batch(data msg) : Message(std::move(msg)) {} - - const vector& batch() const { - return get(as_vector()[2]); - } - - vector& batch() { - return get(as_vector()[2]); - } - - bool valid() const { - if (as_vector().size() < 3) - return false; - - auto vp = get_if(&(as_vector()[2])); - - if (!vp) - return false; - - return true; - } -}; - /// A Zeek log-create message. Note that at the moment this should be used /// only by Zeek itself as the arguments aren't pulbically defined. class LogCreate : public Message { @@ -329,7 +273,9 @@ public: {std::move(stream_id), std::move(writer_id), std::move(writer_info), std::move(fields_data)}) {} - LogCreate(data msg) : Message(std::move(msg)) {} + explicit LogCreate(data msg) : Message(std::move(msg)) {} + + explicit LogCreate(data_message msg) : LogCreate(broker::move_data(msg)) {} const enum_value& stream_id() const { return get(get(as_vector()[2])[0]); @@ -364,26 +310,15 @@ public: } bool valid() const { - if (as_vector().size() < 3) + auto&& outer = data_.to_list(); + if (outer.size() < 3) return false; - - auto vp = get_if(&(as_vector()[2])); - - if (!vp) - return false; - - auto& v = *vp; - - if (v.size() < 4) - return false; - - if (!get_if(&v[0])) - return false; - - if (!get_if(&v[1])) - return false; - - return true; + auto&& inner = outer[2].to_list(); + return inner.size() >= 4 // + && inner[0].is_enum_value() // + && inner[1].is_enum_value() // + && inner[2].is_list() // + && inner[3].is_list(); } }; @@ -397,7 +332,9 @@ public: {std::move(stream_id), std::move(writer_id), std::move(path), std::move(serial_data)}) {} - LogWrite(data msg) : Message(std::move(msg)) {} + explicit LogWrite(data msg) : Message(std::move(msg)) {} + + explicit LogWrite(data_message msg) : LogWrite(broker::move_data(msg)) {} const enum_value& stream_id() const { return get(get(as_vector()[2])[0]); @@ -431,27 +368,20 @@ public: return get(as_vector()[2])[3]; } - bool valid() const { - if (as_vector().size() < 3) - return false; - - auto vp = get_if(&(as_vector()[2])); - - if (!vp) - return false; - - auto& v = *vp; - - if (v.size() < 4) - return false; - - if (!get_if(&v[0])) - return false; + std::string_view serial_data_str() const { + return get(serial_data()); + } - if (!get_if(&v[1])) + bool valid() const { + auto&& outer = data_.to_list(); + if (outer.size() < 3) return false; - - return true; + auto&& inner = outer[2].to_list(); + return inner.size() >= 4 // + && inner[0].is_enum_value() // + && inner[1].is_enum_value() // + && inner[2].is_list() // + && inner[3].is_list(); } }; @@ -461,7 +391,10 @@ public: : Message(Message::Type::IdentifierUpdate, {std::move(id_name), std::move(id_value)}) {} - IdentifierUpdate(data msg) : Message(std::move(msg)) {} + explicit IdentifierUpdate(data msg) : Message(std::move(msg)) {} + + explicit IdentifierUpdate(data_message msg) + : IdentifierUpdate(broker::move_data(msg)) {} const std::string& id_name() const { return get(get(as_vector()[2])[0]); @@ -480,24 +413,129 @@ public: } bool valid() const { - if (as_vector().size() < 3) + auto&& outer = data_.to_list(); + if (outer.size() < 3) return false; + auto&& inner = outer[2].to_list(); + return inner.size() >= 2 && inner[0].is_string(); + } +}; - auto vp = get_if(&(as_vector()[2])); +class BatchBuilder; - if (!vp) - return false; +/// A batch of other messages. +class Batch : public Message { +public: + explicit Batch(vector msgs) + : Message(Message::Type::Batch, std::move(msgs)) {} - auto& v = *vp; + explicit Batch(data msg) : Message(std::move(msg)) {} - if (v.size() < 2) - return false; + explicit Batch(data_message msg) : Batch(broker::move_data(msg)) {} - if (!get_if(&v[0])) - return false; + size_t size() const noexcept { + return impl_ ? impl_->size() : 0; + } - return true; + bool empty() const noexcept { + return size() == 0; + } + + bool valid() const { + return impl_ != nullptr; + } + + template + auto for_each(F&& f) { + if (!impl_) + return; + for (auto& x : *impl_) + std::visit(f, x); + } + + template + auto for_each(F&& f) const { + if (!impl_) + return; + for (const auto& x : *impl_) + std::visit(f, x); } + +private: + using VarMsg = + std::variant; + + using Content = std::vector; + + std::shared_ptr impl_; }; +class BatchBuilder { +public: + void add(Message&& msg) { + inner_.emplace_back(msg.move_data()); + } + + bool empty() const noexcept { + return inner_.empty(); + } + + Batch build() { + vector tmp; + tmp.swap(inner_); + inner_.reserve(tmp.size()); + auto result = Batch{tmp}; + return result; + } + +private: + vector inner_; +}; + +template +auto visit_as_message(F&& f, broker::data_message msg) { + auto do_visit = [&f](auto& tmp) { + if (tmp.valid()) + return f(tmp); + Invalid fallback{std::move(tmp)}; + return f(fallback); + }; + switch (Message::type(msg)) { + default: { + Invalid tmp{msg}; + return f(tmp); + } + case Message::Type::Event: { + Event tmp{std::move(msg)}; + return do_visit(tmp); + } + case Message::Type::LogCreate: { + LogCreate tmp{std::move(msg)}; + return do_visit(tmp); + } + case Message::Type::LogWrite: { + LogWrite tmp{std::move(msg)}; + return do_visit(tmp); + } + case Message::Type::IdentifierUpdate: { + IdentifierUpdate tmp{std::move(msg)}; + return do_visit(tmp); + } + case Message::Type::Batch: { + Batch tmp{std::move(msg)}; + return do_visit(tmp); + } + } +} + } // namespace broker::zeek + +namespace broker { + +inline std::string to_string(const zeek::Message& msg) { + return to_string(msg.as_data()); +} + +} // namespace broker diff --git a/src/data.cc b/src/data.cc index 3d5e67d9..827579f8 100644 --- a/src/data.cc +++ b/src/data.cc @@ -93,6 +93,26 @@ const char* data::get_type_name() const { namespace { +vector empty_vector; + +enum_value empty_enum_value; + +} // namespace + +const enum_value& data::to_enum_value() const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return empty_enum_value; +} + +const vector& data::to_list() const { + if (auto ptr = std::get_if(&data_)) + return *ptr; + return empty_vector; +} + +namespace { + template void container_convert(Container& c, std::string& str, char left, char right) { constexpr auto* delim = ", "; diff --git a/src/endpoint.cc b/src/endpoint.cc index 2db0e387..9d203d8b 100644 --- a/src/endpoint.cc +++ b/src/endpoint.cc @@ -19,6 +19,7 @@ #include "broker/status_subscriber.hh" #include "broker/subscriber.hh" #include "broker/timeout.hh" +#include "broker/zeek.hh" #include #include @@ -821,6 +822,15 @@ void endpoint::publish(const endpoint_info& dst, topic t, data d) { make_data_message(std::move(t), std::move(d)), dst); } +void endpoint::publish(std::string_view t, zeek::Message&& d) { + publish(topic{std::string{t}}, d.move_data()); +} + +void endpoint::publish(const endpoint_info& dst, std::string_view t, + zeek::Message&& d) { + publish(dst, topic{std::string{t}}, d.move_data()); +} + void endpoint::publish(data_message x) { BROKER_INFO("publishing" << x); caf::anon_send(native(core_), atom::publish_v, std::move(x)); diff --git a/tests/benchmark/broker-benchmark.cc b/tests/benchmark/broker-benchmark.cc index 5360e55b..1910b0d9 100644 --- a/tests/benchmark/broker-benchmark.cc +++ b/tests/benchmark/broker-benchmark.cc @@ -149,8 +149,8 @@ void send_batch(endpoint& ep, publisher& p) { auto name = "event_" + std::to_string(event_type); vector batch; for (int i = 0; i < batch_size; i++) { - auto ev = zeek::Event(std::string(name), createEventArgs()); - batch.emplace_back(std::move(ev)); + auto ev = zeek::Event{std::string(name), createEventArgs()}; + batch.emplace_back(ev.move_data()); } total_sent += batch.size(); p.publish(std::move(batch)); @@ -215,7 +215,7 @@ void receivedStats(endpoint& ep, const data& x) { if (max_received && total_recv > max_received) { zeek::Event ev("quit_benchmark", std::vector{}); - ep.publish("/benchmark/terminate", ev); + ep.publish("/benchmark/terminate", std::move(ev)); std::this_thread::sleep_for(2s); // Give clients a bit. exit(0); } @@ -276,8 +276,8 @@ void client_loop(endpoint& ep, bool verbose, status_subscriber& ss) { // Pull: generate random events. for (size_t i = 0; i < hint; ++i) { auto name = "event_" + std::to_string(event_type); - out.emplace_back("/benchmark/events", - zeek::Event(std::move(name), createEventArgs())); + auto ev = zeek::Event{std::string(name), createEventArgs()}; + out.emplace_back("/benchmark/events", ev.move_data()); } }, [] { @@ -347,7 +347,7 @@ void server_mode(endpoint& ep, bool verbose, const std::string& iface, // Count number of events (counts each element in a batch as one event). if (zeek::Message::type(msg) == zeek::Message::Type::Batch) { zeek::Batch batch(std::move(msg)); - num_events += batch.batch().size(); + num_events += batch.size(); } else { ++num_events; }