From ccc5bc2366b7ad680d69794f06ebe50d95241729 Mon Sep 17 00:00:00 2001 From: minghong Date: Wed, 13 Nov 2024 14:58:31 +0800 Subject: [PATCH] enhance PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE --- .../PushDownAggThroughJoinOneSide.java | 146 ++++++++++++++---- .../push_down_count_through_join_one_side.out | 20 +++ ...sh_down_count_through_join_one_side.groovy | 95 ++++++++++++ 3 files changed, 232 insertions(+), 29 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java index f32bf8ea91b3552..ac42182890cbc29 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java @@ -36,6 +36,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.HashMap; @@ -88,15 +89,16 @@ public List buildRules() { }) .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE), logicalAggregate(logicalProject(innerLogicalJoin())) - .when(agg -> agg.child().isAllSlots()) - .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + // .when(agg -> agg.child().isAllSlots()) + // .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) + .whenNot(agg -> agg.child() + .child(0).children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum - || (f instanceof Count && (!((Count) f).isCountStar()))) && !f.isDistinct() - && f.child(0) instanceof Slot); + || f instanceof Count) && !f.isDistinct() + && (f.children().isEmpty() || f.child(0) instanceof Slot)); }) .thenApply(ctx -> { Set enableNereidsRules = ctx.cascadesContext.getConnectContext() @@ -111,6 +113,7 @@ public List buildRules() { ); } + /** * Push down Min/Max/Sum through join. */ @@ -119,21 +122,6 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate leftOutput = join.left().getOutput(); List rightOutput = join.right().getOutput(); - List leftFuncs = new ArrayList<>(); - List rightFuncs = new ArrayList<>(); - for (AggregateFunction func : agg.getAggregateFunctions()) { - Slot slot = (Slot) func.child(0); - if (leftOutput.contains(slot)) { - leftFuncs.add(func); - } else if (rightOutput.contains(slot)) { - rightFuncs.add(func); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } - if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) { - return null; - } Set leftGroupBy = new HashSet<>(); Set rightGroupBy = new HashSet<>(); @@ -144,18 +132,76 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate inputForAlias = proj.getInputSlots(); + if (leftOutput.containsAll(inputForAlias)) { + leftGroupBy.addAll(inputForAlias); + } else if (rightOutput.containsAll(inputForAlias)) { + rightGroupBy.addAll(inputForAlias); + } else { + /* + groupBy(X) + +---> project( a + b as X) + --> join(output: T1.a, T2.b) + --> T1(a) + --> T2(b) + X can not be pushed + */ + return null; + } + break; + } + } + } } } - join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); + + List leftFuncs = new ArrayList<>(); + List rightFuncs = new ArrayList<>(); + Count countStar = null; + for (AggregateFunction func : agg.getAggregateFunctions()) { + if (func instanceof Count && ((Count) func).isCountStar()) { + countStar = (Count) func; } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); + Slot slot = (Slot) func.child(0); + if (leftOutput.contains(slot)) { + leftFuncs.add(func); + } else if (rightOutput.contains(slot)) { + rightFuncs.add(func); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + } + } + // determine count(*) + if (countStar != null) { + if (!leftGroupBy.isEmpty()) { + countStar = (Count) countStar.withChildren(leftGroupBy.iterator().next()); + leftFuncs.add(countStar); + } else if (!rightGroupBy.isEmpty()) { + countStar = (Count) countStar.withChildren(rightGroupBy.iterator().next()); + rightFuncs.add(countStar); + } else { + return null; + } + } + for (Expression condition : join.getHashJoinConjuncts()) { + for (Slot joinConditionSlot : condition.getInputSlots()) { + if (leftOutput.contains(joinConditionSlot)) { + leftGroupBy.add(joinConditionSlot); + } else if (rightOutput.contains(joinConditionSlot)) { + rightGroupBy.add(joinConditionSlot); + } else { + // apply failed + return null; + } } - })); + } Plan left = join.left(); Plan right = join.right(); @@ -196,6 +242,10 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate pushMinMaxSumCount(LogicalAggregate newProjections = Lists.newArrayList(); + newProjections.addAll(project.getProjects()); + Set leftDifference = new HashSet(left.getOutput()); + leftDifference.removeAll(project.getProjects()); + newProjections.addAll(leftDifference); + Set rightDifference = new HashSet(right.getOutput()); + rightDifference.removeAll(project.getProjects()); + newProjections.addAll(rightDifference); + newAggChild = ((LogicalProject) agg.child()).withProjectsAndChild(newProjections, newJoin); + } // TODO: column prune project - return agg.withAggOutputChild(newOutputExprs, newJoin); + LogicalAggregate newAgg = agg.withAggOutputChild(newOutputExprs, newAggChild); + if (checkOutput(newAgg)) { + return newAgg; + } else { + return (LogicalAggregate) agg; + } } private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) { @@ -222,4 +290,24 @@ private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) return func.withChildren(inputSlot); } } + + private static boolean checkOutput(LogicalAggregate agg) { + if (agg.child() instanceof LogicalProject) { + Set joinOutputs = ((Plan) agg.child().child(0)).getOutputSet(); + if (!joinOutputs.containsAll(((LogicalProject) agg.child()).getInputSlots())) { + return false; + } + Set projectOutputs = ((LogicalProject) agg.child()).getOutputSet(); + if (!projectOutputs.containsAll(agg.getInputSlots())) { + return false; + } + return true; + } else { + Set joinOutputs = ((Plan) agg.child()).getOutputSet(); + if (!joinOutputs.containsAll(agg.getInputSlots())) { + return false; + } + return true; + } + } } diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out index da69919becd7f28..eddd2733ee903fc 100644 --- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out +++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out @@ -1034,3 +1034,23 @@ Used: UnUsed: use_push_down_agg_through_join_one_side SyntaxError: +-- !shape -- +PhysicalResultSink +--PhysicalTopN[MERGE_SORT] +----PhysicalTopN[LOCAL_SORT] +------hashAgg[GLOBAL] +--------hashAgg[LOCAL] +----------hashJoin[INNER_JOIN] hashCondition=((dwd_tracking_sensor_init_tmp_ymd.dt = dw_user_b2c_tracking_info_tmp_ymd.dt) and (dwd_tracking_sensor_init_tmp_ymd.guid = dw_user_b2c_tracking_info_tmp_ymd.guid)) otherCondition=((dwd_tracking_sensor_init_tmp_ymd.dt >= substring(first_visit_time, 1, 10))) +------------filter((dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19') and (dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click')) +--------------PhysicalOlapScan[dwd_tracking_sensor_init_tmp_ymd] +------------filter((dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19')) +--------------PhysicalOlapScan[dw_user_b2c_tracking_info_tmp_ymd] + +Hint log: +Used: +UnUsed: use_PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE +SyntaxError: + +-- !agg_pushed -- +2 是 2024-08-19 + diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy index 02e067102963335..e551fa04c9110a8 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy @@ -426,4 +426,99 @@ suite("push_down_count_through_join_one_side") { qt_with_hint_groupby_pushdown_nested_queries """ explain shape plan select /*+ USE_CBO_RULE(push_down_agg_through_join_one_side) */ count(*) from (select * from count_t_one_side where score > 20) t1 join (select * from count_t_one_side where id < 100) t2 on t1.id = t2.id group by t1.name; """ + + sql """ + drop table if exists dw_user_b2c_tracking_info_tmp_ymd; + create table dw_user_b2c_tracking_info_tmp_ymd ( + guid int, + dt varchar, + first_visit_time varchar + )Engine=Olap + DUPLICATE KEY(guid) + distributed by hash(dt) buckets 3 + properties('replication_num' = '1'); + + insert into dw_user_b2c_tracking_info_tmp_ymd values (1, '2024-08-19', '2024-08-19'); + + drop table if exists dwd_tracking_sensor_init_tmp_ymd; + create table dwd_tracking_sensor_init_tmp_ymd ( + guid int, + dt varchar, + tracking_type varchar + )Engine=Olap + DUPLICATE KEY(guid) + distributed by hash(dt) buckets 3 + properties('replication_num' = '1'); + + insert into dwd_tracking_sensor_init_tmp_ymd values(1, '2024-08-19', 'click'), (1, '2024-08-19', 'click'); + """ + sql """ + set ENABLE_NEREIDS_RULES = "PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE"; + set disable_join_reorder=true; + """ + + qt_shape """ + explain shape plan + SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/ + Count(*) AS accee593, + CASE + WHEN dwd_tracking_sensor_init_tmp_ymd.dt = + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '是' + WHEN dwd_tracking_sensor_init_tmp_ymd.dt > + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '否' + ELSE '-1' + end AS a1302fb2, + dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123 + FROM dwd_tracking_sensor_init_tmp_ymd + LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd + ON dwd_tracking_sensor_init_tmp_ymd.guid = + dw_user_b2c_tracking_info_tmp_ymd.guid + AND dwd_tracking_sensor_init_tmp_ymd.dt = + dw_user_b2c_tracking_info_tmp_ymd.dt + WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19' + AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19' + AND dwd_tracking_sensor_init_tmp_ymd.dt >= + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10) + AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click' + GROUP BY 2, + 3 + ORDER BY 3 ASC + LIMIT 10000; + """ + + qt_agg_pushed """ + SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/ + Count(*) AS accee593, + CASE + WHEN dwd_tracking_sensor_init_tmp_ymd.dt = + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '是' + WHEN dwd_tracking_sensor_init_tmp_ymd.dt > + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '否' + ELSE '-1' + end AS a1302fb2, + dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123 + FROM dwd_tracking_sensor_init_tmp_ymd + LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd + ON dwd_tracking_sensor_init_tmp_ymd.guid = + dw_user_b2c_tracking_info_tmp_ymd.guid + AND dwd_tracking_sensor_init_tmp_ymd.dt = + dw_user_b2c_tracking_info_tmp_ymd.dt + WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19' + AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19' + AND dwd_tracking_sensor_init_tmp_ymd.dt >= + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10) + AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click' + GROUP BY 2, + 3 + ORDER BY 3 ASC + LIMIT 10000; + """ }