Skip to content

Commit

Permalink
agg-cse
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Sep 6, 2024
1 parent b2ddc8d commit bd5e1b3
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
import org.apache.doris.nereids.rules.analysis.OneRowRelationExtractAggregate;
import org.apache.doris.nereids.rules.analysis.ProjectAggregateExpressionsForCse;
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate;
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
Expand Down Expand Up @@ -165,6 +166,7 @@ private static List<RewriteJob> buildAnalyzerJobs(Optional<CustomTableResolver>
topDown(new SimplifyAggGroupBy()),
// run BuildAggForRandomDistributedTable before NormalizeAggregate in order to optimize the agg plan
topDown(new BuildAggForRandomDistributedTable()),
topDown(new ProjectAggregateExpressionsForCse()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),
bottomUp(new SemiJoinCommute()),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package org.apache.doris.nereids.processor.post;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.clearspring.analytics.util.Lists;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* agg(output[sum(A+B), sum(A+B+1)])
* =>
* agg(output[sum(#1), sum(#2)])
* +--->project(A+B as #1, A+B+1 as #2)
* after this transformation, we have the opportunity to extract
* common sub expression "A+B" by CommonSubExpressionOpt processor
*
*/
public class AggCse extends PlanPostProcessor {
@Override
public PhysicalHashAggregate<? extends Plan> visitPhysicalHashAggregate(
PhysicalHashAggregate<? extends Plan> aggregate, CascadesContext ctx) {
aggregate = (PhysicalHashAggregate<? extends Plan>) super.visit(aggregate, ctx);

// do not increase network cost
if (aggregate.child(0) instanceof PhysicalDistribute) {
return aggregate;
}

// select sum(A+B), ...
// "A+B" is a cse candidate
// cseCandidates: A+B -> alias(A+B)
Map<Expression, Alias> cseCandidates = new HashMap<>();
Set<Slot> inputSlots = new HashSet<>();

// even with cse, the slots can not be replaced by cse
// example: select sum(f(a)), avg(f(a)), b from T group by b
// agg-->project(f(a), b)-->scan(T),
// in which 'b' must be reserved in project, but 'a' is not.
Set<Slot> reservedInputSlots = new HashSet<>();

for (Expression expr : aggregate.getExpressions()) {
boolean hasAggFunc = getCseCandidatesFromAggregateFunction(expr, cseCandidates);
if (!(expr instanceof SlotReference) && !(expr.isConstant()) && !hasAggFunc) {
// select sum(A+B), C+1, abs(C+1) from T group by C+1
// C+1 is cse candidate
cseCandidates.put(expr, new Alias(expr));
}

if (expr instanceof SlotReference) {
reservedInputSlots.add((SlotReference) expr);
}
inputSlots.addAll(expr.getInputSlots());
}
if (cseCandidates.isEmpty()) {
return aggregate;
}

// select sum(A+B),...
// slotMap: A+B -> alias(A+B) to slot#3
// sum(A+B) is replaced by sum(slot#3)
Map<Expression, Slot> slotMap = new HashMap<>();
for (Expression key : cseCandidates.keySet()) {
slotMap.put(key, cseCandidates.get(key).toSlot());
}
List<NamedExpression> aggOutputReplaced = new ArrayList<>();
for (NamedExpression expr: aggregate.getOutputExpressions()) {
aggOutputReplaced.add((NamedExpression) ExpressionUtils.replace(expr, slotMap));
}

if (aggregate.child() instanceof PhysicalProject) {
PhysicalProject project = (PhysicalProject) aggregate.child();
List<NamedExpression> newProjections = Lists.newArrayList(project.getProjects());
newProjections.addAll(cseCandidates.values());
project = project.withProjectionsAndChild(newProjections, (Plan) project.child());
aggregate = (PhysicalHashAggregate<? extends Plan>) aggregate
.withAggOutput(aggOutputReplaced)
.withChildren(project);
} else {
List<NamedExpression> projections = new ArrayList<>();
projections.addAll(inputSlots);
projections.addAll(cseCandidates.values());
PhysicalProject project = new PhysicalProject(projections, aggregate.child(0).getLogicalProperties(),
(Plan) aggregate.child(0));
aggregate = (PhysicalHashAggregate<? extends Plan>) aggregate
.withAggOutput(aggOutputReplaced).withChildren(project);
}
return aggregate;
}

private boolean getCseCandidatesFromAggregateFunction(Expression expr, Map<Expression, Alias> result) {
if (expr instanceof AggregateFunction) {
for (Expression child : expr.children()) {
if (!(child instanceof SlotReference) && !child.isConstant()) {
result.put(child, new Alias(child));
}
}
return true;
} else {
boolean hasAggFunc = false;
for (Expression child : expr.children()) {
if (!(child instanceof SlotReference) && !child.isConstant()) {
hasAggFunc |= getCseCandidatesFromAggregateFunction(child, result);
}
}
return hasAggFunc;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;

import com.google.common.collect.Lists;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public List<PlanPostProcessor> getProcessors() {
builder.add(new MergeProjectPostProcessor());
builder.add(new RecomputeLogicalPropertiesProcessor());
builder.add(new AddOffsetIntoDistribute());
// builder.add(new AggCse());
builder.add(new CommonSubExpressionOpt());
// DO NOT replace PLAN NODE from here
if (cascadesContext.getConnectContext().getSessionVariable().pushTopnToAgg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public Plan visit(Plan plan, CascadesContext context) {
List<Slot> childrenOutput = plan.children().stream().flatMap(p -> p.getOutput().stream()).collect(
Collectors.toList());
throw new AnalysisException("A expression contains slot not from children\n"
+ "Plan: " + plan + "\n"
+ "Plan: " + plan.treeString() + "\n"
+ "Children Output:" + childrenOutput + "\n"
+ "Slot: " + opt.get() + "\n");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public enum RuleType {
RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),
PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
HAVING_TO_FILTER(RuleTypeClass.REWRITE),
PROJECT_AGGREGATE_EXPRESSIONS_FOR_CSE(RuleTypeClass.REWRITE),
ONE_ROW_RELATION_EXTRACT_AGGREGATE(RuleTypeClass.REWRITE),
PROJECT_WITH_DISTINCT_TO_AGGREGATE(RuleTypeClass.REWRITE),
AVG_DISTINCT_TO_SUM_DIV_COUNT(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package org.apache.doris.nereids.rules.analysis;


import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.clearspring.analytics.util.Lists;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Plan pattern:
* agg(output[sum(A+B), sum(A+B+1)])
* =>
* agg(output[sum(#1), sum(#2)])
* +--->project(A+B as #1, A+B+1 as #2)
*
* after this transformation, we have the opportunity to extract
* common sub expression "A+B" by CommonSubExpressionOpt processor
*
* note:
* select sum(A), C+1, abs(C+1) from T group by C
* C+1 is not pushed down to bottom project, because C+1 is not agg output.
* after AggregateNormalize, the plan is:
* project: output C+1, abs(C+1)
* +-->agg: output sum(A), C
* +--->Scan
* C+1 is processed with the project above agg
*
*/
public class ProjectAggregateExpressionsForCse extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return logicalAggregate()
.then(this::addProjectionIfNeed)
.toRule(RuleType.PROJECT_AGGREGATE_EXPRESSIONS_FOR_CSE);
}

private LogicalAggregate<? extends Plan> addProjectionIfNeed(LogicalAggregate<? extends Plan> aggregate) {
// select sum(A+B), ...
// "A+B" is a cse candidate
// cseCandidates: A+B -> alias(A+B)
Map<Expression, Alias> cseCandidates = new HashMap<>();
Set<Slot> inputSlots = new HashSet<>();

// even with cse, the slots can not be replaced by cse
// example: select sum(f(a)), avg(f(a)), b from T group by b
// agg-->project(f(a), b)-->scan(T),
// in which 'b' must be reserved in project, but 'a' is not.
Set<Slot> reservedInputSlots = new HashSet<>();

for (Expression expr : aggregate.getExpressions()) {
getCseCandidatesFromAggregateFunction(expr, cseCandidates);

if (expr instanceof SlotReference) {
reservedInputSlots.add((SlotReference) expr);
}
inputSlots.addAll(expr.getInputSlots());
}
if (cseCandidates.isEmpty()) {
return null;
}

// select sum(A+B),...
// slotMap: A+B -> alias(A+B) to slot#3
// sum(A+B) is replaced by sum(slot#3)
Map<Expression, Slot> slotMap = new HashMap<>();
for (Expression key : cseCandidates.keySet()) {
slotMap.put(key, cseCandidates.get(key).toSlot());
}
List<NamedExpression> aggOutputReplaced = new ArrayList<>();
for (NamedExpression expr : aggregate.getOutputExpressions()) {
aggOutputReplaced.add((NamedExpression) ExpressionUtils.replace(expr, slotMap));
}

if (aggregate.child() instanceof LogicalProject) {
LogicalProject<? extends Plan> project = (LogicalProject<? extends Plan>) aggregate.child();
List<NamedExpression> newProjections = Lists.newArrayList(project.getProjects());
newProjections.addAll(cseCandidates.values());
project = project.withProjectsAndChild(newProjections, (Plan) project.child());
aggregate = (LogicalAggregate<? extends Plan>) aggregate
.withAggOutput(aggOutputReplaced)
.withChildren(project);
} else {
List<NamedExpression> projections = new ArrayList<>();
projections.addAll(inputSlots);
projections.addAll(cseCandidates.values());
LogicalProject<? extends Plan> project = new LogicalProject<>(projections, aggregate.child(0));
aggregate = (LogicalAggregate<? extends Plan>) aggregate
.withAggOutput(aggOutputReplaced).withChildren(project);
}
return aggregate;
}

private void getCseCandidatesFromAggregateFunction(Expression expr, Map<Expression, Alias> result) {
if (expr instanceof AggregateFunction) {
for (Expression child : expr.children()) {
if (!(child instanceof SlotReference) && !child.isConstant()) {
if (child instanceof Alias) {
result.put(child, (Alias)child);
} else {
result.put(child, new Alias(child));
}
}
}
} else {
for (Expression child : expr.children()) {
if (!(child instanceof SlotReference) && !child.isConstant()) {
getCseCandidatesFromAggregateFunction(child, result);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
public List<? extends Expression> getExpressions() {
return new ImmutableList.Builder<Expression>()
.addAll(outputExpressions)
.addAll(groupByExpressions)
.build();
}

Expand Down

0 comments on commit bd5e1b3

Please sign in to comment.