Skip to content

Commit

Permalink
Merge pull request #27 from Maxxen/dev
Browse files Browse the repository at this point in the history
Fix non-euclidean distance metrics + more optimizer rules
  • Loading branch information
Maxxen authored Sep 5, 2024
2 parents 3e192f2 + 58aa5df commit 77739ea
Show file tree
Hide file tree
Showing 15 changed files with 544 additions and 105 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ CREATE INDEX my_hnsw_cosine_index ON my_vector_table USING HNSW (vec) WITH (metr

The following table shows the supported distance metrics and their corresponding DuckDB functions

| Description | Metric | Function |
| --- | --- | --- |
| Euclidean distance | `l2sq` | `array_distance` |
| Cosine similarity | `cosine` | `array_cosine_similarity` |
| Inner product | `ip` | `array_inner_product` |
| Description | Metric | Function |
| --- | --- |--------------------------------|
| Euclidean distance | `l2sq` | `array_distance` |
| Cosine similarity | `cosine` | `array_cosine_distance` |
| Inner product | `ip` | `array_negative_inner_product` |

## Inserts, Updates, Deletes and Re-Compaction

Expand Down
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 475 files
2 changes: 1 addition & 1 deletion extension-ci-tools
5 changes: 4 additions & 1 deletion src/hnsw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ set(EXTENSION_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_pragmas.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_plan_index_create.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_plan_index_scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_topk_operator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_optimize_topk.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_optimize_expr.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_optimize_scan.cpp
PARENT_SCOPE
)
51 changes: 41 additions & 10 deletions src/hnsw/hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "duckdb/common/serializer/binary_serializer.hpp"
#include "duckdb/execution/index/fixed_size_allocator.hpp"
#include "duckdb/storage/table/scan_state.hpp"
#include "duckdb/planner/operator/logical_get.hpp"
#include "hnsw/hnsw.hpp"

namespace duckdb {
Expand Down Expand Up @@ -227,23 +228,19 @@ string HNSWIndex::GetMetric() const {
}
}

bool HNSWIndex::IsDistanceFunction(const string &distance_function_name) {
auto accepted_functions = {"array_distance", "array_cosine_similarity", "array_inner_product"};
return std::find(accepted_functions.begin(), accepted_functions.end(), distance_function_name) !=
accepted_functions.end();
}

bool HNSWIndex::MatchesDistanceFunction(const string &distance_function_name) const {
if (distance_function_name == "array_distance" &&
bool HNSWIndex::MatchesDistanceFunction(const string &name) const {
if ((name == "array_distance" || name == "<->") &&
index.metric().metric_kind() == unum::usearch::metric_kind_t::l2sq_k) {
// Note: usearch uses l2sq, for their metric, but its functionally equivalent to sqrt(l2sq)
return true;
}
if (distance_function_name == "array_cosine_similarity" &&
if ((name == "array_cosine_distance" || name == "<=>") &&
index.metric().metric_kind() == unum::usearch::metric_kind_t::cos_k) {
return true;
}
if (distance_function_name == "array_inner_product" &&
if ((name == "array_negative_inner_product" || name == "<#>") &&
index.metric().metric_kind() == unum::usearch::metric_kind_t::ip_k) {
// Note: usearch uses (1.0 - ip) for their metric, but its functionally equivalent to (-ip)
return true;
}
return false;
Expand Down Expand Up @@ -536,6 +533,40 @@ void HNSWIndex::VerifyAllocations(IndexLock &state) {
throw NotImplementedException("HNSWIndex::VerifyAllocations() not implemented");
}

//------------------------------------------------------------------------------
// Can rewrite index expression?
//------------------------------------------------------------------------------
static void RewriteIndexExpression(const Index &index, LogicalGet &get, Expression &expr, bool &rewrite_possible,
bool &any_column_ref) {
if (expr.type == ExpressionType::BOUND_COLUMN_REF) {
any_column_ref = true;
auto &bound_colref = expr.Cast<BoundColumnRefExpression>();
// bound column ref: rewrite to fit in the current set of bound column ids
bound_colref.binding.table_index = get.table_index;
auto &column_ids = index.GetColumnIds();
auto &get_column_ids = get.GetColumnIds();
column_t referenced_column = column_ids[bound_colref.binding.column_index];
// search for the referenced column in the set of column_ids
for (idx_t i = 0; i < get_column_ids.size(); i++) {
if (get_column_ids[i] == referenced_column) {
bound_colref.binding.column_index = i;
return;
}
}
// column id not found in bound columns in the LogicalGet: rewrite not possible
rewrite_possible = false;
}
ExpressionIterator::EnumerateChildren(
expr, [&](Expression &child) { RewriteIndexExpression(index, get, child, rewrite_possible, any_column_ref); });
}

bool HNSWIndex::CanRewriteIndexExpression(LogicalGet &get, Expression &column_ref) const {
bool rewrite_possible = true;
bool any_column_ref = false;
RewriteIndexExpression(*this, get, column_ref, rewrite_possible, any_column_ref);
return any_column_ref && rewrite_possible;
}

//------------------------------------------------------------------------------
// Register Index Type
//------------------------------------------------------------------------------
Expand Down
25 changes: 11 additions & 14 deletions src/hnsw/hnsw_index_physical_create.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

namespace duckdb {

PhysicalCreateHNSWIndex::PhysicalCreateHNSWIndex(LogicalOperator &op, TableCatalogEntry &table,
PhysicalCreateHNSWIndex::PhysicalCreateHNSWIndex(LogicalOperator &op, TableCatalogEntry &table_p,
const vector<column_t> &column_ids, unique_ptr<CreateIndexInfo> info,
vector<unique_ptr<Expression>> unbound_expressions,
idx_t estimated_cardinality)
// Declare this operators as a EXTENSION operator
: PhysicalOperator(PhysicalOperatorType::EXTENSION, op.types, estimated_cardinality),
table(table.Cast<DuckTableEntry>()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)),
table(table_p.Cast<DuckTableEntry>()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)),
sorted(false) {

// convert virtual column ids to storage column ids
Expand All @@ -34,7 +34,8 @@ PhysicalCreateHNSWIndex::PhysicalCreateHNSWIndex(LogicalOperator &op, TableCatal
//-------------------------------------------------------------
class CreateHNSWIndexGlobalState final : public GlobalSinkState {
public:
CreateHNSWIndexGlobalState(const PhysicalOperator &op_p) : op(op_p) {}
CreateHNSWIndexGlobalState(const PhysicalOperator &op_p) : op(op_p) {
}

const PhysicalOperator &op;
//! Global index to be added to the table
Expand Down Expand Up @@ -262,21 +263,17 @@ class HNSWIndexConstructionEvent final : public BasePipelineEvent {
// Create the index entry in the catalog
auto &schema = table.schema;
info.column_ids = storage_ids;
const auto index_entry = schema.CreateIndex(*gstate.context, info, table).get();
if (!index_entry) {
D_ASSERT(info.on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT);
// index already exists, but error ignored because of IF NOT EXISTS
// return SinkFinalizeType::READY;
return;

if (schema.GetEntry(schema.GetCatalogTransaction(*gstate.context), CatalogType::INDEX_ENTRY, info.index_name)) {
if (info.on_conflict != OnCreateConflict::IGNORE_ON_CONFLICT) {
throw CatalogException("Index with name \"%s\" already exists", info.index_name);
}
}

// Get the entry as a DuckIndexEntry
const auto index_entry = schema.CreateIndex(schema.GetCatalogTransaction(*gstate.context), info, table).get();
D_ASSERT(index_entry);
auto &duck_index = index_entry->Cast<DuckIndexEntry>();
duck_index.initial_index_size = gstate.global_index->Cast<BoundIndex>().GetInMemorySize();
duck_index.info = make_uniq<IndexDataTableInfo>(storage.GetDataTableInfo(), duck_index.name);
for (auto &parsed_expr : info.parsed_expressions) {
duck_index.parsed_expressions.push_back(parsed_expr->Copy());
}

// Finally add it to storage
storage.AddIndex(std::move(gstate.global_index));
Expand Down
98 changes: 98 additions & 0 deletions src/hnsw/hnsw_optimize_expr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/planner/expression_iterator.hpp"
#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
#include "duckdb/optimizer/column_binding_replacer.hpp"
#include "duckdb/optimizer/optimizer.hpp"

#include "hnsw/hnsw.hpp"
#include "hnsw/hnsw_index.hpp"

namespace duckdb {

//------------------------------------------------------------------------------
// Rewrite rules
//------------------------------------------------------------------------------
// This optimizer rewrites expressions of the form:
// (1.0 - array_cosine_similarity) => (array_cosine_distance)
// (-array_inner_product) => (array_negative_inner_product)

class CosineDistanceRule final : public Rule {
public:
explicit CosineDistanceRule(ExpressionRewriter &rewriter);
unique_ptr<Expression> Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, bool &changes_made,
bool is_root) override;
};

CosineDistanceRule::CosineDistanceRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
auto func = make_uniq<FunctionExpressionMatcher>();
func->matchers.push_back(make_uniq<ExpressionMatcher>());
func->matchers.push_back(make_uniq<ExpressionMatcher>());
func->policy = SetMatcher::Policy::UNORDERED;
func->function = make_uniq<SpecificFunctionMatcher>("array_cosine_similarity");

auto op = make_uniq<FunctionExpressionMatcher>();
op->matchers.push_back(make_uniq<ConstantExpressionMatcher>());
op->matchers[0]->type = make_uniq<SpecificTypeMatcher>(LogicalType::FLOAT);
op->matchers.push_back(std::move(func));
op->policy = SetMatcher::Policy::ORDERED;
op->function = make_uniq<SpecificFunctionMatcher>("-");
op->type = make_uniq<SpecificTypeMatcher>(LogicalType::FLOAT);

root = std::move(op);
}

unique_ptr<Expression> CosineDistanceRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings,
bool &changes_made, bool is_root) {
// auto &root_expr = bindings[0].get().Cast<BoundFunctionExpression>();
const auto &const_expr = bindings[1].get().Cast<BoundConstantExpression>();
auto &similarity_expr = bindings[2].get().Cast<BoundFunctionExpression>();

if (!const_expr.value.IsNull() && const_expr.value.GetValue<float>() == 1.0) {
// Create the new array_cosine_distance function
vector<unique_ptr<Expression>> args;
vector<LogicalType> arg_types;
arg_types.push_back(similarity_expr.children[0]->return_type);
arg_types.push_back(similarity_expr.children[1]->return_type);
args.push_back(std::move(similarity_expr.children[0]));
args.push_back(std::move(similarity_expr.children[1]));

auto &context = GetContext();
auto func_entry = Catalog::GetEntry<ScalarFunctionCatalogEntry>(context, "", "", "array_cosine_distance",
OnEntryNotFound::RETURN_NULL);

if (!func_entry) {
return nullptr;
}

changes_made = true;
auto func = func_entry->functions.GetFunctionByArguments(context, arg_types);
return make_uniq<BoundFunctionExpression>(similarity_expr.return_type, func, std::move(args), nullptr);
}
return nullptr;
}

//------------------------------------------------------------------------------
// Optimizer
//------------------------------------------------------------------------------
class HNSWExprOptimizer : public OptimizerExtension {
public:
HNSWExprOptimizer() {
optimize_function = Optimize;
}

static void Optimize(OptimizerExtensionInput &input, unique_ptr<LogicalOperator> &plan) {
ExpressionRewriter rewriter(input.context);
rewriter.rules.push_back(make_uniq<CosineDistanceRule>(rewriter));
rewriter.VisitOperator(*plan);
}
};

void HNSWModule::RegisterExprOptimizer(DatabaseInstance &db) {
// Register the TopKOptimizer
db.config.optimizer_extensions.push_back(HNSWExprOptimizer());
}

} // namespace duckdb
Loading

0 comments on commit 77739ea

Please sign in to comment.