diff --git a/src/core/functions/table/match.cpp b/src/core/functions/table/match.cpp index 30b48b9d..80a092c2 100644 --- a/src/core/functions/table/match.cpp +++ b/src/core/functions/table/match.cpp @@ -394,7 +394,7 @@ unique_ptr PGQMatchFunction::GenerateCSROperatorSubquery(sha csr_operator_children.emplace_back(CreateColumnRefExpression("rowid", dst_table_alias)); csr_operator_children.emplace_back(CreateColumnRefExpression("rowid", edge_table_alias)); - csr_operator_children.emplace_back(CreateColumnRefExpression("cnt", "count_per_vertex")); + csr_operator_children.emplace_back(CreateColumnRefExpression("cnt", "__count_per_vertex")); select_node->select_list.emplace_back(make_uniq("csr_operator", std::move(csr_operator_children))); auto edge_src_joinref = make_uniq(JoinRefType::REGULAR); @@ -425,9 +425,19 @@ unique_ptr PGQMatchFunction::GenerateCSROperatorSubquery(sha left_join->condition = make_uniq(ExpressionType::COMPARE_EQUAL, CreateColumnRefExpression(edge_table->source_pk[0], src_table_alias), CreateColumnRefExpression(edge_table->source_fk[0], edge_table_alias)); edges_per_vertex_select_node->from_table = std::move(left_join); - // edges_per_vertex_select_node->groups = make_uniq(); - - + edges_per_vertex_select_node->groups.group_expressions.emplace_back(CreateColumnRefExpression("rowid", src_table_alias, src_table_alias + "_rowid")); + GroupingSet grouping_set = {0}; + edges_per_vertex_select_node->groups.grouping_sets.push_back(grouping_set); + + edges_per_vertex_select_statement->node = std::move(edges_per_vertex_select_node); + edges_per_vertex_subquery->subquery = std::move(edges_per_vertex_select_statement); + edges_per_vertex_subquery->alias = "__count_per_vertex"; + auto final_join = make_uniq(JoinRefType::REGULAR); + final_join->left = std::move(edge_dst_joinref); + final_join->right = std::move(edges_per_vertex_subquery); + final_join->condition = make_uniq(ExpressionType::COMPARE_EQUAL, CreateColumnRefExpression(src_table_alias + "_rowid", "__count_per_vertex"), CreateColumnRefExpression("rowid", src_table_alias)); + + select_node->from_table = std::move(final_join); select_statement->node = std::move(select_node); result->subquery = std::move(select_statement); return result;