Skip to content

Commit

Permalink
adding sum agg
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPresent-Han committed Nov 11, 2024
1 parent 64e961b commit 8c41b0f
Show file tree
Hide file tree
Showing 14 changed files with 270 additions and 22 deletions.
8 changes: 8 additions & 0 deletions internal/core/src/common/Vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class BaseVector {
return type_kind_;
}

int32_t elementSize() const {
return GetDataTypeSize(type_kind_);
};

size_t
nullCount() const {
return null_count_.has_value()?null_count_.value():0;
Expand Down Expand Up @@ -148,6 +152,10 @@ class ColumnVector final : public SimpleVector {
valid_values_.set(index, false);
}

void
clearNullAt(size_t index) {
valid_values_.set(index, true);
}

bool
ValidAt(size_t index) override {
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/exec/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
# or implied. See the License for the specific language governing permissions and limitations under the License

add_source_at_current_directory_recursively()
add_library(milvus_exec OBJECT ${SOURCE_FILES})
add_library(milvus_exec OBJECT ${SOURCE_FILES} operator/query-agg/AggregateUtil.h)
17 changes: 17 additions & 0 deletions internal/core/src/exec/operator/query-agg/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class Aggregate {
// so larger values need not be represented.
int32_t rowSizeOffset_ = 0;

// Number of null accumulators in the current state of the aggregation
// operator for this aggregate. If 0, clearing the null as part of update
// is not needed.
uint64_t numNulls_ = 0;

public:
DataType resultType() const {
return result_type_;
Expand Down Expand Up @@ -72,6 +77,18 @@ class Aggregate {

virtual void extractValues(char** groups, int32_t numGroups, VectorPtr* result) {};

template <typename T>
T* value(char* group) const {
AssertInfo(reinterpret_cast<uintptr_t>(group + offset_) % accumulatorAlignmentSize() == 0,
"aggregation value in the groups is not aligned");
return reinterpret_cast<T*>(group + offset_);
}


bool isNull(char* group) const {
return numNulls_ && (group[nullByte_]&nullMask_);
}

// Returns true if the accumulator never takes more than
// accumulatorFixedWidthSize() bytes. If this is false, the
// accumulator needs to track its changing variable length footprint
Expand Down
6 changes: 4 additions & 2 deletions internal/core/src/exec/operator/query-agg/AggregateInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ std::vector<AggregateInfo> toAggregateInfo(
AggregateInfo info;
auto& inputColumnIdxes = info.input_column_idxes_;
for (const auto& inputExpr: aggregate.call_->inputs()) {
if (auto fieldExpr = dynamic_cast<const expr::CallExpr*>(inputExpr.get())) {
//inputColumnIdxes.emplace_back(inputType->GetChildIndex(fieldExpr->fun_name()));
if (auto fieldExpr = dynamic_cast<const expr::FieldAccessTypeExpr*>(inputExpr.get())) {
inputColumnIdxes.emplace_back(inputType->GetChildIndex(fieldExpr->name()));
} else if (inputExpr != nullptr) {
PanicInfo(ExprInvalid, "Only support aggregation towards column for now");
}
}
auto index = numKeys + i;
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/exec/operator/query-agg/AggregateInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct AggregateInfo{
column_index_t output_;

/// Type of intermediate results. Used for spilling.
DataType intermediateType_;
DataType intermediateType_{DataType::None};
};

std::vector<AggregateInfo> toAggregateInfo(
Expand Down
39 changes: 39 additions & 0 deletions internal/core/src/exec/operator/query-agg/AggregateUtil.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

namespace milvus {
namespace exec {
// The result of aggregation function registration.
struct AggregateRegistrationResult {
bool mainFunction{false};
bool partialFunction{false};
bool mergeFunction{false};
bool extractFunction{false};
bool mergeExtractFunction{false};

bool operator==(const AggregateRegistrationResult& other) const {
return mainFunction == other.mainFunction &&
partialFunction == other.partialFunction &&
mergeFunction == other.mergeFunction &&
extractFunction == other.extractFunction &&
mergeExtractFunction == other.mergeExtractFunction;
}
};
}
}


2 changes: 1 addition & 1 deletion internal/core/src/exec/operator/query-agg/GroupingSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void GroupingSet::initializeGlobalAggregation() {

for(auto& aggregate: aggregates_) {
auto& function = aggregate.function_;
Accumulator accumulator(function.get());
Accumulator accumulator(function.get(), function->resultType());
// Accumulator offset must be aligned by their alignment size.
offset = milvus::bits::roundUp(offset, accumulator.alignment());
function->setOffsets(offset,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "RegisterAggregateFunctions.h"

namespace milvus {
namespace exec {

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <memory>
#include <string>

namespace milvus{
namespace exec {
void registerAllFunctions(const std::string& prefix = "",
bool withCompanionFunctions = true,
bool overwrite = true);


extern void registerSumAggregate(const std::string& prefix,
bool withCompanionFunctions,
bool overwrite);
}
}

1 change: 1 addition & 0 deletions internal/core/src/exec/operator/query-agg/RowContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ Accumulator::Accumulator(milvus::exec::Aggregate *aggregate, DataType spillType)
fixedSize_{aggregate->accumulatorFixedWidthSize()},
alignment_(aggregate->accumulatorAlignmentSize()),
spillType_(spillType){

AssertInfo(aggregate!=nullptr, "Input aggregate for accumulator cannot be nullptr!");
}

Expand Down
35 changes: 18 additions & 17 deletions internal/core/src/exec/operator/query-agg/RowContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,24 +343,25 @@ class RowContainer {
result->resize(numRows + resultOffset);
if constexpr (Type == DataType::ROW || Type == DataType::JSON || Type == DataType::ARRAY || Type == DataType::NONE) {
PanicInfo(DataTypeInvalid, "Not Support Extract types:[ROW/JSON/ARRAY/NONE]");
}
using T = typename milvus::TypeTraits<Type>::NativeType;

auto nullMask = column.nullMask();
auto offset = column.offset();
if (nullMask) {
extractValuesWithNulls<useRowNumbers, T>(
rows,
rowNumbers,
numRows,
offset,
column.nullByte(),
nullMask,
resultOffset,
result);
} else {
extractValuesNoNulls<useRowNumbers, T>(
rows, rowNumbers, numRows, offset, resultOffset, result);
using T = typename milvus::TypeTraits<Type>::NativeType;

auto nullMask = column.nullMask();
auto offset = column.offset();
if (nullMask) {
extractValuesWithNulls<useRowNumbers, T>(
rows,
rowNumbers,
numRows,
offset,
column.nullByte(),
nullMask,
resultOffset,
result);
} else {
extractValuesNoNulls<useRowNumbers, T>(
rows, rowNumbers, numRows, offset, resultOffset, result);
}
}
}

Expand Down
47 changes: 47 additions & 0 deletions internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once

#include "Aggregate.h"
namespace milvus{
namespace exec {
template <typename TInput, typename TAccumulator, typename TResult>
class SimpleNumericAggregate : public exec::Aggregate {
protected:
explicit SimpleNumericAggregate(DataType resultType) : Aggregate(resultType){}

// TData is either TAccumulator or TResult, which in most cases are the same,
// but for sum(real) can differ.
template <typename TData = TResult, typename ExtractOneValue>
void doExtractValues(
char** groups,
int32_t numGroups,
VectorPtr* result,
ExtractOneValue extractOneValue) {
AssertInfo((*result)->elementSize()==sizeof(TData), "Incorrect type size of input result vector");
ColumnVectorPtr result_column = std::dynamic_pointer_cast<ColumnVector>(*result);
AssertInfo(result_column != nullptr, "input vector for extracting aggregation must be of Type ColumnVector");
result_column->resize(numGroups);
TData* rawValues = static_cast<TData*>(result_column->GetRawData());
for(auto i = 0; i < numGroups; i++) {
char* group = groups[i];
if (isNull(group)) {
result_column->nullAt(i);
} else {
result_column->clearNullAt(i);
rawValues[i] = extractOneValue(group);
}
}
}
};

}
}
35 changes: 35 additions & 0 deletions internal/core/src/exec/operator/query-agg/SumAggregate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "SumAggregateBase.h"
#include "RegisterAggregateFunctions.h"

namespace milvus {
namespace exec {

template <typename TInput, typename TAccumulator, typename ResultType>
using SumAggregate = SumAggregateBase<TInput, TAccumulator, ResultType, false>;

template <typename <typename U, typename V, typename W> class T>


void registerSumAggregate(const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
regis
}
}
}
42 changes: 42 additions & 0 deletions internal/core/src/exec/operator/query-agg/SumAggregateBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "SimpleNumericAggregate.h"

namespace milvus{
namespace exec {
template <typename TInput,
typename TAccumulator,
typename ResultType,
bool Overflow>
class SumAggregateBase: public SimpleNumericAggregate<TInput, TAccumulator, ResultType> {
using BaseAggregate = SimpleNumericAggregate<TInput, TAccumulator, ResultType>;

public:
explicit SumAggregateBase(DataType resultType): BaseAggregate(resultType){};

constexpr int32_t accumulatorFixedWidthSize() const override {
return sizeof(TAccumulator);
}

constexpr int32_t accumulatorAlignmentSize() const override {
return 1;
}

void extractValues(char** groups, int32_t numGroups, VectorPtr* result) {
BaseAggregate::template doExtractValues<TAccumulator>(
groups, numGroups, result, [&](char* group) {
return (ResultType)(*BaseAggregate::Aggregate::template value<TAccumulator>(group));
});
}
};
}
}

0 comments on commit 8c41b0f

Please sign in to comment.