Skip to content

Commit

Permalink
[fix](round) fix round decimal128 overflow (apache#38106)
Browse files Browse the repository at this point in the history
  • Loading branch information
cambyzju authored Jul 22, 2024
1 parent 10aa083 commit 1d25dff
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions be/src/vec/functions/round.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <fenv.h>
#endif
#include <algorithm>
#include <type_traits>

#include "vec/columns/column.h"
#include "vec/columns/column_decimal.h"
Expand Down Expand Up @@ -75,7 +76,7 @@ enum class TieBreakingMode {
};

template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
TieBreakingMode tie_breaking_mode>
TieBreakingMode tie_breaking_mode, typename U>
struct IntegerRoundingComputation {
static const size_t data_count = 1;

Expand Down Expand Up @@ -139,10 +140,10 @@ struct IntegerRoundingComputation {
__builtin_unreachable();
}

static ALWAYS_INLINE void compute(const T* __restrict in, size_t scale, T* __restrict out,
size_t target_scale) {
static ALWAYS_INLINE void compute(const T* __restrict in, U scale, T* __restrict out,
U target_scale) {
if constexpr (sizeof(T) <= sizeof(scale) && scale_mode == ScaleMode::Negative) {
if (scale > size_t(std::numeric_limits<T>::max())) {
if (scale >= std::numeric_limits<T>::max()) {
*out = 0;
return;
}
Expand All @@ -156,7 +157,7 @@ class DecimalRoundingImpl {
private:
using NativeType = typename T::NativeType;
using Op = IntegerRoundingComputation<NativeType, rounding_mode, ScaleMode::Negative,
tie_breaking_mode>;
tie_breaking_mode, NativeType>;
using Container = typename ColumnDecimal<T>::Container;

public:
Expand All @@ -173,13 +174,13 @@ class DecimalRoundingImpl {
if (out_scale < 0) {
auto negative_scale = DecimalScaleParams::get_scale_factor<T>(-out_scale);
while (p_in < end_in) {
*p_out = Op::compute(*p_in, scale, negative_scale);
Op::compute(p_in, scale, p_out, negative_scale);
++p_in;
++p_out;
}
} else {
while (p_in < end_in) {
*p_out = Op::compute(*p_in, scale, 1);
Op::compute(p_in, scale, p_out, 1);
++p_in;
++p_out;
}
Expand All @@ -196,9 +197,9 @@ class DecimalRoundingImpl {
auto scale = DecimalScaleParams::get_scale_factor<T>(scale_arg);
if (out_scale < 0) {
auto negative_scale = DecimalScaleParams::get_scale_factor<T>(-out_scale);
out = Op::compute(in, scale, negative_scale);
Op::compute(&in, scale, &out, negative_scale);
} else {
out = Op::compute(in, scale, 1);
Op::compute(&in, scale, &out, 1);
}
} else {
memcpy(&out, &in, sizeof(NativeType));
Expand Down Expand Up @@ -353,7 +354,7 @@ template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
TieBreakingMode tie_breaking_mode>
struct IntegerRoundingImpl {
private:
using Op = IntegerRoundingComputation<T, rounding_mode, scale_mode, tie_breaking_mode>;
using Op = IntegerRoundingComputation<T, rounding_mode, scale_mode, tie_breaking_mode, size_t>;
using Container = typename ColumnVector<T>::Container;

public:
Expand Down

0 comments on commit 1d25dff

Please sign in to comment.