From 81f1a4a829167cc5d6e7db8a97a4627f78cc59ff Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Sun, 23 Apr 2023 17:37:41 -0700 Subject: [PATCH] Initial code for adding the SearchPhaseInjectorProcessor interface in Search Pipeline Signed-off-by: Navneet Verma --- CHANGELOG.md | 1 + .../search/AbstractSearchAsyncAction.java | 10 +- .../search/ArraySearchPhaseResults.java | 2 +- .../search/CanMatchPreFilterSearchPhase.java | 7 +- .../action/search/DfsQueryPhase.java | 2 +- .../action/search/ExpandSearchPhase.java | 2 +- .../action/search/FetchSearchPhase.java | 2 +- .../SearchDfsQueryThenFetchAsyncAction.java | 7 +- .../opensearch/action/search/SearchPhase.java | 23 ++- .../action/search/SearchPhaseContext.java | 2 +- .../action/search/SearchPhaseResults.java | 4 +- .../SearchQueryThenFetchAsyncAction.java | 7 +- .../search/SearchScrollAsyncAction.java | 2 +- ...SearchScrollQueryThenFetchAsyncAction.java | 2 +- .../action/search/TransportSearchAction.java | 12 +- .../opensearch/search/pipeline/Pipeline.java | 49 ++++- .../SearchPhaseInjectorProcessor.java | 37 ++++ .../pipeline/SearchPipelineService.java | 21 +++ .../AbstractSearchAsyncActionTests.java | 7 +- .../CanMatchPreFilterSearchPhaseTests.java | 23 ++- .../action/search/SearchAsyncActionTests.java | 20 +- .../SearchQueryThenFetchAsyncActionTests.java | 3 +- .../pipeline/SearchPipelineServiceTests.java | 177 +++++++++++++++++- 23 files changed, 377 insertions(+), 45 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 79ee28aaf7d20..6ca7c802593a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x] ### Added - [Extensions] Moving Extensions APIs to protobuf serialization. ([#6960](https://github.com/opensearch-project/OpenSearch/pull/6960)) +- [SearchPipeline] Initial code for adding the SearchPhaseInjectorProcessor interface in Search Pipeline.([#7283](https://github.com/opensearch-project/OpenSearch/pull/7283)) ### Dependencies diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 9a94737c84385..17a793740e0c4 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -57,6 +57,7 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.transport.Transport; import java.util.ArrayDeque; @@ -116,6 +117,7 @@ abstract class AbstractSearchAsyncAction exten private final boolean throttleConcurrentRequests; private final List releasables = new ArrayList<>(); + private final SearchPipelineService searchPipelineService; AbstractSearchAsyncAction( String name, @@ -134,7 +136,8 @@ abstract class AbstractSearchAsyncAction exten SearchTask task, SearchPhaseResults resultConsumer, int maxConcurrentRequestsPerNode, - SearchResponse.Clusters clusters + SearchResponse.Clusters clusters, + SearchPipelineService searchPipelineService ) { super(name); final List toSkipIterators = new ArrayList<>(); @@ -170,6 +173,7 @@ abstract class AbstractSearchAsyncAction exten this.indexRoutings = indexRoutings; this.results = resultConsumer; this.clusters = clusters; + this.searchPipelineService = searchPipelineService; } @Override @@ -696,7 +700,9 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { * @see #onShardResult(SearchPhaseResult, SearchShardIterator) */ final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() - executeNextPhase(this, getNextPhase(results, this)); + final SearchPhase nextPhase = getNextPhase(results, this); + searchPipelineService.transformSearchPhase(results, this, this.getSearchPhaseName(), nextPhase.getSearchPhaseName()); + executeNextPhase(this, nextPhase); } @Override diff --git a/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java b/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java index 61c81e6cda97a..653b0e8aedb9d 100644 --- a/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java +++ b/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java @@ -66,7 +66,7 @@ boolean hasResult(int shardIndex) { } @Override - AtomicArray getAtomicArray() { + public AtomicArray getAtomicArray() { return results; } } diff --git a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java index ec4d45a0a7124..8b153392b03af 100644 --- a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java @@ -41,6 +41,7 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortOrder; @@ -90,7 +91,8 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction, SearchPhase> phaseFactory, - SearchResponse.Clusters clusters + SearchResponse.Clusters clusters, + SearchPipelineService searchPipelineService ) { // We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests super( @@ -110,7 +112,8 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction, SearchPhase> nextPhaseFactory, SearchPhaseContext context ) { - super("dfs_query"); + super(SearchPhaseName.DFS_QUERY.name()); this.progressListener = context.getTask().getProgressListener(); this.queryResult = queryResult; this.searchResults = searchResults; diff --git a/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java b/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java index cdefe7c2c1712..88d02f5e3504b 100644 --- a/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java @@ -62,7 +62,7 @@ final class ExpandSearchPhase extends SearchPhase { private final AtomicArray queryResults; ExpandSearchPhase(SearchPhaseContext context, InternalSearchResponse searchResponse, AtomicArray queryResults) { - super("expand"); + super(SearchPhaseName.EXPAND.name()); this.context = context; this.searchResponse = searchResponse; this.queryResults = queryResults; diff --git a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java index 31ec896856ce6..1aab8855d1757 100644 --- a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java @@ -92,7 +92,7 @@ final class FetchSearchPhase extends SearchPhase { SearchPhaseContext context, BiFunction, SearchPhase> nextPhaseFactory ) { - super("fetch"); + super(SearchPhaseName.FETCH.name()); if (context.getNumShards() != resultConsumer.getNumShards()) { throw new IllegalStateException( "number of shards must match the length of the query results but doesn't:" diff --git a/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 71a986c0e15f7..422c10e222c2a 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -41,6 +41,7 @@ import org.opensearch.search.dfs.AggregatedDfs; import org.opensearch.search.dfs.DfsSearchResult; import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.transport.Transport; import java.util.List; @@ -76,7 +77,8 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction final TransportSearchAction.SearchTimeProvider timeProvider, final ClusterState clusterState, final SearchTask task, - SearchResponse.Clusters clusters + SearchResponse.Clusters clusters, + SearchPipelineService searchPipelineService ) { super( "dfs", @@ -95,7 +97,8 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction task, new ArraySearchPhaseResults<>(shardsIts.size()), request.getMaxConcurrentShardRequests(), - clusters + clusters, + searchPipelineService ); this.queryPhaseResultConsumer = queryPhaseResultConsumer; this.searchPhaseController = searchPhaseController; diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhase.java b/server/src/main/java/org/opensearch/action/search/SearchPhase.java index 50f0940754078..e15151758113b 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhase.java @@ -41,7 +41,7 @@ * * @opensearch.internal */ -abstract class SearchPhase implements CheckedRunnable { +public abstract class SearchPhase implements CheckedRunnable { private final String name; protected SearchPhase(String name) { @@ -54,4 +54,25 @@ protected SearchPhase(String name) { public String getName() { return name; } + + public SearchPhaseName getSearchPhaseName() { + return SearchPhaseName.valueOf(name); + } + + /** + * Enum for different Search Phases in OpenSearch + * @opensearch.internal + */ + public enum SearchPhaseName { + QUERY("query"), + FETCH("fetch"), + DFS_QUERY("dfs_query"), + EXPAND("expand"); + + private final String name; + + SearchPhaseName(final String name) { + this.name = name; + } + } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java index be364fbcb9c84..4ffd5521793f6 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java @@ -50,7 +50,7 @@ * * @opensearch.internal */ -interface SearchPhaseContext extends Executor { +public interface SearchPhaseContext extends Executor { // TODO maybe we can make this concrete later - for now we just implement this in the base class for all initial phases /** diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java index 1baea0e721c44..cae439bc46b4b 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java @@ -42,7 +42,7 @@ * * @opensearch.internal */ -abstract class SearchPhaseResults { +public abstract class SearchPhaseResults { private final int numShards; SearchPhaseResults(int numShards) { @@ -75,7 +75,7 @@ final int getNumShards() { void consumeShardFailure(int shardIndex) {} - AtomicArray getAtomicArray() { + public AtomicArray getAtomicArray() { throw new UnsupportedOperationException(); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java index 1ead14aac6b51..2aaa1d788c5bc 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -42,6 +42,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.transport.Transport; @@ -81,7 +82,8 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction fetchResults ) { - return new SearchPhase("fetch") { + return new SearchPhase(SearchPhase.SearchPhaseName.FETCH.name()) { @Override public void run() throws IOException { sendResponse(queryPhase, fetchResults); diff --git a/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java index 4119cb1cf28a0..b74509bd65f77 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java @@ -92,7 +92,7 @@ protected void executeInitialPhase( @Override protected SearchPhase moveToNextPhase(BiFunction clusterNodeLookup) { - return new SearchPhase("fetch") { + return new SearchPhase(SearchPhase.SearchPhaseName.FETCH.name()) { @Override public void run() { final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedScrollQueryPhase( diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 221022fcfea80..88cebaa3edd89 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -345,7 +345,8 @@ public AbstractSearchAsyncAction asyncSearchAction( task, new ArraySearchPhaseResults<>(shardsIts.size()), searchRequest.getMaxConcurrentShardRequests(), - clusters + clusters, + searchPipelineService ) { @Override protected void executePhaseOnShard( @@ -1161,7 +1162,8 @@ public void run() { } }; }, - clusters + clusters, + searchPipelineService ); } else { final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults( @@ -1191,7 +1193,8 @@ public void run() { timeProvider, clusterState, task, - clusters + clusters, + searchPipelineService ); break; case QUERY_THEN_FETCH: @@ -1211,7 +1214,8 @@ public void run() { timeProvider, clusterState, task, - clusters + clusters, + searchPipelineService ); break; default: diff --git a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java index dd7e88c86f1e6..736ad3cec5bf6 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java @@ -9,10 +9,14 @@ package org.opensearch.search.pipeline; import org.opensearch.OpenSearchParseException; +import org.opensearch.action.search.SearchPhase; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.Nullable; import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.search.SearchPhaseResult; import java.util.ArrayList; import java.util.Arrays; @@ -30,6 +34,7 @@ class Pipeline { public static final String REQUEST_PROCESSORS_KEY = "request_processors"; public static final String RESPONSE_PROCESSORS_KEY = "response_processors"; + public static final String PHASE_PROCESSORS_KEY = "phase_injector_processors"; private final String id; private final String description; private final Integer version; @@ -39,18 +44,22 @@ class Pipeline { private final List searchRequestProcessors; private final List searchResponseProcessors; + private final List searchPhaseInjectorProcessors; + Pipeline( String id, @Nullable String description, @Nullable Integer version, List requestProcessors, - List responseProcessors + List responseProcessors, + List phaseInjectorProcessors ) { this.id = id; this.description = description; this.version = version; this.searchRequestProcessors = requestProcessors; this.searchResponseProcessors = responseProcessors; + this.searchPhaseInjectorProcessors = phaseInjectorProcessors; } public static Pipeline create(String id, Map config, Map processorFactories) @@ -69,11 +78,18 @@ public static Pipeline create(String id, Map config, Map> phaseProcessorConfigs = ConfigurationUtils.readOptionalList(null, null, config, PHASE_PROCESSORS_KEY); List responseProcessors = readProcessors( SearchResponseProcessor.class, processorFactories, responseProcessorConfigs ); + final List phaseProcessors = readProcessors( + SearchPhaseInjectorProcessor.class, + processorFactories, + phaseProcessorConfigs + ); if (config.isEmpty() == false) { throw new OpenSearchParseException( "pipeline [" @@ -82,7 +98,7 @@ public static Pipeline create(String id, Map config, Map List readProcessors( } List flattenAllProcessors() { - List allProcessors = new ArrayList<>(searchRequestProcessors.size() + searchResponseProcessors.size()); + List allProcessors = new ArrayList<>( + searchRequestProcessors.size() + searchResponseProcessors.size() + searchPhaseInjectorProcessors.size() + ); allProcessors.addAll(searchRequestProcessors); + allProcessors.addAll(searchPhaseInjectorProcessors); allProcessors.addAll(searchResponseProcessors); return allProcessors; } @@ -142,6 +161,10 @@ List getSearchResponseProcessors() { return searchResponseProcessors; } + List getSearchPhaseInjectorProcessors() { + return searchPhaseInjectorProcessors; + } + SearchRequest transformRequest(SearchRequest request) throws Exception { for (SearchRequestProcessor searchRequestProcessor : searchRequestProcessors) { request = searchRequestProcessor.processRequest(request); @@ -159,4 +182,24 @@ SearchResponse transformResponse(SearchRequest request, SearchResponse response) throw new SearchPipelineProcessingException(e); } } + + SearchPhaseResults runSearchPhaseTransformer( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext context, + SearchPhase.SearchPhaseName currentPhase, + SearchPhase.SearchPhaseName nextPhase + ) throws SearchPipelineProcessingException { + + try { + for (SearchPhaseInjectorProcessor searchPhaseInjectorProcessor : searchPhaseInjectorProcessors) { + if (currentPhase == searchPhaseInjectorProcessor.getBeforePhase() + && nextPhase == searchPhaseInjectorProcessor.getAfterPhase()) { + searchPhaseResult = searchPhaseInjectorProcessor.execute(searchPhaseResult, context); + } + } + return searchPhaseResult; + } catch (Exception e) { + throw new SearchPipelineProcessingException(e); + } + } } diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java new file mode 100644 index 0000000000000..0d4f7950596cf --- /dev/null +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseInjectorProcessor.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline; + +import org.opensearch.action.search.SearchPhase; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.search.SearchPhaseResult; + +/** + * Creates a processor that runs between Phases of the Search. + */ +public interface SearchPhaseInjectorProcessor extends Processor { + SearchPhaseResults execute( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ); + + /** + * The phase which should have run before, this processor can start executing. + * @return {@link SearchPhase.SearchPhaseName} + */ + SearchPhase.SearchPhaseName getBeforePhase(); + + /** + * The phase which should run after, this processor execution. + * @return {@link SearchPhase.SearchPhaseName} + */ + SearchPhase.SearchPhaseName getAfterPhase(); + +} diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java index a77523649ec53..73e959ce981e6 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java @@ -16,6 +16,9 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.search.DeleteSearchPipelineRequest; import org.opensearch.action.search.PutSearchPipelineRequest; +import org.opensearch.action.search.SearchPhase; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.master.AcknowledgedResponse; @@ -45,6 +48,7 @@ import org.opensearch.node.ReportingService; import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.script.ScriptService; +import org.opensearch.search.SearchPhaseResult; import org.opensearch.threadpool.ThreadPool; import java.util.ArrayList; @@ -366,6 +370,23 @@ public SearchResponse transformResponse(SearchRequest request, SearchResponse se return searchResponse; } + public SearchPhaseResults transformSearchPhase( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final SearchPhase.SearchPhaseName currentPhase, + final SearchPhase.SearchPhaseName nextPhase + ) { + final String pipelineId = searchPhaseContext.getRequest().pipeline(); + if (pipelineId != null) { + PipelineHolder pipeline = pipelines.get(pipelineId); + if (pipeline == null) { + throw new IllegalArgumentException("Pipeline " + pipelineId + " is not defined"); + } + return pipeline.pipeline.runSearchPhaseTransformer(searchPhaseResult, searchPhaseContext, currentPhase, nextPhase); + } + return searchPhaseResult; + } + Map getProcessorFactories() { return processorFactories; } diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index ad2657517df9a..49798854d3bc8 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -34,6 +34,7 @@ import org.junit.After; import org.junit.Before; +import org.mockito.Mock; import org.opensearch.action.ActionListener; import org.opensearch.action.OriginalIndices; import org.opensearch.action.support.IndicesOptions; @@ -52,6 +53,7 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; @@ -83,6 +85,8 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase { private final List> resolvedNodes = new ArrayList<>(); private final Set releasedContexts = new CopyOnWriteArraySet<>(); private ExecutorService executor; + @Mock + private SearchPipelineService searchPipelineService; @Before @Override @@ -161,7 +165,8 @@ private AbstractSearchAsyncAction createAction( null, results, request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java index cb2ac328002c5..4beabebf8c001 100644 --- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -32,6 +32,7 @@ package org.opensearch.action.search; import org.apache.lucene.util.BytesRef; +import org.mockito.Mock; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.OriginalIndices; @@ -47,6 +48,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortBuilders; import org.opensearch.search.sort.SortOrder; @@ -72,6 +74,9 @@ public class CanMatchPreFilterSearchPhaseTests extends OpenSearchTestCase { + @Mock + private SearchPipelineService searchPipelineService; + public void testFilterShards() throws InterruptedException { final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( @@ -136,7 +141,8 @@ public void run() throws IOException { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ); canMatchPhase.start(); @@ -227,7 +233,8 @@ public void run() throws IOException { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ); canMatchPhase.start(); @@ -317,7 +324,8 @@ public void sendCanMatch( null, new ArraySearchPhaseResults<>(iter.size()), randomIntBetween(1, 32), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override @@ -344,7 +352,8 @@ protected void executePhaseOnShard( } } }, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ); canMatchPhase.start(); @@ -428,7 +437,8 @@ public void run() { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ); canMatchPhase.start(); @@ -527,7 +537,8 @@ public void run() { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ); canMatchPhase.start(); diff --git a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java index 277f2f1dee0bf..2fb7c5cb82d7f 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java @@ -31,6 +31,7 @@ package org.opensearch.action.search; +import org.mockito.Mock; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.OriginalIndices; @@ -50,6 +51,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; @@ -78,6 +80,9 @@ public class SearchAsyncActionTests extends OpenSearchTestCase { + @Mock + private SearchPipelineService searchPipelineService; + public void testSkipSearchShards() throws InterruptedException { SearchRequest request = new SearchRequest(); request.allowPartialSearchResults(true); @@ -135,7 +140,8 @@ public void testSkipSearchShards() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override @@ -253,7 +259,8 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override @@ -370,7 +377,8 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { TestSearchResponse response = new TestSearchResponse(); @@ -492,7 +500,8 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { TestSearchResponse response = new TestSearchResponse(); @@ -605,7 +614,8 @@ public void testAllowPartialResults() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + searchPipelineService ) { @Override diff --git a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 4a23c4ec18951..aedf61b146a60 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -214,7 +214,8 @@ public void sendExecuteQuery( timeProvider, null, task, - SearchResponse.Clusters.EMPTY + SearchResponse.Clusters.EMPTY, + null ) { @Override protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index 239ec79b91082..fb2238743d203 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -10,13 +10,22 @@ import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.opensearch.OpenSearchParseException; import org.opensearch.ResourceNotFoundException; import org.opensearch.Version; import org.opensearch.action.search.DeleteSearchPipelineRequest; +import org.opensearch.action.search.MockSearchPhaseContext; import org.opensearch.action.search.PutSearchPipelineRequest; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhase; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseController; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchProgressListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; @@ -27,8 +36,11 @@ import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.breaker.CircuitBreaker; +import org.opensearch.common.breaker.NoopCircuitBreaker; import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.xcontent.XContentType; @@ -37,7 +49,10 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; +import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.InternalAggregationTestCase; import org.opensearch.test.MockLogAppender; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -191,6 +206,41 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } } + private static class FakeSearchPhaseInjectorProcessor extends FakeProcessor implements SearchPhaseInjectorProcessor { + private Consumer querySearchResultConsumer; + + public FakeSearchPhaseInjectorProcessor( + String type, + String tag, + String description, + Consumer querySearchResultConsumer + ) { + super(type, tag, description); + this.querySearchResultConsumer = querySearchResultConsumer; + } + + @Override + public SearchPhaseResults execute( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext + ) { + List resultAtomicArray = searchPhaseResult.getAtomicArray().asList(); + // updating the maxScore + resultAtomicArray.forEach(querySearchResultConsumer); + return searchPhaseResult; + } + + @Override + public SearchPhase.SearchPhaseName getBeforePhase() { + return SearchPhase.SearchPhaseName.QUERY; + } + + @Override + public SearchPhase.SearchPhaseName getAfterPhase() { + return SearchPhase.SearchPhaseName.FETCH; + } + } + private SearchPipelineService createWithProcessors() { Map processors = new HashMap<>(); processors.put("scale_request_size", (processorFactories, tag, description, config) -> { @@ -206,6 +256,13 @@ private SearchPipelineService createWithProcessors() { float score = ((Number) config.remove("score")).floatValue(); return new FakeResponseProcessor("fixed_score", tag, description, rsp -> rsp.getHits().forEach(h -> h.score(score))); }); + + processors.put("max_score", (processorFactories, tag, description, config) -> { + final float finalScore = config.containsKey("score") ? ((Number) config.remove("score")).floatValue() : 100f; + final Consumer querySearchResultConsumer = (result) -> result.queryResult().topDocs().maxScore = finalScore; + return new FakeSearchPhaseInjectorProcessor("max_score", tag, description, querySearchResultConsumer); + }); + return createWithProcessors(processors); } @@ -253,7 +310,8 @@ public void testUpdatePipelines() { new BytesArray( "{ " + "\"request_processors\" : [ { \"scale_request_size\": { \"scale\" : 2 } } ], " - + "\"response_processors\" : [ { \"fixed_score\" : { \"score\" : 1.0 } } ]" + + "\"response_processors\" : [ { \"fixed_score\" : { \"score\" : 1.0 } } ]," + + "\"phase_injector_processors\" : [ { \"max_score\" : { \"score\": 100 } } ]" + "}" ), XContentType.JSON @@ -271,6 +329,11 @@ public void testUpdatePipelines() { "scale_request_size", searchPipelineService.getPipelines().get("_id").pipeline.getSearchRequestProcessors().get(0).getType() ); + assertEquals(1, searchPipelineService.getPipelines().get("_id").pipeline.getSearchPhaseInjectorProcessors().size()); + assertEquals( + "max_score", + searchPipelineService.getPipelines().get("_id").pipeline.getSearchPhaseInjectorProcessors().get(0).getType() + ); assertEquals(1, searchPipelineService.getPipelines().get("_id").pipeline.getSearchResponseProcessors().size()); assertEquals( "fixed_score", @@ -308,6 +371,7 @@ public void testPutPipeline() { assertEquals("empty pipeline", pipeline.pipeline.getDescription()); assertEquals(0, pipeline.pipeline.getSearchRequestProcessors().size()); assertEquals(0, pipeline.pipeline.getSearchResponseProcessors().size()); + assertEquals(0, pipeline.pipeline.getSearchPhaseInjectorProcessors().size()); } public void testPutInvalidPipeline() throws IllegalAccessException { @@ -505,6 +569,93 @@ public void testTransformResponse() throws Exception { ); } + public void testTransformSearchPhase() { + SearchPipelineService searchPipelineService = createWithProcessors(); + SearchPipelineMetadata metadata = new SearchPipelineMetadata( + Map.of( + "p1", + new PipelineConfiguration( + "p1", + new BytesArray("{\"phase_injector_processors\" : [ { \"max_score\" : { } } ]}"), + XContentType.JSON + ) + ) + ); + ClusterState clusterState = ClusterState.builder(new ClusterName("_name")).build(); + ClusterState previousState = clusterState; + clusterState = ClusterState.builder(clusterState) + .metadata(Metadata.builder().putCustom(SearchPipelineMetadata.TYPE, metadata)) + .build(); + searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState)); + SearchPhaseController controller = new SearchPhaseController( + writableRegistry(), + s -> InternalAggregationTestCase.emptyReduceContextBuilder() + ); + SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(10); + QueryPhaseResultConsumer searchPhaseResults = new QueryPhaseResultConsumer( + searchPhaseContext.getRequest(), + OpenSearchExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + controller, + SearchProgressListener.NOOP, + writableRegistry(), + 2, + exc -> {} + ); + + final QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.setShardIndex(1); + querySearchResult.topDocs(new TopDocsAndMaxScore(new TopDocs(null, new ScoreDoc[1]), 100f), null); + searchPhaseResults.consumeResult(querySearchResult, () -> {}); + + // First try without specifying a pipeline, which should be a no-op. + SearchPhaseResults notTransformedSearchPhaseResults = searchPipelineService.transformSearchPhase( + searchPhaseResults, + searchPhaseContext, + SearchPhase.SearchPhaseName.QUERY, + SearchPhase.SearchPhaseName.FETCH + ); + assertSame(searchPhaseResults, notTransformedSearchPhaseResults); + + // Now set the pipeline as p1 + searchPhaseContext.getRequest().pipeline("p1"); + + SearchPhaseResults transformed = searchPipelineService.transformSearchPhase( + searchPhaseResults, + searchPhaseContext, + SearchPhase.SearchPhaseName.QUERY, + SearchPhase.SearchPhaseName.FETCH + ); + + List resultAtomicArray = transformed.getAtomicArray().asList(); + assertEquals(1, resultAtomicArray.size()); + // updating the maxScore + for (SearchPhaseResult result : resultAtomicArray) { + assertEquals(100f, result.queryResult().topDocs().maxScore, 0); + } + + // Check Processor doesn't run for between other phases + SearchPhaseResults notTransformed = searchPipelineService.transformSearchPhase( + searchPhaseResults, + searchPhaseContext, + SearchPhase.SearchPhaseName.DFS_QUERY, + SearchPhase.SearchPhaseName.QUERY + ); + + assertSame(searchPhaseResults, notTransformed); + + searchPhaseContext.getRequest().pipeline("p2"); + expectThrows( + IllegalArgumentException.class, + () -> searchPipelineService.transformSearchPhase( + searchPhaseResults, + searchPhaseContext, + SearchPhase.SearchPhaseName.QUERY, + SearchPhase.SearchPhaseName.FETCH + ) + ); + } + public void testGetPipelines() { // assertEquals(0, SearchPipelineService.innerGetPipelines(null, "p1").size()); @@ -522,16 +673,23 @@ public void testGetPipelines() { "p2", new BytesArray("{\"response_processors\" : [ { \"fixed_score\": { \"score\" : 2 } } ] }"), XContentType.JSON + ), + "p3", + new PipelineConfiguration( + "p3", + new BytesArray("{\"phase_injector_processors\" : [ { \"max_score\" : { } } ]}"), + XContentType.JSON ) ) ); // Return all when no ids specified List pipelines = SearchPipelineService.innerGetPipelines(metadata); - assertEquals(2, pipelines.size()); + assertEquals(3, pipelines.size()); pipelines.sort(Comparator.comparing(PipelineConfiguration::getId)); assertEquals("p1", pipelines.get(0).getId()); assertEquals("p2", pipelines.get(1).getId()); + assertEquals("p3", pipelines.get(2).getId()); // Get specific pipeline pipelines = SearchPipelineService.innerGetPipelines(metadata, "p1"); @@ -547,17 +705,19 @@ public void testGetPipelines() { // Match all pipelines = SearchPipelineService.innerGetPipelines(metadata, "*"); - assertEquals(2, pipelines.size()); + assertEquals(3, pipelines.size()); pipelines.sort(Comparator.comparing(PipelineConfiguration::getId)); assertEquals("p1", pipelines.get(0).getId()); assertEquals("p2", pipelines.get(1).getId()); + assertEquals("p3", pipelines.get(2).getId()); // Match prefix pipelines = SearchPipelineService.innerGetPipelines(metadata, "p*"); - assertEquals(2, pipelines.size()); + assertEquals(3, pipelines.size()); pipelines.sort(Comparator.comparing(PipelineConfiguration::getId)); assertEquals("p1", pipelines.get(0).getId()); assertEquals("p2", pipelines.get(1).getId()); + assertEquals("p3", pipelines.get(2).getId()); } public void testValidatePipeline() throws Exception { @@ -565,6 +725,7 @@ public void testValidatePipeline() throws Exception { ProcessorInfo reqProcessor = new ProcessorInfo("scale_request_size"); ProcessorInfo rspProcessor = new ProcessorInfo("fixed_score"); + ProcessorInfo injProcessor = new ProcessorInfo("max_score"); DiscoveryNode n1 = new DiscoveryNode("n1", buildNewFakeTransportAddress(), Version.CURRENT); DiscoveryNode n2 = new DiscoveryNode("n2", buildNewFakeTransportAddress(), Version.CURRENT); PutSearchPipelineRequest putRequest = new PutSearchPipelineRequest( @@ -572,7 +733,8 @@ public void testValidatePipeline() throws Exception { new BytesArray( "{" + "\"request_processors\": [{ \"scale_request_size\": { \"scale\" : 2 } }]," - + "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]" + + "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]," + + "\"phase_injector_processors\" : [ { \"max_score\" : { } } ]" + "}" ), XContentType.JSON @@ -618,9 +780,9 @@ public void testValidatePipeline() throws Exception { searchPipelineService.validatePipeline( Map.of( n1, - new SearchPipelineInfo(List.of(reqProcessor, rspProcessor)), + new SearchPipelineInfo(List.of(reqProcessor, rspProcessor, injProcessor)), n2, - new SearchPipelineInfo(List.of(reqProcessor, rspProcessor)) + new SearchPipelineInfo(List.of(reqProcessor, rspProcessor, injProcessor)) ), putRequest ); @@ -631,5 +793,6 @@ public void testInfo() { SearchPipelineInfo info = searchPipelineService.info(); assertTrue(info.containsProcessor("scale_request_size")); assertTrue(info.containsProcessor("fixed_score")); + assertTrue(info.containsProcessor("max_score")); } }