Skip to content

Commit

Permalink
Remove frontend transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Beihao-Zhou committed Jul 31, 2024
1 parent 0529a98 commit 5d87aa2
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 30 deletions.
7 changes: 6 additions & 1 deletion src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,12 @@ struct SortByClause : Node {
bool IsVectorField() const { return vector != nullptr; }

std::string_view Name() const override { return "SortByClause"; }
std::string Dump() const override { return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); }
std::string Dump() const override {
if (!IsVectorField()) {
return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order));
}
return fmt::format("sortby {} <-> {}", field->Dump(), vector->Dump());
}
std::string Content() const override { return OrderToString(order); }

NodeIterator ChildBegin() override { return NodeIterator(field.get()); };
Expand Down
30 changes: 16 additions & 14 deletions src/search/ir_sema_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ struct SemaChecker {
GET_OR_RET(Check(v->query_expr.get()));
if (v->limit) GET_OR_RET(Check(v->limit.get()));
if (v->sort_by) GET_OR_RET(Check(v->sort_by.get()));
if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) {
return {Status::NotOK, "invalid knn search clause without limit"};
}
} else {
return {Status::NotOK, fmt::format("index `{}` not found", index_name)};
}
Expand All @@ -60,8 +63,21 @@ struct SemaChecker {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.IsSortable()) {
return {Status::NotOK, fmt::format("field `{}` is not sortable", v->field->name)};
} else if (auto meta = iter->second.MetadataAs<redis::HnswVectorFieldMetadata>();
(meta != nullptr) != v->IsVectorField()) {
std::string not_str = meta ? "" : "not ";
return {Status::NotOK,
fmt::format("field `{}` is {}a vector field according to metadata and does {}expect a vector parameter",
v->field->name, not_str, not_str)};
} else {
v->field->info = &iter->second;
if (v->IsVectorField()) {
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)};
}
}
}
} else if (auto v = dynamic_cast<AndExpr *>(node)) {
for (const auto &n : v->inners) {
Expand Down Expand Up @@ -97,20 +113,6 @@ struct SemaChecker {
} else {
v->field->info = &iter->second;
}
} else if (auto v = dynamic_cast<VectorKnnExpr *>(node)) {
if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.MetadataAs<redis::HnswVectorFieldMetadata>()) {
return {Status::NotOK, fmt::format("field `{}` is not a vector field", v->field->name)};
} else {
v->field->info = &iter->second;

auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)};
}
}
} else if (auto v = dynamic_cast<VectorRangeExpr *>(node)) {
if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
Expand Down
2 changes: 2 additions & 0 deletions src/search/search_encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata {

HnswVectorFieldMetadata() : IndexFieldMetadata(IndexFieldType::VECTOR) {}

bool IsSortable() const override { return true; }

void Encode(std::string *dst) const override {
IndexFieldMetadata::Encode(dst);
PutFixed8(dst, uint8_t(vector_type));
Expand Down
12 changes: 0 additions & 12 deletions src/search/sql_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,6 @@ struct Transformer : ir::TreeTransformer {
query_expr = std::make_unique<BoolLiteral>(true);
}

if (sort_by && sort_by->IsVectorField()) {
if (!limit) {
return {Status::NotOK, "invalid knn search clause without limit"};
}
CHECK(limit->Offset() == 0);
query_expr = std::make_unique<VectorKnnExpr>(std::move(sort_by->TakeFieldRef()),
std::make_unique<NumericLiteral>(limit->Count()),
std::move(sort_by->TakeVectorLiteral()));
sort_by.reset();
limit.reset();
}

return Node::Create<ir::SearchExpr>(std::move(index), std::move(query_expr), std::move(limit), std::move(sort_by),
std::move(select));
} else if (IsRoot(node)) {
Expand Down
7 changes: 6 additions & 1 deletion tests/cppunit/ir_sema_checker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ TEST(SemaCheckerTest, Simple) {
"vector should be of size `3` for field `f4`");
ASSERT_EQ(checker.Check(Parse("select f4 from ia where f4 <-> [3.6,4.7,5.6] < -5")->get()).Msg(),
"range cannot be a negative number for l2 distance metric");
ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f4 limit 5")->get()).Msg(), "field `f4` is not sortable");
ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f4 limit 5")->get()).Msg(),
"field `f4` is a vector field according to metadata and does expect a vector parameter");
ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f1 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(),
"field `f1` is not sortable");
ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f2 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(),
"field `f2` is not a vector field according to metadata and does not expect a vector parameter");
}

{
Expand Down
3 changes: 1 addition & 2 deletions tests/cppunit/sql_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,11 @@ TEST(SQLParserTest, Vector) {
AssertSyntaxError(Parse("select a from b where embedding <-> [] < 5"));
AssertSyntaxError(Parse("select a from b order by embedding <-> @vec limit 5", {{"vec", "[3.6,7.8]"}}));
AssertSyntaxError(Parse("select a from b where embedding <#> [3,1,2] < 5"));
ASSERT_EQ(Parse("select a from b order by embedding <-> [1,2,3]").Msg(), "invalid knn search clause without limit");

AssertIR(Parse("select a from b where embedding <-> [3,1,2] < 5"),
"select a from b where embedding <-> [3.000000, 1.000000, 2.000000] < 5");
AssertIR(Parse("select a from b where embedding <-> [0.5,0.5] < 10 and c > 100"),
"select a from b where (and embedding <-> [0.500000, 0.500000] < 10, c > 100)");
AssertIR(Parse("select a from b order by embedding <-> [3.6] limit 5"),
"select a from b where embedding <-> [3.600000] knn 5");
"select a from b where true sortby embedding <-> [3.600000] limit 0, 5");
}

0 comments on commit 5d87aa2

Please sign in to comment.