Skip to content

Commit

Permalink
complete aggreation fucntion factory
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPresent-Han committed Nov 22, 2024
1 parent 4a2d8df commit 22ad71b
Show file tree
Hide file tree
Showing 23 changed files with 176 additions and 25 deletions.
4 changes: 4 additions & 0 deletions internal/core/src/common/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "Json.h"

#include "CustomBitset.h"
#include "log/Log.h"

namespace milvus {

Expand Down Expand Up @@ -752,9 +753,12 @@ class RowType final {

column_index_t GetChildIndex(std::string name) const {
std::optional<column_index_t> idx;
LOG_INFO("hc===names_.size():{}", names_.size());
for(auto i = 0; i < names_.size(); i++) {
LOG_INFO("hc===names_[i]:{}, name:{}", names_[i], name);
if (names_[i] == name) {
idx = i;
break;
}
}
AssertInfo(idx.has_value(), "Cannot find target column in the rowType list");
Expand Down
11 changes: 11 additions & 0 deletions internal/core/src/exec/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ Driver::Run(std::shared_ptr<Driver> self) {
}
}

void Driver::initializeOperators() {
if (operatorsInitialized_) {
return;
}
operatorsInitialized_ = true;
for(auto& op: operators_) {
op->initialize();
}
}

void
Driver::Init(std::unique_ptr<DriverContext> ctx,
std::vector<std::unique_ptr<Operator>> operators) {
Expand Down Expand Up @@ -202,6 +212,7 @@ Driver::RunInternal(std::shared_ptr<Driver>& self,
std::shared_ptr<BlockingState>& blocking_state,
RowVectorPtr& result) {
try {
initializeOperators();
int num_operators = operators_.size();
LOG_INFO("hc===operator_size:{}", num_operators);
ContinueFuture future;
Expand Down
6 changes: 6 additions & 0 deletions internal/core/src/exec/Driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ class Driver : public std::enable_shared_from_this<Driver> {
EnqueueInternal() {
}

/// Invoked to initialize the operators from this driver once on its first
/// execution.
void initializeOperators();

static void
Run(std::shared_ptr<Driver> self);

Expand All @@ -238,6 +242,8 @@ class Driver : public std::enable_shared_from_this<Driver> {

size_t current_operator_index_{0};

bool operatorsInitialized_{false};

BlockingReason blocking_reason_{BlockingReason::kNotBlocked};

friend struct DriverFactory;
Expand Down
9 changes: 9 additions & 0 deletions internal/core/src/exec/expression/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,14 @@ GetValueFromProtoWithOverflow(
return GetValueFromProtoInternal<T>(value_proto, overflowed);
}

std::string sanitizeName(const std::string& name) {
std::string sanitizedName;
sanitizedName.resize(name.size());
std::transform(name.begin(), name.end(), sanitizedName.begin(), [](unsigned char c){
return std::tolower(c);
});
return sanitizedName;
}

} // namespace exec
} // namespace milvus
6 changes: 6 additions & 0 deletions internal/core/src/exec/expression/function/init_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@

#include "exec/expression/function/init_c.h"
#include "exec/expression/function/FunctionFactory.h"
#include "exec/operator/query-agg/RegisterAggregateFunctions.h"

void
InitExecExpressionFunctionFactory() {
milvus::exec::expression::FunctionFactory::Instance().Initialize();
}

void
RegisterAggregationFunctions(){
milvus::exec::registerAllAggregateFunctions();
}
5 changes: 5 additions & 0 deletions internal/core/src/exec/expression/function/init_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

#pragma once

#include "RegisterAggregateFunctions.h"

#ifdef __cplusplus
extern "C" {
#endif

void
InitExecExpressionFunctionFactory();

void
RegisterAggregationFunctions();

#ifdef __cplusplus
};
#endif
9 changes: 7 additions & 2 deletions internal/core/src/exec/operator/AggregationNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void PhyAggregationNode::prepareOutput(vector_size_t size){
}

RowVectorPtr PhyAggregationNode::GetOutput() {
LOG_INFO("hc==enter PhyAggregationNode, {}", grouping_set_==nullptr);
if (finished_||(!no_more_input_ && !grouping_set_->hasOutput())) {
LOG_INFO("hc==skip running aggnode");
input_ = nullptr;
Expand All @@ -52,15 +53,18 @@ RowVectorPtr PhyAggregationNode::GetOutput() {

void PhyAggregationNode::initialize() {
Operator::initialize();
LOG_INFO("hc===start to init phy agg operator, aggregationNode_->sources.size:{}", aggregationNode_->sources().size());
const auto& input_type = aggregationNode_->sources()[0]->output_type();
auto hashers = createVectorHashers(input_type, aggregationNode_->GroupingKeys());
auto numHashers = hashers.size();
LOG_INFO("hc===hasher.size:{}", numHashers);
std::vector<AggregateInfo> aggregateInfos = toAggregateInfo(*aggregationNode_,
*operator_context_,
numHashers);

LOG_INFO("hc===aggregateInfos.size:{}", aggregateInfos.size());
// Check that aggregate result type match the output type.
for (auto i = 0; i < aggregateInfos.size(); i++) {
LOG_INFO("hc===asserted aggregation type:{}", i);
const auto aggResultType = aggregateInfos[i].function_->resultType();
const auto expectedType = output_type_->column_type(numHashers + i);
AssertInfo(aggResultType==expectedType,
Expand All @@ -69,13 +73,14 @@ void PhyAggregationNode::initialize() {
expectedType,
plan::AggregationNode::stepName(aggregationNode_->step()));
}

LOG_INFO("hc===asserted aggregation type");
grouping_set_ = std::make_unique<GroupingSet>(
input_type,
std::move(hashers),
std::move(aggregateInfos),
!aggregationNode_->ignoreNullKeys(),
isRawInput(aggregationNode_->step()));
LOG_INFO("hc===has init AggregationNode");
aggregationNode_.reset();
}

Expand Down
6 changes: 3 additions & 3 deletions internal/core/src/exec/operator/ProjectNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@ PhyProjectNode::GetOutput() {
LOG_INFO("hc==start running project node");
auto col_input = GetColumnVector(input_);
TargetBitmapView bitset_view(col_input->GetRawData(), col_input->size());
auto result_pair = segment_->find_first(0, bitset_view);
auto result_pair = segment_->find_first(10000, bitset_view);
auto selected_offsets = result_pair.first;
auto selected_count = selected_offsets.size();
is_finished_ = true;

LOG_INFO("hc==project_selected_count:{}", selected_count);
auto row_type = OutputType();
std::vector<VectorPtr> column_vectors;
for (int i = 0; i < fields_to_project_.size(); i++) {
auto column_type = row_type->column_type(i);
auto field_id = fields_to_project_.at(i);
auto field_data = segment_->bulk_subscript(field_id, selected_offsets.data(), selected_count);
auto column_vector = std::make_shared<ColumnVector>(column_type, selected_count);
auto field_data = segment_->bulk_subscript(field_id, selected_offsets.data(), selected_count);
column_vectors.emplace_back(column_vector);
}
auto row_vector = std::make_shared<RowVector>(std::move(column_vectors));
Expand Down
26 changes: 23 additions & 3 deletions internal/core/src/exec/operator/query-agg/Aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//
#include "Aggregate.h"
#include "AggregateUtil.h"
#include "exec/expression/Utils.h"

namespace milvus{
namespace exec{
Expand All @@ -19,13 +20,32 @@ void Aggregate::setOffsetsInternal(int32_t offset,
initializedByte_ = initializedByte;
initializedMask_ = initializedMask;
rowSizeOffset_ = rowSizeOffset;
}
}

const AggregateFunctionEntry*
getAggregateFunctionEntry(const std::string& name){
auto sanitizedName = milvus::exec::sanitizeName(name);

return aggregateFunctions().withRLock(
[&](const auto& functionsMap) -> const AggregateFunctionEntry* {
auto it = functionsMap.find(sanitizedName);
if (it != functionsMap.end()) {
return &it->second;
}
return nullptr;
});
}

std::unique_ptr<Aggregate> Aggregate::create(const std::string& name,
plan::AggregationNode::Step step,
const std::vector<DataType>& argTypes,
DataType resultType) {
return nullptr;
DataType resultType,
const QueryConfig& query_config) {
if(auto func = getAggregateFunctionEntry(name)) {
LOG_INFO("hc=== found aggregation function factory for name:{}", name);
return func->factory(step, argTypes, resultType, query_config);
}
PanicInfo(UnexpectedError, "Aggregate function not registered: {}", name);
}

bool isRawInput(milvus::plan::AggregationNode::Step step) {
Expand Down
6 changes: 5 additions & 1 deletion internal/core/src/exec/operator/query-agg/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class Aggregate {
const std::string& name,
plan::AggregationNode::Step step,
const std::vector<DataType>& argTypes,
DataType resultType);
DataType resultType,
const QueryConfig& query_config);

void setOffsets(
int32_t offset,
Expand Down Expand Up @@ -159,6 +160,9 @@ struct AggregateFunctionEntry {
AggregateFunctionFactory factory;
};

const AggregateFunctionEntry*
getAggregateFunctionEntry(const std::string& name);

using AggregateFunctionMap = folly::Synchronized<std::unordered_map<std::string, AggregateFunctionEntry>>;

AggregateFunctionMap& aggregateFunctions();
Expand Down
3 changes: 2 additions & 1 deletion internal/core/src/exec/operator/query-agg/AggregateInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ std::vector<AggregateInfo> toAggregateInfo(
aggregate.call_->fun_name(),
isPartialOutput(step)? plan::AggregationNode::Step::kPartial:plan::AggregationNode::Step::kSingle,
aggregate.rawInputTypes_,
aggResultType);
aggResultType,
*(operatorCtx.get_exec_context()->get_query_config()));
info.output_ = index;
aggregates.emplace_back(std::move(info));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,11 @@
namespace milvus {
namespace exec {

void registerAllAggregateFunctions(const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
registerSumAggregate(prefix, withCompanionFunctions, overwrite);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

namespace milvus{
namespace exec {
void registerAllFunctions(const std::string& prefix = "",
void registerAllAggregateFunctions(const std::string& prefix = "",
bool withCompanionFunctions = true,
bool overwrite = true);

Expand Down
3 changes: 1 addition & 2 deletions internal/core/src/plan/PlanNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ RowTypePtr getAggregationOutputType(const std::vector<expr::FieldAccessTypeExprP
std::vector<std::string> names;
std::vector<milvus::DataType> types;
for (auto& key : groupingKeys) {
types.emplace_back(key->type());
names.emplace_back(key->name());
types.emplace_back(key->type());
}

for (int i = 0; i < aggregateNames.size(); i++) {
Expand All @@ -41,7 +41,6 @@ AggregationNode::AggregationNode(const milvus::plan::PlanNodeId &id,
std::vector<expr::FieldAccessTypeExprPtr> &&groupingKeys,
std::vector<std::string> &&aggNames,
std::vector<Aggregate> &&aggregates,
std::shared_ptr<const RowType> output_type,
std::vector<PlanNodePtr> sources):
PlanNode(id),
step_(step),
Expand Down
13 changes: 9 additions & 4 deletions internal/core/src/plan/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,13 @@ class ProjectNode : public PlanNode {
public:
ProjectNode(const PlanNodeId& id,
std::vector<FieldId>& field_ids,
std::vector<std::string>& field_names,
std::vector<milvus::DataType>& field_types,
std::vector<PlanNodePtr> sources = std::vector<PlanNodePtr>{})
: PlanNode(id), sources_(std::move(sources)), field_ids_(std::move(field_ids)){
: PlanNode(id),
sources_(std::move(sources)),
field_ids_(std::move(field_ids)),
output_type_(std::make_shared<RowType>(std::move(field_names), std::move(field_types))){
}

std::vector<PlanNodePtr>
Expand All @@ -286,7 +291,7 @@ class ProjectNode : public PlanNode {

RowTypePtr
output_type() const override {
return RowType::None;
return output_type_;
}

std::string_view
Expand All @@ -307,6 +312,7 @@ class ProjectNode : public PlanNode {
private:
const std::vector<PlanNodePtr> sources_;
const std::vector<FieldId> field_ids_;
const RowTypePtr output_type_;
};

class MvccNode : public PlanNode {
Expand Down Expand Up @@ -470,13 +476,12 @@ class AggregationNode: public PlanNode {
std::vector<expr::FieldAccessTypeExprPtr>&& groupingKeys,
std::vector<std::string>&& aggNames,
std::vector<Aggregate>&& aggregates,
std::shared_ptr<const RowType> output_type,
std::vector<PlanNodePtr> sources = std::vector<PlanNodePtr>{});


RowTypePtr
output_type() const override {
return RowType::None;
return output_type_;
}

std::vector<PlanNodePtr> sources() const override {
Expand Down
Loading

0 comments on commit 22ad71b

Please sign in to comment.