diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java index 040bd9c10f3f3ea..6b17eb1c3a9ce1b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java @@ -21,6 +21,8 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; @@ -62,12 +64,37 @@ public List buildRules() { && ConnectContext.get().getSessionVariable().pushTopnToAgg && ConnectContext.get().getSessionVariable().topnOptLimitThreshold >= limit.getLimit() + limit.getOffset()) - .when(limit -> outputAllGroupKeys(limit, limit.child().child())) .then(limit -> { LogicalProject project = limit.child(); - LogicalAggregate agg = (LogicalAggregate) project.child(); + LogicalAggregate agg + = (LogicalAggregate) project.child(); List orderKeys = generateOrderKeyByGroupKey(agg); - return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), project); + Plan result; + + if (outputAllGroupKeys(limit, agg)) { + result = new LogicalTopN<>(orderKeys, limit.getLimit(), + limit.getOffset(), project); + } else { + // add the first group by key to topn, and prune this key by upper project + // topn order keys are prefix of group by keys + // refer to PushTopnToAgg.tryGenerateOrderKeyByGroupKeyAndTopnKey() + List bottomProjections = Lists.newArrayList(project.getProjects()); + if (agg.getGroupByExpressions().isEmpty()) { + return null; + } + Expression firstGroupByKey = agg.getGroupByExpressions().get(0); + if (!(firstGroupByKey instanceof SlotReference)) { + return null; + } + bottomProjections.add((SlotReference) firstGroupByKey); + LogicalProject bottomProject = project.withProjects(bottomProjections); + LogicalTopN topn = new LogicalTopN<>(orderKeys, limit.getLimit(), + limit.getOffset(), bottomProject); + List limitOutput = limit.getOutput().stream() + .map(e -> (NamedExpression) e).collect(Collectors.toList()); + result = new LogicalProject<>(limitOutput, topn); + } + return result; }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG), // topn -> agg: add all group key to sort key, if sort key is prefix of group key logicalTopN(logicalAggregate())