diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java index 817bc332afa369c..31363c0d37cefb5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java @@ -24,7 +24,6 @@ import org.apache.doris.common.util.Util; import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.qe.ConnectContext; -import org.apache.doris.service.arrowflight.results.FlightSqlChannel; import org.apache.doris.service.arrowflight.results.FlightSqlResultCacheEntry; import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager; @@ -83,12 +82,10 @@ public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable private final BufferAllocator rootAllocator = new RootAllocator(); private final SqlInfoBuilder sqlInfoBuilder; private final FlightSessionsManager flightSessionsManager; - private final FlightSqlChannel flightSqlChannel; public DorisFlightSqlProducer(final Location location, FlightSessionsManager flightSessionsManager) { this.location = location; this.flightSessionsManager = flightSessionsManager; - flightSqlChannel = new FlightSqlChannel(); sqlInfoBuilder = new SqlInfoBuilder(); sqlInfoBuilder.withFlightSqlServerName("DorisFE") .withFlightSqlServerVersion("1.0") @@ -140,12 +137,13 @@ public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, fi flightSQLConnectProcessor.handleQuery(query); if (connectContext.isWaitSyncQueryResult()) { - final ByteString handle = ByteString.copyFromUtf8(DebugUtil.printId(connectContext.queryId())); + final ByteString handle = ByteString.copyFromUtf8( + context.peerIdentity() + ":" + DebugUtil.printId(connectContext.queryId())); TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder() .setStatementHandle(handle).build(); return getFlightInfoForSchema(ticketStatement, descriptor, - flightSqlChannel.getResultSet(DebugUtil.printId(connectContext.queryId())) - .getVectorSchemaRoot().getSchema()); + connectContext.getFlightSqlChannel().getResultSet(DebugUtil.printId(connectContext.queryId())) + .getVectorSchemaRoot().getSchema()); } else { final ByteString handle = ByteString.copyFromUtf8( DebugUtil.printId(connectContext.getFinstId()) + ":" + query); @@ -170,10 +168,10 @@ public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, fi String errMsg = "get flight info statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage( e) + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: " + connectContext.getState().getErrorMessage(); - LOG.warn(errMsg, e); // context query state + LOG.warn(errMsg, e); throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException(); } - LOG.warn("get flight info statement failed, " + e.getMessage(), e); // context query state + LOG.warn("get flight info statement failed, " + e.getMessage(), e); throw CallStatus.INTERNAL.withDescription(Util.getRootCauseMessage(e)).withCause(e).toRuntimeException(); } } @@ -195,18 +193,34 @@ public SchemaResult getSchemaStatement(final CommandStatementQuery command, fina @Override public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context, final ServerStreamListener listener) { + ConnectContext connectContext = null; final String handle = ticketStatementQuery.getStatementHandle().toStringUtf8(); - final FlightSqlResultCacheEntry flightSqlResultCacheEntry = - Objects.requireNonNull(flightSqlChannel.getResultSet(handle)); - try (final VectorSchemaRoot vectorSchemaRoot = flightSqlResultCacheEntry.getVectorSchemaRoot()) { + String[] handleParts = handle.split(":"); + String executedPeerIdentity = handleParts[0]; + String queryId = handleParts[1]; + try { + connectContext = flightSessionsManager.getConnectContext(executedPeerIdentity); + final FlightSqlResultCacheEntry flightSqlResultCacheEntry = + Objects.requireNonNull(connectContext.getFlightSqlChannel().getResultSet(queryId)); + final VectorSchemaRoot vectorSchemaRoot = flightSqlResultCacheEntry.getVectorSchemaRoot(); listener.start(vectorSchemaRoot); listener.putNext(); } catch (Exception e) { - LOG.warn("Failed to getStreamStatement, " + e.getMessage(), e); listener.error(e); + if (null != connectContext) { + String errMsg = "get stream statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage( + e) + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: " + + connectContext.getState().getErrorMessage(); + LOG.warn(errMsg, e); + throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException(); + } + LOG.warn("get stream statement failed, " + e.getMessage(), e); + throw CallStatus.INTERNAL.withDescription(Util.getRootCauseMessage(e)).withCause(e).toRuntimeException(); } finally { listener.completed(); - flightSqlChannel.invalidate(handle); + if (null != connectContext) { + connectContext.getFlightSqlChannel().invalidate(queryId); + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java index c95fcc2e254a594..4e55e089a7840b0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java @@ -35,7 +35,6 @@ import org.apache.arrow.vector.types.pojo.ArrowType.Utf8; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; import org.jetbrains.annotations.NotNull; import java.util.ArrayList; @@ -80,7 +79,6 @@ public void addResultSet(String queryId, String runningQuery, ResultSet resultSe varCharVector.setValueCount(resultData.size()); dataFields.add(varCharVector); } - Schema schema = new Schema(schemaFields); for (int i = 0; i < resultData.size(); i++) { List row = resultData.get(i);