Skip to content

Commit

Permalink
Add optimization pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Beihao-Zhou committed Aug 4, 2024
1 parent 0f5f18e commit b7fcd21
Show file tree
Hide file tree
Showing 13 changed files with 196 additions and 20 deletions.
19 changes: 9 additions & 10 deletions src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,24 +265,19 @@ struct VectorRangeExpr : BoolAtomExpr {
};

struct VectorKnnExpr : BoolAtomExpr {
// TODO: Support pre-filter for hybrid query
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> k;
std::unique_ptr<VectorLiteral> vector;
size_t k;

VectorKnnExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&k,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {}
VectorKnnExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<VectorLiteral> &&vector, size_t k)
: field(std::move(field)), vector(std::move(vector)), k(k) {}

std::string_view Name() const override { return "VectorKnnExpr"; }
std::string Dump() const override {
return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(), vector->Dump());
}
std::string Dump() const override { return fmt::format("KNN k={}, {} <-> {}", k, field->Dump(), vector->Dump()); }

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(k->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
Node::MustAs<VectorLiteral>(vector->Clone()), k);
}
};

Expand Down Expand Up @@ -425,6 +420,10 @@ struct SortByClause : Node {
std::unique_ptr<Node> Clone() const override {
return std::make_unique<SortByClause>(order, Node::MustAs<FieldRef>(field->Clone()));
}

std::unique_ptr<FieldRef> TakeFieldRef() { return std::move(field); }

std::unique_ptr<VectorLiteral> TakeVectorLiteral() { return std::move(vector); }
};

struct SelectClause : Node {
Expand Down
29 changes: 29 additions & 0 deletions src/search/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ struct Visitor : Pass {
return Visit(std::move(v));
} else if (auto v = Node::As<TagContainExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<VectorLiteral>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<VectorKnnExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<VectorRangeExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<StringLiteral>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
Expand All @@ -69,6 +75,10 @@ struct Visitor : Pass {
return Visit(std::move(v));
} else if (auto v = Node::As<TagFieldScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<HnswVectorFieldRangeScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<HnswVectorFieldKnnScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Filter>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Limit>(std::move(node))) {
Expand Down Expand Up @@ -125,6 +135,8 @@ struct Visitor : Pass {

virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericLiteral> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorLiteral> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericCompareExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->num = VisitAs<NumericLiteral>(std::move(node->num));
Expand All @@ -137,6 +149,19 @@ struct Visitor : Pass {
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorKnnExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->vector = VisitAs<VectorLiteral>(std::move(node->vector));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorRangeExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->range = VisitAs<NumericLiteral>(std::move(node->range));
node->vector = VisitAs<VectorLiteral>(std::move(node->vector));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) {
for (auto &n : node->inners) {
n = TransformAs<QueryExpr>(std::move(n));
Expand Down Expand Up @@ -173,6 +198,10 @@ struct Visitor : Pass {

virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagFieldScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<HnswVectorFieldRangeScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<HnswVectorFieldKnnScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<Filter> node) {
node->source = TransformAs<PlanOperator>(std::move(node->source));
node->filter_expr = TransformAs<QueryExpr>(std::move(node->filter_expr));
Expand Down
2 changes: 1 addition & 1 deletion src/search/ir_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct TagFieldScan : FieldScan {

struct HnswVectorFieldKnnScan : FieldScan {
kqir::NumericArray vector;
uint16_t k;
uint32_t k;

HnswVectorFieldKnnScan(std::unique_ptr<FieldRef> field, kqir::NumericArray vector, uint16_t k)
: FieldScan(std::move(field)), vector(std::move(vector)), k(k) {}
Expand Down
3 changes: 0 additions & 3 deletions src/search/ir_sema_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ struct SemaChecker {
return {Status::NotOK,
fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)};
}
if (v->k->val <= 0) {
return {Status::NotOK, fmt::format("KNN search parameter `k` must be greater than 0")};
}
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
Expand Down
10 changes: 10 additions & 0 deletions src/search/passes/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ struct CostModel {
if (auto v = dynamic_cast<const FullIndexScan *>(node)) {
return Visit(v);
}
if (auto v = dynamic_cast<const HnswVectorFieldKnnScan *>(node)) {
return Visit(v);
}
if (auto v = dynamic_cast<const HnswVectorFieldRangeScan *>(node)) {
return Visit(v);
}
if (auto v = dynamic_cast<const NumericFieldScan *>(node)) {
return Visit(v);
}
Expand Down Expand Up @@ -74,6 +80,10 @@ struct CostModel {

static size_t Visit(const TagFieldScan *node) { return 10; }

static size_t Visit(const HnswVectorFieldKnnScan *node) { return 3; }

static size_t Visit(const HnswVectorFieldRangeScan *node) { return 4; }

static size_t Visit(const Filter *node) { return Transform(node->source.get()) + 1; }

static size_t Visit(const Merge *node) {
Expand Down
23 changes: 23 additions & 0 deletions src/search/passes/index_selection.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ struct IndexSelection : Visitor {
if (auto v = dynamic_cast<OrExpr *>(node)) {
return VisitExpr(v);
}
if (auto v = dynamic_cast<VectorKnnExpr *>(node)) {
return VisitExpr(v);
}
if (auto v = dynamic_cast<VectorRangeExpr *>(node)) {
return VisitExpr(v);
}
if (auto v = dynamic_cast<NumericCompareExpr *>(node)) {
return VisitExpr(v);
}
Expand Down Expand Up @@ -153,6 +159,23 @@ struct IndexSelection : Visitor {
return MakeFullIndexFilter(node);
}

std::unique_ptr<PlanOperator> VisitExpr(VectorRangeExpr *node) const {
if (node->field->info->HasIndex()) {
return std::make_unique<HnswVectorFieldRangeScan>(node->field->CloneAs<FieldRef>(), node->vector->values,
node->range->val);
}

return MakeFullIndexFilter(node);
}

std::unique_ptr<PlanOperator> VisitExpr(VectorKnnExpr *node) const {
if (node->field->info->HasIndex()) {
return std::make_unique<HnswVectorFieldKnnScan>(node->field->CloneAs<FieldRef>(), node->vector->values, node->k);
}

return MakeFullIndexFilter(node);
}

template <typename Expr>
std::unique_ptr<PlanOperator> VisitExprImpl(Expr *node) {
struct AggregatedNodes {
Expand Down
4 changes: 3 additions & 1 deletion src/search/passes/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "search/passes/simplify_and_or_expr.h"
#include "search/passes/simplify_boolean.h"
#include "search/passes/sort_limit_fuse.h"
#include "search/passes/transfer_to_knn.h"
#include "type_util.h"

namespace kqir {
Expand Down Expand Up @@ -86,7 +87,8 @@ struct PassManager {
}

static PassSequence ExprPasses() {
return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{});
return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{},
TransferSortByToKnnExpr{}, SimplifyAndOrExpr{});
}
static PassSequence NumericPasses() { return Create(IntervalAnalysis{true}, SimplifyAndOrExpr{}, SimplifyBoolean{}); }
static PassSequence PlanPasses() { return Create(LowerToPlan{}, IndexSelection{}, SortLimitFuse{}); }
Expand Down
62 changes: 62 additions & 0 deletions src/search/passes/transfer_to_knn.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include <memory>

#include "search/ir.h"
#include "search/ir_pass.h"
#include "search/ir_plan.h"

namespace kqir {

struct TransferSortByToKnnExpr : Visitor {
std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) override {
node = Node::MustAs<SearchExpr>(Visitor::Visit(std::move(node)));

if (node->sort_by && node->sort_by->IsVectorField() && node->limit) {
std::vector<std::unique_ptr<QueryExpr>> exprs;
auto knn_expr = std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(node->sort_by->TakeFieldRef()),
Node::MustAs<VectorLiteral>(node->sort_by->TakeVectorLiteral()),
node->limit->Count());
if (auto b = Node::As<BoolLiteral>(std::move(node->query_expr))) {
if (b->val) exprs.push_back(std::move(knn_expr));
} else {
exprs.push_back(std::move(node->query_expr));
exprs.push_back(std::move(knn_expr));
}

if (exprs.empty()) {
node->query_expr = std::make_unique<BoolLiteral>(false);
} else if (exprs.size() == 1) {
node->query_expr = std::move(exprs[0]);
} else {
node->query_expr = std::make_unique<AndExpr>(std::move(exprs));
}
node->sort_by.reset();
node->limit.reset();
}

return node;
}
};

} // namespace kqir
3 changes: 2 additions & 1 deletion src/search/redis_query_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ struct Tag : sor<Identifier, StringL, Param> {};
struct TagList : seq<one<'{'>, WSPad<Tag>, star<seq<one<'|'>, WSPad<Tag>>>, one<'}'>> {};

struct NumberOrParam : sor<Number, Param> {};
struct UintOrParam : sor<UnsignedInteger, Param> {};

struct Inf : seq<opt<one<'+', '-'>>, string<'i', 'n', 'f'>> {};
struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>, WSPad<NumericRangePart>, one<']'>> {};

struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<NumberOrParam>, WSPad<Field>, WSPad<Param>, one<']'>> {};
struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<UintOrParam>, WSPad<Field>, WSPad<Param>, one<']'>> {};
struct VectorRange : seq<one<'['>, WSPad<VectorRangeToken>, WSPad<NumberOrParam>, WSPad<Param>, one<']'>> {};

struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<VectorRange, TagList, NumericRange>>> {};
Expand Down
12 changes: 9 additions & 3 deletions src/search/redis_query_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace ir = kqir;

template <typename Rule>
using TreeSelector = parse_tree::selector<
Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
Rule, parse_tree::store_content::on<Number, UnsignedInteger, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, VectorRange, ExclusiveNumber, FieldQuery, NotExpr, AndExpr,
OrExpr, PrefilterExpr, KnnSearch, Wildcard, VectorRangeToken, KnnToken, ArrowOp>>;

Expand Down Expand Up @@ -168,9 +168,15 @@ struct Transformer : ir::TreeTransformer {
const auto& knn_search = node->children[2];
CHECK(knn_search->children.size() == 4);

size_t k = 0;
if (Is<UnsignedInteger>(knn_search->children[1])) {
k = *ParseInt(knn_search->children[1]->string());
} else {
k = *ParseInt(GET_OR_RET(GetParam(node)));
}

return std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(number_or_param(knn_search->children[1])),
GET_OR_RET(Transform2Vector(knn_search->children[3])));
GET_OR_RET(Transform2Vector(knn_search->children[3])), k);

} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
Expand Down
1 change: 0 additions & 1 deletion src/search/sql_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ struct Transformer : ir::TreeTransformer {
return {Status::NotOK, "the left and right side of numeric comparison should be an identifier and a number"};
}
} else if (Is<VectorRangeExpr>(node)) {
// TODO(Beihao): Handle distance metrics for operator
CHECK(node->children.size() == 2);
const auto& vector_comp_expr = node->children[0];
CHECK(vector_comp_expr->children.size() == 3);
Expand Down
47 changes: 47 additions & 0 deletions tests/cppunit/ir_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ TEST(IRPassTest, Manager) {
"select * from a where (and x <= 1, y >= 2, z != 3)");
}

TEST(IRPassTest, TransferSortByToKnnExpr) {
TransferSortByToKnnExpr tsbtke;

ASSERT_EQ(tsbtke.Transform(*Parse("select a from b order by embedding <-> [3.6] limit 5"))->Dump(),
"select a from b where KNN k=5, embedding <-> [3.600000]");
ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where c = 1 order by embedding <-> [3,1,2] limit 5"))->Dump(),
"select a from b where (and c = 1, KNN k=5, embedding <-> [3.000000, 1.000000, 2.000000])");
}

TEST(IRPassTest, LowerToPlan) {
LowerToPlan ltp;

Expand All @@ -123,6 +132,8 @@ TEST(IRPassTest, LowerToPlan) {
"project a: (sort d, asc: (filter c = 1: full-scan b))");
ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 limit 1"))->Dump(),
"project a: (limit 0, 1: (filter c = 1: full-scan b))");
ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 and d = 2 order by e limit 1"))->Dump(),
"project a: (limit 0, 1: (sort e, asc: (filter (and c = 1, d = 2): full-scan b)))");
ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit 1"))->Dump(),
"project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan b)))");
}
Expand Down Expand Up @@ -176,12 +187,28 @@ static IndexMap MakeIndexMap() {
auto f4 = FieldInfo("n2", std::make_unique<redis::NumericFieldMetadata>());
auto f5 = FieldInfo("n3", std::make_unique<redis::NumericFieldMetadata>());
f5.metadata->noindex = true;

auto hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
hnsw_field_meta->dim = 3;
hnsw_field_meta->distance_metric = redis::DistanceMetric::L2;
auto f6 = FieldInfo("v1", std::move(hnsw_field_meta));

hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
hnsw_field_meta->dim = 3;
hnsw_field_meta->distance_metric = redis::DistanceMetric::L2;
auto f7 = FieldInfo("v2", std::move(hnsw_field_meta));
f7.metadata->noindex = true;

auto ia = std::make_unique<IndexInfo>("ia", redis::IndexMetadata(), "");
ia->Add(std::move(f1));
ia->Add(std::move(f2));
ia->Add(std::move(f3));
ia->Add(std::move(f4));
ia->Add(std::move(f5));
ia->Add(std::move(f6));
ia->Add(std::move(f7));

IndexMap res;
res.Insert(std::move(ia));
Expand Down Expand Up @@ -238,6 +265,26 @@ TEST(IRPassTest, IndexSelection) {
"project *: (filter t2 hastag \"a\": tag-scan t1, a)");
ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where t2 hastag \"a\""))->Dump(),
"project *: (filter t2 hastag \"a\": full-scan ia)");
ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where v1 <-> [3,1,2] < 5"))->Dump(),
"project *: hnsw-vector-range-scan v1, [3.000000, 1.000000, 2.000000], 5");
ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by v1 <-> [3,1,2] limit 5"))->Dump(),
"project *: hnsw-vector-knn-scan v1, [3.000000, 1.000000, 2.000000], 5");
ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where v2 <-> [3,1,2] < 5"))->Dump(),
"project *: (filter v2 <-> [3.000000, 1.000000, 2.000000] < 5: full-scan ia)");
ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 and v1 <-> [3,1,2] < 5"))->Dump(),
"project *: (filter n1 >= 1: hnsw-vector-range-scan v1, [3.000000, 1.000000, 2.000000], 5)");
ASSERT_EQ(
PassManager::Execute(passes, ParseS(sc, "select * from ia where v1 <-> [3,1,2] < 5 and t1 hastag \"a\""))->Dump(),
"project *: (filter t1 hastag \"a\": hnsw-vector-range-scan v1, [3.000000, 1.000000, 2.000000], 5)");
ASSERT_EQ(
PassManager::Execute(passes, ParseS(sc, "select * from ia where t1 hastag \"a\" order by v1 <-> [3,1,2] limit 5"))
->Dump(),
"project *: (filter t1 hastag \"a\": hnsw-vector-knn-scan v1, [3.000000, 1.000000, 2.000000], 5)");
ASSERT_EQ(PassManager::Execute(
passes, ParseS(sc, "select * from ia where v1 <-> [3,1,2] < 2 order by v1 <-> [3,1,2] limit 5"))
->Dump(),
"project *: (filter v1 <-> [3.000000, 1.000000, 2.000000] < 2: hnsw-vector-knn-scan v1, [3.000000, "
"1.000000, 2.000000], 5)");

ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 2 or n1 < 1"))->Dump(),
"project *: (merge numeric-scan n1, [-inf, 1), asc, numeric-scan n1, [2, inf), asc)");
Expand Down
Loading

0 comments on commit b7fcd21

Please sign in to comment.