From 76a3ce9e1144453def5953f879d9c0690d23036d Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Mon, 21 Oct 2024 00:34:53 +0000 Subject: [PATCH] Update vendored DuckDB sources to 12c4a24f --- .../src/function/scalar/string/concat.cpp | 269 ++++++++---------- .../function/table/version/pragma_version.cpp | 6 +- 2 files changed, 121 insertions(+), 154 deletions(-) diff --git a/src/duckdb/src/function/scalar/string/concat.cpp b/src/duckdb/src/function/scalar/string/concat.cpp index 18619a5b..0f1d3509 100644 --- a/src/duckdb/src/function/scalar/string/concat.cpp +++ b/src/duckdb/src/function/scalar/string/concat.cpp @@ -45,7 +45,7 @@ static void StringConcatFunction(DataChunk &args, ExpressionState &state, Vector vector result_lengths(args.size(), 0); for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { auto &input = args.data[col_idx]; - D_ASSERT(input.GetType().id() == LogicalTypeId::VARCHAR); + D_ASSERT(input.GetType().InternalType() == PhysicalType::VARCHAR); if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (ConstantVector::IsNull(input)) { // constant null, skip @@ -143,68 +143,60 @@ static void ConcatOperator(DataChunk &args, ExpressionState &state, Vector &resu }); } +struct ListConcatInputData { + ListConcatInputData(Vector &input, Vector &child_vec) : input(input), child_vec(child_vec) { + } + + UnifiedVectorFormat vdata; + Vector &input; + Vector &child_vec; + UnifiedVectorFormat child_vdata; + const list_entry_t *input_entries = nullptr; +}; + static void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); auto count = args.size(); - Vector &lhs = args.data[0]; - Vector &rhs = args.data[1]; - if (lhs.GetType().id() == LogicalTypeId::SQLNULL) { - result.Reference(rhs); - return; - } - if (rhs.GetType().id() == LogicalTypeId::SQLNULL) { - result.Reference(lhs); - return; - } - - UnifiedVectorFormat lhs_data; - UnifiedVectorFormat rhs_data; - lhs.ToUnifiedFormat(count, lhs_data); - rhs.ToUnifiedFormat(count, rhs_data); - auto lhs_entries = UnifiedVectorFormat::GetData(lhs_data); - auto rhs_entries = UnifiedVectorFormat::GetData(rhs_data); - - auto lhs_list_size = ListVector::GetListSize(lhs); - auto rhs_list_size = ListVector::GetListSize(rhs); - auto &lhs_child = ListVector::GetEntry(lhs); - auto &rhs_child = ListVector::GetEntry(rhs); - UnifiedVectorFormat lhs_child_data; - UnifiedVectorFormat rhs_child_data; - lhs_child.ToUnifiedFormat(lhs_list_size, lhs_child_data); - rhs_child.ToUnifiedFormat(rhs_list_size, rhs_child_data); - - result.SetVectorType(VectorType::FLAT_VECTOR); auto result_entries = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); + vector input_data; + for (auto &input : args.data) { + if (input.GetType().id() == LogicalTypeId::SQLNULL) { + // ignore NULL values + continue; + } + + auto &child_vec = ListVector::GetEntry(input); + ListConcatInputData data(input, child_vec); + input.ToUnifiedFormat(count, data.vdata); + + data.input_entries = UnifiedVectorFormat::GetData(data.vdata); + auto list_size = ListVector::GetListSize(input); + + child_vec.ToUnifiedFormat(list_size, data.child_vdata); + + input_data.push_back(std::move(data)); + } idx_t offset = 0; for (idx_t i = 0; i < count; i++) { - auto lhs_list_index = lhs_data.sel->get_index(i); - auto rhs_list_index = rhs_data.sel->get_index(i); - if (!lhs_data.validity.RowIsValid(lhs_list_index) && !rhs_data.validity.RowIsValid(rhs_list_index)) { - result_validity.SetInvalid(i); - continue; - } - result_entries[i].offset = offset; - result_entries[i].length = 0; - if (lhs_data.validity.RowIsValid(lhs_list_index)) { - const auto &lhs_entry = lhs_entries[lhs_list_index]; - result_entries[i].length += lhs_entry.length; - ListVector::Append(result, lhs_child, *lhs_child_data.sel, lhs_entry.offset + lhs_entry.length, - lhs_entry.offset); - } - if (rhs_data.validity.RowIsValid(rhs_list_index)) { - const auto &rhs_entry = rhs_entries[rhs_list_index]; - result_entries[i].length += rhs_entry.length; - ListVector::Append(result, rhs_child, *rhs_child_data.sel, rhs_entry.offset + rhs_entry.length, - rhs_entry.offset); + auto &result_entry = result_entries[i]; + result_entry.offset = offset; + result_entry.length = 0; + for (auto &data : input_data) { + auto list_index = data.vdata.sel->get_index(i); + if (!data.vdata.validity.RowIsValid(list_index)) { + continue; + } + const auto &list_entry = data.input_entries[list_index]; + result_entry.length += list_entry.length; + ListVector::Append(result, data.child_vec, *data.child_vdata.sel, list_entry.offset + list_entry.length, + list_entry.offset); } - offset += result_entries[i].length; + offset += result_entry.length; } - D_ASSERT(ListVector::GetListSize(result) == offset); + ListVector::SetListSize(result, offset); - if (lhs.GetVectorType() == VectorType::CONSTANT_VECTOR && rhs.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (args.AllConstant()) { result.SetVectorType(VectorType::CONSTANT_VECTOR); } } @@ -235,128 +227,103 @@ static void SetArgumentType(ScalarFunction &bound_function, const LogicalType &t bound_function.return_type = type; } -static void HandleArrayBinding(ClientContext &context, vector> &arguments) { - if (arguments[1]->return_type.id() != LogicalTypeId::ARRAY && - arguments[1]->return_type.id() != LogicalTypeId::SQLNULL) { - throw BinderException("Cannot concatenate types %s and %s", arguments[0]->return_type.ToString(), - arguments[1]->return_type.ToString()); - } - - // if either argument is an array, we cast it to a list - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - arguments[1] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[1])); -} - -static unique_ptr HandleListBinding(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments, bool is_operator) { - // list_concat only accepts two arguments - D_ASSERT(arguments.size() == 2); - - auto &lhs = arguments[0]->return_type; - auto &rhs = arguments[1]->return_type; - - if (lhs.id() == LogicalTypeId::UNKNOWN || rhs.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } else if (lhs.id() == LogicalTypeId::SQLNULL || rhs.id() == LogicalTypeId::SQLNULL) { - // we mimic postgres behaviour: list_concat(NULL, my_list) = my_list - auto return_type = rhs.id() == LogicalTypeId::SQLNULL ? lhs : rhs; - SetArgumentType(bound_function, return_type, is_operator); - return make_uniq(bound_function.return_type, is_operator); - } - if (lhs.id() != LogicalTypeId::LIST || rhs.id() != LogicalTypeId::LIST) { - throw BinderException("Cannot concatenate types %s and %s", lhs.ToString(), rhs.ToString()); - } - - // Resolve list type +static unique_ptr BindListConcat(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments, bool is_operator) { LogicalType child_type = LogicalType::SQLNULL; - for (const auto &argument : arguments) { - auto &next_type = ListType::GetChildType(argument->return_type); + bool all_null = true; + for (auto &arg : arguments) { + auto &return_type = arg->return_type; + if (return_type == LogicalTypeId::SQLNULL) { + // we mimic postgres behaviour: list_concat(NULL, my_list) = my_list + continue; + } + all_null = false; + LogicalType next_type = LogicalTypeId::INVALID; + switch (return_type.id()) { + case LogicalTypeId::UNKNOWN: + throw ParameterNotResolvedException(); + case LogicalTypeId::LIST: + next_type = ListType::GetChildType(return_type); + break; + case LogicalTypeId::ARRAY: + next_type = ArrayType::GetChildType(return_type); + break; + default: { + string type_list; + for (idx_t arg_idx = 0; arg_idx < arguments.size(); arg_idx++) { + if (!type_list.empty()) { + if (arg_idx + 1 == arguments.size()) { + // last argument + type_list += " and "; + } else { + type_list += ", "; + } + } + type_list += arguments[arg_idx]->return_type.ToString(); + } + throw BinderException(*arg, "Cannot concatenate types %s - an explicit cast is required", type_list); + } + } if (!LogicalType::TryGetMaxLogicalType(context, child_type, next_type, child_type)) { - throw BinderException("Cannot concatenate lists of types %s[] and %s[] - an explicit cast is required", + throw BinderException(*arg, + "Cannot concatenate lists of types %s[] and %s[] - an explicit cast is required", child_type.ToString(), next_type.ToString()); } } + if (all_null) { + // all arguments are NULL + SetArgumentType(bound_function, LogicalTypeId::SQLNULL, is_operator); + return make_uniq(bound_function.return_type, is_operator); + } auto list_type = LogicalType::LIST(child_type); SetArgumentType(bound_function, list_type, is_operator); return make_uniq(bound_function.return_type, is_operator); } -static void FindFirstTwoArguments(vector> &arguments, LogicalTypeId &first_arg, - LogicalTypeId &second_arg) { - first_arg = arguments[0]->return_type.id(); - second_arg = first_arg; - if (arguments.size() > 1) { - second_arg = arguments[1]->return_type.id(); +static unique_ptr BindConcatFunctionInternal(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments, + bool is_operator) { + bool list_concat = false; + // blob concat is only supported for the concat operator - regular concat converts to varchar + bool all_blob = is_operator ? true : false; + for (auto &arg : arguments) { + if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + if (arg->return_type.id() == LogicalTypeId::LIST || arg->return_type.id() == LogicalTypeId::ARRAY) { + list_concat = true; + } + if (arg->return_type.id() != LogicalTypeId::BLOB) { + all_blob = false; + } } + if (list_concat) { + return BindListConcat(context, bound_function, arguments, is_operator); + } + auto return_type = all_blob ? LogicalType::BLOB : LogicalType::VARCHAR; + + // we can now assume that the input is a string or castable to a string + SetArgumentType(bound_function, return_type, is_operator); + return make_uniq(bound_function.return_type, is_operator); } static unique_ptr BindConcatFunction(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - LogicalTypeId first_arg; - LogicalTypeId second_arg; - FindFirstTwoArguments(arguments, first_arg, second_arg); - - if (arguments.size() > 2 && (first_arg == LogicalTypeId::ARRAY || first_arg == LogicalTypeId::LIST)) { - throw BinderException("list_concat only accepts two arguments"); - } - - if (first_arg == LogicalTypeId::ARRAY || second_arg == LogicalTypeId::ARRAY) { - HandleArrayBinding(context, arguments); - FindFirstTwoArguments(arguments, first_arg, second_arg); - } - - if (first_arg == LogicalTypeId::LIST || second_arg == LogicalTypeId::LIST) { - return HandleListBinding(context, bound_function, arguments, false); - } - - // we can now assume that the input is a string or castable to a string - SetArgumentType(bound_function, LogicalType::VARCHAR, false); - return make_uniq(bound_function.return_type, false); + return BindConcatFunctionInternal(context, bound_function, arguments, false); } static unique_ptr BindConcatOperator(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - D_ASSERT(arguments.size() == 2); - - LogicalTypeId lhs; - LogicalTypeId rhs; - FindFirstTwoArguments(arguments, lhs, rhs); - - if (lhs == LogicalTypeId::UNKNOWN || rhs == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - if (lhs == LogicalTypeId::ARRAY || rhs == LogicalTypeId::ARRAY) { - HandleArrayBinding(context, arguments); - FindFirstTwoArguments(arguments, lhs, rhs); - } - - if (lhs == LogicalTypeId::LIST || rhs == LogicalTypeId::LIST) { - return HandleListBinding(context, bound_function, arguments, true); - } - - LogicalType return_type; - if (lhs == LogicalTypeId::BLOB && rhs == LogicalTypeId::BLOB) { - return_type = LogicalType::BLOB; - } else { - return_type = LogicalType::VARCHAR; - } - - // we can now assume that the input is a string or castable to a string - SetArgumentType(bound_function, return_type, true); - return make_uniq(bound_function.return_type, true); + return BindConcatFunctionInternal(context, bound_function, arguments, true); } static unique_ptr ListConcatStats(ClientContext &context, FunctionStatisticsInput &input) { auto &child_stats = input.child_stats; - D_ASSERT(child_stats.size() == 2); - - auto &left_stats = child_stats[0]; - auto &right_stats = child_stats[1]; - - auto stats = left_stats.ToUnique(); - stats->Merge(right_stats); - + auto stats = child_stats[0].ToUnique(); + for (idx_t i = 1; i < child_stats.size(); i++) { + stats->Merge(child_stats[i]); + } return stats; } diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index fb75309a..9e968bde 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "3-dev35" +#define DUCKDB_PATCH_VERSION "3-dev38" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 1 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.1.3-dev35" +#define DUCKDB_VERSION "v1.1.3-dev38" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "0446ab42e9" +#define DUCKDB_SOURCE_ID "52b43b1660" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp"