-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
165 changed files
with
28,058 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
196 changes: 196 additions & 0 deletions
196
src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
#include "core_functions/aggregate/algebraic_functions.hpp" | ||
#include "core_functions/aggregate/sum_helpers.hpp" | ||
#include "duckdb/common/types/hugeint.hpp" | ||
#include "duckdb/common/exception.hpp" | ||
#include "duckdb/function/function_set.hpp" | ||
#include "duckdb/planner/expression.hpp" | ||
|
||
namespace duckdb { | ||
|
||
template <class T> | ||
struct AvgState { | ||
uint64_t count; | ||
T value; | ||
|
||
void Initialize() { | ||
this->count = 0; | ||
} | ||
|
||
void Combine(const AvgState<T> &other) { | ||
this->count += other.count; | ||
this->value += other.value; | ||
} | ||
}; | ||
|
||
struct KahanAvgState { | ||
uint64_t count; | ||
double value; | ||
double err; | ||
|
||
void Initialize() { | ||
this->count = 0; | ||
this->err = 0.0; | ||
} | ||
|
||
void Combine(const KahanAvgState &other) { | ||
this->count += other.count; | ||
KahanAddInternal(other.value, this->value, this->err); | ||
KahanAddInternal(other.err, this->value, this->err); | ||
} | ||
}; | ||
|
||
struct AverageDecimalBindData : public FunctionData { | ||
explicit AverageDecimalBindData(double scale) : scale(scale) { | ||
} | ||
|
||
double scale; | ||
|
||
public: | ||
unique_ptr<FunctionData> Copy() const override { | ||
return make_uniq<AverageDecimalBindData>(scale); | ||
}; | ||
|
||
bool Equals(const FunctionData &other_p) const override { | ||
auto &other = other_p.Cast<AverageDecimalBindData>(); | ||
return scale == other.scale; | ||
} | ||
}; | ||
|
||
struct AverageSetOperation { | ||
template <class STATE> | ||
static void Initialize(STATE &state) { | ||
state.Initialize(); | ||
} | ||
template <class STATE> | ||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) { | ||
target.Combine(source); | ||
} | ||
template <class STATE> | ||
static void AddValues(STATE &state, idx_t count) { | ||
state.count += count; | ||
} | ||
}; | ||
|
||
template <class T> | ||
static T GetAverageDivident(uint64_t count, optional_ptr<FunctionData> bind_data) { | ||
T divident = T(count); | ||
if (bind_data) { | ||
auto &avg_bind_data = bind_data->Cast<AverageDecimalBindData>(); | ||
divident *= avg_bind_data.scale; | ||
} | ||
return divident; | ||
} | ||
|
||
struct IntegerAverageOperation : public BaseSumOperation<AverageSetOperation, RegularAdd> { | ||
template <class T, class STATE> | ||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { | ||
if (state.count == 0) { | ||
finalize_data.ReturnNull(); | ||
} else { | ||
double divident = GetAverageDivident<double>(state.count, finalize_data.input.bind_data); | ||
target = double(state.value) / divident; | ||
} | ||
} | ||
}; | ||
|
||
struct IntegerAverageOperationHugeint : public BaseSumOperation<AverageSetOperation, AddToHugeint> { | ||
template <class T, class STATE> | ||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { | ||
if (state.count == 0) { | ||
finalize_data.ReturnNull(); | ||
} else { | ||
long double divident = GetAverageDivident<long double>(state.count, finalize_data.input.bind_data); | ||
target = Hugeint::Cast<long double>(state.value) / divident; | ||
} | ||
} | ||
}; | ||
|
||
struct HugeintAverageOperation : public BaseSumOperation<AverageSetOperation, HugeintAdd> { | ||
template <class T, class STATE> | ||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { | ||
if (state.count == 0) { | ||
finalize_data.ReturnNull(); | ||
} else { | ||
long double divident = GetAverageDivident<long double>(state.count, finalize_data.input.bind_data); | ||
target = Hugeint::Cast<long double>(state.value) / divident; | ||
} | ||
} | ||
}; | ||
|
||
struct NumericAverageOperation : public BaseSumOperation<AverageSetOperation, RegularAdd> { | ||
template <class T, class STATE> | ||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { | ||
if (state.count == 0) { | ||
finalize_data.ReturnNull(); | ||
} else { | ||
target = state.value / state.count; | ||
} | ||
} | ||
}; | ||
|
||
struct KahanAverageOperation : public BaseSumOperation<AverageSetOperation, KahanAdd> { | ||
template <class T, class STATE> | ||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { | ||
if (state.count == 0) { | ||
finalize_data.ReturnNull(); | ||
} else { | ||
target = (state.value / state.count) + (state.err / state.count); | ||
} | ||
} | ||
}; | ||
|
||
AggregateFunction GetAverageAggregate(PhysicalType type) { | ||
switch (type) { | ||
case PhysicalType::INT16: { | ||
return AggregateFunction::UnaryAggregate<AvgState<int64_t>, int16_t, double, IntegerAverageOperation>( | ||
LogicalType::SMALLINT, LogicalType::DOUBLE); | ||
} | ||
case PhysicalType::INT32: { | ||
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int32_t, double, IntegerAverageOperationHugeint>( | ||
LogicalType::INTEGER, LogicalType::DOUBLE); | ||
} | ||
case PhysicalType::INT64: { | ||
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, double, IntegerAverageOperationHugeint>( | ||
LogicalType::BIGINT, LogicalType::DOUBLE); | ||
} | ||
case PhysicalType::INT128: { | ||
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, hugeint_t, double, HugeintAverageOperation>( | ||
LogicalType::HUGEINT, LogicalType::DOUBLE); | ||
} | ||
default: | ||
throw InternalException("Unimplemented average aggregate"); | ||
} | ||
} | ||
|
||
unique_ptr<FunctionData> BindDecimalAvg(ClientContext &context, AggregateFunction &function, | ||
vector<unique_ptr<Expression>> &arguments) { | ||
auto decimal_type = arguments[0]->return_type; | ||
function = GetAverageAggregate(decimal_type.InternalType()); | ||
function.name = "avg"; | ||
function.arguments[0] = decimal_type; | ||
function.return_type = LogicalType::DOUBLE; | ||
return make_uniq<AverageDecimalBindData>( | ||
Hugeint::Cast<double>(Hugeint::POWERS_OF_TEN[DecimalType::GetScale(decimal_type)])); | ||
} | ||
|
||
AggregateFunctionSet AvgFun::GetFunctions() { | ||
AggregateFunctionSet avg; | ||
|
||
avg.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, | ||
nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, | ||
BindDecimalAvg)); | ||
avg.AddFunction(GetAverageAggregate(PhysicalType::INT16)); | ||
avg.AddFunction(GetAverageAggregate(PhysicalType::INT32)); | ||
avg.AddFunction(GetAverageAggregate(PhysicalType::INT64)); | ||
avg.AddFunction(GetAverageAggregate(PhysicalType::INT128)); | ||
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<double>, double, double, NumericAverageOperation>( | ||
LogicalType::DOUBLE, LogicalType::DOUBLE)); | ||
return avg; | ||
} | ||
|
||
AggregateFunction FAvgFun::GetFunction() { | ||
return AggregateFunction::UnaryAggregate<KahanAvgState, double, double, KahanAverageOperation>(LogicalType::DOUBLE, | ||
LogicalType::DOUBLE); | ||
} | ||
|
||
} // namespace duckdb |
13 changes: 13 additions & 0 deletions
13
src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#include "core_functions/aggregate/algebraic_functions.hpp" | ||
#include "core_functions/aggregate/algebraic/covar.hpp" | ||
#include "core_functions/aggregate/algebraic/stddev.hpp" | ||
#include "core_functions/aggregate/algebraic/corr.hpp" | ||
#include "duckdb/function/function_set.hpp" | ||
|
||
namespace duckdb { | ||
|
||
AggregateFunction CorrFun::GetFunction() { | ||
return AggregateFunction::BinaryAggregate<CorrState, double, double, double, CorrOperation>( | ||
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); | ||
} | ||
} // namespace duckdb |
17 changes: 17 additions & 0 deletions
17
src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#include "core_functions/aggregate/algebraic_functions.hpp" | ||
#include "duckdb/common/types/null_value.hpp" | ||
#include "core_functions/aggregate/algebraic/covar.hpp" | ||
|
||
namespace duckdb { | ||
|
||
AggregateFunction CovarPopFun::GetFunction() { | ||
return AggregateFunction::BinaryAggregate<CovarState, double, double, double, CovarPopOperation>( | ||
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); | ||
} | ||
|
||
AggregateFunction CovarSampFun::GetFunction() { | ||
return AggregateFunction::BinaryAggregate<CovarState, double, double, double, CovarSampOperation>( | ||
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); | ||
} | ||
|
||
} // namespace duckdb |
34 changes: 34 additions & 0 deletions
34
src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#include "core_functions/aggregate/algebraic_functions.hpp" | ||
#include "duckdb/common/vector_operations/vector_operations.hpp" | ||
#include "duckdb/function/function_set.hpp" | ||
#include "core_functions/aggregate/algebraic/stddev.hpp" | ||
#include <cmath> | ||
|
||
namespace duckdb { | ||
|
||
AggregateFunction StdDevSampFun::GetFunction() { | ||
return AggregateFunction::UnaryAggregate<StddevState, double, double, STDDevSampOperation>(LogicalType::DOUBLE, | ||
LogicalType::DOUBLE); | ||
} | ||
|
||
AggregateFunction StdDevPopFun::GetFunction() { | ||
return AggregateFunction::UnaryAggregate<StddevState, double, double, STDDevPopOperation>(LogicalType::DOUBLE, | ||
LogicalType::DOUBLE); | ||
} | ||
|
||
AggregateFunction VarPopFun::GetFunction() { | ||
return AggregateFunction::UnaryAggregate<StddevState, double, double, VarPopOperation>(LogicalType::DOUBLE, | ||
LogicalType::DOUBLE); | ||
} | ||
|
||
AggregateFunction VarSampFun::GetFunction() { | ||
return AggregateFunction::UnaryAggregate<StddevState, double, double, VarSampOperation>(LogicalType::DOUBLE, | ||
LogicalType::DOUBLE); | ||
} | ||
|
||
AggregateFunction StandardErrorOfTheMeanFun::GetFunction() { | ||
return AggregateFunction::UnaryAggregate<StddevState, double, double, StandardErrorOfTheMeanOperation>( | ||
LogicalType::DOUBLE, LogicalType::DOUBLE); | ||
} | ||
|
||
} // namespace duckdb |
Oops, something went wrong.