Skip to content

Commit

Permalink
Working end-to-end _search_protobuf API with requests, responses and …
Browse files Browse the repository at this point in the history
…node-to-node communication in protobuf

Signed-off-by: Vacha Shah <[email protected]>
  • Loading branch information
VachaShah committed Nov 18, 2023
1 parent 3153068 commit fc414ed
Show file tree
Hide file tree
Showing 21 changed files with 432 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ private boolean isCollapseRequest() {

@Override
public void run() {
System.out.println("ExpandSearchPhase run");
System.out.println("Is collapse request: " + isCollapseRequest());
if (isCollapseRequest() && searchResponse.hits().getHits().length > 0) {
SearchRequest searchRequest = context.getRequest();
CollapseBuilder collapseBuilder = searchRequest.source().collapse();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private GroupShardsIterator<SearchShardIterator> getIterator(
// to produce a valid search result with all the aggs etc.
possibleMatches.set(0);
}
SearchSourceBuilder source = getRequest().source();
SearchSourceBuilder source = getProtobufRequest().source();
int i = 0;
for (SearchShardIterator iter : shardsIts) {
if (possibleMatches.get(i++)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.action.search;

import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.InternalSearchResponse;

/**
* This search phase is an optional phase that will be executed once all hits are fetched from the shards that executes
* field-collapsing on the inner hits. This phase only executes if field collapsing is requested in the search request and otherwise
* forwards to the next phase immediately.
*
* @opensearch.internal
*/
final class ProtobufExpandSearchPhase extends SearchPhase {
private final SearchPhaseContext context;
private final InternalSearchResponse searchResponse;
private final AtomicArray<SearchPhaseResult> queryResults;

ProtobufExpandSearchPhase(SearchPhaseContext context, InternalSearchResponse searchResponse, AtomicArray<SearchPhaseResult> queryResults) {
super(SearchPhaseName.EXPAND.getName());
this.context = context;
this.searchResponse = searchResponse;
this.queryResults = queryResults;
}

@Override
public void run() {
context.sendSearchResponse(searchResponse, queryResults);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ final class ProtobufFetchSearchPhase extends SearchPhase {
searchPhaseController,
aggregatedDfs,
context,
(response, queryPhaseResults) -> new ExpandSearchPhase(context, response, queryPhaseResults)
(response, queryPhaseResults) -> new ProtobufExpandSearchPhase(context, response, queryPhaseResults)
);
}

Expand Down Expand Up @@ -108,7 +108,7 @@ public void onFailure(Exception e) {

private void innerRun() throws Exception {
final int numShards = context.getNumShards();
final boolean isScrollSearch = context.getRequest().scroll() != null;
final boolean isScrollSearch = context.getProtobufRequest().scroll() != null;
final List<SearchPhaseResult> phaseResults = queryResults.asList();
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = resultConsumer.reduce();
final boolean queryAndFetchOptimization = queryResults.length() == 1;
Expand Down Expand Up @@ -249,7 +249,7 @@ public void onFailure(Exception e) {
private void releaseIrrelevantSearchContext(QuerySearchResult queryResult) {
// we only release search context that we did not fetch from, if we are not scrolling
// or using a PIT and if it has at least one hit that didn't make it to the global topDocs
if (queryResult.hasSearchContext() && context.getRequest().scroll() == null && context.getRequest().pointInTimeBuilder() == null) {
if (queryResult.hasSearchContext() && context.getProtobufRequest().scroll() == null && context.getProtobufRequest().pointInTimeBuilder() == null) {
try {
SearchShardTarget searchShardTarget = queryResult.getSearchShardTarget();
Transport.Connection connection = context.getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());
Expand All @@ -267,7 +267,7 @@ private void moveToNextPhase(
AtomicArray<? extends SearchPhaseResult> fetchResultsArr
) {
final InternalSearchResponse internalResponse = searchPhaseController.merge(
context.getRequest().scroll() != null,
context.getProtobufRequest().scroll() != null,
reducedQueryPhase,
fetchResultsArr.asList(),
fetchResultsArr::get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ protected void onShardResult(SearchPhaseResult result, SearchShardIterator shard
QuerySearchResult queryResult = result.queryResult();
if (queryResult.isNull() == false
// disable sort optims for scroll requests because they keep track of the last bottom doc locally (per shard)
&& getRequest().scroll() == null
&& getProtobufRequest().scroll() == null
&& queryResult.topDocs() != null
&& queryResult.topDocs().topDocs.getClass() == TopFieldDocs.class) {
TopFieldDocs topDocs = (TopFieldDocs) queryResult.topDocs().topDocs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,14 @@ public static BiFunction<Transport.Connection, SearchActionListener, ActionListe
@Override
public void onResponse(SearchPhaseResult response) {
if (response instanceof QueryFetchSearchResult) {
response.queryResult().getShardSearchRequest().setOutboundNetworkTime(0);
response.queryResult().getShardSearchRequest().setInboundNetworkTime(0);
if (response.queryResult().getShardSearchRequest() != null) {
response.queryResult().getShardSearchRequest().setOutboundNetworkTime(0);
response.queryResult().getShardSearchRequest().setInboundNetworkTime(0);
} else if (response.queryResult().getProtobufShardSearchRequest() != null) {
response.queryResult().getProtobufShardSearchRequest().setOutboundNetworkTime(0);
response.queryResult().getProtobufShardSearchRequest().setInboundNetworkTime(0);
}

}
QuerySearchResult queryResult = response.queryResult();
if (response.getShardSearchRequest() != null) {
Expand All @@ -86,6 +92,18 @@ public void onResponse(SearchPhaseResult response) {
response.getShardSearchRequest().setOutboundNetworkTime(0);
response.getShardSearchRequest().setInboundNetworkTime(0);
}
} else if (response.getProtobufShardSearchRequest() != null) {
if (response.remoteAddress() != null) {
// update outbound network time for request sent over network for shard requests
response.getProtobufShardSearchRequest()
.setOutboundNetworkTime(
Math.max(0, System.currentTimeMillis() - response.getShardSearchRequest().getOutboundNetworkTime())
);
} else {
// reset inbound and outbound network time to 0 for local request for shard requests
response.getProtobufShardSearchRequest().setOutboundNetworkTime(0);
response.getProtobufShardSearchRequest().setInboundNetworkTime(0);
}
}
if (nodeId != null && queryResult != null) {
final long serviceTimeEWMA = queryResult.serviceTimeEWMA();
Expand Down
3 changes: 2 additions & 1 deletion server/src/main/java/org/opensearch/search/SearchHits.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.rest.action.search.RestSearchAction;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
Expand All @@ -59,7 +60,7 @@
*
* @opensearch.internal
*/
public final class SearchHits implements Writeable, ToXContentFragment, Iterable<SearchHit> {
public final class SearchHits implements Writeable, ToXContentFragment, Iterable<SearchHit>, Serializable {
public static SearchHits empty() {
return empty(true);
}
Expand Down
14 changes: 13 additions & 1 deletion server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ private SearchPhaseResult executeQueryPhaseProtobuf(ProtobufShardSearchRequest r
afterQueryTime = executor.success();
}
if (request.numberOfShards() == 1) {
return executeFetchPhase(readerContext, context, afterQueryTime);
return executeFetchPhaseProtobuf(readerContext, context, afterQueryTime);
} else {
// Pass the rescoreDocIds to the queryResult to send them the coordinating node and receive them back in the fetch phase.
// We also pass the rescoreDocIds to the LegacyReaderContext in case the search state needs to stay in the data node.
Expand Down Expand Up @@ -733,6 +733,18 @@ private QueryFetchSearchResult executeFetchPhase(ReaderContext reader, SearchCon
return new QueryFetchSearchResult(context.queryResult(), context.fetchResult());
}

private QueryFetchSearchResult executeFetchPhaseProtobuf(ReaderContext reader, SearchContext context, long afterQueryTime) {
try (SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(context, true, afterQueryTime)) {
shortcutDocIdsToLoad(context);
fetchPhase.execute(context);
if (reader.singleSession()) {
freeReaderContext(reader.id());
}
executor.success();
}
return new QueryFetchSearchResult(context.queryResult(), context.fetchResult());
}

public void executeQueryPhase(
InternalScrollSearchRequest request,
SearchShardTask task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,13 @@
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.server.proto.FetchSearchResultProto;
import org.opensearch.server.proto.ShardSearchRequestProto;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.OutputStream;

/**
* Result from a fetch
Expand All @@ -55,6 +60,8 @@ public final class FetchSearchResult extends SearchPhaseResult {
// client side counter
private transient int counter;

private FetchSearchResultProto.FetchSearchResult fetchSearchResultProto;

public FetchSearchResult() {}

public FetchSearchResult(StreamInput in) throws IOException {
Expand All @@ -65,13 +72,23 @@ public FetchSearchResult(StreamInput in) throws IOException {

public FetchSearchResult(byte[] in) throws IOException {
super(in);
contextId = null;
hits = null;
this.fetchSearchResultProto = FetchSearchResultProto.FetchSearchResult.parseFrom(in);
contextId = new ShardSearchContextId(this.fetchSearchResultProto.getContextId().getSessionId(), this.fetchSearchResultProto.getContextId().getId());
ByteArrayInputStream stream = new ByteArrayInputStream(this.fetchSearchResultProto.getHits().toByteArray());
try (ObjectInputStream is = new ObjectInputStream(stream)) {
hits = (SearchHits) is.readObject();
} catch (ClassNotFoundException | IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}

public FetchSearchResult(ShardSearchContextId id, SearchShardTarget shardTarget) {
this.contextId = id;
setSearchShardTarget(shardTarget);
this.fetchSearchResultProto = FetchSearchResultProto.FetchSearchResult.newBuilder()
.setContextId(ShardSearchRequestProto.ShardSearchContextId.newBuilder().setSessionId(id.getSessionId()).setId(id.getId()).build())
.build();
}

@Override
Expand Down Expand Up @@ -114,4 +131,17 @@ public void writeTo(StreamOutput out) throws IOException {
contextId.writeTo(out);
hits.writeTo(out);
}

@Override
public void writeTo(OutputStream out) throws IOException {
out.write(fetchSearchResultProto.toByteArray());
}

public FetchSearchResultProto.FetchSearchResult response() {
return this.fetchSearchResultProto;
}

public FetchSearchResult(FetchSearchResultProto.FetchSearchResult fetchSearchResult) {
this.fetchSearchResultProto = fetchSearchResult;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.server.proto.QueryFetchSearchResultProto;

import java.io.IOException;
import java.io.OutputStream;

/**
* Query fetch result
Expand All @@ -52,6 +54,8 @@ public final class QueryFetchSearchResult extends SearchPhaseResult {
private final QuerySearchResult queryResult;
private final FetchSearchResult fetchResult;

private QueryFetchSearchResultProto.QueryFetchSearchResult queryFetchSearchResult;

public QueryFetchSearchResult(StreamInput in) throws IOException {
super(in);
queryResult = new QuerySearchResult(in);
Expand All @@ -60,13 +64,22 @@ public QueryFetchSearchResult(StreamInput in) throws IOException {

public QueryFetchSearchResult(byte[] in) throws IOException {
super(in);
queryResult = null;
fetchResult = null;
this.queryFetchSearchResult = QueryFetchSearchResultProto.QueryFetchSearchResult.parseFrom(in);
queryResult = new QuerySearchResult(in);
fetchResult = new FetchSearchResult(in);
}

public QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) {
this.queryResult = queryResult;
this.fetchResult = fetchResult;
System.out.println("QueryResult: " + queryResult);
System.out.println("FetchResult: " + fetchResult);
if (queryResult.response() != null && fetchResult.response() != null) {
this.queryFetchSearchResult = QueryFetchSearchResultProto.QueryFetchSearchResult.newBuilder()
.setQueryResult(queryResult.response())
.setFetchResult(fetchResult.response())
.build();
}
}

@Override
Expand Down Expand Up @@ -108,4 +121,19 @@ public void writeTo(StreamOutput out) throws IOException {
queryResult.writeTo(out);
fetchResult.writeTo(out);
}

@Override
public void writeTo(OutputStream out) throws IOException {
out.write(queryFetchSearchResult.toByteArray());
}

public QueryFetchSearchResultProto.QueryFetchSearchResult response() {
return this.queryFetchSearchResult;
}

public QueryFetchSearchResult(QueryFetchSearchResultProto.QueryFetchSearchResult queryFetchSearchResult) {
this.queryFetchSearchResult = queryFetchSearchResult;
this.queryResult = new QuerySearchResult(queryFetchSearchResult.getQueryResult());
this.fetchResult = new FetchSearchResult(queryFetchSearchResult.getFetchResult());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ private ProtobufShardSearchRequest(
// this.keepAlive = keepAlive;
// assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive;

ShardSearchRequestProto.ShardSearchRequest.OriginalIndices originalIndicesProto = ShardSearchRequestProto.ShardSearchRequest.OriginalIndices.newBuilder()
ShardSearchRequestProto.OriginalIndices originalIndicesProto = ShardSearchRequestProto.OriginalIndices.newBuilder()
.addAllIndices(Arrays.stream(originalIndices.indices()).collect(Collectors.toList()))
.setIndicesOptions(ShardSearchRequestProto.ShardSearchRequest.OriginalIndices.IndicesOptions.newBuilder()
.setIndicesOptions(ShardSearchRequestProto.OriginalIndices.IndicesOptions.newBuilder()
.setIgnoreUnavailable(originalIndices.indicesOptions().ignoreUnavailable())
.setAllowNoIndices(originalIndices.indicesOptions().allowNoIndices())
.setExpandWildcardsOpen(originalIndices.indicesOptions().expandWildcardsOpen())
Expand All @@ -230,14 +230,14 @@ private ProtobufShardSearchRequest(
.setIgnoreThrottled(originalIndices.indicesOptions().ignoreThrottled())
.build())
.build();
ShardSearchRequestProto.ShardSearchRequest.ShardId shardIdProto = ShardSearchRequestProto.ShardSearchRequest.ShardId.newBuilder()
ShardSearchRequestProto.ShardId shardIdProto = ShardSearchRequestProto.ShardId.newBuilder()
.setShardId(shardId.getId())
.setHashCode(shardId.hashCode())
.setIndexName(shardId.getIndexName())
.setIndexUUID(shardId.getIndex().getUUID())
.build();

ShardSearchRequestProto.ShardSearchRequest.ShardSearchContextId.Builder shardSearchContextId = ShardSearchRequestProto.ShardSearchRequest.ShardSearchContextId.newBuilder();
ShardSearchRequestProto.ShardSearchContextId.Builder shardSearchContextId = ShardSearchRequestProto.ShardSearchContextId.newBuilder();
System.out.println("Reader id: " + readerId);
if (readerId != null) {
shardSearchContextId.setSessionId(readerId.getSessionId());
Expand Down
Loading

0 comments on commit fc414ed

Please sign in to comment.