diff --git a/src/hnsw/hnsw_optimize_join.cpp b/src/hnsw/hnsw_optimize_join.cpp index 1b28ac5..8a6e283 100644 --- a/src/hnsw/hnsw_optimize_join.cpp +++ b/src/hnsw/hnsw_optimize_join.cpp @@ -137,12 +137,12 @@ OperatorResultType PhysicalHNSWIndexJoin::Execute(ExecutionContext &context, Dat for (idx_t batch_idx = 0; batch_idx < batch_count; batch_idx++, state.input_idx++) { // Get the next batch - const auto rhs_vector_data = rhs_vector_ptr + batch_idx * rhs_vector_size; + const auto rhs_vector_data = rhs_vector_ptr + state.input_idx * rhs_vector_size; // Scan the index for row ids const auto match_count = hnsw_index.ExecuteMultiScan(*state.index_state, rhs_vector_data, limit); for (idx_t i = 0; i < match_count; i++) { - state.match_sel.set_index(output_idx, batch_idx); + state.match_sel.set_index(output_idx, state.input_idx); row_number_vector[output_idx] = i + 1; // Note: 1-indexed! output_idx++; } @@ -337,9 +337,6 @@ bool HNSWIndexJoinOptimizer::TryOptimize(Binder &binder, ClientContext &context, MATCH_OPERATOR(delim_join.children[1], LOGICAL_GET, 0); auto outer_get_ptr = &delim_join.children[1]; auto &outer_get = (*outer_get_ptr)->Cast(); - if (outer_get.function.name != "seq_scan") { - return false; - } // branch // There might not be a projection here if we keep the distance function. diff --git a/test/sql/hnsw/hnsw_lateral_join_group_large.test b/test/sql/hnsw/hnsw_lateral_join_group_large.test new file mode 100644 index 0000000..4c49794 --- /dev/null +++ b/test/sql/hnsw/hnsw_lateral_join_group_large.test @@ -0,0 +1,114 @@ +require vss + +statement ok +SELECT setseed(0.1337); + +statement ok +CREATE TABLE queries (id INT, embedding FLOAT[3]); + +statement ok +INSERT INTO queries SELECT i, [random(), random(), random()]::FLOAT[3] FROM range(1, 1000) as r(i); + +statement ok +CREATE TABLE items (id INT, embedding FLOAT[3]); + +statement ok +INSERT INTO items SELECT i, [random(), random(), random()]::FLOAT[3] FROM range(1, 1000) as r(i); + + +# Sanity check, total cardinality +query I +SELECT COUNT(*) +FROM queries, LATERAL ( + SELECT + items.id as nbr, + array_distance(items.embedding, queries.embedding) as dist + FROM items + ORDER BY dist + LIMIT 3 +); +---- +2997 + +query I rowsort result_total +SELECT COUNT(*) +FROM queries, LATERAL ( + SELECT + items.id as nbr, + array_distance(items.embedding, queries.embedding) as dist + FROM items + ORDER BY dist + LIMIT 3 +); +---- + +# Sanity check, groups of 3 +query I rowsort result_count +SELECT count(*) FROM ( + SELECT queries.id as id, any_value(nbr) + FROM queries, LATERAL ( + SELECT + items.id as nbr, + array_distance(items.embedding, queries.embedding) as dist + FROM items + ORDER BY dist + LIMIT 3 + ) GROUP BY queries.id +) +---- + + +query II rowsort result_scan +SELECT queries.id as id, list(nbr ORDER BY nbr) result_scan + FROM queries, LATERAL ( + SELECT + items.id as nbr, + array_distance(items.embedding, queries.embedding) as dist + FROM items + ORDER BY dist + LIMIT 3 + ) GROUP BY queries.id +---- + + +# Now create an index +statement ok +CREATE INDEX items_embedding_idx ON items USING hnsw(embedding); + +query I rowsort result_total +SELECT COUNT(*) +FROM queries, LATERAL ( + SELECT + items.id as nbr, + array_distance(items.embedding, queries.embedding) as dist + FROM items + ORDER BY dist + LIMIT 3 +); +---- + +query I rowsort result_count +SELECT count(*) FROM ( + SELECT queries.id as id, any_value(nbr) + FROM queries, LATERAL ( + SELECT + items.id as nbr, + array_distance(items.embedding, queries.embedding) as dist + FROM items + ORDER BY dist + LIMIT 3 + ) GROUP BY queries.id +) +---- + +query II rowsort result_scan +SELECT queries.id as id, list(nbr ORDER BY nbr) + FROM queries, LATERAL ( + SELECT + items.id as nbr, + array_distance(items.embedding, queries.embedding) as dist + FROM items + ORDER BY dist + LIMIT 3 + ) GROUP BY queries.id +----