diff --git a/src/search/ir.h b/src/search/ir.h index 3ba980dab4e..c7aec26ba80 100644 --- a/src/search/ir.h +++ b/src/search/ir.h @@ -265,24 +265,19 @@ struct VectorRangeExpr : BoolAtomExpr { }; struct VectorKnnExpr : BoolAtomExpr { - // TODO: Support pre-filter for hybrid query std::unique_ptr field; - std::unique_ptr k; std::unique_ptr vector; + size_t k; - VectorKnnExpr(std::unique_ptr &&field, std::unique_ptr &&k, - std::unique_ptr &&vector) - : field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {} + VectorKnnExpr(std::unique_ptr &&field, std::unique_ptr &&vector, size_t k) + : field(std::move(field)), vector(std::move(vector)), k(k) {} std::string_view Name() const override { return "VectorKnnExpr"; } - std::string Dump() const override { - return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(), vector->Dump()); - } + std::string Dump() const override { return fmt::format("KNN k={}, {} <-> {}", k, field->Dump(), vector->Dump()); } std::unique_ptr Clone() const override { return std::make_unique(Node::MustAs(field->Clone()), - Node::MustAs(k->Clone()), - Node::MustAs(vector->Clone())); + Node::MustAs(vector->Clone()), k); } }; @@ -425,6 +420,10 @@ struct SortByClause : Node { std::unique_ptr Clone() const override { return std::make_unique(order, Node::MustAs(field->Clone())); } + + std::unique_ptr TakeFieldRef() { return std::move(field); } + + std::unique_ptr TakeVectorLiteral() { return std::move(vector); } }; struct SelectClause : Node { diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h index 2068a45a4f4..e783ca8f486 100644 --- a/src/search/ir_pass.h +++ b/src/search/ir_pass.h @@ -59,6 +59,12 @@ struct Visitor : Pass { return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { @@ -69,6 +75,10 @@ struct Visitor : Pass { return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { @@ -125,6 +135,8 @@ struct Visitor : Pass { virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + virtual std::unique_ptr Visit(std::unique_ptr node) { node->field = VisitAs(std::move(node->field)); node->num = VisitAs(std::move(node->num)); @@ -137,6 +149,19 @@ struct Visitor : Pass { return node; } + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->field = VisitAs(std::move(node->field)); + node->vector = VisitAs(std::move(node->vector)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->field = VisitAs(std::move(node->field)); + node->range = VisitAs(std::move(node->range)); + node->vector = VisitAs(std::move(node->vector)); + return node; + } + virtual std::unique_ptr Visit(std::unique_ptr node) { for (auto &n : node->inners) { n = TransformAs(std::move(n)); @@ -173,6 +198,10 @@ struct Visitor : Pass { virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + virtual std::unique_ptr Visit(std::unique_ptr node) { node->source = TransformAs(std::move(node->source)); node->filter_expr = TransformAs(std::move(node->filter_expr)); diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h index 94e8b589c60..8743a827339 100644 --- a/src/search/ir_plan.h +++ b/src/search/ir_plan.h @@ -99,7 +99,7 @@ struct TagFieldScan : FieldScan { struct HnswVectorFieldKnnScan : FieldScan { kqir::NumericArray vector; - uint16_t k; + uint32_t k; HnswVectorFieldKnnScan(std::unique_ptr field, kqir::NumericArray vector, uint16_t k) : FieldScan(std::move(field)), vector(std::move(vector)), k(k) {} diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h index 43d722b4d0b..e79a19327af 100644 --- a/src/search/ir_sema_checker.h +++ b/src/search/ir_sema_checker.h @@ -129,9 +129,6 @@ struct SemaChecker { return {Status::NotOK, fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)}; } - if (v->k->val <= 0) { - return {Status::NotOK, fmt::format("KNN search parameter `k` must be greater than 0")}; - } auto meta = v->field->info->MetadataAs(); if (v->vector->values.size() != meta->dim) { return {Status::NotOK, diff --git a/src/search/passes/cost_model.h b/src/search/passes/cost_model.h index 86e0e3a58e5..960708d740c 100644 --- a/src/search/passes/cost_model.h +++ b/src/search/passes/cost_model.h @@ -36,6 +36,12 @@ struct CostModel { if (auto v = dynamic_cast(node)) { return Visit(v); } + if (auto v = dynamic_cast(node)) { + return Visit(v); + } + if (auto v = dynamic_cast(node)) { + return Visit(v); + } if (auto v = dynamic_cast(node)) { return Visit(v); } @@ -74,6 +80,10 @@ struct CostModel { static size_t Visit(const TagFieldScan *node) { return 10; } + static size_t Visit(const HnswVectorFieldKnnScan *node) { return 3; } + + static size_t Visit(const HnswVectorFieldRangeScan *node) { return 4; } + static size_t Visit(const Filter *node) { return Transform(node->source.get()) + 1; } static size_t Visit(const Merge *node) { diff --git a/src/search/passes/index_selection.h b/src/search/passes/index_selection.h index e60287d4d01..09e1bcb34f5 100644 --- a/src/search/passes/index_selection.h +++ b/src/search/passes/index_selection.h @@ -112,6 +112,12 @@ struct IndexSelection : Visitor { if (auto v = dynamic_cast(node)) { return VisitExpr(v); } + if (auto v = dynamic_cast(node)) { + return VisitExpr(v); + } + if (auto v = dynamic_cast(node)) { + return VisitExpr(v); + } if (auto v = dynamic_cast(node)) { return VisitExpr(v); } @@ -153,6 +159,23 @@ struct IndexSelection : Visitor { return MakeFullIndexFilter(node); } + std::unique_ptr VisitExpr(VectorRangeExpr *node) const { + if (node->field->info->HasIndex()) { + return std::make_unique(node->field->CloneAs(), node->vector->values, + node->range->val); + } + + return MakeFullIndexFilter(node); + } + + std::unique_ptr VisitExpr(VectorKnnExpr *node) const { + if (node->field->info->HasIndex()) { + return std::make_unique(node->field->CloneAs(), node->vector->values, node->k); + } + + return MakeFullIndexFilter(node); + } + template std::unique_ptr VisitExprImpl(Expr *node) { struct AggregatedNodes { diff --git a/src/search/passes/manager.h b/src/search/passes/manager.h index 57f317d213c..d2f07bb7f6a 100644 --- a/src/search/passes/manager.h +++ b/src/search/passes/manager.h @@ -35,6 +35,7 @@ #include "search/passes/simplify_and_or_expr.h" #include "search/passes/simplify_boolean.h" #include "search/passes/sort_limit_fuse.h" +#include "search/passes/transfer_to_knn.h" #include "type_util.h" namespace kqir { @@ -86,7 +87,8 @@ struct PassManager { } static PassSequence ExprPasses() { - return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{}); + return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{}, + TransferSortByToKnnExpr{}, SimplifyAndOrExpr{}); } static PassSequence NumericPasses() { return Create(IntervalAnalysis{true}, SimplifyAndOrExpr{}, SimplifyBoolean{}); } static PassSequence PlanPasses() { return Create(LowerToPlan{}, IndexSelection{}, SortLimitFuse{}); } diff --git a/src/search/passes/transfer_to_knn.h b/src/search/passes/transfer_to_knn.h new file mode 100644 index 00000000000..aeee776350b --- /dev/null +++ b/src/search/passes/transfer_to_knn.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/ir.h" +#include "search/ir_pass.h" +#include "search/ir_plan.h" + +namespace kqir { + +struct TransferSortByToKnnExpr : Visitor { + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + if (node->sort_by && node->sort_by->IsVectorField() && node->limit) { + std::vector> exprs; + auto knn_expr = std::make_unique(Node::MustAs(node->sort_by->TakeFieldRef()), + Node::MustAs(node->sort_by->TakeVectorLiteral()), + node->limit->Count()); + if (auto b = Node::As(std::move(node->query_expr))) { + if (b->val) exprs.push_back(std::move(knn_expr)); + } else { + exprs.push_back(std::move(node->query_expr)); + exprs.push_back(std::move(knn_expr)); + } + + if (exprs.empty()) { + node->query_expr = std::make_unique(false); + } else if (exprs.size() == 1) { + node->query_expr = std::move(exprs[0]); + } else { + node->query_expr = std::make_unique(std::move(exprs)); + } + node->sort_by.reset(); + node->limit.reset(); + } + + return node; + } +}; + +} // namespace kqir diff --git a/src/search/redis_query_parser.h b/src/search/redis_query_parser.h index 5b0f172c763..8596a977e78 100644 --- a/src/search/redis_query_parser.h +++ b/src/search/redis_query_parser.h @@ -43,13 +43,14 @@ struct Tag : sor {}; struct TagList : seq, WSPad, star, WSPad>>, one<'}'>> {}; struct NumberOrParam : sor {}; +struct UintOrParam : sor {}; struct Inf : seq>, string<'i', 'n', 'f'>> {}; struct ExclusiveNumber : seq, NumberOrParam> {}; struct NumericRangePart : sor {}; struct NumericRange : seq, WSPad, WSPad, one<']'>> {}; -struct KnnSearch : seq, WSPad, WSPad, WSPad, WSPad, one<']'>> {}; +struct KnnSearch : seq, WSPad, WSPad, WSPad, WSPad, one<']'>> {}; struct VectorRange : seq, WSPad, WSPad, WSPad, one<']'>> {}; struct FieldQuery : seq, one<':'>, WSPad>> {}; diff --git a/src/search/redis_query_transformer.h b/src/search/redis_query_transformer.h index c81230e4ebf..9329261e70a 100644 --- a/src/search/redis_query_transformer.h +++ b/src/search/redis_query_transformer.h @@ -36,7 +36,7 @@ namespace ir = kqir; template using TreeSelector = parse_tree::selector< - Rule, parse_tree::store_content::on, + Rule, parse_tree::store_content::on, parse_tree::remove_content::on>; @@ -168,9 +168,15 @@ struct Transformer : ir::TreeTransformer { const auto& knn_search = node->children[2]; CHECK(knn_search->children.size() == 4); + size_t k = 0; + if (Is(knn_search->children[1])) { + k = *ParseInt(knn_search->children[1]->string()); + } else { + k = *ParseInt(GET_OR_RET(GetParam(node))); + } + return std::make_unique(std::make_unique(knn_search->children[2]->string()), - GET_OR_RET(number_or_param(knn_search->children[1])), - GET_OR_RET(Transform2Vector(knn_search->children[3]))); + GET_OR_RET(Transform2Vector(knn_search->children[3])), k); } else if (Is(node)) { std::vector> exprs; diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h index 01705107776..49d04307ea8 100644 --- a/src/search/sql_transformer.h +++ b/src/search/sql_transformer.h @@ -118,7 +118,6 @@ struct Transformer : ir::TreeTransformer { return {Status::NotOK, "the left and right side of numeric comparison should be an identifier and a number"}; } } else if (Is(node)) { - // TODO(Beihao): Handle distance metrics for operator CHECK(node->children.size() == 2); const auto& vector_comp_expr = node->children[0]; CHECK(vector_comp_expr->children.size() == 3); diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc index 81ed49e8b94..ee28504bd1b 100644 --- a/tests/cppunit/ir_pass_test.cc +++ b/tests/cppunit/ir_pass_test.cc @@ -111,6 +111,15 @@ TEST(IRPassTest, Manager) { "select * from a where (and x <= 1, y >= 2, z != 3)"); } +TEST(IRPassTest, TransferSortByToKnnExpr) { + TransferSortByToKnnExpr tsbtke; + + ASSERT_EQ(tsbtke.Transform(*Parse("select a from b order by embedding <-> [3.6] limit 5"))->Dump(), + "select a from b where KNN k=5, embedding <-> [3.600000]"); + ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where c = 1 order by embedding <-> [3,1,2] limit 5"))->Dump(), + "select a from b where (and c = 1, KNN k=5, embedding <-> [3.000000, 1.000000, 2.000000])"); +} + TEST(IRPassTest, LowerToPlan) { LowerToPlan ltp; @@ -123,6 +132,8 @@ TEST(IRPassTest, LowerToPlan) { "project a: (sort d, asc: (filter c = 1: full-scan b))"); ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 limit 1"))->Dump(), "project a: (limit 0, 1: (filter c = 1: full-scan b))"); + ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 and d = 2 order by e limit 1"))->Dump(), + "project a: (limit 0, 1: (sort e, asc: (filter (and c = 1, d = 2): full-scan b)))"); ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit 1"))->Dump(), "project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan b)))"); } @@ -176,12 +187,28 @@ static IndexMap MakeIndexMap() { auto f4 = FieldInfo("n2", std::make_unique()); auto f5 = FieldInfo("n3", std::make_unique()); f5.metadata->noindex = true; + + auto hnsw_field_meta = std::make_unique(); + hnsw_field_meta->vector_type = redis::VectorType::FLOAT64; + hnsw_field_meta->dim = 3; + hnsw_field_meta->distance_metric = redis::DistanceMetric::L2; + auto f6 = FieldInfo("v1", std::move(hnsw_field_meta)); + + hnsw_field_meta = std::make_unique(); + hnsw_field_meta->vector_type = redis::VectorType::FLOAT64; + hnsw_field_meta->dim = 3; + hnsw_field_meta->distance_metric = redis::DistanceMetric::L2; + auto f7 = FieldInfo("v2", std::move(hnsw_field_meta)); + f7.metadata->noindex = true; + auto ia = std::make_unique("ia", redis::IndexMetadata(), ""); ia->Add(std::move(f1)); ia->Add(std::move(f2)); ia->Add(std::move(f3)); ia->Add(std::move(f4)); ia->Add(std::move(f5)); + ia->Add(std::move(f6)); + ia->Add(std::move(f7)); IndexMap res; res.Insert(std::move(ia)); @@ -238,6 +265,26 @@ TEST(IRPassTest, IndexSelection) { "project *: (filter t2 hastag \"a\": tag-scan t1, a)"); ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where t2 hastag \"a\""))->Dump(), "project *: (filter t2 hastag \"a\": full-scan ia)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where v1 <-> [3,1,2] < 5"))->Dump(), + "project *: hnsw-vector-range-scan v1, [3.000000, 1.000000, 2.000000], 5"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by v1 <-> [3,1,2] limit 5"))->Dump(), + "project *: hnsw-vector-knn-scan v1, [3.000000, 1.000000, 2.000000], 5"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where v2 <-> [3,1,2] < 5"))->Dump(), + "project *: (filter v2 <-> [3.000000, 1.000000, 2.000000] < 5: full-scan ia)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 and v1 <-> [3,1,2] < 5"))->Dump(), + "project *: (filter n1 >= 1: hnsw-vector-range-scan v1, [3.000000, 1.000000, 2.000000], 5)"); + ASSERT_EQ( + PassManager::Execute(passes, ParseS(sc, "select * from ia where v1 <-> [3,1,2] < 5 and t1 hastag \"a\""))->Dump(), + "project *: (filter t1 hastag \"a\": hnsw-vector-range-scan v1, [3.000000, 1.000000, 2.000000], 5)"); + ASSERT_EQ( + PassManager::Execute(passes, ParseS(sc, "select * from ia where t1 hastag \"a\" order by v1 <-> [3,1,2] limit 5")) + ->Dump(), + "project *: (filter t1 hastag \"a\": hnsw-vector-knn-scan v1, [3.000000, 1.000000, 2.000000], 5)"); + ASSERT_EQ(PassManager::Execute( + passes, ParseS(sc, "select * from ia where v1 <-> [3,1,2] < 2 order by v1 <-> [3,1,2] limit 5")) + ->Dump(), + "project *: (filter v1 <-> [3.000000, 1.000000, 2.000000] < 2: hnsw-vector-knn-scan v1, [3.000000, " + "1.000000, 2.000000], 5)"); ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 2 or n1 < 1"))->Dump(), "project *: (merge numeric-scan n1, [-inf, 1), asc, numeric-scan n1, [2, inf), asc)"); diff --git a/tests/cppunit/redis_query_parser_test.cc b/tests/cppunit/redis_query_parser_test.cc index 4fc25e49db2..c0178e45a53 100644 --- a/tests/cppunit/redis_query_parser_test.cc +++ b/tests/cppunit/redis_query_parser_test.cc @@ -115,6 +115,7 @@ TEST(RedisQueryParserTest, Vector) { AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}})); AssertSyntaxError(Parse("[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}})); AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}})); + AssertSyntaxError(Parse("* =>[KNN -1 @vector $BLOB]", {{"BLOB", vec_str}})); AssertSyntaxError(Parse("*=>[KNN 5 $vector_blob_param]", {{"vector_blob_param", vec_str}})); AssertIR(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", vec_str}}),