Skip to content

Commit

Permalink
enhance PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Nov 13, 2024
1 parent 936ad96 commit ccc5bc2
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -88,15 +89,16 @@ public List<Rule> buildRules() {
})
.toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
// .when(agg -> agg.child().isAllSlots())
// .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child()
.child(0).children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum
|| (f instanceof Count && (!((Count) f).isCountStar()))) && !f.isDistinct()
&& f.child(0) instanceof Slot);
|| f instanceof Count) && !f.isDistinct()
&& (f.children().isEmpty() || f.child(0) instanceof Slot));
})
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
Expand All @@ -111,6 +113,7 @@ public List<Rule> buildRules() {
);
}


/**
* Push down Min/Max/Sum through join.
*/
Expand All @@ -119,21 +122,6 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
List<Slot> leftOutput = join.left().getOutput();
List<Slot> rightOutput = join.right().getOutput();

List<AggregateFunction> leftFuncs = new ArrayList<>();
List<AggregateFunction> rightFuncs = new ArrayList<>();
for (AggregateFunction func : agg.getAggregateFunctions()) {
Slot slot = (Slot) func.child(0);
if (leftOutput.contains(slot)) {
leftFuncs.add(func);
} else if (rightOutput.contains(slot)) {
rightFuncs.add(func);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}
if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) {
return null;
}

Set<Slot> leftGroupBy = new HashSet<>();
Set<Slot> rightGroupBy = new HashSet<>();
Expand All @@ -144,18 +132,76 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
} else if (rightOutput.contains(slot)) {
rightGroupBy.add(slot);
} else {
return null;
if (projects.isEmpty()) {
// TODO: select ... from ... group by A , B, 1.2; 1.2 is constant
return null;
} else {
for (NamedExpression proj : projects) {
if (proj instanceof Alias && proj.toSlot().equals(slot)) {
Set<Slot> inputForAlias = proj.getInputSlots();
if (leftOutput.containsAll(inputForAlias)) {
leftGroupBy.addAll(inputForAlias);
} else if (rightOutput.containsAll(inputForAlias)) {
rightGroupBy.addAll(inputForAlias);
} else {
/*
groupBy(X)
+---> project( a + b as X)
--> join(output: T1.a, T2.b)
--> T1(a)
--> T2(b)
X can not be pushed
*/
return null;
}
break;
}
}
}
}
}
join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
if (leftOutput.contains(slot)) {
leftGroupBy.add(slot);
} else if (rightOutput.contains(slot)) {
rightGroupBy.add(slot);

List<AggregateFunction> leftFuncs = new ArrayList<>();
List<AggregateFunction> rightFuncs = new ArrayList<>();
Count countStar = null;
for (AggregateFunction func : agg.getAggregateFunctions()) {
if (func instanceof Count && ((Count) func).isCountStar()) {
countStar = (Count) func;
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
Slot slot = (Slot) func.child(0);
if (leftOutput.contains(slot)) {
leftFuncs.add(func);
} else if (rightOutput.contains(slot)) {
rightFuncs.add(func);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}
}
// determine count(*)
if (countStar != null) {
if (!leftGroupBy.isEmpty()) {
countStar = (Count) countStar.withChildren(leftGroupBy.iterator().next());
leftFuncs.add(countStar);
} else if (!rightGroupBy.isEmpty()) {
countStar = (Count) countStar.withChildren(rightGroupBy.iterator().next());
rightFuncs.add(countStar);
} else {
return null;
}
}
for (Expression condition : join.getHashJoinConjuncts()) {
for (Slot joinConditionSlot : condition.getInputSlots()) {
if (leftOutput.contains(joinConditionSlot)) {
leftGroupBy.add(joinConditionSlot);
} else if (rightOutput.contains(joinConditionSlot)) {
rightGroupBy.add(joinConditionSlot);
} else {
// apply failed
return null;
}
}
}));
}

Plan left = join.left();
Plan right = join.right();
Expand Down Expand Up @@ -196,6 +242,10 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
for (NamedExpression ne : agg.getOutputExpressions()) {
if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) {
AggregateFunction func = (AggregateFunction) ((Alias) ne).child();
if (func instanceof Count && ((Count) func).isCountStar()) {
// countStar is already rewritten as count(left_slot) or count(right_slot)
func = countStar;
}
Slot slot = (Slot) func.child(0);
if (leftSlotToOutput.containsKey(slot)) {
Expression newFunc = replaceAggFunc(func, leftSlotToOutput.get(slot).toSlot());
Expand All @@ -210,9 +260,27 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
newOutputExprs.add(ne);
}
}
Plan newAggChild = newJoin;
if (agg.child() instanceof LogicalProject) {
LogicalProject project = (LogicalProject) agg.child();
List<NamedExpression> newProjections = Lists.newArrayList();
newProjections.addAll(project.getProjects());
Set<NamedExpression> leftDifference = new HashSet<NamedExpression>(left.getOutput());
leftDifference.removeAll(project.getProjects());
newProjections.addAll(leftDifference);
Set<NamedExpression> rightDifference = new HashSet<NamedExpression>(right.getOutput());
rightDifference.removeAll(project.getProjects());
newProjections.addAll(rightDifference);

newAggChild = ((LogicalProject) agg.child()).withProjectsAndChild(newProjections, newJoin);
}
// TODO: column prune project
return agg.withAggOutputChild(newOutputExprs, newJoin);
LogicalAggregate<Plan> newAgg = agg.withAggOutputChild(newOutputExprs, newAggChild);
if (checkOutput(newAgg)) {
return newAgg;
} else {
return (LogicalAggregate<Plan>) agg;
}
}

private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) {
Expand All @@ -222,4 +290,24 @@ private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot)
return func.withChildren(inputSlot);
}
}

private static boolean checkOutput(LogicalAggregate agg) {
if (agg.child() instanceof LogicalProject) {
Set<Slot> joinOutputs = ((Plan) agg.child().child(0)).getOutputSet();
if (!joinOutputs.containsAll(((LogicalProject<?>) agg.child()).getInputSlots())) {
return false;
}
Set<Slot> projectOutputs = ((LogicalProject<?>) agg.child()).getOutputSet();
if (!projectOutputs.containsAll(agg.getInputSlots())) {
return false;
}
return true;
} else {
Set<Slot> joinOutputs = ((Plan) agg.child()).getOutputSet();
if (!joinOutputs.containsAll(agg.getInputSlots())) {
return false;
}
return true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1034,3 +1034,23 @@ Used:
UnUsed: use_push_down_agg_through_join_one_side
SyntaxError:

-- !shape --
PhysicalResultSink
--PhysicalTopN[MERGE_SORT]
----PhysicalTopN[LOCAL_SORT]
------hashAgg[GLOBAL]
--------hashAgg[LOCAL]
----------hashJoin[INNER_JOIN] hashCondition=((dwd_tracking_sensor_init_tmp_ymd.dt = dw_user_b2c_tracking_info_tmp_ymd.dt) and (dwd_tracking_sensor_init_tmp_ymd.guid = dw_user_b2c_tracking_info_tmp_ymd.guid)) otherCondition=((dwd_tracking_sensor_init_tmp_ymd.dt >= substring(first_visit_time, 1, 10)))
------------filter((dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19') and (dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'))
--------------PhysicalOlapScan[dwd_tracking_sensor_init_tmp_ymd]
------------filter((dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'))
--------------PhysicalOlapScan[dw_user_b2c_tracking_info_tmp_ymd]

Hint log:
Used:
UnUsed: use_PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE
SyntaxError:

-- !agg_pushed --
2 是 2024-08-19

Original file line number Diff line number Diff line change
Expand Up @@ -426,4 +426,99 @@ suite("push_down_count_through_join_one_side") {
qt_with_hint_groupby_pushdown_nested_queries """
explain shape plan select /*+ USE_CBO_RULE(push_down_agg_through_join_one_side) */ count(*) from (select * from count_t_one_side where score > 20) t1 join (select * from count_t_one_side where id < 100) t2 on t1.id = t2.id group by t1.name;
"""

sql """
drop table if exists dw_user_b2c_tracking_info_tmp_ymd;
create table dw_user_b2c_tracking_info_tmp_ymd (
guid int,
dt varchar,
first_visit_time varchar
)Engine=Olap
DUPLICATE KEY(guid)
distributed by hash(dt) buckets 3
properties('replication_num' = '1');
insert into dw_user_b2c_tracking_info_tmp_ymd values (1, '2024-08-19', '2024-08-19');
drop table if exists dwd_tracking_sensor_init_tmp_ymd;
create table dwd_tracking_sensor_init_tmp_ymd (
guid int,
dt varchar,
tracking_type varchar
)Engine=Olap
DUPLICATE KEY(guid)
distributed by hash(dt) buckets 3
properties('replication_num' = '1');
insert into dwd_tracking_sensor_init_tmp_ymd values(1, '2024-08-19', 'click'), (1, '2024-08-19', 'click');
"""
sql """
set ENABLE_NEREIDS_RULES = "PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE";
set disable_join_reorder=true;
"""

qt_shape """
explain shape plan
SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/
Count(*) AS accee593,
CASE
WHEN dwd_tracking_sensor_init_tmp_ymd.dt =
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
10) THEN
'是'
WHEN dwd_tracking_sensor_init_tmp_ymd.dt >
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
10) THEN
'否'
ELSE '-1'
end AS a1302fb2,
dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123
FROM dwd_tracking_sensor_init_tmp_ymd
LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd
ON dwd_tracking_sensor_init_tmp_ymd.guid =
dw_user_b2c_tracking_info_tmp_ymd.guid
AND dwd_tracking_sensor_init_tmp_ymd.dt =
dw_user_b2c_tracking_info_tmp_ymd.dt
WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19'
AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'
AND dwd_tracking_sensor_init_tmp_ymd.dt >=
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10)
AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'
GROUP BY 2,
3
ORDER BY 3 ASC
LIMIT 10000;
"""

qt_agg_pushed """
SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/
Count(*) AS accee593,
CASE
WHEN dwd_tracking_sensor_init_tmp_ymd.dt =
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
10) THEN
'是'
WHEN dwd_tracking_sensor_init_tmp_ymd.dt >
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
10) THEN
'否'
ELSE '-1'
end AS a1302fb2,
dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123
FROM dwd_tracking_sensor_init_tmp_ymd
LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd
ON dwd_tracking_sensor_init_tmp_ymd.guid =
dw_user_b2c_tracking_info_tmp_ymd.guid
AND dwd_tracking_sensor_init_tmp_ymd.dt =
dw_user_b2c_tracking_info_tmp_ymd.dt
WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19'
AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'
AND dwd_tracking_sensor_init_tmp_ymd.dt >=
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10)
AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'
GROUP BY 2,
3
ORDER BY 3 ASC
LIMIT 10000;
"""
}

0 comments on commit ccc5bc2

Please sign in to comment.