diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index cee8bc43d7ffd..8076da6ab970b 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -146,10 +146,11 @@ public void grow(int count) { public void visit(int docID) { // it is possible that size < 1024 and docCount < size but we will continue to count through all the 1024 docs // and collect less, but it won't hurt performance - if (docCount[0] < size) { - adder.add(docID); - docCount[0]++; + if (docCount[0] >= size) { + return; } + adder.add(docID); + docCount[0]++; } @Override @@ -231,7 +232,7 @@ private void intersectRight(PointValues.PointTree pointTree, PointValues.Interse public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) throws IOException { PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - if (docCount[0] > size) { + if (docCount[0] >= size) { return; } switch (r) { @@ -279,7 +280,7 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) throws IOException { PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - if (docCount[0] > size) { + if (docCount[0] >= size) { return; } switch (r) { @@ -435,6 +436,10 @@ public boolean canApproximate(SearchContext context) { if (!(context.query() instanceof ApproximateIndexOrDocValuesQuery)) { return false; } + // size 0 could be set for caching + if (context.from() + context.size() == 0) { + this.setSize(10_000); + } this.setSize(Math.max(context.from() + context.size(), context.trackTotalHitsUpTo())); if (context.request() != null && context.request().source() != null) { FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(context.request().source()); diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java index dd683d28f00f7..4b41079de18ef 100644 --- a/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java @@ -126,9 +126,9 @@ public void testApproximateRangeWithSizeUnderDefault() throws IOException { } doc.add(new LongPoint("point", scratch)); iw.addDocument(doc); - if (i % 15 == 0) iw.flush(); } iw.flush(); + iw.forceMerge(1); try (IndexReader reader = iw.getReader()) { try { long lower = 0; @@ -166,6 +166,7 @@ public void testApproximateRangeWithSizeOverDefault() throws IOException { iw.addDocument(doc); } iw.flush(); + iw.forceMerge(1); try (IndexReader reader = iw.getReader()) { try { long lower = 0; @@ -183,7 +184,7 @@ protected String toString(int dimension, byte[] value) { }; IndexSearcher searcher = new IndexSearcher(reader); TopDocs topDocs = searcher.search(approximateQuery, 11000); - assertEquals(topDocs.totalHits, new TotalHits(11001, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); + assertEquals(topDocs.totalHits, new TotalHits(11000, TotalHits.Relation.EQUAL_TO)); } catch (IOException e) { throw new RuntimeException(e); } @@ -210,6 +211,7 @@ public void testApproximateRangeShortCircuit() throws IOException { if (i % 10 == 0) iw.flush(); } iw.flush(); + iw.forceMerge(1); try (IndexReader reader = iw.getReader()) { try { long lower = 0; @@ -255,6 +257,7 @@ public void testApproximateRangeShortCircuitAscSort() throws IOException { iw.addDocument(doc); } iw.flush(); + iw.forceMerge(1); try (IndexReader reader = iw.getReader()) { try { long lower = 0; @@ -281,12 +284,12 @@ protected String toString(int dimension, byte[] value) { assertNotEquals(topDocs.totalHits, topDocs1.totalHits); assertEquals(topDocs.totalHits, new TotalHits(10, TotalHits.Relation.EQUAL_TO)); assertEquals(topDocs1.totalHits, new TotalHits(21, TotalHits.Relation.EQUAL_TO)); - assertEquals(topDocs.scoreDocs[0].doc, 0); - assertEquals(topDocs.scoreDocs[1].doc, 1); - assertEquals(topDocs.scoreDocs[2].doc, 2); - assertEquals(topDocs.scoreDocs[3].doc, 3); - assertEquals(topDocs.scoreDocs[4].doc, 4); - assertEquals(topDocs.scoreDocs[5].doc, 5); + assertEquals(topDocs.scoreDocs[0].doc, topDocs1.scoreDocs[0].doc); + assertEquals(topDocs.scoreDocs[1].doc, topDocs1.scoreDocs[1].doc); + assertEquals(topDocs.scoreDocs[2].doc, topDocs1.scoreDocs[2].doc); + assertEquals(topDocs.scoreDocs[3].doc, topDocs1.scoreDocs[3].doc); + assertEquals(topDocs.scoreDocs[4].doc, topDocs1.scoreDocs[4].doc); + assertEquals(topDocs.scoreDocs[5].doc, topDocs1.scoreDocs[5].doc); } catch (IOException e) { throw new RuntimeException(e);