Skip to content

Commit

Permalink
Pass PipelinedRequest to SearchAsyncActions
Browse files Browse the repository at this point in the history
We should resolve a search pipeline once at the start of a search
request and then propagate that pipeline through the async actions.

When completing a search phase, we will then use that pipeline to inject
behavior (if applicable).

Signed-off-by: Michael Froh <[email protected]>
  • Loading branch information
msfroh committed Jun 28, 2023
1 parent 4de7035 commit b4bbec6
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@
import org.opensearch.index.shard.ShardId;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.pipeline.PipelinedRequest;
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.transport.Transport;

import java.util.ArrayDeque;
Expand Down Expand Up @@ -90,7 +90,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final SearchTransportService searchTransportService;
private final Executor executor;
private final ActionListener<SearchResponse> listener;
private final SearchRequest request;
private final PipelinedRequest request;
/**
* Used by subclasses to resolve node ids to DiscoveryNodes.
**/
Expand Down Expand Up @@ -118,7 +118,6 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final boolean throttleConcurrentRequests;

private final List<Releasable> releasables = new ArrayList<>();
private final SearchPipelineService searchPipelineService;

AbstractSearchAsyncAction(
String name,
Expand All @@ -129,16 +128,15 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
Map<String, Float> concreteIndexBoosts,
Map<String, Set<String>> indexRoutings,
Executor executor,
SearchRequest request,
PipelinedRequest request,
ActionListener<SearchResponse> listener,
GroupShardsIterator<SearchShardIterator> shardsIts,
TransportSearchAction.SearchTimeProvider timeProvider,
ClusterState clusterState,
SearchTask task,
SearchPhaseResults<Result> resultConsumer,
int maxConcurrentRequestsPerNode,
SearchResponse.Clusters clusters,
SearchPipelineService searchPipelineService
SearchResponse.Clusters clusters
) {
super(name);
final List<SearchShardIterator> toSkipIterators = new ArrayList<>();
Expand Down Expand Up @@ -174,7 +172,6 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
this.indexRoutings = indexRoutings;
this.results = resultConsumer;
this.clusters = clusters;
this.searchPipelineService = searchPipelineService;
}

@Override
Expand All @@ -200,9 +197,10 @@ public final void start() {
if (getNumShards() == 0) {
// no search shards to search on, bail with empty response
// (it happens with search across _all with no indices around and consistent with broadcast operations)
int trackTotalHitsUpTo = request.source() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO
: request.source().trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO
: request.source().trackTotalHitsUpTo();
SearchSourceBuilder source = request.transformedRequest().source();
int trackTotalHitsUpTo = source == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO
: source.trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO
: source.trackTotalHitsUpTo();
// total hits is null in the response if the tracking of total hits is disabled
boolean withTotalHits = trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED;
listener.onResponse(
Expand All @@ -229,9 +227,10 @@ public final void run() {
assert iterator.skip();
skipShard(iterator);
}
SearchRequest searchRequest = request.transformedRequest();
if (shardsIts.size() > 0) {
assert request.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults";
if (request.allowPartialSearchResults() == false) {
assert searchRequest.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults";
if (searchRequest.allowPartialSearchResults() == false) {
final StringBuilder missingShards = new StringBuilder();
// Fail-fast verification of all shards being available
for (int index = 0; index < shardsIts.size(); index++) {
Expand Down Expand Up @@ -376,7 +375,7 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha
logger.debug(() -> new ParameterizedMessage("All shards failed for phase: [{}]", getName()), cause);
onPhaseFailure(currentPhase, "all shards failed", cause);
} else {
Boolean allowPartialResults = request.allowPartialSearchResults();
Boolean allowPartialResults = request.transformedRequest().allowPartialSearchResults();
assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults";
if (allowPartialResults == false && successfulOps.get() != getNumShards()) {
// check if there are actual failures in the atomic array since
Expand Down Expand Up @@ -612,7 +611,7 @@ public final SearchTask getTask() {

@Override
public final SearchRequest getRequest() {
return request;
return request.transformedRequest();
}

protected final SearchResponse buildSearchResponse(
Expand Down Expand Up @@ -643,19 +642,22 @@ boolean buildPointInTimeFromSearchResults() {
@Override
public void sendSearchResponse(InternalSearchResponse internalSearchResponse, AtomicArray<SearchPhaseResult> queryResults) {
ShardSearchFailure[] failures = buildShardFailures();
Boolean allowPartialResults = request.allowPartialSearchResults();
Boolean allowPartialResults = request.transformedRequest().allowPartialSearchResults();
assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults";
if (allowPartialResults == false && failures.length > 0) {
raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures));
} else {
final Version minNodeVersion = clusterState.nodes().getMinNodeVersion();
final String scrollId = request.scroll() != null ? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion) : null;
final String scrollId = request.transformedRequest().scroll() != null
? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion)
: null;
final String searchContextId;
if (buildPointInTimeFromSearchResults()) {
searchContextId = SearchContextId.encode(queryResults.asList(), aliasFilter, minNodeVersion);
} else {
if (request.source() != null && request.source().pointInTimeBuilder() != null) {
searchContextId = request.source().pointInTimeBuilder().getId();
SearchSourceBuilder source = request.transformedRequest().source();
if (source != null && source.pointInTimeBuilder() != null) {
searchContextId = source.pointInTimeBuilder().getId();
} else {
searchContextId = null;
}
Expand All @@ -677,7 +679,7 @@ public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause)
*/
private void raisePhaseFailure(SearchPhaseExecutionException exception) {
// we don't release persistent readers (point in time).
if (request.pointInTimeBuilder() == null) {
if (request.transformedRequest().pointInTimeBuilder() == null) {
results.getSuccessfulResults().forEach((entry) -> {
if (entry.getContextId() != null) {
try {
Expand Down Expand Up @@ -705,9 +707,7 @@ final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim()
// From src files the next phase is never null, but from tests this is a possibility. Hence, making sure that
// tests pass, we need to do null check on next phase.
if (nextPhase != null) {

final PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(this.getRequest());
pipelinedRequest.transformSearchPhase(results, this, this.getName(), nextPhase.getName());
request.transformSearchPhase(results, this, this.getName(), nextPhase.getName());
}
executeNextPhase(this, nextPhase);
}
Expand Down Expand Up @@ -741,7 +741,7 @@ public final ShardSearchRequest buildShardSearchRequest(SearchShardIterator shar
final String[] routings = indexRoutings.getOrDefault(indexName, Collections.emptySet()).toArray(new String[0]);
ShardSearchRequest shardRequest = new ShardSearchRequest(
shardIt.getOriginalIndices(),
request,
request.transformedRequest(),
shardIt.shardId(),
getNumShards(),
filter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.search.pipeline.PipelinedRequest;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.MinAndMax;
import org.opensearch.search.sort.SortOrder;
Expand Down Expand Up @@ -84,15 +84,14 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
Map<String, Float> concreteIndexBoosts,
Map<String, Set<String>> indexRoutings,
Executor executor,
SearchRequest request,
PipelinedRequest request,
ActionListener<SearchResponse> listener,
GroupShardsIterator<SearchShardIterator> shardsIts,
TransportSearchAction.SearchTimeProvider timeProvider,
ClusterState clusterState,
SearchTask task,
Function<GroupShardsIterator<SearchShardIterator>, SearchPhase> phaseFactory,
SearchResponse.Clusters clusters,
SearchPipelineService searchPipelineService
SearchResponse.Clusters clusters
) {
// We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests
super(
Expand All @@ -112,8 +111,7 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
task,
new CanMatchSearchPhaseResults(shardsIts.size()),
shardsIts.size(),
clusters,
searchPipelineService
clusters
);
this.phaseFactory = phaseFactory;
this.shardsIts = shardsIts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import org.opensearch.search.dfs.AggregatedDfs;
import org.opensearch.search.dfs.DfsSearchResult;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.search.pipeline.PipelinedRequest;
import org.opensearch.transport.Transport;

import java.util.List;
Expand Down Expand Up @@ -71,14 +71,13 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
final SearchPhaseController searchPhaseController,
final Executor executor,
final QueryPhaseResultConsumer queryPhaseResultConsumer,
final SearchRequest request,
final PipelinedRequest request,
final ActionListener<SearchResponse> listener,
final GroupShardsIterator<SearchShardIterator> shardsIts,
final TransportSearchAction.SearchTimeProvider timeProvider,
final ClusterState clusterState,
final SearchTask task,
SearchResponse.Clusters clusters,
SearchPipelineService searchPipelineService
SearchResponse.Clusters clusters
) {
super(
"dfs",
Expand All @@ -96,14 +95,13 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
clusterState,
task,
new ArraySearchPhaseResults<>(shardsIts.size()),
request.getMaxConcurrentShardRequests(),
clusters,
searchPipelineService
request.transformedRequest().getMaxConcurrentShardRequests(),
clusters
);
this.queryPhaseResultConsumer = queryPhaseResultConsumer;
this.searchPhaseController = searchPhaseController;
SearchProgressListener progressListener = task.getProgressListener();
SearchSourceBuilder sourceBuilder = request.source();
SearchSourceBuilder sourceBuilder = request.transformedRequest().source();
progressListener.notifyListShards(
SearchProgressListener.buildSearchShards(this.shardsIts),
SearchProgressListener.buildSearchShards(toSkipShardsIts),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@
import org.opensearch.cluster.routing.GroupShardsIterator;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.search.pipeline.PipelinedRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.transport.Transport;

Expand Down Expand Up @@ -76,14 +77,13 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
final SearchPhaseController searchPhaseController,
final Executor executor,
final QueryPhaseResultConsumer resultConsumer,
final SearchRequest request,
final PipelinedRequest request,
final ActionListener<SearchResponse> listener,
final GroupShardsIterator<SearchShardIterator> shardsIts,
final TransportSearchAction.SearchTimeProvider timeProvider,
ClusterState clusterState,
SearchTask task,
SearchResponse.Clusters clusters,
SearchPipelineService searchPipelineService
SearchResponse.Clusters clusters
) {
super(
"query",
Expand All @@ -101,20 +101,20 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
clusterState,
task,
resultConsumer,
request.getMaxConcurrentShardRequests(),
clusters,
searchPipelineService
request.transformedRequest().getMaxConcurrentShardRequests(),
clusters
);
this.topDocsSize = SearchPhaseController.getTopDocsSize(request);
this.trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo();
this.topDocsSize = SearchPhaseController.getTopDocsSize(request.transformedRequest());
this.trackTotalHitsUpTo = request.transformedRequest().resolveTrackTotalHitsUpTo();
this.searchPhaseController = searchPhaseController;
this.progressListener = task.getProgressListener();

// register the release of the query consumer to free up the circuit breaker memory
// at the end of the search
addReleasable(resultConsumer);

boolean hasFetchPhase = request.source() == null ? true : request.source().size() > 0;
SearchSourceBuilder source = request.transformedRequest().source();
boolean hasFetchPhase = source == null ? true : source.size() > 0;
progressListener.notifyListShards(
SearchProgressListener.buildSearchShards(this.shardsIts),
SearchProgressListener.buildSearchShards(toSkipShardsIts),
Expand Down
Loading

0 comments on commit b4bbec6

Please sign in to comment.