Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lateral join optimizer for larger than vector size inputs, allow any table function for outer get #31

Merged
merged 2 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/hnsw/hnsw_optimize_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ OperatorResultType PhysicalHNSWIndexJoin::Execute(ExecutionContext &context, Dat
for (idx_t batch_idx = 0; batch_idx < batch_count; batch_idx++, state.input_idx++) {

// Get the next batch
const auto rhs_vector_data = rhs_vector_ptr + batch_idx * rhs_vector_size;
const auto rhs_vector_data = rhs_vector_ptr + state.input_idx * rhs_vector_size;

// Scan the index for row ids
const auto match_count = hnsw_index.ExecuteMultiScan(*state.index_state, rhs_vector_data, limit);
for (idx_t i = 0; i < match_count; i++) {
state.match_sel.set_index(output_idx, batch_idx);
state.match_sel.set_index(output_idx, state.input_idx);
row_number_vector[output_idx] = i + 1; // Note: 1-indexed!
output_idx++;
}
Expand Down Expand Up @@ -337,9 +337,6 @@ bool HNSWIndexJoinOptimizer::TryOptimize(Binder &binder, ClientContext &context,
MATCH_OPERATOR(delim_join.children[1], LOGICAL_GET, 0);
auto outer_get_ptr = &delim_join.children[1];
auto &outer_get = (*outer_get_ptr)->Cast<LogicalGet>();
if (outer_get.function.name != "seq_scan") {
return false;
}

// branch
// There might not be a projection here if we keep the distance function.
Expand Down
114 changes: 114 additions & 0 deletions test/sql/hnsw/hnsw_lateral_join_group_large.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
require vss

statement ok
SELECT setseed(0.1337);

statement ok
CREATE TABLE queries (id INT, embedding FLOAT[3]);

statement ok
INSERT INTO queries SELECT i, [random(), random(), random()]::FLOAT[3] FROM range(1, 1000) as r(i);

statement ok
CREATE TABLE items (id INT, embedding FLOAT[3]);

statement ok
INSERT INTO items SELECT i, [random(), random(), random()]::FLOAT[3] FROM range(1, 1000) as r(i);


# Sanity check, total cardinality
query I
SELECT COUNT(*)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
);
----
2997

query I rowsort result_total
SELECT COUNT(*)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
);
----

# Sanity check, groups of 3
query I rowsort result_count
SELECT count(*) FROM (
SELECT queries.id as id, any_value(nbr)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
) GROUP BY queries.id
)
----


query II rowsort result_scan
SELECT queries.id as id, list(nbr ORDER BY nbr) result_scan
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
) GROUP BY queries.id
----


# Now create an index
statement ok
CREATE INDEX items_embedding_idx ON items USING hnsw(embedding);

query I rowsort result_total
SELECT COUNT(*)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
);
----

query I rowsort result_count
SELECT count(*) FROM (
SELECT queries.id as id, any_value(nbr)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
) GROUP BY queries.id
)
----

query II rowsort result_scan
SELECT queries.id as id, list(nbr ORDER BY nbr)
FROM queries, LATERAL (
SELECT
items.id as nbr,
array_distance(items.embedding, queries.embedding) as dist
FROM items
ORDER BY dist
LIMIT 3
) GROUP BY queries.id
----
Loading