From 8c41b0f2b8ce68b6d264557cd964753085478f51 Mon Sep 17 00:00:00 2001 From: MrPresent-Han Date: Mon, 11 Nov 2024 04:55:06 -0500 Subject: [PATCH] adding sum agg --- internal/core/src/common/Vector.h | 8 ++++ internal/core/src/exec/CMakeLists.txt | 2 +- .../src/exec/operator/query-agg/Aggregate.h | 17 +++++++ .../exec/operator/query-agg/AggregateInfo.cpp | 6 ++- .../exec/operator/query-agg/AggregateInfo.h | 2 +- .../exec/operator/query-agg/AggregateUtil.h | 39 +++++++++++++++ .../exec/operator/query-agg/GroupingSet.cpp | 2 +- .../query-agg/RegisterAggregateFunctions.cpp | 23 +++++++++ .../query-agg/RegisterAggregateFunctions.h | 33 +++++++++++++ .../exec/operator/query-agg/RowContainer.cpp | 1 + .../exec/operator/query-agg/RowContainer.h | 35 +++++++------- .../query-agg/SimpleNumericAggregate.h | 47 +++++++++++++++++++ .../exec/operator/query-agg/SumAggregate.cpp | 35 ++++++++++++++ .../operator/query-agg/SumAggregateBase.h | 42 +++++++++++++++++ 14 files changed, 270 insertions(+), 22 deletions(-) create mode 100644 internal/core/src/exec/operator/query-agg/AggregateUtil.h create mode 100644 internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.cpp create mode 100644 internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.h create mode 100644 internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h create mode 100644 internal/core/src/exec/operator/query-agg/SumAggregate.cpp create mode 100644 internal/core/src/exec/operator/query-agg/SumAggregateBase.h diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h index 3267885442396..464abe55b5d46 100644 --- a/internal/core/src/common/Vector.h +++ b/internal/core/src/common/Vector.h @@ -52,6 +52,10 @@ class BaseVector { return type_kind_; } + int32_t elementSize() const { + return GetDataTypeSize(type_kind_); + }; + size_t nullCount() const { return null_count_.has_value()?null_count_.value():0; @@ -148,6 +152,10 @@ class ColumnVector final : public SimpleVector { valid_values_.set(index, false); } + void + clearNullAt(size_t index) { + valid_values_.set(index, true); + } bool ValidAt(size_t index) override { diff --git a/internal/core/src/exec/CMakeLists.txt b/internal/core/src/exec/CMakeLists.txt index 53c599ef80a54..ccadf28bc4bc1 100644 --- a/internal/core/src/exec/CMakeLists.txt +++ b/internal/core/src/exec/CMakeLists.txt @@ -10,4 +10,4 @@ # or implied. See the License for the specific language governing permissions and limitations under the License add_source_at_current_directory_recursively() -add_library(milvus_exec OBJECT ${SOURCE_FILES}) +add_library(milvus_exec OBJECT ${SOURCE_FILES} operator/query-agg/AggregateUtil.h) diff --git a/internal/core/src/exec/operator/query-agg/Aggregate.h b/internal/core/src/exec/operator/query-agg/Aggregate.h index 418f579134202..b26afd9b005b2 100644 --- a/internal/core/src/exec/operator/query-agg/Aggregate.h +++ b/internal/core/src/exec/operator/query-agg/Aggregate.h @@ -37,6 +37,11 @@ class Aggregate { // so larger values need not be represented. int32_t rowSizeOffset_ = 0; + // Number of null accumulators in the current state of the aggregation + // operator for this aggregate. If 0, clearing the null as part of update + // is not needed. + uint64_t numNulls_ = 0; + public: DataType resultType() const { return result_type_; @@ -72,6 +77,18 @@ class Aggregate { virtual void extractValues(char** groups, int32_t numGroups, VectorPtr* result) {}; + template + T* value(char* group) const { + AssertInfo(reinterpret_cast(group + offset_) % accumulatorAlignmentSize() == 0, + "aggregation value in the groups is not aligned"); + return reinterpret_cast(group + offset_); + } + + + bool isNull(char* group) const { + return numNulls_ && (group[nullByte_]&nullMask_); + } + // Returns true if the accumulator never takes more than // accumulatorFixedWidthSize() bytes. If this is false, the // accumulator needs to track its changing variable length footprint diff --git a/internal/core/src/exec/operator/query-agg/AggregateInfo.cpp b/internal/core/src/exec/operator/query-agg/AggregateInfo.cpp index 1d759deb9c232..382b1664a8fa7 100644 --- a/internal/core/src/exec/operator/query-agg/AggregateInfo.cpp +++ b/internal/core/src/exec/operator/query-agg/AggregateInfo.cpp @@ -23,8 +23,10 @@ std::vector toAggregateInfo( AggregateInfo info; auto& inputColumnIdxes = info.input_column_idxes_; for (const auto& inputExpr: aggregate.call_->inputs()) { - if (auto fieldExpr = dynamic_cast(inputExpr.get())) { - //inputColumnIdxes.emplace_back(inputType->GetChildIndex(fieldExpr->fun_name())); + if (auto fieldExpr = dynamic_cast(inputExpr.get())) { + inputColumnIdxes.emplace_back(inputType->GetChildIndex(fieldExpr->name())); + } else if (inputExpr != nullptr) { + PanicInfo(ExprInvalid, "Only support aggregation towards column for now"); } } auto index = numKeys + i; diff --git a/internal/core/src/exec/operator/query-agg/AggregateInfo.h b/internal/core/src/exec/operator/query-agg/AggregateInfo.h index b4dda8894dca2..5a3909ca7b07e 100644 --- a/internal/core/src/exec/operator/query-agg/AggregateInfo.h +++ b/internal/core/src/exec/operator/query-agg/AggregateInfo.h @@ -36,7 +36,7 @@ struct AggregateInfo{ column_index_t output_; /// Type of intermediate results. Used for spilling. - DataType intermediateType_; + DataType intermediateType_{DataType::None}; }; std::vector toAggregateInfo( diff --git a/internal/core/src/exec/operator/query-agg/AggregateUtil.h b/internal/core/src/exec/operator/query-agg/AggregateUtil.h new file mode 100644 index 0000000000000..7ddf186dfb476 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/AggregateUtil.h @@ -0,0 +1,39 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +namespace milvus { +namespace exec { +// The result of aggregation function registration. +struct AggregateRegistrationResult { + bool mainFunction{false}; + bool partialFunction{false}; + bool mergeFunction{false}; + bool extractFunction{false}; + bool mergeExtractFunction{false}; + + bool operator==(const AggregateRegistrationResult& other) const { + return mainFunction == other.mainFunction && + partialFunction == other.partialFunction && + mergeFunction == other.mergeFunction && + extractFunction == other.extractFunction && + mergeExtractFunction == other.mergeExtractFunction; + } +}; +} +} + + diff --git a/internal/core/src/exec/operator/query-agg/GroupingSet.cpp b/internal/core/src/exec/operator/query-agg/GroupingSet.cpp index 453ddec8a3afa..d6e771c0b4691 100644 --- a/internal/core/src/exec/operator/query-agg/GroupingSet.cpp +++ b/internal/core/src/exec/operator/query-agg/GroupingSet.cpp @@ -59,7 +59,7 @@ void GroupingSet::initializeGlobalAggregation() { for(auto& aggregate: aggregates_) { auto& function = aggregate.function_; - Accumulator accumulator(function.get()); + Accumulator accumulator(function.get(), function->resultType()); // Accumulator offset must be aligned by their alignment size. offset = milvus::bits::roundUp(offset, accumulator.alignment()); function->setOffsets(offset, diff --git a/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.cpp b/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.cpp new file mode 100644 index 0000000000000..a22fcd1fdce8a --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.cpp @@ -0,0 +1,23 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RegisterAggregateFunctions.h" + +namespace milvus { +namespace exec { + +} +} diff --git a/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.h b/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.h new file mode 100644 index 0000000000000..8773f41ea46c7 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.h @@ -0,0 +1,33 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +namespace milvus{ +namespace exec { +void registerAllFunctions(const std::string& prefix = "", + bool withCompanionFunctions = true, + bool overwrite = true); + + +extern void registerSumAggregate(const std::string& prefix, + bool withCompanionFunctions, + bool overwrite); +} +} + diff --git a/internal/core/src/exec/operator/query-agg/RowContainer.cpp b/internal/core/src/exec/operator/query-agg/RowContainer.cpp index b6b2ac15768bc..e8b02419c74d8 100644 --- a/internal/core/src/exec/operator/query-agg/RowContainer.cpp +++ b/internal/core/src/exec/operator/query-agg/RowContainer.cpp @@ -163,6 +163,7 @@ Accumulator::Accumulator(milvus::exec::Aggregate *aggregate, DataType spillType) fixedSize_{aggregate->accumulatorFixedWidthSize()}, alignment_(aggregate->accumulatorAlignmentSize()), spillType_(spillType){ + AssertInfo(aggregate!=nullptr, "Input aggregate for accumulator cannot be nullptr!"); } diff --git a/internal/core/src/exec/operator/query-agg/RowContainer.h b/internal/core/src/exec/operator/query-agg/RowContainer.h index 65bda817a9b08..2a7d30c4ea089 100644 --- a/internal/core/src/exec/operator/query-agg/RowContainer.h +++ b/internal/core/src/exec/operator/query-agg/RowContainer.h @@ -343,24 +343,25 @@ class RowContainer { result->resize(numRows + resultOffset); if constexpr (Type == DataType::ROW || Type == DataType::JSON || Type == DataType::ARRAY || Type == DataType::NONE) { PanicInfo(DataTypeInvalid, "Not Support Extract types:[ROW/JSON/ARRAY/NONE]"); - } - using T = typename milvus::TypeTraits::NativeType; - - auto nullMask = column.nullMask(); - auto offset = column.offset(); - if (nullMask) { - extractValuesWithNulls( - rows, - rowNumbers, - numRows, - offset, - column.nullByte(), - nullMask, - resultOffset, - result); } else { - extractValuesNoNulls( - rows, rowNumbers, numRows, offset, resultOffset, result); + using T = typename milvus::TypeTraits::NativeType; + + auto nullMask = column.nullMask(); + auto offset = column.offset(); + if (nullMask) { + extractValuesWithNulls( + rows, + rowNumbers, + numRows, + offset, + column.nullByte(), + nullMask, + resultOffset, + result); + } else { + extractValuesNoNulls( + rows, rowNumbers, numRows, offset, resultOffset, result); + } } } diff --git a/internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h b/internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h new file mode 100644 index 0000000000000..c514043c99c8e --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h @@ -0,0 +1,47 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License +#pragma once + +#include "Aggregate.h" +namespace milvus{ +namespace exec { +template +class SimpleNumericAggregate : public exec::Aggregate { +protected: + explicit SimpleNumericAggregate(DataType resultType) : Aggregate(resultType){} + + // TData is either TAccumulator or TResult, which in most cases are the same, + // but for sum(real) can differ. + template + void doExtractValues( + char** groups, + int32_t numGroups, + VectorPtr* result, + ExtractOneValue extractOneValue) { + AssertInfo((*result)->elementSize()==sizeof(TData), "Incorrect type size of input result vector"); + ColumnVectorPtr result_column = std::dynamic_pointer_cast(*result); + AssertInfo(result_column != nullptr, "input vector for extracting aggregation must be of Type ColumnVector"); + result_column->resize(numGroups); + TData* rawValues = static_cast(result_column->GetRawData()); + for(auto i = 0; i < numGroups; i++) { + char* group = groups[i]; + if (isNull(group)) { + result_column->nullAt(i); + } else { + result_column->clearNullAt(i); + rawValues[i] = extractOneValue(group); + } + } + } +}; + +} +} \ No newline at end of file diff --git a/internal/core/src/exec/operator/query-agg/SumAggregate.cpp b/internal/core/src/exec/operator/query-agg/SumAggregate.cpp new file mode 100644 index 0000000000000..89d8958129c6b --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/SumAggregate.cpp @@ -0,0 +1,35 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "SumAggregateBase.h" +#include "RegisterAggregateFunctions.h" + +namespace milvus { +namespace exec { + +template +using SumAggregate = SumAggregateBase; + +template class T> + + +void registerSumAggregate(const std::string& prefix, + bool withCompanionFunctions, + bool overwrite) { + regis +} +} +} diff --git a/internal/core/src/exec/operator/query-agg/SumAggregateBase.h b/internal/core/src/exec/operator/query-agg/SumAggregateBase.h new file mode 100644 index 0000000000000..7640f45b0cdbe --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/SumAggregateBase.h @@ -0,0 +1,42 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License +#pragma once +#include "SimpleNumericAggregate.h" + +namespace milvus{ +namespace exec { + template +class SumAggregateBase: public SimpleNumericAggregate { + using BaseAggregate = SimpleNumericAggregate; + +public: + explicit SumAggregateBase(DataType resultType): BaseAggregate(resultType){}; + + constexpr int32_t accumulatorFixedWidthSize() const override { + return sizeof(TAccumulator); + } + + constexpr int32_t accumulatorAlignmentSize() const override { + return 1; + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) { + BaseAggregate::template doExtractValues( + groups, numGroups, result, [&](char* group) { + return (ResultType)(*BaseAggregate::Aggregate::template value(group)); + }); + } +}; +} +}