Skip to content

Commit

Permalink
Add executor and parser for hybrid query
Browse files Browse the repository at this point in the history
  • Loading branch information
Beihao-Zhou committed Aug 4, 2024
1 parent 470a6f3 commit d812e60
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 9 deletions.
22 changes: 22 additions & 0 deletions src/search/executors/filter_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <variant>

#include "parse_util.h"
#include "search/hnsw_indexer.h"
#include "search/ir.h"
#include "search/plan_executor.h"
#include "search/search_encoding.h"
Expand All @@ -44,6 +45,9 @@ struct QueryExprEvaluator {
if (auto v = dynamic_cast<NotExpr *>(e)) {
return Visit(v);
}
if (auto v = dynamic_cast<VectorRangeExpr *>(e)) {
return Visit(v);
}
if (auto v = dynamic_cast<NumericCompareExpr *>(e)) {
return Visit(v);
}
Expand Down Expand Up @@ -112,6 +116,24 @@ struct QueryExprEvaluator {
__builtin_unreachable();
}
}

StatusOr<bool> Visit(VectorRangeExpr *v) const {
auto val = GET_OR_RET(ctx->Retrieve(row, v->field->info));

CHECK(val.Is<kqir::NumericArray>());
auto l_values = val.Get<kqir::NumericArray>();
auto r_values = v->vector->values;
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();

redis::VectorItem left, right;
GET_OR_RET(redis::VectorItem::Create({}, l_values, meta, &left));
GET_OR_RET(redis::VectorItem::Create({}, r_values, meta, &right));

auto dist = GET_OR_RET(redis::ComputeSimilarity(left, right));
auto effective_range = v->range->val * (1 + meta->epsilon);

return (dist >= -abs(effective_range) && dist <= abs(effective_range));
}
};

struct FilterExecutor : ExecutorNode {
Expand Down
2 changes: 1 addition & 1 deletion src/search/redis_query_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct AndExprP : sor<AndExpr, BooleanExpr> {};
struct OrExpr : seq<AndExprP, plus<seq<one<'|'>, AndExprP>>> {};
struct OrExprP : sor<OrExpr, AndExprP> {};

struct PrefilterExpr : seq<WSPad<BooleanExpr>, ArrowOp, WSPad<KnnSearch>> {};
struct PrefilterExpr : seq<WSPad<OrExprP>, ArrowOp, WSPad<KnnSearch>> {};

struct QueryP : sor<PrefilterExpr, OrExprP> {};

Expand Down
11 changes: 7 additions & 4 deletions src/search/redis_query_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ struct Transformer : ir::TreeTransformer {
} else if (Is<PrefilterExpr>(node)) {
CHECK(node->children.size() == 3);

// TODO(Beihao): Support Hybrid Query
// const auto& prefilter = node->children[0];
const auto& knn_search = node->children[2];
CHECK(knn_search->children.size() == 4);

Expand All @@ -175,9 +173,14 @@ struct Transformer : ir::TreeTransformer {
k = *ParseInt(GET_OR_RET(GetParam(node)));
}

return std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(Transform2Vector(knn_search->children[3])), k);
auto knn_expr = std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(Transform2Vector(knn_search->children[3])), k);

std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
exprs.push_back(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
exprs.push_back(std::move(knn_expr));

return Node::Create<ir::AndExpr>(std::move(exprs));
} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

Expand Down
36 changes: 36 additions & 0 deletions tests/cppunit/plan_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ static auto FieldI(const std::string& f) -> const FieldInfo* { return &IndexI()-

static auto N(double n) { return MakeValue<Numeric>(n); }
static auto T(const std::string& v) { return MakeValue<StringArray>(util::Split(v, ",")); }
static auto V(const std::vector<double>& vals) { return MakeValue<NumericArray>(vals); }

TEST(PlanExecutorTest, TopNSort) {
std::vector<ExecutorNode::RowType> data{
Expand Down Expand Up @@ -201,6 +202,41 @@ TEST(PlanExecutorTest, Filter) {
ASSERT_EQ(NextRow(ctx).key, "f");
ASSERT_EQ(ctx.Next().GetValue(), exe_end);
}

data = {{"a", {{FieldI("f4"), V({1, 2, 3})}}, IndexI()}, {"b", {{FieldI("f4"), V({9, 10, 11})}}, IndexI()},
{"c", {{FieldI("f4"), V({4, 5, 6})}}, IndexI()}, {"d", {{FieldI("f4"), V({1, 2, 3})}}, IndexI()},
{"e", {{FieldI("f4"), V({2, 3, 4})}}, IndexI()}, {"f", {{FieldI("f4"), V({12, 13, 14})}}, IndexI()},
{"g", {{FieldI("f4"), V({1, 2, 3})}}, IndexI()}};
{
auto field = std::make_unique<FieldRef>("f4", FieldI("f4"));
std::vector<double> vector = {11, 12, 13};
auto op = std::make_unique<Filter>(
std::make_unique<Mock>(data),
std::make_unique<VectorRangeExpr>(field->CloneAs<FieldRef>(), std::make_unique<NumericLiteral>(4),
std::make_unique<VectorLiteral>(std::move(vector))));

auto ctx = ExecutorContext(op.get());
ASSERT_EQ(NextRow(ctx).key, "b");
ASSERT_EQ(NextRow(ctx).key, "f");
ASSERT_EQ(ctx.Next().GetValue(), exe_end);
}

{
auto field = std::make_unique<FieldRef>("f4", FieldI("f4"));
std::vector<double> vector = {2, 3, 4};
auto op = std::make_unique<Filter>(
std::make_unique<Mock>(data),
std::make_unique<VectorRangeExpr>(field->CloneAs<FieldRef>(), std::make_unique<NumericLiteral>(5),
std::make_unique<VectorLiteral>(std::move(vector))));

auto ctx = ExecutorContext(op.get());
ASSERT_EQ(NextRow(ctx).key, "a");
ASSERT_EQ(NextRow(ctx).key, "c");
ASSERT_EQ(NextRow(ctx).key, "d");
ASSERT_EQ(NextRow(ctx).key, "e");
ASSERT_EQ(NextRow(ctx).key, "g");
ASSERT_EQ(ctx.Next().GetValue(), exe_end);
}
}

TEST(PlanExecutorTest, Limit) {
Expand Down
8 changes: 4 additions & 4 deletions tests/cppunit/redis_query_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ TEST(RedisQueryParserTest, Vector) {
AssertIR(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", vec_str}}),
"field <-> [1.000000, 2.000000, 3.000000] < 10");
AssertIR(Parse("*=>[KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}),
"KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]");
"(and true, KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000])");
AssertIR(Parse("(*) => [KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}),
"KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]");
"(and true, KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000])");
AssertIR(Parse("(@a:[1 2]) => [KNN 8 @vec_embedding $blob]", {{"blob", vec_str}}),
"KNN k=8, vec_embedding <-> [1.000000, 2.000000, 3.000000]");
"(and (and a >= 1, a <= 2), KNN k=8, vec_embedding <-> [1.000000, 2.000000, 3.000000])");
AssertIR(Parse("* =>[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}}),
"KNN k=5, vector <-> [1.000000, 2.000000, 3.000000]");
"(and true, KNN k=5, vector <-> [1.000000, 2.000000, 3.000000])");

vec_str = vec_str.substr(0, 3);
ASSERT_EQ(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", vec_str}}).Msg(),
Expand Down

0 comments on commit d812e60

Please sign in to comment.