diff --git a/src/include/rfuns_extension.hpp b/src/include/rfuns_extension.hpp index 7280fc878..0a2e3ff03 100644 --- a/src/include/rfuns_extension.hpp +++ b/src/include/rfuns_extension.hpp @@ -59,6 +59,9 @@ ScalarFunctionSet base_r_lte(); ScalarFunctionSet base_r_gt(); ScalarFunctionSet base_r_gte(); +ScalarFunctionSet base_r_is_na(); +ScalarFunctionSet base_r_as_integer(); + // sum AggregateFunctionSet base_r_sum(); AggregateFunctionSet base_r_min(); @@ -66,9 +69,6 @@ AggregateFunctionSet base_r_max(); ScalarFunctionSet binary_dispatch(ScalarFunctionSet fn) ; -// is_na -ScalarFunctionSet base_r_is_na(); - } // namespace rfuns class RfunsExtension : public Extension { diff --git a/src/rfuns.cpp b/src/rfuns.cpp index 750de9cbc..e5231d95b 100644 --- a/src/rfuns.cpp +++ b/src/rfuns.cpp @@ -81,6 +81,92 @@ ScalarFunctionSet base_r_add() { } // namespace duckdb #include "rfuns_extension.hpp" #include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" +#include "duckdb/common/operator/double_cast_operator.hpp" + +#include +#include +#include + +namespace duckdb { +namespace rfuns { + +namespace { + +template +int32_t check_range(T value, ValidityMask &mask, idx_t idx) { + if (value > std::numeric_limits::max() || value < std::numeric_limits::min() ) { + mask.SetInvalid(idx); + } + + return static_cast(value); +} + +template +int32_t cast(T input, ValidityMask &mask, idx_t idx) { + return static_cast(input); +} + +template <> +int32_t cast(double input, ValidityMask &mask, idx_t idx) { + if (isnan(input)) { + mask.SetInvalid(idx); + } + return check_range(input, mask, idx); +} + +template <> +int32_t cast(string_t input, ValidityMask &mask, idx_t idx) { + double result; + if (!TryDoubleCast(input.GetData(), input.GetSize(), result, false)) { + mask.SetInvalid(idx); + } + + return cast(result, mask, idx); +} + +template <> +int32_t cast(date_t input, ValidityMask &mask, idx_t idx) { + return input.days; +} + +template <> +int32_t cast(timestamp_t input, ValidityMask &mask, idx_t idx) { + return check_range(Timestamp::GetEpochSeconds(input), mask, idx); +} + +template +ScalarFunction AsIntegerFunction() { + using physical_type = typename physical::type; + + auto fun = [](DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::ExecuteWithNulls( + args.data[0], result, args.size(), cast + ); + }; + return ScalarFunction({TYPE}, LogicalType::INTEGER, fun); +} + +} + +ScalarFunctionSet base_r_as_integer() { + ScalarFunctionSet set("r_base::as.integer"); + + set.AddFunction(AsIntegerFunction()); + set.AddFunction(AsIntegerFunction()); + set.AddFunction(AsIntegerFunction()); + + set.AddFunction(AsIntegerFunction()); + + set.AddFunction(AsIntegerFunction()); + set.AddFunction(AsIntegerFunction()); + + return set; +} + +} +} +#include "rfuns_extension.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" #include #include @@ -685,6 +771,8 @@ static void register_rfuns(DatabaseInstance &instance) { register_binary(instance, base_r_gte()); ExtensionUtil::RegisterFunction(instance, base_r_is_na()); + ExtensionUtil::RegisterFunction(instance, base_r_as_integer()); + ExtensionUtil::RegisterFunction(instance, base_r_sum()); ExtensionUtil::RegisterFunction(instance, base_r_min()); ExtensionUtil::RegisterFunction(instance, base_r_max());