From 6d559f1511a1a6e6b1f0f7870cc8340101aa4c27 Mon Sep 17 00:00:00 2001 From: Mryange Date: Wed, 18 Dec 2024 16:46:46 +0800 Subject: [PATCH] upd --- be/src/vec/exprs/vectorized_fn_call.cpp | 5 +- be/src/vec/exprs/vexpr.cpp | 3 +- be/src/vec/functions/function.cpp | 61 +++++++++---------------- be/src/vec/functions/function.h | 10 ++-- be/src/vec/functions/round.h | 14 ++++-- 5 files changed, 43 insertions(+), 50 deletions(-) diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index 43aa24090393a15..56ba87d8f808d03 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -59,7 +59,10 @@ Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, ColumnsWithTypeAndName argument_template; argument_template.reserve(_children.size()); for (auto child : _children) { - argument_template.emplace_back(nullptr, child->data_type(), child->expr_name()); + std::shared_ptr const_col; + RETURN_IF_ERROR(child->get_const_col(context, &const_col)); + argument_template.emplace_back(const_col ? const_col->column_ptr : nullptr, + child->data_type(), child->expr_name()); } _expr_name = fmt::format("VectorizedFnCall[{}](arguments={},return={})", _fn.name.function_name, diff --git a/be/src/vec/exprs/vexpr.cpp b/be/src/vec/exprs/vexpr.cpp index 7cfb96e77118fb4..47b6100662f4ead 100644 --- a/be/src/vec/exprs/vexpr.cpp +++ b/be/src/vec/exprs/vexpr.cpp @@ -532,8 +532,7 @@ Status VExpr::get_const_col(VExprContext* context, return Status::OK(); } - if (_constant_col != nullptr) { - DCHECK(column_wrapper != nullptr); + if (_constant_col != nullptr && column_wrapper != nullptr) { *column_wrapper = _constant_col; return Status::OK(); } diff --git a/be/src/vec/functions/function.cpp b/be/src/vec/functions/function.cpp index 851e430d2f04075..c7db34aa9ed93ff 100644 --- a/be/src/vec/functions/function.cpp +++ b/be/src/vec/functions/function.cpp @@ -296,45 +296,28 @@ DataTypePtr FunctionBuilderImpl::get_return_type(const ColumnsWithTypeAndName& a bool FunctionBuilderImpl::is_date_or_datetime_or_decimal( const DataTypePtr& return_type, const DataTypePtr& func_return_type) const { - return (is_date_or_datetime(return_type->is_nullable() - ? ((DataTypeNullable*)return_type.get())->get_nested_type() - : return_type) && - is_date_or_datetime( - func_return_type->is_nullable() - ? ((DataTypeNullable*)func_return_type.get())->get_nested_type() - : func_return_type)) || - (is_date_v2_or_datetime_v2( - return_type->is_nullable() - ? ((DataTypeNullable*)return_type.get())->get_nested_type() - : return_type) && - is_date_v2_or_datetime_v2( - func_return_type->is_nullable() - ? ((DataTypeNullable*)func_return_type.get())->get_nested_type() - : func_return_type)) || - // For some date functions such as str_to_date(string, string), return_type will - // be datetimev2 if users enable datev2 but get_return_type(arguments) will still - // return datetime. We need keep backward compatibility here. - (is_date_v2_or_datetime_v2( - return_type->is_nullable() - ? ((DataTypeNullable*)return_type.get())->get_nested_type() - : return_type) && - is_date_or_datetime( - func_return_type->is_nullable() - ? ((DataTypeNullable*)func_return_type.get())->get_nested_type() - : func_return_type)) || - (is_date_or_datetime(return_type->is_nullable() - ? ((DataTypeNullable*)return_type.get())->get_nested_type() - : return_type) && - is_date_v2_or_datetime_v2( - func_return_type->is_nullable() - ? ((DataTypeNullable*)func_return_type.get())->get_nested_type() - : func_return_type)) || - (is_decimal(return_type->is_nullable() - ? ((DataTypeNullable*)return_type.get())->get_nested_type() - : return_type) && - is_decimal(func_return_type->is_nullable() - ? ((DataTypeNullable*)func_return_type.get())->get_nested_type() - : func_return_type)); + auto expect_return_type = remove_nullable(return_type); + auto real_return_type = remove_nullable(func_return_type); + + auto check_date_and_datetime = [&]() -> bool { + return (is_date_or_datetime(expect_return_type) && is_date_or_datetime(real_return_type)) || + (is_date_v2_or_datetime_v2(expect_return_type) && + is_date_v2_or_datetime_v2(real_return_type)); + }; + + // It is required that both types are either Decimal32, Decimal64, or Decimal128, + // and the scale must be the same. Due to differences between FE and BE code, + // no requirements are currently enforced for precision. + + auto check_decimal = [&]() -> bool { + if (is_decimal(expect_return_type) && is_decimal(real_return_type)) { + return (expect_return_type->get_type_id() == real_return_type->get_type_id()) && + (expect_return_type->get_scale() == real_return_type->get_scale()); + } + return false; + }; + + return check_date_and_datetime() || check_decimal(); } bool FunctionBuilderImpl::is_array_nested_type_date_or_datetime_or_decimal( diff --git a/be/src/vec/functions/function.h b/be/src/vec/functions/function.h index 92282c483948e64..31ca167de9c908a 100644 --- a/be/src/vec/functions/function.h +++ b/be/src/vec/functions/function.h @@ -274,11 +274,11 @@ class FunctionBuilderImpl : public IFunctionBuilder { is_nothing(((DataTypeNullable*)func_return_type.get())->get_nested_type())) || is_date_or_datetime_or_decimal(return_type, func_return_type) || is_array_nested_type_date_or_datetime_or_decimal(return_type, func_return_type))) { - LOG_WARNING( - "function return type check failed, function_name={}, " - "expect_return_type={}, real_return_type={}, input_arguments={}", - get_name(), return_type->get_name(), func_return_type->get_name(), - get_types_string(arguments)); + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "function return type check failed, function_name={}, " + "expect_return_type={}, real_return_type={}, input_arguments={}", + get_name(), return_type->get_name(), + func_return_type->get_name(), get_types_string(arguments)); return nullptr; } return build_impl(arguments, return_type); diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index 3f4f9c60fcbe3df..41e91dd1591ed14 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -692,15 +692,23 @@ class FunctionRounding : public IFunction { } /// Get result types by argument types. If the function does not apply to these arguments, throw an exception. - DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& arguments) const override { if ((arguments.empty()) || (arguments.size() > 2)) { throw doris::Exception( ErrorCode::INVALID_ARGUMENT, "Number of arguments for function {}, doesn't match: should be 1 or 2. ", get_name()); } - - return arguments[0]; + // Keep consistent with FE's roundRule. + if (arguments.size() == 1 || + !WhichDataType {arguments[0].type->get_type_id()}.is_decimal()) { + return arguments[0].type; + } else { + return vectorized::create_decimal(arguments[0].type->get_precision(), + std::min(arguments[1].column->get_int(0), + (Int64)arguments[0].type->get_scale()), + false); + } } static Status get_scale_arg(const ColumnWithTypeAndName& arguments, Int16* scale) {