Skip to content

Commit

Permalink
Addressing PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: Harsha Vamsi Kalluri <[email protected]>
  • Loading branch information
harshavamsi committed Aug 29, 2024
1 parent 3304633 commit 3d63ea0
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 217 deletions.
10 changes: 10 additions & 0 deletions server/src/main/java/org/opensearch/common/util/FeatureFlags.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ public class FeatureFlags {
Property.NodeScope
);

/**
* Gates the functionality of ApproximatePointRangeQuery where we approximate query results.
*/
public static final String APPROXIMATE_POINT_RANGE_QUERY = "opensearch.experimental.feature.approximate_point_range_query.enabled";
public static final Setting<Boolean> APPROXIMATE_POINT_RANGE_QUERY_SETTING = Setting.boolSetting(
APPROXIMATE_POINT_RANGE_QUERY,
false,
Property.NodeScope
);

private static final List<Setting<Boolean>> ALL_FEATURE_FLAG_SETTINGS = List.of(
REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING,
EXTENSIONS_SETTING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSortSortedNumericDocValuesRangeQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.opensearch.OpenSearchParseException;
import org.opensearch.Version;
Expand Down Expand Up @@ -111,6 +111,21 @@ public static DateFormatter getDefaultDateTimeFormatter() {
: LEGACY_DEFAULT_DATE_TIME_FORMATTER;
}

public static Query getDefaultQuery(Query pointRangeQuery, Query dvQuery, String name, long l, long u) {
return FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY_SETTING)
? new ApproximateIndexOrDocValuesQuery(
pointRangeQuery,
new ApproximatePointRangeQuery(name, pack(new long[] { l }).bytes, pack(new long[] { u }).bytes, new long[] { l }.length) {
@Override
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
},
dvQuery
)
: new IndexOrDocValuesQuery(pointRangeQuery, dvQuery);
}

/**
* Resolution of the date time
*
Expand Down Expand Up @@ -466,24 +481,10 @@ public Query rangeQuery(
}
DateMathParser parser = forcedDateParser == null ? dateMathParser : forcedDateParser;
return dateRangeQuery(lowerTerm, upperTerm, includeLower, includeUpper, timeZone, parser, context, resolution, (l, u) -> {
Query pointRangeQuery = isSearchable() ? createPointRangeQuery(l, u) : null;
Query pointRangeQuery = isSearchable() ? LongPoint.newRangeQuery(name(), l, u) : null;
Query dvQuery = hasDocValues() ? SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u) : null;
if (isSearchable() && hasDocValues()) {
Query query = new ApproximateIndexOrDocValuesQuery(
pointRangeQuery,
new ApproximatePointRangeQuery(
name(),
pack(new long[] { l }).bytes,
pack(new long[] { u }).bytes,
new long[] { l }.length
) {
@Override
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
},
dvQuery
);
Query query = getDefaultQuery(pointRangeQuery, dvQuery, name(), l, u);
if (context.indexSortedOnField(name())) {
query = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, query);
}
Expand All @@ -499,14 +500,6 @@ protected String toString(int dimension, byte[] value) {
});
}

private Query createPointRangeQuery(long l, long u) {
return new PointRangeQuery(name(), pack(new long[] { l }).bytes, pack(new long[] { u }).bytes, new long[] { l }.length) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
};
}

public static Query dateRangeQuery(
Object lowerTerm,
Object upperTerm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.index.query.DateRangeIncludingNowQuery;
import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery;
import org.opensearch.search.approximate.ApproximatePointRangeQuery;
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
Expand Down Expand Up @@ -56,6 +55,7 @@ private Helper() {}
queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery) q).getSubQuery());
queryWrappers.put(DateRangeIncludingNowQuery.class, q -> ((DateRangeIncludingNowQuery) q).getQuery());
queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery) q).getIndexQuery());
queryWrappers.put(ApproximateIndexOrDocValuesQuery.class, q -> ((ApproximateIndexOrDocValuesQuery) q).getOriginalQuery());
}

/**
Expand Down Expand Up @@ -125,19 +125,6 @@ public static long[] getDateHistoAggBounds(final SearchContext context, final St
final long[] indexBounds = getShardBounds(leaves, fieldName);
if (indexBounds == null) return null;
return getBoundsWithRangeQuery(prq, fieldName, indexBounds);
} else if (cq instanceof ApproximateIndexOrDocValuesQuery) {
final ApproximateIndexOrDocValuesQuery aiodvq = (ApproximateIndexOrDocValuesQuery) cq;
final long[] indexBounds = getShardBounds(leaves, fieldName);
if (indexBounds == null) return null;
if ((aiodvq.getApproximationQuery() instanceof ApproximatePointRangeQuery)) {
ApproximatePointRangeQuery aprq = (ApproximatePointRangeQuery) aiodvq.getApproximationQuery();
if (aprq.canApproximate(context)) {
return getBoundsWithRangeQuery(aprq.pointRangeQuery, fieldName, indexBounds);
}
final IndexOrDocValuesQuery iodvq = (IndexOrDocValuesQuery) aiodvq.getOriginalQuery();
final PointRangeQuery prq = (PointRangeQuery) iodvq.getIndexQuery();
return getBoundsWithRangeQuery(prq, fieldName, indexBounds);
}
} else if (cq instanceof MatchAllDocsQuery) {
return getShardBounds(leaves, fieldName);
} else if (cq instanceof FieldExistsQuery) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@

/**
* A wrapper around {@link IndexOrDocValuesQuery} that can be used to run approximate queries.
* It delegates to either {@link ApproximateableQuery} or {@link IndexOrDocValuesQuery} based on whether the query can be approximated or not.
* @see ApproximateableQuery
* It delegates to either {@link ApproximateQuery} or {@link IndexOrDocValuesQuery} based on whether the query can be approximated or not.
* @see ApproximateQuery
*/
public final class ApproximateIndexOrDocValuesQuery extends ApproximateScoreQuery {

private final ApproximateableQuery approximateIndexQuery;
private final ApproximateQuery approximateIndexQuery;
private final IndexOrDocValuesQuery indexOrDocValuesQuery;

public ApproximateIndexOrDocValuesQuery(Query indexQuery, ApproximateableQuery approximateIndexQuery, Query dvQuery) {
public ApproximateIndexOrDocValuesQuery(Query indexQuery, ApproximateQuery approximateIndexQuery, Query dvQuery) {
super(new IndexOrDocValuesQuery(indexQuery, dvQuery), approximateIndexQuery);
this.approximateIndexQuery = approximateIndexQuery;
this.indexOrDocValuesQuery = new IndexOrDocValuesQuery(indexQuery, dvQuery);
Expand Down Expand Up @@ -67,9 +67,11 @@ public int hashCode() {

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
if (approximateIndexQuery.canApproximate(this.getContext())) {
return approximateIndexQuery.createWeight(searcher, scoreMode, boost);
// it means we haven't called setContext, some internal test might try to call this without setting context, just return IODVQ's
// weight
if (this.resolvedQuery == null) {
return indexOrDocValuesQuery.createWeight(searcher, scoreMode, boost);
}
return indexOrDocValuesQuery.createWeight(searcher, scoreMode, boost);
return this.resolvedQuery.createWeight(searcher, scoreMode, boost);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@
* An approximate-able version of {@link PointRangeQuery}. It creates an instance of {@link PointRangeQuery} but short-circuits the intersect logic
* after {@code size} is hit
*/
public abstract class ApproximatePointRangeQuery extends ApproximateableQuery {
public abstract class ApproximatePointRangeQuery extends ApproximateQuery {
private int size;

private SortOrder sortOrder;

private long[] docCount = { 0 };

public final PointRangeQuery pointRangeQuery;

protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) {
Expand Down Expand Up @@ -134,7 +132,7 @@ private PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue
}
}

public PointValues.IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) {
public PointValues.IntersectVisitor getIntersectVisitor(DocIdSetBuilder result, long[] docCount) {
return new PointValues.IntersectVisitor() {

DocIdSetBuilder.BulkAdder adder;
Expand Down Expand Up @@ -217,18 +215,21 @@ private boolean checkValidPointValues(PointValues values) throws IOException {
return true;
}

private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor) throws IOException {
intersectLeft(visitor, pointTree);
private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount)
throws IOException {
intersectLeft(visitor, pointTree, docCount);
assert pointTree.moveToParent() == false;
}

private void intersectRight(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor) throws IOException {
intersectRight(visitor, pointTree);
private void intersectRight(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount)
throws IOException {
intersectRight(visitor, pointTree, docCount);
assert pointTree.moveToParent() == false;
}

// custom intersect visitor to walk the left of the tree
public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree) throws IOException {
public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount)
throws IOException {
PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
if (docCount[0] > size) {
return;
Expand All @@ -242,7 +243,7 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin
// we have sufficient doc count. We first move down and then move to the left child
if (pointTree.moveToChild() && docCount[0] < size) {
do {
intersectLeft(visitor, pointTree);
intersectLeft(visitor, pointTree, docCount);
} while (pointTree.moveToSibling() && docCount[0] < size);
pointTree.moveToParent();
} else {
Expand All @@ -257,7 +258,7 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin
// through and do full filtering:
if (pointTree.moveToChild() && docCount[0] < size) {
do {
intersectLeft(visitor, pointTree);
intersectLeft(visitor, pointTree, docCount);
} while (pointTree.moveToSibling() && docCount[0] < size);
pointTree.moveToParent();
} else {
Expand All @@ -275,7 +276,8 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin
}

// custom intersect visitor to walk the right of tree
public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree) throws IOException {
public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount)
throws IOException {
PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
if (docCount[0] > size) {
return;
Expand All @@ -288,13 +290,13 @@ public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.Poi
case CELL_INSIDE_QUERY:
// If the cell is fully inside, we keep moving right as long as the point tree size is over our size requirement
if (pointTree.size() > size && docCount[0] < size && moveRight(pointTree)) {
intersectRight(visitor, pointTree);
intersectRight(visitor, pointTree, docCount);
pointTree.moveToParent();
}
// if point tree size is no longer over, we have to go back one level where it still was over and the intersect left
else if (pointTree.size() <= size && docCount[0] < size) {
pointTree.moveToParent();
intersectLeft(visitor, pointTree);
intersectLeft(visitor, pointTree, docCount);
}
// if we've reached leaf, it means out size is under the size of the leaf, we can just collect all docIDs
else {
Expand All @@ -307,13 +309,13 @@ else if (pointTree.size() <= size && docCount[0] < size) {
case CELL_CROSSES_QUERY:
// If the cell is fully inside, we keep moving right as long as the point tree size is over our size requirement
if (pointTree.size() > size && docCount[0] < size && moveRight(pointTree)) {
intersectRight(visitor, pointTree);
intersectRight(visitor, pointTree, docCount);
pointTree.moveToParent();
}
// if point tree size is no longer over, we have to go back one level where it still was over and the intersect left
else if (pointTree.size() <= size && docCount[0] < size) {
pointTree.moveToParent();
intersectLeft(visitor, pointTree);
intersectLeft(visitor, pointTree, docCount);
}
// if we've reached leaf, it means out size is under the size of the leaf, we can just collect all doc values
else {
Expand All @@ -335,6 +337,7 @@ public boolean moveRight(PointValues.PointTree pointTree) throws IOException {
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
LeafReader reader = context.reader();
long[] docCount = { 0 };

PointValues values = reader.getPointValues(pointRangeQuery.getField());
if (checkValidPointValues(values) == false) {
Expand All @@ -348,12 +351,12 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
return new ScorerSupplier() {

final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, pointRangeQuery.getField());
final PointValues.IntersectVisitor visitor = getIntersectVisitor(result);
final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount);
long cost = -1;

@Override
public Scorer get(long leadCost) throws IOException {
intersectLeft(values.getPointTree(), visitor);
intersectLeft(values.getPointTree(), visitor, docCount);
DocIdSetIterator iterator = result.build().iterator();
return new ConstantScoreScorer(weight, score(), scoreMode, iterator);
}
Expand All @@ -376,12 +379,12 @@ public long cost() {
return new ScorerSupplier() {

final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, pointRangeQuery.getField());
final PointValues.IntersectVisitor visitor = getIntersectVisitor(result);
final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount);
long cost = -1;

@Override
public Scorer get(long leadCost) throws IOException {
intersectRight(values.getPointTree(), visitor);
intersectRight(values.getPointTree(), visitor, docCount);
DocIdSetIterator iterator = result.build().iterator();
return new ConstantScoreScorer(weight, score(), scoreMode, iterator);
}
Expand Down Expand Up @@ -432,7 +435,7 @@ public boolean canApproximate(SearchContext context) {
if (!(context.query() instanceof ApproximateIndexOrDocValuesQuery)) {
return false;
}
this.setSize(Math.max(context.size(), context.trackTotalHitsUpTo()));
this.setSize(Math.max(context.from() + context.size(), context.trackTotalHitsUpTo()));
if (context.request() != null && context.request().source() != null) {
FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(context.request().source());
if (primarySortField != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
/**
* Abstract class that can be inherited by queries that can be approximated. Queries should implement {@link #canApproximate(SearchContext)} to specify conditions on when they can be approximated
*/
public abstract class ApproximateableQuery extends Query {
public abstract class ApproximateQuery extends Query {

protected abstract boolean canApproximate(SearchContext context);

Expand Down
Loading

0 comments on commit 3d63ea0

Please sign in to comment.