diff --git a/components/core/src/clp_s/Utils.cpp b/components/core/src/clp_s/Utils.cpp index b00121f40..976c3ee14 100644 --- a/components/core/src/clp_s/Utils.cpp +++ b/components/core/src/clp_s/Utils.cpp @@ -590,11 +590,7 @@ bool StringUtils::unescape_kql_internal( escaped = false; switch (value[i]) { case '\\': - if (is_value) { - unescaped.append("\\\\"); - } else { - unescaped.push_back('\\'); - } + unescaped.append("\\\\"); break; case '"': unescaped.push_back('"'); @@ -621,11 +617,24 @@ bool StringUtils::unescape_kql_internal( } auto four_byte_hex = std::string_view{value}.substr(i + 1, 4); - if (false == convert_four_byte_hex_to_utf8(four_byte_hex, unescaped)) { + i += 4; + + std::string tmp; + if (false == convert_four_byte_hex_to_utf8(four_byte_hex, tmp)) { return false; } - i += 4; - continue; + + // Make sure unicode escape sequences are always treated as literal characters + if ("\\" == tmp) { + unescaped.append("\\\\"); + } else if ("?" == tmp && is_value) { + unescaped.append("\\?"); + } else if ("*" == tmp) { + unescaped.append("\\*"); + } else { + unescaped.append(tmp); + } + break; } case '{': unescaped.push_back('{'); diff --git a/components/core/src/clp_s/clp-s.cpp b/components/core/src/clp_s/clp-s.cpp index b76683caf..d5015f6f7 100644 --- a/components/core/src/clp_s/clp-s.cpp +++ b/components/core/src/clp_s/clp-s.cpp @@ -197,7 +197,7 @@ bool search_archive( SPDLOG_ERROR("Can not tokenize invalid column: \"{}\"", column); return false; } - projection->add_column(ColumnDescriptor::create(descriptor_tokens)); + projection->add_column(ColumnDescriptor::create_from_escaped_tokens(descriptor_tokens)); } } catch (clp_s::TraceableException& e) { SPDLOG_ERROR("{}", e.what()); diff --git a/components/core/src/clp_s/search/AddTimestampConditions.cpp b/components/core/src/clp_s/search/AddTimestampConditions.cpp index 5af7ef44a..addf7069f 100644 --- a/components/core/src/clp_s/search/AddTimestampConditions.cpp +++ b/components/core/src/clp_s/search/AddTimestampConditions.cpp @@ -18,7 +18,8 @@ std::shared_ptr AddTimestampConditions::run(std::shared_ptr const& descriptors) { DescriptorList list; for (std::string const& descriptor : descriptors) { - list.push_back(DescriptorToken(descriptor)); + list.push_back(DescriptorToken::create_descriptor_from_escaped_token(descriptor)); } return list; } @@ -15,7 +15,7 @@ void ColumnDescriptor::check_and_set_unresolved_descriptor_flag() { m_unresolved_descriptors = false; m_pure_wildcard = m_descriptors.size() == 1 && m_descriptors[0].wildcard(); for (auto const& token : m_descriptors) { - if (token.wildcard() || token.regex()) { + if (token.wildcard()) { m_unresolved_descriptors = true; break; } @@ -24,7 +24,7 @@ void ColumnDescriptor::check_and_set_unresolved_descriptor_flag() { ColumnDescriptor::ColumnDescriptor(std::string const& descriptor) { m_flags = cAllTypes; - m_descriptors.emplace_back(descriptor); + m_descriptors.emplace_back(DescriptorToken::create_descriptor_from_escaped_token(descriptor)); check_and_set_unresolved_descriptor_flag(); if (is_unresolved_descriptor()) { simplify_descriptor_wildcards(); @@ -49,17 +49,21 @@ ColumnDescriptor::ColumnDescriptor(DescriptorList const& descriptors) { } } -std::shared_ptr ColumnDescriptor::create(std::string const& descriptor) { - return std::shared_ptr(new ColumnDescriptor(descriptor)); +std::shared_ptr ColumnDescriptor::create_from_escaped_token( + std::string const& token +) { + return std::shared_ptr(new ColumnDescriptor(token)); } -std::shared_ptr ColumnDescriptor::create( - std::vector const& descriptors +std::shared_ptr ColumnDescriptor::create_from_escaped_tokens( + std::vector const& tokens ) { - return std::shared_ptr(new ColumnDescriptor(descriptors)); + return std::shared_ptr(new ColumnDescriptor(tokens)); } -std::shared_ptr ColumnDescriptor::create(DescriptorList const& descriptors) { +std::shared_ptr ColumnDescriptor::create_from_descriptors( + DescriptorList const& descriptors +) { return std::shared_ptr(new ColumnDescriptor(descriptors)); } diff --git a/components/core/src/clp_s/search/ColumnDescriptor.hpp b/components/core/src/clp_s/search/ColumnDescriptor.hpp index ea1cfd7ec..6ade4728e 100644 --- a/components/core/src/clp_s/search/ColumnDescriptor.hpp +++ b/components/core/src/clp_s/search/ColumnDescriptor.hpp @@ -7,6 +7,7 @@ #include #include +#include "../TraceableException.hpp" #include "Literal.hpp" namespace clp_s::search { @@ -15,27 +16,32 @@ namespace clp_s::search { */ class DescriptorToken { public: + class OperationFailed : public TraceableException { + public: + // Constructors + OperationFailed(ErrorCode error_code, char const* const filename, int line_number) + : TraceableException(error_code, filename, line_number) {} + }; + // Constructors DescriptorToken() = default; /** - * Initialize the token from a string and set flags based on whether the token contains - * wildcards - * @param token the string to initialize the token from + * Creates a DescriptorToken from an escaped token string. The escape sequences '\\' and '\*' + * are supported in order to distinguish the literal '*' from the '*' used to match hierarchies + * of keys. + * @param token */ - explicit DescriptorToken(std::string_view const token) - : m_token(token), - m_wildcard(false), - m_regex(false) { - if (token == "*") { - m_wildcard = true; - } + static DescriptorToken create_descriptor_from_escaped_token(std::string_view const token) { + return DescriptorToken{token, false}; + } - for (char c : token) { - if (c == '*') { - m_regex = true; - } - } + /** + * Creates a DescriptorToken from a literal token string. The token is copied verbatim, and is + * never treated as a wildcard. + */ + static DescriptorToken create_descriptor_from_literal_token(std::string_view const token) { + return DescriptorToken{token, true}; } /** @@ -44,13 +50,6 @@ class DescriptorToken { */ bool wildcard() const { return m_wildcard; } - /** - * Whether the descriptor contains a wildcard somewhere - * TODO: Not currently used, and regex isn't currently supported - * @return true if the descriptor contains a wildcard - */ - bool regex() const { return m_regex; } - /** * Get a reference to the underlying token string * @return a reference to the underlying string @@ -62,14 +61,47 @@ class DescriptorToken { * @return Whether this token is equal to the given token */ bool operator==(DescriptorToken const& rhs) const { - // Note: we only need to compare the m_token field because m_regex and m_wildcard are - // derived from m_token. - return m_token == rhs.m_token; + return m_token == rhs.m_token && m_wildcard == rhs.m_wildcard; } private: + /** + * Initialize the token from a string and set flags based on whether the token contains + * wildcards + * @param token the string to initialize the token from + * @param bool true if the string should be interpreted as literal, and false + */ + explicit DescriptorToken(std::string_view const token, bool is_literal) : m_wildcard(false) { + if (is_literal) { + m_token = token; + return; + } + + if (token == "*") { + m_wildcard = true; + } + + bool escaped{false}; + for (size_t i = 0; i < token.size(); ++i) { + if (false == escaped) { + if ('\\' == token[i]) { + escaped = true; + } else { + m_token.push_back(token[i]); + } + continue; + } else { + m_token.push_back(token[i]); + escaped = false; + } + } + + if (escaped) { + throw OperationFailed(ErrorCodeBadParam, __FILENAME__, __LINE__); + } + } + bool m_wildcard{}; - bool m_regex{}; std::string m_token; }; @@ -92,9 +124,27 @@ class ColumnDescriptor : public Literal { * @param descriptor(s) the token or list of tokens making up the descriptor * @return A ColumnDescriptor */ - static std::shared_ptr create(std::string const& descriptor); - static std::shared_ptr create(std::vector const& descriptors); - static std::shared_ptr create(DescriptorList const& descriptors); + static std::shared_ptr create_from_escaped_token(std::string const& token); + static std::shared_ptr create_from_escaped_tokens( + std::vector const& tokens + ); + static std::shared_ptr create_from_descriptors( + DescriptorList const& descriptors + ); + + /** + * Insert an entire DescriptorList into this ColumnDescriptor at before given position. + * @param pos an iterator to the position inside of the internal descriptor list to insert + * before. + * @param source the list of descriptors to insert + */ + void insert(DescriptorList::iterator pos, DescriptorList const& source) { + m_descriptors.insert(pos, source.begin(), source.end()); + check_and_set_unresolved_descriptor_flag(); + if (is_unresolved_descriptor()) { + simplify_descriptor_wildcards(); + } + } /** * Deep copy of this ColumnDescriptor diff --git a/components/core/src/clp_s/search/SchemaMatch.cpp b/components/core/src/clp_s/search/SchemaMatch.cpp index 203634000..6b5f2ce8c 100644 --- a/components/core/src/clp_s/search/SchemaMatch.cpp +++ b/components/core/src/clp_s/search/SchemaMatch.cpp @@ -77,12 +77,15 @@ std::shared_ptr SchemaMatch::populate_column_mapping(std::shared_ptr auto literal_type = node_to_literal_type(node->get_type()); DescriptorList descriptors; while (node->get_id() != m_tree->get_object_subtree_node_id()) { - // may have to explicitly mark non-regex - descriptors.emplace_back(node->get_key_name()); + descriptors.emplace_back( + DescriptorToken::create_descriptor_from_literal_token( + node->get_key_name() + ) + ); node = &m_tree->get_node(node->get_parent_id()); } std::reverse(descriptors.begin(), descriptors.end()); - auto resolved_column = ColumnDescriptor::create(descriptors); + auto resolved_column = ColumnDescriptor::create_from_descriptors(descriptors); resolved_column->set_matching_type(literal_type); *it = resolved_column; cur->copy_append(possibilities.get()); diff --git a/components/core/src/clp_s/search/StringLiteral.hpp b/components/core/src/clp_s/search/StringLiteral.hpp index 4ac6b9f2f..b741541e9 100644 --- a/components/core/src/clp_s/search/StringLiteral.hpp +++ b/components/core/src/clp_s/search/StringLiteral.hpp @@ -4,6 +4,7 @@ #include #include +#include "../Utils.hpp" #include "Literal.hpp" namespace clp_s::search { @@ -69,18 +70,8 @@ class StringLiteral : public Literal { } // If '?' and '*' are not escaped, we add LiteralType::ClpStringT to m_string_type - bool escape = false; - for (char const c : m_v) { - if ('\\' == c) { - escape = !escape; - } else if ('?' == c || '*' == c) { - if (false == escape) { - m_string_type |= LiteralType::ClpStringT; - break; - } - } else { - escape = false; - } + if (StringUtils::has_unescaped_wildcards(m_v)) { + m_string_type |= LiteralType::ClpStringT; } } }; diff --git a/components/core/src/clp_s/search/kql/kql.cpp b/components/core/src/clp_s/search/kql/kql.cpp index 34085d1f9..098886671 100644 --- a/components/core/src/clp_s/search/kql/kql.cpp +++ b/components/core/src/clp_s/search/kql/kql.cpp @@ -55,7 +55,7 @@ class ParseTreeVisitor : public KqlBaseVisitor { private: static void prepend_column(std::shared_ptr const& desc, DescriptorList const& prefix) { - desc->get_descriptor_list().insert(desc->descriptor_begin(), prefix.begin(), prefix.end()); + desc->insert(desc->get_descriptor_list().begin(), prefix); } void prepend_column(std::shared_ptr const& expr, DescriptorList const& prefix) { @@ -126,7 +126,7 @@ class ParseTreeVisitor : public KqlBaseVisitor { return nullptr; } - return ColumnDescriptor::create(descriptor_tokens); + return ColumnDescriptor::create_from_escaped_tokens(descriptor_tokens); } std::any visitNestedQuery(KqlParser::NestedQueryContext* ctx) override { @@ -202,7 +202,7 @@ class ParseTreeVisitor : public KqlBaseVisitor { std::any visitValue_expression(KqlParser::Value_expressionContext* ctx) override { auto lit = unquote_literal(ctx->LITERAL()->getText()); - auto descriptor = ColumnDescriptor::create("*"); + auto descriptor = ColumnDescriptor::create_from_escaped_token("*"); return FilterExpr::create(descriptor, FilterOperation::EQ, lit); } @@ -222,7 +222,7 @@ class ParseTreeVisitor : public KqlBaseVisitor { base = OrExpr::create(); } - auto empty_descriptor = ColumnDescriptor::create(DescriptorList()); + auto empty_descriptor = ColumnDescriptor::create_from_descriptors(DescriptorList()); for (auto token : ctx->literals) { auto literal = unquote_literal(token->getText()); auto expr = FilterExpr::create(