diff --git a/spark/src/main/antlr/SqlBaseParser.g4 b/spark/src/main/antlr/SqlBaseParser.g4 index 6a6d39e96c..77a9108e06 100644 --- a/spark/src/main/antlr/SqlBaseParser.g4 +++ b/spark/src/main/antlr/SqlBaseParser.g4 @@ -967,7 +967,6 @@ primaryExpression | qualifiedName DOT ASTERISK #star | LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor | LEFT_PAREN query RIGHT_PAREN #subqueryExpression - | IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN #identifierClause | functionName LEFT_PAREN (setQuantifier? argument+=functionArgument (COMMA argument+=functionArgument)*)? RIGHT_PAREN (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? @@ -1196,6 +1195,7 @@ qualifiedNameList functionName : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN + | identFunc=IDENTIFIER_KW // IDENTIFIER itself is also a valid function name. | qualifiedName | FILTER | LEFT diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java index c4382239a1..f57c8facee 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -7,12 +7,14 @@ import java.util.Map; import lombok.Data; +import lombok.EqualsAndHashCode; /** * This POJO carries all the fields required for emr serverless job submission. Used as model in * {@link EMRServerlessClient} interface. */ @Data +@EqualsAndHashCode public class StartJobRequest { public static final Long DEFAULT_JOB_TIMEOUT = 120L; diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index ab9761da36..ae274492b2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -41,6 +41,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.master.AcknowledgedResponse; @@ -78,6 +80,8 @@ public class SparkQueryDispatcherTest { private SparkQueryDispatcher sparkQueryDispatcher; + @Captor ArgumentCaptor startJobRequestArgumentCaptor; + @BeforeEach void setUp() { sparkQueryDispatcher = @@ -125,23 +129,24 @@ void testDispatchSelectQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }), + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -183,24 +188,25 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "basicauth", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); - put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString( + "basicauth", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); + put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); + } + }), + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -240,22 +246,23 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "noauth", - new HashMap<>() { - { - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString( + "noauth", + new HashMap<>() { + { + } + }), + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -302,24 +309,25 @@ void testDispatchIndexQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), - tags, - true, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + })), + tags, + true, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -361,23 +369,24 @@ void testDispatchWithPPLQuery() { LangType.PPL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }), + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -419,23 +428,24 @@ void testDispatchQueryWithoutATableAndDataSourceName() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }), + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -483,24 +493,25 @@ void testDispatchIndexQueryWithoutADatasourceName() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), - tags, - true, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + })), + tags, + true, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -905,8 +916,8 @@ private String constructExpectedSparkSubmitParameterString( + " --conf" + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegatingSessionCatalog " - + authParamConfigBuilder - + " --conf spark.flint.datasource.name=my_glue "; + + " --conf spark.flint.datasource.name=my_glue " + + authParamConfigBuilder; } private String withStructuredStreaming(String parameters) {