Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the SearchPhaseResultsProcessor interface in Search Pipeline #7283

Merged
merged 13 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x]
### Added
- [SearchPipeline] Add new search pipeline processor type, SearchPhaseResultsProcessor, that can modify the result of one search phase before starting the next phase.([#7283](https://github.com/opensearch-project/OpenSearch/pull/7283))
- Add task cancellation monitoring service ([#7642](https://github.com/opensearch-project/OpenSearch/pull/7642))
- Add TokenManager Interface ([#7452](https://github.com/opensearch-project/OpenSearch/pull/7452))
- Add Remote store as a segment replication source ([#7653](https://github.com/opensearch-project/OpenSearch/pull/7653))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ teardown:
{
"script" : {
"lang" : "painless",
"source" : "ctx._source['size'] += 10; ctx._source['from'] -= 1; ctx._source['explain'] = !ctx._source['explain']; ctx._source['version'] = !ctx._source['version']; ctx._source['seq_no_primary_term'] = !ctx._source['seq_no_primary_term']; ctx._source['track_scores'] = !ctx._source['track_scores']; ctx._source['track_total_hits'] = 1; ctx._source['min_score'] -= 0.9; ctx._source['terminate_after'] += 2; ctx._source['profile'] = !ctx._source['profile'];"
"source" : "ctx._source['size'] += 10; ctx._source['from'] = ctx._source['from'] <= 0 ? ctx._source['from'] : ctx._source['from'] - 1 ; ctx._source['explain'] = !ctx._source['explain']; ctx._source['version'] = !ctx._source['version']; ctx._source['seq_no_primary_term'] = !ctx._source['seq_no_primary_term']; ctx._source['track_scores'] = !ctx._source['track_scores']; ctx._source['track_total_hits'] = 1; ctx._source['min_score'] -= 0.9; ctx._source['terminate_after'] += 2; ctx._source['profile'] = !ctx._source['profile'];"
}
}
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.pipeline.PipelinedRequest;
import org.opensearch.transport.Transport;

import java.util.ArrayDeque;
Expand Down Expand Up @@ -696,7 +697,11 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) {
* @see #onShardResult(SearchPhaseResult, SearchShardIterator)
*/
final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim()
executeNextPhase(this, getNextPhase(results, this));
final SearchPhase nextPhase = getNextPhase(results, this);
if (request instanceof PipelinedRequest && nextPhase != null) {
((PipelinedRequest) request).transformSearchPhaseResults(results, this, this.getName(), nextPhase.getName());
}
executeNextPhase(this, nextPhase);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ boolean hasResult(int shardIndex) {
}

@Override
AtomicArray<Result> getAtomicArray() {
public AtomicArray<Result> getAtomicArray() {
return results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
) {
// We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests
super(
"can_match",
SearchPhaseName.CAN_MATCH.getName(),
logger,
searchTransportService,
nodeIdToConnection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ final class DfsQueryPhase extends SearchPhase {
Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
SearchPhaseContext context
) {
super("dfs_query");
super(SearchPhaseName.DFS_QUERY.getName());
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
this.progressListener = context.getTask().getProgressListener();
this.queryResult = queryResult;
this.searchResults = searchResults;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ final class ExpandSearchPhase extends SearchPhase {
private final AtomicArray<SearchPhaseResult> queryResults;

ExpandSearchPhase(SearchPhaseContext context, InternalSearchResponse searchResponse, AtomicArray<SearchPhaseResult> queryResults) {
super("expand");
super(SearchPhaseName.EXPAND.getName());
this.context = context;
this.searchResponse = searchResponse;
this.queryResults = queryResults;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ final class FetchSearchPhase extends SearchPhase {
SearchPhaseContext context,
BiFunction<InternalSearchResponse, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory
) {
super("fetch");
super(SearchPhaseName.FETCH.getName());
if (context.getNumShards() != resultConsumer.getNumShards()) {
throw new IllegalStateException(
"number of shards must match the length of the query results but doesn't:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.common.CheckedRunnable;

import java.io.IOException;
import java.util.Locale;
import java.util.Objects;

/**
Expand All @@ -54,4 +55,13 @@ protected SearchPhase(String name) {
public String getName() {
return name;
}

/**
* Returns the SearchPhase name as {@link SearchPhaseName}. Exception will come if SearchPhase name is not defined
* in {@link SearchPhaseName}
* @return {@link SearchPhaseName}
*/
public SearchPhaseName getSearchPhaseName() {
return SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
*
* @opensearch.internal
*/
interface SearchPhaseContext extends Executor {
public interface SearchPhaseContext extends Executor {
// TODO maybe we can make this concrete later - for now we just implement this in the base class for all initial phases

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.action.search;

/**
* Enum for different Search Phases in OpenSearch
* @opensearch.internal
*/
public enum SearchPhaseName {
QUERY("query"),
FETCH("fetch"),
DFS_QUERY("dfs_query"),
EXPAND("expand"),
CAN_MATCH("can_match");

private final String name;

SearchPhaseName(final String name) {
this.name = name;
}

public String getName() {
return name;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
*
* @opensearch.internal
*/
abstract class SearchPhaseResults<Result extends SearchPhaseResult> {
public abstract class SearchPhaseResults<Result extends SearchPhaseResult> {
private final int numShards;

SearchPhaseResults(int numShards) {
Expand Down Expand Up @@ -75,7 +75,13 @@ final int getNumShards() {

void consumeShardFailure(int shardIndex) {}

AtomicArray<Result> getAtomicArray() {
/**
* Returns an {@link AtomicArray} of {@link Result}, which are nothing but the SearchPhaseResults
* for shards. The {@link Result} are of type {@link SearchPhaseResult}
*
* @return an {@link AtomicArray} of {@link Result}
*/
public AtomicArray<Result> getAtomicArray() {
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ protected SearchPhase sendResponsePhase(
SearchPhaseController.ReducedQueryPhase queryPhase,
final AtomicArray<? extends SearchPhaseResult> fetchResults
) {
return new SearchPhase("fetch") {
return new SearchPhase(SearchPhaseName.FETCH.getName()) {
@Override
public void run() throws IOException {
sendResponse(queryPhase, fetchResults);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ protected void executeInitialPhase(

@Override
protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
return new SearchPhase("fetch") {
return new SearchPhase(SearchPhaseName.FETCH.getName()) {
@Override
public void run() {
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedScrollQueryPhase(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,12 @@ private void executeRequest(
relativeStartNanos,
System::nanoTime
);
SearchRequest searchRequest;
PipelinedRequest searchRequest;
ActionListener<SearchResponse> listener;
try {
PipelinedRequest pipelinedRequest = searchPipelineService.resolvePipeline(originalSearchRequest);
searchRequest = pipelinedRequest.transformedRequest();
searchRequest = searchPipelineService.resolvePipeline(originalSearchRequest);
listener = ActionListener.wrap(
r -> originalListener.onResponse(pipelinedRequest.transformResponse(r)),
r -> originalListener.onResponse(searchRequest.transformResponse(r)),
originalListener::onFailure
);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.plugins;

import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

Expand Down Expand Up @@ -42,4 +43,15 @@ default Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcess
default Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Processor.Parameters parameters) {
return Collections.emptyMap();
}

/**
* Returns additional search pipeline search phase results processor types added by this plugin.
*
* The key of the returned {@link Map} is the unique name for the processor which is specified
* in pipeline configurations, and the value is a {@link org.opensearch.search.pipeline.Processor.Factory}
* to create the processor from a given pipeline configuration.
*/
default Map<String, Processor.Factory<SearchPhaseResultsProcessor>> getSearchPhaseResultsProcessors(Processor.Parameters parameters) {
return Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package org.opensearch.search.pipeline;

import org.opensearch.OpenSearchParseException;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.Nullable;
Expand All @@ -17,6 +19,7 @@
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.SearchPhaseResult;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -35,6 +38,7 @@ class Pipeline {

public static final String REQUEST_PROCESSORS_KEY = "request_processors";
public static final String RESPONSE_PROCESSORS_KEY = "response_processors";
public static final String PHASE_PROCESSORS_KEY = "phase_results_processors";
private final String id;
private final String description;
private final Integer version;
Expand All @@ -43,22 +47,24 @@ class Pipeline {
// Then these can be CompoundProcessors instead of lists.
private final List<SearchRequestProcessor> searchRequestProcessors;
private final List<SearchResponseProcessor> searchResponseProcessors;

private final NamedWriteableRegistry namedWriteableRegistry;
private final List<SearchPhaseResultsProcessor> searchPhaseResultsProcessors;

private Pipeline(
String id,
@Nullable String description,
@Nullable Integer version,
List<SearchRequestProcessor> requestProcessors,
List<SearchResponseProcessor> responseProcessors,
List<SearchPhaseResultsProcessor> phaseResultsProcessors,
NamedWriteableRegistry namedWriteableRegistry
) {
this.id = id;
this.description = description;
this.version = version;
this.searchRequestProcessors = requestProcessors;
this.searchResponseProcessors = responseProcessors;
this.searchPhaseResultsProcessors = phaseResultsProcessors;
this.namedWriteableRegistry = namedWriteableRegistry;
}

Expand All @@ -67,6 +73,7 @@ static Pipeline create(
Map<String, Object> config,
Map<String, Processor.Factory<SearchRequestProcessor>> requestProcessorFactories,
Map<String, Processor.Factory<SearchResponseProcessor>> responseProcessorFactories,
Map<String, Processor.Factory<SearchPhaseResultsProcessor>> phaseResultsProcessorFactories,
NamedWriteableRegistry namedWriteableRegistry
) throws Exception {
String description = ConfigurationUtils.readOptionalStringProperty(null, null, config, DESCRIPTION_KEY);
Expand All @@ -79,7 +86,19 @@ static Pipeline create(
config,
RESPONSE_PROCESSORS_KEY
);
List<SearchResponseProcessor> responseProcessors = readProcessors(responseProcessorFactories, responseProcessorConfigs);

final List<Map<String, Object>> phaseProcessorConfigs = ConfigurationUtils.readOptionalList(
null,
null,
config,
PHASE_PROCESSORS_KEY
);
final List<SearchResponseProcessor> responseProcessors = readProcessors(responseProcessorFactories, responseProcessorConfigs);
final List<SearchPhaseResultsProcessor> phaseResultsProcessors = readProcessors(
phaseResultsProcessorFactories,
phaseProcessorConfigs
);

if (config.isEmpty() == false) {
throw new OpenSearchParseException(
"pipeline ["
Expand All @@ -88,7 +107,15 @@ static Pipeline create(
+ Arrays.toString(config.keySet().toArray())
);
}
return new Pipeline(id, description, version, requestProcessors, responseProcessors, namedWriteableRegistry);
return new Pipeline(
id,
description,
version,
requestProcessors,
responseProcessors,
phaseResultsProcessors,
namedWriteableRegistry
);
}

private static <T extends Processor> List<T> readProcessors(
Expand Down Expand Up @@ -134,6 +161,10 @@ List<SearchResponseProcessor> getSearchResponseProcessors() {
return searchResponseProcessors;
}

List<SearchPhaseResultsProcessor> getSearchPhaseResultsProcessors() {
return searchPhaseResultsProcessors;
}

SearchRequest transformRequest(SearchRequest request) throws Exception {
if (searchRequestProcessors.isEmpty() == false) {
try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) {
Expand Down Expand Up @@ -168,6 +199,26 @@ SearchResponse transformResponse(SearchRequest request, SearchResponse response)
0,
Collections.emptyList(),
Collections.emptyList(),
Collections.emptyList(),
null
);

<Result extends SearchPhaseResult> void runSearchPhaseResultsTransformer(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext context,
String currentPhase,
String nextPhase
) throws SearchPipelineProcessingException {

try {
for (SearchPhaseResultsProcessor searchPhaseResultsProcessor : searchPhaseResultsProcessors) {
if (currentPhase.equals(searchPhaseResultsProcessor.getBeforePhase().getName())
&& nextPhase.equals(searchPhaseResultsProcessor.getAfterPhase().getName())) {
searchPhaseResultsProcessor.process(searchPhaseResult, context);
}
}
} catch (RuntimeException e) {
throw new SearchPipelineProcessingException(e);
}
}
}
Loading