Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for propagating filters from compound to inner retrievers #117914

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/117914.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117914
summary: Fix for propagating filters from compound to inner retrievers
area: Ranking
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.TransportMultiSearchAction;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.rest.RestStatus;
Expand All @@ -46,6 +47,8 @@
*/
public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilder<T>> extends RetrieverBuilder {

public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
pmpailis marked this conversation as resolved.
Show resolved Hide resolved

public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}

protected final int rankWindowSize;
Expand All @@ -64,9 +67,9 @@ public T addChild(RetrieverBuilder retrieverBuilder) {

/**
* Returns a clone of the original retriever, replacing the sub-retrievers with
* the provided {@code newChildRetrievers}.
* the provided {@code newChildRetrievers} and the filters with the {@code newPreFilterQueryBuilders}.
*/
protected abstract T clone(List<RetrieverSource> newChildRetrievers);
protected abstract T clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders);

/**
* Combines the provided {@code rankResults} to return the final top documents.
Expand All @@ -85,13 +88,25 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
}

// Rewrite prefilters
boolean hasChanged = false;
// We eagerly rewrite prefilters, because some of the innerRetrievers
// could be compound too, so we want to propagate all the necessary filter information to them
// and have it available as part of their own rewrite step
var newPreFilters = rewritePreFilters(ctx);
hasChanged |= newPreFilters != preFilterQueryBuilders;
if (newPreFilters != preFilterQueryBuilders) {
return clone(innerRetrievers, newPreFilters);
}

boolean hasChanged = false;
// Rewrite retriever sources
List<RetrieverSource> newRetrievers = new ArrayList<>();
for (var entry : innerRetrievers) {
// we propagate the filters only for compound retrievers as they won't be attached through
// the createSearchSourceBuilder.
// We could remove this check, but we would end up adding the same filters
// multiple times in case an inner retriever rewrites itself, when we re-enter to rewrite
if (entry.retriever.isCompound() && false == preFilterQueryBuilders.isEmpty()) {
entry.retriever.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
}
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
if (newRetriever != entry.retriever) {
newRetrievers.add(new RetrieverSource(newRetriever, null));
Expand All @@ -106,7 +121,7 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
}
}
if (hasChanged) {
return clone(newRetrievers);
return clone(newRetrievers, newPreFilters);
}

// execute searches
Expand Down Expand Up @@ -166,12 +181,7 @@ public void onFailure(Exception e) {
});
});

return new RankDocsRetrieverBuilder(
rankWindowSize,
newRetrievers.stream().map(s -> s.retriever).toList(),
results::get,
newPreFilters
);
return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
ll.onResponse(null);
}));
});
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
return rewritten;
return new KnnRetrieverBuilder(this, () -> toSet.get(), null);
}
return super.rewrite(ctx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
final List<RetrieverBuilder> sources;
final Supplier<RankDoc[]> rankDocs;

public RankDocsRetrieverBuilder(
int rankWindowSize,
List<RetrieverBuilder> sources,
Supplier<RankDoc[]> rankDocs,
List<QueryBuilder> preFilterQueryBuilders
) {
public RankDocsRetrieverBuilder(int rankWindowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs) {
this.rankWindowSize = rankWindowSize;
this.rankDocs = rankDocs;
if (sources == null || sources.isEmpty()) {
throw new IllegalArgumentException("sources must not be null or empty");
}
this.sources = sources;
this.preFilterQueryBuilders = preFilterQueryBuilders;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prefilters were not actually used anywhere for the RankDocsRetrieverBuilder as they had already been accounted for when computing the parent results.

}

@Override
Expand Down Expand Up @@ -73,10 +67,6 @@ private boolean sourceShouldRewrite(QueryRewriteContext ctx) throws IOException
@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
assert false == sourceShouldRewrite(ctx) : "retriever sources should be rewritten first";
var rewrittenFilters = rewritePreFilters(ctx);
if (rewrittenFilters != preFilterQueryBuilders) {
return new RankDocsRetrieverBuilder(rankWindowSize, sources, rankDocs, rewrittenFilters);
}
return this;
}

Expand All @@ -94,7 +84,7 @@ public QueryBuilder topDocsQuery() {
boolQuery.should(query);
}
}
// ignore prefilters of this level, they are already propagated to children
// ignore prefilters of this level, they were already propagated to children
return boolQuery;
}

Expand Down Expand Up @@ -133,7 +123,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
} else {
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
}
// ignore prefilters of this level, they are already propagated to children
// ignore prefilters of this level, they were already propagated to children
searchSourceBuilder.query(rankQuery);
if (sourceHasMinScore()) {
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,7 @@ private List<QueryBuilder> preFilters(QueryRewriteContext queryRewriteContext) t
}

private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException {
return new RankDocsRetrieverBuilder(
randomIntBetween(1, 100),
innerRetrievers(queryRewriteContext),
rankDocsSupplier(),
preFilters(queryRewriteContext)
);
return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier());
}

public void testExtractToSearchSourceBuilder() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
/**
* A SearchPlugin to exercise query vector builder
*/
class TestQueryVectorBuilderPlugin implements SearchPlugin {
public class TestQueryVectorBuilderPlugin implements SearchPlugin {

static class TestQueryVectorBuilder implements QueryVectorBuilder {
public static class TestQueryVectorBuilder implements QueryVectorBuilder {
private static final String NAME = "test_query_vector_builder";

private static final ParseField QUERY_VECTOR = new ParseField("query_vector");
Expand All @@ -47,11 +47,11 @@ static class TestQueryVectorBuilder implements QueryVectorBuilder {

private List<Float> vectorToBuild;

TestQueryVectorBuilder(List<Float> vectorToBuild) {
public TestQueryVectorBuilder(List<Float> vectorToBuild) {
this.vectorToBuild = vectorToBuild;
}

TestQueryVectorBuilder(float[] expected) {
public TestQueryVectorBuilder(float[] expected) {
this.vectorToBuild = new ArrayList<>(expected.length);
for (float f : expected) {
vectorToBuild.add(f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.search.retriever;

import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;

Expand All @@ -23,16 +24,17 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder<TestC
public static final String NAME = "test_compound_retriever_builder";

public TestCompoundRetrieverBuilder(int rankWindowSize) {
this(new ArrayList<>(), rankWindowSize);
this(new ArrayList<>(), rankWindowSize, new ArrayList<>());
}

TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize) {
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, List<QueryBuilder> preFilterQueryBuilders) {
super(childRetrievers, rankWindowSize);
this.preFilterQueryBuilders = preFilterQueryBuilders;
}

@Override
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize);
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize, newPreFilterQueryBuilders);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,14 @@ public QueryRuleRetrieverBuilder(
Map<String, Object> matchCriteria,
List<RetrieverSource> retrieverSource,
int rankWindowSize,
String retrieverName
String retrieverName,
List<QueryBuilder> preFilterQueryBuilders
) {
super(retrieverSource, rankWindowSize);
this.rulesetIds = rulesetIds;
this.matchCriteria = matchCriteria;
this.retrieverName = retrieverName;
this.preFilterQueryBuilders = preFilterQueryBuilders;
}

@Override
Expand Down Expand Up @@ -156,8 +158,15 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
}

@Override
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newChildRetrievers, rankWindowSize, retrieverName);
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
return new QueryRuleRetrieverBuilder(
rulesetIds,
matchCriteria,
newChildRetrievers,
rankWindowSize,
retrieverName,
newPreFilterQueryBuilders
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ public TextSimilarityRankRetrieverBuilder(
}

@Override
protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
protected TextSimilarityRankRetrieverBuilder clone(
List<RetrieverSource> newChildRetrievers,
List<QueryBuilder> newPreFilterQueryBuilders
) {
return new TextSimilarityRankRetrieverBuilder(
newChildRetrievers,
inferenceId,
Expand All @@ -138,7 +141,7 @@ protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChil
rankWindowSize,
minScore,
retrieverName,
preFilterQueryBuilders
newPreFilterQueryBuilders
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand All @@ -33,6 +34,7 @@
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -48,6 +50,8 @@
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -743,6 +747,43 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
}

public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() {
final int rankWindowSize = 100;
final int rankConstant = 10;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will retriever all but 7 only due to top-level filter
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
// this would have retrieved 7 and 6, but due to parent level filter, will retriever 6 and 3 instead
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
"vector",
null,
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 7 }),
pmpailis marked this conversation as resolved.
Show resolved Hide resolved
1,
2,
null
);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
),
rankWindowSize,
rankConstant
)
);
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().mustNot(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
pmpailis marked this conversation as resolved.
Show resolved Hide resolved
source.size(10);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_6"));
assertThat(Arrays.stream(resp.getHits().getHits()).map(SearchHit::getId).toList(), not(contains("doc_7")));
});
}

public void testRewriteOnce() {
final float[] vector = new float[] { 1 };
AtomicInteger numAsyncCalls = new AtomicInteger();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import java.util.Set;

import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED;

/**
Expand All @@ -23,4 +24,9 @@ public class RRFFeatures implements FeatureSpecification {
public Set<NodeFeature> getFeatures() {
return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED, RRF_RETRIEVER_COMPOSITION_SUPPORTED);
}

@Override
public Set<NodeFeature> getTestFeatures() {
return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
Expand Down Expand Up @@ -108,8 +109,10 @@ public String getName() {
}

@Override
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers) {
return new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
return clone;
}

@Override
Expand Down
Loading