From eafbeb4f2ae8ec9abf5525348ef34c847c2f27fa Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Mon, 16 Oct 2023 17:49:43 -0700 Subject: [PATCH] address comments Signed-off-by: Peng Huo --- .../execution/session/InteractiveSession.java | 9 ++- .../spark/execution/session/SessionState.java | 4 ++ .../spark/execution/statement/Statement.java | 5 +- .../execution/statement/StatementModel.java | 26 +++++++- .../execution/statement/StatementState.java | 1 + .../execution/statement/StatementTest.java | 66 ++++++++++++++++++- 6 files changed, 104 insertions(+), 7 deletions(-) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 101cc7f5f1..e33ef4245a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; +import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE; import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; @@ -73,11 +74,13 @@ public StatementId submit(QueryRequest request) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { sessionModel = model.get(); - if (sessionModel.getSessionState() == SessionState.RUNNING) { + if (!END_STATE.contains(sessionModel.getSessionState())) { StatementId statementId = newStatementId(); Statement st = Statement.builder() .sessionId(sessionId) + .applicationId(sessionModel.getApplicationId()) + .jobId(sessionModel.getJobId()) .stateStore(stateStore) .statementId(statementId) .langType(LangType.SQL) @@ -89,7 +92,7 @@ public StatementId submit(QueryRequest request) { } else { String errMsg = String.format( - "can't submit statement, session should in running state, " + "can't submit statement, session should not be in end state, " + "current session state is: %s", sessionModel.getSessionState().getSessionState()); LOG.debug(errMsg); @@ -106,6 +109,8 @@ public Optional get(StatementId stID) { model -> Statement.builder() .sessionId(sessionId) + .applicationId(model.getApplicationId()) + .jobId(model.getJobId()) .statementId(model.getStatementId()) .langType(model.getLangType()) .query(model.getQuery()) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java index 509d5105e9..a4da957f12 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -5,7 +5,9 @@ package org.opensearch.sql.spark.execution.session; +import com.google.common.collect.ImmutableList; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -17,6 +19,8 @@ public enum SessionState { DEAD("dead"), FAIL("fail"); + public static List END_STATE = ImmutableList.of(DEAD, FAIL); + private final String sessionState; SessionState(String sessionState) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index 4c54393379..8fcedb5fca 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -28,6 +28,8 @@ public class Statement { private static final Logger LOG = LogManager.getLogger(); private final SessionId sessionId; + private final String applicationId; + private final String jobId; private final StatementId statementId; private final LangType langType; private final String query; @@ -39,7 +41,8 @@ public class Statement { /** Open a statement. */ public void open() { try { - statementModel = submitStatement(sessionId, statementId, langType, query, queryId); + statementModel = + submitStatement(sessionId, applicationId, jobId, statementId, langType, query, queryId); statementModel = createStatement(stateStore).apply(statementModel); } catch (VersionConflictEngineException e) { String errorMsg = "statement already exist. " + statementId; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java index b57868964e..c7f681c541 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -5,6 +5,8 @@ package org.opensearch.sql.spark.execution.statement; +import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID; +import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import java.io.IOException; @@ -40,6 +42,8 @@ public class StatementModel extends StateModel { private final StatementState statementState; private final StatementId statementId; private final SessionId sessionId; + private final String applicationId; + private final String jobId; private final LangType langType; private final String query; private final String queryId; @@ -58,6 +62,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(STATEMENT_STATE, statementState.getState()) .field(STATEMENT_ID, statementId.getId()) .field(SESSION_ID, sessionId.getSessionId()) + .field(APPLICATION_ID, applicationId) + .field(JOB_ID, jobId) .field(LANG, langType.getText()) .field(QUERY, query) .field(QUERY_ID, queryId) @@ -73,6 +79,8 @@ public static StatementModel copy(StatementModel copy, long seqNo, long primaryT .statementState(copy.statementState) .statementId(copy.statementId) .sessionId(copy.sessionId) + .applicationId(copy.applicationId) + .jobId(copy.jobId) .langType(copy.langType) .query(copy.query) .queryId(copy.queryId) @@ -90,6 +98,8 @@ public static StatementModel copyWithState( .statementState(state) .statementId(copy.statementId) .sessionId(copy.sessionId) + .applicationId(copy.applicationId) + .jobId(copy.jobId) .langType(copy.langType) .query(copy.query) .queryId(copy.queryId) @@ -124,6 +134,12 @@ public static StatementModel fromXContent(XContentParser parser, long seqNo, lon case SESSION_ID: builder.sessionId(new SessionId(parser.text())); break; + case APPLICATION_ID: + builder.applicationId(parser.text()); + break; + case JOB_ID: + builder.jobId(parser.text()); + break; case LANG: builder.langType(LangType.fromString(parser.text())); break; @@ -147,12 +163,20 @@ public static StatementModel fromXContent(XContentParser parser, long seqNo, lon } public static StatementModel submitStatement( - SessionId sid, StatementId statementId, LangType langType, String query, String queryId) { + SessionId sid, + String applicationId, + String jobId, + StatementId statementId, + LangType langType, + String query, + String queryId) { return builder() .version("1.0") .statementState(WAITING) .statementId(statementId) .sessionId(sid) + .applicationId(applicationId) + .jobId(jobId) .langType(langType) .query(query) .queryId(queryId) diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java index 87ad6b11ae..33f7f5e831 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java @@ -10,6 +10,7 @@ import java.util.stream.Collectors; import lombok.Getter; +/** {@link Statement} State. */ @Getter public enum StatementState { WAITING("waiting"), diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index b0bc84219b..331955e14e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -57,6 +57,8 @@ public void openThenCancelStatement() { Statement st = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) .query("query") @@ -80,6 +82,8 @@ public void openFailedBecauseConflict() { Statement st = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) .query("query") @@ -92,6 +96,8 @@ public void openFailedBecauseConflict() { Statement dupSt = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(new StatementId("statementId")) .langType(LangType.SQL) .query("query") @@ -108,6 +114,8 @@ public void cancelNotExistStatement() { Statement st = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(stId) .langType(LangType.SQL) .query("query") @@ -130,6 +138,8 @@ public void cancelFailedBecauseOfConflict() { Statement st = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(stId) .langType(LangType.SQL) .query("query") @@ -157,6 +167,8 @@ public void cancelRunningStatementFailed() { Statement st = Statement.builder() .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") .statementId(stId) .langType(LangType.SQL) .query("query") @@ -195,21 +207,69 @@ public void submitStatementInRunningSession() { } @Test - public void failToSubmitStatementInStartingSession() { + public void submitStatementInNotStartedState() { Session session = new SessionManager(stateStore, emrsClient) .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + assertFalse(statementId.getId().isEmpty()); + } + + @Test + public void failToSubmitStatementInDeadState() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertEquals( + "can't submit statement, session should not be in end state, current session state is:" + + " dead", + exception.getMessage()); + } + + @Test + public void failToSubmitStatementInFailState() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); + IllegalStateException exception = assertThrows( IllegalStateException.class, () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); assertEquals( - "can't submit statement, session should in running state, current session state is:" - + " not_started", + "can't submit statement, session should not be in end state, current session state is:" + + " fail", exception.getMessage()); } + @Test + public void newStatementFieldAssert() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + Optional statement = session.get(statementId); + + assertTrue(statement.isPresent()); + assertEquals(session.getSessionId(), statement.get().getSessionId()); + assertEquals("appId", statement.get().getApplicationId()); + assertEquals("jobId", statement.get().getJobId()); + assertEquals(statementId, statement.get().getStatementId()); + assertEquals(WAITING, statement.get().getStatementState()); + assertEquals(LangType.SQL, statement.get().getLangType()); + assertEquals("select 1", statement.get().getQuery()); + } + @Test public void failToSubmitStatementInDeletedSession() { Session session =