Skip to content

Commit

Permalink
create projection under aggregate node to enable CSE
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Sep 9, 2024
1 parent dc706ed commit 3384836
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.analysis.ProjectAggregateExpressionsForCse;
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
Expand Down Expand Up @@ -191,6 +192,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(new PushDownFilterThroughProject()),
custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION,
AggScalarSubQueryToWindowFunction::new),
topDown(new ProjectAggregateExpressionsForCse()),
bottomUp(
new EliminateUselessPlanUnderApply(),
// CorrelateApplyToUnCorrelateApply and ApplyToJoin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public Plan visit(Plan plan, CascadesContext context) {
List<Slot> childrenOutput = plan.children().stream().flatMap(p -> p.getOutput().stream()).collect(
Collectors.toList());
throw new AnalysisException("A expression contains slot not from children\n"
+ "Plan: " + plan + "\n"
+ "Plan: " + plan.treeString() + "\n"
+ "Children Output:" + childrenOutput + "\n"
+ "Slot: " + opt.get() + "\n");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public enum RuleType {
RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),
PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
HAVING_TO_FILTER(RuleTypeClass.REWRITE),
PROJECT_AGGREGATE_EXPRESSIONS_FOR_CSE(RuleTypeClass.REWRITE),
ONE_ROW_RELATION_EXTRACT_AGGREGATE(RuleTypeClass.REWRITE),
PROJECT_WITH_DISTINCT_TO_AGGREGATE(RuleTypeClass.REWRITE),
AVG_DISTINCT_TO_SUM_DIV_COUNT(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Plan pattern:
* agg(output[sum(A+B), sum(A+B+1)])
* =>
* agg(output[sum(#1), sum(#2)])
* +--->project(A+B as #1, A+B+1 as #2)
*
* after this transformation, we have the opportunity to extract
* common sub expression "A+B" by CommonSubExpressionOpt processor
*
* note:
* select sum(A), C+1, abs(C+1) from T group by C
* C+1 is not pushed down to bottom project, because C+1 is not agg output.
* after AggregateNormalize, the plan is:
* project: output C+1, abs(C+1)
* +-->agg: output sum(A), C
* +--->Scan
* C+1 is processed with the project above agg
*
*/
public class ProjectAggregateExpressionsForCse extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return logicalAggregate()
.then(this::addProjectionIfNeed)
.toRule(RuleType.PROJECT_AGGREGATE_EXPRESSIONS_FOR_CSE);
}

private LogicalAggregate<? extends Plan> addProjectionIfNeed(LogicalAggregate<? extends Plan> aggregate) {
// select sum(A+B), ...
// "A+B" is a cse candidate
// cseCandidates: A+B -> alias(A+B)
Map<Expression, Alias> cseCandidates = new HashMap<>();
Set<Slot> inputSlots = new HashSet<>();

for (Expression expr : aggregate.getExpressions()) {
getCseCandidatesFromAggregateFunction(expr, cseCandidates);
inputSlots.addAll(expr.getInputSlots());
}

if (cseCandidates.isEmpty()) {
// no opportunity to generate cse
return null;
}

// select sum(A+B),...
// slotMap: A+B -> alias(A+B) to slot#3
// sum(A+B) is replaced by sum(slot#3)
Map<Expression, Slot> slotMap = new HashMap<>();
for (Expression key : cseCandidates.keySet()) {
slotMap.put(key, cseCandidates.get(key).toSlot());
}
List<NamedExpression> aggOutputReplaced = new ArrayList<>();
for (NamedExpression expr : aggregate.getOutputExpressions()) {
aggOutputReplaced.add((NamedExpression) ExpressionUtils.replace(expr, slotMap));
}

if (aggregate.child() instanceof LogicalProject) {
LogicalProject<? extends Plan> project = (LogicalProject<? extends Plan>) aggregate.child();
List<NamedExpression> newProjections = Lists.newArrayList(project.getProjects());
newProjections.addAll(cseCandidates.values());
project = project.withProjectsAndChild(newProjections, (Plan) project.child());
aggregate = (LogicalAggregate<? extends Plan>) aggregate
.withAggOutput(aggOutputReplaced)
.withChildren(project);
} else {
List<NamedExpression> projections = new ArrayList<>();
projections.addAll(inputSlots);
projections.addAll(cseCandidates.values());
LogicalProject<? extends Plan> project = new LogicalProject<>(projections, aggregate.child(0));
aggregate = (LogicalAggregate<? extends Plan>) aggregate
.withAggOutput(aggOutputReplaced).withChildren(project);
}
return aggregate;
}

private void getCseCandidatesFromAggregateFunction(Expression expr, Map<Expression, Alias> result) {
if (expr instanceof AggregateFunction) {
for (Expression child : expr.children()) {
if (!(child instanceof SlotReference) && !child.isConstant()) {
if (child instanceof Alias) {
result.put(child, (Alias) child);
} else {
result.put(child, new Alias(child));
}
}
}
} else {
for (Expression child : expr.children()) {
if (!(child instanceof SlotReference) && !child.isConstant()) {
getCseCandidatesFromAggregateFunction(child, result);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
public List<? extends Expression> getExpressions() {
return new ImmutableList.Builder<Expression>()
.addAll(outputExpressions)
.addAll(groupByExpressions)
.build();
}

Expand Down
56 changes: 56 additions & 0 deletions regression-test/suites/nereids_tpch_p0/tpch/agg_cse.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

suite("agg_cse") {
String db = context.config.getDbNameByFile(new File(context.file.parent))
sql "use ${db}"
sql 'set enable_nereids_planner=true'
sql 'set enable_fallback_to_original_planner=false'

qt_select """
select sum(r_regionkey + r_regionkey), avg(r_regionkey + r_regionkey), sum(r_regionkey + (r_regionkey+1))
from region group by r_name;
"""
explain{
sql """
select sum(r_regionkey + r_regionkey), avg(r_regionkey + r_regionkey), sum(r_regionkey + (r_regionkey+1))
from region group by r_name;
"""
contains("intermediate projections:")
}
// expect plan: intermediate projections in OlapScanNode
// 0:VOlapScanNode(168)
// TABLE: tpch.region(region), PREAGGREGATION: ON
// partitions=1/1 (region)
// tablets=3/3, tabletList=135142,135144,135146
// cardinality=5, avgRowSize=2978.0, numNodes=1
// pushAggOp=NONE
// final projections: r_name[#3], ((r_regionkey + r_regionkey)[#5] + 1), (r_regionkey + r_regionkey)[#5]
// final project output tuple id: 2
// intermediate projections: R_NAME[#1], R_REGIONKEY[#0], (R_REGIONKEY[#0] + R_REGIONKEY[#0])
// intermediate tuple id: 1

explain{
sql """
select sum(r_regionkey), avg(r_regionkey), r_name
from region group by r_name;
"""
contains("intermediate projections:")
}
}

0 comments on commit 3384836

Please sign in to comment.