forked from apache/doris
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
262 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
124 changes: 124 additions & 0 deletions
124
fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/AggCse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
package org.apache.doris.nereids.processor.post; | ||
|
||
import org.apache.doris.nereids.CascadesContext; | ||
import org.apache.doris.nereids.trees.expressions.Alias; | ||
import org.apache.doris.nereids.trees.expressions.Expression; | ||
import org.apache.doris.nereids.trees.expressions.NamedExpression; | ||
import org.apache.doris.nereids.trees.expressions.Slot; | ||
import org.apache.doris.nereids.trees.expressions.SlotReference; | ||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; | ||
import org.apache.doris.nereids.trees.plans.Plan; | ||
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; | ||
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; | ||
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; | ||
import org.apache.doris.nereids.util.ExpressionUtils; | ||
|
||
import com.clearspring.analytics.util.Lists; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.HashSet; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Set; | ||
|
||
/** | ||
* agg(output[sum(A+B), sum(A+B+1)]) | ||
* => | ||
* agg(output[sum(#1), sum(#2)]) | ||
* +--->project(A+B as #1, A+B+1 as #2) | ||
* after this transformation, we have the opportunity to extract | ||
* common sub expression "A+B" by CommonSubExpressionOpt processor | ||
* | ||
*/ | ||
public class AggCse extends PlanPostProcessor { | ||
@Override | ||
public PhysicalHashAggregate<? extends Plan> visitPhysicalHashAggregate( | ||
PhysicalHashAggregate<? extends Plan> aggregate, CascadesContext ctx) { | ||
aggregate = (PhysicalHashAggregate<? extends Plan>) super.visit(aggregate, ctx); | ||
|
||
// do not increase network cost | ||
if (aggregate.child(0) instanceof PhysicalDistribute) { | ||
return aggregate; | ||
} | ||
|
||
// select sum(A+B), ... | ||
// "A+B" is a cse candidate | ||
// cseCandidates: A+B -> alias(A+B) | ||
Map<Expression, Alias> cseCandidates = new HashMap<>(); | ||
Set<Slot> inputSlots = new HashSet<>(); | ||
|
||
// even with cse, the slots can not be replaced by cse | ||
// example: select sum(f(a)), avg(f(a)), b from T group by b | ||
// agg-->project(f(a), b)-->scan(T), | ||
// in which 'b' must be reserved in project, but 'a' is not. | ||
Set<Slot> reservedInputSlots = new HashSet<>(); | ||
|
||
for (Expression expr : aggregate.getExpressions()) { | ||
boolean hasAggFunc = getCseCandidatesFromAggregateFunction(expr, cseCandidates); | ||
if (!(expr instanceof SlotReference) && !(expr.isConstant()) && !hasAggFunc) { | ||
// select sum(A+B), C+1, abs(C+1) from T group by C+1 | ||
// C+1 is cse candidate | ||
cseCandidates.put(expr, new Alias(expr)); | ||
} | ||
|
||
if (expr instanceof SlotReference) { | ||
reservedInputSlots.add((SlotReference) expr); | ||
} | ||
inputSlots.addAll(expr.getInputSlots()); | ||
} | ||
if (cseCandidates.isEmpty()) { | ||
return aggregate; | ||
} | ||
|
||
// select sum(A+B),... | ||
// slotMap: A+B -> alias(A+B) to slot#3 | ||
// sum(A+B) is replaced by sum(slot#3) | ||
Map<Expression, Slot> slotMap = new HashMap<>(); | ||
for (Expression key : cseCandidates.keySet()) { | ||
slotMap.put(key, cseCandidates.get(key).toSlot()); | ||
} | ||
List<NamedExpression> aggOutputReplaced = new ArrayList<>(); | ||
for (NamedExpression expr: aggregate.getOutputExpressions()) { | ||
aggOutputReplaced.add((NamedExpression) ExpressionUtils.replace(expr, slotMap)); | ||
} | ||
|
||
if (aggregate.child() instanceof PhysicalProject) { | ||
PhysicalProject project = (PhysicalProject) aggregate.child(); | ||
List<NamedExpression> newProjections = Lists.newArrayList(project.getProjects()); | ||
newProjections.addAll(cseCandidates.values()); | ||
project = project.withProjectionsAndChild(newProjections, (Plan) project.child()); | ||
aggregate = (PhysicalHashAggregate<? extends Plan>) aggregate | ||
.withAggOutput(aggOutputReplaced) | ||
.withChildren(project); | ||
} else { | ||
List<NamedExpression> projections = new ArrayList<>(); | ||
projections.addAll(inputSlots); | ||
projections.addAll(cseCandidates.values()); | ||
PhysicalProject project = new PhysicalProject(projections, aggregate.child(0).getLogicalProperties(), | ||
(Plan) aggregate.child(0)); | ||
aggregate = (PhysicalHashAggregate<? extends Plan>) aggregate | ||
.withAggOutput(aggOutputReplaced).withChildren(project); | ||
} | ||
return aggregate; | ||
} | ||
|
||
private boolean getCseCandidatesFromAggregateFunction(Expression expr, Map<Expression, Alias> result) { | ||
if (expr instanceof AggregateFunction) { | ||
for (Expression child : expr.children()) { | ||
if (!(child instanceof SlotReference) && !child.isConstant()) { | ||
result.put(child, new Alias(child)); | ||
} | ||
} | ||
return true; | ||
} else { | ||
boolean hasAggFunc = false; | ||
for (Expression child : expr.children()) { | ||
if (!(child instanceof SlotReference) && !child.isConstant()) { | ||
hasAggFunc |= getCseCandidatesFromAggregateFunction(child, result); | ||
} | ||
} | ||
return hasAggFunc; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
131 changes: 131 additions & 0 deletions
131
.../main/java/org/apache/doris/nereids/rules/analysis/ProjectAggregateExpressionsForCse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
package org.apache.doris.nereids.rules.analysis; | ||
|
||
|
||
import org.apache.doris.nereids.rules.Rule; | ||
import org.apache.doris.nereids.rules.RuleType; | ||
import org.apache.doris.nereids.trees.expressions.Alias; | ||
import org.apache.doris.nereids.trees.expressions.Expression; | ||
import org.apache.doris.nereids.trees.expressions.NamedExpression; | ||
import org.apache.doris.nereids.trees.expressions.Slot; | ||
import org.apache.doris.nereids.trees.expressions.SlotReference; | ||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; | ||
import org.apache.doris.nereids.trees.plans.Plan; | ||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; | ||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject; | ||
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; | ||
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; | ||
import org.apache.doris.nereids.util.ExpressionUtils; | ||
|
||
import com.clearspring.analytics.util.Lists; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.HashSet; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Set; | ||
|
||
/** | ||
* Plan pattern: | ||
* agg(output[sum(A+B), sum(A+B+1)]) | ||
* => | ||
* agg(output[sum(#1), sum(#2)]) | ||
* +--->project(A+B as #1, A+B+1 as #2) | ||
* | ||
* after this transformation, we have the opportunity to extract | ||
* common sub expression "A+B" by CommonSubExpressionOpt processor | ||
* | ||
* note: | ||
* select sum(A), C+1, abs(C+1) from T group by C | ||
* C+1 is not pushed down to bottom project, because C+1 is not agg output. | ||
* after AggregateNormalize, the plan is: | ||
* project: output C+1, abs(C+1) | ||
* +-->agg: output sum(A), C | ||
* +--->Scan | ||
* C+1 is processed with the project above agg | ||
* | ||
*/ | ||
public class ProjectAggregateExpressionsForCse extends OneAnalysisRuleFactory { | ||
@Override | ||
public Rule build() { | ||
return logicalAggregate() | ||
.then(this::addProjectionIfNeed) | ||
.toRule(RuleType.PROJECT_AGGREGATE_EXPRESSIONS_FOR_CSE); | ||
} | ||
|
||
private LogicalAggregate<? extends Plan> addProjectionIfNeed(LogicalAggregate<? extends Plan> aggregate) { | ||
// select sum(A+B), ... | ||
// "A+B" is a cse candidate | ||
// cseCandidates: A+B -> alias(A+B) | ||
Map<Expression, Alias> cseCandidates = new HashMap<>(); | ||
Set<Slot> inputSlots = new HashSet<>(); | ||
|
||
// even with cse, the slots can not be replaced by cse | ||
// example: select sum(f(a)), avg(f(a)), b from T group by b | ||
// agg-->project(f(a), b)-->scan(T), | ||
// in which 'b' must be reserved in project, but 'a' is not. | ||
Set<Slot> reservedInputSlots = new HashSet<>(); | ||
|
||
for (Expression expr : aggregate.getExpressions()) { | ||
getCseCandidatesFromAggregateFunction(expr, cseCandidates); | ||
|
||
if (expr instanceof SlotReference) { | ||
reservedInputSlots.add((SlotReference) expr); | ||
} | ||
inputSlots.addAll(expr.getInputSlots()); | ||
} | ||
if (cseCandidates.isEmpty()) { | ||
return null; | ||
} | ||
|
||
// select sum(A+B),... | ||
// slotMap: A+B -> alias(A+B) to slot#3 | ||
// sum(A+B) is replaced by sum(slot#3) | ||
Map<Expression, Slot> slotMap = new HashMap<>(); | ||
for (Expression key : cseCandidates.keySet()) { | ||
slotMap.put(key, cseCandidates.get(key).toSlot()); | ||
} | ||
List<NamedExpression> aggOutputReplaced = new ArrayList<>(); | ||
for (NamedExpression expr : aggregate.getOutputExpressions()) { | ||
aggOutputReplaced.add((NamedExpression) ExpressionUtils.replace(expr, slotMap)); | ||
} | ||
|
||
if (aggregate.child() instanceof LogicalProject) { | ||
LogicalProject<? extends Plan> project = (LogicalProject<? extends Plan>) aggregate.child(); | ||
List<NamedExpression> newProjections = Lists.newArrayList(project.getProjects()); | ||
newProjections.addAll(cseCandidates.values()); | ||
project = project.withProjectsAndChild(newProjections, (Plan) project.child()); | ||
aggregate = (LogicalAggregate<? extends Plan>) aggregate | ||
.withAggOutput(aggOutputReplaced) | ||
.withChildren(project); | ||
} else { | ||
List<NamedExpression> projections = new ArrayList<>(); | ||
projections.addAll(inputSlots); | ||
projections.addAll(cseCandidates.values()); | ||
LogicalProject<? extends Plan> project = new LogicalProject<>(projections, aggregate.child(0)); | ||
aggregate = (LogicalAggregate<? extends Plan>) aggregate | ||
.withAggOutput(aggOutputReplaced).withChildren(project); | ||
} | ||
return aggregate; | ||
} | ||
|
||
private void getCseCandidatesFromAggregateFunction(Expression expr, Map<Expression, Alias> result) { | ||
if (expr instanceof AggregateFunction) { | ||
for (Expression child : expr.children()) { | ||
if (!(child instanceof SlotReference) && !child.isConstant()) { | ||
if (child instanceof Alias) { | ||
result.put(child, (Alias)child); | ||
} else { | ||
result.put(child, new Alias(child)); | ||
} | ||
} | ||
} | ||
} else { | ||
for (Expression child : expr.children()) { | ||
if (!(child instanceof SlotReference) && !child.isConstant()) { | ||
getCseCandidatesFromAggregateFunction(child, result); | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters