diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 4ab4165a446733e..4848248e43d634e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -28,6 +28,7 @@ import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.analysis.ProjectAggregateExpressionsForCse; import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite; import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization; import org.apache.doris.nereids.rules.expression.ExpressionRewrite; @@ -191,6 +192,7 @@ public class Rewriter extends AbstractBatchJobExecutor { topDown(new PushDownFilterThroughProject()), custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION, AggScalarSubQueryToWindowFunction::new), + topDown(new ProjectAggregateExpressionsForCse()), bottomUp( new EliminateUselessPlanUnderApply(), // CorrelateApplyToUnCorrelateApply and ApplyToJoin diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java index 8ff3a43bf79e215..b6bded127489e4f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java @@ -80,7 +80,7 @@ public Plan visit(Plan plan, CascadesContext context) { List childrenOutput = plan.children().stream().flatMap(p -> p.getOutput().stream()).collect( Collectors.toList()); throw new AnalysisException("A expression contains slot not from children\n" - + "Plan: " + plan + "\n" + + "Plan: " + plan.treeString() + "\n" + "Children Output:" + childrenOutput + "\n" + "Slot: " + opt.get() + "\n"); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index ca26ab1d9f843c5..20b16e3d428d23d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -71,6 +71,7 @@ public enum RuleType { RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE), PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE), HAVING_TO_FILTER(RuleTypeClass.REWRITE), + PROJECT_AGGREGATE_EXPRESSIONS_FOR_CSE(RuleTypeClass.REWRITE), ONE_ROW_RELATION_EXTRACT_AGGREGATE(RuleTypeClass.REWRITE), PROJECT_WITH_DISTINCT_TO_AGGREGATE(RuleTypeClass.REWRITE), AVG_DISTINCT_TO_SUM_DIV_COUNT(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectAggregateExpressionsForCse.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectAggregateExpressionsForCse.java new file mode 100644 index 000000000000000..1bdeb684cbb555a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectAggregateExpressionsForCse.java @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.analysis; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Plan pattern: + * agg(output[sum(A+B), sum(A+B+1)]) + * => + * agg(output[sum(#1), sum(#2)]) + * +--->project(A+B as #1, A+B+1 as #2) + * + * after this transformation, we have the opportunity to extract + * common sub expression "A+B" by CommonSubExpressionOpt processor + * + * note: + * select sum(A), C+1, abs(C+1) from T group by C + * C+1 is not pushed down to bottom project, because C+1 is not agg output. + * after AggregateNormalize, the plan is: + * project: output C+1, abs(C+1) + * +-->agg: output sum(A), C + * +--->Scan + * C+1 is processed with the project above agg + * + */ +public class ProjectAggregateExpressionsForCse extends OneAnalysisRuleFactory { + @Override + public Rule build() { + return logicalAggregate() + .then(this::addProjectionIfNeed) + .toRule(RuleType.PROJECT_AGGREGATE_EXPRESSIONS_FOR_CSE); + } + + private LogicalAggregate addProjectionIfNeed(LogicalAggregate aggregate) { + // select sum(A+B), ... + // "A+B" is a cse candidate + // cseCandidates: A+B -> alias(A+B) + Map cseCandidates = new HashMap<>(); + Set inputSlots = new HashSet<>(); + + for (Expression expr : aggregate.getExpressions()) { + getCseCandidatesFromAggregateFunction(expr, cseCandidates); + inputSlots.addAll(expr.getInputSlots()); + } + + if (cseCandidates.isEmpty()) { + // no opportunity to generate cse + return null; + } + + // select sum(A+B),... + // slotMap: A+B -> alias(A+B) to slot#3 + // sum(A+B) is replaced by sum(slot#3) + Map slotMap = new HashMap<>(); + for (Expression key : cseCandidates.keySet()) { + slotMap.put(key, cseCandidates.get(key).toSlot()); + } + List aggOutputReplaced = new ArrayList<>(); + for (NamedExpression expr : aggregate.getOutputExpressions()) { + aggOutputReplaced.add((NamedExpression) ExpressionUtils.replace(expr, slotMap)); + } + + if (aggregate.child() instanceof LogicalProject) { + LogicalProject project = (LogicalProject) aggregate.child(); + List newProjections = Lists.newArrayList(project.getProjects()); + newProjections.addAll(cseCandidates.values()); + project = project.withProjectsAndChild(newProjections, (Plan) project.child()); + aggregate = (LogicalAggregate) aggregate + .withAggOutput(aggOutputReplaced) + .withChildren(project); + } else { + List projections = new ArrayList<>(); + projections.addAll(inputSlots); + projections.addAll(cseCandidates.values()); + LogicalProject project = new LogicalProject<>(projections, aggregate.child(0)); + aggregate = (LogicalAggregate) aggregate + .withAggOutput(aggOutputReplaced).withChildren(project); + } + return aggregate; + } + + private void getCseCandidatesFromAggregateFunction(Expression expr, Map result) { + if (expr instanceof AggregateFunction) { + for (Expression child : expr.children()) { + if (!(child instanceof SlotReference) && !child.isConstant()) { + if (child instanceof Alias) { + result.put(child, (Alias) child); + } else { + result.put(child, new Alias(child)); + } + } + } + } else { + for (Expression child : expr.children()) { + if (!(child instanceof SlotReference) && !child.isConstant()) { + getCseCandidatesFromAggregateFunction(child, result); + } + } + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 2798b1aef0102b2..28f7ff803079759 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -209,6 +209,7 @@ public R accept(PlanVisitor visitor, C context) { public List getExpressions() { return new ImmutableList.Builder() .addAll(outputExpressions) + .addAll(groupByExpressions) .build(); } diff --git a/regression-test/suites/nereids_tpch_p0/tpch/agg_cse.groovy b/regression-test/suites/nereids_tpch_p0/tpch/agg_cse.groovy new file mode 100644 index 000000000000000..31c74c5f9b810f9 --- /dev/null +++ b/regression-test/suites/nereids_tpch_p0/tpch/agg_cse.groovy @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +suite("agg_cse") { + String db = context.config.getDbNameByFile(new File(context.file.parent)) + sql "use ${db}" + sql 'set enable_nereids_planner=true' + sql 'set enable_fallback_to_original_planner=false' + + qt_select """ + select sum(r_regionkey + r_regionkey), avg(r_regionkey + r_regionkey), sum(r_regionkey + (r_regionkey+1)) + from region group by r_name; + """ + explain{ + sql """ + select sum(r_regionkey + r_regionkey), avg(r_regionkey + r_regionkey), sum(r_regionkey + (r_regionkey+1)) + from region group by r_name; + """ + contains("intermediate projections:") + } +// expect plan: intermediate projections in OlapScanNode +// 0:VOlapScanNode(168) +// TABLE: tpch.region(region), PREAGGREGATION: ON +// partitions=1/1 (region) +// tablets=3/3, tabletList=135142,135144,135146 +// cardinality=5, avgRowSize=2978.0, numNodes=1 +// pushAggOp=NONE +// final projections: r_name[#3], ((r_regionkey + r_regionkey)[#5] + 1), (r_regionkey + r_regionkey)[#5] +// final project output tuple id: 2 +// intermediate projections: R_NAME[#1], R_REGIONKEY[#0], (R_REGIONKEY[#0] + R_REGIONKEY[#0]) +// intermediate tuple id: 1 + + explain{ + sql """ + select sum(r_regionkey), avg(r_regionkey), r_name + from region group by r_name; + """ + contains("intermediate projections:") + } +}