From 3ba3f7efcc4f2877ee0d3a0596f96da04776ec0a Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Wed, 28 Jun 2023 18:01:17 -0700 Subject: [PATCH] Adding the SearchPhaseResultsProcessor interface in Search Pipeline (#7283) * Initial code for adding the SearchPhaseInjectorProcessor interface in Search Pipeline Signed-off-by: Navneet Verma * Pass PipelinedRequest to SearchAsyncActions We should resolve a search pipeline once at the start of a search request and then propagate that pipeline through the async actions. When completing a search phase, we will then use that pipeline to inject behavior (if applicable). Signed-off-by: Michael Froh * Renamed SearchPhaseInjectorProcessor to SearchPhaseResultsProcessor and fixed the comments Signed-off-by: Navneet Verma * Make PipelinedSearchRequest extend SearchRequest Rather than wrapping a SearchRequest in a PipelinedSearchRequest, changes are less intrusive if we say that a PipelinedSearchRequest "is a" SearchRequest. Signed-off-by: Michael Froh * Revert code change from merge conflict Signed-off-by: Michael Froh * Updated the changelog with more appropiate wording for the change. Signed-off-by: Navneet Verma * Fixed Typos in the code Signed-off-by: Navneet Verma * Fixing comments relating to return of SearchPhaseResults from processor Signed-off-by: Navneet Verma * Moved SearchPhaseName enum in separate class and fixed comments. Signed-off-by: Navneet Verma * Resolve remaining merge conflict Signed-off-by: Michael Froh --------- Signed-off-by: Navneet Verma Signed-off-by: Michael Froh Co-authored-by: Michael Froh Co-authored-by: Andrew Ross Signed-off-by: Shivansh Arora --- CHANGELOG.md | 1 + .../search_pipeline/50_script_processor.yml | 2 +- .../search/AbstractSearchAsyncAction.java | 7 +- .../search/ArraySearchPhaseResults.java | 2 +- .../search/CanMatchPreFilterSearchPhase.java | 2 +- .../action/search/DfsQueryPhase.java | 2 +- .../action/search/ExpandSearchPhase.java | 2 +- .../action/search/FetchSearchPhase.java | 2 +- .../opensearch/action/search/SearchPhase.java | 10 + .../action/search/SearchPhaseContext.java | 2 +- .../action/search/SearchPhaseName.java | 31 +++ .../action/search/SearchPhaseResults.java | 10 +- .../search/SearchScrollAsyncAction.java | 2 +- ...SearchScrollQueryThenFetchAsyncAction.java | 2 +- .../action/search/TransportSearchAction.java | 7 +- .../plugins/SearchPipelinePlugin.java | 12 + .../opensearch/search/pipeline/Pipeline.java | 33 ++- .../search/pipeline/PipelineWithMetrics.java | 24 +- .../search/pipeline/PipelinedRequest.java | 19 +- .../pipeline/SearchPhaseResultsProcessor.java | 47 ++++ .../pipeline/SearchPipelineService.java | 10 +- .../pipeline/SearchPipelineServiceTests.java | 219 ++++++++++++++++-- 22 files changed, 402 insertions(+), 46 deletions(-) create mode 100644 server/src/main/java/org/opensearch/action/search/SearchPhaseName.java create mode 100644 server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java diff --git a/CHANGELOG.md b/CHANGELOG.md index faa64a162cbd8..93b2296475fd1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x] ### Added +- [SearchPipeline] Add new search pipeline processor type, SearchPhaseResultsProcessor, that can modify the result of one search phase before starting the next phase.([#7283](https://github.com/opensearch-project/OpenSearch/pull/7283)) - Add task cancellation monitoring service ([#7642](https://github.com/opensearch-project/OpenSearch/pull/7642)) - Add TokenManager Interface ([#7452](https://github.com/opensearch-project/OpenSearch/pull/7452)) - Add Remote store as a segment replication source ([#7653](https://github.com/opensearch-project/OpenSearch/pull/7653)) diff --git a/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/50_script_processor.yml b/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/50_script_processor.yml index 9b2dc0c41ff31..9d855e8a1861a 100644 --- a/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/50_script_processor.yml +++ b/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/50_script_processor.yml @@ -39,7 +39,7 @@ teardown: { "script" : { "lang" : "painless", - "source" : "ctx._source['size'] += 10; ctx._source['from'] -= 1; ctx._source['explain'] = !ctx._source['explain']; ctx._source['version'] = !ctx._source['version']; ctx._source['seq_no_primary_term'] = !ctx._source['seq_no_primary_term']; ctx._source['track_scores'] = !ctx._source['track_scores']; ctx._source['track_total_hits'] = 1; ctx._source['min_score'] -= 0.9; ctx._source['terminate_after'] += 2; ctx._source['profile'] = !ctx._source['profile'];" + "source" : "ctx._source['size'] += 10; ctx._source['from'] = ctx._source['from'] <= 0 ? ctx._source['from'] : ctx._source['from'] - 1 ; ctx._source['explain'] = !ctx._source['explain']; ctx._source['version'] = !ctx._source['version']; ctx._source['seq_no_primary_term'] = !ctx._source['seq_no_primary_term']; ctx._source['track_scores'] = !ctx._source['track_scores']; ctx._source['track_total_hits'] = 1; ctx._source['min_score'] -= 0.9; ctx._source['terminate_after'] += 2; ctx._source['profile'] = !ctx._source['profile'];" } } ] diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 969e0edbbc9d6..48fac9e8c8d38 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -57,6 +57,7 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.transport.Transport; import java.util.ArrayDeque; @@ -696,7 +697,11 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { * @see #onShardResult(SearchPhaseResult, SearchShardIterator) */ final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() - executeNextPhase(this, getNextPhase(results, this)); + final SearchPhase nextPhase = getNextPhase(results, this); + if (request instanceof PipelinedRequest && nextPhase != null) { + ((PipelinedRequest) request).transformSearchPhaseResults(results, this, this.getName(), nextPhase.getName()); + } + executeNextPhase(this, nextPhase); } @Override diff --git a/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java b/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java index 61c81e6cda97a..653b0e8aedb9d 100644 --- a/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java +++ b/server/src/main/java/org/opensearch/action/search/ArraySearchPhaseResults.java @@ -66,7 +66,7 @@ boolean hasResult(int shardIndex) { } @Override - AtomicArray getAtomicArray() { + public AtomicArray getAtomicArray() { return results; } } diff --git a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java index ec4d45a0a7124..c026c72f77f00 100644 --- a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java @@ -94,7 +94,7 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction, SearchPhase> nextPhaseFactory, SearchPhaseContext context ) { - super("dfs_query"); + super(SearchPhaseName.DFS_QUERY.getName()); this.progressListener = context.getTask().getProgressListener(); this.queryResult = queryResult; this.searchResults = searchResults; diff --git a/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java b/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java index cdefe7c2c1712..618a5620ce093 100644 --- a/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/ExpandSearchPhase.java @@ -62,7 +62,7 @@ final class ExpandSearchPhase extends SearchPhase { private final AtomicArray queryResults; ExpandSearchPhase(SearchPhaseContext context, InternalSearchResponse searchResponse, AtomicArray queryResults) { - super("expand"); + super(SearchPhaseName.EXPAND.getName()); this.context = context; this.searchResponse = searchResponse; this.queryResults = queryResults; diff --git a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java index 31ec896856ce6..85a3d140977bb 100644 --- a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java @@ -92,7 +92,7 @@ final class FetchSearchPhase extends SearchPhase { SearchPhaseContext context, BiFunction, SearchPhase> nextPhaseFactory ) { - super("fetch"); + super(SearchPhaseName.FETCH.getName()); if (context.getNumShards() != resultConsumer.getNumShards()) { throw new IllegalStateException( "number of shards must match the length of the query results but doesn't:" diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhase.java b/server/src/main/java/org/opensearch/action/search/SearchPhase.java index 50f0940754078..50b0cd8e01c1d 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhase.java @@ -34,6 +34,7 @@ import org.opensearch.common.CheckedRunnable; import java.io.IOException; +import java.util.Locale; import java.util.Objects; /** @@ -54,4 +55,13 @@ protected SearchPhase(String name) { public String getName() { return name; } + + /** + * Returns the SearchPhase name as {@link SearchPhaseName}. Exception will come if SearchPhase name is not defined + * in {@link SearchPhaseName} + * @return {@link SearchPhaseName} + */ + public SearchPhaseName getSearchPhaseName() { + return SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT)); + } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java index 04b481249520b..018035f21179b 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java @@ -50,7 +50,7 @@ * * @opensearch.internal */ -interface SearchPhaseContext extends Executor { +public interface SearchPhaseContext extends Executor { // TODO maybe we can make this concrete later - for now we just implement this in the base class for all initial phases /** diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java new file mode 100644 index 0000000000000..b6f842cf2cce1 --- /dev/null +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java @@ -0,0 +1,31 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +/** + * Enum for different Search Phases in OpenSearch + * @opensearch.internal + */ +public enum SearchPhaseName { + QUERY("query"), + FETCH("fetch"), + DFS_QUERY("dfs_query"), + EXPAND("expand"), + CAN_MATCH("can_match"); + + private final String name; + + SearchPhaseName(final String name) { + this.name = name; + } + + public String getName() { + return name; + } +} diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java index 1baea0e721c44..2e6068b1ecddc 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseResults.java @@ -42,7 +42,7 @@ * * @opensearch.internal */ -abstract class SearchPhaseResults { +public abstract class SearchPhaseResults { private final int numShards; SearchPhaseResults(int numShards) { @@ -75,7 +75,13 @@ final int getNumShards() { void consumeShardFailure(int shardIndex) {} - AtomicArray getAtomicArray() { + /** + * Returns an {@link AtomicArray} of {@link Result}, which are nothing but the SearchPhaseResults + * for shards. The {@link Result} are of type {@link SearchPhaseResult} + * + * @return an {@link AtomicArray} of {@link Result} + */ + public AtomicArray getAtomicArray() { throw new UnsupportedOperationException(); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchScrollAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchScrollAsyncAction.java index 0b477624b15cc..899c7a3c1dabd 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchScrollAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchScrollAsyncAction.java @@ -266,7 +266,7 @@ protected SearchPhase sendResponsePhase( SearchPhaseController.ReducedQueryPhase queryPhase, final AtomicArray fetchResults ) { - return new SearchPhase("fetch") { + return new SearchPhase(SearchPhaseName.FETCH.getName()) { @Override public void run() throws IOException { sendResponse(queryPhase, fetchResults); diff --git a/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java index 4119cb1cf28a0..9c0721ef63ea6 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchScrollQueryThenFetchAsyncAction.java @@ -92,7 +92,7 @@ protected void executeInitialPhase( @Override protected SearchPhase moveToNextPhase(BiFunction clusterNodeLookup) { - return new SearchPhase("fetch") { + return new SearchPhase(SearchPhaseName.FETCH.getName()) { @Override public void run() { final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedScrollQueryPhase( diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 69f529fe1d00c..df2170cbe2af1 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -390,13 +390,12 @@ private void executeRequest( relativeStartNanos, System::nanoTime ); - SearchRequest searchRequest; + PipelinedRequest searchRequest; ActionListener listener; try { - PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(originalSearchRequest); - searchRequest = pipelinedRequest.transformedRequest(); + searchRequest = searchPipelineService.resolvePipeline(originalSearchRequest); listener = ActionListener.wrap( - r -> originalListener.onResponse(pipelinedRequest.transformResponse(r)), + r -> originalListener.onResponse(searchRequest.transformResponse(r)), originalListener::onFailure ); } catch (Exception e) { diff --git a/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java b/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java index b8ceddecd3d20..3d76bab93a60c 100644 --- a/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java +++ b/server/src/main/java/org/opensearch/plugins/SearchPipelinePlugin.java @@ -9,6 +9,7 @@ package org.opensearch.plugins; import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.search.pipeline.SearchResponseProcessor; @@ -42,4 +43,15 @@ default Map> getRequestProcess default Map> getResponseProcessors(Processor.Parameters parameters) { return Collections.emptyMap(); } + + /** + * Returns additional search pipeline search phase results processor types added by this plugin. + * + * The key of the returned {@link Map} is the unique name for the processor which is specified + * in pipeline configurations, and the value is a {@link org.opensearch.search.pipeline.Processor.Factory} + * to create the processor from a given pipeline configuration. + */ + default Map> getSearchPhaseResultsProcessors(Processor.Parameters parameters) { + return Collections.emptyMap(); + } } diff --git a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java index 6f44daf48ed21..92826eee5a4f4 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java @@ -8,6 +8,8 @@ package org.opensearch.search.pipeline; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.Nullable; @@ -15,6 +17,7 @@ import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.search.SearchPhaseResult; import java.util.Collections; import java.util.List; @@ -28,6 +31,7 @@ class Pipeline { public static final String REQUEST_PROCESSORS_KEY = "request_processors"; public static final String RESPONSE_PROCESSORS_KEY = "response_processors"; + public static final String PHASE_PROCESSORS_KEY = "phase_results_processors"; private final String id; private final String description; private final Integer version; @@ -36,7 +40,7 @@ class Pipeline { // Then these can be CompoundProcessors instead of lists. private final List searchRequestProcessors; private final List searchResponseProcessors; - + private final List searchPhaseResultsProcessors; private final NamedWriteableRegistry namedWriteableRegistry; private final LongSupplier relativeTimeSupplier; @@ -46,6 +50,7 @@ class Pipeline { @Nullable Integer version, List requestProcessors, List responseProcessors, + List phaseResultsProcessors, NamedWriteableRegistry namedWriteableRegistry, LongSupplier relativeTimeSupplier ) { @@ -54,6 +59,7 @@ class Pipeline { this.version = version; this.searchRequestProcessors = Collections.unmodifiableList(requestProcessors); this.searchResponseProcessors = Collections.unmodifiableList(responseProcessors); + this.searchPhaseResultsProcessors = Collections.unmodifiableList(phaseResultsProcessors); this.namedWriteableRegistry = namedWriteableRegistry; this.relativeTimeSupplier = relativeTimeSupplier; } @@ -78,6 +84,10 @@ List getSearchResponseProcessors() { return searchResponseProcessors; } + List getSearchPhaseResultsProcessors() { + return searchPhaseResultsProcessors; + } + protected void beforeTransformRequest() {} protected void afterTransformRequest(long timeInNanos) {} @@ -168,14 +178,33 @@ SearchResponse transformResponse(SearchRequest request, SearchResponse response) return response; } + void runSearchPhaseResultsTransformer( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext context, + String currentPhase, + String nextPhase + ) throws SearchPipelineProcessingException { + + try { + for (SearchPhaseResultsProcessor searchPhaseResultsProcessor : searchPhaseResultsProcessors) { + if (currentPhase.equals(searchPhaseResultsProcessor.getBeforePhase().getName()) + && nextPhase.equals(searchPhaseResultsProcessor.getAfterPhase().getName())) { + searchPhaseResultsProcessor.process(searchPhaseResult, context); + } + } + } catch (RuntimeException e) { + throw new SearchPipelineProcessingException(e); + } + } + static final Pipeline NO_OP_PIPELINE = new Pipeline( SearchPipelineService.NOOP_PIPELINE_ID, "Pipeline that does not transform anything", 0, Collections.emptyList(), Collections.emptyList(), + Collections.emptyList(), null, () -> 0L ); - } diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelineWithMetrics.java b/server/src/main/java/org/opensearch/search/pipeline/PipelineWithMetrics.java index 662473f190006..612e979e56070 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelineWithMetrics.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelineWithMetrics.java @@ -43,12 +43,22 @@ class PipelineWithMetrics extends Pipeline { Integer version, List requestProcessors, List responseProcessors, + List phaseResultsProcessors, NamedWriteableRegistry namedWriteableRegistry, OperationMetrics totalRequestMetrics, OperationMetrics totalResponseMetrics, LongSupplier relativeTimeSupplier ) { - super(id, description, version, requestProcessors, responseProcessors, namedWriteableRegistry, relativeTimeSupplier); + super( + id, + description, + version, + requestProcessors, + responseProcessors, + phaseResultsProcessors, + namedWriteableRegistry, + relativeTimeSupplier + ); this.totalRequestMetrics = totalRequestMetrics; this.totalResponseMetrics = totalResponseMetrics; for (Processor requestProcessor : getSearchRequestProcessors()) { @@ -64,6 +74,7 @@ static PipelineWithMetrics create( Map config, Map> requestProcessorFactories, Map> responseProcessorFactories, + Map> phaseResultsProcessorFactories, NamedWriteableRegistry namedWriteableRegistry, OperationMetrics totalRequestProcessingMetrics, OperationMetrics totalResponseProcessingMetrics @@ -79,6 +90,16 @@ static PipelineWithMetrics create( RESPONSE_PROCESSORS_KEY ); List responseProcessors = readProcessors(responseProcessorFactories, responseProcessorConfigs); + List> phaseResultsProcessorConfigs = ConfigurationUtils.readOptionalList( + null, + null, + config, + PHASE_PROCESSORS_KEY + ); + List phaseResultsProcessors = readProcessors( + phaseResultsProcessorFactories, + phaseResultsProcessorConfigs + ); if (config.isEmpty() == false) { throw new OpenSearchParseException( "pipeline [" @@ -93,6 +114,7 @@ static PipelineWithMetrics create( version, requestProcessors, responseProcessors, + phaseResultsProcessors, namedWriteableRegistry, totalRequestProcessingMetrics, totalResponseProcessingMetrics, diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java index 0cfff013f4021..5a7539808c127 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java @@ -8,29 +8,36 @@ package org.opensearch.search.pipeline; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.search.SearchPhaseResult; /** * Groups a search pipeline based on a request and the request after being transformed by the pipeline. * * @opensearch.internal */ -public final class PipelinedRequest { +public final class PipelinedRequest extends SearchRequest { private final Pipeline pipeline; - private final SearchRequest transformedRequest; PipelinedRequest(Pipeline pipeline, SearchRequest transformedRequest) { + super(transformedRequest); this.pipeline = pipeline; - this.transformedRequest = transformedRequest; } public SearchResponse transformResponse(SearchResponse response) { - return pipeline.transformResponse(transformedRequest, response); + return pipeline.transformResponse(this, response); } - public SearchRequest transformedRequest() { - return transformedRequest; + public void transformSearchPhaseResults( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final String currentPhase, + final String nextPhase + ) { + pipeline.runSearchPhaseResultsTransformer(searchPhaseResult, searchPhaseContext, currentPhase, nextPhase); } // Visible for testing diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java new file mode 100644 index 0000000000000..772dc8758bace --- /dev/null +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline; + +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; + +/** + * Creates a processor that runs between Phases of the Search. + * @opensearch.api + */ +public interface SearchPhaseResultsProcessor extends Processor { + + /** + * Processes the {@link SearchPhaseResults} obtained from a SearchPhase which will be returned to next + * SearchPhase. + * @param searchPhaseResult {@link SearchPhaseResults} + * @param searchPhaseContext {@link SearchContext} + * @param {@link SearchPhaseResult} + */ + void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ); + + /** + * The phase which should have run before, this processor can start executing. + * @return {@link SearchPhaseName} + */ + SearchPhaseName getBeforePhase(); + + /** + * The phase which should run after, this processor execution. + * @return {@link SearchPhaseName} + */ + SearchPhaseName getAfterPhase(); + +} diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java index 434c8fbfacc74..70dc8546a077f 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java @@ -73,6 +73,7 @@ public class SearchPipelineService implements ClusterStateApplier, ReportingServ private final ScriptService scriptService; private final Map> requestProcessorFactories; private final Map> responseProcessorFactories; + private final Map> phaseInjectorProcessorFactories; private volatile Map pipelines = Collections.emptyMap(); private final ThreadPool threadPool; private final List> searchPipelineClusterStateListeners = new CopyOnWriteArrayList<>(); @@ -116,6 +117,10 @@ public SearchPipelineService( ); this.requestProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getRequestProcessors(parameters)); this.responseProcessorFactories = processorFactories(searchPipelinePlugins, p -> p.getResponseProcessors(parameters)); + this.phaseInjectorProcessorFactories = processorFactories( + searchPipelinePlugins, + p -> p.getSearchPhaseResultsProcessors(parameters) + ); putPipelineTaskKey = clusterService.registerClusterManagerTask(ClusterManagerTaskKeys.PUT_SEARCH_PIPELINE_KEY, true); deletePipelineTaskKey = clusterService.registerClusterManagerTask(ClusterManagerTaskKeys.DELETE_SEARCH_PIPELINE_KEY, true); this.isEnabled = isEnabled; @@ -181,6 +186,7 @@ void innerUpdatePipelines(SearchPipelineMetadata newSearchPipelineMetadata) { newConfiguration.getConfigAsMap(), requestProcessorFactories, responseProcessorFactories, + phaseInjectorProcessorFactories, namedWriteableRegistry, totalRequestProcessingMetrics, totalResponseProcessingMetrics @@ -280,6 +286,7 @@ void validatePipeline(Map searchPipelineInfos pipelineConfig, requestProcessorFactories, responseProcessorFactories, + phaseInjectorProcessorFactories, namedWriteableRegistry, new OperationMetrics(), // Use ephemeral metrics for validation new OperationMetrics() @@ -359,7 +366,7 @@ static ClusterState innerDelete(DeleteSearchPipelineRequest request, ClusterStat return newState.build(); } - public PipelinedRequest resolvePipeline(SearchRequest searchRequest) throws Exception { + public PipelinedRequest resolvePipeline(SearchRequest searchRequest) { Pipeline pipeline = Pipeline.NO_OP_PIPELINE; if (isEnabled == false) { @@ -378,6 +385,7 @@ public PipelinedRequest resolvePipeline(SearchRequest searchRequest) throws Exce searchRequest.source().searchPipelineSource(), requestProcessorFactories, responseProcessorFactories, + phaseInjectorProcessorFactories, namedWriteableRegistry, totalRequestProcessingMetrics, totalResponseProcessingMetrics diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index 219dddff40b35..2ac0b2136ddd9 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -10,13 +10,22 @@ import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.opensearch.OpenSearchParseException; import org.opensearch.ResourceNotFoundException; import org.opensearch.Version; import org.opensearch.action.search.DeleteSearchPipelineRequest; +import org.opensearch.action.search.MockSearchPhaseContext; import org.opensearch.action.search.PutSearchPipelineRequest; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseController; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchProgressListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; @@ -28,10 +37,14 @@ import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.breaker.CircuitBreaker; +import org.opensearch.common.breaker.NoopCircuitBreaker; import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.metrics.OperationStats; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.AtomicArray; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.IndexSettings; @@ -40,7 +53,10 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; +import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.InternalAggregationTestCase; import org.opensearch.test.MockLogAppender; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -68,6 +84,13 @@ public Map> getRequestProcesso public Map> getResponseProcessors(Processor.Parameters parameters) { return Map.of("bar", (factories, tag, description, config) -> null); } + + @Override + public Map> getSearchPhaseResultsProcessors( + Processor.Parameters parameters + ) { + return Map.of("zoe", (factories, tag, description, config) -> null); + } }; private ThreadPool threadPool; @@ -178,13 +201,13 @@ public void testResolveIndexDefaultPipeline() throws Exception { SearchRequest searchRequest = new SearchRequest("my_index").source(SearchSourceBuilder.searchSource().size(5)); PipelinedRequest pipelinedRequest = service.resolvePipeline(searchRequest); assertEquals("p1", pipelinedRequest.getPipeline().getId()); - assertEquals(10, pipelinedRequest.transformedRequest().source().size()); + assertEquals(10, pipelinedRequest.source().size()); // Bypass the default pipeline searchRequest.pipeline("_none"); pipelinedRequest = service.resolvePipeline(searchRequest); assertEquals("_none", pipelinedRequest.getPipeline().getId()); - assertEquals(5, pipelinedRequest.transformedRequest().source().size()); + assertEquals(5, pipelinedRequest.source().size()); } private static abstract class FakeProcessor implements Processor { @@ -244,6 +267,40 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } } + private static class FakeSearchPhaseResultsProcessor extends FakeProcessor implements SearchPhaseResultsProcessor { + private Consumer querySearchResultConsumer; + + public FakeSearchPhaseResultsProcessor( + String type, + String tag, + String description, + Consumer querySearchResultConsumer + ) { + super(type, tag, description); + this.querySearchResultConsumer = querySearchResultConsumer; + } + + @Override + public void process( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext + ) { + List resultAtomicArray = searchPhaseResult.getAtomicArray().asList(); + // updating the maxScore + resultAtomicArray.forEach(querySearchResultConsumer); + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.QUERY; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.FETCH; + } + } + private SearchPipelineService createWithProcessors() { Map> requestProcessors = new HashMap<>(); requestProcessors.put("scale_request_size", (processorFactories, tag, description, config) -> { @@ -260,7 +317,15 @@ private SearchPipelineService createWithProcessors() { float score = ((Number) config.remove("score")).floatValue(); return new FakeResponseProcessor("fixed_score", tag, description, rsp -> rsp.getHits().forEach(h -> h.score(score))); }); - return createWithProcessors(requestProcessors, responseProcessors); + + Map> searchPhaseProcessors = new HashMap<>(); + searchPhaseProcessors.put("max_score", (processorFactories, tag, description, config) -> { + final float finalScore = config.containsKey("score") ? ((Number) config.remove("score")).floatValue() : 100f; + final Consumer querySearchResultConsumer = (result) -> result.queryResult().topDocs().maxScore = finalScore; + return new FakeSearchPhaseResultsProcessor("max_score", tag, description, querySearchResultConsumer); + }); + + return createWithProcessors(requestProcessors, responseProcessors, searchPhaseProcessors); } @Override @@ -271,7 +336,8 @@ protected NamedWriteableRegistry writableRegistry() { private SearchPipelineService createWithProcessors( Map> requestProcessors, - Map> responseProcessors + Map> responseProcessors, + Map> phaseProcessors ) { Client client = mock(Client.class); ThreadPool threadPool = mock(ThreadPool.class); @@ -296,6 +362,14 @@ public Map> getRequestProcesso public Map> getResponseProcessors(Processor.Parameters parameters) { return responseProcessors; } + + @Override + public Map> getSearchPhaseResultsProcessors( + Processor.Parameters parameters + ) { + return phaseProcessors; + } + }), client, true @@ -314,7 +388,8 @@ public void testUpdatePipelines() { new BytesArray( "{ " + "\"request_processors\" : [ { \"scale_request_size\": { \"scale\" : 2 } } ], " - + "\"response_processors\" : [ { \"fixed_score\" : { \"score\" : 1.0 } } ]" + + "\"response_processors\" : [ { \"fixed_score\" : { \"score\" : 1.0 } } ]," + + "\"phase_results_processors\" : [ { \"max_score\" : { \"score\": 100 } } ]" + "}" ), XContentType.JSON @@ -332,6 +407,11 @@ public void testUpdatePipelines() { "scale_request_size", searchPipelineService.getPipelines().get("_id").pipeline.getSearchRequestProcessors().get(0).getType() ); + assertEquals(1, searchPipelineService.getPipelines().get("_id").pipeline.getSearchPhaseResultsProcessors().size()); + assertEquals( + "max_score", + searchPipelineService.getPipelines().get("_id").pipeline.getSearchPhaseResultsProcessors().get(0).getType() + ); assertEquals(1, searchPipelineService.getPipelines().get("_id").pipeline.getSearchResponseProcessors().size()); assertEquals( "fixed_score", @@ -369,6 +449,7 @@ public void testPutPipeline() { assertEquals("empty pipeline", pipeline.pipeline.getDescription()); assertEquals(0, pipeline.pipeline.getSearchRequestProcessors().size()); assertEquals(0, pipeline.pipeline.getSearchResponseProcessors().size()); + assertEquals(0, pipeline.pipeline.getSearchPhaseResultsProcessors().size()); } public void testPutInvalidPipeline() throws IllegalAccessException { @@ -506,17 +587,14 @@ public void testTransformRequest() throws Exception { SearchRequest request = new SearchRequest("_index").source(sourceBuilder).pipeline("p1"); PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(request); - SearchRequest transformedRequest = pipelinedRequest.transformedRequest(); - assertEquals(2 * size, transformedRequest.source().size()); + assertEquals(2 * size, pipelinedRequest.source().size()); assertEquals(size, request.source().size()); // This request doesn't specify a pipeline, it doesn't get transformed. request = new SearchRequest("_index").source(sourceBuilder); pipelinedRequest = searchPipelineService.resolvePipeline(request); - SearchRequest notTransformedRequest = pipelinedRequest.transformedRequest(); - assertEquals(size, notTransformedRequest.source().size()); - assertSame(request, notTransformedRequest); + assertEquals(size, pipelinedRequest.source().size()); } public void testTransformResponse() throws Exception { @@ -565,6 +643,89 @@ public void testTransformResponse() throws Exception { } } + public void testTransformSearchPhase() { + SearchPipelineService searchPipelineService = createWithProcessors(); + SearchPipelineMetadata metadata = new SearchPipelineMetadata( + Map.of( + "p1", + new PipelineConfiguration( + "p1", + new BytesArray("{\"phase_results_processors\" : [ { \"max_score\" : { } } ]}"), + XContentType.JSON + ) + ) + ); + ClusterState clusterState = ClusterState.builder(new ClusterName("_name")).build(); + ClusterState previousState = clusterState; + clusterState = ClusterState.builder(clusterState) + .metadata(Metadata.builder().putCustom(SearchPipelineMetadata.TYPE, metadata)) + .build(); + searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState)); + SearchPhaseController controller = new SearchPhaseController( + writableRegistry(), + s -> InternalAggregationTestCase.emptyReduceContextBuilder() + ); + SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(10); + QueryPhaseResultConsumer searchPhaseResults = new QueryPhaseResultConsumer( + searchPhaseContext.getRequest(), + OpenSearchExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + controller, + SearchProgressListener.NOOP, + writableRegistry(), + 2, + exc -> {} + ); + + final QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.setShardIndex(1); + querySearchResult.topDocs(new TopDocsAndMaxScore(new TopDocs(null, new ScoreDoc[1]), 1f), null); + searchPhaseResults.consumeResult(querySearchResult, () -> {}); + + // First try without specifying a pipeline, which should be a no-op. + SearchRequest searchRequest = new SearchRequest(); + PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); + AtomicArray notTransformedSearchPhaseResults = searchPhaseResults.getAtomicArray(); + pipelinedRequest.transformSearchPhaseResults( + searchPhaseResults, + searchPhaseContext, + SearchPhaseName.QUERY.getName(), + SearchPhaseName.FETCH.getName() + ); + assertSame(searchPhaseResults.getAtomicArray(), notTransformedSearchPhaseResults); + + // Now set the pipeline as p1 + searchRequest = new SearchRequest().pipeline("p1"); + pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); + + pipelinedRequest.transformSearchPhaseResults( + searchPhaseResults, + searchPhaseContext, + SearchPhaseName.QUERY.getName(), + SearchPhaseName.FETCH.getName() + ); + + List resultAtomicArray = searchPhaseResults.getAtomicArray().asList(); + assertEquals(1, resultAtomicArray.size()); + // updating the maxScore + for (SearchPhaseResult result : resultAtomicArray) { + assertEquals(100f, result.queryResult().topDocs().maxScore, 0); + } + + // Check Processor doesn't run for between other phases + searchRequest = new SearchRequest().pipeline("p1"); + pipelinedRequest = searchPipelineService.resolvePipeline(searchRequest); + AtomicArray notTransformedSearchPhaseResult = searchPhaseResults.getAtomicArray(); + pipelinedRequest.transformSearchPhaseResults( + searchPhaseResults, + searchPhaseContext, + SearchPhaseName.DFS_QUERY.getName(), + SearchPhaseName.QUERY.getName() + ); + + assertSame(searchPhaseResults.getAtomicArray(), notTransformedSearchPhaseResult); + } + public void testGetPipelines() { // assertEquals(0, SearchPipelineService.innerGetPipelines(null, "p1").size()); @@ -582,16 +743,23 @@ public void testGetPipelines() { "p2", new BytesArray("{\"response_processors\" : [ { \"fixed_score\": { \"score\" : 2 } } ] }"), XContentType.JSON + ), + "p3", + new PipelineConfiguration( + "p3", + new BytesArray("{\"phase_results_processors\" : [ { \"max_score\" : { } } ]}"), + XContentType.JSON ) ) ); // Return all when no ids specified List pipelines = SearchPipelineService.innerGetPipelines(metadata); - assertEquals(2, pipelines.size()); + assertEquals(3, pipelines.size()); pipelines.sort(Comparator.comparing(PipelineConfiguration::getId)); assertEquals("p1", pipelines.get(0).getId()); assertEquals("p2", pipelines.get(1).getId()); + assertEquals("p3", pipelines.get(2).getId()); // Get specific pipeline pipelines = SearchPipelineService.innerGetPipelines(metadata, "p1"); @@ -607,17 +775,19 @@ public void testGetPipelines() { // Match all pipelines = SearchPipelineService.innerGetPipelines(metadata, "*"); - assertEquals(2, pipelines.size()); + assertEquals(3, pipelines.size()); pipelines.sort(Comparator.comparing(PipelineConfiguration::getId)); assertEquals("p1", pipelines.get(0).getId()); assertEquals("p2", pipelines.get(1).getId()); + assertEquals("p3", pipelines.get(2).getId()); // Match prefix pipelines = SearchPipelineService.innerGetPipelines(metadata, "p*"); - assertEquals(2, pipelines.size()); + assertEquals(3, pipelines.size()); pipelines.sort(Comparator.comparing(PipelineConfiguration::getId)); assertEquals("p1", pipelines.get(0).getId()); assertEquals("p2", pipelines.get(1).getId()); + assertEquals("p3", pipelines.get(2).getId()); } public void testValidatePipeline() throws Exception { @@ -625,6 +795,7 @@ public void testValidatePipeline() throws Exception { ProcessorInfo reqProcessor = new ProcessorInfo("scale_request_size"); ProcessorInfo rspProcessor = new ProcessorInfo("fixed_score"); + ProcessorInfo injProcessor = new ProcessorInfo("max_score"); DiscoveryNode n1 = new DiscoveryNode("n1", buildNewFakeTransportAddress(), Version.CURRENT); DiscoveryNode n2 = new DiscoveryNode("n2", buildNewFakeTransportAddress(), Version.CURRENT); PutSearchPipelineRequest putRequest = new PutSearchPipelineRequest( @@ -632,7 +803,8 @@ public void testValidatePipeline() throws Exception { new BytesArray( "{" + "\"request_processors\": [{ \"scale_request_size\": { \"scale\" : 2 } }]," - + "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]" + + "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]," + + "\"phase_results_processors\" : [ { \"max_score\" : { } } ]" + "}" ), XContentType.JSON @@ -699,8 +871,7 @@ public void testInlinePipeline() throws Exception { assertEquals(1, pipeline.getSearchResponseProcessors().size()); // Verify that pipeline transforms request - SearchRequest transformedRequest = pipelinedRequest.transformedRequest(); - assertEquals(200, transformedRequest.source().size()); + assertEquals(200, pipelinedRequest.source().size()); int size = 10; SearchHit[] hits = new SearchHit[size]; @@ -730,7 +901,7 @@ public void testExceptionOnPipelineCreation() { "bad_factory", (pf, t, f, c) -> { throw new RuntimeException(); } ); - SearchPipelineService searchPipelineService = createWithProcessors(badFactory, Collections.emptyMap()); + SearchPipelineService searchPipelineService = createWithProcessors(badFactory, Collections.emptyMap(), Collections.emptyMap()); Map pipelineSourceMap = new HashMap<>(); pipelineSourceMap.put(Pipeline.REQUEST_PROCESSORS_KEY, List.of(Map.of("bad_factory", Collections.emptyMap()))); @@ -752,7 +923,11 @@ public void testExceptionOnRequestProcessing() { (pf, t, f, c) -> throwingRequestProcessor ); - SearchPipelineService searchPipelineService = createWithProcessors(throwingRequestProcessorFactory, Collections.emptyMap()); + SearchPipelineService searchPipelineService = createWithProcessors( + throwingRequestProcessorFactory, + Collections.emptyMap(), + Collections.emptyMap() + ); Map pipelineSourceMap = new HashMap<>(); pipelineSourceMap.put(Pipeline.REQUEST_PROCESSORS_KEY, List.of(Map.of("throwing_request", Collections.emptyMap()))); @@ -773,7 +948,11 @@ public void testExceptionOnResponseProcessing() throws Exception { (pf, t, f, c) -> throwingResponseProcessor ); - SearchPipelineService searchPipelineService = createWithProcessors(Collections.emptyMap(), throwingResponseProcessorFactory); + SearchPipelineService searchPipelineService = createWithProcessors( + Collections.emptyMap(), + throwingResponseProcessorFactory, + Collections.emptyMap() + ); Map pipelineSourceMap = new HashMap<>(); pipelineSourceMap.put(Pipeline.RESPONSE_PROCESSORS_KEY, List.of(Map.of("throwing_response", Collections.emptyMap()))); @@ -807,7 +986,7 @@ public void testStats() throws Exception { "throwing_response", (pf, t, f, c) -> throwingResponseProcessor ); - SearchPipelineService searchPipelineService = createWithProcessors(requestProcessors, responseProcessors); + SearchPipelineService searchPipelineService = createWithProcessors(requestProcessors, responseProcessors, Collections.emptyMap()); SearchPipelineMetadata metadata = new SearchPipelineMetadata( Map.of(