diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConstantScoreWeight.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConstantScoreWeight.java deleted file mode 100644 index aec84499a9f7b..0000000000000 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConstantScoreWeight.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.approximate; - -import org.apache.lucene.index.PointValues; -import org.apache.lucene.search.ConstantScoreWeight; -import org.apache.lucene.search.Query; -import org.apache.lucene.util.DocIdSetBuilder; - -import java.io.IOException; - -public abstract class ApproximateConstantScoreWeight extends ConstantScoreWeight { - - protected ApproximateConstantScoreWeight(Query query, float score) { - super(query, score); - } - - protected abstract long intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree) throws IOException; - - protected abstract long intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree) throws IOException; - - protected abstract PointValues.IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) throws IOException; - -} diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 49949e935e7e7..55dfa9953293d 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -12,6 +12,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PointValues; import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.PointRangeQuery; @@ -44,11 +45,11 @@ public abstract class ApproximatePointRangeQuery extends Query { private final PointRangeQuery pointRangeQuery; protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) { - this(field, lowerPoint, upperPoint, numDims, 10_000, SortOrder.ASC); + this(field, lowerPoint, upperPoint, numDims, 10_000, null); } protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size) { - this(field, lowerPoint, upperPoint, numDims, size, SortOrder.ASC); + this(field, lowerPoint, upperPoint, numDims, size, null); } protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size, SortOrder sortOrder) { @@ -84,17 +85,30 @@ public void visit(QueryVisitor visitor) { } @Override - public final ApproximateConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + public final ConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { Weight pointRangeQueryWeight = pointRangeQuery.createWeight(searcher, scoreMode, boost); - return new ApproximateConstantScoreWeight(this, boost) { + return new ConstantScoreWeight(this, boost) { private final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(pointRangeQuery.getBytesPerDim()); + // we pull this from PointRangeQuery since it is final private boolean matches(byte[] packedValue) { - return relate(packedValue, packedValue) != PointValues.Relation.CELL_OUTSIDE_QUERY; + for (int dim = 0; dim < pointRangeQuery.getNumDims(); dim++) { + int offset = dim * pointRangeQuery.getBytesPerDim(); + if (comparator.compare(packedValue, offset, pointRangeQuery.getLowerPoint(), offset) < 0) { + // Doc's value is too low, in this dimension + return false; + } + if (comparator.compare(packedValue, offset, pointRangeQuery.getUpperPoint(), offset) > 0) { + // Doc's value is too high, in this dimension + return false; + } + } + return true; } + // we pull this from PointRangeQuery since it is final private PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { boolean crosses = false; @@ -171,6 +185,7 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue }; } + // we pull this from PointRangeQuery since it is final private boolean checkValidPointValues(PointValues values) throws IOException { if (values == null) { // No docs in this segment/field indexed any points @@ -211,10 +226,10 @@ private void intersectRight(PointValues.PointTree pointTree, PointValues.Interse } // custom intersect visitor to walk the left of the tree - public long intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree) throws IOException { + public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree) throws IOException { PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - if (docCount[0] >= size) { - return 0; + if (docCount[0] > size) { + return; } switch (r) { case CELL_OUTSIDE_QUERY: @@ -223,27 +238,25 @@ public long intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin case CELL_INSIDE_QUERY: // If the cell is fully inside, we keep moving to child until we reach a point where we can no longer move or when // we have sufficient doc count. We first move down and then move to the left child - if (pointTree.moveToChild()) { + if (pointTree.moveToChild() && docCount[0] < size) { do { - docCount[0] += intersectLeft(visitor, pointTree); - } while (pointTree.moveToSibling() && docCount[0] <= size); + intersectLeft(visitor, pointTree); + } while (pointTree.moveToSibling() && docCount[0] < size); pointTree.moveToParent(); } else { // we're at the leaf node, if we're under the size, visit all the docIds in this node. if (docCount[0] < size) { pointTree.visitDocIDs(visitor); - docCount[0] += pointTree.size(); - return docCount[0]; - } else break; + } } break; case CELL_CROSSES_QUERY: // The cell crosses the shape boundary, or the cell fully contains the query, so we fall // through and do full filtering: - if (pointTree.moveToChild()) { + if (pointTree.moveToChild() && docCount[0] < size) { do { - docCount[0] += intersectLeft(visitor, pointTree); - } while (pointTree.moveToSibling() && docCount[0] <= size); + intersectLeft(visitor, pointTree); + } while (pointTree.moveToSibling() && docCount[0] < size); pointTree.moveToParent(); } else { // TODO: we can assert that the first value here in fact matches what the pointTree @@ -251,66 +264,70 @@ public long intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin // Leaf node; scan and filter all points in this block: if (docCount[0] < size) { pointTree.visitDocValues(visitor); - } else break; + } } break; default: throw new IllegalArgumentException("Unreachable code"); } - // docCount can be updated by the local visitor so we ensure that we return docCount after pointTree.visitDocValues(visitor) - return docCount[0] > 0 ? docCount[0] : 0; } // custom intersect visitor to walk the right of tree - public long intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree) throws IOException { + public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree) throws IOException { PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - if (docCount[0] >= size) { - return 0; + if (docCount[0] > size) { + return; } switch (r) { case CELL_OUTSIDE_QUERY: // This cell is fully outside the query shape: stop recursing break; + case CELL_INSIDE_QUERY: - // If the cell is fully inside, we keep moving to child until we reach a point where we can no longer move or when - // we have sufficient doc count. We first move down and then move right - if (pointTree.moveToChild()) { - while (pointTree.moveToSibling() && docCount[0] <= size) { - docCount[0] += intersectRight(visitor, pointTree); - } + // If the cell is fully inside, we keep moving right as long as the point tree size is over our size requirement + if (pointTree.size() > size && docCount[0] < size && moveRight(pointTree)) { + intersectRight(visitor, pointTree); pointTree.moveToParent(); - } else { - // we're at the leaf node, if we're under the size, visit all the docIds in this node. - if (docCount[0] <= size) { + } + // if point tree size is no longer over, we have to go back one level where it still was over and the intersect left + else if (pointTree.size() <= size && docCount[0] < size) { + pointTree.moveToParent(); + intersectLeft(visitor, pointTree); + } + // if we've reached leaf, it means out size is under the size of the leaf, we can just collect all docIDs + else { + // Leaf node; scan and filter all points in this block: + if (docCount[0] < size) { pointTree.visitDocIDs(visitor); - docCount[0] += pointTree.size(); - return docCount[0]; - } else break; + } } break; case CELL_CROSSES_QUERY: - // The cell crosses the shape boundary, or the cell fully contains the query, so we fall - // through and do full filtering: - if (pointTree.moveToChild()) { - do { - docCount[0] += intersectRight(visitor, pointTree); - } while (pointTree.moveToSibling() && docCount[0] <= size); + // If the cell is fully inside, we keep moving right as long as the point tree size is over our size requirement + if (pointTree.size() > size && docCount[0] < size && moveRight(pointTree)) { + intersectRight(visitor, pointTree); pointTree.moveToParent(); - } else { - // TODO: we can assert that the first value here in fact matches what the pointTree - // claimed? + } + // if point tree size is no longer over, we have to go back one level where it still was over and the intersect left + else if (pointTree.size() <= size && docCount[0] < size) { + pointTree.moveToParent(); + intersectLeft(visitor, pointTree); + } + // if we've reached leaf, it means out size is under the size of the leaf, we can just collect all doc values + else { // Leaf node; scan and filter all points in this block: - if (docCount[0] <= size) { + if (docCount[0] < size) { pointTree.visitDocValues(visitor); - } else break; + } } break; default: throw new IllegalArgumentException("Unreachable code"); } - // docCount can be updated by the local visitor, so we ensure that we return docCount after - // pointTree.visitDocValues(visitor) - return docCount[0] > 0 ? docCount[0] : 0; + } + + public boolean moveRight(PointValues.PointTree pointTree) throws IOException { + return pointTree.moveToChild() && pointTree.moveToSibling(); } @Override @@ -321,58 +338,11 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti if (checkValidPointValues(values) == false) { return null; } - - if (values.getDocCount() == 0) { - return null; - } else { - final byte[] fieldPackedLower = values.getMinPackedValue(); - final byte[] fieldPackedUpper = values.getMaxPackedValue(); - for (int i = 0; i < pointRangeQuery.getNumDims(); ++i) { - int offset = i * pointRangeQuery.getBytesPerDim(); - if (comparator.compare(pointRangeQuery.getLowerPoint(), offset, fieldPackedUpper, offset) > 0 - || comparator.compare(pointRangeQuery.getUpperPoint(), offset, fieldPackedLower, offset) < 0) { - // If this query is a required clause of a boolean query, then returning null here - // will help make sure that we don't call ScorerSupplier#get on other required clauses - // of the same boolean query, which is an expensive operation for some queries (e.g. - // multi-term queries). - return null; - } - } - } - - boolean allDocsMatch; - if (values.getDocCount() == reader.maxDoc()) { - final byte[] fieldPackedLower = values.getMinPackedValue(); - final byte[] fieldPackedUpper = values.getMaxPackedValue(); - allDocsMatch = true; - for (int i = 0; i < pointRangeQuery.getNumDims(); ++i) { - int offset = i * pointRangeQuery.getBytesPerDim(); - if (comparator.compare(pointRangeQuery.getLowerPoint(), offset, fieldPackedLower, offset) > 0 - || comparator.compare(pointRangeQuery.getUpperPoint(), offset, fieldPackedUpper, offset) < 0) { - allDocsMatch = false; - break; - } - } - } else { - allDocsMatch = false; - } - final Weight weight = this; - if (allDocsMatch) { - // all docs have a value and all points are within bounds, so everything matches - return new ScorerSupplier() { - @Override - public Scorer get(long leadCost) { - return new ConstantScoreScorer(weight, score(), scoreMode, DocIdSetIterator.all(reader.maxDoc())); - } - - @Override - public long cost() { - return reader.maxDoc(); - } - }; + if (size > values.size()) { + return pointRangeQueryWeight.scorerSupplier(context); } else { - if (sortOrder.equals(SortOrder.ASC)) { + if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { return new ScorerSupplier() { final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, pointRangeQuery.getField()); @@ -396,30 +366,35 @@ public long cost() { return cost; } }; - } - return new ScorerSupplier() { + } else { + // we need to fetch size + deleted docs since the collector will prune away deleted docs resulting in fewer results + // than expected + final int deletedDocs = reader.numDeletedDocs(); + size += deletedDocs; + return new ScorerSupplier() { - final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, pointRangeQuery.getField()); - final PointValues.IntersectVisitor visitor = getIntersectVisitor(result); - long cost = -1; + final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, pointRangeQuery.getField()); + final PointValues.IntersectVisitor visitor = getIntersectVisitor(result); + long cost = -1; - @Override - public Scorer get(long leadCost) throws IOException { - intersectRight(values.getPointTree(), visitor); - DocIdSetIterator iterator = result.build().iterator(); - return new ConstantScoreScorer(weight, score(), scoreMode, iterator); - } + @Override + public Scorer get(long leadCost) throws IOException { + intersectRight(values.getPointTree(), visitor); + DocIdSetIterator iterator = result.build().iterator(); + return new ConstantScoreScorer(weight, score(), scoreMode, iterator); + } - @Override - public long cost() { - if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; + @Override + public long cost() { + if (cost == -1) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; + } + return cost; } - return cost; - } - }; + }; + } } } @@ -503,7 +478,7 @@ public final String toString(String field) { * {@link #toString()}. * * @param dimension dimension of the particular value - * @param value single value, never null + * @param value single value, never null * @return human readable value for debugging */ protected abstract String toString(int dimension, byte[] value); diff --git a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java index 7c23440868912..ddc4961bffe3a 100644 --- a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java @@ -69,6 +69,7 @@ import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchService; import org.opensearch.search.approximate.ApproximatePointRangeQuery; @@ -334,10 +335,14 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto BitSet liveDocsBitSet = getSparseBitSetOrNull(liveDocs); if (isApproximateableRangeQuery()) { ApproximateScoreQuery query = ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery()); - if (searchContext.size() > 10_000) ((ApproximatePointRangeQuery) query.getApproximationQuery()).setSize(searchContext.size()); + ((ApproximatePointRangeQuery) query.getApproximationQuery()).setSize( + Math.max(searchContext.size(), searchContext.trackTotalHitsUpTo()) + ); if (searchContext.request() != null && searchContext.request().source() != null) { FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(searchContext.request().source()); - if (primarySortField != null && primarySortField.missing() == null) { + if (primarySortField != null + && primarySortField.missing() == null + && primarySortField.getFieldName().equals(((RangeQueryBuilder) searchContext.request().source().query()).fieldName())) { if (primarySortField.order() == SortOrder.DESC) { ((ApproximatePointRangeQuery) query.getApproximationQuery()).setSortOrder(SortOrder.DESC); }