Skip to content

Commit

Permalink
add init groups logic
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPresent-Han committed Nov 12, 2024
1 parent 73c3bcc commit 84cbbc8
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 3 deletions.
10 changes: 10 additions & 0 deletions internal/core/src/common/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,5 +337,15 @@ T checkPlus(const T& a, const T& b, const char* typeName = "integer"){
return result;
}

template <typename T>
T checkedMultiply(const T& a, const T& b, const char* typeName = "integer") {
T result;
bool overflow = __builtin_mul_overflow(a, b, &result);
if (UNLIKELY(overflow)) {
VELOX_ARITHMETIC_ERROR("{} overflow: {} * {}", typeName, a, b);
}
return result;
}

} // namespace milvus

5 changes: 5 additions & 0 deletions internal/core/src/common/Vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ class ColumnVector final : public SimpleVector {
return reinterpret_cast<char*>(GetRawData()) + index * size_of_element;
}

template <typename T>
T ValueAt(size_t index) const {
return *(reinterpret_cast<T*>(GetRawData()) + index);
}

template<typename T>
void
SetValueAt(size_t index, const T& value) {
Expand Down
26 changes: 23 additions & 3 deletions internal/core/src/exec/operator/query-agg/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,19 @@ class Aggregate {
}

virtual void initializeNewGroups(char** groups, folly::Range<const vector_size_t*> indices) {
initializeNewGroupsInternal(groups, indices);
for(auto index : indices) {
groups[index][initializedByte_] |= initializedMask_;
}
}

virtual void addSingleGroupRawInput(char* group, const TargetBitmapView& activeRows,
const std::vector<VectorPtr>& input, bool mayPushDown) {};
const std::vector<VectorPtr>& input, bool mayPushDown) = 0;

virtual void addRawInput(char** groups, const TargetBitmapView& activeRows,
const std::vector<VectorPtr>& input, bool mayPushDown) {} ;
const std::vector<VectorPtr>& input, bool mayPushDown) = 0;

virtual void extractValues(char** groups, int32_t numGroups, VectorPtr* result) {};
virtual void extractValues(char** groups, int32_t numGroups, VectorPtr* result) = 0;

template <typename T>
T* value(char* group) const {
Expand Down Expand Up @@ -127,6 +128,25 @@ class Aggregate {
// operator for this aggregate. If 0, clearing the null as part of update
// is not needed.
uint64_t numNulls_ = 0;

inline bool clearNull(char* group) {
if (numNulls_) {
uint8_t mask = group[nullByte_];
if (mask & nullMask_) {
group[nullByte_] = mask & ~nullMask_;
numNulls_--;
return true;
}
}
return false;
}

void setAllNulls(char** groups, folly::Range<const vector_size_t*> indices) {
for(auto i:indices) {
groups[i][nullByte_] = nullMask_;
}
numNulls_ += indices.size();
}
};

using AggregateFunctionFactory = std::function<std::unique_ptr<Aggregate>(plan::AggregationNode::Step step,
Expand Down
54 changes: 54 additions & 0 deletions internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,62 @@ class SimpleNumericAggregate : public exec::Aggregate {
const VectorPtr& vector,
UpdateSingleValue updateSingleValue,
bool mayPushdown){
auto start = 0;
auto column_data = std::dynamic_pointer_cast<ColumnVector>(vector);
AssertInfo(column_data!=nullptr, "input column data for upgrading groups should not be nullptr");
while(true) {
auto next_selected = rows.find_next(start);
if (!next_selected.has_value()) {
return;
}
auto selected_idx = next_selected.value();
if (column_data->ValidAt(selected_idx)) {
continue;
}
updateNonNullValue<tableHasNulls, TData>(groups[selected_idx], column_data->ValueAt<TData>(), updateSingleValue);
start = selected_idx;
}
}

template <
typename TData = TResult,
typename TValue = TInput,
typename UpdateSingle,
typename UpdateDuplicate>
void updateOneGroup(
char* group,
const TargetBitmapView& rows,
const VectorPtr& vector,
UpdateSingle updateSingleValue,
UpdateDuplicate /*updateDuplicateValues*/,
bool /*mayPushdown*/,
TData initialValue) {
auto start = 0;
auto column_data = std::dynamic_pointer_cast<ColumnVector>(vector);
AssertInfo(column_data!=nullptr, "input column data for upgrading groups should not be nullptr");
while(true) {
auto next_selected = rows.find_next(start);
if (!next_selected.has_value()) {
return;
}
auto selected_idx = next_selected.value();
if (column_data->ValidAt(selected_idx)) {
continue;
}
updateNonNullValue<true, TData>(group, column_data->ValueAt<TData>(), updateSingleValue);
start = selected_idx;
}
}

template<bool tableHasNulls, typename TDataType = TAccumulator, typename Update>
inline void
updateNonNullValue(char* group, TDataType value, Update updateValue) {
if constexpr (tableHasNulls) {
Aggregate::clearNull(group);
}
updateValue(Aggregate::value<TDataType>(group), value);
}

};

}
Expand Down
24 changes: 24 additions & 0 deletions internal/core/src/exec/operator/query-agg/SumAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,35 @@ class SumAggregateBase: public SimpleNumericAggregate<TInput, TAccumulator, Resu

void addRawInput(char** groups, const TargetBitmapView& activeRows,
const std::vector<VectorPtr>& input, bool mayPushDown) override {
updateInternal<TAccumulator>(groups, activeRows, input, mayPushDown);
}

void addSingleGroupRawInput(char* group, const TargetBitmapView& activeRows,
const std::vector<VectorPtr>& input, bool mayPushDown) override {
BaseAggregate::template updateOneGroup<TAccumulator>(group, activeRows, input[0],
&updateSingleValue<TAccumulator>, &updateDuplicateValues<TAccumulator>, mayPushDown, TAccumulator(0));
}

void initializeNewGroupsInternal(char** groups, folly::Range<const vector_size_t*> indices) override {
Aggregate::setAllNulls(groups, indices);
for(auto i: indices) {
(*Aggregate::value<TAccumulator>(groups[i])) = 0;
}
}

template <typename TData>
#if defined(FOLLY_DISABLE_UNDEFINED_BEHAVIOR_SANITIZER)
FOLLY_DISABLE_UNDEFINED_BEHAVIOR_SANITIZER("signed-integer-overflow")
#endif
static void updateDuplicateValues(TData& result, TData value, int n) {
if constexpr(
(std::is_same_v<TData, int64_t> && Overflow) ||
std::is_same_v<TData, double> || std::is_same_v<TData, float>) {
result += n * value;
} else {
result = checkPlus<TData>(result,
checkedMultiply<TData>(TData(n), value));
}
}

protected:
Expand Down

0 comments on commit 84cbbc8

Please sign in to comment.