Skip to content

Commit

Permalink
fix lateral join group indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxxen committed Oct 21, 2024
1 parent 1f52e72 commit b38c209
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/hnsw/hnsw_optimize_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ OperatorResultType PhysicalHNSWIndexJoin::Execute(ExecutionContext &context, Dat
for (idx_t batch_idx = 0; batch_idx < batch_count; batch_idx++, state.input_idx++) {

// Get the next batch
const auto rhs_vector_data = rhs_vector_ptr + batch_idx * rhs_vector_size;
const auto rhs_vector_data = rhs_vector_ptr + state.input_idx * rhs_vector_size;

// Scan the index for row ids
const auto match_count = hnsw_index.ExecuteMultiScan(*state.index_state, rhs_vector_data, limit);
for (idx_t i = 0; i < match_count; i++) {
state.match_sel.set_index(output_idx, batch_idx);
state.match_sel.set_index(output_idx, state.input_idx);
row_number_vector[output_idx] = i + 1; // Note: 1-indexed!
output_idx++;
}
Expand Down
114 changes: 114 additions & 0 deletions test/sql/hnsw/hnsw_lateral_join_group_large.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
require vss

statement ok
SELECT setseed(0.1337);

statement ok
CREATE TABLE queries (id INT, embedding FLOAT[3]);

statement ok
INSERT INTO queries SELECT i, [random(), random(), random()]::FLOAT[3] FROM range(1, 1000) as r(i);

statement ok
CREATE TABLE items (id INT, embedding FLOAT[3]);

statement ok
INSERT INTO items SELECT i, [random(), random(), random()]::FLOAT[3] FROM range(1, 1000) as r(i);


# Sanity check, total cardinality
query I
SELECT COUNT(*)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
);
----
2997

query I rowsort result_total
SELECT COUNT(*)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
);
----

# Sanity check, groups of 3
query I rowsort result_count
SELECT count(*) FROM (
SELECT queries.id as id, any_value(nbr)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
) GROUP BY queries.id
)
----


query II rowsort result_scan
SELECT queries.id as id, list(nbr ORDER BY nbr) result_scan
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
) GROUP BY queries.id
----


# Now create an index
statement ok
CREATE INDEX items_embedding_idx ON items USING hnsw(embedding);

query I rowsort result_total
SELECT COUNT(*)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
);
----

query I rowsort result_count
SELECT count(*) FROM (
SELECT queries.id as id, any_value(nbr)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
) GROUP BY queries.id
)
----

query II rowsort result_scan
SELECT queries.id as id, list(nbr ORDER BY nbr)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
) GROUP BY queries.id
----

0 comments on commit b38c209

Please sign in to comment.