Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Peng Huo <[email protected]>
  • Loading branch information
penghuo committed Oct 17, 2023
1 parent ff02f28 commit eafbeb4
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.spark.execution.session;

import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession;
import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE;
import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId;
import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession;
import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession;
Expand Down Expand Up @@ -73,11 +74,13 @@ public StatementId submit(QueryRequest request) {
throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId());
} else {
sessionModel = model.get();
if (sessionModel.getSessionState() == SessionState.RUNNING) {
if (!END_STATE.contains(sessionModel.getSessionState())) {
StatementId statementId = newStatementId();
Statement st =
Statement.builder()
.sessionId(sessionId)
.applicationId(sessionModel.getApplicationId())
.jobId(sessionModel.getJobId())
.stateStore(stateStore)
.statementId(statementId)
.langType(LangType.SQL)
Expand All @@ -89,7 +92,7 @@ public StatementId submit(QueryRequest request) {
} else {
String errMsg =
String.format(
"can't submit statement, session should in running state, "
"can't submit statement, session should not be in end state, "
+ "current session state is: %s",
sessionModel.getSessionState().getSessionState());
LOG.debug(errMsg);
Expand All @@ -106,6 +109,8 @@ public Optional<Statement> get(StatementId stID) {
model ->
Statement.builder()
.sessionId(sessionId)
.applicationId(model.getApplicationId())
.jobId(model.getJobId())
.statementId(model.getStatementId())
.langType(model.getLangType())
.query(model.getQuery())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package org.opensearch.sql.spark.execution.session;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Getter;
Expand All @@ -17,6 +19,8 @@ public enum SessionState {
DEAD("dead"),
FAIL("fail");

public static List<SessionState> END_STATE = ImmutableList.of(DEAD, FAIL);

private final String sessionState;

SessionState(String sessionState) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public class Statement {
private static final Logger LOG = LogManager.getLogger();

private final SessionId sessionId;
private final String applicationId;
private final String jobId;
private final StatementId statementId;
private final LangType langType;
private final String query;
Expand All @@ -39,7 +41,8 @@ public class Statement {
/** Open a statement. */
public void open() {
try {
statementModel = submitStatement(sessionId, statementId, langType, query, queryId);
statementModel =
submitStatement(sessionId, applicationId, jobId, statementId, langType, query, queryId);
statementModel = createStatement(stateStore).apply(statementModel);
} catch (VersionConflictEngineException e) {
String errorMsg = "statement already exist. " + statementId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.spark.execution.statement;

import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID;
import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID;
import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING;

import java.io.IOException;
Expand Down Expand Up @@ -40,6 +42,8 @@ public class StatementModel extends StateModel {
private final StatementState statementState;
private final StatementId statementId;
private final SessionId sessionId;
private final String applicationId;
private final String jobId;
private final LangType langType;
private final String query;
private final String queryId;
Expand All @@ -58,6 +62,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
.field(STATEMENT_STATE, statementState.getState())
.field(STATEMENT_ID, statementId.getId())
.field(SESSION_ID, sessionId.getSessionId())
.field(APPLICATION_ID, applicationId)
.field(JOB_ID, jobId)
.field(LANG, langType.getText())
.field(QUERY, query)
.field(QUERY_ID, queryId)
Expand All @@ -73,6 +79,8 @@ public static StatementModel copy(StatementModel copy, long seqNo, long primaryT
.statementState(copy.statementState)
.statementId(copy.statementId)
.sessionId(copy.sessionId)
.applicationId(copy.applicationId)
.jobId(copy.jobId)
.langType(copy.langType)
.query(copy.query)
.queryId(copy.queryId)
Expand All @@ -90,6 +98,8 @@ public static StatementModel copyWithState(
.statementState(state)
.statementId(copy.statementId)
.sessionId(copy.sessionId)
.applicationId(copy.applicationId)
.jobId(copy.jobId)
.langType(copy.langType)
.query(copy.query)
.queryId(copy.queryId)
Expand Down Expand Up @@ -124,6 +134,12 @@ public static StatementModel fromXContent(XContentParser parser, long seqNo, lon
case SESSION_ID:
builder.sessionId(new SessionId(parser.text()));
break;
case APPLICATION_ID:
builder.applicationId(parser.text());
break;
case JOB_ID:
builder.jobId(parser.text());
break;
case LANG:
builder.langType(LangType.fromString(parser.text()));
break;
Expand All @@ -147,12 +163,20 @@ public static StatementModel fromXContent(XContentParser parser, long seqNo, lon
}

public static StatementModel submitStatement(
SessionId sid, StatementId statementId, LangType langType, String query, String queryId) {
SessionId sid,
String applicationId,
String jobId,
StatementId statementId,
LangType langType,
String query,
String queryId) {
return builder()
.version("1.0")
.statementState(WAITING)
.statementId(statementId)
.sessionId(sid)
.applicationId(applicationId)
.jobId(jobId)
.langType(langType)
.query(query)
.queryId(queryId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.stream.Collectors;
import lombok.Getter;

/** {@link Statement} State. */
@Getter
public enum StatementState {
WAITING("waiting"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ public void openThenCancelStatement() {
Statement st =
Statement.builder()
.sessionId(new SessionId("sessionId"))
.applicationId("appId")
.jobId("jobId")
.statementId(new StatementId("statementId"))
.langType(LangType.SQL)
.query("query")
Expand All @@ -80,6 +82,8 @@ public void openFailedBecauseConflict() {
Statement st =
Statement.builder()
.sessionId(new SessionId("sessionId"))
.applicationId("appId")
.jobId("jobId")
.statementId(new StatementId("statementId"))
.langType(LangType.SQL)
.query("query")
Expand All @@ -92,6 +96,8 @@ public void openFailedBecauseConflict() {
Statement dupSt =
Statement.builder()
.sessionId(new SessionId("sessionId"))
.applicationId("appId")
.jobId("jobId")
.statementId(new StatementId("statementId"))
.langType(LangType.SQL)
.query("query")
Expand All @@ -108,6 +114,8 @@ public void cancelNotExistStatement() {
Statement st =
Statement.builder()
.sessionId(new SessionId("sessionId"))
.applicationId("appId")
.jobId("jobId")
.statementId(stId)
.langType(LangType.SQL)
.query("query")
Expand All @@ -130,6 +138,8 @@ public void cancelFailedBecauseOfConflict() {
Statement st =
Statement.builder()
.sessionId(new SessionId("sessionId"))
.applicationId("appId")
.jobId("jobId")
.statementId(stId)
.langType(LangType.SQL)
.query("query")
Expand Down Expand Up @@ -157,6 +167,8 @@ public void cancelRunningStatementFailed() {
Statement st =
Statement.builder()
.sessionId(new SessionId("sessionId"))
.applicationId("appId")
.jobId("jobId")
.statementId(stId)
.langType(LangType.SQL)
.query("query")
Expand Down Expand Up @@ -195,21 +207,69 @@ public void submitStatementInRunningSession() {
}

@Test
public void failToSubmitStatementInStartingSession() {
public void submitStatementInNotStartedState() {
Session session =
new SessionManager(stateStore, emrsClient)
.createSession(new CreateSessionRequest(startJobRequest, "datasource"));

StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1"));
assertFalse(statementId.getId().isEmpty());
}

@Test
public void failToSubmitStatementInDeadState() {
Session session =
new SessionManager(stateStore, emrsClient)
.createSession(new CreateSessionRequest(startJobRequest, "datasource"));

updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD);

IllegalStateException exception =
assertThrows(
IllegalStateException.class,
() -> session.submit(new QueryRequest(LangType.SQL, "select 1")));
assertEquals(
"can't submit statement, session should not be in end state, current session state is:"
+ " dead",
exception.getMessage());
}

@Test
public void failToSubmitStatementInFailState() {
Session session =
new SessionManager(stateStore, emrsClient)
.createSession(new CreateSessionRequest(startJobRequest, "datasource"));

updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL);

IllegalStateException exception =
assertThrows(
IllegalStateException.class,
() -> session.submit(new QueryRequest(LangType.SQL, "select 1")));
assertEquals(
"can't submit statement, session should in running state, current session state is:"
+ " not_started",
"can't submit statement, session should not be in end state, current session state is:"
+ " fail",
exception.getMessage());
}

@Test
public void newStatementFieldAssert() {
Session session =
new SessionManager(stateStore, emrsClient)
.createSession(new CreateSessionRequest(startJobRequest, "datasource"));
StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1"));
Optional<Statement> statement = session.get(statementId);

assertTrue(statement.isPresent());
assertEquals(session.getSessionId(), statement.get().getSessionId());
assertEquals("appId", statement.get().getApplicationId());
assertEquals("jobId", statement.get().getJobId());
assertEquals(statementId, statement.get().getStatementId());
assertEquals(WAITING, statement.get().getStatementState());
assertEquals(LangType.SQL, statement.get().getLangType());
assertEquals("select 1", statement.get().getQuery());
}

@Test
public void failToSubmitStatementInDeletedSession() {
Session session =
Expand Down

0 comments on commit eafbeb4

Please sign in to comment.