diff --git a/include/common/expr_value.h b/include/common/expr_value.h index 4e9eae98..c5d2226f 100644 --- a/include/common/expr_value.h +++ b/include/common/expr_value.h @@ -167,6 +167,7 @@ struct ExprValue { case pb::HLL: case pb::HEX: case pb::TDIGEST: + case pb::JSON: str_val = value.string_val(); break; case pb::BITMAP: { @@ -190,6 +191,7 @@ struct ExprValue { float_precision_len = -1; str_val = value_str; if (primitive_type == pb::STRING + || primitive_type == pb::JSON || primitive_type == pb::HEX || primitive_type == pb::BITMAP || primitive_type == pb::HLL @@ -323,6 +325,7 @@ struct ExprValue { case pb::STRING: case pb::HEX: case pb::BITMAP: + case pb::JSON: case pb::TDIGEST: value->set_string_val(str_val); break; @@ -422,6 +425,7 @@ struct ExprValue { case pb::STRING: case pb::HEX: case pb::HLL: + case pb::JSON: case pb::TDIGEST: return str_val.length(); case pb::DATETIME: @@ -522,6 +526,9 @@ struct ExprValue { case pb::STRING: str_val = get_string(); break; + case pb::JSON: + str_val = get_string(); + break; case pb::BITMAP: { _u.bitmap = new(std::nothrow) Roaring(); if (str_val.size() > 0) { @@ -576,6 +583,7 @@ struct ExprValue { butil::MurmurHash3_x64_128(&_u, 8, seed, out); return out[0]; case pb::STRING: + case pb::JSON: case pb::HEX: { butil::MurmurHash3_x64_128(str_val.c_str(), str_val.size(), seed, out); return out[0]; @@ -627,6 +635,7 @@ struct ExprValue { } case pb::STRING: case pb::HEX: + case pb::JSON: case pb::HLL: case pb::TDIGEST: return str_val; @@ -741,6 +750,7 @@ struct ExprValue { (_u.double_val < other._u.double_val ? -1 : 0); case pb::STRING: case pb::HEX: + case pb::JSON: return str_val.compare(other.str_val); case pb::NULL_TYPE: return -1; @@ -795,7 +805,7 @@ struct ExprValue { } bool is_string() const { - return type == pb::STRING || type == pb::HEX || type == pb::BITMAP || type == pb::HLL || type == pb::TDIGEST; + return type == pb::STRING || type == pb::HEX || type == pb::BITMAP || type == pb::HLL || type == pb::TDIGEST || type == pb::JSON; } bool is_double() const { @@ -934,7 +944,7 @@ struct ExprValue { struct HashFunction { size_t operator()(const ExprValue& ev) const { - if (ev.type == pb::STRING || ev.type == pb::HEX) { + if (ev.type == pb::STRING || ev.type == pb::HEX || ev.type == pb::JSON) { return ev.hash(); } return ev._u.uint64_val; diff --git a/include/common/type_utils.h b/include/common/type_utils.h index 414f158c..b167d48a 100644 --- a/include/common/type_utils.h +++ b/include/common/type_utils.h @@ -360,6 +360,8 @@ inline uint8_t to_mysql_type(pb::PrimitiveType type) { case pb::BITMAP: case pb::TDIGEST: return MYSQL_TYPE_STRING; + case pb::JSON: + return MYSQL_TYPE_JSON; default: return MYSQL_TYPE_STRING; } @@ -399,6 +401,8 @@ inline std::string to_mysql_type_string(pb::PrimitiveType type) { case pb::BITMAP: case pb::TDIGEST: return "binary"; + case pb::JSON: + return "json"; default: return "text"; } diff --git a/include/exec/exec_node.h b/include/exec/exec_node.h index 606b487a..e4ffddea 100644 --- a/include/exec/exec_node.h +++ b/include/exec/exec_node.h @@ -303,7 +303,22 @@ class ExecNode { bool is_get_keypoint() { return _is_get_keypoint; } + bool set_has_optimized(bool has_optimized) { + _has_optimized = has_optimized; + } + bool has_optimized() { + if (_has_optimized) { + return true; + } + for (auto child : _children) { + if (child->has_optimized()) { + return true; + } + } + return false; + } protected: + bool _has_optimized = false; int64_t _limit = -1; int64_t _num_rows_returned = 0; bool _is_explain = false; diff --git a/include/exec/sort_node.h b/include/exec/sort_node.h index 57128088..4d375f4c 100644 --- a/include/exec/sort_node.h +++ b/include/exec/sort_node.h @@ -51,6 +51,9 @@ class SortNode : public ExecNode { for (auto expr : _slot_order_exprs) { ExprNode::create_pb_expr(sort_node->add_slot_order_exprs(), expr); } + if (_limit != -1) { + pb_node->set_limit(_limit); + } } void transfer_fetcher_pb(pb::FetcherNode* pb_fetcher) { diff --git a/include/expr/expr_node.h b/include/expr/expr_node.h index 53bcf96d..b5854caf 100644 --- a/include/expr/expr_node.h +++ b/include/expr/expr_node.h @@ -134,6 +134,7 @@ class ExprNode { // optimize or node to in node static void or_node_optimize(ExprNode** expr_node); + static bool like_node_optimize(ExprNode** root, std::vector& new_exprs); bool has_same_children(); bool is_vaild_or_optimize_tree(int32_t level, std::unordered_set* tuple_set); static int change_or_node_to_in(ExprNode** expr_node); diff --git a/include/expr/internal_functions.h b/include/expr/internal_functions.h index 95d869c4..c07b9b81 100644 --- a/include/expr/internal_functions.h +++ b/include/expr/internal_functions.h @@ -40,6 +40,7 @@ ExprValue pi(const std::vector& input); ExprValue greatest(const std::vector& input); ExprValue least(const std::vector& input); ExprValue pow(const std::vector& input); +ExprValue bit_count(const std::vector& input); //string functions ExprValue length(const std::vector& input); ExprValue bit_length(const std::vector& input); @@ -66,6 +67,11 @@ ExprValue lpad(const std::vector& input); ExprValue rpad(const std::vector& input); ExprValue instr(const std::vector& input); ExprValue json_extract(const std::vector& input); +ExprValue json_extract1(const std::vector& input); +ExprValue json_type(const std::vector& input); +ExprValue json_array(const std::vector& input); +ExprValue json_object(const std::vector& input); +ExprValue json_valid(const std::vector& input); ExprValue export_set(const std::vector& input); ExprValue to_base64(const std::vector& input); ExprValue from_base64(const std::vector& input); diff --git a/include/logical_plan/select_planner.h b/include/logical_plan/select_planner.h index 05183267..bdd4a375 100644 --- a/include/logical_plan/select_planner.h +++ b/include/logical_plan/select_planner.h @@ -64,6 +64,8 @@ class SelectPlanner : public LogicalPlanner { int parse_limit(); int subquery_rewrite(); + + int minmax_remove(); bool is_full_export(); diff --git a/include/runtime/runtime_state.h b/include/runtime/runtime_state.h index e284c618..f4f3e604 100644 --- a/include/runtime/runtime_state.h +++ b/include/runtime/runtime_state.h @@ -469,6 +469,7 @@ class RuntimeState { int range_count_limit = 0; int64_t _sql_exec_timeout = -1; bool _is_ddl_work = false; + bool must_have_one = false; private: bool _is_inited = false; bool _is_cancelled = false; diff --git a/include/runtime/sorter.h b/include/runtime/sorter.h index 7435343a..0a0b168f 100644 --- a/include/runtime/sorter.h +++ b/include/runtime/sorter.h @@ -26,13 +26,13 @@ class Sorter { public: Sorter(MemRowCompare* comp) : _comp(comp), _idx(0) { } - void add_batch(std::shared_ptr& batch) { + virtual void add_batch(std::shared_ptr& batch) { batch->reset(); _min_heap.push_back(batch); } - void sort(); - void merge_sort(); - int get_next(RowBatch* batch, bool* eos); + virtual void sort(); + virtual void merge_sort(); + virtual int get_next(RowBatch* batch, bool* eos); size_t batch_size() { return _min_heap.size(); @@ -40,10 +40,11 @@ class Sorter { private: void multi_sort(); void make_heap(); - void shiftdown(size_t index); + virtual void shiftdown(size_t index); -private: +protected: MemRowCompare* _comp; +private: std::vector> _min_heap; size_t _idx; }; diff --git a/include/runtime/topn_sorter.h b/include/runtime/topn_sorter.h new file mode 100644 index 00000000..e8fdc4d9 --- /dev/null +++ b/include/runtime/topn_sorter.h @@ -0,0 +1,52 @@ +// Copyright (c) 2018-present Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "common.h" +#include "row_batch.h" +#include "mem_row_compare.h" +#include "sorter.h" + +namespace baikaldb { +//对每个batch并行的做sort后,再用heap做归并 + +struct TopNHeapItem { + std::unique_ptr row; + int64_t idx; +}; + +class TopNSorter : public Sorter { +public: + TopNSorter(MemRowCompare* comp, int64_t limit) : Sorter(comp), _limit(limit) { + } + virtual void add_batch(std::shared_ptr& batch); + virtual void sort(); + virtual void merge_sort(){} + virtual int get_next(RowBatch* batch, bool* eos); +private: + virtual void shiftdown(size_t index); + virtual void shiftup(size_t index); + +private: + std::vector _mem_row_heap; + int64_t _limit = -1; + int64_t _current_count = 0; + int64_t _current_idx = 0; +}; +} + +/* vim: set ts=4 sw=4 sts=4 tw=100 */ diff --git a/include/session/user_info.h b/include/session/user_info.h index f08233cb..e0317af0 100644 --- a/include/session/user_info.h +++ b/include/session/user_info.h @@ -81,15 +81,7 @@ struct UserInfo { ~UserInfo() {} - bool is_exceed_quota() { - if (query_cost.get_time() > 1000000) { - query_cost.reset(); - query_count = 0; - return false; - } - return query_count++ > query_quota; - } - + bool is_exceed_quota(); bool connection_inc() { bool res = false; std::lock_guard guard(conn_mutex); diff --git a/include/sqlparser/sql_lex.l b/include/sqlparser/sql_lex.l index 7ca3a386..97ce4758 100644 --- a/include/sqlparser/sql_lex.l +++ b/include/sqlparser/sql_lex.l @@ -445,6 +445,8 @@ VAR_SAMP { un_reserved_keyword(yylval, yyscanner, parser); return VAR_SAMP; } \|\| { return OR; } \<\< { return LS_OP; } \>\> { return RS_OP; } +\-\> { return JS_OP; } +\-\>\> { return JS_OP1; } [0-9]+ { //integer diff --git a/include/sqlparser/sql_parse.y b/include/sqlparser/sql_parse.y index 83a2e00c..c7429a9e 100644 --- a/include/sqlparser/sql_parse.y +++ b/include/sqlparser/sql_parse.y @@ -481,7 +481,7 @@ extern int sql_error(YYLTYPE* yylloc, yyscan_t yyscanner, SqlParser* parser, con VAR_SAMP USER_AGG -%token EQ_OP ASSIGN_OP MOD_OP GE_OP GT_OP LE_OP LT_OP NE_OP AND_OP OR_OP NOT_OP LS_OP RS_OP CHINESE_DOT +%token EQ_OP ASSIGN_OP MOD_OP GE_OP GT_OP LE_OP LT_OP NE_OP AND_OP OR_OP NOT_OP LS_OP RS_OP CHINESE_DOT JS_OP JS_OP1 %token IDENT %token STRING_LIT INTEGER_LIT DECIMAL_LIT PLACE_HOLDER_LIT @@ -760,7 +760,8 @@ extern int sql_error(YYLTYPE* yylloc, yyscan_t yyscanner, SqlParser* parser, con %left EQ_OP NE_OP GE_OP GT_OP LE_OP LT_OP IS LIKE IN %left '|' %left '&' -%left LS_OP RS_OP +%left JS_OP1 +%left LS_OP RS_OP JS_OP %left '+' '-' %left '*' '/' MOD_OP MOD %left '^' @@ -1866,6 +1867,38 @@ SelectField: select_field->as_name = $5; $$ = select_field; } + | ColumnName JS_OP STRING_LIT { + SelectField* select_field = new_node(SelectField); + FuncExpr* fun = new_node(FuncExpr); + fun->fn_name = "json_extract1"; + fun->children.push_back($1, parser->arena); + fun->children.push_back($3, parser->arena); + select_field->expr = fun; + parser::String t1, t2; + t1 = "->\""; + t2 = "\""; + select_field->org_name = ((ColumnName*) $1)->name; + select_field->org_name.append("->\"", parser->arena); + select_field->org_name.append(((LiteralExpr*)$3)->_u.str_val.c_str(), parser->arena); + select_field->org_name.append("\"", parser->arena); + $$ = select_field; + } + | ColumnName JS_OP1 STRING_LIT { + SelectField* select_field = new_node(SelectField); + FuncExpr* fun = new_node(FuncExpr); + fun->fn_name = "json_extract"; + fun->children.push_back($1, parser->arena); + fun->children.push_back($3, parser->arena); + select_field->expr = fun; + parser::String t1, t2; + t1 = "->\""; + t2 = "\""; + select_field->org_name = ((ColumnName*) $1)->name; + select_field->org_name.append("->\"", parser->arena); + select_field->org_name.append(((LiteralExpr*)$3)->_u.str_val.c_str(), parser->arena); + select_field->org_name.append("\"", parser->arena); + $$ = select_field; + } ; FieldAsNameOpt: /* EMPTY */ diff --git a/proto/common.proto b/proto/common.proto index c269d97b..962aab1e 100755 --- a/proto/common.proto +++ b/proto/common.proto @@ -64,7 +64,8 @@ enum PrimitiveType { HEX = 20; BITMAP = 21; TDIGEST = 22; - MAXVALUE_TYPE = 23; + JSON = 23; + MAXVALUE_TYPE = 24; }; enum SchemaType { @@ -146,4 +147,4 @@ message ExprValue { optional float float_val = 7; optional double double_val = 8; optional bytes string_val = 9; -}; \ No newline at end of file +}; diff --git a/src/common/common.cpp b/src/common/common.cpp index 7a467e98..51dd89a7 100644 --- a/src/common/common.cpp +++ b/src/common/common.cpp @@ -479,6 +479,7 @@ int primitive_to_proto_type(pb::PrimitiveType type) { { pb::BOOL, FieldDescriptorProto::TYPE_BOOL}, { pb::BITMAP, FieldDescriptorProto::TYPE_BYTES}, { pb::TDIGEST, FieldDescriptorProto::TYPE_BYTES}, + { pb::JSON, FieldDescriptorProto::TYPE_BYTES}, { pb::NULL_TYPE, FieldDescriptorProto::TYPE_BOOL} }; if (_mysql_pb_type_mapping.count(type) == 0) { diff --git a/src/common/expr_value.cpp b/src/common/expr_value.cpp index 36c916c6..b4a27ebb 100644 --- a/src/common/expr_value.cpp +++ b/src/common/expr_value.cpp @@ -62,39 +62,40 @@ SerializeStatus ExprValue::serialize_to_mysql_text_packet(char* buf, size_t size } case pb::FLOAT: { size_t body_len = 0; - char tmp_buf[100] = {0}; - if (float_precision_len == -1) { - body_len = snprintf(tmp_buf, sizeof(tmp_buf), "%.6g", _u.float_val); - } else { - std::string format= "%." + std::to_string(float_precision_len) + "f"; - body_len = snprintf(tmp_buf, sizeof(tmp_buf), format.c_str(), _u.float_val); + std::ostringstream oss; + if (float_precision_len != -1) { + oss << std::fixed << std::setprecision(float_precision_len); } - + oss << _u.float_val; + std::string tmp_str = oss.str(); + body_len = tmp_str.length(); len = body_len + 1; if (len > size) { return STMPS_NEED_RESIZE; } // byte_array_append_length_coded_binary(body_len < 251LL) buf[0] = (uint8_t)(body_len & 0xff); - memcpy(buf + 1, tmp_buf, body_len); + memcpy(buf + 1, tmp_str.c_str(), body_len); return STMPS_SUCCESS; } case pb::DOUBLE: { size_t body_len = 0; - char tmp_buf[100] = {0}; - if (float_precision_len == -1) { - body_len = snprintf(tmp_buf, sizeof(tmp_buf), "%.12g", _u.double_val); + std::ostringstream oss; + if (float_precision_len != -1) { + oss << std::fixed << std::setprecision(float_precision_len); } else { - std::string format= "%." + std::to_string(float_precision_len) + "lf"; - body_len = snprintf(tmp_buf, sizeof(tmp_buf), format.c_str(), _u.double_val); + oss << std::setprecision(12); } + oss << _u.double_val; + std::string tmp_str = oss.str(); + body_len = tmp_str.length(); len = body_len + 1; if (len > size) { return STMPS_NEED_RESIZE; } // byte_array_append_length_coded_binary(body_len < 251LL) buf[0] = (uint8_t)(body_len & 0xff); - memcpy(buf + 1, tmp_buf, body_len); + memcpy(buf + 1, tmp_str.c_str(), body_len); return STMPS_SUCCESS; } case pb::HLL: { diff --git a/src/common/schema_factory.cpp b/src/common/schema_factory.cpp index 7100f705..1926433d 100644 --- a/src/common/schema_factory.cpp +++ b/src/common/schema_factory.cpp @@ -507,7 +507,7 @@ int SchemaFactory::update_table_internal(SchemaMapping& background, const pb::Sc field_info.default_expr_value.cast_to(field_info.type); } if (field_info.type == pb::STRING || field_info.type == pb::HLL - || field_info.type == pb::BITMAP || field_info.type == pb::TDIGEST) { + || field_info.type == pb::BITMAP || field_info.type == pb::TDIGEST || field_info.type == pb::JSON) { field_info.size = -1; } else { field_info.size = get_num_size(field_info.type); diff --git a/src/engine/transaction.cpp b/src/engine/transaction.cpp index 2582d4f9..632dfe5b 100644 --- a/src/engine/transaction.cpp +++ b/src/engine/transaction.cpp @@ -1444,7 +1444,9 @@ int Transaction::put_primary_columns(const TableKey& primary_key, SmartRecord re return -1; } if (_is_separate) { - add_kvop_put(key.data(), value, _write_ttl_timestamp_us, false); + // add_kvop_put(key.data(), value, _write_ttl_timestamp_us, false); + // 列不加ttl + add_kvop_put(key.data(), value, -1, false); } } return 0; diff --git a/src/exec/agg_node.cpp b/src/exec/agg_node.cpp index c4f01c54..6a71bd20 100644 --- a/src/exec/agg_node.cpp +++ b/src/exec/agg_node.cpp @@ -122,6 +122,13 @@ int AggNode::open(RuntimeState* state) { } _mem_row_desc = state->mem_row_desc(); + bool use_limit = false; + if (_limit > 0 && _agg_fn_calls.size() == 0) { + // case: select distinct f from test limit 100; + // 没有聚合函数时,可以使用limit + use_limit = true; + } + TimeCost cost; int64_t agg_time = 0; int64_t scan_time = 0; @@ -154,6 +161,9 @@ int AggNode::open(RuntimeState* state) { DB_WARNING_STATE(state, "memory limit exceeded"); return -1; } + if (use_limit && _hash_map.size() >= _limit) { + eos = true; + } // 对于用order by分组的特殊优化 //if (_agg_tuple_id == -1 && _limit != -1 && (int64_t)_hash_map.size() >= _limit) { // break; @@ -168,7 +178,7 @@ int AggNode::open(RuntimeState* state) { ExecNode* packet = get_parent_node(pb::PACKET_NODE); // baikaldb才有packet_node;只在baikaldb上产生数据 // TODB:join和子查询后续如果要完全推到store运行得注意 - if (packet != nullptr) { + if (packet != nullptr || _is_merger) { std::unique_ptr row = _mem_row_desc->fetch_mem_row(); uint8_t null_flag = 0; MutTableKey key; diff --git a/src/exec/fetcher_store.cpp b/src/exec/fetcher_store.cpp index 65e26e16..e17847bc 100755 --- a/src/exec/fetcher_store.cpp +++ b/src/exec/fetcher_store.cpp @@ -1087,16 +1087,16 @@ ErrorType FetcherStore::process_binlog_start(RuntimeState* state, pb::OpType op_ need_send_rollback = false; return; } - write_binlog_param.txn_id = state->txn_id; - write_binlog_param.log_id = log_id; - write_binlog_param.primary_region_id = client_conn->primary_region_id; - write_binlog_param.global_conn_id = client_conn->get_global_conn_id(); - write_binlog_param.username = client_conn->user_info->username; - write_binlog_param.ip = client_conn->ip; - write_binlog_param.client_conn = client_conn; - write_binlog_param.fetcher_store = this; binlog_ctx->set_start_ts(timestamp); } + write_binlog_param.txn_id = state->txn_id; + write_binlog_param.log_id = log_id; + write_binlog_param.primary_region_id = client_conn->primary_region_id; + write_binlog_param.global_conn_id = client_conn->get_global_conn_id(); + write_binlog_param.username = client_conn->user_info->username; + write_binlog_param.ip = client_conn->ip; + write_binlog_param.client_conn = client_conn; + write_binlog_param.fetcher_store = this; write_binlog_param.op_type = op_type; auto ret = binlog_ctx->write_binlog(&write_binlog_param); if (ret != E_OK) { diff --git a/src/exec/filter_node.cpp b/src/exec/filter_node.cpp index afccc865..4138d0ce 100644 --- a/src/exec/filter_node.cpp +++ b/src/exec/filter_node.cpp @@ -342,8 +342,8 @@ int FilterNode::expr_optimize(QueryContext* ctx) { DB_WARNING("ExecNode::optimize fail, ret:%d", ret); return ret; } - // sign => pred - std::map pred_map; + + std::vector like2range; for (auto& expr : _conjuncts) { //类型推导 ret = expr->expr_optimize(); @@ -352,7 +352,18 @@ int FilterNode::expr_optimize(QueryContext* ctx) { return ret; } ExprNode::or_node_optimize(&expr); - + bool t = ExprNode::like_node_optimize(&expr, like2range); + if (t) { + _has_optimized = true; + } + } + for (auto expr : like2range) { + _conjuncts.push_back(expr); + } + + // sign => pred + std::map pred_map; + for (auto& expr : _conjuncts) { //非bool型表达式判断 if (expr->col_type() != pb::BOOL) { ExprNode::_s_non_boolean_sql_cnts << 1; @@ -448,7 +459,7 @@ int FilterNode::expr_optimize(QueryContext* ctx) { bool all_const = true; for (uint32_t i = 1; i < expr->children_size(); i++) { // place holder被替换会导致下一次exec参数对不上 - if (!expr->children(i)->is_constant() || expr->children(i)->has_place_holder()) { + if (!expr->children(i)->is_constant()) { all_const = false; break; } @@ -497,6 +508,7 @@ int FilterNode::expr_optimize(QueryContext* ctx) { if (cut_preds.count(expr) == 1) { ExprNode::destroy_tree(expr); iter = _conjuncts.erase(iter); + _has_optimized = true; continue; } if (expr->is_constant()) { diff --git a/src/exec/select_manager_node.cpp b/src/exec/select_manager_node.cpp index 9dd39f7f..11ef1000 100755 --- a/src/exec/select_manager_node.cpp +++ b/src/exec/select_manager_node.cpp @@ -140,6 +140,14 @@ int SelectManagerNode::get_next(RuntimeState* state, RowBatch* batch, bool* eos) *eos = true; _num_rows_returned = _limit; return 0; + } else if (*eos == true) { + if (state->must_have_one && _num_rows_returned == 0) { + // 生成null返回 + std::unique_ptr row = state->mem_row_desc()->fetch_mem_row(); + batch->move_row(std::move(row)); + _num_rows_returned = 1; + return 0; + } } return 0; } diff --git a/src/exec/sort_node.cpp b/src/exec/sort_node.cpp index 322ff3a0..454a9624 100644 --- a/src/exec/sort_node.cpp +++ b/src/exec/sort_node.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "sort_node.h" +#include "topn_sorter.h" #include "runtime_state.h" #include "query_context.h" @@ -157,7 +158,11 @@ int SortNode::open(RuntimeState* state) { _mem_row_desc = state->mem_row_desc(); _mem_row_compare = std::make_shared( _slot_order_exprs, _is_asc, _is_null_first); - _sorter = std::make_shared(_mem_row_compare.get()); + if (_limit == -1) { + _sorter = std::make_shared(_mem_row_compare.get()); + } else { + _sorter = std::make_shared(_mem_row_compare.get(), _limit); + } bool eos = false; int count = 0; diff --git a/src/expr/expr_node.cpp b/src/expr/expr_node.cpp index 9ecdbc91..da291d44 100644 --- a/src/expr/expr_node.cpp +++ b/src/expr/expr_node.cpp @@ -341,6 +341,96 @@ void ExprNode::or_node_optimize(ExprNode** root) { return; } + +// 返回true代表进行了转换,需要标记 +bool ExprNode::like_node_optimize(ExprNode** root, std::vector& new_exprs) { + if (*root == nullptr) { + return false; + } + if ((*root)->node_type() != pb::LIKE_PREDICATE) { + return false; + } + auto expr = *root; + SlotRef* slot = (SlotRef*)expr->children(0); + if (slot->col_type() != pb::STRING) { + return false; + } + if (expr->children(1)->is_constant()) { + expr->children(1)->open(); + } else { + return false; + } + bool is_eq = false; + bool is_prefix = false; + ExprValue prefix_value(pb::STRING); + static_cast(expr)->hit_index(&is_eq, &is_prefix, &(prefix_value.str_val)); + std::string old_val = expr->children(1)->get_value(nullptr).get_string(); + if (!is_prefix || old_val.length() > prefix_value.str_val.length() + 1) { + return false; + } + if (is_eq) { + ScalarFnCall * eqexpr = new ScalarFnCall(); + SlotRef *sloteq = slot->clone(); + Literal *eqval = new Literal(prefix_value); + pb::ExprNode node; + node.set_node_type(pb::FUNCTION_CALL); + node.set_col_type(pb::BOOL); + pb::Function* func = node.mutable_fn(); + func->set_name("eq_string_string"); + func->set_fn_op(parser::FT_EQ); + eqexpr->init(node); + eqexpr->set_is_constant(false); + eqexpr->add_child(sloteq); + eqexpr->add_child(eqval); + *root = eqexpr; + ExprNode::destroy_tree(expr); + return true; + } else if (is_prefix) { + ScalarFnCall *geexpr = new ScalarFnCall(); + SlotRef *slotge = slot->clone(); + Literal *geval = new Literal(prefix_value); + pb::ExprNode node; + node.set_node_type(pb::FUNCTION_CALL); + node.set_col_type(pb::BOOL); + pb::Function* func = node.mutable_fn(); + func->set_name("ge_string_string"); + func->set_fn_op(parser::FT_GE); + geexpr->init(node); + geexpr->set_is_constant(false); + geexpr->add_child(slotge); + geexpr->add_child(geval); + *root = geexpr; + + ScalarFnCall *ltexpr = new ScalarFnCall(); + SlotRef* ltslot = slot->clone(); + ExprValue end_val = prefix_value; + int i = end_val.str_val.length() - 1; + for (; i >= 0; i --) { + uint8_t c = end_val.str_val[i]; + if (c == 255) { + continue; + } + end_val.str_val[i] = char(c + 1); + break; + } + end_val.str_val = end_val.str_val.substr(0, i + 1); + Literal *ltval = new Literal(end_val); + pb::ExprNode ltnode; + ltnode.set_node_type(pb::FUNCTION_CALL); + ltnode.set_col_type(pb::BOOL); + func = ltnode.mutable_fn(); + func->set_name("lt_string_string"); + func->set_fn_op(parser::FT_LT); + ltexpr->init(ltnode); + ltexpr->set_is_constant(false); + ltexpr->add_child(ltslot); + ltexpr->add_child(ltval); + new_exprs.push_back(ltexpr); + ExprNode::destroy_tree(expr); + return true; + } +} + int ExprNode::create_expr_node(const pb::ExprNode& node, ExprNode** expr_node) { switch (node.node_type()) { case pb::SLOT_REF: diff --git a/src/expr/fn_manager.cpp b/src/expr/fn_manager.cpp index b9bd0e61..e0280eff 100644 --- a/src/expr/fn_manager.cpp +++ b/src/expr/fn_manager.cpp @@ -124,6 +124,7 @@ void FunctionManager::register_operators() { register_object_ret("least", least, pb::DOUBLE); register_object_ret("ceil", ceil, pb::INT64); register_object_ret("ceiling", ceil, pb::INT64); + register_object_ret("bit_count", bit_count, pb::INT64); // str funcs register_object_ret("length", length, pb::INT64); @@ -154,6 +155,11 @@ void FunctionManager::register_operators() { register_object_ret("rpad", rpad, pb::STRING); register_object_ret("instr", instr, pb::INT32); register_object_ret("json_extract", json_extract, pb::STRING); + register_object_ret("json_extract1", json_extract1, pb::STRING); + register_object_ret("json_type", json_type, pb::STRING); + register_object_ret("json_array", json_array, pb::STRING); + register_object_ret("json_object", json_object, pb::STRING); + register_object_ret("json_valid", json_valid, pb::BOOL); register_object_ret("export_set", export_set, pb::STRING); register_object_ret("to_base64", to_base64, pb::STRING); register_object_ret("from_base64", from_base64, pb::STRING); diff --git a/src/expr/internal_functions.cpp b/src/expr/internal_functions.cpp index 9ed3b151..05fb2ddb 100644 --- a/src/expr/internal_functions.cpp +++ b/src/expr/internal_functions.cpp @@ -15,6 +15,7 @@ #include "internal_functions.h" #include #include +#include #include #include "hll_common.h" #include "datetime.h" @@ -328,6 +329,21 @@ ExprValue bit_length(const std::vector& input) { tmp._u.uint32_val = input[0].get_string().size() * 8; return tmp; } +ExprValue bit_count(const std::vector& input) { + if (input.size() != 1 || input[0].is_null()) { + return ExprValue::Null(); + } + ExprValue tmp = input[0]; + tmp.cast_to(pb::UINT64); + ExprValue res(pb::INT64); + while (tmp._u.uint64_val) { + if (tmp._u.uint64_val & 1) { + res._u.int64_val += 1; + } + tmp._u.uint64_val >>= 1; + } + return res; +} ExprValue lower(const std::vector& input) { if (input.size() == 0 || input[0].is_null()) { @@ -833,6 +849,158 @@ ExprValue json_extract(const std::vector& input) { return tmp; } +ExprValue json_extract1(const std::vector& input) { + if (input.size() != 2) { + return ExprValue::Null(); + } + + for (auto s : input) { + if (s.is_null()) { + return ExprValue::Null(); + } + } + std::string json_str = input[0].get_string(); + std::string path = input[1].get_string(); + if (path.length() > 0 && path[0] == '$') { + path.erase(path.begin()); + } else { + return ExprValue::Null(); + } + std::replace(path.begin(), path.end(), '.', '/'); + std::replace(path.begin(), path.end(), '[', '/'); + path.erase(std::remove(path.begin(), path.end(), ']'), path.end()); + + rapidjson::Document doc; + try { + doc.Parse<0>(json_str.c_str()); + if (doc.HasParseError()) { + rapidjson::ParseErrorCode code = doc.GetParseError(); + DB_WARNING("parse json_str error [code:%d][%s]", code, json_str.c_str()); + return ExprValue::Null(); + } + + } catch (...) { + DB_WARNING("parse json_str error [%s]", json_str.c_str()); + return ExprValue::Null(); + } + rapidjson::Pointer pointer(path.c_str()); + if (!pointer.IsValid()) { + DB_WARNING("invalid path: [%s]", path.c_str()); + return ExprValue::Null(); + } + + const rapidjson::Value *pValue = rapidjson::GetValueByPointer(doc, pointer); + if (pValue == nullptr) { + DB_WARNING("the path: [%s] does not exist in doc [%s]", path.c_str(), json_str.c_str()); + return ExprValue::Null(); + } + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + // TODO type on fly + ExprValue tmp(pb::STRING); + /* + if (pValue->IsString()) { + tmp.str_val = pValue->GetString(); + } else if (pValue->IsInt()) { + tmp.str_val = std::to_string(pValue->GetInt()); + } else if (pValue->IsInt64()) { + tmp.str_val = std::to_string(pValue->GetInt64()); + } else if (pValue->IsUint()) { + tmp.str_val = std::to_string(pValue->GetUint()); + } else if (pValue->IsUint64()) { + tmp.str_val = std::to_string(pValue->GetUint64()); + } else if (pValue->IsDouble()) { + tmp.str_val = std::to_string(pValue->GetDouble()); + } else if (pValue->IsFloat()) { + tmp.str_val = std::to_string(pValue->GetFloat()); + } else if (pValue->IsBool()) { + tmp.str_val = std::to_string(pValue->GetBool()); + } + */ + pValue->Accept(writer); + tmp.str_val = buffer.GetString(); + return tmp; +} + +ExprValue json_type(const std::vector& input) { + if (input.size() != 1) { + return ExprValue::Null(); + } + ExprValue res(pb::STRING); + if (input[0].is_int()) { + res.str_val = "INTEGER"; + } else if (input[0].is_double()) { + res.str_val = "DOUBLE"; + } else if (input[0].is_bool()) { + res.str_val = "BOOLEAN"; + } else if (input[0].is_null()) { + res.str_val = "NULL"; + } else if (input[0].is_string()) { + rapidjson::Document root; + root.Parse<0>(input[0].str_val.c_str()); + if (root.IsObject()) { + res.str_val = "OBJECT"; + } else if (root.IsArray()) { + res.str_val = "ARRAY"; + } else { + res.str_val = "STRING"; + } + } else { + return ExprValue::Null(); + } + return res; +} + +ExprValue json_array(const std::vector& input) { + if (input.size() < 1) { + return ExprValue::Null(); + } + rapidjson::Document list; + list.SetArray(); + for (size_t i = 0; i < input.size(); i ++) { + list.PushBack(rapidjson::StringRef(input[i].get_string().c_str()), list.GetAllocator()); + } + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + list.Accept(writer); + ExprValue res(pb::STRING); + res.str_val = buffer.GetString(); + return res; +} + +ExprValue json_object(const std::vector& input) { + if (input.size() < 1 || input.size() & 1) { + return ExprValue::Null(); + } + rapidjson::Document obj; + obj.SetObject(); + // TODO 相同的key会重复 + for (size_t i = 0; i < input.size() ; i += 2) { + obj.AddMember(rapidjson::StringRef(input[i].get_string().c_str()), rapidjson::StringRef(input[i + 1].get_string().c_str()), obj.GetAllocator()); + } + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + obj.Accept(writer); + ExprValue res(pb::STRING); + res.str_val = buffer.GetString(); + return res; +} + +ExprValue json_valid(const std::vector& input) { + if (input.size() != 1) { + return ExprValue::Null(); + } + if (input[0].type != pb::JSON && input[0].type != pb::STRING) { + return ExprValue::Null(); + } + rapidjson::Document obj; + obj.Parse<0>(input[0].str_val.c_str()); + if (obj.HasParseError()) { + return ExprValue::False(); + } + return ExprValue::True(); +} + ExprValue substring_index(const std::vector& input) { if (input.size() != 3) { return ExprValue::Null(); diff --git a/src/logical_plan/ddl_planner.cpp b/src/logical_plan/ddl_planner.cpp index bfb5b1c4..643755e8 100644 --- a/src/logical_plan/ddl_planner.cpp +++ b/src/logical_plan/ddl_planner.cpp @@ -1899,6 +1899,12 @@ pb::PrimitiveType DDLPlanner::to_baikal_type(parser::FieldType* field_type) { case parser::MYSQL_TYPE_TDIGEST: { return pb::TDIGEST; } break; + case parser::MYSQL_TYPE_BIT: { + return pb::INT64; + } break; + case parser::MYSQL_TYPE_JSON: { + return pb::JSON; + } break; default : { DB_WARNING("unsupported item type: %d", field_type->type); return pb::INVALID_TYPE; diff --git a/src/logical_plan/logical_planner.cpp b/src/logical_plan/logical_planner.cpp index bf4785c6..ed776853 100644 --- a/src/logical_plan/logical_planner.cpp +++ b/src/logical_plan/logical_planner.cpp @@ -2814,6 +2814,7 @@ int LogicalPlanner::create_sort_node() { sort->add_is_null_first(_order_ascs[idx]); } sort->set_tuple_id(_order_tuple_id); + return 0; } int LogicalPlanner::create_join_and_scan_nodes(JoinMemTmp* join_root, ApplyMemTmp* apply_root) { @@ -2926,7 +2927,7 @@ int LogicalPlanner::create_scan_nodes() { } void LogicalPlanner::set_dml_txn_state(int64_t table_id) { - if (_ctx->is_explain || table_id == -1) { + if (_ctx->is_explain) { return; } auto client = _ctx->client_conn; diff --git a/src/logical_plan/prepare_planner.cpp b/src/logical_plan/prepare_planner.cpp index 5f21794c..1141c013 100644 --- a/src/logical_plan/prepare_planner.cpp +++ b/src/logical_plan/prepare_planner.cpp @@ -173,6 +173,7 @@ int PreparePlanner::stmt_prepare(const std::string& stmt_name, const std::string prepare_ctx->client_conn = client; prepare_ctx->get_runtime_state()->set_client_conn(client); prepare_ctx->sql = stmt_sql; + prepare_ctx->is_full_export = false; std::unique_ptr planner; switch (prepare_ctx->stmt_type) { @@ -248,7 +249,7 @@ int PreparePlanner::stmt_execute(const std::string& stmt_name, std::vectorstat_info.table = prepare_ctx->stat_info.table; _ctx->stat_info.sample_sql << prepare_ctx->stat_info.sample_sql.str(); _ctx->stat_info.sign = prepare_ctx->stat_info.sign; - _ctx->is_full_export = prepare_ctx->is_full_export; + _ctx->is_full_export = false; _ctx->debug_region_id = prepare_ctx->debug_region_id; _ctx->execute_global_flow = prepare_ctx->execute_global_flow; if (params.size() != prepare_ctx->placeholders.size()) { @@ -266,9 +267,11 @@ int PreparePlanner::stmt_execute(const std::string& stmt_name, std::vectorplaceholders; // TODO dml的plan复用 - if (!prepare_ctx->is_select || prepare_ctx->sub_query_plans.size() > 0) { + if (!prepare_ctx->is_select || prepare_ctx->sub_query_plans.size() > 0 || (prepare_ctx->root != nullptr && prepare_ctx->root->has_optimized())) { // enable_2pc=true or table has global index need generate txn_id - set_dml_txn_state(prepare_ctx->prepared_table_id); + if (!prepare_ctx->is_select && prepare_ctx->prepared_table_id != -1) { + set_dml_txn_state(prepare_ctx->prepared_table_id); + } _ctx->plan.CopyFrom(prepare_ctx->plan); int ret = set_dml_local_index_binlog(prepare_ctx->prepared_table_id); if (ret < 0) { diff --git a/src/logical_plan/select_planner.cpp b/src/logical_plan/select_planner.cpp index 0053382c..e02a65fc 100644 --- a/src/logical_plan/select_planner.cpp +++ b/src/logical_plan/select_planner.cpp @@ -96,6 +96,10 @@ int SelectPlanner::plan() { return -1; } + if (0 != minmax_remove()) { + return -1; + } + if (_ctx->is_base_subscribe) { if (0 != get_base_subscribe_scan_ref_slot()) { return -1; @@ -150,6 +154,9 @@ bool SelectPlanner::is_full_export() { if (_ctx->explain_type != EXPLAIN_NULL) { return false; } + if (_ctx->is_prepared) { + return false; + } if (_ctx->debug_region_id != -1) { return false; } @@ -345,6 +352,62 @@ void SelectPlanner::get_slot_column_mapping() { } } +int SelectPlanner::minmax_remove() { + if (!_distinct_agg_funcs.empty() || ! _group_exprs.empty()) { + return 0; + } + if (_select_exprs.size() != 1 || _select_exprs[0].nodes(0).node_type() != pb::AGG_EXPR) { + return 0; + } + if (_group_slots.size() != 0 || _order_exprs.size() != 0 || _group_exprs.size() != 0) { + return 0; + } + pb::Expr select_expr = _select_exprs[0]; + if (select_expr.nodes_size() != 2) { + return 0; + } + std::string fn_name = select_expr.nodes(0).fn().name(); + if (fn_name != "max" && fn_name != "min") { + return 0; + } + pb::ExprNode slot = select_expr.nodes(1); + if (slot.node_type() != pb::SLOT_REF) { + return 0; + } + _select_exprs.clear(); + _group_exprs.clear(); + _agg_funcs.clear(); + pb::Expr new_select; + new_select.set_database(select_expr.database()); + new_select.set_table(select_expr.table()); + auto add_node = new_select.add_nodes(); + *add_node = slot; + _select_exprs.push_back(new_select); + pb::Expr order_expr; + order_expr.set_database(select_expr.database()); + order_expr.set_table(select_expr.table()); + add_node = order_expr.add_nodes(); + *add_node = slot; + _order_exprs.push_back(order_expr); + if (fn_name == "max") { + _order_ascs.push_back(false); + } else { + _order_ascs.push_back(true); + } + _ctx->get_runtime_state()->must_have_one = true; + _limit_offset.clear_nodes(); + auto offset = _limit_offset.add_nodes(); + offset->mutable_derive_node()->set_int_val(0); + offset->set_node_type(pb::INT_LITERAL); + offset->set_col_type(pb::INT64); + _limit_count.clear_nodes(); + auto limit = _limit_count.add_nodes(); + limit->mutable_derive_node()->set_int_val(1); + limit->set_node_type(pb::INT_LITERAL); + limit->set_col_type(pb::INT64); + return 0; +} + int SelectPlanner::subquery_rewrite() { if (!_ctx->expr_params.is_expr_subquery) { return 0; @@ -430,7 +493,8 @@ void SelectPlanner::create_dual_scan_node() { } int SelectPlanner::create_limit_node() { - if (_select->limit == nullptr) { +// if (_select->limit == nullptr && + if (_limit_offset.nodes_size() == 0) { return 0; } pb::PlanNode* limit_node = _ctx->add_plan_node(); @@ -456,30 +520,16 @@ int SelectPlanner::create_limit_node() { int SelectPlanner::create_agg_node() { if (_select->select_opt != nullptr && _select->select_opt->distinct == true) { - // select distinct ()xxx, xxx from xx.xx (no group by) - if (!_agg_funcs.empty() || !_distinct_agg_funcs.empty() || !_group_exprs.empty()) { + // select distinct ()xxx, xxx from xx.xx + if (_agg_funcs.empty() && _distinct_agg_funcs.empty() && _group_exprs.empty()) { + //如果没有agg和group by, 将select列加入到group by中 + for (uint32_t idx = 0; idx < _select_exprs.size(); ++idx) { + _group_exprs.push_back(_select_exprs[idx]); + } + } else if(!_group_exprs.empty()) { DB_WARNING("distinct query doesnot support group by"); return -1; } - pb::PlanNode* agg_node = _ctx->add_plan_node(); - agg_node->set_node_type(pb::AGG_NODE); - agg_node->set_limit(-1); - agg_node->set_is_explain(_ctx->is_explain); - agg_node->set_num_children(1); //TODO - pb::DerivePlanNode* derive = agg_node->mutable_derive_node(); - pb::AggNode* agg = derive->mutable_agg_node(); - - for (uint32_t idx = 0; idx < _select_exprs.size(); ++idx) { - pb::Expr* expr = agg->add_group_exprs(); - expr->CopyFrom(_select_exprs[idx]); -// if (_select_exprs[idx].nodes_size() != 1) { -// DB_WARNING("invalid distinct expr"); -// return -1; -// } -// expr->add_nodes()->CopyFrom(_select_exprs[idx].nodes(0)); - } - agg->set_agg_tuple_id(-1); - return 0; } if (_agg_funcs.empty() && _distinct_agg_funcs.empty() && _group_exprs.empty()) { return 0; diff --git a/src/physical_plan/limit_calc.cpp b/src/physical_plan/limit_calc.cpp index 5e1f6abb..e59b5e7a 100644 --- a/src/physical_plan/limit_calc.cpp +++ b/src/physical_plan/limit_calc.cpp @@ -15,6 +15,7 @@ #include "limit_calc.h" #include "join_node.h" #include "filter_node.h" +#include "agg_node.h" namespace baikaldb { int LimitCalc::analyze(QueryContext* ctx) { @@ -41,9 +42,16 @@ void LimitCalc::_analyze_limit(QueryContext* ctx, ExecNode* node, int64_t limit) return; } } + // case: select distinct f from test limit 10; + // 没有agg_fn时, 在agg_node的open阶段可以使用limit。 + case pb::MERGE_AGG_NODE: + if (static_cast(node)->mutable_agg_fn_calls()->empty()) { + break; + } else { + return; + } case pb::HAVING_FILTER_NODE: case pb::SORT_NODE: - case pb::MERGE_AGG_NODE: case pb::AGG_NODE: return; default: diff --git a/src/protocol/network_server.cpp b/src/protocol/network_server.cpp index cc131dcd..b398d67f 100644 --- a/src/protocol/network_server.cpp +++ b/src/protocol/network_server.cpp @@ -832,6 +832,7 @@ static void on_health_check_done(pb::StoreRes* response, brpc::Controller* cntl, std::unique_ptr response_guard(response); std::unique_ptr cntl_guard(cntl); pb::Status new_status = pb::NORMAL; + old_status = SchemaFactory::get_instance()->get_instance_status(addr).status; if (cntl->Failed()) { if (cntl->ErrorCode() == brpc::ERPCTIMEDOUT || cntl->ErrorCode() == ETIMEDOUT) { diff --git a/src/protocol/state_machine.cpp b/src/protocol/state_machine.cpp index 1c9ac13c..292e9b26 100644 --- a/src/protocol/state_machine.cpp +++ b/src/protocol/state_machine.cpp @@ -24,7 +24,6 @@ namespace baikaldb { DEFINE_int32(max_connections_per_user, 4000, "default user max connections"); -DEFINE_int32(query_quota_per_user, 3000, "default user query quota by 1 second"); DEFINE_string(log_plat_name, "test", "plat name for print log, distinguish monitor"); DEFINE_int64(baikal_max_allowed_packet, 268435456LL, "The largest possible packet : 256M"); DEFINE_int32(query_cache_timeout_s, 10, "query cache timeout(s)"); @@ -610,10 +609,10 @@ int StateMachine::_auth_read(SmartSocket sock) { //use default max_connection sock->user_info->max_connection = FLAGS_max_connections_per_user; } - if (sock->user_info->query_quota == 0) { - //use default query_quota - sock->user_info->query_quota = FLAGS_query_quota_per_user; - } + // if (sock->user_info->query_quota == 0) { + // //use default query_quota + // sock->user_info->query_quota = FLAGS_query_quota_per_user; + // } // Get password. if ((unsigned int)(sock->packet_len + PACKET_HEADER_LEN) < off + 1) { DB_FATAL_CLIENT(sock, "packet_len=%d + 4 <= off=%d + 1", diff --git a/src/runtime/topn_sorter.cpp b/src/runtime/topn_sorter.cpp new file mode 100644 index 00000000..8b192c73 --- /dev/null +++ b/src/runtime/topn_sorter.cpp @@ -0,0 +1,114 @@ +// Copyright (c) 2018-present Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "topn_sorter.h" + +namespace baikaldb { + +void TopNSorter::add_batch(std::shared_ptr& batch){ + while (!batch->is_traverse_over()) { + _current_idx ++; + if (_current_count < _limit) { + _mem_row_heap.push_back(TopNHeapItem{std::move(batch->get_row()), _current_idx}); + _current_count ++; + if (!_comp->need_not_compare()) { + shiftup(_current_count - 1); + } + } else { + auto& row = batch->get_row(); + if (!_comp->need_not_compare()) { + if (_comp->less(row.get(), _mem_row_heap[0].row.get())) { + _mem_row_heap[0] = TopNHeapItem{std::move(row), _current_idx}; + shiftdown(0); + } + } + } + batch->next(); + } +} + +void TopNSorter::sort() { + _current_idx = 0; + if (_comp->need_not_compare()) { + return; + } + auto compare_func = [&](const TopNHeapItem& left, const TopNHeapItem& right) { + auto comp = _comp->compare(left.row.get(), right.row.get()); + if (comp < 0) { + return true; + } else if (comp == 0 && left.idx < right.idx) { + return true; + } + return false; + }; + std::sort(_mem_row_heap.begin(), _mem_row_heap.end(), compare_func); +} + +int TopNSorter::get_next(RowBatch* batch, bool* eos) { + while (1) { + if (batch->is_full()) { + return 0; + } + if (_current_idx >= _mem_row_heap.size()) { + *eos = true; + return 0; + } + batch->move_row(std::move(_mem_row_heap[_current_idx].row)); + _current_idx ++; + } + return 0; +} + +void TopNSorter::shiftdown(size_t index) { + size_t left_index = index * 2 + 1; + size_t right_index = left_index + 1; + if (left_index >= _current_count) { + return; + } + size_t min_index = index; + if (left_index < _current_count) { + int64_t com = _comp->compare(_mem_row_heap[left_index].row.get(), + _mem_row_heap[min_index].row.get()); + if (com > 0) { + min_index = left_index; + } + } + if (right_index < _current_count) { + int64_t com = _comp->compare(_mem_row_heap[right_index].row.get(), + _mem_row_heap[min_index].row.get()); + if (com > 0) { + min_index = right_index; + } + } + if (min_index != index) { + std::iter_swap(_mem_row_heap.begin() + min_index, _mem_row_heap.begin() + index); + shiftdown(min_index); + } +} + +void TopNSorter::shiftup(size_t index) { + if (index == 0) { + return; + } + size_t parent = (index - 1) / 2; + auto com = _comp->compare(_mem_row_heap[index].row.get(), _mem_row_heap[parent].row.get()); + if (com > 0) { + std::iter_swap(_mem_row_heap.begin() + index, _mem_row_heap.begin() + parent); + shiftup(parent); + } +} + +} + +/* vim: set ts=4 sw=4 sts=4 tw=100 */ diff --git a/src/session/user_info.cpp b/src/session/user_info.cpp new file mode 100644 index 00000000..4640d68f --- /dev/null +++ b/src/session/user_info.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2018-present Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "user_info.h" + +namespace baikaldb { +DEFINE_int32(query_quota_per_user, 3000, "default user query quota by 1 second"); +BRPC_VALIDATE_GFLAG(query_quota_per_user, brpc::PassValidate); + +bool UserInfo::is_exceed_quota() { + if (query_cost.get_time() > 1000000) { + query_cost.reset(); + query_count = 0; + return false; + } + int32_t quota = query_quota; + if (quota == 0) { + quota = FLAGS_query_quota_per_user; + } + return query_count++ > quota; +} + +} // namespace baikaldb