diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index cbd0522c72f58..f540c5d421b5f 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -22,9 +22,7 @@ import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.DocIdSetBuilder; -import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IntsRef; import java.io.IOException; @@ -44,6 +42,8 @@ public abstract class ApproximatePointRangeQuery extends Query { private int size; + private long[] docCount = { 0 }; + protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) { this(field, lowerPoint, upperPoint, numDims, 10_000); } @@ -70,6 +70,7 @@ protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upp this.lowerPoint = lowerPoint; this.upperPoint = upperPoint; + this.size = size; } public int getSize() { @@ -147,7 +148,10 @@ public void grow(int count) { @Override public void visit(int docID) { - adder.add(docID); + if (docCount[0] <= size) { + adder.add(docID); + docCount[0]++; + } } @Override @@ -183,63 +187,6 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue }; } - /** - * Create a visitor that clears documents that do NOT match the range. - */ - private PointValues.IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) { - return new PointValues.IntersectVisitor() { - @Override - public void visit(int docID) { - result.clear(docID); - cost[0]--; - } - - @Override - public void visit(DocIdSetIterator iterator) throws IOException { - result.andNot(iterator); - cost[0] = Math.max(0, cost[0] - iterator.cost()); - } - - @Override - public void visit(IntsRef ref) { - for (int i = ref.offset; i < ref.offset + ref.length; i++) { - result.clear(ref.ints[i]); - } - cost[0] -= ref.length; - } - - @Override - public void visit(int docID, byte[] packedValue) { - if (matches(packedValue) == false) { - visit(docID); - } - } - - @Override - public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { - if (matches(packedValue) == false) { - visit(iterator); - } - } - - @Override - public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - PointValues.Relation relation = relate(minPackedValue, maxPackedValue); - switch (relation) { - case CELL_INSIDE_QUERY: - // all points match, skip this subtree - return PointValues.Relation.CELL_OUTSIDE_QUERY; - case CELL_OUTSIDE_QUERY: - // none of the points match, clear all documents - return PointValues.Relation.CELL_INSIDE_QUERY; - case CELL_CROSSES_QUERY: - default: - return relation; - } - } - }; - } - private boolean checkValidPointValues(PointValues values) throws IOException { if (values == null) { // No docs in this segment/field indexed any points @@ -270,14 +217,13 @@ private boolean checkValidPointValues(PointValues values) throws IOException { } private void intersect(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, int count) throws IOException { - intersect(visitor, pointTree, count, 0); + intersect(visitor, pointTree, count); assert pointTree.moveToParent() == false; } - private long intersect(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, int count, long docCount) - throws IOException { + private long intersect(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, int count) throws IOException { PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - if (docCount >= count) { + if (docCount[0] >= count) { return 0; } switch (r) { @@ -294,14 +240,14 @@ private long intersect(PointValues.IntersectVisitor visitor, PointValues.PointTr // through and do full filtering: if (pointTree.moveToChild()) { do { - docCount += intersect(visitor, pointTree, count, docCount); - } while (pointTree.moveToSibling() && docCount <= count); + docCount[0] += intersect(visitor, pointTree, count); + } while (pointTree.moveToSibling() && docCount[0] <= count); pointTree.moveToParent(); } else { // TODO: we can assert that the first value here in fact matches what the pointTree // claimed? // Leaf node; scan and filter all points in this block: - if (docCount <= count) { + if (docCount[0] <= count) { pointTree.visitDocValues(visitor); } else break; } @@ -309,7 +255,8 @@ private long intersect(PointValues.IntersectVisitor visitor, PointValues.PointTr default: throw new IllegalArgumentException("Unreachable code"); } - return 0; + // docCount can be updated by the local visitor so we ensure that we return docCount after pointTree.visitDocValues(visitor) + return docCount[0] > 0 ? docCount[0] : 0; } @Override @@ -379,20 +326,6 @@ public long cost() { @Override public Scorer get(long leadCost) throws IOException { - if (values.getDocCount() == reader.maxDoc() - && values.getDocCount() == values.size() - && cost() > reader.maxDoc() / 2) { - // If all docs have exactly one value and the cost is greater - // than half the leaf size then maybe we can make things faster - // by computing the set of documents that do NOT match the range - final FixedBitSet result = new FixedBitSet(reader.maxDoc()); - result.set(0, reader.maxDoc()); - long[] cost = new long[] { reader.maxDoc() }; - intersect(values.getPointTree(), getInverseIntersectVisitor(result, cost), size); - final DocIdSetIterator iterator = new BitSetIterator(result, cost[0]); - return new ConstantScoreScorer(weight, score(), scoreMode, iterator); - } - intersect(values.getPointTree(), visitor, size); DocIdSetIterator iterator = result.build().iterator(); return new ConstantScoreScorer(weight, score(), scoreMode, iterator); diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java new file mode 100644 index 0000000000000..1fdc5c07be49f --- /dev/null +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java @@ -0,0 +1,171 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; + +import org.apache.lucene.analysis.core.WhitespaceAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.PointRangeQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.apache.lucene.document.LongPoint.pack; + +public class ApproximatePointRangeQueryTests extends OpenSearchTestCase { + + public void testApproximateRangeEqualsActualRange() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + for (int i = 0; i < 100; i++) { + int numPoints = RandomNumbers.randomIntBetween(random(), 1, 10); + Document doc = new Document(); + for (int j = 0; j < numPoints; j++) { + for (int v = 0; v < dims; v++) { + scratch[v] = RandomNumbers.randomLongBetween(random(), 0, 100); + } + doc.add(new LongPoint("point", scratch)); + } + iw.addDocument(doc); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = RandomNumbers.randomLongBetween(random(), -100, 200); + long upper = lower + RandomNumbers.randomLongBetween(random(), 0, 100); + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + Query query = new PointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + TopDocs topDocs1 = searcher.search(query, 10); + assertEquals(topDocs.totalHits, topDocs1.totalHits); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeWithSize() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + for (int i = 0; i < 100; i++) { + int numPoints = RandomNumbers.randomIntBetween(random(), 1, 10); + Document doc = new Document(); + for (int j = 0; j < numPoints; j++) { + for (int v = 0; v < dims; v++) { + scratch[v] = RandomNumbers.randomLongBetween(random(), 0, 100); + } + doc.add(new LongPoint("point", scratch)); + } + iw.addDocument(doc); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = RandomNumbers.randomLongBetween(random(), -100, 200); + long upper = lower + RandomNumbers.randomLongBetween(random(), 0, 100); + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims, 100) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + Query query = new PointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + TopDocs topDocs1 = searcher.search(query, 10); + assertEquals(topDocs.totalHits, topDocs1.totalHits); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeShortCircuit() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 1000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + if (i % 10 == 0) iw.flush(); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 100; + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims, 10) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + Query query = new PointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + TopDocs topDocs1 = searcher.search(query, 10); + + // since we short-circuit from the approx range at the end of size these will not be equal + assertNotEquals(topDocs.totalHits, topDocs1.totalHits); + assertEquals(topDocs.totalHits, new TotalHits(11, TotalHits.Relation.EQUAL_TO)); + assertEquals(topDocs1.totalHits, new TotalHits(101, TotalHits.Relation.EQUAL_TO)); + + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + +} diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximateableQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximateableQueryTests.java index 90f71f9ff7d3d..15e1c764c206e 100644 --- a/server/src/test/java/org/opensearch/search/approximate/ApproximateableQueryTests.java +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximateableQueryTests.java @@ -78,6 +78,5 @@ protected String toString(int dimension, byte[] value) { } } } - } }