diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 0e9e128896..0639768354 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -19,6 +19,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.asyncquery.model.QueryState; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; @@ -116,7 +117,11 @@ public String cancelQuery(String queryId, AsyncQueryRequestContext asyncQueryReq Optional asyncQueryJobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (asyncQueryJobMetadata.isPresent()) { - return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get(), asyncQueryRequestContext); + String result = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get(), asyncQueryRequestContext); + asyncQueryJobMetadataStorageService.updateState( + asyncQueryJobMetadata.get(), QueryState.CANCELLED, asyncQueryRequestContext); + return result; } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java index b4e94c984d..86e925f58f 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryJobMetadataStorageService.java @@ -10,6 +10,7 @@ import java.util.Optional; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.asyncquery.model.QueryState; public interface AsyncQueryJobMetadataStorageService { @@ -17,5 +18,10 @@ void storeJobMetadata( AsyncQueryJobMetadata asyncQueryJobMetadata, AsyncQueryRequestContext asyncQueryRequestContext); + void updateState( + AsyncQueryJobMetadata asyncQueryJobMetadata, + QueryState newState, + AsyncQueryRequestContext asyncQueryRequestContext); + Optional getJobMetadata(String jobId); } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 73850db83d..3177c335d9 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.asyncquery; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -33,6 +34,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.asyncquery.model.QueryState; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.config.SparkSubmitParameterModifier; @@ -109,7 +111,7 @@ void testCreateAsyncQuery() { .getSparkExecutionEngineConfig(asyncQueryRequestContext); verify(sparkQueryDispatcher, times(1)) .dispatch(expectedDispatchQueryRequest, asyncQueryRequestContext); - Assertions.assertEquals(QUERY_ID, createAsyncQueryResponse.getQueryId()); + assertEquals(QUERY_ID, createAsyncQueryResponse.getQueryId()); } @Test @@ -153,8 +155,7 @@ void testGetAsyncQueryResultsWithJobNotFoundException() { AsyncQueryNotFoundException.class, () -> jobExecutorService.getAsyncQueryResults(EMR_JOB_ID, asyncQueryRequestContext)); - Assertions.assertEquals( - "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); + assertEquals("QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); verifyNoInteractions(sparkQueryDispatcher); verifyNoInteractions(sparkExecutionEngineConfigSupplier); } @@ -174,7 +175,7 @@ void testGetAsyncQueryResultsWithInProgressJob() { Assertions.assertNull(asyncQueryExecutionResponse.getResults()); Assertions.assertNull(asyncQueryExecutionResponse.getSchema()); - Assertions.assertEquals("PENDING", asyncQueryExecutionResponse.getStatus()); + assertEquals("PENDING", asyncQueryExecutionResponse.getStatus()); verifyNoInteractions(sparkExecutionEngineConfigSupplier); } @@ -191,11 +192,10 @@ void testGetAsyncQueryResultsWithSuccessJob() throws IOException { AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID, asyncQueryRequestContext); - Assertions.assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); - Assertions.assertEquals(1, asyncQueryExecutionResponse.getSchema().getColumns().size()); - Assertions.assertEquals( - "1", asyncQueryExecutionResponse.getSchema().getColumns().get(0).getName()); - Assertions.assertEquals( + assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus()); + assertEquals(1, asyncQueryExecutionResponse.getSchema().getColumns().size()); + assertEquals("1", asyncQueryExecutionResponse.getSchema().getColumns().get(0).getName()); + assertEquals( 1, ((HashMap) asyncQueryExecutionResponse.getResults().get(0).value()) .get("1")); @@ -212,8 +212,7 @@ void testCancelJobWithJobNotFound() { AsyncQueryNotFoundException.class, () -> jobExecutorService.cancelQuery(EMR_JOB_ID, asyncQueryRequestContext)); - Assertions.assertEquals( - "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); + assertEquals("QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); verifyNoInteractions(sparkQueryDispatcher); verifyNoInteractions(sparkExecutionEngineConfigSupplier); } @@ -227,7 +226,9 @@ void testCancelJob() { String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID, asyncQueryRequestContext); - Assertions.assertEquals(EMR_JOB_ID, jobId); + assertEquals(EMR_JOB_ID, jobId); + verify(asyncQueryJobMetadataStorageService) + .updateState(any(), eq(QueryState.CANCELLED), eq(asyncQueryRequestContext)); verifyNoInteractions(sparkExecutionEngineConfigSupplier); } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java b/async-query/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java index 4847c8e00f..eb377a5cff 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/asyncquery/OpenSearchAsyncQueryJobMetadataStorageService.java @@ -12,6 +12,7 @@ import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.asyncquery.model.QueryState; import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer; @@ -39,6 +40,14 @@ public void storeJobMetadata( OpenSearchStateStoreUtil.getIndexName(asyncQueryJobMetadata.getDatasourceName())); } + @Override + public void updateState( + AsyncQueryJobMetadata asyncQueryJobMetadata, + QueryState newState, + AsyncQueryRequestContext asyncQueryRequestContext) { + // NoOp since AsyncQueryJobMetadata record does not store state now + } + private String mapIdToDocumentId(String id) { return "qid" + id; }