From 2a6d7bbe9c367248deeea31cad77960ff65032e9 Mon Sep 17 00:00:00 2001 From: Harsha Vamsi Kalluri Date: Wed, 21 Aug 2024 15:07:21 -0700 Subject: [PATCH] Adding approximate point range query Signed-off-by: Harsha Vamsi Kalluri --- CHANGELOG.md | 41 +- .../test/search/370_approximate_range.yml | 72 ++ .../index/mapper/DateFieldMapper.java | 50 +- .../ApproximatePointRangeQuery.java | 622 ++++++++---------- .../approximate/ApproximateScoreQuery.java | 155 +++++ .../approximate/ApproximateableQuery.java | 147 +---- .../search/approximate/package-info.java | 12 + .../search/internal/ContextIndexSearcher.java | 33 +- .../index/mapper/DateFieldTypeTests.java | 101 ++- ...angeFieldQueryStringQueryBuilderTests.java | 29 +- .../index/mapper/RangeFieldTypeTests.java | 41 ++ .../query/QueryStringQueryBuilderTests.java | 28 +- .../index/query/RangeQueryBuilderTests.java | 142 +++- .../ApproximatePointRangeQueryTests.java | 306 +++++++++ .../ApproximateScoreQueryTests.java | 82 +++ 15 files changed, 1254 insertions(+), 607 deletions(-) create mode 100644 rest-api-spec/src/main/resources/rest-api-spec/test/search/370_approximate_range.yml create mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java create mode 100644 server/src/main/java/org/opensearch/search/approximate/package-info.java create mode 100644 server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java create mode 100644 server/src/test/java/org/opensearch/search/approximate/ApproximateScoreQueryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bff49af99473..b0134da787aa2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,35 +5,18 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x] ### Added -- Constant Keyword Field ([#12285](https://github.com/opensearch-project/OpenSearch/pull/12285)) -- Convert ingest processor supports ip type ([#12818](https://github.com/opensearch-project/OpenSearch/pull/12818)) -- Add a counter to node stat api to track shard going from idle to non-idle ([#12768](https://github.com/opensearch-project/OpenSearch/pull/12768)) -- Allow setting KEYSTORE_PASSWORD through env variable ([#12865](https://github.com/opensearch-project/OpenSearch/pull/12865)) -- [Concurrent Segment Search] Perform buildAggregation concurrently and support Composite Aggregations ([#12697](https://github.com/opensearch-project/OpenSearch/pull/12697)) -- [Concurrent Segment Search] Disable concurrent segment search for system indices and throttled requests ([#12954](https://github.com/opensearch-project/OpenSearch/pull/12954)) -- Rename ingest processor supports overriding target field if exists ([#12990](https://github.com/opensearch-project/OpenSearch/pull/12990)) -- [Tiered Caching] Make took time caching policy setting dynamic ([#13063](https://github.com/opensearch-project/OpenSearch/pull/13063)) -- Derived fields support to derive field values at query time without indexing ([#12569](https://github.com/opensearch-project/OpenSearch/pull/12569)) -- Detect breaking changes on pull requests ([#9044](https://github.com/opensearch-project/OpenSearch/pull/9044)) -- Add cluster primary balance contraint for rebalancing with buffer ([#12656](https://github.com/opensearch-project/OpenSearch/pull/12656)) -- [Remote Store] Make translog transfer timeout configurable ([#12704](https://github.com/opensearch-project/OpenSearch/pull/12704)) -- Reject Resize index requests (i.e, split, shrink and clone), While DocRep to SegRep migration is in progress.([#12686](https://github.com/opensearch-project/OpenSearch/pull/12686)) -- Add support for more than one protocol for transport ([#12967](https://github.com/opensearch-project/OpenSearch/pull/12967)) -- [Tiered Caching] Add dimension-based stats to ICache implementations. ([#12531](https://github.com/opensearch-project/OpenSearch/pull/12531)) -- Add changes for overriding remote store and replication settings during snapshot restore. ([#11868](https://github.com/opensearch-project/OpenSearch/pull/11868)) -- Add an individual setting of rate limiter for segment replication ([#12959](https://github.com/opensearch-project/OpenSearch/pull/12959)) -- [Tiered Caching] Expose new cache stats API ([#13237](https://github.com/opensearch-project/OpenSearch/pull/13237)) -- [Streaming Indexing] Ensure support of the new transport by security plugin ([#13174](https://github.com/opensearch-project/OpenSearch/pull/13174)) -- Add cluster setting to dynamically configure the buckets for filter rewrite optimization. ([#13179](https://github.com/opensearch-project/OpenSearch/pull/13179)) -- [Tiered Caching] Gate new stats logic behind FeatureFlags.PLUGGABLE_CACHE ([#13238](https://github.com/opensearch-project/OpenSearch/pull/13238)) -- [Tiered Caching] Add a dynamic setting to disable/enable disk cache. ([#13373](https://github.com/opensearch-project/OpenSearch/pull/13373)) -- [Remote Store] Add capability of doing refresh as determined by the translog ([#12992](https://github.com/opensearch-project/OpenSearch/pull/12992)) -- [Batch Ingestion] Add `batch_size` to `_bulk` API. ([#12457](https://github.com/opensearch-project/OpenSearch/issues/12457)) -- [Tiered caching] Make Indices Request Cache Stale Key Mgmt Threshold setting dynamic ([#12941](https://github.com/opensearch-project/OpenSearch/pull/12941)) -- Batch mode for async fetching shard information in GatewayAllocator for unassigned shards ([#8746](https://github.com/opensearch-project/OpenSearch/pull/8746)) -- [Remote Store] Add settings for remote path type and hash algorithm ([#13225](https://github.com/opensearch-project/OpenSearch/pull/13225)) -- [Remote Store] Upload remote paths during remote enabled index creation ([#13386](https://github.com/opensearch-project/OpenSearch/pull/13386)) -- [Search Pipeline] Handle default pipeline for multiple indices ([#13276](https://github.com/opensearch-project/OpenSearch/pull/13276)) +- Add fingerprint ingest processor ([#13724](https://github.com/opensearch-project/OpenSearch/pull/13724)) +- [Remote Store] Rate limiter for remote store low priority uploads ([#14374](https://github.com/opensearch-project/OpenSearch/pull/14374/)) +- Apply the date histogram rewrite optimization to range aggregation ([#13865](https://github.com/opensearch-project/OpenSearch/pull/13865)) +- [Writable Warm] Add composite directory implementation and integrate it with FileCache ([12782](https://github.com/opensearch-project/OpenSearch/pull/12782)) +- [Workload Management] Add QueryGroup schema ([13669](https://github.com/opensearch-project/OpenSearch/pull/13669)) +- Add batching supported processor base type AbstractBatchingProcessor ([#14554](https://github.com/opensearch-project/OpenSearch/pull/14554)) +- Fix race condition while parsing derived fields from search definition ([14445](https://github.com/opensearch-project/OpenSearch/pull/14445)) +- Add `strict_allow_templates` dynamic mapping option ([#14555](https://github.com/opensearch-project/OpenSearch/pull/14555)) +- Add allowlist setting for ingest-common and search-pipeline-common processors ([#14439](https://github.com/opensearch-project/OpenSearch/issues/14439)) +- Create SystemIndexRegistry with helper method matchesSystemIndex ([#14415](https://github.com/opensearch-project/OpenSearch/pull/14415)) +- Print reason why parent task was cancelled ([#14604](https://github.com/opensearch-project/OpenSearch/issues/14604)) +- [Range Queries] Add new approximateable query framework to short-circuit range queries ([#13788](https://github.com/opensearch-project/OpenSearch/pull/13788)) ### Dependencies - Bump `org.apache.commons:commons-configuration2` from 2.10.0 to 2.10.1 ([#12896](https://github.com/opensearch-project/OpenSearch/pull/12896)) diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search/370_approximate_range.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search/370_approximate_range.yml new file mode 100644 index 0000000000000..ba896dfcad506 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search/370_approximate_range.yml @@ -0,0 +1,72 @@ +--- +"search with approximate range": + - do: + indices.create: + index: test + body: + mappings: + properties: + date: + type: date + index: true + doc_values: true + + - do: + bulk: + index: test + refresh: true + body: + - '{"index": {"_index": "test", "_id": "1" }}' + - '{ "date": "2018-10-29T12:12:12.987Z" }' + - '{ "index": { "_index": "test", "_id": "2" }}' + - '{ "date": "2020-10-29T12:12:12.987Z" }' + - '{ "index": { "_index": "test", "_id": "3" } }' + - '{ "date": "2024-10-29T12:12:12.987Z" }' + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + query: + range: { + date: { + gte: "2018-10-29T12:12:12.987Z" + }, + } + + - match: { hits.total: 3 } + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + sort: [{ date: asc }] + query: + range: { + date: { + gte: "2018-10-29T12:12:12.987Z" + }, + } + + + - match: { hits.total: 3 } + - match: { hits.hits.0._id: "1" } + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + sort: [{ date: desc }] + query: + range: { + date: { + gte: "2018-10-29T12:12:12.987Z", + lte: "2020-10-29T12:12:12.987Z" + }, + } + + - match: { hits.total: 2 } + - match: { hits.hits.0._id: "2" } diff --git a/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java index 7c60c9869770a..ba40c88475f7f 100644 --- a/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java @@ -63,7 +63,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.search.DocValueFormat; import org.opensearch.search.approximate.ApproximatePointRangeQuery; -import org.opensearch.search.approximate.ApproximateableQuery; +import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.search.lookup.SearchLookup; import java.io.IOException; @@ -82,8 +82,8 @@ import java.util.function.LongSupplier; import java.util.function.Supplier; -import static org.apache.lucene.document.LongPoint.pack; import static org.opensearch.common.time.DateUtils.toLong; +import static org.apache.lucene.document.LongPoint.pack; /** * A {@link FieldMapper} for dates. @@ -461,32 +461,52 @@ public Query rangeQuery( @Nullable DateMathParser forcedDateParser, QueryShardContext context ) { - failIfNotIndexed(); + failIfNotIndexedAndNoDocValues(); if (relation == ShapeRelation.DISJOINT) { throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] does not support DISJOINT ranges"); } DateMathParser parser = forcedDateParser == null ? dateMathParser : forcedDateParser; return dateRangeQuery(lowerTerm, upperTerm, includeLower, includeUpper, timeZone, parser, context, resolution, (l, u) -> { - Query query = new ApproximateableQuery(new PointRangeQuery(name(), pack(new long[]{l}).bytes, pack(new long[]{u}).bytes, new long[]{l}.length) { + Query pointRangeQuery = isSearchable() ? createPointRangeQuery(l, u) : null; + Query dvQuery = hasDocValues() ? SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u) : null; + if (isSearchable() && hasDocValues()) { + Query query = new IndexOrDocValuesQuery(pointRangeQuery, dvQuery); + + if (context.indexSortedOnField(name())) { + query = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, query); + } + return query; + } + if (hasDocValues()) { + Query query = SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u); + if (context.indexSortedOnField(name())) { + query = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, query); + } + return query; + } + return pointRangeQuery; + }); + } + + private Query createPointRangeQuery(long l, long u) { + return new ApproximateScoreQuery( + new PointRangeQuery(name(), pack(new long[] { l }).bytes, pack(new long[] { u }).bytes, new long[] { l }.length) { protected String toString(int dimension, byte[] value) { return Long.toString(LongPoint.decodeDimension(value, 0)); } - }, new ApproximatePointRangeQuery(name(), pack(new long[]{l}).bytes, pack(new long[]{u}).bytes, new long[]{l}.length) { + }, + new ApproximatePointRangeQuery( + name(), + pack(new long[] { l }).bytes, + pack(new long[] { u }).bytes, + new long[] { l }.length + ) { @Override protected String toString(int dimension, byte[] value) { return Long.toString(LongPoint.decodeDimension(value, 0)); } - }); - if (hasDocValues()) { - Query dvQuery = SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u); - query = new IndexOrDocValuesQuery(query, dvQuery); - - if (context.indexSortedOnField(name())) { - query = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, query); - } } - return query; - }); + ); } public static Query dateRangeQuery( diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 4527a066c1efa..4870a378c34b1 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -14,99 +14,96 @@ import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexOrDocValuesQuery; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; +import org.apache.lucene.search.PointRangeQuery; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.DocIdSetBuilder; -import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IntsRef; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; import java.io.IOException; import java.util.Arrays; import java.util.Objects; -import java.util.function.BiFunction; -import java.util.function.Predicate; -import static org.apache.lucene.search.PointRangeQuery.checkArgs; +/** + * An approximate-able version of {@link PointRangeQuery}. It creates an instance of {@link PointRangeQuery} but short-circuits the intersect logic + * after {@code size} is hit + */ +public abstract class ApproximatePointRangeQuery extends ApproximateableQuery { + private int size; -public abstract class ApproximatePointRangeQuery extends Query { - final String field; - final int numDims; - final int bytesPerDim; - final byte[] lowerPoint; - final byte[] upperPoint; + private SortOrder sortOrder; - private int size; + private long[] docCount = { 0 }; + + private final PointRangeQuery pointRangeQuery; protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) { - this(field, lowerPoint, upperPoint, numDims, 10_000); + this(field, lowerPoint, upperPoint, numDims, 10_000, null); } - protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size){ - checkArgs(field, lowerPoint, upperPoint); - this.field = field; - if (numDims <= 0) { - throw new IllegalArgumentException("numDims must be positive, got " + numDims); - } - if (lowerPoint.length == 0) { - throw new IllegalArgumentException("lowerPoint has length of zero"); - } - if (lowerPoint.length % numDims != 0) { - throw new IllegalArgumentException("lowerPoint is not a fixed multiple of numDims"); - } - if (lowerPoint.length != upperPoint.length) { - throw new IllegalArgumentException( - "lowerPoint has length=" - + lowerPoint.length - + " but upperPoint has different length=" - + upperPoint.length); - } - this.numDims = numDims; - this.bytesPerDim = lowerPoint.length / numDims; + protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size) { + this(field, lowerPoint, upperPoint, numDims, size, null); + } - this.lowerPoint = lowerPoint; - this.upperPoint = upperPoint; + protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size, SortOrder sortOrder) { + this.size = size; + this.sortOrder = sortOrder; + this.pointRangeQuery = new PointRangeQuery(field, lowerPoint, upperPoint, numDims) { + @Override + protected String toString(int dimension, byte[] value) { + return super.toString(field); + } + }; } - public int getSize(){ + public int getSize() { return this.size; } - public void setSize(int size){ + public void setSize(int size) { this.size = size; } + public SortOrder getSortOrder() { + return this.sortOrder; + } + + public void setSortOrder(SortOrder sortOrder) { + this.sortOrder = sortOrder; + } + @Override public void visit(QueryVisitor visitor) { - if (visitor.acceptField(field)) { - visitor.visitLeaf(this); - } + pointRangeQuery.visit(visitor); } @Override - public final Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - - // We don't use RandomAccessWeight here: it's no good to approximate with "match all docs". - // This is an inverted structure and should be used in the first pass: + public final ConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + Weight pointRangeQueryWeight = pointRangeQuery.createWeight(searcher, scoreMode, boost); return new ConstantScoreWeight(this, boost) { - private final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim); + private final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(pointRangeQuery.getBytesPerDim()); + // we pull this from PointRangeQuery since it is final private boolean matches(byte[] packedValue) { - for (int dim = 0; dim < numDims; dim++) { - int offset = dim * bytesPerDim; - if (comparator.compare(packedValue, offset, lowerPoint, offset) < 0) { + for (int dim = 0; dim < pointRangeQuery.getNumDims(); dim++) { + int offset = dim * pointRangeQuery.getBytesPerDim(); + if (comparator.compare(packedValue, offset, pointRangeQuery.getLowerPoint(), offset) < 0) { // Doc's value is too low, in this dimension return false; } - if (comparator.compare(packedValue, offset, upperPoint, offset) > 0) { + if (comparator.compare(packedValue, offset, pointRangeQuery.getUpperPoint(), offset) > 0) { // Doc's value is too high, in this dimension return false; } @@ -114,21 +111,21 @@ private boolean matches(byte[] packedValue) { return true; } + // we pull this from PointRangeQuery since it is final private PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { boolean crosses = false; - for (int dim = 0; dim < numDims; dim++) { - int offset = dim * bytesPerDim; + for (int dim = 0; dim < pointRangeQuery.getNumDims(); dim++) { + int offset = dim * pointRangeQuery.getBytesPerDim(); - if (comparator.compare(minPackedValue, offset, upperPoint, offset) > 0 - || comparator.compare(maxPackedValue, offset, lowerPoint, offset) < 0) { + if (comparator.compare(minPackedValue, offset, pointRangeQuery.getUpperPoint(), offset) > 0 + || comparator.compare(maxPackedValue, offset, pointRangeQuery.getLowerPoint(), offset) < 0) { return PointValues.Relation.CELL_OUTSIDE_QUERY; } - crosses |= - comparator.compare(minPackedValue, offset, lowerPoint, offset) < 0 - || comparator.compare(maxPackedValue, offset, upperPoint, offset) > 0; + crosses |= comparator.compare(minPackedValue, offset, pointRangeQuery.getLowerPoint(), offset) < 0 + || comparator.compare(maxPackedValue, offset, pointRangeQuery.getUpperPoint(), offset) > 0; } if (crosses) { @@ -138,7 +135,7 @@ private PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue } } - private PointValues.IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) { + public PointValues.IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) { return new PointValues.IntersectVisitor() { DocIdSetBuilder.BulkAdder adder; @@ -150,7 +147,12 @@ public void grow(int count) { @Override public void visit(int docID) { - adder.add(docID); + // it is possible that size < 1024 and docCount < size but we will continue to count through all the 1024 docs + // and collect less, but it won't hurt performance + if (docCount[0] < size) { + adder.add(docID); + docCount[0]++; + } } @Override @@ -160,8 +162,8 @@ public void visit(DocIdSetIterator iterator) throws IOException { @Override public void visit(IntsRef ref) { - for (int i = ref.offset; i < ref.offset + ref.length; i++) { - adder.add(ref.ints[i]); + for (int i = 0; i < ref.length; i++) { + adder.add(ref.ints[ref.offset + i]); } } @@ -186,230 +188,216 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue }; } - /** - * Create a visitor that clears documents that do NOT match the range. - */ - private PointValues.IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) { - return new PointValues.IntersectVisitor() { - @Override - public void visit(int docID) { - result.clear(docID); - cost[0]--; - } - - @Override - public void visit(DocIdSetIterator iterator) throws IOException { - result.andNot(iterator); - cost[0] = Math.max(0, cost[0] - iterator.cost()); - } - - @Override - public void visit(IntsRef ref) { - for (int i = ref.offset; i < ref.offset + ref.length; i++) { - result.clear(ref.ints[i]); - } - cost[0] -= ref.length; - } - - @Override - public void visit(int docID, byte[] packedValue) { - if (matches(packedValue) == false) { - visit(docID); - } - } - - @Override - public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { - if (matches(packedValue) == false) { - visit(iterator); - } - } - - @Override - public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - PointValues.Relation relation = relate(minPackedValue, maxPackedValue); - switch (relation) { - case CELL_INSIDE_QUERY: - // all points match, skip this subtree - return PointValues.Relation.CELL_OUTSIDE_QUERY; - case CELL_OUTSIDE_QUERY: - // none of the points match, clear all documents - return PointValues.Relation.CELL_INSIDE_QUERY; - case CELL_CROSSES_QUERY: - default: - return relation; - } - } - }; - } - + // we pull this from PointRangeQuery since it is final private boolean checkValidPointValues(PointValues values) throws IOException { if (values == null) { // No docs in this segment/field indexed any points return false; } - if (values.getNumIndexDimensions() != numDims) { + if (values.getNumIndexDimensions() != pointRangeQuery.getNumDims()) { throw new IllegalArgumentException( "field=\"" - + field + + pointRangeQuery.getField() + "\" was indexed with numIndexDimensions=" + values.getNumIndexDimensions() + " but this query has numDims=" - + numDims); + + pointRangeQuery.getNumDims() + ); } - if (bytesPerDim != values.getBytesPerDimension()) { + if (pointRangeQuery.getBytesPerDim() != values.getBytesPerDimension()) { throw new IllegalArgumentException( "field=\"" - + field + + pointRangeQuery.getField() + "\" was indexed with bytesPerDim=" + values.getBytesPerDimension() + " but this query has bytesPerDim=" - + bytesPerDim); + + pointRangeQuery.getBytesPerDim() + ); } return true; } - private void intersect(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, int count) throws IOException { - intersect(visitor, pointTree, count, 0); + private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor) throws IOException { + intersectLeft(visitor, pointTree); + assert pointTree.moveToParent() == false; + } + + private void intersectRight(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor) throws IOException { + intersectRight(visitor, pointTree); assert pointTree.moveToParent() == false; } - private long intersect(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, int count, long docCount) throws IOException { + // custom intersect visitor to walk the left of the tree + public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree) throws IOException { PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - if (docCount >= count) { - return 0; + if (docCount[0] > size) { + return; } switch (r) { case CELL_OUTSIDE_QUERY: // This cell is fully outside the query shape: stop recursing break; case CELL_INSIDE_QUERY: - // This cell is fully inside the query shape: recursively add all points in this cell - // without filtering - pointTree.visitDocIDs(visitor); - return pointTree.size(); + // If the cell is fully inside, we keep moving to child until we reach a point where we can no longer move or when + // we have sufficient doc count. We first move down and then move to the left child + if (pointTree.moveToChild() && docCount[0] < size) { + do { + intersectLeft(visitor, pointTree); + } while (pointTree.moveToSibling() && docCount[0] < size); + pointTree.moveToParent(); + } else { + // we're at the leaf node, if we're under the size, visit all the docIds in this node. + if (docCount[0] < size) { + pointTree.visitDocIDs(visitor); + } + } + break; case CELL_CROSSES_QUERY: // The cell crosses the shape boundary, or the cell fully contains the query, so we fall // through and do full filtering: - if (pointTree.moveToChild()) { + if (pointTree.moveToChild() && docCount[0] < size) { do { - docCount += intersect(visitor, pointTree, count, docCount); - } while (pointTree.moveToSibling() && docCount <= count); + intersectLeft(visitor, pointTree); + } while (pointTree.moveToSibling() && docCount[0] < size); pointTree.moveToParent(); } else { // TODO: we can assert that the first value here in fact matches what the pointTree // claimed? // Leaf node; scan and filter all points in this block: - if (docCount <= count) { + if (docCount[0] < size) { pointTree.visitDocValues(visitor); } - else break; } break; default: throw new IllegalArgumentException("Unreachable code"); } - return 0; + } + + // custom intersect visitor to walk the right of tree + public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree) throws IOException { + PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + if (docCount[0] > size) { + return; + } + switch (r) { + case CELL_OUTSIDE_QUERY: + // This cell is fully outside the query shape: stop recursing + break; + + case CELL_INSIDE_QUERY: + // If the cell is fully inside, we keep moving right as long as the point tree size is over our size requirement + if (pointTree.size() > size && docCount[0] < size && moveRight(pointTree)) { + intersectRight(visitor, pointTree); + pointTree.moveToParent(); + } + // if point tree size is no longer over, we have to go back one level where it still was over and the intersect left + else if (pointTree.size() <= size && docCount[0] < size) { + pointTree.moveToParent(); + intersectLeft(visitor, pointTree); + } + // if we've reached leaf, it means out size is under the size of the leaf, we can just collect all docIDs + else { + // Leaf node; scan and filter all points in this block: + if (docCount[0] < size) { + pointTree.visitDocIDs(visitor); + } + } + break; + case CELL_CROSSES_QUERY: + // If the cell is fully inside, we keep moving right as long as the point tree size is over our size requirement + if (pointTree.size() > size && docCount[0] < size && moveRight(pointTree)) { + intersectRight(visitor, pointTree); + pointTree.moveToParent(); + } + // if point tree size is no longer over, we have to go back one level where it still was over and the intersect left + else if (pointTree.size() <= size && docCount[0] < size) { + pointTree.moveToParent(); + intersectLeft(visitor, pointTree); + } + // if we've reached leaf, it means out size is under the size of the leaf, we can just collect all doc values + else { + // Leaf node; scan and filter all points in this block: + if (docCount[0] < size) { + pointTree.visitDocValues(visitor); + } + } + break; + default: + throw new IllegalArgumentException("Unreachable code"); + } + } + + public boolean moveRight(PointValues.PointTree pointTree) throws IOException { + return pointTree.moveToChild() && pointTree.moveToSibling(); } @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { LeafReader reader = context.reader(); - PointValues values = (PointValues) reader.getPointValues(field); + PointValues values = reader.getPointValues(pointRangeQuery.getField()); if (checkValidPointValues(values) == false) { return null; } - - if (values.getDocCount() == 0) { - return null; - } else { - final byte[] fieldPackedLower = values.getMinPackedValue(); - final byte[] fieldPackedUpper = values.getMaxPackedValue(); - for (int i = 0; i < numDims; ++i) { - int offset = i * bytesPerDim; - if (comparator.compare(lowerPoint, offset, fieldPackedUpper, offset) > 0 - || comparator.compare(upperPoint, offset, fieldPackedLower, offset) < 0) { - // If this query is a required clause of a boolean query, then returning null here - // will help make sure that we don't call ScorerSupplier#get on other required clauses - // of the same boolean query, which is an expensive operation for some queries (e.g. - // multi-term queries). - return null; - } - } - } - - boolean allDocsMatch; - if (values.getDocCount() == reader.maxDoc()) { - final byte[] fieldPackedLower = values.getMinPackedValue(); - final byte[] fieldPackedUpper = values.getMaxPackedValue(); - allDocsMatch = true; - for (int i = 0; i < numDims; ++i) { - int offset = i * bytesPerDim; - if (comparator.compare(lowerPoint, offset, fieldPackedLower, offset) > 0 - || comparator.compare(upperPoint, offset, fieldPackedUpper, offset) < 0) { - allDocsMatch = false; - break; - } - } + final Weight weight = this; + if (size > values.size()) { + return pointRangeQueryWeight.scorerSupplier(context); } else { - allDocsMatch = false; - } + if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { + return new ScorerSupplier() { - final Weight weight = this; - if (allDocsMatch) { - // all docs have a value and all points are within bounds, so everything matches - return new ScorerSupplier() { - @Override - public Scorer get(long leadCost) { - return new ConstantScoreScorer( - weight, score(), scoreMode, DocIdSetIterator.all(reader.maxDoc())); - } + final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, pointRangeQuery.getField()); + final PointValues.IntersectVisitor visitor = getIntersectVisitor(result); + long cost = -1; - @Override - public long cost() { - return reader.maxDoc(); - } - }; - } else { - return new ScorerSupplier() { - - final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); - final PointValues.IntersectVisitor visitor = getIntersectVisitor(result); - long cost = -1; - - @Override - public Scorer get(long leadCost) throws IOException { - if (values.getDocCount() == reader.maxDoc() - && values.getDocCount() == values.size() - && cost() > reader.maxDoc() / 2) { - // If all docs have exactly one value and the cost is greater - // than half the leaf size then maybe we can make things faster - // by computing the set of documents that do NOT match the range - final FixedBitSet result = new FixedBitSet(reader.maxDoc()); - result.set(0, reader.maxDoc()); - long[] cost = new long[]{reader.maxDoc()}; - intersect(values.getPointTree(), getInverseIntersectVisitor(result, cost), size); - final DocIdSetIterator iterator = new BitSetIterator(result, cost[0]); + @Override + public Scorer get(long leadCost) throws IOException { + intersectLeft(values.getPointTree(), visitor); + DocIdSetIterator iterator = result.build().iterator(); return new ConstantScoreScorer(weight, score(), scoreMode, iterator); } - intersect(values.getPointTree(), visitor, size); - DocIdSetIterator iterator = result.build().iterator(); - return new ConstantScoreScorer(weight, score(), scoreMode, iterator); - } + @Override + public long cost() { + if (cost == -1) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; + } + return cost; + } + }; + } else { + // we need to fetch size + deleted docs since the collector will prune away deleted docs resulting in fewer results + // than expected + final int deletedDocs = reader.numDeletedDocs(); + size += deletedDocs; + return new ScorerSupplier() { + + final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, pointRangeQuery.getField()); + final PointValues.IntersectVisitor visitor = getIntersectVisitor(result); + long cost = -1; + + @Override + public Scorer get(long leadCost) throws IOException { + intersectRight(values.getPointTree(), visitor); + DocIdSetIterator iterator = result.build().iterator(); + return new ConstantScoreScorer(weight, score(), scoreMode, iterator); + } - @Override - public long cost() { - if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; + @Override + public long cost() { + if (cost == -1) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; + } + return cost; } - return cost; - } - }; + }; + } } } @@ -424,109 +412,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { @Override public int count(LeafReaderContext context) throws IOException { - LeafReader reader = context.reader(); - - PointValues values = (PointValues) reader.getPointValues(field); - if (checkValidPointValues(values) == false) { - return 0; - } - - if (reader.hasDeletions() == false) { - if (relate(values.getMinPackedValue(), values.getMaxPackedValue()) - == PointValues.Relation.CELL_INSIDE_QUERY) { - return values.getDocCount(); - } - // only 1D: we have the guarantee that it will actually run fast since there are at most 2 - // crossing leaves. - // docCount == size : counting according number of points in leaf node, so must be - // single-valued. - if (numDims == 1 && values.getDocCount() == values.size()) { - return (int) pointCount((PointValues.PointTree) values.getPointTree(), this::relate, this::matches); - } - } - return super.count(context); - } - - /** - * Finds the number of points matching the provided range conditions. Using this method is - * faster than calling {@link PointValues#intersect(PointValues.IntersectVisitor)} to get the count of - * intersecting points. This method does not enforce live documents, therefore it should only - * be used when there are no deleted documents. - * - * @param pointTree start node of the count operation - * @param nodeComparator comparator to be used for checking whether the internal node is - * inside the range - * @param leafComparator comparator to be used for checking whether the leaf node is inside - * the range - * @return count of points that match the range - */ - private long pointCount( - PointValues.PointTree pointTree, - BiFunction nodeComparator, - Predicate leafComparator) - throws IOException { - final long[] matchingNodeCount = {0}; - // create a custom IntersectVisitor that records the number of leafNodes that matched - final PointValues.IntersectVisitor visitor = - new PointValues.IntersectVisitor() { - @Override - public void visit(int docID) { - // this branch should be unreachable - throw new UnsupportedOperationException( - "This IntersectVisitor does not perform any actions on a " - + "docID=" - + docID - + " node being visited"); - } - - @Override - public void visit(int docID, byte[] packedValue) { - if (leafComparator.test(packedValue)) { - matchingNodeCount[0]++; - } - } - - @Override - public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - return nodeComparator.apply(minPackedValue, maxPackedValue); - } - }; - pointCount(visitor, pointTree, matchingNodeCount); - return matchingNodeCount[0]; - } - - private void pointCount( - PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] matchingNodeCount) - throws IOException { - PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - switch (r) { - case CELL_OUTSIDE_QUERY: - // This cell is fully outside the query shape: return 0 as the count of its nodes - return; - case CELL_INSIDE_QUERY: - // This cell is fully inside the query shape: return the size of the entire node as the - // count - matchingNodeCount[0] += pointTree.size(); - return; - case CELL_CROSSES_QUERY: - /* - The cell crosses the shape boundary, or the cell fully contains the query, so we fall - through and do full counting. - */ - if (pointTree.moveToChild()) { - do { - pointCount(visitor, pointTree, matchingNodeCount); - } while (pointTree.moveToSibling()); - pointTree.moveToParent(); - } else { - // we have reached a leaf node here. - pointTree.visitDocValues(visitor); - // leaf node count is saved in the matchingNodeCount array by the visitor - } - return; - default: - throw new IllegalArgumentException("Unreachable code"); - } + return pointRangeQueryWeight.count(context); } @Override @@ -536,35 +422,35 @@ public boolean isCacheable(LeafReaderContext ctx) { }; } - public String getField() { - return field; - } - - public int getNumDims() { - return numDims; - } - - public int getBytesPerDim() { - return bytesPerDim; - } - - public byte[] getLowerPoint() { - return lowerPoint.clone(); - } - - public byte[] getUpperPoint() { - return upperPoint.clone(); + @Override + public boolean canApproximate(SearchContext context) { + if (context == null) { + return false; + } + if (!(context.query() instanceof IndexOrDocValuesQuery + && ((IndexOrDocValuesQuery) context.query()).getIndexQuery() instanceof ApproximateScoreQuery + && ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) context.query()).getIndexQuery()) + .getOriginalQuery() instanceof PointRangeQuery)) { + return false; + } + ApproximateScoreQuery query = ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) context.query()).getIndexQuery()); + ((ApproximatePointRangeQuery) query.getApproximationQuery()).setSize(Math.max(context.size(), context.trackTotalHitsUpTo())); + if (context.request() != null && context.request().source() != null) { + FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(context.request().source()); + if (primarySortField != null + && primarySortField.missing() == null + && primarySortField.getFieldName().equals(((RangeQueryBuilder) context.request().source().query()).fieldName())) { + if (primarySortField.order() == SortOrder.DESC) { + ((ApproximatePointRangeQuery) query.getApproximationQuery()).setSortOrder(SortOrder.DESC); + } + } + } + return true; } @Override public final int hashCode() { - int hash = classHash(); - hash = 31 * hash + field.hashCode(); - hash = 31 * hash + Arrays.hashCode(lowerPoint); - hash = 31 * hash + Arrays.hashCode(upperPoint); - hash = 31 * hash + numDims; - hash = 31 * hash + Objects.hashCode(bytesPerDim); - return hash; + return pointRangeQuery.hashCode(); } @Override @@ -573,42 +459,56 @@ public final boolean equals(Object o) { } private boolean equalsTo(ApproximatePointRangeQuery other) { - return Objects.equals(field, other.getField()) - && numDims == other.getNumDims() - && bytesPerDim == other.getBytesPerDim() - && Arrays.equals(lowerPoint, other.getLowerPoint()) - && Arrays.equals(upperPoint, other.getUpperPoint()); + return Objects.equals(pointRangeQuery.getField(), other.pointRangeQuery.getField()) + && pointRangeQuery.getNumDims() == other.pointRangeQuery.getNumDims() + && pointRangeQuery.getBytesPerDim() == other.pointRangeQuery.getBytesPerDim() + && Arrays.equals(pointRangeQuery.getLowerPoint(), other.pointRangeQuery.getLowerPoint()) + && Arrays.equals(pointRangeQuery.getUpperPoint(), other.pointRangeQuery.getUpperPoint()); } @Override public final String toString(String field) { final StringBuilder sb = new StringBuilder(); - if (this.field.equals(field) == false) { - sb.append(this.field); + if (pointRangeQuery.getField().equals(field) == false) { + sb.append(pointRangeQuery.getField()); sb.append(':'); } // print ourselves as "range per dimension" - for (int i = 0; i < numDims; i++) { + for (int i = 0; i < pointRangeQuery.getNumDims(); i++) { if (i > 0) { sb.append(','); } - int startOffset = bytesPerDim * i; + int startOffset = pointRangeQuery.getBytesPerDim() * i; sb.append('['); sb.append( toString( - i, ArrayUtil.copyOfSubArray(lowerPoint, startOffset, startOffset + bytesPerDim))); + i, + ArrayUtil.copyOfSubArray(pointRangeQuery.getLowerPoint(), startOffset, startOffset + pointRangeQuery.getBytesPerDim()) + ) + ); sb.append(" TO "); sb.append( toString( - i, ArrayUtil.copyOfSubArray(upperPoint, startOffset, startOffset + bytesPerDim))); + i, + ArrayUtil.copyOfSubArray(pointRangeQuery.getUpperPoint(), startOffset, startOffset + pointRangeQuery.getBytesPerDim()) + ) + ); sb.append(']'); } return sb.toString(); } + /** + * Returns a string of a single value in a human-readable format for debugging. This is used by + * {@link #toString()}. + * + * @param dimension dimension of the particular value + * @param value single value, never null + * @return human readable value for debugging + */ protected abstract String toString(int dimension, byte[] value); } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java new file mode 100644 index 0000000000000..fad46712136a8 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -0,0 +1,155 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Matches; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.Weight; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; + +/** + * Base class for a query that can be approximated. + * + * This class is heavily inspired by {@link org.apache.lucene.search.IndexOrDocValuesQuery}. It acts as a wrapper that consumer two queries, a regular query and an approximate version of the same. By default, it executes the regular query and returns {@link Weight#scorer} for the original query. At run-time, depending on certain constraints, we can re-write the {@code Weight} to use the approximate weight instead. + */ +public final class ApproximateScoreQuery extends Query { + + private final Query originalQuery; + private final ApproximateableQuery approximationQuery; + + private Weight originalQueryWeight, approximationQueryWeight; + + private SearchContext context; + + public ApproximateScoreQuery(Query originalQuery, ApproximateableQuery approximationQuery) { + this(originalQuery, approximationQuery, null, null); + } + + public ApproximateScoreQuery( + Query originalQuery, + ApproximateableQuery approximationQuery, + Weight originalQueryWeight, + Weight approximationQueryWeight + ) { + this.originalQuery = originalQuery; + this.approximationQuery = approximationQuery; + this.originalQueryWeight = originalQueryWeight; + this.approximationQueryWeight = approximationQueryWeight; + } + + public Query getOriginalQuery() { + return originalQuery; + } + + public ApproximateableQuery getApproximationQuery() { + return approximationQuery; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + originalQueryWeight = originalQuery.createWeight(searcher, scoreMode, boost); + approximationQueryWeight = approximationQuery.createWeight(searcher, scoreMode, boost); + + return new Weight(this) { + @Override + public Explanation explain(LeafReaderContext leafReaderContext, int doc) throws IOException { + return originalQueryWeight.explain(leafReaderContext, doc); + } + + @Override + public Matches matches(LeafReaderContext leafReaderContext, int doc) throws IOException { + return originalQueryWeight.matches(leafReaderContext, doc); + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext leafReaderContext) throws IOException { + final ScorerSupplier originalQueryScoreSupplier = originalQueryWeight.scorerSupplier(leafReaderContext); + final ScorerSupplier approximationQueryScoreSupplier = approximationQueryWeight.scorerSupplier(leafReaderContext); + if (originalQueryScoreSupplier == null || approximationQueryScoreSupplier == null) { + return null; + } + + return new ScorerSupplier() { + @Override + public Scorer get(long l) throws IOException { + if (approximationQuery.canApproximate(context)) { + return approximationQueryScoreSupplier.get(l); + } + return originalQueryScoreSupplier.get(l); + } + + @Override + public long cost() { + return originalQueryScoreSupplier.cost(); + } + }; + } + + @Override + public Scorer scorer(LeafReaderContext leafReaderContext) throws IOException { + ScorerSupplier scorerSupplier = scorerSupplier(leafReaderContext); + if (scorerSupplier == null) { + return null; + } + return scorerSupplier.get(Long.MAX_VALUE); + } + + @Override + public boolean isCacheable(LeafReaderContext leafReaderContext) { + return originalQueryWeight.isCacheable(leafReaderContext); + } + }; + } + + public void setContext(SearchContext context) { + this.context = context; + }; + + @Override + public String toString(String s) { + return "ApproximateScoreQuery(originalQuery=" + + originalQuery.toString() + + ", approximationQuery=" + + approximationQuery.toString() + + ")"; + } + + @Override + public void visit(QueryVisitor queryVisitor) { + QueryVisitor v = queryVisitor.getSubVisitor(BooleanClause.Occur.MUST, this); + originalQuery.visit(v); + approximationQuery.visit(v); + } + + @Override + public boolean equals(Object o) { + if (!sameClassAs(o)) { + return false; + } + return true; + } + + @Override + public int hashCode() { + int h = classHash(); + h = 31 * h + originalQuery.hashCode(); + h = 31 * h + approximationQuery.hashCode(); + return h; + } +} diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateableQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateableQuery.java index 9a709161c3dee..d65b844bfe70f 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateableQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateableQuery.java @@ -8,151 +8,14 @@ package org.opensearch.search.approximate; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Matches; import org.apache.lucene.search.Query; -import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.ScorerSupplier; -import org.apache.lucene.search.Weight; - -import java.io.IOException; +import org.opensearch.search.internal.SearchContext; /** - * Base class for a query that can be approximated - */ -public final class ApproximateableQuery extends Query { - - private final Query originalQuery, approximationQuery; - - private Weight originalQueryWeight, approximationQueryWeight; - - public ApproximateableQuery(Query originalQuery, Query approximationQuery) { - this(originalQuery, approximationQuery, null, null); - } - - public ApproximateableQuery(Query originalQuery, Query approximationQuery, Weight originalQueryWeight, Weight approximationQueryWeight){ - this.originalQuery = originalQuery; - this.approximationQuery = approximationQuery; - this.originalQueryWeight = originalQueryWeight; - this.approximationQueryWeight = approximationQueryWeight; - } - - public void setOriginalQueryWeight(Weight originalQueryWeight){ - this.originalQueryWeight = originalQueryWeight; - } - - public void setApproximationQueryWeight(Weight approximationQueryWeight){ - this.approximationQueryWeight = approximationQueryWeight; - } - - public Weight getOriginalQueryWeight() { - return originalQueryWeight; - } - - public Weight getApproximationQueryWeight() { - return approximationQueryWeight; - } - - public Query getOriginalQuery() { - return originalQuery; - } - - public Query getApproximationQuery() { - return approximationQuery; - } - - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - final Weight originalQueryWeight = originalQuery.createWeight(searcher, scoreMode, boost); - setOriginalQueryWeight(originalQueryWeight); - final Weight approximationQueryWeight = approximationQuery.createWeight(searcher, scoreMode, boost); - setApproximationQueryWeight(approximationQueryWeight); - - return new Weight(this) { - @Override - public Explanation explain(LeafReaderContext leafReaderContext, int doc) throws IOException { - return originalQueryWeight.explain(leafReaderContext, doc); - } - - @Override - public Matches matches(LeafReaderContext leafReaderContext, int doc) throws IOException { - return originalQueryWeight.matches(leafReaderContext, doc); - } - - @Override - public ScorerSupplier scorerSupplier(LeafReaderContext leafReaderContext) throws IOException { - final ScorerSupplier originalQueryScoreSupplier = originalQueryWeight.scorerSupplier(leafReaderContext); - final ScorerSupplier approximationQueryScoreSupplier = approximationQueryWeight.scorerSupplier(leafReaderContext); - if (originalQueryScoreSupplier == null || approximationQueryScoreSupplier == null) { - return null; - } - - return new ScorerSupplier() { - @Override - public Scorer get(long l) throws IOException { - // TODO: we need to figure out how to compute the cost of running two different queries, by default return the original query's scoreSupplier - return originalQueryScoreSupplier.get(l); - } - - @Override - public long cost() { - return originalQueryScoreSupplier.cost(); - } - }; - } - - @Override - public Scorer scorer(LeafReaderContext leafReaderContext) throws IOException { - ScorerSupplier scorerSupplier = scorerSupplier(leafReaderContext); - if (scorerSupplier == null) { - return null; - } - return scorerSupplier.get(Long.MAX_VALUE); - } - - @Override - public boolean isCacheable(LeafReaderContext leafReaderContext) { - return originalQueryWeight.isCacheable(leafReaderContext); - } - }; - } - - - @Override - public String toString(String s) { - return "ApproximateableQuery(originalQuery=" - + originalQuery.toString() - + ", approximationQuery=" - + approximationQuery.toString() - + ")"; - } - - @Override - public void visit(QueryVisitor queryVisitor) { - QueryVisitor v = queryVisitor.getSubVisitor(BooleanClause.Occur.MUST, this); - originalQuery.visit(v); - approximationQuery.visit(v); - } + * Abstract class that can be inherited by queries that can be approximated. Queries should implement {@link #canApproximate(SearchContext)} to specify conditions on when they can be approximated +*/ +public abstract class ApproximateableQuery extends Query { - @Override - public boolean equals(Object o) { - if(!sameClassAs(o)){ - return false; - } - ApproximateableQuery that = (ApproximateableQuery) o; - return originalQuery.equals(that.originalQuery) && approximationQuery.equals(that.approximationQuery); - } + protected abstract boolean canApproximate(SearchContext context); - @Override - public int hashCode() { - int h = classHash(); - h = 31 * h + originalQuery.hashCode(); - h = 31 * h + approximationQuery.hashCode(); - return h; - } } diff --git a/server/src/main/java/org/opensearch/search/approximate/package-info.java b/server/src/main/java/org/opensearch/search/approximate/package-info.java new file mode 100644 index 0000000000000..1a09183c7d9fa --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Approximation query framework to approximate commonly used queries + */ +package org.opensearch.search.approximate; diff --git a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java index 76165c6441e59..ccf9cc26299b8 100644 --- a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java @@ -34,7 +34,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; @@ -50,14 +49,12 @@ import org.apache.lucene.search.IndexOrDocValuesQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; -import org.apache.lucene.search.PointRangeQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryCache; import org.apache.lucene.search.QueryCachingPolicy; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; @@ -71,12 +68,9 @@ import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; -import org.opensearch.index.mapper.DateFieldMapper; -import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchService; -import org.opensearch.search.approximate.ApproximatePointRangeQuery; -import org.opensearch.search.approximate.ApproximateableQuery; +import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.search.dfs.AggregatedDfs; import org.opensearch.search.profile.ContextualProfileBreakdown; import org.opensearch.search.profile.Timer; @@ -329,11 +323,10 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto // catch early terminated exception and rethrow? Bits liveDocs = ctx.reader().getLiveDocs(); BitSet liveDocsBitSet = getSparseBitSetOrNull(liveDocs); - if(isApproximateableRangeQuery()){ - ApproximateableQuery query = ((ApproximateableQuery) ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery()); - if (searchContext.size() > 10_000) - ((ApproximatePointRangeQuery) query.getApproximationQuery()).setSize(searchContext.size()); - weight = query.getApproximationQueryWeight(); + if (searchContext.query() instanceof IndexOrDocValuesQuery + && ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery() instanceof ApproximateScoreQuery) { + ApproximateScoreQuery query = ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery()); + query.setContext(searchContext); } if (liveDocsBitSet == null) { BulkScorer bulkScorer = weight.bulkScorer(ctx); @@ -426,22 +419,6 @@ private static BitSet getSparseBitSetOrNull(Bits liveDocs) { } - private boolean isApproximateableRangeQuery(){ - boolean isTopLevelRangeQuery = searchContext.query() instanceof IndexOrDocValuesQuery && - ((ApproximateableQuery) ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery()).getOriginalQuery() instanceof PointRangeQuery; - - boolean hasSort = false; - - if (searchContext.request() != null && searchContext.request().source() != null ) { - FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(searchContext.request().source()); - if (primarySortField != null && primarySortField.missing() == null && Objects.equals(searchContext.trackTotalHitsUpTo(), SearchContext.TRACK_TOTAL_HITS_DISABLED)) { - hasSort = true; - } - } - - return isTopLevelRangeQuery && !hasSort; - } - static void intersectScorerAndBitSet(Scorer scorer, BitSet acceptDocs, LeafCollector collector, Runnable checkCancelled) throws IOException { collector.setScorer(scorer); diff --git a/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java index ab53ae81ab0ce..47253c3403f92 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java @@ -44,6 +44,7 @@ import org.apache.lucene.search.IndexOrDocValuesQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSortSortedNumericDocValuesRangeQuery; +import org.apache.lucene.search.PointRangeQuery; import org.apache.lucene.search.Query; import org.apache.lucene.store.Directory; import org.opensearch.Version; @@ -65,12 +66,16 @@ import org.opensearch.index.query.DateRangeIncludingNowQuery; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.search.approximate.ApproximatePointRangeQuery; +import org.opensearch.search.approximate.ApproximateScoreQuery; import org.joda.time.DateTimeZone; import java.io.IOException; import java.time.ZoneOffset; import java.util.Collections; +import static org.apache.lucene.document.LongPoint.pack; + public class DateFieldTypeTests extends FieldTypeTestCase { private static final long nowInMillis = 0; @@ -207,7 +212,29 @@ public void testTermQuery() { String date = "2015-10-12T14:10:55"; long instant = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date)).toInstant().toEpochMilli(); Query expected = new IndexOrDocValuesQuery( - LongPoint.newRangeQuery("field", instant, instant + 999), + new ApproximateScoreQuery( + new PointRangeQuery( + "field", + pack(new long[] { instant }).bytes, + pack(new long[] { instant + 999 }).bytes, + new long[] { instant }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + new ApproximatePointRangeQuery( + "field", + pack(new long[] { instant }).bytes, + pack(new long[] { instant + 999 }).bytes, + new long[] { instant }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + } + ), SortedNumericDocValuesField.newSlowRangeQuery("field", instant, instant + 999) ); assertEquals(expected, ft.termQuery(date, context)); @@ -257,7 +284,29 @@ public void testRangeQuery() throws IOException { long instant1 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date1)).toInstant().toEpochMilli(); long instant2 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date2)).toInstant().toEpochMilli() + 999; Query expected = new IndexOrDocValuesQuery( - LongPoint.newRangeQuery("field", instant1, instant2), + new ApproximateScoreQuery( + new PointRangeQuery( + "field", + pack(new long[] { instant1 }).bytes, + pack(new long[] { instant2 }).bytes, + new long[] { instant1 }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + new ApproximatePointRangeQuery( + "field", + pack(new long[] { instant1 }).bytes, + pack(new long[] { instant2 }).bytes, + new long[] { instant1 }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + } + ), SortedNumericDocValuesField.newSlowRangeQuery("field", instant1, instant2) ); assertEquals( @@ -269,7 +318,29 @@ public void testRangeQuery() throws IOException { instant2 = instant1 + 100; expected = new DateRangeIncludingNowQuery( new IndexOrDocValuesQuery( - LongPoint.newRangeQuery("field", instant1, instant2), + new ApproximateScoreQuery( + new PointRangeQuery( + "field", + pack(new long[] { instant1 }).bytes, + pack(new long[] { instant2 }).bytes, + new long[] { instant1 }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + new ApproximatePointRangeQuery( + "field", + pack(new long[] { instant1 }).bytes, + pack(new long[] { instant2 }).bytes, + new long[] { instant1 }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + } + ), SortedNumericDocValuesField.newSlowRangeQuery("field", instant1, instant2) ) ); @@ -329,7 +400,29 @@ public void testRangeQueryWithIndexSort() { long instant1 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date1)).toInstant().toEpochMilli(); long instant2 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date2)).toInstant().toEpochMilli() + 999; - Query pointQuery = LongPoint.newRangeQuery("field", instant1, instant2); + Query pointQuery = new ApproximateScoreQuery( + new PointRangeQuery( + "field", + pack(new long[] { instant1 }).bytes, + pack(new long[] { instant2 }).bytes, + new long[] { instant1 }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + new ApproximatePointRangeQuery( + "field", + pack(new long[] { instant1 }).bytes, + pack(new long[] { instant2 }).bytes, + new long[] { instant1 }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + } + ); Query dvQuery = SortedNumericDocValuesField.newSlowRangeQuery("field", instant1, instant2); Query expected = new IndexSortSortedNumericDocValuesRangeQuery( "field", diff --git a/server/src/test/java/org/opensearch/index/mapper/RangeFieldQueryStringQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/mapper/RangeFieldQueryStringQueryBuilderTests.java index 9dea7e13ac45e..496ab4370b399 100644 --- a/server/src/test/java/org/opensearch/index/mapper/RangeFieldQueryStringQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/RangeFieldQueryStringQueryBuilderTests.java @@ -49,6 +49,8 @@ import org.opensearch.common.time.DateMathParser; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryStringQueryBuilder; +import org.opensearch.search.approximate.ApproximatePointRangeQuery; +import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.test.AbstractQueryTestCase; import java.io.IOException; @@ -56,6 +58,7 @@ import static org.hamcrest.Matchers.either; import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.apache.lucene.document.LongPoint.pack; public class RangeFieldQueryStringQueryBuilderTests extends AbstractQueryTestCase { @@ -173,10 +176,28 @@ public void testDateRangeQuery() throws Exception { DateFieldMapper.DateFieldType dateType = (DateFieldMapper.DateFieldType) context.fieldMapper(DATE_FIELD_NAME); parser = dateType.dateMathParser; Query queryOnDateField = new QueryStringQueryBuilder(DATE_FIELD_NAME + ":[2010-01-01 TO 2018-01-01]").toQuery(createShardContext()); - Query controlQuery = LongPoint.newRangeQuery( - DATE_FIELD_NAME, - new long[] { parser.parse(lowerBoundExact, () -> 0).toEpochMilli() }, - new long[] { parser.parse(upperBoundExact, () -> 0).toEpochMilli() } + Query controlQuery = new ApproximateScoreQuery( + new PointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { parser.parse(lowerBoundExact, () -> 0).toEpochMilli() }).bytes, + pack(new long[] { parser.parse(upperBoundExact, () -> 0).toEpochMilli() }).bytes, + new long[] { parser.parse(lowerBoundExact, () -> 0).toEpochMilli() }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { parser.parse(lowerBoundExact, () -> 0).toEpochMilli() }).bytes, + pack(new long[] { parser.parse(upperBoundExact, () -> 0).toEpochMilli() }).bytes, + new long[] { parser.parse(lowerBoundExact, () -> 0).toEpochMilli() }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + } ); Query controlDv = SortedNumericDocValuesField.newSlowRangeQuery( diff --git a/server/src/test/java/org/opensearch/index/mapper/RangeFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/RangeFieldTypeTests.java index 00b48240d0567..4e02410a37d83 100644 --- a/server/src/test/java/org/opensearch/index/mapper/RangeFieldTypeTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/RangeFieldTypeTests.java @@ -51,11 +51,13 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.time.DateFormatter; import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.DateFieldMapper.DateFieldType; import org.opensearch.index.mapper.RangeFieldMapper.RangeFieldType; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; +import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.test.IndexSettingsModule; import org.joda.time.DateTime; import org.junit.Before; @@ -249,6 +251,45 @@ private QueryShardContext createContext() { ); } + public void testDateRangeQueryUsingMappingFormatLegacy() { + assumeThat("Using legacy datetime format as default", FeatureFlags.isEnabled(FeatureFlags.DATETIME_FORMATTER_CACHING), is(false)); + + QueryShardContext context = createContext(); + RangeFieldType strict = new RangeFieldType("field", RangeFieldMapper.Defaults.DATE_FORMATTER); + // don't use DISJOINT here because it doesn't work on date fields which we want to compare bounds with + ShapeRelation relation = randomValueOtherThan(ShapeRelation.DISJOINT, () -> randomFrom(ShapeRelation.values())); + + // dates will break the default format, month/day of month is turned around in the format + final String from = "2016-15-06T15:29:50+08:00"; + final String to = "2016-16-06T15:29:50+08:00"; + + OpenSearchParseException ex = expectThrows( + OpenSearchParseException.class, + () -> strict.rangeQuery(from, to, true, true, relation, null, null, context) + ); + assertThat( + ex.getMessage(), + containsString("failed to parse date field [2016-15-06T15:29:50+08:00] with format [strict_date_optional_time||epoch_millis]") + ); + + // setting mapping format which is compatible with those dates + final DateFormatter formatter = DateFormatter.forPattern("yyyy-dd-MM'T'HH:mm:ssZZZZZ"); + assertEquals(1465975790000L, formatter.parseMillis(from)); + assertEquals(1466062190000L, formatter.parseMillis(to)); + + RangeFieldType fieldType = new RangeFieldType("field", formatter); + final Query query = fieldType.rangeQuery(from, to, true, true, relation, null, fieldType.dateMathParser(), context); + assertEquals("field:", ((IndexOrDocValuesQuery) query).getIndexQuery().toString()); + + // compare lower and upper bounds with what we would get on a `date` field + DateFieldType dateFieldType = new DateFieldType("field", DateFieldMapper.Resolution.MILLISECONDS, formatter); + final Query queryOnDateField = dateFieldType.rangeQuery(from, to, true, true, relation, null, fieldType.dateMathParser(), context); + assertEquals( + "field:[1465975790000 TO 1466062190999]", + ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) queryOnDateField).getIndexQuery()).getOriginalQuery().toString() + ); + } + public void testDateRangeQueryUsingMappingFormat() { QueryShardContext context = createContext(); RangeFieldType strict = new RangeFieldType("field", RangeFieldMapper.Defaults.DATE_FORMATTER); diff --git a/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java index af4a34aa98116..53a4fd9509102 100644 --- a/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java @@ -53,6 +53,7 @@ import org.apache.lucene.search.MultiTermQuery; import org.apache.lucene.search.NormsFieldExistsQuery; import org.apache.lucene.search.PhraseQuery; +import org.apache.lucene.search.PointRangeQuery; import org.apache.lucene.search.PrefixQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.RegexpQuery; @@ -76,6 +77,8 @@ import org.opensearch.index.mapper.FieldNamesFieldMapper; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.search.QueryStringQueryParser; +import org.opensearch.search.approximate.ApproximatePointRangeQuery; +import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.test.AbstractQueryTestCase; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; @@ -98,6 +101,7 @@ import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; +import static org.apache.lucene.document.LongPoint.pack; public class QueryStringQueryBuilderTests extends AbstractQueryTestCase { @@ -863,7 +867,29 @@ public void testToQueryDateWithTimeZone() throws Exception { } private IndexOrDocValuesQuery calculateExpectedDateQuery(long lower, long upper) { - Query query = LongPoint.newRangeQuery(DATE_FIELD_NAME, lower, upper); + Query query = new ApproximateScoreQuery( + new PointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + } + ); Query dv = SortedNumericDocValuesField.newSlowRangeQuery(DATE_FIELD_NAME, lower, upper); return new IndexOrDocValuesQuery(query, dv); } diff --git a/server/src/test/java/org/opensearch/index/query/RangeQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/RangeQueryBuilderTests.java index e72be29b85b63..4dbd04f25a1b0 100644 --- a/server/src/test/java/org/opensearch/index/query/RangeQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/RangeQueryBuilderTests.java @@ -53,6 +53,8 @@ import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MappedFieldType.Relation; import org.opensearch.index.mapper.MapperService; +import org.opensearch.search.approximate.ApproximatePointRangeQuery; +import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.test.AbstractQueryTestCase; import org.joda.time.DateTime; import org.joda.time.chrono.ISOChronology; @@ -68,6 +70,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.sameInstance; +import static org.apache.lucene.document.LongPoint.pack; public class RangeQueryBuilderTests extends AbstractQueryTestCase { @Override @@ -185,7 +188,11 @@ protected void doAssertLuceneQuery(RangeQueryBuilder queryBuilder, Query query, } else if (expectedFieldName.equals(DATE_FIELD_NAME)) { assertThat(query, instanceOf(IndexOrDocValuesQuery.class)); query = ((IndexOrDocValuesQuery) query).getIndexQuery(); - assertThat(query, instanceOf(PointRangeQuery.class)); + assertThat(query, instanceOf(ApproximateScoreQuery.class)); + Query originalQuery = ((ApproximateScoreQuery) query).getOriginalQuery(); + assertThat(originalQuery, instanceOf(PointRangeQuery.class)); + Query approximateQuery = ((ApproximateScoreQuery) query).getApproximationQuery(); + assertThat(approximateQuery, instanceOf(ApproximatePointRangeQuery.class)); MapperService mapperService = context.getMapperService(); MappedFieldType mappedFieldType = mapperService.fieldType(expectedFieldName); final Long fromInMillis; @@ -234,7 +241,32 @@ protected void doAssertLuceneQuery(RangeQueryBuilder queryBuilder, Query query, maxLong--; } } - assertEquals(LongPoint.newRangeQuery(DATE_FIELD_NAME, minLong, maxLong), query); + assertEquals( + new ApproximateScoreQuery( + new PointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { minLong }).bytes, + pack(new long[] { maxLong }).bytes, + new long[] { minLong }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { minLong }).bytes, + pack(new long[] { maxLong }).bytes, + new long[] { minLong }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + } + ), + query + ); } else if (expectedFieldName.equals(INT_FIELD_NAME)) { assertThat(query, instanceOf(IndexOrDocValuesQuery.class)); query = ((IndexOrDocValuesQuery) query).getIndexQuery(); @@ -301,13 +333,36 @@ public void testDateRangeQueryFormat() throws IOException { Query parsedQuery = parseQuery(query).toQuery(createShardContext()); assertThat(parsedQuery, instanceOf(IndexOrDocValuesQuery.class)); parsedQuery = ((IndexOrDocValuesQuery) parsedQuery).getIndexQuery(); - assertThat(parsedQuery, instanceOf(PointRangeQuery.class)); - + assertThat(parsedQuery, instanceOf(ApproximateScoreQuery.class)); + Query originalQuery = ((ApproximateScoreQuery) parsedQuery).getOriginalQuery(); + assertThat(originalQuery, instanceOf(PointRangeQuery.class)); + Query approximateQuery = ((ApproximateScoreQuery) parsedQuery).getApproximationQuery(); + assertThat(approximateQuery, instanceOf(ApproximatePointRangeQuery.class)); + long lower = DateTime.parse("2012-01-01T00:00:00.000+00").getMillis(); + long upper = DateTime.parse("2030-01-01T00:00:00.000+00").getMillis() - 1; assertEquals( - LongPoint.newRangeQuery( - DATE_FIELD_NAME, - DateTime.parse("2012-01-01T00:00:00.000+00").getMillis(), - DateTime.parse("2030-01-01T00:00:00.000+00").getMillis() - 1 + new ApproximateScoreQuery( + new PointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + } ), parsedQuery ); @@ -341,14 +396,35 @@ public void testDateRangeBoundaries() throws IOException { Query parsedQuery = parseQuery(query).toQuery(createShardContext()); assertThat(parsedQuery, instanceOf(IndexOrDocValuesQuery.class)); parsedQuery = ((IndexOrDocValuesQuery) parsedQuery).getIndexQuery(); - assertThat(parsedQuery, instanceOf(PointRangeQuery.class)); + assertThat(parsedQuery, instanceOf(ApproximateScoreQuery.class)); + + long lower = DateTime.parse("2014-11-01T00:00:00.000+00").getMillis(); + long upper = DateTime.parse("2014-12-08T23:59:59.999+00").getMillis(); assertEquals( - LongPoint.newRangeQuery( - DATE_FIELD_NAME, - DateTime.parse("2014-11-01T00:00:00.000+00").getMillis(), - DateTime.parse("2014-12-08T23:59:59.999+00").getMillis() - ), - parsedQuery + new ApproximateScoreQuery( + new PointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + } + ).toString(), + parsedQuery.toString() ); query = "{\n" @@ -364,14 +440,34 @@ public void testDateRangeBoundaries() throws IOException { parsedQuery = parseQuery(query).toQuery(createShardContext()); assertThat(parsedQuery, instanceOf(IndexOrDocValuesQuery.class)); parsedQuery = ((IndexOrDocValuesQuery) parsedQuery).getIndexQuery(); - assertThat(parsedQuery, instanceOf(PointRangeQuery.class)); + assertThat(parsedQuery, instanceOf(ApproximateScoreQuery.class)); + lower = DateTime.parse("2014-11-30T23:59:59.999+00").getMillis() + 1; + upper = DateTime.parse("2014-12-08T00:00:00.000+00").getMillis() - 1; assertEquals( - LongPoint.newRangeQuery( - DATE_FIELD_NAME, - DateTime.parse("2014-11-30T23:59:59.999+00").getMillis() + 1, - DateTime.parse("2014-12-08T00:00:00.000+00").getMillis() - 1 - ), - parsedQuery + new ApproximateScoreQuery( + new PointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + } + ).toString(), + parsedQuery.toString() ); } @@ -393,7 +489,7 @@ public void testDateRangeQueryTimezone() throws IOException { parsedQuery = ((DateRangeIncludingNowQuery) parsedQuery).getQuery(); assertThat(parsedQuery, instanceOf(IndexOrDocValuesQuery.class)); parsedQuery = ((IndexOrDocValuesQuery) parsedQuery).getIndexQuery(); - assertThat(parsedQuery, instanceOf(PointRangeQuery.class)); + assertThat(parsedQuery, instanceOf(ApproximateScoreQuery.class)); // TODO what else can we assert query = "{\n" diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java new file mode 100644 index 0000000000000..6460d69aea4d4 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java @@ -0,0 +1,306 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; + +import org.apache.lucene.analysis.core.WhitespaceAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.PointRangeQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.apache.lucene.document.LongPoint.pack; + +public class ApproximatePointRangeQueryTests extends OpenSearchTestCase { + + public void testApproximateRangeEqualsActualRange() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + for (int i = 0; i < 100; i++) { + int numPoints = RandomNumbers.randomIntBetween(random(), 1, 10); + Document doc = new Document(); + for (int j = 0; j < numPoints; j++) { + for (int v = 0; v < dims; v++) { + scratch[v] = RandomNumbers.randomLongBetween(random(), 0, 100); + } + doc.add(new LongPoint("point", scratch)); + } + iw.addDocument(doc); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = RandomNumbers.randomLongBetween(random(), -100, 200); + long upper = lower + RandomNumbers.randomLongBetween(random(), 0, 100); + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + Query query = new PointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + TopDocs topDocs1 = searcher.search(query, 10); + assertEquals(topDocs.totalHits, topDocs1.totalHits); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeWithDefaultSize() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 1000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + if (i % 15 == 0) iw.flush(); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 1000; + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + assertEquals(topDocs.totalHits, new TotalHits(1000, TotalHits.Relation.EQUAL_TO)); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeWithSizeUnderDefault() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 1000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + if (i % 15 == 0) iw.flush(); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 45; + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims, 10) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + assertEquals(topDocs.totalHits, new TotalHits(10, TotalHits.Relation.EQUAL_TO)); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeWithSizeOverDefault() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 15000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 12000; + Query approximateQuery = new ApproximatePointRangeQuery( + "point", + pack(lower).bytes, + pack(upper).bytes, + dims, + 11_000 + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 11000); + assertEquals(topDocs.totalHits, new TotalHits(11001, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeShortCircuit() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 1000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + if (i % 10 == 0) iw.flush(); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 100; + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims, 10) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + Query query = new PointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + TopDocs topDocs1 = searcher.search(query, 10); + + // since we short-circuit from the approx range at the end of size these will not be equal + assertNotEquals(topDocs.totalHits, topDocs1.totalHits); + assertEquals(topDocs.totalHits, new TotalHits(10, TotalHits.Relation.EQUAL_TO)); + assertEquals(topDocs1.totalHits, new TotalHits(101, TotalHits.Relation.EQUAL_TO)); + + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeShortCircuitAscSort() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 1000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 20; + Query approximateQuery = new ApproximatePointRangeQuery( + "point", + pack(lower).bytes, + pack(upper).bytes, + dims, + 10, + SortOrder.ASC + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + Query query = new PointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + TopDocs topDocs1 = searcher.search(query, 10); + + // since we short-circuit from the approx range at the end of size these will not be equal + assertNotEquals(topDocs.totalHits, topDocs1.totalHits); + assertEquals(topDocs.totalHits, new TotalHits(10, TotalHits.Relation.EQUAL_TO)); + assertEquals(topDocs1.totalHits, new TotalHits(21, TotalHits.Relation.EQUAL_TO)); + assertEquals(topDocs.scoreDocs[0].doc, 0); + assertEquals(topDocs.scoreDocs[1].doc, 1); + assertEquals(topDocs.scoreDocs[2].doc, 2); + assertEquals(topDocs.scoreDocs[3].doc, 3); + assertEquals(topDocs.scoreDocs[4].doc, 4); + assertEquals(topDocs.scoreDocs[5].doc, 5); + + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } +} diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximateScoreQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximateScoreQueryTests.java new file mode 100644 index 0000000000000..83fa4a500de76 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximateScoreQueryTests.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.analysis.core.WhitespaceAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.PointRangeQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.apache.lucene.document.LongPoint.pack; + +public class ApproximateScoreQueryTests extends OpenSearchTestCase { + + public void testApproximationScoreSupplier() throws IOException { + long l = Long.MIN_VALUE; + long u = Long.MAX_VALUE; + Query originalQuery = new PointRangeQuery( + "test-index", + pack(new long[] { l }).bytes, + pack(new long[] { u }).bytes, + new long[] { l }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + + ApproximateableQuery approximateQuery = new ApproximatePointRangeQuery( + "test-index", + pack(new long[] { l }).bytes, + pack(new long[] { u }).bytes, + new long[] { l }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + + ApproximateScoreQuery query = new ApproximateScoreQuery(originalQuery, approximateQuery); + + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + Document document = new Document(); + document.add(new LongPoint("testPoint", Long.MIN_VALUE)); + iw.addDocument(document); + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + IndexSearcher searcher = new IndexSearcher(reader); + searcher.search(query, 10); + Weight weight = query.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0F); + Scorer scorer = weight.scorer(reader.leaves().get(0)); + assertEquals( + scorer, + originalQuery.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0F).scorer(searcher.getLeafContexts().get(0)) + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } +}