From c7f134a471018510dd6cb830a28a5f89d2bc670a Mon Sep 17 00:00:00 2001 From: minghong Date: Mon, 18 Nov 2024 15:33:20 +0800 Subject: [PATCH] fix --- .../rules/rewrite/LimitAggToTopNAgg.java | 52 +++++++++++++------ .../plans/physical/PhysicalHashAggregate.java | 4 +- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java index 6b17eb1c3a9ce1b..4534213b51b1a7a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java @@ -34,6 +34,7 @@ import com.google.common.collect.Lists; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; /** @@ -55,7 +56,11 @@ public List buildRules() { >= limit.getLimit() + limit.getOffset()) .then(limit -> { LogicalAggregate agg = limit.child(); - List orderKeys = generateOrderKeyByGroupKey(agg); + Optional orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg); + if (!orderKeysOpt.isPresent()) { + return null; + } + List orderKeys = Lists.newArrayList(orderKeysOpt.get()); return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), agg); }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG), //limit->project->agg to topn->project->agg @@ -68,7 +73,11 @@ public List buildRules() { LogicalProject project = limit.child(); LogicalAggregate agg = (LogicalAggregate) project.child(); - List orderKeys = generateOrderKeyByGroupKey(agg); + Optional orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg); + if (!orderKeysOpt.isPresent()) { + return null; + } + List orderKeys = Lists.newArrayList(orderKeysOpt.get()); Plan result; if (outputAllGroupKeys(limit, agg)) { @@ -78,21 +87,27 @@ public List buildRules() { // add the first group by key to topn, and prune this key by upper project // topn order keys are prefix of group by keys // refer to PushTopnToAgg.tryGenerateOrderKeyByGroupKeyAndTopnKey() - List bottomProjections = Lists.newArrayList(project.getProjects()); - if (agg.getGroupByExpressions().isEmpty()) { - return null; - } Expression firstGroupByKey = agg.getGroupByExpressions().get(0); if (!(firstGroupByKey instanceof SlotReference)) { return null; } - bottomProjections.add((SlotReference) firstGroupByKey); - LogicalProject bottomProject = project.withProjects(bottomProjections); + boolean shouldPruneFirstGroupByKey = true; + if (project.getOutputs().contains(firstGroupByKey)) { + shouldPruneFirstGroupByKey = false; + } else { + List bottomProjections = Lists.newArrayList(project.getProjects()); + bottomProjections.add((SlotReference) firstGroupByKey); + project = project.withProjects(bottomProjections); + } LogicalTopN topn = new LogicalTopN<>(orderKeys, limit.getLimit(), - limit.getOffset(), bottomProject); - List limitOutput = limit.getOutput().stream() - .map(e -> (NamedExpression) e).collect(Collectors.toList()); - result = new LogicalProject<>(limitOutput, topn); + limit.getOffset(), project); + if (shouldPruneFirstGroupByKey) { + List limitOutput = limit.getOutput().stream() + .map(e -> (NamedExpression) e).collect(Collectors.toList()); + result = new LogicalProject<>(limitOutput, topn); + } else { + result = topn; + } } return result; }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG), @@ -138,9 +153,14 @@ private boolean outputAllGroupKeys(LogicalLimit limit, LogicalAggregate agg) { return limit.getOutputSet().containsAll(agg.getGroupByExpressions()); } - private List generateOrderKeyByGroupKey(LogicalAggregate agg) { - return agg.getGroupByExpressions().stream() - .map(key -> new OrderKey(key, true, false)) - .collect(Collectors.toList()); + private Optional tryGenerateOrderKeyByTheFirstGroupKey(LogicalAggregate agg) { + if (agg.getGroupByExpressions().isEmpty()) { + return Optional.empty(); + } + if (agg.getGroupByExpressions().get(0) instanceof SlotReference) { + // agg normalize projects the expression under agg. we cannot use it as order key above agg + return Optional.of(new OrderKey(agg.getGroupByExpressions().get(0), true, false)); + } + return Optional.empty(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java index 404c30fe379d4a4..7ed39fed8b60fb7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java @@ -206,8 +206,8 @@ public String toString() { "groupByExpr", groupByExpressions, "outputExpr", outputExpressions, "partitionExpr", partitionExpressions, - "requireProperties", requireProperties, - "topnOpt", topnPushInfo != null + "topnFilter", topnPushInfo != null, + "topnPushDown", getMutableState(MutableState.KEY_PUSH_TOPN_TO_AGG).isPresent() ); }