diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index da530c94ba778..730a67810c1dd 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -337,5 +337,15 @@ T checkPlus(const T& a, const T& b, const char* typeName = "integer"){ return result; } +template +T checkedMultiply(const T& a, const T& b, const char* typeName = "integer") { + T result; + bool overflow = __builtin_mul_overflow(a, b, &result); + if (UNLIKELY(overflow)) { + VELOX_ARITHMETIC_ERROR("{} overflow: {} * {}", typeName, a, b); + } + return result; +} + } // namespace milvus diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h index 464abe55b5d46..a288bfc47dc19 100644 --- a/internal/core/src/common/Vector.h +++ b/internal/core/src/common/Vector.h @@ -141,6 +141,11 @@ class ColumnVector final : public SimpleVector { return reinterpret_cast(GetRawData()) + index * size_of_element; } + template + T ValueAt(size_t index) const { + return *(reinterpret_cast(GetRawData()) + index); + } + template void SetValueAt(size_t index, const T& value) { diff --git a/internal/core/src/exec/operator/query-agg/Aggregate.h b/internal/core/src/exec/operator/query-agg/Aggregate.h index 0a735b3d092c0..a05539b49c6df 100644 --- a/internal/core/src/exec/operator/query-agg/Aggregate.h +++ b/internal/core/src/exec/operator/query-agg/Aggregate.h @@ -65,18 +65,19 @@ class Aggregate { } virtual void initializeNewGroups(char** groups, folly::Range indices) { + initializeNewGroupsInternal(groups, indices); for(auto index : indices) { groups[index][initializedByte_] |= initializedMask_; } } virtual void addSingleGroupRawInput(char* group, const TargetBitmapView& activeRows, - const std::vector& input, bool mayPushDown) {}; + const std::vector& input, bool mayPushDown) = 0; virtual void addRawInput(char** groups, const TargetBitmapView& activeRows, - const std::vector& input, bool mayPushDown) {} ; + const std::vector& input, bool mayPushDown) = 0; - virtual void extractValues(char** groups, int32_t numGroups, VectorPtr* result) {}; + virtual void extractValues(char** groups, int32_t numGroups, VectorPtr* result) = 0; template T* value(char* group) const { @@ -127,6 +128,25 @@ class Aggregate { // operator for this aggregate. If 0, clearing the null as part of update // is not needed. uint64_t numNulls_ = 0; + + inline bool clearNull(char* group) { + if (numNulls_) { + uint8_t mask = group[nullByte_]; + if (mask & nullMask_) { + group[nullByte_] = mask & ~nullMask_; + numNulls_--; + return true; + } + } + return false; + } + + void setAllNulls(char** groups, folly::Range indices) { + for(auto i:indices) { + groups[i][nullByte_] = nullMask_; + } + numNulls_ += indices.size(); + } }; using AggregateFunctionFactory = std::function(plan::AggregationNode::Step step, diff --git a/internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h b/internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h index ac0ce098dea73..f9ec3fa9ba8cb 100644 --- a/internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h +++ b/internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h @@ -49,8 +49,62 @@ class SimpleNumericAggregate : public exec::Aggregate { const VectorPtr& vector, UpdateSingleValue updateSingleValue, bool mayPushdown){ + auto start = 0; + auto column_data = std::dynamic_pointer_cast(vector); + AssertInfo(column_data!=nullptr, "input column data for upgrading groups should not be nullptr"); + while(true) { + auto next_selected = rows.find_next(start); + if (!next_selected.has_value()) { + return; + } + auto selected_idx = next_selected.value(); + if (column_data->ValidAt(selected_idx)) { + continue; + } + updateNonNullValue(groups[selected_idx], column_data->ValueAt(), updateSingleValue); + start = selected_idx; + } + } + template < + typename TData = TResult, + typename TValue = TInput, + typename UpdateSingle, + typename UpdateDuplicate> + void updateOneGroup( + char* group, + const TargetBitmapView& rows, + const VectorPtr& vector, + UpdateSingle updateSingleValue, + UpdateDuplicate /*updateDuplicateValues*/, + bool /*mayPushdown*/, + TData initialValue) { + auto start = 0; + auto column_data = std::dynamic_pointer_cast(vector); + AssertInfo(column_data!=nullptr, "input column data for upgrading groups should not be nullptr"); + while(true) { + auto next_selected = rows.find_next(start); + if (!next_selected.has_value()) { + return; + } + auto selected_idx = next_selected.value(); + if (column_data->ValidAt(selected_idx)) { + continue; + } + updateNonNullValue(group, column_data->ValueAt(), updateSingleValue); + start = selected_idx; + } } + + template + inline void + updateNonNullValue(char* group, TDataType value, Update updateValue) { + if constexpr (tableHasNulls) { + Aggregate::clearNull(group); + } + updateValue(Aggregate::value(group), value); + } + }; } diff --git a/internal/core/src/exec/operator/query-agg/SumAggregateBase.h b/internal/core/src/exec/operator/query-agg/SumAggregateBase.h index afbd6d0b2f5fd..2d803c148c601 100644 --- a/internal/core/src/exec/operator/query-agg/SumAggregateBase.h +++ b/internal/core/src/exec/operator/query-agg/SumAggregateBase.h @@ -41,11 +41,35 @@ class SumAggregateBase: public SimpleNumericAggregate& input, bool mayPushDown) override { + updateInternal(groups, activeRows, input, mayPushDown); + } + void addSingleGroupRawInput(char* group, const TargetBitmapView& activeRows, + const std::vector& input, bool mayPushDown) override { + BaseAggregate::template updateOneGroup(group, activeRows, input[0], + &updateSingleValue, &updateDuplicateValues, mayPushDown, TAccumulator(0)); } void initializeNewGroupsInternal(char** groups, folly::Range indices) override { + Aggregate::setAllNulls(groups, indices); + for(auto i: indices) { + (*Aggregate::value(groups[i])) = 0; + } + } + template + #if defined(FOLLY_DISABLE_UNDEFINED_BEHAVIOR_SANITIZER) + FOLLY_DISABLE_UNDEFINED_BEHAVIOR_SANITIZER("signed-integer-overflow") + #endif + static void updateDuplicateValues(TData& result, TData value, int n) { + if constexpr( + (std::is_same_v && Overflow) || + std::is_same_v || std::is_same_v) { + result += n * value; + } else { + result = checkPlus(result, + checkedMultiply(TData(n), value)); + } } protected: