diff --git a/binding.gyp b/binding.gyp index 4bde6709..d3e31e68 100644 --- a/binding.gyp +++ b/binding.gyp @@ -326,7 +326,30 @@ "src/duckdb/extension/json/json_serializer.cpp", "src/duckdb/extension/json/json_deserializer.cpp", "src/duckdb/extension/json/serialize_json.cpp", - "src/duckdb/ub_extension_json_json_functions.cpp" + "src/duckdb/ub_extension_json_json_functions.cpp", + "src/duckdb/extension/core_functions/function_list.cpp", + "src/duckdb/extension/core_functions/core_functions_extension.cpp", + "src/duckdb/extension/core_functions/lambda_functions.cpp", + "src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp", + "src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp", + "src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp", + "src/duckdb/ub_extension_core_functions_aggregate_regression.cpp", + "src/duckdb/ub_extension_core_functions_aggregate_nested.cpp", + "src/duckdb/ub_extension_core_functions_scalar_bit.cpp", + "src/duckdb/ub_extension_core_functions_scalar_operators.cpp", + "src/duckdb/ub_extension_core_functions_scalar_array.cpp", + "src/duckdb/ub_extension_core_functions_scalar_date.cpp", + "src/duckdb/ub_extension_core_functions_scalar_enum.cpp", + "src/duckdb/ub_extension_core_functions_scalar_math.cpp", + "src/duckdb/ub_extension_core_functions_scalar_struct.cpp", + "src/duckdb/ub_extension_core_functions_scalar_map.cpp", + "src/duckdb/ub_extension_core_functions_scalar_list.cpp", + "src/duckdb/ub_extension_core_functions_scalar_union.cpp", + "src/duckdb/ub_extension_core_functions_scalar_generic.cpp", + "src/duckdb/ub_extension_core_functions_scalar_string.cpp", + "src/duckdb/ub_extension_core_functions_scalar_random.cpp", + "src/duckdb/ub_extension_core_functions_scalar_blob.cpp", + "src/duckdb/ub_extension_core_functions_scalar_debug.cpp" ], "include_dirs": [ " +struct AvgState { + uint64_t count; + T value; + + void Initialize() { + this->count = 0; + } + + void Combine(const AvgState &other) { + this->count += other.count; + this->value += other.value; + } +}; + +struct KahanAvgState { + uint64_t count; + double value; + double err; + + void Initialize() { + this->count = 0; + this->err = 0.0; + } + + void Combine(const KahanAvgState &other) { + this->count += other.count; + KahanAddInternal(other.value, this->value, this->err); + KahanAddInternal(other.err, this->value, this->err); + } +}; + +struct AverageDecimalBindData : public FunctionData { + explicit AverageDecimalBindData(double scale) : scale(scale) { + } + + double scale; + +public: + unique_ptr Copy() const override { + return make_uniq(scale); + }; + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return scale == other.scale; + } +}; + +struct AverageSetOperation { + template + static void Initialize(STATE &state) { + state.Initialize(); + } + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.Combine(source); + } + template + static void AddValues(STATE &state, idx_t count) { + state.count += count; + } +}; + +template +static T GetAverageDivident(uint64_t count, optional_ptr bind_data) { + T divident = T(count); + if (bind_data) { + auto &avg_bind_data = bind_data->Cast(); + divident *= avg_bind_data.scale; + } + return divident; +} + +struct IntegerAverageOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); + target = double(state.value) / divident; + } + } +}; + +struct IntegerAverageOperationHugeint : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + long double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); + target = Hugeint::Cast(state.value) / divident; + } + } +}; + +struct HugeintAverageOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + long double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); + target = Hugeint::Cast(state.value) / divident; + } + } +}; + +struct NumericAverageOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = state.value / state.count; + } + } +}; + +struct KahanAverageOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = (state.value / state.count) + (state.err / state.count); + } + } +}; + +AggregateFunction GetAverageAggregate(PhysicalType type) { + switch (type) { + case PhysicalType::INT16: { + return AggregateFunction::UnaryAggregate, int16_t, double, IntegerAverageOperation>( + LogicalType::SMALLINT, LogicalType::DOUBLE); + } + case PhysicalType::INT32: { + return AggregateFunction::UnaryAggregate, int32_t, double, IntegerAverageOperationHugeint>( + LogicalType::INTEGER, LogicalType::DOUBLE); + } + case PhysicalType::INT64: { + return AggregateFunction::UnaryAggregate, int64_t, double, IntegerAverageOperationHugeint>( + LogicalType::BIGINT, LogicalType::DOUBLE); + } + case PhysicalType::INT128: { + return AggregateFunction::UnaryAggregate, hugeint_t, double, HugeintAverageOperation>( + LogicalType::HUGEINT, LogicalType::DOUBLE); + } + default: + throw InternalException("Unimplemented average aggregate"); + } +} + +unique_ptr BindDecimalAvg(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto decimal_type = arguments[0]->return_type; + function = GetAverageAggregate(decimal_type.InternalType()); + function.name = "avg"; + function.arguments[0] = decimal_type; + function.return_type = LogicalType::DOUBLE; + return make_uniq( + Hugeint::Cast(Hugeint::POWERS_OF_TEN[DecimalType::GetScale(decimal_type)])); +} + +AggregateFunctionSet AvgFun::GetFunctions() { + AggregateFunctionSet avg; + + avg.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, + nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, + BindDecimalAvg)); + avg.AddFunction(GetAverageAggregate(PhysicalType::INT16)); + avg.AddFunction(GetAverageAggregate(PhysicalType::INT32)); + avg.AddFunction(GetAverageAggregate(PhysicalType::INT64)); + avg.AddFunction(GetAverageAggregate(PhysicalType::INT128)); + avg.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericAverageOperation>( + LogicalType::DOUBLE, LogicalType::DOUBLE)); + return avg; +} + +AggregateFunction FAvgFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp new file mode 100644 index 00000000..bf53a5ad --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp @@ -0,0 +1,13 @@ +#include "core_functions/aggregate/algebraic_functions.hpp" +#include "core_functions/aggregate/algebraic/covar.hpp" +#include "core_functions/aggregate/algebraic/stddev.hpp" +#include "core_functions/aggregate/algebraic/corr.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +AggregateFunction CorrFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp new file mode 100644 index 00000000..fddb9ed2 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp @@ -0,0 +1,17 @@ +#include "core_functions/aggregate/algebraic_functions.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "core_functions/aggregate/algebraic/covar.hpp" + +namespace duckdb { + +AggregateFunction CovarPopFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +AggregateFunction CovarSampFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp new file mode 100644 index 00000000..e9d14ee2 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp @@ -0,0 +1,34 @@ +#include "core_functions/aggregate/algebraic_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/function/function_set.hpp" +#include "core_functions/aggregate/algebraic/stddev.hpp" +#include + +namespace duckdb { + +AggregateFunction StdDevSampFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +AggregateFunction StdDevPopFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +AggregateFunction VarPopFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +AggregateFunction VarSampFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +AggregateFunction StandardErrorOfTheMeanFun::GetFunction() { + return AggregateFunction::UnaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp new file mode 100644 index 00000000..37f05b20 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp @@ -0,0 +1,99 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/types/hyperloglog.hpp" +#include "core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "hyperloglog.hpp" + +namespace duckdb { + +// Algorithms from +// "New cardinality estimation algorithms for HyperLogLog sketches" +// Otmar Ertl, arXiv:1702.01284 +struct ApproxDistinctCountState { + HyperLogLog hll; +}; + +struct ApproxCountDistinctFunction { + template + static void Initialize(STATE &state) { + new (&state) STATE(); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.hll.Merge(source.hll); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + target = UnsafeNumericCast(state.hll.Count()); + } + + static bool IgnoreNull() { + return true; + } +}; + +static void ApproxCountDistinctSimpleUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, + data_ptr_t state, idx_t count) { + D_ASSERT(input_count == 1); + auto &input = inputs[0]; + + if (count > STANDARD_VECTOR_SIZE) { + throw InternalException("ApproxCountDistinct - count must be at most vector size"); + } + Vector hash_vec(LogicalType::HASH, count); + VectorOperations::Hash(input, hash_vec, count); + + auto agg_state = reinterpret_cast(state); + agg_state->hll.Update(input, hash_vec, count); +} + +static void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, + Vector &state_vector, idx_t count) { + D_ASSERT(input_count == 1); + auto &input = inputs[0]; + UnifiedVectorFormat idata; + input.ToUnifiedFormat(count, idata); + + if (count > STANDARD_VECTOR_SIZE) { + throw InternalException("ApproxCountDistinct - count must be at most vector size"); + } + Vector hash_vec(LogicalType::HASH, count); + VectorOperations::Hash(input, hash_vec, count); + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + const auto states = UnifiedVectorFormat::GetDataNoConst(sdata); + + UnifiedVectorFormat hdata; + hash_vec.ToUnifiedFormat(count, hdata); + const auto *hashes = UnifiedVectorFormat::GetData(hdata); + for (idx_t i = 0; i < count; i++) { + if (idata.validity.RowIsValid(idata.sel->get_index(i))) { + auto agg_state = states[sdata.sel->get_index(i)]; + const auto hash = hashes[hdata.sel->get_index(i)]; + agg_state->hll.InsertElement(hash); + } + } +} + +AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) { + auto fun = AggregateFunction( + {input_type}, LogicalTypeId::BIGINT, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + ApproxCountDistinctUpdateFunction, + AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, + ApproxCountDistinctSimpleUpdateFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +AggregateFunction ApproxCountDistinctFun::GetFunction() { + return GetApproxCountDistinctFunction(LogicalType::ANY); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp new file mode 100644 index 00000000..63c112b3 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp @@ -0,0 +1,742 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/function/create_sort_key.hpp" +#include "duckdb/function/aggregate/minmax_n_helpers.hpp" + +namespace duckdb { + +struct ArgMinMaxStateBase { + ArgMinMaxStateBase() : is_initialized(false), arg_null(false) { + } + + template + static inline void CreateValue(T &value) { + } + + template + static inline void DestroyValue(T &value) { + } + + template + static inline void AssignValue(T &target, T new_value) { + target = new_value; + } + + template + static inline void ReadValue(Vector &result, T &arg, T &target) { + target = arg; + } + + bool is_initialized; + bool arg_null; +}; + +// Out-of-line specialisations +template <> +void ArgMinMaxStateBase::CreateValue(string_t &value) { + value = string_t(uint32_t(0)); +} + +template <> +void ArgMinMaxStateBase::DestroyValue(string_t &value) { + if (!value.IsInlined()) { + delete[] value.GetData(); + } +} + +template <> +void ArgMinMaxStateBase::AssignValue(string_t &target, string_t new_value) { + DestroyValue(target); + if (new_value.IsInlined()) { + target = new_value; + } else { + // non-inlined string, need to allocate space for it + auto len = new_value.GetSize(); + auto ptr = new char[len]; + memcpy(ptr, new_value.GetData(), len); + + target = string_t(ptr, UnsafeNumericCast(len)); + } +} + +template <> +void ArgMinMaxStateBase::ReadValue(Vector &result, string_t &arg, string_t &target) { + target = StringVector::AddStringOrBlob(result, arg); +} + +template +struct ArgMinMaxState : public ArgMinMaxStateBase { + using ARG_TYPE = A; + using BY_TYPE = B; + + ARG_TYPE arg; + BY_TYPE value; + + ArgMinMaxState() { + CreateValue(arg); + CreateValue(value); + } + + ~ArgMinMaxState() { + if (is_initialized) { + DestroyValue(arg); + DestroyValue(value); + is_initialized = false; + } + } +}; + +template +struct ArgMinMaxBase { + template + static void Initialize(STATE &state) { + new (&state) STATE; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.~STATE(); + } + + template + static void Assign(STATE &state, const A_TYPE &x, const B_TYPE &y, const bool x_null) { + if (IGNORE_NULL) { + STATE::template AssignValue(state.arg, x); + STATE::template AssignValue(state.value, y); + } else { + state.arg_null = x_null; + if (!state.arg_null) { + STATE::template AssignValue(state.arg, x); + } + STATE::template AssignValue(state.value, y); + } + } + + template + static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &binary) { + if (!state.is_initialized) { + if (IGNORE_NULL || binary.right_mask.RowIsValid(binary.ridx)) { + Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx)); + state.is_initialized = true; + } + } else { + OP::template Execute(state, x, y, binary); + } + } + + template + static void Execute(STATE &state, A_TYPE x_data, B_TYPE y_data, AggregateBinaryInput &binary) { + if ((IGNORE_NULL || binary.right_mask.RowIsValid(binary.ridx)) && COMPARATOR::Operation(y_data, state.value)) { + Assign(state, x_data, y_data, !binary.left_mask.RowIsValid(binary.lidx)); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.is_initialized) { + return; + } + if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { + Assign(target, source.arg, source.value, source.arg_null); + target.is_initialized = true; + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_initialized || state.arg_null) { + finalize_data.ReturnNull(); + } else { + STATE::template ReadValue(finalize_data.result, state.arg, target); + } + } + + static bool IgnoreNull() { + return IGNORE_NULL; + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { + ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type); + } + function.arguments[0] = arguments[0]->return_type; + function.return_type = arguments[0]->return_type; + return nullptr; + } +}; + +struct SpecializedGenericArgMinMaxState { + static bool CreateExtraState(idx_t count) { + // nop extra state + return false; + } + + static void PrepareData(Vector &by, idx_t count, bool &, UnifiedVectorFormat &result) { + by.ToUnifiedFormat(count, result); + } +}; + +template +struct GenericArgMinMaxState { + static Vector CreateExtraState(idx_t count) { + return Vector(LogicalType::BLOB, count); + } + + static void PrepareData(Vector &by, idx_t count, Vector &extra_state, UnifiedVectorFormat &result) { + OrderModifiers modifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST); + CreateSortKeyHelpers::CreateSortKeyWithValidity(by, extra_state, modifiers, count); + extra_state.ToUnifiedFormat(count, result); + } +}; + +template +struct VectorArgMinMaxBase : ArgMinMaxBase { + template + static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { + auto &arg = inputs[0]; + UnifiedVectorFormat adata; + arg.ToUnifiedFormat(count, adata); + + using ARG_TYPE = typename STATE::ARG_TYPE; + using BY_TYPE = typename STATE::BY_TYPE; + auto &by = inputs[1]; + UnifiedVectorFormat bdata; + auto extra_state = UPDATE_TYPE::CreateExtraState(count); + UPDATE_TYPE::PrepareData(by, count, extra_state, bdata); + const auto bys = UnifiedVectorFormat::GetData(bdata); + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + + STATE *last_state = nullptr; + sel_t assign_sel[STANDARD_VECTOR_SIZE]; + idx_t assign_count = 0; + + auto states = UnifiedVectorFormat::GetData(sdata); + for (idx_t i = 0; i < count; i++) { + const auto bidx = bdata.sel->get_index(i); + if (!bdata.validity.RowIsValid(bidx)) { + continue; + } + const auto bval = bys[bidx]; + + const auto aidx = adata.sel->get_index(i); + const auto arg_null = !adata.validity.RowIsValid(aidx); + if (IGNORE_NULL && arg_null) { + continue; + } + + const auto sidx = sdata.sel->get_index(i); + auto &state = *states[sidx]; + if (!state.is_initialized || COMPARATOR::template Operation(bval, state.value)) { + STATE::template AssignValue(state.value, bval); + state.arg_null = arg_null; + // micro-adaptivity: it is common we overwrite the same state repeatedly + // e.g. when running arg_max(val, ts) and ts is sorted in ascending order + // this check essentially says: + // "if we are overriding the same state as the last row, the last write was pointless" + // hence we skip the last write altogether + if (!arg_null) { + if (&state == last_state) { + assign_count--; + } + assign_sel[assign_count++] = UnsafeNumericCast(i); + last_state = &state; + } + state.is_initialized = true; + } + } + if (assign_count == 0) { + // no need to assign anything: nothing left to do + return; + } + Vector sort_key(LogicalType::BLOB); + auto modifiers = OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST); + // slice with a selection vector and generate sort keys + SelectionVector sel(assign_sel); + Vector sliced_input(arg, sel, assign_count); + CreateSortKeyHelpers::CreateSortKey(sliced_input, assign_count, modifiers, sort_key); + auto sort_key_data = FlatVector::GetData(sort_key); + + // now assign sort keys + for (idx_t i = 0; i < assign_count; i++) { + const auto sidx = sdata.sel->get_index(sel.get_index(i)); + auto &state = *states[sidx]; + STATE::template AssignValue(state.arg, sort_key_data[i]); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.is_initialized) { + return; + } + if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { + STATE::template AssignValue(target.value, source.value); + target.arg_null = source.arg_null; + if (!target.arg_null) { + STATE::template AssignValue(target.arg, source.arg); + } + target.is_initialized = true; + } + } + + template + static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { + if (!state.is_initialized || state.arg_null) { + finalize_data.ReturnNull(); + } else { + CreateSortKeyHelpers::DecodeSortKey(state.arg, finalize_data.result, finalize_data.result_idx, + OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST)); + } + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { + ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type); + } + function.arguments[0] = arguments[0]->return_type; + function.return_type = arguments[0]->return_type; + return nullptr; + } +}; + +template +AggregateFunction GetGenericArgMinMaxFunction() { + using STATE = ArgMinMaxState; + return AggregateFunction( + {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, OP::template Update, + AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateDestroy); +} + +template +AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { +#ifndef DUCKDB_SMALLER_BINARY + using STATE = ArgMinMaxState; + return AggregateFunction({type, by_type}, type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + OP::template Update, AggregateFunction::StateCombine, + AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateDestroy); +#else + auto function = GetGenericArgMinMaxFunction(); + function.arguments = {type, by_type}; + function.return_type = type; + return function; +#endif +} + +#ifndef DUCKDB_SMALLER_BINARY +template +AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { + switch (by_type.InternalType()) { + case PhysicalType::INT32: + return GetVectorArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::INT64: + return GetVectorArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::INT128: + return GetVectorArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::DOUBLE: + return GetVectorArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::VARCHAR: + return GetVectorArgMinMaxFunctionInternal(by_type, type); + default: + throw InternalException("Unimplemented arg_min/arg_max aggregate"); + } +} +#endif + +static const vector ArgMaxByTypes() { + vector types = {LogicalType::INTEGER, LogicalType::BIGINT, LogicalType::HUGEINT, + LogicalType::DOUBLE, LogicalType::VARCHAR, LogicalType::DATE, + LogicalType::TIMESTAMP, LogicalType::TIMESTAMP_TZ, LogicalType::BLOB}; + return types; +} + +template +void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { + auto by_types = ArgMaxByTypes(); + for (const auto &by_type : by_types) { +#ifndef DUCKDB_SMALLER_BINARY + fun.AddFunction(GetVectorArgMinMaxFunctionBy(by_type, type)); +#else + fun.AddFunction(GetVectorArgMinMaxFunctionInternal(by_type, type)); +#endif + } +} + +template +AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { +#ifndef DUCKDB_SMALLER_BINARY + using STATE = ArgMinMaxState; + auto function = + AggregateFunction::BinaryAggregate( + type, by_type, type); + if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) { + function.destructor = AggregateFunction::StateDestroy; + } + function.bind = OP::Bind; +#else + auto function = GetGenericArgMinMaxFunction(); + function.arguments = {type, by_type}; + function.return_type = type; +#endif + return function; +} + +#ifndef DUCKDB_SMALLER_BINARY +template +AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { + switch (by_type.InternalType()) { + case PhysicalType::INT32: + return GetArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::INT64: + return GetArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::INT128: + return GetArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::DOUBLE: + return GetArgMinMaxFunctionInternal(by_type, type); + case PhysicalType::VARCHAR: + return GetArgMinMaxFunctionInternal(by_type, type); + default: + throw InternalException("Unimplemented arg_min/arg_max by aggregate"); + } +} +#endif + +template +void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { + auto by_types = ArgMaxByTypes(); + for (const auto &by_type : by_types) { +#ifndef DUCKDB_SMALLER_BINARY + fun.AddFunction(GetArgMinMaxFunctionBy(by_type, type)); +#else + fun.AddFunction(GetArgMinMaxFunctionInternal(by_type, type)); +#endif + } +} + +template +static AggregateFunction GetDecimalArgMinMaxFunction(const LogicalType &by_type, const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::DECIMAL); +#ifndef DUCKDB_SMALLER_BINARY + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetArgMinMaxFunctionBy(by_type, type); + case PhysicalType::INT32: + return GetArgMinMaxFunctionBy(by_type, type); + case PhysicalType::INT64: + return GetArgMinMaxFunctionBy(by_type, type); + default: + return GetArgMinMaxFunctionBy(by_type, type); + } +#else + return GetArgMinMaxFunctionInternal(by_type, type); +#endif +} + +template +static unique_ptr BindDecimalArgMinMax(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto decimal_type = arguments[0]->return_type; + auto by_type = arguments[1]->return_type; + + // To avoid a combinatorial explosion, cast the ordering argument to one from the list + auto by_types = ArgMaxByTypes(); + idx_t best_target = DConstants::INVALID_INDEX; + int64_t lowest_cost = NumericLimits::Maximum(); + for (idx_t i = 0; i < by_types.size(); ++i) { + // Before falling back to casting, check for a physical type match for the by_type + if (by_types[i].InternalType() == by_type.InternalType()) { + lowest_cost = 0; + best_target = DConstants::INVALID_INDEX; + break; + } + + auto cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(by_type, by_types[i]); + if (cast_cost < 0) { + continue; + } + if (cast_cost < lowest_cost) { + best_target = i; + } + } + + if (best_target != DConstants::INVALID_INDEX) { + by_type = by_types[best_target]; + } + + auto name = std::move(function.name); + function = GetDecimalArgMinMaxFunction(by_type, decimal_type); + function.name = std::move(name); + function.return_type = decimal_type; + return nullptr; +} + +template +void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &by_type) { + fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, BindDecimalArgMinMax)); +} + +template +void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun) { + fun.AddFunction(GetGenericArgMinMaxFunction()); +} + +template +static void AddArgMinMaxFunctions(AggregateFunctionSet &fun) { + using GENERIC_VECTOR_OP = VectorArgMinMaxBase>; +#ifndef DUCKDB_SMALLER_BINARY + using OP = ArgMinMaxBase; + using VECTOR_OP = VectorArgMinMaxBase; +#else + using OP = GENERIC_VECTOR_OP; + using VECTOR_OP = GENERIC_VECTOR_OP; +#endif + AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER); + AddArgMinMaxFunctionBy(fun, LogicalType::BIGINT); + AddArgMinMaxFunctionBy(fun, LogicalType::DOUBLE); + AddArgMinMaxFunctionBy(fun, LogicalType::VARCHAR); + AddArgMinMaxFunctionBy(fun, LogicalType::DATE); + AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP); + AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP_TZ); + AddArgMinMaxFunctionBy(fun, LogicalType::BLOB); + + auto by_types = ArgMaxByTypes(); + for (const auto &by_type : by_types) { + AddDecimalArgMinMaxFunctionBy(fun, by_type); + } + + AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY); + + // we always use LessThan when using sort keys because the ORDER_TYPE takes care of selecting the lowest or highest + AddGenericArgMinMaxFunction(fun); +} + +//------------------------------------------------------------------------------ +// ArgMinMax(N) Function +//------------------------------------------------------------------------------ +//------------------------------------------------------------------------------ +// State +//------------------------------------------------------------------------------ + +template +class ArgMinMaxNState { +public: + using VAL_TYPE = A; + using ARG_TYPE = B; + + using V = typename VAL_TYPE::TYPE; + using K = typename ARG_TYPE::TYPE; + + BinaryAggregateHeap heap; + + bool is_initialized = false; + void Initialize(idx_t nval) { + heap.Initialize(nval); + is_initialized = true; + } +}; + +//------------------------------------------------------------------------------ +// Operation +//------------------------------------------------------------------------------ +template +static void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, + idx_t count) { + + auto &val_vector = inputs[0]; + auto &arg_vector = inputs[1]; + auto &n_vector = inputs[2]; + + UnifiedVectorFormat val_format; + UnifiedVectorFormat arg_format; + UnifiedVectorFormat n_format; + UnifiedVectorFormat state_format; + + auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); + auto arg_extra_state = STATE::ARG_TYPE::CreateExtraState(arg_vector, count); + + STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); + STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format); + + n_vector.ToUnifiedFormat(count, n_format); + state_vector.ToUnifiedFormat(count, state_format); + + auto states = UnifiedVectorFormat::GetData(state_format); + + for (idx_t i = 0; i < count; i++) { + const auto arg_idx = arg_format.sel->get_index(i); + const auto val_idx = val_format.sel->get_index(i); + if (!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx)) { + continue; + } + const auto state_idx = state_format.sel->get_index(i); + auto &state = *states[state_idx]; + + // Initialize the heap if necessary and add the input to the heap + if (!state.is_initialized) { + static constexpr int64_t MAX_N = 1000000; + const auto nidx = n_format.sel->get_index(i); + if (!n_format.validity.RowIsValid(nidx)) { + throw InvalidInputException("Invalid input for arg_min/arg_max: n value cannot be NULL"); + } + const auto nval = UnifiedVectorFormat::GetData(n_format)[nidx]; + if (nval <= 0) { + throw InvalidInputException("Invalid input for arg_min/arg_max: n value must be > 0"); + } + if (nval >= MAX_N) { + throw InvalidInputException("Invalid input for arg_min/arg_max: n value must be < %d", MAX_N); + } + state.Initialize(UnsafeNumericCast(nval)); + } + + // Now add the input to the heap + auto arg_val = STATE::ARG_TYPE::Create(arg_format, arg_idx); + auto val_val = STATE::VAL_TYPE::Create(val_format, val_idx); + + state.heap.Insert(aggr_input.allocator, arg_val, val_val); + } +} + +//------------------------------------------------------------------------------ +// Bind +//------------------------------------------------------------------------------ +template +static void SpecializeArgMinMaxNFunction(AggregateFunction &function) { + using STATE = ArgMinMaxNState; + using OP = MinMaxNOperation; + + function.state_size = AggregateFunction::StateSize; + function.initialize = AggregateFunction::StateInitialize; + function.combine = AggregateFunction::StateCombine; + function.destructor = AggregateFunction::StateDestroy; + + function.finalize = MinMaxNOperation::Finalize; + function.update = ArgMinMaxNUpdate; +} + +template +static void SpecializeArgMinMaxNFunction(PhysicalType arg_type, AggregateFunction &function) { + switch (arg_type) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::VARCHAR: + SpecializeArgMinMaxNFunction(function); + break; + case PhysicalType::INT32: + SpecializeArgMinMaxNFunction, COMPARATOR>(function); + break; + case PhysicalType::INT64: + SpecializeArgMinMaxNFunction, COMPARATOR>(function); + break; + case PhysicalType::FLOAT: + SpecializeArgMinMaxNFunction, COMPARATOR>(function); + break; + case PhysicalType::DOUBLE: + SpecializeArgMinMaxNFunction, COMPARATOR>(function); + break; +#endif + default: + SpecializeArgMinMaxNFunction(function); + break; + } +} + +template +static void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) { + switch (val_type) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::VARCHAR: + SpecializeArgMinMaxNFunction(arg_type, function); + break; + case PhysicalType::INT32: + SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); + break; + case PhysicalType::INT64: + SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); + break; + case PhysicalType::FLOAT: + SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); + break; + case PhysicalType::DOUBLE: + SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); + break; +#endif + default: + SpecializeArgMinMaxNFunction(arg_type, function); + break; + } +} + +template +unique_ptr ArgMinMaxNBind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + for (auto &arg : arguments) { + if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + } + + const auto val_type = arguments[0]->return_type.InternalType(); + const auto arg_type = arguments[1]->return_type.InternalType(); + + // Specialize the function based on the input types + SpecializeArgMinMaxNFunction(val_type, arg_type, function); + + function.return_type = LogicalType::LIST(arguments[0]->return_type); + return nullptr; +} + +template +static void AddArgMinMaxNFunction(AggregateFunctionSet &set) { + AggregateFunction function({LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::ANY), nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, ArgMinMaxNBind); + + return set.AddFunction(function); +} + +//------------------------------------------------------------------------------ +// Function Registration +//------------------------------------------------------------------------------ + +AggregateFunctionSet ArgMinFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun); + AddArgMinMaxNFunction(fun); + return fun; +} + +AggregateFunctionSet ArgMaxFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun); + AddArgMinMaxNFunction(fun); + return fun; +} + +AggregateFunctionSet ArgMinNullFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun); + return fun; +} + +AggregateFunctionSet ArgMaxNullFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun); + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp new file mode 100644 index 00000000..241d2569 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp @@ -0,0 +1,231 @@ +#include "core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/aggregate_executor.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/cast_helpers.hpp" + +namespace duckdb { + +template +struct BitState { + using TYPE = T; + bool is_set; + T value; +}; + +template +static AggregateFunction GetBitfieldUnaryAggregate(LogicalType type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); + case LogicalTypeId::SMALLINT: + return AggregateFunction::UnaryAggregate, int16_t, int16_t, OP>(type, type); + case LogicalTypeId::INTEGER: + return AggregateFunction::UnaryAggregate, int32_t, int32_t, OP>(type, type); + case LogicalTypeId::BIGINT: + return AggregateFunction::UnaryAggregate, int64_t, int64_t, OP>(type, type); + case LogicalTypeId::HUGEINT: + return AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, OP>(type, type); + case LogicalTypeId::UTINYINT: + return AggregateFunction::UnaryAggregate, uint8_t, uint8_t, OP>(type, type); + case LogicalTypeId::USMALLINT: + return AggregateFunction::UnaryAggregate, uint16_t, uint16_t, OP>(type, type); + case LogicalTypeId::UINTEGER: + return AggregateFunction::UnaryAggregate, uint32_t, uint32_t, OP>(type, type); + case LogicalTypeId::UBIGINT: + return AggregateFunction::UnaryAggregate, uint64_t, uint64_t, OP>(type, type); + case LogicalTypeId::UHUGEINT: + return AggregateFunction::UnaryAggregate, uhugeint_t, uhugeint_t, OP>(type, type); + default: + throw InternalException("Unimplemented bitfield type for unary aggregate"); + } +} + +struct BitwiseOperation { + template + static void Initialize(STATE &state) { + // If there are no matching rows, returns a null value. + state.is_set = false; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + if (!state.is_set) { + OP::template Assign(state, input); + state.is_set = true; + } else { + OP::template Execute(state, input); + } + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + OP::template Operation(state, input, unary_input); + } + + template + static void Assign(STATE &state, INPUT_TYPE input) { + state.value = typename STATE::TYPE(input); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.is_set) { + // source is NULL, nothing to do. + return; + } + if (!target.is_set) { + // target is NULL, use source value directly. + OP::template Assign(target, source.value); + target.is_set = true; + } else { + OP::template Execute(target, source.value); + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set) { + finalize_data.ReturnNull(); + } else { + target = T(state.value); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct BitAndOperation : public BitwiseOperation { + template + static void Execute(STATE &state, INPUT_TYPE input) { + state.value &= typename STATE::TYPE(input); + ; + } +}; + +struct BitOrOperation : public BitwiseOperation { + template + static void Execute(STATE &state, INPUT_TYPE input) { + state.value |= typename STATE::TYPE(input); + ; + } +}; + +struct BitXorOperation : public BitwiseOperation { + template + static void Execute(STATE &state, INPUT_TYPE input) { + state.value ^= typename STATE::TYPE(input); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } +}; + +struct BitStringBitwiseOperation : public BitwiseOperation { + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.is_set && !state.value.IsInlined()) { + delete[] state.value.GetData(); + } + } + + template + static void Assign(STATE &state, INPUT_TYPE input) { + D_ASSERT(state.is_set == false); + if (input.IsInlined()) { + state.value = input; + } else { // non-inlined string, need to allocate space for it + auto len = input.GetSize(); + auto ptr = new char[len]; + memcpy(ptr, input.GetData(), len); + + state.value = string_t(ptr, UnsafeNumericCast(len)); + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set) { + finalize_data.ReturnNull(); + } else { + target = finalize_data.ReturnString(state.value); + } + } +}; + +struct BitStringAndOperation : public BitStringBitwiseOperation { + + template + static void Execute(STATE &state, INPUT_TYPE input) { + Bit::BitwiseAnd(input, state.value, state.value); + } +}; + +struct BitStringOrOperation : public BitStringBitwiseOperation { + + template + static void Execute(STATE &state, INPUT_TYPE input) { + Bit::BitwiseOr(input, state.value, state.value); + } +}; + +struct BitStringXorOperation : public BitStringBitwiseOperation { + template + static void Execute(STATE &state, INPUT_TYPE input) { + Bit::BitwiseXor(input, state.value, state.value); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } +}; + +AggregateFunctionSet BitAndFun::GetFunctions() { + AggregateFunctionSet bit_and; + for (auto &type : LogicalType::Integral()) { + bit_and.AddFunction(GetBitfieldUnaryAggregate(type)); + } + + bit_and.AddFunction( + AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringAndOperation>( + LogicalType::BIT, LogicalType::BIT)); + return bit_and; +} + +AggregateFunctionSet BitOrFun::GetFunctions() { + AggregateFunctionSet bit_or; + for (auto &type : LogicalType::Integral()) { + bit_or.AddFunction(GetBitfieldUnaryAggregate(type)); + } + bit_or.AddFunction( + AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringOrOperation>( + LogicalType::BIT, LogicalType::BIT)); + return bit_or; +} + +AggregateFunctionSet BitXorFun::GetFunctions() { + AggregateFunctionSet bit_xor; + for (auto &type : LogicalType::Integral()) { + bit_xor.AddFunction(GetBitfieldUnaryAggregate(type)); + } + bit_xor.AddFunction( + AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringXorOperation>( + LogicalType::BIT, LogicalType::BIT)); + return bit_xor; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp new file mode 100644 index 00000000..c9e39983 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp @@ -0,0 +1,320 @@ +#include "core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/vector_operations/aggregate_executor.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/uhugeint.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" + +namespace duckdb { + +template +struct BitAggState { + bool is_set; + string_t value; + INPUT_TYPE min; + INPUT_TYPE max; +}; + +struct BitstringAggBindData : public FunctionData { + Value min; + Value max; + + BitstringAggBindData() { + } + + BitstringAggBindData(Value min, Value max) : min(std::move(min)), max(std::move(max)) { + } + + unique_ptr Copy() const override { + return make_uniq(*this); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + if (min.IsNull() && other.min.IsNull() && max.IsNull() && other.max.IsNull()) { + return true; + } + if (Value::NotDistinctFrom(min, other.min) && Value::NotDistinctFrom(max, other.max)) { + return true; + } + return false; + } + + static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "min", bind_data.min); + serializer.WriteProperty(101, "max", bind_data.max); + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &) { + Value min; + Value max; + deserializer.ReadProperty(100, "min", min); + deserializer.ReadProperty(101, "max", max); + return make_uniq(min, max); + } +}; + +struct BitStringAggOperation { + static constexpr const idx_t MAX_BIT_RANGE = 1000000000; // for now capped at 1 billion bits + + template + static void Initialize(STATE &state) { + state.is_set = false; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + auto &bind_agg_data = unary_input.input.bind_data->template Cast(); + if (!state.is_set) { + if (bind_agg_data.min.IsNull() || bind_agg_data.max.IsNull()) { + throw BinderException( + "Could not retrieve required statistics. Alternatively, try by providing the statistics " + "explicitly: BITSTRING_AGG(col, min, max) "); + } + state.min = bind_agg_data.min.GetValue(); + state.max = bind_agg_data.max.GetValue(); + if (state.min > state.max) { + throw InvalidInputException("Invalid explicit bitstring range: Minimum (%s) > maximum (%s)", + NumericHelper::ToString(state.min), NumericHelper::ToString(state.max)); + } + idx_t bit_range = + GetRange(bind_agg_data.min.GetValue(), bind_agg_data.max.GetValue()); + if (bit_range > MAX_BIT_RANGE) { + throw OutOfRangeException( + "The range between min and max value (%s <-> %s) is too large for bitstring aggregation", + NumericHelper::ToString(state.min), NumericHelper::ToString(state.max)); + } + idx_t len = Bit::ComputeBitstringLen(bit_range); + auto target = len > string_t::INLINE_LENGTH ? string_t(new char[len], UnsafeNumericCast(len)) + : string_t(UnsafeNumericCast(len)); + Bit::SetEmptyBitString(target, bit_range); + + state.value = target; + state.is_set = true; + } + if (input >= state.min && input <= state.max) { + Execute(state, input, bind_agg_data.min.GetValue()); + } else { + throw OutOfRangeException("Value %s is outside of provided min and max range (%s <-> %s)", + NumericHelper::ToString(input), NumericHelper::ToString(state.min), + NumericHelper::ToString(state.max)); + } + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + OP::template Operation(state, input, unary_input); + } + + template + static idx_t GetRange(INPUT_TYPE min, INPUT_TYPE max) { + if (min > max) { + throw InvalidInputException("Invalid explicit bitstring range: Minimum (%d) > maximum (%d)", min, max); + } + INPUT_TYPE result; + if (!TrySubtractOperator::Operation(max, min, result)) { + return NumericLimits::Maximum(); + } + auto val = NumericCast(result); + if (val == NumericLimits::Maximum()) { + return val; + } + return val + 1; + } + + template + static void Execute(STATE &state, INPUT_TYPE input, INPUT_TYPE min) { + Bit::SetBit(state.value, UnsafeNumericCast(input - min), 1); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.is_set) { + return; + } + if (!target.is_set) { + Assign(target, source.value); + target.is_set = true; + target.min = source.min; + target.max = source.max; + } else { + Bit::BitwiseOr(source.value, target.value, target.value); + } + } + + template + static void Assign(STATE &state, INPUT_TYPE input) { + D_ASSERT(state.is_set == false); + if (input.IsInlined()) { + state.value = input; + } else { // non-inlined string, need to allocate space for it + auto len = input.GetSize(); + auto ptr = new char[len]; + memcpy(ptr, input.GetData(), len); + state.value = string_t(ptr, UnsafeNumericCast(len)); + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set) { + finalize_data.ReturnNull(); + } else { + target = StringVector::AddStringOrBlob(finalize_data.result, state.value); + } + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.is_set && !state.value.IsInlined()) { + delete[] state.value.GetData(); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +template <> +void BitStringAggOperation::Execute(BitAggState &state, hugeint_t input, hugeint_t min) { + idx_t val; + if (Hugeint::TryCast(input - min, val)) { + Bit::SetBit(state.value, val, 1); + } else { + throw OutOfRangeException("Range too large for bitstring aggregation"); + } +} + +template <> +idx_t BitStringAggOperation::GetRange(hugeint_t min, hugeint_t max) { + hugeint_t result; + if (!TrySubtractOperator::Operation(max, min, result)) { + return NumericLimits::Maximum(); + } + idx_t range; + if (!Hugeint::TryCast(result + 1, range) || result == NumericLimits::Maximum()) { + return NumericLimits::Maximum(); + } + return range; +} + +template <> +void BitStringAggOperation::Execute(BitAggState &state, uhugeint_t input, uhugeint_t min) { + idx_t val; + if (Uhugeint::TryCast(input - min, val)) { + Bit::SetBit(state.value, val, 1); + } else { + throw OutOfRangeException("Range too large for bitstring aggregation"); + } +} + +template <> +idx_t BitStringAggOperation::GetRange(uhugeint_t min, uhugeint_t max) { + uhugeint_t result; + if (!TrySubtractOperator::Operation(max, min, result)) { + return NumericLimits::Maximum(); + } + idx_t range; + if (!Uhugeint::TryCast(result + 1, range) || result == NumericLimits::Maximum()) { + return NumericLimits::Maximum(); + } + return range; +} + +unique_ptr BitstringPropagateStats(ClientContext &context, BoundAggregateExpression &expr, + AggregateStatisticsInput &input) { + + if (NumericStats::HasMinMax(input.child_stats[0])) { + auto &bind_agg_data = input.bind_data->Cast(); + bind_agg_data.min = NumericStats::Min(input.child_stats[0]); + bind_agg_data.max = NumericStats::Max(input.child_stats[0]); + } + return nullptr; +} + +unique_ptr BindBitstringAgg(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + if (arguments.size() == 3) { + if (!arguments[1]->IsFoldable() || !arguments[2]->IsFoldable()) { + throw BinderException("bitstring_agg requires a constant min and max argument"); + } + auto min = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + auto max = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); + Function::EraseArgument(function, arguments, 2); + Function::EraseArgument(function, arguments, 1); + return make_uniq(min, max); + } + return make_uniq(); +} + +template +static void BindBitString(AggregateFunctionSet &bitstring_agg, const LogicalTypeId &type) { + auto function = + AggregateFunction::UnaryAggregateDestructor, TYPE, string_t, BitStringAggOperation>( + type, LogicalType::BIT); + function.bind = BindBitstringAgg; // create new a 'BitstringAggBindData' + function.serialize = BitstringAggBindData::Serialize; + function.deserialize = BitstringAggBindData::Deserialize; + function.statistics = BitstringPropagateStats; // stores min and max from column stats in BitstringAggBindData + bitstring_agg.AddFunction(function); // uses the BitstringAggBindData to access statistics for creating bitstring + function.arguments = {type, type, type}; + function.statistics = nullptr; // min and max are provided as arguments + bitstring_agg.AddFunction(function); +} + +void GetBitStringAggregate(const LogicalType &type, AggregateFunctionSet &bitstring_agg) { + switch (type.id()) { + case LogicalType::TINYINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::SMALLINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::INTEGER: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::BIGINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::HUGEINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::UTINYINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::USMALLINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::UINTEGER: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::UBIGINT: { + return BindBitString(bitstring_agg, type.id()); + } + case LogicalType::UHUGEINT: { + return BindBitString(bitstring_agg, type.id()); + } + default: + throw InternalException("Unimplemented bitstring aggregate"); + } +} + +AggregateFunctionSet BitstringAggFun::GetFunctions() { + AggregateFunctionSet bitstring_agg("bitstring_agg"); + for (auto &type : LogicalType::Integral()) { + GetBitStringAggregate(type, bitstring_agg); + } + return bitstring_agg; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp new file mode 100644 index 00000000..9b781f84 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp @@ -0,0 +1,110 @@ +#include "core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct BoolState { + bool empty; + bool val; +}; + +struct BoolAndFunFunction { + template + static void Initialize(STATE &state) { + state.val = true; + state.empty = true; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.val = target.val && source.val; + target.empty = target.empty && source.empty; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.empty) { + finalize_data.ReturnNull(); + return; + } + target = state.val; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + state.empty = false; + state.val = input && state.val; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + static bool IgnoreNull() { + return true; + } +}; + +struct BoolOrFunFunction { + template + static void Initialize(STATE &state) { + state.val = false; + state.empty = true; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.val = target.val || source.val; + target.empty = target.empty && source.empty; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.empty) { + finalize_data.ReturnNull(); + return; + } + target = state.val; + } + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + state.empty = false; + state.val = input || state.val; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction BoolOrFun::GetFunction() { + auto fun = AggregateFunction::UnaryAggregate( + LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; + return fun; +} + +AggregateFunction BoolAndFun::GetFunction() { + auto fun = AggregateFunction::UnaryAggregate( + LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp new file mode 100644 index 00000000..4f9f6f30 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp @@ -0,0 +1,113 @@ +#include "core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/algorithm.hpp" + +namespace duckdb { + +struct KurtosisState { + idx_t n; + double sum; + double sum_sqr; + double sum_cub; + double sum_four; +}; + +struct KurtosisFlagBiasCorrection {}; + +struct KurtosisFlagNoBiasCorrection {}; + +template +struct KurtosisOperation { + template + static void Initialize(STATE &state) { + state.n = 0; + state.sum = state.sum_sqr = state.sum_cub = state.sum_four = 0.0; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + state.n++; + state.sum += input; + state.sum_sqr += pow(input, 2); + state.sum_cub += pow(input, 3); + state.sum_four += pow(input, 4); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.n == 0) { + return; + } + target.n += source.n; + target.sum += source.sum; + target.sum_sqr += source.sum_sqr; + target.sum_cub += source.sum_cub; + target.sum_four += source.sum_four; + } + + template + static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { + auto n = (double)state.n; + if (n <= 1) { + finalize_data.ReturnNull(); + return; + } + if (std::is_same::value && n <= 3) { + finalize_data.ReturnNull(); + return; + } + double temp = 1 / n; + //! This is necessary due to linux 32 bits + long double temp_aux = 1 / n; + if (state.sum_sqr - state.sum * state.sum * temp == 0 || + state.sum_sqr - state.sum * state.sum * temp_aux == 0) { + finalize_data.ReturnNull(); + return; + } + double m4 = + temp * (state.sum_four - 4 * state.sum_cub * state.sum * temp + + 6 * state.sum_sqr * state.sum * state.sum * temp * temp - 3 * pow(state.sum, 4) * pow(temp, 3)); + + double m2 = temp * (state.sum_sqr - state.sum * state.sum * temp); + if (m2 <= 0) { // m2 shouldn't be below 0 but floating points are weird + finalize_data.ReturnNull(); + return; + } + if (std::is_same::value) { + target = m4 / (m2 * m2) - 3; + } else { + target = (n - 1) * ((n + 1) * m4 / (m2 * m2) - 3 * (n - 1)) / ((n - 2) * (n - 3)); + } + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("Kurtosis is out of range!"); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction KurtosisFun::GetFunction() { + return AggregateFunction::UnaryAggregate>(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +AggregateFunction KurtosisPopFun::GetFunction() { + return AggregateFunction::UnaryAggregate>(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/product.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/product.cpp new file mode 100644 index 00000000..324893f6 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/distributive/product.cpp @@ -0,0 +1,61 @@ +#include "core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct ProductState { + bool empty; + double val; +}; + +struct ProductFunction { + template + static void Initialize(STATE &state) { + state.val = 1; + state.empty = true; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.val *= source.val; + target.empty = target.empty && source.empty; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.empty) { + finalize_data.ReturnNull(); + return; + } + target = state.val; + } + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + if (state.empty) { + state.empty = false; + } + state.val *= input; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction ProductFun::GetFunction() { + return AggregateFunction::UnaryAggregate( + LogicalType(LogicalTypeId::DOUBLE), LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/skew.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/skew.cpp new file mode 100644 index 00000000..12f23761 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/distributive/skew.cpp @@ -0,0 +1,86 @@ +#include "core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/algorithm.hpp" + +namespace duckdb { + +struct SkewState { + size_t n; + double sum; + double sum_sqr; + double sum_cub; +}; + +struct SkewnessOperation { + template + static void Initialize(STATE &state) { + state.n = 0; + state.sum = state.sum_sqr = state.sum_cub = 0; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + state.n++; + state.sum += input; + state.sum_sqr += pow(input, 2); + state.sum_cub += pow(input, 3); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.n == 0) { + return; + } + + target.n += source.n; + target.sum += source.sum; + target.sum_sqr += source.sum_sqr; + target.sum_cub += source.sum_cub; + } + + template + static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { + if (state.n <= 2) { + finalize_data.ReturnNull(); + return; + } + double n = state.n; + double temp = 1 / n; + auto p = std::pow(temp * (state.sum_sqr - state.sum * state.sum * temp), 3); + if (p < 0) { + p = 0; // Shouldn't be below 0 but floating points are weird + } + double div = std::sqrt(p); + if (div == 0) { + target = NAN; + return; + } + double temp1 = std::sqrt(n * (n - 1)) / (n - 2); + target = temp1 * temp * + (state.sum_cub - 3 * state.sum_sqr * state.sum * temp + 2 * pow(state.sum, 3) * temp * temp) / div; + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("SKEW is out of range!"); + } + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction SkewnessFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp new file mode 100644 index 00000000..b694a236 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp @@ -0,0 +1,175 @@ +#include "core_functions/aggregate/distributive_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/null_value.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +struct StringAggState { + idx_t size; + idx_t alloc_size; + char *dataptr; +}; + +struct StringAggBindData : public FunctionData { + explicit StringAggBindData(string sep_p) : sep(std::move(sep_p)) { + } + + string sep; + + unique_ptr Copy() const override { + return make_uniq(sep); + } + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return sep == other.sep; + } +}; + +struct StringAggFunction { + template + static void Initialize(STATE &state) { + state.dataptr = nullptr; + state.alloc_size = 0; + state.size = 0; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.dataptr) { + finalize_data.ReturnNull(); + } else { + target = StringVector::AddString(finalize_data.result, state.dataptr, state.size); + } + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.dataptr) { + delete[] state.dataptr; + } + } + + static bool IgnoreNull() { + return true; + } + + static inline void PerformOperation(StringAggState &state, const char *str, const char *sep, idx_t str_size, + idx_t sep_size) { + if (!state.dataptr) { + // first iteration: allocate space for the string and copy it into the state + state.alloc_size = MaxValue(8, NextPowerOfTwo(str_size)); + state.dataptr = new char[state.alloc_size]; + state.size = str_size; + memcpy(state.dataptr, str, str_size); + } else { + // subsequent iteration: first check if we have space to place the string and separator + idx_t required_size = state.size + str_size + sep_size; + if (required_size > state.alloc_size) { + // no space! allocate extra space + while (state.alloc_size < required_size) { + state.alloc_size *= 2; + } + auto new_data = new char[state.alloc_size]; + memcpy(new_data, state.dataptr, state.size); + delete[] state.dataptr; + state.dataptr = new_data; + } + // copy the separator + memcpy(state.dataptr + state.size, sep, sep_size); + state.size += sep_size; + // copy the string + memcpy(state.dataptr + state.size, str, str_size); + state.size += str_size; + } + } + + static inline void PerformOperation(StringAggState &state, string_t str, optional_ptr data_p) { + auto &data = data_p->Cast(); + PerformOperation(state, str.GetData(), data.sep.c_str(), str.GetSize(), data.sep.size()); + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + PerformOperation(state, input, unary_input.input.bind_data); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + if (!source.dataptr) { + // source is not set: skip combining + return; + } + PerformOperation(target, string_t(source.dataptr, UnsafeNumericCast(source.size)), + aggr_input_data.bind_data); + } +}; + +unique_ptr StringAggBind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + if (arguments.size() == 1) { + // single argument: default to comma + return make_uniq(","); + } + D_ASSERT(arguments.size() == 2); + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw BinderException("Separator argument to StringAgg must be a constant"); + } + auto separator_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + string separator_string = ","; + if (separator_val.IsNull()) { + arguments[0] = make_uniq(Value(LogicalType::VARCHAR)); + } else { + separator_string = separator_val.ToString(); + } + Function::EraseArgument(function, arguments, arguments.size() - 1); + return make_uniq(std::move(separator_string)); +} + +static void StringAggSerialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "separator", bind_data.sep); +} + +unique_ptr StringAggDeserialize(Deserializer &deserializer, AggregateFunction &bound_function) { + auto sep = deserializer.ReadProperty(100, "separator"); + return make_uniq(std::move(sep)); +} + +AggregateFunctionSet StringAggFun::GetFunctions() { + AggregateFunctionSet string_agg; + AggregateFunction string_agg_param( + {LogicalType::ANY_PARAMS(LogicalType::VARCHAR)}, LogicalType::VARCHAR, + AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + AggregateFunction::UnaryScatterUpdate, + AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, + AggregateFunction::UnaryUpdate, StringAggBind, + AggregateFunction::StateDestroy); + string_agg_param.serialize = StringAggSerialize; + string_agg_param.deserialize = StringAggDeserialize; + string_agg.AddFunction(string_agg_param); + string_agg_param.arguments.emplace_back(LogicalType::VARCHAR); + string_agg.AddFunction(string_agg_param); + return string_agg; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp new file mode 100644 index 00000000..be37d5df --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp @@ -0,0 +1,245 @@ +#include "core_functions/aggregate/distributive_functions.hpp" +#include "core_functions/aggregate/sum_helpers.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +struct SumSetOperation { + template + static void Initialize(STATE &state) { + state.Initialize(); + } + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.Combine(source); + } + template + static void AddValues(STATE &state, idx_t count) { + state.isset = true; + } +}; + +struct IntegerSumOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.isset) { + finalize_data.ReturnNull(); + } else { + target = Hugeint::Convert(state.value); + } + } +}; + +struct SumToHugeintOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.isset) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } +}; + +template +struct DoubleSumOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.isset) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } +}; + +using NumericSumOperation = DoubleSumOperation; +using KahanSumOperation = DoubleSumOperation; + +struct HugeintSumOperation : public BaseSumOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.isset) { + finalize_data.ReturnNull(); + } else { + target = state.value; + } + } +}; + +unique_ptr SumNoOverflowBind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + throw BinderException("sum_no_overflow is for internal use only!"); +} + +void SumNoOverflowSerialize(Serializer &serializer, const optional_ptr bind_data, + const AggregateFunction &function) { + return; +} + +unique_ptr SumNoOverflowDeserialize(Deserializer &deserializer, AggregateFunction &function) { + function.return_type = deserializer.Get(); + return nullptr; +} + +AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) { + switch (type) { + case PhysicalType::INT32: { + auto function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, IntegerSumOperation>( + LogicalType::INTEGER, LogicalType::HUGEINT); + function.name = "sum_no_overflow"; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + function.bind = SumNoOverflowBind; + function.serialize = SumNoOverflowSerialize; + function.deserialize = SumNoOverflowDeserialize; + return function; + } + case PhysicalType::INT64: { + auto function = AggregateFunction::UnaryAggregate, int64_t, hugeint_t, IntegerSumOperation>( + LogicalType::BIGINT, LogicalType::HUGEINT); + function.name = "sum_no_overflow"; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + function.bind = SumNoOverflowBind; + function.serialize = SumNoOverflowSerialize; + function.deserialize = SumNoOverflowDeserialize; + return function; + } + default: + throw BinderException("Unsupported internal type for sum_no_overflow"); + } +} + +AggregateFunction GetSumAggregateNoOverflowDecimal() { + AggregateFunction aggr({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, nullptr, + nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, SumNoOverflowBind); + aggr.serialize = SumNoOverflowSerialize; + aggr.deserialize = SumNoOverflowDeserialize; + return aggr; +} + +unique_ptr SumPropagateStats(ClientContext &context, BoundAggregateExpression &expr, + AggregateStatisticsInput &input) { + if (input.node_stats && input.node_stats->has_max_cardinality) { + auto &numeric_stats = input.child_stats[0]; + if (!NumericStats::HasMinMax(numeric_stats)) { + return nullptr; + } + auto internal_type = numeric_stats.GetType().InternalType(); + hugeint_t max_negative; + hugeint_t max_positive; + switch (internal_type) { + case PhysicalType::INT32: + max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe(); + max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe(); + break; + case PhysicalType::INT64: + max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe(); + max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe(); + break; + default: + throw InternalException("Unsupported type for propagate sum stats"); + } + auto max_sum_negative = max_negative * Hugeint::Convert(input.node_stats->max_cardinality); + auto max_sum_positive = max_positive * Hugeint::Convert(input.node_stats->max_cardinality); + if (max_sum_positive >= NumericLimits::Maximum() || + max_sum_negative <= NumericLimits::Minimum()) { + // sum can potentially exceed int64_t bounds: use hugeint sum + return nullptr; + } + // total sum is guaranteed to fit in a single int64: use int64 sum instead of hugeint sum + expr.function = GetSumAggregateNoOverflow(internal_type); + } + return nullptr; +} + +AggregateFunction GetSumAggregate(PhysicalType type) { + switch (type) { + case PhysicalType::BOOL: { + auto function = AggregateFunction::UnaryAggregate, bool, hugeint_t, IntegerSumOperation>( + LogicalType::BOOLEAN, LogicalType::HUGEINT); + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + case PhysicalType::INT16: { + auto function = AggregateFunction::UnaryAggregate, int16_t, hugeint_t, IntegerSumOperation>( + LogicalType::SMALLINT, LogicalType::HUGEINT); + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + + case PhysicalType::INT32: { + auto function = + AggregateFunction::UnaryAggregate, int32_t, hugeint_t, SumToHugeintOperation>( + LogicalType::INTEGER, LogicalType::HUGEINT); + function.statistics = SumPropagateStats; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + case PhysicalType::INT64: { + auto function = + AggregateFunction::UnaryAggregate, int64_t, hugeint_t, SumToHugeintOperation>( + LogicalType::BIGINT, LogicalType::HUGEINT); + function.statistics = SumPropagateStats; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + case PhysicalType::INT128: { + auto function = + AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, HugeintSumOperation>( + LogicalType::HUGEINT, LogicalType::HUGEINT); + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return function; + } + default: + throw InternalException("Unimplemented sum aggregate"); + } +} + +unique_ptr BindDecimalSum(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto decimal_type = arguments[0]->return_type; + function = GetSumAggregate(decimal_type.InternalType()); + function.name = "sum"; + function.arguments[0] = decimal_type; + function.return_type = LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type)); + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return nullptr; +} + +AggregateFunctionSet SumFun::GetFunctions() { + AggregateFunctionSet sum; + // decimal + sum.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, + nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, + BindDecimalSum)); + sum.AddFunction(GetSumAggregate(PhysicalType::BOOL)); + sum.AddFunction(GetSumAggregate(PhysicalType::INT16)); + sum.AddFunction(GetSumAggregate(PhysicalType::INT32)); + sum.AddFunction(GetSumAggregate(PhysicalType::INT64)); + sum.AddFunction(GetSumAggregate(PhysicalType::INT128)); + sum.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericSumOperation>( + LogicalType::DOUBLE, LogicalType::DOUBLE)); + return sum; +} + +AggregateFunction CountIfFun::GetFunction() { + return GetSumAggregate(PhysicalType::BOOL); +} + +AggregateFunctionSet SumNoOverflowFun::GetFunctions() { + AggregateFunctionSet sum_no_overflow; + sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT32)); + sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT64)); + sum_no_overflow.AddFunction(GetSumAggregateNoOverflowDecimal()); + return sum_no_overflow; +} + +AggregateFunction KahanSumFun::GetFunction() { + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp new file mode 100644 index 00000000..4eb2b9d3 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp @@ -0,0 +1,413 @@ +#include "core_functions/aggregate/histogram_helpers.hpp" +#include "core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/function/aggregate/sort_key_helpers.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/string_map_set.hpp" +#include "duckdb/common/printer.hpp" + +namespace duckdb { + +struct ApproxTopKString { + ApproxTopKString() : str(UINT32_C(0)), hash(0) { + } + ApproxTopKString(string_t str_p, hash_t hash_p) : str(str_p), hash(hash_p) { + } + + string_t str; + hash_t hash; +}; + +struct ApproxTopKHash { + std::size_t operator()(const ApproxTopKString &k) const { + return k.hash; + } +}; + +struct ApproxTopKEquality { + bool operator()(const ApproxTopKString &a, const ApproxTopKString &b) const { + return Equals::Operation(a.str, b.str); + } +}; + +template +using approx_topk_map_t = unordered_map; + +// approx top k algorithm based on "A parallel space saving algorithm for frequent items and the Hurwitz zeta +// distribution" arxiv link - https://arxiv.org/pdf/1401.0702 +// together with the filter extension (Filtered Space-Saving) from "Estimating Top-k Destinations in Data Streams" +struct ApproxTopKValue { + //! The counter + idx_t count = 0; + //! Index in the values array + idx_t index = 0; + //! The string value + ApproxTopKString str_val; + //! Allocated data + char *dataptr = nullptr; + uint32_t size = 0; + uint32_t capacity = 0; +}; + +struct InternalApproxTopKState { + // the top-k data structure has two components + // a list of k values sorted on "count" (i.e. values[0] has the lowest count) + // a lookup map: string_t -> idx in "values" array + unsafe_unique_array stored_values; + unsafe_vector> values; + approx_topk_map_t> lookup_map; + unsafe_vector filter; + idx_t k = 0; + idx_t capacity = 0; + idx_t filter_mask; + + void Initialize(idx_t kval) { + static constexpr idx_t MONITORED_VALUES_RATIO = 3; + static constexpr idx_t FILTER_RATIO = 8; + + D_ASSERT(values.empty()); + D_ASSERT(lookup_map.empty()); + k = kval; + capacity = kval * MONITORED_VALUES_RATIO; + stored_values = make_unsafe_uniq_array_uninitialized(capacity); + values.reserve(capacity); + + // we scale the filter based on the amount of values we are monitoring + idx_t filter_size = NextPowerOfTwo(capacity * FILTER_RATIO); + filter_mask = filter_size - 1; + filter.resize(filter_size); + } + + static void CopyValue(ApproxTopKValue &value, const ApproxTopKString &input, AggregateInputData &input_data) { + value.str_val.hash = input.hash; + if (input.str.IsInlined()) { + // no need to copy + value.str_val = input; + return; + } + value.size = UnsafeNumericCast(input.str.GetSize()); + if (value.size > value.capacity) { + // need to re-allocate for this value + value.capacity = UnsafeNumericCast(NextPowerOfTwo(value.size)); + value.dataptr = char_ptr_cast(input_data.allocator.Allocate(value.capacity)); + } + // copy over the data + memcpy(value.dataptr, input.str.GetData(), value.size); + value.str_val.str = string_t(value.dataptr, value.size); + } + + void InsertOrReplaceEntry(const ApproxTopKString &input, AggregateInputData &aggr_input, idx_t increment = 1) { + if (values.size() < capacity) { + D_ASSERT(increment > 0); + // we can always add this entry + auto &val = stored_values[values.size()]; + val.index = values.size(); + values.push_back(val); + } + auto &value = values.back().get(); + if (value.count > 0) { + // the capacity is reached - we need to replace an entry + + // we use the filter as an early out + // based on the hash - we find a slot in the filter + // instead of monitoring the value immediately, we add to the slot in the filter + // ONLY when the value in the filter exceeds the current min value, we start monitoring the value + // this speeds up the algorithm as switching monitor values means we need to erase/insert in the hash table + auto &filter_value = filter[input.hash & filter_mask]; + if (filter_value + increment < value.count) { + // if the filter has a lower count than the current min count + // we can skip adding this entry (for now) + filter_value += increment; + return; + } + // the filter exceeds the min value - start monitoring this value + // erase the existing entry from the map + // and set the filter for the minimum value back to the current minimum value + filter[value.str_val.hash & filter_mask] = value.count; + lookup_map.erase(value.str_val); + } + CopyValue(value, input, aggr_input); + lookup_map.insert(make_pair(value.str_val, reference(value))); + IncrementCount(value, increment); + } + + void IncrementCount(ApproxTopKValue &value, idx_t increment = 1) { + value.count += increment; + // maintain sortedness of "values" + // swap while we have a higher count than the next entry + while (value.index > 0 && values[value.index].get().count > values[value.index - 1].get().count) { + // swap the elements around + auto &left = values[value.index]; + auto &right = values[value.index - 1]; + std::swap(left.get().index, right.get().index); + std::swap(left, right); + } + } + + void Verify() const { +#ifdef DEBUG + if (values.empty()) { + D_ASSERT(lookup_map.empty()); + return; + } + D_ASSERT(values.size() <= capacity); + for (idx_t k = 0; k < values.size(); k++) { + auto &val = values[k].get(); + D_ASSERT(val.count > 0); + // verify map exists + auto entry = lookup_map.find(val.str_val); + D_ASSERT(entry != lookup_map.end()); + // verify the index is correct + D_ASSERT(val.index == k); + if (k > 0) { + // sortedness + D_ASSERT(val.count <= values[k - 1].get().count); + } + } + // verify lookup map does not contain extra entries + D_ASSERT(lookup_map.size() == values.size()); +#endif + } +}; + +struct ApproxTopKState { + InternalApproxTopKState *state; + + InternalApproxTopKState &GetState() { + if (!state) { + state = new InternalApproxTopKState(); + } + return *state; + } + + const InternalApproxTopKState &GetState() const { + if (!state) { + throw InternalException("No state available"); + } + return *state; + } +}; + +struct ApproxTopKOperation { + template + static void Initialize(STATE &state) { + state.state = nullptr; + } + + template + static void Operation(STATE &aggr_state, const TYPE &input, AggregateInputData &aggr_input, Vector &top_k_vector, + idx_t offset, idx_t count) { + auto &state = aggr_state.GetState(); + if (state.values.empty()) { + static constexpr int64_t MAX_APPROX_K = 1000000; + // not initialized yet - initialize the K value and set all counters to 0 + UnifiedVectorFormat kdata; + top_k_vector.ToUnifiedFormat(count, kdata); + auto kidx = kdata.sel->get_index(offset); + if (!kdata.validity.RowIsValid(kidx)) { + throw InvalidInputException("Invalid input for approx_top_k: k value cannot be NULL"); + } + auto kval = UnifiedVectorFormat::GetData(kdata)[kidx]; + if (kval <= 0) { + throw InvalidInputException("Invalid input for approx_top_k: k value must be > 0"); + } + if (kval >= MAX_APPROX_K) { + throw InvalidInputException("Invalid input for approx_top_k: k value must be < %d", MAX_APPROX_K); + } + state.Initialize(UnsafeNumericCast(kval)); + } + ApproxTopKString topk_string(input, Hash(input)); + auto entry = state.lookup_map.find(topk_string); + if (entry != state.lookup_map.end()) { + // the input is monitored - increment the count + state.IncrementCount(entry->second.get()); + } else { + // the input is not monitored - replace the first entry with the current entry and increment + state.InsertOrReplaceEntry(topk_string, aggr_input); + } + } + + template + static void Combine(const STATE &aggr_source, STATE &aggr_target, AggregateInputData &aggr_input) { + if (!aggr_source.state) { + // source state is empty + return; + } + auto &source = aggr_source.GetState(); + auto &target = aggr_target.GetState(); + if (source.values.empty()) { + // source is empty + return; + } + source.Verify(); + auto min_source = source.values.back().get().count; + idx_t min_target; + if (target.values.empty()) { + min_target = 0; + target.Initialize(source.k); + } else { + if (source.k != target.k) { + throw NotImplementedException("Approx Top K - cannot combine approx_top_K with different k values. " + "K values must be the same for all entries within the same group"); + } + min_target = target.values.back().get().count; + } + // for all entries in target + // check if they are tracked in source + // if they do - add the tracked count + // if they do not - add the minimum count + for (idx_t target_idx = 0; target_idx < target.values.size(); target_idx++) { + auto &val = target.values[target_idx].get(); + auto source_entry = source.lookup_map.find(val.str_val); + idx_t increment = min_source; + if (source_entry != source.lookup_map.end()) { + increment = source_entry->second.get().count; + } + if (increment == 0) { + continue; + } + target.IncrementCount(val, increment); + } + // now for each entry in source, if it is not tracked by the target, at the target minimum + for (auto &source_entry : source.values) { + auto &source_val = source_entry.get(); + auto target_entry = target.lookup_map.find(source_val.str_val); + if (target_entry != target.lookup_map.end()) { + // already tracked - no need to add anything + continue; + } + auto new_count = source_val.count + min_target; + idx_t increment; + if (target.values.size() >= target.capacity) { + idx_t current_min = target.values.empty() ? 0 : target.values.back().get().count; + D_ASSERT(target.values.size() == target.capacity); + // target already has capacity values + // check if we should insert this entry + if (new_count <= current_min) { + // if we do not we can skip this entry + continue; + } + increment = new_count - current_min; + } else { + // target does not have capacity entries yet + // just add this entry with the full count + increment = new_count; + } + target.InsertOrReplaceEntry(source_val.str_val, aggr_input, increment); + } + // copy over the filter + D_ASSERT(source.filter.size() == target.filter.size()); + for (idx_t filter_idx = 0; filter_idx < source.filter.size(); filter_idx++) { + target.filter[filter_idx] += source.filter[filter_idx]; + } + target.Verify(); + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + delete state.state; + } + + static bool IgnoreNull() { + return true; + } +}; + +template +static void ApproxTopKUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, + idx_t count) { + using STATE = ApproxTopKState; + auto &input = inputs[0]; + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + + auto &top_k_vector = inputs[1]; + + auto extra_state = OP::CreateExtraState(count); + UnifiedVectorFormat input_data; + OP::PrepareData(input, count, extra_state, input_data); + + auto states = UnifiedVectorFormat::GetData(sdata); + auto data = UnifiedVectorFormat::GetData(input_data); + for (idx_t i = 0; i < count; i++) { + auto idx = input_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(idx)) { + continue; + } + auto &state = *states[sdata.sel->get_index(i)]; + ApproxTopKOperation::Operation(state, data[idx], aggr_input, top_k_vector, i, count); + } +} + +template +static void ApproxTopKFinalize(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) { + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = UnifiedVectorFormat::GetData(sdata); + + auto &mask = FlatVector::Validity(result); + auto old_len = ListVector::GetListSize(result); + idx_t new_entries = 0; + // figure out how much space we need + for (idx_t i = 0; i < count; i++) { + auto &state = states[sdata.sel->get_index(i)]->GetState(); + if (state.values.empty()) { + continue; + } + // get up to k values for each state + // this can be less of fewer unique values were found + new_entries += MinValue(state.values.size(), state.k); + } + // reserve space in the list vector + ListVector::Reserve(result, old_len + new_entries); + auto list_entries = FlatVector::GetData(result); + auto &child_data = ListVector::GetEntry(result); + + idx_t current_offset = old_len; + for (idx_t i = 0; i < count; i++) { + const auto rid = i + offset; + auto &state = states[sdata.sel->get_index(i)]->GetState(); + if (state.values.empty()) { + mask.SetInvalid(rid); + continue; + } + auto &list_entry = list_entries[rid]; + list_entry.offset = current_offset; + for (idx_t val_idx = 0; val_idx < MinValue(state.values.size(), state.k); val_idx++) { + auto &val = state.values[val_idx].get(); + D_ASSERT(val.count > 0); + OP::template HistogramFinalize(val.str_val.str, child_data, current_offset); + current_offset++; + } + list_entry.length = current_offset - list_entry.offset; + } + D_ASSERT(current_offset == old_len + new_entries); + ListVector::SetListSize(result, current_offset); + result.Verify(count); +} + +unique_ptr ApproxTopKBind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + for (auto &arg : arguments) { + if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + } + if (arguments[0]->return_type.id() == LogicalTypeId::VARCHAR) { + function.update = ApproxTopKUpdate; + function.finalize = ApproxTopKFinalize; + } + function.return_type = LogicalType::LIST(arguments[0]->return_type); + return nullptr; +} + +AggregateFunction ApproxTopKFun::GetFunction() { + using STATE = ApproxTopKState; + using OP = ApproxTopKOperation; + return AggregateFunction("approx_top_k", {LogicalTypeId::ANY, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::ANY), AggregateFunction::StateSize, + AggregateFunction::StateInitialize, ApproxTopKUpdate, + AggregateFunction::StateCombine, ApproxTopKFinalize, nullptr, ApproxTopKBind, + AggregateFunction::StateDestroy); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp new file mode 100644 index 00000000..23d2cf47 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp @@ -0,0 +1,444 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "core_functions/aggregate/holistic_functions.hpp" +#include "t_digest.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include +#include +#include + +namespace duckdb { + +struct ApproxQuantileState { + duckdb_tdigest::TDigest *h; + idx_t pos; +}; + +struct ApproximateQuantileBindData : public FunctionData { + ApproximateQuantileBindData() { + } + explicit ApproximateQuantileBindData(float quantile_p) : quantiles(1, quantile_p) { + } + + explicit ApproximateQuantileBindData(vector quantiles_p) : quantiles(std::move(quantiles_p)) { + } + + unique_ptr Copy() const override { + return make_uniq(quantiles); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + // return quantiles == other.quantiles; + if (quantiles != other.quantiles) { + return false; + } + return true; + } + + static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "quantiles", bind_data.quantiles); + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto result = make_uniq(); + deserializer.ReadProperty(100, "quantiles", result->quantiles); + return std::move(result); + } + + vector quantiles; +}; + +struct ApproxQuantileOperation { + using SAVE_TYPE = duckdb_tdigest::Value; + + template + static void Initialize(STATE &state) { + state.pos = 0; + state.h = nullptr; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + auto val = Cast::template Operation(input); + if (!Value::DoubleIsFinite(val)) { + return; + } + if (!state.h) { + state.h = new duckdb_tdigest::TDigest(100); + } + state.h->add(val); + state.pos++; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.pos == 0) { + return; + } + D_ASSERT(source.h); + if (!target.h) { + target.h = new duckdb_tdigest::TDigest(100); + } + target.h->merge(source.h); + target.pos += source.pos; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.h) { + delete state.h; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct ApproxQuantileScalarOperation : public ApproxQuantileOperation { + template + static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { + if (state.pos == 0) { + finalize_data.ReturnNull(); + return; + } + D_ASSERT(state.h); + D_ASSERT(finalize_data.input.bind_data); + state.h->compress(); + auto &bind_data = finalize_data.input.bind_data->template Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + // The result is approximate, so clamp instead of overflowing. + const auto source = state.h->quantile(bind_data.quantiles[0]); + if (TryCast::Operation(source, target, false)) { + return; + } else if (source < 0) { + target = NumericLimits::Minimum(); + } else { + target = NumericLimits::Maximum(); + } + } +}; + +static AggregateFunction GetApproximateQuantileAggregateFunction(const LogicalType &type) { + // Not binary comparable + if (type == LogicalType::TIME_TZ) { + return AggregateFunction::UnaryAggregateDestructor(type, type); + } + switch (type.InternalType()) { + case PhysicalType::INT8: + return AggregateFunction::UnaryAggregateDestructor(type, type); + case PhysicalType::INT16: + return AggregateFunction::UnaryAggregateDestructor(type, type); + case PhysicalType::INT32: + return AggregateFunction::UnaryAggregateDestructor(type, type); + case PhysicalType::INT64: + return AggregateFunction::UnaryAggregateDestructor(type, type); + case PhysicalType::INT128: + return AggregateFunction::UnaryAggregateDestructor(type, type); + case PhysicalType::FLOAT: + return AggregateFunction::UnaryAggregateDestructor(type, type); + case PhysicalType::DOUBLE: + return AggregateFunction::UnaryAggregateDestructor(type, type); + default: + throw InternalException("Unimplemented quantile aggregate"); + } +} + +static AggregateFunction GetApproximateQuantileDecimalAggregateFunction(const LogicalType &type) { + switch (type.InternalType()) { + case PhysicalType::INT8: + return GetApproximateQuantileAggregateFunction(LogicalType::TINYINT); + case PhysicalType::INT16: + return GetApproximateQuantileAggregateFunction(LogicalType::SMALLINT); + case PhysicalType::INT32: + return GetApproximateQuantileAggregateFunction(LogicalType::INTEGER); + case PhysicalType::INT64: + return GetApproximateQuantileAggregateFunction(LogicalType::BIGINT); + case PhysicalType::INT128: + return GetApproximateQuantileAggregateFunction(LogicalType::HUGEINT); + default: + throw InternalException("Unimplemented quantile decimal aggregate"); + } +} + +static float CheckApproxQuantile(const Value &quantile_val) { + if (quantile_val.IsNull()) { + throw BinderException("APPROXIMATE QUANTILE parameter cannot be NULL"); + } + auto quantile = quantile_val.GetValue(); + if (quantile < 0 || quantile > 1) { + throw BinderException("APPROXIMATE QUANTILE can only take parameters in range [0, 1]"); + } + + return quantile; +} + +unique_ptr BindApproxQuantile(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw BinderException("APPROXIMATE QUANTILE can only take constant quantile parameters"); + } + Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + if (quantile_val.IsNull()) { + throw BinderException("APPROXIMATE QUANTILE parameter list cannot be NULL"); + } + + vector quantiles; + switch (quantile_val.type().id()) { + case LogicalTypeId::LIST: + for (const auto &element_val : ListValue::GetChildren(quantile_val)) { + quantiles.push_back(CheckApproxQuantile(element_val)); + } + break; + case LogicalTypeId::ARRAY: + for (const auto &element_val : ArrayValue::GetChildren(quantile_val)) { + quantiles.push_back(CheckApproxQuantile(element_val)); + } + break; + default: + quantiles.push_back(CheckApproxQuantile(quantile_val)); + break; + } + + // remove the quantile argument so we can use the unary aggregate + Function::EraseArgument(function, arguments, arguments.size() - 1); + return make_uniq(quantiles); +} + +AggregateFunction ApproxQuantileDecimalFunction(const LogicalType &type) { + auto function = GetApproximateQuantileDecimalAggregateFunction(type); + function.name = "approx_quantile"; + function.serialize = ApproximateQuantileBindData::Serialize; + function.deserialize = ApproximateQuantileBindData::Deserialize; + return function; +} + +unique_ptr BindApproxQuantileDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto bind_data = BindApproxQuantile(context, function, arguments); + function = ApproxQuantileDecimalFunction(arguments[0]->return_type); + return bind_data; +} + +AggregateFunction GetApproximateQuantileAggregate(const LogicalType &type) { + auto fun = GetApproximateQuantileAggregateFunction(type); + fun.bind = BindApproxQuantile; + fun.serialize = ApproximateQuantileBindData::Serialize; + fun.deserialize = ApproximateQuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::FLOAT); + return fun; +} + +template +struct ApproxQuantileListOperation : public ApproxQuantileOperation { + + template + static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) { + if (state.pos == 0) { + finalize_data.ReturnNull(); + return; + } + + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->template Cast(); + + auto &result = ListVector::GetEntry(finalize_data.result); + auto ridx = ListVector::GetListSize(finalize_data.result); + ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); + auto rdata = FlatVector::GetData(result); + + D_ASSERT(state.h); + state.h->compress(); + + auto &entry = target; + entry.offset = ridx; + entry.length = bind_data.quantiles.size(); + for (size_t q = 0; q < entry.length; ++q) { + const auto &quantile = bind_data.quantiles[q]; + rdata[ridx + q] = Cast::template Operation(state.h->quantile(quantile)); + } + + ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); + } +}; + +template +static AggregateFunction ApproxQuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { + LogicalType result_type = LogicalType::LIST(child_type); + return AggregateFunction( + {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, + nullptr, AggregateFunction::StateDestroy); +} + +template +AggregateFunction GetTypedApproxQuantileListAggregateFunction(const LogicalType &type) { + using STATE = ApproxQuantileState; + using OP = ApproxQuantileListOperation; + auto fun = ApproxQuantileListAggregate(type, type); + fun.serialize = ApproximateQuantileBindData::Serialize; + fun.deserialize = ApproximateQuantileBindData::Deserialize; + return fun; +} + +AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::SMALLINT: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::INTEGER: + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::BIGINT: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::TIME_TZ: + // Not binary comparable + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::HUGEINT: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::FLOAT: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::DOUBLE: + return GetTypedApproxQuantileListAggregateFunction(type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedApproxQuantileListAggregateFunction(type); + case PhysicalType::INT32: + return GetTypedApproxQuantileListAggregateFunction(type); + case PhysicalType::INT64: + return GetTypedApproxQuantileListAggregateFunction(type); + case PhysicalType::INT128: + return GetTypedApproxQuantileListAggregateFunction(type); + default: + throw NotImplementedException("Unimplemented approximate quantile list decimal aggregate"); + } + default: + throw NotImplementedException("Unimplemented approximate quantile list aggregate"); + } +} + +AggregateFunction ApproxQuantileDecimalListFunction(const LogicalType &type) { + auto function = GetApproxQuantileListAggregateFunction(type); + function.name = "approx_quantile"; + function.serialize = ApproximateQuantileBindData::Serialize; + function.deserialize = ApproximateQuantileBindData::Deserialize; + return function; +} + +unique_ptr BindApproxQuantileDecimalList(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + auto bind_data = BindApproxQuantile(context, function, arguments); + function = ApproxQuantileDecimalListFunction(arguments[0]->return_type); + return bind_data; +} + +AggregateFunction GetApproxQuantileListAggregate(const LogicalType &type) { + auto fun = GetApproxQuantileListAggregateFunction(type); + fun.bind = BindApproxQuantile; + fun.serialize = ApproximateQuantileBindData::Serialize; + fun.deserialize = ApproximateQuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + auto list_of_float = LogicalType::LIST(LogicalType::FLOAT); + fun.arguments.push_back(list_of_float); + return fun; +} + +unique_ptr ApproxQuantileDecimalDeserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = ApproximateQuantileBindData::Deserialize(deserializer, function); + auto &return_type = deserializer.Get(); + if (return_type.id() == LogicalTypeId::LIST) { + function = ApproxQuantileDecimalListFunction(function.arguments[0]); + } else { + function = ApproxQuantileDecimalFunction(function.arguments[0]); + } + return bind_data; +} + +AggregateFunction GetApproxQuantileDecimal() { + // stub function - the actual function is set during bind or deserialize + AggregateFunction fun({LogicalTypeId::DECIMAL, LogicalType::FLOAT}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, BindApproxQuantileDecimal); + fun.serialize = ApproximateQuantileBindData::Serialize; + fun.deserialize = ApproxQuantileDecimalDeserialize; + return fun; +} + +AggregateFunction GetApproxQuantileDecimalList() { + // stub function - the actual function is set during bind or deserialize + AggregateFunction fun({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::FLOAT)}, + LogicalType::LIST(LogicalTypeId::DECIMAL), nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, BindApproxQuantileDecimalList); + fun.serialize = ApproximateQuantileBindData::Serialize; + fun.deserialize = ApproxQuantileDecimalDeserialize; + return fun; +} + +AggregateFunctionSet ApproxQuantileFun::GetFunctions() { + AggregateFunctionSet approx_quantile; + approx_quantile.AddFunction(GetApproxQuantileDecimal()); + + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::SMALLINT)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::INTEGER)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::BIGINT)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::HUGEINT)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::DOUBLE)); + + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::DATE)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIME)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIME_TZ)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIMESTAMP)); + approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIMESTAMP_TZ)); + + // List variants + approx_quantile.AddFunction(GetApproxQuantileDecimalList()); + + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::TINYINT)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::SMALLINT)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::INTEGER)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::BIGINT)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::HUGEINT)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::FLOAT)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::DOUBLE)); + + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::DATE)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIME)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIME_TZ)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIMESTAMP)); + approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIMESTAMP_TZ)); + + return approx_quantile; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp new file mode 100644 index 00000000..dedb7429 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp @@ -0,0 +1,345 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/abs.hpp" +#include "core_functions/aggregate/quantile_state.hpp" + +namespace duckdb { + +struct FrameSet { + inline explicit FrameSet(const SubFrames &frames_p) : frames(frames_p) { + } + + inline idx_t Size() const { + idx_t result = 0; + for (const auto &frame : frames) { + result += frame.end - frame.start; + } + + return result; + } + + inline bool Contains(idx_t i) const { + for (idx_t f = 0; f < frames.size(); ++f) { + const auto &frame = frames[f]; + if (frame.start <= i && i < frame.end) { + return true; + } + } + return false; + } + const SubFrames &frames; +}; + +struct QuantileReuseUpdater { + idx_t *index; + idx_t j; + + inline QuantileReuseUpdater(idx_t *index, idx_t j) : index(index), j(j) { + } + + inline void Neither(idx_t begin, idx_t end) { + } + + inline void Left(idx_t begin, idx_t end) { + } + + inline void Right(idx_t begin, idx_t end) { + for (; begin < end; ++begin) { + index[j++] = begin; + } + } + + inline void Both(idx_t begin, idx_t end) { + } +}; + +void ReuseIndexes(idx_t *index, const SubFrames &currs, const SubFrames &prevs) { + + // Copy overlapping indices by scanning the previous set and copying down into holes. + // We copy instead of leaving gaps in case there are fewer values in the current frame. + FrameSet prev_set(prevs); + FrameSet curr_set(currs); + const auto prev_count = prev_set.Size(); + idx_t j = 0; + for (idx_t p = 0; p < prev_count; ++p) { + auto idx = index[p]; + + // Shift down into any hole + if (j != p) { + index[j] = idx; + } + + // Skip overlapping values + if (curr_set.Contains(idx)) { + ++j; + } + } + + // Insert new indices + if (j > 0) { + QuantileReuseUpdater updater(index, j); + AggregateExecutor::IntersectFrames(prevs, currs, updater); + } else { + // No overlap: overwrite with new values + for (const auto &curr : currs) { + for (auto idx = curr.start; idx < curr.end; ++idx) { + index[j++] = idx; + } + } + } +} + +//===--------------------------------------------------------------------===// +// Median Absolute Deviation +//===--------------------------------------------------------------------===// +template +struct MadAccessor { + using INPUT_TYPE = T; + using RESULT_TYPE = R; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const RESULT_TYPE delta = input - UnsafeNumericCast(median); + return TryAbsOperator::Operation(delta); + } +}; + +// hugeint_t - double => undefined +template <> +struct MadAccessor { + using INPUT_TYPE = hugeint_t; + using RESULT_TYPE = double; + using MEDIAN_TYPE = double; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto delta = Hugeint::Cast(input) - median; + return TryAbsOperator::Operation(delta); + } +}; + +// date_t - timestamp_t => interval_t +template <> +struct MadAccessor { + using INPUT_TYPE = date_t; + using RESULT_TYPE = interval_t; + using MEDIAN_TYPE = timestamp_t; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto dt = Cast::Operation(input); + const auto delta = dt - median; + return Interval::FromMicro(TryAbsOperator::Operation(delta)); + } +}; + +// timestamp_t - timestamp_t => int64_t +template <> +struct MadAccessor { + using INPUT_TYPE = timestamp_t; + using RESULT_TYPE = interval_t; + using MEDIAN_TYPE = timestamp_t; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto delta = input - median; + return Interval::FromMicro(TryAbsOperator::Operation(delta)); + } +}; + +// dtime_t - dtime_t => int64_t +template <> +struct MadAccessor { + using INPUT_TYPE = dtime_t; + using RESULT_TYPE = interval_t; + using MEDIAN_TYPE = dtime_t; + const MEDIAN_TYPE &median; + explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { + } + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + const auto delta = input - median; + return Interval::FromMicro(TryAbsOperator::Operation(delta)); + } +}; + +template +struct MedianAbsoluteDeviationOperation : QuantileOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } + using INPUT_TYPE = typename STATE::InputType; + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + const auto &q = bind_data.quantiles[0]; + Interpolator interp(q, state.v.size(), false); + const auto med = interp.template Operation(state.v.data(), finalize_data.result); + + MadAccessor accessor(med); + target = interp.template Operation(state.v.data(), finalize_data.result, accessor); + } + + template + static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result, + idx_t ridx) { + auto &state = *reinterpret_cast(l_state); + auto gstate = reinterpret_cast(g_state); + + auto &data = state.GetOrCreateWindowCursor(partition); + const auto &fmask = partition.filter_mask; + + auto rdata = FlatVector::GetData(result); + + QuantileIncluded included(fmask, data); + const auto n = FrameSize(included, frames); + + if (!n) { + auto &rmask = FlatVector::Validity(result); + rmask.Set(ridx, false); + return; + } + + // Compute the median + D_ASSERT(aggr_input_data.bind_data); + auto &bind_data = aggr_input_data.bind_data->Cast(); + + D_ASSERT(bind_data.quantiles.size() == 1); + const auto &quantile = bind_data.quantiles[0]; + auto &window_state = state.GetOrCreateWindowState(); + MEDIAN_TYPE med; + if (gstate && gstate->HasTree()) { + med = gstate->GetWindowState().template WindowScalar(data, frames, n, result, quantile); + } else { + window_state.UpdateSkip(data, frames, included); + med = window_state.template WindowScalar(data, frames, n, result, quantile); + } + + // Lazily initialise frame state + window_state.SetCount(frames.back().end - frames.front().start); + auto index2 = window_state.m.data(); + D_ASSERT(index2); + + // The replacement trick does not work on the second index because if + // the median has changed, the previous order is not correct. + // It is probably close, however, and so reuse is helpful. + auto &prevs = window_state.prevs; + ReuseIndexes(index2, frames, prevs); + std::partition(index2, index2 + window_state.count, included); + + Interpolator interp(quantile, n, false); + + // Compute mad from the second index + using ID = QuantileIndirect; + ID indirect(data); + + using MAD = MadAccessor; + MAD mad(med); + + using MadIndirect = QuantileComposed; + MadIndirect mad_indirect(mad, indirect); + rdata[ridx] = interp.template Operation(index2, result, mad_indirect); + + // Prev is used by both skip lists and increments + prevs = frames; + } +}; + +unique_ptr BindMAD(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + return make_uniq(Value::DECIMAL(int16_t(5), 2, 1)); +} + +template +AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const LogicalType &input_type, + const LogicalType &target_type) { + using STATE = QuantileState; + using OP = MedianAbsoluteDeviationOperation; + auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); + fun.bind = BindMAD; + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; +#ifndef DUCKDB_SMALLER_BINARY + fun.window = OP::template Window; + fun.window_init = OP::template WindowInit; +#endif + return fun; +} + +AggregateFunction GetMedianAbsoluteDeviationAggregateFunctionInternal(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::FLOAT: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case LogicalTypeId::DOUBLE: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case PhysicalType::INT32: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case PhysicalType::INT64: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + case PhysicalType::INT128: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); + default: + throw NotImplementedException("Unimplemented Median Absolute Deviation DECIMAL aggregate"); + } + break; + + case LogicalTypeId::DATE: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, + LogicalType::INTERVAL); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + return GetTypedMedianAbsoluteDeviationAggregateFunction( + type, LogicalType::INTERVAL); + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return GetTypedMedianAbsoluteDeviationAggregateFunction(type, + LogicalType::INTERVAL); + + default: + throw NotImplementedException("Unimplemented Median Absolute Deviation aggregate"); + } +} + +AggregateFunction GetMedianAbsoluteDeviationAggregateFunction(const LogicalType &type) { + auto result = GetMedianAbsoluteDeviationAggregateFunctionInternal(type); + result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; + return result; +} + +unique_ptr BindMedianAbsoluteDeviationDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetMedianAbsoluteDeviationAggregateFunction(arguments[0]->return_type); + function.name = "mad"; + function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return BindMAD(context, function, arguments); +} + +AggregateFunctionSet MadFun::GetFunctions() { + AggregateFunctionSet mad("mad"); + mad.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, BindMedianAbsoluteDeviationDecimal)); + + const vector MAD_TYPES = {LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, + LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, + LogicalType::TIME_TZ}; + for (const auto &type : MAD_TYPES) { + mad.AddFunction(GetMedianAbsoluteDeviationAggregateFunction(type)); + } + return mad; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp new file mode 100644 index 00000000..8c35fc8c --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp @@ -0,0 +1,573 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/uhugeint.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "core_functions/aggregate/distributive_functions.hpp" +#include "core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/owning_string_map.hpp" +#include "duckdb/function/create_sort_key.hpp" +#include "duckdb/function/aggregate/sort_key_helpers.hpp" +#include "duckdb/common/algorithm.hpp" +#include + +// MODE( ) +// Returns the most frequent value for the values within expr1. +// NULL values are ignored. If all the values are NULL, or there are 0 rows, then the function returns NULL. + +namespace std {} // namespace std + +namespace duckdb { + +struct ModeAttr { + ModeAttr() : count(0), first_row(std::numeric_limits::max()) { + } + size_t count; + idx_t first_row; +}; + +template +struct ModeStandard { + using MAP_TYPE = unordered_map; + + static MAP_TYPE *CreateEmpty(ArenaAllocator &) { + return new MAP_TYPE(); + } + static MAP_TYPE *CreateEmpty(Allocator &) { + return new MAP_TYPE(); + } + + template + static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { + return RESULT_TYPE(input); + } +}; + +struct ModeString { + using MAP_TYPE = OwningStringMap; + + static MAP_TYPE *CreateEmpty(ArenaAllocator &allocator) { + return new MAP_TYPE(allocator); + } + static MAP_TYPE *CreateEmpty(Allocator &allocator) { + return new MAP_TYPE(allocator); + } + + template + static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { + return StringVector::AddStringOrBlob(result, input); + } +}; + +template +struct ModeState { + using Counts = typename TYPE_OP::MAP_TYPE; + + ModeState() { + } + + SubFrames prevs; + Counts *frequency_map = nullptr; + KEY_TYPE *mode = nullptr; + size_t nonzero = 0; + bool valid = false; + size_t count = 0; + + //! The collection being read + const ColumnDataCollection *inputs; + //! The state used for reading the collection on this thread + ColumnDataScanState *scan = nullptr; + //! The data chunk paged into into + DataChunk page; + //! The data pointer + const KEY_TYPE *data = nullptr; + //! The validity mask + const ValidityMask *validity = nullptr; + + ~ModeState() { + if (frequency_map) { + delete frequency_map; + } + if (mode) { + delete mode; + } + if (scan) { + delete scan; + } + } + + void InitializePage(const WindowPartitionInput &partition) { + if (!scan) { + scan = new ColumnDataScanState(); + } + if (page.ColumnCount() == 0) { + D_ASSERT(partition.inputs); + inputs = partition.inputs; + D_ASSERT(partition.column_ids.size() == 1); + inputs->InitializeScan(*scan, partition.column_ids); + inputs->InitializeScanChunk(*scan, page); + } + } + + inline sel_t RowOffset(idx_t row_idx) const { + D_ASSERT(RowIsVisible(row_idx)); + return UnsafeNumericCast(row_idx - scan->current_row_index); + } + + inline bool RowIsVisible(idx_t row_idx) const { + return (row_idx < scan->next_row_index && scan->current_row_index <= row_idx); + } + + inline idx_t Seek(idx_t row_idx) { + if (!RowIsVisible(row_idx)) { + D_ASSERT(inputs); + inputs->Seek(row_idx, *scan, page); + data = FlatVector::GetData(page.data[0]); + validity = &FlatVector::Validity(page.data[0]); + } + return RowOffset(row_idx); + } + + inline const KEY_TYPE &GetCell(idx_t row_idx) { + const auto offset = Seek(row_idx); + return data[offset]; + } + + inline bool RowIsValid(idx_t row_idx) { + const auto offset = Seek(row_idx); + return validity->RowIsValid(offset); + } + + void Reset() { + if (frequency_map) { + frequency_map->clear(); + } + nonzero = 0; + count = 0; + valid = false; + } + + void ModeAdd(idx_t row) { + const auto &key = GetCell(row); + auto &attr = (*frequency_map)[key]; + auto new_count = (attr.count += 1); + if (new_count == 1) { + ++nonzero; + attr.first_row = row; + } else { + attr.first_row = MinValue(row, attr.first_row); + } + if (new_count > count) { + valid = true; + count = new_count; + if (mode) { + *mode = key; + } else { + mode = new KEY_TYPE(key); + } + } + } + + void ModeRm(idx_t frame) { + const auto &key = GetCell(frame); + auto &attr = (*frequency_map)[key]; + auto old_count = attr.count; + nonzero -= size_t(old_count == 1); + + attr.count -= 1; + if (count == old_count && key == *mode) { + valid = false; + } + } + + typename Counts::const_iterator Scan() const { + //! Initialize control variables to first variable of the frequency map + auto highest_frequency = frequency_map->begin(); + for (auto i = highest_frequency; i != frequency_map->end(); ++i) { + // Tie break with the lowest insert position + if (i->second.count > highest_frequency->second.count || + (i->second.count == highest_frequency->second.count && + i->second.first_row < highest_frequency->second.first_row)) { + highest_frequency = i; + } + } + return highest_frequency; + } +}; + +template +struct ModeIncluded { + inline explicit ModeIncluded(const ValidityMask &fmask_p, STATE &state) : fmask(fmask_p), state(state) { + } + + inline bool operator()(const idx_t &idx) const { + return fmask.RowIsValid(idx) && state.RowIsValid(idx); + } + const ValidityMask &fmask; + STATE &state; +}; + +template +struct BaseModeFunction { + template + static void Initialize(STATE &state) { + new (&state) STATE(); + } + + template + static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) { + if (!state.frequency_map) { + state.frequency_map = TYPE_OP::CreateEmpty(input_data.allocator); + } + auto &i = (*state.frequency_map)[key]; + ++i.count; + i.first_row = MinValue(i.first_row, state.count); + ++state.count; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &key, AggregateUnaryInput &aggr_input) { + Execute(state, key, aggr_input.input); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.frequency_map) { + return; + } + if (!target.frequency_map) { + // Copy - don't destroy! Otherwise windowing will break. + target.frequency_map = new typename STATE::Counts(*source.frequency_map); + target.count = source.count; + return; + } + for (auto &val : *source.frequency_map) { + auto &i = (*target.frequency_map)[val.first]; + i.count += val.second.count; + i.first_row = MinValue(i.first_row, val.second.first_row); + } + target.count += source.count; + } + + static bool IgnoreNull() { + return true; + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.~STATE(); + } +}; + +template +struct TypedModeFunction : BaseModeFunction { + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &key, AggregateUnaryInput &aggr_input, idx_t count) { + if (!state.frequency_map) { + state.frequency_map = TYPE_OP::CreateEmpty(aggr_input.input.allocator); + } + auto &i = (*state.frequency_map)[key]; + i.count += count; + i.first_row = MinValue(i.first_row, state.count); + state.count += count; + } +}; + +template +struct ModeFunction : TypedModeFunction { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.frequency_map) { + finalize_data.ReturnNull(); + return; + } + auto highest_frequency = state.Scan(); + if (highest_frequency != state.frequency_map->end()) { + target = TYPE_OP::template Assign(finalize_data.result, highest_frequency->first); + } else { + finalize_data.ReturnNull(); + } + } + + template + struct UpdateWindowState { + STATE &state; + ModeIncluded &included; + + inline UpdateWindowState(STATE &state, ModeIncluded &included) : state(state), included(included) { + } + + inline void Neither(idx_t begin, idx_t end) { + } + + inline void Left(idx_t begin, idx_t end) { + for (; begin < end; ++begin) { + if (included(begin)) { + state.ModeRm(begin); + } + } + } + + inline void Right(idx_t begin, idx_t end) { + for (; begin < end; ++begin) { + if (included(begin)) { + state.ModeAdd(begin); + } + } + } + + inline void Both(idx_t begin, idx_t end) { + } + }; + + template + static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result, + idx_t rid) { + auto &state = *reinterpret_cast(l_state); + + state.InitializePage(partition); + const auto &fmask = partition.filter_mask; + + auto rdata = FlatVector::GetData(result); + auto &rmask = FlatVector::Validity(result); + auto &prevs = state.prevs; + if (prevs.empty()) { + prevs.resize(1); + } + + ModeIncluded included(fmask, state); + + if (!state.frequency_map) { + state.frequency_map = TYPE_OP::CreateEmpty(Allocator::DefaultAllocator()); + } + const size_t tau_inverse = 4; // tau==0.25 + if (state.nonzero <= (state.frequency_map->size() / tau_inverse) || prevs.back().end <= frames.front().start || + frames.back().end <= prevs.front().start) { + state.Reset(); + // for f ∈ F do + for (const auto &frame : frames) { + for (auto i = frame.start; i < frame.end; ++i) { + if (included(i)) { + state.ModeAdd(i); + } + } + } + } else { + using Updater = UpdateWindowState; + Updater updater(state, included); + AggregateExecutor::IntersectFrames(prevs, frames, updater); + } + + if (!state.valid) { + // Rescan + auto highest_frequency = state.Scan(); + if (highest_frequency != state.frequency_map->end()) { + *(state.mode) = highest_frequency->first; + state.count = highest_frequency->second.count; + state.valid = (state.count > 0); + } + } + + if (state.valid) { + rdata[rid] = TYPE_OP::template Assign(result, *state.mode); + } else { + rmask.Set(rid, false); + } + + prevs = frames; + } +}; + +template +struct ModeFallbackFunction : BaseModeFunction { + template + static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { + if (!state.frequency_map) { + finalize_data.ReturnNull(); + return; + } + auto highest_frequency = state.Scan(); + if (highest_frequency != state.frequency_map->end()) { + CreateSortKeyHelpers::DecodeSortKey(highest_frequency->first, finalize_data.result, + finalize_data.result_idx, + OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); + } else { + finalize_data.ReturnNull(); + } + } +}; + +AggregateFunction GetFallbackModeFunction(const LogicalType &type) { + using STATE = ModeState; + using OP = ModeFallbackFunction; + AggregateFunction aggr({type}, type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateVoidFinalize, nullptr); + aggr.destructor = AggregateFunction::StateDestroy; + return aggr; +} + +template > +AggregateFunction GetTypedModeFunction(const LogicalType &type) { + using STATE = ModeState; + using OP = ModeFunction; + auto func = + AggregateFunction::UnaryAggregateDestructor( + type, type); + func.window = OP::template Window; + return func; +} + +AggregateFunction GetModeAggregate(const LogicalType &type) { + switch (type.InternalType()) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::INT8: + return GetTypedModeFunction(type); + case PhysicalType::UINT8: + return GetTypedModeFunction(type); + case PhysicalType::INT16: + return GetTypedModeFunction(type); + case PhysicalType::UINT16: + return GetTypedModeFunction(type); + case PhysicalType::INT32: + return GetTypedModeFunction(type); + case PhysicalType::UINT32: + return GetTypedModeFunction(type); + case PhysicalType::INT64: + return GetTypedModeFunction(type); + case PhysicalType::UINT64: + return GetTypedModeFunction(type); + case PhysicalType::INT128: + return GetTypedModeFunction(type); + case PhysicalType::UINT128: + return GetTypedModeFunction(type); + case PhysicalType::FLOAT: + return GetTypedModeFunction(type); + case PhysicalType::DOUBLE: + return GetTypedModeFunction(type); + case PhysicalType::INTERVAL: + return GetTypedModeFunction(type); + case PhysicalType::VARCHAR: + return GetTypedModeFunction(type); +#endif + default: + return GetFallbackModeFunction(type); + } +} + +unique_ptr BindModeAggregate(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetModeAggregate(arguments[0]->return_type); + function.name = "mode"; + return nullptr; +} + +AggregateFunctionSet ModeFun::GetFunctions() { + AggregateFunctionSet mode("mode"); + mode.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalTypeId::ANY, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, BindModeAggregate)); + return mode; +} + +//===--------------------------------------------------------------------===// +// Entropy +//===--------------------------------------------------------------------===// +template +static double FinalizeEntropy(STATE &state) { + if (!state.frequency_map) { + return 0; + } + double count = static_cast(state.count); + double entropy = 0; + for (auto &val : *state.frequency_map) { + double val_sec = static_cast(val.second.count); + entropy += (val_sec / count) * log2(count / val_sec); + } + return entropy; +} + +template +struct EntropyFunction : TypedModeFunction { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + target = FinalizeEntropy(state); + } +}; + +template +struct EntropyFallbackFunction : BaseModeFunction { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + target = FinalizeEntropy(state); + } +}; + +template > +AggregateFunction GetTypedEntropyFunction(const LogicalType &type) { + using STATE = ModeState; + using OP = EntropyFunction; + auto func = + AggregateFunction::UnaryAggregateDestructor( + type, LogicalType::DOUBLE); + func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return func; +} + +AggregateFunction GetFallbackEntropyFunction(const LogicalType &type) { + using STATE = ModeState; + using OP = EntropyFallbackFunction; + AggregateFunction func({type}, LogicalType::DOUBLE, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, nullptr); + func.destructor = AggregateFunction::StateDestroy; + func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return func; +} + +AggregateFunction GetEntropyFunction(const LogicalType &type) { + switch (type.InternalType()) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::UINT16: + return GetTypedEntropyFunction(type); + case PhysicalType::UINT32: + return GetTypedEntropyFunction(type); + case PhysicalType::UINT64: + return GetTypedEntropyFunction(type); + case PhysicalType::INT16: + return GetTypedEntropyFunction(type); + case PhysicalType::INT32: + return GetTypedEntropyFunction(type); + case PhysicalType::INT64: + return GetTypedEntropyFunction(type); + case PhysicalType::FLOAT: + return GetTypedEntropyFunction(type); + case PhysicalType::DOUBLE: + return GetTypedEntropyFunction(type); + case PhysicalType::VARCHAR: + return GetTypedEntropyFunction(type); +#endif + default: + return GetFallbackEntropyFunction(type); + } +} + +unique_ptr BindEntropyAggregate(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetEntropyFunction(arguments[0]->return_type); + function.name = "entropy"; + return nullptr; +} + +AggregateFunctionSet EntropyFun::GetFunctions() { + AggregateFunctionSet entropy("entropy"); + entropy.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalType::DOUBLE, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, BindEntropyAggregate)); + return entropy; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp new file mode 100644 index 00000000..98ca4d5b --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp @@ -0,0 +1,873 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/common/enums/quantile_enum.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/abs.hpp" +#include "core_functions/aggregate/quantile_state.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/queue.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/function/aggregate/sort_key_helpers.hpp" + +namespace duckdb { + +template +struct IndirectLess { + inline explicit IndirectLess(const INPUT_TYPE *inputs_p) : inputs(inputs_p) { + } + + inline bool operator()(const idx_t &lhi, const idx_t &rhi) const { + return inputs[lhi] < inputs[rhi]; + } + + const INPUT_TYPE *inputs; +}; + +template +static inline T QuantileAbs(const T &t) { + return AbsOperator::Operation(t); +} + +template <> +inline Value QuantileAbs(const Value &v) { + const auto &type = v.type(); + switch (type.id()) { + case LogicalTypeId::DECIMAL: { + const auto integral = IntegralValue::Get(v); + const auto width = DecimalType::GetWidth(type); + const auto scale = DecimalType::GetScale(type); + switch (type.InternalType()) { + case PhysicalType::INT16: + return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); + case PhysicalType::INT32: + return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); + case PhysicalType::INT64: + return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); + case PhysicalType::INT128: + return Value::DECIMAL(QuantileAbs(integral), width, scale); + default: + throw InternalException("Unknown DECIMAL type"); + } + } + default: + return Value::DOUBLE(QuantileAbs(v.GetValue())); + } +} + +//===--------------------------------------------------------------------===// +// Quantile Bind Data +//===--------------------------------------------------------------------===// +QuantileBindData::QuantileBindData() { +} + +QuantileBindData::QuantileBindData(const Value &quantile_p) + : quantiles(1, QuantileValue(QuantileAbs(quantile_p))), order(1, 0), desc(quantile_p < 0) { +} + +QuantileBindData::QuantileBindData(const vector &quantiles_p) { + vector normalised; + size_t pos = 0; + size_t neg = 0; + for (idx_t i = 0; i < quantiles_p.size(); ++i) { + const auto &q = quantiles_p[i]; + pos += (q > 0); + neg += (q < 0); + normalised.emplace_back(QuantileAbs(q)); + order.push_back(i); + } + if (pos && neg) { + throw BinderException("QUANTILE parameters must have consistent signs"); + } + desc = (neg > 0); + + IndirectLess lt(normalised.data()); + std::sort(order.begin(), order.end(), lt); + + for (const auto &q : normalised) { + quantiles.emplace_back(QuantileValue(q)); + } +} + +QuantileBindData::QuantileBindData(const QuantileBindData &other) : order(other.order), desc(other.desc) { + for (const auto &q : other.quantiles) { + quantiles.emplace_back(q); + } +} + +unique_ptr QuantileBindData::Copy() const { + return make_uniq(*this); +} + +bool QuantileBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return desc == other.desc && quantiles == other.quantiles && order == other.order; +} + +void QuantileBindData::Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto &bind_data = bind_data_p->Cast(); + vector raw; + for (const auto &q : bind_data.quantiles) { + raw.emplace_back(q.val); + } + serializer.WriteProperty(100, "quantiles", raw); + serializer.WriteProperty(101, "order", bind_data.order); + serializer.WriteProperty(102, "desc", bind_data.desc); +} + +unique_ptr QuantileBindData::Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto result = make_uniq(); + vector raw; + deserializer.ReadProperty(100, "quantiles", raw); + deserializer.ReadProperty(101, "order", result->order); + deserializer.ReadProperty(102, "desc", result->desc); + QuantileSerializationType deserialization_type; + deserializer.ReadPropertyWithExplicitDefault(103, "quantile_type", deserialization_type, + QuantileSerializationType::NON_DECIMAL); + + if (deserialization_type != QuantileSerializationType::NON_DECIMAL) { + deserializer.ReadDeletedProperty(104, "logical_type"); + } + + for (const auto &r : raw) { + result->quantiles.emplace_back(QuantileValue(r)); + } + return std::move(result); +} + +//===--------------------------------------------------------------------===// +// Cast Interpolation +//===--------------------------------------------------------------------===// +template <> +interval_t CastInterpolation::Cast(const dtime_t &src, Vector &result) { + return {0, 0, src.micros}; +} + +template <> +double CastInterpolation::Interpolate(const double &lo, const double d, const double &hi) { + return lo * (1.0 - d) + hi * d; +} + +template <> +dtime_t CastInterpolation::Interpolate(const dtime_t &lo, const double d, const dtime_t &hi) { + return dtime_t(std::llround(static_cast(lo.micros) * (1.0 - d) + static_cast(hi.micros) * d)); +} + +template <> +timestamp_t CastInterpolation::Interpolate(const timestamp_t &lo, const double d, const timestamp_t &hi) { + return timestamp_t(std::llround(static_cast(lo.value) * (1.0 - d) + static_cast(hi.value) * d)); +} + +template <> +hugeint_t CastInterpolation::Interpolate(const hugeint_t &lo, const double d, const hugeint_t &hi) { + return Hugeint::Convert(Interpolate(Hugeint::Cast(lo), d, Hugeint::Cast(hi))); +} + +static interval_t MultiplyByDouble(const interval_t &i, const double &d) { // NOLINT + D_ASSERT(d >= 0 && d <= 1); + return Interval::FromMicro(std::llround(static_cast(Interval::GetMicro(i)) * d)); +} + +inline interval_t operator+(const interval_t &lhs, const interval_t &rhs) { + return Interval::FromMicro(Interval::GetMicro(lhs) + Interval::GetMicro(rhs)); +} + +inline interval_t operator-(const interval_t &lhs, const interval_t &rhs) { + return Interval::FromMicro(Interval::GetMicro(lhs) - Interval::GetMicro(rhs)); +} + +template <> +interval_t CastInterpolation::Interpolate(const interval_t &lo, const double d, const interval_t &hi) { + const interval_t delta = hi - lo; + return lo + MultiplyByDouble(delta, d); +} + +template <> +string_t CastInterpolation::Cast(const string_t &src, Vector &result) { + return StringVector::AddStringOrBlob(result, src); +} + +//===--------------------------------------------------------------------===// +// Scalar Quantile +//===--------------------------------------------------------------------===// +template +struct QuantileScalarOperation : public QuantileOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + Interpolator interp(bind_data.quantiles[0], state.v.size(), bind_data.desc); + target = interp.template Operation(state.v.data(), finalize_data.result); + } + + template + static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result, + idx_t ridx) { + auto &state = *reinterpret_cast(l_state); + auto gstate = reinterpret_cast(g_state); + + auto &data = state.GetOrCreateWindowCursor(partition); + const auto &fmask = partition.filter_mask; + + QuantileIncluded included(fmask, data); + const auto n = FrameSize(included, frames); + + D_ASSERT(aggr_input_data.bind_data); + auto &bind_data = aggr_input_data.bind_data->Cast(); + + auto rdata = FlatVector::GetData(result); + auto &rmask = FlatVector::Validity(result); + + if (!n) { + rmask.Set(ridx, false); + return; + } + + const auto &quantile = bind_data.quantiles[0]; + if (gstate && gstate->HasTree()) { + rdata[ridx] = gstate->GetWindowState().template WindowScalar(data, frames, n, result, + quantile); + } else { + auto &window_state = state.GetOrCreateWindowState(); + + // Update the skip list + window_state.UpdateSkip(data, frames, included); + + // Find the position(s) needed + rdata[ridx] = window_state.template WindowScalar(data, frames, n, result, quantile); + + // Save the previous state for next time + window_state.prevs = frames; + } + } +}; + +struct QuantileScalarFallback : QuantileOperation { + template + static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) { + state.AddElement(key, input_data); + } + + template + static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + D_ASSERT(bind_data.quantiles.size() == 1); + Interpolator interp(bind_data.quantiles[0], state.v.size(), bind_data.desc); + auto interpolation_result = interp.InterpolateInternal(state.v.data()); + CreateSortKeyHelpers::DecodeSortKey(interpolation_result, finalize_data.result, finalize_data.result_idx, + OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); + } +}; + +//===--------------------------------------------------------------------===// +// Quantile List +//===--------------------------------------------------------------------===// +template +struct QuantileListOperation : QuantileOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } + + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + + auto &result = ListVector::GetEntry(finalize_data.result); + auto ridx = ListVector::GetListSize(finalize_data.result); + ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); + auto rdata = FlatVector::GetData(result); + + auto v_t = state.v.data(); + D_ASSERT(v_t); + + auto &entry = target; + entry.offset = ridx; + idx_t lower = 0; + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + Interpolator interp(quantile, state.v.size(), bind_data.desc); + interp.begin = lower; + rdata[ridx + q] = interp.template Operation(v_t, result); + lower = interp.FRN; + } + entry.length = bind_data.quantiles.size(); + + ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); + } + + template + static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &list, + idx_t lidx) { + auto &state = *reinterpret_cast(l_state); + auto gstate = reinterpret_cast(g_state); + + auto &data = state.GetOrCreateWindowCursor(partition); + const auto &fmask = partition.filter_mask; + + D_ASSERT(aggr_input_data.bind_data); + auto &bind_data = aggr_input_data.bind_data->Cast(); + + QuantileIncluded included(fmask, data); + const auto n = FrameSize(included, frames); + + // Result is a constant LIST with a fixed length + if (!n) { + auto &lmask = FlatVector::Validity(list); + lmask.Set(lidx, false); + return; + } + + if (gstate && gstate->HasTree()) { + gstate->GetWindowState().template WindowList(data, frames, n, list, lidx, bind_data); + } else { + auto &window_state = state.GetOrCreateWindowState(); + window_state.UpdateSkip(data, frames, included); + window_state.template WindowList(data, frames, n, list, lidx, bind_data); + window_state.prevs = frames; + } + } +}; + +struct QuantileListFallback : QuantileOperation { + template + static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) { + state.AddElement(key, input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.v.empty()) { + finalize_data.ReturnNull(); + return; + } + + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->Cast(); + + auto &result = ListVector::GetEntry(finalize_data.result); + auto ridx = ListVector::GetListSize(finalize_data.result); + ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); + + D_ASSERT(state.v.data()); + + auto &entry = target; + entry.offset = ridx; + idx_t lower = 0; + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + Interpolator interp(quantile, state.v.size(), bind_data.desc); + interp.begin = lower; + auto interpolation_result = interp.InterpolateInternal(state.v.data()); + CreateSortKeyHelpers::DecodeSortKey(interpolation_result, result, ridx + q, + OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); + lower = interp.FRN; + } + entry.length = bind_data.quantiles.size(); + + ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); + } +}; + +//===--------------------------------------------------------------------===// +// Discrete Quantiles +//===--------------------------------------------------------------------===// +template +AggregateFunction GetDiscreteQuantileTemplated(const LogicalType &type) { + switch (type.InternalType()) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::INT8: + return OP::template GetFunction(type); + case PhysicalType::INT16: + return OP::template GetFunction(type); + case PhysicalType::INT32: + return OP::template GetFunction(type); + case PhysicalType::INT64: + return OP::template GetFunction(type); + case PhysicalType::INT128: + return OP::template GetFunction(type); + case PhysicalType::FLOAT: + return OP::template GetFunction(type); + case PhysicalType::DOUBLE: + return OP::template GetFunction(type); + case PhysicalType::INTERVAL: + return OP::template GetFunction(type); + case PhysicalType::VARCHAR: + return OP::template GetFunction(type); +#endif + default: + return OP::GetFallback(type); + } +} + +struct ScalarDiscreteQuantile { + template + static AggregateFunction GetFunction(const LogicalType &type) { + using STATE = QuantileState; + using OP = QuantileScalarOperation; + auto fun = AggregateFunction::UnaryAggregateDestructor(type, type); +#ifndef DUCKDB_SMALLER_BINARY + fun.window = OP::Window; + fun.window_init = OP::WindowInit; +#endif + return fun; + } + + static AggregateFunction GetFallback(const LogicalType &type) { + using STATE = QuantileState; + using OP = QuantileScalarFallback; + + AggregateFunction fun({type}, type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + AggregateSortKeyHelpers::UnaryUpdate, + AggregateFunction::StateCombine, + AggregateFunction::StateVoidFinalize, nullptr, nullptr, + AggregateFunction::StateDestroy); + return fun; + } +}; + +template +static AggregateFunction QuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { // NOLINT + LogicalType result_type = LogicalType::LIST(child_type); + return AggregateFunction( + {input_type}, result_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, + nullptr, AggregateFunction::StateDestroy); +} + +struct ListDiscreteQuantile { + template + static AggregateFunction GetFunction(const LogicalType &type) { + using STATE = QuantileState; + using OP = QuantileListOperation; + auto fun = QuantileListAggregate(type, type); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; +#ifndef DUCKDB_SMALLER_BINARY + fun.window = OP::template Window; + fun.window_init = OP::template WindowInit; +#endif + return fun; + } + + static AggregateFunction GetFallback(const LogicalType &type) { + using STATE = QuantileState; + using OP = QuantileListFallback; + + AggregateFunction fun({type}, LogicalType::LIST(type), AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + AggregateSortKeyHelpers::UnaryUpdate, + AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, nullptr, nullptr, + AggregateFunction::StateDestroy); + return fun; + } +}; + +AggregateFunction GetDiscreteQuantile(const LogicalType &type) { + return GetDiscreteQuantileTemplated(type); +} + +AggregateFunction GetDiscreteQuantileList(const LogicalType &type) { + return GetDiscreteQuantileTemplated(type); +} + +//===--------------------------------------------------------------------===// +// Continuous Quantiles +//===--------------------------------------------------------------------===// +template +AggregateFunction GetContinuousQuantileTemplated(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return OP::template GetFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::SMALLINT: + return OP::template GetFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::SQLNULL: + case LogicalTypeId::INTEGER: + return OP::template GetFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::BIGINT: + return OP::template GetFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::HUGEINT: + return OP::template GetFunction(type, LogicalType::DOUBLE); + case LogicalTypeId::FLOAT: + return OP::template GetFunction(type, type); + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::DOUBLE: + return OP::template GetFunction(LogicalType::DOUBLE, LogicalType::DOUBLE); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return OP::template GetFunction(type, type); + case PhysicalType::INT32: + return OP::template GetFunction(type, type); + case PhysicalType::INT64: + return OP::template GetFunction(type, type); + case PhysicalType::INT128: + return OP::template GetFunction(type, type); + default: + throw NotImplementedException("Unimplemented continuous quantile DECIMAL aggregate"); + } + case LogicalTypeId::DATE: + return OP::template GetFunction(type, LogicalType::TIMESTAMP); + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + return OP::template GetFunction(type, type); + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return OP::template GetFunction(type, type); + default: + throw NotImplementedException("Unimplemented continuous quantile aggregate"); + } +} + +struct ScalarContinuousQuantile { + template + static AggregateFunction GetFunction(const LogicalType &input_type, const LogicalType &target_type) { + using STATE = QuantileState; + using OP = QuantileScalarOperation; + auto fun = + AggregateFunction::UnaryAggregateDestructor(input_type, target_type); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; +#ifndef DUCKDB_SMALLER_BINARY + fun.window = OP::template Window; + fun.window_init = OP::template WindowInit; +#endif + return fun; + } +}; + +struct ListContinuousQuantile { + template + static AggregateFunction GetFunction(const LogicalType &input_type, const LogicalType &target_type) { + using STATE = QuantileState; + using OP = QuantileListOperation; + auto fun = QuantileListAggregate(input_type, target_type); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; +#ifndef DUCKDB_SMALLER_BINARY + fun.window = OP::template Window; + fun.window_init = OP::template WindowInit; +#endif + return fun; + } +}; + +AggregateFunction GetContinuousQuantile(const LogicalType &type) { + return GetContinuousQuantileTemplated(type); +} + +AggregateFunction GetContinuousQuantileList(const LogicalType &type) { + return GetContinuousQuantileTemplated(type); +} + +//===--------------------------------------------------------------------===// +// Quantile binding +//===--------------------------------------------------------------------===// +static const Value &CheckQuantile(const Value &quantile_val) { + if (quantile_val.IsNull()) { + throw BinderException("QUANTILE parameter cannot be NULL"); + } + auto quantile = quantile_val.GetValue(); + if (quantile < -1 || quantile > 1) { + throw BinderException("QUANTILE can only take parameters in the range [-1, 1]"); + } + if (Value::IsNan(quantile)) { + throw BinderException("QUANTILE parameter cannot be NaN"); + } + + return quantile_val; +} + +unique_ptr BindQuantile(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + if (arguments.size() < 2) { + throw BinderException("QUANTILE requires a range argument between [0, 1]"); + } + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw BinderException("QUANTILE can only take constant parameters"); + } + Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + if (quantile_val.IsNull()) { + throw BinderException("QUANTILE argument must not be NULL"); + } + vector quantiles; + switch (quantile_val.type().id()) { + case LogicalTypeId::LIST: + for (const auto &element_val : ListValue::GetChildren(quantile_val)) { + quantiles.push_back(CheckQuantile(element_val)); + } + break; + case LogicalTypeId::ARRAY: + for (const auto &element_val : ArrayValue::GetChildren(quantile_val)) { + quantiles.push_back(CheckQuantile(element_val)); + } + break; + default: + quantiles.push_back(CheckQuantile(quantile_val)); + break; + } + + Function::EraseArgument(function, arguments, arguments.size() - 1); + return make_uniq(quantiles); +} + +//===--------------------------------------------------------------------===// +// Function definitions +//===--------------------------------------------------------------------===// +static bool CanInterpolate(const LogicalType &type) { + if (type.HasAlias()) { + return false; + } + switch (type.id()) { + case LogicalTypeId::DECIMAL: + case LogicalTypeId::SQLNULL: + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::BIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DATE: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return true; + default: + return false; + } +} + +struct MedianFunction { + static AggregateFunction GetAggregate(const LogicalType &type) { + auto fun = CanInterpolate(type) ? GetContinuousQuantile(type) : GetDiscreteQuantile(type); + fun.name = "median"; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = Deserialize; + return fun; + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = QuantileBindData::Deserialize(deserializer, function); + + auto &input_type = function.arguments[0]; + function = GetAggregate(input_type); + return bind_data; + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetAggregate(arguments[0]->return_type); + return make_uniq(Value::DECIMAL(int16_t(5), 2, 1)); + } +}; + +struct DiscreteQuantileListFunction { + static AggregateFunction GetAggregate(const LogicalType &type) { + auto fun = GetDiscreteQuantileList(type); + fun.name = "quantile_disc"; + fun.bind = Bind; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::LIST(LogicalType::DOUBLE)); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = QuantileBindData::Deserialize(deserializer, function); + + auto &input_type = function.arguments[0]; + function = GetAggregate(input_type); + return bind_data; + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetAggregate(arguments[0]->return_type); + return BindQuantile(context, function, arguments); + } +}; + +struct DiscreteQuantileFunction { + static AggregateFunction GetAggregate(const LogicalType &type) { + auto fun = GetDiscreteQuantile(type); + fun.name = "quantile_disc"; + fun.bind = Bind; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::DOUBLE); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = QuantileBindData::Deserialize(deserializer, function); + auto &quantile_data = bind_data->Cast(); + + auto &input_type = function.arguments[0]; + if (quantile_data.quantiles.size() == 1) { + function = GetAggregate(input_type); + } else { + function = DiscreteQuantileListFunction::GetAggregate(input_type); + } + return bind_data; + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetAggregate(arguments[0]->return_type); + return BindQuantile(context, function, arguments); + } +}; + +struct ContinuousQuantileFunction { + static AggregateFunction GetAggregate(const LogicalType &type) { + auto fun = GetContinuousQuantile(type); + fun.name = "quantile_cont"; + fun.bind = Bind; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::DOUBLE); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = QuantileBindData::Deserialize(deserializer, function); + + auto &input_type = function.arguments[0]; + function = GetAggregate(input_type); + return bind_data; + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetAggregate(function.arguments[0].id() == LogicalTypeId::DECIMAL ? arguments[0]->return_type + : function.arguments[0]); + return BindQuantile(context, function, arguments); + } +}; + +struct ContinuousQuantileListFunction { + static AggregateFunction GetAggregate(const LogicalType &type) { + auto fun = GetContinuousQuantileList(type); + fun.name = "quantile_cont"; + fun.bind = Bind; + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = Deserialize; + // temporarily push an argument so we can bind the actual quantile + auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); + fun.arguments.push_back(list_of_double); + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto bind_data = QuantileBindData::Deserialize(deserializer, function); + + auto &input_type = function.arguments[0]; + function = GetAggregate(input_type); + return bind_data; + } + + static unique_ptr Bind(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetAggregate(function.arguments[0].id() == LogicalTypeId::DECIMAL ? arguments[0]->return_type + : function.arguments[0]); + return BindQuantile(context, function, arguments); + } +}; + +template +AggregateFunction EmptyQuantileFunction(LogicalType input, LogicalType result, const LogicalType &extra_arg) { + AggregateFunction fun({std::move(input)}, std::move(result), nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + OP::Bind); + if (extra_arg.id() != LogicalTypeId::INVALID) { + fun.arguments.push_back(extra_arg); + } + fun.serialize = QuantileBindData::Serialize; + fun.deserialize = OP::Deserialize; + fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + return fun; +} + +AggregateFunctionSet MedianFun::GetFunctions() { + AggregateFunctionSet set("median"); + set.AddFunction(EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, LogicalTypeId::INVALID)); + return set; +} + +AggregateFunctionSet QuantileDiscFun::GetFunctions() { + AggregateFunctionSet set("quantile_disc"); + set.AddFunction( + EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, LogicalType::DOUBLE)); + set.AddFunction(EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, + LogicalType::LIST(LogicalType::DOUBLE))); + // this function is here for deserialization - it cannot be called by users + set.AddFunction( + EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, LogicalType::INVALID)); + return set; +} + +vector GetContinuousQuantileTypes() { + return {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, LogicalType::BIGINT, + LogicalType::HUGEINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, + LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ}; +} + +AggregateFunctionSet QuantileContFun::GetFunctions() { + AggregateFunctionSet quantile_cont("quantile_cont"); + quantile_cont.AddFunction(EmptyQuantileFunction( + LogicalTypeId::DECIMAL, LogicalTypeId::DECIMAL, LogicalType::DOUBLE)); + quantile_cont.AddFunction(EmptyQuantileFunction( + LogicalTypeId::DECIMAL, LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE))); + for (const auto &type : GetContinuousQuantileTypes()) { + quantile_cont.AddFunction(EmptyQuantileFunction(type, type, LogicalType::DOUBLE)); + quantile_cont.AddFunction( + EmptyQuantileFunction(type, type, LogicalType::LIST(LogicalType::DOUBLE))); + } + return quantile_cont; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp new file mode 100644 index 00000000..8c332500 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp @@ -0,0 +1,449 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/reservoir_sample.hpp" +#include "core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/common/queue.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include +#include + +namespace duckdb { + +template +struct ReservoirQuantileState { + T *v; + idx_t len; + idx_t pos; + BaseReservoirSampling *r_samp; + + void Resize(idx_t new_len) { + if (new_len <= len) { + return; + } + T *old_v = v; + v = (T *)realloc(v, new_len * sizeof(T)); + if (!v) { + free(old_v); + throw InternalException("Memory allocation failure"); + } + len = new_len; + } + + void ReplaceElement(T &input) { + v[r_samp->min_weighted_entry_index] = input; + r_samp->ReplaceElement(); + } + + void FillReservoir(idx_t sample_size, T element) { + if (pos < sample_size) { + v[pos++] = element; + r_samp->InitializeReservoirWeights(pos, len); + } else { + D_ASSERT(r_samp->next_index_to_sample >= r_samp->num_entries_to_skip_b4_next_sample); + if (r_samp->next_index_to_sample == r_samp->num_entries_to_skip_b4_next_sample) { + ReplaceElement(element); + } + } + } +}; + +struct ReservoirQuantileBindData : public FunctionData { + ReservoirQuantileBindData() { + } + ReservoirQuantileBindData(double quantile_p, idx_t sample_size_p) + : quantiles(1, quantile_p), sample_size(sample_size_p) { + } + + ReservoirQuantileBindData(vector quantiles_p, idx_t sample_size_p) + : quantiles(std::move(quantiles_p)), sample_size(sample_size_p) { + } + + unique_ptr Copy() const override { + return make_uniq(quantiles, sample_size); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return quantiles == other.quantiles && sample_size == other.sample_size; + } + + static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "quantiles", bind_data.quantiles); + serializer.WriteProperty(101, "sample_size", bind_data.sample_size); + } + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { + auto result = make_uniq(); + deserializer.ReadProperty(100, "quantiles", result->quantiles); + deserializer.ReadProperty(101, "sample_size", result->sample_size); + return std::move(result); + } + + vector quantiles; + idx_t sample_size; +}; + +struct ReservoirQuantileOperation { + template + static void Initialize(STATE &state) { + state.v = nullptr; + state.len = 0; + state.pos = 0; + state.r_samp = nullptr; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + auto &bind_data = unary_input.input.bind_data->template Cast(); + if (state.pos == 0) { + state.Resize(bind_data.sample_size); + } + if (!state.r_samp) { + state.r_samp = new BaseReservoirSampling(); + } + D_ASSERT(state.v); + state.FillReservoir(bind_data.sample_size, input); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.pos == 0) { + return; + } + if (target.pos == 0) { + target.Resize(source.len); + } + if (!target.r_samp) { + target.r_samp = new BaseReservoirSampling(); + } + for (idx_t src_idx = 0; src_idx < source.pos; src_idx++) { + target.FillReservoir(target.len, source.v[src_idx]); + } + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + if (state.v) { + free(state.v); + state.v = nullptr; + } + if (state.r_samp) { + delete state.r_samp; + state.r_samp = nullptr; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct ReservoirQuantileScalarOperation : public ReservoirQuantileOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.pos == 0) { + finalize_data.ReturnNull(); + return; + } + D_ASSERT(state.v); + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->template Cast(); + auto v_t = state.v; + D_ASSERT(bind_data.quantiles.size() == 1); + auto offset = (idx_t)((double)(state.pos - 1) * bind_data.quantiles[0]); + std::nth_element(v_t, v_t + offset, v_t + state.pos); + target = v_t[offset]; + } +}; + +AggregateFunction GetReservoirQuantileAggregateFunction(PhysicalType type) { + switch (type) { + case PhysicalType::INT8: + return AggregateFunction::UnaryAggregateDestructor, int8_t, int8_t, + ReservoirQuantileScalarOperation>(LogicalType::TINYINT, + LogicalType::TINYINT); + + case PhysicalType::INT16: + return AggregateFunction::UnaryAggregateDestructor, int16_t, int16_t, + ReservoirQuantileScalarOperation>(LogicalType::SMALLINT, + LogicalType::SMALLINT); + + case PhysicalType::INT32: + return AggregateFunction::UnaryAggregateDestructor, int32_t, int32_t, + ReservoirQuantileScalarOperation>(LogicalType::INTEGER, + LogicalType::INTEGER); + + case PhysicalType::INT64: + return AggregateFunction::UnaryAggregateDestructor, int64_t, int64_t, + ReservoirQuantileScalarOperation>(LogicalType::BIGINT, + LogicalType::BIGINT); + + case PhysicalType::INT128: + return AggregateFunction::UnaryAggregateDestructor, hugeint_t, hugeint_t, + ReservoirQuantileScalarOperation>(LogicalType::HUGEINT, + LogicalType::HUGEINT); + case PhysicalType::FLOAT: + return AggregateFunction::UnaryAggregateDestructor, float, float, + ReservoirQuantileScalarOperation>(LogicalType::FLOAT, + LogicalType::FLOAT); + case PhysicalType::DOUBLE: + return AggregateFunction::UnaryAggregateDestructor, double, double, + ReservoirQuantileScalarOperation>(LogicalType::DOUBLE, + LogicalType::DOUBLE); + default: + throw InternalException("Unimplemented reservoir quantile aggregate"); + } +} + +template +struct ReservoirQuantileListOperation : public ReservoirQuantileOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.pos == 0) { + finalize_data.ReturnNull(); + return; + } + + D_ASSERT(finalize_data.input.bind_data); + auto &bind_data = finalize_data.input.bind_data->template Cast(); + + auto &result = ListVector::GetEntry(finalize_data.result); + auto ridx = ListVector::GetListSize(finalize_data.result); + ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); + auto rdata = FlatVector::GetData(result); + + auto v_t = state.v; + D_ASSERT(v_t); + + auto &entry = target; + entry.offset = ridx; + entry.length = bind_data.quantiles.size(); + for (size_t q = 0; q < entry.length; ++q) { + const auto &quantile = bind_data.quantiles[q]; + auto offset = (idx_t)((double)(state.pos - 1) * quantile); + std::nth_element(v_t, v_t + offset, v_t + state.pos); + rdata[ridx + q] = v_t[offset]; + } + + ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); + } +}; + +template +static AggregateFunction ReservoirQuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { + LogicalType result_type = LogicalType::LIST(child_type); + return AggregateFunction( + {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, + AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, + AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, + nullptr, AggregateFunction::StateDestroy); +} + +template +AggregateFunction GetTypedReservoirQuantileListAggregateFunction(const LogicalType &type) { + using STATE = ReservoirQuantileState; + using OP = ReservoirQuantileListOperation; + auto fun = ReservoirQuantileListAggregate(type, type); + return fun; +} + +AggregateFunction GetReservoirQuantileListAggregateFunction(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::SMALLINT: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::INTEGER: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::BIGINT: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::HUGEINT: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::FLOAT: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::DOUBLE: + return GetTypedReservoirQuantileListAggregateFunction(type); + case LogicalTypeId::DECIMAL: + switch (type.InternalType()) { + case PhysicalType::INT16: + return GetTypedReservoirQuantileListAggregateFunction(type); + case PhysicalType::INT32: + return GetTypedReservoirQuantileListAggregateFunction(type); + case PhysicalType::INT64: + return GetTypedReservoirQuantileListAggregateFunction(type); + case PhysicalType::INT128: + return GetTypedReservoirQuantileListAggregateFunction(type); + default: + throw NotImplementedException("Unimplemented reservoir quantile list aggregate"); + } + default: + // TODO: Add quantitative temporal types + throw NotImplementedException("Unimplemented reservoir quantile list aggregate"); + } +} + +static double CheckReservoirQuantile(const Value &quantile_val) { + if (quantile_val.IsNull()) { + throw BinderException("RESERVOIR_QUANTILE QUANTILE parameter cannot be NULL"); + } + auto quantile = quantile_val.GetValue(); + if (quantile < 0 || quantile > 1) { + throw BinderException("RESERVOIR_QUANTILE can only take parameters in the range [0, 1]"); + } + return quantile; +} + +unique_ptr BindReservoirQuantile(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + D_ASSERT(arguments.size() >= 2); + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw BinderException("RESERVOIR_QUANTILE can only take constant quantile parameters"); + } + Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + vector quantiles; + if (quantile_val.type().id() != LogicalTypeId::LIST) { + quantiles.push_back(CheckReservoirQuantile(quantile_val)); + } else { + for (const auto &element_val : ListValue::GetChildren(quantile_val)) { + quantiles.push_back(CheckReservoirQuantile(element_val)); + } + } + + if (arguments.size() == 2) { + // remove the quantile argument so we can use the unary aggregate + if (function.arguments.size() == 2) { + Function::EraseArgument(function, arguments, arguments.size() - 1); + } else { + arguments.pop_back(); + } + return make_uniq(quantiles, 8192ULL); + } + if (!arguments[2]->IsFoldable()) { + throw BinderException("RESERVOIR_QUANTILE can only take constant sample size parameters"); + } + Value sample_size_val = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); + if (sample_size_val.IsNull()) { + throw BinderException("Size of the RESERVOIR_QUANTILE sample cannot be NULL"); + } + auto sample_size = sample_size_val.GetValue(); + + if (sample_size_val.IsNull() || sample_size <= 0) { + throw BinderException("Size of the RESERVOIR_QUANTILE sample must be bigger than 0"); + } + + // remove the quantile arguments so we can use the unary aggregate + if (function.arguments.size() == arguments.size()) { + Function::EraseArgument(function, arguments, arguments.size() - 1); + Function::EraseArgument(function, arguments, arguments.size() - 1); + } else { + arguments.pop_back(); + arguments.pop_back(); + } + return make_uniq(quantiles, NumericCast(sample_size)); +} + +unique_ptr BindReservoirQuantileDecimal(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetReservoirQuantileAggregateFunction(arguments[0]->return_type.InternalType()); + auto bind_data = BindReservoirQuantile(context, function, arguments); + function.name = "reservoir_quantile"; + function.serialize = ReservoirQuantileBindData::Serialize; + function.deserialize = ReservoirQuantileBindData::Deserialize; + return bind_data; +} + +AggregateFunction GetReservoirQuantileAggregate(PhysicalType type) { + auto fun = GetReservoirQuantileAggregateFunction(type); + fun.bind = BindReservoirQuantile; + fun.serialize = ReservoirQuantileBindData::Serialize; + fun.deserialize = ReservoirQuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + fun.arguments.emplace_back(LogicalType::DOUBLE); + return fun; +} + +unique_ptr BindReservoirQuantileDecimalList(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + function = GetReservoirQuantileListAggregateFunction(arguments[0]->return_type); + auto bind_data = BindReservoirQuantile(context, function, arguments); + function.serialize = ReservoirQuantileBindData::Serialize; + function.deserialize = ReservoirQuantileBindData::Deserialize; + function.name = "reservoir_quantile"; + return bind_data; +} + +AggregateFunction GetReservoirQuantileListAggregate(const LogicalType &type) { + auto fun = GetReservoirQuantileListAggregateFunction(type); + fun.bind = BindReservoirQuantile; + fun.serialize = ReservoirQuantileBindData::Serialize; + fun.deserialize = ReservoirQuantileBindData::Deserialize; + // temporarily push an argument so we can bind the actual quantile + auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); + fun.arguments.push_back(list_of_double); + return fun; +} + +static void DefineReservoirQuantile(AggregateFunctionSet &set, const LogicalType &type) { + // Four versions: type, scalar/list[, count] + auto fun = GetReservoirQuantileAggregate(type.InternalType()); + set.AddFunction(fun); + + fun.arguments.emplace_back(LogicalType::INTEGER); + set.AddFunction(fun); + + // List variants + fun = GetReservoirQuantileListAggregate(type); + set.AddFunction(fun); + + fun.arguments.emplace_back(LogicalType::INTEGER); + set.AddFunction(fun); +} + +static void GetReservoirQuantileDecimalFunction(AggregateFunctionSet &set, const vector &arguments, + const LogicalType &return_value) { + AggregateFunction fun(arguments, return_value, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + BindReservoirQuantileDecimal); + fun.serialize = ReservoirQuantileBindData::Serialize; + fun.deserialize = ReservoirQuantileBindData::Deserialize; + set.AddFunction(fun); + + fun.arguments.emplace_back(LogicalType::INTEGER); + set.AddFunction(fun); +} + +AggregateFunctionSet ReservoirQuantileFun::GetFunctions() { + AggregateFunctionSet reservoir_quantile; + + // DECIMAL + GetReservoirQuantileDecimalFunction(reservoir_quantile, {LogicalTypeId::DECIMAL, LogicalType::DOUBLE}, + LogicalTypeId::DECIMAL); + GetReservoirQuantileDecimalFunction(reservoir_quantile, + {LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE)}, + LogicalType::LIST(LogicalTypeId::DECIMAL)); + + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::TINYINT); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::SMALLINT); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::INTEGER); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::BIGINT); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::HUGEINT); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::FLOAT); + DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::DOUBLE); + return reservoir_quantile; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp b/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp new file mode 100644 index 00000000..fc184b8b --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp @@ -0,0 +1,410 @@ +#include "duckdb/function/scalar/nested_functions.hpp" +#include "core_functions/aggregate/nested_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/common/types/vector.hpp" +#include "core_functions/aggregate/histogram_helpers.hpp" +#include "core_functions/scalar/generic_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/algorithm.hpp" + +namespace duckdb { + +template +struct HistogramBinState { + using TYPE = T; + + unsafe_vector *bin_boundaries; + unsafe_vector *counts; + + void Initialize() { + bin_boundaries = nullptr; + counts = nullptr; + } + + void Destroy() { + if (bin_boundaries) { + delete bin_boundaries; + bin_boundaries = nullptr; + } + if (counts) { + delete counts; + counts = nullptr; + } + } + + bool IsSet() { + return bin_boundaries; + } + + template + void InitializeBins(Vector &bin_vector, idx_t count, idx_t pos, AggregateInputData &aggr_input) { + bin_boundaries = new unsafe_vector(); + counts = new unsafe_vector(); + UnifiedVectorFormat bin_data; + bin_vector.ToUnifiedFormat(count, bin_data); + auto bin_counts = UnifiedVectorFormat::GetData(bin_data); + auto bin_index = bin_data.sel->get_index(pos); + auto bin_list = bin_counts[bin_index]; + if (!bin_data.validity.RowIsValid(bin_index)) { + throw BinderException("Histogram bin list cannot be NULL"); + } + + auto &bin_child = ListVector::GetEntry(bin_vector); + auto bin_count = ListVector::GetListSize(bin_vector); + UnifiedVectorFormat bin_child_data; + auto extra_state = OP::CreateExtraState(bin_count); + OP::PrepareData(bin_child, bin_count, extra_state, bin_child_data); + + bin_boundaries->reserve(bin_list.length); + for (idx_t i = 0; i < bin_list.length; i++) { + auto bin_child_idx = bin_child_data.sel->get_index(bin_list.offset + i); + if (!bin_child_data.validity.RowIsValid(bin_child_idx)) { + throw BinderException("Histogram bin entry cannot be NULL"); + } + bin_boundaries->push_back(OP::template ExtractValue(bin_child_data, bin_list.offset + i, aggr_input)); + } + // sort the bin boundaries + std::sort(bin_boundaries->begin(), bin_boundaries->end()); + // ensure there are no duplicate bin boundaries + for (idx_t i = 1; i < bin_boundaries->size(); i++) { + if (Equals::Operation((*bin_boundaries)[i - 1], (*bin_boundaries)[i])) { + bin_boundaries->erase_at(i); + i--; + } + } + + counts->resize(bin_list.length + 1); + } +}; + +struct HistogramBinFunction { + template + static void Initialize(STATE &state) { + state.Initialize(); + } + + template + static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { + state.Destroy(); + } + + static bool IgnoreNull() { + return true; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { + if (!source.bin_boundaries) { + // nothing to combine + return; + } + if (!target.bin_boundaries) { + // target does not have bin boundaries - copy everything over + target.bin_boundaries = new unsafe_vector(); + target.counts = new unsafe_vector(); + *target.bin_boundaries = *source.bin_boundaries; + *target.counts = *source.counts; + } else { + // both source and target have bin boundaries + if (*target.bin_boundaries != *source.bin_boundaries) { + throw NotImplementedException( + "Histogram - cannot combine histograms with different bin boundaries. " + "Bin boundaries must be the same for all histograms within the same group"); + } + if (target.counts->size() != source.counts->size()) { + throw InternalException("Histogram combine - bin boundaries are the same but counts are different"); + } + D_ASSERT(target.counts->size() == source.counts->size()); + for (idx_t bin_idx = 0; bin_idx < target.counts->size(); bin_idx++) { + (*target.counts)[bin_idx] += (*source.counts)[bin_idx]; + } + } + } +}; + +struct HistogramRange { + static constexpr bool EXACT = false; + + template + static idx_t GetBin(T value, const unsafe_vector &bin_boundaries) { + auto entry = std::lower_bound(bin_boundaries.begin(), bin_boundaries.end(), value); + return UnsafeNumericCast(entry - bin_boundaries.begin()); + } +}; + +struct HistogramExact { + static constexpr bool EXACT = true; + + template + static idx_t GetBin(T value, const unsafe_vector &bin_boundaries) { + auto entry = std::lower_bound(bin_boundaries.begin(), bin_boundaries.end(), value); + if (entry == bin_boundaries.end() || !(*entry == value)) { + // entry not found - return last bucket + return bin_boundaries.size(); + } + return UnsafeNumericCast(entry - bin_boundaries.begin()); + } +}; + +template +static void HistogramBinUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, + Vector &state_vector, idx_t count) { + auto &input = inputs[0]; + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + + auto &bin_vector = inputs[1]; + + auto extra_state = OP::CreateExtraState(count); + UnifiedVectorFormat input_data; + OP::PrepareData(input, count, extra_state, input_data); + + auto states = UnifiedVectorFormat::GetData *>(sdata); + auto data = UnifiedVectorFormat::GetData(input_data); + for (idx_t i = 0; i < count; i++) { + auto idx = input_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(idx)) { + continue; + } + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.IsSet()) { + state.template InitializeBins(bin_vector, count, i, aggr_input); + } + auto bin_entry = HIST::template GetBin(data[idx], *state.bin_boundaries); + ++(*state.counts)[bin_entry]; + } +} + +static bool SupportsOtherBucket(const LogicalType &type) { + if (type.HasAlias()) { + return false; + } + switch (type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::DATE: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::VARCHAR: + case LogicalTypeId::BLOB: + case LogicalTypeId::STRUCT: + case LogicalTypeId::LIST: + return true; + default: + return false; + } +} +static Value OtherBucketValue(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::DECIMAL: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + return Value::MaximumValue(type); + case LogicalTypeId::DATE: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + return Value::Infinity(type); + case LogicalTypeId::VARCHAR: + return Value(""); + case LogicalTypeId::BLOB: + return Value::BLOB(""); + case LogicalTypeId::STRUCT: { + // for structs we can set all child members to NULL + auto &child_types = StructType::GetChildTypes(type); + child_list_t child_list; + for (auto &child_type : child_types) { + child_list.push_back(make_pair(child_type.first, Value(child_type.second))); + } + return Value::STRUCT(std::move(child_list)); + } + case LogicalTypeId::LIST: + return Value::LIST(ListType::GetChildType(type), vector()); + default: + throw InternalException("Unsupported type for other bucket"); + } +} + +static void IsHistogramOtherBinFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input_type = args.data[0].GetType(); + if (!SupportsOtherBucket(input_type)) { + result.Reference(Value::BOOLEAN(false)); + return; + } + auto v = OtherBucketValue(input_type); + Vector ref(v); + VectorOperations::NotDistinctFrom(args.data[0], ref, result, args.size()); +} + +template +static void HistogramBinFinalizeFunction(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, + idx_t offset) { + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = UnifiedVectorFormat::GetData *>(sdata); + + auto &mask = FlatVector::Validity(result); + auto old_len = ListVector::GetListSize(result); + idx_t new_entries = 0; + bool supports_other_bucket = SupportsOtherBucket(MapType::KeyType(result.GetType())); + // figure out how much space we need + for (idx_t i = 0; i < count; i++) { + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.bin_boundaries) { + continue; + } + new_entries += state.bin_boundaries->size(); + if (state.counts->back() > 0 && supports_other_bucket) { + // overflow bucket has entries + new_entries++; + } + } + // reserve space in the list vector + ListVector::Reserve(result, old_len + new_entries); + auto &keys = MapVector::GetKeys(result); + auto &values = MapVector::GetValues(result); + auto list_entries = FlatVector::GetData(result); + auto count_entries = FlatVector::GetData(values); + + idx_t current_offset = old_len; + for (idx_t i = 0; i < count; i++) { + const auto rid = i + offset; + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.bin_boundaries) { + mask.SetInvalid(rid); + continue; + } + + auto &list_entry = list_entries[rid]; + list_entry.offset = current_offset; + for (idx_t bin_idx = 0; bin_idx < state.bin_boundaries->size(); bin_idx++) { + OP::template HistogramFinalize((*state.bin_boundaries)[bin_idx], keys, current_offset); + count_entries[current_offset] = (*state.counts)[bin_idx]; + current_offset++; + } + if (state.counts->back() > 0 && supports_other_bucket) { + // add overflow bucket ("others") + // set bin boundary to NULL for overflow bucket + keys.SetValue(current_offset, OtherBucketValue(keys.GetType())); + count_entries[current_offset] = state.counts->back(); + current_offset++; + } + list_entry.length = current_offset - list_entry.offset; + } + D_ASSERT(current_offset == old_len + new_entries); + ListVector::SetListSize(result, current_offset); + result.Verify(count); +} + +template +static AggregateFunction GetHistogramBinFunction(const LogicalType &type) { + using STATE_TYPE = HistogramBinState; + + const char *function_name = HIST::EXACT ? "histogram_exact" : "histogram"; + + auto struct_type = LogicalType::MAP(type, LogicalType::UBIGINT); + return AggregateFunction( + function_name, {type, LogicalType::LIST(type)}, struct_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, HistogramBinUpdateFunction, + AggregateFunction::StateCombine, HistogramBinFinalizeFunction, nullptr, + nullptr, AggregateFunction::StateDestroy); +} + +template +AggregateFunction GetHistogramBinFunction(const LogicalType &type) { + if (type.id() == LogicalTypeId::DECIMAL) { + return GetHistogramBinFunction(LogicalType::DOUBLE); + } + switch (type.InternalType()) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::BOOL: + return GetHistogramBinFunction(type); + case PhysicalType::UINT8: + return GetHistogramBinFunction(type); + case PhysicalType::UINT16: + return GetHistogramBinFunction(type); + case PhysicalType::UINT32: + return GetHistogramBinFunction(type); + case PhysicalType::UINT64: + return GetHistogramBinFunction(type); + case PhysicalType::INT8: + return GetHistogramBinFunction(type); + case PhysicalType::INT16: + return GetHistogramBinFunction(type); + case PhysicalType::INT32: + return GetHistogramBinFunction(type); + case PhysicalType::INT64: + return GetHistogramBinFunction(type); + case PhysicalType::FLOAT: + return GetHistogramBinFunction(type); + case PhysicalType::DOUBLE: + return GetHistogramBinFunction(type); + case PhysicalType::VARCHAR: + return GetHistogramBinFunction(type); +#endif + default: + return GetHistogramBinFunction(type); + } +} + +template +unique_ptr HistogramBinBindFunction(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + for (auto &arg : arguments) { + if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + } + + function = GetHistogramBinFunction(arguments[0]->return_type); + return nullptr; +} + +AggregateFunction HistogramFun::BinnedHistogramFunction() { + return AggregateFunction("histogram", {LogicalType::ANY, LogicalType::LIST(LogicalType::ANY)}, LogicalTypeId::MAP, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + HistogramBinBindFunction, nullptr); +} + +AggregateFunction HistogramExactFun::GetFunction() { + return AggregateFunction("histogram_exact", {LogicalType::ANY, LogicalType::LIST(LogicalType::ANY)}, + LogicalTypeId::MAP, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + HistogramBinBindFunction, nullptr); +} + +ScalarFunction IsHistogramOtherBinFun::GetFunction() { + return ScalarFunction("is_histogram_other_bin", {LogicalType::ANY}, LogicalType::BOOLEAN, + IsHistogramOtherBinFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp new file mode 100644 index 00000000..8a736f23 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp @@ -0,0 +1,236 @@ +#include "duckdb/function/scalar/nested_functions.hpp" +#include "core_functions/aggregate/nested_functions.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/string_map_set.hpp" +#include "core_functions/aggregate/histogram_helpers.hpp" +#include "duckdb/common/owning_string_map.hpp" + +namespace duckdb { + +template +struct HistogramFunction { + template + static void Initialize(STATE &state) { + state.hist = nullptr; + } + + template + static void Destroy(STATE &state, AggregateInputData &) { + if (state.hist) { + delete state.hist; + } + } + + static bool IgnoreNull() { + return true; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { + if (!source.hist) { + return; + } + if (!target.hist) { + target.hist = MAP_TYPE::CreateEmpty(input_data.allocator); + } + for (auto &entry : *source.hist) { + (*target.hist)[entry.first] += entry.second; + } + } +}; + +template +struct DefaultMapType { + using MAP_TYPE = TYPE; + + static TYPE *CreateEmpty(ArenaAllocator &) { + return new TYPE(); + } +}; + +template +struct StringMapType { + using MAP_TYPE = TYPE; + + static TYPE *CreateEmpty(ArenaAllocator &allocator) { + return new TYPE(allocator); + } +}; + +template +static void HistogramUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, + Vector &state_vector, idx_t count) { + + D_ASSERT(input_count == 1); + + auto &input = inputs[0]; + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + + auto extra_state = OP::CreateExtraState(count); + UnifiedVectorFormat input_data; + OP::PrepareData(input, count, extra_state, input_data); + + auto states = UnifiedVectorFormat::GetData *>(sdata); + auto input_values = UnifiedVectorFormat::GetData(input_data); + for (idx_t i = 0; i < count; i++) { + auto idx = input_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(idx)) { + continue; + } + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.hist) { + state.hist = MAP_TYPE::CreateEmpty(aggr_input.allocator); + } + auto &input_value = input_values[idx]; + ++(*state.hist)[input_value]; + } +} + +template +static void HistogramFinalizeFunction(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, + idx_t offset) { + using HIST_STATE = HistogramAggState; + + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = UnifiedVectorFormat::GetData(sdata); + + auto &mask = FlatVector::Validity(result); + auto old_len = ListVector::GetListSize(result); + idx_t new_entries = 0; + // figure out how much space we need + for (idx_t i = 0; i < count; i++) { + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.hist) { + continue; + } + new_entries += state.hist->size(); + } + // reserve space in the list vector + ListVector::Reserve(result, old_len + new_entries); + auto &keys = MapVector::GetKeys(result); + auto &values = MapVector::GetValues(result); + auto list_entries = FlatVector::GetData(result); + auto count_entries = FlatVector::GetData(values); + + idx_t current_offset = old_len; + for (idx_t i = 0; i < count; i++) { + const auto rid = i + offset; + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.hist) { + mask.SetInvalid(rid); + continue; + } + + auto &list_entry = list_entries[rid]; + list_entry.offset = current_offset; + for (auto &entry : *state.hist) { + OP::template HistogramFinalize(entry.first, keys, current_offset); + count_entries[current_offset] = entry.second; + current_offset++; + } + list_entry.length = current_offset - list_entry.offset; + } + D_ASSERT(current_offset == old_len + new_entries); + ListVector::SetListSize(result, current_offset); + result.Verify(count); +} + +template +static AggregateFunction GetHistogramFunction(const LogicalType &type) { + using STATE_TYPE = HistogramAggState; + using HIST_FUNC = HistogramFunction; + + auto struct_type = LogicalType::MAP(type, LogicalType::UBIGINT); + return AggregateFunction( + "histogram", {type}, struct_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, HistogramUpdateFunction, + AggregateFunction::StateCombine, HistogramFinalizeFunction, nullptr, + nullptr, AggregateFunction::StateDestroy); +} + +template +AggregateFunction GetMapTypeInternal(const LogicalType &type) { + return GetHistogramFunction(type); +} + +template +AggregateFunction GetMapType(const LogicalType &type) { + if (IS_ORDERED) { + return GetMapTypeInternal>>(type); + } + return GetMapTypeInternal>>(type); +} + +template +AggregateFunction GetStringMapType(const LogicalType &type) { + if (IS_ORDERED) { + return GetMapTypeInternal>>(type); + } else { + return GetMapTypeInternal>>(type); + } +} + +template +AggregateFunction GetHistogramFunction(const LogicalType &type) { + switch (type.InternalType()) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::BOOL: + return GetMapType(type); + case PhysicalType::UINT8: + return GetMapType(type); + case PhysicalType::UINT16: + return GetMapType(type); + case PhysicalType::UINT32: + return GetMapType(type); + case PhysicalType::UINT64: + return GetMapType(type); + case PhysicalType::INT8: + return GetMapType(type); + case PhysicalType::INT16: + return GetMapType(type); + case PhysicalType::INT32: + return GetMapType(type); + case PhysicalType::INT64: + return GetMapType(type); + case PhysicalType::FLOAT: + return GetMapType(type); + case PhysicalType::DOUBLE: + return GetMapType(type); + case PhysicalType::VARCHAR: + return GetStringMapType(type); +#endif + default: + return GetStringMapType(type); + } +} + +template +unique_ptr HistogramBindFunction(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + + D_ASSERT(arguments.size() == 1); + + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + function = GetHistogramFunction(arguments[0]->return_type); + return make_uniq(function.return_type); +} + +AggregateFunctionSet HistogramFun::GetFunctions() { + AggregateFunctionSet fun; + AggregateFunction histogram_function("histogram", {LogicalType::ANY}, LogicalTypeId::MAP, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, HistogramBindFunction, nullptr); + fun.AddFunction(HistogramFun::BinnedHistogramFunction()); + fun.AddFunction(histogram_function); + return fun; +} + +AggregateFunction HistogramFun::GetHistogramUnorderedMap(LogicalType &type) { + return AggregateFunction("histogram", {LogicalType::ANY}, LogicalTypeId::MAP, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, HistogramBindFunction, nullptr); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp new file mode 100644 index 00000000..7b23987d --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp @@ -0,0 +1,212 @@ +#include "duckdb/common/pair.hpp" +#include "duckdb/common/types/list_segment.hpp" +#include "core_functions/aggregate/nested_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" + +namespace duckdb { + +struct ListBindData : public FunctionData { + explicit ListBindData(const LogicalType &stype_p); + ~ListBindData() override; + + LogicalType stype; + ListSegmentFunctions functions; + + unique_ptr Copy() const override { + return make_uniq(stype); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return stype == other.stype; + } +}; + +ListBindData::ListBindData(const LogicalType &stype_p) : stype(stype_p) { + // always unnest once because the result vector is of type LIST + auto type = ListType::GetChildType(stype_p); + GetSegmentDataFunctions(functions, type); +} + +ListBindData::~ListBindData() { +} + +struct ListAggState { + LinkedList linked_list; +}; + +struct ListFunction { + template + static void Initialize(STATE &state) { + state.linked_list.total_capacity = 0; + state.linked_list.first_segment = nullptr; + state.linked_list.last_segment = nullptr; + } + static bool IgnoreNull() { + return false; + } +}; + +static void ListUpdateFunction(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + Vector &state_vector, idx_t count) { + + D_ASSERT(input_count == 1); + auto &input = inputs[0]; + RecursiveUnifiedVectorFormat input_data; + Vector::RecursiveToUnifiedFormat(input, count, input_data); + + UnifiedVectorFormat states_data; + state_vector.ToUnifiedFormat(count, states_data); + auto states = UnifiedVectorFormat::GetData(states_data); + + auto &list_bind_data = aggr_input_data.bind_data->Cast(); + + for (idx_t i = 0; i < count; i++) { + auto &state = *states[states_data.sel->get_index(i)]; + aggr_input_data.allocator.AlignNext(); + list_bind_data.functions.AppendRow(aggr_input_data.allocator, state.linked_list, input_data, i); + } +} + +static void ListAbsorbFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data, + idx_t count) { + D_ASSERT(aggr_input_data.combine_type == AggregateCombineType::ALLOW_DESTRUCTIVE); + + UnifiedVectorFormat states_data; + states_vector.ToUnifiedFormat(count, states_data); + auto states_ptr = UnifiedVectorFormat::GetData(states_data); + + auto combined_ptr = FlatVector::GetData(combined); + for (idx_t i = 0; i < count; i++) { + + auto &state = *states_ptr[states_data.sel->get_index(i)]; + if (state.linked_list.total_capacity == 0) { + // NULL, no need to append + // this can happen when adding a FILTER to the grouping, e.g., + // LIST(i) FILTER (WHERE i <> 3) + continue; + } + + if (combined_ptr[i]->linked_list.total_capacity == 0) { + combined_ptr[i]->linked_list = state.linked_list; + continue; + } + + // append the linked list + combined_ptr[i]->linked_list.last_segment->next = state.linked_list.first_segment; + combined_ptr[i]->linked_list.last_segment = state.linked_list.last_segment; + combined_ptr[i]->linked_list.total_capacity += state.linked_list.total_capacity; + } +} + +static void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + idx_t offset) { + + UnifiedVectorFormat states_data; + states_vector.ToUnifiedFormat(count, states_data); + auto states = UnifiedVectorFormat::GetData(states_data); + + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + + auto &mask = FlatVector::Validity(result); + auto result_data = FlatVector::GetData(result); + size_t total_len = ListVector::GetListSize(result); + + auto &list_bind_data = aggr_input_data.bind_data->Cast(); + + // first iterate over all entries and set up the list entries, and get the newly required total length + for (idx_t i = 0; i < count; i++) { + + auto &state = *states[states_data.sel->get_index(i)]; + const auto rid = i + offset; + result_data[rid].offset = total_len; + if (state.linked_list.total_capacity == 0) { + mask.SetInvalid(rid); + result_data[rid].length = 0; + continue; + } + + // set the length and offset of this list in the result vector + auto total_capacity = state.linked_list.total_capacity; + result_data[rid].length = total_capacity; + total_len += total_capacity; + } + + // reserve capacity, then iterate over all entries again and copy over the data to the child vector + ListVector::Reserve(result, total_len); + auto &result_child = ListVector::GetEntry(result); + for (idx_t i = 0; i < count; i++) { + + auto &state = *states[states_data.sel->get_index(i)]; + const auto rid = i + offset; + if (state.linked_list.total_capacity == 0) { + continue; + } + + idx_t current_offset = result_data[rid].offset; + list_bind_data.functions.BuildListVector(state.linked_list, result_child, current_offset); + } + + ListVector::SetListSize(result, total_len); +} + +static void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data, + idx_t count) { + + // Can we use destructive combining? + if (aggr_input_data.combine_type == AggregateCombineType::ALLOW_DESTRUCTIVE) { + ListAbsorbFunction(states_vector, combined, aggr_input_data, count); + return; + } + + UnifiedVectorFormat states_data; + states_vector.ToUnifiedFormat(count, states_data); + auto states_ptr = UnifiedVectorFormat::GetData(states_data); + auto combined_ptr = FlatVector::GetData(combined); + + auto &list_bind_data = aggr_input_data.bind_data->Cast(); + auto result_type = ListType::GetChildType(list_bind_data.stype); + + for (idx_t i = 0; i < count; i++) { + auto &source = *states_ptr[states_data.sel->get_index(i)]; + auto &target = *combined_ptr[i]; + + const auto entry_count = source.linked_list.total_capacity; + Vector input(result_type, source.linked_list.total_capacity); + list_bind_data.functions.BuildListVector(source.linked_list, input, 0); + + RecursiveUnifiedVectorFormat input_data; + Vector::RecursiveToUnifiedFormat(input, entry_count, input_data); + + for (idx_t entry_idx = 0; entry_idx < entry_count; ++entry_idx) { + aggr_input_data.allocator.AlignNext(); + list_bind_data.functions.AppendRow(aggr_input_data.allocator, target.linked_list, input_data, entry_idx); + } + } +} + +unique_ptr ListBindFunction(ClientContext &context, AggregateFunction &function, + vector> &arguments) { + D_ASSERT(arguments.size() == 1); + D_ASSERT(function.arguments.size() == 1); + + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + function.arguments[0] = LogicalTypeId::UNKNOWN; + function.return_type = LogicalType::SQLNULL; + return nullptr; + } + + function.return_type = LogicalType::LIST(arguments[0]->return_type); + return make_uniq(function.return_type); +} + +AggregateFunction ListFun::GetFunction() { + auto func = + AggregateFunction({LogicalType::ANY}, LogicalTypeId::LIST, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, ListUpdateFunction, + ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr, nullptr); + + return func; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_avg.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_avg.cpp new file mode 100644 index 00000000..b4b43af2 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_avg.cpp @@ -0,0 +1,64 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "core_functions/aggregate/regression_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { +struct RegrState { + double sum; + size_t count; +}; + +struct RegrAvgFunction { + template + static void Initialize(STATE &state) { + state.sum = 0; + state.count = 0; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target.sum += source.sum; + target.count += source.count; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = state.sum / (double)state.count; + } + } + static bool IgnoreNull() { + return true; + } +}; +struct RegrAvgXFunction : RegrAvgFunction { + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + state.sum += x; + state.count++; + } +}; + +struct RegrAvgYFunction : RegrAvgFunction { + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + state.sum += y; + state.count++; + } +}; + +AggregateFunction RegrAvgxFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +AggregateFunction RegrAvgyFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp new file mode 100644 index 00000000..9215fcfb --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp @@ -0,0 +1,18 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "core_functions/aggregate/regression_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "core_functions/aggregate/regression/regr_count.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +AggregateFunction RegrCountFun::GetFunction() { + auto regr_count = AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::UINTEGER); + regr_count.name = "regr_count"; + regr_count.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return regr_count; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_intercept.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_intercept.cpp new file mode 100644 index 00000000..e727d266 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_intercept.cpp @@ -0,0 +1,67 @@ +//! AVG(y)-REGR_SLOPE(y,x)*AVG(x) + +#include "core_functions/aggregate/regression_functions.hpp" +#include "core_functions/aggregate/regression/regr_slope.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct RegrInterceptState { + size_t count; + double sum_x; + double sum_y; + RegrSlopeState slope; +}; + +struct RegrInterceptOperation { + template + static void Initialize(STATE &state) { + state.count = 0; + state.sum_x = 0; + state.sum_y = 0; + RegrSlopeOperation::Initialize(state.slope); + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + state.count++; + state.sum_x += x; + state.sum_y += y; + RegrSlopeOperation::Operation(state.slope, y, x, idata); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + target.count += source.count; + target.sum_x += source.sum_x; + target.sum_y += source.sum_y; + RegrSlopeOperation::Combine(source.slope, target.slope, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + return; + } + RegrSlopeOperation::Finalize(state.slope, target, finalize_data); + if (Value::IsNan(target)) { + finalize_data.ReturnNull(); + return; + } + auto x_avg = state.sum_x / state.count; + auto y_avg = state.sum_y / state.count; + target = y_avg - target * x_avg; + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction RegrInterceptFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_r2.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_r2.cpp new file mode 100644 index 00000000..ba89a8a6 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_r2.cpp @@ -0,0 +1,72 @@ +// REGR_R2(y, x) +// Returns the coefficient of determination for non-null pairs in a group. +// It is computed for non-null pairs using the following formula: +// null if var_pop(x) = 0, else +// 1 if var_pop(y) = 0 and var_pop(x) <> 0, else +// power(corr(y,x), 2) + +#include "core_functions/aggregate/algebraic/corr.hpp" +#include "duckdb/function/function_set.hpp" +#include "core_functions/aggregate/regression_functions.hpp" + +namespace duckdb { +struct RegrR2State { + CorrState corr; + StddevState var_pop_x; + StddevState var_pop_y; +}; + +struct RegrR2Operation { + template + static void Initialize(STATE &state) { + CorrOperation::Initialize(state.corr); + STDDevBaseOperation::Initialize(state.var_pop_x); + STDDevBaseOperation::Initialize(state.var_pop_y); + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + CorrOperation::Operation(state.corr, y, x, idata); + STDDevBaseOperation::Execute(state.var_pop_x, x); + STDDevBaseOperation::Execute(state.var_pop_y, y); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + CorrOperation::Combine(source.corr, target.corr, aggr_input_data); + STDDevBaseOperation::Combine(source.var_pop_x, target.var_pop_x, aggr_input_data); + STDDevBaseOperation::Combine(source.var_pop_y, target.var_pop_y, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + auto var_pop_x = state.var_pop_x.count > 1 ? (state.var_pop_x.dsquared / state.var_pop_x.count) : 0; + if (!Value::DoubleIsFinite(var_pop_x)) { + throw OutOfRangeException("VARPOP(X) is out of range!"); + } + if (var_pop_x == 0) { + finalize_data.ReturnNull(); + return; + } + auto var_pop_y = state.var_pop_y.count > 1 ? (state.var_pop_y.dsquared / state.var_pop_y.count) : 0; + if (!Value::DoubleIsFinite(var_pop_y)) { + throw OutOfRangeException("VARPOP(Y) is out of range!"); + } + if (var_pop_y == 0) { + target = 1; + return; + } + CorrOperation::Finalize(state.corr, target, finalize_data); + target = pow(target, 2); + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction RegrR2Fun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_slope.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_slope.cpp new file mode 100644 index 00000000..c5859399 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_slope.cpp @@ -0,0 +1,20 @@ +// REGR_SLOPE(y, x) +// Returns the slope of the linear regression line for non-null pairs in a group. +// It is computed for non-null pairs using the following formula: +// COVAR_POP(x,y) / VAR_POP(x) + +//! Input : Any numeric type +//! Output : Double + +#include "core_functions/aggregate/regression/regr_slope.hpp" +#include "duckdb/function/function_set.hpp" +#include "core_functions/aggregate/regression_functions.hpp" + +namespace duckdb { + +AggregateFunction RegrSlopeFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_sxx_syy.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_sxx_syy.cpp new file mode 100644 index 00000000..72202c2b --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_sxx_syy.cpp @@ -0,0 +1,75 @@ +// REGR_SXX(y, x) +// Returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs. +// REGR_SYY(y, x) +// Returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs. + +#include "core_functions/aggregate/regression/regr_count.hpp" +#include "duckdb/function/function_set.hpp" +#include "core_functions/aggregate/regression_functions.hpp" + +namespace duckdb { + +struct RegrSState { + size_t count; + StddevState var_pop; +}; + +struct RegrBaseOperation { + template + static void Initialize(STATE &state) { + RegrCountFunction::Initialize(state.count); + STDDevBaseOperation::Initialize(state.var_pop); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + RegrCountFunction::Combine(source.count, target.count, aggr_input_data); + STDDevBaseOperation::Combine(source.var_pop, target.var_pop, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.var_pop.count == 0) { + finalize_data.ReturnNull(); + return; + } + auto var_pop = state.var_pop.count > 1 ? (state.var_pop.dsquared / state.var_pop.count) : 0; + if (!Value::DoubleIsFinite(var_pop)) { + throw OutOfRangeException("VARPOP is out of range!"); + } + RegrCountFunction::Finalize(state.count, target, finalize_data); + target *= var_pop; + } + + static bool IgnoreNull() { + return true; + } +}; + +struct RegrSXXOperation : RegrBaseOperation { + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + RegrCountFunction::Operation(state.count, y, x, idata); + STDDevBaseOperation::Execute(state.var_pop, x); + } +}; + +struct RegrSYYOperation : RegrBaseOperation { + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + RegrCountFunction::Operation(state.count, y, x, idata); + STDDevBaseOperation::Execute(state.var_pop, y); + } +}; + +AggregateFunction RegrSXXFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +AggregateFunction RegrSYYFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_sxy.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_sxy.cpp new file mode 100644 index 00000000..1ab726e8 --- /dev/null +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_sxy.cpp @@ -0,0 +1,53 @@ +// REGR_SXY(y, x) +// Returns REGR_COUNT(expr1, expr2) * COVAR_POP(expr1, expr2) for non-null pairs. + +#include "core_functions/aggregate/regression/regr_count.hpp" +#include "core_functions/aggregate/algebraic/covar.hpp" +#include "core_functions/aggregate/regression_functions.hpp" +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct RegrSXyState { + size_t count; + CovarState cov_pop; +}; + +struct RegrSXYOperation { + template + static void Initialize(STATE &state) { + RegrCountFunction::Initialize(state.count); + CovarOperation::Initialize(state.cov_pop); + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + RegrCountFunction::Operation(state.count, y, x, idata); + CovarOperation::Operation(state.cov_pop, y, x, idata); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); + RegrCountFunction::Combine(source.count, target.count, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + CovarPopOperation::Finalize(state.cov_pop, target, finalize_data); + auto cov_pop = target; + RegrCountFunction::Finalize(state.count, target, finalize_data); + target *= cov_pop; + } + + static bool IgnoreNull() { + return true; + } +}; + +AggregateFunction RegrSXYFun::GetFunction() { + return AggregateFunction::BinaryAggregate( + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/core_functions_extension.cpp b/src/duckdb/extension/core_functions/core_functions_extension.cpp new file mode 100644 index 00000000..8bf09b80 --- /dev/null +++ b/src/duckdb/extension/core_functions/core_functions_extension.cpp @@ -0,0 +1,85 @@ +#define DUCKDB_EXTENSION_MAIN +#include "core_functions_extension.hpp" + +#include "core_functions/function_list.hpp" +#include "duckdb/main/extension_util.hpp" +#include "duckdb/function/register_function_list_helper.hpp" +#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" + +namespace duckdb { + +template +static void FillExtraInfo(const StaticFunctionDefinition &function, T &info) { + info.internal = true; + FillFunctionDescriptions(function, info); + info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; +} + +void LoadInternal(DuckDB &db) { + auto functions = StaticFunctionDefinition::GetFunctionList(); + for (idx_t i = 0; functions[i].name; i++) { + auto &function = functions[i]; + if (function.get_function || function.get_function_set) { + // scalar function + ScalarFunctionSet result; + if (function.get_function) { + result.AddFunction(function.get_function()); + } else { + result = function.get_function_set(); + } + result.name = function.name; + CreateScalarFunctionInfo info(result); + FillExtraInfo(function, info); + ExtensionUtil::RegisterFunction(*db.instance, std::move(info)); + } else if (function.get_aggregate_function || function.get_aggregate_function_set) { + // aggregate function + AggregateFunctionSet result; + if (function.get_aggregate_function) { + result.AddFunction(function.get_aggregate_function()); + } else { + result = function.get_aggregate_function_set(); + } + result.name = function.name; + CreateAggregateFunctionInfo info(result); + FillExtraInfo(function, info); + ExtensionUtil::RegisterFunction(*db.instance, std::move(info)); + } else { + throw InternalException("Do not know how to register function of this type"); + } + } +} + +void CoreFunctionsExtension::Load(DuckDB &db) { + LoadInternal(db); +} + +std::string CoreFunctionsExtension::Name() { + return "core_functions"; +} + +std::string CoreFunctionsExtension::Version() const { +#ifdef EXT_VERSION_CORE_FUNCTIONS + return EXT_VERSION_CORE_FUNCTIONS; +#else + return ""; +#endif +} + +} // namespace duckdb + +extern "C" { + +DUCKDB_EXTENSION_API void core_functions_init(duckdb::DatabaseInstance &db) { + duckdb::DuckDB db_wrapper(db); + duckdb::LoadInternal(db_wrapper); +} + +DUCKDB_EXTENSION_API const char *core_functions_version() { + return duckdb::DuckDB::LibraryVersion(); +} +} + +#ifndef DUCKDB_EXTENSION_MAIN +#error DUCKDB_EXTENSION_MAIN not defined +#endif diff --git a/src/duckdb/extension/core_functions/function_list.cpp b/src/duckdb/extension/core_functions/function_list.cpp new file mode 100644 index 00000000..53d96feb --- /dev/null +++ b/src/duckdb/extension/core_functions/function_list.cpp @@ -0,0 +1,407 @@ +#include "core_functions/function_list.hpp" +#include "core_functions/aggregate/algebraic_functions.hpp" +#include "core_functions/aggregate/distributive_functions.hpp" +#include "core_functions/aggregate/holistic_functions.hpp" +#include "core_functions/aggregate/nested_functions.hpp" +#include "core_functions/aggregate/regression_functions.hpp" +#include "core_functions/scalar/bit_functions.hpp" +#include "core_functions/scalar/blob_functions.hpp" +#include "core_functions/scalar/date_functions.hpp" +#include "core_functions/scalar/enum_functions.hpp" +#include "core_functions/scalar/generic_functions.hpp" +#include "core_functions/scalar/list_functions.hpp" +#include "core_functions/scalar/map_functions.hpp" +#include "core_functions/scalar/math_functions.hpp" +#include "core_functions/scalar/operators_functions.hpp" +#include "core_functions/scalar/random_functions.hpp" +#include "core_functions/scalar/secret_functions.hpp" +#include "core_functions/scalar/string_functions.hpp" +#include "core_functions/scalar/struct_functions.hpp" +#include "core_functions/scalar/union_functions.hpp" +#include "core_functions/scalar/array_functions.hpp" +#include "core_functions/scalar/debug_functions.hpp" + +namespace duckdb { + +// Scalar Function +#define DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _NAME) \ + { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::GetFunction, nullptr, nullptr, nullptr } +#define DUCKDB_SCALAR_FUNCTION(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _PARAM::Name) +#define DUCKDB_SCALAR_FUNCTION_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name) +// Scalar Function Set +#define DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _NAME) \ + { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, _PARAM::GetFunctions, nullptr, nullptr } +#define DUCKDB_SCALAR_FUNCTION_SET(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _PARAM::Name) +#define DUCKDB_SCALAR_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name) +// Aggregate Function +#define DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _NAME) \ + { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, nullptr, _PARAM::GetFunction, nullptr } +#define DUCKDB_AGGREGATE_FUNCTION(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _PARAM::Name) +#define DUCKDB_AGGREGATE_FUNCTION_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name) +// Aggregate Function Set +#define DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _NAME) \ + { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, nullptr, nullptr, _PARAM::GetFunctions } +#define DUCKDB_AGGREGATE_FUNCTION_SET(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _PARAM::Name) +#define DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name) +#define FINAL_FUNCTION \ + { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr } + +// this list is generated by scripts/generate_functions.py +static const StaticFunctionDefinition core_functions[] = { + DUCKDB_SCALAR_FUNCTION(FactorialOperatorFun), + DUCKDB_SCALAR_FUNCTION_SET(BitwiseAndFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasAnyFunAlias), + DUCKDB_SCALAR_FUNCTION(PowOperatorFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListDistanceFunAlias), + DUCKDB_SCALAR_FUNCTION_SET(LeftShiftFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListCosineDistanceFunAlias), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasAllFunAlias2), + DUCKDB_SCALAR_FUNCTION_SET(RightShiftFun), + DUCKDB_SCALAR_FUNCTION_SET(AbsOperatorFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasAllFunAlias), + DUCKDB_SCALAR_FUNCTION_ALIAS(PowOperatorFunAlias), + DUCKDB_SCALAR_FUNCTION(StartsWithOperatorFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(AbsFun), + DUCKDB_SCALAR_FUNCTION(AcosFun), + DUCKDB_SCALAR_FUNCTION(AcoshFun), + DUCKDB_SCALAR_FUNCTION_SET(AgeFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(AggregateFun), + DUCKDB_SCALAR_FUNCTION(AliasFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ApplyFun), + DUCKDB_AGGREGATE_FUNCTION(ApproxCountDistinctFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ApproxQuantileFun), + DUCKDB_AGGREGATE_FUNCTION(ApproxTopKFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxNullFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinNullFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgmaxFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgminFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(ArrayAggFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayAggrFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayAggregateFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayApplyFun), + DUCKDB_SCALAR_FUNCTION_SET(ArrayCosineDistanceFun), + DUCKDB_SCALAR_FUNCTION_SET(ArrayCosineSimilarityFun), + DUCKDB_SCALAR_FUNCTION_SET(ArrayCrossProductFun), + DUCKDB_SCALAR_FUNCTION_SET(ArrayDistanceFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayDistinctFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayDotProductFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayFilterFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayGradeUpFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayHasAllFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayHasAnyFun), + DUCKDB_SCALAR_FUNCTION_SET(ArrayInnerProductFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayNegativeDotProductFun), + DUCKDB_SCALAR_FUNCTION_SET(ArrayNegativeInnerProductFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayReduceFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayReverseSortFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArraySliceFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArraySortFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayTransformFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayUniqueFun), + DUCKDB_SCALAR_FUNCTION(ArrayValueFun), + DUCKDB_SCALAR_FUNCTION(ASCIIFun), + DUCKDB_SCALAR_FUNCTION(AsinFun), + DUCKDB_SCALAR_FUNCTION(AsinhFun), + DUCKDB_SCALAR_FUNCTION(AtanFun), + DUCKDB_SCALAR_FUNCTION(Atan2Fun), + DUCKDB_SCALAR_FUNCTION(AtanhFun), + DUCKDB_AGGREGATE_FUNCTION_SET(AvgFun), + DUCKDB_SCALAR_FUNCTION_SET(BarFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(Base64Fun), + DUCKDB_SCALAR_FUNCTION_SET(BinFun), + DUCKDB_AGGREGATE_FUNCTION_SET(BitAndFun), + DUCKDB_SCALAR_FUNCTION_SET(BitCountFun), + DUCKDB_AGGREGATE_FUNCTION_SET(BitOrFun), + DUCKDB_SCALAR_FUNCTION(BitPositionFun), + DUCKDB_AGGREGATE_FUNCTION_SET(BitXorFun), + DUCKDB_SCALAR_FUNCTION_SET(BitStringFun), + DUCKDB_AGGREGATE_FUNCTION_SET(BitstringAggFun), + DUCKDB_AGGREGATE_FUNCTION(BoolAndFun), + DUCKDB_AGGREGATE_FUNCTION(BoolOrFun), + DUCKDB_SCALAR_FUNCTION(CanCastImplicitlyFun), + DUCKDB_SCALAR_FUNCTION(CardinalityFun), + DUCKDB_SCALAR_FUNCTION(CbrtFun), + DUCKDB_SCALAR_FUNCTION_SET(CeilFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(CeilingFun), + DUCKDB_SCALAR_FUNCTION_SET(CenturyFun), + DUCKDB_SCALAR_FUNCTION(ChrFun), + DUCKDB_AGGREGATE_FUNCTION(CorrFun), + DUCKDB_SCALAR_FUNCTION(CosFun), + DUCKDB_SCALAR_FUNCTION(CoshFun), + DUCKDB_SCALAR_FUNCTION(CotFun), + DUCKDB_AGGREGATE_FUNCTION(CountIfFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(CountifFun), + DUCKDB_AGGREGATE_FUNCTION(CovarPopFun), + DUCKDB_AGGREGATE_FUNCTION(CovarSampFun), + DUCKDB_SCALAR_FUNCTION(CurrentDatabaseFun), + DUCKDB_SCALAR_FUNCTION(CurrentQueryFun), + DUCKDB_SCALAR_FUNCTION(CurrentSchemaFun), + DUCKDB_SCALAR_FUNCTION(CurrentSchemasFun), + DUCKDB_SCALAR_FUNCTION(CurrentSettingFun), + DUCKDB_SCALAR_FUNCTION(DamerauLevenshteinFun), + DUCKDB_SCALAR_FUNCTION_SET(DateDiffFun), + DUCKDB_SCALAR_FUNCTION_SET(DatePartFun), + DUCKDB_SCALAR_FUNCTION_SET(DateSubFun), + DUCKDB_SCALAR_FUNCTION_SET(DateTruncFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatediffFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatepartFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatesubFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatetruncFun), + DUCKDB_SCALAR_FUNCTION_SET(DayFun), + DUCKDB_SCALAR_FUNCTION_SET(DayNameFun), + DUCKDB_SCALAR_FUNCTION_SET(DayOfMonthFun), + DUCKDB_SCALAR_FUNCTION_SET(DayOfWeekFun), + DUCKDB_SCALAR_FUNCTION_SET(DayOfYearFun), + DUCKDB_SCALAR_FUNCTION_SET(DecadeFun), + DUCKDB_SCALAR_FUNCTION(DecodeFun), + DUCKDB_SCALAR_FUNCTION(DegreesFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(Editdist3Fun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ElementAtFun), + DUCKDB_SCALAR_FUNCTION(EncodeFun), + DUCKDB_AGGREGATE_FUNCTION_SET(EntropyFun), + DUCKDB_SCALAR_FUNCTION(EnumCodeFun), + DUCKDB_SCALAR_FUNCTION(EnumFirstFun), + DUCKDB_SCALAR_FUNCTION(EnumLastFun), + DUCKDB_SCALAR_FUNCTION(EnumRangeFun), + DUCKDB_SCALAR_FUNCTION(EnumRangeBoundaryFun), + DUCKDB_SCALAR_FUNCTION_SET(EpochFun), + DUCKDB_SCALAR_FUNCTION_SET(EpochMsFun), + DUCKDB_SCALAR_FUNCTION_SET(EpochNsFun), + DUCKDB_SCALAR_FUNCTION_SET(EpochUsFun), + DUCKDB_SCALAR_FUNCTION_SET(EquiWidthBinsFun), + DUCKDB_SCALAR_FUNCTION_SET(EraFun), + DUCKDB_SCALAR_FUNCTION(EvenFun), + DUCKDB_SCALAR_FUNCTION(ExpFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(FactorialFun), + DUCKDB_AGGREGATE_FUNCTION(FAvgFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(FilterFun), + DUCKDB_SCALAR_FUNCTION(ListFlattenFun), + DUCKDB_SCALAR_FUNCTION_SET(FloorFun), + DUCKDB_SCALAR_FUNCTION(FormatFun), + DUCKDB_SCALAR_FUNCTION(FormatreadabledecimalsizeFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(FormatreadablesizeFun), + DUCKDB_SCALAR_FUNCTION(FormatBytesFun), + DUCKDB_SCALAR_FUNCTION(FromBase64Fun), + DUCKDB_SCALAR_FUNCTION_ALIAS(FromBinaryFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(FromHexFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(FsumFun), + DUCKDB_SCALAR_FUNCTION(GammaFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(GcdFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(GenRandomUuidFun), + DUCKDB_SCALAR_FUNCTION_SET(GenerateSeriesFun), + DUCKDB_SCALAR_FUNCTION(GetBitFun), + DUCKDB_SCALAR_FUNCTION(GetCurrentTimestampFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(GradeUpFun), + DUCKDB_SCALAR_FUNCTION_SET(GreatestFun), + DUCKDB_SCALAR_FUNCTION_SET(GreatestCommonDivisorFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(GroupConcatFun), + DUCKDB_SCALAR_FUNCTION(HammingFun), + DUCKDB_SCALAR_FUNCTION(HashFun), + DUCKDB_SCALAR_FUNCTION_SET(HexFun), + DUCKDB_AGGREGATE_FUNCTION_SET(HistogramFun), + DUCKDB_AGGREGATE_FUNCTION(HistogramExactFun), + DUCKDB_SCALAR_FUNCTION_SET(HoursFun), + DUCKDB_SCALAR_FUNCTION(InSearchPathFun), + DUCKDB_SCALAR_FUNCTION(InstrFun), + DUCKDB_SCALAR_FUNCTION(IsHistogramOtherBinFun), + DUCKDB_SCALAR_FUNCTION_SET(IsFiniteFun), + DUCKDB_SCALAR_FUNCTION_SET(IsInfiniteFun), + DUCKDB_SCALAR_FUNCTION_SET(IsNanFun), + DUCKDB_SCALAR_FUNCTION_SET(ISODayOfWeekFun), + DUCKDB_SCALAR_FUNCTION_SET(ISOYearFun), + DUCKDB_SCALAR_FUNCTION(JaccardFun), + DUCKDB_SCALAR_FUNCTION_SET(JaroSimilarityFun), + DUCKDB_SCALAR_FUNCTION_SET(JaroWinklerSimilarityFun), + DUCKDB_SCALAR_FUNCTION_SET(JulianDayFun), + DUCKDB_AGGREGATE_FUNCTION(KahanSumFun), + DUCKDB_AGGREGATE_FUNCTION(KurtosisFun), + DUCKDB_AGGREGATE_FUNCTION(KurtosisPopFun), + DUCKDB_SCALAR_FUNCTION_SET(LastDayFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(LcmFun), + DUCKDB_SCALAR_FUNCTION_SET(LeastFun), + DUCKDB_SCALAR_FUNCTION_SET(LeastCommonMultipleFun), + DUCKDB_SCALAR_FUNCTION(LeftFun), + DUCKDB_SCALAR_FUNCTION(LeftGraphemeFun), + DUCKDB_SCALAR_FUNCTION(LevenshteinFun), + DUCKDB_SCALAR_FUNCTION(LogGammaFun), + DUCKDB_AGGREGATE_FUNCTION(ListFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListAggrFun), + DUCKDB_SCALAR_FUNCTION(ListAggregateFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListApplyFun), + DUCKDB_SCALAR_FUNCTION_SET(ListCosineDistanceFun), + DUCKDB_SCALAR_FUNCTION_SET(ListCosineSimilarityFun), + DUCKDB_SCALAR_FUNCTION_SET(ListDistanceFun), + DUCKDB_SCALAR_FUNCTION(ListDistinctFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListDotProductFun), + DUCKDB_SCALAR_FUNCTION(ListFilterFun), + DUCKDB_SCALAR_FUNCTION_SET(ListGradeUpFun), + DUCKDB_SCALAR_FUNCTION(ListHasAllFun), + DUCKDB_SCALAR_FUNCTION(ListHasAnyFun), + DUCKDB_SCALAR_FUNCTION_SET(ListInnerProductFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListNegativeDotProductFun), + DUCKDB_SCALAR_FUNCTION_SET(ListNegativeInnerProductFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ListPackFun), + DUCKDB_SCALAR_FUNCTION(ListReduceFun), + DUCKDB_SCALAR_FUNCTION_SET(ListReverseSortFun), + DUCKDB_SCALAR_FUNCTION_SET(ListSliceFun), + DUCKDB_SCALAR_FUNCTION_SET(ListSortFun), + DUCKDB_SCALAR_FUNCTION(ListTransformFun), + DUCKDB_SCALAR_FUNCTION(ListUniqueFun), + DUCKDB_SCALAR_FUNCTION(ListValueFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ListaggFun), + DUCKDB_SCALAR_FUNCTION(LnFun), + DUCKDB_SCALAR_FUNCTION_SET(LogFun), + DUCKDB_SCALAR_FUNCTION(Log10Fun), + DUCKDB_SCALAR_FUNCTION(Log2Fun), + DUCKDB_SCALAR_FUNCTION(LpadFun), + DUCKDB_SCALAR_FUNCTION_SET(LtrimFun), + DUCKDB_AGGREGATE_FUNCTION_SET(MadFun), + DUCKDB_SCALAR_FUNCTION_SET(MakeDateFun), + DUCKDB_SCALAR_FUNCTION(MakeTimeFun), + DUCKDB_SCALAR_FUNCTION_SET(MakeTimestampFun), + DUCKDB_SCALAR_FUNCTION_SET(MakeTimestampNsFun), + DUCKDB_SCALAR_FUNCTION(MapFun), + DUCKDB_SCALAR_FUNCTION(MapConcatFun), + DUCKDB_SCALAR_FUNCTION(MapEntriesFun), + DUCKDB_SCALAR_FUNCTION(MapExtractFun), + DUCKDB_SCALAR_FUNCTION(MapFromEntriesFun), + DUCKDB_SCALAR_FUNCTION(MapKeysFun), + DUCKDB_SCALAR_FUNCTION(MapValuesFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MaxByFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MeanFun), + DUCKDB_AGGREGATE_FUNCTION_SET(MedianFun), + DUCKDB_SCALAR_FUNCTION_SET(MicrosecondsFun), + DUCKDB_SCALAR_FUNCTION_SET(MillenniumFun), + DUCKDB_SCALAR_FUNCTION_SET(MillisecondsFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MinByFun), + DUCKDB_SCALAR_FUNCTION_SET(MinutesFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(MismatchesFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ModeFun), + DUCKDB_SCALAR_FUNCTION_SET(MonthFun), + DUCKDB_SCALAR_FUNCTION_SET(MonthNameFun), + DUCKDB_SCALAR_FUNCTION_SET(NanosecondsFun), + DUCKDB_SCALAR_FUNCTION_SET(NextAfterFun), + DUCKDB_SCALAR_FUNCTION(NormalizedIntervalFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(NowFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(OrdFun), + DUCKDB_SCALAR_FUNCTION_SET(ParseDirnameFun), + DUCKDB_SCALAR_FUNCTION_SET(ParseDirpathFun), + DUCKDB_SCALAR_FUNCTION_SET(ParseFilenameFun), + DUCKDB_SCALAR_FUNCTION_SET(ParsePathFun), + DUCKDB_SCALAR_FUNCTION(PiFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(PositionFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(PowFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(PowerFun), + DUCKDB_SCALAR_FUNCTION(PrintfFun), + DUCKDB_AGGREGATE_FUNCTION(ProductFun), + DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(QuantileFun), + DUCKDB_AGGREGATE_FUNCTION_SET(QuantileContFun), + DUCKDB_AGGREGATE_FUNCTION_SET(QuantileDiscFun), + DUCKDB_SCALAR_FUNCTION_SET(QuarterFun), + DUCKDB_SCALAR_FUNCTION(RadiansFun), + DUCKDB_SCALAR_FUNCTION(RandomFun), + DUCKDB_SCALAR_FUNCTION_SET(ListRangeFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ReduceFun), + DUCKDB_AGGREGATE_FUNCTION(RegrAvgxFun), + DUCKDB_AGGREGATE_FUNCTION(RegrAvgyFun), + DUCKDB_AGGREGATE_FUNCTION(RegrCountFun), + DUCKDB_AGGREGATE_FUNCTION(RegrInterceptFun), + DUCKDB_AGGREGATE_FUNCTION(RegrR2Fun), + DUCKDB_AGGREGATE_FUNCTION(RegrSlopeFun), + DUCKDB_AGGREGATE_FUNCTION(RegrSXXFun), + DUCKDB_AGGREGATE_FUNCTION(RegrSXYFun), + DUCKDB_AGGREGATE_FUNCTION(RegrSYYFun), + DUCKDB_SCALAR_FUNCTION_SET(RepeatFun), + DUCKDB_SCALAR_FUNCTION(ReplaceFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ReservoirQuantileFun), + DUCKDB_SCALAR_FUNCTION(ReverseFun), + DUCKDB_SCALAR_FUNCTION(RightFun), + DUCKDB_SCALAR_FUNCTION(RightGraphemeFun), + DUCKDB_SCALAR_FUNCTION_SET(RoundFun), + DUCKDB_SCALAR_FUNCTION(RpadFun), + DUCKDB_SCALAR_FUNCTION_SET(RtrimFun), + DUCKDB_SCALAR_FUNCTION_SET(SecondsFun), + DUCKDB_AGGREGATE_FUNCTION(StandardErrorOfTheMeanFun), + DUCKDB_SCALAR_FUNCTION(SetBitFun), + DUCKDB_SCALAR_FUNCTION(SetseedFun), + DUCKDB_SCALAR_FUNCTION_SET(SignFun), + DUCKDB_SCALAR_FUNCTION_SET(SignBitFun), + DUCKDB_SCALAR_FUNCTION(SinFun), + DUCKDB_SCALAR_FUNCTION(SinhFun), + DUCKDB_AGGREGATE_FUNCTION(SkewnessFun), + DUCKDB_SCALAR_FUNCTION(SqrtFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StartsWithFun), + DUCKDB_SCALAR_FUNCTION(StatsFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(StddevFun), + DUCKDB_AGGREGATE_FUNCTION(StdDevPopFun), + DUCKDB_AGGREGATE_FUNCTION(StdDevSampFun), + DUCKDB_AGGREGATE_FUNCTION_SET(StringAggFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StrposFun), + DUCKDB_SCALAR_FUNCTION(StructInsertFun), + DUCKDB_AGGREGATE_FUNCTION_SET(SumFun), + DUCKDB_AGGREGATE_FUNCTION_SET(SumNoOverflowFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(SumkahanFun), + DUCKDB_SCALAR_FUNCTION(TanFun), + DUCKDB_SCALAR_FUNCTION(TanhFun), + DUCKDB_SCALAR_FUNCTION_SET(TimeBucketFun), + DUCKDB_SCALAR_FUNCTION(TimeTZSortKeyFun), + DUCKDB_SCALAR_FUNCTION_SET(TimezoneFun), + DUCKDB_SCALAR_FUNCTION_SET(TimezoneHourFun), + DUCKDB_SCALAR_FUNCTION_SET(TimezoneMinuteFun), + DUCKDB_SCALAR_FUNCTION_SET(ToBaseFun), + DUCKDB_SCALAR_FUNCTION(ToBase64Fun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ToBinaryFun), + DUCKDB_SCALAR_FUNCTION(ToCenturiesFun), + DUCKDB_SCALAR_FUNCTION(ToDaysFun), + DUCKDB_SCALAR_FUNCTION(ToDecadesFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ToHexFun), + DUCKDB_SCALAR_FUNCTION(ToHoursFun), + DUCKDB_SCALAR_FUNCTION(ToMicrosecondsFun), + DUCKDB_SCALAR_FUNCTION(ToMillenniaFun), + DUCKDB_SCALAR_FUNCTION(ToMillisecondsFun), + DUCKDB_SCALAR_FUNCTION(ToMinutesFun), + DUCKDB_SCALAR_FUNCTION(ToMonthsFun), + DUCKDB_SCALAR_FUNCTION(ToQuartersFun), + DUCKDB_SCALAR_FUNCTION(ToSecondsFun), + DUCKDB_SCALAR_FUNCTION(ToTimestampFun), + DUCKDB_SCALAR_FUNCTION(ToWeeksFun), + DUCKDB_SCALAR_FUNCTION(ToYearsFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(TransactionTimestampFun), + DUCKDB_SCALAR_FUNCTION(TranslateFun), + DUCKDB_SCALAR_FUNCTION_SET(TrimFun), + DUCKDB_SCALAR_FUNCTION_SET(TruncFun), + DUCKDB_SCALAR_FUNCTION(CurrentTransactionIdFun), + DUCKDB_SCALAR_FUNCTION(TypeOfFun), + DUCKDB_SCALAR_FUNCTION(UnbinFun), + DUCKDB_SCALAR_FUNCTION(UnhexFun), + DUCKDB_SCALAR_FUNCTION(UnicodeFun), + DUCKDB_SCALAR_FUNCTION(UnionExtractFun), + DUCKDB_SCALAR_FUNCTION(UnionTagFun), + DUCKDB_SCALAR_FUNCTION(UnionValueFun), + DUCKDB_SCALAR_FUNCTION(UnpivotListFun), + DUCKDB_SCALAR_FUNCTION(UrlDecodeFun), + DUCKDB_SCALAR_FUNCTION(UrlEncodeFun), + DUCKDB_SCALAR_FUNCTION(UUIDFun), + DUCKDB_AGGREGATE_FUNCTION(VarPopFun), + DUCKDB_AGGREGATE_FUNCTION(VarSampFun), + DUCKDB_AGGREGATE_FUNCTION_ALIAS(VarianceFun), + DUCKDB_SCALAR_FUNCTION(VectorTypeFun), + DUCKDB_SCALAR_FUNCTION(VersionFun), + DUCKDB_SCALAR_FUNCTION_SET(WeekFun), + DUCKDB_SCALAR_FUNCTION_SET(WeekDayFun), + DUCKDB_SCALAR_FUNCTION_SET(WeekOfYearFun), + DUCKDB_SCALAR_FUNCTION_SET(BitwiseXorFun), + DUCKDB_SCALAR_FUNCTION_SET(YearFun), + DUCKDB_SCALAR_FUNCTION_SET(YearWeekFun), + DUCKDB_SCALAR_FUNCTION_SET(BitwiseOrFun), + DUCKDB_SCALAR_FUNCTION_SET(BitwiseNotFun), + FINAL_FUNCTION +}; + +const StaticFunctionDefinition *StaticFunctionDefinition::GetFunctionList() { + return core_functions; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/corr.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/corr.hpp new file mode 100644 index 00000000..05cdfb14 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/corr.hpp @@ -0,0 +1,70 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/algebraic/corr.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/aggregate_function.hpp" +#include "core_functions/aggregate/algebraic/covar.hpp" +#include "core_functions/aggregate/algebraic/stddev.hpp" + +namespace duckdb { + +struct CorrState { + CovarState cov_pop; + StddevState dev_pop_x; + StddevState dev_pop_y; +}; + +// Returns the correlation coefficient for non-null pairs in a group. +// CORR(y, x) = COVAR_POP(y, x) / (STDDEV_POP(x) * STDDEV_POP(y)) +struct CorrOperation { + template + static void Initialize(STATE &state) { + CovarOperation::Initialize(state.cov_pop); + STDDevBaseOperation::Initialize(state.dev_pop_x); + STDDevBaseOperation::Initialize(state.dev_pop_y); + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + CovarOperation::Operation(state.cov_pop, y, x, idata); + STDDevBaseOperation::Execute(state.dev_pop_x, x); + STDDevBaseOperation::Execute(state.dev_pop_y, y); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); + STDDevBaseOperation::Combine(source.dev_pop_x, target.dev_pop_x, aggr_input_data); + STDDevBaseOperation::Combine(source.dev_pop_y, target.dev_pop_y, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.cov_pop.count == 0 || state.dev_pop_x.count == 0 || state.dev_pop_y.count == 0) { + finalize_data.ReturnNull(); + } else { + auto cov = state.cov_pop.co_moment / state.cov_pop.count; + auto std_x = state.dev_pop_x.count > 1 ? sqrt(state.dev_pop_x.dsquared / state.dev_pop_x.count) : 0; + if (!Value::DoubleIsFinite(std_x)) { + throw OutOfRangeException("STDDEV_POP for X is out of range!"); + } + auto std_y = state.dev_pop_y.count > 1 ? sqrt(state.dev_pop_y.dsquared / state.dev_pop_y.count) : 0; + if (!Value::DoubleIsFinite(std_y)) { + throw OutOfRangeException("STDDEV_POP for Y is out of range!"); + } + target = std_x * std_y != 0 ? cov / (std_x * std_y) : NAN; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/covar.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/covar.hpp new file mode 100644 index 00000000..1908dfad --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/covar.hpp @@ -0,0 +1,101 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/algebraic/covar.hpp +// +// +//===----------------------------------------------------------------------===// +// COVAR_POP(y,x) + +#pragma once + +#include "duckdb/function/aggregate_function.hpp" + +namespace duckdb { + +struct CovarState { + uint64_t count; + double meanx; + double meany; + double co_moment; +}; + +struct CovarOperation { + template + static void Initialize(STATE &state) { + state.count = 0; + state.meanx = 0; + state.meany = 0; + state.co_moment = 0; + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + // update running mean and d^2 + const double n = static_cast(++(state.count)); + + const double dx = (x - state.meanx); + const double meanx = state.meanx + dx / n; + + const double dy = (y - state.meany); + const double meany = state.meany + dy / n; + + // Schubert and Gertz SSDBM 2018 (4.3) + const double C = state.co_moment + dx * (y - meany); + + state.meanx = meanx; + state.meany = meany; + state.co_moment = C; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (target.count == 0) { + target = source; + } else if (source.count > 0) { + const auto count = target.count + source.count; + D_ASSERT(count >= target.count); // This is a check that we are not overflowing + const auto target_count = static_cast(target.count); + const auto source_count = static_cast(source.count); + const auto total_count = static_cast(count); + const auto meanx = (source_count * source.meanx + target_count * target.meanx) / total_count; + const auto meany = (source_count * source.meany + target_count * target.meany) / total_count; + + // Schubert and Gertz SSDBM 2018, equation 21 + const auto deltax = target.meanx - source.meanx; + const auto deltay = target.meany - source.meany; + target.co_moment = + source.co_moment + target.co_moment + deltax * deltay * source_count * target_count / total_count; + target.meanx = meanx; + target.meany = meany; + target.count = count; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct CovarPopOperation : public CovarOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = state.co_moment / state.count; + } + } +}; + +struct CovarSampOperation : public CovarOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count < 2) { + finalize_data.ReturnNull(); + } else { + target = state.co_moment / (state.count - 1); + } + } +}; +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp new file mode 100644 index 00000000..bdcafae9 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp @@ -0,0 +1,151 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/algebraic/stddev.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/aggregate_function.hpp" +#include + +namespace duckdb { + +struct StddevState { + uint64_t count; // n + double mean; // M1 + double dsquared; // M2 +}; + +// Streaming approximate standard deviation using Welford's +// method, DOI: 10.2307/1266577 +struct STDDevBaseOperation { + template + static void Initialize(STATE &state) { + state.count = 0; + state.mean = 0; + state.dsquared = 0; + } + + template + static void Execute(STATE &state, const INPUT_TYPE &input) { + // update running mean and d^2 + state.count++; + const double mean_differential = (input - state.mean) / state.count; + const double new_mean = state.mean + mean_differential; + const double dsquared_increment = (input - new_mean) * (input - state.mean); + const double new_dsquared = state.dsquared + dsquared_increment; + + state.mean = new_mean; + state.dsquared = new_dsquared; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + Execute(state, input); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (target.count == 0) { + target = source; + } else if (source.count > 0) { + const auto count = target.count + source.count; + D_ASSERT(count >= target.count); // This is a check that we are not overflowing + const double target_count = static_cast(target.count); + const double source_count = static_cast(source.count); + const double total_count = static_cast(count); + const auto mean = (source_count * source.mean + target_count * target.mean) / total_count; + const auto delta = source.mean - target.mean; + target.dsquared = + source.dsquared + target.dsquared + delta * delta * source_count * target_count / total_count; + target.mean = mean; + target.count = count; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +struct VarSampOperation : public STDDevBaseOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count <= 1) { + finalize_data.ReturnNull(); + } else { + target = state.dsquared / (state.count - 1); + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("VARSAMP is out of range!"); + } + } + } +}; + +struct VarPopOperation : public STDDevBaseOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = state.count > 1 ? (state.dsquared / state.count) : 0; + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("VARPOP is out of range!"); + } + } + } +}; + +struct STDDevSampOperation : public STDDevBaseOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count <= 1) { + finalize_data.ReturnNull(); + } else { + target = sqrt(state.dsquared / (state.count - 1)); + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("STDDEV_SAMP is out of range!"); + } + } + } +}; + +struct STDDevPopOperation : public STDDevBaseOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = state.count > 1 ? sqrt(state.dsquared / state.count) : 0; + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("STDDEV_POP is out of range!"); + } + } + } +}; + +struct StandardErrorOfTheMeanOperation : public STDDevBaseOperation { + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.count == 0) { + finalize_data.ReturnNull(); + } else { + target = sqrt(state.dsquared / state.count) / sqrt((state.count)); + if (!Value::DoubleIsFinite(target)) { + throw OutOfRangeException("SEM is out of range!"); + } + } + } +}; +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic_functions.hpp new file mode 100644 index 00000000..da08c769 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic_functions.hpp @@ -0,0 +1,126 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/aggregate/algebraic_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct AvgFun { + static constexpr const char *Name = "avg"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Calculates the average value for all tuples in x."; + static constexpr const char *Example = "SUM(x) / COUNT(*)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct MeanFun { + using ALIAS = AvgFun; + + static constexpr const char *Name = "mean"; +}; + +struct CorrFun { + static constexpr const char *Name = "corr"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the correlation coefficient for non-null pairs in a group."; + static constexpr const char *Example = "COVAR_POP(y, x) / (STDDEV_POP(x) * STDDEV_POP(y))"; + + static AggregateFunction GetFunction(); +}; + +struct CovarPopFun { + static constexpr const char *Name = "covar_pop"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the population covariance of input values."; + static constexpr const char *Example = "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / COUNT(*)"; + + static AggregateFunction GetFunction(); +}; + +struct CovarSampFun { + static constexpr const char *Name = "covar_samp"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the sample covariance for non-null pairs in a group."; + static constexpr const char *Example = "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / (COUNT(*) - 1)"; + + static AggregateFunction GetFunction(); +}; + +struct FAvgFun { + static constexpr const char *Name = "favg"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Calculates the average using a more accurate floating point summation (Kahan Sum)"; + static constexpr const char *Example = "favg(A)"; + + static AggregateFunction GetFunction(); +}; + +struct StandardErrorOfTheMeanFun { + static constexpr const char *Name = "sem"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the standard error of the mean"; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct StdDevPopFun { + static constexpr const char *Name = "stddev_pop"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the population standard deviation."; + static constexpr const char *Example = "sqrt(var_pop(x))"; + + static AggregateFunction GetFunction(); +}; + +struct StdDevSampFun { + static constexpr const char *Name = "stddev_samp"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the sample standard deviation"; + static constexpr const char *Example = "sqrt(var_samp(x))"; + + static AggregateFunction GetFunction(); +}; + +struct StddevFun { + using ALIAS = StdDevSampFun; + + static constexpr const char *Name = "stddev"; +}; + +struct VarPopFun { + static constexpr const char *Name = "var_pop"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the population variance."; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct VarSampFun { + static constexpr const char *Name = "var_samp"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the sample variance of all input values."; + static constexpr const char *Example = "(SUM(x^2) - SUM(x)^2 / COUNT(x)) / (COUNT(x) - 1)"; + + static AggregateFunction GetFunction(); +}; + +struct VarianceFun { + using ALIAS = VarSampFun; + + static constexpr const char *Name = "variance"; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp new file mode 100644 index 00000000..50c0197a --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp @@ -0,0 +1,261 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/aggregate/distributive_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct ApproxCountDistinctFun { + static constexpr const char *Name = "approx_count_distinct"; + static constexpr const char *Parameters = "any"; + static constexpr const char *Description = "Computes the approximate count of distinct elements using HyperLogLog."; + static constexpr const char *Example = "approx_count_distinct(A)"; + + static AggregateFunction GetFunction(); +}; + +struct ArgMinFun { + static constexpr const char *Name = "arg_min"; + static constexpr const char *Parameters = "arg,val"; + static constexpr const char *Description = "Finds the row with the minimum val. Calculates the non-NULL arg expression at that row."; + static constexpr const char *Example = "arg_min(A,B)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ArgminFun { + using ALIAS = ArgMinFun; + + static constexpr const char *Name = "argmin"; +}; + +struct MinByFun { + using ALIAS = ArgMinFun; + + static constexpr const char *Name = "min_by"; +}; + +struct ArgMinNullFun { + static constexpr const char *Name = "arg_min_null"; + static constexpr const char *Parameters = "arg,val"; + static constexpr const char *Description = "Finds the row with the minimum val. Calculates the arg expression at that row."; + static constexpr const char *Example = "arg_min_null(A,B)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ArgMaxFun { + static constexpr const char *Name = "arg_max"; + static constexpr const char *Parameters = "arg,val"; + static constexpr const char *Description = "Finds the row with the maximum val. Calculates the non-NULL arg expression at that row."; + static constexpr const char *Example = "arg_max(A,B)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ArgmaxFun { + using ALIAS = ArgMaxFun; + + static constexpr const char *Name = "argmax"; +}; + +struct MaxByFun { + using ALIAS = ArgMaxFun; + + static constexpr const char *Name = "max_by"; +}; + +struct ArgMaxNullFun { + static constexpr const char *Name = "arg_max_null"; + static constexpr const char *Parameters = "arg,val"; + static constexpr const char *Description = "Finds the row with the maximum val. Calculates the arg expression at that row."; + static constexpr const char *Example = "arg_max_null(A,B)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct BitAndFun { + static constexpr const char *Name = "bit_and"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns the bitwise AND of all bits in a given expression."; + static constexpr const char *Example = "bit_and(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct BitOrFun { + static constexpr const char *Name = "bit_or"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns the bitwise OR of all bits in a given expression."; + static constexpr const char *Example = "bit_or(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct BitXorFun { + static constexpr const char *Name = "bit_xor"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns the bitwise XOR of all bits in a given expression."; + static constexpr const char *Example = "bit_xor(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct BitstringAggFun { + static constexpr const char *Name = "bitstring_agg"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns a bitstring with bits set for each distinct value."; + static constexpr const char *Example = "bitstring_agg(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct BoolAndFun { + static constexpr const char *Name = "bool_and"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns TRUE if every input value is TRUE, otherwise FALSE."; + static constexpr const char *Example = "bool_and(A)"; + + static AggregateFunction GetFunction(); +}; + +struct BoolOrFun { + static constexpr const char *Name = "bool_or"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns TRUE if any input value is TRUE, otherwise FALSE."; + static constexpr const char *Example = "bool_or(A)"; + + static AggregateFunction GetFunction(); +}; + +struct CountIfFun { + static constexpr const char *Name = "count_if"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Counts the total number of TRUE values for a boolean column"; + static constexpr const char *Example = "count_if(A)"; + + static AggregateFunction GetFunction(); +}; + +struct CountifFun { + using ALIAS = CountIfFun; + + static constexpr const char *Name = "countif"; +}; + +struct EntropyFun { + static constexpr const char *Name = "entropy"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the log-2 entropy of count input-values."; + static constexpr const char *Example = ""; + + static AggregateFunctionSet GetFunctions(); +}; + +struct KahanSumFun { + static constexpr const char *Name = "kahan_sum"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Calculates the sum using a more accurate floating point summation (Kahan Sum)."; + static constexpr const char *Example = "kahan_sum(A)"; + + static AggregateFunction GetFunction(); +}; + +struct FsumFun { + using ALIAS = KahanSumFun; + + static constexpr const char *Name = "fsum"; +}; + +struct SumkahanFun { + using ALIAS = KahanSumFun; + + static constexpr const char *Name = "sumkahan"; +}; + +struct KurtosisFun { + static constexpr const char *Name = "kurtosis"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the excess kurtosis (Fisher’s definition) of all input values, with a bias correction according to the sample size"; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct KurtosisPopFun { + static constexpr const char *Name = "kurtosis_pop"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the excess kurtosis (Fisher’s definition) of all input values, without bias correction"; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct ProductFun { + static constexpr const char *Name = "product"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Calculates the product of all tuples in arg."; + static constexpr const char *Example = "product(A)"; + + static AggregateFunction GetFunction(); +}; + +struct SkewnessFun { + static constexpr const char *Name = "skewness"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the skewness of all input values."; + static constexpr const char *Example = "skewness(A)"; + + static AggregateFunction GetFunction(); +}; + +struct StringAggFun { + static constexpr const char *Name = "string_agg"; + static constexpr const char *Parameters = "str,arg"; + static constexpr const char *Description = "Concatenates the column string values with an optional separator."; + static constexpr const char *Example = "string_agg(A, '-')"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct GroupConcatFun { + using ALIAS = StringAggFun; + + static constexpr const char *Name = "group_concat"; +}; + +struct ListaggFun { + using ALIAS = StringAggFun; + + static constexpr const char *Name = "listagg"; +}; + +struct SumFun { + static constexpr const char *Name = "sum"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Calculates the sum value for all tuples in arg."; + static constexpr const char *Example = "sum(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct SumNoOverflowFun { + static constexpr const char *Name = "sum_no_overflow"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Internal only. Calculates the sum value for all tuples in arg without overflow checks."; + static constexpr const char *Example = "sum_no_overflow(A)"; + + static AggregateFunctionSet GetFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/histogram_helpers.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/histogram_helpers.hpp new file mode 100644 index 00000000..7d73a3ca --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/histogram_helpers.hpp @@ -0,0 +1,99 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/histogram_helpers.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/function/create_sort_key.hpp" + +namespace duckdb { + +struct HistogramFunctor { + template + static void HistogramFinalize(T value, Vector &result, idx_t offset) { + FlatVector::GetData(result)[offset] = value; + } + + static bool CreateExtraState(idx_t count) { + return false; + } + + static void PrepareData(Vector &input, idx_t count, bool &, UnifiedVectorFormat &result) { + input.ToUnifiedFormat(count, result); + } + + template + static T ExtractValue(UnifiedVectorFormat &bin_data, idx_t offset, AggregateInputData &) { + return UnifiedVectorFormat::GetData(bin_data)[bin_data.sel->get_index(offset)]; + } + + static bool RequiresExtract() { + return false; + } +}; + +struct HistogramStringFunctorBase { + template + static T ExtractValue(UnifiedVectorFormat &bin_data, idx_t offset, AggregateInputData &aggr_input) { + auto &input_str = UnifiedVectorFormat::GetData(bin_data)[bin_data.sel->get_index(offset)]; + if (input_str.IsInlined()) { + // inlined strings can be inserted directly + return input_str; + } + // if the string is not inlined we need to allocate space for it + auto input_str_size = UnsafeNumericCast(input_str.GetSize()); + auto string_memory = aggr_input.allocator.Allocate(input_str_size); + // copy over the string + memcpy(string_memory, input_str.GetData(), input_str_size); + // now insert it into the histogram + string_t histogram_str(char_ptr_cast(string_memory), input_str_size); + return histogram_str; + } + + static bool RequiresExtract() { + return true; + } +}; + +struct HistogramStringFunctor : HistogramStringFunctorBase { + template + static void HistogramFinalize(T value, Vector &result, idx_t offset) { + FlatVector::GetData(result)[offset] = StringVector::AddStringOrBlob(result, value); + } + + static bool CreateExtraState(idx_t count) { + return false; + } + + static void PrepareData(Vector &input, idx_t count, bool &, UnifiedVectorFormat &result) { + input.ToUnifiedFormat(count, result); + } +}; + +struct HistogramGenericFunctor : HistogramStringFunctorBase { + template + static void HistogramFinalize(T value, Vector &result, idx_t offset) { + CreateSortKeyHelpers::DecodeSortKey(value, result, offset, + OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); + } + + static Vector CreateExtraState(idx_t count) { + return Vector(LogicalType::BLOB, count); + } + + static void PrepareData(Vector &input, idx_t count, Vector &extra_state, UnifiedVectorFormat &result) { + OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + CreateSortKeyHelpers::CreateSortKey(input, count, modifiers, extra_state); + input.Flatten(count); + extra_state.Flatten(count); + FlatVector::Validity(extra_state).Initialize(FlatVector::Validity(input)); + extra_state.ToUnifiedFormat(count, result); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/holistic_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/holistic_functions.hpp new file mode 100644 index 00000000..f8b96a16 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/holistic_functions.hpp @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/aggregate/holistic_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct ApproxQuantileFun { + static constexpr const char *Name = "approx_quantile"; + static constexpr const char *Parameters = "x,pos"; + static constexpr const char *Description = "Computes the approximate quantile using T-Digest."; + static constexpr const char *Example = "approx_quantile(x, 0.5)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct MadFun { + static constexpr const char *Name = "mad"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the median absolute deviation for the values within x. NULL values are ignored. Temporal types return a positive INTERVAL. "; + static constexpr const char *Example = "mad(x)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct MedianFun { + static constexpr const char *Name = "median"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the middle value of the set. NULL values are ignored. For even value counts, quantitative values are averaged and ordinal values return the lower value."; + static constexpr const char *Example = "median(x)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ModeFun { + static constexpr const char *Name = "mode"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the most frequent value for the values within x. NULL values are ignored."; + static constexpr const char *Example = ""; + + static AggregateFunctionSet GetFunctions(); +}; + +struct QuantileDiscFun { + static constexpr const char *Name = "quantile_disc"; + static constexpr const char *Parameters = "x,pos"; + static constexpr const char *Description = "Returns the exact quantile number between 0 and 1 . If pos is a LIST of FLOATs, then the result is a LIST of the corresponding exact quantiles."; + static constexpr const char *Example = "quantile_disc(x, 0.5)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct QuantileFun { + using ALIAS = QuantileDiscFun; + + static constexpr const char *Name = "quantile"; +}; + +struct QuantileContFun { + static constexpr const char *Name = "quantile_cont"; + static constexpr const char *Parameters = "x,pos"; + static constexpr const char *Description = "Returns the interpolated quantile number between 0 and 1 . If pos is a LIST of FLOATs, then the result is a LIST of the corresponding interpolated quantiles. "; + static constexpr const char *Example = "quantile_cont(x, 0.5)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ReservoirQuantileFun { + static constexpr const char *Name = "reservoir_quantile"; + static constexpr const char *Parameters = "x,quantile,sample_size"; + static constexpr const char *Description = "Gives the approximate quantile using reservoir sampling, the sample size is optional and uses 8192 as a default size."; + static constexpr const char *Example = "reservoir_quantile(A,0.5,1024)"; + + static AggregateFunctionSet GetFunctions(); +}; + +struct ApproxTopKFun { + static constexpr const char *Name = "approx_top_k"; + static constexpr const char *Parameters = "val,k"; + static constexpr const char *Description = "Finds the k approximately most occurring values in the data set"; + static constexpr const char *Example = "approx_top_k(x, 5)"; + + static AggregateFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/nested_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/nested_functions.hpp new file mode 100644 index 00000000..eb83e5e1 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/nested_functions.hpp @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/aggregate/nested_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct HistogramFun { + static constexpr const char *Name = "histogram"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns a LIST of STRUCTs with the fields bucket and count."; + static constexpr const char *Example = "histogram(A)"; + + static AggregateFunctionSet GetFunctions(); + static AggregateFunction GetHistogramUnorderedMap(LogicalType &type); + static AggregateFunction BinnedHistogramFunction(); +}; + +struct HistogramExactFun { + static constexpr const char *Name = "histogram_exact"; + static constexpr const char *Parameters = "arg,bins"; + static constexpr const char *Description = "Returns a LIST of STRUCTs with the fields bucket and count matching the buckets exactly."; + static constexpr const char *Example = "histogram_exact(A, [0, 1, 2])"; + + static AggregateFunction GetFunction(); +}; + +struct ListFun { + static constexpr const char *Name = "list"; + static constexpr const char *Parameters = "arg"; + static constexpr const char *Description = "Returns a LIST containing all the values of a column."; + static constexpr const char *Example = "list(A)"; + + static AggregateFunction GetFunction(); +}; + +struct ArrayAggFun { + using ALIAS = ListFun; + + static constexpr const char *Name = "array_agg"; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_helpers.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_helpers.hpp new file mode 100644 index 00000000..253657f5 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_helpers.hpp @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/quantile_helpers.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/quantile_enum.hpp" +#include "core_functions/aggregate/holistic_functions.hpp" + +namespace duckdb { + +// Avoid using naked Values in inner loops... +struct QuantileValue { + explicit QuantileValue(const Value &v) : val(v), dbl(v.GetValue()) { + const auto &type = val.type(); + switch (type.id()) { + case LogicalTypeId::DECIMAL: { + integral = IntegralValue::Get(v); + scaling = Hugeint::POWERS_OF_TEN[DecimalType::GetScale(type)]; + break; + } + default: + break; + } + } + + Value val; + + // DOUBLE + double dbl; + + // DECIMAL + hugeint_t integral; + hugeint_t scaling; + + inline bool operator==(const QuantileValue &other) const { + return val == other.val; + } +}; + +struct QuantileBindData : public FunctionData { + QuantileBindData(); + explicit QuantileBindData(const Value &quantile_p); + explicit QuantileBindData(const vector &quantiles_p); + QuantileBindData(const QuantileBindData &other); + + unique_ptr Copy() const override; + bool Equals(const FunctionData &other_p) const override; + + static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const AggregateFunction &function); + + static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function); + + vector quantiles; + vector order; + bool desc; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp new file mode 100644 index 00000000..a330c0a4 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp @@ -0,0 +1,431 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/quantile_sort_tree.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/common/types/row/row_layout.hpp" +#include "core_functions/aggregate/quantile_helpers.hpp" +#include "duckdb/execution/merge_sort_tree.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/function/window/window_index_tree.hpp" +#include +#include +#include +#include + +namespace duckdb { + +// Paged access +template +struct QuantileCursor { + explicit QuantileCursor(const WindowPartitionInput &partition) : inputs(*partition.inputs) { + D_ASSERT(partition.column_ids.size() == 1); + inputs.InitializeScan(scan, partition.column_ids); + inputs.InitializeScanChunk(scan, page); + + D_ASSERT(partition.all_valid.size() == 1); + all_valid = partition.all_valid[0]; + } + + inline sel_t RowOffset(idx_t row_idx) const { + D_ASSERT(RowIsVisible(row_idx)); + return UnsafeNumericCast(row_idx - scan.current_row_index); + } + + inline bool RowIsVisible(idx_t row_idx) const { + return (row_idx < scan.next_row_index && scan.current_row_index <= row_idx); + } + + inline idx_t Seek(idx_t row_idx) { + if (!RowIsVisible(row_idx)) { + inputs.Seek(row_idx, scan, page); + data = FlatVector::GetData(page.data[0]); + validity = &FlatVector::Validity(page.data[0]); + } + return RowOffset(row_idx); + } + + inline const INPUT_TYPE &operator[](idx_t row_idx) { + const auto offset = Seek(row_idx); + return data[offset]; + } + + inline bool RowIsValid(idx_t row_idx) { + const auto offset = Seek(row_idx); + return validity->RowIsValid(offset); + } + + inline bool AllValid() { + return all_valid; + } + + //! Windowed paging + const ColumnDataCollection &inputs; + //! The state used for reading the collection on this thread + ColumnDataScanState scan; + //! The data chunk paged into into + DataChunk page; + //! The data pointer + const INPUT_TYPE *data = nullptr; + //! The validity mask + const ValidityMask *validity = nullptr; + //! Paged chunks do not track this but it is really necessary for performance + bool all_valid; +}; + +// Direct access +template +struct QuantileDirect { + using INPUT_TYPE = T; + using RESULT_TYPE = T; + + inline const INPUT_TYPE &operator()(const INPUT_TYPE &x) const { + return x; + } +}; + +// Indirect access +template +struct QuantileIndirect { + using INPUT_TYPE = idx_t; + using RESULT_TYPE = T; + using CURSOR = QuantileCursor; + CURSOR &data; + + explicit QuantileIndirect(CURSOR &data_p) : data(data_p) { + } + + inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { + return data[input]; + } +}; + +// Composed access +template +struct QuantileComposed { + using INPUT_TYPE = typename INNER::INPUT_TYPE; + using RESULT_TYPE = typename OUTER::RESULT_TYPE; + + const OUTER &outer; + const INNER &inner; + + explicit QuantileComposed(const OUTER &outer_p, const INNER &inner_p) : outer(outer_p), inner(inner_p) { + } + + inline RESULT_TYPE operator()(const idx_t &input) const { + return outer(inner(input)); + } +}; + +// Accessed comparison +template +struct QuantileCompare { + using INPUT_TYPE = typename ACCESSOR::INPUT_TYPE; + const ACCESSOR &accessor_l; + const ACCESSOR &accessor_r; + const bool desc; + + // Single cursor for linear operations + explicit QuantileCompare(const ACCESSOR &accessor, bool desc_p) + : accessor_l(accessor), accessor_r(accessor), desc(desc_p) { + } + + // Independent cursors for sorting + explicit QuantileCompare(const ACCESSOR &accessor_l, const ACCESSOR &accessor_r, bool desc_p) + : accessor_l(accessor_l), accessor_r(accessor_r), desc(desc_p) { + } + + inline bool operator()(const INPUT_TYPE &lhs, const INPUT_TYPE &rhs) const { + const auto lval = accessor_l(lhs); + const auto rval = accessor_r(rhs); + + return desc ? (rval < lval) : (lval < rval); + } +}; + +struct CastInterpolation { + template + static inline TARGET_TYPE Cast(const INPUT_TYPE &src, Vector &result) { + return Cast::Operation(src); + } + template + static inline TARGET_TYPE Interpolate(const TARGET_TYPE &lo, const double d, const TARGET_TYPE &hi) { + const auto delta = hi - lo; + return LossyNumericCast(lo + delta * d); + } +}; + +template <> +interval_t CastInterpolation::Cast(const dtime_t &src, Vector &result); +template <> +double CastInterpolation::Interpolate(const double &lo, const double d, const double &hi); +template <> +dtime_t CastInterpolation::Interpolate(const dtime_t &lo, const double d, const dtime_t &hi); +template <> +timestamp_t CastInterpolation::Interpolate(const timestamp_t &lo, const double d, const timestamp_t &hi); +template <> +hugeint_t CastInterpolation::Interpolate(const hugeint_t &lo, const double d, const hugeint_t &hi); +template <> +interval_t CastInterpolation::Interpolate(const interval_t &lo, const double d, const interval_t &hi); +template <> +string_t CastInterpolation::Cast(const string_t &src, Vector &result); + +// Continuous interpolation +template +struct Interpolator { + Interpolator(const QuantileValue &q, const idx_t n_p, const bool desc_p) + : desc(desc_p), RN((double)(n_p - 1) * q.dbl), FRN(ExactNumericCast(floor(RN))), + CRN(ExactNumericCast(ceil(RN))), begin(0), end(n_p) { + } + + template > + TARGET_TYPE Interpolate(INPUT_TYPE lidx, INPUT_TYPE hidx, Vector &result, const ACCESSOR &accessor) const { + using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + if (lidx == hidx) { + return CastInterpolation::Cast(accessor(lidx), result); + } else { + auto lo = CastInterpolation::Cast(accessor(lidx), result); + auto hi = CastInterpolation::Cast(accessor(hidx), result); + return CastInterpolation::Interpolate(lo, RN - FRN, hi); + } + } + + template > + TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { + using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + QuantileCompare comp(accessor, desc); + if (CRN == FRN) { + std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); + return CastInterpolation::Cast(accessor(v_t[FRN]), result); + } else { + std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); + std::nth_element(v_t + FRN, v_t + CRN, v_t + end, comp); + auto lo = CastInterpolation::Cast(accessor(v_t[FRN]), result); + auto hi = CastInterpolation::Cast(accessor(v_t[CRN]), result); + return CastInterpolation::Interpolate(lo, RN - FRN, hi); + } + } + + template + inline TARGET_TYPE Extract(const INPUT_TYPE *dest, Vector &result) const { + if (CRN == FRN) { + return CastInterpolation::Cast(dest[0], result); + } else { + auto lo = CastInterpolation::Cast(dest[0], result); + auto hi = CastInterpolation::Cast(dest[1], result); + return CastInterpolation::Interpolate(lo, RN - FRN, hi); + } + } + + const bool desc; + const double RN; + const idx_t FRN; + const idx_t CRN; + + idx_t begin; + idx_t end; +}; + +// Discrete "interpolation" +template <> +struct Interpolator { + static inline idx_t Index(const QuantileValue &q, const idx_t n) { + idx_t floored; + switch (q.val.type().id()) { + case LogicalTypeId::DECIMAL: { + // Integer arithmetic for accuracy + const auto integral = q.integral; + const auto scaling = q.scaling; + const auto scaled_q = + DecimalMultiplyOverflowCheck::Operation(Hugeint::Convert(n), integral); + const auto scaled_n = + DecimalMultiplyOverflowCheck::Operation(Hugeint::Convert(n), scaling); + floored = Cast::Operation((scaled_n - scaled_q) / scaling); + break; + } + default: + const auto scaled_q = double(n) * q.dbl; + floored = LossyNumericCast(floor(double(n) - scaled_q)); + break; + } + + return MaxValue(1, n - floored) - 1; + } + + Interpolator(const QuantileValue &q, const idx_t n_p, bool desc_p) + : desc(desc_p), FRN(Index(q, n_p)), CRN(FRN), begin(0), end(n_p) { + } + + template > + TARGET_TYPE Interpolate(INPUT_TYPE lidx, INPUT_TYPE hidx, Vector &result, const ACCESSOR &accessor) const { + using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + return CastInterpolation::Cast(accessor(lidx), result); + } + + template > + typename ACCESSOR::RESULT_TYPE InterpolateInternal(INPUT_TYPE *v_t, const ACCESSOR &accessor = ACCESSOR()) const { + QuantileCompare comp(accessor, desc); + std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); + return accessor(v_t[FRN]); + } + + template > + TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { + using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + return CastInterpolation::Cast(InterpolateInternal(v_t, accessor), result); + } + + template + TARGET_TYPE Extract(const INPUT_TYPE *dest, Vector &result) const { + return CastInterpolation::Cast(dest[0], result); + } + + const bool desc; + const idx_t FRN; + const idx_t CRN; + + idx_t begin; + idx_t end; +}; + +template +struct QuantileIncluded { + using CURSOR_TYPE = QuantileCursor; + + inline explicit QuantileIncluded(const ValidityMask &fmask_p, CURSOR_TYPE &dmask_p) + : fmask(fmask_p), dmask(dmask_p) { + } + + inline bool operator()(const idx_t &idx) { + return fmask.RowIsValid(idx) && dmask.RowIsValid(idx); + } + + inline bool AllValid() { + return fmask.AllValid() && dmask.AllValid(); + } + + const ValidityMask &fmask; + CURSOR_TYPE &dmask; +}; + +struct QuantileSortTree { + + unique_ptr index_tree; + + QuantileSortTree(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition) { + // TODO: Two pass parallel sorting using Build + auto &inputs = *partition.inputs; + ColumnDataScanState scan; + DataChunk sort; + inputs.InitializeScan(scan, partition.column_ids); + inputs.InitializeScanChunk(scan, sort); + + // Sort on the single argument + auto &bind_data = aggr_input_data.bind_data->Cast(); + auto order_expr = make_uniq(Value(sort.GetTypes()[0])); + auto order_type = bind_data.desc ? OrderType::DESCENDING : OrderType::ASCENDING; + BoundOrderModifier order_bys; + order_bys.orders.emplace_back(BoundOrderByNode(order_type, OrderByNullType::NULLS_LAST, std::move(order_expr))); + vector sort_idx(1, 0); + const auto count = partition.count; + + index_tree = make_uniq(partition.context, order_bys, sort_idx, count); + auto index_state = index_tree->GetLocalState(); + auto &local_state = index_state->Cast(); + + // Build the indirection array by scanning the valid indices + const auto &filter_mask = partition.filter_mask; + SelectionVector filter_sel(STANDARD_VECTOR_SIZE); + while (inputs.Scan(scan, sort)) { + const auto row_idx = scan.current_row_index; + if (!filter_mask.AllValid() || !partition.all_valid[0]) { + auto &key = sort.data[0]; + auto &validity = FlatVector::Validity(key); + idx_t filtered = 0; + for (sel_t i = 0; i < sort.size(); ++i) { + if (filter_mask.RowIsValid(i + row_idx) && validity.RowIsValid(i)) { + filter_sel[filtered++] = i; + } + } + local_state.SinkChunk(sort, row_idx, filter_sel, filtered); + } else { + local_state.SinkChunk(sort, row_idx, nullptr, 0); + } + } + local_state.Sort(); + } + + inline idx_t SelectNth(const SubFrames &frames, size_t n) const { + return index_tree->SelectNth(frames, n); + } + + template + RESULT_TYPE WindowScalar(QuantileCursor &data, const SubFrames &frames, const idx_t n, Vector &result, + const QuantileValue &q) { + D_ASSERT(n > 0); + + // Thread safe and idempotent. + index_tree->Build(); + + // Find the interpolated indicies within the frame + Interpolator interp(q, n, false); + const auto lo_data = SelectNth(frames, interp.FRN); + auto hi_data = lo_data; + if (interp.CRN != interp.FRN) { + hi_data = SelectNth(frames, interp.CRN); + } + + // Interpolate indirectly + using ID = QuantileIndirect; + ID indirect(data); + return interp.template Interpolate(lo_data, hi_data, result, indirect); + } + + template + void WindowList(QuantileCursor &data, const SubFrames &frames, const idx_t n, Vector &list, + const idx_t lidx, const QuantileBindData &bind_data) { + D_ASSERT(n > 0); + + // Thread safe and idempotent. + index_tree->Build(); + + // Result is a constant LIST with a fixed length + auto ldata = FlatVector::GetData(list); + auto &lentry = ldata[lidx]; + lentry.offset = ListVector::GetListSize(list); + lentry.length = bind_data.quantiles.size(); + + ListVector::Reserve(list, lentry.offset + lentry.length); + ListVector::SetListSize(list, lentry.offset + lentry.length); + auto &result = ListVector::GetEntry(list); + auto rdata = FlatVector::GetData(result); + + using ID = QuantileIndirect; + ID indirect(data); + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + Interpolator interp(quantile, n, false); + + const auto lo_data = SelectNth(frames, interp.FRN); + auto hi_data = lo_data; + if (interp.CRN != interp.FRN) { + hi_data = SelectNth(frames, interp.CRN); + } + + // Interpolate indirectly + rdata[lentry.offset + q] = + interp.template Interpolate(lo_data, hi_data, result, indirect); + } + } +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp new file mode 100644 index 00000000..00f4baf7 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp @@ -0,0 +1,307 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/quantile_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "core_functions/aggregate/quantile_sort_tree.hpp" +#include "SkipList.h" + +namespace duckdb { + +struct QuantileOperation { + template + static void Initialize(STATE &state) { + new (&state) STATE(); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &aggr_input) { + state.AddElement(input, aggr_input.input); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (source.v.empty()) { + return; + } + target.v.insert(target.v.end(), source.v.begin(), source.v.end()); + } + + template + static void Destroy(STATE &state, AggregateInputData &) { + state.~STATE(); + } + + static bool IgnoreNull() { + return true; + } + + template + static void WindowInit(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + data_ptr_t g_state) { + D_ASSERT(partition.inputs); + + const auto &stats = partition.stats; + + // If frames overlap significantly, then use local skip lists. + if (stats[0].end <= stats[1].begin) { + // Frames can overlap + const auto overlap = double(stats[1].begin - stats[0].end); + const auto cover = double(stats[1].end - stats[0].begin); + const auto ratio = overlap / cover; + if (ratio > .75) { + return; + } + } + + // Build the tree + auto &state = *reinterpret_cast(g_state); + auto &window_state = state.GetOrCreateWindowState(); + window_state.qst = make_uniq(aggr_input_data, partition); + } + + template + static idx_t FrameSize(QuantileIncluded &included, const SubFrames &frames) { + // Count the number of valid values + idx_t n = 0; + if (included.AllValid()) { + for (const auto &frame : frames) { + n += frame.end - frame.start; + } + } else { + // NULLs or FILTERed values, + for (const auto &frame : frames) { + for (auto i = frame.start; i < frame.end; ++i) { + n += included(i); + } + } + } + + return n; + } +}; + +template +struct SkipLess { + inline bool operator()(const T &lhi, const T &rhi) const { + return lhi.second < rhi.second; + } +}; + +template +struct WindowQuantileState { + // Windowed Quantile merge sort trees + unique_ptr qst; + + // Windowed Quantile skip lists + using SkipType = pair; + using SkipListType = duckdb_skiplistlib::skip_list::HeadNode>; + SubFrames prevs; + unique_ptr s; + mutable vector skips; + + // Windowed MAD indirection + idx_t count; + vector m; + + using IncludedType = QuantileIncluded; + using CursorType = QuantileCursor; + + WindowQuantileState() : count(0) { + } + + inline void SetCount(size_t count_p) { + count = count_p; + if (count >= m.size()) { + m.resize(count); + } + } + + inline SkipListType &GetSkipList(bool reset = false) { + if (reset || !s) { + s.reset(); + s = make_uniq(); + } + return *s; + } + + struct SkipListUpdater { + SkipListType &skip; + CursorType &data; + IncludedType &included; + + inline SkipListUpdater(SkipListType &skip, CursorType &data, IncludedType &included) + : skip(skip), data(data), included(included) { + } + + inline void Neither(idx_t begin, idx_t end) { + } + + inline void Left(idx_t begin, idx_t end) { + for (; begin < end; ++begin) { + if (included(begin)) { + skip.remove(SkipType(begin, data[begin])); + } + } + } + + inline void Right(idx_t begin, idx_t end) { + for (; begin < end; ++begin) { + if (included(begin)) { + skip.insert(SkipType(begin, data[begin])); + } + } + } + + inline void Both(idx_t begin, idx_t end) { + } + }; + + void UpdateSkip(CursorType &data, const SubFrames &frames, IncludedType &included) { + // No overlap, or no data + if (!s || prevs.back().end <= frames.front().start || frames.back().end <= prevs.front().start) { + auto &skip = GetSkipList(true); + for (const auto &frame : frames) { + for (auto i = frame.start; i < frame.end; ++i) { + if (included(i)) { + skip.insert(SkipType(i, data[i])); + } + } + } + } else { + auto &skip = GetSkipList(); + SkipListUpdater updater(skip, data, included); + AggregateExecutor::IntersectFrames(prevs, frames, updater); + } + } + + bool HasTree() const { + return qst.get(); + } + + template + RESULT_TYPE WindowScalar(CursorType &data, const SubFrames &frames, const idx_t n, Vector &result, + const QuantileValue &q) const { + D_ASSERT(n > 0); + if (qst) { + return qst->WindowScalar(data, frames, n, result, q); + } else if (s) { + // Find the position(s) needed + try { + Interpolator interp(q, s->size(), false); + s->at(interp.FRN, interp.CRN - interp.FRN + 1, skips); + array dest; + dest[0] = skips[0].second; + if (skips.size() > 1) { + dest[1] = skips[1].second; + } + return interp.template Extract(dest.data(), result); + } catch (const duckdb_skiplistlib::skip_list::IndexError &idx_err) { + throw InternalException(idx_err.message()); + } + } else { + throw InternalException("No accelerator for scalar QUANTILE"); + } + } + + template + void WindowList(CursorType &data, const SubFrames &frames, const idx_t n, Vector &list, const idx_t lidx, + const QuantileBindData &bind_data) const { + D_ASSERT(n > 0); + // Result is a constant LIST with a fixed length + auto ldata = FlatVector::GetData(list); + auto &lentry = ldata[lidx]; + lentry.offset = ListVector::GetListSize(list); + lentry.length = bind_data.quantiles.size(); + + ListVector::Reserve(list, lentry.offset + lentry.length); + ListVector::SetListSize(list, lentry.offset + lentry.length); + auto &result = ListVector::GetEntry(list); + auto rdata = FlatVector::GetData(result); + + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + rdata[lentry.offset + q] = WindowScalar(data, frames, n, result, quantile); + } + } +}; + +struct QuantileStandardType { + template + static T Operation(T input, AggregateInputData &) { + return input; + } +}; + +struct QuantileStringType { + template + static T Operation(T input, AggregateInputData &input_data) { + if (input.IsInlined()) { + return input; + } + auto string_data = input_data.allocator.Allocate(input.GetSize()); + memcpy(string_data, input.GetData(), input.GetSize()); + return string_t(char_ptr_cast(string_data), UnsafeNumericCast(input.GetSize())); + } +}; + +template +struct QuantileState { + using InputType = INPUT_TYPE; + using CursorType = QuantileCursor; + + // Regular aggregation + vector v; + + // Window Quantile State + unique_ptr> window_state; + unique_ptr window_cursor; + + void AddElement(INPUT_TYPE element, AggregateInputData &aggr_input) { + v.emplace_back(TYPE_OP::Operation(element, aggr_input)); + } + + bool HasTree() const { + return window_state && window_state->HasTree(); + } + WindowQuantileState &GetOrCreateWindowState() { + if (!window_state) { + window_state = make_uniq>(); + } + return *window_state; + } + WindowQuantileState &GetWindowState() { + return *window_state; + } + const WindowQuantileState &GetWindowState() const { + return *window_state; + } + + CursorType &GetOrCreateWindowCursor(const WindowPartitionInput &partition) { + if (!window_cursor) { + window_cursor = make_uniq(partition); + } + return *window_cursor; + } + CursorType &GetWindowCursor() { + return *window_cursor; + } + const CursorType &GetWindowCursor() const { + return *window_cursor; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_count.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_count.hpp new file mode 100644 index 00000000..40366ef6 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_count.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/regression/regr_count.hpp +// +// +//===----------------------------------------------------------------------===// +// REGR_COUNT(y, x) + +#pragma once + +#include "duckdb/function/aggregate_function.hpp" +#include "core_functions/aggregate/algebraic/covar.hpp" +#include "core_functions/aggregate/algebraic/stddev.hpp" + +namespace duckdb { + +struct RegrCountFunction { + template + static void Initialize(STATE &state) { + state = 0; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + target += source; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + target = static_cast(state); + } + static bool IgnoreNull() { + return true; + } + template + static void Operation(STATE &state, const A_TYPE &, const B_TYPE &, AggregateBinaryInput &) { + state += 1; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_slope.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_slope.hpp new file mode 100644 index 00000000..d89af040 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_slope.hpp @@ -0,0 +1,57 @@ +// REGR_SLOPE(y, x) +// Returns the slope of the linear regression line for non-null pairs in a group. +// It is computed for non-null pairs using the following formula: +// COVAR_POP(x,y) / VAR_POP(x) + +//! Input : Any numeric type +//! Output : Double + +#pragma once +#include "core_functions/aggregate/algebraic/stddev.hpp" +#include "core_functions/aggregate/algebraic/covar.hpp" + +namespace duckdb { + +struct RegrSlopeState { + CovarState cov_pop; + StddevState var_pop; +}; + +struct RegrSlopeOperation { + template + static void Initialize(STATE &state) { + CovarOperation::Initialize(state.cov_pop); + STDDevBaseOperation::Initialize(state.var_pop); + } + + template + static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { + CovarOperation::Operation(state.cov_pop, y, x, idata); + STDDevBaseOperation::Execute(state.var_pop, x); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); + STDDevBaseOperation::Combine(source.var_pop, target.var_pop, aggr_input_data); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.cov_pop.count == 0 || state.var_pop.count == 0) { + finalize_data.ReturnNull(); + } else { + auto cov = state.cov_pop.co_moment / state.cov_pop.count; + auto var_pop = state.var_pop.count > 1 ? (state.var_pop.dsquared / state.var_pop.count) : 0; + if (!Value::DoubleIsFinite(var_pop)) { + throw OutOfRangeException("VARPOP is out of range!"); + } + target = var_pop != 0 ? cov / var_pop : NAN; + } + } + + static bool IgnoreNull() { + return true; + } +}; +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression_functions.hpp new file mode 100644 index 00000000..e82b9fdf --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression_functions.hpp @@ -0,0 +1,99 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/aggregate/regression_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct RegrAvgxFun { + static constexpr const char *Name = "regr_avgx"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the average of the independent variable for non-null pairs in a group, where x is the independent variable and y is the dependent variable."; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct RegrAvgyFun { + static constexpr const char *Name = "regr_avgy"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the average of the dependent variable for non-null pairs in a group, where x is the independent variable and y is the dependent variable."; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct RegrCountFun { + static constexpr const char *Name = "regr_count"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the number of non-null number pairs in a group."; + static constexpr const char *Example = "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / COUNT(*)"; + + static AggregateFunction GetFunction(); +}; + +struct RegrInterceptFun { + static constexpr const char *Name = "regr_intercept"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the intercept of the univariate linear regression line for non-null pairs in a group."; + static constexpr const char *Example = "AVG(y)-REGR_SLOPE(y,x)*AVG(x)"; + + static AggregateFunction GetFunction(); +}; + +struct RegrR2Fun { + static constexpr const char *Name = "regr_r2"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the coefficient of determination for non-null pairs in a group."; + static constexpr const char *Example = ""; + + static AggregateFunction GetFunction(); +}; + +struct RegrSlopeFun { + static constexpr const char *Name = "regr_slope"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the slope of the linear regression line for non-null pairs in a group."; + static constexpr const char *Example = "COVAR_POP(x,y) / VAR_POP(x)"; + + static AggregateFunction GetFunction(); +}; + +struct RegrSXXFun { + static constexpr const char *Name = "regr_sxx"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = ""; + static constexpr const char *Example = "REGR_COUNT(y, x) * VAR_POP(x)"; + + static AggregateFunction GetFunction(); +}; + +struct RegrSXYFun { + static constexpr const char *Name = "regr_sxy"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Returns the population covariance of input values"; + static constexpr const char *Example = "REGR_COUNT(y, x) * COVAR_POP(y, x)"; + + static AggregateFunction GetFunction(); +}; + +struct RegrSYYFun { + static constexpr const char *Name = "regr_syy"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = ""; + static constexpr const char *Example = "REGR_COUNT(y, x) * VAR_POP(y)"; + + static AggregateFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp new file mode 100644 index 00000000..562f61ad --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp @@ -0,0 +1,175 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/aggregate/sum_helpers.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +static inline void KahanAddInternal(double input, double &summed, double &err) { + double diff = input - err; + double newval = summed + diff; + err = (newval - summed) - diff; + summed = newval; +} + +template +struct SumState { + bool isset; + T value; + + void Initialize() { + this->isset = false; + } + + void Combine(const SumState &other) { + this->isset = other.isset || this->isset; + this->value += other.value; + } +}; + +struct KahanSumState { + bool isset; + double value; + double err; + + void Initialize() { + this->isset = false; + this->err = 0.0; + } + + void Combine(const KahanSumState &other) { + this->isset = other.isset || this->isset; + KahanAddInternal(other.value, this->value, this->err); + KahanAddInternal(other.err, this->value, this->err); + } +}; + +struct RegularAdd { + template + static void AddNumber(STATE &state, T input) { + state.value += input; + } + + template + static void AddConstant(STATE &state, T input, idx_t count) { + state.value += input * int64_t(count); + } +}; + +struct HugeintAdd { + template + static void AddNumber(STATE &state, T input) { + state.value = Hugeint::Add(state.value, input); + } + + template + static void AddConstant(STATE &state, T input, idx_t count) { + AddNumber(state, Hugeint::Multiply(input, UnsafeNumericCast(count))); + } +}; + +struct KahanAdd { + template + static void AddNumber(STATE &state, T input) { + KahanAddInternal(input, state.value, state.err); + } + + template + static void AddConstant(STATE &state, T input, idx_t count) { + KahanAddInternal(input * count, state.value, state.err); + } +}; + +struct AddToHugeint { + static void AddValue(hugeint_t &result, uint64_t value, int positive) { + // integer summation taken from Tim Gubner et al. - Efficient Query Processing + // with Optimistically Compressed Hash Tables & Strings in the USSR + + // add the value to the lower part of the hugeint + result.lower += value; + // now handle overflows + int overflow = result.lower < value; + // we consider two situations: + // (1) input[idx] is positive, and current value is lower than value: overflow + // (2) input[idx] is negative, and current value is higher than value: underflow + if (!(overflow ^ positive)) { + // in the case of an overflow or underflow we either increment or decrement the upper base + // positive: +1, negative: -1 + result.upper += -1 + 2 * positive; + } + } + + template + static void AddNumber(STATE &state, T input) { + AddValue(state.value, uint64_t(input), input >= 0); + } + + template + static void AddConstant(STATE &state, T input, idx_t count) { + // add a constant X number of times + // fast path: check if value * count fits into a uint64_t + // note that we check if value * VECTOR_SIZE fits in a uint64_t to avoid having to actually do a division + // this is still a pretty high number (18014398509481984) so most positive numbers will fit + if (input >= 0 && uint64_t(input) < (NumericLimits::Maximum() / STANDARD_VECTOR_SIZE)) { + // if it does just multiply it and add the value + uint64_t value = uint64_t(input) * count; + AddValue(state.value, value, 1); + } else { + // if it doesn't fit we have two choices + // either we loop over count and add the values individually + // or we convert to a hugeint and multiply the hugeint + // the problem is that hugeint multiplication is expensive + // hence we switch here: with a low count we do the loop + // with a high count we do the hugeint multiplication + if (count < 8) { + for (idx_t i = 0; i < count; i++) { + AddValue(state.value, uint64_t(input), input >= 0); + } + } else { + hugeint_t addition = hugeint_t(input) * Hugeint::Convert(count); + state.value += addition; + } + } + } +}; + +template +struct BaseSumOperation { + template + static void Initialize(STATE &state) { + state.value = 0; + STATEOP::template Initialize(state); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { + STATEOP::template Combine(source, target, aggr_input_data); + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + STATEOP::template AddValues(state, 1); + ADDOP::template AddNumber(state, input); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &, idx_t count) { + STATEOP::template AddValues(state, count); + ADDOP::template AddConstant(state, input, count); + } + + static bool IgnoreNull() { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp b/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp new file mode 100644 index 00000000..dd6e2915 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp @@ -0,0 +1,107 @@ +#pragma once +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/algorithm.hpp" +#include + +namespace duckdb { + +//------------------------------------------------------------------------- +// Folding Operations +//------------------------------------------------------------------------- +struct InnerProductOp { + static constexpr bool ALLOW_EMPTY = true; + + template + static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { + + TYPE result = 0; + + auto lhs_ptr = lhs_data; + auto rhs_ptr = rhs_data; + + for (idx_t i = 0; i < count; i++) { + const auto x = *lhs_ptr++; + const auto y = *rhs_ptr++; + result += x * y; + } + + return result; + } +}; + +struct NegativeInnerProductOp { + static constexpr bool ALLOW_EMPTY = true; + + template + static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { + return -InnerProductOp::Operation(lhs_data, rhs_data, count); + } +}; + +struct CosineSimilarityOp { + static constexpr bool ALLOW_EMPTY = false; + + template + static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { + + TYPE distance = 0; + TYPE norm_l = 0; + TYPE norm_r = 0; + + auto l_ptr = lhs_data; + auto r_ptr = rhs_data; + + for (idx_t i = 0; i < count; i++) { + const auto x = *l_ptr++; + const auto y = *r_ptr++; + distance += x * y; + norm_l += x * x; + norm_r += y * y; + } + + auto similarity = distance / std::sqrt(norm_l * norm_r); + return std::max(static_cast(-1.0), std::min(similarity, static_cast(1.0))); + } +}; + +struct CosineDistanceOp { + static constexpr bool ALLOW_EMPTY = false; + + template + static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { + return static_cast(1.0) - CosineSimilarityOp::Operation(lhs_data, rhs_data, count); + } +}; + +struct DistanceSquaredOp { + static constexpr bool ALLOW_EMPTY = true; + + template + static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { + + TYPE distance = 0; + + auto l_ptr = lhs_data; + auto r_ptr = rhs_data; + + for (idx_t i = 0; i < count; i++) { + const auto x = *l_ptr++; + const auto y = *r_ptr++; + const auto diff = x - y; + distance += diff * diff; + } + + return distance; + } +}; + +struct DistanceOp { + static constexpr bool ALLOW_EMPTY = true; + + template + static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { + return std::sqrt(DistanceSquaredOp::Operation(lhs_data, rhs_data, count)); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/function_list.hpp b/src/duckdb/extension/core_functions/include/core_functions/function_list.hpp new file mode 100644 index 00000000..024ca49f --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/function_list.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/function_list.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.hpp" + +namespace duckdb { + +typedef ScalarFunction (*get_scalar_function_t)(); +typedef ScalarFunctionSet (*get_scalar_function_set_t)(); +typedef AggregateFunction (*get_aggregate_function_t)(); +typedef AggregateFunctionSet (*get_aggregate_function_set_t)(); + +struct StaticFunctionDefinition { + const char *name; + const char *parameters; + const char *description; + const char *example; + get_scalar_function_t get_function; + get_scalar_function_set_t get_function_set; + get_aggregate_function_t get_aggregate_function; + get_aggregate_function_set_t get_aggregate_function_set; + + static const StaticFunctionDefinition *GetFunctionList(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/array_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/array_functions.hpp new file mode 100644 index 00000000..561643be --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/array_functions.hpp @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/array_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct ArrayValueFun { + static constexpr const char *Name = "array_value"; + static constexpr const char *Parameters = "any,..."; + static constexpr const char *Description = "Create an ARRAY containing the argument values."; + static constexpr const char *Example = "array_value(4, 5, 6)"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayCrossProductFun { + static constexpr const char *Name = "array_cross_product"; + static constexpr const char *Parameters = "array, array"; + static constexpr const char *Description = "Compute the cross product of two arrays of size 3. The array elements can not be NULL."; + static constexpr const char *Example = "array_cross_product([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArrayCosineSimilarityFun { + static constexpr const char *Name = "array_cosine_similarity"; + static constexpr const char *Parameters = "array1,array2"; + static constexpr const char *Description = "Compute the cosine similarity between two arrays of the same size. The array elements can not be NULL. The arrays can have any size as long as the size is the same for both arguments."; + static constexpr const char *Example = "array_cosine_similarity([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArrayCosineDistanceFun { + static constexpr const char *Name = "array_cosine_distance"; + static constexpr const char *Parameters = "array1,array2"; + static constexpr const char *Description = "Compute the cosine distance between two arrays of the same size. The array elements can not be NULL. The arrays can have any size as long as the size is the same for both arguments."; + static constexpr const char *Example = "array_cosine_distance([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArrayDistanceFun { + static constexpr const char *Name = "array_distance"; + static constexpr const char *Parameters = "array1,array2"; + static constexpr const char *Description = "Compute the distance between two arrays of the same size. The array elements can not be NULL. The arrays can have any size as long as the size is the same for both arguments."; + static constexpr const char *Example = "array_distance([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArrayInnerProductFun { + static constexpr const char *Name = "array_inner_product"; + static constexpr const char *Parameters = "array1,array2"; + static constexpr const char *Description = "Compute the inner product between two arrays of the same size. The array elements can not be NULL. The arrays can have any size as long as the size is the same for both arguments."; + static constexpr const char *Example = "array_inner_product([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArrayDotProductFun { + using ALIAS = ArrayInnerProductFun; + + static constexpr const char *Name = "array_dot_product"; +}; + +struct ArrayNegativeInnerProductFun { + static constexpr const char *Name = "array_negative_inner_product"; + static constexpr const char *Parameters = "array1,array2"; + static constexpr const char *Description = "Compute the negative inner product between two arrays of the same size. The array elements can not be NULL. The arrays can have any size as long as the size is the same for both arguments."; + static constexpr const char *Example = "array_negative_inner_product([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArrayNegativeDotProductFun { + using ALIAS = ArrayNegativeInnerProductFun; + + static constexpr const char *Name = "array_negative_dot_product"; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/bit_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/bit_functions.hpp new file mode 100644 index 00000000..e01a2fc5 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/bit_functions.hpp @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/bit_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct GetBitFun { + static constexpr const char *Name = "get_bit"; + static constexpr const char *Parameters = "bitstring,index"; + static constexpr const char *Description = "Extracts the nth bit from bitstring; the first (leftmost) bit is indexed 0"; + static constexpr const char *Example = "get_bit('0110010'::BIT, 2)"; + + static ScalarFunction GetFunction(); +}; + +struct SetBitFun { + static constexpr const char *Name = "set_bit"; + static constexpr const char *Parameters = "bitstring,index,new_value"; + static constexpr const char *Description = "Sets the nth bit in bitstring to newvalue; the first (leftmost) bit is indexed 0. Returns a new bitstring"; + static constexpr const char *Example = "set_bit('0110010'::BIT, 2, 0)"; + + static ScalarFunction GetFunction(); +}; + +struct BitPositionFun { + static constexpr const char *Name = "bit_position"; + static constexpr const char *Parameters = "substring,bitstring"; + static constexpr const char *Description = "Returns first starting index of the specified substring within bits, or zero if it is not present. The first (leftmost) bit is indexed 1"; + static constexpr const char *Example = "bit_position('010'::BIT, '1110101'::BIT)"; + + static ScalarFunction GetFunction(); +}; + +struct BitStringFun { + static constexpr const char *Name = "bitstring"; + static constexpr const char *Parameters = "bitstring,length"; + static constexpr const char *Description = "Pads the bitstring until the specified length"; + static constexpr const char *Example = "bitstring('1010'::BIT, 7)"; + + static ScalarFunctionSet GetFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/blob_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/blob_functions.hpp new file mode 100644 index 00000000..051e212c --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/blob_functions.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/blob_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct DecodeFun { + static constexpr const char *Name = "decode"; + static constexpr const char *Parameters = "blob"; + static constexpr const char *Description = "Convert blob to varchar. Fails if blob is not valid utf-8"; + static constexpr const char *Example = "decode('\\xC3\\xBC'::BLOB)"; + + static ScalarFunction GetFunction(); +}; + +struct EncodeFun { + static constexpr const char *Name = "encode"; + static constexpr const char *Parameters = "string"; + static constexpr const char *Description = "Convert varchar to blob. Converts utf-8 characters into literal encoding"; + static constexpr const char *Example = "encode('my_string_with_ü')"; + + static ScalarFunction GetFunction(); +}; + +struct FromBase64Fun { + static constexpr const char *Name = "from_base64"; + static constexpr const char *Parameters = "string"; + static constexpr const char *Description = "Convert a base64 encoded string to a character string"; + static constexpr const char *Example = "from_base64('QQ==')"; + + static ScalarFunction GetFunction(); +}; + +struct ToBase64Fun { + static constexpr const char *Name = "to_base64"; + static constexpr const char *Parameters = "blob"; + static constexpr const char *Description = "Convert a blob to a base64 encoded string"; + static constexpr const char *Example = "base64('A'::blob)"; + + static ScalarFunction GetFunction(); +}; + +struct Base64Fun { + using ALIAS = ToBase64Fun; + + static constexpr const char *Name = "base64"; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/date_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/date_functions.hpp new file mode 100644 index 00000000..7256502a --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/date_functions.hpp @@ -0,0 +1,603 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/date_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct AgeFun { + static constexpr const char *Name = "age"; + static constexpr const char *Parameters = "timestamp,timestamp"; + static constexpr const char *Description = "Subtract arguments, resulting in the time difference between the two timestamps"; + static constexpr const char *Example = "age(TIMESTAMP '2001-04-10', TIMESTAMP '1992-09-20')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct CenturyFun { + static constexpr const char *Name = "century"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the century component from a date or timestamp"; + static constexpr const char *Example = "century(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DateDiffFun { + static constexpr const char *Name = "date_diff"; + static constexpr const char *Parameters = "part,startdate,enddate"; + static constexpr const char *Description = "The number of partition boundaries between the timestamps"; + static constexpr const char *Example = "date_diff('hour', TIMESTAMPTZ '1992-09-30 23:59:59', TIMESTAMPTZ '1992-10-01 01:58:00')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DatediffFun { + using ALIAS = DateDiffFun; + + static constexpr const char *Name = "datediff"; +}; + +struct DatePartFun { + static constexpr const char *Name = "date_part"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Get subfield (equivalent to extract)"; + static constexpr const char *Example = "date_part('minute', TIMESTAMP '1992-09-20 20:38:40')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DatepartFun { + using ALIAS = DatePartFun; + + static constexpr const char *Name = "datepart"; +}; + +struct DateSubFun { + static constexpr const char *Name = "date_sub"; + static constexpr const char *Parameters = "part,startdate,enddate"; + static constexpr const char *Description = "The number of complete partitions between the timestamps"; + static constexpr const char *Example = "date_sub('hour', TIMESTAMPTZ '1992-09-30 23:59:59', TIMESTAMPTZ '1992-10-01 01:58:00')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DatesubFun { + using ALIAS = DateSubFun; + + static constexpr const char *Name = "datesub"; +}; + +struct DateTruncFun { + static constexpr const char *Name = "date_trunc"; + static constexpr const char *Parameters = "part,timestamp"; + static constexpr const char *Description = "Truncate to specified precision"; + static constexpr const char *Example = "date_trunc('hour', TIMESTAMPTZ '1992-09-20 20:38:40')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DatetruncFun { + using ALIAS = DateTruncFun; + + static constexpr const char *Name = "datetrunc"; +}; + +struct DayFun { + static constexpr const char *Name = "day"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the day component from a date or timestamp"; + static constexpr const char *Example = "day(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DayNameFun { + static constexpr const char *Name = "dayname"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "The (English) name of the weekday"; + static constexpr const char *Example = "dayname(TIMESTAMP '1992-03-22')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DayOfMonthFun { + static constexpr const char *Name = "dayofmonth"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the dayofmonth component from a date or timestamp"; + static constexpr const char *Example = "dayofmonth(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DayOfWeekFun { + static constexpr const char *Name = "dayofweek"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the dayofweek component from a date or timestamp"; + static constexpr const char *Example = "dayofweek(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DayOfYearFun { + static constexpr const char *Name = "dayofyear"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the dayofyear component from a date or timestamp"; + static constexpr const char *Example = "dayofyear(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct DecadeFun { + static constexpr const char *Name = "decade"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the decade component from a date or timestamp"; + static constexpr const char *Example = "decade(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct EpochFun { + static constexpr const char *Name = "epoch"; + static constexpr const char *Parameters = "temporal"; + static constexpr const char *Description = "Extract the epoch component from a temporal type"; + static constexpr const char *Example = "epoch(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct EpochMsFun { + static constexpr const char *Name = "epoch_ms"; + static constexpr const char *Parameters = "temporal"; + static constexpr const char *Description = "Extract the epoch component in milliseconds from a temporal type"; + static constexpr const char *Example = "epoch_ms(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct EpochUsFun { + static constexpr const char *Name = "epoch_us"; + static constexpr const char *Parameters = "temporal"; + static constexpr const char *Description = "Extract the epoch component in microseconds from a temporal type"; + static constexpr const char *Example = "epoch_us(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct EpochNsFun { + static constexpr const char *Name = "epoch_ns"; + static constexpr const char *Parameters = "temporal"; + static constexpr const char *Description = "Extract the epoch component in nanoseconds from a temporal type"; + static constexpr const char *Example = "epoch_ns(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct EraFun { + static constexpr const char *Name = "era"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the era component from a date or timestamp"; + static constexpr const char *Example = "era(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct GetCurrentTimestampFun { + static constexpr const char *Name = "get_current_timestamp"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the current timestamp"; + static constexpr const char *Example = "get_current_timestamp()"; + + static ScalarFunction GetFunction(); +}; + +struct NowFun { + using ALIAS = GetCurrentTimestampFun; + + static constexpr const char *Name = "now"; +}; + +struct TransactionTimestampFun { + using ALIAS = GetCurrentTimestampFun; + + static constexpr const char *Name = "transaction_timestamp"; +}; + +struct HoursFun { + static constexpr const char *Name = "hour"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the hour component from a date or timestamp"; + static constexpr const char *Example = "hour(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ISODayOfWeekFun { + static constexpr const char *Name = "isodow"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the isodow component from a date or timestamp"; + static constexpr const char *Example = "isodow(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ISOYearFun { + static constexpr const char *Name = "isoyear"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the isoyear component from a date or timestamp"; + static constexpr const char *Example = "isoyear(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct JulianDayFun { + static constexpr const char *Name = "julian"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the Julian Day number from a date or timestamp"; + static constexpr const char *Example = "julian(timestamp '2006-01-01 12:00')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct LastDayFun { + static constexpr const char *Name = "last_day"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Returns the last day of the month"; + static constexpr const char *Example = "last_day(TIMESTAMP '1992-03-22 01:02:03.1234')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MakeDateFun { + static constexpr const char *Name = "make_date"; + static constexpr const char *Parameters = "year,month,day\1date-struct::STRUCT(year BIGINT, month BIGINT, day BIGINT)"; + static constexpr const char *Description = "The date for the given parts\1The date for the given struct."; + static constexpr const char *Example = "make_date(1992, 9, 20)\1make_date({'year': 2024, 'month': 11, 'day': 14})"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MakeTimeFun { + static constexpr const char *Name = "make_time"; + static constexpr const char *Parameters = "hour,minute,seconds"; + static constexpr const char *Description = "The time for the given parts"; + static constexpr const char *Example = "make_time(13, 34, 27.123456)"; + + static ScalarFunction GetFunction(); +}; + +struct MakeTimestampFun { + static constexpr const char *Name = "make_timestamp"; + static constexpr const char *Parameters = "year,month,day,hour,minute,seconds"; + static constexpr const char *Description = "The timestamp for the given parts"; + static constexpr const char *Example = "make_timestamp(1992, 9, 20, 13, 34, 27.123456)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MakeTimestampNsFun { + static constexpr const char *Name = "make_timestamp_ns"; + static constexpr const char *Parameters = "nanos"; + static constexpr const char *Description = "The timestamp for the given nanoseconds since epoch"; + static constexpr const char *Example = "make_timestamp(1732117793000000000)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MicrosecondsFun { + static constexpr const char *Name = "microsecond"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the microsecond component from a date or timestamp"; + static constexpr const char *Example = "microsecond(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MillenniumFun { + static constexpr const char *Name = "millennium"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the millennium component from a date or timestamp"; + static constexpr const char *Example = "millennium(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MillisecondsFun { + static constexpr const char *Name = "millisecond"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the millisecond component from a date or timestamp"; + static constexpr const char *Example = "millisecond(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MinutesFun { + static constexpr const char *Name = "minute"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the minute component from a date or timestamp"; + static constexpr const char *Example = "minute(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MonthFun { + static constexpr const char *Name = "month"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the month component from a date or timestamp"; + static constexpr const char *Example = "month(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct MonthNameFun { + static constexpr const char *Name = "monthname"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "The (English) name of the month"; + static constexpr const char *Example = "monthname(TIMESTAMP '1992-09-20')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct NanosecondsFun { + static constexpr const char *Name = "nanosecond"; + static constexpr const char *Parameters = "tsns"; + static constexpr const char *Description = "Extract the nanosecond component from a date or timestamp"; + static constexpr const char *Example = "nanosecond(timestamp_ns '2021-08-03 11:59:44.123456789') => 44123456789"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct NormalizedIntervalFun { + static constexpr const char *Name = "normalized_interval"; + static constexpr const char *Parameters = "interval"; + static constexpr const char *Description = "Normalizes an INTERVAL to an equivalent interval"; + static constexpr const char *Example = "normalized_interval(INTERVAL '30 days')"; + + static ScalarFunction GetFunction(); +}; + +struct QuarterFun { + static constexpr const char *Name = "quarter"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the quarter component from a date or timestamp"; + static constexpr const char *Example = "quarter(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct SecondsFun { + static constexpr const char *Name = "second"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the second component from a date or timestamp"; + static constexpr const char *Example = "second(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct TimeBucketFun { + static constexpr const char *Name = "time_bucket"; + static constexpr const char *Parameters = "bucket_width,timestamp,origin"; + static constexpr const char *Description = "Truncate TIMESTAMPTZ by the specified interval bucket_width. Buckets are aligned relative to origin TIMESTAMPTZ. The origin defaults to 2000-01-03 00:00:00+00 for buckets that do not include a month or year interval, and to 2000-01-01 00:00:00+00 for month and year buckets"; + static constexpr const char *Example = "time_bucket(INTERVAL '2 weeks', TIMESTAMP '1992-04-20 15:26:00-07', TIMESTAMP '1992-04-01 00:00:00-07')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct TimezoneFun { + static constexpr const char *Name = "timezone"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the timezone component from a date or timestamp"; + static constexpr const char *Example = "timezone(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct TimezoneHourFun { + static constexpr const char *Name = "timezone_hour"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the timezone_hour component from a date or timestamp"; + static constexpr const char *Example = "timezone_hour(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct TimezoneMinuteFun { + static constexpr const char *Name = "timezone_minute"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the timezone_minute component from a date or timestamp"; + static constexpr const char *Example = "timezone_minute(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct TimeTZSortKeyFun { + static constexpr const char *Name = "timetz_byte_comparable"; + static constexpr const char *Parameters = "time_tz"; + static constexpr const char *Description = "Converts a TIME WITH TIME ZONE to an integer sort key"; + static constexpr const char *Example = "timetz_byte_comparable('18:18:16.21-07:00'::TIME_TZ)"; + + static ScalarFunction GetFunction(); +}; + +struct ToCenturiesFun { + static constexpr const char *Name = "to_centuries"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a century interval"; + static constexpr const char *Example = "to_centuries(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToDaysFun { + static constexpr const char *Name = "to_days"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a day interval"; + static constexpr const char *Example = "to_days(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToDecadesFun { + static constexpr const char *Name = "to_decades"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a decade interval"; + static constexpr const char *Example = "to_decades(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToHoursFun { + static constexpr const char *Name = "to_hours"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a hour interval"; + static constexpr const char *Example = "to_hours(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToMicrosecondsFun { + static constexpr const char *Name = "to_microseconds"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a microsecond interval"; + static constexpr const char *Example = "to_microseconds(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToMillenniaFun { + static constexpr const char *Name = "to_millennia"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a millenium interval"; + static constexpr const char *Example = "to_millennia(1)"; + + static ScalarFunction GetFunction(); +}; + +struct ToMillisecondsFun { + static constexpr const char *Name = "to_milliseconds"; + static constexpr const char *Parameters = "double"; + static constexpr const char *Description = "Construct a millisecond interval"; + static constexpr const char *Example = "to_milliseconds(5.5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToMinutesFun { + static constexpr const char *Name = "to_minutes"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a minute interval"; + static constexpr const char *Example = "to_minutes(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToMonthsFun { + static constexpr const char *Name = "to_months"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a month interval"; + static constexpr const char *Example = "to_months(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToQuartersFun { + static constexpr const char *Name = "to_quarters"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a quarter interval"; + static constexpr const char *Example = "to_quarters(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToSecondsFun { + static constexpr const char *Name = "to_seconds"; + static constexpr const char *Parameters = "double"; + static constexpr const char *Description = "Construct a second interval"; + static constexpr const char *Example = "to_seconds(5.5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToTimestampFun { + static constexpr const char *Name = "to_timestamp"; + static constexpr const char *Parameters = "sec"; + static constexpr const char *Description = "Converts secs since epoch to a timestamp with time zone"; + static constexpr const char *Example = "to_timestamp(1284352323.5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToWeeksFun { + static constexpr const char *Name = "to_weeks"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a week interval"; + static constexpr const char *Example = "to_weeks(5)"; + + static ScalarFunction GetFunction(); +}; + +struct ToYearsFun { + static constexpr const char *Name = "to_years"; + static constexpr const char *Parameters = "integer"; + static constexpr const char *Description = "Construct a year interval"; + static constexpr const char *Example = "to_years(5)"; + + static ScalarFunction GetFunction(); +}; + +struct WeekFun { + static constexpr const char *Name = "week"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the week component from a date or timestamp"; + static constexpr const char *Example = "week(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct WeekDayFun { + static constexpr const char *Name = "weekday"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the weekday component from a date or timestamp"; + static constexpr const char *Example = "weekday(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct WeekOfYearFun { + static constexpr const char *Name = "weekofyear"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the weekofyear component from a date or timestamp"; + static constexpr const char *Example = "weekofyear(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct YearFun { + static constexpr const char *Name = "year"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the year component from a date or timestamp"; + static constexpr const char *Example = "year(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct YearWeekFun { + static constexpr const char *Name = "yearweek"; + static constexpr const char *Parameters = "ts"; + static constexpr const char *Description = "Extract the yearweek component from a date or timestamp"; + static constexpr const char *Example = "yearweek(timestamp '2021-08-03 11:59:44.123456')"; + + static ScalarFunctionSet GetFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/debug_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/debug_functions.hpp new file mode 100644 index 00000000..ce4debc6 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/debug_functions.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/debug_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct VectorTypeFun { + static constexpr const char *Name = "vector_type"; + static constexpr const char *Parameters = "col"; + static constexpr const char *Description = "Returns the VectorType of a given column"; + static constexpr const char *Example = "vector_type(col)"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/enum_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/enum_functions.hpp new file mode 100644 index 00000000..73791f8a --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/enum_functions.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/enum_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct EnumFirstFun { + static constexpr const char *Name = "enum_first"; + static constexpr const char *Parameters = "enum"; + static constexpr const char *Description = "Returns the first value of the input enum type"; + static constexpr const char *Example = "enum_first(NULL::mood)"; + + static ScalarFunction GetFunction(); +}; + +struct EnumLastFun { + static constexpr const char *Name = "enum_last"; + static constexpr const char *Parameters = "enum"; + static constexpr const char *Description = "Returns the last value of the input enum type"; + static constexpr const char *Example = "enum_last(NULL::mood)"; + + static ScalarFunction GetFunction(); +}; + +struct EnumCodeFun { + static constexpr const char *Name = "enum_code"; + static constexpr const char *Parameters = "enum"; + static constexpr const char *Description = "Returns the numeric value backing the given enum value"; + static constexpr const char *Example = "enum_code('happy'::mood)"; + + static ScalarFunction GetFunction(); +}; + +struct EnumRangeFun { + static constexpr const char *Name = "enum_range"; + static constexpr const char *Parameters = "enum"; + static constexpr const char *Description = "Returns all values of the input enum type as an array"; + static constexpr const char *Example = "enum_range(NULL::mood)"; + + static ScalarFunction GetFunction(); +}; + +struct EnumRangeBoundaryFun { + static constexpr const char *Name = "enum_range_boundary"; + static constexpr const char *Parameters = "start,end"; + static constexpr const char *Description = "Returns the range between the two given enum values as an array. The values must be of the same enum type. When the first parameter is NULL, the result starts with the first value of the enum type. When the second parameter is NULL, the result ends with the last value of the enum type"; + static constexpr const char *Example = "enum_range_boundary(NULL, 'happy'::mood)"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/generic_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/generic_functions.hpp new file mode 100644 index 00000000..d874e72a --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/generic_functions.hpp @@ -0,0 +1,171 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/generic_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct AliasFun { + static constexpr const char *Name = "alias"; + static constexpr const char *Parameters = "expr"; + static constexpr const char *Description = "Returns the name of a given expression"; + static constexpr const char *Example = "alias(42 + 1)"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentSettingFun { + static constexpr const char *Name = "current_setting"; + static constexpr const char *Parameters = "setting_name"; + static constexpr const char *Description = "Returns the current value of the configuration setting"; + static constexpr const char *Example = "current_setting('access_mode')"; + + static ScalarFunction GetFunction(); +}; + +struct HashFun { + static constexpr const char *Name = "hash"; + static constexpr const char *Parameters = "param"; + static constexpr const char *Description = "Returns an integer with the hash of the value. Note that this is not a cryptographic hash"; + static constexpr const char *Example = "hash('🦆')"; + + static ScalarFunction GetFunction(); +}; + +struct LeastFun { + static constexpr const char *Name = "least"; + static constexpr const char *Parameters = "arg1, arg2, ..."; + static constexpr const char *Description = "Returns the lowest value of the set of input parameters"; + static constexpr const char *Example = "least(42, 84)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct GreatestFun { + static constexpr const char *Name = "greatest"; + static constexpr const char *Parameters = "arg1, arg2, ..."; + static constexpr const char *Description = "Returns the highest value of the set of input parameters"; + static constexpr const char *Example = "greatest(42, 84)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct StatsFun { + static constexpr const char *Name = "stats"; + static constexpr const char *Parameters = "expression"; + static constexpr const char *Description = "Returns a string with statistics about the expression. Expression can be a column, constant, or SQL expression"; + static constexpr const char *Example = "stats(5)"; + + static ScalarFunction GetFunction(); +}; + +struct TypeOfFun { + static constexpr const char *Name = "typeof"; + static constexpr const char *Parameters = "expression"; + static constexpr const char *Description = "Returns the name of the data type of the result of the expression"; + static constexpr const char *Example = "typeof('abc')"; + + static ScalarFunction GetFunction(); +}; + +struct CanCastImplicitlyFun { + static constexpr const char *Name = "can_cast_implicitly"; + static constexpr const char *Parameters = "source_type,target_type"; + static constexpr const char *Description = "Whether or not we can implicitly cast from the source type to the other type"; + static constexpr const char *Example = "can_implicitly_cast(NULL::INTEGER, NULL::BIGINT)"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentQueryFun { + static constexpr const char *Name = "current_query"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the current query as a string"; + static constexpr const char *Example = "current_query()"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentSchemaFun { + static constexpr const char *Name = "current_schema"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the name of the currently active schema. Default is main"; + static constexpr const char *Example = "current_schema()"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentSchemasFun { + static constexpr const char *Name = "current_schemas"; + static constexpr const char *Parameters = "include_implicit"; + static constexpr const char *Description = "Returns list of schemas. Pass a parameter of True to include implicit schemas"; + static constexpr const char *Example = "current_schemas(true)"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentDatabaseFun { + static constexpr const char *Name = "current_database"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the name of the currently active database"; + static constexpr const char *Example = "current_database()"; + + static ScalarFunction GetFunction(); +}; + +struct InSearchPathFun { + static constexpr const char *Name = "in_search_path"; + static constexpr const char *Parameters = "database_name,schema_name"; + static constexpr const char *Description = "Returns whether or not the database/schema are in the search path"; + static constexpr const char *Example = "in_search_path('memory', 'main')"; + + static ScalarFunction GetFunction(); +}; + +struct CurrentTransactionIdFun { + static constexpr const char *Name = "txid_current"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the current transaction’s ID (a BIGINT). It will assign a new one if the current transaction does not have one already"; + static constexpr const char *Example = "txid_current()"; + + static ScalarFunction GetFunction(); +}; + +struct VersionFun { + static constexpr const char *Name = "version"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the currently active version of DuckDB in this format: v0.3.2 "; + static constexpr const char *Example = "version()"; + + static ScalarFunction GetFunction(); +}; + +struct EquiWidthBinsFun { + static constexpr const char *Name = "equi_width_bins"; + static constexpr const char *Parameters = "min,max,bin_count,nice_rounding"; + static constexpr const char *Description = "Generates bin_count equi-width bins between the min and max. If enabled nice_rounding makes the numbers more readable/less jagged"; + static constexpr const char *Example = "equi_width_bins(0, 10, 2, true)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct IsHistogramOtherBinFun { + static constexpr const char *Name = "is_histogram_other_bin"; + static constexpr const char *Parameters = "val"; + static constexpr const char *Description = "Whether or not the provided value is the histogram \"other\" bin (used for values not belonging to any provided bin)"; + static constexpr const char *Example = "is_histogram_other_bin(v)"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/list_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/list_functions.hpp new file mode 100644 index 00000000..2b9318b4 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/list_functions.hpp @@ -0,0 +1,390 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/list_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct ListFlattenFun { + static constexpr const char *Name = "flatten"; + static constexpr const char *Parameters = "nested_list"; + static constexpr const char *Description = "Flatten a nested list by one level"; + static constexpr const char *Example = "flatten([[1, 2, 3], [4, 5]])"; + + static ScalarFunction GetFunction(); +}; + +struct ListAggregateFun { + static constexpr const char *Name = "list_aggregate"; + static constexpr const char *Parameters = "list,name"; + static constexpr const char *Description = "Executes the aggregate function name on the elements of list"; + static constexpr const char *Example = "list_aggregate([1, 2, NULL], 'min')"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayAggregateFun { + using ALIAS = ListAggregateFun; + + static constexpr const char *Name = "array_aggregate"; +}; + +struct ListAggrFun { + using ALIAS = ListAggregateFun; + + static constexpr const char *Name = "list_aggr"; +}; + +struct ArrayAggrFun { + using ALIAS = ListAggregateFun; + + static constexpr const char *Name = "array_aggr"; +}; + +struct AggregateFun { + using ALIAS = ListAggregateFun; + + static constexpr const char *Name = "aggregate"; +}; + +struct ListDistinctFun { + static constexpr const char *Name = "list_distinct"; + static constexpr const char *Parameters = "list"; + static constexpr const char *Description = "Removes all duplicates and NULLs from a list. Does not preserve the original order"; + static constexpr const char *Example = "list_distinct([1, 1, NULL, -3, 1, 5])"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayDistinctFun { + using ALIAS = ListDistinctFun; + + static constexpr const char *Name = "array_distinct"; +}; + +struct ListUniqueFun { + static constexpr const char *Name = "list_unique"; + static constexpr const char *Parameters = "list"; + static constexpr const char *Description = "Counts the unique elements of a list"; + static constexpr const char *Example = "list_unique([1, 1, NULL, -3, 1, 5])"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayUniqueFun { + using ALIAS = ListUniqueFun; + + static constexpr const char *Name = "array_unique"; +}; + +struct ListValueFun { + static constexpr const char *Name = "list_value"; + static constexpr const char *Parameters = "any,..."; + static constexpr const char *Description = "Create a LIST containing the argument values"; + static constexpr const char *Example = "list_value(4, 5, 6)"; + + static ScalarFunction GetFunction(); +}; + +struct ListPackFun { + using ALIAS = ListValueFun; + + static constexpr const char *Name = "list_pack"; +}; + +struct ListSliceFun { + static constexpr const char *Name = "list_slice"; + static constexpr const char *Parameters = "list,begin,end\1list,begin,end,step"; + static constexpr const char *Description = "Extract a sublist using slice conventions. Negative values are accepted.\1list_slice with added step feature."; + static constexpr const char *Example = "list_slice([4, 5, 6], 2, 3)\2array_slice('DuckDB', 3, 4)\2array_slice('DuckDB', 3, NULL)\2array_slice('DuckDB', 0, -3)\1list_slice([4, 5, 6], 1, 3, 2)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArraySliceFun { + using ALIAS = ListSliceFun; + + static constexpr const char *Name = "array_slice"; +}; + +struct ListSortFun { + static constexpr const char *Name = "list_sort"; + static constexpr const char *Parameters = "list"; + static constexpr const char *Description = "Sorts the elements of the list"; + static constexpr const char *Example = "list_sort([3, 6, 1, 2])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArraySortFun { + using ALIAS = ListSortFun; + + static constexpr const char *Name = "array_sort"; +}; + +struct ListGradeUpFun { + static constexpr const char *Name = "list_grade_up"; + static constexpr const char *Parameters = "list"; + static constexpr const char *Description = "Returns the index of their sorted position."; + static constexpr const char *Example = "list_grade_up([3, 6, 1, 2])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArrayGradeUpFun { + using ALIAS = ListGradeUpFun; + + static constexpr const char *Name = "array_grade_up"; +}; + +struct GradeUpFun { + using ALIAS = ListGradeUpFun; + + static constexpr const char *Name = "grade_up"; +}; + +struct ListReverseSortFun { + static constexpr const char *Name = "list_reverse_sort"; + static constexpr const char *Parameters = "list"; + static constexpr const char *Description = "Sorts the elements of the list in reverse order"; + static constexpr const char *Example = "list_reverse_sort([3, 6, 1, 2])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ArrayReverseSortFun { + using ALIAS = ListReverseSortFun; + + static constexpr const char *Name = "array_reverse_sort"; +}; + +struct ListTransformFun { + static constexpr const char *Name = "list_transform"; + static constexpr const char *Parameters = "list,lambda"; + static constexpr const char *Description = "Returns a list that is the result of applying the lambda function to each element of the input list. See the Lambda Functions section for more details"; + static constexpr const char *Example = "list_transform([1, 2, 3], x -> x + 1)"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayTransformFun { + using ALIAS = ListTransformFun; + + static constexpr const char *Name = "array_transform"; +}; + +struct ListApplyFun { + using ALIAS = ListTransformFun; + + static constexpr const char *Name = "list_apply"; +}; + +struct ArrayApplyFun { + using ALIAS = ListTransformFun; + + static constexpr const char *Name = "array_apply"; +}; + +struct ApplyFun { + using ALIAS = ListTransformFun; + + static constexpr const char *Name = "apply"; +}; + +struct ListFilterFun { + static constexpr const char *Name = "list_filter"; + static constexpr const char *Parameters = "list,lambda"; + static constexpr const char *Description = "Constructs a list from those elements of the input list for which the lambda function returns true"; + static constexpr const char *Example = "list_filter([3, 4, 5], x -> x > 4)"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayFilterFun { + using ALIAS = ListFilterFun; + + static constexpr const char *Name = "array_filter"; +}; + +struct FilterFun { + using ALIAS = ListFilterFun; + + static constexpr const char *Name = "filter"; +}; + +struct ListReduceFun { + static constexpr const char *Name = "list_reduce"; + static constexpr const char *Parameters = "list,lambda"; + static constexpr const char *Description = "Returns a single value that is the result of applying the lambda function to each element of the input list, starting with the first element and then repeatedly applying the lambda function to the result of the previous application and the next element of the list."; + static constexpr const char *Example = "list_reduce([1, 2, 3], (x, y) -> x + y)"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayReduceFun { + using ALIAS = ListReduceFun; + + static constexpr const char *Name = "array_reduce"; +}; + +struct ReduceFun { + using ALIAS = ListReduceFun; + + static constexpr const char *Name = "reduce"; +}; + +struct GenerateSeriesFun { + static constexpr const char *Name = "generate_series"; + static constexpr const char *Parameters = "start,stop,step"; + static constexpr const char *Description = "Create a list of values between start and stop - the stop parameter is inclusive"; + static constexpr const char *Example = "generate_series(2, 5, 3)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListRangeFun { + static constexpr const char *Name = "range"; + static constexpr const char *Parameters = "start,stop,step"; + static constexpr const char *Description = "Create a list of values between start and stop - the stop parameter is exclusive"; + static constexpr const char *Example = "range(2, 5, 3)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListCosineDistanceFun { + static constexpr const char *Name = "list_cosine_distance"; + static constexpr const char *Parameters = "list1,list2"; + static constexpr const char *Description = "Compute the cosine distance between two lists"; + static constexpr const char *Example = "list_cosine_distance([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListCosineDistanceFunAlias { + using ALIAS = ListCosineDistanceFun; + + static constexpr const char *Name = "<=>"; +}; + +struct ListCosineSimilarityFun { + static constexpr const char *Name = "list_cosine_similarity"; + static constexpr const char *Parameters = "list1,list2"; + static constexpr const char *Description = "Compute the cosine similarity between two lists"; + static constexpr const char *Example = "list_cosine_similarity([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListDistanceFun { + static constexpr const char *Name = "list_distance"; + static constexpr const char *Parameters = "list1,list2"; + static constexpr const char *Description = "Compute the distance between two lists"; + static constexpr const char *Example = "list_distance([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListDistanceFunAlias { + using ALIAS = ListDistanceFun; + + static constexpr const char *Name = "<->"; +}; + +struct ListInnerProductFun { + static constexpr const char *Name = "list_inner_product"; + static constexpr const char *Parameters = "list1,list2"; + static constexpr const char *Description = "Compute the inner product between two lists"; + static constexpr const char *Example = "list_inner_product([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListDotProductFun { + using ALIAS = ListInnerProductFun; + + static constexpr const char *Name = "list_dot_product"; +}; + +struct ListNegativeInnerProductFun { + static constexpr const char *Name = "list_negative_inner_product"; + static constexpr const char *Parameters = "list1,list2"; + static constexpr const char *Description = "Compute the negative inner product between two lists"; + static constexpr const char *Example = "list_negative_inner_product([1, 2, 3], [1, 2, 3])"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ListNegativeDotProductFun { + using ALIAS = ListNegativeInnerProductFun; + + static constexpr const char *Name = "list_negative_dot_product"; +}; + +struct UnpivotListFun { + static constexpr const char *Name = "unpivot_list"; + static constexpr const char *Parameters = "any,..."; + static constexpr const char *Description = "Identical to list_value, but generated as part of unpivot for better error messages"; + static constexpr const char *Example = "unpivot_list(4, 5, 6)"; + + static ScalarFunction GetFunction(); +}; + +struct ListHasAnyFun { + static constexpr const char *Name = "list_has_any"; + static constexpr const char *Parameters = "l1, l2"; + static constexpr const char *Description = "Returns true if the lists have any element in common. NULLs are ignored."; + static constexpr const char *Example = "list_has_any([1, 2, 3], [2, 3, 4])"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayHasAnyFun { + using ALIAS = ListHasAnyFun; + + static constexpr const char *Name = "array_has_any"; +}; + +struct ListHasAnyFunAlias { + using ALIAS = ListHasAnyFun; + + static constexpr const char *Name = "&&"; +}; + +struct ListHasAllFun { + static constexpr const char *Name = "list_has_all"; + static constexpr const char *Parameters = "l1, l2"; + static constexpr const char *Description = "Returns true if all elements of l2 are in l1. NULLs are ignored."; + static constexpr const char *Example = "list_has_all([1, 2, 3], [2, 3])"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayHasAllFun { + using ALIAS = ListHasAllFun; + + static constexpr const char *Name = "array_has_all"; +}; + +struct ListHasAllFunAlias { + using ALIAS = ListHasAllFun; + + static constexpr const char *Name = "@>"; +}; + +struct ListHasAllFunAlias2 { + using ALIAS = ListHasAllFun; + + static constexpr const char *Name = "<@"; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/map_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/map_functions.hpp new file mode 100644 index 00000000..0998a315 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/map_functions.hpp @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/map_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct CardinalityFun { + static constexpr const char *Name = "cardinality"; + static constexpr const char *Parameters = "map"; + static constexpr const char *Description = "Returns the size of the map (or the number of entries in the map)"; + static constexpr const char *Example = "cardinality( map([4, 2], ['a', 'b']) );"; + + static ScalarFunction GetFunction(); +}; + +struct MapFun { + static constexpr const char *Name = "map"; + static constexpr const char *Parameters = "keys,values"; + static constexpr const char *Description = "Creates a map from a set of keys and values"; + static constexpr const char *Example = "map(['key1', 'key2'], ['val1', 'val2'])"; + + static ScalarFunction GetFunction(); +}; + +struct MapEntriesFun { + static constexpr const char *Name = "map_entries"; + static constexpr const char *Parameters = "map"; + static constexpr const char *Description = "Returns the map entries as a list of keys/values"; + static constexpr const char *Example = "map_entries(map(['key'], ['val']))"; + + static ScalarFunction GetFunction(); +}; + +struct MapExtractFun { + static constexpr const char *Name = "map_extract"; + static constexpr const char *Parameters = "map,key"; + static constexpr const char *Description = "Returns a list containing the value for a given key or an empty list if the key is not contained in the map. The type of the key provided in the second parameter must match the type of the map’s keys else an error is returned"; + static constexpr const char *Example = "map_extract(map(['key'], ['val']), 'key')"; + + static ScalarFunction GetFunction(); +}; + +struct ElementAtFun { + using ALIAS = MapExtractFun; + + static constexpr const char *Name = "element_at"; +}; + +struct MapFromEntriesFun { + static constexpr const char *Name = "map_from_entries"; + static constexpr const char *Parameters = "map"; + static constexpr const char *Description = "Returns a map created from the entries of the array"; + static constexpr const char *Example = "map_from_entries([{k: 5, v: 'val1'}, {k: 3, v: 'val2'}]);"; + + static ScalarFunction GetFunction(); +}; + +struct MapConcatFun { + static constexpr const char *Name = "map_concat"; + static constexpr const char *Parameters = "any,..."; + static constexpr const char *Description = "Returns a map created from merging the input maps, on key collision the value is taken from the last map with that key"; + static constexpr const char *Example = "map_concat(map([1,2], ['a', 'b']), map([2,3], ['c', 'd']));"; + + static ScalarFunction GetFunction(); +}; + +struct MapKeysFun { + static constexpr const char *Name = "map_keys"; + static constexpr const char *Parameters = "map"; + static constexpr const char *Description = "Returns the keys of a map as a list"; + static constexpr const char *Example = "map_keys(map(['key'], ['val']))"; + + static ScalarFunction GetFunction(); +}; + +struct MapValuesFun { + static constexpr const char *Name = "map_values"; + static constexpr const char *Parameters = "map"; + static constexpr const char *Description = "Returns the values of a map as a list"; + static constexpr const char *Example = "map_values(map(['key'], ['val']))"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/math_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/math_functions.hpp new file mode 100644 index 00000000..7b8e2bef --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/math_functions.hpp @@ -0,0 +1,453 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/math_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct AbsOperatorFun { + static constexpr const char *Name = "@"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Absolute value"; + static constexpr const char *Example = "abs(-17.4)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct AbsFun { + using ALIAS = AbsOperatorFun; + + static constexpr const char *Name = "abs"; +}; + +struct PowOperatorFun { + static constexpr const char *Name = "**"; + static constexpr const char *Parameters = "x,y"; + static constexpr const char *Description = "Computes x to the power of y"; + static constexpr const char *Example = "pow(2, 3)"; + + static ScalarFunction GetFunction(); +}; + +struct PowFun { + using ALIAS = PowOperatorFun; + + static constexpr const char *Name = "pow"; +}; + +struct PowerFun { + using ALIAS = PowOperatorFun; + + static constexpr const char *Name = "power"; +}; + +struct PowOperatorFunAlias { + using ALIAS = PowOperatorFun; + + static constexpr const char *Name = "^"; +}; + +struct FactorialOperatorFun { + static constexpr const char *Name = "!__postfix"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Factorial of x. Computes the product of the current integer and all integers below it"; + static constexpr const char *Example = "4!"; + + static ScalarFunction GetFunction(); +}; + +struct FactorialFun { + using ALIAS = FactorialOperatorFun; + + static constexpr const char *Name = "factorial"; +}; + +struct AcosFun { + static constexpr const char *Name = "acos"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the arccosine of x"; + static constexpr const char *Example = "acos(0.5)"; + + static ScalarFunction GetFunction(); +}; + +struct AsinFun { + static constexpr const char *Name = "asin"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the arcsine of x"; + static constexpr const char *Example = "asin(0.5)"; + + static ScalarFunction GetFunction(); +}; + +struct AtanFun { + static constexpr const char *Name = "atan"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the arctangent of x"; + static constexpr const char *Example = "atan(0.5)"; + + static ScalarFunction GetFunction(); +}; + +struct Atan2Fun { + static constexpr const char *Name = "atan2"; + static constexpr const char *Parameters = "y,x"; + static constexpr const char *Description = "Computes the arctangent (y, x)"; + static constexpr const char *Example = "atan2(1.0, 0.0)"; + + static ScalarFunction GetFunction(); +}; + +struct BitCountFun { + static constexpr const char *Name = "bit_count"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the number of bits that are set"; + static constexpr const char *Example = "bit_count(31)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct CbrtFun { + static constexpr const char *Name = "cbrt"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the cube root of x"; + static constexpr const char *Example = "cbrt(8)"; + + static ScalarFunction GetFunction(); +}; + +struct CeilFun { + static constexpr const char *Name = "ceil"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Rounds the number up"; + static constexpr const char *Example = "ceil(17.4)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct CeilingFun { + using ALIAS = CeilFun; + + static constexpr const char *Name = "ceiling"; +}; + +struct CosFun { + static constexpr const char *Name = "cos"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the cos of x"; + static constexpr const char *Example = "cos(90)"; + + static ScalarFunction GetFunction(); +}; + +struct CotFun { + static constexpr const char *Name = "cot"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the cotangent of x"; + static constexpr const char *Example = "cot(0.5)"; + + static ScalarFunction GetFunction(); +}; + +struct DegreesFun { + static constexpr const char *Name = "degrees"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Converts radians to degrees"; + static constexpr const char *Example = "degrees(pi())"; + + static ScalarFunction GetFunction(); +}; + +struct EvenFun { + static constexpr const char *Name = "even"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Rounds x to next even number by rounding away from zero"; + static constexpr const char *Example = "even(2.9)"; + + static ScalarFunction GetFunction(); +}; + +struct ExpFun { + static constexpr const char *Name = "exp"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes e to the power of x"; + static constexpr const char *Example = "exp(1)"; + + static ScalarFunction GetFunction(); +}; + +struct FloorFun { + static constexpr const char *Name = "floor"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Rounds the number down"; + static constexpr const char *Example = "floor(17.4)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct IsFiniteFun { + static constexpr const char *Name = "isfinite"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns true if the floating point value is finite, false otherwise"; + static constexpr const char *Example = "isfinite(5.5)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct IsInfiniteFun { + static constexpr const char *Name = "isinf"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns true if the floating point value is infinite, false otherwise"; + static constexpr const char *Example = "isinf('Infinity'::float)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct IsNanFun { + static constexpr const char *Name = "isnan"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns true if the floating point value is not a number, false otherwise"; + static constexpr const char *Example = "isnan('NaN'::FLOAT)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct GammaFun { + static constexpr const char *Name = "gamma"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Interpolation of (x-1) factorial (so decimal inputs are allowed)"; + static constexpr const char *Example = "gamma(5.5)"; + + static ScalarFunction GetFunction(); +}; + +struct GreatestCommonDivisorFun { + static constexpr const char *Name = "greatest_common_divisor"; + static constexpr const char *Parameters = "x,y"; + static constexpr const char *Description = "Computes the greatest common divisor of x and y"; + static constexpr const char *Example = "greatest_common_divisor(42, 57)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct GcdFun { + using ALIAS = GreatestCommonDivisorFun; + + static constexpr const char *Name = "gcd"; +}; + +struct LeastCommonMultipleFun { + static constexpr const char *Name = "least_common_multiple"; + static constexpr const char *Parameters = "x,y"; + static constexpr const char *Description = "Computes the least common multiple of x and y"; + static constexpr const char *Example = "least_common_multiple(42, 57)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct LcmFun { + using ALIAS = LeastCommonMultipleFun; + + static constexpr const char *Name = "lcm"; +}; + +struct LogGammaFun { + static constexpr const char *Name = "lgamma"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the log of the gamma function"; + static constexpr const char *Example = "lgamma(2)"; + + static ScalarFunction GetFunction(); +}; + +struct LnFun { + static constexpr const char *Name = "ln"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the natural logarithm of x"; + static constexpr const char *Example = "ln(2)"; + + static ScalarFunction GetFunction(); +}; + +struct Log2Fun { + static constexpr const char *Name = "log2"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the 2-log of x"; + static constexpr const char *Example = "log2(8)"; + + static ScalarFunction GetFunction(); +}; + +struct Log10Fun { + static constexpr const char *Name = "log10"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the 10-log of x"; + static constexpr const char *Example = "log10(1000)"; + + static ScalarFunction GetFunction(); +}; + +struct LogFun { + static constexpr const char *Name = "log"; + static constexpr const char *Parameters = "b, x"; + static constexpr const char *Description = "Computes the logarithm of x to base b. b may be omitted, in which case the default 10"; + static constexpr const char *Example = "log(2, 64)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct NextAfterFun { + static constexpr const char *Name = "nextafter"; + static constexpr const char *Parameters = "x, y"; + static constexpr const char *Description = "Returns the next floating point value after x in the direction of y"; + static constexpr const char *Example = "nextafter(1::float, 2::float)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct PiFun { + static constexpr const char *Name = "pi"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns the value of pi"; + static constexpr const char *Example = "pi()"; + + static ScalarFunction GetFunction(); +}; + +struct RadiansFun { + static constexpr const char *Name = "radians"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Converts degrees to radians"; + static constexpr const char *Example = "radians(90)"; + + static ScalarFunction GetFunction(); +}; + +struct RoundFun { + static constexpr const char *Name = "round"; + static constexpr const char *Parameters = "x,precision"; + static constexpr const char *Description = "Rounds x to s decimal places"; + static constexpr const char *Example = "round(42.4332, 2)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct SignFun { + static constexpr const char *Name = "sign"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the sign of x as -1, 0 or 1"; + static constexpr const char *Example = "sign(-349)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct SignBitFun { + static constexpr const char *Name = "signbit"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns whether the signbit is set or not"; + static constexpr const char *Example = "signbit(-0.0)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct SinFun { + static constexpr const char *Name = "sin"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the sin of x"; + static constexpr const char *Example = "sin(90)"; + + static ScalarFunction GetFunction(); +}; + +struct SqrtFun { + static constexpr const char *Name = "sqrt"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Returns the square root of x"; + static constexpr const char *Example = "sqrt(4)"; + + static ScalarFunction GetFunction(); +}; + +struct TanFun { + static constexpr const char *Name = "tan"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the tan of x"; + static constexpr const char *Example = "tan(90)"; + + static ScalarFunction GetFunction(); +}; + +struct TruncFun { + static constexpr const char *Name = "trunc"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Truncates the number"; + static constexpr const char *Example = "trunc(17.4)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct CoshFun { + static constexpr const char *Name = "cosh"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the hyperbolic cos of x"; + static constexpr const char *Example = "cosh(1)"; + + static ScalarFunction GetFunction(); +}; + +struct SinhFun { + static constexpr const char *Name = "sinh"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the hyperbolic sin of x"; + static constexpr const char *Example = "sinh(1)"; + + static ScalarFunction GetFunction(); +}; + +struct TanhFun { + static constexpr const char *Name = "tanh"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the hyperbolic tan of x"; + static constexpr const char *Example = "tanh(1)"; + + static ScalarFunction GetFunction(); +}; + +struct AcoshFun { + static constexpr const char *Name = "acosh"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the inverse hyperbolic cos of x"; + static constexpr const char *Example = "acosh(2.3)"; + + static ScalarFunction GetFunction(); +}; + +struct AsinhFun { + static constexpr const char *Name = "asinh"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the inverse hyperbolic sin of x"; + static constexpr const char *Example = "asinh(0.5)"; + + static ScalarFunction GetFunction(); +}; + +struct AtanhFun { + static constexpr const char *Name = "atanh"; + static constexpr const char *Parameters = "x"; + static constexpr const char *Description = "Computes the inverse hyperbolic tan of x"; + static constexpr const char *Example = "atanh(0.5)"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/operators_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/operators_functions.hpp new file mode 100644 index 00000000..3bbfc565 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/operators_functions.hpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/operators_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct BitwiseAndFun { + static constexpr const char *Name = "&"; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Bitwise AND"; + static constexpr const char *Example = "91 & 15"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct BitwiseOrFun { + static constexpr const char *Name = "|"; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Bitwise OR"; + static constexpr const char *Example = "32 | 3"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct BitwiseNotFun { + static constexpr const char *Name = "~"; + static constexpr const char *Parameters = "input"; + static constexpr const char *Description = "Bitwise NOT"; + static constexpr const char *Example = "~15"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct LeftShiftFun { + static constexpr const char *Name = "<<"; + static constexpr const char *Parameters = "input"; + static constexpr const char *Description = "Bitwise shift left"; + static constexpr const char *Example = "1 << 4"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct RightShiftFun { + static constexpr const char *Name = ">>"; + static constexpr const char *Parameters = "input"; + static constexpr const char *Description = "Bitwise shift right"; + static constexpr const char *Example = "8 >> 2"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct BitwiseXorFun { + static constexpr const char *Name = "xor"; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Bitwise XOR"; + static constexpr const char *Example = "xor(17, 5)"; + + static ScalarFunctionSet GetFunctions(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/random_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/random_functions.hpp new file mode 100644 index 00000000..1002f0e4 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/random_functions.hpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/random_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct RandomFun { + static constexpr const char *Name = "random"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns a random number between 0 and 1"; + static constexpr const char *Example = "random()"; + + static ScalarFunction GetFunction(); +}; + +struct SetseedFun { + static constexpr const char *Name = "setseed"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Sets the seed to be used for the random function"; + static constexpr const char *Example = "setseed(0.42)"; + + static ScalarFunction GetFunction(); +}; + +struct UUIDFun { + static constexpr const char *Name = "uuid"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = "Returns a random UUID similar to this: eeccb8c5-9943-b2bb-bb5e-222f4e14b687"; + static constexpr const char *Example = "uuid()"; + + static ScalarFunction GetFunction(); +}; + +struct GenRandomUuidFun { + using ALIAS = UUIDFun; + + static constexpr const char *Name = "gen_random_uuid"; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/secret_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/secret_functions.hpp new file mode 100644 index 00000000..17e5614e --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/secret_functions.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/core_functions/scalar/secret_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct WhichSecretFun { + static constexpr const char *Name = "which_secret"; + static constexpr const char *Parameters = "path,type"; + static constexpr const char *Description = "Print out the name of the secret that will be used for reading a path"; + static constexpr const char *Example = "which_secret('s3://some/authenticated/path.csv', 's3')"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/string_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/string_functions.hpp new file mode 100644 index 00000000..6a6db36d --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/string_functions.hpp @@ -0,0 +1,444 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/string_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct StartsWithOperatorFun { + static constexpr const char *Name = "^@"; + static constexpr const char *Parameters = "string,search_string"; + static constexpr const char *Description = "Returns true if string begins with search_string"; + static constexpr const char *Example = "starts_with('abc','a')"; + + static ScalarFunction GetFunction(); +}; + +struct StartsWithFun { + using ALIAS = StartsWithOperatorFun; + + static constexpr const char *Name = "starts_with"; +}; + +struct ASCIIFun { + static constexpr const char *Name = "ascii"; + static constexpr const char *Parameters = "string"; + static constexpr const char *Description = "Returns an integer that represents the Unicode code point of the first character of the string"; + static constexpr const char *Example = "ascii('Ω')"; + + static ScalarFunction GetFunction(); +}; + +struct BarFun { + static constexpr const char *Name = "bar"; + static constexpr const char *Parameters = "x,min,max,width"; + static constexpr const char *Description = "Draws a band whose width is proportional to (x - min) and equal to width characters when x = max. width defaults to 80"; + static constexpr const char *Example = "bar(5, 0, 20, 10)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct BinFun { + static constexpr const char *Name = "bin"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Converts the value to binary representation"; + static constexpr const char *Example = "bin(42)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ToBinaryFun { + using ALIAS = BinFun; + + static constexpr const char *Name = "to_binary"; +}; + +struct ChrFun { + static constexpr const char *Name = "chr"; + static constexpr const char *Parameters = "code_point"; + static constexpr const char *Description = "Returns a character which is corresponding the ASCII code value or Unicode code point"; + static constexpr const char *Example = "chr(65)"; + + static ScalarFunction GetFunction(); +}; + +struct DamerauLevenshteinFun { + static constexpr const char *Name = "damerau_levenshtein"; + static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Description = "Extension of Levenshtein distance to also include transposition of adjacent characters as an allowed edit operation. In other words, the minimum number of edit operations (insertions, deletions, substitutions or transpositions) required to change one string to another. Different case is considered different"; + static constexpr const char *Example = "damerau_levenshtein('hello', 'world')"; + + static ScalarFunction GetFunction(); +}; + +struct FormatFun { + static constexpr const char *Name = "format"; + static constexpr const char *Parameters = "format,parameters..."; + static constexpr const char *Description = "Formats a string using fmt syntax"; + static constexpr const char *Example = "format('Benchmark \"{}\" took {} seconds', 'CSV', 42)"; + + static ScalarFunction GetFunction(); +}; + +struct FormatBytesFun { + static constexpr const char *Name = "format_bytes"; + static constexpr const char *Parameters = "bytes"; + static constexpr const char *Description = "Converts bytes to a human-readable presentation (e.g. 16000 -> 15.6 KiB)"; + static constexpr const char *Example = "format_bytes(1000 * 16)"; + + static ScalarFunction GetFunction(); +}; + +struct FormatreadablesizeFun { + using ALIAS = FormatBytesFun; + + static constexpr const char *Name = "formatReadableSize"; +}; + +struct FormatreadabledecimalsizeFun { + static constexpr const char *Name = "formatReadableDecimalSize"; + static constexpr const char *Parameters = "bytes"; + static constexpr const char *Description = "Converts bytes to a human-readable presentation (e.g. 16000 -> 16.0 KB)"; + static constexpr const char *Example = "format_bytes(1000 * 16)"; + + static ScalarFunction GetFunction(); +}; + +struct HammingFun { + static constexpr const char *Name = "hamming"; + static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Description = "The number of positions with different characters for 2 strings of equal length. Different case is considered different"; + static constexpr const char *Example = "hamming('duck','luck')"; + + static ScalarFunction GetFunction(); +}; + +struct MismatchesFun { + using ALIAS = HammingFun; + + static constexpr const char *Name = "mismatches"; +}; + +struct HexFun { + static constexpr const char *Name = "hex"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Converts the value to hexadecimal representation"; + static constexpr const char *Example = "hex(42)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ToHexFun { + using ALIAS = HexFun; + + static constexpr const char *Name = "to_hex"; +}; + +struct InstrFun { + static constexpr const char *Name = "instr"; + static constexpr const char *Parameters = "haystack,needle"; + static constexpr const char *Description = "Returns location of first occurrence of needle in haystack, counting from 1. Returns 0 if no match found"; + static constexpr const char *Example = "instr('test test','es')"; + + static ScalarFunction GetFunction(); +}; + +struct StrposFun { + using ALIAS = InstrFun; + + static constexpr const char *Name = "strpos"; +}; + +struct PositionFun { + using ALIAS = InstrFun; + + static constexpr const char *Name = "position"; +}; + +struct JaccardFun { + static constexpr const char *Name = "jaccard"; + static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Description = "The Jaccard similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; + static constexpr const char *Example = "jaccard('duck','luck')"; + + static ScalarFunction GetFunction(); +}; + +struct JaroSimilarityFun { + static constexpr const char *Name = "jaro_similarity"; + static constexpr const char *Parameters = "str1,str2,score_cutoff"; + static constexpr const char *Description = "The Jaro similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; + static constexpr const char *Example = "jaro_similarity('duck', 'duckdb', 0.5)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct JaroWinklerSimilarityFun { + static constexpr const char *Name = "jaro_winkler_similarity"; + static constexpr const char *Parameters = "str1,str2,score_cutoff"; + static constexpr const char *Description = "The Jaro-Winkler similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; + static constexpr const char *Example = "jaro_winkler_similarity('duck', 'duckdb', 0.5)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct LeftFun { + static constexpr const char *Name = "left"; + static constexpr const char *Parameters = "string,count"; + static constexpr const char *Description = "Extract the left-most count characters"; + static constexpr const char *Example = "left('Hello🦆', 2)"; + + static ScalarFunction GetFunction(); +}; + +struct LeftGraphemeFun { + static constexpr const char *Name = "left_grapheme"; + static constexpr const char *Parameters = "string,count"; + static constexpr const char *Description = "Extract the left-most count grapheme clusters"; + static constexpr const char *Example = "left_grapheme('🤦🏼‍♂️🤦🏽‍♀️', 1)"; + + static ScalarFunction GetFunction(); +}; + +struct LevenshteinFun { + static constexpr const char *Name = "levenshtein"; + static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Description = "The minimum number of single-character edits (insertions, deletions or substitutions) required to change one string to the other. Different case is considered different"; + static constexpr const char *Example = "levenshtein('duck','db')"; + + static ScalarFunction GetFunction(); +}; + +struct Editdist3Fun { + using ALIAS = LevenshteinFun; + + static constexpr const char *Name = "editdist3"; +}; + +struct LpadFun { + static constexpr const char *Name = "lpad"; + static constexpr const char *Parameters = "string,count,character"; + static constexpr const char *Description = "Pads the string with the character from the left until it has count characters"; + static constexpr const char *Example = "lpad('hello', 10, '>')"; + + static ScalarFunction GetFunction(); +}; + +struct LtrimFun { + static constexpr const char *Name = "ltrim"; + static constexpr const char *Parameters = "string,characters"; + static constexpr const char *Description = "Removes any occurrences of any of the characters from the left side of the string"; + static constexpr const char *Example = "ltrim('>>>>test<<', '><')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ParseDirnameFun { + static constexpr const char *Name = "parse_dirname"; + static constexpr const char *Parameters = "string,separator"; + static constexpr const char *Description = "Returns the top-level directory name. separator options: system, both_slash (default), forward_slash, backslash"; + static constexpr const char *Example = "parse_dirname('path/to/file.csv', 'system')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ParseDirpathFun { + static constexpr const char *Name = "parse_dirpath"; + static constexpr const char *Parameters = "string,separator"; + static constexpr const char *Description = "Returns the head of the path similarly to Python's os.path.dirname. separator options: system, both_slash (default), forward_slash, backslash"; + static constexpr const char *Example = "parse_dirpath('path/to/file.csv', 'system')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ParseFilenameFun { + static constexpr const char *Name = "parse_filename"; + static constexpr const char *Parameters = "string,trim_extension,separator"; + static constexpr const char *Description = "Returns the last component of the path similarly to Python's os.path.basename. If trim_extension is true, the file extension will be removed (it defaults to false). separator options: system, both_slash (default), forward_slash, backslash"; + static constexpr const char *Example = "parse_filename('path/to/file.csv', true, 'forward_slash')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ParsePathFun { + static constexpr const char *Name = "parse_path"; + static constexpr const char *Parameters = "string,separator"; + static constexpr const char *Description = "Returns a list of the components (directories and filename) in the path similarly to Python's pathlib.PurePath::parts. separator options: system, both_slash (default), forward_slash, backslash"; + static constexpr const char *Example = "parse_path('path/to/file.csv', 'system')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct PrintfFun { + static constexpr const char *Name = "printf"; + static constexpr const char *Parameters = "format,parameters..."; + static constexpr const char *Description = "Formats a string using printf syntax"; + static constexpr const char *Example = "printf('Benchmark \"%s\" took %d seconds', 'CSV', 42)"; + + static ScalarFunction GetFunction(); +}; + +struct RepeatFun { + static constexpr const char *Name = "repeat"; + static constexpr const char *Parameters = "string,count"; + static constexpr const char *Description = "Repeats the string count number of times"; + static constexpr const char *Example = "repeat('A', 5)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct ReplaceFun { + static constexpr const char *Name = "replace"; + static constexpr const char *Parameters = "string,source,target"; + static constexpr const char *Description = "Replaces any occurrences of the source with target in string"; + static constexpr const char *Example = "replace('hello', 'l', '-')"; + + static ScalarFunction GetFunction(); +}; + +struct ReverseFun { + static constexpr const char *Name = "reverse"; + static constexpr const char *Parameters = "string"; + static constexpr const char *Description = "Reverses the string"; + static constexpr const char *Example = "reverse('hello')"; + + static ScalarFunction GetFunction(); +}; + +struct RightFun { + static constexpr const char *Name = "right"; + static constexpr const char *Parameters = "string,count"; + static constexpr const char *Description = "Extract the right-most count characters"; + static constexpr const char *Example = "right('Hello🦆', 3)"; + + static ScalarFunction GetFunction(); +}; + +struct RightGraphemeFun { + static constexpr const char *Name = "right_grapheme"; + static constexpr const char *Parameters = "string,count"; + static constexpr const char *Description = "Extract the right-most count grapheme clusters"; + static constexpr const char *Example = "right_grapheme('🤦🏼‍♂️🤦🏽‍♀️', 1)"; + + static ScalarFunction GetFunction(); +}; + +struct RpadFun { + static constexpr const char *Name = "rpad"; + static constexpr const char *Parameters = "string,count,character"; + static constexpr const char *Description = "Pads the string with the character from the right until it has count characters"; + static constexpr const char *Example = "rpad('hello', 10, '<')"; + + static ScalarFunction GetFunction(); +}; + +struct RtrimFun { + static constexpr const char *Name = "rtrim"; + static constexpr const char *Parameters = "string,characters"; + static constexpr const char *Description = "Removes any occurrences of any of the characters from the right side of the string"; + static constexpr const char *Example = "rtrim('>>>>test<<', '><')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct TranslateFun { + static constexpr const char *Name = "translate"; + static constexpr const char *Parameters = "string,from,to"; + static constexpr const char *Description = "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted"; + static constexpr const char *Example = "translate('12345', '143', 'ax')"; + + static ScalarFunction GetFunction(); +}; + +struct TrimFun { + static constexpr const char *Name = "trim"; + static constexpr const char *Parameters = "string::VARCHAR\1string::VARCHAR,characters::VARCHAR"; + static constexpr const char *Description = "Removes any spaces from either side of the string.\1Removes any occurrences of any of the characters from either side of the string"; + static constexpr const char *Example = "trim(' test ')\1trim('>>>>test<<', '><')"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct UnbinFun { + static constexpr const char *Name = "unbin"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Converts a value from binary representation to a blob"; + static constexpr const char *Example = "unbin('0110')"; + + static ScalarFunction GetFunction(); +}; + +struct FromBinaryFun { + using ALIAS = UnbinFun; + + static constexpr const char *Name = "from_binary"; +}; + +struct UnhexFun { + static constexpr const char *Name = "unhex"; + static constexpr const char *Parameters = "value"; + static constexpr const char *Description = "Converts a value from hexadecimal representation to a blob"; + static constexpr const char *Example = "unhex('2A')"; + + static ScalarFunction GetFunction(); +}; + +struct FromHexFun { + using ALIAS = UnhexFun; + + static constexpr const char *Name = "from_hex"; +}; + +struct UnicodeFun { + static constexpr const char *Name = "unicode"; + static constexpr const char *Parameters = "str"; + static constexpr const char *Description = "Returns the unicode codepoint of the first character of the string"; + static constexpr const char *Example = "unicode('ü')"; + + static ScalarFunction GetFunction(); +}; + +struct OrdFun { + using ALIAS = UnicodeFun; + + static constexpr const char *Name = "ord"; +}; + +struct ToBaseFun { + static constexpr const char *Name = "to_base"; + static constexpr const char *Parameters = "number,radix,min_length"; + static constexpr const char *Description = "Converts a value to a string in the given base radix, optionally padding with leading zeros to the minimum length"; + static constexpr const char *Example = "to_base(42, 16)"; + + static ScalarFunctionSet GetFunctions(); +}; + +struct UrlEncodeFun { + static constexpr const char *Name = "url_encode"; + static constexpr const char *Parameters = "input"; + static constexpr const char *Description = "Escapes the input string by encoding it so that it can be included in a URL query parameter."; + static constexpr const char *Example = "url_encode('this string has/ special+ characters>')"; + + static ScalarFunction GetFunction(); +}; + +struct UrlDecodeFun { + static constexpr const char *Name = "url_decode"; + static constexpr const char *Parameters = "input"; + static constexpr const char *Description = "Unescapes the URL encoded input."; + static constexpr const char *Example = "url_decode('this%20string%20is%2BFencoded')"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/struct_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/struct_functions.hpp new file mode 100644 index 00000000..f921bf43 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/struct_functions.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/struct_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct StructInsertFun { + static constexpr const char *Name = "struct_insert"; + static constexpr const char *Parameters = "struct,any"; + static constexpr const char *Description = "Adds field(s)/value(s) to an existing STRUCT with the argument values. The entry name(s) will be the bound variable name(s)"; + static constexpr const char *Example = "struct_insert({'a': 1}, b := 2)"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/union_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/union_functions.hpp new file mode 100644 index 00000000..766c12e8 --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/union_functions.hpp @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions/scalar/union_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct UnionExtractFun { + static constexpr const char *Name = "union_extract"; + static constexpr const char *Parameters = "union,tag"; + static constexpr const char *Description = "Extract the value with the named tags from the union. NULL if the tag is not currently selected"; + static constexpr const char *Example = "union_extract(s, 'k')"; + + static ScalarFunction GetFunction(); +}; + +struct UnionTagFun { + static constexpr const char *Name = "union_tag"; + static constexpr const char *Parameters = "union"; + static constexpr const char *Description = "Retrieve the currently selected tag of the union as an ENUM"; + static constexpr const char *Example = "union_tag(union_value(k := 'foo'))"; + + static ScalarFunction GetFunction(); +}; + +struct UnionValueFun { + static constexpr const char *Name = "union_value"; + static constexpr const char *Parameters = "tag"; + static constexpr const char *Description = "Create a single member UNION containing the argument value. The tag of the value will be the bound variable name"; + static constexpr const char *Example = "union_value(k := 'hello')"; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions_extension.hpp b/src/duckdb/extension/core_functions/include/core_functions_extension.hpp new file mode 100644 index 00000000..e877860f --- /dev/null +++ b/src/duckdb/extension/core_functions/include/core_functions_extension.hpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// core_functions_extension.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.hpp" + +namespace duckdb { + +class CoreFunctionsExtension : public Extension { +public: + void Load(DuckDB &db) override; + std::string Name() override; + std::string Version() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/lambda_functions.cpp b/src/duckdb/extension/core_functions/lambda_functions.cpp new file mode 100644 index 00000000..b5549914 --- /dev/null +++ b/src/duckdb/extension/core_functions/lambda_functions.cpp @@ -0,0 +1,414 @@ +#include "duckdb/function/lambda_functions.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Helper functions +//===--------------------------------------------------------------------===// + +//! LambdaExecuteInfo holds information for executing the lambda expression on an input chunk and +//! a resulting lambda chunk. +struct LambdaExecuteInfo { + LambdaExecuteInfo(ClientContext &context, const Expression &lambda_expr, const DataChunk &args, + const bool has_index, const Vector &child_vector) + : has_index(has_index) { + + expr_executor = make_uniq(context, lambda_expr); + + // get the input types for the input chunk + vector input_types; + if (has_index) { + input_types.push_back(LogicalType::BIGINT); + } + input_types.push_back(child_vector.GetType()); + for (idx_t i = 1; i < args.ColumnCount(); i++) { + input_types.push_back(args.data[i].GetType()); + } + + // get the result types + vector result_types {lambda_expr.return_type}; + + // initialize the data chunks + input_chunk.InitializeEmpty(input_types); + lambda_chunk.Initialize(Allocator::DefaultAllocator(), result_types); + }; + + //! The expression executor that executes the lambda expression + unique_ptr expr_executor; + //! The input chunk on which we execute the lambda expression + DataChunk input_chunk; + //! The chunk holding the result of executing the lambda expression + DataChunk lambda_chunk; + //! True, if this lambda expression expects an index vector in the input chunk + bool has_index; +}; + +//! A helper struct with information that is specific to the list_filter function +struct ListFilterInfo { + //! The new list lengths after filtering out elements + vector entry_lengths; + //! The length of the current list + idx_t length = 0; + //! The offset of the current list + idx_t offset = 0; + //! The current row index + idx_t row_idx = 0; + //! The length of the source list + idx_t src_length = 0; +}; + +//! ListTransformFunctor contains list_transform specific functionality +struct ListTransformFunctor { + static void ReserveNewLengths(vector &, const idx_t) { + // NOP + } + static void PushEmptyList(vector &) { + // NOP + } + //! Sets the list entries of the result vector + static void SetResultEntry(list_entry_t *result_entries, idx_t &offset, const list_entry_t &entry, + const idx_t row_idx, vector &) { + result_entries[row_idx].offset = offset; + result_entries[row_idx].length = entry.length; + offset += entry.length; + } + //! Appends the lambda vector to the result's child vector + static void AppendResult(Vector &result, Vector &lambda_vector, const idx_t elem_cnt, list_entry_t *, + ListFilterInfo &, LambdaExecuteInfo &) { + ListVector::Append(result, lambda_vector, elem_cnt, 0); + } +}; + +//! ListFilterFunctor contains list_filter specific functionality +struct ListFilterFunctor { + //! Initializes the entry_lengths vector + static void ReserveNewLengths(vector &entry_lengths, const idx_t row_count) { + entry_lengths.reserve(row_count); + } + //! Pushes an empty list to the entry_lengths vector + static void PushEmptyList(vector &entry_lengths) { + entry_lengths.emplace_back(0); + } + //! Pushes the length of the original list to the entry_lengths vector + static void SetResultEntry(list_entry_t *, idx_t &, const list_entry_t &entry, const idx_t, + vector &entry_lengths) { + entry_lengths.push_back(entry.length); + } + //! Uses the lambda vector to filter the incoming list and to append the filtered list to the result vector + static void AppendResult(Vector &result, Vector &lambda_vector, const idx_t elem_cnt, list_entry_t *result_entries, + ListFilterInfo &info, LambdaExecuteInfo &execute_info) { + + idx_t count = 0; + SelectionVector sel(elem_cnt); + UnifiedVectorFormat lambda_data; + lambda_vector.ToUnifiedFormat(elem_cnt, lambda_data); + + auto lambda_values = UnifiedVectorFormat::GetData(lambda_data); + auto &lambda_validity = lambda_data.validity; + + // compute the new lengths and offsets, and create a selection vector + for (idx_t i = 0; i < elem_cnt; i++) { + auto entry_idx = lambda_data.sel->get_index(i); + + // set length and offset of empty lists + while (info.row_idx < info.entry_lengths.size() && !info.entry_lengths[info.row_idx]) { + result_entries[info.row_idx].offset = info.offset; + result_entries[info.row_idx].length = 0; + info.row_idx++; + } + + // found a true value + if (lambda_validity.RowIsValid(entry_idx) && lambda_values[entry_idx]) { + sel.set_index(count++, i); + info.length++; + } + + info.src_length++; + + // we traversed the entire source list + if (info.entry_lengths[info.row_idx] == info.src_length) { + // set the offset and length of the result entry + result_entries[info.row_idx].offset = info.offset; + result_entries[info.row_idx].length = info.length; + + // reset all other fields + info.offset += info.length; + info.row_idx++; + info.length = 0; + info.src_length = 0; + } + } + + // set length and offset of all remaining empty lists + while (info.row_idx < info.entry_lengths.size() && !info.entry_lengths[info.row_idx]) { + result_entries[info.row_idx].offset = info.offset; + result_entries[info.row_idx].length = 0; + info.row_idx++; + } + + // slice the input chunk's corresponding vector to get the new lists + // and append them to the result + idx_t source_list_idx = execute_info.has_index ? 1 : 0; + Vector result_lists(execute_info.input_chunk.data[source_list_idx], sel, count); + ListVector::Append(result, result_lists, count, 0); + } +}; + +vector LambdaFunctions::GetColumnInfo(DataChunk &args, const idx_t row_count) { + vector data; + // skip the input list and then insert all remaining input vectors + for (idx_t i = 1; i < args.ColumnCount(); i++) { + data.emplace_back(args.data[i]); + args.data[i].ToUnifiedFormat(row_count, data.back().format); + } + return data; +} + +vector> +LambdaFunctions::GetMutableColumnInfo(vector &data) { + vector> inconstant_info; + for (auto &entry : data) { + if (entry.vector.get().GetVectorType() != VectorType::CONSTANT_VECTOR) { + inconstant_info.push_back(entry); + } + } + return inconstant_info; +} + +void ExecuteExpression(const idx_t elem_cnt, const LambdaFunctions::ColumnInfo &column_info, + const vector &column_infos, const Vector &index_vector, + LambdaExecuteInfo &info) { + + info.input_chunk.SetCardinality(elem_cnt); + info.lambda_chunk.SetCardinality(elem_cnt); + + // slice the child vector + Vector slice(column_info.vector, column_info.sel, elem_cnt); + + // reference the child vector (and the index vector) + if (info.has_index) { + info.input_chunk.data[0].Reference(index_vector); + info.input_chunk.data[1].Reference(slice); + } else { + info.input_chunk.data[0].Reference(slice); + } + idx_t slice_offset = info.has_index ? 2 : 1; + + // (slice and) reference the other columns + vector slices; + for (idx_t i = 0; i < column_infos.size(); i++) { + + if (column_infos[i].vector.get().GetVectorType() == VectorType::CONSTANT_VECTOR) { + // only reference constant vectorsl + info.input_chunk.data[i + slice_offset].Reference(column_infos[i].vector); + + } else { + // slice inconstant vectors + slices.emplace_back(column_infos[i].vector, column_infos[i].sel, elem_cnt); + info.input_chunk.data[i + slice_offset].Reference(slices.back()); + } + } + + // execute the lambda expression + info.expr_executor->Execute(info.input_chunk, info.lambda_chunk); +} + +//===--------------------------------------------------------------------===// +// ListLambdaBindData +//===--------------------------------------------------------------------===// + +unique_ptr ListLambdaBindData::Copy() const { + auto lambda_expr_copy = lambda_expr ? lambda_expr->Copy() : nullptr; + return make_uniq(return_type, std::move(lambda_expr_copy), has_index); +} + +bool ListLambdaBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return Expression::Equals(lambda_expr, other.lambda_expr) && return_type == other.return_type && + has_index == other.has_index; +} + +void ListLambdaBindData::Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const ScalarFunction &) { + auto &bind_data = bind_data_p->Cast(); + serializer.WriteProperty(100, "return_type", bind_data.return_type); + serializer.WritePropertyWithDefault(101, "lambda_expr", bind_data.lambda_expr, unique_ptr()); + serializer.WriteProperty(102, "has_index", bind_data.has_index); +} + +unique_ptr ListLambdaBindData::Deserialize(Deserializer &deserializer, ScalarFunction &) { + auto return_type = deserializer.ReadProperty(100, "return_type"); + auto lambda_expr = deserializer.ReadPropertyWithExplicitDefault>(101, "lambda_expr", + unique_ptr()); + auto has_index = deserializer.ReadProperty(102, "has_index"); + return make_uniq(return_type, std::move(lambda_expr), has_index); +} + +//===--------------------------------------------------------------------===// +// LambdaFunctions +//===--------------------------------------------------------------------===// + +LogicalType LambdaFunctions::BindBinaryLambda(const idx_t parameter_idx, const LogicalType &list_child_type) { + switch (parameter_idx) { + case 0: + return list_child_type; + case 1: + return LogicalType::BIGINT; + default: + throw BinderException("This lambda function only supports up to two lambda parameters!"); + } +} + +LogicalType LambdaFunctions::BindTernaryLambda(const idx_t parameter_idx, const LogicalType &list_child_type) { + switch (parameter_idx) { + case 0: + return list_child_type; + case 1: + return list_child_type; + case 2: + return LogicalType::BIGINT; + default: + throw BinderException("This lambda function only supports up to three lambda parameters!"); + } +} + +template +void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &result) { + + bool result_is_null = false; + LambdaFunctions::LambdaInfo info(args, state, result, result_is_null); + if (result_is_null) { + return; + } + + auto result_entries = FlatVector::GetData(result); + auto mutable_column_infos = LambdaFunctions::GetMutableColumnInfo(info.column_infos); + + // special-handling for the child_vector + auto child_vector_size = ListVector::GetListSize(args.data[0]); + LambdaFunctions::ColumnInfo child_info(*info.child_vector); + info.child_vector->ToUnifiedFormat(child_vector_size, child_info.format); + + // get the expression executor + LambdaExecuteInfo execute_info(state.GetContext(), *info.lambda_expr, args, info.has_index, *info.child_vector); + + // get list_filter specific info + ListFilterInfo list_filter_info; + FUNCTION_FUNCTOR::ReserveNewLengths(list_filter_info.entry_lengths, info.row_count); + + // additional index vector + Vector index_vector(LogicalType::BIGINT); + + // loop over the child entries and create chunks to be executed by the expression executor + idx_t elem_cnt = 0; + idx_t offset = 0; + for (idx_t row_idx = 0; row_idx < info.row_count; row_idx++) { + + auto list_idx = info.list_column_format.sel->get_index(row_idx); + const auto &list_entry = info.list_entries[list_idx]; + + // set the result to NULL for this row + if (!info.list_column_format.validity.RowIsValid(list_idx)) { + info.result_validity->SetInvalid(row_idx); + FUNCTION_FUNCTOR::PushEmptyList(list_filter_info.entry_lengths); + continue; + } + + FUNCTION_FUNCTOR::SetResultEntry(result_entries, offset, list_entry, row_idx, list_filter_info.entry_lengths); + + // empty list, nothing to execute + if (list_entry.length == 0) { + continue; + } + + // iterate the elements of the current list and create the corresponding selection vectors + for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { + + // reached STANDARD_VECTOR_SIZE elements + if (elem_cnt == STANDARD_VECTOR_SIZE) { + + execute_info.lambda_chunk.Reset(); + ExecuteExpression(elem_cnt, child_info, info.column_infos, index_vector, execute_info); + auto &lambda_vector = execute_info.lambda_chunk.data[0]; + + FUNCTION_FUNCTOR::AppendResult(result, lambda_vector, elem_cnt, result_entries, list_filter_info, + execute_info); + elem_cnt = 0; + } + + // FIXME: reuse same selection vector for inconstant rows + // adjust indexes for slicing + child_info.sel.set_index(elem_cnt, list_entry.offset + child_idx); + for (auto &entry : mutable_column_infos) { + entry.get().sel.set_index(elem_cnt, row_idx); + } + + // set the index vector + if (info.has_index) { + index_vector.SetValue(elem_cnt, Value::BIGINT(NumericCast(child_idx + 1))); + } + + elem_cnt++; + } + } + + execute_info.lambda_chunk.Reset(); + ExecuteExpression(elem_cnt, child_info, info.column_infos, index_vector, execute_info); + auto &lambda_vector = execute_info.lambda_chunk.data[0]; + + FUNCTION_FUNCTOR::AppendResult(result, lambda_vector, elem_cnt, result_entries, list_filter_info, execute_info); + + if (info.is_all_constant && !info.is_volatile) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +unique_ptr LambdaFunctions::ListLambdaPrepareBind(vector> &arguments, + ClientContext &context, + ScalarFunction &bound_function) { + // NULL list parameter + if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { + bound_function.arguments[0] = LogicalType::SQLNULL; + bound_function.return_type = LogicalType::SQLNULL; + return make_uniq(bound_function.return_type, nullptr); + } + // prepared statements + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + D_ASSERT(arguments[0]->return_type.id() == LogicalTypeId::LIST); + return nullptr; +} + +unique_ptr LambdaFunctions::ListLambdaBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments, + const bool has_index) { + unique_ptr bind_data = ListLambdaPrepareBind(arguments, context, bound_function); + if (bind_data) { + return bind_data; + } + + // get the lambda expression and put it in the bind info + auto &bound_lambda_expr = arguments[1]->Cast(); + auto lambda_expr = std::move(bound_lambda_expr.lambda_expr); + + return make_uniq(bound_function.return_type, std::move(lambda_expr), has_index); +} + +void LambdaFunctions::ListTransformFunction(DataChunk &args, ExpressionState &state, Vector &result) { + ExecuteLambda(args, state, result); +} + +void LambdaFunctions::ListFilterFunction(DataChunk &args, ExpressionState &state, Vector &result) { + ExecuteLambda(args, state, result); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp b/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp new file mode 100644 index 00000000..af7d0ee0 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp @@ -0,0 +1,280 @@ +#include "core_functions/scalar/array_functions.hpp" +#include "core_functions/array_kernels.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +static unique_ptr ArrayGenericBinaryBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + const auto lhs_is_param = arguments[0]->HasParameter(); + const auto rhs_is_param = arguments[1]->HasParameter(); + + if (lhs_is_param && rhs_is_param) { + throw ParameterNotResolvedException(); + } + + const auto &lhs_type = arguments[0]->return_type; + const auto &rhs_type = arguments[1]->return_type; + + bound_function.arguments[0] = lhs_is_param ? rhs_type : lhs_type; + bound_function.arguments[1] = rhs_is_param ? lhs_type : rhs_type; + + if (bound_function.arguments[0].id() != LogicalTypeId::ARRAY || + bound_function.arguments[1].id() != LogicalTypeId::ARRAY) { + throw InvalidInputException( + StringUtil::Format("%s: Arguments must be arrays of FLOAT or DOUBLE", bound_function.name)); + } + + const auto lhs_size = ArrayType::GetSize(bound_function.arguments[0]); + const auto rhs_size = ArrayType::GetSize(bound_function.arguments[1]); + + if (lhs_size != rhs_size) { + throw BinderException("%s: Array arguments must be of the same size", bound_function.name); + } + + const auto &lhs_element_type = ArrayType::GetChildType(bound_function.arguments[0]); + const auto &rhs_element_type = ArrayType::GetChildType(bound_function.arguments[1]); + + // Resolve common type + LogicalType common_type; + if (!LogicalType::TryGetMaxLogicalType(context, lhs_element_type, rhs_element_type, common_type)) { + throw BinderException("%s: Cannot infer common element type (left = '%s', right = '%s')", bound_function.name, + lhs_element_type.ToString(), rhs_element_type.ToString()); + } + + // Ensure it is float or double + if (common_type.id() != LogicalTypeId::FLOAT && common_type.id() != LogicalTypeId::DOUBLE) { + throw BinderException("%s: Arguments must be arrays of FLOAT or DOUBLE", bound_function.name); + } + + // The important part is just that we resolve the size of the input arrays + bound_function.arguments[0] = LogicalType::ARRAY(common_type, lhs_size); + bound_function.arguments[1] = LogicalType::ARRAY(common_type, rhs_size); + + return nullptr; +} + +//------------------------------------------------------------------------------ +// Element-wise combine functions +//------------------------------------------------------------------------------ +// Given two arrays of the same size, combine their elements into a single array +// of the same size as the input arrays. + +struct CrossProductOp { + template + static void Operation(const TYPE *lhs_data, const TYPE *rhs_data, TYPE *res_data, idx_t size) { + D_ASSERT(size == 3); + + auto lx = lhs_data[0]; + auto ly = lhs_data[1]; + auto lz = lhs_data[2]; + + auto rx = rhs_data[0]; + auto ry = rhs_data[1]; + auto rz = rhs_data[2]; + + res_data[0] = ly * rz - lz * ry; + res_data[1] = lz * rx - lx * rz; + res_data[2] = lx * ry - ly * rx; + } +}; + +template +static void ArrayFixedCombine(DataChunk &args, ExpressionState &state, Vector &result) { + const auto &lstate = state.Cast(); + const auto &expr = lstate.expr.Cast(); + const auto &func_name = expr.function.name; + + const auto count = args.size(); + auto &lhs_child = ArrayVector::GetEntry(args.data[0]); + auto &rhs_child = ArrayVector::GetEntry(args.data[1]); + auto &res_child = ArrayVector::GetEntry(result); + + const auto &lhs_child_validity = FlatVector::Validity(lhs_child); + const auto &rhs_child_validity = FlatVector::Validity(rhs_child); + + UnifiedVectorFormat lhs_format; + UnifiedVectorFormat rhs_format; + + args.data[0].ToUnifiedFormat(count, lhs_format); + args.data[1].ToUnifiedFormat(count, rhs_format); + + auto lhs_data = FlatVector::GetData(lhs_child); + auto rhs_data = FlatVector::GetData(rhs_child); + auto res_data = FlatVector::GetData(res_child); + + for (idx_t i = 0; i < count; i++) { + const auto lhs_idx = lhs_format.sel->get_index(i); + const auto rhs_idx = rhs_format.sel->get_index(i); + + if (!lhs_format.validity.RowIsValid(lhs_idx) || !rhs_format.validity.RowIsValid(rhs_idx)) { + FlatVector::SetNull(result, i, true); + continue; + } + + const auto left_offset = lhs_idx * N; + if (!lhs_child_validity.CheckAllValid(left_offset + N, left_offset)) { + throw InvalidInputException(StringUtil::Format("%s: left argument can not contain NULL values", func_name)); + } + + const auto right_offset = rhs_idx * N; + if (!rhs_child_validity.CheckAllValid(right_offset + N, right_offset)) { + throw InvalidInputException( + StringUtil::Format("%s: right argument can not contain NULL values", func_name)); + } + const auto result_offset = i * N; + + const auto lhs_data_ptr = lhs_data + left_offset; + const auto rhs_data_ptr = rhs_data + right_offset; + const auto res_data_ptr = res_data + result_offset; + + OP::Operation(lhs_data_ptr, rhs_data_ptr, res_data_ptr, N); + } + + if (count == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +//------------------------------------------------------------------------------ +// Generic "fold" function +//------------------------------------------------------------------------------ +// Given two arrays, combine and reduce their elements into a single scalar value. + +template +static void ArrayGenericFold(DataChunk &args, ExpressionState &state, Vector &result) { + const auto &lstate = state.Cast(); + const auto &expr = lstate.expr.Cast(); + const auto &func_name = expr.function.name; + + const auto count = args.size(); + auto &lhs_child = ArrayVector::GetEntry(args.data[0]); + auto &rhs_child = ArrayVector::GetEntry(args.data[1]); + + const auto &lhs_child_validity = FlatVector::Validity(lhs_child); + const auto &rhs_child_validity = FlatVector::Validity(rhs_child); + + UnifiedVectorFormat lhs_format; + UnifiedVectorFormat rhs_format; + + args.data[0].ToUnifiedFormat(count, lhs_format); + args.data[1].ToUnifiedFormat(count, rhs_format); + + auto lhs_data = FlatVector::GetData(lhs_child); + auto rhs_data = FlatVector::GetData(rhs_child); + auto res_data = FlatVector::GetData(result); + + const auto array_size = ArrayType::GetSize(args.data[0].GetType()); + D_ASSERT(array_size == ArrayType::GetSize(args.data[1].GetType())); + + for (idx_t i = 0; i < count; i++) { + const auto lhs_idx = lhs_format.sel->get_index(i); + const auto rhs_idx = rhs_format.sel->get_index(i); + + if (!lhs_format.validity.RowIsValid(lhs_idx) || !rhs_format.validity.RowIsValid(rhs_idx)) { + FlatVector::SetNull(result, i, true); + continue; + } + + const auto left_offset = lhs_idx * array_size; + if (!lhs_child_validity.CheckAllValid(left_offset + array_size, left_offset)) { + throw InvalidInputException(StringUtil::Format("%s: left argument can not contain NULL values", func_name)); + } + + const auto right_offset = rhs_idx * array_size; + if (!rhs_child_validity.CheckAllValid(right_offset + array_size, right_offset)) { + throw InvalidInputException( + StringUtil::Format("%s: right argument can not contain NULL values", func_name)); + } + + const auto lhs_data_ptr = lhs_data + left_offset; + const auto rhs_data_ptr = rhs_data + right_offset; + + res_data[i] = OP::Operation(lhs_data_ptr, rhs_data_ptr, array_size); + } + + if (count == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +//------------------------------------------------------------------------------ +// Function Registration +//------------------------------------------------------------------------------ +// Note: In the future we could add a wrapper with a non-type template parameter to specialize for specific array sizes +// e.g. 256, 512, 1024, 2048 etc. which may allow the compiler to vectorize the loop better. Perhaps something for an +// extension. + +template +static void AddArrayFoldFunction(ScalarFunctionSet &set, const LogicalType &type) { + const auto array = LogicalType::ARRAY(type, optional_idx()); + if (type.id() == LogicalTypeId::FLOAT) { + ScalarFunction function({array, array}, type, ArrayGenericFold, ArrayGenericBinaryBind); + BaseScalarFunction::SetReturnsError(function); + set.AddFunction(function); + } else if (type.id() == LogicalTypeId::DOUBLE) { + ScalarFunction function({array, array}, type, ArrayGenericFold, ArrayGenericBinaryBind); + BaseScalarFunction::SetReturnsError(function); + set.AddFunction(function); + } else { + throw NotImplementedException("Array function not implemented for type %s", type.ToString()); + } +} + +ScalarFunctionSet ArrayDistanceFun::GetFunctions() { + ScalarFunctionSet set("array_distance"); + for (auto &type : LogicalType::Real()) { + AddArrayFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ArrayInnerProductFun::GetFunctions() { + ScalarFunctionSet set("array_inner_product"); + for (auto &type : LogicalType::Real()) { + AddArrayFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ArrayNegativeInnerProductFun::GetFunctions() { + ScalarFunctionSet set("array_negative_inner_product"); + for (auto &type : LogicalType::Real()) { + AddArrayFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ArrayCosineSimilarityFun::GetFunctions() { + ScalarFunctionSet set("array_cosine_similarity"); + for (auto &type : LogicalType::Real()) { + AddArrayFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ArrayCosineDistanceFun::GetFunctions() { + ScalarFunctionSet set("array_cosine_distance"); + for (auto &type : LogicalType::Real()) { + AddArrayFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ArrayCrossProductFun::GetFunctions() { + ScalarFunctionSet set("array_cross_product"); + + auto float_array = LogicalType::ARRAY(LogicalType::FLOAT, 3); + auto double_array = LogicalType::ARRAY(LogicalType::DOUBLE, 3); + set.AddFunction( + ScalarFunction({float_array, float_array}, float_array, ArrayFixedCombine)); + set.AddFunction( + ScalarFunction({double_array, double_array}, double_array, ArrayFixedCombine)); + for (auto &func : set.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/array/array_value.cpp b/src/duckdb/extension/core_functions/scalar/array/array_value.cpp new file mode 100644 index 00000000..e7f715f7 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/array/array_value.cpp @@ -0,0 +1,87 @@ +#include "core_functions/scalar/array_functions.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/storage/statistics/array_stats.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +static void ArrayValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto array_type = result.GetType(); + + D_ASSERT(array_type.id() == LogicalTypeId::ARRAY); + D_ASSERT(args.ColumnCount() == ArrayType::GetSize(array_type)); + + auto &child_type = ArrayType::GetChildType(array_type); + + result.SetVectorType(VectorType::CONSTANT_VECTOR); + for (idx_t i = 0; i < args.ColumnCount(); i++) { + if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::FLAT_VECTOR); + } + } + + auto num_rows = args.size(); + auto num_columns = args.ColumnCount(); + + auto &child = ArrayVector::GetEntry(result); + + if (num_columns > 1) { + // Ensure that the child has a validity mask of the correct size + // The SetValue call below expects the validity mask to be initialized + auto &child_validity = FlatVector::Validity(child); + child_validity.Resize(num_rows * num_columns); + } + + for (idx_t i = 0; i < num_rows; i++) { + for (idx_t j = 0; j < num_columns; j++) { + auto val = args.GetValue(j, i).DefaultCastAs(child_type); + child.SetValue((i * num_columns) + j, val); + } + } + + result.Verify(args.size()); +} + +static unique_ptr ArrayValueBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.empty()) { + throw InvalidInputException("array_value requires at least one argument"); + } + + // construct return type + LogicalType child_type = arguments[0]->return_type; + for (idx_t i = 1; i < arguments.size(); i++) { + child_type = LogicalType::MaxLogicalType(context, child_type, arguments[i]->return_type); + } + + if (arguments.size() > ArrayType::MAX_ARRAY_SIZE) { + throw OutOfRangeException("Array size exceeds maximum allowed size"); + } + + // this is more for completeness reasons + bound_function.varargs = child_type; + bound_function.return_type = LogicalType::ARRAY(child_type, arguments.size()); + return make_uniq(bound_function.return_type); +} + +unique_ptr ArrayValueStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + auto list_stats = ArrayStats::CreateEmpty(expr.return_type); + auto &list_child_stats = ArrayStats::GetChildStats(list_stats); + for (idx_t i = 0; i < child_stats.size(); i++) { + list_child_stats.Merge(child_stats[i]); + } + return list_stats.ToUnique(); +} + +ScalarFunction ArrayValueFun::GetFunction() { + // the arguments and return types are actually set in the binder function + ScalarFunction fun("array_value", {}, LogicalTypeId::ARRAY, ArrayValueFunction, ArrayValueBind, nullptr, + ArrayValueStats); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp b/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp new file mode 100644 index 00000000..0dbcb8eb --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp @@ -0,0 +1,122 @@ +#include "core_functions/scalar/bit_functions.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/cast_helpers.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// BitStringFunction +//===--------------------------------------------------------------------===// +template +static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t n) { + if (n < 0) { + throw InvalidInputException("The bitstring length cannot be negative"); + } + idx_t input_length; + if (FROM_STRING) { + input_length = input.GetSize(); + } else { + input_length = Bit::BitLength(input); + } + if (idx_t(n) < input_length) { + throw InvalidInputException("Length must be equal or larger than input string"); + } + idx_t len; + if (FROM_STRING) { + Bit::TryGetBitStringSize(input, len, nullptr); // string verification + } + + len = Bit::ComputeBitstringLen(UnsafeNumericCast(n)); + string_t target = StringVector::EmptyString(result, len); + if (FROM_STRING) { + Bit::BitString(input, UnsafeNumericCast(n), target); + } else { + Bit::ExtendBitString(input, UnsafeNumericCast(n), target); + } + target.Finalize(); + return target; + }); +} + +ScalarFunctionSet BitStringFun::GetFunctions() { + ScalarFunctionSet bitstring; + bitstring.AddFunction( + ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction)); + bitstring.AddFunction( + ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction)); + for (auto &func : bitstring.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return bitstring; +} + +//===--------------------------------------------------------------------===// +// get_bit +//===--------------------------------------------------------------------===// +struct GetBitOperator { + template + static inline TR Operation(TA input, TB n) { + if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { + throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), + NumericHelper::ToString(Bit::BitLength(input) - 1)); + } + return UnsafeNumericCast(Bit::GetBit(input, UnsafeNumericCast(n))); + } +}; + +ScalarFunction GetBitFun::GetFunction() { + ScalarFunction func({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::INTEGER, + ScalarFunction::BinaryFunction); + BaseScalarFunction::SetReturnsError(func); + return func; +} + +//===--------------------------------------------------------------------===// +// set_bit +//===--------------------------------------------------------------------===// +static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &result) { + TernaryExecutor::Execute( + args.data[0], args.data[1], args.data[2], result, args.size(), + [&](string_t input, int32_t n, int32_t new_value) { + if (new_value != 0 && new_value != 1) { + throw InvalidInputException("The new bit must be 1 or 0"); + } + if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { + throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), + NumericHelper::ToString(Bit::BitLength(input) - 1)); + } + string_t target = StringVector::EmptyString(result, input.GetSize()); + memcpy(target.GetDataWriteable(), input.GetData(), input.GetSize()); + Bit::SetBit(target, UnsafeNumericCast(n), UnsafeNumericCast(new_value)); + return target; + }); +} + +ScalarFunction SetBitFun::GetFunction() { + ScalarFunction function({LogicalType::BIT, LogicalType::INTEGER, LogicalType::INTEGER}, LogicalType::BIT, + SetBitOperation); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// bit_position +//===--------------------------------------------------------------------===// +struct BitPositionOperator { + template + static inline TR Operation(TA substring, TB input) { + if (substring.GetSize() > input.GetSize()) { + return 0; + } + return UnsafeNumericCast(Bit::BitPosition(substring, input)); + } +}; + +ScalarFunction BitPositionFun::GetFunction() { + return ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::INTEGER, + ScalarFunction::BinaryFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/blob/base64.cpp b/src/duckdb/extension/core_functions/scalar/blob/base64.cpp new file mode 100644 index 00000000..fb903fa8 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/blob/base64.cpp @@ -0,0 +1,47 @@ +#include "core_functions/scalar/blob_functions.hpp" +#include "duckdb/common/types/blob.hpp" + +namespace duckdb { + +struct Base64EncodeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto result_str = StringVector::EmptyString(result, Blob::ToBase64Size(input)); + Blob::ToBase64(input, result_str.GetDataWriteable()); + result_str.Finalize(); + return result_str; + } +}; + +struct Base64DecodeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto result_size = Blob::FromBase64Size(input); + auto result_blob = StringVector::EmptyString(result, result_size); + Blob::FromBase64(input, data_ptr_cast(result_blob.GetDataWriteable()), result_size); + result_blob.Finalize(); + return result_blob; + } +}; + +static void Base64EncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // decode is also a nop cast, but requires verification if the provided string is actually + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); +} + +static void Base64DecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // decode is also a nop cast, but requires verification if the provided string is actually + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); +} + +ScalarFunction ToBase64Fun::GetFunction() { + return ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, Base64EncodeFunction); +} + +ScalarFunction FromBase64Fun::GetFunction() { + ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, Base64DecodeFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/blob/encode.cpp b/src/duckdb/extension/core_functions/scalar/blob/encode.cpp new file mode 100644 index 00000000..66cedb0b --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/blob/encode.cpp @@ -0,0 +1,42 @@ +#include "core_functions/scalar/blob_functions.hpp" +#include "utf8proc_wrapper.hpp" +#include "duckdb/common/exception/conversion_exception.hpp" + +namespace duckdb { + +static void EncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // encode is essentially a nop cast from varchar to blob + // we only need to reinterpret the data using the blob type + result.Reinterpret(args.data[0]); +} + +struct BlobDecodeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input) { + auto input_data = input.GetData(); + auto input_length = input.GetSize(); + if (Utf8Proc::Analyze(input_data, input_length) == UnicodeType::INVALID) { + throw ConversionException( + "Failure in decode: could not convert blob to UTF8 string, the blob contained invalid UTF8 characters"); + } + return input; + } +}; + +static void DecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // decode is also a nop cast, but requires verification if the provided string is actually + UnaryExecutor::Execute(args.data[0], result, args.size()); + StringVector::AddHeapReference(result, args.data[0]); +} + +ScalarFunction EncodeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, EncodeFunction); +} + +ScalarFunction DecodeFun::GetFunction() { + ScalarFunction function({LogicalType::BLOB}, LogicalType::VARCHAR, DecodeFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/age.cpp b/src/duckdb/extension/core_functions/scalar/date/age.cpp new file mode 100644 index 00000000..cf7281f0 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/date/age.cpp @@ -0,0 +1,55 @@ +#include "core_functions/scalar/date_functions.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/transaction/meta_transaction.hpp" + +namespace duckdb { + +static void AgeFunctionStandard(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + // Subtract argument from current_date (at midnight) + // Theoretically, this should be TZ-sensitive, but since we have to be able to handle + // plain TZ when ICU is not loaded, we implement this in UTC (like everything else) + // To get the PG behaviour, we overload these functions in ICU for TSTZ arguments. + auto current_date = Timestamp::FromDatetime( + Timestamp::GetDate(MetaTransaction::Get(state.GetContext()).start_timestamp), dtime_t(0)); + + UnaryExecutor::ExecuteWithNulls(input.data[0], result, input.size(), + [&](timestamp_t input, ValidityMask &mask, idx_t idx) { + if (Timestamp::IsFinite(input)) { + return Interval::GetAge(current_date, input); + } else { + mask.SetInvalid(idx); + return interval_t(); + } + }); +} + +static void AgeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 2); + + BinaryExecutor::ExecuteWithNulls( + input.data[0], input.data[1], result, input.size(), + [&](timestamp_t input1, timestamp_t input2, ValidityMask &mask, idx_t idx) { + if (Timestamp::IsFinite(input1) && Timestamp::IsFinite(input2)) { + return Interval::GetAge(input1, input2); + } else { + mask.SetInvalid(idx); + return interval_t(); + } + }); +} + +ScalarFunctionSet AgeFun::GetFunctions() { + ScalarFunctionSet age("age"); + age.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::INTERVAL, AgeFunctionStandard)); + age.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, LogicalType::INTERVAL, AgeFunction)); + return age; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/current.cpp b/src/duckdb/extension/core_functions/scalar/date/current.cpp new file mode 100644 index 00000000..3d25ee80 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/date/current.cpp @@ -0,0 +1,29 @@ +#include "core_functions/scalar/date_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/transaction/meta_transaction.hpp" + +namespace duckdb { + +static timestamp_t GetTransactionTimestamp(ExpressionState &state) { + return MetaTransaction::Get(state.GetContext()).start_timestamp; +} + +static void CurrentTimestampFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 0); + auto ts = GetTransactionTimestamp(state); + auto val = Value::TIMESTAMPTZ(timestamp_tz_t(ts)); + result.Reference(val); +} + +ScalarFunction GetCurrentTimestampFun::GetFunction() { + ScalarFunction current_timestamp({}, LogicalType::TIMESTAMP_TZ, CurrentTimestampFunction); + current_timestamp.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + return current_timestamp; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/date_diff.cpp b/src/duckdb/extension/core_functions/scalar/date/date_diff.cpp new file mode 100644 index 00000000..c0e4ba1d --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/date/date_diff.cpp @@ -0,0 +1,454 @@ +#include "core_functions/scalar/date_functions.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +// This function is an implementation of the "period-crossing" date difference function from T-SQL +// https://docs.microsoft.com/en-us/sql/t-sql/functions/datediff-transact-sql?view=sql-server-ver15 +struct DateDiff { + template + static inline void BinaryExecute(Vector &left, Vector &right, Vector &result, idx_t count) { + BinaryExecutor::ExecuteWithNulls( + left, right, result, count, [&](TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { + return OP::template Operation(startdate, enddate); + } else { + mask.SetInvalid(idx); + return TR(); + } + }); + } + + // We need to truncate down, not towards 0 + static inline int64_t Truncate(int64_t value, int64_t units) { + return (value + (value < 0)) / units - (value < 0); + } + static inline int64_t Diff(int64_t start, int64_t end, int64_t units) { + return Truncate(end, units) - Truncate(start, units); + } + + struct YearOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::ExtractYear(enddate) - Date::ExtractYear(startdate); + } + }; + + struct MonthOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + int32_t start_year, start_month, start_day; + Date::Convert(startdate, start_year, start_month, start_day); + int32_t end_year, end_month, end_day; + Date::Convert(enddate, end_year, end_month, end_day); + + return (end_year * 12 + end_month - 1) - (start_year * 12 + start_month - 1); + } + }; + + struct DayOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return TR(Date::EpochDays(enddate)) - TR(Date::EpochDays(startdate)); + } + }; + + struct DecadeOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::ExtractYear(enddate) / 10 - Date::ExtractYear(startdate) / 10; + } + }; + + struct CenturyOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::ExtractYear(enddate) / 100 - Date::ExtractYear(startdate) / 100; + } + }; + + struct MilleniumOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::ExtractYear(enddate) / 1000 - Date::ExtractYear(startdate) / 1000; + } + }; + + struct QuarterOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + int32_t start_year, start_month, start_day; + Date::Convert(startdate, start_year, start_month, start_day); + int32_t end_year, end_month, end_day; + Date::Convert(enddate, end_year, end_month, end_day); + + return (end_year * 12 + end_month - 1) / Interval::MONTHS_PER_QUARTER - + (start_year * 12 + start_month - 1) / Interval::MONTHS_PER_QUARTER; + } + }; + + struct WeekOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + // Weeks do not count Monday crossings, just distance + return (enddate.days - startdate.days) / Interval::DAYS_PER_WEEK; + } + }; + + struct ISOYearOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::ExtractISOYearNumber(enddate) - Date::ExtractISOYearNumber(startdate); + } + }; + + struct MicrosecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::EpochMicroseconds(enddate) - Date::EpochMicroseconds(startdate); + } + }; + + struct MillisecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::EpochMicroseconds(enddate) / Interval::MICROS_PER_MSEC - + Date::EpochMicroseconds(startdate) / Interval::MICROS_PER_MSEC; + } + }; + + struct SecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::Epoch(enddate) - Date::Epoch(startdate); + } + }; + + struct MinutesOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::Epoch(enddate) / Interval::SECS_PER_MINUTE - + Date::Epoch(startdate) / Interval::SECS_PER_MINUTE; + } + }; + + struct HoursOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return Date::Epoch(enddate) / Interval::SECS_PER_HOUR - Date::Epoch(startdate) / Interval::SECS_PER_HOUR; + } + }; +}; + +// TIMESTAMP specialisations +template <> +int64_t DateDiff::YearOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return YearOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::MonthOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return MonthOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::DayOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return DayOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::DecadeOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return DecadeOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::CenturyOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return CenturyOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::MilleniumOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return MilleniumOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::QuarterOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return QuarterOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::WeekOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return WeekOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::ISOYearOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + return ISOYearOperator::Operation(Timestamp::GetDate(startdate), + Timestamp::GetDate(enddate)); +} + +template <> +int64_t DateDiff::MicrosecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + const auto start = Timestamp::GetEpochMicroSeconds(startdate); + const auto end = Timestamp::GetEpochMicroSeconds(enddate); + return SubtractOperatorOverflowCheck::Operation(end, start); +} + +template <> +int64_t DateDiff::MillisecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + D_ASSERT(Timestamp::IsFinite(startdate)); + D_ASSERT(Timestamp::IsFinite(enddate)); + return Diff(startdate.value, enddate.value, Interval::MICROS_PER_MSEC); +} + +template <> +int64_t DateDiff::SecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + D_ASSERT(Timestamp::IsFinite(startdate)); + D_ASSERT(Timestamp::IsFinite(enddate)); + return Diff(startdate.value, enddate.value, Interval::MICROS_PER_SEC); +} + +template <> +int64_t DateDiff::MinutesOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + D_ASSERT(Timestamp::IsFinite(startdate)); + D_ASSERT(Timestamp::IsFinite(enddate)); + return Diff(startdate.value, enddate.value, Interval::MICROS_PER_MINUTE); +} + +template <> +int64_t DateDiff::HoursOperator::Operation(timestamp_t startdate, timestamp_t enddate) { + D_ASSERT(Timestamp::IsFinite(startdate)); + D_ASSERT(Timestamp::IsFinite(enddate)); + return Diff(startdate.value, enddate.value, Interval::MICROS_PER_HOUR); +} + +// TIME specialisations +template <> +int64_t DateDiff::YearOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"year\" not recognized"); +} + +template <> +int64_t DateDiff::MonthOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"month\" not recognized"); +} + +template <> +int64_t DateDiff::DayOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"day\" not recognized"); +} + +template <> +int64_t DateDiff::DecadeOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"decade\" not recognized"); +} + +template <> +int64_t DateDiff::CenturyOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"century\" not recognized"); +} + +template <> +int64_t DateDiff::MilleniumOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"millennium\" not recognized"); +} + +template <> +int64_t DateDiff::QuarterOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"quarter\" not recognized"); +} + +template <> +int64_t DateDiff::WeekOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"week\" not recognized"); +} + +template <> +int64_t DateDiff::ISOYearOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"isoyear\" not recognized"); +} + +template <> +int64_t DateDiff::MicrosecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros - startdate.micros; +} + +template <> +int64_t DateDiff::MillisecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros / Interval::MICROS_PER_MSEC - startdate.micros / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DateDiff::SecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros / Interval::MICROS_PER_SEC - startdate.micros / Interval::MICROS_PER_SEC; +} + +template <> +int64_t DateDiff::MinutesOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros / Interval::MICROS_PER_MINUTE - startdate.micros / Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DateDiff::HoursOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros / Interval::MICROS_PER_HOUR - startdate.micros / Interval::MICROS_PER_HOUR; +} + +template +static int64_t DifferenceDates(DatePartSpecifier type, TA startdate, TB enddate) { + switch (type) { + case DatePartSpecifier::YEAR: + return DateDiff::YearOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MONTH: + return DateDiff::MonthOperator::template Operation(startdate, enddate); + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + return DateDiff::DayOperator::template Operation(startdate, enddate); + case DatePartSpecifier::DECADE: + return DateDiff::DecadeOperator::template Operation(startdate, enddate); + case DatePartSpecifier::CENTURY: + return DateDiff::CenturyOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MILLENNIUM: + return DateDiff::MilleniumOperator::template Operation(startdate, enddate); + case DatePartSpecifier::QUARTER: + return DateDiff::QuarterOperator::template Operation(startdate, enddate); + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + return DateDiff::WeekOperator::template Operation(startdate, enddate); + case DatePartSpecifier::ISOYEAR: + return DateDiff::ISOYearOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MICROSECONDS: + return DateDiff::MicrosecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MILLISECONDS: + return DateDiff::MillisecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + return DateDiff::SecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MINUTE: + return DateDiff::MinutesOperator::template Operation(startdate, enddate); + case DatePartSpecifier::HOUR: + return DateDiff::HoursOperator::template Operation(startdate, enddate); + default: + throw NotImplementedException("Specifier type not implemented for DATEDIFF"); + } +} + +struct DateDiffTernaryOperator { + template + static inline TR Operation(TS part, TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { + return DifferenceDates(GetDatePartSpecifier(part.GetString()), startdate, enddate); + } else { + mask.SetInvalid(idx); + return TR(); + } + } +}; + +template +static void DateDiffBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, Vector &result, idx_t count) { + switch (type) { + case DatePartSpecifier::YEAR: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MONTH: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::DECADE: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::CENTURY: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MILLENNIUM: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::QUARTER: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::ISOYEAR: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MICROSECONDS: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MILLISECONDS: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MINUTE: + DateDiff::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::HOUR: + DateDiff::BinaryExecute(left, right, result, count); + break; + default: + throw NotImplementedException("Specifier type not implemented for DATEDIFF"); + } +} + +template +static void DateDiffFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3); + auto &part_arg = args.data[0]; + auto &start_arg = args.data[1]; + auto &end_arg = args.data[2]; + + if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // Common case of constant part. + if (ConstantVector::IsNull(part_arg)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); + DateDiffBinaryExecutor(type, start_arg, end_arg, result, args.size()); + } + } else { + TernaryExecutor::ExecuteWithNulls( + part_arg, start_arg, end_arg, result, args.size(), + DateDiffTernaryOperator::Operation); + } +} + +ScalarFunctionSet DateDiffFun::GetFunctions() { + ScalarFunctionSet date_diff("date_diff"); + date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE, LogicalType::DATE}, + LogicalType::BIGINT, DateDiffFunction)); + date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, + LogicalType::BIGINT, DateDiffFunction)); + date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME, LogicalType::TIME}, + LogicalType::BIGINT, DateDiffFunction)); + return date_diff; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp new file mode 100644 index 00000000..1aeb4550 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp @@ -0,0 +1,2263 @@ +#include "core_functions/scalar/date_functions.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/exception/conversion_exception.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/date_lookup_cache.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +DatePartSpecifier GetDateTypePartSpecifier(const string &specifier, LogicalType &type) { + const auto part = GetDatePartSpecifier(specifier); + switch (type.id()) { + case LogicalType::TIMESTAMP: + case LogicalType::TIMESTAMP_TZ: + return part; + case LogicalType::DATE: + switch (part) { + case DatePartSpecifier::YEAR: + case DatePartSpecifier::MONTH: + case DatePartSpecifier::DAY: + case DatePartSpecifier::DECADE: + case DatePartSpecifier::CENTURY: + case DatePartSpecifier::MILLENNIUM: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::ISOYEAR: + case DatePartSpecifier::WEEK: + case DatePartSpecifier::QUARTER: + case DatePartSpecifier::DOY: + case DatePartSpecifier::YEARWEEK: + case DatePartSpecifier::ERA: + case DatePartSpecifier::EPOCH: + case DatePartSpecifier::JULIAN_DAY: + return part; + default: + break; + } + break; + case LogicalType::TIME: + case LogicalType::TIME_TZ: + switch (part) { + case DatePartSpecifier::MICROSECONDS: + case DatePartSpecifier::MILLISECONDS: + case DatePartSpecifier::SECOND: + case DatePartSpecifier::MINUTE: + case DatePartSpecifier::HOUR: + case DatePartSpecifier::EPOCH: + case DatePartSpecifier::TIMEZONE: + case DatePartSpecifier::TIMEZONE_HOUR: + case DatePartSpecifier::TIMEZONE_MINUTE: + return part; + default: + break; + } + break; + case LogicalType::INTERVAL: + switch (part) { + case DatePartSpecifier::YEAR: + case DatePartSpecifier::MONTH: + case DatePartSpecifier::DAY: + case DatePartSpecifier::DECADE: + case DatePartSpecifier::CENTURY: + case DatePartSpecifier::QUARTER: + case DatePartSpecifier::MILLENNIUM: + case DatePartSpecifier::MICROSECONDS: + case DatePartSpecifier::MILLISECONDS: + case DatePartSpecifier::SECOND: + case DatePartSpecifier::MINUTE: + case DatePartSpecifier::HOUR: + case DatePartSpecifier::EPOCH: + return part; + default: + break; + } + break; + default: + break; + } + + throw NotImplementedException("\"%s\" units \"%s\" not recognized", EnumUtil::ToString(type.id()), specifier); +} + +template +static unique_ptr PropagateSimpleDatePartStatistics(vector &child_stats) { + // we can always propagate simple date part statistics + // since the min and max can never exceed these bounds + auto result = NumericStats::CreateEmpty(LogicalType::BIGINT); + result.CopyValidity(child_stats[0]); + NumericStats::SetMin(result, Value::BIGINT(MIN)); + NumericStats::SetMax(result, Value::BIGINT(MAX)); + return result.ToUnique(); +} + +template +struct DateCacheLocalState : public FunctionLocalState { + explicit DateCacheLocalState() { + } + + DateLookupCache cache; +}; + +template +unique_ptr InitDateCacheLocalState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data) { + return make_uniq>(); +} + +struct DatePart { + template + static unique_ptr PropagateDatePartStatistics(vector &child_stats, + const LogicalType &stats_type = LogicalType::BIGINT) { + // we can only propagate complex date part stats if the child has stats + auto &nstats = child_stats[0]; + if (!NumericStats::HasMinMax(nstats)) { + return nullptr; + } + // run the operator on both the min and the max, this gives us the [min, max] bound + auto min = NumericStats::GetMin(nstats); + auto max = NumericStats::GetMax(nstats); + if (min > max) { + return nullptr; + } + // Infinities prevent us from computing generic ranges + if (!Value::IsFinite(min) || !Value::IsFinite(max)) { + return nullptr; + } + TR min_part = OP::template Operation(min); + TR max_part = OP::template Operation(max); + auto result = NumericStats::CreateEmpty(stats_type); + NumericStats::SetMin(result, Value(min_part)); + NumericStats::SetMax(result, Value(max_part)); + result.CopyValidity(child_stats[0]); + return result.ToUnique(); + } + + template + struct PartOperator { + template + static inline TR Operation(TA input, ValidityMask &mask, idx_t idx, void *dataptr) { + if (Value::IsFinite(input)) { + return OP::template Operation(input); + } else { + mask.SetInvalid(idx); + return TR(); + } + } + }; + + template + static void UnaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() >= 1); + using IOP = PartOperator; + UnaryExecutor::GenericExecute(input.data[0], result, input.size(), nullptr, true); + } + + struct YearOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractYear(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct MonthOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractMonth(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + // min/max of month operator is [1, 12] + return PropagateSimpleDatePartStatistics<1, 12>(input.child_stats); + } + }; + + struct DayOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractDay(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + // min/max of day operator is [1, 31] + return PropagateSimpleDatePartStatistics<1, 31>(input.child_stats); + } + }; + + struct DecadeOperator { + // From the PG docs: "The year field divided by 10" + template + static inline TR DecadeFromYear(TR yyyy) { + return yyyy / 10; + } + + template + static inline TR Operation(TA input) { + return DecadeFromYear(YearOperator::Operation(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct CenturyOperator { + // From the PG docs: + // "The first century starts at 0001-01-01 00:00:00 AD, although they did not know it at the time. + // This definition applies to all Gregorian calendar countries. + // There is no century number 0, you go from -1 century to 1 century. + // If you disagree with this, please write your complaint to: Pope, Cathedral Saint-Peter of Roma, Vatican." + // (To be fair, His Holiness had nothing to do with this - + // it was the lack of zero in the counting systems of the time...) + template + static inline TR CenturyFromYear(TR yyyy) { + if (yyyy > 0) { + return ((yyyy - 1) / 100) + 1; + } else { + return (yyyy / 100) - 1; + } + } + + template + static inline TR Operation(TA input) { + return CenturyFromYear(YearOperator::Operation(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct MillenniumOperator { + // See the century comment + template + static inline TR MillenniumFromYear(TR yyyy) { + if (yyyy > 0) { + return ((yyyy - 1) / 1000) + 1; + } else { + return (yyyy / 1000) - 1; + } + } + + template + static inline TR Operation(TA input) { + return MillenniumFromYear(YearOperator::Operation(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct QuarterOperator { + template + static inline TR QuarterFromMonth(TR mm) { + return (mm - 1) / Interval::MONTHS_PER_QUARTER + 1; + } + + template + static inline TR Operation(TA input) { + return QuarterFromMonth(Date::ExtractMonth(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + // min/max of quarter operator is [1, 4] + return PropagateSimpleDatePartStatistics<1, 4>(input.child_stats); + } + }; + + struct DayOfWeekOperator { + template + static inline TR DayOfWeekFromISO(TR isodow) { + // day of the week (Sunday = 0, Saturday = 6) + // turn sunday into 0 by doing mod 7 + return isodow % 7; + } + + template + static inline TR Operation(TA input) { + return DayOfWeekFromISO(Date::ExtractISODayOfTheWeek(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 6>(input.child_stats); + } + }; + + struct ISODayOfWeekOperator { + template + static inline TR Operation(TA input) { + // isodow (Monday = 1, Sunday = 7) + return Date::ExtractISODayOfTheWeek(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<1, 7>(input.child_stats); + } + }; + + struct DayOfYearOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractDayOfTheYear(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<1, 366>(input.child_stats); + } + }; + + struct WeekOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractISOWeekNumber(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<1, 54>(input.child_stats); + } + }; + + struct ISOYearOperator { + template + static inline TR Operation(TA input) { + return Date::ExtractISOYearNumber(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct YearWeekOperator { + template + static inline TR YearWeekFromParts(TR yyyy, TR ww) { + return yyyy * 100 + ((yyyy > 0) ? ww : -ww); + } + + template + static inline TR Operation(TA input) { + int32_t yyyy, ww; + Date::ExtractISOYearWeek(input, yyyy, ww); + return YearWeekFromParts(yyyy, ww); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct EpochNanosecondsOperator { + template + static inline TR Operation(TA input) { + return Timestamp::GetEpochNanoSeconds(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct EpochMicrosecondsOperator { + template + static inline TR Operation(TA input) { + return Timestamp::GetEpochMicroSeconds(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + }; + + struct EpochMillisOperator { + template + static inline TR Operation(TA input) { + return Cast::Operation(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats); + } + + static void Inverse(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](int64_t input) { + // millisecond amounts provided to epoch_ms should never be considered infinite + // instead such values will just throw when converted to microseconds + return Timestamp::FromEpochMsPossiblyInfinite(input); + }); + } + }; + + struct NanosecondsOperator { + template + static inline TR Operation(TA input) { + return MicrosecondsOperator::Operation(input) * Interval::NANOS_PER_MICRO; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 60000000000>(input.child_stats); + } + }; + + struct MicrosecondsOperator { + template + static inline TR Operation(TA input) { + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 60000000>(input.child_stats); + } + }; + + struct MillisecondsOperator { + template + static inline TR Operation(TA input) { + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 60000>(input.child_stats); + } + }; + + struct SecondsOperator { + template + static inline TR Operation(TA input) { + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 60>(input.child_stats); + } + }; + + struct MinutesOperator { + template + static inline TR Operation(TA input) { + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 60>(input.child_stats); + } + }; + + struct HoursOperator { + template + static inline TR Operation(TA input) { + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 24>(input.child_stats); + } + }; + + struct EpochOperator { + template + static inline TR Operation(TA input) { + return TR(Date::Epoch(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats, LogicalType::DOUBLE); + } + }; + + struct EraOperator { + template + static inline TR EraFromYear(TR yyyy) { + return yyyy > 0 ? 1 : 0; + } + + template + static inline TR Operation(TA input) { + return EraFromYear(Date::ExtractYear(input)); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 1>(input.child_stats); + } + }; + + struct TimezoneOperator { + template + static inline TR Operation(TA input) { + // Regular timestamps are UTC. + return 0; + } + + template + static TR Operation(TA interval, TB timetz) { + auto time = Time::NormalizeTimeTZ(timetz); + date_t date(0); + time = Interval::Add(time, interval, date); + auto offset = UnsafeNumericCast(interval.micros / Interval::MICROS_PER_SEC); + return TR(time, offset); + } + + template + static void BinaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 2); + auto &offset = input.data[0]; + auto &timetz = input.data[1]; + + auto func = DatePart::TimezoneOperator::Operation; + BinaryExecutor::Execute(offset, timetz, result, input.size(), func); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); + } + }; + + struct TimezoneHourOperator { + template + static inline TR Operation(TA input) { + // Regular timestamps are UTC. + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); + } + }; + + struct TimezoneMinuteOperator { + template + static inline TR Operation(TA input) { + // Regular timestamps are UTC. + return 0; + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); + } + }; + + struct JulianDayOperator { + template + static inline TR Operation(TA input) { + return Timestamp::GetJulianDay(input); + } + + template + static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return PropagateDatePartStatistics(input.child_stats, LogicalType::DOUBLE); + } + }; + + struct StructOperator { + using part_codes_t = vector; + using part_mask_t = uint64_t; + + enum MaskBits : uint8_t { + YMD = 1 << 0, + DOW = 1 << 1, + DOY = 1 << 2, + EPOCH = 1 << 3, + TIME = 1 << 4, + ZONE = 1 << 5, + ISO = 1 << 6, + JD = 1 << 7 + }; + + static part_mask_t GetMask(const part_codes_t &part_codes) { + part_mask_t mask = 0; + for (const auto &part_code : part_codes) { + switch (part_code) { + case DatePartSpecifier::YEAR: + case DatePartSpecifier::MONTH: + case DatePartSpecifier::DAY: + case DatePartSpecifier::DECADE: + case DatePartSpecifier::CENTURY: + case DatePartSpecifier::MILLENNIUM: + case DatePartSpecifier::QUARTER: + case DatePartSpecifier::ERA: + mask |= YMD; + break; + case DatePartSpecifier::YEARWEEK: + case DatePartSpecifier::WEEK: + case DatePartSpecifier::ISOYEAR: + mask |= ISO; + break; + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + mask |= DOW; + break; + case DatePartSpecifier::DOY: + mask |= DOY; + break; + case DatePartSpecifier::EPOCH: + mask |= EPOCH; + break; + case DatePartSpecifier::JULIAN_DAY: + mask |= JD; + break; + case DatePartSpecifier::MICROSECONDS: + case DatePartSpecifier::MILLISECONDS: + case DatePartSpecifier::SECOND: + case DatePartSpecifier::MINUTE: + case DatePartSpecifier::HOUR: + mask |= TIME; + break; + case DatePartSpecifier::TIMEZONE: + case DatePartSpecifier::TIMEZONE_HOUR: + case DatePartSpecifier::TIMEZONE_MINUTE: + mask |= ZONE; + break; + case DatePartSpecifier::INVALID: + throw InternalException("Invalid DatePartSpecifier for STRUCT mask!"); + } + } + return mask; + } + + template + static inline P HasPartValue(vector

part_values, DatePartSpecifier part) { + auto idx = size_t(part); + if (IsBigintDatepart(part)) { + return part_values[idx - size_t(DatePartSpecifier::BEGIN_BIGINT)]; + } else { + return part_values[idx - size_t(DatePartSpecifier::BEGIN_DOUBLE)]; + } + } + + using bigint_vec = vector; + using double_vec = vector; + + template + static inline void Operation(bigint_vec &bigint_values, double_vec &double_values, const TA &input, + const idx_t idx, const part_mask_t mask) { + int64_t *bigint_data; + // YMD calculations + int32_t yyyy = 1970; + int32_t mm = 0; + int32_t dd = 1; + if (mask & YMD) { + Date::Convert(input, yyyy, mm, dd); + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::YEAR); + if (bigint_data) { + bigint_data[idx] = yyyy; + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::MONTH); + if (bigint_data) { + bigint_data[idx] = mm; + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DAY); + if (bigint_data) { + bigint_data[idx] = dd; + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DECADE); + if (bigint_data) { + bigint_data[idx] = DecadeOperator::DecadeFromYear(yyyy); + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::CENTURY); + if (bigint_data) { + bigint_data[idx] = CenturyOperator::CenturyFromYear(yyyy); + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::MILLENNIUM); + if (bigint_data) { + bigint_data[idx] = MillenniumOperator::MillenniumFromYear(yyyy); + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::QUARTER); + if (bigint_data) { + bigint_data[idx] = QuarterOperator::QuarterFromMonth(mm); + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ERA); + if (bigint_data) { + bigint_data[idx] = EraOperator::EraFromYear(yyyy); + } + } + + // Week calculations + if (mask & DOW) { + auto isodow = Date::ExtractISODayOfTheWeek(input); + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DOW); + if (bigint_data) { + bigint_data[idx] = DayOfWeekOperator::DayOfWeekFromISO(isodow); + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ISODOW); + if (bigint_data) { + bigint_data[idx] = isodow; + } + } + + // ISO calculations + if (mask & ISO) { + int32_t ww = 0; + int32_t iyyy = 0; + Date::ExtractISOYearWeek(input, iyyy, ww); + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::WEEK); + if (bigint_data) { + bigint_data[idx] = ww; + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ISOYEAR); + if (bigint_data) { + bigint_data[idx] = iyyy; + } + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::YEARWEEK); + if (bigint_data) { + bigint_data[idx] = YearWeekOperator::YearWeekFromParts(iyyy, ww); + } + } + + if (mask & EPOCH) { + auto double_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); + if (double_data) { + double_data[idx] = double(Date::Epoch(input)); + } + } + if (mask & DOY) { + bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DOY); + if (bigint_data) { + bigint_data[idx] = Date::ExtractDayOfTheYear(input); + } + } + if (mask & JD) { + auto double_data = HasPartValue(double_values, DatePartSpecifier::JULIAN_DAY); + if (double_data) { + double_data[idx] = double(Date::ExtractJulianDay(input)); + } + } + } + }; +}; + +template +static void DatePartCachedFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast>(); + UnaryExecutor::ExecuteWithNulls( + args.data[0], result, args.size(), + [&](T input, ValidityMask &mask, idx_t idx) { return lstate.cache.ExtractElement(input, mask, idx); }); +} + +template <> +int64_t DatePart::YearOperator::Operation(timestamp_t input) { + return YearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::YearOperator::Operation(interval_t input) { + return input.months / Interval::MONTHS_PER_YEAR; +} + +template <> +int64_t DatePart::YearOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"year\" not recognized"); +} + +template <> +int64_t DatePart::YearOperator::Operation(dtime_tz_t input) { + return YearOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::MonthOperator::Operation(timestamp_t input) { + return MonthOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::MonthOperator::Operation(interval_t input) { + return input.months % Interval::MONTHS_PER_YEAR; +} + +template <> +int64_t DatePart::MonthOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"month\" not recognized"); +} + +template <> +int64_t DatePart::MonthOperator::Operation(dtime_tz_t input) { + return MonthOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::DayOperator::Operation(timestamp_t input) { + return DayOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::DayOperator::Operation(interval_t input) { + return input.days; +} + +template <> +int64_t DatePart::DayOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"day\" not recognized"); +} + +template <> +int64_t DatePart::DayOperator::Operation(dtime_tz_t input) { + return DayOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::DecadeOperator::Operation(interval_t input) { + return input.months / Interval::MONTHS_PER_DECADE; +} + +template <> +int64_t DatePart::DecadeOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"decade\" not recognized"); +} + +template <> +int64_t DatePart::DecadeOperator::Operation(dtime_tz_t input) { + return DecadeOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::CenturyOperator::Operation(interval_t input) { + return input.months / Interval::MONTHS_PER_CENTURY; +} + +template <> +int64_t DatePart::CenturyOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"century\" not recognized"); +} + +template <> +int64_t DatePart::CenturyOperator::Operation(dtime_tz_t input) { + return CenturyOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::MillenniumOperator::Operation(interval_t input) { + return input.months / Interval::MONTHS_PER_MILLENIUM; +} + +template <> +int64_t DatePart::MillenniumOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"millennium\" not recognized"); +} + +template <> +int64_t DatePart::MillenniumOperator::Operation(dtime_tz_t input) { + return MillenniumOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::QuarterOperator::Operation(timestamp_t input) { + return QuarterOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::QuarterOperator::Operation(interval_t input) { + return MonthOperator::Operation(input) / Interval::MONTHS_PER_QUARTER + 1; +} + +template <> +int64_t DatePart::QuarterOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"quarter\" not recognized"); +} + +template <> +int64_t DatePart::QuarterOperator::Operation(dtime_tz_t input) { + return QuarterOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::DayOfWeekOperator::Operation(timestamp_t input) { + return DayOfWeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::DayOfWeekOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"dow\" not recognized"); +} + +template <> +int64_t DatePart::DayOfWeekOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"dow\" not recognized"); +} + +template <> +int64_t DatePart::DayOfWeekOperator::Operation(dtime_tz_t input) { + return DayOfWeekOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::ISODayOfWeekOperator::Operation(timestamp_t input) { + return ISODayOfWeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::ISODayOfWeekOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"isodow\" not recognized"); +} + +template <> +int64_t DatePart::ISODayOfWeekOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"isodow\" not recognized"); +} + +template <> +int64_t DatePart::ISODayOfWeekOperator::Operation(dtime_tz_t input) { + return ISODayOfWeekOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::DayOfYearOperator::Operation(timestamp_t input) { + return DayOfYearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::DayOfYearOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"doy\" not recognized"); +} + +template <> +int64_t DatePart::DayOfYearOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"doy\" not recognized"); +} + +template <> +int64_t DatePart::DayOfYearOperator::Operation(dtime_tz_t input) { + return DayOfYearOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::WeekOperator::Operation(timestamp_t input) { + return WeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::WeekOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"week\" not recognized"); +} + +template <> +int64_t DatePart::WeekOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"week\" not recognized"); +} + +template <> +int64_t DatePart::WeekOperator::Operation(dtime_tz_t input) { + return WeekOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::ISOYearOperator::Operation(timestamp_t input) { + return ISOYearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::ISOYearOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"isoyear\" not recognized"); +} + +template <> +int64_t DatePart::ISOYearOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"isoyear\" not recognized"); +} + +template <> +int64_t DatePart::ISOYearOperator::Operation(dtime_tz_t input) { + return ISOYearOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::YearWeekOperator::Operation(timestamp_t input) { + return YearWeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::YearWeekOperator::Operation(interval_t input) { + const auto yyyy = YearOperator::Operation(input); + const auto ww = WeekOperator::Operation(input); + return YearWeekOperator::YearWeekFromParts(yyyy, ww); +} + +template <> +int64_t DatePart::YearWeekOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"yearweek\" not recognized"); +} + +template <> +int64_t DatePart::YearWeekOperator::Operation(dtime_tz_t input) { + return YearWeekOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::EpochNanosecondsOperator::Operation(timestamp_t input) { + D_ASSERT(Timestamp::IsFinite(input)); + return Timestamp::GetEpochNanoSeconds(input); +} + +template <> +int64_t DatePart::EpochNanosecondsOperator::Operation(date_t input) { + D_ASSERT(Date::IsFinite(input)); + return Date::EpochNanoseconds(input); +} + +template <> +int64_t DatePart::EpochNanosecondsOperator::Operation(interval_t input) { + return Interval::GetNanoseconds(input); +} + +template <> +int64_t DatePart::EpochNanosecondsOperator::Operation(dtime_t input) { + return input.micros * Interval::NANOS_PER_MICRO; +} + +template <> +int64_t DatePart::EpochNanosecondsOperator::Operation(dtime_tz_t input) { + return DatePart::EpochNanosecondsOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::EpochMicrosecondsOperator::Operation(date_t input) { + return Date::EpochMicroseconds(input); +} + +template <> +int64_t DatePart::EpochMicrosecondsOperator::Operation(interval_t input) { + return Interval::GetMicro(input); +} + +template <> +int64_t DatePart::EpochMillisOperator::Operation(timestamp_t input) { + D_ASSERT(Timestamp::IsFinite(input)); + return Cast::Operation(input).value; +} + +template <> +int64_t DatePart::EpochMicrosecondsOperator::Operation(dtime_t input) { + return input.micros; +} + +template <> +int64_t DatePart::EpochMicrosecondsOperator::Operation(dtime_tz_t input) { + return DatePart::EpochMicrosecondsOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::EpochMillisOperator::Operation(date_t input) { + return Date::EpochMilliseconds(input); +} + +template <> +int64_t DatePart::EpochMillisOperator::Operation(interval_t input) { + return Interval::GetMilli(input); +} + +template <> +int64_t DatePart::EpochMillisOperator::Operation(dtime_t input) { + return input.micros / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DatePart::EpochMillisOperator::Operation(dtime_tz_t input) { + return DatePart::EpochMillisOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::NanosecondsOperator::Operation(timestamp_ns_t input) { + if (!Timestamp::IsFinite(input)) { + throw ConversionException("Can't get nanoseconds of infinite TIMESTAMP"); + } + date_t date; + dtime_t time; + int32_t nanos; + Timestamp::Convert(input, date, time, nanos); + // remove everything but the second & nanosecond part + return (time.micros % Interval::MICROS_PER_MINUTE) * Interval::NANOS_PER_MICRO + nanos; +} + +template <> +int64_t DatePart::MicrosecondsOperator::Operation(timestamp_t input) { + D_ASSERT(Timestamp::IsFinite(input)); + auto time = Timestamp::GetTime(input); + // remove everything but the second & microsecond part + return time.micros % Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MicrosecondsOperator::Operation(interval_t input) { + // remove everything but the second & microsecond part + return input.micros % Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MicrosecondsOperator::Operation(dtime_t input) { + // remove everything but the second & microsecond part + return input.micros % Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MicrosecondsOperator::Operation(dtime_tz_t input) { + return DatePart::MicrosecondsOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::MillisecondsOperator::Operation(timestamp_t input) { + D_ASSERT(Timestamp::IsFinite(input)); + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DatePart::MillisecondsOperator::Operation(interval_t input) { + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DatePart::MillisecondsOperator::Operation(dtime_t input) { + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DatePart::MillisecondsOperator::Operation(dtime_tz_t input) { + return DatePart::MillisecondsOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::SecondsOperator::Operation(timestamp_t input) { + D_ASSERT(Timestamp::IsFinite(input)); + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; +} + +template <> +int64_t DatePart::SecondsOperator::Operation(interval_t input) { + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; +} + +template <> +int64_t DatePart::SecondsOperator::Operation(dtime_t input) { + return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; +} + +template <> +int64_t DatePart::SecondsOperator::Operation(dtime_tz_t input) { + return DatePart::SecondsOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::MinutesOperator::Operation(timestamp_t input) { + D_ASSERT(Timestamp::IsFinite(input)); + auto time = Timestamp::GetTime(input); + // remove the hour part, and truncate to minutes + return (time.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MinutesOperator::Operation(interval_t input) { + // remove the hour part, and truncate to minutes + return (input.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MinutesOperator::Operation(dtime_t input) { + // remove the hour part, and truncate to minutes + return (input.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DatePart::MinutesOperator::Operation(dtime_tz_t input) { + return DatePart::MinutesOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::HoursOperator::Operation(timestamp_t input) { + D_ASSERT(Timestamp::IsFinite(input)); + return Timestamp::GetTime(input).micros / Interval::MICROS_PER_HOUR; +} + +template <> +int64_t DatePart::HoursOperator::Operation(interval_t input) { + return input.micros / Interval::MICROS_PER_HOUR; +} + +template <> +int64_t DatePart::HoursOperator::Operation(dtime_t input) { + return input.micros / Interval::MICROS_PER_HOUR; +} + +template <> +int64_t DatePart::HoursOperator::Operation(dtime_tz_t input) { + return DatePart::HoursOperator::Operation(input.time()); +} + +template <> +double DatePart::EpochOperator::Operation(timestamp_t input) { + D_ASSERT(Timestamp::IsFinite(input)); + return double(Timestamp::GetEpochMicroSeconds(input)) / double(Interval::MICROS_PER_SEC); +} + +template <> +double DatePart::EpochOperator::Operation(interval_t input) { + int64_t interval_years = input.months / Interval::MONTHS_PER_YEAR; + int64_t interval_days; + interval_days = Interval::DAYS_PER_YEAR * interval_years; + interval_days += Interval::DAYS_PER_MONTH * (input.months % Interval::MONTHS_PER_YEAR); + interval_days += input.days; + int64_t interval_epoch; + interval_epoch = interval_days * Interval::SECS_PER_DAY; + // we add 0.25 days per year to sort of account for leap days + interval_epoch += interval_years * (Interval::SECS_PER_DAY / 4); + return double(interval_epoch) + double(input.micros) / double(Interval::MICROS_PER_SEC); +} + +// TODO: We can't propagate interval statistics because we can't easily compare interval_t for order. +template <> +unique_ptr DatePart::EpochOperator::PropagateStatistics(ClientContext &context, + FunctionStatisticsInput &input) { + return nullptr; +} + +template <> +double DatePart::EpochOperator::Operation(dtime_t input) { + return double(input.micros) / double(Interval::MICROS_PER_SEC); +} + +template <> +double DatePart::EpochOperator::Operation(dtime_tz_t input) { + return DatePart::EpochOperator::Operation(input.time()); +} + +template <> +unique_ptr DatePart::EpochOperator::PropagateStatistics(ClientContext &context, + FunctionStatisticsInput &input) { + auto result = NumericStats::CreateEmpty(LogicalType::DOUBLE); + result.CopyValidity(input.child_stats[0]); + NumericStats::SetMin(result, Value::DOUBLE(0)); + NumericStats::SetMax(result, Value::DOUBLE(Interval::SECS_PER_DAY)); + return result.ToUnique(); +} + +template <> +int64_t DatePart::EraOperator::Operation(timestamp_t input) { + D_ASSERT(Timestamp::IsFinite(input)); + return EraOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +int64_t DatePart::EraOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"era\" not recognized"); +} + +template <> +int64_t DatePart::EraOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"era\" not recognized"); +} + +template <> +int64_t DatePart::EraOperator::Operation(dtime_tz_t input) { + return EraOperator::Operation(input.time()); +} + +template <> +int64_t DatePart::TimezoneOperator::Operation(date_t input) { + throw NotImplementedException("\"date\" units \"timezone\" not recognized"); +} + +template <> +int64_t DatePart::TimezoneOperator::Operation(interval_t input) { + throw NotImplementedException("\"interval\" units \"timezone\" not recognized"); +} + +template <> +int64_t DatePart::TimezoneOperator::Operation(dtime_tz_t input) { + return input.offset(); +} + +template <> +int64_t DatePart::TimezoneHourOperator::Operation(date_t input) { + throw NotImplementedException("\"date\" units \"timezone_hour\" not recognized"); +} + +template <> +int64_t DatePart::TimezoneHourOperator::Operation(interval_t input) { + throw NotImplementedException("\"interval\" units \"timezone_hour\" not recognized"); +} + +template <> +int64_t DatePart::TimezoneHourOperator::Operation(dtime_tz_t input) { + return input.offset() / Interval::SECS_PER_HOUR; +} + +template <> +int64_t DatePart::TimezoneMinuteOperator::Operation(date_t input) { + throw NotImplementedException("\"date\" units \"timezone_minute\" not recognized"); +} + +template <> +int64_t DatePart::TimezoneMinuteOperator::Operation(interval_t input) { + throw NotImplementedException("\"interval\" units \"timezone_minute\" not recognized"); +} + +template <> +int64_t DatePart::TimezoneMinuteOperator::Operation(dtime_tz_t input) { + return (input.offset() / Interval::SECS_PER_MINUTE) % Interval::MINS_PER_HOUR; +} + +template <> +double DatePart::JulianDayOperator::Operation(date_t input) { + return double(Date::ExtractJulianDay(input)); +} + +template <> +double DatePart::JulianDayOperator::Operation(interval_t input) { + throw NotImplementedException("interval units \"julian\" not recognized"); +} + +template <> +double DatePart::JulianDayOperator::Operation(dtime_t input) { + throw NotImplementedException("\"time\" units \"julian\" not recognized"); +} + +template <> +double DatePart::JulianDayOperator::Operation(dtime_tz_t input) { + return JulianDayOperator::Operation(input.time()); +} + +template <> +void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const dtime_t &input, + const idx_t idx, const part_mask_t mask) { + int64_t *part_data; + if (mask & TIME) { + const auto micros = MicrosecondsOperator::Operation(input); + part_data = HasPartValue(bigint_values, DatePartSpecifier::MICROSECONDS); + if (part_data) { + part_data[idx] = micros; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLISECONDS); + if (part_data) { + part_data[idx] = micros / Interval::MICROS_PER_MSEC; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::SECOND); + if (part_data) { + part_data[idx] = micros / Interval::MICROS_PER_SEC; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MINUTE); + if (part_data) { + part_data[idx] = MinutesOperator::Operation(input); + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::HOUR); + if (part_data) { + part_data[idx] = HoursOperator::Operation(input); + } + } + + if (mask & EPOCH) { + auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); + if (part_data) { + part_data[idx] = EpochOperator::Operation(input); + } + } + + if (mask & ZONE) { + part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE); + if (part_data) { + part_data[idx] = 0; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_HOUR); + if (part_data) { + part_data[idx] = 0; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_MINUTE); + if (part_data) { + part_data[idx] = 0; + } + } +} + +template <> +void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const dtime_tz_t &input, + const idx_t idx, const part_mask_t mask) { + int64_t *part_data; + if (mask & TIME) { + const auto micros = MicrosecondsOperator::Operation(input); + part_data = HasPartValue(bigint_values, DatePartSpecifier::MICROSECONDS); + if (part_data) { + part_data[idx] = micros; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLISECONDS); + if (part_data) { + part_data[idx] = micros / Interval::MICROS_PER_MSEC; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::SECOND); + if (part_data) { + part_data[idx] = micros / Interval::MICROS_PER_SEC; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MINUTE); + if (part_data) { + part_data[idx] = MinutesOperator::Operation(input); + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::HOUR); + if (part_data) { + part_data[idx] = HoursOperator::Operation(input); + } + } + + if (mask & EPOCH) { + auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); + if (part_data) { + part_data[idx] = EpochOperator::Operation(input); + } + } + + if (mask & ZONE) { + part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE); + if (part_data) { + part_data[idx] = TimezoneOperator::Operation(input); + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_HOUR); + if (part_data) { + part_data[idx] = TimezoneHourOperator::Operation(input); + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_MINUTE); + if (part_data) { + part_data[idx] = TimezoneMinuteOperator::Operation(input); + } + return; + } +} + +template <> +void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const timestamp_t &input, + const idx_t idx, const part_mask_t mask) { + D_ASSERT(Timestamp::IsFinite(input)); + date_t d; + dtime_t t; + Timestamp::Convert(input, d, t); + + // Both define epoch, and the correct value is the sum. + // So mask it out and compute it separately. + Operation(bigint_values, double_values, d, idx, mask & ~UnsafeNumericCast(EPOCH)); + Operation(bigint_values, double_values, t, idx, mask & ~UnsafeNumericCast(EPOCH)); + + if (mask & EPOCH) { + auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); + if (part_data) { + part_data[idx] = EpochOperator::Operation(input); + } + } + + if (mask & JD) { + auto part_data = HasPartValue(double_values, DatePartSpecifier::JULIAN_DAY); + if (part_data) { + part_data[idx] = JulianDayOperator::Operation(input); + } + } +} + +template <> +void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const interval_t &input, + const idx_t idx, const part_mask_t mask) { + int64_t *part_data; + if (mask & YMD) { + const auto mm = input.months % Interval::MONTHS_PER_YEAR; + part_data = HasPartValue(bigint_values, DatePartSpecifier::YEAR); + if (part_data) { + part_data[idx] = input.months / Interval::MONTHS_PER_YEAR; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MONTH); + if (part_data) { + part_data[idx] = mm; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::DAY); + if (part_data) { + part_data[idx] = input.days; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::DECADE); + if (part_data) { + part_data[idx] = input.months / Interval::MONTHS_PER_DECADE; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::CENTURY); + if (part_data) { + part_data[idx] = input.months / Interval::MONTHS_PER_CENTURY; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLENNIUM); + if (part_data) { + part_data[idx] = input.months / Interval::MONTHS_PER_MILLENIUM; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::QUARTER); + if (part_data) { + part_data[idx] = mm / Interval::MONTHS_PER_QUARTER + 1; + } + } + + if (mask & TIME) { + const auto micros = MicrosecondsOperator::Operation(input); + part_data = HasPartValue(bigint_values, DatePartSpecifier::MICROSECONDS); + if (part_data) { + part_data[idx] = micros; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLISECONDS); + if (part_data) { + part_data[idx] = micros / Interval::MICROS_PER_MSEC; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::SECOND); + if (part_data) { + part_data[idx] = micros / Interval::MICROS_PER_SEC; + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::MINUTE); + if (part_data) { + part_data[idx] = MinutesOperator::Operation(input); + } + part_data = HasPartValue(bigint_values, DatePartSpecifier::HOUR); + if (part_data) { + part_data[idx] = HoursOperator::Operation(input); + } + } + + if (mask & EPOCH) { + auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); + if (part_data) { + part_data[idx] = EpochOperator::Operation(input); + } + } +} + +template +static int64_t ExtractElement(DatePartSpecifier type, T element) { + switch (type) { + case DatePartSpecifier::YEAR: + return DatePart::YearOperator::template Operation(element); + case DatePartSpecifier::MONTH: + return DatePart::MonthOperator::template Operation(element); + case DatePartSpecifier::DAY: + return DatePart::DayOperator::template Operation(element); + case DatePartSpecifier::DECADE: + return DatePart::DecadeOperator::template Operation(element); + case DatePartSpecifier::CENTURY: + return DatePart::CenturyOperator::template Operation(element); + case DatePartSpecifier::MILLENNIUM: + return DatePart::MillenniumOperator::template Operation(element); + case DatePartSpecifier::QUARTER: + return DatePart::QuarterOperator::template Operation(element); + case DatePartSpecifier::DOW: + return DatePart::DayOfWeekOperator::template Operation(element); + case DatePartSpecifier::ISODOW: + return DatePart::ISODayOfWeekOperator::template Operation(element); + case DatePartSpecifier::DOY: + return DatePart::DayOfYearOperator::template Operation(element); + case DatePartSpecifier::WEEK: + return DatePart::WeekOperator::template Operation(element); + case DatePartSpecifier::ISOYEAR: + return DatePart::ISOYearOperator::template Operation(element); + case DatePartSpecifier::YEARWEEK: + return DatePart::YearWeekOperator::template Operation(element); + case DatePartSpecifier::MICROSECONDS: + return DatePart::MicrosecondsOperator::template Operation(element); + case DatePartSpecifier::MILLISECONDS: + return DatePart::MillisecondsOperator::template Operation(element); + case DatePartSpecifier::SECOND: + return DatePart::SecondsOperator::template Operation(element); + case DatePartSpecifier::MINUTE: + return DatePart::MinutesOperator::template Operation(element); + case DatePartSpecifier::HOUR: + return DatePart::HoursOperator::template Operation(element); + case DatePartSpecifier::ERA: + return DatePart::EraOperator::template Operation(element); + case DatePartSpecifier::TIMEZONE: + return DatePart::TimezoneOperator::template Operation(element); + case DatePartSpecifier::TIMEZONE_HOUR: + return DatePart::TimezoneHourOperator::template Operation(element); + case DatePartSpecifier::TIMEZONE_MINUTE: + return DatePart::TimezoneMinuteOperator::template Operation(element); + default: + throw NotImplementedException("Specifier type not implemented for DATEPART"); + } +} + +template +static void DatePartFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + auto &spec_arg = args.data[0]; + auto &date_arg = args.data[1]; + + BinaryExecutor::ExecuteWithNulls( + spec_arg, date_arg, result, args.size(), [&](string_t specifier, T date, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(date)) { + return ExtractElement(GetDatePartSpecifier(specifier.GetString()), date); + } else { + mask.SetInvalid(idx); + return int64_t(0); + } + }); +} + +static unique_ptr DatePartBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // If we are only looking for Julian Days for timestamps, + // then return doubles. + if (arguments[0]->HasParameter() || !arguments[0]->IsFoldable()) { + return nullptr; + } + + Value part_value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); + const auto part_name = part_value.ToString(); + switch (GetDatePartSpecifier(part_name)) { + case DatePartSpecifier::JULIAN_DAY: + arguments.erase(arguments.begin()); + bound_function.arguments.erase(bound_function.arguments.begin()); + bound_function.name = "julian"; + bound_function.return_type = LogicalType::DOUBLE; + switch (arguments[0]->return_type.id()) { + case LogicalType::TIMESTAMP: + case LogicalType::TIMESTAMP_S: + case LogicalType::TIMESTAMP_MS: + case LogicalType::TIMESTAMP_NS: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::JulianDayOperator::template PropagateStatistics; + break; + case LogicalType::DATE: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::JulianDayOperator::template PropagateStatistics; + break; + default: + throw BinderException("%s can only take DATE or TIMESTAMP arguments", bound_function.name); + } + break; + case DatePartSpecifier::EPOCH: + arguments.erase(arguments.begin()); + bound_function.arguments.erase(bound_function.arguments.begin()); + bound_function.name = "epoch"; + bound_function.return_type = LogicalType::DOUBLE; + switch (arguments[0]->return_type.id()) { + case LogicalType::TIMESTAMP: + case LogicalType::TIMESTAMP_S: + case LogicalType::TIMESTAMP_MS: + case LogicalType::TIMESTAMP_NS: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + break; + case LogicalType::DATE: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + break; + case LogicalType::INTERVAL: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + break; + case LogicalType::TIME: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + break; + case LogicalType::TIME_TZ: + bound_function.function = DatePart::UnaryFunction; + bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + break; + default: + throw BinderException("%s can only take temporal arguments", bound_function.name); + } + break; + default: + break; + } + + return nullptr; +} + +template +ScalarFunctionSet GetGenericDatePartFunction(scalar_function_t date_func, scalar_function_t ts_func, + scalar_function_t interval_func, function_statistics_t date_stats, + function_statistics_t ts_stats) { + ScalarFunctionSet operator_set; + operator_set.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BIGINT, std::move(date_func), nullptr, + nullptr, date_stats, DATE_CACHE)); + operator_set.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BIGINT, std::move(ts_func), nullptr, + nullptr, ts_stats, DATE_CACHE)); + operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, LogicalType::BIGINT, std::move(interval_func))); + for (auto &func : operator_set.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return operator_set; +} + +template +static ScalarFunctionSet GetDatePartFunction() { + return GetGenericDatePartFunction( + DatePart::UnaryFunction, DatePart::UnaryFunction, + ScalarFunction::UnaryFunction, OP::template PropagateStatistics, + OP::template PropagateStatistics); +} + +ScalarFunctionSet GetGenericTimePartFunction(const LogicalType &result_type, scalar_function_t date_func, + scalar_function_t ts_func, scalar_function_t interval_func, + scalar_function_t time_func, scalar_function_t timetz_func, + function_statistics_t date_stats, function_statistics_t ts_stats, + function_statistics_t time_stats, function_statistics_t timetz_stats) { + ScalarFunctionSet operator_set; + operator_set.AddFunction( + ScalarFunction({LogicalType::DATE}, result_type, std::move(date_func), nullptr, nullptr, date_stats)); + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP}, result_type, std::move(ts_func), nullptr, nullptr, ts_stats)); + operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, result_type, std::move(interval_func))); + operator_set.AddFunction( + ScalarFunction({LogicalType::TIME}, result_type, std::move(time_func), nullptr, nullptr, time_stats)); + operator_set.AddFunction( + ScalarFunction({LogicalType::TIME_TZ}, result_type, std::move(timetz_func), nullptr, nullptr, timetz_stats)); + return operator_set; +} + +template +static ScalarFunctionSet GetTimePartFunction(const LogicalType &result_type = LogicalType::BIGINT) { + return GetGenericTimePartFunction( + result_type, DatePart::UnaryFunction, DatePart::UnaryFunction, + ScalarFunction::UnaryFunction, ScalarFunction::UnaryFunction, + ScalarFunction::UnaryFunction, OP::template PropagateStatistics, + OP::template PropagateStatistics, OP::template PropagateStatistics, + OP::template PropagateStatistics); +} + +struct LastDayOperator { + template + static inline TR Operation(TA input) { + int32_t yyyy, mm, dd; + Date::Convert(input, yyyy, mm, dd); + yyyy += (mm / 12); + mm %= 12; + ++mm; + return Date::FromDate(yyyy, mm, 1) - 1; + } +}; + +template <> +date_t LastDayOperator::Operation(timestamp_t input) { + return LastDayOperator::Operation(Timestamp::GetDate(input)); +} + +struct MonthNameOperator { + template + static inline TR Operation(TA input) { + return Date::MONTH_NAMES[DatePart::MonthOperator::Operation(input) - 1]; + } +}; + +struct DayNameOperator { + template + static inline TR Operation(TA input) { + return Date::DAY_NAMES[DatePart::DayOfWeekOperator::Operation(input)]; + } +}; + +struct StructDatePart { + using part_codes_t = vector; + + struct BindData : public VariableReturnBindData { + part_codes_t part_codes; + + explicit BindData(const LogicalType &stype, const part_codes_t &part_codes_p) + : VariableReturnBindData(stype), part_codes(part_codes_p) { + } + + unique_ptr Copy() const override { + return make_uniq(stype, part_codes); + } + }; + + static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // collect names and deconflict, construct return type + if (arguments[0]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[0]->IsFoldable()) { + throw BinderException("%s can only take constant lists of part names", bound_function.name); + } + + case_insensitive_set_t name_collision_set; + child_list_t struct_children; + part_codes_t part_codes; + + Value parts_list = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); + if (parts_list.type().id() == LogicalTypeId::LIST) { + auto &list_children = ListValue::GetChildren(parts_list); + if (list_children.empty()) { + throw BinderException("%s requires non-empty lists of part names", bound_function.name); + } + for (const auto &part_value : list_children) { + if (part_value.IsNull()) { + throw BinderException("NULL struct entry name in %s", bound_function.name); + } + const auto part_name = part_value.ToString(); + const auto part_code = GetDateTypePartSpecifier(part_name, arguments[1]->return_type); + if (name_collision_set.find(part_name) != name_collision_set.end()) { + throw BinderException("Duplicate struct entry name \"%s\" in %s", part_name, bound_function.name); + } + name_collision_set.insert(part_name); + part_codes.emplace_back(part_code); + const auto part_type = IsBigintDatepart(part_code) ? LogicalType::BIGINT : LogicalType::DOUBLE; + struct_children.emplace_back(make_pair(part_name, part_type)); + } + } else { + throw BinderException("%s can only take constant lists of part names", bound_function.name); + } + + Function::EraseArgument(bound_function, arguments, 0); + bound_function.return_type = LogicalType::STRUCT(struct_children); + return make_uniq(bound_function.return_type, part_codes); + } + + template + static void Function(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + D_ASSERT(args.ColumnCount() == 1); + + const auto count = args.size(); + Vector &input = args.data[0]; + + // Type counts + const auto BIGINT_COUNT = size_t(DatePartSpecifier::BEGIN_DOUBLE) - size_t(DatePartSpecifier::BEGIN_BIGINT); + const auto DOUBLE_COUNT = size_t(DatePartSpecifier::BEGIN_INVALID) - size_t(DatePartSpecifier::BEGIN_DOUBLE); + DatePart::StructOperator::bigint_vec bigint_values(BIGINT_COUNT, nullptr); + DatePart::StructOperator::double_vec double_values(DOUBLE_COUNT, nullptr); + const auto part_mask = DatePart::StructOperator::GetMask(info.part_codes); + + auto &child_entries = StructVector::GetEntries(result); + + // The first computer of a part "owns" it + // and other requestors just reference the owner + vector owners(int(DatePartSpecifier::JULIAN_DAY) + 1, child_entries.size()); + for (size_t col = 0; col < child_entries.size(); ++col) { + const auto part_index = size_t(info.part_codes[col]); + if (owners[part_index] == child_entries.size()) { + owners[part_index] = col; + } + } + + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + + if (ConstantVector::IsNull(input)) { + ConstantVector::SetNull(result, true); + } else { + ConstantVector::SetNull(result, false); + for (size_t col = 0; col < child_entries.size(); ++col) { + auto &child_entry = child_entries[col]; + ConstantVector::SetNull(*child_entry, false); + const auto part_index = size_t(info.part_codes[col]); + if (owners[part_index] == col) { + if (IsBigintDatepart(info.part_codes[col])) { + bigint_values[part_index - size_t(DatePartSpecifier::BEGIN_BIGINT)] = + ConstantVector::GetData(*child_entry); + } else { + double_values[part_index - size_t(DatePartSpecifier::BEGIN_DOUBLE)] = + ConstantVector::GetData(*child_entry); + } + } + } + auto tdata = ConstantVector::GetData(input); + if (Value::IsFinite(tdata[0])) { + DatePart::StructOperator::Operation(bigint_values, double_values, tdata[0], 0, part_mask); + } else { + for (auto &child_entry : child_entries) { + ConstantVector::SetNull(*child_entry, true); + } + } + } + } else { + UnifiedVectorFormat rdata; + input.ToUnifiedFormat(count, rdata); + + const auto &arg_valid = rdata.validity; + auto tdata = UnifiedVectorFormat::GetData(rdata); + + // Start with a valid flat vector + result.SetVectorType(VectorType::FLAT_VECTOR); + auto &res_valid = FlatVector::Validity(result); + if (res_valid.GetData()) { + res_valid.SetAllValid(count); + } + + // Start with valid children + for (size_t col = 0; col < child_entries.size(); ++col) { + auto &child_entry = child_entries[col]; + child_entry->SetVectorType(VectorType::FLAT_VECTOR); + auto &child_validity = FlatVector::Validity(*child_entry); + if (child_validity.GetData()) { + child_validity.SetAllValid(count); + } + + // Pre-multiplex + const auto part_index = size_t(info.part_codes[col]); + if (owners[part_index] == col) { + if (IsBigintDatepart(info.part_codes[col])) { + bigint_values[part_index - size_t(DatePartSpecifier::BEGIN_BIGINT)] = + FlatVector::GetData(*child_entry); + } else { + double_values[part_index - size_t(DatePartSpecifier::BEGIN_DOUBLE)] = + FlatVector::GetData(*child_entry); + } + } + } + + for (idx_t i = 0; i < count; ++i) { + const auto idx = rdata.sel->get_index(i); + if (arg_valid.RowIsValid(idx)) { + if (Value::IsFinite(tdata[idx])) { + DatePart::StructOperator::Operation(bigint_values, double_values, tdata[idx], i, part_mask); + } else { + for (auto &child_entry : child_entries) { + FlatVector::Validity(*child_entry).SetInvalid(i); + } + } + } else { + res_valid.SetInvalid(i); + for (auto &child_entry : child_entries) { + FlatVector::Validity(*child_entry).SetInvalid(i); + } + } + } + } + + // Reference any duplicate parts + for (size_t col = 0; col < child_entries.size(); ++col) { + const auto part_index = size_t(info.part_codes[col]); + const auto owner = owners[part_index]; + if (owner != col) { + child_entries[col]->Reference(*child_entries[owner]); + } + } + + result.Verify(count); + } + + static void SerializeFunction(Serializer &serializer, const optional_ptr bind_data_p, + const ScalarFunction &function) { + D_ASSERT(bind_data_p); + auto &info = bind_data_p->Cast(); + serializer.WriteProperty(100, "stype", info.stype); + serializer.WriteProperty(101, "part_codes", info.part_codes); + } + + static unique_ptr DeserializeFunction(Deserializer &deserializer, ScalarFunction &bound_function) { + auto stype = deserializer.ReadProperty(100, "stype"); + auto part_codes = deserializer.ReadProperty>(101, "part_codes"); + return make_uniq(std::move(stype), std::move(part_codes)); + } + + template + static ScalarFunction GetFunction(const LogicalType &temporal_type) { + auto part_type = LogicalType::LIST(LogicalType::VARCHAR); + auto result_type = LogicalType::STRUCT({}); + ScalarFunction result({part_type, temporal_type}, result_type, Function, Bind); + result.serialize = SerializeFunction; + result.deserialize = DeserializeFunction; + return result; + } +}; +template +ScalarFunctionSet GetCachedDatepartFunction() { + return GetGenericDatePartFunction>( + DatePartCachedFunction, DatePartCachedFunction, + ScalarFunction::UnaryFunction, OP::template PropagateStatistics, + OP::template PropagateStatistics); +} + +ScalarFunctionSet YearFun::GetFunctions() { + return GetCachedDatepartFunction(); +} + +ScalarFunctionSet MonthFun::GetFunctions() { + return GetCachedDatepartFunction(); +} + +ScalarFunctionSet DayFun::GetFunctions() { + return GetCachedDatepartFunction(); +} + +ScalarFunctionSet DecadeFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet CenturyFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet MillenniumFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet QuarterFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet DayOfWeekFun::GetFunctions() { + auto set = GetDatePartFunction(); + for (auto &func : set.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return set; +} + +ScalarFunctionSet ISODayOfWeekFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet DayOfYearFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet WeekFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet ISOYearFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet EraFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet TimezoneFun::GetFunctions() { + auto operator_set = GetDatePartFunction(); + + // PG also defines timezone(INTERVAL, TIME_TZ) => TIME_TZ + ScalarFunction function({LogicalType::INTERVAL, LogicalType::TIME_TZ}, LogicalType::TIME_TZ, + DatePart::TimezoneOperator::BinaryFunction); + + operator_set.AddFunction(function); + + for (auto &func : operator_set.functions) { + BaseScalarFunction::SetReturnsError(func); + } + + return operator_set; +} + +ScalarFunctionSet TimezoneHourFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet TimezoneMinuteFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet EpochFun::GetFunctions() { + return GetTimePartFunction(LogicalType::DOUBLE); +} + +struct GetEpochNanosOperator { + static int64_t Operation(timestamp_ns_t timestamp) { + return Timestamp::GetEpochNanoSeconds(timestamp); + } +}; + +static void ExecuteGetNanosFromTimestampNs(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + + auto func = GetEpochNanosOperator::Operation; + UnaryExecutor::Execute(input.data[0], result, input.size(), func); +} + +ScalarFunctionSet EpochNsFun::GetFunctions() { + using OP = DatePart::EpochNanosecondsOperator; + auto operator_set = GetTimePartFunction(); + + // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU + auto tstz_func = DatePart::UnaryFunction; + auto tstz_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); + + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_NS}, LogicalType::BIGINT, ExecuteGetNanosFromTimestampNs)); + return operator_set; +} + +ScalarFunctionSet EpochUsFun::GetFunctions() { + using OP = DatePart::EpochMicrosecondsOperator; + auto operator_set = GetTimePartFunction(); + + // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU + auto tstz_func = DatePart::UnaryFunction; + auto tstz_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); + return operator_set; +} + +ScalarFunctionSet EpochMsFun::GetFunctions() { + using OP = DatePart::EpochMillisOperator; + auto operator_set = GetTimePartFunction(); + + // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU + auto tstz_func = DatePart::UnaryFunction; + auto tstz_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); + + // Legacy inverse BIGINT => TIMESTAMP + operator_set.AddFunction( + ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP, DatePart::EpochMillisOperator::Inverse)); + + return operator_set; +} + +ScalarFunctionSet NanosecondsFun::GetFunctions() { + using OP = DatePart::NanosecondsOperator; + using TR = int64_t; + const LogicalType &result_type = LogicalType::BIGINT; + auto operator_set = GetTimePartFunction(); + + auto ns_func = DatePart::UnaryFunction; + auto ns_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_NS}, result_type, ns_func, nullptr, nullptr, ns_stats)); + + // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU + auto tstz_func = DatePart::UnaryFunction; + auto tstz_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); + + return operator_set; +} + +ScalarFunctionSet MicrosecondsFun::GetFunctions() { + return GetTimePartFunction(); +} + +ScalarFunctionSet MillisecondsFun::GetFunctions() { + return GetTimePartFunction(); +} + +ScalarFunctionSet SecondsFun::GetFunctions() { + return GetTimePartFunction(); +} + +ScalarFunctionSet MinutesFun::GetFunctions() { + return GetTimePartFunction(); +} + +ScalarFunctionSet HoursFun::GetFunctions() { + return GetTimePartFunction(); +} + +ScalarFunctionSet YearWeekFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet DayOfMonthFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet WeekDayFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet WeekOfYearFun::GetFunctions() { + return GetDatePartFunction(); +} + +ScalarFunctionSet LastDayFun::GetFunctions() { + ScalarFunctionSet last_day; + last_day.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::DATE, + DatePart::UnaryFunction)); + last_day.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::DATE, + DatePart::UnaryFunction)); + return last_day; +} + +ScalarFunctionSet MonthNameFun::GetFunctions() { + ScalarFunctionSet monthname; + monthname.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::VARCHAR, + DatePart::UnaryFunction)); + monthname.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::VARCHAR, + DatePart::UnaryFunction)); + return monthname; +} + +ScalarFunctionSet DayNameFun::GetFunctions() { + ScalarFunctionSet dayname; + dayname.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::VARCHAR, + DatePart::UnaryFunction)); + dayname.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::VARCHAR, + DatePart::UnaryFunction)); + return dayname; +} + +ScalarFunctionSet JulianDayFun::GetFunctions() { + using OP = DatePart::JulianDayOperator; + + ScalarFunctionSet operator_set; + auto date_func = DatePart::UnaryFunction; + auto date_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::DATE}, LogicalType::DOUBLE, date_func, nullptr, nullptr, date_stats)); + auto ts_func = DatePart::UnaryFunction; + auto ts_stats = OP::template PropagateStatistics; + operator_set.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::DOUBLE, ts_func, nullptr, nullptr, ts_stats)); + + return operator_set; +} + +ScalarFunctionSet DatePartFun::GetFunctions() { + ScalarFunctionSet date_part; + date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::BIGINT, + DatePartFunction, DatePartBind)); + date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::BIGINT, + DatePartFunction, DatePartBind)); + date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME}, LogicalType::BIGINT, + DatePartFunction, DatePartBind)); + date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::INTERVAL}, LogicalType::BIGINT, + DatePartFunction, DatePartBind)); + date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME_TZ}, LogicalType::BIGINT, + DatePartFunction, DatePartBind)); + + // struct variants + date_part.AddFunction(StructDatePart::GetFunction(LogicalType::DATE)); + date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIMESTAMP)); + date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIME)); + date_part.AddFunction(StructDatePart::GetFunction(LogicalType::INTERVAL)); + date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIME_TZ)); + + for (auto &func : date_part.functions) { + BaseScalarFunction::SetReturnsError(func); + } + + return date_part; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp b/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp new file mode 100644 index 00000000..acfb2c79 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp @@ -0,0 +1,454 @@ +#include "core_functions/scalar/date_functions.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +struct DateSub { + static int64_t SubtractMicros(timestamp_t startdate, timestamp_t enddate) { + const auto start = Timestamp::GetEpochMicroSeconds(startdate); + const auto end = Timestamp::GetEpochMicroSeconds(enddate); + return SubtractOperatorOverflowCheck::Operation(end, start); + } + + template + static inline void BinaryExecute(Vector &left, Vector &right, Vector &result, idx_t count) { + BinaryExecutor::ExecuteWithNulls( + left, right, result, count, [&](TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { + return OP::template Operation(startdate, enddate); + } else { + mask.SetInvalid(idx); + return TR(); + } + }); + } + + struct MonthOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + + if (start_ts > end_ts) { + return -MonthOperator::Operation(end_ts, start_ts); + } + // The number of complete months depends on whether end_ts is on the last day of the month. + date_t end_date; + dtime_t end_time; + Timestamp::Convert(end_ts, end_date, end_time); + + int32_t yyyy, mm, dd; + Date::Convert(end_date, yyyy, mm, dd); + const auto end_days = Date::MonthDays(yyyy, mm); + if (end_days == dd) { + // Now check whether the start day is after the end day + date_t start_date; + dtime_t start_time; + Timestamp::Convert(start_ts, start_date, start_time); + Date::Convert(start_date, yyyy, mm, dd); + if (dd > end_days || (dd == end_days && start_time < end_time)) { + // Move back to the same time on the last day of the (shorter) end month + start_date = Date::FromDate(yyyy, mm, end_days); + start_ts = Timestamp::FromDatetime(start_date, start_time); + } + } + + // Our interval difference will now give the correct result. + // Note that PG gives different interval subtraction results, + // so if we change this we will have to reimplement. + return Interval::GetAge(end_ts, start_ts).months; + } + }; + + struct QuarterOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_QUARTER; + } + }; + + struct YearOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_YEAR; + } + }; + + struct DecadeOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_DECADE; + } + }; + + struct CenturyOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_CENTURY; + } + }; + + struct MilleniumOperator { + template + static inline TR Operation(TA start_ts, TB end_ts) { + return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_MILLENIUM; + } + }; + + struct DayOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_DAY; + } + }; + + struct WeekOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_WEEK; + } + }; + + struct MicrosecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate); + } + }; + + struct MillisecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_MSEC; + } + }; + + struct SecondsOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_SEC; + } + }; + + struct MinutesOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_MINUTE; + } + }; + + struct HoursOperator { + template + static inline TR Operation(TA startdate, TB enddate) { + return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_HOUR; + } + }; +}; + +// DATE specialisations +template <> +int64_t DateSub::YearOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return YearOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::MonthOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return MonthOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::DayOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return DayOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::DecadeOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return DecadeOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::CenturyOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return CenturyOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::MilleniumOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return MilleniumOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::QuarterOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return QuarterOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::WeekOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return WeekOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::MicrosecondsOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return MicrosecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::MillisecondsOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return MillisecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::SecondsOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return SecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::MinutesOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return MinutesOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +template <> +int64_t DateSub::HoursOperator::Operation(date_t startdate, date_t enddate) { + dtime_t t0(0); + return HoursOperator::Operation(Timestamp::FromDatetime(startdate, t0), + Timestamp::FromDatetime(enddate, t0)); +} + +// TIME specialisations +template <> +int64_t DateSub::YearOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"year\" not recognized"); +} + +template <> +int64_t DateSub::MonthOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"month\" not recognized"); +} + +template <> +int64_t DateSub::DayOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"day\" not recognized"); +} + +template <> +int64_t DateSub::DecadeOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"decade\" not recognized"); +} + +template <> +int64_t DateSub::CenturyOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"century\" not recognized"); +} + +template <> +int64_t DateSub::MilleniumOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"millennium\" not recognized"); +} + +template <> +int64_t DateSub::QuarterOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"quarter\" not recognized"); +} + +template <> +int64_t DateSub::WeekOperator::Operation(dtime_t startdate, dtime_t enddate) { + throw NotImplementedException("\"time\" units \"week\" not recognized"); +} + +template <> +int64_t DateSub::MicrosecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return enddate.micros - startdate.micros; +} + +template <> +int64_t DateSub::MillisecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return (enddate.micros - startdate.micros) / Interval::MICROS_PER_MSEC; +} + +template <> +int64_t DateSub::SecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { + return (enddate.micros - startdate.micros) / Interval::MICROS_PER_SEC; +} + +template <> +int64_t DateSub::MinutesOperator::Operation(dtime_t startdate, dtime_t enddate) { + return (enddate.micros - startdate.micros) / Interval::MICROS_PER_MINUTE; +} + +template <> +int64_t DateSub::HoursOperator::Operation(dtime_t startdate, dtime_t enddate) { + return (enddate.micros - startdate.micros) / Interval::MICROS_PER_HOUR; +} + +template +static int64_t SubtractDateParts(DatePartSpecifier type, TA startdate, TB enddate) { + switch (type) { + case DatePartSpecifier::YEAR: + case DatePartSpecifier::ISOYEAR: + return DateSub::YearOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MONTH: + return DateSub::MonthOperator::template Operation(startdate, enddate); + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + return DateSub::DayOperator::template Operation(startdate, enddate); + case DatePartSpecifier::DECADE: + return DateSub::DecadeOperator::template Operation(startdate, enddate); + case DatePartSpecifier::CENTURY: + return DateSub::CenturyOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MILLENNIUM: + return DateSub::MilleniumOperator::template Operation(startdate, enddate); + case DatePartSpecifier::QUARTER: + return DateSub::QuarterOperator::template Operation(startdate, enddate); + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + return DateSub::WeekOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MICROSECONDS: + return DateSub::MicrosecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MILLISECONDS: + return DateSub::MillisecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + return DateSub::SecondsOperator::template Operation(startdate, enddate); + case DatePartSpecifier::MINUTE: + return DateSub::MinutesOperator::template Operation(startdate, enddate); + case DatePartSpecifier::HOUR: + return DateSub::HoursOperator::template Operation(startdate, enddate); + default: + throw NotImplementedException("Specifier type not implemented for DATESUB"); + } +} + +struct DateSubTernaryOperator { + template + static inline TR Operation(TS part, TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { + if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { + return SubtractDateParts(GetDatePartSpecifier(part.GetString()), startdate, enddate); + } else { + mask.SetInvalid(idx); + return TR(); + } + } +}; + +template +static void DateSubBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, Vector &result, idx_t count) { + switch (type) { + case DatePartSpecifier::YEAR: + case DatePartSpecifier::ISOYEAR: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MONTH: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::DECADE: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::CENTURY: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MILLENNIUM: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::QUARTER: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MICROSECONDS: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MILLISECONDS: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::MINUTE: + DateSub::BinaryExecute(left, right, result, count); + break; + case DatePartSpecifier::HOUR: + DateSub::BinaryExecute(left, right, result, count); + break; + default: + throw NotImplementedException("Specifier type not implemented for DATESUB"); + } +} + +template +static void DateSubFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3); + auto &part_arg = args.data[0]; + auto &start_arg = args.data[1]; + auto &end_arg = args.data[2]; + + if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // Common case of constant part. + if (ConstantVector::IsNull(part_arg)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); + DateSubBinaryExecutor(type, start_arg, end_arg, result, args.size()); + } + } else { + TernaryExecutor::ExecuteWithNulls( + part_arg, start_arg, end_arg, result, args.size(), + DateSubTernaryOperator::Operation); + } +} + +ScalarFunctionSet DateSubFun::GetFunctions() { + ScalarFunctionSet date_sub("date_sub"); + date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE, LogicalType::DATE}, + LogicalType::BIGINT, DateSubFunction)); + date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, + LogicalType::BIGINT, DateSubFunction)); + date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME, LogicalType::TIME}, + LogicalType::BIGINT, DateSubFunction)); + return date_sub; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp b/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp new file mode 100644 index 00000000..cb54e30d --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp @@ -0,0 +1,737 @@ +#include "core_functions/scalar/date_functions.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +struct DateTrunc { + template + static inline TR UnaryFunction(TA input) { + if (Value::IsFinite(input)) { + return OP::template Operation(input); + } else { + return Cast::template Operation(input); + } + } + + template + static inline void UnaryExecute(Vector &left, Vector &result, idx_t count) { + UnaryExecutor::Execute(left, result, count, UnaryFunction); + } + + struct MillenniumOperator { + template + static inline TR Operation(TA input) { + return Date::FromDate((Date::ExtractYear(input) / 1000) * 1000, 1, 1); + } + }; + + struct CenturyOperator { + template + static inline TR Operation(TA input) { + return Date::FromDate((Date::ExtractYear(input) / 100) * 100, 1, 1); + } + }; + + struct DecadeOperator { + template + static inline TR Operation(TA input) { + return Date::FromDate((Date::ExtractYear(input) / 10) * 10, 1, 1); + } + }; + + struct YearOperator { + template + static inline TR Operation(TA input) { + return Date::FromDate(Date::ExtractYear(input), 1, 1); + } + }; + + struct QuarterOperator { + template + static inline TR Operation(TA input) { + int32_t yyyy, mm, dd; + Date::Convert(input, yyyy, mm, dd); + mm = 1 + (((mm - 1) / 3) * 3); + return Date::FromDate(yyyy, mm, 1); + } + }; + + struct MonthOperator { + template + static inline TR Operation(TA input) { + return Date::FromDate(Date::ExtractYear(input), Date::ExtractMonth(input), 1); + } + }; + + struct WeekOperator { + template + static inline TR Operation(TA input) { + return Date::GetMondayOfCurrentWeek(input); + } + }; + + struct ISOYearOperator { + template + static inline TR Operation(TA input) { + date_t date = Date::GetMondayOfCurrentWeek(input); + date.days -= (Date::ExtractISOWeekNumber(date) - 1) * Interval::DAYS_PER_WEEK; + + return date; + } + }; + + struct DayOperator { + template + static inline TR Operation(TA input) { + return input; + } + }; + + struct HourOperator { + template + static inline TR Operation(TA input) { + int32_t hour, min, sec, micros; + date_t date; + dtime_t time; + Timestamp::Convert(input, date, time); + Time::Convert(time, hour, min, sec, micros); + return Timestamp::FromDatetime(date, Time::FromTime(hour, 0, 0, 0)); + } + }; + + struct MinuteOperator { + template + static inline TR Operation(TA input) { + int32_t hour, min, sec, micros; + date_t date; + dtime_t time; + Timestamp::Convert(input, date, time); + Time::Convert(time, hour, min, sec, micros); + return Timestamp::FromDatetime(date, Time::FromTime(hour, min, 0, 0)); + } + }; + + struct SecondOperator { + template + static inline TR Operation(TA input) { + int32_t hour, min, sec, micros; + date_t date; + dtime_t time; + Timestamp::Convert(input, date, time); + Time::Convert(time, hour, min, sec, micros); + return Timestamp::FromDatetime(date, Time::FromTime(hour, min, sec, 0)); + } + }; + + struct MillisecondOperator { + template + static inline TR Operation(TA input) { + int32_t hour, min, sec, micros; + date_t date; + dtime_t time; + Timestamp::Convert(input, date, time); + Time::Convert(time, hour, min, sec, micros); + micros -= UnsafeNumericCast(micros % Interval::MICROS_PER_MSEC); + return Timestamp::FromDatetime(date, Time::FromTime(hour, min, sec, micros)); + } + }; + + struct MicrosecondOperator { + template + static inline TR Operation(TA input) { + return input; + } + }; +}; + +// DATE specialisations +template <> +date_t DateTrunc::MillenniumOperator::Operation(timestamp_t input) { + return MillenniumOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::MillenniumOperator::Operation(date_t input) { + return Timestamp::FromDatetime(MillenniumOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::MillenniumOperator::Operation(timestamp_t input) { + return MillenniumOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::CenturyOperator::Operation(timestamp_t input) { + return CenturyOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::CenturyOperator::Operation(date_t input) { + return Timestamp::FromDatetime(CenturyOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::CenturyOperator::Operation(timestamp_t input) { + return CenturyOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::DecadeOperator::Operation(timestamp_t input) { + return DecadeOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::DecadeOperator::Operation(date_t input) { + return Timestamp::FromDatetime(DecadeOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::DecadeOperator::Operation(timestamp_t input) { + return DecadeOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::YearOperator::Operation(timestamp_t input) { + return YearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::YearOperator::Operation(date_t input) { + return Timestamp::FromDatetime(YearOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::YearOperator::Operation(timestamp_t input) { + return YearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::QuarterOperator::Operation(timestamp_t input) { + return QuarterOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::QuarterOperator::Operation(date_t input) { + return Timestamp::FromDatetime(QuarterOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::QuarterOperator::Operation(timestamp_t input) { + return QuarterOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::MonthOperator::Operation(timestamp_t input) { + return MonthOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::MonthOperator::Operation(date_t input) { + return Timestamp::FromDatetime(MonthOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::MonthOperator::Operation(timestamp_t input) { + return MonthOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::WeekOperator::Operation(timestamp_t input) { + return WeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::WeekOperator::Operation(date_t input) { + return Timestamp::FromDatetime(WeekOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::WeekOperator::Operation(timestamp_t input) { + return WeekOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::ISOYearOperator::Operation(timestamp_t input) { + return ISOYearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::ISOYearOperator::Operation(date_t input) { + return Timestamp::FromDatetime(ISOYearOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::ISOYearOperator::Operation(timestamp_t input) { + return ISOYearOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::DayOperator::Operation(timestamp_t input) { + return DayOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +timestamp_t DateTrunc::DayOperator::Operation(date_t input) { + return Timestamp::FromDatetime(DayOperator::Operation(input), dtime_t(0)); +} + +template <> +timestamp_t DateTrunc::DayOperator::Operation(timestamp_t input) { + return DayOperator::Operation(Timestamp::GetDate(input)); +} + +template <> +date_t DateTrunc::HourOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +timestamp_t DateTrunc::HourOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +date_t DateTrunc::HourOperator::Operation(timestamp_t input) { + return Timestamp::GetDate(HourOperator::Operation(input)); +} + +template <> +date_t DateTrunc::MinuteOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +timestamp_t DateTrunc::MinuteOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +date_t DateTrunc::MinuteOperator::Operation(timestamp_t input) { + return Timestamp::GetDate(HourOperator::Operation(input)); +} + +template <> +date_t DateTrunc::SecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +timestamp_t DateTrunc::SecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +date_t DateTrunc::SecondOperator::Operation(timestamp_t input) { + return Timestamp::GetDate(DayOperator::Operation(input)); +} + +template <> +date_t DateTrunc::MillisecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +timestamp_t DateTrunc::MillisecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +date_t DateTrunc::MillisecondOperator::Operation(timestamp_t input) { + return Timestamp::GetDate(MillisecondOperator::Operation(input)); +} + +template <> +date_t DateTrunc::MicrosecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +timestamp_t DateTrunc::MicrosecondOperator::Operation(date_t input) { + return DayOperator::Operation(input); +} + +template <> +date_t DateTrunc::MicrosecondOperator::Operation(timestamp_t input) { + return Timestamp::GetDate(MicrosecondOperator::Operation(input)); +} + +// INTERVAL specialisations +template <> +interval_t DateTrunc::MillenniumOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + input.months = (input.months / Interval::MONTHS_PER_MILLENIUM) * Interval::MONTHS_PER_MILLENIUM; + return input; +} + +template <> +interval_t DateTrunc::CenturyOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + input.months = (input.months / Interval::MONTHS_PER_CENTURY) * Interval::MONTHS_PER_CENTURY; + return input; +} + +template <> +interval_t DateTrunc::DecadeOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + input.months = (input.months / Interval::MONTHS_PER_DECADE) * Interval::MONTHS_PER_DECADE; + return input; +} + +template <> +interval_t DateTrunc::YearOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + input.months = (input.months / Interval::MONTHS_PER_YEAR) * Interval::MONTHS_PER_YEAR; + return input; +} + +template <> +interval_t DateTrunc::QuarterOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + input.months = (input.months / Interval::MONTHS_PER_QUARTER) * Interval::MONTHS_PER_QUARTER; + return input; +} + +template <> +interval_t DateTrunc::MonthOperator::Operation(interval_t input) { + input.days = 0; + input.micros = 0; + return input; +} + +template <> +interval_t DateTrunc::WeekOperator::Operation(interval_t input) { + input.micros = 0; + input.days = (input.days / Interval::DAYS_PER_WEEK) * Interval::DAYS_PER_WEEK; + return input; +} + +template <> +interval_t DateTrunc::ISOYearOperator::Operation(interval_t input) { + return YearOperator::Operation(input); +} + +template <> +interval_t DateTrunc::DayOperator::Operation(interval_t input) { + input.micros = 0; + return input; +} + +template <> +interval_t DateTrunc::HourOperator::Operation(interval_t input) { + input.micros = (input.micros / Interval::MICROS_PER_HOUR) * Interval::MICROS_PER_HOUR; + return input; +} + +template <> +interval_t DateTrunc::MinuteOperator::Operation(interval_t input) { + input.micros = (input.micros / Interval::MICROS_PER_MINUTE) * Interval::MICROS_PER_MINUTE; + return input; +} + +template <> +interval_t DateTrunc::SecondOperator::Operation(interval_t input) { + input.micros = (input.micros / Interval::MICROS_PER_SEC) * Interval::MICROS_PER_SEC; + return input; +} + +template <> +interval_t DateTrunc::MillisecondOperator::Operation(interval_t input) { + input.micros = (input.micros / Interval::MICROS_PER_MSEC) * Interval::MICROS_PER_MSEC; + return input; +} + +template <> +interval_t DateTrunc::MicrosecondOperator::Operation(interval_t input) { + return input; +} + +template +static TR TruncateElement(DatePartSpecifier type, TA element) { + if (!Value::IsFinite(element)) { + return Cast::template Operation(element); + } + + switch (type) { + case DatePartSpecifier::MILLENNIUM: + return DateTrunc::MillenniumOperator::Operation(element); + case DatePartSpecifier::CENTURY: + return DateTrunc::CenturyOperator::Operation(element); + case DatePartSpecifier::DECADE: + return DateTrunc::DecadeOperator::Operation(element); + case DatePartSpecifier::YEAR: + return DateTrunc::YearOperator::Operation(element); + case DatePartSpecifier::QUARTER: + return DateTrunc::QuarterOperator::Operation(element); + case DatePartSpecifier::MONTH: + return DateTrunc::MonthOperator::Operation(element); + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + return DateTrunc::WeekOperator::Operation(element); + case DatePartSpecifier::ISOYEAR: + return DateTrunc::ISOYearOperator::Operation(element); + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + return DateTrunc::DayOperator::Operation(element); + case DatePartSpecifier::HOUR: + return DateTrunc::HourOperator::Operation(element); + case DatePartSpecifier::MINUTE: + return DateTrunc::MinuteOperator::Operation(element); + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + return DateTrunc::SecondOperator::Operation(element); + case DatePartSpecifier::MILLISECONDS: + return DateTrunc::MillisecondOperator::Operation(element); + case DatePartSpecifier::MICROSECONDS: + return DateTrunc::MicrosecondOperator::Operation(element); + default: + throw NotImplementedException("Specifier type not implemented for DATETRUNC"); + } +} + +struct DateTruncBinaryOperator { + template + static inline TR Operation(TA specifier, TB date) { + return TruncateElement(GetDatePartSpecifier(specifier.GetString()), date); + } +}; + +template +static void DateTruncUnaryExecutor(DatePartSpecifier type, Vector &left, Vector &result, idx_t count) { + switch (type) { + case DatePartSpecifier::MILLENNIUM: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::CENTURY: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::DECADE: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::YEAR: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::QUARTER: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::MONTH: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::ISOYEAR: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::HOUR: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::MINUTE: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::MILLISECONDS: + DateTrunc::UnaryExecute(left, result, count); + break; + case DatePartSpecifier::MICROSECONDS: + DateTrunc::UnaryExecute(left, result, count); + break; + default: + throw NotImplementedException("Specifier type not implemented for DATETRUNC"); + } +} + +template +static void DateTruncFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + auto &part_arg = args.data[0]; + auto &date_arg = args.data[1]; + + if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + // Common case of constant part. + if (ConstantVector::IsNull(part_arg)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); + DateTruncUnaryExecutor(type, date_arg, result, args.size()); + } + } else { + BinaryExecutor::ExecuteStandard(part_arg, date_arg, result, + args.size()); + } +} + +template +static unique_ptr DateTruncStatistics(vector &child_stats) { + // we can only propagate date stats if the child has stats + auto &nstats = child_stats[1]; + if (!NumericStats::HasMinMax(nstats)) { + return nullptr; + } + // run the operator on both the min and the max, this gives us the [min, max] bound + auto min = NumericStats::GetMin(nstats); + auto max = NumericStats::GetMax(nstats); + if (min > max) { + return nullptr; + } + + // Infinite values are unmodified + auto min_part = DateTrunc::UnaryFunction(min); + auto max_part = DateTrunc::UnaryFunction(max); + + auto min_value = Value::CreateValue(min_part); + auto max_value = Value::CreateValue(max_part); + auto result = NumericStats::CreateEmpty(min_value.type()); + NumericStats::SetMin(result, min_value); + NumericStats::SetMax(result, max_value); + result.CopyValidity(child_stats[0]); + return result.ToUnique(); +} + +template +static unique_ptr PropagateDateTruncStatistics(ClientContext &context, FunctionStatisticsInput &input) { + return DateTruncStatistics(input.child_stats); +} + +template +static function_statistics_t DateTruncStats(DatePartSpecifier type) { + switch (type) { + case DatePartSpecifier::MILLENNIUM: + return PropagateDateTruncStatistics; + case DatePartSpecifier::CENTURY: + return PropagateDateTruncStatistics; + case DatePartSpecifier::DECADE: + return PropagateDateTruncStatistics; + case DatePartSpecifier::YEAR: + return PropagateDateTruncStatistics; + case DatePartSpecifier::QUARTER: + return PropagateDateTruncStatistics; + case DatePartSpecifier::MONTH: + return PropagateDateTruncStatistics; + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + return PropagateDateTruncStatistics; + case DatePartSpecifier::ISOYEAR: + return PropagateDateTruncStatistics; + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + return PropagateDateTruncStatistics; + case DatePartSpecifier::HOUR: + return PropagateDateTruncStatistics; + case DatePartSpecifier::MINUTE: + return PropagateDateTruncStatistics; + case DatePartSpecifier::SECOND: + case DatePartSpecifier::EPOCH: + return PropagateDateTruncStatistics; + case DatePartSpecifier::MILLISECONDS: + return PropagateDateTruncStatistics; + case DatePartSpecifier::MICROSECONDS: + return PropagateDateTruncStatistics; + default: + throw NotImplementedException("Specifier type not implemented for DATETRUNC statistics"); + } +} + +static unique_ptr DateTruncBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (!arguments[0]->IsFoldable()) { + return nullptr; + } + + // Rebind to return a date if we are truncating that far + Value part_value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); + if (part_value.IsNull()) { + return nullptr; + } + const auto part_name = part_value.ToString(); + const auto part_code = GetDatePartSpecifier(part_name); + switch (part_code) { + case DatePartSpecifier::MILLENNIUM: + case DatePartSpecifier::CENTURY: + case DatePartSpecifier::DECADE: + case DatePartSpecifier::YEAR: + case DatePartSpecifier::QUARTER: + case DatePartSpecifier::MONTH: + case DatePartSpecifier::WEEK: + case DatePartSpecifier::YEARWEEK: + case DatePartSpecifier::ISOYEAR: + case DatePartSpecifier::DAY: + case DatePartSpecifier::DOW: + case DatePartSpecifier::ISODOW: + case DatePartSpecifier::DOY: + case DatePartSpecifier::JULIAN_DAY: + switch (bound_function.arguments[1].id()) { + case LogicalType::TIMESTAMP: + bound_function.function = DateTruncFunction; + bound_function.statistics = DateTruncStats(part_code); + break; + case LogicalType::DATE: + bound_function.function = DateTruncFunction; + bound_function.statistics = DateTruncStats(part_code); + break; + default: + throw NotImplementedException("Temporal argument type for DATETRUNC"); + } + bound_function.return_type = LogicalType::DATE; + break; + default: + switch (bound_function.arguments[1].id()) { + case LogicalType::TIMESTAMP: + bound_function.statistics = DateTruncStats(part_code); + break; + case LogicalType::DATE: + bound_function.statistics = DateTruncStats(part_code); + break; + default: + throw NotImplementedException("Temporal argument type for DATETRUNC"); + } + break; + } + + return nullptr; +} + +ScalarFunctionSet DateTruncFun::GetFunctions() { + ScalarFunctionSet date_trunc("date_trunc"); + date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP, + DateTruncFunction, DateTruncBind)); + date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::TIMESTAMP, + DateTruncFunction, DateTruncBind)); + date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::INTERVAL}, LogicalType::INTERVAL, + DateTruncFunction)); + for (auto &func : date_trunc.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return date_trunc; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/epoch.cpp b/src/duckdb/extension/core_functions/scalar/date/epoch.cpp new file mode 100644 index 00000000..cda3232a --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/date/epoch.cpp @@ -0,0 +1,64 @@ +#include "core_functions/scalar/date_functions.hpp" + +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" + +namespace duckdb { + +struct EpochSecOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE sec) { + int64_t result; + if (!TryCast::Operation(sec * Interval::MICROS_PER_SEC, result)) { + throw ConversionException("Epoch seconds out of range for TIMESTAMP WITH TIME ZONE"); + } + return timestamp_t(result); + } +}; + +static void EpochSecFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + + UnaryExecutor::Execute(input.data[0], result, input.size()); +} + +ScalarFunction ToTimestampFun::GetFunction() { + // to_timestamp is an alias from Postgres that converts the time in seconds to a timestamp + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::TIMESTAMP_TZ, EpochSecFunction); +} + +struct NormalizedIntervalOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input) { + return input.Normalize(); + } +}; + +static void NormalizedIntervalFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + + UnaryExecutor::Execute(input.data[0], result, input.size()); +} + +ScalarFunction NormalizedIntervalFun::GetFunction() { + return ScalarFunction({LogicalType::INTERVAL}, LogicalType::INTERVAL, NormalizedIntervalFunction); +} + +struct TimeTZSortKeyOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input) { + return input.sort_key(); + } +}; + +static void TimeTZSortKeyFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + + UnaryExecutor::Execute(input.data[0], result, input.size()); +} + +ScalarFunction TimeTZSortKeyFun::GetFunction() { + return ScalarFunction({LogicalType::TIME_TZ}, LogicalType::UBIGINT, TimeTZSortKeyFunction); +} +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/make_date.cpp b/src/duckdb/extension/core_functions/scalar/date/make_date.cpp new file mode 100644 index 00000000..0fe00a92 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/date/make_date.cpp @@ -0,0 +1,181 @@ +#include "core_functions/scalar/date_functions.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/vector_operations/senary_executor.hpp" +#include "duckdb/common/exception/conversion_exception.hpp" + +#include + +namespace duckdb { + +static void MakeDateFromEpoch(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + result.Reinterpret(input.data[0]); +} + +struct MakeDateOperator { + template + static RESULT_TYPE Operation(YYYY yyyy, MM mm, DD dd) { + return Date::FromDate(Cast::Operation(yyyy), Cast::Operation(mm), + Cast::Operation(dd)); + } +}; + +template +static void ExecuteMakeDate(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 3); + auto &yyyy = input.data[0]; + auto &mm = input.data[1]; + auto &dd = input.data[2]; + + TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), + MakeDateOperator::Operation); +} + +template +static date_t FromDateCast(T year, T month, T day) { + date_t result; + if (!Date::TryFromDate(Cast::Operation(year), Cast::Operation(month), + Cast::Operation(day), result)) { + throw ConversionException("Date out of range: %d-%d-%d", year, month, day); + } + return result; +} + +template +static void ExecuteStructMakeDate(DataChunk &input, ExpressionState &state, Vector &result) { + // this should be guaranteed by the binder + D_ASSERT(input.ColumnCount() == 1); + auto &vec = input.data[0]; + + auto &children = StructVector::GetEntries(vec); + D_ASSERT(children.size() == 3); + auto &yyyy = *children[0]; + auto &mm = *children[1]; + auto &dd = *children[2]; + + TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), FromDateCast); +} + +struct MakeTimeOperator { + template + static RESULT_TYPE Operation(HH hh, MM mm, SS ss) { + + auto hh_32 = Cast::Operation(hh); + auto mm_32 = Cast::Operation(mm); + // Have to check this separately because safe casting of DOUBLE => INT32 can round. + int32_t ss_32 = 0; + if (ss < 0 || ss > Interval::SECS_PER_MINUTE) { + ss_32 = Cast::Operation(ss); + } else { + ss_32 = LossyNumericCast(ss); + } + auto micros = LossyNumericCast(std::round((ss - ss_32) * Interval::MICROS_PER_SEC)); + + if (!Time::IsValidTime(hh_32, mm_32, ss_32, micros)) { + throw ConversionException("Time out of range: %d:%d:%d.%d", hh_32, mm_32, ss_32, micros); + } + return Time::FromTime(hh_32, mm_32, ss_32, micros); + } +}; + +template +static void ExecuteMakeTime(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 3); + auto &yyyy = input.data[0]; + auto &mm = input.data[1]; + auto &dd = input.data[2]; + + TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), + MakeTimeOperator::Operation); +} + +struct MakeTimestampOperator { + template + static RESULT_TYPE Operation(YYYY yyyy, MM mm, DD dd, HR hr, MN mn, SS ss) { + const auto d = MakeDateOperator::Operation(yyyy, mm, dd); + const auto t = MakeTimeOperator::Operation(hr, mn, ss); + return Timestamp::FromDatetime(d, t); + } + + template + static RESULT_TYPE Operation(T value) { + const auto result = RESULT_TYPE(value); + if (!Timestamp::IsFinite(result)) { + throw ConversionException("Timestamp microseconds out of range: %ld", value); + } + return RESULT_TYPE(value); + } +}; + +template +static void ExecuteMakeTimestamp(DataChunk &input, ExpressionState &state, Vector &result) { + if (input.ColumnCount() == 1) { + auto func = MakeTimestampOperator::Operation; + UnaryExecutor::Execute(input.data[0], result, input.size(), func); + return; + } + + D_ASSERT(input.ColumnCount() == 6); + + auto func = MakeTimestampOperator::Operation; + SenaryExecutor::Execute(input, result, func); +} + +template +static void ExecuteMakeTimestampNs(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.ColumnCount() == 1); + + auto func = MakeTimestampOperator::Operation; + UnaryExecutor::Execute(input.data[0], result, input.size(), func); + return; +} + +ScalarFunctionSet MakeDateFun::GetFunctions() { + ScalarFunctionSet make_date("make_date"); + make_date.AddFunction(ScalarFunction({LogicalType::INTEGER}, LogicalType::DATE, MakeDateFromEpoch)); + make_date.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::DATE, ExecuteMakeDate)); + + child_list_t make_date_children { + {"year", LogicalType::BIGINT}, {"month", LogicalType::BIGINT}, {"day", LogicalType::BIGINT}}; + make_date.AddFunction( + ScalarFunction({LogicalType::STRUCT(make_date_children)}, LogicalType::DATE, ExecuteStructMakeDate)); + for (auto &func : make_date.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return make_date; +} + +ScalarFunction MakeTimeFun::GetFunction() { + ScalarFunction function({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE}, LogicalType::TIME, + ExecuteMakeTime); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunctionSet MakeTimestampFun::GetFunctions() { + ScalarFunctionSet operator_set("make_timestamp"); + operator_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, + LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE}, + LogicalType::TIMESTAMP, ExecuteMakeTimestamp)); + operator_set.AddFunction( + ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP, ExecuteMakeTimestamp)); + + for (auto &func : operator_set.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return operator_set; +} + +ScalarFunctionSet MakeTimestampNsFun::GetFunctions() { + ScalarFunctionSet operator_set("make_timestamp_ns"); + operator_set.AddFunction( + ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP_NS, ExecuteMakeTimestampNs)); + return operator_set; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp b/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp new file mode 100644 index 00000000..726d6b54 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp @@ -0,0 +1,373 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "core_functions/scalar/date_functions.hpp" + +namespace duckdb { + +struct TimeBucket { + + // Use 2000-01-03 00:00:00 (Monday) as origin when bucket_width is days, hours, ... for TimescaleDB compatibility + // There are 10959 days between 1970-01-01 and 2000-01-03 + constexpr static const int64_t DEFAULT_ORIGIN_MICROS = 10959 * Interval::MICROS_PER_DAY; + // Use 2000-01-01 as origin when bucket_width is months, years, ... for TimescaleDB compatibility + // There are 360 months between 1970-01-01 and 2000-01-01 + constexpr static const int32_t DEFAULT_ORIGIN_MONTHS = 360; + + enum struct BucketWidthType : uint8_t { CONVERTIBLE_TO_MICROS, CONVERTIBLE_TO_MONTHS, UNCLASSIFIED }; + + static inline BucketWidthType ClassifyBucketWidth(const interval_t bucket_width) { + if (bucket_width.months == 0 && Interval::GetMicro(bucket_width) > 0) { + return BucketWidthType::CONVERTIBLE_TO_MICROS; + } else if (bucket_width.months > 0 && bucket_width.days == 0 && bucket_width.micros == 0) { + return BucketWidthType::CONVERTIBLE_TO_MONTHS; + } else { + return BucketWidthType::UNCLASSIFIED; + } + } + + static inline BucketWidthType ClassifyBucketWidthErrorThrow(const interval_t bucket_width) { + if (bucket_width.months == 0) { + int64_t bucket_width_micros = Interval::GetMicro(bucket_width); + if (bucket_width_micros <= 0) { + throw NotImplementedException("Period must be greater than 0"); + } + return BucketWidthType::CONVERTIBLE_TO_MICROS; + } else if (bucket_width.months != 0 && bucket_width.days == 0 && bucket_width.micros == 0) { + if (bucket_width.months < 0) { + throw NotImplementedException("Period must be greater than 0"); + } + return BucketWidthType::CONVERTIBLE_TO_MONTHS; + } else { + throw NotImplementedException("Month intervals cannot have day or time component"); + } + } + + template + static inline int32_t EpochMonths(T ts) { + date_t ts_date = Cast::template Operation(ts); + return (Date::ExtractYear(ts_date) - 1970) * 12 + Date::ExtractMonth(ts_date) - 1; + } + + static inline timestamp_t WidthConvertibleToMicrosCommon(int64_t bucket_width_micros, int64_t ts_micros, + int64_t origin_micros) { + origin_micros %= bucket_width_micros; + ts_micros = SubtractOperatorOverflowCheck::Operation(ts_micros, origin_micros); + + int64_t result_micros = (ts_micros / bucket_width_micros) * bucket_width_micros; + if (ts_micros < 0 && ts_micros % bucket_width_micros != 0) { + result_micros = + SubtractOperatorOverflowCheck::Operation(result_micros, bucket_width_micros); + } + result_micros += origin_micros; + + return Timestamp::FromEpochMicroSeconds(result_micros); + } + + static inline date_t WidthConvertibleToMonthsCommon(int32_t bucket_width_months, int32_t ts_months, + int32_t origin_months) { + origin_months %= bucket_width_months; + ts_months = SubtractOperatorOverflowCheck::Operation(ts_months, origin_months); + + int32_t result_months = (ts_months / bucket_width_months) * bucket_width_months; + if (ts_months < 0 && ts_months % bucket_width_months != 0) { + result_months = + SubtractOperatorOverflowCheck::Operation(result_months, bucket_width_months); + } + result_months += origin_months; + + int32_t year = + (result_months < 0 && result_months % 12 != 0) ? 1970 + result_months / 12 - 1 : 1970 + result_months / 12; + int32_t month = + (result_months < 0 && result_months % 12 != 0) ? result_months % 12 + 13 : result_months % 12 + 1; + + return Date::FromDate(year, month, 1); + } + + struct WidthConvertibleToMicrosBinaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int64_t bucket_width_micros = Interval::GetMicro(bucket_width); + int64_t ts_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(ts)); + return Cast::template Operation( + WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, DEFAULT_ORIGIN_MICROS)); + } + }; + + struct WidthConvertibleToMonthsBinaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int32_t ts_months = EpochMonths(ts); + return Cast::template Operation( + WidthConvertibleToMonthsCommon(bucket_width.months, ts_months, DEFAULT_ORIGIN_MONTHS)); + } + }; + + struct BinaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts) { + BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); + switch (bucket_width_type) { + case BucketWidthType::CONVERTIBLE_TO_MICROS: + return WidthConvertibleToMicrosBinaryOperator::Operation(bucket_width, ts); + case BucketWidthType::CONVERTIBLE_TO_MONTHS: + return WidthConvertibleToMonthsBinaryOperator::Operation(bucket_width, ts); + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + }; + + struct OffsetWidthConvertibleToMicrosTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC offset) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int64_t bucket_width_micros = Interval::GetMicro(bucket_width); + int64_t ts_micros = Timestamp::GetEpochMicroSeconds( + Interval::Add(Cast::template Operation(ts), Interval::Invert(offset))); + return Cast::template Operation(Interval::Add( + WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, DEFAULT_ORIGIN_MICROS), offset)); + } + }; + + struct OffsetWidthConvertibleToMonthsTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC offset) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int32_t ts_months = EpochMonths(Interval::Add(ts, Interval::Invert(offset))); + return Interval::Add(Cast::template Operation(WidthConvertibleToMonthsCommon( + bucket_width.months, ts_months, DEFAULT_ORIGIN_MONTHS)), + offset); + } + }; + + struct OffsetTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC offset) { + BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); + switch (bucket_width_type) { + case BucketWidthType::CONVERTIBLE_TO_MICROS: + return OffsetWidthConvertibleToMicrosTernaryOperator::Operation(bucket_width, ts, + offset); + case BucketWidthType::CONVERTIBLE_TO_MONTHS: + return OffsetWidthConvertibleToMonthsTernaryOperator::Operation(bucket_width, ts, + offset); + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + }; + + struct OriginWidthConvertibleToMicrosTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC origin) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int64_t bucket_width_micros = Interval::GetMicro(bucket_width); + int64_t ts_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(ts)); + int64_t origin_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(origin)); + return Cast::template Operation( + WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, origin_micros)); + } + }; + + struct OriginWidthConvertibleToMonthsTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC origin) { + if (!Value::IsFinite(ts)) { + return Cast::template Operation(ts); + } + int32_t ts_months = EpochMonths(ts); + int32_t origin_months = EpochMonths(origin); + return Cast::template Operation( + WidthConvertibleToMonthsCommon(bucket_width.months, ts_months, origin_months)); + } + }; + + struct OriginTernaryOperator { + template + static inline TR Operation(TA bucket_width, TB ts, TC origin, ValidityMask &mask, idx_t idx) { + if (!Value::IsFinite(origin)) { + mask.SetInvalid(idx); + return TR(); + } + BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); + switch (bucket_width_type) { + case BucketWidthType::CONVERTIBLE_TO_MICROS: + return OriginWidthConvertibleToMicrosTernaryOperator::Operation(bucket_width, ts, + origin); + case BucketWidthType::CONVERTIBLE_TO_MONTHS: + return OriginWidthConvertibleToMonthsTernaryOperator::Operation(bucket_width, ts, + origin); + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + }; +}; + +template +static void TimeBucketFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 2); + + auto &bucket_width_arg = args.data[0]; + auto &ts_arg = args.data[1]; + + if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(bucket_width_arg)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); + TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); + switch (bucket_width_type) { + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: + BinaryExecutor::Execute( + bucket_width_arg, ts_arg, result, args.size(), + TimeBucket::WidthConvertibleToMicrosBinaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: + BinaryExecutor::Execute( + bucket_width_arg, ts_arg, result, args.size(), + TimeBucket::WidthConvertibleToMonthsBinaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::UNCLASSIFIED: + BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, args.size(), + TimeBucket::BinaryOperator::Operation); + break; + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + } else { + BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, args.size(), + TimeBucket::BinaryOperator::Operation); + } +} + +template +static void TimeBucketOffsetFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3); + + auto &bucket_width_arg = args.data[0]; + auto &ts_arg = args.data[1]; + auto &offset_arg = args.data[2]; + + if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(bucket_width_arg)) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); + TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); + switch (bucket_width_type) { + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, offset_arg, result, args.size(), + TimeBucket::OffsetWidthConvertibleToMicrosTernaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, offset_arg, result, args.size(), + TimeBucket::OffsetWidthConvertibleToMonthsTernaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::UNCLASSIFIED: + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, offset_arg, result, args.size(), + TimeBucket::OffsetTernaryOperator::Operation); + break; + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + } else { + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, offset_arg, result, args.size(), + TimeBucket::OffsetTernaryOperator::Operation); + } +} + +template +static void TimeBucketOriginFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3); + + auto &bucket_width_arg = args.data[0]; + auto &ts_arg = args.data[1]; + auto &origin_arg = args.data[2]; + + if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR && + origin_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(bucket_width_arg) || ConstantVector::IsNull(origin_arg) || + !Value::IsFinite(*ConstantVector::GetData(origin_arg))) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + } else { + interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); + TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); + switch (bucket_width_type) { + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, origin_arg, result, args.size(), + TimeBucket::OriginWidthConvertibleToMicrosTernaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: + TernaryExecutor::Execute( + bucket_width_arg, ts_arg, origin_arg, result, args.size(), + TimeBucket::OriginWidthConvertibleToMonthsTernaryOperator::Operation); + break; + case TimeBucket::BucketWidthType::UNCLASSIFIED: + TernaryExecutor::ExecuteWithNulls( + bucket_width_arg, ts_arg, origin_arg, result, args.size(), + TimeBucket::OriginTernaryOperator::Operation); + break; + default: + throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); + } + } + } else { + TernaryExecutor::ExecuteWithNulls( + bucket_width_arg, ts_arg, origin_arg, result, args.size(), + TimeBucket::OriginTernaryOperator::Operation); + } +} + +ScalarFunctionSet TimeBucketFun::GetFunctions() { + ScalarFunctionSet time_bucket; + time_bucket.AddFunction( + ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE}, LogicalType::DATE, TimeBucketFunction)); + time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP, + TimeBucketFunction)); + time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE, LogicalType::INTERVAL}, + LogicalType::DATE, TimeBucketOffsetFunction)); + time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, + LogicalType::TIMESTAMP, TimeBucketOffsetFunction)); + time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE, LogicalType::DATE}, + LogicalType::DATE, TimeBucketOriginFunction)); + time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, + LogicalType::TIMESTAMP, TimeBucketOriginFunction)); + for (auto &func : time_bucket.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return time_bucket; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp b/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp new file mode 100644 index 00000000..c8d50888 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp @@ -0,0 +1,258 @@ +#include "core_functions/scalar/date_functions.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/function/to_interval.hpp" + +namespace duckdb { + +template <> +bool TryMultiplyOperator::Operation(double left, int64_t right, int64_t &result) { + return TryCast::Operation(left * double(right), result); +} + +struct ToMillenniaOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.days = 0; + result.micros = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_MILLENIUM, + result.months)) { + throw OutOfRangeException("Interval value %s millennia out of range", NumericHelper::ToString(input)); + } + return result; + } +}; + +struct ToCenturiesOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.days = 0; + result.micros = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_CENTURY, result.months)) { + throw OutOfRangeException("Interval value %s centuries out of range", NumericHelper::ToString(input)); + } + return result; + } +}; + +struct ToDecadesOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.days = 0; + result.micros = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_DECADE, result.months)) { + throw OutOfRangeException("Interval value %s decades out of range", NumericHelper::ToString(input)); + } + return result; + } +}; + +struct ToYearsOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.days = 0; + result.micros = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_YEAR, + result.months)) { + throw OutOfRangeException("Interval value %d years out of range", input); + } + return result; + } +}; + +struct ToQuartersOperator { + template + static inline TR Operation(TA input) { + interval_t result; + if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_QUARTER, + result.months)) { + throw OutOfRangeException("Interval value %d quarters out of range", input); + } + result.days = 0; + result.micros = 0; + return result; + } +}; + +struct ToMonthsOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = input; + result.days = 0; + result.micros = 0; + return result; + } +}; + +struct ToWeeksOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + if (!TryMultiplyOperator::Operation(input, Interval::DAYS_PER_WEEK, result.days)) { + throw OutOfRangeException("Interval value %d weeks out of range", input); + } + result.micros = 0; + return result; + } +}; + +struct ToDaysOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = input; + result.micros = 0; + return result; + } +}; + +struct ToHoursOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_HOUR, result.micros)) { + throw OutOfRangeException("Interval value %s hours out of range", NumericHelper::ToString(input)); + } + return result; + } +}; + +struct ToMinutesOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_MINUTE, result.micros)) { + throw OutOfRangeException("Interval value %s minutes out of range", NumericHelper::ToString(input)); + } + return result; + } +}; + +struct ToMilliSecondsOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = 0; + if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_MSEC, result.micros)) { + throw OutOfRangeException("Interval value %s milliseconds out of range", NumericHelper::ToString(input)); + } + return result; + } +}; + +struct ToMicroSecondsOperator { + template + static inline TR Operation(TA input) { + interval_t result; + result.months = 0; + result.days = 0; + result.micros = input; + return result; + } +}; + +ScalarFunction ToMillenniaFun::GetFunction() { + ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToCenturiesFun::GetFunction() { + ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToDecadesFun::GetFunction() { + ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToYearsFun::GetFunction() { + ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToQuartersFun::GetFunction() { + ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToMonthsFun::GetFunction() { + ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToWeeksFun::GetFunction() { + ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToDaysFun::GetFunction() { + ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToHoursFun::GetFunction() { + ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToMinutesFun::GetFunction() { + ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToSecondsFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToMillisecondsFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunction ToMicrosecondsFun::GetFunction() { + ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp b/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp new file mode 100644 index 00000000..627d7ac2 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp @@ -0,0 +1,24 @@ +#include "core_functions/scalar/debug_functions.hpp" + +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +static void VectorTypeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + auto data = ConstantVector::GetData(result); + data[0] = StringVector::AddString(result, EnumUtil::ToString(input.data[0].GetVectorType())); +} + +ScalarFunction VectorTypeFun::GetFunction() { + auto vector_type_fun = ScalarFunction("vector_type", // name of the function + {LogicalType::ANY}, // argument list + LogicalType::VARCHAR, // return type + VectorTypeFunction); + vector_type_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return vector_type_fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp b/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp new file mode 100644 index 00000000..a10ec381 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp @@ -0,0 +1,164 @@ +#include "core_functions/scalar/enum_functions.hpp" + +namespace duckdb { + +static void EnumFirstFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto types = input.GetTypes(); + D_ASSERT(types.size() == 1); + auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); + auto val = Value(enum_vector.GetValue(0)); + result.Reference(val); +} + +static void EnumLastFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto types = input.GetTypes(); + D_ASSERT(types.size() == 1); + auto enum_size = EnumType::GetSize(types[0]); + auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); + auto val = Value(enum_vector.GetValue(enum_size - 1)); + result.Reference(val); +} + +static void EnumRangeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto types = input.GetTypes(); + D_ASSERT(types.size() == 1); + auto enum_size = EnumType::GetSize(types[0]); + auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); + vector enum_values; + for (idx_t i = 0; i < enum_size; i++) { + enum_values.emplace_back(enum_vector.GetValue(i)); + } + auto val = Value::LIST(LogicalType::VARCHAR, enum_values); + result.Reference(val); +} + +static void EnumRangeBoundaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto types = input.GetTypes(); + D_ASSERT(types.size() == 2); + idx_t start, end; + auto first_param = input.GetValue(0, 0); + auto second_param = input.GetValue(1, 0); + + auto &enum_vector = + first_param.IsNull() ? EnumType::GetValuesInsertOrder(types[1]) : EnumType::GetValuesInsertOrder(types[0]); + + if (first_param.IsNull()) { + start = 0; + } else { + start = first_param.GetValue(); + } + if (second_param.IsNull()) { + end = EnumType::GetSize(types[0]); + } else { + end = second_param.GetValue() + 1; + } + vector enum_values; + for (idx_t i = start; i < end; i++) { + enum_values.emplace_back(enum_vector.GetValue(i)); + } + auto val = Value::LIST(LogicalType::VARCHAR, enum_values); + result.Reference(val); +} + +static void EnumCodeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + D_ASSERT(input.GetTypes().size() == 1); + result.Reinterpret(input.data[0]); +} + +static void CheckEnumParameter(const Expression &expr) { + if (expr.HasParameter()) { + throw ParameterNotResolvedException(); + } +} + +unique_ptr BindEnumFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + CheckEnumParameter(*arguments[0]); + if (arguments[0]->return_type.id() != LogicalTypeId::ENUM) { + throw BinderException("This function needs an ENUM as an argument"); + } + return nullptr; +} + +unique_ptr BindEnumCodeFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + CheckEnumParameter(*arguments[0]); + if (arguments[0]->return_type.id() != LogicalTypeId::ENUM) { + throw BinderException("This function needs an ENUM as an argument"); + } + + auto phy_type = EnumType::GetPhysicalType(arguments[0]->return_type); + switch (phy_type) { + case PhysicalType::UINT8: + bound_function.return_type = LogicalType(LogicalTypeId::UTINYINT); + break; + case PhysicalType::UINT16: + bound_function.return_type = LogicalType(LogicalTypeId::USMALLINT); + break; + case PhysicalType::UINT32: + bound_function.return_type = LogicalType(LogicalTypeId::UINTEGER); + break; + case PhysicalType::UINT64: + bound_function.return_type = LogicalType(LogicalTypeId::UBIGINT); + break; + default: + throw InternalException("Unsupported Enum Internal Type"); + } + + return nullptr; +} + +unique_ptr BindEnumRangeBoundaryFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + CheckEnumParameter(*arguments[0]); + CheckEnumParameter(*arguments[1]); + if (arguments[0]->return_type.id() != LogicalTypeId::ENUM && arguments[0]->return_type != LogicalType::SQLNULL) { + throw BinderException("This function needs an ENUM as an argument"); + } + if (arguments[1]->return_type.id() != LogicalTypeId::ENUM && arguments[1]->return_type != LogicalType::SQLNULL) { + throw BinderException("This function needs an ENUM as an argument"); + } + if (arguments[0]->return_type == LogicalType::SQLNULL && arguments[1]->return_type == LogicalType::SQLNULL) { + throw BinderException("This function needs an ENUM as an argument"); + } + if (arguments[0]->return_type.id() == LogicalTypeId::ENUM && + arguments[1]->return_type.id() == LogicalTypeId::ENUM && + arguments[0]->return_type != arguments[1]->return_type) { + throw BinderException("The parameters need to link to ONLY one enum OR be NULL "); + } + return nullptr; +} + +ScalarFunction EnumFirstFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumFirstFunction, BindEnumFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +ScalarFunction EnumLastFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumLastFunction, BindEnumFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +ScalarFunction EnumCodeFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::ANY, EnumCodeFunction, BindEnumCodeFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +ScalarFunction EnumRangeFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), EnumRangeFunction, + BindEnumFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +ScalarFunction EnumRangeBoundaryFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), + EnumRangeBoundaryFunction, BindEnumRangeBoundaryFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/alias.cpp b/src/duckdb/extension/core_functions/scalar/generic/alias.cpp new file mode 100644 index 00000000..4edadcaa --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/generic/alias.cpp @@ -0,0 +1,18 @@ +#include "core_functions/scalar/generic_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +static void AliasFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + Value v(state.expr.GetAlias().empty() ? func_expr.children[0]->GetName() : state.expr.GetAlias()); + result.Reference(v); +} + +ScalarFunction AliasFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, AliasFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/binning.cpp b/src/duckdb/extension/core_functions/scalar/generic/binning.cpp new file mode 100644 index 00000000..83f7c070 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/generic/binning.cpp @@ -0,0 +1,508 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/vector_operations/generic_executor.hpp" +#include "core_functions/scalar/generic_functions.hpp" + +namespace duckdb { + +static hugeint_t GetPreviousPowerOfTen(hugeint_t input) { + hugeint_t power_of_ten = 1; + while (power_of_ten < input) { + power_of_ten *= 10; + } + return power_of_ten / 10; +} + +enum class NiceRounding { CEILING, ROUND }; + +hugeint_t RoundToNumber(hugeint_t input, hugeint_t num, NiceRounding rounding) { + if (rounding == NiceRounding::ROUND) { + return (input + (num / 2)) / num * num; + } else { + return (input + (num - 1)) / num * num; + } +} + +hugeint_t MakeNumberNice(hugeint_t input, hugeint_t step, NiceRounding rounding) { + // we consider numbers nice if they are divisible by 2 or 5 times the power-of-ten one lower than the current + // e.g. 120 is a nice number because it is divisible by 20 + // 122 is not a nice number -> we make it nice by turning it into 120 [/20] + // 153 is not a nice number -> we make it nice by turning it into 150 [/50] + // 1220 is not a nice number -> we turn it into 1200 [/200] + // first figure out the previous power of 10 (i.e. for 67 we return 10) + // now the power of ten is the power BELOW the current number + // i.e. for 67, it is not 10 + // now we can get the 2 or 5 divisors + hugeint_t power_of_ten = GetPreviousPowerOfTen(step); + hugeint_t two = power_of_ten * 2; + hugeint_t five = power_of_ten; + if (power_of_ten * 3 <= step) { + two *= 5; + } + if (power_of_ten * 2 <= step) { + five *= 5; + } + + // compute the closest round number by adding the divisor / 2 and truncating + // do this for both divisors + hugeint_t round_to_two = RoundToNumber(input, two, rounding); + hugeint_t round_to_five = RoundToNumber(input, five, rounding); + // now pick the closest number of the two (i.e. for 147 we pick 150, not 140) + if (AbsValue(input - round_to_two) < AbsValue(input - round_to_five)) { + return round_to_two; + } else { + return round_to_five; + } +} + +static double GetPreviousPowerOfTen(double input) { + double power_of_ten = 1; + if (input < 1) { + while (power_of_ten > input) { + power_of_ten /= 10; + } + return power_of_ten; + } + while (power_of_ten < input) { + power_of_ten *= 10; + } + return power_of_ten / 10; +} + +double RoundToNumber(double input, double num, NiceRounding rounding) { + double result; + if (rounding == NiceRounding::ROUND) { + result = std::round(input / num) * num; + } else { + result = std::ceil(input / num) * num; + } + if (!Value::IsFinite(result)) { + return input; + } + return result; +} + +double MakeNumberNice(double input, const double step, NiceRounding rounding) { + if (input == 0) { + return 0; + } + // now the power of ten is the power BELOW the current number + // i.e. for 67, it is not 10 + // now we can get the 2 or 5 divisors + double power_of_ten = GetPreviousPowerOfTen(step); + double two = power_of_ten * 2; + double five = power_of_ten; + if (power_of_ten * 3 <= step) { + two *= 5; + } + if (power_of_ten * 2 <= step) { + five *= 5; + } + + double round_to_two = RoundToNumber(input, two, rounding); + double round_to_five = RoundToNumber(input, five, rounding); + // now pick the closest number of the two (i.e. for 147 we pick 150, not 140) + if (AbsValue(input - round_to_two) < AbsValue(input - round_to_five)) { + return round_to_two; + } else { + return round_to_five; + } +} + +struct EquiWidthBinsInteger { + static constexpr LogicalTypeId LOGICAL_TYPE = LogicalTypeId::BIGINT; + + static vector> Operation(const Expression &expr, int64_t input_min, int64_t input_max, + idx_t bin_count, bool nice_rounding) { + vector> result; + // to prevent integer truncation from affecting the bin boundaries we calculate them with numbers multiplied by + // 1000 we then divide to get the actual boundaries + const auto FACTOR = hugeint_t(1000); + auto min = hugeint_t(input_min) * FACTOR; + auto max = hugeint_t(input_max) * FACTOR; + + const hugeint_t span = max - min; + hugeint_t step = span / Hugeint::Convert(bin_count); + if (nice_rounding) { + // when doing nice rounding we try to make the max/step values nicer + hugeint_t new_step = MakeNumberNice(step, step, NiceRounding::ROUND); + hugeint_t new_max = RoundToNumber(max, new_step, NiceRounding::CEILING); + if (new_max != min && new_step != 0) { + max = new_max; + step = new_step; + } + // we allow for more bins when doing nice rounding since the bin count is approximate + bin_count *= 2; + } + for (hugeint_t bin_boundary = max; bin_boundary > min; bin_boundary -= step) { + const hugeint_t target_boundary = bin_boundary / FACTOR; + int64_t real_boundary = Hugeint::Cast(target_boundary); + if (!result.empty()) { + if (real_boundary < input_min || result.size() >= bin_count) { + // we can never generate input_min + break; + } + if (real_boundary == result.back().val) { + // we cannot generate the same value multiple times in a row - skip this step + continue; + } + } + result.push_back(real_boundary); + } + return result; + } +}; + +struct EquiWidthBinsDouble { + static constexpr LogicalTypeId LOGICAL_TYPE = LogicalTypeId::DOUBLE; + + static vector> Operation(const Expression &expr, double min, double input_max, + idx_t bin_count, bool nice_rounding) { + double max = input_max; + if (!Value::IsFinite(min) || !Value::IsFinite(max)) { + throw InvalidInputException("equi_width_bucket does not support infinite or nan as min/max value"); + } + vector> result; + const double span = max - min; + double step; + if (!Value::IsFinite(span)) { + // max - min does not fit + step = max / static_cast(bin_count) - min / static_cast(bin_count); + } else { + step = span / static_cast(bin_count); + } + const double step_power_of_ten = GetPreviousPowerOfTen(step); + if (nice_rounding) { + // when doing nice rounding we try to make the max/step values nicer + step = MakeNumberNice(step, step, NiceRounding::ROUND); + max = RoundToNumber(input_max, step, NiceRounding::CEILING); + // we allow for more bins when doing nice rounding since the bin count is approximate + bin_count *= 2; + } + if (step == 0) { + throw InternalException("step is 0!?"); + } + + const double round_multiplication = 10 / step_power_of_ten; + for (double bin_boundary = max; bin_boundary > min; bin_boundary -= step) { + // because floating point addition adds inaccuracies, we add rounding at every step + double real_boundary = bin_boundary; + if (nice_rounding) { + real_boundary = std::round(bin_boundary * round_multiplication) / round_multiplication; + } + if (!result.empty() && result.back().val == real_boundary) { + // skip this step + continue; + } + if (real_boundary <= min || result.size() >= bin_count) { + // we can never generate below input_min + break; + } + result.push_back(real_boundary); + } + return result; + } +}; + +void NextMonth(int32_t &year, int32_t &month) { + month++; + if (month == 13) { + year++; + month = 1; + } +} + +void NextDay(int32_t &year, int32_t &month, int32_t &day) { + day++; + if (!Date::IsValid(year, month, day)) { + // day is out of range for month, move to next month + NextMonth(year, month); + day = 1; + } +} + +void NextHour(int32_t &year, int32_t &month, int32_t &day, int32_t &hour) { + hour++; + if (hour >= 24) { + NextDay(year, month, day); + hour = 0; + } +} + +void NextMinute(int32_t &year, int32_t &month, int32_t &day, int32_t &hour, int32_t &minute) { + minute++; + if (minute >= 60) { + NextHour(year, month, day, hour); + minute = 0; + } +} + +void NextSecond(int32_t &year, int32_t &month, int32_t &day, int32_t &hour, int32_t &minute, int32_t &sec) { + sec++; + if (sec >= 60) { + NextMinute(year, month, day, hour, minute); + sec = 0; + } +} + +timestamp_t MakeTimestampNice(int32_t year, int32_t month, int32_t day, int32_t hour, int32_t minute, int32_t sec, + int32_t micros, interval_t step) { + // how to make a timestamp nice depends on the step + if (step.months >= 12) { + // if the step involves one year or more, ceil to months + // set time component to 00:00:00.00 + if (day > 1 || hour > 0 || minute > 0 || sec > 0 || micros > 0) { + // move to next month + NextMonth(year, month); + hour = minute = sec = micros = 0; + day = 1; + } + } else if (step.months > 0 || step.days >= 1) { + // if the step involves more than one day, ceil to days + if (hour > 0 || minute > 0 || sec > 0 || micros > 0) { + NextDay(year, month, day); + hour = minute = sec = micros = 0; + } + } else if (step.days > 0 || step.micros >= Interval::MICROS_PER_HOUR) { + // if the step involves more than one hour, ceil to hours + if (minute > 0 || sec > 0 || micros > 0) { + NextHour(year, month, day, hour); + minute = sec = micros = 0; + } + } else if (step.micros >= Interval::MICROS_PER_MINUTE) { + // if the step involves more than one minute, ceil to minutes + if (sec > 0 || micros > 0) { + NextMinute(year, month, day, hour, minute); + sec = micros = 0; + } + } else if (step.micros >= Interval::MICROS_PER_SEC) { + // if the step involves more than one second, ceil to seconds + if (micros > 0) { + NextSecond(year, month, day, hour, minute, sec); + micros = 0; + } + } + return Timestamp::FromDatetime(Date::FromDate(year, month, day), Time::FromTime(hour, minute, sec, micros)); +} + +int64_t RoundNumberToDivisor(int64_t number, int64_t divisor) { + return (number + (divisor / 2)) / divisor * divisor; +} + +interval_t MakeIntervalNice(interval_t interval) { + if (interval.months >= 6) { + // if we have more than 6 months, we don't care about days + interval.days = 0; + interval.micros = 0; + } else if (interval.months > 0 || interval.days >= 5) { + // if we have any months or more than 5 days, we don't care about micros + interval.micros = 0; + } else if (interval.days > 0 || interval.micros >= 6 * Interval::MICROS_PER_HOUR) { + // if we any days or more than 6 hours, we want micros to be roundable by hours at least + interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_HOUR); + } else if (interval.micros >= Interval::MICROS_PER_HOUR) { + // if we have more than an hour, we want micros to be divisible by quarter hours + interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_MINUTE * 15); + } else if (interval.micros >= Interval::MICROS_PER_MINUTE * 10) { + // if we have more than 10 minutes, we want micros to be divisible by minutes + interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_MINUTE); + } else if (interval.micros >= Interval::MICROS_PER_MINUTE) { + // if we have more than a minute, we want micros to be divisible by quarter minutes + interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_SEC * 15); + } else if (interval.micros >= Interval::MICROS_PER_SEC * 10) { + // if we have more than 10 seconds, we want micros to be divisible by seconds + interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_SEC); + } + return interval; +} + +void GetTimestampComponents(timestamp_t input, int32_t &year, int32_t &month, int32_t &day, int32_t &hour, + int32_t &minute, int32_t &sec, int32_t µs) { + date_t date; + dtime_t time; + + Timestamp::Convert(input, date, time); + Date::Convert(date, year, month, day); + Time::Convert(time, hour, minute, sec, micros); +} + +struct EquiWidthBinsTimestamp { + static constexpr LogicalTypeId LOGICAL_TYPE = LogicalTypeId::TIMESTAMP; + + static vector> Operation(const Expression &expr, timestamp_t input_min, + timestamp_t input_max, idx_t bin_count, bool nice_rounding) { + if (!Value::IsFinite(input_min) || !Value::IsFinite(input_max)) { + throw InvalidInputException(expr, "equi_width_bucket does not support infinite or nan as min/max value"); + } + + if (!nice_rounding) { + // if we are not doing nice rounding it is pretty simple - just interpolate between the timestamp values + auto interpolated_values = + EquiWidthBinsInteger::Operation(expr, input_min.value, input_max.value, bin_count, false); + + vector> result; + for (auto &val : interpolated_values) { + result.push_back(timestamp_t(val.val)); + } + return result; + } + // fetch the components of the timestamps + int32_t min_year, min_month, min_day, min_hour, min_minute, min_sec, min_micros; + int32_t max_year, max_month, max_day, max_hour, max_minute, max_sec, max_micros; + GetTimestampComponents(input_min, min_year, min_month, min_day, min_hour, min_minute, min_sec, min_micros); + GetTimestampComponents(input_max, max_year, max_month, max_day, max_hour, max_minute, max_sec, max_micros); + + // get the interval differences per component + // note: these can be negative (except for the largest non-zero difference) + interval_t interval_diff; + interval_diff.months = (max_year - min_year) * Interval::MONTHS_PER_YEAR + (max_month - min_month); + interval_diff.days = max_day - min_day; + interval_diff.micros = (max_hour - min_hour) * Interval::MICROS_PER_HOUR + + (max_minute - min_minute) * Interval::MICROS_PER_MINUTE + + (max_sec - min_sec) * Interval::MICROS_PER_SEC + (max_micros - min_micros); + + double step_months = static_cast(interval_diff.months) / static_cast(bin_count); + double step_days = static_cast(interval_diff.days) / static_cast(bin_count); + double step_micros = static_cast(interval_diff.micros) / static_cast(bin_count); + // since we truncate the months/days, propagate any fractional component to the unit below (i.e. 0.2 months + // becomes 6 days) + if (step_months > 0) { + double overflow_months = step_months - std::floor(step_months); + step_days += overflow_months * Interval::DAYS_PER_MONTH; + } + if (step_days > 0) { + double overflow_days = step_days - std::floor(step_days); + step_micros += overflow_days * Interval::MICROS_PER_DAY; + } + interval_t step; + step.months = static_cast(step_months); + step.days = static_cast(step_days); + step.micros = static_cast(step_micros); + + // now we make the max, and the step nice + step = MakeIntervalNice(step); + timestamp_t timestamp_val = + MakeTimestampNice(max_year, max_month, max_day, max_hour, max_minute, max_sec, max_micros, step); + if (step.months <= 0 && step.days <= 0 && step.micros <= 0) { + // interval must be at least one microsecond + step.months = step.days = 0; + step.micros = 1; + } + + vector> result; + while (timestamp_val.value >= input_min.value && result.size() < bin_count) { + result.push_back(timestamp_val); + timestamp_val = SubtractOperator::Operation(timestamp_val, step); + } + return result; + } +}; + +unique_ptr BindEquiWidthFunction(ClientContext &, ScalarFunction &bound_function, + vector> &arguments) { + // while internally the bins are computed over a unified type + // the equi_width_bins function returns the same type as the input MAX + LogicalType child_type; + switch (arguments[1]->return_type.id()) { + case LogicalTypeId::UNKNOWN: + case LogicalTypeId::SQLNULL: + return nullptr; + case LogicalTypeId::DECIMAL: + // for decimals we promote to double because + child_type = LogicalType::DOUBLE; + break; + default: + child_type = arguments[1]->return_type; + break; + } + bound_function.return_type = LogicalType::LIST(child_type); + return nullptr; +} + +template +static void EquiWidthBinFunction(DataChunk &args, ExpressionState &state, Vector &result) { + static constexpr int64_t MAX_BIN_COUNT = 1000000; + auto &min_arg = args.data[0]; + auto &max_arg = args.data[1]; + auto &bin_count = args.data[2]; + auto &nice_rounding = args.data[3]; + + Vector intermediate_result(LogicalType::LIST(OP::LOGICAL_TYPE)); + GenericExecutor::ExecuteQuaternary, PrimitiveType, PrimitiveType, PrimitiveType, + GenericListType>>( + min_arg, max_arg, bin_count, nice_rounding, intermediate_result, args.size(), + [&](PrimitiveType min_p, PrimitiveType max_p, PrimitiveType bins_p, + PrimitiveType nice_rounding_p) { + if (max_p.val < min_p.val) { + throw InvalidInputException(state.expr, + "Invalid input for bin function - max value is smaller than min value"); + } + if (bins_p.val <= 0) { + throw InvalidInputException(state.expr, "Invalid input for bin function - there must be > 0 bins"); + } + if (bins_p.val > MAX_BIN_COUNT) { + throw InvalidInputException(state.expr, "Invalid input for bin function - max bin count of %d exceeded", + MAX_BIN_COUNT); + } + GenericListType> result_bins; + if (max_p.val == min_p.val) { + // if max = min return a single bucket + result_bins.values.push_back(max_p.val); + } else { + result_bins.values = OP::Operation(state.expr, min_p.val, max_p.val, static_cast(bins_p.val), + nice_rounding_p.val); + // last bin should always be the input max + if (result_bins.values[0].val < max_p.val) { + result_bins.values[0].val = max_p.val; + } + std::reverse(result_bins.values.begin(), result_bins.values.end()); + } + return result_bins; + }); + VectorOperations::DefaultCast(intermediate_result, result, args.size()); +} + +static void UnsupportedEquiWidth(DataChunk &args, ExpressionState &state, Vector &) { + throw BinderException(state.expr, "Unsupported type \"%s\" for equi_width_bins", args.data[0].GetType()); +} + +void EquiWidthBinSerialize(Serializer &, const optional_ptr, const ScalarFunction &) { + return; +} + +unique_ptr EquiWidthBinDeserialize(Deserializer &deserializer, ScalarFunction &function) { + function.return_type = deserializer.Get(); + return nullptr; +} + +ScalarFunctionSet EquiWidthBinsFun::GetFunctions() { + ScalarFunctionSet functions("equi_width_bins"); + functions.AddFunction( + ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BOOLEAN}, + LogicalType::LIST(LogicalType::ANY), EquiWidthBinFunction, + BindEquiWidthFunction)); + functions.AddFunction(ScalarFunction( + {LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::BIGINT, LogicalType::BOOLEAN}, + LogicalType::LIST(LogicalType::ANY), EquiWidthBinFunction, BindEquiWidthFunction)); + functions.AddFunction( + ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::BIGINT, LogicalType::BOOLEAN}, + LogicalType::LIST(LogicalType::ANY), EquiWidthBinFunction, + BindEquiWidthFunction)); + functions.AddFunction( + ScalarFunction({LogicalType::ANY_PARAMS(LogicalType::ANY, 150), LogicalType::ANY_PARAMS(LogicalType::ANY, 150), + LogicalType::BIGINT, LogicalType::BOOLEAN}, + LogicalType::LIST(LogicalType::ANY), UnsupportedEquiWidth, BindEquiWidthFunction)); + for (auto &function : functions.functions) { + function.serialize = EquiWidthBinSerialize; + function.deserialize = EquiWidthBinDeserialize; + BaseScalarFunction::SetReturnsError(function); + } + return functions; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp b/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp new file mode 100644 index 00000000..5db38d60 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp @@ -0,0 +1,40 @@ +#include "core_functions/scalar/generic_functions.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/function/cast_rules.hpp" + +namespace duckdb { + +bool CanCastImplicitly(ClientContext &context, const LogicalType &source, const LogicalType &target) { + return CastFunctionSet::Get(context).ImplicitCastCost(source, target) >= 0; +} + +static void CanCastImplicitlyFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &context = state.GetContext(); + bool can_cast_implicitly = CanCastImplicitly(context, args.data[0].GetType(), args.data[1].GetType()); + auto v = Value::BOOLEAN(can_cast_implicitly); + result.Reference(v); +} + +unique_ptr BindCanCastImplicitlyExpression(FunctionBindExpressionInput &input) { + auto &source_type = input.function.children[0]->return_type; + auto &target_type = input.function.children[1]->return_type; + if (source_type.id() == LogicalTypeId::UNKNOWN || source_type.id() == LogicalTypeId::SQLNULL || + target_type.id() == LogicalTypeId::UNKNOWN || target_type.id() == LogicalTypeId::SQLNULL) { + // parameter - unknown return type + return nullptr; + } + // emit a constant expression + return make_uniq( + Value::BOOLEAN(CanCastImplicitly(input.context, source_type, target_type))); +} + +ScalarFunction CanCastImplicitlyFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::BOOLEAN, CanCastImplicitlyFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.bind_expression = BindCanCastImplicitlyExpression; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp new file mode 100644 index 00000000..b983b27c --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp @@ -0,0 +1,68 @@ +#include "core_functions/scalar/generic_functions.hpp" + +#include "duckdb/main/database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/catalog/catalog.hpp" +namespace duckdb { + +struct CurrentSettingBindData : public FunctionData { + explicit CurrentSettingBindData(Value value_p) : value(std::move(value_p)) { + } + + Value value; + +public: + unique_ptr Copy() const override { + return make_uniq(value); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return Value::NotDistinctFrom(value, other.value); + } +}; + +static void CurrentSettingFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + result.Reference(info.value); +} + +unique_ptr CurrentSettingBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + auto &key_child = arguments[0]; + if (key_child->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + if (key_child->return_type.id() != LogicalTypeId::VARCHAR || + key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { + throw ParserException("Key name for current_setting needs to be a constant string"); + } + Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); + D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); + if (key_val.IsNull() || StringValue::Get(key_val).empty()) { + throw ParserException("Key name for current_setting needs to be neither NULL nor empty"); + } + + auto key = StringUtil::Lower(StringValue::Get(key_val)); + Value val; + if (!context.TryGetCurrentSetting(key, val)) { + Catalog::AutoloadExtensionByConfigName(context, key); + // If autoloader didn't throw, the config is now available + context.TryGetCurrentSetting(key, val); + } + + bound_function.return_type = val.type(); + return make_uniq(val); +} + +ScalarFunction CurrentSettingFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::VARCHAR}, LogicalType::ANY, CurrentSettingFunction, CurrentSettingBind); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/hash.cpp b/src/duckdb/extension/core_functions/scalar/generic/hash.cpp new file mode 100644 index 00000000..18491944 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/generic/hash.cpp @@ -0,0 +1,19 @@ +#include "core_functions/scalar/generic_functions.hpp" + +namespace duckdb { + +static void HashFunction(DataChunk &args, ExpressionState &state, Vector &result) { + args.Hash(result); + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +ScalarFunction HashFun::GetFunction() { + auto hash_fun = ScalarFunction({LogicalType::ANY}, LogicalType::HASH, HashFunction); + hash_fun.varargs = LogicalType::ANY; + hash_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return hash_fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/least.cpp b/src/duckdb/extension/core_functions/scalar/generic/least.cpp new file mode 100644 index 00000000..40a94310 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/generic/least.cpp @@ -0,0 +1,259 @@ +#include "duckdb/common/operator/comparison_operators.hpp" +#include "core_functions/scalar/generic_functions.hpp" +#include "duckdb/function/create_sort_key.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +struct LeastOp { + using OP = LessThan; + + static OrderByNullType NullOrdering() { + return OrderByNullType::NULLS_LAST; + } +}; + +struct GreaterOp { + using OP = GreaterThan; + + static OrderByNullType NullOrdering() { + return OrderByNullType::NULLS_FIRST; + } +}; + +template +struct LeastOperator { + template + static T Operation(T left, T right) { + return OP::Operation(left, right) ? left : right; + } +}; + +struct LeastGreatestSortKeyState : public FunctionLocalState { + explicit LeastGreatestSortKeyState(idx_t column_count, OrderByNullType null_ordering) + : intermediate(LogicalType::BLOB), modifiers(OrderType::ASCENDING, null_ordering) { + vector types; + // initialize sort key chunk + for (idx_t i = 0; i < column_count; i++) { + types.push_back(LogicalType::BLOB); + } + sort_keys.Initialize(Allocator::DefaultAllocator(), types); + } + + DataChunk sort_keys; + Vector intermediate; + OrderModifiers modifiers; +}; + +template +unique_ptr LeastGreatestSortKeyInit(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data) { + return make_uniq(expr.children.size(), OP::NullOrdering()); +} + +template +struct StandardLeastGreatest { + static constexpr bool IS_STRING = STRING; + + static DataChunk &Prepare(DataChunk &args, ExpressionState &) { + return args; + } + + static Vector &TargetVector(Vector &result, ExpressionState &) { + return result; + } + + static void FinalizeResult(idx_t rows, bool result_has_value[], Vector &result, ExpressionState &) { + auto &result_mask = FlatVector::Validity(result); + for (idx_t i = 0; i < rows; i++) { + if (!result_has_value[i]) { + result_mask.SetInvalid(i); + } + } + } +}; + +struct SortKeyLeastGreatest { + static constexpr bool IS_STRING = false; + + static DataChunk &Prepare(DataChunk &args, ExpressionState &state) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + lstate.sort_keys.Reset(); + for (idx_t c_idx = 0; c_idx < args.ColumnCount(); c_idx++) { + CreateSortKeyHelpers::CreateSortKey(args.data[c_idx], args.size(), lstate.modifiers, + lstate.sort_keys.data[c_idx]); + } + lstate.sort_keys.SetCardinality(args.size()); + return lstate.sort_keys; + } + + static Vector &TargetVector(Vector &result, ExpressionState &state) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + return lstate.intermediate; + } + + static void FinalizeResult(idx_t rows, bool result_has_value[], Vector &result, ExpressionState &state) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + auto result_keys = FlatVector::GetData(lstate.intermediate); + auto &result_mask = FlatVector::Validity(result); + for (idx_t i = 0; i < rows; i++) { + if (!result_has_value[i]) { + result_mask.SetInvalid(i); + } else { + CreateSortKeyHelpers::DecodeSortKey(result_keys[i], result, i, lstate.modifiers); + } + } + } +}; + +template > +static void LeastGreatestFunction(DataChunk &args, ExpressionState &state, Vector &result) { + if (args.ColumnCount() == 1) { + // single input: nop + result.Reference(args.data[0]); + return; + } + auto &input = BASE_OP::Prepare(args, state); + auto &result_vector = BASE_OP::TargetVector(result, state); + + auto result_type = VectorType::CONSTANT_VECTOR; + for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { + if (args.data[col_idx].GetVectorType() != VectorType::CONSTANT_VECTOR) { + // non-constant input: result is not a constant vector + result_type = VectorType::FLAT_VECTOR; + } + if (BASE_OP::IS_STRING) { + // for string vectors we add a reference to the heap of the children + StringVector::AddHeapReference(result_vector, input.data[col_idx]); + } + } + + auto result_data = FlatVector::GetData(result_vector); + bool result_has_value[STANDARD_VECTOR_SIZE] {false}; + // perform the operation column-by-column + for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { + if (input.data[col_idx].GetVectorType() == VectorType::CONSTANT_VECTOR && + ConstantVector::IsNull(input.data[col_idx])) { + // ignore null vector + continue; + } + + UnifiedVectorFormat vdata; + input.data[col_idx].ToUnifiedFormat(input.size(), vdata); + + auto input_data = UnifiedVectorFormat::GetData(vdata); + if (!vdata.validity.AllValid()) { + // potential new null entries: have to check the null mask + for (idx_t i = 0; i < input.size(); i++) { + auto vindex = vdata.sel->get_index(i); + if (vdata.validity.RowIsValid(vindex)) { + // not a null entry: perform the operation and add to new set + auto ivalue = input_data[vindex]; + if (!result_has_value[i] || OP::template Operation(ivalue, result_data[i])) { + result_has_value[i] = true; + result_data[i] = ivalue; + } + } + } + } else { + // no new null entries: only need to perform the operation + for (idx_t i = 0; i < input.size(); i++) { + auto vindex = vdata.sel->get_index(i); + + auto ivalue = input_data[vindex]; + if (!result_has_value[i] || OP::template Operation(ivalue, result_data[i])) { + result_has_value[i] = true; + result_data[i] = ivalue; + } + } + } + } + BASE_OP::FinalizeResult(input.size(), result_has_value, result, state); + result.SetVectorType(result_type); +} + +template +unique_ptr BindLeastGreatest(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + LogicalType child_type = ExpressionBinder::GetExpressionReturnType(*arguments[0]); + for (idx_t i = 1; i < arguments.size(); i++) { + auto arg_type = ExpressionBinder::GetExpressionReturnType(*arguments[i]); + if (!LogicalType::TryGetMaxLogicalType(context, child_type, arg_type, child_type)) { + throw BinderException(arguments[i]->GetQueryLocation(), + "Cannot combine types of %s and %s - an explicit cast is required", + child_type.ToString(), arg_type.ToString()); + } + } + switch (child_type.id()) { + case LogicalTypeId::UNKNOWN: + throw ParameterNotResolvedException(); + case LogicalTypeId::INTEGER_LITERAL: + child_type = IntegerLiteral::GetType(child_type); + break; + case LogicalTypeId::STRING_LITERAL: + child_type = LogicalType::VARCHAR; + break; + default: + break; + } + using OP = typename LEAST_GREATER_OP::OP; + switch (child_type.InternalType()) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::BOOL: + case PhysicalType::INT8: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::INT16: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::INT32: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::INT64: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::INT128: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::DOUBLE: + bound_function.function = LeastGreatestFunction; + break; + case PhysicalType::VARCHAR: + bound_function.function = LeastGreatestFunction>; + break; +#endif + default: + // fallback with sort keys + bound_function.function = LeastGreatestFunction; + bound_function.init_local_state = LeastGreatestSortKeyInit; + break; + } + bound_function.arguments[0] = child_type; + bound_function.varargs = child_type; + bound_function.return_type = child_type; + return nullptr; +} + +template +ScalarFunction GetLeastGreatestFunction() { + return ScalarFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, BindLeastGreatest, nullptr, nullptr, + nullptr, LogicalType::ANY, FunctionStability::CONSISTENT, + FunctionNullHandling::SPECIAL_HANDLING); +} + +template +static ScalarFunctionSet GetLeastGreatestFunctions() { + ScalarFunctionSet fun_set; + fun_set.AddFunction(GetLeastGreatestFunction()); + return fun_set; +} + +ScalarFunctionSet LeastFun::GetFunctions() { + return GetLeastGreatestFunctions(); +} + +ScalarFunctionSet GreatestFun::GetFunctions() { + return GetLeastGreatestFunctions(); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/stats.cpp b/src/duckdb/extension/core_functions/scalar/generic/stats.cpp new file mode 100644 index 00000000..ad3f4cd0 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/generic/stats.cpp @@ -0,0 +1,54 @@ +#include "core_functions/scalar/generic_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +struct StatsBindData : public FunctionData { + explicit StatsBindData(string stats_p = string()) : stats(std::move(stats_p)) { + } + + string stats; + +public: + unique_ptr Copy() const override { + return make_uniq(stats); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return stats == other.stats; + } +}; + +static void StatsFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + if (info.stats.empty()) { + info.stats = "No statistics"; + } + Value v(info.stats); + result.Reference(v); +} + +unique_ptr StatsBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + return make_uniq(); +} + +static unique_ptr StatsPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &bind_data = input.bind_data; + auto &info = bind_data->Cast(); + info.stats = child_stats[0].ToString(); + return nullptr; +} + +ScalarFunction StatsFun::GetFunction() { + ScalarFunction stats({LogicalType::ANY}, LogicalType::VARCHAR, StatsFunction, StatsBind, nullptr, + StatsPropagateStats); + stats.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + stats.stability = FunctionStability::VOLATILE; + return stats; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp b/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp new file mode 100644 index 00000000..5e4251c0 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp @@ -0,0 +1,148 @@ +#include "duckdb/catalog/catalog_search_path.hpp" +#include "core_functions/scalar/generic_functions.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +// current_query +static void CurrentQueryFunction(DataChunk &input, ExpressionState &state, Vector &result) { + Value val(state.GetContext().GetCurrentQuery()); + result.Reference(val); +} + +// current_schema +static void CurrentSchemaFunction(DataChunk &input, ExpressionState &state, Vector &result) { + Value val(ClientData::Get(state.GetContext()).catalog_search_path->GetDefault().schema); + result.Reference(val); +} + +// current_database +static void CurrentDatabaseFunction(DataChunk &input, ExpressionState &state, Vector &result) { + Value val(DatabaseManager::GetDefaultDatabase(state.GetContext())); + result.Reference(val); +} + +struct CurrentSchemasBindData : public FunctionData { + explicit CurrentSchemasBindData(Value result_value) : result(std::move(result_value)) { + } + + Value result; + +public: + unique_ptr Copy() const override { + return make_uniq(result); + } + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return Value::NotDistinctFrom(result, other.result); + } +}; + +static unique_ptr CurrentSchemasBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments[0]->return_type.id() != LogicalTypeId::BOOLEAN) { + throw BinderException("current_schemas requires a boolean input"); + } + if (!arguments[0]->IsFoldable()) { + throw NotImplementedException("current_schemas requires a constant input"); + } + Value schema_value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); + Value result_val; + if (schema_value.IsNull()) { + // null + result_val = Value(LogicalType::LIST(LogicalType::VARCHAR)); + } else { + auto implicit_schemas = BooleanValue::Get(schema_value); + vector schema_list; + auto &catalog_search_path = ClientData::Get(context).catalog_search_path; + auto &search_path = implicit_schemas ? catalog_search_path->Get() : catalog_search_path->GetSetPaths(); + std::transform(search_path.begin(), search_path.end(), std::back_inserter(schema_list), + [](const CatalogSearchEntry &s) -> Value { return Value(s.schema); }); + result_val = Value::LIST(LogicalType::VARCHAR, schema_list); + } + return make_uniq(std::move(result_val)); +} + +// current_schemas +static void CurrentSchemasFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + result.Reference(info.result); +} + +// in_search_path +static void InSearchPathFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto &context = state.GetContext(); + auto &search_path = ClientData::Get(context).catalog_search_path; + BinaryExecutor::Execute( + input.data[0], input.data[1], result, input.size(), [&](string_t db_name, string_t schema_name) { + return search_path->SchemaInSearchPath(context, db_name.GetString(), schema_name.GetString()); + }); +} + +// txid_current +static void TransactionIdCurrent(DataChunk &input, ExpressionState &state, Vector &result) { + auto &context = state.GetContext(); + auto &catalog = Catalog::GetCatalog(context, DatabaseManager::GetDefaultDatabase(context)); + auto &transaction = DuckTransaction::Get(context, catalog); + auto val = Value::UBIGINT(transaction.start_time); + result.Reference(val); +} + +// version +static void VersionFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto val = Value(DuckDB::LibraryVersion()); + result.Reference(val); +} + +ScalarFunction CurrentQueryFun::GetFunction() { + ScalarFunction current_query({}, LogicalType::VARCHAR, CurrentQueryFunction); + current_query.stability = FunctionStability::VOLATILE; + return current_query; +} + +ScalarFunction CurrentSchemaFun::GetFunction() { + ScalarFunction current_schema({}, LogicalType::VARCHAR, CurrentSchemaFunction); + current_schema.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + return current_schema; +} + +ScalarFunction CurrentDatabaseFun::GetFunction() { + ScalarFunction current_database({}, LogicalType::VARCHAR, CurrentDatabaseFunction); + current_database.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + return current_database; +} + +ScalarFunction CurrentSchemasFun::GetFunction() { + auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); + ScalarFunction current_schemas({LogicalType::BOOLEAN}, varchar_list_type, CurrentSchemasFunction, + CurrentSchemasBind); + current_schemas.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + return current_schemas; +} + +ScalarFunction InSearchPathFun::GetFunction() { + ScalarFunction in_search_path({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + InSearchPathFunction); + in_search_path.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + return in_search_path; +} + +ScalarFunction CurrentTransactionIdFun::GetFunction() { + ScalarFunction txid_current({}, LogicalType::UBIGINT, TransactionIdCurrent); + txid_current.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + return txid_current; +} + +ScalarFunction VersionFun::GetFunction() { + return ScalarFunction({}, LogicalType::VARCHAR, VersionFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp b/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp new file mode 100644 index 00000000..1f7caef8 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp @@ -0,0 +1,29 @@ +#include "core_functions/scalar/generic_functions.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +static void TypeOfFunction(DataChunk &args, ExpressionState &state, Vector &result) { + Value v(args.data[0].GetType().ToString()); + result.Reference(v); +} + +unique_ptr BindTypeOfFunctionExpression(FunctionBindExpressionInput &input) { + auto &return_type = input.function.children[0]->return_type; + if (return_type.id() == LogicalTypeId::UNKNOWN || return_type.id() == LogicalTypeId::SQLNULL) { + // parameter - unknown return type + return nullptr; + } + // emit a constant expression + return make_uniq(Value(return_type.ToString())); +} + +ScalarFunction TypeOfFun::GetFunction() { + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, TypeOfFunction); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.bind_expression = BindTypeOfFunctionExpression; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp b/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp new file mode 100644 index 00000000..0962b3c2 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp @@ -0,0 +1,460 @@ +#include "core_functions/scalar/list_functions.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/swap.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/function/scalar/string_common.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" + +namespace duckdb { + +struct ListSliceBindData : public FunctionData { + ListSliceBindData(const LogicalType &return_type_p, bool begin_is_empty_p, bool end_is_empty_p) + : return_type(return_type_p), begin_is_empty(begin_is_empty_p), end_is_empty(end_is_empty_p) { + } + ~ListSliceBindData() override; + + LogicalType return_type; + + bool begin_is_empty; + bool end_is_empty; + +public: + bool Equals(const FunctionData &other_p) const override; + unique_ptr Copy() const override; +}; + +ListSliceBindData::~ListSliceBindData() { +} + +bool ListSliceBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return return_type == other.return_type && begin_is_empty == other.begin_is_empty && + end_is_empty == other.end_is_empty; +} + +unique_ptr ListSliceBindData::Copy() const { + return make_uniq(return_type, begin_is_empty, end_is_empty); +} + +template +static idx_t CalculateSliceLength(idx_t begin, idx_t end, INDEX_TYPE step, bool svalid) { + if (step < 0) { + step = AbsValue(step); + } + if (step == 0 && svalid) { + throw InvalidInputException("Slice step cannot be zero"); + } + if (step == 1) { + return NumericCast(end - begin); + } else if (static_cast(step) >= (end - begin)) { + return 1; + } + if ((end - begin) % UnsafeNumericCast(step) != 0) { + return (end - begin) / UnsafeNumericCast(step) + 1; + } + return (end - begin) / UnsafeNumericCast(step); +} + +struct BlobSliceOperations { + static int64_t ValueLength(const string_t &value) { + return UnsafeNumericCast(value.GetSize()); + } + + static string_t SliceValue(Vector &result, string_t input, int64_t begin, int64_t end) { + return SubstringASCII(result, input, begin + 1, end - begin); + } + + static string_t SliceValueWithSteps(Vector &result, SelectionVector &sel, string_t input, int64_t begin, + int64_t end, int64_t step, idx_t &sel_idx) { + throw InternalException("Slicing with steps is not supported for strings"); + } +}; + +struct StringSliceOperations { + static int64_t ValueLength(const string_t &value) { + return Length(value); + } + + static string_t SliceValue(Vector &result, string_t input, int64_t begin, int64_t end) { + return SubstringUnicode(result, input, begin + 1, end - begin); + } + + static string_t SliceValueWithSteps(Vector &result, SelectionVector &sel, string_t input, int64_t begin, + int64_t end, int64_t step, idx_t &sel_idx) { + throw InternalException("Slicing with steps is not supported for strings"); + } +}; + +struct ListSliceOperations { + static int64_t ValueLength(const list_entry_t &value) { + return UnsafeNumericCast(value.length); + } + + static list_entry_t SliceValue(Vector &result, list_entry_t input, int64_t begin, int64_t end) { + input.offset = UnsafeNumericCast(UnsafeNumericCast(input.offset) + begin); + input.length = UnsafeNumericCast(end - begin); + return input; + } + + static list_entry_t SliceValueWithSteps(Vector &result, SelectionVector &sel, list_entry_t input, int64_t begin, + int64_t end, int64_t step, idx_t &sel_idx) { + if (end - begin == 0) { + input.length = 0; + input.offset = sel_idx; + return input; + } + input.length = CalculateSliceLength(UnsafeNumericCast(begin), UnsafeNumericCast(end), step, true); + idx_t child_idx = input.offset + UnsafeNumericCast(begin); + if (step < 0) { + child_idx = input.offset + UnsafeNumericCast(end) - 1; + } + input.offset = sel_idx; + for (idx_t i = 0; i < input.length; i++) { + sel.set_index(sel_idx, child_idx); + child_idx += static_cast(step); // intentional overflow?? + sel_idx++; + } + return input; + } +}; + +template +static void ClampIndex(INDEX_TYPE &index, const INPUT_TYPE &value, const INDEX_TYPE length, bool is_min) { + if (index < 0) { + index = (!is_min) ? index + 1 : index; + index = length + index; + return; + } else if (index > length) { + index = length; + } + return; +} + +template +static bool ClampSlice(const INPUT_TYPE &value, INDEX_TYPE &begin, INDEX_TYPE &end) { + // Clamp offsets + begin = (begin != 0 && begin != (INDEX_TYPE)NumericLimits::Minimum()) ? begin - 1 : begin; + + bool is_min = false; + if (begin == (INDEX_TYPE)NumericLimits::Minimum()) { + begin++; + is_min = true; + } + + const auto length = OP::ValueLength(value); + if (begin < 0 && -begin > length && end < 0 && end < -length) { + begin = 0; + end = 0; + return true; + } + if (begin < 0 && -begin > length) { + begin = 0; + } + ClampIndex(begin, value, length, is_min); + ClampIndex(end, value, length, false); + end = MaxValue(begin, end); + + return true; +} + +template +static void ExecuteConstantSlice(Vector &result, Vector &str_vector, Vector &begin_vector, Vector &end_vector, + optional_ptr step_vector, const idx_t count, SelectionVector &sel, + idx_t &sel_idx, optional_ptr result_child_vector, bool begin_is_empty, + bool end_is_empty) { + + // check all this nullness early + auto str_valid = !ConstantVector::IsNull(str_vector); + auto begin_valid = !ConstantVector::IsNull(begin_vector); + auto end_valid = !ConstantVector::IsNull(end_vector); + auto step_valid = step_vector && !ConstantVector::IsNull(*step_vector); + + if (!str_valid || !begin_valid || !end_valid || (step_vector && !step_valid)) { + ConstantVector::SetNull(result, true); + return; + } + + auto result_data = ConstantVector::GetData(result); + auto str_data = ConstantVector::GetData(str_vector); + auto begin_data = ConstantVector::GetData(begin_vector); + auto end_data = ConstantVector::GetData(end_vector); + auto step_data = step_vector ? ConstantVector::GetData(*step_vector) : nullptr; + + auto str = str_data[0]; + auto begin = begin_is_empty ? 0 : begin_data[0]; + auto end = end_is_empty ? OP::ValueLength(str) : end_data[0]; + auto step = step_data ? step_data[0] : 1; + + if (step < 0) { + swap(begin, end); + begin = end_is_empty ? 0 : begin; + end = begin_is_empty ? OP::ValueLength(str) : end; + } + + // Clamp offsets + bool clamp_result = false; + if (step_valid || step == 1) { + clamp_result = ClampSlice(str, begin, end); + } + + idx_t sel_length = 0; + bool sel_valid = false; + if (step_valid && step != 1 && end - begin > 0) { + sel_length = + CalculateSliceLength(UnsafeNumericCast(begin), UnsafeNumericCast(end), step, step_valid); + sel.Initialize(sel_length); + sel_valid = true; + } + + // Try to slice + if (!clamp_result) { + ConstantVector::SetNull(result, true); + } else if (step == 1) { + result_data[0] = OP::SliceValue(result, str, begin, end); + } else { + result_data[0] = OP::SliceValueWithSteps(result, sel, str, begin, end, step, sel_idx); + } + + if (sel_valid) { + result_child_vector->Slice(sel, sel_length); + result_child_vector->Flatten(sel_length); + ListVector::SetListSize(result, sel_length); + } +} + +template +static void ExecuteFlatSlice(Vector &result, Vector &list_vector, Vector &begin_vector, Vector &end_vector, + optional_ptr step_vector, const idx_t count, SelectionVector &sel, idx_t &sel_idx, + optional_ptr result_child_vector, bool begin_is_empty, bool end_is_empty) { + UnifiedVectorFormat list_data, begin_data, end_data, step_data; + idx_t sel_length = 0; + + list_vector.ToUnifiedFormat(count, list_data); + begin_vector.ToUnifiedFormat(count, begin_data); + end_vector.ToUnifiedFormat(count, end_data); + if (step_vector) { + step_vector->ToUnifiedFormat(count, step_data); + sel.Initialize(ListVector::GetListSize(list_vector)); + } + + auto result_data = FlatVector::GetData(result); + auto &result_mask = FlatVector::Validity(result); + + for (idx_t i = 0; i < count; ++i) { + auto list_idx = list_data.sel->get_index(i); + auto begin_idx = begin_data.sel->get_index(i); + auto end_idx = end_data.sel->get_index(i); + auto step_idx = step_vector ? step_data.sel->get_index(i) : 0; + + auto list_valid = list_data.validity.RowIsValid(list_idx); + auto begin_valid = begin_data.validity.RowIsValid(begin_idx); + auto end_valid = end_data.validity.RowIsValid(end_idx); + auto step_valid = step_vector && step_data.validity.RowIsValid(step_idx); + + if (!list_valid || !begin_valid || !end_valid || (step_vector && !step_valid)) { + result_mask.SetInvalid(i); + continue; + } + + auto sliced = reinterpret_cast(list_data.data)[list_idx]; + auto begin = begin_is_empty ? 0 : reinterpret_cast(begin_data.data)[begin_idx]; + auto end = end_is_empty ? OP::ValueLength(sliced) : reinterpret_cast(end_data.data)[end_idx]; + auto step = step_vector ? reinterpret_cast(step_data.data)[step_idx] : 1; + + if (step < 0) { + swap(begin, end); + begin = end_is_empty ? 0 : begin; + end = begin_is_empty ? OP::ValueLength(sliced) : end; + } + + bool clamp_result = false; + if (step_valid || step == 1) { + clamp_result = ClampSlice(sliced, begin, end); + } + + idx_t length = 0; + if (end - begin > 0) { + length = + CalculateSliceLength(UnsafeNumericCast(begin), UnsafeNumericCast(end), step, step_valid); + } + sel_length += length; + + if (!clamp_result) { + result_mask.SetInvalid(i); + } else if (!step_vector) { + result_data[i] = OP::SliceValue(result, sliced, begin, end); + } else { + result_data[i] = OP::SliceValueWithSteps(result, sel, sliced, begin, end, step, sel_idx); + } + } + if (step_vector) { + SelectionVector new_sel(sel_length); + for (idx_t i = 0; i < sel_length; ++i) { + new_sel.set_index(i, sel.get_index(i)); + } + result_child_vector->Slice(new_sel, sel_length); + result_child_vector->Flatten(sel_length); + ListVector::SetListSize(result, sel_length); + } +} + +template +static void ExecuteSlice(Vector &result, Vector &list_or_str_vector, Vector &begin_vector, Vector &end_vector, + optional_ptr step_vector, const idx_t count, bool begin_is_empty, bool end_is_empty) { + optional_ptr result_child_vector; + if (step_vector) { + result_child_vector = &ListVector::GetEntry(result); + } + + SelectionVector sel; + idx_t sel_idx = 0; + + if (result.GetVectorType() == VectorType::CONSTANT_VECTOR) { + ExecuteConstantSlice(result, list_or_str_vector, begin_vector, end_vector, + step_vector, count, sel, sel_idx, result_child_vector, + begin_is_empty, end_is_empty); + } else { + ExecuteFlatSlice(result, list_or_str_vector, begin_vector, end_vector, step_vector, + count, sel, sel_idx, result_child_vector, begin_is_empty, + end_is_empty); + } + result.Verify(count); +} + +static void ArraySliceFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3 || args.ColumnCount() == 4); + D_ASSERT(args.data.size() == 3 || args.data.size() == 4); + auto count = args.size(); + + Vector &list_or_str_vector = result; + // this ensures that we do not change the input chunk + VectorOperations::Copy(args.data[0], list_or_str_vector, count, 0, 0); + + if (list_or_str_vector.GetType().id() == LogicalTypeId::SQLNULL) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + + Vector &begin_vector = args.data[1]; + Vector &end_vector = args.data[2]; + + optional_ptr step_vector; + if (args.ColumnCount() == 4) { + step_vector = &args.data[3]; + } + + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto begin_is_empty = info.begin_is_empty; + auto end_is_empty = info.end_is_empty; + + result.SetVectorType(args.AllConstant() ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); + switch (result.GetType().id()) { + case LogicalTypeId::LIST: { + // Share the value dictionary as we are just going to slice it + if (list_or_str_vector.GetVectorType() != VectorType::FLAT_VECTOR && + list_or_str_vector.GetVectorType() != VectorType::CONSTANT_VECTOR) { + list_or_str_vector.Flatten(count); + } + ExecuteSlice(result, list_or_str_vector, begin_vector, end_vector, + step_vector, count, begin_is_empty, end_is_empty); + break; + } + case LogicalTypeId::BLOB: + ExecuteSlice(result, list_or_str_vector, begin_vector, end_vector, + step_vector, count, begin_is_empty, end_is_empty); + break; + case LogicalTypeId::VARCHAR: + ExecuteSlice(result, list_or_str_vector, begin_vector, end_vector, + step_vector, count, begin_is_empty, end_is_empty); + break; + default: + throw NotImplementedException("Specifier type not implemented"); + } +} + +static bool CheckIfParamIsEmpty(duckdb::unique_ptr ¶m) { + bool is_empty = false; + if (param->return_type.id() == LogicalTypeId::LIST) { + auto empty_list = make_uniq(Value::LIST(LogicalType::INTEGER, vector())); + is_empty = param->Equals(*empty_list); + if (!is_empty) { + // if the param is not empty, the user has entered a list instead of a BIGINT + throw BinderException("The upper and lower bounds of the slice must be a BIGINT"); + } + } + return is_empty; +} + +static unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(arguments.size() == 3 || arguments.size() == 4); + D_ASSERT(bound_function.arguments.size() == 3 || bound_function.arguments.size() == 4); + + switch (arguments[0]->return_type.id()) { + case LogicalTypeId::ARRAY: { + // Cast to list + auto child_type = ArrayType::GetChildType(arguments[0]->return_type); + auto target_type = LogicalType::LIST(child_type); + arguments[0] = BoundCastExpression::AddCastToType(context, std::move(arguments[0]), target_type); + bound_function.return_type = arguments[0]->return_type; + } break; + case LogicalTypeId::LIST: + // The result is the same type + bound_function.return_type = arguments[0]->return_type; + break; + case LogicalTypeId::BLOB: + case LogicalTypeId::VARCHAR: + // string slice returns a string + if (bound_function.arguments.size() == 4) { + throw NotImplementedException( + "Slice with steps has not been implemented for string types, you can consider rewriting your query as " + "follows:\n SELECT array_to_string((str_split(string, '')[begin:end:step], '');"); + } + bound_function.return_type = arguments[0]->return_type; + for (idx_t i = 1; i < 3; i++) { + if (arguments[i]->return_type.id() != LogicalTypeId::LIST) { + bound_function.arguments[i] = LogicalType::BIGINT; + } + } + break; + case LogicalTypeId::SQLNULL: + case LogicalTypeId::UNKNOWN: + bound_function.arguments[0] = LogicalTypeId::UNKNOWN; + bound_function.return_type = LogicalType::SQLNULL; + break; + default: + throw BinderException("ARRAY_SLICE can only operate on LISTs and VARCHARs"); + } + + bool begin_is_empty = CheckIfParamIsEmpty(arguments[1]); + if (!begin_is_empty) { + bound_function.arguments[1] = LogicalType::BIGINT; + } + bool end_is_empty = CheckIfParamIsEmpty(arguments[2]); + if (!end_is_empty) { + bound_function.arguments[2] = LogicalType::BIGINT; + } + + return make_uniq(bound_function.return_type, begin_is_empty, end_is_empty); +} + +ScalarFunctionSet ListSliceFun::GetFunctions() { + // the arguments and return types are actually set in the binder function + ScalarFunction fun({LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, ArraySliceFunction, + ArraySliceBind); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + BaseScalarFunction::SetReturnsError(fun); + ScalarFunctionSet set; + set.AddFunction(fun); + fun.arguments.push_back(LogicalType::BIGINT); + set.AddFunction(fun); + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/flatten.cpp b/src/duckdb/extension/core_functions/scalar/list/flatten.cpp new file mode 100644 index 00000000..849c20d1 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/flatten.cpp @@ -0,0 +1,171 @@ +#include "core_functions/scalar/list_functions.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" + +namespace duckdb { + +void ListFlattenFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + + Vector &input = args.data[0]; + if (input.GetType().id() == LogicalTypeId::SQLNULL) { + result.Reference(input); + return; + } + + idx_t count = args.size(); + + // Prepare the result vector + result.SetVectorType(VectorType::FLAT_VECTOR); + // This holds the new offsets and lengths + auto result_entries = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + // The outermost list in each row + UnifiedVectorFormat row_data; + input.ToUnifiedFormat(count, row_data); + auto row_entries = UnifiedVectorFormat::GetData(row_data); + + // The list elements in each row: [HERE, ...] + auto &row_lists = ListVector::GetEntry(input); + UnifiedVectorFormat row_lists_data; + idx_t total_row_lists = ListVector::GetListSize(input); + row_lists.ToUnifiedFormat(total_row_lists, row_lists_data); + auto row_lists_entries = UnifiedVectorFormat::GetData(row_lists_data); + + if (row_lists.GetType().id() == LogicalTypeId::SQLNULL) { + for (idx_t row_cnt = 0; row_cnt < count; row_cnt++) { + auto row_idx = row_data.sel->get_index(row_cnt); + if (!row_data.validity.RowIsValid(row_idx)) { + result_validity.SetInvalid(row_cnt); + continue; + } + result_entries[row_cnt].offset = 0; + result_entries[row_cnt].length = 0; + } + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + return; + } + + // The actual elements inside each row list: [[HERE, ...], []] + // This one becomes the child vector of the result. + auto &elem_vector = ListVector::GetEntry(row_lists); + + // We'll use this selection vector to slice the elem_vector. + idx_t child_elem_cnt = ListVector::GetListSize(row_lists); + SelectionVector sel(child_elem_cnt); + idx_t sel_idx = 0; + + // HERE, [[]], ... + for (idx_t row_cnt = 0; row_cnt < count; row_cnt++) { + auto row_idx = row_data.sel->get_index(row_cnt); + + if (!row_data.validity.RowIsValid(row_idx)) { + result_validity.SetInvalid(row_cnt); + continue; + } + + idx_t list_offset = sel_idx; + idx_t list_length = 0; + + // [HERE, [...], ...] + auto row_entry = row_entries[row_idx]; + for (idx_t row_lists_cnt = 0; row_lists_cnt < row_entry.length; row_lists_cnt++) { + auto row_lists_idx = row_lists_data.sel->get_index(row_entry.offset + row_lists_cnt); + + // Skip invalid lists + if (!row_lists_data.validity.RowIsValid(row_lists_idx)) { + continue; + } + + // [[HERE, ...], [.., ...]] + auto list_entry = row_lists_entries[row_lists_idx]; + list_length += list_entry.length; + + for (idx_t elem_cnt = 0; elem_cnt < list_entry.length; elem_cnt++) { + // offset of the element in the elem_vector. + idx_t offset = list_entry.offset + elem_cnt; + sel.set_index(sel_idx, offset); + sel_idx++; + } + } + + result_entries[row_cnt].offset = list_offset; + result_entries[row_cnt].length = list_length; + } + + ListVector::SetListSize(result, sel_idx); + + auto &result_child_vector = ListVector::GetEntry(result); + result_child_vector.Slice(elem_vector, sel, sel_idx); + result_child_vector.Flatten(sel_idx); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr ListFlattenBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 1); + + if (arguments[0]->return_type.id() == LogicalTypeId::ARRAY) { + auto child_type = ArrayType::GetChildType(arguments[0]->return_type); + if (child_type.id() == LogicalTypeId::ARRAY) { + child_type = LogicalType::LIST(ArrayType::GetChildType(child_type)); + } + arguments[0] = + BoundCastExpression::AddCastToType(context, std::move(arguments[0]), LogicalType::LIST(child_type)); + } else if (arguments[0]->return_type.id() == LogicalTypeId::LIST) { + auto child_type = ListType::GetChildType(arguments[0]->return_type); + if (child_type.id() == LogicalTypeId::ARRAY) { + child_type = LogicalType::LIST(ArrayType::GetChildType(child_type)); + arguments[0] = + BoundCastExpression::AddCastToType(context, std::move(arguments[0]), LogicalType::LIST(child_type)); + } + } + + auto &input_type = arguments[0]->return_type; + bound_function.arguments[0] = input_type; + if (input_type.id() == LogicalTypeId::UNKNOWN) { + bound_function.arguments[0] = LogicalType(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + D_ASSERT(input_type.id() == LogicalTypeId::LIST); + + auto child_type = ListType::GetChildType(input_type); + if (child_type.id() == LogicalType::SQLNULL) { + bound_function.return_type = input_type; + return make_uniq(bound_function.return_type); + } + if (child_type.id() == LogicalTypeId::UNKNOWN) { + bound_function.arguments[0] = LogicalType(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + D_ASSERT(child_type.id() == LogicalTypeId::LIST); + + bound_function.return_type = child_type; + return make_uniq(bound_function.return_type); +} + +static unique_ptr ListFlattenStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &list_child_stats = ListStats::GetChildStats(child_stats[0]); + auto child_copy = list_child_stats.Copy(); + child_copy.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + return child_copy.ToUnique(); +} + +ScalarFunction ListFlattenFun::GetFunction() { + return ScalarFunction({LogicalType::LIST(LogicalType::LIST(LogicalType::ANY))}, LogicalType::LIST(LogicalType::ANY), + ListFlattenFunction, ListFlattenBind, nullptr, ListFlattenStats); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp new file mode 100644 index 00000000..1b2aab71 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp @@ -0,0 +1,534 @@ +#include "core_functions/scalar/list_functions.hpp" +#include "core_functions/aggregate/nested_functions.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/function/create_sort_key.hpp" +#include "duckdb/common/owning_string_map.hpp" + +namespace duckdb { + +// FIXME: use a local state for each thread to increase performance? +// FIXME: benchmark the use of simple_update against using update (if applicable) + +static unique_ptr ListAggregatesBindFailure(ScalarFunction &bound_function) { + bound_function.arguments[0] = LogicalType::SQLNULL; + bound_function.return_type = LogicalType::SQLNULL; + return make_uniq(LogicalType::SQLNULL); +} + +struct ListAggregatesBindData : public FunctionData { + ListAggregatesBindData(const LogicalType &stype_p, unique_ptr aggr_expr_p); + ~ListAggregatesBindData() override; + + LogicalType stype; + unique_ptr aggr_expr; + + unique_ptr Copy() const override { + return make_uniq(stype, aggr_expr->Copy()); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return stype == other.stype && aggr_expr->Equals(*other.aggr_expr); + } + void Serialize(Serializer &serializer) const { + serializer.WriteProperty(1, "stype", stype); + serializer.WriteProperty(2, "aggr_expr", aggr_expr); + } + static unique_ptr Deserialize(Deserializer &deserializer) { + auto stype = deserializer.ReadProperty(1, "stype"); + auto aggr_expr = deserializer.ReadProperty>(2, "aggr_expr"); + auto result = make_uniq(std::move(stype), std::move(aggr_expr)); + return result; + } + + static void SerializeFunction(Serializer &serializer, const optional_ptr bind_data_p, + const ScalarFunction &function) { + auto bind_data = dynamic_cast(bind_data_p.get()); + serializer.WritePropertyWithDefault(100, "bind_data", bind_data, (const ListAggregatesBindData *)nullptr); + } + + static unique_ptr DeserializeFunction(Deserializer &deserializer, ScalarFunction &bound_function) { + auto result = deserializer.ReadPropertyWithExplicitDefault>( + 100, "bind_data", unique_ptr(nullptr)); + if (!result) { + return ListAggregatesBindFailure(bound_function); + } + return std::move(result); + } +}; + +ListAggregatesBindData::ListAggregatesBindData(const LogicalType &stype_p, unique_ptr aggr_expr_p) + : stype(stype_p), aggr_expr(std::move(aggr_expr_p)) { +} + +ListAggregatesBindData::~ListAggregatesBindData() { +} + +struct StateVector { + StateVector(idx_t count_p, unique_ptr aggr_expr_p) + : count(count_p), aggr_expr(std::move(aggr_expr_p)), state_vector(Vector(LogicalType::POINTER, count_p)) { + } + + ~StateVector() { // NOLINT + // destroy objects within the aggregate states + auto &aggr = aggr_expr->Cast(); + if (aggr.function.destructor) { + ArenaAllocator allocator(Allocator::DefaultAllocator()); + AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); + aggr.function.destructor(state_vector, aggr_input_data, count); + } + } + + idx_t count; + unique_ptr aggr_expr; + Vector state_vector; +}; + +struct FinalizeValueFunctor { + template + static void HistogramFinalize(T value, Vector &result, idx_t offset) { + FlatVector::GetData(result)[offset] = value; + } +}; + +struct FinalizeStringValueFunctor { + template + static void HistogramFinalize(T value, Vector &result, idx_t offset) { + FlatVector::GetData(result)[offset] = StringVector::AddStringOrBlob(result, value); + } +}; + +struct FinalizeGenericValueFunctor { + template + static void HistogramFinalize(T value, Vector &result, idx_t offset) { + CreateSortKeyHelpers::DecodeSortKey(value, result, offset, + OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); + } +}; + +struct AggregateFunctor { + template > + static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { + } +}; + +struct DistinctFunctor { + template > + static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = UnifiedVectorFormat::GetData *>(sdata); + + auto old_len = ListVector::GetListSize(result); + idx_t new_entries = 0; + // figure out how much space we need + for (idx_t i = 0; i < count; i++) { + auto &state = *states[sdata.sel->get_index(i)]; + if (!state.hist) { + continue; + } + new_entries += state.hist->size(); + } + // reserve space in the list vector + ListVector::Reserve(result, old_len + new_entries); + auto &child_elements = ListVector::GetEntry(result); + auto list_entries = FlatVector::GetData(result); + + idx_t current_offset = old_len; + for (idx_t i = 0; i < count; i++) { + const auto rid = i; + auto &state = *states[sdata.sel->get_index(i)]; + auto &list_entry = list_entries[rid]; + list_entry.offset = current_offset; + if (!state.hist) { + list_entry.length = 0; + continue; + } + + for (auto &entry : *state.hist) { + OP::template HistogramFinalize(entry.first, child_elements, current_offset); + current_offset++; + } + list_entry.length = current_offset - list_entry.offset; + } + D_ASSERT(current_offset == old_len + new_entries); + ListVector::SetListSize(result, current_offset); + result.Verify(count); + } +}; + +struct UniqueFunctor { + template > + static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { + UnifiedVectorFormat sdata; + state_vector.ToUnifiedFormat(count, sdata); + auto states = UnifiedVectorFormat::GetData *>(sdata); + + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + + auto state = states[sdata.sel->get_index(i)]; + + if (!state->hist) { + result_data[i] = 0; + continue; + } + result_data[i] = state->hist->size(); + } + result.Verify(count); + } +}; + +template +static void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto count = args.size(); + Vector &lists = args.data[0]; + + // set the result vector + result.SetVectorType(VectorType::FLAT_VECTOR); + auto &result_validity = FlatVector::Validity(result); + + if (lists.GetType().id() == LogicalTypeId::SQLNULL) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + + // get the aggregate function + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto &aggr = info.aggr_expr->Cast(); + ArenaAllocator allocator(Allocator::DefaultAllocator()); + AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); + + D_ASSERT(aggr.function.update); + + auto lists_size = ListVector::GetListSize(lists); + auto &child_vector = ListVector::GetEntry(lists); + child_vector.Flatten(lists_size); + + UnifiedVectorFormat child_data; + child_vector.ToUnifiedFormat(lists_size, child_data); + + UnifiedVectorFormat lists_data; + lists.ToUnifiedFormat(count, lists_data); + auto list_entries = UnifiedVectorFormat::GetData(lists_data); + + // state_buffer holds the state for each list of this chunk + idx_t size = aggr.function.state_size(aggr.function); + auto state_buffer = make_unsafe_uniq_array_uninitialized(size * count); + + // state vector for initialize and finalize + StateVector state_vector(count, info.aggr_expr->Copy()); + auto states = FlatVector::GetData(state_vector.state_vector); + + // state vector of STANDARD_VECTOR_SIZE holds the pointers to the states + Vector state_vector_update = Vector(LogicalType::POINTER); + auto states_update = FlatVector::GetData(state_vector_update); + + // selection vector pointing to the data + SelectionVector sel_vector(STANDARD_VECTOR_SIZE); + idx_t states_idx = 0; + + for (idx_t i = 0; i < count; i++) { + + // initialize the state for this list + auto state_ptr = state_buffer.get() + size * i; + states[i] = state_ptr; + aggr.function.initialize(aggr.function, states[i]); + + auto lists_index = lists_data.sel->get_index(i); + const auto &list_entry = list_entries[lists_index]; + + // nothing to do for this list + if (!lists_data.validity.RowIsValid(lists_index)) { + result_validity.SetInvalid(i); + continue; + } + + // skip empty list + if (list_entry.length == 0) { + continue; + } + + for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { + // states vector is full, update + if (states_idx == STANDARD_VECTOR_SIZE) { + // update the aggregate state(s) + Vector slice(child_vector, sel_vector, states_idx); + aggr.function.update(&slice, aggr_input_data, 1, state_vector_update, states_idx); + + // reset values + states_idx = 0; + } + + auto source_idx = child_data.sel->get_index(list_entry.offset + child_idx); + sel_vector.set_index(states_idx, source_idx); + states_update[states_idx] = state_ptr; + states_idx++; + } + } + + // update the remaining elements of the last list(s) + if (states_idx != 0) { + Vector slice(child_vector, sel_vector, states_idx); + aggr.function.update(&slice, aggr_input_data, 1, state_vector_update, states_idx); + } + + if (IS_AGGR) { + // finalize all the aggregate states + aggr.function.finalize(state_vector.state_vector, aggr_input_data, result, count, 0); + + } else { + // finalize manually to use the map + D_ASSERT(aggr.function.arguments.size() == 1); + auto key_type = aggr.function.arguments[0]; + + switch (key_type.InternalType()) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::BOOL: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::UINT8: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::UINT16: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::UINT32: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::UINT64: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::INT8: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::INT16: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::INT32: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::INT64: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::FLOAT: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::DOUBLE: + FUNCTION_FUNCTOR::template ListExecuteFunction( + result, state_vector.state_vector, count); + break; + case PhysicalType::VARCHAR: + FUNCTION_FUNCTOR::template ListExecuteFunction>(result, state_vector.state_vector, + count); + break; +#endif + default: + FUNCTION_FUNCTOR::template ListExecuteFunction>(result, state_vector.state_vector, + count); + break; + } + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static void ListAggregateFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() >= 2); + ListAggregatesFunction(args, state, result); +} + +static void ListDistinctFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + ListAggregatesFunction(args, state, result); +} + +static void ListUniqueFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + ListAggregatesFunction(args, state, result); +} + +template +static unique_ptr +ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_function, const LogicalType &list_child_type, + AggregateFunction &aggr_function, vector> &arguments) { + + // create the child expression and its type + vector> children; + auto expr = make_uniq(Value(list_child_type)); + children.push_back(std::move(expr)); + // push any extra arguments into the list aggregate bind + if (arguments.size() > 2) { + for (idx_t i = 2; i < arguments.size(); i++) { + children.push_back(std::move(arguments[i])); + } + arguments.resize(2); + } + + FunctionBinder function_binder(context); + auto bound_aggr_function = function_binder.BindAggregateFunction(aggr_function, std::move(children)); + bound_function.arguments[0] = LogicalType::LIST(bound_aggr_function->function.arguments[0]); + + if (IS_AGGR) { + bound_function.return_type = bound_aggr_function->function.return_type; + } + // check if the aggregate function consumed all the extra input arguments + if (bound_aggr_function->children.size() > 1) { + throw InvalidInputException( + "Aggregate function %s is not supported for list_aggr: extra arguments were not removed during bind", + bound_aggr_function->ToString()); + } + + return make_uniq(bound_function.return_type, std::move(bound_aggr_function)); +} + +template +static unique_ptr ListAggregatesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + + if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { + return ListAggregatesBindFailure(bound_function); + } + + bool is_parameter = arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN; + LogicalType child_type; + if (is_parameter) { + child_type = LogicalType::ANY; + } else if (arguments[0]->return_type.id() == LogicalTypeId::LIST || + arguments[0]->return_type.id() == LogicalTypeId::MAP) { + child_type = ListType::GetChildType(arguments[0]->return_type); + } else { + // Unreachable + throw InvalidInputException("First argument of list aggregate must be a list, map or array"); + } + + string function_name = "histogram"; + if (IS_AGGR) { // get the name of the aggregate function + if (!arguments[1]->IsFoldable()) { + throw InvalidInputException("Aggregate function name must be a constant"); + } + // get the function name + Value function_value = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + function_name = function_value.ToString(); + } + + // look up the aggregate function in the catalog + auto &func = Catalog::GetSystemCatalog(context).GetEntry(context, DEFAULT_SCHEMA, + function_name); + D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); + + if (is_parameter) { + bound_function.arguments[0] = LogicalTypeId::UNKNOWN; + bound_function.return_type = LogicalType::SQLNULL; + return nullptr; + } + + // find a matching aggregate function + ErrorData error; + vector types; + types.push_back(child_type); + // push any extra arguments into the type list + for (idx_t i = 2; i < arguments.size(); i++) { + types.push_back(arguments[i]->return_type); + } + + FunctionBinder function_binder(context); + auto best_function_idx = function_binder.BindFunction(func.name, func.functions, types, error); + if (!best_function_idx.IsValid()) { + throw BinderException("No matching aggregate function\n%s", error.Message()); + } + + // found a matching function, bind it as an aggregate + auto best_function = func.functions.GetFunctionByOffset(best_function_idx.GetIndex()); + if (IS_AGGR) { + bound_function.errors = best_function.errors; + return ListAggregatesBindFunction(context, bound_function, child_type, best_function, arguments); + } + + // create the unordered map histogram function + D_ASSERT(best_function.arguments.size() == 1); + auto aggr_function = HistogramFun::GetHistogramUnorderedMap(child_type); + return ListAggregatesBindFunction(context, bound_function, child_type, aggr_function, arguments); +} + +static unique_ptr ListAggregateBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + // the list column and the name of the aggregate function + D_ASSERT(bound_function.arguments.size() >= 2); + D_ASSERT(arguments.size() >= 2); + + return ListAggregatesBind(context, bound_function, arguments); +} + +static unique_ptr ListDistinctBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + D_ASSERT(bound_function.arguments.size() == 1); + D_ASSERT(arguments.size() == 1); + + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + bound_function.return_type = arguments[0]->return_type; + + return ListAggregatesBind<>(context, bound_function, arguments); +} + +static unique_ptr ListUniqueBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + D_ASSERT(bound_function.arguments.size() == 1); + D_ASSERT(arguments.size() == 1); + bound_function.return_type = LogicalType::UBIGINT; + + return ListAggregatesBind<>(context, bound_function, arguments); +} + +ScalarFunction ListAggregateFun::GetFunction() { + auto result = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, LogicalType::ANY, + ListAggregateFunction, ListAggregateBind); + BaseScalarFunction::SetReturnsError(result); + result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + result.varargs = LogicalType::ANY; + result.serialize = ListAggregatesBindData::SerializeFunction; + result.deserialize = ListAggregatesBindData::DeserializeFunction; + return result; +} + +ScalarFunction ListDistinctFun::GetFunction() { + return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), + ListDistinctFunction, ListDistinctBind); +} + +ScalarFunction ListUniqueFun::GetFunction() { + return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::UBIGINT, ListUniqueFunction, + ListUniqueBind); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp b/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp new file mode 100644 index 00000000..5c3513b2 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp @@ -0,0 +1,131 @@ +#include "core_functions/scalar/list_functions.hpp" +#include "core_functions/array_kernels.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +//------------------------------------------------------------------------------ +// Generic "fold" function +//------------------------------------------------------------------------------ +// Given two lists of the same size, combine and reduce their elements into a +// single scalar value. + +template +static void ListGenericFold(DataChunk &args, ExpressionState &state, Vector &result) { + const auto &lstate = state.Cast(); + const auto &expr = lstate.expr.Cast(); + const auto &func_name = expr.function.name; + + auto count = args.size(); + + auto &lhs_vec = args.data[0]; + auto &rhs_vec = args.data[1]; + + const auto lhs_count = ListVector::GetListSize(lhs_vec); + const auto rhs_count = ListVector::GetListSize(rhs_vec); + + auto &lhs_child = ListVector::GetEntry(lhs_vec); + auto &rhs_child = ListVector::GetEntry(rhs_vec); + + lhs_child.Flatten(lhs_count); + rhs_child.Flatten(rhs_count); + + D_ASSERT(lhs_child.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(rhs_child.GetVectorType() == VectorType::FLAT_VECTOR); + + if (!FlatVector::Validity(lhs_child).CheckAllValid(lhs_count)) { + throw InvalidInputException("%s: left argument can not contain NULL values", func_name); + } + + if (!FlatVector::Validity(rhs_child).CheckAllValid(rhs_count)) { + throw InvalidInputException("%s: right argument can not contain NULL values", func_name); + } + + auto lhs_data = FlatVector::GetData(lhs_child); + auto rhs_data = FlatVector::GetData(rhs_child); + + BinaryExecutor::ExecuteWithNulls( + lhs_vec, rhs_vec, result, count, + [&](const list_entry_t &left, const list_entry_t &right, ValidityMask &mask, idx_t row_idx) { + if (left.length != right.length) { + throw InvalidInputException( + "%s: list dimensions must be equal, got left length '%d' and right length '%d'", func_name, + left.length, right.length); + } + + if (!OP::ALLOW_EMPTY && left.length == 0) { + mask.SetInvalid(row_idx); + return TYPE(); + } + + return OP::Operation(lhs_data + left.offset, rhs_data + right.offset, left.length); + }); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +//------------------------------------------------------------------------- +// Function Registration +//------------------------------------------------------------------------- + +template +static void AddListFoldFunction(ScalarFunctionSet &set, const LogicalType &type) { + const auto list = LogicalType::LIST(type); + if (type.id() == LogicalTypeId::FLOAT) { + set.AddFunction(ScalarFunction({list, list}, type, ListGenericFold)); + } else if (type.id() == LogicalTypeId::DOUBLE) { + set.AddFunction(ScalarFunction({list, list}, type, ListGenericFold)); + } else { + throw NotImplementedException("List function not implemented for type %s", type.ToString()); + } +} + +ScalarFunctionSet ListDistanceFun::GetFunctions() { + ScalarFunctionSet set("list_distance"); + for (auto &type : LogicalType::Real()) { + AddListFoldFunction(set, type); + } + for (auto &func : set.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return set; +} + +ScalarFunctionSet ListInnerProductFun::GetFunctions() { + ScalarFunctionSet set("list_inner_product"); + for (auto &type : LogicalType::Real()) { + AddListFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ListNegativeInnerProductFun::GetFunctions() { + ScalarFunctionSet set("list_negative_inner_product"); + for (auto &type : LogicalType::Real()) { + AddListFoldFunction(set, type); + } + return set; +} + +ScalarFunctionSet ListCosineSimilarityFun::GetFunctions() { + ScalarFunctionSet set("list_cosine_similarity"); + for (auto &type : LogicalType::Real()) { + AddListFoldFunction(set, type); + } + for (auto &func : set.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return set; +} + +ScalarFunctionSet ListCosineDistanceFun::GetFunctions() { + ScalarFunctionSet set("list_cosine_distance"); + for (auto &type : LogicalType::Real()) { + AddListFoldFunction(set, type); + } + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp b/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp new file mode 100644 index 00000000..30ac79db --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp @@ -0,0 +1,49 @@ +#include "core_functions/scalar/list_functions.hpp" + +#include "duckdb/function/lambda_functions.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" + +namespace duckdb { + +static unique_ptr ListFilterBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + // the list column and the bound lambda expression + D_ASSERT(arguments.size() == 2); + if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { + throw BinderException("Invalid lambda expression!"); + } + + auto &bound_lambda_expr = arguments[1]->Cast(); + + // try to cast to boolean, if the return type of the lambda filter expression is not already boolean + if (bound_lambda_expr.lambda_expr->return_type != LogicalType::BOOLEAN) { + auto cast_lambda_expr = + BoundCastExpression::AddCastToType(context, std::move(bound_lambda_expr.lambda_expr), LogicalType::BOOLEAN); + bound_lambda_expr.lambda_expr = std::move(cast_lambda_expr); + } + + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + + bound_function.return_type = arguments[0]->return_type; + auto has_index = bound_lambda_expr.parameter_count == 2; + return LambdaFunctions::ListLambdaBind(context, bound_function, arguments, has_index); +} + +static LogicalType ListFilterBindLambda(const idx_t parameter_idx, const LogicalType &list_child_type) { + return LambdaFunctions::BindBinaryLambda(parameter_idx, list_child_type); +} + +ScalarFunction ListFilterFun::GetFunction() { + ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), + LambdaFunctions::ListFilterFunction, ListFilterBind, nullptr, nullptr); + + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.serialize = ListLambdaBindData::Serialize; + fun.deserialize = ListLambdaBindData::Deserialize; + fun.bind_lambda = ListFilterBindLambda; + + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp new file mode 100644 index 00000000..dd15edc9 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp @@ -0,0 +1,227 @@ +#include "duckdb/function/lambda_functions.hpp" +#include "core_functions/scalar/list_functions.hpp" +#include "duckdb/function/create_sort_key.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/common/string_map_set.hpp" + +namespace duckdb { + +static unique_ptr ListHasAnyOrAllBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + arguments[1] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[1])); + + const auto lhs_is_param = arguments[0]->HasParameter(); + const auto rhs_is_param = arguments[1]->HasParameter(); + + if (lhs_is_param && rhs_is_param) { + throw ParameterNotResolvedException(); + } + + const auto &lhs_list = arguments[0]->return_type; + const auto &rhs_list = arguments[1]->return_type; + + if (lhs_is_param) { + bound_function.arguments[0] = rhs_list; + bound_function.arguments[1] = rhs_list; + return nullptr; + } + if (rhs_is_param) { + bound_function.arguments[0] = lhs_list; + bound_function.arguments[1] = lhs_list; + return nullptr; + } + + bound_function.arguments[0] = lhs_list; + bound_function.arguments[1] = rhs_list; + + const auto &lhs_child = ListType::GetChildType(bound_function.arguments[0]); + const auto &rhs_child = ListType::GetChildType(bound_function.arguments[1]); + + if (lhs_child != LogicalType::SQLNULL && rhs_child != LogicalType::SQLNULL && lhs_child != rhs_child) { + LogicalType common_child; + if (!LogicalType::TryGetMaxLogicalType(context, lhs_child, rhs_child, common_child)) { + throw BinderException("'%s' cannot compare lists of different types: '%s' and '%s'", bound_function.name, + lhs_child.ToString(), rhs_child.ToString()); + } + bound_function.arguments[0] = LogicalType::LIST(common_child); + bound_function.arguments[1] = LogicalType::LIST(common_child); + } + + return nullptr; +} + +static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &result) { + + auto &l_vec = args.data[0]; + auto &r_vec = args.data[1]; + + if (ListType::GetChildType(l_vec.GetType()) == LogicalType::SQLNULL || + ListType::GetChildType(r_vec.GetType()) == LogicalType::SQLNULL) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::GetData(result)[0] = false; + return; + } + + const auto l_size = ListVector::GetListSize(l_vec); + const auto r_size = ListVector::GetListSize(r_vec); + + auto &l_child = ListVector::GetEntry(l_vec); + auto &r_child = ListVector::GetEntry(r_vec); + + // Setup unified formats for the list elements + UnifiedVectorFormat l_child_format; + UnifiedVectorFormat r_child_format; + + l_child.ToUnifiedFormat(l_size, l_child_format); + r_child.ToUnifiedFormat(r_size, r_child_format); + + // Create the sort keys for the list elements + Vector l_sortkey_vec(LogicalType::BLOB, l_size); + Vector r_sortkey_vec(LogicalType::BLOB, r_size); + + const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + + CreateSortKeyHelpers::CreateSortKey(l_child, l_size, order_modifiers, l_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(r_child, r_size, order_modifiers, r_sortkey_vec); + + const auto l_sortkey_ptr = FlatVector::GetData(l_sortkey_vec); + const auto r_sortkey_ptr = FlatVector::GetData(r_sortkey_vec); + + string_set_t set; + + BinaryExecutor::Execute( + l_vec, r_vec, result, args.size(), [&](const list_entry_t &l_list, const list_entry_t &r_list) { + // Short circuit if either list is empty + if (l_list.length == 0 || r_list.length == 0) { + return false; + } + + auto build_list = l_list; + auto probe_list = r_list; + + auto build_data = l_sortkey_ptr; + auto probe_data = r_sortkey_ptr; + + auto build_format = &l_child_format; + auto probe_format = &r_child_format; + + // Use the smaller list to build the set + if (r_list.length < l_list.length) { + + build_list = r_list; + probe_list = l_list; + + build_data = r_sortkey_ptr; + probe_data = l_sortkey_ptr; + + build_format = &r_child_format; + probe_format = &l_child_format; + } + + // Reset the set + set.clear(); + + // Build the set + for (auto idx = build_list.offset; idx < build_list.offset + build_list.length; idx++) { + const auto entry_idx = build_format->sel->get_index(idx); + if (build_format->validity.RowIsValid(entry_idx)) { + set.insert(build_data[entry_idx]); + } + } + // Probe the set + for (auto idx = probe_list.offset; idx < probe_list.offset + probe_list.length; idx++) { + const auto entry_idx = probe_format->sel->get_index(idx); + if (probe_format->validity.RowIsValid(entry_idx) && set.find(probe_data[entry_idx]) != set.end()) { + return true; + } + } + return false; + }); +} + +static void ListHasAllFunction(DataChunk &args, ExpressionState &state, Vector &result) { + + const auto &func_expr = state.expr.Cast(); + const auto swap = func_expr.function.name == "<@"; + + auto &l_vec = args.data[swap ? 1 : 0]; + auto &r_vec = args.data[swap ? 0 : 1]; + + if (ListType::GetChildType(l_vec.GetType()) == LogicalType::SQLNULL && + ListType::GetChildType(r_vec.GetType()) == LogicalType::SQLNULL) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::GetData(result)[0] = true; + return; + } + + const auto l_size = ListVector::GetListSize(l_vec); + const auto r_size = ListVector::GetListSize(r_vec); + + auto &l_child = ListVector::GetEntry(l_vec); + auto &r_child = ListVector::GetEntry(r_vec); + + // Setup unified formats for the list elements + UnifiedVectorFormat build_format; + UnifiedVectorFormat probe_format; + + l_child.ToUnifiedFormat(l_size, build_format); + r_child.ToUnifiedFormat(r_size, probe_format); + + // Create the sort keys for the list elements + Vector l_sortkey_vec(LogicalType::BLOB, l_size); + Vector r_sortkey_vec(LogicalType::BLOB, r_size); + + const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + + CreateSortKeyHelpers::CreateSortKey(l_child, l_size, order_modifiers, l_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(r_child, r_size, order_modifiers, r_sortkey_vec); + + const auto build_data = FlatVector::GetData(l_sortkey_vec); + const auto probe_data = FlatVector::GetData(r_sortkey_vec); + + string_set_t set; + + BinaryExecutor::Execute( + l_vec, r_vec, result, args.size(), [&](const list_entry_t &build_list, const list_entry_t &probe_list) { + // Short circuit if the probe list is empty + if (probe_list.length == 0) { + return true; + } + + // Reset the set + set.clear(); + + // Build the set + for (auto idx = build_list.offset; idx < build_list.offset + build_list.length; idx++) { + const auto entry_idx = build_format.sel->get_index(idx); + if (build_format.validity.RowIsValid(entry_idx)) { + set.insert(build_data[entry_idx]); + } + } + + // Probe the set + for (auto idx = probe_list.offset; idx < probe_list.offset + probe_list.length; idx++) { + const auto entry_idx = probe_format.sel->get_index(idx); + if (probe_format.validity.RowIsValid(entry_idx) && set.find(probe_data[entry_idx]) == set.end()) { + return false; + } + } + return true; + }); +} + +ScalarFunction ListHasAnyFun::GetFunction() { + ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, LogicalType::BOOLEAN, + ListHasAnyFunction, ListHasAnyOrAllBind); + return fun; +} + +ScalarFunction ListHasAllFun::GetFunction() { + ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, LogicalType::BOOLEAN, + ListHasAllFunction, ListHasAnyOrAllBind); + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp b/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp new file mode 100644 index 00000000..173b5269 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp @@ -0,0 +1,232 @@ +#include "core_functions/scalar/list_functions.hpp" +#include "duckdb/function/lambda_functions.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +struct ReduceExecuteInfo { + ReduceExecuteInfo(LambdaFunctions::LambdaInfo &info, ClientContext &context) + : left_slice(make_uniq(*info.child_vector)) { + SelectionVector left_vector(info.row_count); + active_rows.Resize(info.row_count); + active_rows.SetAllValid(info.row_count); + + left_sel.Initialize(info.row_count); + active_rows_sel.Initialize(info.row_count); + + idx_t reduced_row_idx = 0; + + for (idx_t original_row_idx = 0; original_row_idx < info.row_count; original_row_idx++) { + auto list_column_format_index = info.list_column_format.sel->get_index(original_row_idx); + if (info.list_column_format.validity.RowIsValid(list_column_format_index)) { + if (info.list_entries[list_column_format_index].length == 0) { + throw ParameterNotAllowedException("Cannot perform list_reduce on an empty input list"); + } + left_vector.set_index(reduced_row_idx, info.list_entries[list_column_format_index].offset); + reduced_row_idx++; + } else { + // Set the row as invalid and remove it from the active rows. + FlatVector::SetNull(info.result, original_row_idx, true); + active_rows.SetInvalid(original_row_idx); + } + } + left_slice->Slice(left_vector, reduced_row_idx); + + if (info.has_index) { + input_types.push_back(LogicalType::BIGINT); + } + input_types.push_back(left_slice->GetType()); + input_types.push_back(left_slice->GetType()); + for (auto &entry : info.column_infos) { + input_types.push_back(entry.vector.get().GetType()); + } + + expr_executor = make_uniq(context, *info.lambda_expr); + }; + ValidityMask active_rows; + unique_ptr left_slice; + unique_ptr expr_executor; + vector input_types; + + SelectionVector left_sel; + SelectionVector active_rows_sel; +}; + +static bool ExecuteReduce(idx_t loops, ReduceExecuteInfo &execute_info, LambdaFunctions::LambdaInfo &info, + DataChunk &result_chunk) { + idx_t original_row_idx = 0; + idx_t reduced_row_idx = 0; + idx_t valid_row_idx = 0; + + // create selection vectors for the left and right slice + auto data = execute_info.active_rows.GetData(); + + // reset right_sel each iteration to prevent referencing issues + SelectionVector right_sel; + right_sel.Initialize(info.row_count); + + idx_t bits_per_entry = sizeof(idx_t) * 8; + for (idx_t entry_idx = 0; original_row_idx < info.row_count; entry_idx++) { + if (data[entry_idx] == 0) { + original_row_idx += bits_per_entry; + continue; + } + + for (idx_t j = 0; entry_idx * bits_per_entry + j < info.row_count; j++) { + if (!execute_info.active_rows.RowIsValid(original_row_idx)) { + original_row_idx++; + continue; + } + auto list_column_format_index = info.list_column_format.sel->get_index(original_row_idx); + if (info.list_entries[list_column_format_index].length > loops + 1) { + right_sel.set_index(reduced_row_idx, info.list_entries[list_column_format_index].offset + loops + 1); + execute_info.left_sel.set_index(reduced_row_idx, valid_row_idx); + execute_info.active_rows_sel.set_index(reduced_row_idx, original_row_idx); + reduced_row_idx++; + + } else { + execute_info.active_rows.SetInvalid(original_row_idx); + auto val = execute_info.left_slice->GetValue(valid_row_idx); + info.result.SetValue(original_row_idx, val); + } + + original_row_idx++; + valid_row_idx++; + } + } + + if (reduced_row_idx == 0) { + return true; + } + + // create the index vector + Vector index_vector(Value::BIGINT(UnsafeNumericCast(loops + 2))); + + // slice the left and right slice + execute_info.left_slice->Slice(*execute_info.left_slice, execute_info.left_sel, reduced_row_idx); + Vector right_slice(*info.child_vector, right_sel, reduced_row_idx); + + // create the input chunk + DataChunk input_chunk; + input_chunk.InitializeEmpty(execute_info.input_types); + input_chunk.SetCardinality(reduced_row_idx); + + idx_t slice_offset = info.has_index ? 1 : 0; + if (info.has_index) { + input_chunk.data[0].Reference(index_vector); + } + input_chunk.data[slice_offset + 1].Reference(*execute_info.left_slice); + input_chunk.data[slice_offset].Reference(right_slice); + + // add the other columns + vector slices; + for (idx_t i = 0; i < info.column_infos.size(); i++) { + if (info.column_infos[i].vector.get().GetVectorType() == VectorType::CONSTANT_VECTOR) { + // only reference constant vectors + input_chunk.data[slice_offset + 2 + i].Reference(info.column_infos[i].vector); + } else { + // slice the other vectors + slices.emplace_back(info.column_infos[i].vector, execute_info.active_rows_sel, reduced_row_idx); + input_chunk.data[slice_offset + 2 + i].Reference(slices.back()); + } + } + + result_chunk.Reset(); + result_chunk.SetCardinality(reduced_row_idx); + execute_info.expr_executor->Execute(input_chunk, result_chunk); + + // We need to copy the result into left_slice to avoid data loss due to vector.Reference(...). + // Otherwise, we only keep the data of the previous iteration alive, not that of previous iterations. + execute_info.left_slice = make_uniq(result_chunk.data[0].GetType(), reduced_row_idx); + VectorOperations::Copy(result_chunk.data[0], *execute_info.left_slice, reduced_row_idx, 0, 0); + return false; +} + +void LambdaFunctions::ListReduceFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // Initializes the left slice from the list entries, active rows, the expression executor and the input types + bool completed = false; + LambdaFunctions::LambdaInfo info(args, state, result, completed); + if (completed) { + return; + } + + ReduceExecuteInfo execute_info(info, state.GetContext()); + + // Since the left slice references the result chunk, we need to create two result chunks. + // This means there is always an empty result chunk for the next iteration, + // without the referenced chunk having to be reset until the current iteration is complete. + DataChunk odd_result_chunk; + odd_result_chunk.Initialize(Allocator::DefaultAllocator(), {info.lambda_expr->return_type}); + + DataChunk even_result_chunk; + even_result_chunk.Initialize(Allocator::DefaultAllocator(), {info.lambda_expr->return_type}); + + // Execute reduce until all rows are finished. + idx_t loops = 0; + bool end = false; + while (!end) { + auto &result_chunk = loops % 2 ? odd_result_chunk : even_result_chunk; + auto &spare_result_chunk = loops % 2 ? even_result_chunk : odd_result_chunk; + + end = ExecuteReduce(loops, execute_info, info, result_chunk); + spare_result_chunk.Reset(); + loops++; + } + + if (info.is_all_constant && !info.is_volatile) { + info.result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr ListReduceBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + // the list column and the bound lambda expression + D_ASSERT(arguments.size() == 2); + if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { + throw BinderException("Invalid lambda expression!"); + } + + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + + auto &bound_lambda_expr = arguments[1]->Cast(); + if (bound_lambda_expr.parameter_count < 2 || bound_lambda_expr.parameter_count > 3) { + throw BinderException("list_reduce expects a function with 2 or 3 arguments"); + } + auto has_index = bound_lambda_expr.parameter_count == 3; + + unique_ptr bind_data = LambdaFunctions::ListLambdaPrepareBind(arguments, context, bound_function); + if (bind_data) { + return bind_data; + } + + auto list_child_type = arguments[0]->return_type; + list_child_type = ListType::GetChildType(list_child_type); + + auto cast_lambda_expr = + BoundCastExpression::AddCastToType(context, std::move(bound_lambda_expr.lambda_expr), list_child_type, false); + if (!cast_lambda_expr) { + throw BinderException("Could not cast lambda expression to list child type"); + } + bound_function.return_type = cast_lambda_expr->return_type; + return make_uniq(bound_function.return_type, std::move(cast_lambda_expr), has_index); +} + +static LogicalType ListReduceBindLambda(const idx_t parameter_idx, const LogicalType &list_child_type) { + return LambdaFunctions::BindTernaryLambda(parameter_idx, list_child_type); +} + +ScalarFunction ListReduceFun::GetFunction() { + ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::ANY, + LambdaFunctions::ListReduceFunction, ListReduceBind, nullptr, nullptr); + + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.serialize = ListLambdaBindData::Serialize; + fun.deserialize = ListLambdaBindData::Deserialize; + fun.bind_lambda = ListReduceBindLambda; + + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp new file mode 100644 index 00000000..5ab523d2 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp @@ -0,0 +1,416 @@ +#include "core_functions/scalar/list_functions.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/common/sort/sort.hpp" + +namespace duckdb { + +struct ListSortBindData : public FunctionData { + ListSortBindData(OrderType order_type_p, OrderByNullType null_order_p, bool is_grade_up, + const LogicalType &return_type_p, const LogicalType &child_type_p, ClientContext &context_p); + ~ListSortBindData() override; + + OrderType order_type; + OrderByNullType null_order; + LogicalType return_type; + LogicalType child_type; + bool is_grade_up; + + vector types; + vector payload_types; + + ClientContext &context; + RowLayout payload_layout; + vector orders; + +public: + bool Equals(const FunctionData &other_p) const override; + unique_ptr Copy() const override; +}; + +ListSortBindData::ListSortBindData(OrderType order_type_p, OrderByNullType null_order_p, bool is_grade_up_p, + const LogicalType &return_type_p, const LogicalType &child_type_p, + ClientContext &context_p) + : order_type(order_type_p), null_order(null_order_p), return_type(return_type_p), child_type(child_type_p), + is_grade_up(is_grade_up_p), context(context_p) { + + // get the vector types + types.emplace_back(LogicalType::USMALLINT); + types.emplace_back(child_type); + D_ASSERT(types.size() == 2); + + // get the payload types + payload_types.emplace_back(LogicalType::UINTEGER); + D_ASSERT(payload_types.size() == 1); + + // initialize the payload layout + payload_layout.Initialize(payload_types); + + // get the BoundOrderByNode + auto idx_col_expr = make_uniq_base(LogicalType::USMALLINT, 0U); + auto lists_col_expr = make_uniq_base(child_type, 1U); + orders.emplace_back(OrderType::ASCENDING, OrderByNullType::ORDER_DEFAULT, std::move(idx_col_expr)); + orders.emplace_back(order_type, null_order, std::move(lists_col_expr)); +} + +unique_ptr ListSortBindData::Copy() const { + return make_uniq(order_type, null_order, is_grade_up, return_type, child_type, context); +} + +bool ListSortBindData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return order_type == other.order_type && null_order == other.null_order && is_grade_up == other.is_grade_up; +} + +ListSortBindData::~ListSortBindData() { +} + +// create the key_chunk and the payload_chunk and sink them into the local_sort_state +void SinkDataChunk(Vector *child_vector, SelectionVector &sel, idx_t offset_lists_indices, vector &types, + vector &payload_types, Vector &payload_vector, LocalSortState &local_sort_state, + bool &data_to_sort, Vector &lists_indices) { + + // slice the child vector + Vector slice(*child_vector, sel, offset_lists_indices); + + // initialize and fill key_chunk + DataChunk key_chunk; + key_chunk.InitializeEmpty(types); + key_chunk.data[0].Reference(lists_indices); + key_chunk.data[1].Reference(slice); + key_chunk.SetCardinality(offset_lists_indices); + + // initialize and fill key_chunk and payload_chunk + DataChunk payload_chunk; + payload_chunk.InitializeEmpty(payload_types); + payload_chunk.data[0].Reference(payload_vector); + payload_chunk.SetCardinality(offset_lists_indices); + + key_chunk.Verify(); + payload_chunk.Verify(); + + // sink + key_chunk.Flatten(); + local_sort_state.SinkChunk(key_chunk, payload_chunk); + data_to_sort = true; +} + +static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() >= 1 && args.ColumnCount() <= 3); + auto count = args.size(); + Vector &input_lists = args.data[0]; + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto &result_validity = FlatVector::Validity(result); + + if (input_lists.GetType().id() == LogicalTypeId::SQLNULL) { + result_validity.SetInvalid(0); + return; + } + + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + // initialize the global and local sorting state + auto &buffer_manager = BufferManager::GetBufferManager(info.context); + GlobalSortState global_sort_state(buffer_manager, info.orders, info.payload_layout); + LocalSortState local_sort_state; + local_sort_state.Initialize(global_sort_state, buffer_manager); + + Vector sort_result_vec = info.is_grade_up ? Vector(input_lists.GetType()) : result; + + // this ensures that we do not change the order of the entries in the input chunk + VectorOperations::Copy(input_lists, sort_result_vec, count, 0, 0); + + // get the child vector + auto lists_size = ListVector::GetListSize(sort_result_vec); + auto &child_vector = ListVector::GetEntry(sort_result_vec); + + // get the lists data + UnifiedVectorFormat lists_data; + sort_result_vec.ToUnifiedFormat(count, lists_data); + auto list_entries = UnifiedVectorFormat::GetData(lists_data); + + // create the lists_indices vector, this contains an element for each list's entry, + // the element corresponds to the list's index, e.g. for [1, 2, 4], [5, 4] + // lists_indices contains [0, 0, 0, 1, 1] + Vector lists_indices(LogicalType::USMALLINT); + auto lists_indices_data = FlatVector::GetData(lists_indices); + + // create the payload_vector, this is just a vector containing incrementing integers + // this will later be used as the 'new' selection vector of the child_vector, after + // rearranging the payload according to the sorting order + Vector payload_vector(LogicalType::UINTEGER); + auto payload_vector_data = FlatVector::GetData(payload_vector); + + // selection vector pointing to the data of the child vector, + // used for slicing the child_vector correctly + SelectionVector sel(STANDARD_VECTOR_SIZE); + + idx_t offset_lists_indices = 0; + uint32_t incr_payload_count = 0; + bool data_to_sort = false; + + for (idx_t i = 0; i < count; i++) { + auto lists_index = lists_data.sel->get_index(i); + const auto &list_entry = list_entries[lists_index]; + + // nothing to do for this list + if (!lists_data.validity.RowIsValid(lists_index)) { + result_validity.SetInvalid(i); + continue; + } + + // empty list, no sorting required + if (list_entry.length == 0) { + continue; + } + + for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { + // lists_indices vector is full, sink + if (offset_lists_indices == STANDARD_VECTOR_SIZE) { + SinkDataChunk(&child_vector, sel, offset_lists_indices, info.types, info.payload_types, payload_vector, + local_sort_state, data_to_sort, lists_indices); + offset_lists_indices = 0; + } + + auto source_idx = list_entry.offset + child_idx; + sel.set_index(offset_lists_indices, source_idx); + lists_indices_data[offset_lists_indices] = UnsafeNumericCast(i); + payload_vector_data[offset_lists_indices] = NumericCast(source_idx); + offset_lists_indices++; + incr_payload_count++; + } + } + + if (offset_lists_indices != 0) { + SinkDataChunk(&child_vector, sel, offset_lists_indices, info.types, info.payload_types, payload_vector, + local_sort_state, data_to_sort, lists_indices); + } + + if (info.is_grade_up) { + ListVector::Reserve(result, lists_size); + ListVector::SetListSize(result, lists_size); + auto result_data = ListVector::GetData(result); + memcpy(result_data, list_entries, count * sizeof(list_entry_t)); + } + + if (data_to_sort) { + // add local state to global state, which sorts the data + global_sort_state.AddLocalState(local_sort_state); + global_sort_state.PrepareMergePhase(); + + // selection vector that is to be filled with the 'sorted' payload + SelectionVector sel_sorted(incr_payload_count); + idx_t sel_sorted_idx = 0; + + // scan the sorted row data + PayloadScanner scanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state); + for (;;) { + DataChunk result_chunk; + result_chunk.Initialize(Allocator::DefaultAllocator(), info.payload_types); + result_chunk.SetCardinality(0); + scanner.Scan(result_chunk); + if (result_chunk.size() == 0) { + break; + } + + // construct the selection vector with the new order from the result vectors + Vector result_vector(result_chunk.data[0]); + auto result_data = FlatVector::GetData(result_vector); + auto row_count = result_chunk.size(); + + for (idx_t i = 0; i < row_count; i++) { + sel_sorted.set_index(sel_sorted_idx, result_data[i]); + D_ASSERT(result_data[i] < lists_size); + sel_sorted_idx++; + } + } + + D_ASSERT(sel_sorted_idx == incr_payload_count); + if (info.is_grade_up) { + auto &result_entry = ListVector::GetEntry(result); + auto result_data = ListVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + if (!result_validity.RowIsValid(i)) { + continue; + } + for (idx_t j = result_data[i].offset; j < result_data[i].offset + result_data[i].length; j++) { + auto b = sel_sorted.get_index(j) - result_data[i].offset; + result_entry.SetValue(j, Value::BIGINT(UnsafeNumericCast(b + 1))); + } + } + } else { + child_vector.Slice(sel_sorted, sel_sorted_idx); + child_vector.Flatten(sel_sorted_idx); + } + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr ListSortBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments, OrderType &order, + OrderByNullType &null_order) { + + LogicalType child_type; + if (arguments[0]->return_type == LogicalTypeId::UNKNOWN) { + bound_function.arguments[0] = LogicalTypeId::UNKNOWN; + bound_function.return_type = LogicalType::SQLNULL; + child_type = bound_function.return_type; + return make_uniq(order, null_order, false, bound_function.return_type, child_type, context); + } + + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + child_type = ListType::GetChildType(arguments[0]->return_type); + + bound_function.arguments[0] = arguments[0]->return_type; + bound_function.return_type = arguments[0]->return_type; + + return make_uniq(order, null_order, false, bound_function.return_type, child_type, context); +} + +template +static T GetOrder(ClientContext &context, Expression &expr) { + if (!expr.IsFoldable()) { + throw InvalidInputException("Sorting order must be a constant"); + } + Value order_value = ExpressionExecutor::EvaluateScalar(context, expr); + auto order_name = StringUtil::Upper(order_value.ToString()); + return EnumUtil::FromString(order_name.c_str()); +} + +static unique_ptr ListGradeUpBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + D_ASSERT(!arguments.empty() && arguments.size() <= 3); + auto order = OrderType::ORDER_DEFAULT; + auto null_order = OrderByNullType::ORDER_DEFAULT; + + // get the sorting order + if (arguments.size() >= 2) { + order = GetOrder(context, *arguments[1]); + } + // get the null sorting order + if (arguments.size() == 3) { + null_order = GetOrder(context, *arguments[2]); + } + auto &config = DBConfig::GetConfig(context); + order = config.ResolveOrder(order); + null_order = config.ResolveNullOrder(order, null_order); + + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + + bound_function.arguments[0] = arguments[0]->return_type; + bound_function.return_type = LogicalType::LIST(LogicalTypeId::BIGINT); + auto child_type = ListType::GetChildType(arguments[0]->return_type); + return make_uniq(order, null_order, true, bound_function.return_type, child_type, context); +} + +static unique_ptr ListNormalSortBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(!arguments.empty() && arguments.size() <= 3); + auto order = OrderType::ORDER_DEFAULT; + auto null_order = OrderByNullType::ORDER_DEFAULT; + + // get the sorting order + if (arguments.size() >= 2) { + order = GetOrder(context, *arguments[1]); + } + // get the null sorting order + if (arguments.size() == 3) { + null_order = GetOrder(context, *arguments[2]); + } + auto &config = DBConfig::GetConfig(context); + order = config.ResolveOrder(order); + null_order = config.ResolveNullOrder(order, null_order); + return ListSortBind(context, bound_function, arguments, order, null_order); +} + +static unique_ptr ListReverseSortBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto order = OrderType::ORDER_DEFAULT; + auto null_order = OrderByNullType::ORDER_DEFAULT; + + if (arguments.size() == 2) { + null_order = GetOrder(context, *arguments[1]); + } + auto &config = DBConfig::GetConfig(context); + order = config.ResolveOrder(order); + switch (order) { + case OrderType::ASCENDING: + order = OrderType::DESCENDING; + break; + case OrderType::DESCENDING: + order = OrderType::ASCENDING; + break; + default: + throw InternalException("Unexpected order type in list reverse sort"); + } + null_order = config.ResolveNullOrder(order, null_order); + return ListSortBind(context, bound_function, arguments, order, null_order); +} + +ScalarFunctionSet ListSortFun::GetFunctions() { + // one parameter: list + ScalarFunction sort({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), ListSortFunction, + ListNormalSortBind); + + // two parameters: list, order + ScalarFunction sort_order({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListNormalSortBind); + + // three parameters: list, order, null order + ScalarFunction sort_orders({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListNormalSortBind); + + ScalarFunctionSet list_sort; + list_sort.AddFunction(sort); + list_sort.AddFunction(sort_order); + list_sort.AddFunction(sort_orders); + return list_sort; +} + +ScalarFunctionSet ListGradeUpFun::GetFunctions() { + // one parameter: list + ScalarFunction sort({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), ListSortFunction, + ListGradeUpBind); + + // two parameters: list, order + ScalarFunction sort_order({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListGradeUpBind); + + // three parameters: list, order, null order + ScalarFunction sort_orders({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListGradeUpBind); + + ScalarFunctionSet list_grade_up; + list_grade_up.AddFunction(sort); + list_grade_up.AddFunction(sort_order); + list_grade_up.AddFunction(sort_orders); + return list_grade_up; +} + +ScalarFunctionSet ListReverseSortFun::GetFunctions() { + // one parameter: list + ScalarFunction sort_reverse({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), + ListSortFunction, ListReverseSortBind); + + // two parameters: list, null order + ScalarFunction sort_reverse_null_order({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListReverseSortBind); + + ScalarFunctionSet list_reverse_sort; + list_reverse_sort.AddFunction(sort_reverse); + list_reverse_sort.AddFunction(sort_reverse_null_order); + return list_reverse_sort; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp b/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp new file mode 100644 index 00000000..26c6ad4b --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp @@ -0,0 +1,41 @@ +#include "core_functions/scalar/list_functions.hpp" + +#include "duckdb/function/lambda_functions.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" + +namespace duckdb { + +static unique_ptr ListTransformBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + // the list column and the bound lambda expression + D_ASSERT(arguments.size() == 2); + if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { + throw BinderException("Invalid lambda expression!"); + } + + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + + auto &bound_lambda_expr = arguments[1]->Cast(); + bound_function.return_type = LogicalType::LIST(bound_lambda_expr.lambda_expr->return_type); + auto has_index = bound_lambda_expr.parameter_count == 2; + return LambdaFunctions::ListLambdaBind(context, bound_function, arguments, has_index); +} + +static LogicalType ListTransformBindLambda(const idx_t parameter_idx, const LogicalType &list_child_type) { + return LambdaFunctions::BindBinaryLambda(parameter_idx, list_child_type); +} + +ScalarFunction ListTransformFun::GetFunction() { + ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), + LambdaFunctions::ListTransformFunction, ListTransformBind, nullptr, nullptr); + + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.serialize = ListLambdaBindData::Serialize; + fun.deserialize = ListLambdaBindData::Deserialize; + fun.bind_lambda = ListTransformBindLambda; + + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp new file mode 100644 index 00000000..01b342ec --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp @@ -0,0 +1,203 @@ +#include "core_functions/scalar/list_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" +#include "duckdb/planner/expression_binder.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/parser/query_error_context.hpp" + +namespace duckdb { + +struct ListValueAssign { + template + static T Assign(const T &input, Vector &result) { + return input; + } +}; + +struct ListValueStringAssign { + template + static T Assign(const T &input, Vector &result) { + return StringVector::AddStringOrBlob(result, input); + } +}; + +template +static void TemplatedListValueFunction(DataChunk &args, Vector &result) { + idx_t list_size = args.ColumnCount(); + ListVector::Reserve(result, args.size() * list_size); + auto result_data = FlatVector::GetData(result); + auto &list_child = ListVector::GetEntry(result); + auto child_data = FlatVector::GetData(list_child); + auto &child_validity = FlatVector::Validity(list_child); + + auto unified_format = args.ToUnifiedFormat(); + for (idx_t r = 0; r < args.size(); r++) { + for (idx_t c = 0; c < list_size; c++) { + auto input_idx = unified_format[c].sel->get_index(r); + auto result_idx = r * list_size + c; + auto input_data = UnifiedVectorFormat::GetData(unified_format[c]); + if (unified_format[c].validity.RowIsValid(input_idx)) { + child_data[result_idx] = OP::template Assign(input_data[input_idx], list_child); + } else { + child_validity.SetInvalid(result_idx); + } + } + result_data[r].offset = r * list_size; + result_data[r].length = list_size; + } + ListVector::SetListSize(result, args.size() * list_size); +} + +static void TemplatedListValueFunctionFallback(DataChunk &args, Vector &result) { + auto &child_type = ListType::GetChildType(result.GetType()); + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < args.size(); i++) { + result_data[i].offset = ListVector::GetListSize(result); + for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { + auto val = args.GetValue(col_idx, i).DefaultCastAs(child_type); + ListVector::PushBack(result, val); + } + result_data[i].length = args.ColumnCount(); + } +} + +static void ListValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + if (args.ColumnCount() == 0) { + // no columns - early out - result is a constant empty list + auto result_data = FlatVector::GetData(result); + result_data[0].length = 0; + result_data[0].offset = 0; + return; + } + for (idx_t i = 0; i < args.ColumnCount(); i++) { + if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::FLAT_VECTOR); + } + } + auto &result_type = ListVector::GetEntry(result).GetType(); + switch (result_type.InternalType()) { + case PhysicalType::BOOL: + case PhysicalType::INT8: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::INT16: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::INT32: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::INT64: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::UINT8: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::UINT16: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::UINT32: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::UINT64: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::INT128: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::UINT128: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::FLOAT: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::DOUBLE: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::INTERVAL: + TemplatedListValueFunction(args, result); + break; + case PhysicalType::VARCHAR: + TemplatedListValueFunction(args, result); + break; + default: { + TemplatedListValueFunctionFallback(args, result); + break; + } + } +} + +template +static unique_ptr ListValueBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // collect names and deconflict, construct return type + LogicalType child_type = + arguments.empty() ? LogicalType::SQLNULL : ExpressionBinder::GetExpressionReturnType(*arguments[0]); + for (idx_t i = 1; i < arguments.size(); i++) { + auto arg_type = ExpressionBinder::GetExpressionReturnType(*arguments[i]); + if (!LogicalType::TryGetMaxLogicalType(context, child_type, arg_type, child_type)) { + if (IS_UNPIVOT) { + string list_arguments = "Full list: "; + idx_t error_index = list_arguments.size(); + for (idx_t k = 0; k < arguments.size(); k++) { + if (k > 0) { + list_arguments += ", "; + } + if (k == i) { + error_index = list_arguments.size(); + } + list_arguments += arguments[k]->ToString() + " " + arguments[k]->return_type.ToString(); + } + auto error = + StringUtil::Format("Cannot unpivot columns of types %s and %s - an explicit cast is required", + child_type.ToString(), arg_type.ToString()); + throw BinderException(arguments[i]->GetQueryLocation(), + QueryErrorContext::Format(list_arguments, error, error_index, false)); + } else { + throw BinderException(arguments[i]->GetQueryLocation(), + "Cannot create a list of types %s and %s - an explicit cast is required", + child_type.ToString(), arg_type.ToString()); + } + } + } + child_type = LogicalType::NormalizeType(child_type); + + // this is more for completeness reasons + bound_function.varargs = child_type; + bound_function.return_type = LogicalType::LIST(child_type); + return make_uniq(bound_function.return_type); +} + +unique_ptr ListValueStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + auto list_stats = ListStats::CreateEmpty(expr.return_type); + auto &list_child_stats = ListStats::GetChildStats(list_stats); + for (idx_t i = 0; i < child_stats.size(); i++) { + list_child_stats.Merge(child_stats[i]); + } + return list_stats.ToUnique(); +} + +ScalarFunction ListValueFun::GetFunction() { + // the arguments and return types are actually set in the binder function + ScalarFunction fun("list_value", {}, LogicalTypeId::LIST, ListValueFunction, ListValueBind, nullptr, + ListValueStats); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +ScalarFunction UnpivotListFun::GetFunction() { + auto fun = ListValueFun::GetFunction(); + fun.name = "unpivot_list"; + fun.bind = ListValueBind; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/range.cpp b/src/duckdb/extension/core_functions/scalar/list/range.cpp new file mode 100644 index 00000000..8c641d13 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/list/range.cpp @@ -0,0 +1,281 @@ +#include "core_functions/scalar/list_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types/timestamp.hpp" + +namespace duckdb { + +struct NumericRangeInfo { + using TYPE = int64_t; + using INCREMENT_TYPE = int64_t; + + static int64_t DefaultStart() { + return 0; + } + static int64_t DefaultIncrement() { + return 1; + } + + static uint64_t ListLength(int64_t start_value, int64_t end_value, int64_t increment_value, bool inclusive_bound) { + if (increment_value == 0) { + return 0; + } + if (start_value > end_value && increment_value > 0) { + return 0; + } + if (start_value < end_value && increment_value < 0) { + return 0; + } + hugeint_t total_diff = AbsValue(hugeint_t(end_value) - hugeint_t(start_value)); + hugeint_t increment = AbsValue(hugeint_t(increment_value)); + hugeint_t total_values = total_diff / increment; + if (total_diff % increment == 0) { + if (inclusive_bound) { + total_values += 1; + } + } else { + total_values += 1; + } + if (total_values > NumericLimits::Maximum()) { + throw InvalidInputException("Lists larger than 2^32 elements are not supported"); + } + return Hugeint::Cast(total_values); + } + + static void Increment(int64_t &input, int64_t increment) { + input += increment; + } +}; +struct TimestampRangeInfo { + using TYPE = timestamp_t; + using INCREMENT_TYPE = interval_t; + + static timestamp_t DefaultStart() { + throw InternalException("Default start not implemented for timestamp range"); + } + static interval_t DefaultIncrement() { + throw InternalException("Default increment not implemented for timestamp range"); + } + static uint64_t ListLength(timestamp_t start_value, timestamp_t end_value, interval_t increment_value, + bool inclusive_bound) { + bool is_positive = increment_value.months > 0 || increment_value.days > 0 || increment_value.micros > 0; + bool is_negative = increment_value.months < 0 || increment_value.days < 0 || increment_value.micros < 0; + if (!is_negative && !is_positive) { + // interval is 0: no result + return 0; + } + // We don't allow infinite bounds because they generate errors or infinite loops + if (!Timestamp::IsFinite(start_value) || !Timestamp::IsFinite(end_value)) { + throw InvalidInputException("Interval infinite bounds not supported"); + } + + if (is_negative && is_positive) { + // we don't allow a mix of + throw InvalidInputException("Interval with mix of negative/positive entries not supported"); + } + if (start_value > end_value && is_positive) { + return 0; + } + if (start_value < end_value && is_negative) { + return 0; + } + uint64_t total_values = 0; + if (is_negative) { + // negative interval, start_value is going down + while (inclusive_bound ? start_value >= end_value : start_value > end_value) { + start_value = Interval::Add(start_value, increment_value); + total_values++; + if (total_values > NumericLimits::Maximum()) { + throw InvalidInputException("Lists larger than 2^32 elements are not supported"); + } + } + } else { + // positive interval, start_value is going up + while (inclusive_bound ? start_value <= end_value : start_value < end_value) { + start_value = Interval::Add(start_value, increment_value); + total_values++; + if (total_values > NumericLimits::Maximum()) { + throw InvalidInputException("Lists larger than 2^32 elements are not supported"); + } + } + } + return total_values; + } + + static void Increment(timestamp_t &input, interval_t increment) { + input = Interval::Add(input, increment); + } +}; + +template +class RangeInfoStruct { +public: + explicit RangeInfoStruct(DataChunk &args_p) : args(args_p) { + switch (args.ColumnCount()) { + case 1: + args.data[0].ToUnifiedFormat(args.size(), vdata[0]); + break; + case 2: + args.data[0].ToUnifiedFormat(args.size(), vdata[0]); + args.data[1].ToUnifiedFormat(args.size(), vdata[1]); + break; + case 3: + args.data[0].ToUnifiedFormat(args.size(), vdata[0]); + args.data[1].ToUnifiedFormat(args.size(), vdata[1]); + args.data[2].ToUnifiedFormat(args.size(), vdata[2]); + break; + default: + throw InternalException("Unsupported number of parameters for range"); + } + } + + bool RowIsValid(idx_t row_idx) { + for (idx_t i = 0; i < args.ColumnCount(); i++) { + auto idx = vdata[i].sel->get_index(row_idx); + if (!vdata[i].validity.RowIsValid(idx)) { + return false; + } + } + return true; + } + + typename OP::TYPE StartListValue(idx_t row_idx) { + if (args.ColumnCount() == 1) { + return OP::DefaultStart(); + } else { + auto data = (typename OP::TYPE *)vdata[0].data; + auto idx = vdata[0].sel->get_index(row_idx); + return data[idx]; + } + } + + typename OP::TYPE EndListValue(idx_t row_idx) { + idx_t vdata_idx = args.ColumnCount() == 1 ? 0 : 1; + auto data = (typename OP::TYPE *)vdata[vdata_idx].data; + auto idx = vdata[vdata_idx].sel->get_index(row_idx); + return data[idx]; + } + + typename OP::INCREMENT_TYPE ListIncrementValue(idx_t row_idx) { + if (args.ColumnCount() < 3) { + return OP::DefaultIncrement(); + } else { + auto data = (typename OP::INCREMENT_TYPE *)vdata[2].data; + auto idx = vdata[2].sel->get_index(row_idx); + return data[idx]; + } + } + + void GetListValues(idx_t row_idx, typename OP::TYPE &start_value, typename OP::TYPE &end_value, + typename OP::INCREMENT_TYPE &increment_value) { + start_value = StartListValue(row_idx); + end_value = EndListValue(row_idx); + increment_value = ListIncrementValue(row_idx); + } + + uint64_t ListLength(idx_t row_idx) { + typename OP::TYPE start_value; + typename OP::TYPE end_value; + typename OP::INCREMENT_TYPE increment_value; + GetListValues(row_idx, start_value, end_value, increment_value); + return OP::ListLength(start_value, end_value, increment_value, INCLUSIVE_BOUND); + } + +private: + DataChunk &args; + UnifiedVectorFormat vdata[3]; +}; + +template +static void ListRangeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + + RangeInfoStruct info(args); + idx_t args_size = 1; + auto result_type = VectorType::CONSTANT_VECTOR; + for (idx_t i = 0; i < args.ColumnCount(); i++) { + if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { + args_size = args.size(); + result_type = VectorType::FLAT_VECTOR; + break; + } + } + auto list_data = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + uint64_t total_size = 0; + for (idx_t i = 0; i < args_size; i++) { + if (!info.RowIsValid(i)) { + result_validity.SetInvalid(i); + list_data[i].offset = total_size; + list_data[i].length = 0; + } else { + list_data[i].offset = total_size; + list_data[i].length = info.ListLength(i); + total_size += list_data[i].length; + } + } + + // now construct the child vector of the list + ListVector::Reserve(result, total_size); + auto range_data = FlatVector::GetData(ListVector::GetEntry(result)); + idx_t total_idx = 0; + for (idx_t i = 0; i < args_size; i++) { + typename OP::TYPE start_value = info.StartListValue(i); + typename OP::INCREMENT_TYPE increment = info.ListIncrementValue(i); + + typename OP::TYPE range_value = start_value; + for (idx_t range_idx = 0; range_idx < list_data[i].length; range_idx++) { + if (range_idx > 0) { + OP::Increment(range_value, increment); + } + range_data[total_idx++] = range_value; + } + } + + ListVector::SetListSize(result, total_size); + result.SetVectorType(result_type); + + result.Verify(args.size()); +} + +ScalarFunctionSet ListRangeFun::GetFunctions() { + // the arguments and return types are actually set in the binder function + ScalarFunctionSet range_set; + range_set.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + range_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + range_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + range_set.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, + LogicalType::LIST(LogicalType::TIMESTAMP), + ListRangeFunction)); + for (auto &func : range_set.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return range_set; +} + +ScalarFunctionSet GenerateSeriesFun::GetFunctions() { + ScalarFunctionSet generate_series; + generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::BIGINT), + ListRangeFunction)); + generate_series.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, + LogicalType::LIST(LogicalType::TIMESTAMP), + ListRangeFunction)); + for (auto &func : generate_series.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return generate_series; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp b/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp new file mode 100644 index 00000000..9c81223e --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp @@ -0,0 +1,50 @@ +#include "core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +static void CardinalityFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &map = args.data[0]; + UnifiedVectorFormat map_data; + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + auto &result_validity = FlatVector::Validity(result); + + map.ToUnifiedFormat(args.size(), map_data); + for (idx_t row = 0; row < args.size(); row++) { + auto list_entry = UnifiedVectorFormat::GetData(map_data)[map_data.sel->get_index(row)]; + result_data[row] = list_entry.length; + result_validity.Set(row, map_data.validity.RowIsValid(map_data.sel->get_index(row))); + } + + if (args.size() == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr CardinalityBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.size() != 1) { + throw BinderException("Cardinality must have exactly one arguments"); + } + + if (arguments[0]->return_type.id() != LogicalTypeId::MAP) { + throw BinderException("Cardinality can only operate on MAPs"); + } + + bound_function.return_type = LogicalType::UBIGINT; + return make_uniq(bound_function.return_type); +} + +ScalarFunction CardinalityFun::GetFunction() { + ScalarFunction fun({LogicalType::ANY}, LogicalType::UBIGINT, CardinalityFunction, CardinalityBind); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map.cpp b/src/duckdb/extension/core_functions/scalar/map/map.cpp new file mode 100644 index 00000000..b83a4a08 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/map/map.cpp @@ -0,0 +1,223 @@ +#include "core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/types/value_map.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +static void MapFunctionEmptyInput(Vector &result, const idx_t row_count) { + // If no chunk is set in ExpressionExecutor::ExecuteExpression (args.data.empty(), e.g., in SELECT MAP()), + // then we always pass a row_count of 1. + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ListVector::SetListSize(result, 0); + + auto result_data = ListVector::GetData(result); + result_data[0] = list_entry_t(); + result.Verify(row_count); +} + +static bool MapIsNull(DataChunk &chunk) { + if (chunk.data.empty()) { + return false; + } + D_ASSERT(chunk.data.size() == 2); + auto &keys = chunk.data[0]; + auto &values = chunk.data[1]; + + if (keys.GetType().id() == LogicalTypeId::SQLNULL) { + return true; + } + if (values.GetType().id() == LogicalTypeId::SQLNULL) { + return true; + } + return false; +} + +static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { + + // internal MAP representation + // - LIST-vector that contains STRUCTs as child entries + // - STRUCTs have exactly two fields, a key-field, and a value-field + // - key names are unique + D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); + + if (MapIsNull(args)) { + auto &validity = FlatVector::Validity(result); + validity.SetInvalid(0); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + return; + } + + auto row_count = args.size(); + + // early-out, if no data + if (args.data.empty()) { + return MapFunctionEmptyInput(result, row_count); + } + + auto &keys = args.data[0]; + auto &values = args.data[1]; + + // a LIST vector, where each row contains a LIST of KEYS + UnifiedVectorFormat keys_data; + keys.ToUnifiedFormat(row_count, keys_data); + auto keys_entries = UnifiedVectorFormat::GetData(keys_data); + + // the KEYs child vector + auto keys_child_vector = ListVector::GetEntry(keys); + UnifiedVectorFormat keys_child_data; + keys_child_vector.ToUnifiedFormat(ListVector::GetListSize(keys), keys_child_data); + + // a LIST vector, where each row contains a LIST of VALUES + UnifiedVectorFormat values_data; + values.ToUnifiedFormat(row_count, values_data); + auto values_entries = UnifiedVectorFormat::GetData(values_data); + + // the VALUEs child vector + auto values_child_vector = ListVector::GetEntry(values); + UnifiedVectorFormat values_child_data; + values_child_vector.ToUnifiedFormat(ListVector::GetListSize(values), values_child_data); + + // a LIST vector, where each row contains a MAP (LIST of STRUCTs) + UnifiedVectorFormat result_data; + result.ToUnifiedFormat(row_count, result_data); + auto result_entries = UnifiedVectorFormat::GetDataNoConst(result_data); + + auto &result_validity = FlatVector::Validity(result); + + // get the resulting size of the key/value child lists + idx_t result_child_size = 0; + for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { + auto keys_idx = keys_data.sel->get_index(row_idx); + auto values_idx = values_data.sel->get_index(row_idx); + if (!keys_data.validity.RowIsValid(keys_idx) || !values_data.validity.RowIsValid(values_idx)) { + continue; + } + auto keys_entry = keys_entries[keys_idx]; + result_child_size += keys_entry.length; + } + + // we need to slice potential non-flat vectors + SelectionVector sel_keys(result_child_size); + SelectionVector sel_values(result_child_size); + idx_t offset = 0; + + for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { + + auto keys_idx = keys_data.sel->get_index(row_idx); + auto values_idx = values_data.sel->get_index(row_idx); + auto result_idx = result_data.sel->get_index(row_idx); + + // NULL MAP + if (!keys_data.validity.RowIsValid(keys_idx) || !values_data.validity.RowIsValid(values_idx)) { + result_validity.SetInvalid(row_idx); + continue; + } + + auto keys_entry = keys_entries[keys_idx]; + auto values_entry = values_entries[values_idx]; + + if (keys_entry.length != values_entry.length) { + MapVector::EvalMapInvalidReason(MapInvalidReason::NOT_ALIGNED); + } + + // set the selection vectors and perform a duplicate key check + value_set_t unique_keys; + for (idx_t child_idx = 0; child_idx < keys_entry.length; child_idx++) { + + auto key_idx = keys_child_data.sel->get_index(keys_entry.offset + child_idx); + auto value_idx = values_child_data.sel->get_index(values_entry.offset + child_idx); + + // NULL check + if (!keys_child_data.validity.RowIsValid(key_idx)) { + MapVector::EvalMapInvalidReason(MapInvalidReason::NULL_KEY); + } + + // unique check + auto value = keys_child_vector.GetValue(key_idx); + auto unique = unique_keys.insert(value).second; + if (!unique) { + MapVector::EvalMapInvalidReason(MapInvalidReason::DUPLICATE_KEY); + } + + // set selection vectors + sel_keys.set_index(offset + child_idx, key_idx); + sel_values.set_index(offset + child_idx, value_idx); + } + + // keys_entry and values_entry have the same length + result_entries[result_idx].length = keys_entry.length; + result_entries[result_idx].offset = offset; + offset += keys_entry.length; + } + D_ASSERT(offset == result_child_size); + + auto &result_key_vector = MapVector::GetKeys(result); + auto &result_value_vector = MapVector::GetValues(result); + + ListVector::SetListSize(result, offset); + result_key_vector.Slice(keys_child_vector, sel_keys, offset); + result_key_vector.Flatten(offset); + result_value_vector.Slice(values_child_vector, sel_values, offset); + result_value_vector.Flatten(offset); + FlatVector::Validity(ListVector::GetEntry(result)).Resize(result_child_size); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(row_count); +} + +static unique_ptr MapBind(ClientContext &, ScalarFunction &bound_function, + vector> &arguments) { + + if (arguments.size() != 2 && !arguments.empty()) { + MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); + } + + bool is_null = false; + if (arguments.empty()) { + is_null = true; + } + if (!is_null) { + auto key_id = arguments[0]->return_type.id(); + auto value_id = arguments[1]->return_type.id(); + if (key_id == LogicalTypeId::SQLNULL || value_id == LogicalTypeId::SQLNULL) { + is_null = true; + } + } + + if (is_null) { + bound_function.return_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); + return make_uniq(bound_function.return_type); + } + + // bind a MAP with key-value pairs + D_ASSERT(arguments.size() == 2); + if (arguments[0]->return_type.id() != LogicalTypeId::LIST) { + MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); + } + if (arguments[1]->return_type.id() != LogicalTypeId::LIST) { + MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); + } + + auto key_type = ListType::GetChildType(arguments[0]->return_type); + auto value_type = ListType::GetChildType(arguments[1]->return_type); + + bound_function.return_type = LogicalType::MAP(key_type, value_type); + return make_uniq(bound_function.return_type); +} + +ScalarFunction MapFun::GetFunction() { + ScalarFunction fun({}, LogicalTypeId::MAP, MapFunction, MapBind); + fun.varargs = LogicalType::ANY; + BaseScalarFunction::SetReturnsError(fun); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp b/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp new file mode 100644 index 00000000..c958f41b --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp @@ -0,0 +1,200 @@ +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "core_functions/scalar/map_functions.hpp" + +namespace duckdb { + +namespace { + +struct MapKeyIndexPair { + MapKeyIndexPair(idx_t map, idx_t key) : map_index(map), key_index(key) { + } + // The index of the map that this key comes from + idx_t map_index; + // The index within the maps key_list + idx_t key_index; +}; + +} // namespace + +vector GetListEntries(vector keys, vector values) { + D_ASSERT(keys.size() == values.size()); + vector entries; + for (idx_t i = 0; i < keys.size(); i++) { + child_list_t children; + children.emplace_back(make_pair("key", std::move(keys[i]))); + children.emplace_back(make_pair("value", std::move(values[i]))); + entries.push_back(Value::STRUCT(std::move(children))); + } + return entries; +} + +static void MapConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { + if (result.GetType().id() == LogicalTypeId::SQLNULL) { + // All inputs are NULL, just return NULL + auto &validity = FlatVector::Validity(result); + validity.SetInvalid(0); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + return; + } + D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); + auto count = args.size(); + + auto map_count = args.ColumnCount(); + vector map_formats(map_count); + for (idx_t i = 0; i < map_count; i++) { + auto &map = args.data[i]; + map.ToUnifiedFormat(count, map_formats[i]); + } + auto result_data = FlatVector::GetData(result); + + for (idx_t i = 0; i < count; i++) { + // Loop through all the maps per list + // we cant do better because all the entries of the child vector have to be contiguous + // so we cant start the next row before we have finished the one before it + auto &result_entry = result_data[i]; + vector index_to_map; + vector keys_list; + bool all_null = true; + for (idx_t map_idx = 0; map_idx < map_count; map_idx++) { + if (args.data[map_idx].GetType().id() == LogicalTypeId::SQLNULL) { + continue; + } + + auto &map_format = map_formats[map_idx]; + auto index = map_format.sel->get_index(i); + if (!map_format.validity.RowIsValid(index)) { + continue; + } + + all_null = false; + auto &keys = MapVector::GetKeys(args.data[map_idx]); + auto entry = UnifiedVectorFormat::GetData(map_format)[index]; + + // Update the list for this row + for (idx_t list_idx = 0; list_idx < entry.length; list_idx++) { + auto key_index = entry.offset + list_idx; + auto key = keys.GetValue(key_index); + auto entry = std::find(keys_list.begin(), keys_list.end(), key); + if (entry == keys_list.end()) { + // Result list does not contain this value yet + keys_list.push_back(key); + index_to_map.emplace_back(map_idx, key_index); + } else { + // Result list already contains this, update where to find the value at + auto distance = std::distance(keys_list.begin(), entry); + auto &mapping = *(index_to_map.begin() + distance); + mapping.key_index = key_index; + mapping.map_index = map_idx; + } + } + } + + result_entry.offset = ListVector::GetListSize(result); + result_entry.length = keys_list.size(); + if (all_null) { + D_ASSERT(keys_list.empty() && index_to_map.empty()); + FlatVector::SetNull(result, i, true); + continue; + } + + vector values_list; + D_ASSERT(keys_list.size() == index_to_map.size()); + // Get the values from the mapping + for (auto &mapping : index_to_map) { + auto &map = args.data[mapping.map_index]; + auto &values = MapVector::GetValues(map); + values_list.push_back(values.GetValue(mapping.key_index)); + } + D_ASSERT(values_list.size() == keys_list.size()); + auto list_entries = GetListEntries(std::move(keys_list), std::move(values_list)); + for (auto &list_entry : list_entries) { + ListVector::PushBack(result, list_entry); + } + } + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + +static bool IsEmptyMap(const LogicalType &map) { + D_ASSERT(map.id() == LogicalTypeId::MAP); + auto &key_type = MapType::KeyType(map); + auto &value_type = MapType::ValueType(map); + return key_type.id() == LogicalType::SQLNULL && value_type.id() == LogicalType::SQLNULL; +} + +static unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + auto arg_count = arguments.size(); + if (arg_count < 2) { + throw InvalidInputException("The provided amount of arguments is incorrect, please provide 2 or more maps"); + } + + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + // Prepared statement + bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + + LogicalType expected = LogicalType::SQLNULL; + + bool is_null = true; + // Check and verify that all the maps are of the same type + for (idx_t i = 0; i < arg_count; i++) { + auto &arg = arguments[i]; + auto &map = arg->return_type; + if (map.id() == LogicalTypeId::UNKNOWN) { + // Prepared statement + bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + if (map.id() == LogicalTypeId::SQLNULL) { + // The maps are allowed to be NULL + continue; + } + if (map.id() != LogicalTypeId::MAP) { + throw InvalidInputException("MAP_CONCAT only takes map arguments"); + } + is_null = false; + if (IsEmptyMap(map)) { + // Map is allowed to be empty + continue; + } + + if (expected.id() == LogicalTypeId::SQLNULL) { + expected = map; + } else if (map != expected) { + throw InvalidInputException( + "'value' type of map differs between arguments, expected '%s', found '%s' instead", expected.ToString(), + map.ToString()); + } + } + + if (expected.id() == LogicalTypeId::SQLNULL && is_null == false) { + expected = LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL); + } + bound_function.return_type = expected; + return make_uniq(bound_function.return_type); +} + +ScalarFunction MapConcatFun::GetFunction() { + //! the arguments and return types are actually set in the binder function + ScalarFunction fun("map_concat", {}, LogicalTypeId::LIST, MapConcatFunction, MapConcatBind); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.varargs = LogicalType::ANY; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp new file mode 100644 index 00000000..487fd75f --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp @@ -0,0 +1,79 @@ +#include "core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +// Reverse of map_from_entries +static void MapEntriesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto count = args.size(); + + auto &map = args.data[0]; + if (map.GetType().id() == LogicalTypeId::SQLNULL) { + // Input is a constant NULL + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + + MapUtil::ReinterpretMap(result, map, count); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + +static LogicalType CreateReturnType(const LogicalType &map) { + auto &key_type = MapType::KeyType(map); + auto &value_type = MapType::ValueType(map); + + child_list_t child_types; + child_types.push_back(make_pair("key", key_type)); + child_types.push_back(make_pair("value", value_type)); + + auto row_type = LogicalType::STRUCT(child_types); + return LogicalType::LIST(row_type); +} + +static unique_ptr MapEntriesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.size() != 1) { + throw InvalidInputException("Too many arguments provided, only expecting a single map"); + } + auto &map = arguments[0]->return_type; + + if (map.id() == LogicalTypeId::UNKNOWN) { + // Prepared statement + bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + + if (map.id() == LogicalTypeId::SQLNULL) { + // Input is NULL, output is STRUCT(NULL, NULL)[] + auto map_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); + bound_function.return_type = CreateReturnType(map_type); + return make_uniq(bound_function.return_type); + } + + if (map.id() != LogicalTypeId::MAP) { + throw InvalidInputException("The provided argument is not a map"); + } + bound_function.return_type = CreateReturnType(map); + return make_uniq(bound_function.return_type); +} + +ScalarFunction MapEntriesFun::GetFunction() { + //! the arguments and return types are actually set in the binder function + ScalarFunction fun({}, LogicalTypeId::LIST, MapEntriesFunction, MapEntriesBind); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.varargs = LogicalType::ANY; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp new file mode 100644 index 00000000..170f2b7d --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp @@ -0,0 +1,104 @@ +#include "core_functions/scalar/map_functions.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/list/contains_or_position.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +static unique_ptr MapExtractBind(ClientContext &, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.size() != 2) { + throw BinderException("MAP_EXTRACT must have exactly two arguments"); + } + + auto &map_type = arguments[0]->return_type; + auto &input_type = arguments[1]->return_type; + + if (map_type.id() == LogicalTypeId::SQLNULL) { + bound_function.return_type = LogicalTypeId::SQLNULL; + return make_uniq(bound_function.return_type); + } + + if (map_type.id() != LogicalTypeId::MAP) { + throw BinderException("MAP_EXTRACT can only operate on MAPs"); + } + auto &value_type = MapType::ValueType(map_type); + + //! Here we have to construct the List Type that will be returned + bound_function.return_type = value_type; + auto key_type = MapType::KeyType(map_type); + if (key_type.id() != LogicalTypeId::SQLNULL && input_type.id() != LogicalTypeId::SQLNULL) { + bound_function.arguments[1] = MapType::KeyType(map_type); + } + return make_uniq(bound_function.return_type); +} + +static void MapExtractFunc(DataChunk &args, ExpressionState &state, Vector &result) { + const auto count = args.size(); + + auto &map_vec = args.data[0]; + auto &arg_vec = args.data[1]; + + const auto map_is_null = map_vec.GetType().id() == LogicalTypeId::SQLNULL; + const auto arg_is_null = arg_vec.GetType().id() == LogicalTypeId::SQLNULL; + + if (map_is_null || arg_is_null) { + // Short-circuit if either the map or the arg is NULL + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + result.Verify(count); + return; + } + + auto &key_vec = MapVector::GetKeys(map_vec); + auto &val_vec = MapVector::GetValues(map_vec); + + // Collect the matching positions + Vector pos_vec(LogicalType::INTEGER, count); + ListSearchOp(map_vec, key_vec, arg_vec, pos_vec, args.size()); + + UnifiedVectorFormat pos_format; + UnifiedVectorFormat lst_format; + + pos_vec.ToUnifiedFormat(count, pos_format); + map_vec.ToUnifiedFormat(count, lst_format); + + const auto pos_data = UnifiedVectorFormat::GetData(pos_format); + const auto inc_list_data = ListVector::GetData(map_vec); + + auto &result_validity = FlatVector::Validity(result); + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + auto lst_idx = lst_format.sel->get_index(row_idx); + if (!lst_format.validity.RowIsValid(lst_idx)) { + FlatVector::SetNull(result, row_idx, true); + continue; + } + + const auto pos_idx = pos_format.sel->get_index(row_idx); + if (!pos_format.validity.RowIsValid(pos_idx)) { + // We didnt find the key in the map, so return NULL + result_validity.SetInvalid(row_idx); + continue; + } + + // Compute the actual position of the value in the map value vector + const auto pos = inc_list_data[lst_idx].offset + UnsafeNumericCast(pos_data[pos_idx] - 1); + VectorOperations::Copy(val_vec, result, pos + 1, pos, row_idx); + } + + if (args.size() == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + + result.Verify(count); +} + +ScalarFunction MapExtractFun::GetFunction() { + ScalarFunction fun({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, MapExtractFunc, MapExtractBind); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp new file mode 100644 index 00000000..edbe1d4f --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp @@ -0,0 +1,60 @@ +#include "core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +static void MapFromEntriesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto count = args.size(); + + MapUtil::ReinterpretMap(result, args.data[0], count); + MapVector::MapConversionVerify(result, count); + result.Verify(count); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr MapFromEntriesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.size() != 1) { + throw InvalidInputException("The input argument must be a list of structs."); + } + auto &list = arguments[0]->return_type; + + if (list.id() == LogicalTypeId::UNKNOWN) { + bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + + if (list.id() != LogicalTypeId::LIST) { + throw InvalidInputException("The provided argument is not a list of structs"); + } + auto &elem_type = ListType::GetChildType(list); + if (elem_type.id() != LogicalTypeId::STRUCT) { + throw InvalidInputException("The elements of the list must be structs"); + } + auto &children = StructType::GetChildTypes(elem_type); + if (children.size() != 2) { + throw InvalidInputException("The provided struct type should only contain 2 fields, a key and a value"); + } + + bound_function.return_type = LogicalType::MAP(elem_type); + return make_uniq(bound_function.return_type); +} + +ScalarFunction MapFromEntriesFun::GetFunction() { + //! the arguments and return types are actually set in the binder function + ScalarFunction fun({}, LogicalTypeId::MAP, MapFromEntriesFunction, MapFromEntriesBind); + fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + fun.varargs = LogicalType::ANY; + BaseScalarFunction::SetReturnsError(fun); + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp new file mode 100644 index 00000000..6d99a353 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp @@ -0,0 +1,112 @@ +#include "core_functions/scalar/map_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +static void MapKeyValueFunction(DataChunk &args, ExpressionState &state, Vector &result, + Vector &(*get_child_vector)(Vector &)) { + auto &map = args.data[0]; + + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + if (map.GetType().id() == LogicalTypeId::SQLNULL) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + + auto count = args.size(); + D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); + auto child = get_child_vector(map); + + auto &entries = ListVector::GetEntry(result); + entries.Reference(child); + + UnifiedVectorFormat map_data; + map.ToUnifiedFormat(count, map_data); + + D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); + FlatVector::SetData(result, map_data.data); + FlatVector::SetValidity(result, map_data.validity); + auto list_size = ListVector::GetListSize(map); + ListVector::SetListSize(result, list_size); + if (map.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + result.Slice(*map_data.sel, count); + } + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + +static void MapKeysFunction(DataChunk &args, ExpressionState &state, Vector &result) { + MapKeyValueFunction(args, state, result, MapVector::GetKeys); +} + +static void MapValuesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + MapKeyValueFunction(args, state, result, MapVector::GetValues); +} + +static unique_ptr MapKeyValueBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments, + const LogicalType &(*type_func)(const LogicalType &)) { + if (arguments.size() != 1) { + throw InvalidInputException("Too many arguments provided, only expecting a single map"); + } + auto &map = arguments[0]->return_type; + + if (map.id() == LogicalTypeId::UNKNOWN) { + // Prepared statement + bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); + bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + return nullptr; + } + + if (map.id() == LogicalTypeId::SQLNULL) { + // Input is NULL, output is NULL[] + bound_function.return_type = LogicalType::LIST(LogicalTypeId::SQLNULL); + return make_uniq(bound_function.return_type); + } + + if (map.id() != LogicalTypeId::MAP) { + throw InvalidInputException("The provided argument is not a map"); + } + + auto &type = type_func(map); + + bound_function.return_type = LogicalType::LIST(type); + return make_uniq(bound_function.return_type); +} + +static unique_ptr MapKeysBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + return MapKeyValueBind(context, bound_function, arguments, MapType::KeyType); +} + +static unique_ptr MapValuesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + return MapKeyValueBind(context, bound_function, arguments, MapType::ValueType); +} + +ScalarFunction MapKeysFun::GetFunction() { + //! the arguments and return types are actually set in the binder function + ScalarFunction function({}, LogicalTypeId::LIST, MapKeysFunction, MapKeysBind); + function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + BaseScalarFunction::SetReturnsError(function); + function.varargs = LogicalType::ANY; + return function; +} + +ScalarFunction MapValuesFun::GetFunction() { + ScalarFunction function({}, LogicalTypeId::LIST, MapValuesFunction, MapValuesBind); + function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + BaseScalarFunction::SetReturnsError(function); + function.varargs = LogicalType::ANY; + return function; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp new file mode 100644 index 00000000..47eed7a7 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp @@ -0,0 +1,1469 @@ +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/operator/abs.hpp" +#include "duckdb/common/operator/multiply.hpp" +#include "duckdb/common/operator/numeric_binary_operators.hpp" +#include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/uhugeint.hpp" +#include "duckdb/common/types/validity_mask.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "core_functions/scalar/math_functions.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +#include +#include +#include +#include +#include + +namespace duckdb { + +template +static scalar_function_t GetScalarIntegerUnaryFunctionFixedReturn(const LogicalType &type) { + scalar_function_t function; + switch (type.id()) { + case LogicalTypeId::TINYINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::SMALLINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::INTEGER: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::BIGINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::HUGEINT: + function = &ScalarFunction::UnaryFunction; + break; + default: + throw NotImplementedException("Unimplemented type for GetScalarIntegerUnaryFunctionFixedReturn"); + } + return function; +} + +//===--------------------------------------------------------------------===// +// nextafter +//===--------------------------------------------------------------------===// +struct NextAfterOperator { + template + static inline TR Operation(TA base, TB exponent) { + throw NotImplementedException("Unimplemented type for NextAfter Function"); + } + + template + static inline double Operation(double input, double approximate_to) { + return nextafter(input, approximate_to); + } + template + static inline float Operation(float input, float approximate_to) { + return nextafterf(input, approximate_to); + } +}; + +ScalarFunctionSet NextAfterFun::GetFunctions() { + ScalarFunctionSet next_after_fun; + next_after_fun.AddFunction( + ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::BinaryFunction)); + next_after_fun.AddFunction(ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, + ScalarFunction::BinaryFunction)); + return next_after_fun; +} + +//===--------------------------------------------------------------------===// +// abs +//===--------------------------------------------------------------------===// +static unique_ptr PropagateAbsStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + D_ASSERT(child_stats.size() == 1); + // can only propagate stats if the children have stats + auto &lstats = child_stats[0]; + Value new_min, new_max; + bool potential_overflow = true; + if (NumericStats::HasMinMax(lstats)) { + switch (expr.return_type.InternalType()) { + case PhysicalType::INT8: + potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); + break; + case PhysicalType::INT16: + potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); + break; + case PhysicalType::INT32: + potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); + break; + case PhysicalType::INT64: + potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); + break; + default: + return nullptr; + } + } + if (potential_overflow) { + new_min = Value(expr.return_type); + new_max = Value(expr.return_type); + } else { + // no potential overflow + + // compute stats + auto current_min = NumericStats::Min(lstats).GetValue(); + auto current_max = NumericStats::Max(lstats).GetValue(); + + int64_t min_val, max_val; + + if (current_min < 0 && current_max < 0) { + // if both min and max are below zero, then min=abs(cur_max) and max=abs(cur_min) + min_val = AbsValue(current_max); + max_val = AbsValue(current_min); + } else if (current_min < 0) { + D_ASSERT(current_max >= 0); + // if min is below zero and max is above 0, then min=0 and max=max(cur_max, abs(cur_min)) + min_val = 0; + max_val = MaxValue(AbsValue(current_min), current_max); + } else { + // if both current_min and current_max are > 0, then the abs is a no-op and can be removed entirely + *input.expr_ptr = std::move(input.expr.children[0]); + return child_stats[0].ToUnique(); + } + new_min = Value::Numeric(expr.return_type, min_val); + new_max = Value::Numeric(expr.return_type, max_val); + expr.function.function = ScalarFunction::GetScalarUnaryFunction(expr.return_type); + } + auto stats = NumericStats::CreateEmpty(expr.return_type); + NumericStats::SetMin(stats, new_min); + NumericStats::SetMax(stats, new_max); + stats.CopyValidity(lstats); + return stats.ToUnique(); +} + +template +unique_ptr DecimalUnaryOpBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto decimal_type = arguments[0]->return_type; + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT); + break; + case PhysicalType::INT32: + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER); + break; + case PhysicalType::INT64: + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT); + break; + default: + bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT); + break; + } + bound_function.arguments[0] = decimal_type; + bound_function.return_type = decimal_type; + return nullptr; +} + +ScalarFunctionSet AbsOperatorFun::GetFunctions() { + ScalarFunctionSet abs; + for (auto &type : LogicalType::Numeric()) { + switch (type.id()) { + case LogicalTypeId::DECIMAL: + abs.AddFunction(ScalarFunction({type}, type, nullptr, DecimalUnaryOpBind)); + break; + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: { + ScalarFunction function({type}, type, ScalarFunction::GetScalarUnaryFunction(type)); + function.statistics = PropagateAbsStats; + abs.AddFunction(function); + break; + } + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + abs.AddFunction(ScalarFunction({type}, type, ScalarFunction::NopFunction)); + break; + default: + abs.AddFunction(ScalarFunction({type}, type, ScalarFunction::GetScalarUnaryFunction(type))); + break; + } + } + for (auto &func : abs.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return abs; +} + +//===--------------------------------------------------------------------===// +// bit_count +//===--------------------------------------------------------------------===// +struct BitCntOperator { + template + static inline TR Operation(TA input) { + using TU = typename std::make_unsigned::type; + TR count = 0; + for (auto value = TU(input); value; ++count) { + value &= (value - 1); + } + return count; + } +}; + +struct HugeIntBitCntOperator { + template + static inline TR Operation(TA input) { + using TU = typename std::make_unsigned::type; + TR count = 0; + + for (auto value = TU(input.upper); value; ++count) { + value &= (value - 1); + } + for (auto value = TU(input.lower); value; ++count) { + value &= (value - 1); + } + return count; + } +}; + +struct BitStringBitCntOperator { + template + static inline TR Operation(TA input) { + TR count = Bit::BitCount(input); + return count; + } +}; + +ScalarFunctionSet BitCountFun::GetFunctions() { + ScalarFunctionSet functions; + functions.AddFunction(ScalarFunction({LogicalType::TINYINT}, LogicalType::TINYINT, + ScalarFunction::UnaryFunction)); + functions.AddFunction(ScalarFunction({LogicalType::SMALLINT}, LogicalType::TINYINT, + ScalarFunction::UnaryFunction)); + functions.AddFunction(ScalarFunction({LogicalType::INTEGER}, LogicalType::TINYINT, + ScalarFunction::UnaryFunction)); + functions.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::TINYINT, + ScalarFunction::UnaryFunction)); + functions.AddFunction(ScalarFunction({LogicalType::HUGEINT}, LogicalType::TINYINT, + ScalarFunction::UnaryFunction)); + functions.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, + ScalarFunction::UnaryFunction)); + return functions; +} + +//===--------------------------------------------------------------------===// +// sign +//===--------------------------------------------------------------------===// +struct SignOperator { + template + static TR Operation(TA input) { + if (input == TA(0)) { + return 0; + } else if (input > TA(0)) { + return 1; + } else { + return -1; + } + } +}; + +template <> +int8_t SignOperator::Operation(float input) { + if (input == 0 || Value::IsNan(input)) { + return 0; + } else if (input > 0) { + return 1; + } else { + return -1; + } +} + +template <> +int8_t SignOperator::Operation(double input) { + if (input == 0 || Value::IsNan(input)) { + return 0; + } else if (input > 0) { + return 1; + } else { + return -1; + } +} + +ScalarFunctionSet SignFun::GetFunctions() { + ScalarFunctionSet sign; + for (auto &type : LogicalType::Numeric()) { + if (type.id() == LogicalTypeId::DECIMAL) { + continue; + } else { + sign.AddFunction( + ScalarFunction({type}, LogicalType::TINYINT, + ScalarFunction::GetScalarUnaryFunctionFixedReturn(type))); + } + } + return sign; +} + +//===--------------------------------------------------------------------===// +// ceil +//===--------------------------------------------------------------------===// +struct CeilOperator { + template + static inline TR Operation(TA left) { + return std::ceil(left); + } +}; + +template +static void GenericRoundFunctionDecimal(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + OP::template Operation(input, DecimalType::GetScale(func_expr.children[0]->return_type), result); +} + +template +unique_ptr BindGenericRoundFunctionDecimal(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // ceil essentially removes the scale + auto &decimal_type = arguments[0]->return_type; + auto scale = DecimalType::GetScale(decimal_type); + auto width = DecimalType::GetWidth(decimal_type); + if (scale == 0) { + bound_function.function = ScalarFunction::NopFunction; + } else { + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + bound_function.function = GenericRoundFunctionDecimal; + break; + case PhysicalType::INT32: + bound_function.function = GenericRoundFunctionDecimal; + break; + case PhysicalType::INT64: + bound_function.function = GenericRoundFunctionDecimal; + break; + default: + bound_function.function = GenericRoundFunctionDecimal; + break; + } + } + bound_function.arguments[0] = decimal_type; + bound_function.return_type = LogicalType::DECIMAL(width, 0); + return nullptr; +} + +struct CeilDecimalOperator { + template + static void Operation(DataChunk &input, uint8_t scale, Vector &result) { + T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input <= 0) { + // below 0 we floor the number (e.g. -10.5 -> -10) + return UnsafeNumericCast(input / power_of_ten); + } else { + // above 0 we ceil the number + return UnsafeNumericCast(((input - 1) / power_of_ten) + 1); + } + }); + } +}; + +ScalarFunctionSet CeilFun::GetFunctions() { + ScalarFunctionSet ceil; + for (auto &type : LogicalType::Numeric()) { + scalar_function_t func = nullptr; + bind_scalar_function_t bind_func = nullptr; + if (type.IsIntegral()) { + // no ceil for integral numbers + continue; + } + switch (type.id()) { + case LogicalTypeId::FLOAT: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DOUBLE: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DECIMAL: + bind_func = BindGenericRoundFunctionDecimal; + break; + default: + throw InternalException("Unimplemented numeric type for function \"ceil\""); + } + ceil.AddFunction(ScalarFunction({type}, type, func, bind_func)); + } + return ceil; +} + +//===--------------------------------------------------------------------===// +// floor +//===--------------------------------------------------------------------===// +struct FloorOperator { + template + static inline TR Operation(TA left) { + return std::floor(left); + } +}; + +struct FloorDecimalOperator { + template + static void Operation(DataChunk &input, uint8_t scale, Vector &result) { + T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + // below 0 we ceil the number (e.g. -10.5 -> -11) + return UnsafeNumericCast(((input + 1) / power_of_ten) - 1); + } else { + // above 0 we floor the number + return UnsafeNumericCast(input / power_of_ten); + } + }); + } +}; + +ScalarFunctionSet FloorFun::GetFunctions() { + ScalarFunctionSet floor; + for (auto &type : LogicalType::Numeric()) { + scalar_function_t func = nullptr; + bind_scalar_function_t bind_func = nullptr; + if (type.IsIntegral()) { + // no floor for integral numbers + continue; + } + switch (type.id()) { + case LogicalTypeId::FLOAT: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DOUBLE: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DECIMAL: + bind_func = BindGenericRoundFunctionDecimal; + break; + default: + throw InternalException("Unimplemented numeric type for function \"floor\""); + } + floor.AddFunction(ScalarFunction({type}, type, func, bind_func)); + } + return floor; +} + +//===--------------------------------------------------------------------===// +// trunc +//===--------------------------------------------------------------------===// +struct TruncOperator { + // Integer truncation is a NOP + template + static inline TR Operation(TA left) { + return std::trunc(left); + } +}; + +struct TruncDecimalOperator { + template + static void Operation(DataChunk &input, uint8_t scale, Vector &result) { + T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + // Always floor + return UnsafeNumericCast((input / power_of_ten)); + }); + } +}; + +ScalarFunctionSet TruncFun::GetFunctions() { + ScalarFunctionSet trunc; + for (auto &type : LogicalType::Numeric()) { + scalar_function_t func = nullptr; + bind_scalar_function_t bind_func = nullptr; + // Truncation of integers gets generated by some tools (e.g., Tableau/JDBC:Postgres) + switch (type.id()) { + case LogicalTypeId::FLOAT: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DOUBLE: + func = ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::DECIMAL: + bind_func = BindGenericRoundFunctionDecimal; + break; + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: + func = ScalarFunction::NopFunction; + break; + default: + throw InternalException("Unimplemented numeric type for function \"trunc\""); + } + trunc.AddFunction(ScalarFunction({type}, type, func, bind_func)); + } + return trunc; +} + +//===--------------------------------------------------------------------===// +// round +//===--------------------------------------------------------------------===// +struct RoundOperatorPrecision { + template + static inline TR Operation(TA input, TB precision) { + double rounded_value; + if (precision < 0) { + double modifier = std::pow(10, -TA(precision)); + rounded_value = (std::round(input / modifier)) * modifier; + if (std::isinf(rounded_value) || std::isnan(rounded_value)) { + return 0; + } + } else { + double modifier = std::pow(10, TA(precision)); + rounded_value = (std::round(input * modifier)) / modifier; + if (std::isinf(rounded_value) || std::isnan(rounded_value)) { + return input; + } + } + return LossyNumericCast(rounded_value); + } +}; + +struct RoundOperator { + template + static inline TR Operation(TA input) { + double rounded_value = round(input); + if (std::isinf(rounded_value) || std::isnan(rounded_value)) { + return input; + } + return LossyNumericCast(rounded_value); + } +}; + +struct RoundDecimalOperator { + template + static void Operation(DataChunk &input, uint8_t scale, Vector &result) { + T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); + T addition = power_of_ten / 2; + // regular round rounds towards the nearest number + // in case of a tie we round away from zero + // i.e. -10.5 -> -11, 10.5 -> 11 + // we implement this by adding (positive) or subtracting (negative) 0.5 + // and then flooring the number + // e.g. 10.5 + 0.5 = 11, floor(11) = 11 + // 10.4 + 0.5 = 10.9, floor(10.9) = 10 + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + input -= addition; + } else { + input += addition; + } + return UnsafeNumericCast(input / power_of_ten); + }); + } +}; + +struct RoundPrecisionFunctionData : public FunctionData { + explicit RoundPrecisionFunctionData(int32_t target_scale) : target_scale(target_scale) { + } + + int32_t target_scale; + + unique_ptr Copy() const override { + return make_uniq(target_scale); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return target_scale == other.target_scale; + } +}; + +template +static void DecimalRoundNegativePrecisionFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); + auto width = DecimalType::GetWidth(func_expr.children[0]->return_type); + if (info.target_scale <= -int32_t(width - source_scale)) { + // scale too big for width + result.SetVectorType(VectorType::CONSTANT_VECTOR); + result.SetValue(0, Value::INTEGER(0)); + return; + } + T divide_power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale + source_scale]); + T multiply_power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale]); + T addition = divide_power_of_ten / 2; + + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + input -= addition; + } else { + input += addition; + } + return UnsafeNumericCast(input / divide_power_of_ten * multiply_power_of_ten); + }); +} + +template +static void DecimalRoundPositivePrecisionFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); + T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[source_scale - info.target_scale]); + T addition = power_of_ten / 2; + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + input -= addition; + } else { + input += addition; + } + return UnsafeNumericCast(input / power_of_ten); + }); +} + +unique_ptr BindDecimalRoundPrecision(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto &decimal_type = arguments[0]->return_type; + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw NotImplementedException("ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); + } + Value val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]).DefaultCastAs(LogicalType::INTEGER); + if (val.IsNull()) { + throw NotImplementedException("ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); + } + // our new precision becomes the round value + // e.g. ROUND(DECIMAL(18,3), 1) -> DECIMAL(18,1) + // but ONLY if the round value is positive + // if it is negative the scale becomes zero + // i.e. ROUND(DECIMAL(18,3), -1) -> DECIMAL(18,0) + int32_t round_value = IntegerValue::Get(val); + uint8_t target_scale; + auto width = DecimalType::GetWidth(decimal_type); + auto scale = DecimalType::GetScale(decimal_type); + if (round_value < 0) { + target_scale = 0; + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + bound_function.function = DecimalRoundNegativePrecisionFunction; + break; + case PhysicalType::INT32: + bound_function.function = DecimalRoundNegativePrecisionFunction; + break; + case PhysicalType::INT64: + bound_function.function = DecimalRoundNegativePrecisionFunction; + break; + default: + bound_function.function = DecimalRoundNegativePrecisionFunction; + break; + } + } else { + if (round_value >= (int32_t)scale) { + // if round_value is bigger than or equal to scale we do nothing + bound_function.function = ScalarFunction::NopFunction; + target_scale = scale; + } else { + target_scale = NumericCast(round_value); + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + bound_function.function = DecimalRoundPositivePrecisionFunction; + break; + case PhysicalType::INT32: + bound_function.function = DecimalRoundPositivePrecisionFunction; + break; + case PhysicalType::INT64: + bound_function.function = DecimalRoundPositivePrecisionFunction; + break; + default: + bound_function.function = DecimalRoundPositivePrecisionFunction; + break; + } + } + } + bound_function.arguments[0] = decimal_type; + bound_function.return_type = LogicalType::DECIMAL(width, target_scale); + return make_uniq(round_value); +} + +ScalarFunctionSet RoundFun::GetFunctions() { + ScalarFunctionSet round; + for (auto &type : LogicalType::Numeric()) { + scalar_function_t round_prec_func = nullptr; + scalar_function_t round_func = nullptr; + bind_scalar_function_t bind_func = nullptr; + bind_scalar_function_t bind_prec_func = nullptr; + if (type.IsIntegral()) { + // no round for integral numbers + continue; + } + switch (type.id()) { + case LogicalTypeId::FLOAT: + round_func = ScalarFunction::UnaryFunction; + round_prec_func = ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::DOUBLE: + round_func = ScalarFunction::UnaryFunction; + round_prec_func = ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::DECIMAL: + bind_func = BindGenericRoundFunctionDecimal; + bind_prec_func = BindDecimalRoundPrecision; + break; + default: + throw InternalException("Unimplemented numeric type for function \"floor\""); + } + round.AddFunction(ScalarFunction({type}, type, round_func, bind_func)); + round.AddFunction(ScalarFunction({type, LogicalType::INTEGER}, type, round_prec_func, bind_prec_func)); + } + return round; +} + +//===--------------------------------------------------------------------===// +// exp +//===--------------------------------------------------------------------===// +struct ExpOperator { + template + static inline TR Operation(TA left) { + return std::exp(left); + } +}; + +ScalarFunction ExpFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// pow +//===--------------------------------------------------------------------===// +struct PowOperator { + template + static inline TR Operation(TA base, TB exponent) { + return std::pow(base, exponent); + } +}; + +ScalarFunction PowOperatorFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::BinaryFunction); +} + +//===--------------------------------------------------------------------===// +// sqrt +//===--------------------------------------------------------------------===// +struct SqrtOperator { + template + static inline TR Operation(TA input) { + if (input < 0) { + throw OutOfRangeException("cannot take square root of a negative number"); + } + return std::sqrt(input); + } +}; + +ScalarFunction SqrtFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// cbrt +//===--------------------------------------------------------------------===// +struct CbRtOperator { + template + static inline TR Operation(TA left) { + return std::cbrt(left); + } +}; + +ScalarFunction CbrtFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// ln +//===--------------------------------------------------------------------===// + +struct LnOperator { + template + static inline TR Operation(TA input) { + if (input < 0) { + throw OutOfRangeException("cannot take logarithm of a negative number"); + } + if (input == 0) { + throw OutOfRangeException("cannot take logarithm of zero"); + } + return std::log(input); + } +}; + +ScalarFunction LnFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// log +//===--------------------------------------------------------------------===// +struct Log10Operator { + template + static inline TR Operation(TA input) { + if (input < 0) { + throw OutOfRangeException("cannot take logarithm of a negative number"); + } + if (input == 0) { + throw OutOfRangeException("cannot take logarithm of zero"); + } + return std::log10(input); + } +}; + +ScalarFunction Log10Fun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// log with base +//===--------------------------------------------------------------------===// +struct LogBaseOperator { + template + static inline TR Operation(TA b, TB x) { + auto divisor = Log10Operator::Operation(b); + if (divisor == 0) { + throw OutOfRangeException("divison by zero in based logarithm"); + } + return Log10Operator::Operation(x) / divisor; + } +}; + +ScalarFunctionSet LogFun::GetFunctions() { + ScalarFunctionSet funcs; + funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::BinaryFunction)); + for (auto &function : funcs.functions) { + BaseScalarFunction::SetReturnsError(function); + } + return funcs; +} + +//===--------------------------------------------------------------------===// +// log2 +//===--------------------------------------------------------------------===// +struct Log2Operator { + template + static inline TR Operation(TA input) { + if (input < 0) { + throw OutOfRangeException("cannot take logarithm of a negative number"); + } + if (input == 0) { + throw OutOfRangeException("cannot take logarithm of zero"); + } + return std::log2(input); + } +}; + +ScalarFunction Log2Fun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// pi +//===--------------------------------------------------------------------===// +static void PiFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 0); + Value pi_value = Value::DOUBLE(PI); + result.Reference(pi_value); +} + +ScalarFunction PiFun::GetFunction() { + return ScalarFunction({}, LogicalType::DOUBLE, PiFunction); +} + +//===--------------------------------------------------------------------===// +// degrees +//===--------------------------------------------------------------------===// +struct DegreesOperator { + template + static inline TR Operation(TA left) { + return left * (180 / PI); + } +}; + +ScalarFunction DegreesFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// radians +//===--------------------------------------------------------------------===// +struct RadiansOperator { + template + static inline TR Operation(TA left) { + return left * (PI / 180); + } +}; + +ScalarFunction RadiansFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// isnan +//===--------------------------------------------------------------------===// +struct IsNanOperator { + template + static inline TR Operation(TA input) { + return Value::IsNan(input); + } +}; + +ScalarFunctionSet IsNanFun::GetFunctions() { + ScalarFunctionSet funcs; + funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + return funcs; +} + +//===--------------------------------------------------------------------===// +// signbit +//===--------------------------------------------------------------------===// +struct SignBitOperator { + template + static inline TR Operation(TA input) { + return std::signbit(input); + } +}; + +ScalarFunctionSet SignBitFun::GetFunctions() { + ScalarFunctionSet funcs; + funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + return funcs; +} + +//===--------------------------------------------------------------------===// +// isinf +//===--------------------------------------------------------------------===// +struct IsInfiniteOperator { + template + static inline TR Operation(TA input) { + return !Value::IsNan(input) && !Value::IsFinite(input); + } +}; + +template <> +bool IsInfiniteOperator::Operation(date_t input) { + return !Value::IsFinite(input); +} + +template <> +bool IsInfiniteOperator::Operation(timestamp_t input) { + return !Value::IsFinite(input); +} + +ScalarFunctionSet IsInfiniteFun::GetFunctions() { + ScalarFunctionSet funcs("isinf"); + funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + return funcs; +} + +//===--------------------------------------------------------------------===// +// isfinite +//===--------------------------------------------------------------------===// +struct IsFiniteOperator { + template + static inline TR Operation(TA input) { + return Value::IsFinite(input); + } +}; + +ScalarFunctionSet IsFiniteFun::GetFunctions() { + ScalarFunctionSet funcs; + funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BOOLEAN, + ScalarFunction::UnaryFunction)); + return funcs; +} + +//===--------------------------------------------------------------------===// +// sin +//===--------------------------------------------------------------------===// +template +struct NoInfiniteDoubleWrapper { + template + static RESULT_TYPE Operation(INPUT_TYPE input) { + if (DUCKDB_UNLIKELY(!Value::IsFinite(input))) { + if (Value::IsNan(input)) { + return input; + } + throw OutOfRangeException("input value %lf is out of range for numeric function", input); + } + return OP::template Operation(input); + } +}; + +struct SinOperator { + template + static inline TR Operation(TA input) { + return std::sin(input); + } +}; + +ScalarFunction SinFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// cos +//===--------------------------------------------------------------------===// +struct CosOperator { + template + static inline TR Operation(TA input) { + return (double)std::cos(input); + } +}; + +ScalarFunction CosFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// tan +//===--------------------------------------------------------------------===// +struct TanOperator { + template + static inline TR Operation(TA input) { + return (double)std::tan(input); + } +}; + +ScalarFunction TanFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// asin +//===--------------------------------------------------------------------===// +struct ASinOperator { + template + static inline TR Operation(TA input) { + if (input < -1 || input > 1) { + throw InvalidInputException("ASIN is undefined outside [-1,1]"); + } + return (double)std::asin(input); + } +}; + +ScalarFunction AsinFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// atan +//===--------------------------------------------------------------------===// +struct ATanOperator { + template + static inline TR Operation(TA input) { + return (double)std::atan(input); + } +}; + +ScalarFunction AtanFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// atan2 +//===--------------------------------------------------------------------===// +struct ATan2 { + template + static inline TR Operation(TA left, TB right) { + return (double)std::atan2(left, right); + } +}; + +ScalarFunction Atan2Fun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::BinaryFunction); +} + +//===--------------------------------------------------------------------===// +// acos +//===--------------------------------------------------------------------===// +struct ACos { + template + static inline TR Operation(TA input) { + if (input < -1 || input > 1) { + throw InvalidInputException("ACOS is undefined outside [-1,1]"); + } + return (double)std::acos(input); + } +}; + +ScalarFunction AcosFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// cosh +//===--------------------------------------------------------------------===// +struct CoshOperator { + template + static inline TR Operation(TA input) { + return (double)std::cosh(input); + } +}; + +ScalarFunction CoshFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// acosh +//===--------------------------------------------------------------------===// +struct AcoshOperator { + template + static inline TR Operation(TA input) { + return (double)std::acosh(input); + } +}; + +ScalarFunction AcoshFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// sinh +//===--------------------------------------------------------------------===// +struct SinhOperator { + template + static inline TR Operation(TA input) { + return (double)std::sinh(input); + } +}; + +ScalarFunction SinhFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// asinh +//===--------------------------------------------------------------------===// +struct AsinhOperator { + template + static inline TR Operation(TA input) { + return (double)std::asinh(input); + } +}; + +ScalarFunction AsinhFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// tanh +//===--------------------------------------------------------------------===// +struct TanhOperator { + template + static inline TR Operation(TA input) { + return (double)std::tanh(input); + } +}; + +ScalarFunction TanhFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// atanh +//===--------------------------------------------------------------------===// +struct AtanhOperator { + template + static inline TR Operation(TA input) { + if (input < -1 || input > 1) { + throw InvalidInputException("ATANH is undefined outside [-1,1]"); + } + if (input == -1 || input == 1) { + return INFINITY; + } + return (double)std::atanh(input); + } +}; + +ScalarFunction AtanhFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// cot +//===--------------------------------------------------------------------===// +template +struct NoInfiniteNoZeroDoubleWrapper { + template + static RESULT_TYPE Operation(INPUT_TYPE input) { + if (DUCKDB_UNLIKELY(!Value::IsFinite(input))) { + if (Value::IsNan(input)) { + return input; + } + throw OutOfRangeException("input value %lf is out of range for numeric function", input); + } + if (DUCKDB_UNLIKELY((double)input == 0.0 || (double)input == -0.0)) { + throw OutOfRangeException("input value %lf is out of range for numeric function cotangent", input); + } + return OP::template Operation(input); + } +}; + +struct CotOperator { + template + static inline TR Operation(TA input) { + return 1.0 / (double)std::tan(input); + } +}; + +ScalarFunction CotFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction>); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// gamma +//===--------------------------------------------------------------------===// +struct GammaOperator { + template + static inline TR Operation(TA input) { + if (input == 0) { + throw OutOfRangeException("cannot take gamma of zero"); + } + return std::tgamma(input); + } +}; + +ScalarFunction GammaFun::GetFunction() { + auto func = ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(func); + return func; +} + +//===--------------------------------------------------------------------===// +// gamma +//===--------------------------------------------------------------------===// +struct LogGammaOperator { + template + static inline TR Operation(TA input) { + if (input == 0) { + throw OutOfRangeException("cannot take log gamma of zero"); + } + return std::lgamma(input); + } +}; + +ScalarFunction LogGammaFun::GetFunction() { + ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// factorial(), ! +//===--------------------------------------------------------------------===// +struct FactorialOperator { + template + static inline TR Operation(TA left) { + TR ret = 1; + for (TA i = 2; i <= left; i++) { + if (!TryMultiplyOperator::Operation(ret, TR(i), ret)) { + throw OutOfRangeException("Value out of range"); + } + } + return ret; + } +}; + +ScalarFunction FactorialOperatorFun::GetFunction() { + ScalarFunction function({LogicalType::INTEGER}, LogicalType::HUGEINT, + ScalarFunction::UnaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +//===--------------------------------------------------------------------===// +// even +//===--------------------------------------------------------------------===// +struct EvenOperator { + template + static inline TR Operation(TA left) { + double value; + if (left >= 0) { + value = std::ceil(left); + } else { + value = std::ceil(-left); + value = -value; + } + if (std::floor(value / 2) * 2 != value) { + if (left >= 0) { + return value += 1; + } + return value -= 1; + } + return value; + } +}; + +ScalarFunction EvenFun::GetFunction() { + return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, + ScalarFunction::UnaryFunction); +} + +//===--------------------------------------------------------------------===// +// gcd +//===--------------------------------------------------------------------===// + +// should be replaced with std::gcd in a newer C++ standard +template +TA GreatestCommonDivisor(TA left, TA right) { + TA a = left; + TA b = right; + + // This protects the following modulo operations from a corner case, + // where we would get a runtime error due to an integer overflow. + if ((left == NumericLimits::Minimum() && right == -1) || + (left == -1 && right == NumericLimits::Minimum())) { + return 1; + } + + while (true) { + if (a == 0) { + return TryAbsOperator::Operation(b); + } + b %= a; + + if (b == 0) { + return TryAbsOperator::Operation(a); + } + a %= b; + } +} + +struct GreatestCommonDivisorOperator { + template + static inline TR Operation(TA left, TB right) { + return GreatestCommonDivisor(left, right); + } +}; + +ScalarFunctionSet GreatestCommonDivisorFun::GetFunctions() { + ScalarFunctionSet funcs; + funcs.AddFunction( + ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT, + ScalarFunction::BinaryFunction)); + funcs.AddFunction( + ScalarFunction({LogicalType::HUGEINT, LogicalType::HUGEINT}, LogicalType::HUGEINT, + ScalarFunction::BinaryFunction)); + return funcs; +} + +//===--------------------------------------------------------------------===// +// lcm +//===--------------------------------------------------------------------===// + +// should be replaced with std::lcm in a newer C++ standard +struct LeastCommonMultipleOperator { + template + static inline TR Operation(TA left, TB right) { + if (left == 0 || right == 0) { + return 0; + } + TR result; + if (!TryMultiplyOperator::Operation(left, right / GreatestCommonDivisor(left, right), result)) { + throw OutOfRangeException("lcm value is out of range"); + } + return TryAbsOperator::Operation(result); + } +}; + +ScalarFunctionSet LeastCommonMultipleFun::GetFunctions() { + ScalarFunctionSet funcs; + + funcs.AddFunction( + ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT, + ScalarFunction::BinaryFunction)); + funcs.AddFunction( + ScalarFunction({LogicalType::HUGEINT, LogicalType::HUGEINT}, LogicalType::HUGEINT, + ScalarFunction::BinaryFunction)); + for (auto &function : funcs.functions) { + BaseScalarFunction::SetReturnsError(function); + } + return funcs; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp b/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp new file mode 100644 index 00000000..103e0c2a --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp @@ -0,0 +1,330 @@ +#include "core_functions/scalar/operators_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/types/bit.hpp" + +namespace duckdb { + +template +static scalar_function_t GetScalarIntegerUnaryFunction(const LogicalType &type) { + scalar_function_t function; + switch (type.id()) { + case LogicalTypeId::TINYINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::SMALLINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::INTEGER: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::BIGINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UTINYINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::USMALLINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UINTEGER: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UBIGINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::HUGEINT: + function = &ScalarFunction::UnaryFunction; + break; + case LogicalTypeId::UHUGEINT: + function = &ScalarFunction::UnaryFunction; + break; + default: + throw NotImplementedException("Unimplemented type for GetScalarIntegerUnaryFunction"); + } + return function; +} + +template +static scalar_function_t GetScalarIntegerBinaryFunction(const LogicalType &type) { + scalar_function_t function; + switch (type.id()) { + case LogicalTypeId::TINYINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::SMALLINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::INTEGER: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::BIGINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::UTINYINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::USMALLINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::UINTEGER: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::UBIGINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::HUGEINT: + function = &ScalarFunction::BinaryFunction; + break; + case LogicalTypeId::UHUGEINT: + function = &ScalarFunction::BinaryFunction; + break; + default: + throw NotImplementedException("Unimplemented type for GetScalarIntegerBinaryFunction"); + } + return function; +} + +//===--------------------------------------------------------------------===// +// & [bitwise_and] +//===--------------------------------------------------------------------===// +struct BitwiseANDOperator { + template + static inline TR Operation(TA left, TB right) { + return left & right; + } +}; + +static void BitwiseANDOperation(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { + string_t target = StringVector::EmptyString(result, rhs.GetSize()); + + Bit::BitwiseAnd(rhs, lhs, target); + return target; + }); +} + +ScalarFunctionSet BitwiseAndFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); + } + functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseANDOperation)); + for (auto &function : functions.functions) { + BaseScalarFunction::SetReturnsError(function); + } + return functions; +} + +//===--------------------------------------------------------------------===// +// | [bitwise_or] +//===--------------------------------------------------------------------===// +struct BitwiseOROperator { + template + static inline TR Operation(TA left, TB right) { + return left | right; + } +}; + +static void BitwiseOROperation(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { + string_t target = StringVector::EmptyString(result, rhs.GetSize()); + + Bit::BitwiseOr(rhs, lhs, target); + return target; + }); +} + +ScalarFunctionSet BitwiseOrFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); + } + functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseOROperation)); + for (auto &function : functions.functions) { + BaseScalarFunction::SetReturnsError(function); + } + return functions; +} + +//===--------------------------------------------------------------------===// +// # [bitwise_xor] +//===--------------------------------------------------------------------===// +struct BitwiseXOROperator { + template + static inline TR Operation(TA left, TB right) { + return left ^ right; + } +}; + +static void BitwiseXOROperation(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { + string_t target = StringVector::EmptyString(result, rhs.GetSize()); + + Bit::BitwiseXor(rhs, lhs, target); + return target; + }); +} + +ScalarFunctionSet BitwiseXorFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); + } + functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseXOROperation)); + for (auto &function : functions.functions) { + BaseScalarFunction::SetReturnsError(function); + } + return functions; +} + +//===--------------------------------------------------------------------===// +// ~ [bitwise_not] +//===--------------------------------------------------------------------===// +struct BitwiseNotOperator { + template + static inline TR Operation(TA input) { + return ~input; + } +}; + +static void BitwiseNOTOperation(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { + string_t target = StringVector::EmptyString(result, input.GetSize()); + + Bit::BitwiseNot(input, target); + return target; + }); +} + +ScalarFunctionSet BitwiseNotFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction(ScalarFunction({type}, type, GetScalarIntegerUnaryFunction(type))); + } + functions.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIT, BitwiseNOTOperation)); + for (auto &function : functions.functions) { + BaseScalarFunction::SetReturnsError(function); + } + return functions; +} + +//===--------------------------------------------------------------------===// +// << [bitwise_left_shift] +//===--------------------------------------------------------------------===// +struct BitwiseShiftLeftOperator { + template + static inline TR Operation(TA input, TB shift) { + TA max_shift = TA(sizeof(TA) * 8) + (NumericLimits::IsSigned() ? 0 : 1); + if (input < 0) { + throw OutOfRangeException("Cannot left-shift negative number %s", NumericHelper::ToString(input)); + } + if (shift < 0) { + throw OutOfRangeException("Cannot left-shift by negative number %s", NumericHelper::ToString(shift)); + } + if (shift >= max_shift) { + if (input == 0) { + return 0; + } + throw OutOfRangeException("Left-shift value %s is out of range", NumericHelper::ToString(shift)); + } + if (shift == 0) { + return input; + } + TA max_value = UnsafeNumericCast((TA(1) << (max_shift - shift - 1))); + if (input >= max_value) { + throw OutOfRangeException("Overflow in left shift (%s << %s)", NumericHelper::ToString(input), + NumericHelper::ToString(shift)); + } + return UnsafeNumericCast(input << shift); + } +}; + +static void BitwiseShiftLeftOperation(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t shift) { + auto max_shift = UnsafeNumericCast(Bit::BitLength(input)); + if (shift == 0) { + return input; + } + if (shift < 0) { + throw OutOfRangeException("Cannot left-shift by negative number %s", NumericHelper::ToString(shift)); + } + string_t target = StringVector::EmptyString(result, input.GetSize()); + + if (shift >= max_shift) { + Bit::SetEmptyBitString(target, input); + return target; + } + Bit::LeftShift(input, UnsafeNumericCast(shift), target); + return target; + }); +} + +ScalarFunctionSet LeftShiftFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); + } + functions.AddFunction( + ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftLeftOperation)); + for (auto &function : functions.functions) { + BaseScalarFunction::SetReturnsError(function); + } + return functions; +} + +//===--------------------------------------------------------------------===// +// >> [bitwise_right_shift] +//===--------------------------------------------------------------------===// +template +bool RightShiftInRange(T shift) { + return shift >= 0 && shift < T(sizeof(T) * 8); +} + +struct BitwiseShiftRightOperator { + template + static inline TR Operation(TA input, TB shift) { + return RightShiftInRange(shift) ? input >> shift : 0; + } +}; + +static void BitwiseShiftRightOperation(DataChunk &args, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t shift) { + auto max_shift = UnsafeNumericCast(Bit::BitLength(input)); + if (shift == 0) { + return input; + } + string_t target = StringVector::EmptyString(result, input.GetSize()); + if (shift < 0 || shift >= max_shift) { + Bit::SetEmptyBitString(target, input); + return target; + } + Bit::RightShift(input, UnsafeNumericCast(shift), target); + return target; + }); +} + +ScalarFunctionSet RightShiftFun::GetFunctions() { + ScalarFunctionSet functions; + for (auto &type : LogicalType::Integral()) { + functions.AddFunction( + ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); + } + functions.AddFunction( + ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftRightOperation)); + for (auto &function : functions.functions) { + BaseScalarFunction::SetReturnsError(function); + } + return functions; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/random/random.cpp b/src/duckdb/extension/core_functions/scalar/random/random.cpp new file mode 100644 index 00000000..3054170f --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/random/random.cpp @@ -0,0 +1,64 @@ +#include "core_functions/scalar/random_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/random_engine.hpp" +#include "duckdb/common/types/uuid.hpp" + +namespace duckdb { + +struct RandomLocalState : public FunctionLocalState { + explicit RandomLocalState(uint64_t seed) : random_engine(0) { + random_engine.SetSeed(seed); + } + + RandomEngine random_engine; +}; + +static void RandomFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 0); + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < args.size(); i++) { + result_data[i] = lstate.random_engine.NextRandom(); + } +} + +static unique_ptr RandomInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data) { + auto &random_engine = RandomEngine::Get(state.GetContext()); + lock_guard guard(random_engine.lock); + return make_uniq(random_engine.NextRandomInteger64()); +} + +ScalarFunction RandomFun::GetFunction() { + ScalarFunction random("random", {}, LogicalType::DOUBLE, RandomFunction, nullptr, nullptr, nullptr, + RandomInitLocalState); + random.stability = FunctionStability::VOLATILE; + return random; +} + +static void GenerateUUIDFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 0); + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetData(result); + + for (idx_t i = 0; i < args.size(); i++) { + result_data[i] = UUID::GenerateRandomUUID(lstate.random_engine); + } +} + +ScalarFunction UUIDFun::GetFunction() { + ScalarFunction uuid_function({}, LogicalType::UUID, GenerateUUIDFunction, nullptr, nullptr, nullptr, + RandomInitLocalState); + // generate a random uuid + uuid_function.stability = FunctionStability::VOLATILE; + return uuid_function; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/random/setseed.cpp b/src/duckdb/extension/core_functions/scalar/random/setseed.cpp new file mode 100644 index 00000000..ca286528 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/random/setseed.cpp @@ -0,0 +1,62 @@ +#include "core_functions/scalar/random_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/random_engine.hpp" + +namespace duckdb { + +struct SetseedBindData : public FunctionData { + //! The client context for the function call + ClientContext &context; + + explicit SetseedBindData(ClientContext &context) : context(context) { + } + + unique_ptr Copy() const override { + return make_uniq(context); + } + + bool Equals(const FunctionData &other_p) const override { + return true; + } +}; + +static void SetSeedFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto &input = args.data[0]; + input.Flatten(args.size()); + + auto input_seeds = FlatVector::GetData(input); + uint32_t half_max = NumericLimits::Maximum() / 2; + + auto &random_engine = RandomEngine::Get(info.context); + for (idx_t i = 0; i < args.size(); i++) { + if (input_seeds[i] < -1.0 || input_seeds[i] > 1.0 || Value::IsNan(input_seeds[i])) { + throw InvalidInputException("SETSEED accepts seed values between -1.0 and 1.0, inclusive"); + } + auto norm_seed = LossyNumericCast((input_seeds[i] + 1.0) * half_max); + random_engine.SetSeed(norm_seed); + } + + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); +} + +unique_ptr SetSeedBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + return make_uniq(context); +} + +ScalarFunction SetseedFun::GetFunction() { + ScalarFunction setseed("setseed", {LogicalType::DOUBLE}, LogicalType::SQLNULL, SetSeedFunction, SetSeedBind); + setseed.stability = FunctionStability::VOLATILE; + BaseScalarFunction::SetReturnsError(setseed); + return setseed; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/ascii.cpp b/src/duckdb/extension/core_functions/scalar/string/ascii.cpp new file mode 100644 index 00000000..4083c85d --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/ascii.cpp @@ -0,0 +1,24 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "utf8proc.hpp" +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +struct AsciiOperator { + template + static inline TR Operation(const TA &input) { + auto str = input.GetData(); + if (Utf8Proc::Analyze(str, input.GetSize()) == UnicodeType::ASCII) { + return str[0]; + } + int utf8_bytes = 4; + return Utf8Proc::UTF8ToCodepoint(str, utf8_bytes); + } +}; + +ScalarFunction ASCIIFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::INTEGER, + ScalarFunction::UnaryFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/bar.cpp b/src/duckdb/extension/core_functions/scalar/string/bar.cpp new file mode 100644 index 00000000..957b8c62 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/bar.cpp @@ -0,0 +1,100 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/unicode_bar.hpp" +#include "duckdb/common/vector_operations/generic_executor.hpp" + +namespace duckdb { + +static string_t BarScalarFunction(double x, double min, double max, double max_width, string &result) { + static const char *FULL_BLOCK = UnicodeBar::FullBlock(); + static const char *const *PARTIAL_BLOCKS = UnicodeBar::PartialBlocks(); + static const idx_t PARTIAL_BLOCKS_COUNT = UnicodeBar::PartialBlocksCount(); + + if (!Value::IsFinite(max_width)) { + throw OutOfRangeException("Max bar width must not be NaN or infinity"); + } + if (max_width < 1) { + throw OutOfRangeException("Max bar width must be >= 1"); + } + if (max_width > 1000) { + throw OutOfRangeException("Max bar width must be <= 1000"); + } + + double width; + + if (Value::IsNan(x) || Value::IsNan(min) || Value::IsNan(max) || x <= min) { + width = 0; + } else if (x >= max) { + width = max_width; + } else { + width = max_width * (x - min) / (max - min); + } + + if (!Value::IsFinite(width)) { + throw OutOfRangeException("Bar width must not be NaN or infinity"); + } + + result.clear(); + idx_t used_blocks = 0; + + auto width_as_int = LossyNumericCast(width * PARTIAL_BLOCKS_COUNT); + idx_t full_blocks_count = (width_as_int / PARTIAL_BLOCKS_COUNT); + for (idx_t i = 0; i < full_blocks_count; i++) { + used_blocks++; + result += FULL_BLOCK; + } + + idx_t remaining = width_as_int % PARTIAL_BLOCKS_COUNT; + + if (remaining) { + used_blocks++; + result += PARTIAL_BLOCKS[remaining]; + } + + const idx_t integer_max_width = (idx_t)max_width; + if (used_blocks < integer_max_width) { + result += std::string(integer_max_width - used_blocks, ' '); + } + return string_t(result); +} + +static void BarFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 3 || args.ColumnCount() == 4); + auto &x_arg = args.data[0]; + auto &min_arg = args.data[1]; + auto &max_arg = args.data[2]; + string buffer; + + if (args.ColumnCount() == 3) { + GenericExecutor::ExecuteTernary, PrimitiveType, PrimitiveType, + PrimitiveType>( + x_arg, min_arg, max_arg, result, args.size(), + [&](PrimitiveType x, PrimitiveType min, PrimitiveType max) { + return StringVector::AddString(result, BarScalarFunction(x.val, min.val, max.val, 80, buffer)); + }); + } else { + auto &width_arg = args.data[3]; + GenericExecutor::ExecuteQuaternary, PrimitiveType, PrimitiveType, + PrimitiveType, PrimitiveType>( + x_arg, min_arg, max_arg, width_arg, result, args.size(), + [&](PrimitiveType x, PrimitiveType min, PrimitiveType max, + PrimitiveType width) { + return StringVector::AddString(result, BarScalarFunction(x.val, min.val, max.val, width.val, buffer)); + }); + } +} + +ScalarFunctionSet BarFun::GetFunctions() { + ScalarFunctionSet bar; + bar.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE}, + LogicalType::VARCHAR, BarFunction)); + bar.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE}, + LogicalType::VARCHAR, BarFunction)); + return bar; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/chr.cpp b/src/duckdb/extension/core_functions/scalar/string/chr.cpp new file mode 100644 index 00000000..bca2de6d --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/chr.cpp @@ -0,0 +1,48 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "utf8proc.hpp" +#include "utf8proc_wrapper.hpp" + +namespace duckdb { + +struct ChrOperator { + static void GetCodepoint(int32_t input, char c[], int &utf8_bytes) { + if (input < 0 || !Utf8Proc::CodepointToUtf8(input, utf8_bytes, &c[0])) { + throw InvalidInputException("Invalid UTF8 Codepoint %d", input); + } + } + + template + static inline TR Operation(const TA &input) { + char c[5] = {'\0', '\0', '\0', '\0', '\0'}; + int utf8_bytes; + GetCodepoint(input, c, utf8_bytes); + return string_t(&c[0], UnsafeNumericCast(utf8_bytes)); + } +}; + +#ifdef DUCKDB_DEBUG_NO_INLINE +// the chr function depends on the data always being inlined (which is always possible, since it outputs max 4 bytes) +// to enable chr when string inlining is disabled we create a special function here +static void ChrFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &code_vec = args.data[0]; + + char c[5] = {'\0', '\0', '\0', '\0', '\0'}; + int utf8_bytes; + UnaryExecutor::Execute(code_vec, result, args.size(), [&](int32_t input) { + ChrOperator::GetCodepoint(input, c, utf8_bytes); + return StringVector::AddString(result, &c[0], UnsafeNumericCast(utf8_bytes)); + }); +} +#endif + +ScalarFunction ChrFun::GetFunction() { + return ScalarFunction("chr", {LogicalType::INTEGER}, LogicalType::VARCHAR, +#ifdef DUCKDB_DEBUG_NO_INLINE + ChrFunction +#else + ScalarFunction::UnaryFunction +#endif + ); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/damerau_levenshtein.cpp b/src/duckdb/extension/core_functions/scalar/string/damerau_levenshtein.cpp new file mode 100644 index 00000000..91b0fbd3 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/damerau_levenshtein.cpp @@ -0,0 +1,104 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +// Using Lowrance-Wagner (LW) algorithm: https://doi.org/10.1145%2F321879.321880 +// Can't calculate as trivial modification to levenshtein algorithm +// as we need to potentially know about earlier in the string +static idx_t DamerauLevenshteinDistance(const string_t &source, const string_t &target) { + // costs associated with each type of edit, to aid readability + constexpr uint8_t COST_SUBSTITUTION = 1; + constexpr uint8_t COST_INSERTION = 1; + constexpr uint8_t COST_DELETION = 1; + constexpr uint8_t COST_TRANSPOSITION = 1; + const auto source_len = source.GetSize(); + const auto target_len = target.GetSize(); + + // If one string is empty, the distance equals the length of the other string + // either through target_len insertions + // or source_len deletions + if (source_len == 0) { + return target_len * COST_INSERTION; + } else if (target_len == 0) { + return source_len * COST_DELETION; + } + + const auto source_str = source.GetData(); + const auto target_str = target.GetData(); + + // larger than the largest possible value: + const auto inf = source_len * COST_DELETION + target_len * COST_INSERTION + 1; + // minimum edit distance from prefix of source string to prefix of target string + // same object as H in LW paper (with indices offset by 1) + vector> distance(source_len + 2, vector(target_len + 2, inf)); + // keeps track of the largest string indices of source string matching each character + // same as DA in LW paper + map largest_source_chr_matching; + + // initialise row/column corresponding to zero-length strings + // partial string -> empty requires a deletion for each character + for (idx_t source_idx = 0; source_idx <= source_len; source_idx++) { + distance[source_idx + 1][1] = source_idx * COST_DELETION; + } + // and empty -> partial string means simply inserting characters + for (idx_t target_idx = 1; target_idx <= target_len; target_idx++) { + distance[1][target_idx + 1] = target_idx * COST_INSERTION; + } + // loop through string indices - these are offset by 2 from distance indices + for (idx_t source_idx = 0; source_idx < source_len; source_idx++) { + // keeps track of the largest string indices of target string matching current source character + // same as DB in LW paper + idx_t largest_target_chr_matching; + largest_target_chr_matching = 0; + for (idx_t target_idx = 0; target_idx < target_len; target_idx++) { + // correspond to i1 and j1 in LW paper respectively + idx_t largest_source_chr_matching_target; + idx_t largest_target_chr_matching_source; + // cost associated to diagnanl shift in distance matrix + // corresponds to d in LW paper + uint8_t cost_diagonal_shift; + largest_source_chr_matching_target = largest_source_chr_matching[target_str[target_idx]]; + largest_target_chr_matching_source = largest_target_chr_matching; + // if characters match, diagonal move costs nothing and we update our largest target index + // otherwise move is substitution and costs as such + if (source_str[source_idx] == target_str[target_idx]) { + cost_diagonal_shift = 0; + largest_target_chr_matching = target_idx + 1; + } else { + cost_diagonal_shift = COST_SUBSTITUTION; + } + distance[source_idx + 2][target_idx + 2] = MinValue( + distance[source_idx + 1][target_idx + 1] + cost_diagonal_shift, + MinValue(distance[source_idx + 2][target_idx + 1] + COST_INSERTION, + MinValue(distance[source_idx + 1][target_idx + 2] + COST_DELETION, + distance[largest_source_chr_matching_target][largest_target_chr_matching_source] + + (source_idx - largest_source_chr_matching_target) * COST_DELETION + + COST_TRANSPOSITION + + (target_idx - largest_target_chr_matching_source) * COST_INSERTION))); + } + largest_source_chr_matching[source_str[source_idx]] = source_idx + 1; + } + return distance[source_len + 1][target_len + 1]; +} + +static int64_t DamerauLevenshteinScalarFunction(Vector &result, const string_t source, const string_t target) { + return (int64_t)DamerauLevenshteinDistance(source, target); +} + +static void DamerauLevenshteinFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &source_vec = args.data[0]; + auto &target_vec = args.data[1]; + + BinaryExecutor::Execute( + source_vec, target_vec, result, args.size(), + [&](string_t source, string_t target) { return DamerauLevenshteinScalarFunction(result, source, target); }); +} + +ScalarFunction DamerauLevenshteinFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, + DamerauLevenshteinFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/format_bytes.cpp b/src/duckdb/extension/core_functions/scalar/string/format_bytes.cpp new file mode 100644 index 00000000..46db22f2 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/format_bytes.cpp @@ -0,0 +1,34 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +template +static void FormatBytesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](int64_t bytes) { + bool is_negative = bytes < 0; + idx_t unsigned_bytes; + if (bytes < 0) { + if (bytes == NumericLimits::Minimum()) { + unsigned_bytes = idx_t(NumericLimits::Maximum()) + 1; + } else { + unsigned_bytes = idx_t(-bytes); + } + } else { + unsigned_bytes = idx_t(bytes); + } + return StringVector::AddString(result, (is_negative ? "-" : "") + + StringUtil::BytesToHumanReadableString(unsigned_bytes, MULTIPLIER)); + }); +} + +ScalarFunction FormatBytesFun::GetFunction() { + return ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, FormatBytesFunction<1024>); +} + +ScalarFunction FormatreadabledecimalsizeFun::GetFunction() { + return ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, FormatBytesFunction<1000>); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/hamming.cpp b/src/duckdb/extension/core_functions/scalar/string/hamming.cpp new file mode 100644 index 00000000..b32a8019 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/hamming.cpp @@ -0,0 +1,45 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" + +#include +#include + +namespace duckdb { + +static int64_t MismatchesScalarFunction(Vector &result, const string_t str, string_t tgt) { + idx_t str_len = str.GetSize(); + idx_t tgt_len = tgt.GetSize(); + + if (str_len != tgt_len) { + throw InvalidInputException("Mismatch Function: Strings must be of equal length!"); + } + if (str_len < 1) { + throw InvalidInputException("Mismatch Function: Strings must be of length > 0!"); + } + + idx_t mismatches = 0; + auto str_str = str.GetData(); + auto tgt_str = tgt.GetData(); + + for (idx_t idx = 0; idx < str_len; ++idx) { + if (str_str[idx] != tgt_str[idx]) { + mismatches++; + } + } + return (int64_t)mismatches; +} + +static void MismatchesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &tgt_vec = args.data[1]; + + BinaryExecutor::Execute( + str_vec, tgt_vec, result, args.size(), + [&](string_t str, string_t tgt) { return MismatchesScalarFunction(result, str, tgt); }); +} + +ScalarFunction HammingFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, MismatchesFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/hex.cpp b/src/duckdb/extension/core_functions/scalar/string/hex.cpp new file mode 100644 index 00000000..cbf541e1 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/hex.cpp @@ -0,0 +1,440 @@ +#include "duckdb/common/bit_utils.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/blob.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "core_functions/scalar/string_functions.hpp" + +namespace duckdb { + +static void WriteHexBytes(uint64_t x, char *&output, idx_t buffer_size) { + idx_t offset = buffer_size * 4; + + for (; offset >= 4; offset -= 4) { + uint8_t byte = (x >> (offset - 4)) & 0x0F; + *output = Blob::HEX_TABLE[byte]; + output++; + } +} + +template +static void WriteHugeIntHexBytes(T x, char *&output, idx_t buffer_size) { + idx_t offset = buffer_size * 4; + auto upper = x.upper; + auto lower = x.lower; + + for (; offset >= 68; offset -= 4) { + uint8_t byte = (upper >> (offset - 68)) & 0x0F; + *output = Blob::HEX_TABLE[byte]; + output++; + } + + for (; offset >= 4; offset -= 4) { + uint8_t byte = (lower >> (offset - 4)) & 0x0F; + *output = Blob::HEX_TABLE[byte]; + output++; + } +} + +static void WriteBinBytes(uint64_t x, char *&output, idx_t buffer_size) { + idx_t offset = buffer_size; + for (; offset >= 1; offset -= 1) { + *output = NumericCast(((x >> (offset - 1)) & 0x01) + '0'); + output++; + } +} + +template +static void WriteHugeIntBinBytes(T x, char *&output, idx_t buffer_size) { + auto upper = x.upper; + auto lower = x.lower; + idx_t offset = buffer_size; + + for (; offset >= 65; offset -= 1) { + *output = ((upper >> (offset - 65)) & 0x01) + '0'; + output++; + } + + for (; offset >= 1; offset -= 1) { + *output = ((lower >> (offset - 1)) & 0x01) + '0'; + output++; + } +} + +struct HexStrOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto data = input.GetData(); + auto size = input.GetSize(); + + // Allocate empty space + auto target = StringVector::EmptyString(result, size * 2); + auto output = target.GetDataWriteable(); + + for (idx_t i = 0; i < size; ++i) { + *output = Blob::HEX_TABLE[(data[i] >> 4) & 0x0F]; + output++; + *output = Blob::HEX_TABLE[data[i] & 0x0F]; + output++; + } + + target.Finalize(); + return target; + } +}; + +struct HexIntegralOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + + auto num_leading_zero = CountZeros::Leading(static_cast(input)); + idx_t num_bits_to_check = 64 - num_leading_zero; + D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); + + idx_t buffer_size = (num_bits_to_check + 3) / 4; + + // Special case: All bits are zero + if (buffer_size == 0) { + auto target = StringVector::EmptyString(result, 1); + auto output = target.GetDataWriteable(); + *output = '0'; + target.Finalize(); + return target; + } + + D_ASSERT(buffer_size > 0); + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + WriteHexBytes(static_cast(input), output, buffer_size); + + target.Finalize(); + return target; + } +}; + +struct HexHugeIntOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + + idx_t num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); + idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); + + // Special case: All bits are zero + if (buffer_size == 0) { + auto target = StringVector::EmptyString(result, 1); + auto output = target.GetDataWriteable(); + *output = '0'; + target.Finalize(); + return target; + } + + D_ASSERT(buffer_size > 0); + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + WriteHugeIntHexBytes(input, output, buffer_size); + + target.Finalize(); + return target; + } +}; + +struct HexUhugeIntOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + + idx_t num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); + idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); + + // Special case: All bits are zero + if (buffer_size == 0) { + auto target = StringVector::EmptyString(result, 1); + auto output = target.GetDataWriteable(); + *output = '0'; + target.Finalize(); + return target; + } + + D_ASSERT(buffer_size > 0); + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + WriteHugeIntHexBytes(input, output, buffer_size); + + target.Finalize(); + return target; + } +}; + +template +static void ToHexFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + auto &input = args.data[0]; + idx_t count = args.size(); + UnaryExecutor::ExecuteString(input, result, count); +} + +struct BinaryStrOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto data = input.GetData(); + auto size = input.GetSize(); + + // Allocate empty space + auto target = StringVector::EmptyString(result, size * 8); + auto output = target.GetDataWriteable(); + + for (idx_t i = 0; i < size; ++i) { + auto byte = static_cast(data[i]); + for (idx_t i = 8; i >= 1; --i) { + *output = ((byte >> (i - 1)) & 0x01) + '0'; + output++; + } + } + + target.Finalize(); + return target; + } +}; + +struct BinaryIntegralOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + + auto num_leading_zero = CountZeros::Leading(static_cast(input)); + idx_t num_bits_to_check = 64 - num_leading_zero; + D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); + + idx_t buffer_size = num_bits_to_check; + + // Special case: All bits are zero + if (buffer_size == 0) { + auto target = StringVector::EmptyString(result, 1); + auto output = target.GetDataWriteable(); + *output = '0'; + target.Finalize(); + return target; + } + + D_ASSERT(buffer_size > 0); + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + WriteBinBytes(static_cast(input), output, buffer_size); + + target.Finalize(); + return target; + } +}; + +struct BinaryHugeIntOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); + idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; + + // Special case: All bits are zero + if (buffer_size == 0) { + auto target = StringVector::EmptyString(result, 1); + auto output = target.GetDataWriteable(); + *output = '0'; + target.Finalize(); + return target; + } + + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + WriteHugeIntBinBytes(input, output, buffer_size); + + target.Finalize(); + return target; + } +}; + +struct BinaryUhugeIntOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); + idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; + + // Special case: All bits are zero + if (buffer_size == 0) { + auto target = StringVector::EmptyString(result, 1); + auto output = target.GetDataWriteable(); + *output = '0'; + target.Finalize(); + return target; + } + + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + WriteHugeIntBinBytes(input, output, buffer_size); + + target.Finalize(); + return target; + } +}; + +struct FromHexOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto data = input.GetData(); + auto size = input.GetSize(); + + if (size > NumericLimits::Maximum()) { + throw InvalidInputException("Hexadecimal input length larger than 2^32 are not supported"); + } + + D_ASSERT(size <= NumericLimits::Maximum()); + auto buffer_size = (size + 1) / 2; + + // Allocate empty space + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + // Treated as a single byte + idx_t i = 0; + if (size % 2 != 0) { + *output = static_cast(StringUtil::GetHexValue(data[i])); + i++; + output++; + } + + for (; i < size; i += 2) { + uint8_t major = StringUtil::GetHexValue(data[i]); + uint8_t minor = StringUtil::GetHexValue(data[i + 1]); + *output = static_cast((major << 4) | minor); + output++; + } + + target.Finalize(); + return target; + } +}; + +struct FromBinaryOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto data = input.GetData(); + auto size = input.GetSize(); + + if (size > NumericLimits::Maximum()) { + throw InvalidInputException("Binary input length larger than 2^32 are not supported"); + } + + D_ASSERT(size <= NumericLimits::Maximum()); + auto buffer_size = (size + 7) / 8; + + // Allocate empty space + auto target = StringVector::EmptyString(result, buffer_size); + auto output = target.GetDataWriteable(); + + // Treated as a single byte + idx_t i = 0; + if (size % 8 != 0) { + uint8_t byte = 0; + for (idx_t j = size % 8; j > 0; --j) { + byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); + i++; + } + *output = static_cast(byte); // binary eh + output++; + } + + while (i < size) { + uint8_t byte = 0; + for (idx_t j = 8; j > 0; --j) { + byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); + i++; + } + *output = static_cast(byte); + output++; + } + + target.Finalize(); + return target; + } +}; + +template +static void ToBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + auto &input = args.data[0]; + idx_t count = args.size(); + UnaryExecutor::ExecuteString(input, result, count); +} + +static void FromBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + D_ASSERT(args.data[0].GetType().InternalType() == PhysicalType::VARCHAR); + auto &input = args.data[0]; + idx_t count = args.size(); + + UnaryExecutor::ExecuteString(input, result, count); +} + +static void FromHexFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + D_ASSERT(args.data[0].GetType().InternalType() == PhysicalType::VARCHAR); + auto &input = args.data[0]; + idx_t count = args.size(); + + UnaryExecutor::ExecuteString(input, result, count); +} + +ScalarFunctionSet HexFun::GetFunctions() { + ScalarFunctionSet to_hex; + to_hex.AddFunction( + ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToHexFunction)); + to_hex.AddFunction( + ScalarFunction({LogicalType::VARINT}, LogicalType::VARCHAR, ToHexFunction)); + to_hex.AddFunction( + ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, ToHexFunction)); + to_hex.AddFunction( + ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, ToHexFunction)); + to_hex.AddFunction( + ScalarFunction({LogicalType::UBIGINT}, LogicalType::VARCHAR, ToHexFunction)); + to_hex.AddFunction( + ScalarFunction({LogicalType::HUGEINT}, LogicalType::VARCHAR, ToHexFunction)); + to_hex.AddFunction( + ScalarFunction({LogicalType::UHUGEINT}, LogicalType::VARCHAR, ToHexFunction)); + return to_hex; +} + +ScalarFunction UnhexFun::GetFunction() { + ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, FromHexFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +ScalarFunctionSet BinFun::GetFunctions() { + ScalarFunctionSet to_binary; + + to_binary.AddFunction( + ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToBinaryFunction)); + to_binary.AddFunction( + ScalarFunction({LogicalType::VARINT}, LogicalType::VARCHAR, ToBinaryFunction)); + to_binary.AddFunction(ScalarFunction({LogicalType::UBIGINT}, LogicalType::VARCHAR, + ToBinaryFunction)); + to_binary.AddFunction( + ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, ToBinaryFunction)); + to_binary.AddFunction(ScalarFunction({LogicalType::HUGEINT}, LogicalType::VARCHAR, + ToBinaryFunction)); + to_binary.AddFunction(ScalarFunction({LogicalType::UHUGEINT}, LogicalType::VARCHAR, + ToBinaryFunction)); + return to_binary; +} + +ScalarFunction UnbinFun::GetFunction() { + ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, FromBinaryFunction); + BaseScalarFunction::SetReturnsError(function); + return function; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/instr.cpp b/src/duckdb/extension/core_functions/scalar/string/instr.cpp new file mode 100644 index 00000000..77539e7c --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/instr.cpp @@ -0,0 +1,58 @@ +#include "core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/scalar/string_common.hpp" +#include "utf8proc.hpp" + +namespace duckdb { + +struct InstrOperator { + template + static inline TR Operation(TA haystack, TB needle) { + int64_t string_position = 0; + + auto location = FindStrInStr(haystack, needle); + if (location != DConstants::INVALID_INDEX) { + auto len = (utf8proc_ssize_t)location; + auto str = reinterpret_cast(haystack.GetData()); + D_ASSERT(len <= (utf8proc_ssize_t)haystack.GetSize()); + for (++string_position; len > 0; ++string_position) { + utf8proc_int32_t codepoint; + auto bytes = utf8proc_iterate(str, len, &codepoint); + str += bytes; + len -= bytes; + } + } + return string_position; + } +}; + +struct InstrAsciiOperator { + template + static inline TR Operation(TA haystack, TB needle) { + auto location = FindStrInStr(haystack, needle); + return UnsafeNumericCast(location == DConstants::INVALID_INDEX ? 0U : location + 1U); + } +}; + +static unique_ptr InStrPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + D_ASSERT(child_stats.size() == 2); + // can only propagate stats if the children have stats + // for strpos, we only care if the FIRST string has unicode or not + if (!StringStats::CanContainUnicode(child_stats[0])) { + expr.function.function = ScalarFunction::BinaryFunction; + } + return nullptr; +} + +ScalarFunction InstrFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, + ScalarFunction::BinaryFunction, nullptr, nullptr, + InStrPropagateStats); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/jaccard.cpp b/src/duckdb/extension/core_functions/scalar/string/jaccard.cpp new file mode 100644 index 00000000..eae31dc9 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/jaccard.cpp @@ -0,0 +1,58 @@ +#include "duckdb/common/map.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "core_functions/scalar/string_functions.hpp" + +#include +#include + +namespace duckdb { + +namespace { +constexpr size_t MAX_SIZE = std::numeric_limits::max() + 1; +} + +static inline std::bitset GetSet(const string_t &str) { + std::bitset array_set; + + idx_t str_len = str.GetSize(); + auto s = str.GetData(); + + for (idx_t pos = 0; pos < str_len; pos++) { + array_set.set(static_cast(s[pos])); + } + return array_set; +} + +static double JaccardSimilarity(const string_t &str, const string_t &txt) { + if (str.GetSize() < 1 || txt.GetSize() < 1) { + throw InvalidInputException("Jaccard Function: An argument too short!"); + } + std::bitset m_str, m_txt; + + m_str = GetSet(str); + m_txt = GetSet(txt); + + idx_t size_intersect = (m_str & m_txt).count(); + idx_t size_union = (m_str | m_txt).count(); + + return static_cast(size_intersect) / static_cast(size_union); +} + +static double JaccardScalarFunction(Vector &result, const string_t str, string_t tgt) { + return (double)JaccardSimilarity(str, tgt); +} + +static void JaccardFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &tgt_vec = args.data[1]; + + BinaryExecutor::Execute( + str_vec, tgt_vec, result, args.size(), + [&](string_t str, string_t tgt) { return JaccardScalarFunction(result, str, tgt); }); +} + +ScalarFunction JaccardFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaccardFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/jaro_winkler.cpp b/src/duckdb/extension/core_functions/scalar/string/jaro_winkler.cpp new file mode 100644 index 00000000..13db07c7 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/jaro_winkler.cpp @@ -0,0 +1,112 @@ +#include "jaro_winkler.hpp" + +#include "core_functions/scalar/string_functions.hpp" + +namespace duckdb { + +static inline double JaroScalarFunction(const string_t &s1, const string_t &s2, const double_t &score_cutoff = 0.0) { + auto s1_begin = s1.GetData(); + auto s2_begin = s2.GetData(); + return duckdb_jaro_winkler::jaro_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, s2_begin + s2.GetSize(), + score_cutoff); +} + +static inline double JaroWinklerScalarFunction(const string_t &s1, const string_t &s2, + const double_t &score_cutoff = 0.0) { + auto s1_begin = s1.GetData(); + auto s2_begin = s2.GetData(); + return duckdb_jaro_winkler::jaro_winkler_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, + s2_begin + s2.GetSize(), 0.1, score_cutoff); +} + +template +static void CachedFunction(Vector &constant, Vector &other, Vector &result, DataChunk &args) { + auto val = constant.GetValue(0); + idx_t count = args.size(); + if (val.IsNull()) { + auto &result_validity = FlatVector::Validity(result); + result_validity.SetAllInvalid(count); + return; + } + + auto str_val = StringValue::Get(val); + auto cached = CACHED_SIMILARITY(str_val); + + D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3); + if (args.ColumnCount() == 2) { + UnaryExecutor::Execute(other, result, count, [&](const string_t &other_str) { + auto other_str_begin = other_str.GetData(); + return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize()); + }); + } else { + auto score_cutoff = args.data[2]; + BinaryExecutor::Execute( + other, score_cutoff, result, count, [&](const string_t &other_str, const double_t score_cutoff) { + auto other_str_begin = other_str.GetData(); + return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize(), score_cutoff); + }); + } +} + +template +static void TemplatedJaroWinklerFunction(DataChunk &args, Vector &result, SIMILARITY_FUNCTION fun) { + bool arg0_constant = args.data[0].GetVectorType() == VectorType::CONSTANT_VECTOR; + bool arg1_constant = args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR; + if (!(arg0_constant ^ arg1_constant)) { + // We can't optimize by caching one of the two strings + D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3); + if (args.ColumnCount() == 2) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), + [&](const string_t &s1, const string_t &s2) { return fun(s1, s2, 0.0); }); + return; + } else { + TernaryExecutor::Execute(args.data[0], args.data[1], args.data[2], + result, args.size(), fun); + return; + } + } + + if (arg0_constant) { + CachedFunction(args.data[0], args.data[1], result, args); + } else { + CachedFunction(args.data[1], args.data[0], result, args); + } +} + +static void JaroFunction(DataChunk &args, ExpressionState &state, Vector &result) { + TemplatedJaroWinklerFunction>(args, result, JaroScalarFunction); +} + +static void JaroWinklerFunction(DataChunk &args, ExpressionState &state, Vector &result) { + TemplatedJaroWinklerFunction>(args, result, + JaroWinklerScalarFunction); +} + +ScalarFunctionSet JaroSimilarityFun::GetFunctions() { + ScalarFunctionSet jaro; + + const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); + auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroFunction); + jaro.AddFunction(fun); + + fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::DOUBLE}, LogicalType::DOUBLE, + JaroFunction); + jaro.AddFunction(fun); + return jaro; +} + +ScalarFunctionSet JaroWinklerSimilarityFun::GetFunctions() { + ScalarFunctionSet jaroWinkler; + + const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); + auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroWinklerFunction); + jaroWinkler.AddFunction(fun); + + fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::DOUBLE}, LogicalType::DOUBLE, + JaroWinklerFunction); + jaroWinkler.AddFunction(fun); + return jaroWinkler; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/left_right.cpp b/src/duckdb/extension/core_functions/scalar/string/left_right.cpp new file mode 100644 index 00000000..b13ff956 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/left_right.cpp @@ -0,0 +1,100 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/function/scalar/string_common.hpp" + +#include +#include + +namespace duckdb { + +struct LeftRightUnicode { + template + static inline TR Operation(TA input) { + return Length(input); + } + + static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { + return SubstringUnicode(result, input, offset, length); + } +}; + +struct LeftRightGrapheme { + template + static inline TR Operation(TA input) { + return GraphemeCount(input); + } + + static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { + return SubstringGrapheme(result, input, offset, length); + } +}; + +template +static string_t LeftScalarFunction(Vector &result, const string_t str, int64_t pos) { + if (pos >= 0) { + return OP::Substring(result, str, 1, pos); + } + + int64_t num_characters = OP::template Operation(str); + pos = MaxValue(0, num_characters + pos); + return OP::Substring(result, str, 1, pos); +} + +template +static void LeftFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &pos_vec = args.data[1]; + + BinaryExecutor::Execute( + str_vec, pos_vec, result, args.size(), + [&](string_t str, int64_t pos) { return LeftScalarFunction(result, str, pos); }); +} + +ScalarFunction LeftFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, + LeftFunction); +} + +ScalarFunction LeftGraphemeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, + LeftFunction); +} + +template +static string_t RightScalarFunction(Vector &result, const string_t str, int64_t pos) { + int64_t num_characters = OP::template Operation(str); + if (pos >= 0) { + int64_t len = MinValue(num_characters, pos); + int64_t start = num_characters - len + 1; + return OP::Substring(result, str, start, len); + } + + int64_t len = 0; + if (pos != std::numeric_limits::min()) { + len = num_characters - MinValue(num_characters, -pos); + } + int64_t start = num_characters - len + 1; + return OP::Substring(result, str, start, len); +} + +template +static void RightFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &pos_vec = args.data[1]; + BinaryExecutor::Execute( + str_vec, pos_vec, result, args.size(), + [&](string_t str, int64_t pos) { return RightScalarFunction(result, str, pos); }); +} + +ScalarFunction RightFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, + RightFunction); +} + +ScalarFunction RightGraphemeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, + RightFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/levenshtein.cpp b/src/duckdb/extension/core_functions/scalar/string/levenshtein.cpp new file mode 100644 index 00000000..24e28b89 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/levenshtein.cpp @@ -0,0 +1,84 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/string_util.hpp" + +#include +#include + +namespace duckdb { + +// See: https://www.kdnuggets.com/2020/10/optimizing-levenshtein-distance-measuring-text-similarity.html +// And: Iterative 2-row algorithm: https://en.wikipedia.org/wiki/Levenshtein_distance +// Note: A first implementation using the array algorithm version resulted in an error raised by duckdb +// (too muach memory usage) + +static idx_t LevenshteinDistance(const string_t &txt, const string_t &tgt) { + auto txt_len = txt.GetSize(); + auto tgt_len = tgt.GetSize(); + + // If one string is empty, the distance equals the length of the other string + if (txt_len == 0) { + return tgt_len; + } else if (tgt_len == 0) { + return txt_len; + } + + auto txt_str = txt.GetData(); + auto tgt_str = tgt.GetData(); + + // Create two working vectors + vector distances0(tgt_len + 1, 0); + vector distances1(tgt_len + 1, 0); + + idx_t cost_substitution = 0; + idx_t cost_insertion = 0; + idx_t cost_deletion = 0; + + // initialize distances0 vector + // edit distance for an empty txt string is just the number of characters to delete from tgt + for (idx_t pos_tgt = 0; pos_tgt <= tgt_len; pos_tgt++) { + distances0[pos_tgt] = pos_tgt; + } + + for (idx_t pos_txt = 0; pos_txt < txt_len; pos_txt++) { + // calculate distances1 (current raw distances) from the previous row + + distances1[0] = pos_txt + 1; + + for (idx_t pos_tgt = 0; pos_tgt < tgt_len; pos_tgt++) { + cost_deletion = distances0[pos_tgt + 1] + 1; + cost_insertion = distances1[pos_tgt] + 1; + cost_substitution = distances0[pos_tgt]; + + if (txt_str[pos_txt] != tgt_str[pos_tgt]) { + cost_substitution += 1; + } + + distances1[pos_tgt + 1] = MinValue(cost_deletion, MinValue(cost_substitution, cost_insertion)); + } + // copy distances1 (current row) to distances0 (previous row) for next iteration + // since data in distances1 is always invalidated, a swap without copy is more efficient + distances0 = distances1; + } + + return distances0[tgt_len]; +} + +static int64_t LevenshteinScalarFunction(Vector &result, const string_t str, string_t tgt) { + return (int64_t)LevenshteinDistance(str, tgt); +} + +static void LevenshteinFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &tgt_vec = args.data[1]; + + BinaryExecutor::Execute( + str_vec, tgt_vec, result, args.size(), + [&](string_t str, string_t tgt) { return LevenshteinScalarFunction(result, str, tgt); }); +} + +ScalarFunction LevenshteinFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, LevenshteinFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/pad.cpp b/src/duckdb/extension/core_functions/scalar/string/pad.cpp new file mode 100644 index 00000000..586e1605 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/pad.cpp @@ -0,0 +1,147 @@ +#include "core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "duckdb/common/pair.hpp" + +#include "utf8proc.hpp" + +namespace duckdb { + +static pair PadCountChars(const idx_t len, const char *data, const idx_t size) { + // Count how much of str will fit in the output + auto str = reinterpret_cast(data); + idx_t nbytes = 0; + idx_t nchars = 0; + for (; nchars < len && nbytes < size; ++nchars) { + utf8proc_int32_t codepoint; + auto bytes = utf8proc_iterate(str + nbytes, UnsafeNumericCast(size - nbytes), &codepoint); + D_ASSERT(bytes > 0); + nbytes += UnsafeNumericCast(bytes); + } + + return pair(nbytes, nchars); +} + +static bool InsertPadding(const idx_t len, const string_t &pad, vector &result) { + // Copy the padding until the output is long enough + auto data = pad.GetData(); + auto size = pad.GetSize(); + + // Check whether we need data that we don't have + if (len > 0 && size == 0) { + return false; + } + + // Insert characters until we have all we need. + auto str = reinterpret_cast(data); + idx_t nbytes = 0; + for (idx_t nchars = 0; nchars < len; ++nchars) { + // If we are at the end of the pad, flush all of it and loop back + if (nbytes >= size) { + result.insert(result.end(), data, data + size); + nbytes = 0; + } + + // Write the next character + utf8proc_int32_t codepoint; + auto bytes = utf8proc_iterate(str + nbytes, UnsafeNumericCast(size - nbytes), &codepoint); + D_ASSERT(bytes > 0); + nbytes += UnsafeNumericCast(bytes); + } + + // Flush the remaining pad + result.insert(result.end(), data, data + nbytes); + + return true; +} + +static string_t LeftPadFunction(const string_t &str, const int32_t len, const string_t &pad, vector &result) { + // Reuse the buffer + result.clear(); + + // Get information about the base string + auto data_str = str.GetData(); + auto size_str = str.GetSize(); + + // Count how much of str will fit in the output + auto written = PadCountChars(UnsafeNumericCast(len), data_str, size_str); + + // Left pad by the number of characters still needed + if (!InsertPadding(UnsafeNumericCast(len) - written.second, pad, result)) { + throw InvalidInputException("Insufficient padding in LPAD."); + } + + // Append as much of the original string as fits + result.insert(result.end(), data_str, data_str + written.first); + + return string_t(result.data(), UnsafeNumericCast(result.size())); +} + +struct LeftPadOperator { + static inline string_t Operation(const string_t &str, const int32_t len, const string_t &pad, + vector &result) { + return LeftPadFunction(str, len, pad, result); + } +}; + +static string_t RightPadFunction(const string_t &str, const int32_t len, const string_t &pad, vector &result) { + // Reuse the buffer + result.clear(); + + // Get information about the base string + auto data_str = str.GetData(); + auto size_str = str.GetSize(); + + // Count how much of str will fit in the output + auto written = PadCountChars(UnsafeNumericCast(len), data_str, size_str); + + // Append as much of the original string as fits + result.insert(result.end(), data_str, data_str + written.first); + + // Right pad by the number of characters still needed + if (!InsertPadding(UnsafeNumericCast(len) - written.second, pad, result)) { + throw InvalidInputException("Insufficient padding in RPAD."); + }; + + return string_t(result.data(), UnsafeNumericCast(result.size())); +} + +struct RightPadOperator { + static inline string_t Operation(const string_t &str, const int32_t len, const string_t &pad, + vector &result) { + return RightPadFunction(str, len, pad, result); + } +}; + +template +static void PadFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vector = args.data[0]; + auto &len_vector = args.data[1]; + auto &pad_vector = args.data[2]; + + vector buffer; + TernaryExecutor::Execute( + str_vector, len_vector, pad_vector, result, args.size(), [&](string_t str, int32_t len, string_t pad) { + len = MaxValue(len, 0); + return StringVector::AddString(result, OP::Operation(str, len, pad, buffer)); + }); +} + +ScalarFunction LpadFun::GetFunction() { + ScalarFunction func({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, + PadFunction); + BaseScalarFunction::SetReturnsError(func); + return func; +} + +ScalarFunction RpadFun::GetFunction() { + ScalarFunction func({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, + PadFunction); + BaseScalarFunction::SetReturnsError(func); + return func; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/parse_path.cpp b/src/duckdb/extension/core_functions/scalar/string/parse_path.cpp new file mode 100644 index 00000000..9ed926b4 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/parse_path.cpp @@ -0,0 +1,348 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/function/scalar/string_common.hpp" +#include "duckdb/common/local_file_system.hpp" +#include + +namespace duckdb { + +static string GetSeparator(const string_t &input) { + string option = input.GetString(); + + // system's path separator + auto fs = FileSystem::CreateLocal(); + auto system_sep = fs->PathSeparator(option); + + string separator; + if (option == "system") { + separator = system_sep; + } else if (option == "forward_slash") { + separator = "/"; + } else if (option == "backslash") { + separator = "\\"; + } else { // both_slash (default) + separator = "/\\"; + } + return separator; +} + +struct SplitInput { + SplitInput(Vector &result_list, Vector &result_child, idx_t offset) + : result_list(result_list), result_child(result_child), offset(offset) { + } + + Vector &result_list; + Vector &result_child; + idx_t offset; + + void AddSplit(const char *split_data, idx_t split_size, idx_t list_idx) { + auto list_entry = offset + list_idx; + if (list_entry >= ListVector::GetListCapacity(result_list)) { + ListVector::SetListSize(result_list, offset + list_idx); + ListVector::Reserve(result_list, ListVector::GetListCapacity(result_list) * 2); + } + FlatVector::GetData(result_child)[list_entry] = + StringVector::AddString(result_child, split_data, split_size); + } +}; + +static bool IsIdxValid(const idx_t &i, const idx_t &sentence_size) { + if (i > sentence_size || i == DConstants::INVALID_INDEX) { + return false; + } + return true; +} + +static idx_t Find(const char *input_data, idx_t input_size, const string &sep_data) { + if (sep_data.empty()) { + return 0; + } + auto pos = FindStrInStr(const_uchar_ptr_cast(input_data), input_size, const_uchar_ptr_cast(&sep_data[0]), 1); + // both_slash option + if (sep_data.size() > 1) { + auto sec_pos = + FindStrInStr(const_uchar_ptr_cast(input_data), input_size, const_uchar_ptr_cast(&sep_data[1]), 1); + // choose the leftmost valid position + if (sec_pos != DConstants::INVALID_INDEX && (sec_pos < pos || pos == DConstants::INVALID_INDEX)) { + return sec_pos; + } + } + return pos; +} + +static idx_t FindLast(const char *data_ptr, idx_t input_size, const string &sep_data) { + idx_t start = 0; + while (input_size > 0) { + auto pos = Find(data_ptr, input_size, sep_data); + if (!IsIdxValid(pos, input_size)) { + break; + } + start += (pos + 1); + data_ptr += (pos + 1); + input_size -= (pos + 1); + } + if (start < 1) { + return DConstants::INVALID_INDEX; + } + return start - 1; +} + +static idx_t SplitPath(string_t input, const string &sep, SplitInput &state) { + auto input_data = input.GetData(); + auto input_size = input.GetSize(); + if (!input_size) { + return 0; + } + idx_t list_idx = 0; + while (input_size > 0) { + auto pos = Find(input_data, input_size, sep); + if (!IsIdxValid(pos, input_size)) { + break; + } + + D_ASSERT(input_size >= pos); + if (pos == 0) { + if (list_idx == 0) { // first character in path is separator + state.AddSplit(input_data, 1, list_idx); + list_idx++; + if (input_size == 1) { // special case: the only character in path is a separator + return list_idx; + } + } // else: separator is in the path + } else { + state.AddSplit(input_data, pos, list_idx); + list_idx++; + } + input_data += (pos + 1); + input_size -= (pos + 1); + } + if (input_size > 0) { + state.AddSplit(input_data, input_size, list_idx); + list_idx++; + } + return list_idx; +} + +static void ReadOptionalArgs(DataChunk &args, Vector &sep, Vector &trim, const bool &front_trim) { + switch (args.ColumnCount()) { + case 1: { + // use default values + break; + } + case 2: { + UnifiedVectorFormat sec_arg; + args.data[1].ToUnifiedFormat(args.size(), sec_arg); + if (sec_arg.validity.RowIsValid(0)) { // if not NULL + switch (args.data[1].GetType().id()) { + case LogicalTypeId::VARCHAR: { + sep.Reinterpret(args.data[1]); + break; + } + case LogicalTypeId::BOOLEAN: { // parse_path and parse_driname won't get in here + trim.Reinterpret(args.data[1]); + break; + } + default: + throw InvalidInputException("Invalid argument type"); + } + } + break; + } + case 3: { + if (!front_trim) { + // set trim_extension + UnifiedVectorFormat sec_arg; + args.data[1].ToUnifiedFormat(args.size(), sec_arg); + if (sec_arg.validity.RowIsValid(0)) { + trim.Reinterpret(args.data[1]); + } + UnifiedVectorFormat third_arg; + args.data[2].ToUnifiedFormat(args.size(), third_arg); + if (third_arg.validity.RowIsValid(0)) { + sep.Reinterpret(args.data[2]); + } + } else { + throw InvalidInputException("Invalid number of arguments"); + } + break; + } + default: + throw InvalidInputException("Invalid number of arguments"); + } +} + +template +static void TrimPathFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // set default values + Vector &path = args.data[0]; + Vector separator(string_t("default")); + Vector trim_extension(Value::BOOLEAN(false)); + ReadOptionalArgs(args, separator, trim_extension, FRONT_TRIM); + + TernaryExecutor::Execute( + path, separator, trim_extension, result, args.size(), + [&](string_t &inputs, string_t input_sep, bool trim_extension) { + auto data = inputs.GetData(); + auto input_size = inputs.GetSize(); + auto sep = GetSeparator(input_sep.GetString()); + + // find the beginning idx and the size of the result string + idx_t begin = 0; + idx_t new_size = input_size; + if (FRONT_TRIM) { // left trim + auto pos = Find(data, input_size, sep); + if (pos == 0) { // path starts with separator + pos = 1; + } + new_size = (IsIdxValid(pos, input_size)) ? pos : 0; + } else { // right trim + auto idx_last_sep = FindLast(data, input_size, sep); + if (IsIdxValid(idx_last_sep, input_size)) { + begin = idx_last_sep + 1; + } + if (trim_extension) { + auto idx_extension_sep = FindLast(data, input_size, "."); + if (begin <= idx_extension_sep && IsIdxValid(idx_extension_sep, input_size)) { + new_size = idx_extension_sep; + } + } + } + // copy the trimmed string + D_ASSERT(begin <= new_size); + auto target = StringVector::EmptyString(result, new_size - begin); + auto output = target.GetDataWriteable(); + memcpy(output, data + begin, new_size - begin); + + target.Finalize(); + return target; + }); +} + +static void ParseDirpathFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // set default values + Vector &path = args.data[0]; + Vector separator(string_t("default")); + Vector trim_extension(false); + ReadOptionalArgs(args, separator, trim_extension, true); + + BinaryExecutor::Execute( + path, separator, result, args.size(), [&](string_t input_path, string_t input_sep) { + auto path = input_path.GetData(); + auto path_size = input_path.GetSize(); + auto sep = GetSeparator(input_sep.GetString()); + + auto last_sep = FindLast(path, path_size, sep); + if (last_sep == 0 && path_size == 1) { + last_sep = 1; + } + idx_t new_size = (IsIdxValid(last_sep, path_size)) ? last_sep : 0; + + auto target = StringVector::EmptyString(result, new_size); + auto output = target.GetDataWriteable(); + memcpy(output, path, new_size); + target.Finalize(); + return StringVector::AddString(result, target); + }); +} + +static void ParsePathFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1 || args.ColumnCount() == 2); + UnifiedVectorFormat input_data; + args.data[0].ToUnifiedFormat(args.size(), input_data); + auto inputs = UnifiedVectorFormat::GetData(input_data); + + // set the separator + string input_sep = "default"; + if (args.ColumnCount() == 2) { + UnifiedVectorFormat sep_data; + args.data[1].ToUnifiedFormat(args.size(), sep_data); + if (sep_data.validity.RowIsValid(0)) { + input_sep = UnifiedVectorFormat::GetData(sep_data)->GetString(); + } + } + const string sep = GetSeparator(input_sep); + + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + result.SetVectorType(VectorType::FLAT_VECTOR); + ListVector::SetListSize(result, 0); + + // set up the list entries + auto list_data = FlatVector::GetData(result); + auto &child_entry = ListVector::GetEntry(result); + auto &result_mask = FlatVector::Validity(result); + idx_t total_splits = 0; + for (idx_t i = 0; i < args.size(); i++) { + auto input_idx = input_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(input_idx)) { + result_mask.SetInvalid(i); + continue; + } + SplitInput split_input(result, child_entry, total_splits); + auto list_length = SplitPath(inputs[input_idx], sep, split_input); + list_data[i].length = list_length; + list_data[i].offset = total_splits; + total_splits += list_length; + } + ListVector::SetListSize(result, total_splits); + D_ASSERT(ListVector::GetListSize(result) == total_splits); + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +ScalarFunctionSet ParseDirnameFun::GetFunctions() { + ScalarFunctionSet parse_dirname; + ScalarFunction func({LogicalType::VARCHAR}, LogicalType::VARCHAR, TrimPathFunction, nullptr, nullptr, nullptr, + nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, + FunctionNullHandling::SPECIAL_HANDLING); + parse_dirname.AddFunction(func); + // separator options + func.arguments.emplace_back(LogicalType::VARCHAR); + parse_dirname.AddFunction(func); + return parse_dirname; +} + +ScalarFunctionSet ParseDirpathFun::GetFunctions() { + ScalarFunctionSet parse_dirpath; + ScalarFunction func({LogicalType::VARCHAR}, LogicalType::VARCHAR, ParseDirpathFunction, nullptr, nullptr, nullptr, + nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, + FunctionNullHandling::SPECIAL_HANDLING); + parse_dirpath.AddFunction(func); + // separator options + func.arguments.emplace_back(LogicalType::VARCHAR); + parse_dirpath.AddFunction(func); + return parse_dirpath; +} + +ScalarFunctionSet ParseFilenameFun::GetFunctions() { + ScalarFunctionSet parse_filename; + parse_filename.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, TrimPathFunction, + nullptr, nullptr, nullptr, nullptr, LogicalType::INVALID, + FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); + parse_filename.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, TrimPathFunction, nullptr, nullptr, + nullptr, nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); + parse_filename.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::BOOLEAN}, LogicalType::VARCHAR, TrimPathFunction, nullptr, nullptr, + nullptr, nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); + parse_filename.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BOOLEAN, LogicalType::VARCHAR}, + LogicalType::VARCHAR, TrimPathFunction, nullptr, nullptr, nullptr, + nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, + FunctionNullHandling::SPECIAL_HANDLING)); + return parse_filename; +} + +ScalarFunctionSet ParsePathFun::GetFunctions() { + auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); + ScalarFunctionSet parse_path; + ScalarFunction func({LogicalType::VARCHAR}, varchar_list_type, ParsePathFunction, nullptr, nullptr, nullptr, + nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, + FunctionNullHandling::SPECIAL_HANDLING); + parse_path.AddFunction(func); + // separator options + func.arguments.emplace_back(LogicalType::VARCHAR); + parse_path.AddFunction(func); + return parse_path; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/printf.cpp b/src/duckdb/extension/core_functions/scalar/string/printf.cpp new file mode 100644 index 00000000..1db25b0d --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/printf.cpp @@ -0,0 +1,189 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/limits.hpp" +#include "fmt/format.h" +#include "fmt/printf.h" + +namespace duckdb { + +struct FMTPrintf { + template + static string OP(const char *format_str, vector> &format_args) { + return duckdb_fmt::vsprintf( + format_str, duckdb_fmt::basic_format_args(format_args.data(), static_cast(format_args.size()))); + } +}; + +struct FMTFormat { + template + static string OP(const char *format_str, vector> &format_args) { + return duckdb_fmt::vformat( + format_str, duckdb_fmt::basic_format_args(format_args.data(), static_cast(format_args.size()))); + } +}; + +unique_ptr BindPrintfFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + for (idx_t i = 1; i < arguments.size(); i++) { + switch (arguments[i]->return_type.id()) { + case LogicalTypeId::BOOLEAN: + bound_function.arguments.emplace_back(LogicalType::BOOLEAN); + break; + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + bound_function.arguments.emplace_back(LogicalType::BIGINT); + break; + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + bound_function.arguments.emplace_back(LogicalType::UBIGINT); + break; + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + bound_function.arguments.emplace_back(LogicalType::DOUBLE); + break; + case LogicalTypeId::VARCHAR: + bound_function.arguments.push_back(LogicalType::VARCHAR); + break; + case LogicalTypeId::DECIMAL: + // decimal type: add cast to double + bound_function.arguments.emplace_back(LogicalType::DOUBLE); + break; + case LogicalTypeId::UNKNOWN: + // parameter: accept any input and rebind later + bound_function.arguments.emplace_back(LogicalType::ANY); + break; + default: + // all other types: add cast to string + bound_function.arguments.emplace_back(LogicalType::VARCHAR); + break; + } + } + return nullptr; +} + +template +static void PrintfFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &format_string = args.data[0]; + auto &result_validity = FlatVector::Validity(result); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + result_validity.Initialize(args.size()); + for (idx_t i = 0; i < args.ColumnCount(); i++) { + switch (args.data[i].GetVectorType()) { + case VectorType::CONSTANT_VECTOR: + if (ConstantVector::IsNull(args.data[i])) { + // constant null! result is always NULL regardless of other input + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + break; + default: + // FLAT VECTOR, we can directly OR the nullmask + args.data[i].Flatten(args.size()); + result.SetVectorType(VectorType::FLAT_VECTOR); + result_validity.Combine(FlatVector::Validity(args.data[i]), args.size()); + break; + } + } + idx_t count = result.GetVectorType() == VectorType::CONSTANT_VECTOR ? 1 : args.size(); + + auto format_data = FlatVector::GetData(format_string); + auto result_data = FlatVector::GetData(result); + for (idx_t idx = 0; idx < count; idx++) { + if (result.GetVectorType() == VectorType::FLAT_VECTOR && FlatVector::IsNull(result, idx)) { + // this entry is NULL: skip it + continue; + } + + // first fetch the format string + auto fmt_idx = format_string.GetVectorType() == VectorType::CONSTANT_VECTOR ? 0 : idx; + auto format_string = format_data[fmt_idx].GetString(); + + // now gather all the format arguments + vector> format_args; + vector> string_args; + + for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { + auto &col = args.data[col_idx]; + idx_t arg_idx = col.GetVectorType() == VectorType::CONSTANT_VECTOR ? 0 : idx; + switch (col.GetType().id()) { + case LogicalTypeId::BOOLEAN: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::TINYINT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::SMALLINT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::INTEGER: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::BIGINT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::UBIGINT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::FLOAT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::DOUBLE: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::VARCHAR: { + auto arg_data = FlatVector::GetData(col); + auto string_view = + duckdb_fmt::basic_string_view(arg_data[arg_idx].GetData(), arg_data[arg_idx].GetSize()); + format_args.emplace_back(duckdb_fmt::internal::make_arg(string_view)); + break; + } + default: + throw InternalException("Unexpected type for printf format"); + } + } + // finally actually perform the format + string dynamic_result = FORMAT_FUN::template OP(format_string.c_str(), format_args); + result_data[idx] = StringVector::AddString(result, dynamic_result); + } +} + +ScalarFunction PrintfFun::GetFunction() { + // duckdb_fmt::printf_context, duckdb_fmt::vsprintf + ScalarFunction printf_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, + PrintfFunction, BindPrintfFunction); + printf_fun.varargs = LogicalType::ANY; + BaseScalarFunction::SetReturnsError(printf_fun); + return printf_fun; +} + +ScalarFunction FormatFun::GetFunction() { + // duckdb_fmt::format_context, duckdb_fmt::vformat + ScalarFunction format_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, + PrintfFunction, BindPrintfFunction); + format_fun.varargs = LogicalType::ANY; + BaseScalarFunction::SetReturnsError(format_fun); + return format_fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp new file mode 100644 index 00000000..154634f9 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp @@ -0,0 +1,90 @@ +#include "duckdb/common/vector_operations/binary_executor.hpp" +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/common/operator/multiply.hpp" + +namespace duckdb { + +static void RepeatFunction(DataChunk &args, ExpressionState &, Vector &result) { + auto &str_vector = args.data[0]; + auto &cnt_vector = args.data[1]; + + BinaryExecutor::Execute( + str_vector, cnt_vector, result, args.size(), [&](string_t str, int64_t cnt) { + auto input_str = str.GetData(); + auto size_str = str.GetSize(); + idx_t copy_count = cnt <= 0 || size_str == 0 ? 0 : UnsafeNumericCast(cnt); + + idx_t copy_size; + if (TryMultiplyOperator::Operation(size_str, copy_count, copy_size)) { + auto result_str = StringVector::EmptyString(result, copy_size); + auto result_data = result_str.GetDataWriteable(); + for (idx_t i = 0; i < copy_count; i++) { + memcpy(result_data + i * size_str, input_str, size_str); + } + result_str.Finalize(); + return result_str; + } else { + throw OutOfRangeException( + "Cannot create a string of size: '%d' * '%d', the maximum supported string size is: '%d'", size_str, + copy_count, string_t::MAX_STRING_SIZE); + } + }); +} + +unique_ptr RepeatBindFunction(ClientContext &, ScalarFunction &bound_function, + vector> &arguments) { + switch (arguments[0]->return_type.id()) { + case LogicalTypeId::UNKNOWN: + throw ParameterNotResolvedException(); + case LogicalTypeId::LIST: + break; + default: + throw NotImplementedException("repeat(list, count) requires a list as parameter"); + } + bound_function.arguments[0] = arguments[0]->return_type; + bound_function.return_type = arguments[0]->return_type; + return nullptr; +} + +static void RepeatListFunction(DataChunk &args, ExpressionState &, Vector &result) { + auto &list_vector = args.data[0]; + auto &cnt_vector = args.data[1]; + + auto &source_child = ListVector::GetEntry(list_vector); + auto &result_child = ListVector::GetEntry(result); + + idx_t current_size = ListVector::GetListSize(result); + BinaryExecutor::Execute( + list_vector, cnt_vector, result, args.size(), [&](list_entry_t list_input, int64_t cnt) { + idx_t copy_count = cnt <= 0 || list_input.length == 0 ? 0 : UnsafeNumericCast(cnt); + idx_t result_length = list_input.length * copy_count; + idx_t new_size = current_size + result_length; + ListVector::Reserve(result, new_size); + list_entry_t result_list; + result_list.offset = current_size; + result_list.length = result_length; + for (idx_t i = 0; i < copy_count; i++) { + // repeat the list contents "cnt" times + VectorOperations::Copy(source_child, result_child, list_input.offset + list_input.length, + list_input.offset, current_size); + current_size += list_input.length; + } + return result_list; + }); + ListVector::SetListSize(result, current_size); +} + +ScalarFunctionSet RepeatFun::GetFunctions() { + ScalarFunctionSet repeat; + for (const auto &type : {LogicalType::VARCHAR, LogicalType::BLOB}) { + repeat.AddFunction(ScalarFunction({type, LogicalType::BIGINT}, type, RepeatFunction)); + } + repeat.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::ANY), RepeatListFunction, RepeatBindFunction)); + for (auto &func : repeat.functions) { + BaseScalarFunction::SetReturnsError(func); + } + return repeat; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/replace.cpp b/src/duckdb/extension/core_functions/scalar/string/replace.cpp new file mode 100644 index 00000000..4702292c --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/replace.cpp @@ -0,0 +1,84 @@ +#include "core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" + +#include +#include +#include + +namespace duckdb { + +static idx_t NextNeedle(const char *input_haystack, idx_t size_haystack, const char *input_needle, + const idx_t size_needle) { + // Needle needs something to proceed + if (size_needle > 0) { + // Haystack should be bigger or equal size to the needle + for (idx_t string_position = 0; (size_haystack - string_position) >= size_needle; ++string_position) { + // Compare Needle to the Haystack + if ((memcmp(input_haystack + string_position, input_needle, size_needle) == 0)) { + return string_position; + } + } + } + // Did not find the needle + return size_haystack; +} + +static string_t ReplaceScalarFunction(const string_t &haystack, const string_t &needle, const string_t &thread, + vector &result) { + // Get information about the needle, the haystack and the "thread" + auto input_haystack = haystack.GetData(); + auto size_haystack = haystack.GetSize(); + + auto input_needle = needle.GetData(); + auto size_needle = needle.GetSize(); + + auto input_thread = thread.GetData(); + auto size_thread = thread.GetSize(); + + // Reuse the buffer + result.clear(); + + for (;;) { + // Append the non-matching characters + auto string_position = NextNeedle(input_haystack, size_haystack, input_needle, size_needle); + result.insert(result.end(), input_haystack, input_haystack + string_position); + input_haystack += string_position; + size_haystack -= string_position; + + // Stop when we have read the entire haystack + if (size_haystack == 0) { + break; + } + + // Replace the matching characters + result.insert(result.end(), input_thread, input_thread + size_thread); + input_haystack += size_needle; + size_haystack -= size_needle; + } + + return string_t(result.data(), UnsafeNumericCast(result.size())); +} + +static void ReplaceFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &haystack_vector = args.data[0]; + auto &needle_vector = args.data[1]; + auto &thread_vector = args.data[2]; + + vector buffer; + TernaryExecutor::Execute( + haystack_vector, needle_vector, thread_vector, result, args.size(), + [&](string_t input_string, string_t needle_string, string_t thread_string) { + return StringVector::AddString(result, + ReplaceScalarFunction(input_string, needle_string, thread_string, buffer)); + }); +} + +ScalarFunction ReplaceFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + ReplaceFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/reverse.cpp b/src/duckdb/extension/core_functions/scalar/string/reverse.cpp new file mode 100644 index 00000000..4ff65490 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/reverse.cpp @@ -0,0 +1,55 @@ +#include "core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "utf8proc_wrapper.hpp" + +#include + +namespace duckdb { + +//! Fast ASCII string reverse, returns false if the input data is not ascii +static bool StrReverseASCII(const char *input, idx_t n, char *output) { + for (idx_t i = 0; i < n; i++) { + if (input[i] & 0x80) { + // non-ascii character + return false; + } + output[n - i - 1] = input[i]; + } + return true; +} + +//! Unicode string reverse using grapheme breakers +static void StrReverseUnicode(const char *input, idx_t n, char *output) { + for (auto cluster : Utf8Proc::GraphemeClusters(input, n)) { + memcpy(output + n - cluster.end, input + cluster.start, cluster.end - cluster.start); + } +} + +struct ReverseOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto input_data = input.GetData(); + auto input_length = input.GetSize(); + + auto target = StringVector::EmptyString(result, input_length); + auto target_data = target.GetDataWriteable(); + if (!StrReverseASCII(input_data, input_length, target_data)) { + StrReverseUnicode(input_data, input_length, target_data); + } + target.Finalize(); + return target; + } +}; + +static void ReverseFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); +} + +ScalarFunction ReverseFun::GetFunction() { + return ScalarFunction("reverse", {LogicalType::VARCHAR}, LogicalType::VARCHAR, ReverseFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp b/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp new file mode 100644 index 00000000..7ef27729 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp @@ -0,0 +1,46 @@ +#include "core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +static bool StartsWith(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, + idx_t needle_size) { + D_ASSERT(needle_size > 0); + if (needle_size > haystack_size) { + // needle is bigger than haystack: haystack cannot start with needle + return false; + } + return memcmp(haystack, needle, needle_size) == 0; +} + +static bool StartsWith(const string_t &haystack_s, const string_t &needle_s) { + + auto haystack = const_uchar_ptr_cast(haystack_s.GetData()); + auto haystack_size = haystack_s.GetSize(); + auto needle = const_uchar_ptr_cast(needle_s.GetData()); + auto needle_size = needle_s.GetSize(); + if (needle_size == 0) { + // empty needle: always true + return true; + } + return StartsWith(haystack, haystack_size, needle, needle_size); +} + +struct StartsWithOperator { + template + static inline TR Operation(TA left, TB right) { + return StartsWith(left, right); + } +}; + +ScalarFunction StartsWithOperatorFun::GetFunction() { + ScalarFunction starts_with({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, + ScalarFunction::BinaryFunction); + starts_with.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + return starts_with; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/to_base.cpp b/src/duckdb/extension/core_functions/scalar/string/to_base.cpp new file mode 100644 index 00000000..f85f54be --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/to_base.cpp @@ -0,0 +1,66 @@ +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" + +namespace duckdb { + +static const char alphabet[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + +static unique_ptr ToBaseBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // If no min_length is specified, default to 0 + D_ASSERT(arguments.size() == 2 || arguments.size() == 3); + if (arguments.size() == 2) { + arguments.push_back(make_uniq_base(Value::INTEGER(0))); + } + return nullptr; +} + +static void ToBaseFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input = args.data[0]; + auto &radix = args.data[1]; + auto &min_length = args.data[2]; + auto count = args.size(); + + TernaryExecutor::Execute( + input, radix, min_length, result, count, [&](int64_t input, int32_t radix, int32_t min_length) { + if (input < 0) { + throw InvalidInputException("'to_base' number must be greater than or equal to 0"); + } + if (radix < 2 || radix > 36) { + throw InvalidInputException("'to_base' radix must be between 2 and 36"); + } + if (min_length > 64 || min_length < 0) { + throw InvalidInputException("'to_base' min_length must be between 0 and 64"); + } + + char buf[64]; + char *end = buf + sizeof(buf); + char *ptr = end; + do { + *--ptr = alphabet[input % radix]; + input /= radix; + } while (input > 0); + + auto length = end - ptr; + while (length < min_length) { + *--ptr = '0'; + length++; + } + + return StringVector::AddString(result, ptr, UnsafeNumericCast(end - ptr)); + }); +} + +ScalarFunctionSet ToBaseFun::GetFunctions() { + ScalarFunctionSet set("to_base"); + + set.AddFunction( + ScalarFunction({LogicalType::BIGINT, LogicalType::INTEGER}, LogicalType::VARCHAR, ToBaseFunction, ToBaseBind)); + set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::INTEGER, LogicalType::INTEGER}, + LogicalType::VARCHAR, ToBaseFunction, ToBaseBind)); + + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/translate.cpp b/src/duckdb/extension/core_functions/scalar/string/translate.cpp new file mode 100644 index 00000000..ca661cb3 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/translate.cpp @@ -0,0 +1,96 @@ +#include "core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/ternary_executor.hpp" +#include "utf8proc.hpp" +#include "utf8proc_wrapper.hpp" + +#include +#include +#include +#include + +namespace duckdb { + +static string_t TranslateScalarFunction(const string_t &haystack, const string_t &needle, const string_t &thread, + vector &result) { + // Get information about the haystack, the needle and the "thread" + auto input_haystack = haystack.GetData(); + auto size_haystack = haystack.GetSize(); + + auto input_needle = needle.GetData(); + auto size_needle = needle.GetSize(); + + auto input_thread = thread.GetData(); + auto size_thread = thread.GetSize(); + + // Reuse the buffer + result.clear(); + result.reserve(size_haystack); + + idx_t i = 0, j = 0; + int sz = 0, c_sz = 0; + + // Character to be replaced + unordered_map to_replace; + while (i < size_needle && j < size_thread) { + auto codepoint_needle = Utf8Proc::UTF8ToCodepoint(input_needle, sz); + input_needle += sz; + i += UnsafeNumericCast(sz); + auto codepoint_thread = Utf8Proc::UTF8ToCodepoint(input_thread, sz); + input_thread += sz; + j += UnsafeNumericCast(sz); + // Ignore unicode character that is existed in to_replace + if (to_replace.count(codepoint_needle) == 0) { + to_replace[codepoint_needle] = codepoint_thread; + } + } + + // Character to be deleted + unordered_set to_delete; + while (i < size_needle) { + auto codepoint_needle = Utf8Proc::UTF8ToCodepoint(input_needle, sz); + input_needle += sz; + i += UnsafeNumericCast(sz); + // Add unicode character that will be deleted + if (to_replace.count(codepoint_needle) == 0) { + to_delete.insert(codepoint_needle); + } + } + + char c[5] = {'\0', '\0', '\0', '\0', '\0'}; + for (i = 0; i < size_haystack; i += UnsafeNumericCast(sz)) { + auto codepoint_haystack = Utf8Proc::UTF8ToCodepoint(input_haystack, sz); + if (to_replace.count(codepoint_haystack) != 0) { + Utf8Proc::CodepointToUtf8(to_replace[codepoint_haystack], c_sz, c); + result.insert(result.end(), c, c + c_sz); + } else if (to_delete.count(codepoint_haystack) == 0) { + result.insert(result.end(), input_haystack, input_haystack + sz); + } + input_haystack += sz; + } + + return string_t(result.data(), UnsafeNumericCast(result.size())); +} + +static void TranslateFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &haystack_vector = args.data[0]; + auto &needle_vector = args.data[1]; + auto &thread_vector = args.data[2]; + + vector buffer; + TernaryExecutor::Execute( + haystack_vector, needle_vector, thread_vector, result, args.size(), + [&](string_t input_string, string_t needle_string, string_t thread_string) { + return StringVector::AddString(result, + TranslateScalarFunction(input_string, needle_string, thread_string, buffer)); + }); +} + +ScalarFunction TranslateFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + TranslateFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/trim.cpp b/src/duckdb/extension/core_functions/scalar/string/trim.cpp new file mode 100644 index 00000000..5553d75e --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/trim.cpp @@ -0,0 +1,158 @@ +#include "core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "utf8proc.hpp" + +#include + +namespace duckdb { + +template +struct TrimOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto data = input.GetData(); + auto size = input.GetSize(); + + utf8proc_int32_t codepoint; + auto str = reinterpret_cast(data); + + // Find the first character that is not left trimmed + idx_t begin = 0; + if (LTRIM) { + while (begin < size) { + auto bytes = + utf8proc_iterate(str + begin, UnsafeNumericCast(size - begin), &codepoint); + D_ASSERT(bytes > 0); + if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { + break; + } + begin += UnsafeNumericCast(bytes); + } + } + + // Find the last character that is not right trimmed + idx_t end; + if (RTRIM) { + end = begin; + for (auto next = begin; next < size;) { + auto bytes = utf8proc_iterate(str + next, UnsafeNumericCast(size - next), &codepoint); + D_ASSERT(bytes > 0); + next += UnsafeNumericCast(bytes); + if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { + end = next; + } + } + } else { + end = size; + } + + // Copy the trimmed string + auto target = StringVector::EmptyString(result, end - begin); + auto output = target.GetDataWriteable(); + memcpy(output, data + begin, end - begin); + + target.Finalize(); + return target; + } +}; + +template +static void UnaryTrimFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteString>(args.data[0], result, args.size()); +} + +static void GetIgnoredCodepoints(string_t ignored, unordered_set &ignored_codepoints) { + auto dataptr = reinterpret_cast(ignored.GetData()); + auto size = ignored.GetSize(); + idx_t pos = 0; + while (pos < size) { + utf8proc_int32_t codepoint; + pos += UnsafeNumericCast( + utf8proc_iterate(dataptr + pos, UnsafeNumericCast(size - pos), &codepoint)); + ignored_codepoints.insert(codepoint); + } +} + +template +static void BinaryTrimFunction(DataChunk &input, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + input.data[0], input.data[1], result, input.size(), [&](string_t input, string_t ignored) { + auto data = input.GetData(); + auto size = input.GetSize(); + + unordered_set ignored_codepoints; + GetIgnoredCodepoints(ignored, ignored_codepoints); + + utf8proc_int32_t codepoint; + auto str = reinterpret_cast(data); + + // Find the first character that is not left trimmed + idx_t begin = 0; + if (LTRIM) { + while (begin < size) { + auto bytes = + utf8proc_iterate(str + begin, UnsafeNumericCast(size - begin), &codepoint); + if (ignored_codepoints.find(codepoint) == ignored_codepoints.end()) { + break; + } + begin += UnsafeNumericCast(bytes); + } + } + + // Find the last character that is not right trimmed + idx_t end; + if (RTRIM) { + end = begin; + for (auto next = begin; next < size;) { + auto bytes = + utf8proc_iterate(str + next, UnsafeNumericCast(size - next), &codepoint); + D_ASSERT(bytes > 0); + next += UnsafeNumericCast(bytes); + if (ignored_codepoints.find(codepoint) == ignored_codepoints.end()) { + end = next; + } + } + } else { + end = size; + } + + // Copy the trimmed string + auto target = StringVector::EmptyString(result, end - begin); + auto output = target.GetDataWriteable(); + memcpy(output, data + begin, end - begin); + + target.Finalize(); + return target; + }); +} + +ScalarFunctionSet TrimFun::GetFunctions() { + ScalarFunctionSet trim; + trim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); + + trim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + BinaryTrimFunction)); + return trim; +} + +ScalarFunctionSet LtrimFun::GetFunctions() { + ScalarFunctionSet ltrim; + ltrim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); + ltrim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + BinaryTrimFunction)); + return ltrim; +} + +ScalarFunctionSet RtrimFun::GetFunctions() { + ScalarFunctionSet rtrim; + rtrim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); + + rtrim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, + BinaryTrimFunction)); + return rtrim; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/unicode.cpp b/src/duckdb/extension/core_functions/scalar/string/unicode.cpp new file mode 100644 index 00000000..902c7c5e --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/unicode.cpp @@ -0,0 +1,28 @@ +#include "core_functions/scalar/string_functions.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "utf8proc.hpp" + +#include + +namespace duckdb { + +struct UnicodeOperator { + template + static inline TR Operation(const TA &input) { + auto str = reinterpret_cast(input.GetData()); + auto len = input.GetSize(); + utf8proc_int32_t codepoint; + (void)utf8proc_iterate(str, UnsafeNumericCast(len), &codepoint); + return codepoint; + } +}; + +ScalarFunction UnicodeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::INTEGER, + ScalarFunction::UnaryFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/url_encode.cpp b/src/duckdb/extension/core_functions/scalar/string/url_encode.cpp new file mode 100644 index 00000000..17b9ad3c --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/string/url_encode.cpp @@ -0,0 +1,49 @@ +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "core_functions/scalar/string_functions.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +struct URLEncodeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto input_str = input.GetData(); + auto input_size = input.GetSize(); + idx_t result_length = StringUtil::URLEncodeSize(input_str, input_size); + auto result_str = StringVector::EmptyString(result, result_length); + StringUtil::URLEncodeBuffer(input_str, input_size, result_str.GetDataWriteable()); + result_str.Finalize(); + return result_str; + } +}; + +static void URLEncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); +} + +ScalarFunction UrlEncodeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, URLEncodeFunction); +} + +struct URLDecodeOperator { + template + static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { + auto input_str = input.GetData(); + auto input_size = input.GetSize(); + idx_t result_length = StringUtil::URLDecodeSize(input_str, input_size); + auto result_str = StringVector::EmptyString(result, result_length); + StringUtil::URLDecodeBuffer(input_str, input_size, result_str.GetDataWriteable()); + result_str.Finalize(); + return result_str; + } +}; + +static void URLDecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteString(args.data[0], result, args.size()); +} + +ScalarFunction UrlDecodeFun::GetFunction() { + return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, URLDecodeFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp new file mode 100644 index 00000000..c83a83e3 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp @@ -0,0 +1,103 @@ +#include "core_functions/scalar/struct_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +static void StructInsertFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &starting_vec = args.data[0]; + starting_vec.Verify(args.size()); + + auto &starting_child_entries = StructVector::GetEntries(starting_vec); + auto &result_child_entries = StructVector::GetEntries(result); + + // Assign the original child entries to the STRUCT. + for (idx_t i = 0; i < starting_child_entries.size(); i++) { + auto &starting_child = starting_child_entries[i]; + result_child_entries[i]->Reference(*starting_child); + } + + // Assign the new children to the result vector. + for (idx_t i = 1; i < args.ColumnCount(); i++) { + result_child_entries[starting_child_entries.size() + i - 1]->Reference(args.data[i]); + } + + result.Verify(args.size()); + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +static unique_ptr StructInsertBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.empty()) { + throw InvalidInputException("Missing required arguments for struct_insert function."); + } + if (LogicalTypeId::STRUCT != arguments[0]->return_type.id()) { + throw InvalidInputException("The first argument to struct_insert must be a STRUCT"); + } + if (arguments.size() < 2) { + throw InvalidInputException("Can't insert nothing into a STRUCT"); + } + + case_insensitive_set_t name_collision_set; + child_list_t new_children; + auto &existing_children = StructType::GetChildTypes(arguments[0]->return_type); + + for (idx_t i = 0; i < existing_children.size(); i++) { + auto &child = existing_children[i]; + name_collision_set.insert(child.first); + new_children.push_back(make_pair(child.first, child.second)); + } + + // Loop through the additional arguments (name/value pairs) + for (idx_t i = 1; i < arguments.size(); i++) { + auto &child = arguments[i]; + if (child->GetAlias().empty()) { + throw BinderException("Need named argument for struct insert, e.g., a := b"); + } + if (name_collision_set.find(child->GetAlias()) != name_collision_set.end()) { + throw BinderException("Duplicate struct entry name \"%s\"", child->GetAlias()); + } + name_collision_set.insert(child->GetAlias()); + new_children.push_back(make_pair(child->GetAlias(), arguments[i]->return_type)); + } + + bound_function.return_type = LogicalType::STRUCT(new_children); + return make_uniq(bound_function.return_type); +} + +unique_ptr StructInsertStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + auto new_stats = StructStats::CreateUnknown(expr.return_type); + + auto existing_count = StructType::GetChildCount(child_stats[0].GetType()); + auto existing_stats = StructStats::GetChildStats(child_stats[0]); + for (idx_t i = 0; i < existing_count; i++) { + StructStats::SetChildStats(new_stats, i, existing_stats[i]); + } + + auto new_count = StructType::GetChildCount(expr.return_type); + auto offset = new_count - child_stats.size(); + for (idx_t i = 1; i < child_stats.size(); i++) { + StructStats::SetChildStats(new_stats, offset + i, child_stats[i]); + } + return new_stats.ToUnique(); +} + +ScalarFunction StructInsertFun::GetFunction() { + ScalarFunction fun({}, LogicalTypeId::STRUCT, StructInsertFunction, StructInsertBind, nullptr, StructInsertStats); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.varargs = LogicalType::ANY; + fun.serialize = VariableReturnBindData::Serialize; + fun.deserialize = VariableReturnBindData::Deserialize; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp b/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp new file mode 100644 index 00000000..2a537107 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp @@ -0,0 +1,108 @@ +#include "core_functions/scalar/union_functions.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" + +namespace duckdb { + +struct UnionExtractBindData : public FunctionData { + UnionExtractBindData(string key, idx_t index, LogicalType type) + : key(std::move(key)), index(index), type(std::move(type)) { + } + + string key; + idx_t index; + LogicalType type; + +public: + unique_ptr Copy() const override { + return make_uniq(key, index, type); + } + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return key == other.key && index == other.index && type == other.type; + } +}; + +static void UnionExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + + // this should be guaranteed by the binder + auto &vec = args.data[0]; + vec.Verify(args.size()); + + D_ASSERT(info.index < UnionType::GetMemberCount(vec.GetType())); + auto &member = UnionVector::GetMember(vec, info.index); + result.Reference(member); + result.Verify(args.size()); +} + +static unique_ptr UnionExtractBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 2); + if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + if (arguments[0]->return_type.id() != LogicalTypeId::UNION) { + throw BinderException("union_extract can only take a union parameter"); + } + idx_t union_member_count = UnionType::GetMemberCount(arguments[0]->return_type); + if (union_member_count == 0) { + throw InternalException("Can't extract something from an empty union"); + } + bound_function.arguments[0] = arguments[0]->return_type; + + auto &key_child = arguments[1]; + if (key_child->HasParameter()) { + throw ParameterNotResolvedException(); + } + + if (key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { + throw BinderException("Key name for union_extract needs to be a constant string"); + } + Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); + D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); + auto &key_str = StringValue::Get(key_val); + if (key_val.IsNull() || key_str.empty()) { + throw BinderException("Key name for union_extract needs to be neither NULL nor empty"); + } + string key = StringUtil::Lower(key_str); + + LogicalType return_type; + idx_t key_index = 0; + bool found_key = false; + + for (size_t i = 0; i < union_member_count; i++) { + auto &member_name = UnionType::GetMemberName(arguments[0]->return_type, i); + if (StringUtil::Lower(member_name) == key) { + found_key = true; + key_index = i; + return_type = UnionType::GetMemberType(arguments[0]->return_type, i); + break; + } + } + + if (!found_key) { + vector candidates; + candidates.reserve(union_member_count); + for (idx_t i = 0; i < union_member_count; i++) { + candidates.push_back(UnionType::GetMemberName(arguments[0]->return_type, i)); + } + auto closest_settings = StringUtil::TopNJaroWinkler(candidates, key); + auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); + throw BinderException("Could not find key \"%s\" in union\n%s", key, message); + } + + bound_function.return_type = return_type; + return make_uniq(key, key_index, return_type); +} + +ScalarFunction UnionExtractFun::GetFunction() { + // the arguments and return types are actually set in the binder function + return ScalarFunction({LogicalTypeId::UNION, LogicalType::VARCHAR}, LogicalType::ANY, UnionExtractFunction, + UnionExtractBind, nullptr, nullptr); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp b/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp new file mode 100644 index 00000000..173e36d6 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp @@ -0,0 +1,58 @@ +#include "core_functions/scalar/union_functions.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" + +namespace duckdb { + +static unique_ptr UnionTagBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + if (arguments.empty()) { + throw BinderException("Missing required arguments for union_tag function."); + } + + if (LogicalTypeId::UNKNOWN == arguments[0]->return_type.id()) { + throw ParameterNotResolvedException(); + } + + if (LogicalTypeId::UNION != arguments[0]->return_type.id()) { + throw BinderException("First argument to union_tag function must be a union type."); + } + + if (arguments.size() > 1) { + throw BinderException("Too many arguments, union_tag takes at most one argument."); + } + + auto member_count = UnionType::GetMemberCount(arguments[0]->return_type); + if (member_count == 0) { + // this should never happen, empty unions are not allowed + throw InternalException("Can't get tags from an empty union"); + } + + bound_function.arguments[0] = arguments[0]->return_type; + + auto varchar_vector = Vector(LogicalType::VARCHAR, member_count); + for (idx_t i = 0; i < member_count; i++) { + auto str = string_t(UnionType::GetMemberName(arguments[0]->return_type, i)); + FlatVector::GetData(varchar_vector)[i] = + str.IsInlined() ? str : StringVector::AddString(varchar_vector, str); + } + auto enum_type = LogicalType::ENUM(varchar_vector, member_count); + bound_function.return_type = enum_type; + + return nullptr; +} + +static void UnionTagFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(result.GetType().id() == LogicalTypeId::ENUM); + result.Reinterpret(UnionVector::GetTags(args.data[0])); +} + +ScalarFunction UnionTagFun::GetFunction() { + return ScalarFunction({LogicalTypeId::UNION}, LogicalTypeId::ANY, UnionTagFunction, UnionTagBind, nullptr, + nullptr); // TODO: Statistics? +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/union/union_value.cpp b/src/duckdb/extension/core_functions/scalar/union/union_value.cpp new file mode 100644 index 00000000..655003da --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/union/union_value.cpp @@ -0,0 +1,68 @@ +#include "core_functions/scalar/union_functions.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_parameter_expression.hpp" + +namespace duckdb { + +struct UnionValueBindData : public FunctionData { + UnionValueBindData() { + } + +public: + unique_ptr Copy() const override { + return make_uniq(); + } + bool Equals(const FunctionData &other_p) const override { + return true; + } +}; + +static void UnionValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // Assign the new entries to the result vector + UnionVector::GetMember(result, 0).Reference(args.data[0]); + + // Set the result tag vector to a constant value + auto &tag_vector = UnionVector::GetTags(result); + tag_vector.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::GetData(tag_vector)[0] = 0; + + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + + result.Verify(args.size()); +} + +static unique_ptr UnionValueBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + + if (arguments.size() != 1) { + throw BinderException("union_value takes exactly one argument"); + } + auto &child = arguments[0]; + + if (child->GetAlias().empty()) { + throw BinderException("Need named argument for union tag, e.g. UNION_VALUE(a := b)"); + } + + child_list_t union_members; + + union_members.push_back(make_pair(child->GetAlias(), child->return_type)); + + bound_function.return_type = LogicalType::UNION(std::move(union_members)); + return make_uniq(bound_function.return_type); +} + +ScalarFunction UnionValueFun::GetFunction() { + ScalarFunction fun("union_value", {}, LogicalTypeId::UNION, UnionValueFunction, UnionValueBind, nullptr, nullptr); + fun.varargs = LogicalType::ANY; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.serialize = VariableReturnBindData::Serialize; + fun.deserialize = VariableReturnBindData::Deserialize; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp b/src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp new file mode 100644 index 00000000..ac1532ac --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp @@ -0,0 +1,8 @@ +#include "extension/core_functions/aggregate/algebraic/corr.cpp" + +#include "extension/core_functions/aggregate/algebraic/stddev.cpp" + +#include "extension/core_functions/aggregate/algebraic/avg.cpp" + +#include "extension/core_functions/aggregate/algebraic/covar.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp b/src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp new file mode 100644 index 00000000..22a9ac79 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp @@ -0,0 +1,20 @@ +#include "extension/core_functions/aggregate/distributive/kurtosis.cpp" + +#include "extension/core_functions/aggregate/distributive/string_agg.cpp" + +#include "extension/core_functions/aggregate/distributive/sum.cpp" + +#include "extension/core_functions/aggregate/distributive/arg_min_max.cpp" + +#include "extension/core_functions/aggregate/distributive/approx_count.cpp" + +#include "extension/core_functions/aggregate/distributive/skew.cpp" + +#include "extension/core_functions/aggregate/distributive/bitagg.cpp" + +#include "extension/core_functions/aggregate/distributive/bitstring_agg.cpp" + +#include "extension/core_functions/aggregate/distributive/product.cpp" + +#include "extension/core_functions/aggregate/distributive/bool.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp b/src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp new file mode 100644 index 00000000..7ee6f047 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp @@ -0,0 +1,12 @@ +#include "extension/core_functions/aggregate/holistic/approx_top_k.cpp" + +#include "extension/core_functions/aggregate/holistic/quantile.cpp" + +#include "extension/core_functions/aggregate/holistic/reservoir_quantile.cpp" + +#include "extension/core_functions/aggregate/holistic/mad.cpp" + +#include "extension/core_functions/aggregate/holistic/approximate_quantile.cpp" + +#include "extension/core_functions/aggregate/holistic/mode.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_aggregate_nested.cpp b/src/duckdb/ub_extension_core_functions_aggregate_nested.cpp new file mode 100644 index 00000000..9d9f036b --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_aggregate_nested.cpp @@ -0,0 +1,6 @@ +#include "extension/core_functions/aggregate/nested/binned_histogram.cpp" + +#include "extension/core_functions/aggregate/nested/list.cpp" + +#include "extension/core_functions/aggregate/nested/histogram.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_aggregate_regression.cpp b/src/duckdb/ub_extension_core_functions_aggregate_regression.cpp new file mode 100644 index 00000000..a7d5acb1 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_aggregate_regression.cpp @@ -0,0 +1,14 @@ +#include "extension/core_functions/aggregate/regression/regr_sxy.cpp" + +#include "extension/core_functions/aggregate/regression/regr_intercept.cpp" + +#include "extension/core_functions/aggregate/regression/regr_count.cpp" + +#include "extension/core_functions/aggregate/regression/regr_r2.cpp" + +#include "extension/core_functions/aggregate/regression/regr_avg.cpp" + +#include "extension/core_functions/aggregate/regression/regr_slope.cpp" + +#include "extension/core_functions/aggregate/regression/regr_sxx_syy.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_array.cpp b/src/duckdb/ub_extension_core_functions_scalar_array.cpp new file mode 100644 index 00000000..e4f63a36 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_array.cpp @@ -0,0 +1,4 @@ +#include "extension/core_functions/scalar/array/array_functions.cpp" + +#include "extension/core_functions/scalar/array/array_value.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_bit.cpp b/src/duckdb/ub_extension_core_functions_scalar_bit.cpp new file mode 100644 index 00000000..0e48db86 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_bit.cpp @@ -0,0 +1,2 @@ +#include "extension/core_functions/scalar/bit/bitstring.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_blob.cpp b/src/duckdb/ub_extension_core_functions_scalar_blob.cpp new file mode 100644 index 00000000..1eda3bad --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_blob.cpp @@ -0,0 +1,4 @@ +#include "extension/core_functions/scalar/blob/base64.cpp" + +#include "extension/core_functions/scalar/blob/encode.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_date.cpp b/src/duckdb/ub_extension_core_functions_scalar_date.cpp new file mode 100644 index 00000000..614e5e4e --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_date.cpp @@ -0,0 +1,20 @@ +#include "extension/core_functions/scalar/date/current.cpp" + +#include "extension/core_functions/scalar/date/age.cpp" + +#include "extension/core_functions/scalar/date/date_diff.cpp" + +#include "extension/core_functions/scalar/date/date_sub.cpp" + +#include "extension/core_functions/scalar/date/to_interval.cpp" + +#include "extension/core_functions/scalar/date/time_bucket.cpp" + +#include "extension/core_functions/scalar/date/date_trunc.cpp" + +#include "extension/core_functions/scalar/date/epoch.cpp" + +#include "extension/core_functions/scalar/date/date_part.cpp" + +#include "extension/core_functions/scalar/date/make_date.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_debug.cpp b/src/duckdb/ub_extension_core_functions_scalar_debug.cpp new file mode 100644 index 00000000..f1c3fa82 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_debug.cpp @@ -0,0 +1,2 @@ +#include "extension/core_functions/scalar/debug/vector_type.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_enum.cpp b/src/duckdb/ub_extension_core_functions_scalar_enum.cpp new file mode 100644 index 00000000..74e9bf3f --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_enum.cpp @@ -0,0 +1,2 @@ +#include "extension/core_functions/scalar/enum/enum_functions.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_generic.cpp b/src/duckdb/ub_extension_core_functions_scalar_generic.cpp new file mode 100644 index 00000000..d2458089 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_generic.cpp @@ -0,0 +1,18 @@ +#include "extension/core_functions/scalar/generic/alias.cpp" + +#include "extension/core_functions/scalar/generic/binning.cpp" + +#include "extension/core_functions/scalar/generic/can_implicitly_cast.cpp" + +#include "extension/core_functions/scalar/generic/current_setting.cpp" + +#include "extension/core_functions/scalar/generic/hash.cpp" + +#include "extension/core_functions/scalar/generic/least.cpp" + +#include "extension/core_functions/scalar/generic/stats.cpp" + +#include "extension/core_functions/scalar/generic/typeof.cpp" + +#include "extension/core_functions/scalar/generic/system_functions.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_list.cpp b/src/duckdb/ub_extension_core_functions_scalar_list.cpp new file mode 100644 index 00000000..e3ad4275 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_list.cpp @@ -0,0 +1,22 @@ +#include "extension/core_functions/scalar/list/flatten.cpp" + +#include "extension/core_functions/scalar/list/list_transform.cpp" + +#include "extension/core_functions/scalar/list/range.cpp" + +#include "extension/core_functions/scalar/list/list_value.cpp" + +#include "extension/core_functions/scalar/list/list_filter.cpp" + +#include "extension/core_functions/scalar/list/list_has_any_or_all.cpp" + +#include "extension/core_functions/scalar/list/list_aggregates.cpp" + +#include "extension/core_functions/scalar/list/list_distance.cpp" + +#include "extension/core_functions/scalar/list/array_slice.cpp" + +#include "extension/core_functions/scalar/list/list_sort.cpp" + +#include "extension/core_functions/scalar/list/list_reduce.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_map.cpp b/src/duckdb/ub_extension_core_functions_scalar_map.cpp new file mode 100644 index 00000000..52bd226f --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_map.cpp @@ -0,0 +1,14 @@ +#include "extension/core_functions/scalar/map/map_keys_values.cpp" + +#include "extension/core_functions/scalar/map/map_extract.cpp" + +#include "extension/core_functions/scalar/map/map_from_entries.cpp" + +#include "extension/core_functions/scalar/map/map_entries.cpp" + +#include "extension/core_functions/scalar/map/map.cpp" + +#include "extension/core_functions/scalar/map/map_concat.cpp" + +#include "extension/core_functions/scalar/map/cardinality.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_math.cpp b/src/duckdb/ub_extension_core_functions_scalar_math.cpp new file mode 100644 index 00000000..27320ea9 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_math.cpp @@ -0,0 +1,2 @@ +#include "extension/core_functions/scalar/math/numeric.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_operators.cpp b/src/duckdb/ub_extension_core_functions_scalar_operators.cpp new file mode 100644 index 00000000..47383d4e --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_operators.cpp @@ -0,0 +1,2 @@ +#include "extension/core_functions/scalar/operators/bitwise.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_random.cpp b/src/duckdb/ub_extension_core_functions_scalar_random.cpp new file mode 100644 index 00000000..f71b7b4c --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_random.cpp @@ -0,0 +1,4 @@ +#include "extension/core_functions/scalar/random/random.cpp" + +#include "extension/core_functions/scalar/random/setseed.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_string.cpp b/src/duckdb/ub_extension_core_functions_scalar_string.cpp new file mode 100644 index 00000000..f01d70e2 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_string.cpp @@ -0,0 +1,48 @@ +#include "extension/core_functions/scalar/string/starts_with.cpp" + +#include "extension/core_functions/scalar/string/jaccard.cpp" + +#include "extension/core_functions/scalar/string/levenshtein.cpp" + +#include "extension/core_functions/scalar/string/damerau_levenshtein.cpp" + +#include "extension/core_functions/scalar/string/bar.cpp" + +#include "extension/core_functions/scalar/string/printf.cpp" + +#include "extension/core_functions/scalar/string/replace.cpp" + +#include "extension/core_functions/scalar/string/hamming.cpp" + +#include "extension/core_functions/scalar/string/instr.cpp" + +#include "extension/core_functions/scalar/string/ascii.cpp" + +#include "extension/core_functions/scalar/string/reverse.cpp" + +#include "extension/core_functions/scalar/string/url_encode.cpp" + +#include "extension/core_functions/scalar/string/parse_path.cpp" + +#include "extension/core_functions/scalar/string/left_right.cpp" + +#include "extension/core_functions/scalar/string/to_base.cpp" + +#include "extension/core_functions/scalar/string/pad.cpp" + +#include "extension/core_functions/scalar/string/trim.cpp" + +#include "extension/core_functions/scalar/string/format_bytes.cpp" + +#include "extension/core_functions/scalar/string/hex.cpp" + +#include "extension/core_functions/scalar/string/repeat.cpp" + +#include "extension/core_functions/scalar/string/translate.cpp" + +#include "extension/core_functions/scalar/string/chr.cpp" + +#include "extension/core_functions/scalar/string/unicode.cpp" + +#include "extension/core_functions/scalar/string/jaro_winkler.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_struct.cpp b/src/duckdb/ub_extension_core_functions_scalar_struct.cpp new file mode 100644 index 00000000..8d52b2f9 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_struct.cpp @@ -0,0 +1,2 @@ +#include "extension/core_functions/scalar/struct/struct_insert.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_union.cpp b/src/duckdb/ub_extension_core_functions_scalar_union.cpp new file mode 100644 index 00000000..fad24f29 --- /dev/null +++ b/src/duckdb/ub_extension_core_functions_scalar_union.cpp @@ -0,0 +1,6 @@ +#include "extension/core_functions/scalar/union/union_extract.cpp" + +#include "extension/core_functions/scalar/union/union_value.cpp" + +#include "extension/core_functions/scalar/union/union_tag.cpp" + diff --git a/vendor.py b/vendor.py index cb8283cd..c9b10461 100644 --- a/vendor.py +++ b/vendor.py @@ -15,7 +15,7 @@ # list of extensions to bundle -extensions = ['parquet', 'icu', 'json'] +extensions = ['parquet', 'icu', 'json', 'core_functions'] # path to target basedir = os.getcwd()