diff --git a/src/search/ir.h b/src/search/ir.h index e33af8196f4..f82c5df269d 100644 --- a/src/search/ir.h +++ b/src/search/ir.h @@ -411,7 +411,12 @@ struct SortByClause : Node { bool IsVectorField() const { return vector != nullptr; } std::string_view Name() const override { return "SortByClause"; } - std::string Dump() const override { return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); } + std::string Dump() const override { + if (!IsVectorField()) { + return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); + } + return fmt::format("sortby {} <-> {}", field->Dump(), vector->Dump()); + } std::string Content() const override { return OrderToString(order); } NodeIterator ChildBegin() override { return NodeIterator(field.get()); }; diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h index 42d141f8e28..28ab8a4c051 100644 --- a/src/search/ir_sema_checker.h +++ b/src/search/ir_sema_checker.h @@ -50,6 +50,9 @@ struct SemaChecker { GET_OR_RET(Check(v->query_expr.get())); if (v->limit) GET_OR_RET(Check(v->limit.get())); if (v->sort_by) GET_OR_RET(Check(v->sort_by.get())); + if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) { + return {Status::NotOK, "invalid knn search clause without limit"}; + } } else { return {Status::NotOK, fmt::format("index `{}` not found", index_name)}; } @@ -60,8 +63,21 @@ struct SemaChecker { return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)}; } else if (!iter->second.IsSortable()) { return {Status::NotOK, fmt::format("field `{}` is not sortable", v->field->name)}; + } else if (auto meta = iter->second.MetadataAs(); + (meta != nullptr) != v->IsVectorField()) { + std::string not_str = meta ? "" : "not "; + return {Status::NotOK, + fmt::format("field `{}` is {}a vector field according to metadata and does {}expect a vector parameter", + v->field->name, not_str, not_str)}; } else { v->field->info = &iter->second; + if (v->IsVectorField()) { + auto meta = v->field->info->MetadataAs(); + if (v->vector->values.size() != meta->dim) { + return {Status::NotOK, + fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)}; + } + } } } else if (auto v = dynamic_cast(node)) { for (const auto &n : v->inners) { @@ -97,20 +113,6 @@ struct SemaChecker { } else { v->field->info = &iter->second; } - } else if (auto v = dynamic_cast(node)) { - if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) { - return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)}; - } else if (!iter->second.MetadataAs()) { - return {Status::NotOK, fmt::format("field `{}` is not a vector field", v->field->name)}; - } else { - v->field->info = &iter->second; - - auto meta = v->field->info->MetadataAs(); - if (v->vector->values.size() != meta->dim) { - return {Status::NotOK, - fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)}; - } - } } else if (auto v = dynamic_cast(node)) { if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) { return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)}; diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h index 2fbbde8c21e..26b442ca32c 100644 --- a/src/search/search_encoding.h +++ b/src/search/search_encoding.h @@ -373,6 +373,8 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata { HnswVectorFieldMetadata() : IndexFieldMetadata(IndexFieldType::VECTOR) {} + bool IsSortable() const override { return true; } + void Encode(std::string *dst) const override { IndexFieldMetadata::Encode(dst); PutFixed8(dst, uint8_t(vector_type)); diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h index ba9d692c276..01705107776 100644 --- a/src/search/sql_transformer.h +++ b/src/search/sql_transformer.h @@ -220,18 +220,6 @@ struct Transformer : ir::TreeTransformer { query_expr = std::make_unique(true); } - if (sort_by && sort_by->IsVectorField()) { - if (!limit) { - return {Status::NotOK, "invalid knn search clause without limit"}; - } - CHECK(limit->Offset() == 0); - query_expr = std::make_unique(std::move(sort_by->TakeFieldRef()), - std::make_unique(limit->Count()), - std::move(sort_by->TakeVectorLiteral())); - sort_by.reset(); - limit.reset(); - } - return Node::Create(std::move(index), std::move(query_expr), std::move(limit), std::move(sort_by), std::move(select)); } else if (IsRoot(node)) { diff --git a/tests/cppunit/ir_sema_checker_test.cc b/tests/cppunit/ir_sema_checker_test.cc index 4823e6b1029..92d733dae7a 100644 --- a/tests/cppunit/ir_sema_checker_test.cc +++ b/tests/cppunit/ir_sema_checker_test.cc @@ -80,7 +80,12 @@ TEST(SemaCheckerTest, Simple) { "vector should be of size `3` for field `f4`"); ASSERT_EQ(checker.Check(Parse("select f4 from ia where f4 <-> [3.6,4.7,5.6] < -5")->get()).Msg(), "range cannot be a negative number for l2 distance metric"); - ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f4 limit 5")->get()).Msg(), "field `f4` is not sortable"); + ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f4 limit 5")->get()).Msg(), + "field `f4` is a vector field according to metadata and does expect a vector parameter"); + ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f1 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(), + "field `f1` is not sortable"); + ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f2 <-> [3.6,4.7,5.6] limit 5")->get()).Msg(), + "field `f2` is not a vector field according to metadata and does not expect a vector parameter"); } { diff --git a/tests/cppunit/sql_parser_test.cc b/tests/cppunit/sql_parser_test.cc index 1e5f45f30db..102a3b70253 100644 --- a/tests/cppunit/sql_parser_test.cc +++ b/tests/cppunit/sql_parser_test.cc @@ -159,12 +159,11 @@ TEST(SQLParserTest, Vector) { AssertSyntaxError(Parse("select a from b where embedding <-> [] < 5")); AssertSyntaxError(Parse("select a from b order by embedding <-> @vec limit 5", {{"vec", "[3.6,7.8]"}})); AssertSyntaxError(Parse("select a from b where embedding <#> [3,1,2] < 5")); - ASSERT_EQ(Parse("select a from b order by embedding <-> [1,2,3]").Msg(), "invalid knn search clause without limit"); AssertIR(Parse("select a from b where embedding <-> [3,1,2] < 5"), "select a from b where embedding <-> [3.000000, 1.000000, 2.000000] < 5"); AssertIR(Parse("select a from b where embedding <-> [0.5,0.5] < 10 and c > 100"), "select a from b where (and embedding <-> [0.500000, 0.500000] < 10, c > 100)"); AssertIR(Parse("select a from b order by embedding <-> [3.6] limit 5"), - "select a from b where embedding <-> [3.600000] knn 5"); + "select a from b where true sortby embedding <-> [3.600000] limit 0, 5"); }