From f53c7681f846c6c47cbc167196a5250f575eec90 Mon Sep 17 00:00:00 2001 From: Dominik Charousset Date: Sun, 24 Sep 2023 11:44:57 +0200 Subject: [PATCH 1/2] Avoid implicit conversions in Zeek utility classes --- CMakeLists.txt | 1 + doc/_examples/ping.cc | 11 +- doc/_examples/pong.cc | 11 +- include/broker/data.hh | 167 ++++++++++++++ include/broker/endpoint.hh | 17 ++ include/broker/message.hh | 6 + include/broker/zeek.hh | 339 +++++++++++++++------------- src/data.cc | 20 ++ src/endpoint.cc | 10 + src/zeek.cc | 64 ++++++ tests/benchmark/broker-benchmark.cc | 12 +- 11 files changed, 492 insertions(+), 166 deletions(-) create mode 100644 src/zeek.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 652ae2cc..a348a161 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -360,6 +360,7 @@ set(BROKER_SRC src/topic.cc src/version.cc src/worker.cc + src/zeek.cc ) if (ENABLE_SHARED) diff --git a/doc/_examples/ping.cc b/doc/_examples/ping.cc index 56fbeb58..3565e8c3 100644 --- a/doc/_examples/ping.cc +++ b/doc/_examples/ping.cc @@ -25,12 +25,15 @@ int main() { // Do five ping / pong. for (int n = 0; n < 5; n++) { // Send event "ping(n)". - zeek::Event ping("ping", {n}); - ep.publish("/topic/test", ping); + ep.publish("/topic/test", zeek::Event{"ping", {n}}); // Wait for "pong" reply event. auto msg = sub.get(); - zeek::Event pong(move_data(msg)); - std::cout << "received " << pong.name() << pong.args() << std::endl; + auto pong = zeek::Event{std::move(msg)}; + if (pong.valid()) + std::cout << "received " << pong.name() << pong.args() << std::endl; + else + std::cout << "received invalid pong message: " << to_string(pong) + << std::endl; } } diff --git a/doc/_examples/pong.cc b/doc/_examples/pong.cc index 6e6c528d..f003f21e 100644 --- a/doc/_examples/pong.cc +++ b/doc/_examples/pong.cc @@ -26,11 +26,14 @@ int main() { for (int n = 0; n < 5; n++) { // Wait for a "ping" event. auto msg = sub.get(); - zeek::Event ping(move_data(msg)); - std::cout << "received " << ping.name() << ping.args() << std::endl; + auto ping = zeek::Event{std::move(msg)}; + if (ping.valid()) + std::cout << "received " << ping.name() << ping.args() << std::endl; + else + std::cout << "received invalid ping message: " << to_string(ping) + << std::endl; // Send event "pong" response. - zeek::Event pong("pong", {n}); - ep.publish("/topic/test", pong); + ep.publish("/topic/test", zeek::Event{"pong", {n}}); } } diff --git a/include/broker/data.hh b/include/broker/data.hh index 2653eaba..ed520aa5 100644 --- a/include/broker/data.hh +++ b/include/broker/data.hh @@ -138,6 +138,8 @@ public: return *this; } + // -- properties ------------------------------------------------------------- + /// Returns a string representation of the stored type. const char* get_type_name() const; @@ -156,6 +158,171 @@ public: return data_; } + /// Checks whether this view contains the `nil` value. + bool is_none() const noexcept { + return get_type() == type::none; + } + + /// Checks whether this view contains a boolean. + bool is_boolean() const noexcept { + return get_type() == type::boolean; + } + + /// Checks whether this view contains a count. + bool is_count() const noexcept { + return get_type() == type::count; + } + + /// Checks whether this view contains a integer. + bool is_integer() const noexcept { + return get_type() == type::integer; + } + + /// Checks whether this view contains a real. + bool is_real() const noexcept { + return get_type() == type::real; + } + + /// Checks whether this view contains a count. + bool is_string() const noexcept { + return get_type() == type::string; + } + + /// Checks whether this view contains a count. + bool is_address() const noexcept { + return get_type() == type::address; + } + + /// Checks whether this view contains a count. + bool is_subnet() const noexcept { + return get_type() == type::subnet; + } + + /// Checks whether this view contains a count. + bool is_port() const noexcept { + return get_type() == type::port; + } + + /// Checks whether this view contains a count. + bool is_timestamp() const noexcept { + return get_type() == type::timestamp; + } + + /// Checks whether this view contains a count. + bool is_timespan() const noexcept { + return get_type() == type::timespan; + } + + /// Checks whether this view contains a count. + bool is_enum_value() const noexcept { + return get_type() == type::enum_value; + } + + /// Checks whether this view contains a set. + bool is_set() const noexcept { + return get_type() == type::set; + } + + /// Checks whether this view contains a table. + bool is_table() const noexcept { + return get_type() == type::table; + } + + /// Checks whether this view contains a list. + bool is_list() const noexcept { + return get_type() == type::vector; + } + + // -- conversions ------------------------------------------------------------ + + /// Retrieves the @c boolean value or returns @p fallback if this object does + /// not contain a @c boolean. + bool to_boolean(bool fallback = false) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c count value or returns @p fallback if this object does + /// not contain a @c count. + count to_count(count fallback = 0) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c integer value or returns @p fallback if this object does + /// not contain a @c integer. + integer to_integer(integer fallback = 0) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c real value or returns @p fallback if this object does + /// not contain a @c real. + real to_real(real fallback = 0) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the string value or returns an empty string if this object does + /// not contain a string. + std::string_view to_string() const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return std::string_view{}; + } + + /// Retrieves the @c address value or returns @p fallback if this object does + /// not contain a @c address. + address to_address(const address& fallback = {}) const noexcept { + if (auto* val = std::get_if
(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c subnet value or returns @p fallback if this object does + /// not contain a @c subnet. + subnet to_subnet(const subnet& fallback = {}) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c port value or returns @p fallback if this object does + /// not contain a @c port. + port to_port(port fallback = {}) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c timestamp value or returns @p fallback if this object + /// does not contain a @c timestamp. + timestamp to_timestamp(timestamp fallback = {}) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the @c timespan value or returns @p fallback if this object does + /// not contain a @c timespan. + timespan to_timespan(timespan fallback = {}) const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return fallback; + } + + /// Retrieves the enum_value value or returns @p fallback if this object does + /// not contain a enum_value. + const enum_value& to_enum_value() const noexcept; + + /// Converts the stored data as a list (`vector`). If the stored data is + /// not a list, the result is an empty list. + [[nodiscard]] const vector& to_list() const; + private: data_variant data_; }; diff --git a/include/broker/endpoint.hh b/include/broker/endpoint.hh index f6e3ad9f..6b09360c 100644 --- a/include/broker/endpoint.hh +++ b/include/broker/endpoint.hh @@ -39,6 +39,12 @@ struct endpoint_context; } // namespace broker::internal +namespace broker::zeek { + +class Message; + +} // namespace broker::zeek + namespace broker { /// The main publish/subscribe abstraction. Endpoints can *peer* with each @@ -258,6 +264,17 @@ public: /// @param d The message data. void publish(const endpoint_info& dst, topic t, data d); + /// Publishes a message. + /// @param t The topic of the message. + /// @param d The message data. + void publish(std::string_view t, zeek::Message&& d); + + /// Publishes a message to a specific peer endpoint only. + /// @param dst The destination endpoint. + /// @param t The topic of the message. + /// @param d The message data. + void publish(const endpoint_info& dst, std::string_view t, zeek::Message&& d); + /// Publishes a message as vector. /// @param t The topic of the messages. /// @param xs The contents of the messages. diff --git a/include/broker/message.hh b/include/broker/message.hh index 092c96af..f5126573 100644 --- a/include/broker/message.hh +++ b/include/broker/message.hh @@ -212,6 +212,12 @@ inline const topic& get_topic(const data_message& x) { return get<0>(x); } +/// Retrieves the topic from a ::command_message as a string. +/// @relates data_message +inline std::string get_topic_str(const data_message& x) { + return get_topic(x).string(); +} + /// Retrieves the topic from a ::command_message. /// @relates data_message inline const topic& get_topic(const command_message& x) { diff --git a/include/broker/zeek.hh b/include/broker/zeek.hh index 516f7ba0..7590f5c2 100644 --- a/include/broker/zeek.hh +++ b/include/broker/zeek.hh @@ -6,6 +6,7 @@ #include "broker/data.hh" #include "broker/detail/assert.hh" +#include "broker/message.hh" namespace broker::zeek { @@ -33,19 +34,10 @@ public: MAX = Batch, }; - Type type() const { - if (as_vector().size() < 2) - return Type::Invalid; - - auto cp = get_if(&as_vector()[1]); - - if (!cp) - return Type::Invalid; - - if (*cp > Type::MAX) - return Type::Invalid; + virtual ~Message(); - return Type(*cp); + Type type() const { + return type(data_); } data&& move_data() { @@ -68,39 +60,45 @@ public: return get(data_); } - operator data() const { - return as_data(); - } - static Type type(const data& msg) { - auto vp = get_if(&msg); + constexpr auto max_tag = static_cast(Type::MAX); + auto&& elements = msg.to_list(); + if (elements.size() >= 2) { + auto tag = elements[1].to_count(max_tag + 1); + if (tag <= max_tag) { + return static_cast(tag); + } + } + return Type::Invalid; + } - if (!vp) - return Type::Invalid; + static Type type(const data_message& msg) { + return type(get_data(msg)); + } - auto& v = *vp; +protected: + explicit Message(Type type, vector content) + : data_(vector{ProtocolVersion, count(type), std::move(content)}) {} - if (v.size() < 2) - return Type::Invalid; + explicit Message(data msg) : data_(std::move(msg)) {} - auto cp = get_if(&v[1]); + Message() = default; - if (!cp) - return Type::Invalid; + Message(Message&&) = default; - if (*cp > Type::MAX) - return Type::Invalid; + data data_; +}; - return Type(*cp); - } +/// Represents an invalid message. +class Invalid : public Message { +public: + Invalid() = default; -protected: - Message(Type type, vector content) - : data_(vector{ProtocolVersion, count(type), std::move(content)}) {} + explicit Invalid(data msg) : Message(std::move(msg)) {} - Message(data msg) : data_(std::move(msg)) {} + explicit Invalid(data_message msg) : Invalid(broker::move_data(msg)) {} - data data_; + explicit Invalid(Message&& msg) : Message(std::move(msg)) {} }; /// Support iteration with structured binding. @@ -196,7 +194,9 @@ public: : Message(Message::Type::Event, {std::move(name), std::move(args), std::move(metadata)}) {} - Event(data msg) : Message(std::move(msg)) {} + explicit Event(data msg) : Message(std::move(msg)) {} + + explicit Event(data_message msg) : Event(broker::move_data(msg)) {} const std::string& name() const { return get(get(as_vector()[2])[0]); @@ -230,27 +230,12 @@ public: } bool valid() const { - if (as_vector().size() < 3) - return false; - - auto vp = get_if(&(as_vector()[2])); - - if (!vp) + auto&& outer = data_.to_list(); + if (outer.size() < 3) return false; - auto& v = *vp; - - if (v.size() < 2) - return false; - - auto name_ptr = get_if(&v[0]); - - if (!name_ptr) - return false; - - auto args_ptr = get_if(&v[1]); - - if (!args_ptr) + auto&& items = outer[2].to_list(); + if (items.size() < 2 || !items[0].is_string() || !items[1].is_list()) return false; // Optional event metadata verification. @@ -258,32 +243,23 @@ public: // Verify the third element if it exists is a vector> // and type and further check that the NetworkTimestamp metadata has the // right type because we know down here what to expect. - if (v.size() > 2) { - auto md_ptr = get_if(&v[2]); - if (!md_ptr) + if (items.size() > 2) { + auto&& meta_field = items[2]; + if (!meta_field.is_list()) return false; - for (const auto& mde : *md_ptr) { - auto mdev_ptr = get_if(mde); - if (!mdev_ptr) - return false; - - if (mdev_ptr->size() != 2) - return false; - - const auto& mdev = *mdev_ptr; + for (const auto& field : meta_field.to_list()) { + auto&& kvp = field.to_list(); - auto mde_key_ptr = get_if(mdev[0]); - if (!mde_key_ptr) + // Must be two elements: key and value. + if (kvp.size() != 2 || !kvp[0].is_count()) return false; + // If we have a NetworkTimestamp key, the value must be a timestamp. constexpr auto net_ts_key = static_cast(MetadataType::NetworkTimestamp); - if (*mde_key_ptr == net_ts_key) { - auto mde_val_ptr = get_if(mdev[1]); - if (!mde_val_ptr) - return false; - } + if (kvp[0].to_count() == net_ts_key && !kvp[1].is_timestamp()) + return false; } } @@ -291,34 +267,6 @@ public: } }; -/// A batch of other messages. -class Batch : public Message { -public: - Batch(vector msgs) : Message(Message::Type::Batch, std::move(msgs)) {} - - Batch(data msg) : Message(std::move(msg)) {} - - const vector& batch() const { - return get(as_vector()[2]); - } - - vector& batch() { - return get(as_vector()[2]); - } - - bool valid() const { - if (as_vector().size() < 3) - return false; - - auto vp = get_if(&(as_vector()[2])); - - if (!vp) - return false; - - return true; - } -}; - /// A Zeek log-create message. Note that at the moment this should be used /// only by Zeek itself as the arguments aren't pulbically defined. class LogCreate : public Message { @@ -329,7 +277,9 @@ public: {std::move(stream_id), std::move(writer_id), std::move(writer_info), std::move(fields_data)}) {} - LogCreate(data msg) : Message(std::move(msg)) {} + explicit LogCreate(data msg) : Message(std::move(msg)) {} + + explicit LogCreate(data_message msg) : LogCreate(broker::move_data(msg)) {} const enum_value& stream_id() const { return get(get(as_vector()[2])[0]); @@ -364,26 +314,13 @@ public: } bool valid() const { - if (as_vector().size() < 3) - return false; - - auto vp = get_if(&(as_vector()[2])); - - if (!vp) - return false; - - auto& v = *vp; - - if (v.size() < 4) - return false; - - if (!get_if(&v[0])) + auto&& outer = data_.to_list(); + if (outer.size() < 3) return false; - - if (!get_if(&v[1])) - return false; - - return true; + auto&& inner = outer[2].to_list(); + return inner.size() >= 4 // + && inner[0].is_enum_value() // + && inner[1].is_enum_value(); } }; @@ -397,7 +334,9 @@ public: {std::move(stream_id), std::move(writer_id), std::move(path), std::move(serial_data)}) {} - LogWrite(data msg) : Message(std::move(msg)) {} + explicit LogWrite(data msg) : Message(std::move(msg)) {} + + explicit LogWrite(data_message msg) : LogWrite(broker::move_data(msg)) {} const enum_value& stream_id() const { return get(get(as_vector()[2])[0]); @@ -423,6 +362,10 @@ public: return get(as_vector()[2])[2]; }; + std::string_view path_str() { + return get(path()); + }; + const data& serial_data() const { return get(as_vector()[2])[3]; } @@ -431,27 +374,20 @@ public: return get(as_vector()[2])[3]; } - bool valid() const { - if (as_vector().size() < 3) - return false; - - auto vp = get_if(&(as_vector()[2])); - - if (!vp) - return false; - - auto& v = *vp; - - if (v.size() < 4) - return false; - - if (!get_if(&v[0])) - return false; + std::string_view serial_data_str() const { + return get(serial_data()); + } - if (!get_if(&v[1])) + bool valid() const { + auto&& outer = data_.to_list(); + if (outer.size() < 3) return false; - - return true; + auto&& inner = outer[2].to_list(); + return inner.size() >= 4 // + && inner[0].is_enum_value() // + && inner[1].is_enum_value() // + && inner[2].is_string() // + && inner[3].is_string(); } }; @@ -461,7 +397,10 @@ public: : Message(Message::Type::IdentifierUpdate, {std::move(id_name), std::move(id_value)}) {} - IdentifierUpdate(data msg) : Message(std::move(msg)) {} + explicit IdentifierUpdate(data msg) : Message(std::move(msg)) {} + + explicit IdentifierUpdate(data_message msg) + : IdentifierUpdate(broker::move_data(msg)) {} const std::string& id_name() const { return get(get(as_vector()[2])[0]); @@ -480,24 +419,120 @@ public: } bool valid() const { - if (as_vector().size() < 3) + auto&& outer = data_.to_list(); + if (outer.size() < 3) return false; + auto&& inner = outer[2].to_list(); + return inner.size() >= 2 && inner[0].is_string(); + } +}; - auto vp = get_if(&(as_vector()[2])); +class BatchBuilder; - if (!vp) - return false; +/// A batch of other messages. +class Batch : public Message { +public: + explicit Batch(data msg); - auto& v = *vp; + explicit Batch(data_message msg) : Batch(broker::move_data(msg)) {} - if (v.size() < 2) - return false; + size_t size() const noexcept { + return impl_ ? impl_->size() : 0; + } - if (!get_if(&v[0])) - return false; + bool empty() const noexcept { + return size() == 0; + } - return true; + bool valid() const { + return impl_ != nullptr; + } + + template + auto for_each(F&& f) { + if (!impl_) + return; + for (auto& x : *impl_) + std::visit(f, x); } + + template + auto for_each(F&& f) const { + if (!impl_) + return; + for (const auto& x : *impl_) + std::visit(f, x); + } + +private: + using VarMsg = + std::variant; + + using Content = std::vector; + + std::shared_ptr impl_; }; +class BatchBuilder { +public: + void add(Message&& msg) { + inner_.emplace_back(msg.move_data()); + } + + bool empty() const noexcept { + return inner_.empty(); + } + + Batch build(); + +private: + vector inner_; +}; + +template +auto visit_as_message(F&& f, broker::data_message msg) { + auto do_visit = [&f](auto& tmp) { + if (tmp.valid()) + return f(tmp); + Invalid fallback{std::move(tmp)}; + return f(fallback); + }; + switch (Message::type(msg)) { + default: { + Invalid tmp{std::move(msg)}; + return f(tmp); + } + case Message::Type::Event: { + Event tmp{std::move(msg)}; + return do_visit(tmp); + } + case Message::Type::LogCreate: { + LogCreate tmp{std::move(msg)}; + return do_visit(tmp); + } + case Message::Type::LogWrite: { + LogWrite tmp{std::move(msg)}; + return do_visit(tmp); + } + case Message::Type::IdentifierUpdate: { + IdentifierUpdate tmp{std::move(msg)}; + return do_visit(tmp); + } + case Message::Type::Batch: { + Batch tmp{std::move(msg)}; + return do_visit(tmp); + } + } +} + } // namespace broker::zeek + +namespace broker { + +inline std::string to_string(const zeek::Message& msg) { + return to_string(msg.as_data()); +} + +} // namespace broker diff --git a/src/data.cc b/src/data.cc index 3d5e67d9..827579f8 100644 --- a/src/data.cc +++ b/src/data.cc @@ -93,6 +93,26 @@ const char* data::get_type_name() const { namespace { +vector empty_vector; + +enum_value empty_enum_value; + +} // namespace + +const enum_value& data::to_enum_value() const noexcept { + if (auto* val = std::get_if(&data_)) + return *val; + return empty_enum_value; +} + +const vector& data::to_list() const { + if (auto ptr = std::get_if(&data_)) + return *ptr; + return empty_vector; +} + +namespace { + template void container_convert(Container& c, std::string& str, char left, char right) { constexpr auto* delim = ", "; diff --git a/src/endpoint.cc b/src/endpoint.cc index 2db0e387..9d203d8b 100644 --- a/src/endpoint.cc +++ b/src/endpoint.cc @@ -19,6 +19,7 @@ #include "broker/status_subscriber.hh" #include "broker/subscriber.hh" #include "broker/timeout.hh" +#include "broker/zeek.hh" #include #include @@ -821,6 +822,15 @@ void endpoint::publish(const endpoint_info& dst, topic t, data d) { make_data_message(std::move(t), std::move(d)), dst); } +void endpoint::publish(std::string_view t, zeek::Message&& d) { + publish(topic{std::string{t}}, d.move_data()); +} + +void endpoint::publish(const endpoint_info& dst, std::string_view t, + zeek::Message&& d) { + publish(dst, topic{std::string{t}}, d.move_data()); +} + void endpoint::publish(data_message x) { BROKER_INFO("publishing" << x); caf::anon_send(native(core_), atom::publish_v, std::move(x)); diff --git a/src/zeek.cc b/src/zeek.cc new file mode 100644 index 00000000..fd735470 --- /dev/null +++ b/src/zeek.cc @@ -0,0 +1,64 @@ +#include "broker/zeek.hh" + +namespace broker::zeek { + +Message::~Message() {} + +Batch::Batch(data elements) : Message(std::move(elements)) { + if (!holds_alternative(data_)) + return; + auto& outer = as_vector(); + if (!holds_alternative(outer[2])) + return; + auto& items = get(outer[2]); + auto tmp = std::make_shared(); + tmp->reserve(items.size()); + auto append = [&tmp](auto&& msg) { + if (!msg.valid()) + return false; + tmp->emplace_back(std::forward(msg)); + return true; + }; + for (auto&& item : items) { + switch (Message::type(item)) { + case Message::Type::Event: + if (!append(zeek::Event{item})) + return; + break; + case Message::Type::LogCreate: + if (!append(zeek::LogCreate{item})) + return; + break; + case Message::Type::LogWrite: + if (!append(zeek::LogWrite{item})) + return; + break; + case Message::Type::IdentifierUpdate: + if (!append(zeek::IdentifierUpdate{item})) + return; + break; + case Message::Type::Batch: + if (!append(zeek::Batch{item})) + return; + break; + default: + return; + } + } + impl_ = std::move(tmp); +} + +Batch BatchBuilder::build() { + vector tmp; + tmp.swap(inner_); + inner_.reserve(tmp.size()); + vector outer; + outer.reserve(3); + outer.emplace_back(ProtocolVersion); + outer.emplace_back(static_cast(Message::Type::Batch)); + outer.emplace_back(std::move(tmp)); + auto result = Batch{data{std::move(outer)}}; + return result; +} + +} // namespace broker::zeek diff --git a/tests/benchmark/broker-benchmark.cc b/tests/benchmark/broker-benchmark.cc index 5360e55b..1910b0d9 100644 --- a/tests/benchmark/broker-benchmark.cc +++ b/tests/benchmark/broker-benchmark.cc @@ -149,8 +149,8 @@ void send_batch(endpoint& ep, publisher& p) { auto name = "event_" + std::to_string(event_type); vector batch; for (int i = 0; i < batch_size; i++) { - auto ev = zeek::Event(std::string(name), createEventArgs()); - batch.emplace_back(std::move(ev)); + auto ev = zeek::Event{std::string(name), createEventArgs()}; + batch.emplace_back(ev.move_data()); } total_sent += batch.size(); p.publish(std::move(batch)); @@ -215,7 +215,7 @@ void receivedStats(endpoint& ep, const data& x) { if (max_received && total_recv > max_received) { zeek::Event ev("quit_benchmark", std::vector{}); - ep.publish("/benchmark/terminate", ev); + ep.publish("/benchmark/terminate", std::move(ev)); std::this_thread::sleep_for(2s); // Give clients a bit. exit(0); } @@ -276,8 +276,8 @@ void client_loop(endpoint& ep, bool verbose, status_subscriber& ss) { // Pull: generate random events. for (size_t i = 0; i < hint; ++i) { auto name = "event_" + std::to_string(event_type); - out.emplace_back("/benchmark/events", - zeek::Event(std::move(name), createEventArgs())); + auto ev = zeek::Event{std::string(name), createEventArgs()}; + out.emplace_back("/benchmark/events", ev.move_data()); } }, [] { @@ -347,7 +347,7 @@ void server_mode(endpoint& ep, bool verbose, const std::string& iface, // Count number of events (counts each element in a batch as one event). if (zeek::Message::type(msg) == zeek::Message::Type::Batch) { zeek::Batch batch(std::move(msg)); - num_events += batch.batch().size(); + num_events += batch.size(); } else { ++num_events; } From e4b0cfc6713bf2bffb96e6fa2ca9c17f9436b8d7 Mon Sep 17 00:00:00 2001 From: Dominik Charousset Date: Thu, 12 Oct 2023 17:27:16 +0200 Subject: [PATCH 2/2] Replace magic numbers with named indexes --- include/broker/zeek.hh | 227 +++++++++++++++++++++++++++++++---------- src/zeek.cc | 10 +- 2 files changed, 174 insertions(+), 63 deletions(-) diff --git a/include/broker/zeek.hh b/include/broker/zeek.hh index 7590f5c2..716dae7a 100644 --- a/include/broker/zeek.hh +++ b/include/broker/zeek.hh @@ -24,6 +24,19 @@ enum class MetadataType : uint8_t { /// Generic Zeek-level message. class Message { public: + /// The index of the version field in the message. + static constexpr size_t version_index = 0; + + /// The index of the type field in the message. + static constexpr size_t type_index = 1; + + /// The index of the content field in the message. The type of the content + /// depends on the sub-type of the message. + static constexpr size_t content_index = 2; + + /// The number of top-level fields in the message. + static constexpr size_t num_top_level_fields = 3; + enum Type { Invalid = 0, Event = 1, @@ -34,6 +47,8 @@ public: MAX = Batch, }; + static constexpr auto max_tag = static_cast(Type::MAX); + virtual ~Message(); Type type() const { @@ -61,10 +76,9 @@ public: } static Type type(const data& msg) { - constexpr auto max_tag = static_cast(Type::MAX); auto&& elements = msg.to_list(); - if (elements.size() >= 2) { - auto tag = elements[1].to_count(max_tag + 1); + if (elements.size() >= num_top_level_fields) { + auto tag = elements[type_index].to_count(); if (tag <= max_tag) { return static_cast(tag); } @@ -77,6 +91,29 @@ public: } protected: + bool validate_outer_fields(Type tag) const { + auto&& outer = data_.to_list(); + if (outer.size() < num_top_level_fields) + return false; + + return outer[version_index].to_count() == ProtocolVersion + && outer[type_index].to_count() == static_cast(tag) + && outer[content_index].is_list(); + } + + /// Returns the content of the message, i.e., the fields for the sub-type. + /// @pre validate_outer_fields(tag) + const vector& sub_fields() const { + auto&& outer = data_.to_list(); + return outer[content_index].to_list(); + } + + /// @copydoc sub_fields() + vector& sub_fields() { + auto& outer = as_vector(); + return get(outer[content_index]); + } + explicit Message(Type type, vector content) : data_(vector{ProtocolVersion, count(type), std::move(content)}) {} @@ -181,6 +218,18 @@ private: /// A Zeek event. class Event : public Message { public: + /// The index of the event name field. + static constexpr size_t name_index = 0; + + /// The index of the event arguments field. + static constexpr size_t args_index = 1; + + /// The index of the optional metadata field. + static constexpr size_t metadata_index = 2; + + /// The minimum number of fields in a valid event. + static constexpr size_t min_fields = 2; + Event(std::string name, vector args) : Message(Message::Type::Event, {std::move(name), std::move(args)}) {} @@ -199,17 +248,19 @@ public: explicit Event(data_message msg) : Event(broker::move_data(msg)) {} const std::string& name() const { - return get(get(as_vector()[2])[0]); + auto&& fields = sub_fields(); + return get(fields[name_index]); } std::string& name() { - return get(get(as_vector()[2])[0]); + auto&& fields = sub_fields(); + return get(fields[name_index]); } MetadataWrapper metadata() const { - if (const auto* ev_vec_ptr = get_if(as_vector()[2]); - ev_vec_ptr && ev_vec_ptr->size() >= 3) - return MetadataWrapper{get_if((*ev_vec_ptr)[2])}; + auto&& fields = sub_fields(); + if (fields.size() > metadata_index) + return MetadataWrapper{get_if(fields[metadata_index])}; return MetadataWrapper{nullptr}; } @@ -222,20 +273,23 @@ public: } const vector& args() const { - return get(get(as_vector()[2])[1]); + auto&& fields = sub_fields(); + return get(fields[args_index]); } vector& args() { - return get(get(as_vector()[2])[1]); + auto&& fields = sub_fields(); + return get(fields[args_index]); } bool valid() const { - auto&& outer = data_.to_list(); - if (outer.size() < 3) + if (!validate_outer_fields(Type::Event)) return false; - auto&& items = outer[2].to_list(); - if (items.size() < 2 || !items[0].is_string() || !items[1].is_list()) + auto&& fields = sub_fields(); + + if (fields.size() < min_fields || !fields[name_index].is_string() + || !fields[args_index].is_list()) return false; // Optional event metadata verification. @@ -243,8 +297,8 @@ public: // Verify the third element if it exists is a vector> // and type and further check that the NetworkTimestamp metadata has the // right type because we know down here what to expect. - if (items.size() > 2) { - auto&& meta_field = items[2]; + if (fields.size() > metadata_index) { + auto&& meta_field = fields[metadata_index]; if (!meta_field.is_list()) return false; @@ -271,6 +325,21 @@ public: /// only by Zeek itself as the arguments aren't pulbically defined. class LogCreate : public Message { public: + /// The index of the stream ID field. + static constexpr size_t stream_id_index = 0; + + /// The index of the writer ID field. + static constexpr size_t writer_id_index = 1; + + /// The index of the writer info field. + static constexpr size_t writer_info_index = 2; + + /// The index of the fields data field. + static constexpr size_t fields_data_index = 3; + + /// The minimum number of fields in a valid log-create message. + static constexpr size_t min_fields = 4; + LogCreate(enum_value stream_id, enum_value writer_id, data writer_info, data fields_data) : Message(Message::Type::LogCreate, @@ -282,45 +351,53 @@ public: explicit LogCreate(data_message msg) : LogCreate(broker::move_data(msg)) {} const enum_value& stream_id() const { - return get(get(as_vector()[2])[0]); + auto&& fields = sub_fields(); + return get(fields[stream_id_index]); } enum_value& stream_id() { - return get(get(as_vector()[2])[0]); + auto&& fields = sub_fields(); + return get(fields[stream_id_index]); } const enum_value& writer_id() const { - return get(get(as_vector()[2])[1]); + auto&& fields = sub_fields(); + return get(fields[writer_id_index]); } enum_value& writer_id() { - return get(get(as_vector()[2])[1]); + auto&& fields = sub_fields(); + return get(fields[writer_id_index]); } const data& writer_info() const { - return get(as_vector()[2])[2]; + auto&& fields = sub_fields(); + return fields[writer_info_index]; } data& writer_info() { - return get(as_vector()[2])[2]; + auto&& fields = sub_fields(); + return fields[writer_info_index]; } const data& fields_data() const { - return get(as_vector()[2])[3]; + auto&& fields = sub_fields(); + return fields[fields_data_index]; } data& fields_data() { - return get(as_vector()[2])[3]; + auto&& fields = sub_fields(); + return fields[fields_data_index]; } bool valid() const { - auto&& outer = data_.to_list(); - if (outer.size() < 3) + if (!validate_outer_fields(Type::LogCreate)) return false; - auto&& inner = outer[2].to_list(); - return inner.size() >= 4 // - && inner[0].is_enum_value() // - && inner[1].is_enum_value(); + + auto&& fields = sub_fields(); + return fields.size() >= min_fields + && fields[stream_id_index].is_enum_value() + && fields[writer_id_index].is_enum_value(); } }; @@ -328,6 +405,21 @@ public: /// by Zeek itself as the arguments aren't publicly defined. class LogWrite : public Message { public: + /// The index of the stream ID field. + static constexpr size_t stream_id_index = 0; + + /// The index of the writer ID field. + static constexpr size_t writer_id_index = 1; + + /// The index of the path field. + static constexpr size_t path_index = 2; + + /// The index of the serial data field. + static constexpr size_t serial_data_index = 3; + + /// The minimum number of fields in a valid log-create message. + static constexpr size_t min_fields = 4; + LogWrite(enum_value stream_id, enum_value writer_id, data path, data serial_data) : Message(Message::Type::LogWrite, @@ -339,60 +431,79 @@ public: explicit LogWrite(data_message msg) : LogWrite(broker::move_data(msg)) {} const enum_value& stream_id() const { - return get(get(as_vector()[2])[0]); + auto&& fields = sub_fields(); + return get(fields[stream_id_index]); } enum_value& stream_id() { - return get(get(as_vector()[2])[0]); + auto&& fields = sub_fields(); + return get(fields[stream_id_index]); } const enum_value& writer_id() const { - return get(get(as_vector()[2])[1]); + auto&& fields = sub_fields(); + return get(fields[writer_id_index]); } enum_value& writer_id() { - return get(get(as_vector()[2])[1]); + auto&& fields = sub_fields(); + return get(fields[writer_id_index]); } const data& path() const { - return get(as_vector()[2])[2]; + auto&& fields = sub_fields(); + return fields[path_index]; } data& path() { - return get(as_vector()[2])[2]; + auto&& fields = sub_fields(); + return fields[path_index]; }; std::string_view path_str() { - return get(path()); + auto&& fields = sub_fields(); + return fields[path_index].to_string(); }; const data& serial_data() const { - return get(as_vector()[2])[3]; + auto&& fields = sub_fields(); + return fields[serial_data_index]; } data& serial_data() { - return get(as_vector()[2])[3]; + auto&& fields = sub_fields(); + return fields[serial_data_index]; } std::string_view serial_data_str() const { - return get(serial_data()); + auto&& fields = sub_fields(); + return fields[serial_data_index].to_string(); } bool valid() const { - auto&& outer = data_.to_list(); - if (outer.size() < 3) + if (!validate_outer_fields(Type::LogWrite)) return false; - auto&& inner = outer[2].to_list(); - return inner.size() >= 4 // - && inner[0].is_enum_value() // - && inner[1].is_enum_value() // - && inner[2].is_string() // - && inner[3].is_string(); + + auto&& fields = sub_fields(); + return fields.size() >= min_fields + && fields[stream_id_index].is_enum_value() + && fields[writer_id_index].is_enum_value() + && fields[path_index].is_string() + && fields[serial_data_index].is_string(); } }; class IdentifierUpdate : public Message { public: + /// The index of the ID name field. + static constexpr size_t id_name_index = 0; + + /// The index of the ID value field. + static constexpr size_t id_value_index = 1; + + /// The minimum number of fields in a valid identifier-update message. + static constexpr size_t min_fields = 2; + IdentifierUpdate(std::string id_name, data id_value) : Message(Message::Type::IdentifierUpdate, {std::move(id_name), std::move(id_value)}) {} @@ -403,27 +514,31 @@ public: : IdentifierUpdate(broker::move_data(msg)) {} const std::string& id_name() const { - return get(get(as_vector()[2])[0]); + auto&& fields = sub_fields(); + return get(fields[id_name_index]); } std::string& id_name() { - return get(get(as_vector()[2])[0]); + auto&& fields = sub_fields(); + return get(fields[id_name_index]); } const data& id_value() const { - return get(as_vector()[2])[1]; + auto&& fields = sub_fields(); + return fields[id_value_index]; } data& id_value() { - return get(as_vector()[2])[1]; + auto&& fields = sub_fields(); + return fields[id_value_index]; } bool valid() const { - auto&& outer = data_.to_list(); - if (outer.size() < 3) + if (!validate_outer_fields(Type::IdentifierUpdate)) return false; - auto&& inner = outer[2].to_list(); - return inner.size() >= 2 && inner[0].is_string(); + + auto&& fields = sub_fields(); + return fields.size() >= min_fields && fields[id_name_index].is_string(); } }; diff --git a/src/zeek.cc b/src/zeek.cc index fd735470..95f12e95 100644 --- a/src/zeek.cc +++ b/src/zeek.cc @@ -5,12 +5,9 @@ namespace broker::zeek { Message::~Message() {} Batch::Batch(data elements) : Message(std::move(elements)) { - if (!holds_alternative(data_)) + if (!validate_outer_fields(Type::Batch)) return; - auto& outer = as_vector(); - if (!holds_alternative(outer[2])) - return; - auto& items = get(outer[2]); + auto&& items = sub_fields(); // Each field in the content is a message. auto tmp = std::make_shared(); tmp->reserve(items.size()); auto append = [&tmp](auto&& msg) { @@ -57,8 +54,7 @@ Batch BatchBuilder::build() { outer.emplace_back(ProtocolVersion); outer.emplace_back(static_cast(Message::Type::Batch)); outer.emplace_back(std::move(tmp)); - auto result = Batch{data{std::move(outer)}}; - return result; + return Batch{data{std::move(outer)}}; } } // namespace broker::zeek