From d27c98abc49d6432a007b4e3087fd3a4fccb6144 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=8B=87?= Date: Fri, 10 May 2024 14:43:23 +0800 Subject: [PATCH] fix: precision --- include/common/expr_value.h | 9 ++++++++- include/common/type_utils.h | 12 ++++++++++++ include/expr/expr_node.h | 7 +++++++ include/expr/slot_ref.h | 6 +++++- src/common/expr_value.cpp | 15 +++++++++++++-- src/exec/packet_node.cpp | 14 ++++++++++++++ src/store/region.cpp | 4 ++-- 7 files changed, 61 insertions(+), 6 deletions(-) diff --git a/include/common/expr_value.h b/include/common/expr_value.h index 707a17804..4e9eae986 100644 --- a/include/common/expr_value.h +++ b/include/common/expr_value.h @@ -610,12 +610,19 @@ struct ExprValue { return std::to_string(_u.uint64_val); case pb::FLOAT: { std::ostringstream oss; + if (float_precision_len != -1) { + oss << std::fixed << std::setprecision(float_precision_len); + } oss << _u.float_val; return oss.str(); } case pb::DOUBLE: { std::ostringstream oss; - oss << std::setprecision(15) << _u.double_val; + if (float_precision_len != -1) { + oss << std::fixed << std::setprecision(float_precision_len) << _u.double_val; + } else { + oss << std::setprecision(15) << _u.double_val; + } return oss.str(); } case pb::STRING: diff --git a/include/common/type_utils.h b/include/common/type_utils.h index 8e7c17d66..ca2605d97 100644 --- a/include/common/type_utils.h +++ b/include/common/type_utils.h @@ -313,6 +313,18 @@ inline bool is_binary(uint32_t flag) { } } +inline bool is_precision(pb::PrimitiveType type) { + switch (type) { + case pb::FLOAT: + case pb::DOUBLE: + case pb::DATETIME: + return true; + default: + return false; + } +} + + inline uint8_t to_mysql_type(pb::PrimitiveType type) { switch (type) { case pb::BOOL: diff --git a/include/expr/expr_node.h b/include/expr/expr_node.h index cd3260d0f..53bcf96df 100644 --- a/include/expr/expr_node.h +++ b/include/expr/expr_node.h @@ -247,6 +247,12 @@ class ExprNode { void set_col_flag(uint32_t col_flag) { _col_flag = col_flag; } + void set_float_precision_len(int32_t float_precision_len) { + _float_precision_len = float_precision_len; + } + int32_t float_precision_len() { + return _float_precision_len; + } void flatten_or_expr(std::vector* or_exprs) { if (node_type() != pb::OR_PREDICATE) { @@ -297,6 +303,7 @@ class ExprNode { bool _replace_agg_to_slot = true; int32_t _tuple_id = -1; int32_t _slot_id = -1; + int32_t _float_precision_len = -1; bool is_logical_and_or_not(); public: static int create_expr_node(const pb::ExprNode& node, ExprNode** expr_node); diff --git a/include/expr/slot_ref.h b/include/expr/slot_ref.h index 8799e2204..57f717fb2 100644 --- a/include/expr/slot_ref.h +++ b/include/expr/slot_ref.h @@ -32,7 +32,11 @@ class SlotRef : public ExprNode { if (row == NULL) { return ExprValue::Null(); } - return row->get_value(_tuple_id, _slot_id).cast_to(_col_type); + ExprValue v = row->get_value(_tuple_id, _slot_id).cast_to(_col_type); + if (_float_precision_len != -1) { + v.set_precision_len(_float_precision_len); + } + return v; } virtual ExprValue get_value(const ExprValue& value) { return value; diff --git a/src/common/expr_value.cpp b/src/common/expr_value.cpp index 0b79ce316..36c916c61 100644 --- a/src/common/expr_value.cpp +++ b/src/common/expr_value.cpp @@ -63,7 +63,13 @@ SerializeStatus ExprValue::serialize_to_mysql_text_packet(char* buf, size_t size case pb::FLOAT: { size_t body_len = 0; char tmp_buf[100] = {0}; - body_len = snprintf(tmp_buf, sizeof(tmp_buf), "%.6g", _u.float_val); + if (float_precision_len == -1) { + body_len = snprintf(tmp_buf, sizeof(tmp_buf), "%.6g", _u.float_val); + } else { + std::string format= "%." + std::to_string(float_precision_len) + "f"; + body_len = snprintf(tmp_buf, sizeof(tmp_buf), format.c_str(), _u.float_val); + } + len = body_len + 1; if (len > size) { return STMPS_NEED_RESIZE; @@ -76,7 +82,12 @@ SerializeStatus ExprValue::serialize_to_mysql_text_packet(char* buf, size_t size case pb::DOUBLE: { size_t body_len = 0; char tmp_buf[100] = {0}; - body_len = snprintf(tmp_buf, sizeof(tmp_buf), "%.12g", _u.double_val); + if (float_precision_len == -1) { + body_len = snprintf(tmp_buf, sizeof(tmp_buf), "%.12g", _u.double_val); + } else { + std::string format= "%." + std::to_string(float_precision_len) + "lf"; + body_len = snprintf(tmp_buf, sizeof(tmp_buf), format.c_str(), _u.double_val); + } len = body_len + 1; if (len > size) { return STMPS_NEED_RESIZE; diff --git a/src/exec/packet_node.cpp b/src/exec/packet_node.cpp index 67155a418..b9d83200e 100644 --- a/src/exec/packet_node.cpp +++ b/src/exec/packet_node.cpp @@ -459,6 +459,20 @@ int PacketNode::open(RuntimeState* state) { DB_WARNING("Expr::open fail:%d", ret); return ret; } + // 设置返回精度 + if (expr->is_slot_ref() && is_precision(expr->col_type())) { + SlotRef* slot_ref = static_cast(expr); + int64_t table_id = state->get_tuple_desc(slot_ref->tuple_id())->table_id(); + auto table_info = SchemaFactory::get_instance()->get_table_info_ptr(table_id); + if (table_info == nullptr) { + continue; + } + auto field_info = table_info->get_field_ptr(slot_ref->field_id()); + if (field_info == nullptr) { + continue; + } + slot_ref->set_float_precision_len(field_info->float_precision_len); + } } if (state->is_expr_subquery()) { return fatch_expr_subquery_results(state); diff --git a/src/store/region.cpp b/src/store/region.cpp index 7be92ee90..6fe54ea52 100644 --- a/src/store/region.cpp +++ b/src/store/region.cpp @@ -3258,9 +3258,9 @@ void Region::do_apply(int64_t term, int64_t index, const pb::StoreReq& request, if (res.has_scan_rows()) { ((DMLClosure*)done)->response->set_scan_rows(res.scan_rows()); } - if (res.has_read_disk_size()) { + if (res.has_read_disk_size()) { ((DMLClosure*)done)->response->set_read_disk_size(res.read_disk_size()); - } + } if (res.has_filter_rows()) { ((DMLClosure*)done)->response->set_filter_rows(res.filter_rows()); }