Skip to content

Commit

Permalink
dummy implementation of join on coordinator
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Oct 14, 2024
1 parent 130d554 commit 49b1284
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,31 @@

package org.opensearch.action.search;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.TopFieldDocs;
import org.opensearch.arrow.ArrowStreamProvider;
import org.opensearch.arrow.StreamManager;
import org.opensearch.arrow.StreamTicket;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.routing.GroupShardsIterator;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.builder.Join;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.stream.OSTicket;
import org.opensearch.search.stream.StreamSearchResult;
import org.opensearch.search.suggest.Suggest;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.transport.Transport;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -66,14 +65,44 @@
import java.util.function.BiFunction;

/**
* Async transport action for query then fetch
* Stream at coordinator layer
*
* @opensearch.internal
*/
class StreamAsyncAction extends SearchQueryThenFetchAsyncAction {

public StreamAsyncAction(Logger logger, SearchTransportService searchTransportService, BiFunction<String, String, Transport.Connection> nodeIdToConnection, Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts, Map<String, Set<String>> indexRoutings, SearchPhaseController searchPhaseController, Executor executor, QueryPhaseResultConsumer resultConsumer, SearchRequest request, ActionListener<SearchResponse> listener, GroupShardsIterator<SearchShardIterator> shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters, SearchRequestContext searchRequestContext, Tracer tracer) {
super(logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, resultConsumer, request, listener, shardsIts, timeProvider, clusterState, task, clusters, searchRequestContext, tracer);
private final StreamManager streamManager;
private final Join join;

public StreamAsyncAction(
Logger logger,
SearchTransportService searchTransportService,
BiFunction<String, String, Transport.Connection> nodeIdToConnection,
Map<String, AliasFilter> aliasFilter,
Map<String, Float> concreteIndexBoosts,
Map<String, Set<String>> indexRoutings,
SearchPhaseController searchPhaseController,
Executor executor,
QueryPhaseResultConsumer resultConsumer,
SearchRequest request,
ActionListener<SearchResponse> listener,
GroupShardsIterator<SearchShardIterator> shardsIts,
TransportSearchAction.SearchTimeProvider timeProvider,
ClusterState clusterState,
SearchTask task,
SearchResponse.Clusters clusters,
SearchRequestContext searchRequestContext,
Tracer tracer,
StreamManager streamManager
) {
super(
logger, searchTransportService, nodeIdToConnection, aliasFilter,
concreteIndexBoosts, indexRoutings, searchPhaseController, executor,
resultConsumer, request, listener, shardsIts, timeProvider,
clusterState, task, clusters, searchRequestContext, tracer
);
this.streamManager = streamManager;
this.join = searchRequestContext.getRequest().source().getJoin();
}

@Override
Expand All @@ -82,7 +111,8 @@ protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> r
}

class StreamSearchReducePhase extends SearchPhase {
private SearchPhaseContext context;
private final SearchPhaseContext context;

protected StreamSearchReducePhase(String name, SearchPhaseContext context) {
super(name);
this.context = context;
Expand All @@ -92,24 +122,62 @@ protected StreamSearchReducePhase(String name, SearchPhaseContext context) {
public void run() {
context.execute(new StreamReduceAction(context, this));
}
};
}

class StreamReduceAction extends AbstractRunnable {
private SearchPhaseContext context;
private final SearchPhaseContext context;
private SearchPhase phase;

StreamReduceAction(SearchPhaseContext context, SearchPhase phase) {
this.context = context;

}

@Override
protected void doRun() throws Exception {

List<OSTicket> tickets = new ArrayList<>();
for (SearchPhaseResult entry : results.getAtomicArray().asList()) {
if (entry instanceof StreamSearchResult) {
tickets.addAll(((StreamSearchResult) entry).getFlightTickets());
((StreamSearchResult) entry).getFlightTickets().forEach(osTicket -> {
// System.out.println("Ticket: " + new String(osTicket.getBytes(), StandardCharsets.UTF_8));
// VectorSchemaRoot root = streamManager.getVectorSchemaRoot(osTicket);
// System.out.println("Number of rows: " + root.getRowCount());
});
}
}
InternalSearchResponse internalSearchResponse = new InternalSearchResponse(SearchHits.empty(),null, null, null, false, false, 1, Collections.emptyList(), tickets);

// shard/table, schema

// ticket should contain which IndexShard it comes from
// based on the search request, perform join using these tickets

// join operate on 2 indexes using condition
// join contain already contain the schema, or at least hold the schema data

StreamTicket joinResult = streamManager.registerStream((allocator) -> new ArrowStreamProvider.Task() {
@Override
public VectorSchemaRoot init(BufferAllocator allocator) {
IntVector docIDVector = new IntVector("docID", allocator);
FieldVector[] vectors = new FieldVector[]{
docIDVector
};
VectorSchemaRoot root = new VectorSchemaRoot(Arrays.asList(vectors));
return root;
}

public void run(VectorSchemaRoot root, ArrowStreamProvider.FlushSignal flushSignal) {
// TODO perform join algo
IntVector docIDVector = (IntVector) root.getVector("docID");
for (int i = 0; i < 10; i++) {
docIDVector.set(i, i);
}
root.setRowCount(10);
flushSignal.awaitConsumption();
}
});

InternalSearchResponse internalSearchResponse = new InternalSearchResponse(SearchHits.empty(), null, null, null, false, false, 1, Collections.emptyList(), List.of(new OSTicket(joinResult.getBytes(), null)));
context.sendSearchResponse(internalSearchResponse, results.getAtomicArray());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@
import java.util.stream.StreamSupport;

import static org.opensearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
import static org.opensearch.action.search.SearchType.*;
import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
import static org.opensearch.action.search.SearchType.QUERY_THEN_FETCH;
import static org.opensearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort;

/**
Expand Down Expand Up @@ -1324,7 +1325,8 @@ private AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction
task,
clusters,
searchRequestContext,
tracer
tracer,
searchService.getStreamManager()
);
break;
default:
Expand Down
112 changes: 112 additions & 0 deletions server/src/main/java/org/opensearch/arrow/query/ArrowCollector.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.arrow.query;

import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.FilterCollector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BytesRef;
import org.opensearch.arrow.ArrowStreamProvider;
import org.opensearch.common.annotation.ExperimentalApi;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;@ExperimentalApi

public class ArrowCollector extends FilterCollector {

List<ProjectionField> projectionFields;
private final VectorSchemaRoot root;
private final ArrowStreamProvider.FlushSignal flushSignal;
private final int batchSize;

public ArrowCollector(Collector in, List<ProjectionField> projectionFields, VectorSchemaRoot root, int batchSize, ArrowStreamProvider.FlushSignal flushSignal) {
super(in);
this.projectionFields = projectionFields;
this.root = root;
this.batchSize = batchSize;
this.flushSignal = flushSignal;
}

@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {

Map<String, SortedSetDocValues> fieldValues = new HashMap<>();
projectionFields.forEach(field -> {
try {
SortedSetDocValues dv = context.reader().getSortedSetDocValues(field.fieldName);
fieldValues.put(field.fieldName, dv);
} catch (IOException e) {
throw new RuntimeException(e);
}
});

final int[] currentRow = {0};
return new LeafCollector() {

@Override
public void collect(int docId) throws IOException {
// innerLeafCollector.collect(docId);

// read from the lucene field values
for (Map.Entry<String, SortedSetDocValues> entry : fieldValues.entrySet()) {
String field = entry.getKey();
SortedSetDocValues dv = entry.getValue();
VarCharVector vector = (VarCharVector) root.getVector(field);

if (dv.advanceExact(docId)) {
BytesRef keyword = dv.termsEnum().next();
vector.setSafe(currentRow[0], keyword.utf8ToString().getBytes());
}
}

currentRow[0]++;
if (currentRow[0] >= batchSize) {
root.setRowCount(batchSize);
flushSignal.awaitConsumption();
currentRow[0] = 0;
}
}

@Override
public void finish() throws IOException {
if (currentRow[0] > 0) {
root.setRowCount(currentRow[0]);
flushSignal.awaitConsumption();
currentRow[0] = 0;
}
}

@Override
public void setScorer(Scorable scorable) throws IOException {
// innerLeafCollector.setScorer(scorable);
}
};
}

@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}

@Override
public void setWeight(Weight weight) {
if (this.in != null) {
this.in.setWeight(weight);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.arrow.query;

import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.index.fielddata.IndexNumericFieldData;

@ExperimentalApi
public class ProjectionField {
public String fieldName;
String type;

public ProjectionField(String fieldName, String type) {
this.fieldName = fieldName;
this.type = type;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
};
}

static String[] addString(String[] originalArray, String newString) {
String[] newArray = new String[originalArray.length + 1];
System.arraycopy(originalArray, 0, newArray, 0, originalArray.length);
newArray[newArray.length - 1] = newString;
return newArray;
}

/**
* Parses the rest request on top of the SearchRequest, preserving values that are not overridden by the rest request.
*
Expand All @@ -163,6 +170,10 @@ public static void parseSearchRequest(
searchRequest.source().parseXContent(requestContentParser, true);
}

if (searchRequest.source().getJoin() != null) {
searchRequest.indices(addString(searchRequest.indices(), searchRequest.source().getJoin().getIndex()));
}

final int batchedReduceSize = request.paramAsInt("batched_reduce_size", searchRequest.getBatchedReduceSize());
searchRequest.setBatchedReduceSize(batchedReduceSize);
if (request.hasParam("pre_filter_shard_size")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,10 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
private final TaskResourceTrackingService taskResourceTrackingService;
private final StreamManager streamManager;

public StreamManager getStreamManager() {
return streamManager;
}

public SearchService(
ClusterService clusterService,
IndicesService indicesService,
Expand Down Expand Up @@ -1761,7 +1765,7 @@ public CanMatchResponse canMatch(ShardSearchRequest request) throws IOException
}

private CanMatchResponse canMatch(ShardSearchRequest request, boolean checkRefreshPending) throws IOException {
assert request.searchType() == SearchType.QUERY_THEN_FETCH : "unexpected search type: " + request.searchType();
assert request.searchType() == SearchType.QUERY_THEN_FETCH || request.searchType() == SearchType.STREAM: "unexpected search type: " + request.searchType();
final ReaderContext readerContext = request.readerId() != null ? findReaderContext(request.readerId(), request) : null;
final Releasable markAsUsed = readerContext != null ? readerContext.markAsUsed(getKeepAlive(request)) : () -> {};
try (Releasable ignored = markAsUsed) {
Expand Down
Loading

0 comments on commit 49b1284

Please sign in to comment.