diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index e9d18e1b6c694..fb478c4860e67 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -53,12 +53,12 @@ import org.opensearch.index.shard.ShardId; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.pipeline.PipelinedRequest; -import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.transport.Transport; import java.util.ArrayDeque; @@ -90,7 +90,7 @@ abstract class AbstractSearchAsyncAction exten private final SearchTransportService searchTransportService; private final Executor executor; private final ActionListener listener; - private final SearchRequest request; + private final PipelinedRequest request; /** * Used by subclasses to resolve node ids to DiscoveryNodes. **/ @@ -118,7 +118,6 @@ abstract class AbstractSearchAsyncAction exten private final boolean throttleConcurrentRequests; private final List releasables = new ArrayList<>(); - private final SearchPipelineService searchPipelineService; AbstractSearchAsyncAction( String name, @@ -129,7 +128,7 @@ abstract class AbstractSearchAsyncAction exten Map concreteIndexBoosts, Map> indexRoutings, Executor executor, - SearchRequest request, + PipelinedRequest request, ActionListener listener, GroupShardsIterator shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, @@ -137,8 +136,7 @@ abstract class AbstractSearchAsyncAction exten SearchTask task, SearchPhaseResults resultConsumer, int maxConcurrentRequestsPerNode, - SearchResponse.Clusters clusters, - SearchPipelineService searchPipelineService + SearchResponse.Clusters clusters ) { super(name); final List toSkipIterators = new ArrayList<>(); @@ -174,7 +172,6 @@ abstract class AbstractSearchAsyncAction exten this.indexRoutings = indexRoutings; this.results = resultConsumer; this.clusters = clusters; - this.searchPipelineService = searchPipelineService; } @Override @@ -200,9 +197,10 @@ public final void start() { if (getNumShards() == 0) { // no search shards to search on, bail with empty response // (it happens with search across _all with no indices around and consistent with broadcast operations) - int trackTotalHitsUpTo = request.source() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO - : request.source().trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO - : request.source().trackTotalHitsUpTo(); + SearchSourceBuilder source = request.transformedRequest().source(); + int trackTotalHitsUpTo = source == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO + : source.trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO + : source.trackTotalHitsUpTo(); // total hits is null in the response if the tracking of total hits is disabled boolean withTotalHits = trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED; listener.onResponse( @@ -229,9 +227,10 @@ public final void run() { assert iterator.skip(); skipShard(iterator); } + SearchRequest searchRequest = request.transformedRequest(); if (shardsIts.size() > 0) { - assert request.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (request.allowPartialSearchResults() == false) { + assert searchRequest.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults"; + if (searchRequest.allowPartialSearchResults() == false) { final StringBuilder missingShards = new StringBuilder(); // Fail-fast verification of all shards being available for (int index = 0; index < shardsIts.size(); index++) { @@ -376,7 +375,7 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha logger.debug(() -> new ParameterizedMessage("All shards failed for phase: [{}]", getName()), cause); onPhaseFailure(currentPhase, "all shards failed", cause); } else { - Boolean allowPartialResults = request.allowPartialSearchResults(); + Boolean allowPartialResults = request.transformedRequest().allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; if (allowPartialResults == false && successfulOps.get() != getNumShards()) { // check if there are actual failures in the atomic array since @@ -612,7 +611,7 @@ public final SearchTask getTask() { @Override public final SearchRequest getRequest() { - return request; + return request.transformedRequest(); } protected final SearchResponse buildSearchResponse( @@ -643,19 +642,22 @@ boolean buildPointInTimeFromSearchResults() { @Override public void sendSearchResponse(InternalSearchResponse internalSearchResponse, AtomicArray queryResults) { ShardSearchFailure[] failures = buildShardFailures(); - Boolean allowPartialResults = request.allowPartialSearchResults(); + Boolean allowPartialResults = request.transformedRequest().allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; if (allowPartialResults == false && failures.length > 0) { raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures)); } else { final Version minNodeVersion = clusterState.nodes().getMinNodeVersion(); - final String scrollId = request.scroll() != null ? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion) : null; + final String scrollId = request.transformedRequest().scroll() != null + ? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion) + : null; final String searchContextId; if (buildPointInTimeFromSearchResults()) { searchContextId = SearchContextId.encode(queryResults.asList(), aliasFilter, minNodeVersion); } else { - if (request.source() != null && request.source().pointInTimeBuilder() != null) { - searchContextId = request.source().pointInTimeBuilder().getId(); + SearchSourceBuilder source = request.transformedRequest().source(); + if (source != null && source.pointInTimeBuilder() != null) { + searchContextId = source.pointInTimeBuilder().getId(); } else { searchContextId = null; } @@ -677,7 +679,7 @@ public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) */ private void raisePhaseFailure(SearchPhaseExecutionException exception) { // we don't release persistent readers (point in time). - if (request.pointInTimeBuilder() == null) { + if (request.transformedRequest().pointInTimeBuilder() == null) { results.getSuccessfulResults().forEach((entry) -> { if (entry.getContextId() != null) { try { @@ -705,9 +707,7 @@ final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() // From src files the next phase is never null, but from tests this is a possibility. Hence, making sure that // tests pass, we need to do null check on next phase. if (nextPhase != null) { - - final PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(this.getRequest()); - pipelinedRequest.transformSearchPhase(results, this, this.getName(), nextPhase.getName()); + request.transformSearchPhase(results, this, this.getName(), nextPhase.getName()); } executeNextPhase(this, nextPhase); } @@ -741,7 +741,7 @@ public final ShardSearchRequest buildShardSearchRequest(SearchShardIterator shar final String[] routings = indexRoutings.getOrDefault(indexName, Collections.emptySet()).toArray(new String[0]); ShardSearchRequest shardRequest = new ShardSearchRequest( shardIt.getOriginalIndices(), - request, + request.transformedRequest(), shardIt.shardId(), getNumShards(), filter, diff --git a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java index 8b153392b03af..4226a814e096b 100644 --- a/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/CanMatchPreFilterSearchPhase.java @@ -41,7 +41,7 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; -import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortOrder; @@ -84,15 +84,14 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction concreteIndexBoosts, Map> indexRoutings, Executor executor, - SearchRequest request, + PipelinedRequest request, ActionListener listener, GroupShardsIterator shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState, SearchTask task, Function, SearchPhase> phaseFactory, - SearchResponse.Clusters clusters, - SearchPipelineService searchPipelineService + SearchResponse.Clusters clusters ) { // We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests super( @@ -112,8 +111,7 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, final ClusterState clusterState, final SearchTask task, - SearchResponse.Clusters clusters, - SearchPipelineService searchPipelineService + SearchResponse.Clusters clusters ) { super( "dfs", @@ -96,14 +95,13 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()), - request.getMaxConcurrentShardRequests(), - clusters, - searchPipelineService + request.transformedRequest().getMaxConcurrentShardRequests(), + clusters ); this.queryPhaseResultConsumer = queryPhaseResultConsumer; this.searchPhaseController = searchPhaseController; SearchProgressListener progressListener = task.getProgressListener(); - SearchSourceBuilder sourceBuilder = request.source(); + SearchSourceBuilder sourceBuilder = request.transformedRequest().source(); progressListener.notifyListShards( SearchProgressListener.buildSearchShards(this.shardsIts), SearchProgressListener.buildSearchShards(toSkipShardsIts), diff --git a/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java index 2aaa1d788c5bc..6e72dabfb439b 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -39,10 +39,11 @@ import org.opensearch.cluster.routing.GroupShardsIterator; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.transport.Transport; @@ -76,14 +77,13 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState, SearchTask task, - SearchResponse.Clusters clusters, - SearchPipelineService searchPipelineService + SearchResponse.Clusters clusters ) { super( "query", @@ -101,12 +101,11 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction 0; + SearchSourceBuilder source = request.transformedRequest().source(); + boolean hasFetchPhase = source == null ? true : source.size() > 0; progressListener.notifyListShards( SearchProgressListener.buildSearchShards(this.shardsIts), SearchProgressListener.buildSearchShards(toSkipShardsIts), diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index fe7fc2d7ee383..ef83f0450b21a 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -315,7 +315,7 @@ public void executeRequest( @Override public AbstractSearchAsyncAction asyncSearchAction( SearchTask task, - SearchRequest searchRequest, + PipelinedRequest pipelinedRequest, Executor executor, GroupShardsIterator shardsIts, SearchTimeProvider timeProvider, @@ -338,16 +338,15 @@ public AbstractSearchAsyncAction asyncSearchAction( concreteIndexBoosts, indexRoutings, executor, - searchRequest, + pipelinedRequest, listener, shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()), - searchRequest.getMaxConcurrentShardRequests(), - clusters, - searchPipelineService + pipelinedRequest.transformedRequest().getMaxConcurrentShardRequests(), + clusters ) { @Override protected void executePhaseOnShard( @@ -391,11 +390,10 @@ private void executeRequest( relativeStartNanos, System::nanoTime ); - SearchRequest searchRequest; + PipelinedRequest pipelinedRequest; ActionListener listener; try { - PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(originalSearchRequest); - searchRequest = pipelinedRequest.transformedRequest(); + pipelinedRequest = searchPipelineService.resolvePipeline(originalSearchRequest); listener = ActionListener.wrap( r -> originalListener.onResponse(pipelinedRequest.transformResponse(r)), originalListener::onFailure @@ -404,6 +402,7 @@ private void executeRequest( originalListener.onFailure(e); return; } + SearchRequest searchRequest = pipelinedRequest.transformedRequest(); ActionListener rewriteListener = ActionListener.wrap(source -> { if (source != searchRequest.source()) { @@ -430,7 +429,7 @@ private void executeRequest( executeLocalSearch( task, timeProvider, - searchRequest, + pipelinedRequest, localIndices, clusterState, listener, @@ -440,7 +439,7 @@ private void executeRequest( } else { if (shouldMinimizeRoundtrips(searchRequest)) { ccsRemoteReduce( - searchRequest, + pipelinedRequest, localIndices, remoteClusterIndices, timeProvider, @@ -497,7 +496,7 @@ private void executeRequest( executeSearch( (SearchTask) task, timeProvider, - searchRequest, + pipelinedRequest, localIndices, remoteShardIterators, clusterNodeLookup, @@ -545,7 +544,7 @@ static boolean shouldMinimizeRoundtrips(SearchRequest searchRequest) { } static void ccsRemoteReduce( - SearchRequest searchRequest, + PipelinedRequest pipelinedRequest, OriginalIndices localIndices, Map remoteIndices, SearchTimeProvider timeProvider, @@ -553,7 +552,7 @@ static void ccsRemoteReduce( RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener listener, - BiConsumer> localSearchConsumer + BiConsumer> localSearchConsumer ) { if (localIndices == null && remoteIndices.size() == 1) { @@ -564,7 +563,7 @@ static void ccsRemoteReduce( boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest( - searchRequest, + pipelinedRequest.transformedRequest(), indices.indices(), clusterAlias, timeProvider.getAbsoluteStartMillis(), @@ -613,7 +612,7 @@ public void onFailure(Exception e) { }); } else { SearchResponseMerger searchResponseMerger = createSearchResponseMerger( - searchRequest.source(), + pipelinedRequest.transformedRequest().source(), timeProvider, aggReduceContextBuilder ); @@ -626,7 +625,7 @@ public void onFailure(Exception e) { boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias); OriginalIndices indices = entry.getValue(); SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest( - searchRequest, + pipelinedRequest.transformedRequest(), indices.indices(), clusterAlias, timeProvider.getAbsoluteStartMillis(), @@ -657,13 +656,14 @@ public void onFailure(Exception e) { listener ); SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest( - searchRequest, + pipelinedRequest.transformedRequest(), localIndices.indices(), RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, timeProvider.getAbsoluteStartMillis(), false ); - localSearchConsumer.accept(ccsLocalSearchRequest, ccsListener); + + localSearchConsumer.accept(pipelinedRequest.replaceRequest(ccsLocalSearchRequest), ccsListener); } } } @@ -779,7 +779,7 @@ SearchResponse createFinalResponse() { private void executeLocalSearch( Task task, SearchTimeProvider timeProvider, - SearchRequest searchRequest, + PipelinedRequest pipelinedRequest, OriginalIndices localIndices, ClusterState clusterState, ActionListener listener, @@ -789,7 +789,7 @@ private void executeLocalSearch( executeSearch( (SearchTask) task, timeProvider, - searchRequest, + pipelinedRequest, localIndices, Collections.emptyList(), (clusterName, nodeId) -> null, @@ -907,7 +907,7 @@ private Index[] resolveLocalIndices(OriginalIndices localIndices, ClusterState c private void executeSearch( SearchTask task, SearchTimeProvider timeProvider, - SearchRequest searchRequest, + PipelinedRequest pipelinedRequest, OriginalIndices localIndices, List remoteShardIterators, BiFunction remoteConnections, @@ -929,6 +929,7 @@ private void executeSearch( final Map> indexRoutings; final String[] concreteLocalIndices; + final SearchRequest searchRequest = pipelinedRequest.transformedRequest(); if (searchContext != null) { assert searchRequest.pointInTimeBuilder() != null; aliasFilter = searchContext.aliasFilter(); @@ -1010,7 +1011,7 @@ private void executeSearch( ); searchAsyncActionProvider.asyncSearchAction( task, - searchRequest, + pipelinedRequest, asyncSearchExecutor, shardIterators, timeProvider, @@ -1093,7 +1094,7 @@ static GroupShardsIterator mergeShardsIterators( interface SearchAsyncActionProvider { AbstractSearchAsyncAction asyncSearchAction( SearchTask task, - SearchRequest searchRequest, + PipelinedRequest searchRequest, Executor executor, GroupShardsIterator shardIterators, SearchTimeProvider timeProvider, @@ -1111,7 +1112,7 @@ AbstractSearchAsyncAction asyncSearchAction( private AbstractSearchAsyncAction searchAsyncAction( SearchTask task, - SearchRequest searchRequest, + PipelinedRequest pipelinedRequest, Executor executor, GroupShardsIterator shardIterators, SearchTimeProvider timeProvider, @@ -1134,7 +1135,7 @@ private AbstractSearchAsyncAction searchAsyncAction concreteIndexBoosts, indexRoutings, executor, - searchRequest, + pipelinedRequest, listener, shardIterators, timeProvider, @@ -1143,7 +1144,7 @@ private AbstractSearchAsyncAction searchAsyncAction (iter) -> { AbstractSearchAsyncAction action = searchAsyncAction( task, - searchRequest, + pipelinedRequest, executor, iter, timeProvider, @@ -1164,10 +1165,10 @@ public void run() { } }; }, - clusters, - searchPipelineService + clusters ); } else { + final SearchRequest searchRequest = pipelinedRequest.transformedRequest(); final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults( executor, circuitBreaker, @@ -1189,14 +1190,13 @@ public void run() { searchPhaseController, executor, queryResultConsumer, - searchRequest, + pipelinedRequest, listener, shardIterators, timeProvider, clusterState, task, - clusters, - searchPipelineService + clusters ); break; case QUERY_THEN_FETCH: @@ -1210,14 +1210,13 @@ public void run() { searchPhaseController, executor, queryResultConsumer, - searchRequest, + pipelinedRequest, listener, shardIterators, timeProvider, clusterState, task, - clusters, - searchPipelineService + clusters ); break; default: diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java index e45b510d7c760..966d6ba5a3e9b 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java @@ -49,4 +49,21 @@ public SearchPhaseResults transformSe Pipeline getPipeline() { return pipeline; } + + /** + * Wraps a search request with a no-op pipeline. Useful for testing. + * + * @param searchRequest the original search request + * @return a search request associated with a pipeline that does nothing + */ + public static PipelinedRequest wrapSearchRequest(SearchRequest searchRequest) { + return new PipelinedRequest(Pipeline.NO_OP_PIPELINE, searchRequest); + } + + /** + * Wraps the given search request with this request's pipeline. + */ + public PipelinedRequest replaceRequest(SearchRequest searchRequest) { + return new PipelinedRequest(pipeline, searchRequest); + } } diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index f55ce93d019fe..206b8a571bb5b 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -52,7 +52,7 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; @@ -84,7 +84,6 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase { private final List> resolvedNodes = new ArrayList<>(); private final Set releasedContexts = new CopyOnWriteArraySet<>(); private ExecutorService executor; - private SearchPipelineService searchPipelineService; @Before @Override @@ -155,7 +154,7 @@ private AbstractSearchAsyncAction createAction( Collections.singletonMap("foo", 2.0f), Collections.singletonMap("name", Sets.newHashSet("bar", "baz")), executor, - request, + PipelinedRequest.wrapSearchRequest(request), listener, new GroupShardsIterator<>(Arrays.asList(shards)), timeProvider, @@ -163,8 +162,7 @@ private AbstractSearchAsyncAction createAction( null, results, request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java index 9876dbdf6f90b..1a743716025d6 100644 --- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -32,7 +32,6 @@ package org.opensearch.action.search; import org.apache.lucene.util.BytesRef; -import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.OriginalIndices; @@ -48,7 +47,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortBuilders; import org.opensearch.search.sort.SortOrder; @@ -74,8 +73,6 @@ public class CanMatchPreFilterSearchPhaseTests extends OpenSearchTestCase { - private final SearchPipelineService searchPipelineService = Mockito.mock(SearchPipelineService.class); - public void testFilterShards() throws InterruptedException { final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( @@ -127,7 +124,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, @@ -140,8 +137,7 @@ public void run() throws IOException { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ); canMatchPhase.start(); @@ -219,7 +215,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, @@ -232,8 +228,7 @@ public void run() throws IOException { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ); canMatchPhase.start(); @@ -301,7 +296,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, @@ -315,7 +310,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), executor, - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), responseListener, iter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -323,8 +318,7 @@ public void sendCanMatch( null, new ArraySearchPhaseResults<>(iter.size()), randomIntBetween(1, 32), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override @@ -351,8 +345,7 @@ protected void executePhaseOnShard( } } }, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ); canMatchPhase.start(); @@ -423,7 +416,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, @@ -436,8 +429,7 @@ public void run() { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ); canMatchPhase.start(); @@ -523,7 +515,7 @@ public void sendCanMatch( Collections.emptyMap(), Collections.emptyMap(), OpenSearchExecutors.newDirectExecutorService(), - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, @@ -536,8 +528,7 @@ public void run() { latch.countDown(); } }, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ); canMatchPhase.start(); diff --git a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java index 521778dcbf171..53131d884a60a 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java @@ -51,6 +51,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; @@ -131,7 +132,7 @@ public void testSkipSearchShards() throws InterruptedException { Collections.emptyMap(), Collections.emptyMap(), null, - request, + PipelinedRequest.wrapSearchRequest(request), responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -139,8 +140,7 @@ public void testSkipSearchShards() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override @@ -250,7 +250,7 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { Collections.emptyMap(), Collections.emptyMap(), null, - request, + PipelinedRequest.wrapSearchRequest(request), responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -258,8 +258,7 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override @@ -368,7 +367,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI Collections.emptyMap(), Collections.emptyMap(), executor, - request, + PipelinedRequest.wrapSearchRequest(request), responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -376,8 +375,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { TestSearchResponse response = new TestSearchResponse(); @@ -491,7 +489,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI Collections.emptyMap(), Collections.emptyMap(), executor, - request, + PipelinedRequest.wrapSearchRequest(request), responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -499,8 +497,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { TestSearchResponse response = new TestSearchResponse(); @@ -605,7 +602,7 @@ public void testAllowPartialResults() throws InterruptedException { Collections.emptyMap(), Collections.emptyMap(), null, - request, + PipelinedRequest.wrapSearchRequest(request), responseListener, shardsIter, new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), @@ -613,8 +610,7 @@ public void testAllowPartialResults() throws InterruptedException { null, new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override diff --git a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 3649ee554c197..e1bf9244b3a6b 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -37,7 +37,6 @@ import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.grouping.CollapseTopFieldDocs; -import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.action.OriginalIndices; import org.opensearch.cluster.node.DiscoveryNode; @@ -57,7 +56,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.sort.SortBuilders; import org.opensearch.test.OpenSearchTestCase; @@ -77,7 +76,6 @@ import static org.hamcrest.Matchers.instanceOf; public class SearchQueryThenFetchAsyncActionTests extends OpenSearchTestCase { - private final SearchPipelineService searchPipelineService = Mockito.mock(SearchPipelineService.class); public void testBottomFieldSort() throws Exception { testCase(false, false); @@ -212,14 +210,13 @@ public void sendExecuteQuery( controller, executor, resultConsumer, - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), null, shardsIter, timeProvider, null, task, - SearchResponse.Clusters.EMPTY, - searchPipelineService + SearchResponse.Clusters.EMPTY ) { @Override protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { diff --git a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java index 51d9a06c9ac43..96ffb016604f9 100644 --- a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java @@ -74,6 +74,7 @@ import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.PipelinedRequest; import org.opensearch.search.sort.SortBuilders; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.transport.MockTransportService; @@ -458,14 +459,14 @@ public void testCCSRemoteReduceMergeFails() throws Exception { SearchRequest searchRequest = new SearchRequest(); searchRequest.preference("null_target"); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -478,8 +479,8 @@ public void testCCSRemoteReduceMergeFails() throws Exception { if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -514,14 +515,14 @@ public void testCCSRemoteReduce() throws Exception { { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -534,8 +535,8 @@ public void testCCSRemoteReduce() throws Exception { if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -551,14 +552,14 @@ public void testCCSRemoteReduce() throws Exception { SearchRequest searchRequest = new SearchRequest(); searchRequest.preference("index_not_found"); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -571,8 +572,8 @@ public void testCCSRemoteReduce() throws Exception { if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -609,14 +610,14 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference failure = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(r -> fail("no response expected"), failure::set), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -629,8 +630,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -649,14 +650,14 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -669,8 +670,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); } @@ -700,14 +701,14 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti { SearchRequest searchRequest = new SearchRequest(); final CountDownLatch latch = new CountDownLatch(1); - SetOnce>> setOnce = new SetOnce<>(); + SetOnce>> setOnce = new SetOnce<>(); AtomicReference response = new AtomicReference<>(); LatchedActionListener listener = new LatchedActionListener<>( ActionListener.wrap(response::set, e -> fail("no failures expected")), latch ); TransportSearchAction.ccsRemoteReduce( - searchRequest, + PipelinedRequest.wrapSearchRequest(searchRequest), localIndices, remoteIndicesByCluster, timeProvider, @@ -720,8 +721,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti if (localIndices == null) { assertNull(setOnce.get()); } else { - Tuple> tuple = setOnce.get(); - assertEquals("", tuple.v1().getLocalClusterAlias()); + Tuple> tuple = setOnce.get(); + assertEquals("", tuple.v1().transformedRequest().getLocalClusterAlias()); assertThat(tuple.v2(), instanceOf(TransportSearchAction.CCSActionListener.class)); tuple.v2().onResponse(emptySearchResponse()); }