diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index 57bef10867f7d..f0c2842d7f62f 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -233,14 +233,14 @@ class Schema { } DataType - FieldType(const FieldId& field_id) const { + GetFieldType(const FieldId& field_id) const { AssertInfo(fields_.count(field_id), "field_id:{} is not existed in the schema", field_id.get()); auto& meta = fields_.at(field_id); return meta.get_data_type(); } const std::string& - FieldName(const FieldId& field_id) const { + GetFieldName(const FieldId& field_id) const { AssertInfo(fields_.count(field_id), "field_id:{} is not existed in the schema", field_id.get()); auto& meta = fields_.at(field_id); return meta.get_name().get(); diff --git a/internal/core/src/exec/Driver.cpp b/internal/core/src/exec/Driver.cpp index 794bb2e4362b0..971768a26bb11 100644 --- a/internal/core/src/exec/Driver.cpp +++ b/internal/core/src/exec/Driver.cpp @@ -26,7 +26,8 @@ #include "exec/operator/MvccNode.h" #include "exec/operator/Operator.h" #include "exec/operator/VectorSearchNode.h" -#include "exec/operator/GroupByNode.h" +#include "exec/operator/SearchGroupByNode.h" +#include "exec/operator/QueryGroupByNode.h" #include "exec/Task.h" #include "common/EasyAssert.h" @@ -72,11 +73,14 @@ DriverFactory::CreateDriver(std::unique_ptr ctx, plannode)) { operators.push_back(std::make_unique( id, ctx.get(), vectorsearchnode)); - } else if (auto groupbynode = - std::dynamic_pointer_cast( + } else if (auto vectorGroupByNode = + std::dynamic_pointer_cast( plannode)) { operators.push_back( - std::make_unique(id, ctx.get(), groupbynode)); + std::make_unique(id, ctx.get(), vectorGroupByNode)); + } else if (auto queryGroupByNode = std::dynamic_pointer_cast(plannode)) { + operators.push_back( + std::make_unique(id, ctx.get(), queryGroupByNode)); } // TODO: add more operators } diff --git a/internal/core/src/exec/operator/QueryGroupByNode.cpp b/internal/core/src/exec/operator/QueryGroupByNode.cpp new file mode 100644 index 0000000000000..a2d04040d7e65 --- /dev/null +++ b/internal/core/src/exec/operator/QueryGroupByNode.cpp @@ -0,0 +1,20 @@ +// +// Created by hanchun on 24-10-18. +// + +#include "QueryGroupByNode.h" + +namespace milvus { +namespace exec { + +PhyQueryGroupByNode::PhyQueryGroupByNode(int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& node): + Operator(ctx, node->output_type(), operator_id, node->id()){ + +} + + +} +} + diff --git a/internal/core/src/exec/operator/QueryGroupByNode.h b/internal/core/src/exec/operator/QueryGroupByNode.h new file mode 100644 index 0000000000000..a62ff258c8d2c --- /dev/null +++ b/internal/core/src/exec/operator/QueryGroupByNode.h @@ -0,0 +1,62 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exec/operator/Operator.h" + +namespace milvus{ +namespace exec{ +class PhyQueryGroupByNode: public Operator { +public: + PhyQueryGroupByNode(int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& node); + + bool NeedInput() const override { + return true; + } + + void AddInput(RowVectorPtr& input) override { + + } + + RowVectorPtr + GetOutput() override { + return nullptr; + } + + bool + IsFinished() override { + return false; + } + + bool + IsFilter() override { + return false; + } + + BlockingReason + IsBlocked(ContinueFuture* future){ + return BlockingReason::kNotBlocked; + } + + virtual void + Close() { + input_ = nullptr; + results_.clear(); + } +}; +} +} diff --git a/internal/core/src/exec/operator/GroupByNode.cpp b/internal/core/src/exec/operator/SearchGroupByNode.cpp similarity index 92% rename from internal/core/src/exec/operator/GroupByNode.cpp rename to internal/core/src/exec/operator/SearchGroupByNode.cpp index b19461470d1e2..74b24d71980ee 100644 --- a/internal/core/src/exec/operator/GroupByNode.cpp +++ b/internal/core/src/exec/operator/SearchGroupByNode.cpp @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "GroupByNode.h" +#include "SearchGroupByNode.h" #include "exec/operator/groupby/SearchGroupByOperator.h" #include "monitor/prometheus_client.h" @@ -22,10 +22,10 @@ namespace milvus { namespace exec { -PhyGroupByNode::PhyGroupByNode( +PhySearchGroupByNode::PhySearchGroupByNode( int32_t operator_id, DriverContext* driverctx, - const std::shared_ptr& node) + const std::shared_ptr& node) : Operator(driverctx, node->output_type(), operator_id, node->id()) { ExecContext* exec_context = operator_context_->get_exec_context(); query_context_ = exec_context->get_query_context(); @@ -34,12 +34,12 @@ PhyGroupByNode::PhyGroupByNode( } void -PhyGroupByNode::AddInput(RowVectorPtr& input) { +PhySearchGroupByNode::AddInput(RowVectorPtr& input) { input_ = std::move(input); } RowVectorPtr -PhyGroupByNode::GetOutput() { +PhySearchGroupByNode::GetOutput() { if (is_finished_ || !no_more_input_) { return nullptr; } @@ -86,7 +86,7 @@ PhyGroupByNode::GetOutput() { } bool -PhyGroupByNode::IsFinished() { +PhySearchGroupByNode::IsFinished() { return is_finished_; } diff --git a/internal/core/src/exec/operator/GroupByNode.h b/internal/core/src/exec/operator/SearchGroupByNode.h similarity index 86% rename from internal/core/src/exec/operator/GroupByNode.h rename to internal/core/src/exec/operator/SearchGroupByNode.h index 90ce08832fd23..dc5c95bcd3369 100644 --- a/internal/core/src/exec/operator/GroupByNode.h +++ b/internal/core/src/exec/operator/SearchGroupByNode.h @@ -27,11 +27,11 @@ namespace milvus { namespace exec { -class PhyGroupByNode : public Operator { +class PhySearchGroupByNode : public Operator { public: - PhyGroupByNode(int32_t operator_id, - DriverContext* ctx, - const std::shared_ptr& node); + PhySearchGroupByNode(int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& node); bool IsFilter() override { @@ -63,7 +63,7 @@ class PhyGroupByNode : public Operator { virtual std::string ToString() const override { - return "PhyGroupByNode"; + return "PhySearchGroupByNode"; } private: diff --git a/internal/core/src/plan/PlanNode.h b/internal/core/src/plan/PlanNode.h index edda348a14c92..1d12e820d95ef 100644 --- a/internal/core/src/plan/PlanNode.h +++ b/internal/core/src/plan/PlanNode.h @@ -335,9 +335,9 @@ class VectorSearchNode : public PlanNode { const std::vector sources_; }; -class VectorGroupByNode : public PlanNode { +class SearchGroupByNode : public PlanNode { public: - VectorGroupByNode(const PlanNodeId& id, + SearchGroupByNode(const PlanNodeId& id, std::vector sources = std::vector{}) : PlanNode(id), sources_{std::move(sources)} { } @@ -354,12 +354,12 @@ class VectorGroupByNode : public PlanNode { std::string_view name() const override { - return "VectorGroupByNode"; + return "SearchGroupByNode"; } std::string ToString() const override { - return fmt::format("VectorGroupByNode:\n\t[source node:{}]", + return fmt::format("SearchGroupByNode:\n\t[source node:{}]", SourceToString()); } @@ -420,18 +420,33 @@ class AggregationNode: public PlanNode { Aggregate(expr::CallTypeExprPtr call):call_(call){} }; - std::vector sources() const override { - return sources_; - } - AggregationNode(const PlanNodeId& id, std::vector&& groupingKeys, std::vector&& aggNames, std::vector&& aggregates, - std::vector&& sources, - RowType&& output_type) + /*RowType&& output_type,*/ + std::vector sources = std::vector{}) : PlanNode(id), groupingKeys_(std::move(groupingKeys)), aggregateNames_(std::move(aggNames)), aggregates_(std::move(aggregates)), - sources_(std::move(sources)), output_type_(std::move(output_type)), ignoreNullKeys_(true){} + sources_(std::move(sources))/*, output_type_(std::move(output_type))*/, ignoreNullKeys_(true){} + + DataType + output_type() const override { + return DataType::BOOL; + } + + std::vector sources() const override { + return sources_; + } + + std::string + ToString() const override{ + return ""; + } + + std::string_view + name() const override { + return "agg"; + } private: const std::vector groupingKeys_; @@ -439,7 +454,7 @@ class AggregationNode: public PlanNode { const std::vector aggregates_; const bool ignoreNullKeys_; const std::vector sources_; - const RowType output_type_; + //const RowType output_type_; }; enum class ExecutionStrategy { diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index 9cffdaf4cf82a..3d711d866ffa5 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -34,7 +34,8 @@ std::string getAggregateOpName(planpb::AggregateOp op) { case planpb::avg: return "avg"; case planpb::min: return "min"; case planpb::max: return "max"; - default: return "unknown"; + default: + PanicInfo(OpTypeInvalid, "Unknown op type for aggregation"); } } @@ -134,7 +135,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { sources = std::vector{plannode}; if (plan_node->search_info_.group_by_field_id_ != std::nullopt) { - plannode = std::make_shared( + plannode = std::make_shared( milvus::plan::GetNextPlanNodeId(), sources); sources = std::vector{plannode}; } @@ -203,11 +204,12 @@ ProtoParser::RetrievePlanNodeFromProto( auto input_field_id = query.group_by_field_ids(i); AssertInfo(input_field_id > 0, "input field_id to group by must be positive, but is:{}", input_field_id); auto field_id = FieldId(input_field_id); - auto field_type = schema.FieldType(field_id); + auto field_type = schema.GetFieldType(field_id); groupingKeys.emplace_back(std::make_shared(field_type, field_id)); } } std::vector aggregates; + std::vector agg_names; if (query.aggregates_size() > 0) { aggregates.reserve(query.aggregates_size()); for(int i = 0; i < query.aggregates_size(); i++) { @@ -215,15 +217,19 @@ ProtoParser::RetrievePlanNodeFromProto( auto input_agg_field_id = aggregate.field_id(); AssertInfo(input_agg_field_id > 0, "input field_id to aggregate must be positive, but is:{}", input_agg_field_id); auto field_id = FieldId(input_agg_field_id); - auto field_type = schema.FieldType(field_id); - auto field_name = schema.FieldName(field_id); + auto field_type = schema.GetFieldType(field_id); + auto field_name = schema.GetFieldName(field_id); auto agg_name = getAggregateOpName(aggregate.op()); + agg_names.emplace_back(agg_name); auto agg_input = std::make_shared(field_type, field_name, field_id); auto call = std::make_shared(field_type, std::vector{agg_input}, agg_name); aggregates.emplace_back(plan::AggregationNode::Aggregate{call}); - //check type conversion here } } + + plannode = std::make_shared(milvus::plan::GetNextPlanNodeId(), std::move(groupingKeys), + std::move(agg_names), std::move(aggregates), std::move(sources)); + sources = std::vector{plannode}; node->plannodes_ = plannode; } return node;