Skip to content

Commit

Permalink
Migrate max workflow search to sdkClient
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 22, 2024
1 parent 03da56c commit d55e694
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import org.apache.logging.log4j.message.ParameterizedMessageFactory;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
Expand All @@ -26,7 +26,6 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand All @@ -42,6 +41,7 @@
import org.opensearch.plugins.PluginsService;
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
Expand Down Expand Up @@ -256,6 +256,7 @@ private void createExecute(WorkflowRequest request, User user, String tenantId,
checkMaxWorkflows(
flowFrameworkSettings.getRequestTimeout(),
flowFrameworkSettings.getMaxWorkflows(),
tenantId,
ActionListener.wrap(max -> {
if (FALSE.equals(max)) {
String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage(
Expand Down Expand Up @@ -500,24 +501,36 @@ private void handleFullDocUpdate(WorkflowRequest request, Template template, Act
* @param maxWorkflow max workflows
* @param internalListener listener for search request
*/
void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, ActionListener<Boolean> internalListener) {
if (!flowFrameworkIndicesHandler.doesIndexExist(CommonValue.GLOBAL_CONTEXT_INDEX)) {
void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, String tenantId, ActionListener<Boolean> internalListener) {
if (!flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)) {
internalListener.onResponse(true);
} else {
QueryBuilder query = QueryBuilders.matchAllQuery();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeOut);

SearchRequest searchRequest = new SearchRequest(CommonValue.GLOBAL_CONTEXT_INDEX).source(searchSourceBuilder);
SearchDataObjectRequest searchRequest = SearchDataObjectRequest.builder()
.indices(GLOBAL_CONTEXT_INDEX)
.searchSourceBuilder(searchSourceBuilder)
.tenantId(tenantId)
.build();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
logger.info("Querying existing workflows to count the max");
client.search(searchRequest, ActionListener.wrap(searchResponse -> {
context.restore();
internalListener.onResponse(searchResponse.getHits().getTotalHits().value < maxWorkflow);
}, exception -> {
String errorMessage = "Unable to fetch the workflows";
logger.error(errorMessage, exception);
internalListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)));
}));
sdkClient.searchDataObjectAsync(searchRequest, client.threadPool().executor(WORKFLOW_THREAD_POOL))
.whenComplete((r, throwable) -> {
if (throwable == null) {
context.restore();
try {
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
internalListener.onResponse(searchResponse.getHits().getTotalHits().value < maxWorkflow);
} catch (Exception e) {
logger.error("Failed to parse workflow searchResponse", e);
internalListener.onFailure(e);
}
} else {
Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable);
String errorMessage = "Unable to fetch the workflows";
logger.error(errorMessage, exception);
internalListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)));
}
});
} catch (Exception e) {
String errorMessage = "Unable to fetch the workflows";
logger.error(errorMessage, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.update.UpdateResponse;
Expand Down Expand Up @@ -48,6 +50,7 @@
import org.opensearch.remote.metadata.client.impl.SdkClientFactory;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ScalingExecutorBuilder;
Expand Down Expand Up @@ -269,23 +272,23 @@ public void testValidation_Failed() throws Exception {
verify(listener, times(1)).onFailure(any());
}

public void testMaxWorkflow() {
public void testMaxWorkflow() throws InterruptedException {
when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(true);

@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false);

doAnswer(invocation -> {
ActionListener<SearchResponse> searchListener = invocation.getArgument(1);
SearchResponse searchResponse = mock(SearchResponse.class);
SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(3, TotalHits.Relation.EQUAL_TO), 1.0f);
when(searchResponse.getHits()).thenReturn(searchHits);
searchListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), any());
SearchResponse searchResponse = generateEmptySearchResponseWithHitCount(3);
PlainActionFuture<SearchResponse> future = PlainActionFuture.newFuture();
future.onResponse(searchResponse);
when(client.search(any(SearchRequest.class))).thenReturn(future);

CountDownLatch latch = new CountDownLatch(1);
LatchedActionListener<WorkflowResponse> latchedActionListener = new LatchedActionListener<>(listener, latch);
createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, latchedActionListener);
latch.await(1, TimeUnit.SECONDS);

createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals(("Maximum workflows limit reached: 2"), exceptionCaptor.getValue().getMessage());
Expand All @@ -305,7 +308,7 @@ public void onFailure(Exception e) {
fail("Should call onResponse");
}
};
createWorkflowTransportAction.checkMaxWorkflows(new TimeValue(10, TimeUnit.SECONDS), 10, listener);
createWorkflowTransportAction.checkMaxWorkflows(new TimeValue(10, TimeUnit.SECONDS), 10, "tenant-id", listener);
}

public void testFailedToCreateNewWorkflow() {
Expand All @@ -318,7 +321,7 @@ public void testFailedToCreateNewWorkflow() {
ActionListener<Boolean> checkMaxWorkflowListener = invocation.getArgument(2);
checkMaxWorkflowListener.onResponse(true);
return null;
}).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any());
}).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), anyString(), any());

// Bypass initializeConfigIndex and force onResponse
doAnswer(invocation -> {
Expand Down Expand Up @@ -349,7 +352,7 @@ public void testCreateNewWorkflow() {
ActionListener<Boolean> checkMaxWorkflowListener = invocation.getArgument(2);
checkMaxWorkflowListener.onResponse(true);
return null;
}).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any());
}).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), anyString(), any());

// Bypass initializeConfigIndex and force onResponse
doAnswer(invocation -> {
Expand Down Expand Up @@ -414,7 +417,7 @@ public void testCreateWithUserAndFilterOn() {
ActionListener<Boolean> checkMaxWorkflowListener = invocation.getArgument(2);
checkMaxWorkflowListener.onResponse(true);
return null;
}).when(createWorkflowTransportAction1).checkMaxWorkflows(any(TimeValue.class), anyInt(), any());
}).when(createWorkflowTransportAction1).checkMaxWorkflows(any(TimeValue.class), anyInt(), anyString(), any());

// Bypass initializeConfigIndex and force onResponse
doAnswer(invocation -> {
Expand Down Expand Up @@ -763,7 +766,7 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc
ActionListener<Boolean> checkMaxWorkflowListener = invocation.getArgument(2);
checkMaxWorkflowListener.onResponse(true);
return null;
}).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any());
}).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), anyString(), any());

// Bypass initializeConfigIndex and force onResponse
doAnswer(invocation -> {
Expand Down Expand Up @@ -823,7 +826,7 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning()
ActionListener<Boolean> checkMaxWorkflowListener = invocation.getArgument(2);
checkMaxWorkflowListener.onResponse(true);
return null;
}).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any());
}).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), anyString(), any());

// Bypass initializeConfigIndex and force onResponse
doAnswer(invocation -> {
Expand Down Expand Up @@ -913,4 +916,34 @@ private Template generateValidTemplate() {

return validTemplate;
}

/**
* Generates a parseable SearchResponse with a hit count but no hits (size=0)
* @param hitCount number of hits
* @return a parseable SearchResponse
*/
private SearchResponse generateEmptySearchResponseWithHitCount(int hitCount) {
SearchHit[] hits = new SearchHit[0];
SearchHits searchHits = new SearchHits(hits, new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO), 1.0f);
SearchResponseSections searchSections = new SearchResponseSections(
searchHits,
InternalAggregations.EMPTY,
null,
true,
false,
null,
1
);
SearchResponse searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);
return searchResponse;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ public void testProvisionWorkflow() throws IOException, InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
LatchedActionListener<WorkflowResponse> latchedActionListener = new LatchedActionListener<>(listener, latch);
provisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, latchedActionListener);
latch.await(2, TimeUnit.SECONDS);
latch.await(5, TimeUnit.SECONDS);

ArgumentCaptor<WorkflowResponse> responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class);
verify(listener, times(1)).onResponse(responseCaptor.capture());
Expand Down

0 comments on commit d55e694

Please sign in to comment.