Skip to content

Commit

Permalink
Added new ApproximateableQuery
Browse files Browse the repository at this point in the history
Signed-off-by: Harsha Vamsi Kalluri <[email protected]>
  • Loading branch information
harshavamsi committed Aug 20, 2024
1 parent 06ce11d commit bda095b
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ public Query rangeQuery(
}
DateMathParser parser = forcedDateParser == null ? dateMathParser : forcedDateParser;
return dateRangeQuery(lowerTerm, upperTerm, includeLower, includeUpper, timeZone, parser, context, resolution, (l, u) -> {
Query pointRangeQuery = createPointRangeQuery(l, u);
Query pointRangeQuery = isSearchable() ? createPointRangeQuery(l, u) : null;
Query dvQuery = hasDocValues() ? SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u) : null;
if (isSearchable() && hasDocValues()) {
Query query = new IndexOrDocValuesQuery(pointRangeQuery, dvQuery);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
Expand All @@ -25,6 +25,9 @@
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.IntsRef;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.SortOrder;

import java.io.IOException;
Expand All @@ -35,7 +38,7 @@
* An approximate-able version of {@link PointRangeQuery}. It creates an instance of {@link PointRangeQuery} but short-circuits the intersect logic
* after {@code size} is hit
*/
public abstract class ApproximatePointRangeQuery extends Query {
public abstract class ApproximatePointRangeQuery extends ApproximateableQuery {
private int size;

private SortOrder sortOrder;
Expand Down Expand Up @@ -419,6 +422,29 @@ public boolean isCacheable(LeafReaderContext ctx) {
};
}

@Override
public boolean canApproximate(SearchContext context) {
if (!(context.query() instanceof IndexOrDocValuesQuery
&& ((IndexOrDocValuesQuery) context.query()).getIndexQuery() instanceof ApproximateScoreQuery
&& ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) context.query()).getIndexQuery())
.getOriginalQuery() instanceof PointRangeQuery)) {
return false;
}
ApproximateScoreQuery query = ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) context.query()).getIndexQuery());
((ApproximatePointRangeQuery) query.getApproximationQuery()).setSize(Math.max(context.size(), context.trackTotalHitsUpTo()));
if (context.request() != null && context.request().source() != null) {
FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(context.request().source());
if (primarySortField != null
&& primarySortField.missing() == null
&& primarySortField.getFieldName().equals(((RangeQueryBuilder) context.request().source().query()).fieldName())) {
if (primarySortField.order() == SortOrder.DESC) {
((ApproximatePointRangeQuery) query.getApproximationQuery()).setSortOrder(SortOrder.DESC);
}
}
}
return true;
}

@Override
public final int hashCode() {
return pointRangeQuery.hashCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;

Expand All @@ -29,17 +30,20 @@
*/
public final class ApproximateScoreQuery extends Query {

private final Query originalQuery, approximationQuery;
private final Query originalQuery;
private final ApproximateableQuery approximationQuery;

private Weight originalQueryWeight, approximationQueryWeight;

public ApproximateScoreQuery(Query originalQuery, Query approximationQuery) {
private SearchContext context;

public ApproximateScoreQuery(Query originalQuery, ApproximateableQuery approximationQuery) {
this(originalQuery, approximationQuery, null, null);
}

public ApproximateScoreQuery(
Query originalQuery,
Query approximationQuery,
ApproximateableQuery approximationQuery,
Weight originalQueryWeight,
Weight approximationQueryWeight
) {
Expand All @@ -49,15 +53,11 @@ public ApproximateScoreQuery(
this.approximationQueryWeight = approximationQueryWeight;
}

public Weight getApproximationQueryWeight() {
return approximationQueryWeight;
}

public Query getOriginalQuery() {
return originalQuery;
}

public Query getApproximationQuery() {
public ApproximateableQuery getApproximationQuery() {
return approximationQuery;
}

Expand Down Expand Up @@ -88,8 +88,9 @@ public ScorerSupplier scorerSupplier(LeafReaderContext leafReaderContext) throws
return new ScorerSupplier() {
@Override
public Scorer get(long l) throws IOException {
// TODO: we need to figure out how to compute the cost of running two different queries, by default return the
// original query's scoreSupplier
if (approximationQuery.canApproximate(context)) {
return originalQueryScoreSupplier.get(l);
}
return originalQueryScoreSupplier.get(l);
}

Expand All @@ -116,6 +117,10 @@ public boolean isCacheable(LeafReaderContext leafReaderContext) {
};
}

public void setContext(SearchContext context) {
this.context = context;
};

@Override
public String toString(String s) {
return "ApproximateScoreQuery(originalQuery="
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.approximate;

import org.apache.lucene.search.Query;
import org.opensearch.search.internal.SearchContext;

public abstract class ApproximateableQuery extends Query {

protected abstract boolean canApproximate(SearchContext context);

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryCache;
import org.apache.lucene.search.QueryCachingPolicy;
Expand All @@ -69,10 +68,8 @@
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.SearchService;
import org.opensearch.search.approximate.ApproximatePointRangeQuery;
import org.opensearch.search.approximate.ApproximateScoreQuery;
import org.opensearch.search.dfs.AggregatedDfs;
import org.opensearch.search.profile.ContextualProfileBreakdown;
Expand All @@ -84,7 +81,6 @@
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.MinAndMax;
import org.opensearch.search.sort.SortOrder;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -338,22 +334,10 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto
// catch early terminated exception and rethrow?
Bits liveDocs = ctx.reader().getLiveDocs();
BitSet liveDocsBitSet = getSparseBitSetOrNull(liveDocs);
if (isApproximateableRangeQuery()) {
if (searchContext.query() instanceof IndexOrDocValuesQuery
&& ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery() instanceof ApproximateScoreQuery) {
ApproximateScoreQuery query = ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery());
((ApproximatePointRangeQuery) query.getApproximationQuery()).setSize(
Math.max(searchContext.size(), searchContext.trackTotalHitsUpTo())
);
if (searchContext.request() != null && searchContext.request().source() != null) {
FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(searchContext.request().source());
if (primarySortField != null
&& primarySortField.missing() == null
&& primarySortField.getFieldName().equals(((RangeQueryBuilder) searchContext.request().source().query()).fieldName())) {
if (primarySortField.order() == SortOrder.DESC) {
((ApproximatePointRangeQuery) query.getApproximationQuery()).setSortOrder(SortOrder.DESC);
}
}
}
weight = query.getApproximationQueryWeight();
query.setContext(searchContext);
}
if (liveDocsBitSet == null) {
BulkScorer bulkScorer = weight.bulkScorer(ctx);
Expand Down Expand Up @@ -446,13 +430,6 @@ private static BitSet getSparseBitSetOrNull(Bits liveDocs) {

}

private boolean isApproximateableRangeQuery() {
return searchContext.query() instanceof IndexOrDocValuesQuery
&& ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery() instanceof ApproximateScoreQuery
&& ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery())
.getOriginalQuery() instanceof PointRangeQuery;
}

static void intersectScorerAndBitSet(Scorer scorer, BitSet acceptDocs, LeafCollector collector, Runnable checkCancelled)
throws IOException {
collector.setScorer(scorer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ protected String toString(int dimension, byte[] value) {
}
};

Query approximateQuery = new ApproximatePointRangeQuery(
ApproximateableQuery approximateQuery = new ApproximatePointRangeQuery(
"test-index",
pack(new long[] { l }).bytes,
pack(new long[] { u }).bytes,
Expand Down

0 comments on commit bda095b

Please sign in to comment.