Skip to content

Commit

Permalink
opt push agg through join on side
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Nov 6, 2024
1 parent 800f5c6 commit 9bc2486
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -88,15 +89,16 @@ public List<Rule> buildRules() {
})
.toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
// .when(agg -> agg.child().isAllSlots())
// .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child()
.child(0).children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum
|| (f instanceof Count && (!((Count) f).isCountStar()))) && !f.isDistinct()
&& f.child(0) instanceof Slot);
|| f instanceof Count) && !f.isDistinct()
&& (f.children().isEmpty() || f.child(0) instanceof Slot));
})
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
Expand All @@ -111,6 +113,7 @@ public List<Rule> buildRules() {
);
}


/**
* Push down Min/Max/Sum through join.
*/
Expand All @@ -119,21 +122,6 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
List<Slot> leftOutput = join.left().getOutput();
List<Slot> rightOutput = join.right().getOutput();

List<AggregateFunction> leftFuncs = new ArrayList<>();
List<AggregateFunction> rightFuncs = new ArrayList<>();
for (AggregateFunction func : agg.getAggregateFunctions()) {
Slot slot = (Slot) func.child(0);
if (leftOutput.contains(slot)) {
leftFuncs.add(func);
} else if (rightOutput.contains(slot)) {
rightFuncs.add(func);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}
if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) {
return null;
}

Set<Slot> leftGroupBy = new HashSet<>();
Set<Slot> rightGroupBy = new HashSet<>();
Expand All @@ -143,19 +131,74 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
leftGroupBy.add(slot);
} else if (rightOutput.contains(slot)) {
rightGroupBy.add(slot);
} else {
if (projects.isEmpty()) {
return null;
} else {
for (NamedExpression proj : projects) {
if (proj instanceof Alias && proj.toSlot().equals(slot)) {
for (Slot inSlot : proj.getInputSlots()) {
if (leftOutput.contains(inSlot)) {
leftGroupBy.add(inSlot);
} else if (rightOutput.contains(inSlot)) {
rightGroupBy.add(inSlot);
} else {
// dangling slot. should not come here
return null;
}
}
break;
}
}
}
}
}

List<AggregateFunction> leftFuncs = new ArrayList<>();
List<AggregateFunction> rightFuncs = new ArrayList<>();
Count countStar = null;
for (AggregateFunction func : agg.getAggregateFunctions()) {
if (func instanceof Count && ((Count) func).isCountStar()) {
countStar = (Count) func;
} else {
Slot slot = (Slot) func.child(0);
if (leftOutput.contains(slot)) {
leftFuncs.add(func);
} else if (rightOutput.contains(slot)) {
rightFuncs.add(func);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}
}
// determine count(*)
if (countStar != null) {
if (!leftGroupBy.isEmpty()) {
if (!leftFuncs.isEmpty() || rightFuncs.isEmpty()) {
countStar = (Count) countStar.withChildren(leftGroupBy.iterator().next());
leftFuncs.add(countStar);
}
} else if (!rightGroupBy.isEmpty()) {
countStar = (Count) countStar.withChildren(rightGroupBy.iterator().next());
rightFuncs.add(countStar);
} else {
return null;
}
}
join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {

if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) {
return null;
}

join.getConditionSlot().forEach(slot -> {
if (leftOutput.contains(slot)) {
leftGroupBy.add(slot);
} else if (rightOutput.contains(slot)) {
rightGroupBy.add(slot);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}));
});

Plan left = join.left();
Plan right = join.right();
Expand Down Expand Up @@ -196,6 +239,9 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
for (NamedExpression ne : agg.getOutputExpressions()) {
if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) {
AggregateFunction func = (AggregateFunction) ((Alias) ne).child();
if (func instanceof Count && ((Count) func).isCountStar()) {
func = countStar;
}
Slot slot = (Slot) func.child(0);
if (leftSlotToOutput.containsKey(slot)) {
Expression newFunc = replaceAggFunc(func, leftSlotToOutput.get(slot).toSlot());
Expand All @@ -210,9 +256,27 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
newOutputExprs.add(ne);
}
}
Plan newAggChild = newJoin;
if (agg.child() instanceof LogicalProject) {
LogicalProject project = (LogicalProject) agg.child();
List<NamedExpression> newProjections = Lists.newArrayList();
newProjections.addAll(project.getProjects());
Set<NamedExpression> leftDifference = new HashSet<NamedExpression>(left.getOutput());
leftDifference.removeAll(project.getProjects());
newProjections.addAll(leftDifference);
Set<NamedExpression> rightDifference = new HashSet<NamedExpression>(right.getOutput());
rightDifference.removeAll(project.getProjects());
newProjections.addAll(rightDifference);

newAggChild = ((LogicalProject) agg.child()).withProjectsAndChild(newProjections, newJoin);
}
// TODO: column prune project
return agg.withAggOutputChild(newOutputExprs, newJoin);
LogicalAggregate<Plan> newAgg = agg.withAggOutputChild(newOutputExprs, newAggChild);
if (checkOutput(newAgg)) {
return newAgg;
} else {
return (LogicalAggregate<Plan>) agg;
}
}

private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) {
Expand All @@ -222,4 +286,24 @@ private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot)
return func.withChildren(inputSlot);
}
}

private static boolean checkOutput(LogicalAggregate agg) {
if (agg.child() instanceof LogicalProject) {
Set<Slot> joinOutputs = ((Plan) agg.child().child(0)).getOutputSet();
if (!joinOutputs.containsAll(((LogicalProject<?>) agg.child()).getInputSlots())) {
return false;
}
Set<Slot> projectOutputs = ((LogicalProject<?>) agg.child()).getOutputSet();
if (!projectOutputs.containsAll(agg.getInputSlots())) {
return false;
}
return true;
} else {
Set<Slot> joinOutputs = ((Plan) agg.child()).getOutputSet();
if (!joinOutputs.containsAll(agg.getInputSlots())) {
return false;
}
return true;
}
}
}
Loading

0 comments on commit 9bc2486

Please sign in to comment.