diff --git a/be/src/runtime/result_buffer_mgr.cpp b/be/src/runtime/result_buffer_mgr.cpp index 9dbe228bcd98905..a2009c5ec3c970c 100644 --- a/be/src/runtime/result_buffer_mgr.cpp +++ b/be/src/runtime/result_buffer_mgr.cpp @@ -109,21 +109,21 @@ std::shared_ptr ResultBufferMgr::find_control_block(const TU return std::shared_ptr(); } -void ResultBufferMgr::register_row_descriptor(const TUniqueId& query_id, - const RowDescriptor& row_desc) { - std::unique_lock wlock(_row_descriptor_map_lock); - _row_descriptor_map.insert(std::make_pair(query_id, row_desc)); +void ResultBufferMgr::register_arrow_schema(const TUniqueId& query_id, + const std::shared_ptr& arrow_schema) { + std::unique_lock wlock(_arrow_schema_map_lock); + _arrow_schema_map.insert(std::make_pair(query_id, arrow_schema)); } -RowDescriptor ResultBufferMgr::find_row_descriptor(const TUniqueId& query_id) { - std::shared_lock rlock(_row_descriptor_map_lock); - RowDescriptorMap::iterator iter = _row_descriptor_map.find(query_id); +std::shared_ptr ResultBufferMgr::find_arrow_schema(const TUniqueId& query_id) { + std::shared_lock rlock(_arrow_schema_map_lock); + auto iter = _arrow_schema_map.find(query_id); - if (_row_descriptor_map.end() != iter) { + if (_arrow_schema_map.end() != iter) { return iter->second; } - return RowDescriptor(); + return nullptr; } void ResultBufferMgr::fetch_data(const PUniqueId& finst_id, GetResultBatchCtx* ctx) { @@ -162,11 +162,11 @@ Status ResultBufferMgr::cancel(const TUniqueId& query_id) { } { - std::unique_lock wlock(_row_descriptor_map_lock); - RowDescriptorMap::iterator row_desc_iter = _row_descriptor_map.find(query_id); + std::unique_lock wlock(_arrow_schema_map_lock); + auto arrow_schema_iter = _arrow_schema_map.find(query_id); - if (_row_descriptor_map.end() != row_desc_iter) { - _row_descriptor_map.erase(row_desc_iter); + if (_arrow_schema_map.end() != arrow_schema_iter) { + _arrow_schema_map.erase(arrow_schema_iter); } } diff --git a/be/src/runtime/result_buffer_mgr.h b/be/src/runtime/result_buffer_mgr.h index 4e5cd38a7264b7a..7995496cbf9c6d5 100644 --- a/be/src/runtime/result_buffer_mgr.h +++ b/be/src/runtime/result_buffer_mgr.h @@ -29,12 +29,12 @@ #include "common/status.h" #include "gutil/ref_counted.h" -#include "runtime/descriptors.h" #include "util/countdown_latch.h" #include "util/hash_util.hpp" namespace arrow { class RecordBatch; +class Schema; } // namespace arrow namespace doris { @@ -66,8 +66,9 @@ class ResultBufferMgr { // fetch data result to Arrow Flight Server Status fetch_arrow_data(const TUniqueId& finst_id, std::shared_ptr* result); - void register_row_descriptor(const TUniqueId& query_id, const RowDescriptor& row_desc); - RowDescriptor find_row_descriptor(const TUniqueId& query_id); + void register_arrow_schema(const TUniqueId& query_id, + const std::shared_ptr& arrow_schema); + std::shared_ptr find_arrow_schema(const TUniqueId& query_id); // cancel Status cancel(const TUniqueId& fragment_id); @@ -78,7 +79,7 @@ class ResultBufferMgr { private: using BufferMap = std::unordered_map>; using TimeoutMap = std::map>; - using RowDescriptorMap = std::unordered_map; + using ArrowSchemaMap = std::unordered_map>; std::shared_ptr find_control_block(const TUniqueId& query_id); @@ -90,10 +91,10 @@ class ResultBufferMgr { std::shared_mutex _buffer_map_lock; // buffer block map BufferMap _buffer_map; - // lock for descriptor map - std::shared_mutex _row_descriptor_map_lock; + // lock for arrow schema map + std::shared_mutex _arrow_schema_map_lock; // for arrow flight - RowDescriptorMap _row_descriptor_map; + ArrowSchemaMap _arrow_schema_map; // lock for timeout map std::mutex _timeout_lock; diff --git a/be/src/service/arrow_flight/arrow_flight_batch_reader.cpp b/be/src/service/arrow_flight/arrow_flight_batch_reader.cpp index 8a0f1e67859494c..553ca96b3748367 100644 --- a/be/src/service/arrow_flight/arrow_flight_batch_reader.cpp +++ b/be/src/service/arrow_flight/arrow_flight_batch_reader.cpp @@ -40,17 +40,11 @@ arrow::Result> ArrowFlightBatchReader::C const std::shared_ptr& statement_) { // Make sure that FE send the fragment to BE and creates the BufferControlBlock before returning ticket // to the ADBC client, so that the row_descriptor and control block can be found. - RowDescriptor row_desc = - ExecEnv::GetInstance()->result_mgr()->find_row_descriptor(statement_->query_id); - if (row_desc.equals(RowDescriptor())) { + auto schema = ExecEnv::GetInstance()->result_mgr()->find_arrow_schema(statement_->query_id); + if (schema == nullptr) { ARROW_RETURN_NOT_OK(arrow::Status::Invalid(fmt::format( - "Schema RowDescriptor Not Found, queryid: {}", print_id(statement_->query_id)))); - } - std::shared_ptr schema; - auto st = convert_to_arrow_schema(row_desc, &schema); - if (UNLIKELY(!st.ok())) { - LOG(WARNING) << st.to_string(); - ARROW_RETURN_NOT_OK(to_arrow_status(st)); + "not found arrow flight schema, maybe query has been canceled, queryid: {}", + print_id(statement_->query_id)))); } std::shared_ptr result(new ArrowFlightBatchReader(statement_, schema)); return result; diff --git a/be/src/service/internal_service.cpp b/be/src/service/internal_service.cpp index f80f4ddb5e0a9e5..91250e50f9a3073 100644 --- a/be/src/service/internal_service.cpp +++ b/be/src/service/internal_service.cpp @@ -707,23 +707,19 @@ void PInternalServiceImpl::fetch_arrow_flight_schema(google::protobuf::RpcContro google::protobuf::Closure* done) { bool ret = _light_work_pool.try_offer([request, result, done]() { brpc::ClosureGuard closure_guard(done); - RowDescriptor row_desc = ExecEnv::GetInstance()->result_mgr()->find_row_descriptor( - UniqueId(request->finst_id()).to_thrift()); - if (row_desc.equals(RowDescriptor())) { - auto st = Status::NotFound("not found row descriptor"); - st.to_protobuf(result->mutable_status()); - return; - } - - std::shared_ptr schema; - auto st = convert_to_arrow_schema(row_desc, &schema); - if (UNLIKELY(!st.ok())) { + std::shared_ptr schema = + ExecEnv::GetInstance()->result_mgr()->find_arrow_schema( + UniqueId(request->finst_id()).to_thrift()); + if (schema == nullptr) { + LOG(INFO) << "not found arrow flight schema, maybe query has been canceled"; + auto st = Status::NotFound( + "not found arrow flight schema, maybe query has been canceled"); st.to_protobuf(result->mutable_status()); return; } std::string schema_str; - st = serialize_arrow_schema(row_desc, &schema, &schema_str); + auto st = serialize_arrow_schema(&schema, &schema_str); if (st.ok()) { result->set_schema(std::move(schema_str)); } diff --git a/be/src/util/arrow/row_batch.cpp b/be/src/util/arrow/row_batch.cpp index 6a44da2ec6b642f..6662f2e0ba7aee7 100644 --- a/be/src/util/arrow/row_batch.cpp +++ b/be/src/util/arrow/row_batch.cpp @@ -39,6 +39,8 @@ #include "runtime/types.h" #include "util/arrow/block_convertor.h" #include "vec/core/block.h" +#include "vec/exprs/vexpr.h" +#include "vec/exprs/vexpr_context.h" namespace doris { @@ -163,6 +165,22 @@ Status convert_to_arrow_schema(const RowDescriptor& row_desc, return Status::OK(); } +Status convert_expr_ctxs_arrow_schema(const vectorized::VExprContextSPtrs& output_vexpr_ctxs, + std::shared_ptr* result) { + std::vector> fields; + for (auto expr_ctx : output_vexpr_ctxs) { + std::shared_ptr arrow_type; + auto root_expr = expr_ctx->root(); + RETURN_IF_ERROR(convert_to_arrow_type(root_expr->type(), &arrow_type)); + auto field_name = root_expr->is_slot_ref() ? root_expr->expr_name() + : root_expr->data_type()->get_name(); + fields.push_back( + std::make_shared(field_name, arrow_type, root_expr->is_nullable())); + } + *result = arrow::schema(std::move(fields)); + return Status::OK(); +} + Status serialize_record_batch(const arrow::RecordBatch& record_batch, std::string* result) { // create sink memory buffer outputstream with the computed capacity int64_t capacity; @@ -206,15 +224,13 @@ Status serialize_record_batch(const arrow::RecordBatch& record_batch, std::strin return Status::OK(); } -Status serialize_arrow_schema(RowDescriptor row_desc, std::shared_ptr* schema, - std::string* result) { - std::vector slots; - for (auto tuple_desc : row_desc.tuple_descriptors()) { - slots.insert(slots.end(), tuple_desc->slots().begin(), tuple_desc->slots().end()); +Status serialize_arrow_schema(std::shared_ptr* schema, std::string* result) { + auto make_empty_result = arrow::RecordBatch::MakeEmpty(*schema); + if (!make_empty_result.ok()) { + return Status::InternalError("serialize_arrow_schema failed, reason: {}", + make_empty_result.status().ToString()); } - auto block = vectorized::Block(slots, 0); - std::shared_ptr batch; - RETURN_IF_ERROR(convert_to_arrow_batch(block, *schema, arrow::default_memory_pool(), &batch)); + auto batch = make_empty_result.ValueOrDie(); return serialize_record_batch(*batch, result); } diff --git a/be/src/util/arrow/row_batch.h b/be/src/util/arrow/row_batch.h index 1bd408754f1b58e..ddffc3324d34512 100644 --- a/be/src/util/arrow/row_batch.h +++ b/be/src/util/arrow/row_batch.h @@ -23,6 +23,7 @@ #include "common/status.h" #include "runtime/types.h" #include "vec/core/block.h" +#include "vec/exprs/vexpr_fwd.h" // This file will convert Doris RowBatch to/from Arrow's RecordBatch // RowBatch is used by Doris query engine to exchange data between @@ -49,9 +50,11 @@ Status convert_to_arrow_schema(const RowDescriptor& row_desc, Status convert_block_arrow_schema(const vectorized::Block& block, std::shared_ptr* result); +Status convert_expr_ctxs_arrow_schema(const vectorized::VExprContextSPtrs& output_vexpr_ctxs, + std::shared_ptr* result); + Status serialize_record_batch(const arrow::RecordBatch& record_batch, std::string* result); -Status serialize_arrow_schema(RowDescriptor row_desc, std::shared_ptr* schema, - std::string* result); +Status serialize_arrow_schema(std::shared_ptr* schema, std::string* result); } // namespace doris diff --git a/be/src/vec/sink/varrow_flight_result_writer.cpp b/be/src/vec/sink/varrow_flight_result_writer.cpp index 4a71e10df426a61..771040bfb8b4779 100644 --- a/be/src/vec/sink/varrow_flight_result_writer.cpp +++ b/be/src/vec/sink/varrow_flight_result_writer.cpp @@ -27,14 +27,13 @@ namespace doris { namespace vectorized { -VArrowFlightResultWriter::VArrowFlightResultWriter(BufferControlBlock* sinker, - const VExprContextSPtrs& output_vexpr_ctxs, - RuntimeProfile* parent_profile, - const RowDescriptor& row_desc) +VArrowFlightResultWriter::VArrowFlightResultWriter( + BufferControlBlock* sinker, const VExprContextSPtrs& output_vexpr_ctxs, + RuntimeProfile* parent_profile, const std::shared_ptr& arrow_schema) : _sinker(sinker), _output_vexpr_ctxs(output_vexpr_ctxs), _parent_profile(parent_profile), - _row_desc(row_desc) {} + _arrow_schema(arrow_schema) {} Status VArrowFlightResultWriter::init(RuntimeState* state) { _init_profile(); @@ -42,8 +41,6 @@ Status VArrowFlightResultWriter::init(RuntimeState* state) { return Status::InternalError("sinker is NULL pointer."); } _is_dry_run = state->query_options().dry_run_query; - // generate the arrow schema - RETURN_IF_ERROR(convert_to_arrow_schema(_row_desc, &_arrow_schema)); return Status::OK(); } @@ -100,7 +97,7 @@ bool VArrowFlightResultWriter::can_sink() { return _sinker->can_sink(); } -Status VArrowFlightResultWriter::close(Status) { +Status VArrowFlightResultWriter::close(Status st) { COUNTER_SET(_sent_rows_counter, _written_rows); COUNTER_UPDATE(_bytes_sent_counter, _bytes_sent); return Status::OK(); diff --git a/be/src/vec/sink/varrow_flight_result_writer.h b/be/src/vec/sink/varrow_flight_result_writer.h index 7aa8ec6824a7e62..02faebfddb3ad5e 100644 --- a/be/src/vec/sink/varrow_flight_result_writer.h +++ b/be/src/vec/sink/varrow_flight_result_writer.h @@ -31,7 +31,6 @@ namespace doris { class BufferControlBlock; class RuntimeState; -class RowDescriptor; namespace vectorized { class Block; @@ -39,7 +38,8 @@ class Block; class VArrowFlightResultWriter final : public ResultWriter { public: VArrowFlightResultWriter(BufferControlBlock* sinker, const VExprContextSPtrs& output_vexpr_ctxs, - RuntimeProfile* parent_profile, const RowDescriptor& row_desc); + RuntimeProfile* parent_profile, + const std::shared_ptr& arrow_schema); Status init(RuntimeState* state) override; @@ -72,7 +72,6 @@ class VArrowFlightResultWriter final : public ResultWriter { uint64_t _bytes_sent = 0; - const RowDescriptor& _row_desc; std::shared_ptr _arrow_schema; }; } // namespace vectorized diff --git a/be/src/vec/sink/vmemory_scratch_sink.cpp b/be/src/vec/sink/vmemory_scratch_sink.cpp index f9192d5c79f140b..d4f0d4521c04b26 100644 --- a/be/src/vec/sink/vmemory_scratch_sink.cpp +++ b/be/src/vec/sink/vmemory_scratch_sink.cpp @@ -56,8 +56,6 @@ Status MemoryScratchSink::_prepare_vexpr(RuntimeState* state) { RETURN_IF_ERROR(VExpr::create_expr_trees(_t_output_expr, _output_vexpr_ctxs)); // Prepare the exprs to run. RETURN_IF_ERROR(VExpr::prepare(_output_vexpr_ctxs, state, _row_desc)); - // generate the arrow schema - RETURN_IF_ERROR(convert_to_arrow_schema(_row_desc, &_arrow_schema)); return Status::OK(); } diff --git a/be/src/vec/sink/vmemory_scratch_sink.h b/be/src/vec/sink/vmemory_scratch_sink.h index e91d130547acca9..c9a6922336ce299 100644 --- a/be/src/vec/sink/vmemory_scratch_sink.h +++ b/be/src/vec/sink/vmemory_scratch_sink.h @@ -65,8 +65,6 @@ class MemoryScratchSink final : public DataSink { private: Status _prepare_vexpr(RuntimeState* state); - std::shared_ptr _arrow_schema; - BlockQueueSharedPtr _queue; // Owned by the RuntimeState. diff --git a/be/src/vec/sink/vresult_sink.cpp b/be/src/vec/sink/vresult_sink.cpp index b3a2d3bae7f7b9a..d5ca67b76c79856 100644 --- a/be/src/vec/sink/vresult_sink.cpp +++ b/be/src/vec/sink/vresult_sink.cpp @@ -33,6 +33,7 @@ #include "runtime/exec_env.h" #include "runtime/result_buffer_mgr.h" #include "runtime/runtime_state.h" +#include "util/arrow/row_batch.h" #include "util/runtime_profile.h" #include "util/telemetry/telemetry.h" #include "vec/exprs/vexpr.h" @@ -98,12 +99,15 @@ Status VResultSink::prepare(RuntimeState* state) { _writer.reset(new (std::nothrow) VMysqlResultWriter(_sender.get(), _output_vexpr_ctxs, _profile)); break; - case TResultSinkType::ARROW_FLIGHT_PROTOCAL: - state->exec_env()->result_mgr()->register_row_descriptor(state->fragment_instance_id(), - _row_desc); + case TResultSinkType::ARROW_FLIGHT_PROTOCAL: { + std::shared_ptr arrow_schema; + RETURN_IF_ERROR(convert_expr_ctxs_arrow_schema(_output_vexpr_ctxs, &arrow_schema)); + state->exec_env()->result_mgr()->register_arrow_schema(state->fragment_instance_id(), + arrow_schema); _writer.reset(new (std::nothrow) VArrowFlightResultWriter(_sender.get(), _output_vexpr_ctxs, - _profile, _row_desc)); + _profile, arrow_schema)); break; + } default: return Status::InternalError("Unknown result sink type"); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/ErrorCode.java b/fe/fe-core/src/main/java/org/apache/doris/common/ErrorCode.java index 5a4806b12ec3a96..27e61ff87e05887 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/ErrorCode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/ErrorCode.java @@ -1204,7 +1204,10 @@ public enum ErrorCode { "the auto increment must be BIGINT type."), ERR_AUTO_INCREMENT_COLUMN_IN_AGGREGATE_TABLE(5096, new byte[]{'4', '2', '0', '0', '0'}, - "the auto increment is only supported in duplicate table and unique table."); + "the auto increment is only supported in duplicate table and unique table."), + + ERR_ARROW_FLIGHT_SQL_MUST_ONLY_RESULT_STMT(5097, new byte[]{'4', '2', '0', '0', '0'}, + "There can only be one stmt that returns the result and it is at the end."); // This is error code private final int code; diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/AcceptListener.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/AcceptListener.java index 1bde95c165073ca..67a84bfb8a2fc78 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/AcceptListener.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/AcceptListener.java @@ -22,6 +22,7 @@ import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.ConnectProcessor; import org.apache.doris.qe.ConnectScheduler; +import org.apache.doris.qe.MysqlConnectProcessor; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -81,7 +82,7 @@ public void handleEvent(AcceptingChannel channel) { context.getEnv().getAuth().getQueryTimeout(context.getQualifiedUser())); context.setUserInsertTimeout( context.getEnv().getAuth().getInsertTimeout(context.getQualifiedUser())); - ConnectProcessor processor = new ConnectProcessor(context); + ConnectProcessor processor = new MysqlConnectProcessor(context); context.startAcceptQuery(processor); } catch (AfterConnectedException e) { // do not need to print log for this kind of exception. diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlCommand.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlCommand.java index f8a03029d5a383e..f1f1a44313114d5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlCommand.java @@ -23,6 +23,7 @@ import java.util.Map; // MySQL protocol text command +// Reused by arrow flight protocol public enum MysqlCommand { COM_SLEEP("Sleep", 0), COM_QUIT("Quit", 1), diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java index 15359faacaf650c..c75ed5a326faaa9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java @@ -17,6 +17,7 @@ package org.apache.doris.qe; +import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.UserIdentity; import org.apache.doris.catalog.DatabaseIf; import org.apache.doris.catalog.Env; @@ -39,10 +40,12 @@ import org.apache.doris.nereids.stats.StatsErrorEstimator; import org.apache.doris.plugin.AuditEvent.AuditEventBuilder; import org.apache.doris.resource.Tag; +import org.apache.doris.service.arrowflight.results.FlightSqlChannel; import org.apache.doris.statistics.ColumnStatistic; import org.apache.doris.statistics.Histogram; import org.apache.doris.system.Backend; import org.apache.doris.task.LoadTaskInfo; +import org.apache.doris.thrift.TNetworkAddress; import org.apache.doris.thrift.TResultSinkType; import org.apache.doris.thrift.TUniqueId; import org.apache.doris.transaction.TransactionEntry; @@ -59,6 +62,7 @@ import org.xnio.StreamConnection; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -77,7 +81,7 @@ public class ConnectContext { public enum ConnectType { MYSQL, - ARROW_FLIGHT + ARROW_FLIGHT_SQL } protected volatile ConnectType connectType; @@ -96,8 +100,15 @@ public enum ConnectType { protected volatile int connectionId; // Timestamp when the connection is make protected volatile long loginTime; - // arrow flight token + // for arrow flight + protected volatile FlightSqlChannel flightSqlChannel; protected volatile String peerIdentity; + private String runningQuery; + private TNetworkAddress resultFlightServerAddr; + private TNetworkAddress resultInternalServiceAddr; + private ArrayList resultOutputExprs; + private TUniqueId finstId; + private boolean returnResultFromLocal = true; // mysql net protected volatile MysqlChannel mysqlChannel; // state @@ -190,7 +201,7 @@ public enum ConnectType { private TResultSinkType resultSinkType = TResultSinkType.MYSQL_PROTOCAL; - //internal call like `insert overwrite` need skipAuth + // internal call like `insert overwrite` need skipAuth // For example, `insert overwrite` only requires load permission, // but the internal implementation will call the logic of `AlterTable`. // In this case, `skipAuth` needs to be set to `true` to skip the permission check of `AlterTable` @@ -286,41 +297,30 @@ public ConnectType getConnectType() { return connectType; } - public ConnectContext() { - this((StreamConnection) null); - } - - public ConnectContext(String peerIdentity) { - this.connectType = ConnectType.ARROW_FLIGHT; - this.peerIdentity = peerIdentity; + public void init() { state = new QueryState(); returnRows = 0; isKilled = false; sessionVariable = VariableMgr.newSessionVariable(); - mysqlChannel = new DummyMysqlChannel(); command = MysqlCommand.COM_SLEEP; if (Config.use_fuzzy_session_variable) { sessionVariable.initFuzzyModeVariables(); } - setResultSinkType(TResultSinkType.ARROW_FLIGHT_PROTOCAL); + } + + public ConnectContext() { + this((StreamConnection) null); } public ConnectContext(StreamConnection connection) { connectType = ConnectType.MYSQL; - state = new QueryState(); - returnRows = 0; serverCapability = MysqlCapability.DEFAULT_CAPABILITY; - isKilled = false; if (connection != null) { mysqlChannel = new MysqlChannel(connection); } else { mysqlChannel = new DummyMysqlChannel(); } - sessionVariable = VariableMgr.newSessionVariable(); - command = MysqlCommand.COM_SLEEP; - if (Config.use_fuzzy_session_variable) { - sessionVariable.initFuzzyModeVariables(); - } + init(); } public boolean isTxnModel() { @@ -541,14 +541,70 @@ public void resetLoginTime() { this.loginTime = System.currentTimeMillis(); } + public void setRunningQuery(String runningQuery) { + this.runningQuery = runningQuery; + } + + public String getRunningQuery() { + return runningQuery; + } + + public void setResultFlightServerAddr(TNetworkAddress resultFlightServerAddr) { + this.resultFlightServerAddr = resultFlightServerAddr; + } + + public TNetworkAddress getResultFlightServerAddr() { + return resultFlightServerAddr; + } + + public void setResultInternalServiceAddr(TNetworkAddress resultInternalServiceAddr) { + this.resultInternalServiceAddr = resultInternalServiceAddr; + } + + public TNetworkAddress getResultInternalServiceAddr() { + return resultInternalServiceAddr; + } + + public void setResultOutputExprs(ArrayList resultOutputExprs) { + this.resultOutputExprs = resultOutputExprs; + } + + public ArrayList getResultOutputExprs() { + return resultOutputExprs; + } + + public void setFinstId(TUniqueId finstId) { + this.finstId = finstId; + } + + public TUniqueId getFinstId() { + return finstId; + } + + public void setReturnResultFromLocal(boolean returnResultFromLocal) { + this.returnResultFromLocal = returnResultFromLocal; + } + + public boolean isReturnResultFromLocal() { + return returnResultFromLocal; + } + public String getPeerIdentity() { return peerIdentity; } + public FlightSqlChannel getFlightSqlChannel() { + throw new RuntimeException("getFlightSqlChannel not in flight sql connection"); + } + public MysqlChannel getMysqlChannel() { return mysqlChannel; } + public String getClientIP() { + return getMysqlChannel().getRemoteHostPortString(); + } + public QueryState getState() { return state; } @@ -620,10 +676,14 @@ public StmtExecutor getExecutor() { return executor; } - public void cleanup() { + protected void closeChannel() { if (mysqlChannel != null) { mysqlChannel.close(); } + } + + public void cleanup() { + closeChannel(); threadLocalInfo.remove(); returnRows = 0; } @@ -701,27 +761,17 @@ public void setResultSinkType(TResultSinkType resultSinkType) { } public String getRemoteHostPortString() { - if (connectType.equals(ConnectType.MYSQL)) { - return getMysqlChannel().getRemoteHostPortString(); - } else if (connectType.equals(ConnectType.ARROW_FLIGHT)) { - // TODO Get flight client IP:Port - return peerIdentity; - } - return ""; + return getMysqlChannel().getRemoteHostPortString(); } // kill operation with no protect. public void kill(boolean killConnection) { - LOG.warn("kill query from {}, kill connection: {}", getRemoteHostPortString(), killConnection); + LOG.warn("kill query from {}, kill mysql connection: {}", getRemoteHostPortString(), killConnection); if (killConnection) { isKilled = true; - if (connectType.equals(ConnectType.MYSQL)) { - // Close channel to break connection with client - getMysqlChannel().close(); - } else if (connectType.equals(ConnectType.ARROW_FLIGHT)) { - connectScheduler.unregisterConnection(this); - } + // Close channel to break connection with client + closeChannel(); } // Now, cancel running query. cancelQuery(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java index 5640a8c034c74fd..3885af944b2ff9a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java @@ -17,11 +17,8 @@ package org.apache.doris.qe; -import org.apache.doris.analysis.ExecuteStmt; import org.apache.doris.analysis.InsertStmt; import org.apache.doris.analysis.KillStmt; -import org.apache.doris.analysis.LiteralExpr; -import org.apache.doris.analysis.NullLiteral; import org.apache.doris.analysis.QueryStmt; import org.apache.doris.analysis.SqlParser; import org.apache.doris.analysis.SqlScanner; @@ -35,7 +32,7 @@ import org.apache.doris.common.AnalysisException; import org.apache.doris.common.DdlException; import org.apache.doris.common.ErrorCode; -import org.apache.doris.common.ErrorReport; +import org.apache.doris.common.NotImplementedException; import org.apache.doris.common.UserException; import org.apache.doris.common.telemetry.Telemetry; import org.apache.doris.common.util.DebugUtil; @@ -47,7 +44,6 @@ import org.apache.doris.mysql.MysqlChannel; import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.mysql.MysqlPacket; -import org.apache.doris.mysql.MysqlProto; import org.apache.doris.mysql.MysqlSerializer; import org.apache.doris.mysql.MysqlServerStatusFlag; import org.apache.doris.nereids.exceptions.NotSupportedException; @@ -61,6 +57,7 @@ import org.apache.doris.thrift.TMasterOpResult; import org.apache.doris.thrift.TUniqueId; +import com.google.common.base.Preconditions; import com.google.common.base.Strings; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanKind; @@ -74,9 +71,6 @@ import java.io.IOException; import java.io.StringReader; import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.channels.AsynchronousCloseException; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -84,11 +78,16 @@ import java.util.UUID; /** - * Process one mysql connection, receive one packet, process, send one packet. + * Process one connection, the life cycle is the same as connection */ -public class ConnectProcessor { +public abstract class ConnectProcessor { + public enum ConnectType { + MYSQL, + ARROW_FLIGHT_SQL + } + private static final Logger LOG = LogManager.getLogger(ConnectProcessor.class); - private static final TextMapGetter> getter = + protected static final TextMapGetter> getter = new TextMapGetter>() { @Override public Iterable keys(Map carrier) { @@ -103,17 +102,17 @@ public String get(Map carrier, String key) { return ""; } }; - private final ConnectContext ctx; - private ByteBuffer packetBuf; - private StmtExecutor executor = null; + protected final ConnectContext ctx; + protected StmtExecutor executor = null; + protected ConnectType connectType; + protected ArrayList returnResultFromRemoteExecutor = new ArrayList<>(); public ConnectProcessor(ConnectContext context) { this.ctx = context; } - // COM_INIT_DB: change current database of this session. - private void handleInitDb() { - String fullDbName = new String(packetBuf.array(), 1, packetBuf.limit() - 1); + // change current database of this session. + protected void handleInitDb(String fullDbName) { if (Strings.isNullOrEmpty(ctx.getClusterName())) { ctx.getState().setError(ErrorCode.ERR_CLUSTER_NAME_NULL, "Please enter cluster"); return; @@ -160,24 +159,22 @@ private void handleInitDb() { ctx.getState().setOk(); } - // COM_QUIT: set killed flag and then return OK packet. - private void handleQuit() { + // set killed flag + protected void handleQuit() { ctx.setKilled(); ctx.getState().setOk(); } - // process COM_PING statement, do nothing, just return one OK packet. - private void handlePing() { + // do nothing + protected void handlePing() { ctx.getState().setOk(); } - private void handleStmtReset() { + protected void handleStmtReset() { ctx.getState().setOk(); } - private void handleStmtClose() { - packetBuf = packetBuf.order(ByteOrder.LITTLE_ENDIAN); - int stmtId = packetBuf.getInt(); + protected void handleStmtClose(int stmtId) { LOG.debug("close stmt id: {}", stmtId); ConnectContext.get().removePrepareStmt(String.valueOf(stmtId)); // No response packet is sent back to the client, see @@ -185,119 +182,27 @@ private void handleStmtClose() { ctx.getState().setNoop(); } - private void debugPacket() { - byte[] bytes = packetBuf.array(); - StringBuilder printB = new StringBuilder(); - for (byte b : bytes) { - if (Character.isLetterOrDigit((char) b & 0xFF)) { - char x = (char) b; - printB.append(x); - } else { - printB.append("0x" + Integer.toHexString(b & 0xFF)); - } - printB.append(" "); - } - LOG.debug("debug packet {}", printB.toString().substring(0, 200)); - } - - private static boolean isNull(byte[] bitmap, int position) { + protected static boolean isNull(byte[] bitmap, int position) { return (bitmap[position / 8] & (1 << (position & 7))) != 0; } - // process COM_EXECUTE, parse binary row data - // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html - private void handleExecute() { - // debugPacket(); - packetBuf = packetBuf.order(ByteOrder.LITTLE_ENDIAN); - // parse stmt_id, flags, params - int stmtId = packetBuf.getInt(); - // flag - packetBuf.get(); - // iteration_count always 1, - packetBuf.getInt(); - LOG.debug("execute prepared statement {}", stmtId); - PrepareStmtContext prepareCtx = ctx.getPreparedStmt(String.valueOf(stmtId)); - if (prepareCtx == null) { - LOG.debug("No such statement in context, stmtId:{}", stmtId); - ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR, - "msg: Not supported such prepared statement"); - return; - } - ctx.setStartTime(); - if (prepareCtx.stmt.getInnerStmt() instanceof QueryStmt) { - ctx.getState().setIsQuery(true); - } - prepareCtx.stmt.setIsPrepared(); - int paramCount = prepareCtx.stmt.getParmCount(); - // null bitmap - byte[] nullbitmapData = new byte[(paramCount + 7) / 8]; - packetBuf.get(nullbitmapData); - String stmtStr = ""; - try { - // new_params_bind_flag - if ((int) packetBuf.get() != 0) { - // parse params's types - for (int i = 0; i < paramCount; ++i) { - int typeCode = packetBuf.getChar(); - LOG.debug("code {}", typeCode); - prepareCtx.stmt.placeholders().get(i).setTypeCode(typeCode); - } - } - List realValueExprs = new ArrayList<>(); - // parse param data - for (int i = 0; i < paramCount; ++i) { - if (isNull(nullbitmapData, i)) { - realValueExprs.add(new NullLiteral()); - continue; - } - LiteralExpr l = prepareCtx.stmt.placeholders().get(i).createLiteralFromType(); - l.setupParamFromBinary(packetBuf); - realValueExprs.add(l); - } - ExecuteStmt executeStmt = new ExecuteStmt(String.valueOf(stmtId), realValueExprs); - // TODO set real origin statement - executeStmt.setOrigStmt(new OriginStatement("null", 0)); - executeStmt.setUserInfo(ctx.getCurrentUserIdentity()); - LOG.debug("executeStmt {}", executeStmt); - executor = new StmtExecutor(ctx, executeStmt); - ctx.setExecutor(executor); - executor.execute(); - stmtStr = executeStmt.toSql(); - } catch (Throwable e) { - // Catch all throwable. - // If reach here, maybe palo bug. - LOG.warn("Process one query failed because unknown reason: ", e); - ctx.getState().setError(ErrorCode.ERR_UNKNOWN_ERROR, - e.getClass().getSimpleName() + ", msg: " + e.getMessage()); - } - auditAfterExec(stmtStr, prepareCtx.stmt.getInnerStmt(), null, false); - } - - private void auditAfterExec(String origStmt, StatementBase parsedStmt, - Data.PQueryStatistics statistics, boolean printFuzzyVariables) { + protected void auditAfterExec(String origStmt, StatementBase parsedStmt, + Data.PQueryStatistics statistics, boolean printFuzzyVariables) { AuditLogHelper.logAuditLog(ctx, origStmt, parsedStmt, statistics, printFuzzyVariables); } - // Process COM_QUERY statement, // only throw an exception when there is a problem interacting with the requesting client - private void handleQuery(MysqlCommand mysqlCommand) { + protected void handleQuery(MysqlCommand mysqlCommand, String originStmt) { if (MetricRepo.isInit) { MetricRepo.COUNTER_REQUEST_ALL.increase(1L); } - // convert statement to Java string - byte[] bytes = packetBuf.array(); - int ending = packetBuf.limit() - 1; - while (ending >= 1 && bytes[ending] == '\0') { - ending--; - } - String originStmt = new String(bytes, 1, ending, StandardCharsets.UTF_8); String sqlHash = DigestUtils.md5Hex(originStmt); ctx.setSqlHash(sqlHash); ctx.getAuditEventBuilder().reset(); ctx.getAuditEventBuilder() .setTimestamp(System.currentTimeMillis()) - .setClientIp(ctx.getMysqlChannel().getRemoteHostPortString()) + .setClientIp(ctx.getClientIP()) .setUser(ClusterNamespace.getNameFromFullName(ctx.getQualifiedUser())) .setSqlHash(ctx.getSqlHash()); @@ -356,10 +261,25 @@ private void handleQuery(MysqlCommand mysqlCommand) { try { executor.execute(); - if (i != stmts.size() - 1) { - ctx.getState().serverStatus |= MysqlServerStatusFlag.SERVER_MORE_RESULTS_EXISTS; - if (ctx.getState().getStateType() != MysqlStateType.ERR) { - finalizeCommand(); + if (connectType.equals(ConnectType.MYSQL)) { + if (i != stmts.size() - 1) { + ctx.getState().serverStatus |= MysqlServerStatusFlag.SERVER_MORE_RESULTS_EXISTS; + if (ctx.getState().getStateType() != MysqlStateType.ERR) { + finalizeCommand(); + } + } + } else if (connectType.equals(ConnectType.ARROW_FLIGHT_SQL)) { + if (!ctx.isReturnResultFromLocal()) { + returnResultFromRemoteExecutor.add(executor); + } + Preconditions.checkState(ctx.getFlightSqlChannel().resultNum() <= 1); + if (ctx.getFlightSqlChannel().resultNum() == 1 && i != stmts.size() - 1) { + String errMsg = "Only be one stmt that returns the result and it is at the end. stmts.size(): " + + stmts.size(); + LOG.warn(errMsg); + ctx.getState().setError(ErrorCode.ERR_ARROW_FLIGHT_SQL_MUST_ONLY_RESULT_STMT, errMsg); + ctx.getState().setErrType(QueryState.ErrType.OTHER_ERR); + break; } } auditAfterExec(auditStmt, executor.getParsedStmt(), executor.getQueryStatisticsForAuditLog(), true); @@ -381,8 +301,8 @@ private void handleQuery(MysqlCommand mysqlCommand) { } // Use a handler for exception to avoid big try catch block which is a little hard to understand - private void handleQueryException(Throwable throwable, String origStmt, - StatementBase parsedStmt, Data.PQueryStatistics statistics) { + protected void handleQueryException(Throwable throwable, String origStmt, + StatementBase parsedStmt, Data.PQueryStatistics statistics) { if (ctx.getMinidump() != null) { MinidumpUtils.saveMinidumpString(ctx.getMinidump(), DebugUtil.printId(ctx.queryId())); } @@ -415,7 +335,7 @@ private void handleQueryException(Throwable throwable, String origStmt, } // analyze the origin stmt and return multi-statements - private List parse(String originStmt) throws AnalysisException, DdlException { + protected List parse(String originStmt) throws AnalysisException, DdlException { LOG.debug("the originStmts are: {}", originStmt); // Parse statement with parser generated by CUP&FLEX SqlScanner input = new SqlScanner(new StringReader(originStmt), ctx.getSessionVariable().getSqlMode()); @@ -443,9 +363,8 @@ private List parse(String originStmt) throws AnalysisException, D // Get the column definitions of a table @SuppressWarnings("rawtypes") - private void handleFieldList() throws IOException { + protected void handleFieldList(String tableName) { // Already get command code. - String tableName = new String(MysqlProto.readNulTerminateString(packetBuf), StandardCharsets.UTF_8); if (Strings.isNullOrEmpty(tableName)) { ctx.getState().setError(ErrorCode.ERR_UNKNOWN_TABLE, "Empty tableName"); return; @@ -463,18 +382,21 @@ private void handleFieldList() throws IOException { table.readLock(); try { - MysqlChannel channel = ctx.getMysqlChannel(); - MysqlSerializer serializer = channel.getSerializer(); - - // Send fields - // NOTE: Field list doesn't send number of fields - List baseSchema = table.getBaseSchema(); - for (Column column : baseSchema) { - serializer.reset(); - serializer.writeField(db.getFullName(), table.getName(), column, true); - channel.sendOnePacket(serializer.toByteBuffer()); + if (connectType.equals(ConnectType.MYSQL)) { + MysqlChannel channel = ctx.getMysqlChannel(); + MysqlSerializer serializer = channel.getSerializer(); + + // Send fields + // NOTE: Field list doesn't send number of fields + List baseSchema = table.getBaseSchema(); + for (Column column : baseSchema) { + serializer.reset(); + serializer.writeField(db.getFullName(), table.getName(), column, true); + channel.sendOnePacket(serializer.toByteBuffer()); + } + } else if (connectType.equals(ConnectType.ARROW_FLIGHT_SQL)) { + // TODO } - } catch (Throwable throwable) { handleQueryException(throwable, "", null, null); } finally { @@ -483,62 +405,9 @@ private void handleFieldList() throws IOException { ctx.getState().setEof(); } - private void dispatch() throws IOException { - int code = packetBuf.get(); - MysqlCommand command = MysqlCommand.fromCode(code); - if (command == null) { - ErrorReport.report(ErrorCode.ERR_UNKNOWN_COM_ERROR); - ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR, "Unknown command(" + code + ")"); - LOG.warn("Unknown command(" + code + ")"); - return; - } - LOG.debug("handle command {}", command); - ctx.setCommand(command); - ctx.setStartTime(); - - switch (command) { - case COM_INIT_DB: - handleInitDb(); - break; - case COM_QUIT: - handleQuit(); - break; - case COM_QUERY: - case COM_STMT_PREPARE: - ctx.initTracer("trace"); - Span rootSpan = ctx.getTracer().spanBuilder("handleQuery").setNoParent().startSpan(); - try (Scope scope = rootSpan.makeCurrent()) { - handleQuery(command); - } catch (Exception e) { - rootSpan.recordException(e); - throw e; - } finally { - rootSpan.end(); - } - break; - case COM_STMT_EXECUTE: - handleExecute(); - break; - case COM_FIELD_LIST: - handleFieldList(); - break; - case COM_PING: - handlePing(); - break; - case COM_STMT_RESET: - handleStmtReset(); - break; - case COM_STMT_CLOSE: - handleStmtClose(); - break; - default: - ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR, "Unsupported command(" + command + ")"); - LOG.warn("Unsupported command(" + command + ")"); - break; - } - } - - private ByteBuffer getResultPacket() { + // only Mysql protocol + protected ByteBuffer getResultPacket() { + Preconditions.checkState(connectType.equals(ConnectType.MYSQL)); MysqlPacket packet = ctx.getState().toResponsePacket(); if (packet == null) { // possible two cases: @@ -555,7 +424,9 @@ private ByteBuffer getResultPacket() { // When any request is completed, it will generally need to send a response packet to the client // This method is used to send a response packet to the client - private void finalizeCommand() throws IOException { + // only Mysql protocol + public void finalizeCommand() throws IOException { + Preconditions.checkState(connectType.equals(ConnectType.MYSQL)); ByteBuffer packet; if (executor != null && executor.isForwardToMaster() && ctx.getState().getStateType() != QueryState.MysqlStateType.ERR) { @@ -736,47 +607,9 @@ public TMasterOpResult proxyExecute(TMasterOpRequest request) { return result; } - // Process a MySQL request - public void processOnce() throws IOException { - // set status of query to OK. - ctx.getState().reset(); - executor = null; - - // reset sequence id of MySQL protocol - final MysqlChannel channel = ctx.getMysqlChannel(); - channel.setSequenceId(0); - // read packet from channel - try { - packetBuf = channel.fetchOnePacket(); - if (packetBuf == null) { - LOG.warn("Null packet received from network. remote: {}", channel.getRemoteHostPortString()); - throw new IOException("Error happened when receiving packet."); - } - } catch (AsynchronousCloseException e) { - // when this happened, timeout checker close this channel - // killed flag in ctx has been already set, just return - return; - } - - // dispatch - dispatch(); - // finalize - finalizeCommand(); - - ctx.setCommand(MysqlCommand.COM_SLEEP); - } - - public void loop() { - while (!ctx.isKilled()) { - try { - processOnce(); - } catch (Exception e) { - // TODO(zhaochun): something wrong - LOG.warn("Exception happened in one session(" + ctx + ").", e); - ctx.setKilled(); - break; - } - } + // only Mysql protocol + public void processOnce() throws IOException, NotImplementedException { + throw new NotImplementedException("Not Impl processOnce"); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectScheduler.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectScheduler.java index 70bfd7e2d8cdd97..5be4c330e0aec2f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectScheduler.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectScheduler.java @@ -102,7 +102,7 @@ public boolean registerConnection(ConnectContext ctx) { return false; } connectionMap.put(ctx.getConnectionId(), ctx); - if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT)) { + if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) { flightToken2ConnectionId.put(ctx.getPeerIdentity(), ctx.getConnectionId()); } return true; @@ -116,7 +116,7 @@ public void unregisterConnection(ConnectContext ctx) { conns.decrementAndGet(); } numberConnection.decrementAndGet(); - if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT)) { + if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) { flightToken2ConnectionId.remove(ctx.getPeerIdentity()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java index 752dedf80154957..9aaca2f8e2cdbb2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java @@ -19,7 +19,6 @@ import org.apache.doris.analysis.Analyzer; import org.apache.doris.analysis.DescriptorTable; -import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.StorageBackend; import org.apache.doris.catalog.Env; import org.apache.doris.catalog.FsBroker; @@ -68,6 +67,7 @@ import org.apache.doris.proto.InternalService.PExecPlanFragmentStartRequest; import org.apache.doris.proto.Types; import org.apache.doris.proto.Types.PUniqueId; +import org.apache.doris.qe.ConnectContext.ConnectType; import org.apache.doris.qe.QueryStatisticsItem.FragmentInstanceInfo; import org.apache.doris.rpc.BackendServiceProxy; import org.apache.doris.rpc.RpcException; @@ -210,11 +210,6 @@ public class Coordinator implements CoordInterface { private final List needCheckBackendExecStates = Lists.newArrayList(); private final List needCheckPipelineExecContexts = Lists.newArrayList(); private ResultReceiver receiver; - private TNetworkAddress resultFlightServerAddr; - private TNetworkAddress resultInternalServiceAddr; - private ArrayList resultOutputExprs; - - private TUniqueId finstId; private final List scanNodes; private int scanRangeNum = 0; // number of instances of this query, equals to @@ -283,22 +278,6 @@ public ExecutionProfile getExecutionProfile() { return executionProfile; } - public TNetworkAddress getResultFlightServerAddr() { - return resultFlightServerAddr; - } - - public TNetworkAddress getResultInternalServiceAddr() { - return resultInternalServiceAddr; - } - - public ArrayList getResultOutputExprs() { - return resultOutputExprs; - } - - public TUniqueId getFinstId() { - return finstId; - } - // True if all scan node are ExternalScanNode. private boolean isAllExternalScan = true; @@ -631,10 +610,14 @@ public void exec() throws Exception { TNetworkAddress execBeAddr = topParams.instanceExecParams.get(0).host; receiver = new ResultReceiver(queryId, topParams.instanceExecParams.get(0).instanceId, addressToBackendID.get(execBeAddr), toBrpcHost(execBeAddr), this.timeoutDeadline); - finstId = topParams.instanceExecParams.get(0).instanceId; - resultFlightServerAddr = toArrowFlightHost(execBeAddr); - resultInternalServiceAddr = toBrpcHost(execBeAddr); - resultOutputExprs = fragments.get(0).getOutputExprs(); + + if (!context.isReturnResultFromLocal()) { + Preconditions.checkState(context.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)); + context.setFinstId(topParams.instanceExecParams.get(0).instanceId); + context.setResultFlightServerAddr(toArrowFlightHost(execBeAddr)); + context.setResultInternalServiceAddr(toBrpcHost(execBeAddr)); + context.setResultOutputExprs(fragments.get(0).getOutputExprs()); + } if (LOG.isDebugEnabled()) { LOG.debug("dispatch result sink of query {} to {}", DebugUtil.printId(queryId), topParams.instanceExecParams.get(0).host); diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java new file mode 100644 index 000000000000000..63781d2addaff56 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java @@ -0,0 +1,272 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.qe; + +import org.apache.doris.analysis.ExecuteStmt; +import org.apache.doris.analysis.LiteralExpr; +import org.apache.doris.analysis.NullLiteral; +import org.apache.doris.analysis.QueryStmt; +import org.apache.doris.common.ErrorCode; +import org.apache.doris.common.ErrorReport; +import org.apache.doris.mysql.MysqlChannel; +import org.apache.doris.mysql.MysqlCommand; +import org.apache.doris.mysql.MysqlProto; + +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.context.Scope; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.AsynchronousCloseException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +/** + * Process one mysql connection, receive one packet, process, send one packet. + */ +public class MysqlConnectProcessor extends ConnectProcessor { + private static final Logger LOG = LogManager.getLogger(MysqlConnectProcessor.class); + + private ByteBuffer packetBuf; + + public MysqlConnectProcessor(ConnectContext context) { + super(context); + connectType = ConnectType.MYSQL; + } + + // COM_INIT_DB: change current database of this session. + private void handleInitDb() { + String fullDbName = new String(packetBuf.array(), 1, packetBuf.limit() - 1); + handleInitDb(fullDbName); + } + + private void handleStmtClose() { + packetBuf = packetBuf.order(ByteOrder.LITTLE_ENDIAN); + int stmtId = packetBuf.getInt(); + handleStmtClose(stmtId); + } + + private void debugPacket() { + byte[] bytes = packetBuf.array(); + StringBuilder printB = new StringBuilder(); + for (byte b : bytes) { + if (Character.isLetterOrDigit((char) b & 0xFF)) { + char x = (char) b; + printB.append(x); + } else { + printB.append("0x" + Integer.toHexString(b & 0xFF)); + } + printB.append(" "); + } + LOG.debug("debug packet {}", printB.toString().substring(0, 200)); + } + + // process COM_EXECUTE, parse binary row data + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html + private void handleExecute() { + // debugPacket(); + packetBuf = packetBuf.order(ByteOrder.LITTLE_ENDIAN); + // parse stmt_id, flags, params + int stmtId = packetBuf.getInt(); + // flag + packetBuf.get(); + // iteration_count always 1, + packetBuf.getInt(); + LOG.debug("execute prepared statement {}", stmtId); + PrepareStmtContext prepareCtx = ctx.getPreparedStmt(String.valueOf(stmtId)); + if (prepareCtx == null) { + LOG.debug("No such statement in context, stmtId:{}", stmtId); + ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR, + "msg: Not supported such prepared statement"); + return; + } + ctx.setStartTime(); + if (prepareCtx.stmt.getInnerStmt() instanceof QueryStmt) { + ctx.getState().setIsQuery(true); + } + prepareCtx.stmt.setIsPrepared(); + int paramCount = prepareCtx.stmt.getParmCount(); + // null bitmap + byte[] nullbitmapData = new byte[(paramCount + 7) / 8]; + packetBuf.get(nullbitmapData); + String stmtStr = ""; + try { + // new_params_bind_flag + if ((int) packetBuf.get() != 0) { + // parse params's types + for (int i = 0; i < paramCount; ++i) { + int typeCode = packetBuf.getChar(); + LOG.debug("code {}", typeCode); + prepareCtx.stmt.placeholders().get(i).setTypeCode(typeCode); + } + } + List realValueExprs = new ArrayList<>(); + // parse param data + for (int i = 0; i < paramCount; ++i) { + if (isNull(nullbitmapData, i)) { + realValueExprs.add(new NullLiteral()); + continue; + } + LiteralExpr l = prepareCtx.stmt.placeholders().get(i).createLiteralFromType(); + l.setupParamFromBinary(packetBuf); + realValueExprs.add(l); + } + ExecuteStmt executeStmt = new ExecuteStmt(String.valueOf(stmtId), realValueExprs); + // TODO set real origin statement + executeStmt.setOrigStmt(new OriginStatement("null", 0)); + executeStmt.setUserInfo(ctx.getCurrentUserIdentity()); + LOG.debug("executeStmt {}", executeStmt); + executor = new StmtExecutor(ctx, executeStmt); + ctx.setExecutor(executor); + executor.execute(); + stmtStr = executeStmt.toSql(); + } catch (Throwable e) { + // Catch all throwable. + // If reach here, maybe palo bug. + LOG.warn("Process one query failed because unknown reason: ", e); + ctx.getState().setError(ErrorCode.ERR_UNKNOWN_ERROR, + e.getClass().getSimpleName() + ", msg: " + e.getMessage()); + } + auditAfterExec(stmtStr, prepareCtx.stmt.getInnerStmt(), null, false); + } + + // Process COM_QUERY statement, + private void handleQuery(MysqlCommand mysqlCommand) { + // convert statement to Java string + byte[] bytes = packetBuf.array(); + int ending = packetBuf.limit() - 1; + while (ending >= 1 && bytes[ending] == '\0') { + ending--; + } + String originStmt = new String(bytes, 1, ending, StandardCharsets.UTF_8); + + handleQuery(mysqlCommand, originStmt); + } + + private void dispatch() throws IOException { + int code = packetBuf.get(); + MysqlCommand command = MysqlCommand.fromCode(code); + if (command == null) { + ErrorReport.report(ErrorCode.ERR_UNKNOWN_COM_ERROR); + ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR, "Unknown command(" + code + ")"); + LOG.warn("Unknown command(" + code + ")"); + return; + } + LOG.debug("handle command {}", command); + ctx.setCommand(command); + ctx.setStartTime(); + + switch (command) { + case COM_INIT_DB: + handleInitDb(); + break; + case COM_QUIT: + // COM_QUIT: set killed flag and then return OK packet. + handleQuit(); + break; + case COM_QUERY: + case COM_STMT_PREPARE: + // Process COM_QUERY statement, + ctx.initTracer("trace"); + Span rootSpan = ctx.getTracer().spanBuilder("handleQuery").setNoParent().startSpan(); + try (Scope scope = rootSpan.makeCurrent()) { + handleQuery(command); + } catch (Exception e) { + rootSpan.recordException(e); + throw e; + } finally { + rootSpan.end(); + } + break; + case COM_STMT_EXECUTE: + handleExecute(); + break; + case COM_FIELD_LIST: + handleFieldList(); + break; + case COM_PING: + // process COM_PING statement, do nothing, just return one OK packet. + handlePing(); + break; + case COM_STMT_RESET: + handleStmtReset(); + break; + case COM_STMT_CLOSE: + handleStmtClose(); + break; + default: + ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR, "Unsupported command(" + command + ")"); + LOG.warn("Unsupported command(" + command + ")"); + break; + } + } + + private void handleFieldList() { + String tableName = new String(MysqlProto.readNulTerminateString(packetBuf), StandardCharsets.UTF_8); + handleFieldList(tableName); + } + + // Process a MySQL request + public void processOnce() throws IOException { + // set status of query to OK. + ctx.getState().reset(); + executor = null; + + // reset sequence id of MySQL protocol + final MysqlChannel channel = ctx.getMysqlChannel(); + channel.setSequenceId(0); + // read packet from channel + try { + packetBuf = channel.fetchOnePacket(); + if (packetBuf == null) { + LOG.warn("Null packet received from network. remote: {}", channel.getRemoteHostPortString()); + throw new IOException("Error happened when receiving packet."); + } + } catch (AsynchronousCloseException e) { + // when this happened, timeout checker close this channel + // killed flag in ctx has been already set, just return + return; + } + + // dispatch + dispatch(); + // finalize + finalizeCommand(); + + ctx.setCommand(MysqlCommand.COM_SLEEP); + } + + public void loop() { + while (!ctx.isKilled()) { + try { + processOnce(); + } catch (Exception e) { + // TODO(zhaochun): something wrong + LOG.warn("Exception happened in one session(" + ctx + ").", e); + ctx.setKilled(); + break; + } + } + } +} + + diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/QueryState.java b/fe/fe-core/src/main/java/org/apache/doris/qe/QueryState.java index 3619a15876bc7f8..a5f52f26288c44a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/QueryState.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/QueryState.java @@ -25,6 +25,7 @@ // query state used to record state of query, maybe query status is better public class QueryState { + // Reused by arrow flight protocol public enum MysqlStateType { NOOP, // send nothing to remote OK, // send OK packet to remote diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java index 09afe499435c342..853793cf8f3be08 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java @@ -136,6 +136,7 @@ import org.apache.doris.proto.InternalService.PGroupCommitInsertResponse; import org.apache.doris.proto.Types; import org.apache.doris.qe.CommonResultSet.CommonResultSetMetaData; +import org.apache.doris.qe.ConnectContext.ConnectType; import org.apache.doris.qe.QueryState.MysqlStateType; import org.apache.doris.qe.cache.Cache; import org.apache.doris.qe.cache.CacheAnalyzer; @@ -147,7 +148,6 @@ import org.apache.doris.rpc.BackendServiceProxy; import org.apache.doris.rpc.RpcException; import org.apache.doris.service.FrontendOptions; -import org.apache.doris.service.arrowflight.FlightStatementExecutor; import org.apache.doris.statistics.ResultRow; import org.apache.doris.statistics.util.InternalQueryBuffer; import org.apache.doris.system.Backend; @@ -242,6 +242,7 @@ public class StmtExecutor { // this constructor is mainly for proxy public StmtExecutor(ConnectContext context, OriginStatement originStmt, boolean isProxy) { + Preconditions.checkState(context.getConnectType().equals(ConnectType.MYSQL)); this.context = context; this.originStmt = originStmt; this.serializer = context.getMysqlChannel().getSerializer(); @@ -262,7 +263,11 @@ public StmtExecutor(ConnectContext ctx, StatementBase parsedStmt) { this.context = ctx; this.parsedStmt = parsedStmt; this.originStmt = parsedStmt.getOrigStmt(); - this.serializer = context.getMysqlChannel().getSerializer(); + if (context.getConnectType() == ConnectType.MYSQL) { + this.serializer = context.getMysqlChannel().getSerializer(); + } else { + this.serializer = null; + } this.isProxy = false; if (parsedStmt instanceof LogicalPlanAdapter) { this.statementContext = ((LogicalPlanAdapter) parsedStmt).getStatementContext(); @@ -428,7 +433,7 @@ public boolean isAnalyzeStmt() { * isValuesOrConstantSelect: when this interface return true, original string is truncated at 1024 * * @return parsed and analyzed statement for Stale planner. - * an unresolved LogicalPlan wrapped with a LogicalPlanAdapter for Nereids. + * an unresolved LogicalPlan wrapped with a LogicalPlanAdapter for Nereids. */ public StatementBase getParsedStmt() { return parsedStmt; @@ -444,13 +449,16 @@ public void execute() throws Exception { public void execute(TUniqueId queryId) throws Exception { SessionVariable sessionVariable = context.getSessionVariable(); Span executeSpan = context.getTracer().spanBuilder("execute").setParent(Context.current()).startSpan(); + if (context.getConnectType() == ConnectType.ARROW_FLIGHT_SQL) { + context.setReturnResultFromLocal(true); + } try (Scope scope = executeSpan.makeCurrent()) { if (parsedStmt instanceof LogicalPlanAdapter || (parsedStmt == null && sessionVariable.isEnableNereidsPlanner())) { try { executeByNereids(queryId); } catch (NereidsException | ParseException e) { - if (context.getMinidump() != null) { + if (context.getMinidump() != null && context.getMinidump().toString(4) != null) { MinidumpUtils.saveMinidumpString(context.getMinidump(), DebugUtil.printId(context.queryId())); } // try to fall back to legacy planner @@ -597,12 +605,23 @@ private void parseByNereids() { } if (statements.size() <= originStmt.idx) { throw new ParseException("Nereids parse failed. Parser get " + statements.size() + " statements," - + " but we need at least " + originStmt.idx + " statements."); + + " but we need at least " + originStmt.idx + " statements."); } parsedStmt = statements.get(originStmt.idx); } + public void finalizeQuery() { + // The final profile report occurs after be returns the query data, and the profile cannot be + // received after unregisterQuery(), causing the instance profile to be lost, so we should wait + // for the profile before unregisterQuery(). + updateProfile(true); + QeProcessorImpl.INSTANCE.unregisterQuery(context.queryId()); + } + private void handleQueryWithRetry(TUniqueId queryId) throws Exception { + if (context.getConnectType() == ConnectType.ARROW_FLIGHT_SQL) { + context.setReturnResultFromLocal(false); + } // queue query here syncJournalIfNeeded(); QueueOfferToken offerRet = null; @@ -628,7 +647,7 @@ private void handleQueryWithRetry(TUniqueId queryId) throws Exception { try { for (int i = 0; i < retryTime; i++) { try { - //reset query id for each retry + // reset query id for each retry if (i > 0) { UUID uuid = UUID.randomUUID(); TUniqueId newQueryId = new TUniqueId(uuid.getMostSignificantBits(), @@ -643,17 +662,15 @@ private void handleQueryWithRetry(TUniqueId queryId) throws Exception { if (i == retryTime - 1) { throw e; } - if (!context.getMysqlChannel().isSend()) { + if (context.getConnectType().equals(ConnectType.MYSQL) && !context.getMysqlChannel().isSend()) { LOG.warn("retry {} times. stmt: {}", (i + 1), parsedStmt.getOrigStmt().originStmt); } else { throw e; } } finally { - // The final profile report occurs after be returns the query data, and the profile cannot be - // received after unregisterQuery(), causing the instance profile to be lost, so we should wait - // for the profile before unregisterQuery(). - updateProfile(true); - QeProcessorImpl.INSTANCE.unregisterQuery(context.queryId()); + if (context.isReturnResultFromLocal()) { + finalizeQuery(); + } } } } finally { @@ -1355,9 +1372,11 @@ private void handleCacheStmt(CacheAnalyzer cacheAnalyzer, MysqlChannel channel) // Process a select statement. private void handleQueryStmt() throws Exception { LOG.info("Handling query {} with query id {}", - originStmt.originStmt, DebugUtil.printId(context.queryId)); - // Every time set no send flag and clean all data in buffer - context.getMysqlChannel().reset(); + originStmt.originStmt, DebugUtil.printId(context.queryId)); + if (context.getConnectType() == ConnectType.MYSQL) { + // Every time set no send flag and clean all data in buffer + context.getMysqlChannel().reset(); + } Queriable queryStmt = (Queriable) parsedStmt; QueryDetail queryDetail = new QueryDetail(context.getStartTime(), @@ -1384,12 +1403,16 @@ private void handleQueryStmt() throws Exception { return; } - MysqlChannel channel = context.getMysqlChannel(); + MysqlChannel channel = null; + if (context.getConnectType().equals(ConnectType.MYSQL)) { + channel = context.getMysqlChannel(); + } boolean isOutfileQuery = queryStmt.hasOutFileClause(); // Sql and PartitionCache CacheAnalyzer cacheAnalyzer = new CacheAnalyzer(context, parsedStmt, planner); - if (cacheAnalyzer.enableCache() && !isOutfileQuery + // TODO support arrow flight sql + if (context.getConnectType().equals(ConnectType.MYSQL) && cacheAnalyzer.enableCache() && !isOutfileQuery && context.getSessionVariable().getSqlSelectLimit() < 0 && context.getSessionVariable().getDefaultOrderByLimit() < 0) { if (queryStmt instanceof QueryStmt || queryStmt instanceof LogicalPlanAdapter) { @@ -1399,6 +1422,7 @@ private void handleQueryStmt() throws Exception { } } + // TODO support arrow flight sql // handle select .. from xx limit 0 if (parsedStmt instanceof SelectStmt) { SelectStmt parsedSelectStmt = (SelectStmt) parsedStmt; @@ -1463,6 +1487,22 @@ private void sendResult(boolean isOutfileQuery, boolean isSendFields, Queriable } } + if (context.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) { + Preconditions.checkState(!context.isReturnResultFromLocal()); + profile.getSummaryProfile().setTempStartTime(); + if (coordBase.getInstanceTotalNum() > 1 && LOG.isDebugEnabled()) { + try { + LOG.debug("Finish to execute fragment. user: {}, db: {}, sql: {}, fragment instance num: {}", + context.getQualifiedUser(), context.getDatabase(), + parsedStmt.getOrigStmt().originStmt.replace("\n", " "), + coordBase.getInstanceTotalNum()); + } catch (Exception e) { + LOG.warn("Fail to print fragment concurrency for Query.", e); + } + } + return; + } + Span fetchResultSpan = context.getTracer().spanBuilder("fetch result").setParent(Context.current()).startSpan(); try (Scope scope = fetchResultSpan.makeCurrent()) { while (true) { @@ -1567,8 +1607,10 @@ private TWaitingTxnStatusResult getWaitingTxnStatus(TWaitingTxnStatusRequest req } private void handleTransactionStmt() throws Exception { - // Every time set no send flag and clean all data in buffer - context.getMysqlChannel().reset(); + if (context.getConnectType() == ConnectType.MYSQL) { + // Every time set no send flag and clean all data in buffer + context.getMysqlChannel().reset(); + } context.getState().setOk(0, 0, ""); // create plan if (context.getTxnEntry() != null && context.getTxnEntry().getRowsInTransaction() == 0 @@ -1768,8 +1810,8 @@ private void beginTxn(String dbName, String tblName) throws UserException, TExce // Process an insert statement. private void handleInsertStmt() throws Exception { - // Every time set no send flag and clean all data in buffer - if (context.getMysqlChannel() != null) { + if (context.getConnectType() == ConnectType.MYSQL) { + // Every time set no send flag and clean all data in buffer context.getMysqlChannel().reset(); } InsertStmt insertStmt = (InsertStmt) parsedStmt; @@ -1985,8 +2027,7 @@ private void handleInsertStmt() throws Exception { */ throwable = t; } finally { - updateProfile(true); - QeProcessorImpl.INSTANCE.unregisterQuery(context.queryId()); + finalizeQuery(); } // Go here, which means: @@ -2055,7 +2096,9 @@ private void handleExternalInsertStmt() { } private void handleUnsupportedStmt() { - context.getMysqlChannel().reset(); + if (context.getConnectType() == ConnectType.MYSQL) { + context.getMysqlChannel().reset(); + } // do nothing context.getState().setOk(); } @@ -2080,10 +2123,10 @@ private void handleSwitchStmt() throws AnalysisException { private void handlePrepareStmt() throws Exception { // register prepareStmt LOG.debug("add prepared statement {}, isBinaryProtocol {}", - prepareStmt.getName(), prepareStmt.isBinaryProtocol()); + prepareStmt.getName(), prepareStmt.isBinaryProtocol()); context.addPreparedStmt(prepareStmt.getName(), new PrepareStmtContext(prepareStmt, - context, planner, analyzer, prepareStmt.getName())); + context, planner, analyzer, prepareStmt.getName())); if (prepareStmt.isBinaryProtocol()) { sendStmtPrepareOK(); } @@ -2110,6 +2153,7 @@ private void handleUseStmt() throws AnalysisException { } private void sendMetaData(ResultSetMetaData metaData) throws IOException { + Preconditions.checkState(context.getConnectType() == ConnectType.MYSQL); // sends how many columns serializer.reset(); serializer.writeVInt(metaData.getColumnCount()); @@ -2133,6 +2177,7 @@ private List exprToStringType(List exprs) { } private void sendStmtPrepareOK() throws IOException { + Preconditions.checkState(context.getConnectType() == ConnectType.MYSQL); // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response serializer.reset(); // 0x00 OK @@ -2170,6 +2215,7 @@ private void sendStmtPrepareOK() throws IOException { } private void sendFields(List colNames, List types) throws IOException { + Preconditions.checkState(context.getConnectType() == ConnectType.MYSQL); // sends how many columns serializer.reset(); serializer.writeVInt(colNames.size()); @@ -2201,24 +2247,33 @@ private void sendFields(List colNames, List types) throws IOExcept } public void sendResultSet(ResultSet resultSet) throws IOException { - context.updateReturnRows(resultSet.getResultRows().size()); - // Send meta data. - sendMetaData(resultSet.getMetaData()); + if (context.getConnectType().equals(ConnectType.MYSQL)) { + context.updateReturnRows(resultSet.getResultRows().size()); + // Send meta data. + sendMetaData(resultSet.getMetaData()); - // Send result set. - for (List row : resultSet.getResultRows()) { - serializer.reset(); - for (String item : row) { - if (item == null || item.equals(FeConstants.null_string)) { - serializer.writeNull(); - } else { - serializer.writeLenEncodedString(item); + // Send result set. + for (List row : resultSet.getResultRows()) { + serializer.reset(); + for (String item : row) { + if (item == null || item.equals(FeConstants.null_string)) { + serializer.writeNull(); + } else { + serializer.writeLenEncodedString(item); + } } + context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer()); } - context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer()); - } - context.getState().setEof(); + context.getState().setEof(); + } else if (context.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) { + context.updateReturnRows(resultSet.getResultRows().size()); + context.getFlightSqlChannel() + .addResult(DebugUtil.printId(context.queryId()), context.getRunningQuery(), resultSet); + context.getState().setEof(); + } else { + LOG.error("sendResultSet error connect type"); + } } // Process show statement @@ -2244,6 +2299,7 @@ private void handleLockTablesStmt() { } public void handleExplainStmt(String result, boolean isNereids) throws IOException { + // TODO support arrow flight sql ShowResultSetMetaData metaData = ShowResultSetMetaData.builder() .addColumn(new Column("Explain String" + (isNereids ? "(Nereids Planner)" : "(Old Planner)"), ScalarType.createVarchar(20))) @@ -2682,64 +2738,6 @@ public List executeInternalQuery() { } } - public void executeArrowFlightQuery(FlightStatementExecutor flightStatementExecutor) { - LOG.debug("ARROW FLIGHT QUERY: " + originStmt.toString()); - try { - try { - if (ConnectContext.get() != null - && ConnectContext.get().getSessionVariable().isEnableNereidsPlanner()) { - try { - parseByNereids(); - Preconditions.checkState(parsedStmt instanceof LogicalPlanAdapter, - "Nereids only process LogicalPlanAdapter," - + " but parsedStmt is " + parsedStmt.getClass().getName()); - context.getState().setNereids(true); - context.getState().setIsQuery(true); - planner = new NereidsPlanner(statementContext); - planner.plan(parsedStmt, context.getSessionVariable().toThrift()); - } catch (Exception e) { - LOG.warn("fall back to legacy planner, because: {}", e.getMessage(), e); - parsedStmt = null; - context.getState().setNereids(false); - analyzer = new Analyzer(context.getEnv(), context); - analyze(context.getSessionVariable().toThrift()); - } - } else { - analyzer = new Analyzer(context.getEnv(), context); - analyze(context.getSessionVariable().toThrift()); - } - } catch (Exception e) { - throw new RuntimeException("Failed to execute Arrow Flight SQL. " + Util.getRootCauseMessage(e), e); - } - coord = new Coordinator(context, analyzer, planner, context.getStatsErrorEstimator()); - profile.addExecutionProfile(coord.getExecutionProfile()); - try { - QeProcessorImpl.INSTANCE.registerQuery(context.queryId(), - new QeProcessorImpl.QueryInfo(context, originStmt.originStmt, coord)); - } catch (UserException e) { - throw new RuntimeException("Failed to execute Arrow Flight SQL. " + Util.getRootCauseMessage(e), e); - } - - Span queryScheduleSpan = context.getTracer() - .spanBuilder("Arrow Flight SQL schedule").setParent(Context.current()).startSpan(); - try (Scope scope = queryScheduleSpan.makeCurrent()) { - coord.exec(); - } catch (Exception e) { - queryScheduleSpan.recordException(e); - LOG.warn("Failed to coord exec Arrow Flight SQL, because: {}", e.getMessage(), e); - throw new RuntimeException(e.getMessage() + Util.getRootCauseMessage(e), e); - } finally { - queryScheduleSpan.end(); - } - } finally { - QeProcessorImpl.INSTANCE.unregisterQuery(context.queryId()); // TODO for query profile - } - flightStatementExecutor.setFinstId(coord.getFinstId()); - flightStatementExecutor.setResultFlightServerAddr(coord.getResultFlightServerAddr()); - flightStatementExecutor.setResultInternalServiceAddr(coord.getResultInternalServiceAddr()); - flightStatementExecutor.setResultOutputExprs(coord.getResultOutputExprs()); - } - private List convertResultBatchToResultRows(TResultBatch batch) { List columns = parsedStmt.getColLabels(); List resultRows = new ArrayList<>(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/FrontendServiceImpl.java b/fe/fe-core/src/main/java/org/apache/doris/service/FrontendServiceImpl.java index e6a883e07c728d0..09f2c0be5a3b7b8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/FrontendServiceImpl.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/FrontendServiceImpl.java @@ -83,15 +83,18 @@ import org.apache.doris.planner.OlapTableSink; import org.apache.doris.planner.StreamLoadPlanner; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.ConnectContext.ConnectType; import org.apache.doris.qe.ConnectProcessor; import org.apache.doris.qe.Coordinator; import org.apache.doris.qe.DdlExecutor; import org.apache.doris.qe.MasterCatalogExecutor; +import org.apache.doris.qe.MysqlConnectProcessor; import org.apache.doris.qe.OriginStatement; import org.apache.doris.qe.QeProcessorImpl; import org.apache.doris.qe.QueryState; import org.apache.doris.qe.StmtExecutor; import org.apache.doris.qe.VariableMgr; +import org.apache.doris.service.arrowflight.FlightSqlConnectProcessor; import org.apache.doris.statistics.ColumnStatistic; import org.apache.doris.statistics.ResultRow; import org.apache.doris.statistics.StatisticsCacheKey; @@ -1104,7 +1107,16 @@ public TMasterOpResult forward(TMasterOpRequest params) throws TException { ConnectContext context = new ConnectContext(); // Set current connected FE to the client address, so that we can know where this request come from. context.setCurrentConnectedFEIp(params.getClientNodeHost()); - ConnectProcessor processor = new ConnectProcessor(context); + + ConnectProcessor processor = null; + if (context.getConnectType().equals(ConnectType.MYSQL)) { + processor = new MysqlConnectProcessor(context); + } else if (context.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) { + processor = new FlightSqlConnectProcessor(context); + } else { + throw new TException("unknown ConnectType: " + context.getConnectType()); + } + TMasterOpResult result = processor.proxyExecute(params); if (QueryState.MysqlStateType.ERR.name().equalsIgnoreCase(result.getStatus())) { context.getState().setError(result.getStatus()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java index 0e73fbb2ad69bdb..d2f8b46b8936839 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/DorisFlightSqlProducer.java @@ -24,8 +24,11 @@ import org.apache.doris.common.util.Util; import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.QueryState.MysqlStateType; +import org.apache.doris.service.arrowflight.results.FlightSqlResultCacheEntry; import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager; +import com.google.common.base.Preconditions; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Message; @@ -63,12 +66,15 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import java.util.Collections; import java.util.List; +import java.util.Objects; +import java.util.UUID; public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable { private static final Logger LOG = LogManager.getLogger(DorisFlightSqlProducer.class); @@ -111,33 +117,72 @@ public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, fi ConnectContext connectContext = null; try { connectContext = flightSessionsManager.getConnectContext(context.peerIdentity()); - // Only for ConnectContext check timeout. - connectContext.setCommand(MysqlCommand.COM_QUERY); + // After the previous query was executed, there was no getStreamStatement to take away the result. + connectContext.getFlightSqlChannel().reset(); final String query = request.getQuery(); - final FlightStatementExecutor flightStatementExecutor = new FlightStatementExecutor(query, connectContext); - - flightStatementExecutor.executeQuery(); - - TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder() - .setStatementHandle(ByteString.copyFromUtf8( - DebugUtil.printId(flightStatementExecutor.getFinstId()) + ":" + query)).build(); - final Ticket ticket = new Ticket(Any.pack(ticketStatement).toByteArray()); - // TODO Support multiple endpoints. - Location location = Location.forGrpcInsecure(flightStatementExecutor.getResultFlightServerAddr().hostname, - flightStatementExecutor.getResultFlightServerAddr().port); - List endpoints = Collections.singletonList(new FlightEndpoint(ticket, location)); - - Schema schema; - schema = flightStatementExecutor.fetchArrowFlightSchema(5000); - if (schema == null) { - throw CallStatus.INTERNAL.withDescription("fetch arrow flight schema is null").toRuntimeException(); + final FlightSqlConnectProcessor flightSQLConnectProcessor = new FlightSqlConnectProcessor(connectContext); + + flightSQLConnectProcessor.handleQuery(query); + if (connectContext.getState().getStateType() == MysqlStateType.ERR) { + throw new RuntimeException("after handleQuery"); + } + + if (connectContext.isReturnResultFromLocal()) { + // set/use etc. stmt returns an OK result by default. + if (connectContext.getFlightSqlChannel().resultNum() == 0) { + // a random query id and add empty results + String queryId = UUID.randomUUID().toString(); + connectContext.getFlightSqlChannel().addEmptyResult(queryId, query); + + final ByteString handle = ByteString.copyFromUtf8(context.peerIdentity() + ":" + queryId); + TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder().setStatementHandle(handle) + .build(); + return getFlightInfoForSchema(ticketStatement, descriptor, + connectContext.getFlightSqlChannel().getResult(queryId).getVectorSchemaRoot().getSchema()); + } + + // A Flight Sql request can only contain one statement that returns result, + // otherwise expected thrown exception during execution. + Preconditions.checkState(connectContext.getFlightSqlChannel().resultNum() == 1); + + // The tokens used for authentication between getStreamStatement and getFlightInfoStatement + // are different. So put the peerIdentity into the ticket and then getStreamStatement is used to find + // the correct ConnectContext. + // queryId is used to find query results. + final ByteString handle = ByteString.copyFromUtf8( + context.peerIdentity() + ":" + DebugUtil.printId(connectContext.queryId())); + TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder().setStatementHandle(handle) + .build(); + return getFlightInfoForSchema(ticketStatement, descriptor, + connectContext.getFlightSqlChannel().getResult(DebugUtil.printId(connectContext.queryId())) + .getVectorSchemaRoot().getSchema()); + } else { + // Now only query stmt will pull results from BE. + final ByteString handle = ByteString.copyFromUtf8( + DebugUtil.printId(connectContext.getFinstId()) + ":" + query); + Schema schema = flightSQLConnectProcessor.fetchArrowFlightSchema(5000); + if (schema == null) { + throw CallStatus.INTERNAL.withDescription("fetch arrow flight schema is null").toRuntimeException(); + } + TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder().setStatementHandle(handle) + .build(); + Ticket ticket = new Ticket(Any.pack(ticketStatement).toByteArray()); + // TODO Support multiple endpoints. + Location location = Location.forGrpcInsecure(connectContext.getResultFlightServerAddr().hostname, + connectContext.getResultFlightServerAddr().port); + List endpoints = Collections.singletonList(new FlightEndpoint(ticket, location)); + // TODO Set in BE callback after query end, Client will not callback. + connectContext.setCommand(MysqlCommand.COM_SLEEP); + return new FlightInfo(schema, descriptor, endpoints, -1, -1); } - // TODO Set in BE callback after query end, Client client will not callback by default. - connectContext.setCommand(MysqlCommand.COM_SLEEP); - return new FlightInfo(schema, descriptor, endpoints, -1, -1); } catch (Exception e) { if (null != connectContext) { connectContext.setCommand(MysqlCommand.COM_SLEEP); + String errMsg = "get flight info statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage( + e) + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: " + + connectContext.getState().getErrorMessage(); + LOG.warn(errMsg, e); + throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException(); } LOG.warn("get flight info statement failed, " + e.getMessage(), e); throw CallStatus.INTERNAL.withDescription(Util.getRootCauseMessage(e)).withCause(e).toRuntimeException(); @@ -146,8 +191,7 @@ public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, fi @Override public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command, - final CallContext context, - final FlightDescriptor descriptor) { + final CallContext context, final FlightDescriptor descriptor) { throw CallStatus.UNIMPLEMENTED.withDescription("getFlightInfoPreparedStatement unimplemented") .toRuntimeException(); } @@ -158,6 +202,42 @@ public SchemaResult getSchemaStatement(final CommandStatementQuery command, fina throw CallStatus.UNIMPLEMENTED.withDescription("getSchemaStatement unimplemented").toRuntimeException(); } + @Override + public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context, + final ServerStreamListener listener) { + ConnectContext connectContext = null; + final String handle = ticketStatementQuery.getStatementHandle().toStringUtf8(); + String[] handleParts = handle.split(":"); + String executedPeerIdentity = handleParts[0]; + String queryId = handleParts[1]; + try { + // The tokens used for authentication between getStreamStatement and getFlightInfoStatement are different. + connectContext = flightSessionsManager.getConnectContext(executedPeerIdentity); + final FlightSqlResultCacheEntry flightSqlResultCacheEntry = Objects.requireNonNull( + connectContext.getFlightSqlChannel().getResult(queryId)); + final VectorSchemaRoot vectorSchemaRoot = flightSqlResultCacheEntry.getVectorSchemaRoot(); + listener.start(vectorSchemaRoot); + listener.putNext(); + } catch (Exception e) { + listener.error(e); + if (null != connectContext) { + String errMsg = "get stream statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e) + + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: " + + connectContext.getState().getErrorMessage(); + LOG.warn(errMsg, e); + throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException(); + } + LOG.warn("get stream statement failed, " + e.getMessage(), e); + throw CallStatus.INTERNAL.withDescription(Util.getRootCauseMessage(e)).withCause(e).toRuntimeException(); + } finally { + listener.completed(); + if (null != connectContext) { + // The result has been sent, delete it. + connectContext.getFlightSqlChannel().invalidate(queryId); + } + } + } + @Override public void close() throws Exception { AutoCloseables.close(rootAllocator); @@ -180,8 +260,7 @@ public void doExchange(CallContext context, FlightStream reader, ServerStreamLis } @Override - public Runnable acceptPutStatement(CommandStatementUpdate command, - CallContext context, FlightStream flightStream, + public Runnable acceptPutStatement(CommandStatementUpdate command, CallContext context, FlightStream flightStream, StreamListener ackStream) { throw CallStatus.UNIMPLEMENTED.withDescription("acceptPutStatement unimplemented").toRuntimeException(); } @@ -219,8 +298,7 @@ public FlightInfo getFlightInfoTypeInfo(CommandGetXdbcTypeInfo request, CallCont } @Override - public void getStreamTypeInfo(CommandGetXdbcTypeInfo request, CallContext context, - ServerStreamListener listener) { + public void getStreamTypeInfo(CommandGetXdbcTypeInfo request, CallContext context, ServerStreamListener listener) { throw CallStatus.UNIMPLEMENTED.withDescription("getStreamTypeInfo unimplemented").toRuntimeException(); } @@ -323,12 +401,6 @@ public void getStreamCrossReference(CommandGetCrossReference command, CallContex throw CallStatus.UNIMPLEMENTED.withDescription("getStreamCrossReference unimplemented").toRuntimeException(); } - @Override - public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context, - final ServerStreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("getStreamStatement unimplemented").toRuntimeException(); - } - private FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor, final Schema schema) { final Ticket ticket = new Ticket(Any.pack(request).toByteArray()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightStatementExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlConnectProcessor.java similarity index 65% rename from fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightStatementExecutor.java rename to fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlConnectProcessor.java index 8c9cdf124f3485f..ef5b53c2d1ff757 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightStatementExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/FlightSqlConnectProcessor.java @@ -14,18 +14,19 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -// This file is copied from -// https://github.com/apache/arrow/blob/main/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/StatementContext.java -// and modified by Doris package org.apache.doris.service.arrowflight; import org.apache.doris.analysis.Expr; +import org.apache.doris.common.ErrorCode; +import org.apache.doris.common.ErrorReport; import org.apache.doris.common.Status; import org.apache.doris.common.util.DebugUtil; +import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.proto.InternalService; import org.apache.doris.proto.Types; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.ConnectProcessor; import org.apache.doris.qe.StmtExecutor; import org.apache.doris.rpc.BackendServiceProxy; import org.apache.doris.rpc.RpcException; @@ -33,112 +34,84 @@ import org.apache.doris.thrift.TStatusCode; import org.apache.doris.thrift.TUniqueId; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.context.Scope; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import java.io.ByteArrayInputStream; import java.util.ArrayList; import java.util.List; -import java.util.Objects; -import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -public final class FlightStatementExecutor implements AutoCloseable { - private ConnectContext connectContext; - private final String query; - private TUniqueId queryId; - private TUniqueId finstId; - private TNetworkAddress resultFlightServerAddr; - private TNetworkAddress resultInternalServiceAddr; - private ArrayList resultOutputExprs; - - public FlightStatementExecutor(final String query, ConnectContext connectContext) { - this.query = query; - this.connectContext = connectContext; - connectContext.setThreadLocalInfo(); - } - - public void setQueryId(TUniqueId queryId) { - this.queryId = queryId; - } - - public void setFinstId(TUniqueId finstId) { - this.finstId = finstId; - } - - public void setResultFlightServerAddr(TNetworkAddress resultFlightServerAddr) { - this.resultFlightServerAddr = resultFlightServerAddr; - } - - public void setResultInternalServiceAddr(TNetworkAddress resultInternalServiceAddr) { - this.resultInternalServiceAddr = resultInternalServiceAddr; - } - - public void setResultOutputExprs(ArrayList resultOutputExprs) { - this.resultOutputExprs = resultOutputExprs; - } - - public String getQuery() { - return query; - } - - public TUniqueId getQueryId() { - return queryId; - } - - public TUniqueId getFinstId() { - return finstId; - } - - public TNetworkAddress getResultFlightServerAddr() { - return resultFlightServerAddr; - } - - public TNetworkAddress getResultInternalServiceAddr() { - return resultInternalServiceAddr; - } - - public ArrayList getResultOutputExprs() { - return resultOutputExprs; - } - - @Override - public boolean equals(final Object other) { - if (!(other instanceof FlightStatementExecutor)) { - return false; +/** + * Process one flgiht sql connection. + */ +public class FlightSqlConnectProcessor extends ConnectProcessor implements AutoCloseable { + private static final Logger LOG = LogManager.getLogger(FlightSqlConnectProcessor.class); + + public FlightSqlConnectProcessor(ConnectContext context) { + super(context); + connectType = ConnectType.ARROW_FLIGHT_SQL; + context.setThreadLocalInfo(); + context.setReturnResultFromLocal(true); + } + + public void prepare(MysqlCommand command) { + // set status of query to OK. + ctx.getState().reset(); + executor = null; + + if (command == null) { + ErrorReport.report(ErrorCode.ERR_UNKNOWN_COM_ERROR); + ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR, "Unknown command(" + command.toString() + ")"); + LOG.warn("Unknown command(" + command + ")"); + return; } - return this == other; + LOG.debug("arrow flight sql handle command {}", command); + ctx.setCommand(command); + ctx.setStartTime(); } - @Override - public int hashCode() { - return Objects.hash(this); - } + public void handleQuery(String query) { + MysqlCommand command = MysqlCommand.COM_QUERY; + prepare(command); - public void executeQuery() { - try { - UUID uuid = UUID.randomUUID(); - TUniqueId queryId = new TUniqueId(uuid.getMostSignificantBits(), uuid.getLeastSignificantBits()); - setQueryId(queryId); - connectContext.setQueryId(queryId); - StmtExecutor stmtExecutor = new StmtExecutor(connectContext, getQuery()); - connectContext.setExecutor(stmtExecutor); - stmtExecutor.executeArrowFlightQuery(this); + ctx.setRunningQuery(query); + ctx.initTracer("trace"); + Span rootSpan = ctx.getTracer().spanBuilder("handleQuery").setNoParent().startSpan(); + try (Scope scope = rootSpan.makeCurrent()) { + handleQuery(command, query); } catch (Exception e) { - throw new RuntimeException("Failed to coord exec", e); + rootSpan.recordException(e); + throw e; + } finally { + rootSpan.end(); } } + // TODO + // private void handleInitDb() { + // handleInitDb(fullDbName); + // } + + // TODO + // private void handleFieldList() { + // handleFieldList(tableName); + // } + public Schema fetchArrowFlightSchema(int timeoutMs) { - TNetworkAddress address = getResultInternalServiceAddr(); - TUniqueId tid = getFinstId(); - ArrayList resultOutputExprs = getResultOutputExprs(); + TNetworkAddress address = ctx.getResultInternalServiceAddr(); + TUniqueId tid = ctx.getFinstId(); + ArrayList resultOutputExprs = ctx.getResultOutputExprs(); Types.PUniqueId finstId = Types.PUniqueId.newBuilder().setHi(tid.hi).setLo(tid.lo).build(); try { InternalService.PFetchArrowFlightSchemaRequest request = @@ -156,7 +129,7 @@ public Schema fetchArrowFlightSchema(int timeoutMs) { } TStatusCode code = TStatusCode.findByValue(pResult.getStatus().getStatusCode()); if (code != TStatusCode.OK) { - Status status = null; + Status status = new Status(); status.setPstatus(pResult.getStatus()); throw new RuntimeException(String.format("fetch arrow flight schema failed, finstId: %s, errmsg: %s", DebugUtil.printId(tid), status)); @@ -204,6 +177,14 @@ public Schema fetchArrowFlightSchema(int timeoutMs) { @Override public void close() throws Exception { + ctx.setCommand(MysqlCommand.COM_SLEEP); + // TODO support query profile + for (StmtExecutor asynExecutor : returnResultFromRemoteExecutor) { + asynExecutor.finalizeQuery(); + } + returnResultFromRemoteExecutor.clear(); ConnectContext.remove(); } } + + diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java new file mode 100644 index 000000000000000..094d713d5b68104 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlChannel.java @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.service.arrowflight.results; + +import org.apache.doris.catalog.Column; +import org.apache.doris.common.FeConstants; +import org.apache.doris.qe.ResultSet; +import org.apache.doris.qe.ResultSetMetaData; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.RemovalListener; +import com.google.common.cache.RemovalNotification; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType.Utf8; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.jetbrains.annotations.NotNull; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +public class FlightSqlChannel { + private final Cache resultCache; + private final BufferAllocator allocator; + + public FlightSqlChannel() { + resultCache = + CacheBuilder.newBuilder() + .maximumSize(100) + .expireAfterWrite(10, TimeUnit.MINUTES) + .removalListener(new ResultRemovalListener()) + .build(); + allocator = new RootAllocator(Long.MAX_VALUE); + } + + // TODO + public String getRemoteIp() { + return "0.0.0.0"; + } + + // TODO + public String getRemoteHostPortString() { + return "0.0.0.0:0"; + } + + public void addResult(String queryId, String runningQuery, ResultSet resultSet) { + List schemaFields = new ArrayList<>(); + List dataFields = new ArrayList<>(); + List> resultData = resultSet.getResultRows(); + ResultSetMetaData metaData = resultSet.getMetaData(); + + // TODO: only support varchar type + for (Column col : metaData.getColumns()) { + schemaFields.add(new Field(col.getName(), FieldType.nullable(new Utf8()), null)); + VarCharVector varCharVector = new VarCharVector(col.getName(), allocator); + varCharVector.allocateNew(); + varCharVector.setValueCount(resultData.size()); + dataFields.add(varCharVector); + } + + for (int i = 0; i < resultData.size(); i++) { + List row = resultData.get(i); + for (int j = 0; j < row.size(); j++) { + String item = row.get(j); + if (item == null || item.equals(FeConstants.null_string)) { + dataFields.get(j).setNull(i); + } else { + ((VarCharVector) dataFields.get(j)).setSafe(i, item.getBytes()); + } + } + } + VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(schemaFields, dataFields); + final FlightSqlResultCacheEntry flightSqlResultCacheEntry = new FlightSqlResultCacheEntry(vectorSchemaRoot, + runningQuery); + resultCache.put(queryId, flightSqlResultCacheEntry); + } + + public void addEmptyResult(String queryId, String query) { + List schemaFields = new ArrayList<>(); + List dataFields = new ArrayList<>(); + schemaFields.add(new Field("StatusResult", FieldType.nullable(new Utf8()), null)); + VarCharVector varCharVector = new VarCharVector("StatusResult", allocator); + varCharVector.allocateNew(); + varCharVector.setValueCount(1); + varCharVector.setSafe(0, "OK".getBytes()); + dataFields.add(varCharVector); + + VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(schemaFields, dataFields); + final FlightSqlResultCacheEntry flightSqlResultCacheEntry = new FlightSqlResultCacheEntry(vectorSchemaRoot, + query); + resultCache.put(queryId, flightSqlResultCacheEntry); + } + + public FlightSqlResultCacheEntry getResult(String queryId) { + return resultCache.getIfPresent(queryId); + } + + public void invalidate(String handle) { + resultCache.invalidate(handle); + } + + public long resultNum() { + return resultCache.size(); + } + + public void reset() { + resultCache.invalidateAll(); + } + + public void close() { + reset(); + } + + private static class ResultRemovalListener implements RemovalListener { + @Override + public void onRemoval(@NotNull final RemovalNotification notification) { + try { + AutoCloseables.close(notification.getValue()); + } catch (final Exception e) { + // swallow + } + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlResultCacheEntry.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlResultCacheEntry.java new file mode 100644 index 000000000000000..12ce04ca8ed842b --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/results/FlightSqlResultCacheEntry.java @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.service.arrowflight.results; + +import org.apache.arrow.vector.VectorSchemaRoot; + +import java.util.Objects; + + +public final class FlightSqlResultCacheEntry implements AutoCloseable { + + private final VectorSchemaRoot vectorSchemaRoot; + private final String query; + + public FlightSqlResultCacheEntry(final VectorSchemaRoot vectorSchemaRoot, final String query) { + this.vectorSchemaRoot = Objects.requireNonNull(vectorSchemaRoot, "result cannot be null."); + this.query = query; + } + + public VectorSchemaRoot getVectorSchemaRoot() { + return vectorSchemaRoot; + } + + public String getQuery() { + return query; + } + + @Override + public void close() throws Exception { + vectorSchemaRoot.clear(); + } + + @Override + public boolean equals(final Object other) { + if (this == other) { + return true; + } + if (!(other instanceof VectorSchemaRoot)) { + return false; + } + final VectorSchemaRoot that = (VectorSchemaRoot) other; + return vectorSchemaRoot.equals(that); + } + + @Override + public int hashCode() { + return Objects.hash(vectorSchemaRoot); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java index ed01098c6756521..f850384d4ed96ca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsManager.java @@ -49,8 +49,8 @@ public interface FlightSessionsManager { */ ConnectContext createConnectContext(String peerIdentity); - public static ConnectContext buildConnectContext(String peerIdentity, UserIdentity userIdentity, String remoteIP) { - ConnectContext connectContext = new ConnectContext(peerIdentity); + static ConnectContext buildConnectContext(String peerIdentity, UserIdentity userIdentity, String remoteIP) { + ConnectContext connectContext = new FlightSqlConnectContext(peerIdentity); connectContext.setEnv(Env.getCurrentEnv()); connectContext.setStartTime(); connectContext.setCluster(SystemInfoService.DEFAULT_CLUSTER); diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java index ce12f610ea27a04..e1866b094b26410 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSessionsWithTokenManager.java @@ -58,6 +58,7 @@ public ConnectContext createConnectContext(String peerIdentity) { if (flightTokenDetails.getCreatedSession()) { return null; } + flightTokenDetails.setCreatedSession(true); return FlightSessionsManager.buildConnectContext(peerIdentity, flightTokenDetails.getUserIdentity(), flightTokenDetails.getRemoteIp()); } catch (IllegalArgumentException e) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSqlConnectContext.java b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSqlConnectContext.java new file mode 100644 index 000000000000000..615a7f66ddcfa7a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/service/arrowflight/sessions/FlightSqlConnectContext.java @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.service.arrowflight.sessions; + +import org.apache.doris.mysql.MysqlChannel; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.ConnectProcessor; +import org.apache.doris.service.arrowflight.results.FlightSqlChannel; +import org.apache.doris.thrift.TResultSinkType; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.io.IOException; + +public class FlightSqlConnectContext extends ConnectContext { + private static final Logger LOG = LogManager.getLogger(FlightSqlConnectContext.class); + + public FlightSqlConnectContext(String peerIdentity) { + this.connectType = ConnectType.ARROW_FLIGHT_SQL; + this.peerIdentity = peerIdentity; + mysqlChannel = null; // Use of MysqlChannel is not expected + flightSqlChannel = new FlightSqlChannel(); + setResultSinkType(TResultSinkType.ARROW_FLIGHT_PROTOCAL); + init(); + } + + @Override + public FlightSqlChannel getFlightSqlChannel() { + return flightSqlChannel; + } + + @Override + public MysqlChannel getMysqlChannel() { + throw new RuntimeException("getMysqlChannel not in mysql connection"); + } + + @Override + public String getClientIP() { + return flightSqlChannel.getRemoteHostPortString(); + } + + @Override + protected void closeChannel() { + if (flightSqlChannel != null) { + flightSqlChannel.close(); + } + } + + // kill operation with no protect. + @Override + public void kill(boolean killConnection) { + LOG.warn("kill query from {}, kill flight sql connection: {}", getRemoteHostPortString(), killConnection); + + if (killConnection) { + isKilled = true; + closeChannel(); + connectScheduler.unregisterConnection(this); + } + // Now, cancel running query. + cancelQuery(); + } + + @Override + public String getRemoteHostPortString() { + return getFlightSqlChannel().getRemoteHostPortString(); + } + + @Override + public void startAcceptQuery(ConnectProcessor connectProcessor) { + throw new RuntimeException("Flight Sql Not impl startAcceptQuery"); + } + + @Override + public void suspendAcceptQuery() { + throw new RuntimeException("Flight Sql Not impl suspendAcceptQuery"); + } + + @Override + public void resumeAcceptQuery() { + throw new RuntimeException("Flight Sql Not impl resumeAcceptQuery"); + } + + @Override + public void stopAcceptQuery() throws IOException { + throw new RuntimeException("Flight Sql Not impl stopAcceptQuery"); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java b/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java index 507102fb0d258d6..dcc0d85cb479a56 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java @@ -38,6 +38,7 @@ import org.apache.doris.mysql.MysqlChannel; import org.apache.doris.mysql.MysqlSerializer; import org.apache.doris.planner.OriginalPlanner; +import org.apache.doris.qe.ConnectContext.ConnectType; import org.apache.doris.rewrite.ExprRewriter; import org.apache.doris.service.FrontendOptions; import org.apache.doris.thrift.TQueryOptions; @@ -380,6 +381,10 @@ public void testKillOtherFail(@Mocked KillStmt killStmt, @Mocked SqlParser parse killCtx.kill(true); minTimes = 0; + killCtx.getConnectType(); + minTimes = 0; + result = ConnectType.MYSQL; + ConnectContext.get(); minTimes = 0; result = ctx; @@ -437,6 +442,10 @@ public void testKillOther(@Mocked KillStmt killStmt, @Mocked SqlParser parser, killCtx.kill(true); minTimes = 0; + killCtx.getConnectType(); + minTimes = 0; + result = ConnectType.MYSQL; + ConnectContext.get(); minTimes = 0; result = ctx;