Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
Signed-off-by: Harsha Vamsi Kalluri <[email protected]>
  • Loading branch information
harshavamsi committed Jun 10, 2024
1 parent e92bd90 commit 5728c31
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IntsRef;

import java.io.IOException;
Expand All @@ -44,6 +42,8 @@ public abstract class ApproximatePointRangeQuery extends Query {

private int size;

private long[] docCount = { 0 };

protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) {
this(field, lowerPoint, upperPoint, numDims, 10_000);
}
Expand All @@ -70,6 +70,7 @@ protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upp

this.lowerPoint = lowerPoint;
this.upperPoint = upperPoint;
this.size = size;
}

public int getSize() {
Expand Down Expand Up @@ -147,7 +148,10 @@ public void grow(int count) {

@Override
public void visit(int docID) {
adder.add(docID);
if (docCount[0] <= size) {
adder.add(docID);
docCount[0]++;
}
}

@Override
Expand Down Expand Up @@ -183,63 +187,6 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue
};
}

/**
* Create a visitor that clears documents that do NOT match the range.
*/
private PointValues.IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) {
return new PointValues.IntersectVisitor() {
@Override
public void visit(int docID) {
result.clear(docID);
cost[0]--;
}

@Override
public void visit(DocIdSetIterator iterator) throws IOException {
result.andNot(iterator);
cost[0] = Math.max(0, cost[0] - iterator.cost());
}

@Override
public void visit(IntsRef ref) {
for (int i = ref.offset; i < ref.offset + ref.length; i++) {
result.clear(ref.ints[i]);
}
cost[0] -= ref.length;
}

@Override
public void visit(int docID, byte[] packedValue) {
if (matches(packedValue) == false) {
visit(docID);
}
}

@Override
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
if (matches(packedValue) == false) {
visit(iterator);
}
}

@Override
public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
PointValues.Relation relation = relate(minPackedValue, maxPackedValue);
switch (relation) {
case CELL_INSIDE_QUERY:
// all points match, skip this subtree
return PointValues.Relation.CELL_OUTSIDE_QUERY;
case CELL_OUTSIDE_QUERY:
// none of the points match, clear all documents
return PointValues.Relation.CELL_INSIDE_QUERY;
case CELL_CROSSES_QUERY:
default:
return relation;
}
}
};
}

private boolean checkValidPointValues(PointValues values) throws IOException {
if (values == null) {
// No docs in this segment/field indexed any points
Expand Down Expand Up @@ -270,14 +217,13 @@ private boolean checkValidPointValues(PointValues values) throws IOException {
}

private void intersect(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, int count) throws IOException {
intersect(visitor, pointTree, count, 0);
intersect(visitor, pointTree, count);
assert pointTree.moveToParent() == false;
}

private long intersect(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, int count, long docCount)
throws IOException {
private long intersect(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, int count) throws IOException {
PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
if (docCount >= count) {
if (docCount[0] >= count) {
return 0;
}
switch (r) {
Expand All @@ -294,22 +240,23 @@ private long intersect(PointValues.IntersectVisitor visitor, PointValues.PointTr
// through and do full filtering:
if (pointTree.moveToChild()) {
do {
docCount += intersect(visitor, pointTree, count, docCount);
} while (pointTree.moveToSibling() && docCount <= count);
docCount[0] += intersect(visitor, pointTree, count);
} while (pointTree.moveToSibling() && docCount[0] <= count);
pointTree.moveToParent();
} else {
// TODO: we can assert that the first value here in fact matches what the pointTree
// claimed?
// Leaf node; scan and filter all points in this block:
if (docCount <= count) {
if (docCount[0] <= count) {
pointTree.visitDocValues(visitor);
} else break;
}
break;
default:
throw new IllegalArgumentException("Unreachable code");
}
return 0;
// docCount can be updated by the local visitor so we ensure that we return docCount after pointTree.visitDocValues(visitor)
return docCount[0] > 0 ? docCount[0] : 0;
}

@Override
Expand Down Expand Up @@ -379,20 +326,6 @@ public long cost() {

@Override
public Scorer get(long leadCost) throws IOException {
if (values.getDocCount() == reader.maxDoc()
&& values.getDocCount() == values.size()
&& cost() > reader.maxDoc() / 2) {
// If all docs have exactly one value and the cost is greater
// than half the leaf size then maybe we can make things faster
// by computing the set of documents that do NOT match the range
final FixedBitSet result = new FixedBitSet(reader.maxDoc());
result.set(0, reader.maxDoc());
long[] cost = new long[] { reader.maxDoc() };
intersect(values.getPointTree(), getInverseIntersectVisitor(result, cost), size);
final DocIdSetIterator iterator = new BitSetIterator(result, cost[0]);
return new ConstantScoreScorer(weight, score(), scoreMode, iterator);
}

intersect(values.getPointTree(), visitor, size);
DocIdSetIterator iterator = result.build().iterator();
return new ConstantScoreScorer(weight, score(), scoreMode, iterator);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.approximate;

import com.carrotsearch.randomizedtesting.generators.RandomNumbers;

import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;

import static org.apache.lucene.document.LongPoint.pack;

public class ApproximatePointRangeQueryTests extends OpenSearchTestCase {

public void testApproximateRangeEqualsActualRange() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) {
int dims = 1;

long[] scratch = new long[dims];
for (int i = 0; i < 100; i++) {
int numPoints = RandomNumbers.randomIntBetween(random(), 1, 10);
Document doc = new Document();
for (int j = 0; j < numPoints; j++) {
for (int v = 0; v < dims; v++) {
scratch[v] = RandomNumbers.randomLongBetween(random(), 0, 100);
}
doc.add(new LongPoint("point", scratch));
}
iw.addDocument(doc);
}
iw.flush();
try (IndexReader reader = iw.getReader()) {
try {
long lower = RandomNumbers.randomLongBetween(random(), -100, 200);
long upper = lower + RandomNumbers.randomLongBetween(random(), 0, 100);
Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
};
Query query = new PointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
};
IndexSearcher searcher = new IndexSearcher(reader);
TopDocs topDocs = searcher.search(approximateQuery, 10);
TopDocs topDocs1 = searcher.search(query, 10);
assertEquals(topDocs.totalHits, topDocs1.totalHits);
} catch (IOException e) {
throw new RuntimeException(e);
}

}
}
}
}

public void testApproximateRangeWithSize() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) {
int dims = 1;

long[] scratch = new long[dims];
for (int i = 0; i < 100; i++) {
int numPoints = RandomNumbers.randomIntBetween(random(), 1, 10);
Document doc = new Document();
for (int j = 0; j < numPoints; j++) {
for (int v = 0; v < dims; v++) {
scratch[v] = RandomNumbers.randomLongBetween(random(), 0, 100);
}
doc.add(new LongPoint("point", scratch));
}
iw.addDocument(doc);
}
iw.flush();
try (IndexReader reader = iw.getReader()) {
try {
long lower = RandomNumbers.randomLongBetween(random(), -100, 200);
long upper = lower + RandomNumbers.randomLongBetween(random(), 0, 100);
Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims, 100) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
};
Query query = new PointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
};
IndexSearcher searcher = new IndexSearcher(reader);
TopDocs topDocs = searcher.search(approximateQuery, 10);
TopDocs topDocs1 = searcher.search(query, 10);
assertEquals(topDocs.totalHits, topDocs1.totalHits);
} catch (IOException e) {
throw new RuntimeException(e);
}

}
}
}
}

public void testApproximateRangeShortCircuit() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) {
int dims = 1;

long[] scratch = new long[dims];
int numPoints = 1000;
for (int i = 0; i < numPoints; i++) {
Document doc = new Document();
for (int v = 0; v < dims; v++) {
scratch[v] = i;
}
doc.add(new LongPoint("point", scratch));
iw.addDocument(doc);
if (i % 10 == 0) iw.flush();
}
iw.flush();
try (IndexReader reader = iw.getReader()) {
try {
long lower = 0;
long upper = 100;
Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims, 10) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
};
Query query = new PointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
};
IndexSearcher searcher = new IndexSearcher(reader);
TopDocs topDocs = searcher.search(approximateQuery, 10);
TopDocs topDocs1 = searcher.search(query, 10);

// since we short-circuit from the approx range at the end of size these will not be equal
assertNotEquals(topDocs.totalHits, topDocs1.totalHits);
assertEquals(topDocs.totalHits, new TotalHits(11, TotalHits.Relation.EQUAL_TO));
assertEquals(topDocs1.totalHits, new TotalHits(101, TotalHits.Relation.EQUAL_TO));

} catch (IOException e) {
throw new RuntimeException(e);
}

}
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,5 @@ protected String toString(int dimension, byte[] value) {
}
}
}

}
}

0 comments on commit 5728c31

Please sign in to comment.