diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java index a2b5875fa89f1f..049709dd23a311 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; @@ -34,8 +35,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import java.util.HashMap; import java.util.List; @@ -51,7 +50,6 @@ * 2. push limit to local agg */ public class LimitAggToTopNAgg implements RewriteRuleFactory { - public static final Logger LOG = LogManager.getLogger(LimitAggToTopNAgg.class); @Override public List buildRules() { @@ -125,8 +123,6 @@ public List buildRules() { LogicalTopN originTopn = topn; LogicalProject project = topn.child(); LogicalAggregate agg = (LogicalAggregate) project.child(); - StringBuilder builder = new StringBuilder(); - builder.append("@@@@@###"); if (!project.isAllSlots()) { /* topn(orderKey=[a]) @@ -144,9 +140,6 @@ public List buildRules() { keyAsKey.put((SlotReference) e.toSlot(), (SlotReference) e.child(0)); } } - builder.append(topn); - builder.append(project); - List projectOrderKeys = Lists.newArrayList(); boolean hasNew = false; for (OrderKey orderKey : topn.getOrderKeys()) { @@ -165,22 +158,36 @@ public List buildRules() { supplementOrderKeyByGroupKeyIfCompatible(topn, agg); Plan result; if (pair == null) { - builder.append("|not compatible"); result = originTopn; } else { - builder.append("|compatible"); agg = agg.withGroupBy(pair.second); topn = (LogicalTopN) topn.withOrderKeys(pair.first); - topn = (LogicalTopN) topn.withChildren(agg); - project = (LogicalProject) project.withChildren(topn); - result = project; + if (isOrderKeysInProject(topn, project)) { + project = (LogicalProject) project.withChildren(agg); + topn = (LogicalTopN>>) + topn.withChildren(project); + result = topn; + } else { + topn = (LogicalTopN) topn.withChildren(agg); + project = (LogicalProject) project.withChildren(topn); + result = project; + } } - LOG.warn(builder.toString()); return result; }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG) ); } + private boolean isOrderKeysInProject(LogicalTopN topn, LogicalProject project) { + Set projectSlots = project.getOutputSet(); + for (OrderKey orderKey : topn.getOrderKeys()) { + if (!projectSlots.contains(orderKey.getExpr())) { + return false; + } + } + return true; + } + private List generateOrderKeyByGroupKey(LogicalAggregate agg) { return agg.getGroupByExpressions().stream() .map(key -> new OrderKey(key, true, false))