Skip to content

Commit

Permalink
enhance: adding agg operator(milvus-io#37009)
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han committed Oct 21, 2024
1 parent 66c555f commit c4c58aa
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 35 deletions.
4 changes: 2 additions & 2 deletions internal/core/src/common/Schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,14 @@ class Schema {
}

DataType
FieldType(const FieldId& field_id) const {
GetFieldType(const FieldId& field_id) const {
AssertInfo(fields_.count(field_id), "field_id:{} is not existed in the schema", field_id.get());
auto& meta = fields_.at(field_id);
return meta.get_data_type();
}

const std::string&
FieldName(const FieldId& field_id) const {
GetFieldName(const FieldId& field_id) const {
AssertInfo(fields_.count(field_id), "field_id:{} is not existed in the schema", field_id.get());
auto& meta = fields_.at(field_id);
return meta.get_name().get();
Expand Down
12 changes: 8 additions & 4 deletions internal/core/src/exec/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
#include "exec/operator/MvccNode.h"
#include "exec/operator/Operator.h"
#include "exec/operator/VectorSearchNode.h"
#include "exec/operator/GroupByNode.h"
#include "exec/operator/SearchGroupByNode.h"
#include "exec/operator/QueryGroupByNode.h"
#include "exec/Task.h"

#include "common/EasyAssert.h"
Expand Down Expand Up @@ -72,11 +73,14 @@ DriverFactory::CreateDriver(std::unique_ptr<DriverContext> ctx,
plannode)) {
operators.push_back(std::make_unique<PhyVectorSearchNode>(
id, ctx.get(), vectorsearchnode));
} else if (auto groupbynode =
std::dynamic_pointer_cast<const plan::VectorGroupByNode>(
} else if (auto vectorGroupByNode =
std::dynamic_pointer_cast<const plan::SearchGroupByNode>(
plannode)) {
operators.push_back(
std::make_unique<PhyGroupByNode>(id, ctx.get(), groupbynode));
std::make_unique<PhySearchGroupByNode>(id, ctx.get(), vectorGroupByNode));
} else if (auto queryGroupByNode = std::dynamic_pointer_cast<const plan::AggregationNode>(plannode)) {
operators.push_back(
std::make_unique<PhyQueryGroupByNode>(id, ctx.get(), queryGroupByNode));
}
// TODO: add more operators
}
Expand Down
20 changes: 20 additions & 0 deletions internal/core/src/exec/operator/QueryGroupByNode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//
// Created by hanchun on 24-10-18.
//

#include "QueryGroupByNode.h"

namespace milvus {
namespace exec {

PhyQueryGroupByNode::PhyQueryGroupByNode(int32_t operator_id,
DriverContext* ctx,
const std::shared_ptr<const plan::AggregationNode>& node):
Operator(ctx, node->output_type(), operator_id, node->id()){

}


}
}

62 changes: 62 additions & 0 deletions internal/core/src/exec/operator/QueryGroupByNode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "exec/operator/Operator.h"

namespace milvus{
namespace exec{
class PhyQueryGroupByNode: public Operator {
public:
PhyQueryGroupByNode(int32_t operator_id,
DriverContext* ctx,
const std::shared_ptr<const plan::AggregationNode>& node);

bool NeedInput() const override {
return true;
}

void AddInput(RowVectorPtr& input) override {

}

RowVectorPtr
GetOutput() override {
return nullptr;
}

bool
IsFinished() override {
return false;
}

bool
IsFilter() override {
return false;
}

BlockingReason
IsBlocked(ContinueFuture* future){
return BlockingReason::kNotBlocked;
}

virtual void
Close() {
input_ = nullptr;
results_.clear();
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "GroupByNode.h"
#include "SearchGroupByNode.h"

#include "exec/operator/groupby/SearchGroupByOperator.h"
#include "monitor/prometheus_client.h"

namespace milvus {
namespace exec {

PhyGroupByNode::PhyGroupByNode(
PhySearchGroupByNode::PhySearchGroupByNode(
int32_t operator_id,
DriverContext* driverctx,
const std::shared_ptr<const plan::VectorGroupByNode>& node)
const std::shared_ptr<const plan::SearchGroupByNode>& node)
: Operator(driverctx, node->output_type(), operator_id, node->id()) {
ExecContext* exec_context = operator_context_->get_exec_context();
query_context_ = exec_context->get_query_context();
Expand All @@ -34,12 +34,12 @@ PhyGroupByNode::PhyGroupByNode(
}

void
PhyGroupByNode::AddInput(RowVectorPtr& input) {
PhySearchGroupByNode::AddInput(RowVectorPtr& input) {
input_ = std::move(input);
}

RowVectorPtr
PhyGroupByNode::GetOutput() {
PhySearchGroupByNode::GetOutput() {
if (is_finished_ || !no_more_input_) {
return nullptr;
}
Expand Down Expand Up @@ -86,7 +86,7 @@ PhyGroupByNode::GetOutput() {
}

bool
PhyGroupByNode::IsFinished() {
PhySearchGroupByNode::IsFinished() {
return is_finished_;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
namespace milvus {
namespace exec {

class PhyGroupByNode : public Operator {
class PhySearchGroupByNode : public Operator {
public:
PhyGroupByNode(int32_t operator_id,
DriverContext* ctx,
const std::shared_ptr<const plan::VectorGroupByNode>& node);
PhySearchGroupByNode(int32_t operator_id,
DriverContext* ctx,
const std::shared_ptr<const plan::SearchGroupByNode>& node);

bool
IsFilter() override {
Expand Down Expand Up @@ -63,7 +63,7 @@ class PhyGroupByNode : public Operator {

virtual std::string
ToString() const override {
return "PhyGroupByNode";
return "PhySearchGroupByNode";
}

private:
Expand Down
39 changes: 27 additions & 12 deletions internal/core/src/plan/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,9 @@ class VectorSearchNode : public PlanNode {
const std::vector<PlanNodePtr> sources_;
};

class VectorGroupByNode : public PlanNode {
class SearchGroupByNode : public PlanNode {
public:
VectorGroupByNode(const PlanNodeId& id,
SearchGroupByNode(const PlanNodeId& id,
std::vector<PlanNodePtr> sources = std::vector<PlanNodePtr>{})
: PlanNode(id), sources_{std::move(sources)} {
}
Expand All @@ -354,12 +354,12 @@ class VectorGroupByNode : public PlanNode {

std::string_view
name() const override {
return "VectorGroupByNode";
return "SearchGroupByNode";
}

std::string
ToString() const override {
return fmt::format("VectorGroupByNode:\n\t[source node:{}]",
return fmt::format("SearchGroupByNode:\n\t[source node:{}]",
SourceToString());
}

Expand Down Expand Up @@ -420,26 +420,41 @@ class AggregationNode: public PlanNode {
Aggregate(expr::CallTypeExprPtr call):call_(call){}
};

std::vector<PlanNodePtr> sources() const override {
return sources_;
}

AggregationNode(const PlanNodeId& id,
std::vector<expr::FieldAccessTypeExprPtr>&& groupingKeys,
std::vector<std::string>&& aggNames,
std::vector<Aggregate>&& aggregates,
std::vector<PlanNodePtr>&& sources,
RowType&& output_type)
/*RowType&& output_type,*/
std::vector<PlanNodePtr> sources = std::vector<PlanNodePtr>{})
: PlanNode(id), groupingKeys_(std::move(groupingKeys)), aggregateNames_(std::move(aggNames)), aggregates_(std::move(aggregates)),
sources_(std::move(sources)), output_type_(std::move(output_type)), ignoreNullKeys_(true){}
sources_(std::move(sources))/*, output_type_(std::move(output_type))*/, ignoreNullKeys_(true){}

DataType
output_type() const override {
return DataType::BOOL;
}

std::vector<PlanNodePtr> sources() const override {
return sources_;
}

std::string
ToString() const override{
return "";
}

std::string_view
name() const override {
return "agg";
}

private:
const std::vector<expr::FieldAccessTypeExprPtr> groupingKeys_;
const std::vector<std::string> aggregateNames_;
const std::vector<Aggregate> aggregates_;
const bool ignoreNullKeys_;
const std::vector<PlanNodePtr> sources_;
const RowType output_type_;
//const RowType output_type_;
};

enum class ExecutionStrategy {
Expand Down
18 changes: 12 additions & 6 deletions internal/core/src/query/PlanProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ std::string getAggregateOpName(planpb::AggregateOp op) {
case planpb::avg: return "avg";
case planpb::min: return "min";
case planpb::max: return "max";
default: return "unknown";
default:
PanicInfo(OpTypeInvalid, "Unknown op type for aggregation");
}
}

Expand Down Expand Up @@ -134,7 +135,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};

if (plan_node->search_info_.group_by_field_id_ != std::nullopt) {
plannode = std::make_shared<milvus::plan::VectorGroupByNode>(
plannode = std::make_shared<milvus::plan::SearchGroupByNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}
Expand Down Expand Up @@ -203,27 +204,32 @@ ProtoParser::RetrievePlanNodeFromProto(
auto input_field_id = query.group_by_field_ids(i);
AssertInfo(input_field_id > 0, "input field_id to group by must be positive, but is:{}", input_field_id);
auto field_id = FieldId(input_field_id);
auto field_type = schema.FieldType(field_id);
auto field_type = schema.GetFieldType(field_id);
groupingKeys.emplace_back(std::make_shared<const expr::FieldAccessTypeExpr>(field_type, field_id));
}
}
std::vector<plan::AggregationNode::Aggregate> aggregates;
std::vector<std::string> agg_names;
if (query.aggregates_size() > 0) {
aggregates.reserve(query.aggregates_size());
for(int i = 0; i < query.aggregates_size(); i++) {
auto aggregate = query.aggregates(i);
auto input_agg_field_id = aggregate.field_id();
AssertInfo(input_agg_field_id > 0, "input field_id to aggregate must be positive, but is:{}", input_agg_field_id);
auto field_id = FieldId(input_agg_field_id);
auto field_type = schema.FieldType(field_id);
auto field_name = schema.FieldName(field_id);
auto field_type = schema.GetFieldType(field_id);
auto field_name = schema.GetFieldName(field_id);
auto agg_name = getAggregateOpName(aggregate.op());
agg_names.emplace_back(agg_name);
auto agg_input = std::make_shared<expr::FieldAccessTypeExpr>(field_type, field_name, field_id);
auto call = std::make_shared<const expr::CallTypeExpr>(field_type, std::vector<expr::TypedExprPtr>{agg_input}, agg_name);
aggregates.emplace_back(plan::AggregationNode::Aggregate{call});
//check type conversion here
}
}

plannode = std::make_shared<plan::AggregationNode>(milvus::plan::GetNextPlanNodeId(), std::move(groupingKeys),
std::move(agg_names), std::move(aggregates), std::move(sources));
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
node->plannodes_ = plannode;
}
return node;
Expand Down

0 comments on commit c4c58aa

Please sign in to comment.