diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/DummyMysqlChannel.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/DummyMysqlChannel.java index 05b72552f96ed16..4a2d0a7e013c06a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/DummyMysqlChannel.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/DummyMysqlChannel.java @@ -32,21 +32,25 @@ public DummyMysqlChannel() { } public void setSequenceId(int sequenceId) { + throwIfArrowFlightSql(); this.sequenceId = sequenceId; } @Override public String getRemoteIp() { + throwIfArrowFlightSql(); return ""; } @Override public String getRemoteHostPortString() { + throwIfArrowFlightSql(); return ""; } @Override public void close() { + throwIfArrowFlightSql(); } @Override @@ -56,26 +60,32 @@ protected int readAll(ByteBuffer dstBuf, boolean isHeader) throws IOException { @Override public ByteBuffer fetchOnePacket() throws IOException { + throwIfArrowFlightSql(); return ByteBuffer.allocate(0); } @Override public void flush() throws IOException { + throwIfArrowFlightSql(); } @Override public void sendOnePacket(ByteBuffer packet) throws IOException { + throwIfArrowFlightSql(); } @Override public void sendAndFlush(ByteBuffer packet) throws IOException { + throwIfArrowFlightSql(); } @Override public void reset() { + throwIfArrowFlightSql(); } public MysqlSerializer getSerializer() { + throwIfArrowFlightSql(); return serializer; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java index 5eaee47fa4b3774..e299c1d74151c89 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; +import java.util.Arrays; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; @@ -78,16 +79,19 @@ public class MysqlChannel { // mysql flag CLIENT_DEPRECATE_EOF private boolean clientDeprecatedEOF; + protected boolean isFlightSql = false; protected MysqlChannel() { // For DummyMysqlChannel } public void setClientDeprecatedEOF() { + throwIfArrowFlightSql(); clientDeprecatedEOF = true; } public boolean clientDeprecatedEOF() { + throwIfArrowFlightSql(); return clientDeprecatedEOF; } @@ -116,6 +120,7 @@ public MysqlChannel(StreamConnection connection) { } public void initSslBuffer() { + throwIfArrowFlightSql(); // allocate buffer when needed. this.remainingBuffer = ByteBuffer.allocate(16 * 1024); this.remainingBuffer.flip(); @@ -124,20 +129,28 @@ public void initSslBuffer() { } public void setSequenceId(int sequenceId) { + throwIfArrowFlightSql(); this.sequenceId = sequenceId; } + public void setIsFlightSql(boolean isFlightSql) { + this.isFlightSql = isFlightSql; + } + public String getRemoteIp() { + throwIfArrowFlightSql(); return remoteIp; } public void setSslEngine(SSLEngine sslEngine) { + throwIfArrowFlightSql(); this.sslEngine = sslEngine; decryptAppData = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize() * 2); encryptNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize() * 2); } public void setSslMode(boolean sslMode) { + throwIfArrowFlightSql(); isSslMode = sslMode; if (isSslMode) { // channel in ssl mode means handshake phase has finished. @@ -146,6 +159,7 @@ public void setSslMode(boolean sslMode) { } public void setSslHandshaking(boolean sslHandshaking) { + throwIfArrowFlightSql(); isSslHandshaking = sslHandshaking; } @@ -173,6 +187,7 @@ private void accSequenceId() { // Close channel public void close() { + throwIfArrowFlightSql(); try { conn.close(); } catch (IOException e) { @@ -237,6 +252,7 @@ protected void decryptData(ByteBuffer dstBuf, boolean isHeader) throws SSLExcept // NOTE: all of the following code is assumed that the channel is in block mode. // if in handshaking mode we return a packet with header otherwise without header. public ByteBuffer fetchOnePacket() throws IOException { + throwIfArrowFlightSql(); int readLen; ByteBuffer result = defaultBuffer; result.clear(); @@ -390,6 +406,7 @@ protected ByteBuffer encryptData(ByteBuffer dstBuf) throws SSLException { } public void flush() throws IOException { + throwIfArrowFlightSql(); if (null == sendBuffer || sendBuffer.position() == 0) { // Nothing to send return; @@ -437,6 +454,7 @@ private void writeBuffer(ByteBuffer buffer, boolean isSsl) throws IOException { } public void sendOnePacket(ByteBuffer packet) throws IOException { + throwIfArrowFlightSql(); // handshake in packet with header and has encrypted, need to send in ssl format // ssl mode in packet no header and no encrypted, need to encrypted and add header and send in ssl format int bufLen; @@ -464,12 +482,14 @@ public void sendOnePacket(ByteBuffer packet) throws IOException { } public void sendAndFlush(ByteBuffer packet) throws IOException { + throwIfArrowFlightSql(); sendOnePacket(packet); flush(); } // Call this function before send query before public void reset() { + throwIfArrowFlightSql(); isSend = false; if (null != sendBuffer) { sendBuffer.clear(); @@ -477,27 +497,33 @@ public void reset() { } public boolean isSend() { + throwIfArrowFlightSql(); return isSend; } public String getRemoteHostPortString() { + throwIfArrowFlightSql(); return remoteHostPortString; } public void startAcceptQuery(ConnectContext connectContext, ConnectProcessor connectProcessor) { + throwIfArrowFlightSql(); conn.getSourceChannel().setReadListener(new ReadListener(connectContext, connectProcessor)); conn.getSourceChannel().resumeReads(); } public void suspendAcceptQuery() { + throwIfArrowFlightSql(); conn.getSourceChannel().suspendReads(); } public void resumeAcceptQuery() { + throwIfArrowFlightSql(); conn.getSourceChannel().resumeReads(); } public void stopAcceptQuery() throws IOException { + throwIfArrowFlightSql(); conn.getSourceChannel().shutdownReads(); } @@ -552,4 +578,15 @@ private boolean handleUnwrapResult(SSLEngineResult sslEngineResult) { } } + protected void throwIfArrowFlightSql() { + if (isFlightSql) { + try { + throw new RuntimeException("Arrow flight sql unexpected use of MysqlChannel"); + } catch (Exception e) { + LOG.warn(e.getMessage() + Arrays.toString(e.getStackTrace())); + // throw e; + } + } + } + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java index 2f89bcaeb697864..20b4a077b0292a2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java @@ -40,7 +40,7 @@ import org.apache.doris.nereids.stats.StatsErrorEstimator; import org.apache.doris.plugin.AuditEvent.AuditEventBuilder; import org.apache.doris.resource.Tag; -import org.apache.doris.service.arrowflight.MysqlChannelToFlightSql; +import org.apache.doris.service.arrowflight.FlightSqlChannel; import org.apache.doris.statistics.ColumnStatistic; import org.apache.doris.statistics.Histogram; import org.apache.doris.system.Backend; @@ -51,6 +51,7 @@ import org.apache.doris.transaction.TransactionEntry; import org.apache.doris.transaction.TransactionStatus; +import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -104,6 +105,7 @@ public enum ConnectType { protected volatile String peerIdentity; // mysql net protected volatile MysqlChannel mysqlChannel; + protected volatile FlightSqlChannel flightSqlChannel; // state protected volatile QueryState state; protected volatile long returnRows; @@ -351,7 +353,9 @@ public ConnectContext(String peerIdentity) { returnRows = 0; isKilled = false; sessionVariable = VariableMgr.newSessionVariable(); - mysqlChannel = new MysqlChannelToFlightSql(); + mysqlChannel = new DummyMysqlChannel(); + mysqlChannel.setIsFlightSql(true); + flightSqlChannel = new FlightSqlChannel(); command = MysqlCommand.COM_SLEEP; if (Config.use_fuzzy_session_variable) { sessionVariable.initFuzzyModeVariables(); @@ -603,11 +607,16 @@ public MysqlChannel getMysqlChannel() { return mysqlChannel; } + public FlightSqlChannel getFlightSqlChannel() { + Preconditions.checkState(connectType.equals(ConnectType.ARROW_FLIGHT_SQL)); + return flightSqlChannel; + } + public String getClientIP() { if (connectType.equals(ConnectType.MYSQL)) { return mysqlChannel.getRemoteHostPortString(); } else { - return "0.0.0.0"; // TODO + return flightSqlChannel.getRemoteHostPortString(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java index e20e5e6e7cf8715..ee5fd09911d8fa9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java @@ -428,7 +428,7 @@ public boolean isAnalyzeStmt() { * isValuesOrConstantSelect: when this interface return true, original string is truncated at 1024 * * @return parsed and analyzed statement for Stale planner. - * an unresolved LogicalPlan wrapped with a LogicalPlanAdapter for Nereids. + * an unresolved LogicalPlan wrapped with a LogicalPlanAdapter for Nereids. */ public StatementBase getParsedStmt() { return parsedStmt; @@ -597,7 +597,7 @@ private void parseByNereids() { } if (statements.size() <= originStmt.idx) { throw new ParseException("Nereids parse failed. Parser get " + statements.size() + " statements," - + " but we need at least " + originStmt.idx + " statements."); + + " but we need at least " + originStmt.idx + " statements."); } parsedStmt = statements.get(originStmt.idx); } @@ -636,7 +636,7 @@ private void handleQueryWithRetry(TUniqueId queryId) throws Exception { try { for (int i = 0; i < retryTime; i++) { try { - //reset query id for each retry + // reset query id for each retry if (i > 0) { UUID uuid = UUID.randomUUID(); TUniqueId newQueryId = new TUniqueId(uuid.getMostSignificantBits(), @@ -1363,7 +1363,7 @@ private void handleCacheStmt(CacheAnalyzer cacheAnalyzer, MysqlChannel channel) // Process a select statement. private void handleQueryStmt() throws Exception { LOG.info("Handling query {} with query id {}", - originStmt.originStmt, DebugUtil.printId(context.queryId)); + originStmt.originStmt, DebugUtil.printId(context.queryId)); // Every time set no send flag and clean all data in buffer context.getMysqlChannel().reset(); Queriable queryStmt = (Queriable) parsedStmt; @@ -2102,10 +2102,10 @@ private void handleSwitchStmt() throws AnalysisException { private void handlePrepareStmt() throws Exception { // register prepareStmt LOG.debug("add prepared statement {}, isBinaryProtocol {}", - prepareStmt.getName(), prepareStmt.isBinaryProtocol()); + prepareStmt.getName(), prepareStmt.isBinaryProtocol()); context.addPreparedStmt(prepareStmt.getName(), new PrepareStmtContext(prepareStmt, - context, planner, analyzer, prepareStmt.getName())); + context, planner, analyzer, prepareStmt.getName())); if (prepareStmt.isBinaryProtocol()) { sendStmtPrepareOK(); } @@ -2223,24 +2223,33 @@ private void sendFields(List colNames, List types) throws IOExcept } public void sendResultSet(ResultSet resultSet) throws IOException { - context.updateReturnRows(resultSet.getResultRows().size()); - // Send meta data. - sendMetaData(resultSet.getMetaData()); + if (context.getConnectType().equals(ConnectType.MYSQL)) { + context.updateReturnRows(resultSet.getResultRows().size()); + // Send meta data. + sendMetaData(resultSet.getMetaData()); - // Send result set. - for (List row : resultSet.getResultRows()) { - serializer.reset(); - for (String item : row) { - if (item == null || item.equals(FeConstants.null_string)) { - serializer.writeNull(); - } else { - serializer.writeLenEncodedString(item); + // Send result set. + for (List row : resultSet.getResultRows()) { + serializer.reset(); + for (String item : row) { + if (item == null || item.equals(FeConstants.null_string)) { + serializer.writeNull(); + } else { + serializer.writeLenEncodedString(item); + } } + context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer()); } - context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer()); - } - context.getState().setEof(); + context.getState().setEof(); + } else if (context.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) { + context.updateReturnRows(resultSet.getResultRows().size()); + context.getFlightSqlChannel() + .addResultSet(DebugUtil.printId(context.queryId()), context.getRunningQuery(), resultSet); + context.getState().setEof(); + } else { + LOG.error("sendResultSet error connect type"); + } } // Process show statement diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/FrontendServiceImpl.java b/fe/fe-core/src/main/java/org/apache/doris/service/FrontendServiceImpl.java index accedda8c2c051a..6fa8b98c8d7a088 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/FrontendServiceImpl.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/FrontendServiceImpl.java @@ -94,7 +94,7 @@ import org.apache.doris.qe.QueryState; import org.apache.doris.qe.StmtExecutor; import org.apache.doris.qe.VariableMgr; -import org.apache.doris.service.arrowflight.FlightSQLConnectProcessor; +import org.apache.doris.service.arrowflight.FlightSqlConnectProcessor; import org.apache.doris.statistics.ColumnStatistic; import org.apache.doris.statistics.ResultRow; import org.apache.doris.statistics.StatisticsCacheKey; @@ -1112,7 +1112,7 @@ public TMasterOpResult forward(TMasterOpRequest params) throws TException { if (context.getConnectType().equals(ConnectType.MYSQL)) { processor = new MysqlConnectProcessor(context); } else if (context.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) { - processor = new FlightSQLConnectProcessor(context); + processor = new FlightSqlConnectProcessor(context); } else { LOG.warn("unknown ConnectType: {}", context.getConnectType()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java index 66fd75610631f03..7b9ab13cc585758 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java @@ -20,7 +20,6 @@ package org.apache.doris.service.arrowflight; -import org.apache.doris.catalog.Env; import org.apache.doris.common.util.DebugUtil; import org.apache.doris.common.util.Util; import org.apache.doris.mysql.MysqlCommand; @@ -30,7 +29,9 @@ import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Message; -import static java.util.Arrays.asList; +import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; +import org.apache.arrow.adapter.jdbc.JdbcToArrow; +import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils; import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.Criteria; import org.apache.arrow.flight.FlightDescriptor; @@ -65,26 +66,34 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; -import org.apache.arrow.vector.types.pojo.ArrowType.Bool; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import java.io.IOException; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Calendar; import java.util.Collections; import java.util.List; +import java.util.Objects; public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable { private static final Logger LOG = LogManager.getLogger(DorisFlightSqlProducer.class); + private static final Calendar DEFAULT_CALENDAR = JdbcToArrowUtils.getUtcCalendar(); private final Location location; private final BufferAllocator rootAllocator = new RootAllocator(); private final SqlInfoBuilder sqlInfoBuilder; private final FlightSessionsManager flightSessionsManager; + private final FlightSqlChannel flightSqlChannel; public DorisFlightSqlProducer(final Location location, FlightSessionsManager flightSessionsManager) { this.location = location; this.flightSessionsManager = flightSessionsManager; + flightSqlChannel = new FlightSqlChannel(); sqlInfoBuilder = new SqlInfoBuilder(); sqlInfoBuilder.withFlightSqlServerName("DorisFE") .withFlightSqlServerVersion("1.0") @@ -99,9 +108,24 @@ public DorisFlightSqlProducer(final Location location, FlightSessionsManager fli } @Override - public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context, - final ServerStreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("getStreamPreparedStatement unimplemented").toRuntimeException(); + public void close() throws Exception { + AutoCloseables.close(rootAllocator); + } + + @Override + public void listFlights(CallContext context, Criteria criteria, StreamListener listener) { + throw CallStatus.UNIMPLEMENTED.withDescription("listFlights unimplemented").toRuntimeException(); + } + + @Override + public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { + throw CallStatus.UNIMPLEMENTED.withDescription("doExchange unimplemented").toRuntimeException(); + } + + @Override + public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, final CallContext context, + final StreamListener listener) { + throw CallStatus.UNIMPLEMENTED.withDescription("createPreparedStatement unimplemented").toRuntimeException(); } @Override @@ -117,41 +141,35 @@ public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, fi try { connectContext = flightSessionsManager.getConnectContext(context.peerIdentity()); final String query = request.getQuery(); - final FlightSQLConnectProcessor flightSQLConnectProcessor = new FlightSQLConnectProcessor(connectContext); + final FlightSqlConnectProcessor flightSQLConnectProcessor = new FlightSqlConnectProcessor(connectContext); flightSQLConnectProcessor.handleQuery(query); - - Ticket ticket; - Location location; - Schema schema; if (connectContext.isWaitSyncQueryResult()) { + final ByteString handle = ByteString.copyFromUtf8(DebugUtil.printId(connectContext.queryId())); TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder() - .setStatementHandle(ByteString.copyFromUtf8(query)).build(); - ticket = new Ticket(Any.pack(ticketStatement).toByteArray()); - location = Location.forGrpcInsecure(Env.getCurrentEnv().getSelfNode().getHost(), - Env.getCurrentEnv().getSelfNode().getPort()); - schema = new Schema(asList( - new Field("i", new FieldType(true, new Bool(), null, null), null) - )); + .setStatementHandle(handle).build(); + return getFlightInfoForSchema(ticketStatement, descriptor, + jdbcToArrowSchema(flightSqlChannel.getResultSet(DebugUtil.printId(connectContext.queryId())) + .getResultSet() + .getMetaData(), DEFAULT_CALENDAR)); } else { - TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder() - .setStatementHandle(ByteString.copyFromUtf8( - DebugUtil.printId(connectContext.getFinstId()) + ":" + query)).build(); - ticket = new Ticket(Any.pack(ticketStatement).toByteArray()); - // TODO Support multiple endpoints. - location = Location.forGrpcInsecure(connectContext.getResultFlightServerAddr().hostname, - connectContext.getResultFlightServerAddr().port); - - schema = flightSQLConnectProcessor.fetchArrowFlightSchema(5000); + final ByteString handle = ByteString.copyFromUtf8( + DebugUtil.printId(connectContext.getFinstId()) + ":" + query); + Schema schema = flightSQLConnectProcessor.fetchArrowFlightSchema(5000); if (schema == null) { throw CallStatus.INTERNAL.withDescription("fetch arrow flight schema is null").toRuntimeException(); } + TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder() + .setStatementHandle(handle).build(); + Ticket ticket = new Ticket(Any.pack(ticketStatement).toByteArray()); + // TODO Support multiple endpoints. + Location location = Location.forGrpcInsecure(connectContext.getResultFlightServerAddr().hostname, + connectContext.getResultFlightServerAddr().port); + List endpoints = Collections.singletonList(new FlightEndpoint(ticket, location)); + // TODO Set in BE callback after query end, Client will not callback. + connectContext.setCommand(MysqlCommand.COM_SLEEP); + return new FlightInfo(schema, descriptor, endpoints, -1, -1); } - - List endpoints = Collections.singletonList(new FlightEndpoint(ticket, location)); - // TODO Set in BE callback after query end, Client will not callback. - connectContext.setCommand(MysqlCommand.COM_SLEEP); - return new FlightInfo(schema, descriptor, endpoints, -1, -1); } catch (Exception e) { if (null != connectContext) { connectContext.setCommand(MysqlCommand.COM_SLEEP); @@ -176,24 +194,40 @@ public SchemaResult getSchemaStatement(final CommandStatementQuery command, fina } @Override - public void close() throws Exception { - AutoCloseables.close(rootAllocator); - } - - @Override - public void listFlights(CallContext context, Criteria criteria, StreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("listFlights unimplemented").toRuntimeException(); - } + public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context, + final ServerStreamListener listener) { + final String handle = ticketStatementQuery.getStatementHandle().toStringUtf8(); + final FlightSqlResultSet flightSqlResultSet = + Objects.requireNonNull(flightSqlChannel.getResultSet(handle)); + try (final ResultSet resultSet = flightSqlResultSet.getResultSet()) { + final Schema schema = JdbcToArrowUtils.jdbcToArrowSchema(resultSet.getMetaData(), DEFAULT_CALENDAR); + try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) { + final VectorLoader loader = new VectorLoader(vectorSchemaRoot); + listener.start(vectorSchemaRoot); + + final ArrowVectorIterator iterator = JdbcToArrow.sqlToArrowVectorIterator(resultSet, rootAllocator); + while (iterator.hasNext()) { + final VectorUnloader unloader = new VectorUnloader(iterator.next()); + loader.load(unloader.getRecordBatch()); + listener.putNext(); + vectorSchemaRoot.clear(); + } - @Override - public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, final CallContext context, - final StreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("createPreparedStatement unimplemented").toRuntimeException(); + listener.putNext(); + } + } catch (SQLException | IOException e) { + LOG.warn("Failed to getStreamStatement, " + e.getMessage(), e); + listener.error(e); + } finally { + listener.completed(); + flightSqlChannel.invalidate(handle); + } } @Override - public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { - throw CallStatus.UNIMPLEMENTED.withDescription("doExchange unimplemented").toRuntimeException(); + public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context, + final ServerStreamListener listener) { + throw CallStatus.UNIMPLEMENTED.withDescription("getStreamPreparedStatement unimplemented").toRuntimeException(); } @Override @@ -310,24 +344,12 @@ public FlightInfo getFlightInfoExportedKeys(final CommandGetExportedKeys request return getFlightInfoForSchema(request, descriptor, Schemas.GET_EXPORTED_KEYS_SCHEMA); } - @Override - public void getStreamExportedKeys(final CommandGetExportedKeys command, final CallContext context, - final ServerStreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("getStreamExportedKeys unimplemented").toRuntimeException(); - } - @Override public FlightInfo getFlightInfoImportedKeys(final CommandGetImportedKeys request, final CallContext context, final FlightDescriptor descriptor) { return getFlightInfoForSchema(request, descriptor, Schemas.GET_IMPORTED_KEYS_SCHEMA); } - @Override - public void getStreamImportedKeys(final CommandGetImportedKeys command, final CallContext context, - final ServerStreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("getStreamImportedKeys unimplemented").toRuntimeException(); - } - @Override public FlightInfo getFlightInfoCrossReference(CommandGetCrossReference request, CallContext context, FlightDescriptor descriptor) { @@ -335,15 +357,21 @@ public FlightInfo getFlightInfoCrossReference(CommandGetCrossReference request, } @Override - public void getStreamCrossReference(CommandGetCrossReference command, CallContext context, - ServerStreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("getStreamCrossReference unimplemented").toRuntimeException(); + public void getStreamExportedKeys(final CommandGetExportedKeys command, final CallContext context, + final ServerStreamListener listener) { + throw CallStatus.UNIMPLEMENTED.withDescription("getStreamExportedKeys unimplemented").toRuntimeException(); } @Override - public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context, + public void getStreamImportedKeys(final CommandGetImportedKeys command, final CallContext context, final ServerStreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("getStreamStatement unimplemented").toRuntimeException(); + throw CallStatus.UNIMPLEMENTED.withDescription("getStreamImportedKeys unimplemented").toRuntimeException(); + } + + @Override + public void getStreamCrossReference(CommandGetCrossReference command, CallContext context, + ServerStreamListener listener) { + throw CallStatus.UNIMPLEMENTED.withDescription("getStreamCrossReference unimplemented").toRuntimeException(); } private FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor, diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSQLConnectProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSQLConnectProcessor.java index 53bf2f08601d97e..df837ab7e144148 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSQLConnectProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSQLConnectProcessor.java @@ -55,10 +55,10 @@ /** * Process one mysql connection, receive one packet, process, send one packet. */ -public class FlightSQLConnectProcessor extends ConnectProcessor implements AutoCloseable { - private static final Logger LOG = LogManager.getLogger(FlightSQLConnectProcessor.class); +public class FlightSqlConnectProcessor extends ConnectProcessor implements AutoCloseable { + private static final Logger LOG = LogManager.getLogger(FlightSqlConnectProcessor.class); - public FlightSQLConnectProcessor(ConnectContext context) { + public FlightSqlConnectProcessor(ConnectContext context) { super(context); connectType = ConnectType.ARROW_FLIGHT_SQL; context.setThreadLocalInfo(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlChannel.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlChannel.java new file mode 100644 index 000000000000000..9dbb99ffbec1687 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlChannel.java @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.service.arrowflight; + +import org.apache.doris.qe.ResultSet; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.RemovalListener; +import com.google.common.cache.RemovalNotification; +import org.apache.arrow.util.AutoCloseables; +import org.jetbrains.annotations.NotNull; + +import java.util.concurrent.TimeUnit; + +public class FlightSqlChannel { + private final Cache resultCache; + + public FlightSqlChannel() { + resultCache = + CacheBuilder.newBuilder() + .maximumSize(100) + .expireAfterWrite(10, TimeUnit.MINUTES) + .removalListener(new ResultRemovalListener()) + .build(); + } + + // TODO + public String getRemoteIp() { + return "0.0.0.0"; + } + + // TODO + public String getRemoteHostPortString() { + return "0.0.0.0:0"; + } + + public void addResultSet(String queryId, String runningQuery, ResultSet resultSet) { + // connectcontext里存一个队列,把 relustSet 转成 arrow 后存到 队列里,然后 flight 哪里拿 + // handshake in packet with header and has encrypted, need to send in ssl format + // ssl mode in packet no header and no encrypted, need to encrypted and add header and send in ssl format + final FlightSqlResultSet flightSqlResultSet = new FlightSqlResultSet((java.sql.ResultSet) resultSet, + runningQuery); + resultCache.put(queryId, flightSqlResultSet); + } + + public FlightSqlResultSet getResultSet(String queryId) { + return resultCache.getIfPresent(queryId); + } + + public void invalidate(String handle) { + resultCache.invalidate(handle); + } + + private static class ResultRemovalListener implements RemovalListener { + @Override + public void onRemoval(@NotNull final RemovalNotification notification) { + try { + AutoCloseables.close(notification.getValue()); + } catch (final Exception e) { + // swallow + } + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlResultSet.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlResultSet.java new file mode 100644 index 000000000000000..4f55a643c8c7b9d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlResultSet.java @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.service.arrowflight; + +import java.sql.ResultSet; +import java.util.Objects; + + +public final class FlightSqlResultSet implements AutoCloseable { + + private final ResultSet result; + private final String query; + + public FlightSqlResultSet(final ResultSet result, final String query) { + this.result = Objects.requireNonNull(result, "result cannot be null."); + this.query = query; + } + + public ResultSet getResultSet() { + return result; + } + + public String getQuery() { + return query; + } + + @Override + public void close() throws Exception { + } + + @Override + public boolean equals(final Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ResultSet)) { + return false; + } + final ResultSet that = (ResultSet) other; + return result.equals(that); + } + + @Override + public int hashCode() { + return Objects.hash(result); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/MysqlChannelToFlightSql.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/MysqlChannelToFlightSql.java deleted file mode 100644 index f61cdff583dc0ac..000000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/MysqlChannelToFlightSql.java +++ /dev/null @@ -1,74 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.service.arrowflight; - -import org.apache.doris.mysql.MysqlChannel; -import org.apache.doris.mysql.MysqlSerializer; - -import java.io.IOException; -import java.nio.ByteBuffer; - -public class MysqlChannelToFlightSql extends MysqlChannel { - public MysqlChannelToFlightSql() { - this.serializer = MysqlSerializer.newInstance(); - } - - @Override - public String getRemoteIp() { - return ""; - } - - @Override - public String getRemoteHostPortString() { - return ""; - } - - @Override - public void close() { - } - - @Override - protected int readAll(ByteBuffer dstBuf, boolean isHeader) throws IOException { - return 0; - } - - @Override - public ByteBuffer fetchOnePacket() throws IOException { - return ByteBuffer.allocate(0); - } - - @Override - public void flush() throws IOException { - } - - @Override - public void sendOnePacket(ByteBuffer packet) throws IOException { - } - - @Override - public void sendAndFlush(ByteBuffer packet) throws IOException { - } - - @Override - public void reset() { - } - - public MysqlSerializer getSerializer() { - return serializer; - } -}