From 0f3b7a22e318ed099e976fd334b6c1f005ae43b9 Mon Sep 17 00:00:00 2001 From: Harsha Vamsi Kalluri Date: Tue, 17 Sep 2024 17:46:08 -0700 Subject: [PATCH] Initial collector implementation Signed-off-by: Harsha Vamsi Kalluri --- server/build.gradle | 51 +++++ .../search/DefaultSearchContext.java | 12 + .../internal/FilteredSearchContext.java | 11 + .../search/internal/SearchContext.java | 7 + .../search/query/ArrowCollector.java | 206 +++++++++++++++++ .../search/query/ArrowCollectorContext.java | 35 +++ .../search/query/FieldTypeDefinition.java | 13 ++ .../search/query/ProjectionField.java | 23 ++ .../opensearch/search/query/QueryPhase.java | 208 +++++++++++++++++- .../search/query/QueryPhaseSearcher.java | 13 ++ .../search/query/ArrowCollectorTests.java | 145 ++++++++++++ .../opensearch/test/TestSearchContext.java | 12 + 12 files changed, 735 insertions(+), 1 deletion(-) create mode 100644 server/src/main/java/org/opensearch/search/query/ArrowCollector.java create mode 100644 server/src/main/java/org/opensearch/search/query/ArrowCollectorContext.java create mode 100644 server/src/main/java/org/opensearch/search/query/FieldTypeDefinition.java create mode 100644 server/src/main/java/org/opensearch/search/query/ProjectionField.java create mode 100644 server/src/test/java/org/opensearch/search/query/ArrowCollectorTests.java diff --git a/server/build.gradle b/server/build.gradle index 0cc42ad690eab..1aa5ed10266d2 100644 --- a/server/build.gradle +++ b/server/build.gradle @@ -129,6 +129,57 @@ dependencies { // https://mvnrepository.com/artifact/org.roaringbitmap/RoaringBitmap implementation 'org.roaringbitmap:RoaringBitmap:1.2.1' + api group: 'com.google.code.findbugs', name: 'jsr305', version: '3.0.2' + api 'org.slf4j:slf4j-api:1.7.36' + api("io.netty:netty-common:${versions.netty}") { + exclude group: 'io.netty', module: 'netty-common' + } + api("io.netty:netty-buffer:${versions.netty}") { + exclude group: 'io.netty', module: 'netty-buffer' + } + api group: 'org.apache.arrow', name: 'arrow-memory-netty-buffer-patch', version: '17.0.0' + api group: 'org.apache.arrow', name: 'arrow-vector', version: '17.0.0' + api 'org.apache.arrow:arrow-memory-core:17.0.0' + api group: 'org.apache.arrow', name: 'arrow-memory-netty', version: '17.0.0' + api 'org.apache.arrow:arrow-format:17.0.0' + api 'org.apache.arrow:arrow-flight:17.0.0' + api 'org.apache.arrow:flight-core:17.0.0' + // api 'org.apache.arrow:flight-grpc:17.0.0' + // api 'org.apache.arrow:flight-grpc:17.0.0' + api 'io.grpc:grpc-api:1.57.2' + api 'io.grpc:grpc-netty:1.63.0' + // api 'io.grpc:grpc-java:1.57.2' + api 'io.grpc:grpc-core:1.63.0' + api 'io.grpc:grpc-stub:1.63.0' +// api 'io.grpc:grpc-all:1.63.0' + api 'io.grpc:grpc-protobuf:1.63.0' + api 'io.grpc:grpc-protobuf-lite:1.63.0' + + + api 'io.grpc:grpc-all:1.57.2' + api "io.netty:netty-buffer:${versions.netty}" + api "io.netty:netty-codec:${versions.netty}" + api "io.netty:netty-codec-http:${versions.netty}" + api "io.netty:netty-codec-http2:${versions.netty}" + api "io.netty:netty-common:${versions.netty}" + api "io.netty:netty-handler:${versions.netty}" + api "io.netty:netty-resolver:${versions.netty}" + api "io.netty:netty-transport:${versions.netty}" + api "io.netty:netty-transport-native-unix-common:${versions.netty}" + runtimeOnly 'io.perfmark:perfmark-api:0.27.0' +// runtimeOnly('com.google.guava:guava:32.1.1-jre') + runtimeOnly "com.google.guava:failureaccess:1.0.1" + + + api 'com.google.flatbuffers:flatbuffers-java:2.0.0' + api 'org.apache.parquet:parquet-arrow:1.13.1' + + api 'com.fasterxml.jackson.core:jackson-databind:2.17.2' + api 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.17.2' + api 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.17.2' + api 'com.fasterxml.jackson.core:jackson-annotations:2.17.2' + // api 'org.apache.arrow:arrow-compression:13.0.0' + testImplementation(project(":test:framework")) { // tests use the locally compiled version of server exclude group: 'org.opensearch', module: 'server' diff --git a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java index 74a7482d975df..9a6dc6fa3cb17 100644 --- a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java @@ -92,6 +92,7 @@ import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.profile.Profilers; +import org.opensearch.search.query.ArrowCollector; import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; @@ -199,6 +200,7 @@ final class DefaultSearchContext extends SearchContext { private List rescore; private Profilers profilers; private BucketCollectorProcessor bucketCollectorProcessor = NO_OP_BUCKET_COLLECTOR_PROCESSOR; + private ArrowCollector arrowCollector = NO_OP_ARROW_COLLECTOR; private final Map searchExtBuilders = new HashMap<>(); private final Map, CollectorManager> queryCollectorManagers = new HashMap<>(); private final QueryShardContext queryShardContext; @@ -1052,6 +1054,16 @@ public BucketCollectorProcessor bucketCollectorProcessor() { return bucketCollectorProcessor; } + @Override + public ArrowCollector getArrowCollector() { + return arrowCollector; + } + + @Override + public void setArrowCollector(ArrowCollector arrowCollector) { + this.arrowCollector = arrowCollector; + } + /** * Evaluate the concurrentSearchMode based on cluster and index settings if concurrent segment search * should be used for this request context diff --git a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java index 3a3b46366a6d2..2f39672aed1ac 100644 --- a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java @@ -63,6 +63,7 @@ import org.opensearch.search.fetch.subphase.ScriptFieldsContext; import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.opensearch.search.profile.Profilers; +import org.opensearch.search.query.ArrowCollector; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; import org.opensearch.search.rescore.RescoreContext; @@ -568,6 +569,16 @@ public BucketCollectorProcessor bucketCollectorProcessor() { return in.bucketCollectorProcessor(); } + @Override + public void setArrowCollector(ArrowCollector arrowCollector) { + in.setArrowCollector(arrowCollector); + } + + @Override + public ArrowCollector getArrowCollector() { + return in.getArrowCollector(); + } + @Override public boolean shouldUseConcurrentSearch() { return in.shouldUseConcurrentSearch(); diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index 5357206e8c117..9f6c5a4fb307d 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -72,6 +72,7 @@ import org.opensearch.search.fetch.subphase.ScriptFieldsContext; import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.opensearch.search.profile.Profilers; +import org.opensearch.search.query.ArrowCollector; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; import org.opensearch.search.rescore.RescoreContext; @@ -121,6 +122,8 @@ public List toInternalAggregations(Collection co } }; + public static final ArrowCollector NO_OP_ARROW_COLLECTOR = new ArrowCollector(); + private final List releasables = new CopyOnWriteArrayList<>(); private final AtomicBoolean closed = new AtomicBoolean(false); private InnerHitsContext innerHitsContext; @@ -515,6 +518,10 @@ public String toString() { public abstract BucketCollectorProcessor bucketCollectorProcessor(); + public abstract ArrowCollector getArrowCollector(); + + public abstract void setArrowCollector(ArrowCollector arrowCollector); + public abstract int getTargetMaxSliceCount(); public abstract boolean shouldUseTimeSeriesDescSortOptimization(); diff --git a/server/src/main/java/org/opensearch/search/query/ArrowCollector.java b/server/src/main/java/org/opensearch/search/query/ArrowCollector.java new file mode 100644 index 0000000000000..c9f94139ec23d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/ArrowCollector.java @@ -0,0 +1,206 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float2Vector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.ScoreMode; +import org.opensearch.common.annotation.ExperimentalApi; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@ExperimentalApi +public class ArrowCollector implements Collector { + + BufferAllocator allocator; + Schema schema; + List projectionFields; + VectorSchemaRoot root; + + final int BATCH_SIZE = 1000; + + public ArrowCollector() { + this(new ArrayList<>()); + } + + public ArrowCollector(List projectionFields) { + // super(delegateCollector); + allocator = new RootAllocator(); + this.projectionFields = projectionFields; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + // LeafCollector innerLeafCollector = this.in.getLeafCollector(context); + Map arrowFields = new HashMap<>(); + Map vectors = new HashMap<>(); + Map iterators = new HashMap<>(); + final NumericDocValues[] numericDocValues = new NumericDocValues[1]; + projectionFields.forEach(field -> { + switch (field.type) { + case INT: + Field intField = new Field(field.fieldName, FieldType.nullable(new ArrowType.Int(32, true)), null); + IntVector intVector = new IntVector(intField, allocator); + intVector.allocateNew(BATCH_SIZE); + vectors.put(field.fieldName, intVector); + arrowFields.put(field.fieldName, intField); + break; + case BOOLEAN: + Field boolField = new Field(field.fieldName, FieldType.nullable(new ArrowType.Bool()), null); + // vectors.put(field.fieldName, intVector); + arrowFields.put(field.fieldName, boolField); + break; + case DATE: + case DATE_NANOSECONDS: + case LONG: + Field longField = new Field(field.fieldName, FieldType.nullable(new ArrowType.Int(64, true)), null); + BigIntVector bigIntVector = new BigIntVector(longField, allocator); + bigIntVector.allocateNew(BATCH_SIZE); + vectors.put(field.fieldName, bigIntVector); + arrowFields.put(field.fieldName, longField); + break; + case UNSIGNED_LONG: + Field unsignedLongField = new Field(field.fieldName, FieldType.nullable(new ArrowType.Int(64, false)), null); + UInt8Vector uInt8Vector = new UInt8Vector(unsignedLongField, allocator); + uInt8Vector.allocateNew(BATCH_SIZE); + vectors.put(field.fieldName, uInt8Vector); + arrowFields.put(field.fieldName, unsignedLongField); + break; + case HALF_FLOAT: + Field halfFoatField = new Field( + field.fieldName, + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)), + null + ); + Float2Vector float2Vector = new Float2Vector(halfFoatField, allocator); + float2Vector.allocateNew(BATCH_SIZE); + vectors.put(field.fieldName, float2Vector); + arrowFields.put(field.fieldName, halfFoatField); + break; + case FLOAT: + Field floatField = new Field( + field.fieldName, + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), + null + ); + Float4Vector float4Vector = new Float4Vector(floatField, allocator); + float4Vector.allocateNew(BATCH_SIZE); + vectors.put(field.fieldName, float4Vector); + arrowFields.put(field.fieldName, floatField); + break; + case DOUBLE: + Field doubleField = new Field( + field.fieldName, + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + null + ); + Float8Vector float8Vector = new Float8Vector(doubleField, allocator); + float8Vector.allocateNew(BATCH_SIZE); + vectors.put(field.fieldName, float8Vector); + arrowFields.put(field.fieldName, doubleField); + break; + case SHORT: + Field shortField = new Field(field.fieldName, FieldType.nullable(new ArrowType.Int(16, true)), null); + SmallIntVector smallIntVector = new SmallIntVector(shortField, allocator); + smallIntVector.allocateNew(BATCH_SIZE); + vectors.put(field.fieldName, smallIntVector); + arrowFields.put(field.fieldName, shortField); + break; + case BYTE: + Field byteField = new Field(field.fieldName, FieldType.nullable(new ArrowType.Int(8, true)), null); + TinyIntVector tinyIntVector = new TinyIntVector(byteField, allocator); + tinyIntVector.allocateNew(BATCH_SIZE); + vectors.put(field.fieldName, tinyIntVector); + arrowFields.put(field.fieldName, byteField); + break; + default: + throw new UnsupportedOperationException("Field type not supported"); + } + ; + try { + numericDocValues[0] = context.reader().getNumericDocValues(field.fieldName); + } catch (IOException e) { + throw new RuntimeException(e); + } + iterators.put(field.fieldName, numericDocValues[0]); + }); + schema = new Schema(arrowFields.values()); + root = new VectorSchemaRoot(new ArrayList<>(arrowFields.values()), new ArrayList<>(vectors.values())); + final int[] i = { 0 }; + return new LeafCollector() { + @Override + public void setScorer(Scorable scorable) throws IOException { + // innerLeafCollector.setScorer(scorable); + } + + @Override + public void collect(int docId) throws IOException { + // innerLeafCollector.collect(docId); + for (String field : iterators.keySet()) { + NumericDocValues iterator = iterators.get(field); + BigIntVector vector = (BigIntVector) vectors.get(field); + if (iterator == null) { + break; + } + if (iterator.advanceExact(docId)) { + if (i[0] > BATCH_SIZE) { + vector.allocateNew(BATCH_SIZE); + } + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES); + buffer.putLong(iterator.longValue()); + vector.set(i[0], iterator.longValue()); + i[0]++; + } else { + break; + } + } + } + + @Override + public void finish() throws IOException { + // innerLeafCollector.finish(); + root.setRowCount(i[0]); + } + }; + } + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; + } + + public VectorSchemaRoot getRootVector() { + return root; + } +} diff --git a/server/src/main/java/org/opensearch/search/query/ArrowCollectorContext.java b/server/src/main/java/org/opensearch/search/query/ArrowCollectorContext.java new file mode 100644 index 0000000000000..bd220e1191035 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/ArrowCollectorContext.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; + +import java.io.IOException; +import java.util.List; + +public class ArrowCollectorContext extends QueryCollectorContext { + + List projectionFields; + + ArrowCollectorContext(String profilerName, List projectionFields) { + super(profilerName); + this.projectionFields = projectionFields; + } + + @Override + Collector create(Collector in) throws IOException { + return new ArrowCollector(projectionFields); + } + + @Override + CollectorManager createManager(CollectorManager in) throws IOException { + return null; + } +} diff --git a/server/src/main/java/org/opensearch/search/query/FieldTypeDefinition.java b/server/src/main/java/org/opensearch/search/query/FieldTypeDefinition.java new file mode 100644 index 0000000000000..65c34066af701 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/FieldTypeDefinition.java @@ -0,0 +1,13 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +public enum FieldTypeDefinition { + +} diff --git a/server/src/main/java/org/opensearch/search/query/ProjectionField.java b/server/src/main/java/org/opensearch/search/query/ProjectionField.java new file mode 100644 index 0000000000000..e4492dd66a1ae --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/ProjectionField.java @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.index.fielddata.IndexNumericFieldData; + +@ExperimentalApi +public class ProjectionField { + String fieldName; + IndexNumericFieldData.NumericType type; + + public ProjectionField(IndexNumericFieldData.NumericType type, String fieldName) { + this.type = type; + this.fieldName = fieldName; + } +} diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhase.java b/server/src/main/java/org/opensearch/search/query/QueryPhase.java index 55b7c0bc5178d..2dcfcff8fd46a 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhase.java @@ -72,6 +72,7 @@ import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -79,6 +80,7 @@ import java.util.concurrent.ExecutorService; import java.util.stream.Collectors; +import static org.opensearch.search.profile.query.CollectorResult.REASON_SEARCH_TOP_HITS; import static org.opensearch.search.query.QueryCollectorContext.createEarlyTerminationCollectorContext; import static org.opensearch.search.query.QueryCollectorContext.createFilteredCollectorContext; import static org.opensearch.search.query.QueryCollectorContext.createMinScoreCollectorContext; @@ -97,6 +99,7 @@ public class QueryPhase { // TODO: remove this property public static final boolean SYS_PROP_REWRITE_SORT = Booleans.parseBoolean(System.getProperty("opensearch.search.rewrite_sort", "true")); public static final QueryPhaseSearcher DEFAULT_QUERY_PHASE_SEARCHER = new DefaultQueryPhaseSearcher(); + public static final QueryPhaseSearcher STREAM_QUERY_PHASE_SEARCHER = new StreamQueryPhaseSearcher(); private final QueryPhaseSearcher queryPhaseSearcher; private final SuggestProcessor suggestProcessor; private final RescoreProcessor rescoreProcessor; @@ -183,12 +186,153 @@ static boolean executeInternal(SearchContext searchContext) throws QueryPhaseExe return executeInternal(searchContext, QueryPhase.DEFAULT_QUERY_PHASE_SEARCHER); } + public static boolean executeStreamInternal( + SearchContext searchContext, + QueryPhaseSearcher queryPhaseSearcher, + List projectionFields + ) { + return executeInternal(searchContext, QueryPhase.STREAM_QUERY_PHASE_SEARCHER, projectionFields); + } + + /** + * In a package-private method so that it can be tested without having to + * wire everything (mapperService, etc.) + * @return whether the rescoring phase should be executed + * + * TODO: refactor this + */ + public static boolean executeInternal( + SearchContext searchContext, + QueryPhaseSearcher queryPhaseSearcher, + List projectionFields + ) throws QueryPhaseExecutionException { + final ContextIndexSearcher searcher = searchContext.searcher(); + final IndexReader reader = searcher.getIndexReader(); + QuerySearchResult queryResult = searchContext.queryResult(); + queryResult.searchTimedOut(false); + try { + queryResult.from(searchContext.from()); + queryResult.size(searchContext.size()); + Query query = searchContext.query(); + assert query == searcher.rewrite(query); // already rewritten + + final ScrollContext scrollContext = searchContext.scrollContext(); + if (scrollContext != null) { + if (scrollContext.totalHits == null) { + // first round + assert scrollContext.lastEmittedDoc == null; + // there is not much that we can optimize here since we want to collect all + // documents in order to get the total number of hits + + } else { + final ScoreDoc after = scrollContext.lastEmittedDoc; + if (canEarlyTerminate(reader, searchContext.sort())) { + // now this gets interesting: since the search sort is a prefix of the index sort, we can directly + // skip to the desired doc + if (after != null) { + query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.MUST) + .add(new SearchAfterSortedDocQuery(searchContext.sort().sort, (FieldDoc) after), BooleanClause.Occur.FILTER) + .build(); + } + } + } + } + + final LinkedList collectors = new LinkedList<>(); + // whether the chain contains a collector that filters documents + boolean hasFilterCollector = false; + if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER) { + // add terminate_after before the filter collectors + // it will only be applied on documents accepted by these filter collectors + collectors.add(createEarlyTerminationCollectorContext(searchContext.terminateAfter())); + // this collector can filter documents during the collection + hasFilterCollector = true; + } + if (searchContext.parsedPostFilter() != null) { + // add post filters before aggregations + // it will only be applied to top hits + collectors.add(createFilteredCollectorContext(searcher, searchContext.parsedPostFilter().query())); + // this collector can filter documents during the collection + hasFilterCollector = true; + } + + // plug in additional collectors, like aggregations except global aggregations + final List> managersExceptGlobalAgg = searchContext + .queryCollectorManagers() + .entrySet() + .stream() + .filter(entry -> !(entry.getKey().equals(GlobalAggCollectorManager.class))) + .map(Map.Entry::getValue) + .collect(Collectors.toList()); + if (managersExceptGlobalAgg.isEmpty() == false) { + collectors.add(createMultiCollectorContext(managersExceptGlobalAgg)); + } + + if (searchContext.minimumScore() != null) { + // apply the minimum score after multi collector so we filter aggs as well + collectors.add(createMinScoreCollectorContext(searchContext.minimumScore())); + // this collector can filter documents during the collection + hasFilterCollector = true; + } + + boolean timeoutSet = scrollContext == null + && searchContext.timeout() != null + && searchContext.timeout().equals(SearchService.NO_TIMEOUT) == false; + + final Runnable timeoutRunnable; + if (timeoutSet) { + timeoutRunnable = searcher.addQueryCancellation(createQueryTimeoutChecker(searchContext)); + } else { + timeoutRunnable = null; + } + + if (searchContext.lowLevelCancellation()) { + searcher.addQueryCancellation(() -> { + SearchShardTask task = searchContext.getTask(); + if (task != null && task.isCancelled()) { + throw new TaskCancelledException("cancelled task with reason: " + task.getReasonCancelled()); + } + }); + } + + try { + boolean shouldRescore = queryPhaseSearcher.searchWith( + searchContext, + searcher, + query, + collectors, + projectionFields, + hasFilterCollector, + timeoutSet + ); + + ExecutorService executor = searchContext.indexShard().getThreadPool().executor(ThreadPool.Names.SEARCH); + if (executor instanceof EWMATrackingThreadPoolExecutor) { + final EWMATrackingThreadPoolExecutor rExecutor = (EWMATrackingThreadPoolExecutor) executor; + queryResult.nodeQueueSize(rExecutor.getCurrentQueueSize()); + queryResult.serviceTimeEWMA((long) rExecutor.getTaskExecutionEWMA()); + } + + return shouldRescore; + } finally { + // Search phase has finished, no longer need to check for timeout + // otherwise aggregation phase might get cancelled. + if (timeoutRunnable != null) { + searcher.removeQueryCancellation(timeoutRunnable); + } + } + } catch (Exception e) { + throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Failed to execute main query", e); + } + } + /** * In a package-private method so that it can be tested without having to * wire everything (mapperService, etc.) * @return whether the rescoring phase should be executed */ - static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher queryPhaseSearcher) throws QueryPhaseExecutionException { + public static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher queryPhaseSearcher) + throws QueryPhaseExecutionException { final ContextIndexSearcher searcher = searchContext.searcher(); final IndexReader reader = searcher.getIndexReader(); QuerySearchResult queryResult = searchContext.queryResult(); @@ -350,6 +494,9 @@ private static boolean searchWithCollector( } else { queryCollector = QueryCollectorContext.createQueryCollector(collectors); } + if (queryCollector instanceof ArrowCollector) { + searchContext.setArrowCollector((ArrowCollector) queryCollector); + } QuerySearchResult queryResult = searchContext.queryResult(); try { searcher.search(query, queryCollector); @@ -470,4 +617,63 @@ protected boolean searchWithCollector( ); } } + + /** + * Default {@link QueryPhaseSearcher} implementation which delegates to the {@link QueryPhase}. + * + * @opensearch.internal + */ + public static class StreamQueryPhaseSearcher implements QueryPhaseSearcher { + + /** + * Please use {@link QueryPhase#STREAM_QUERY_PHASE_SEARCHER} + */ + protected StreamQueryPhaseSearcher() {} + + @Override + public boolean searchWith( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return searchWith(searchContext, searcher, query, collectors, new ArrayList<>(), hasFilterCollector, hasTimeout); + } + + @Override + public boolean searchWith( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + List projectionFields, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return searchWithCollector(searchContext, searcher, query, collectors, projectionFields, hasFilterCollector, hasTimeout); + } + + protected boolean searchWithCollector( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + List projectionFields, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + final ArrowCollectorContext arrowCollectorContext = new ArrowCollectorContext(REASON_SEARCH_TOP_HITS, projectionFields); + return QueryPhase.searchWithCollector( + searchContext, + searcher, + query, + collectors, + arrowCollectorContext, + hasFilterCollector, + hasTimeout + ); + } + } } diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhaseSearcher.java b/server/src/main/java/org/opensearch/search/query/QueryPhaseSearcher.java index 38e45a5212c81..d0e5b4c63e836 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhaseSearcher.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhaseSearcher.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.LinkedList; +import java.util.List; /** * The extension point which allows to plug in custom search implementation to be @@ -45,6 +46,18 @@ boolean searchWith( boolean hasTimeout ) throws IOException; + default boolean searchWith( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + List projectionFields, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return false; + } + /** * {@link AggregationProcessor} to use to setup and post process aggregation related collectors during search request * @param searchContext search context diff --git a/server/src/test/java/org/opensearch/search/query/ArrowCollectorTests.java b/server/src/test/java/org/opensearch/search/query/ArrowCollectorTests.java new file mode 100644 index 0000000000000..59dfeaa6a9a3f --- /dev/null +++ b/server/src/test/java/org/opensearch/search/query/ArrowCollectorTests.java @@ -0,0 +1,145 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.index.fielddata.IndexNumericFieldData; +import org.opensearch.index.query.ParsedQuery; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.shard.IndexShardTestCase; +import org.opensearch.index.shard.SearchOperationListener; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.test.TestSearchContext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ArrowCollectorTests extends IndexShardTestCase { + private IndexShard indexShard; + private final QueryPhaseSearcher queryPhaseSearcher; + + @ParametersFactory + public static Collection concurrency() { + return Collections.singletonList(new Object[] { 0, QueryPhase.DEFAULT_QUERY_PHASE_SEARCHER }); + } + + public ArrowCollectorTests(int concurrency, QueryPhaseSearcher queryPhaseSearcher) { + this.queryPhaseSearcher = queryPhaseSearcher; + } + + @Override + public void setUp() throws Exception { + super.setUp(); + indexShard = newShard(true); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + closeShards(indexShard); + } + + public void testArrow() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + + // FlightService flightService = new FlightService(); + final int numDocs = scaledRandomIntBetween(100, 200); + IndexReader reader = null; + try { + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new LongPoint("longpoint", i)); + doc.add(new NumericDocValuesField("longpoint", i)); + w.addDocument(doc); + } + w.close(); + reader = DirectoryReader.open(dir); + // flightService.start(); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, null), null); + context.setSize(1000); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); + List projectionFields = new ArrayList<>(); + projectionFields.add(new ProjectionField(IndexNumericFieldData.NumericType.LONG, "longpoint")); + QueryPhase.executeStreamInternal(context.withCleanQueryResult(), QueryPhase.STREAM_QUERY_PHASE_SEARCHER, projectionFields); + VectorSchemaRoot vectorSchemaRoot = context.getArrowCollector().getRootVector(); + System.out.println(vectorSchemaRoot.getSchema()); + Field longPoint = vectorSchemaRoot.getSchema().findField("longpoint"); + assertEquals(longPoint, new Field("longpoint", FieldType.nullable(new ArrowType.Int(64, true)), null)); + BigIntVector vector = (BigIntVector) vectorSchemaRoot.getVector("longpoint"); + assertEquals(vector.getValueCount(), numDocs); + // assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) 9)); + // FlightStream flightStream = flightService.getFlightClient().getStream(new Ticket("id1".getBytes(StandardCharsets.UTF_8))); + + // System.out.println(flightStream.getSchema()); + // System.out.println(flightStream.next()); + // System.out.println(flightStream.getRoot().contentToTSVString()); + // System.out.println(flightStream.getRoot().getRowCount()); + // System.out.println(flightStream.next()); + // flightStream.close(); + } finally { + if (reader != null) reader.close(); + dir.close(); + // flightService.stop(); + // flightService.close(); + } + } + + private static ContextIndexSearcher newContextSearcher(IndexReader reader, ExecutorService executor) throws IOException { + SearchContext searchContext = mock(SearchContext.class); + IndexShard indexShard = mock(IndexShard.class); + when(searchContext.indexShard()).thenReturn(indexShard); + SearchOperationListener searchOperationListener = new SearchOperationListener() { + }; + when(indexShard.getSearchOperationListener()).thenReturn(searchOperationListener); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(executor != null); + if (executor != null) { + when(searchContext.getTargetMaxSliceCount()).thenReturn(randomIntBetween(0, 2)); + } else { + when(searchContext.getTargetMaxSliceCount()).thenThrow(IllegalStateException.class); + } + return new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + executor, + searchContext + ); + } +} diff --git a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java index 09a72dcdc3641..23c28082ca2f6 100644 --- a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java @@ -72,6 +72,7 @@ import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.profile.Profilers; +import org.opensearch.search.query.ArrowCollector; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; import org.opensearch.search.rescore.RescoreContext; @@ -121,6 +122,7 @@ public class TestSearchContext extends SearchContext { private CollapseContext collapse; protected boolean concurrentSegmentSearchEnabled; private BucketCollectorProcessor bucketCollectorProcessor = NO_OP_BUCKET_COLLECTOR_PROCESSOR; + private ArrowCollector arrowCollector = NO_OP_ARROW_COLLECTOR; private int maxSliceCount; /** @@ -698,6 +700,16 @@ public BucketCollectorProcessor bucketCollectorProcessor() { return bucketCollectorProcessor; } + @Override + public void setArrowCollector(ArrowCollector arrowCollector) { + this.arrowCollector = arrowCollector; + } + + @Override + public ArrowCollector getArrowCollector() { + return arrowCollector; + } + @Override public int getTargetMaxSliceCount() { assert concurrentSegmentSearchEnabled == true : "Please use concurrent search before fetching maxSliceCount";