Skip to content

Commit

Permalink
[fix](Nereids) set correct sort key for aggregate (apache#45369)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
in previous apache#44042, we supported more patterns for PushTopnToAgg rule.
the new pattern:
topn
  +-->agg(global)
      +-->shuffle
          +-->agg(local)

In order to support this new pattern, the group by keys and orderkeys
are identical, but group keys can be in different order.
that is 
topn(orderkey=[B,A])->agg(groupkey=[A,B,C])
=>
topn(orderkey=[B, A, C]) ->agg(groupKey=[A, B, C])
  • Loading branch information
englefly authored Dec 18, 2024
1 parent fa10bdd commit a886463
Show file tree
Hide file tree
Showing 14 changed files with 428 additions and 345 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ private static List<RewriteJob> buildAnalyzerJobs(Optional<CustomTableResolver>
topDown(new EliminateGroupByConstant()),

topDown(new SimplifyAggGroupBy()),
topDown(new CompressedMaterialize()),
bottomUp(new CompressedMaterialize()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),
topDown(new QualifyToFilter()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
import org.apache.doris.nereids.rules.rewrite.ReduceAggregateChildOutputRows;
import org.apache.doris.nereids.rules.rewrite.ReorderJoin;
import org.apache.doris.nereids.rules.rewrite.RewriteCteChildren;
import org.apache.doris.nereids.rules.rewrite.SimplifyEncodeDecode;
import org.apache.doris.nereids.rules.rewrite.SimplifyWindowExpression;
import org.apache.doris.nereids.rules.rewrite.SplitLimit;
import org.apache.doris.nereids.rules.rewrite.SumLiteralRewrite;
Expand Down Expand Up @@ -371,6 +372,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
// generate one PhysicalLimit if current distribution is gather or two
// PhysicalLimits with gather exchange
topDown(new LimitSortToTopN()),
topDown(new SimplifyEncodeDecode()),
topDown(new LimitAggToTopNAgg()),
topDown(new MergeTopNs()),
topDown(new SplitLimit()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,157 +21,84 @@
package org.apache.doris.nereids.processor.post;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate.TopnPushInfo;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
import org.apache.doris.qe.ConnectContext;

import org.apache.hadoop.util.Lists;

import java.util.List;
import java.util.stream.Collectors;

/**
* Add SortInfo to Agg. This SortInfo is used as boundary, not used to sort elements.
* Add TopNInfo to Agg. This TopNInfo is used as boundary, not used to sort elements.
* example
* sql: select count(*) from orders group by o_clerk order by o_clerk limit 1;
* plan: topn(1) -> aggGlobal -> shuffle -> aggLocal -> scan
* optimization: aggLocal and aggGlobal only need to generate the smallest row with respect to o_clerk.
*
* TODO: the following case is not covered:
* sql: select sum(o_shippriority) from orders group by o_clerk limit 1;
* plan: limit -> aggGlobal -> shuffle -> aggLocal -> scan
* aggGlobal may receive partial aggregate results, and hence is not supported now
* instance1: input (key=2, v=1) => localAgg => (2, 1) => aggGlobal inst1 => (2, 1)
* instance2: input (key=1, v=1), (key=2, v=2) => localAgg inst2 => (1, 1)
* (2,1),(1,1) => limit => may output (2, 1), which is not complete, missing (2, 2) in instance2
*
*TOPN:
* Precondition: topn orderkeys are the prefix of group keys
* TODO: topnKeys could be subset of groupKeys. This will be implemented in future
* Pattern 2-phase agg:
* topn -> aggGlobal -> distribute -> aggLocal
* =>
* topn(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
* Pattern 1-phase agg:
* topn->agg->Any(not agg) -> topn -> agg(topn=n) -> any
*
* LIMIT:
* Pattern 1: limit->agg(1phase)->any
* Pattern 2: limit->agg(global)->gather->agg(local)
* This rule only applies to the patterns
* 1. topn->project->agg, or
* 2. topn->agg
* that
* 1. orderKeys and groupkeys are one-one mapping
* 2. aggregate is not scalar agg
* Refer to LimitAggToTopNAgg rule.
*/
public class PushTopnToAgg extends PlanPostProcessor {
@Override
public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext ctx) {
topN.child().accept(this, ctx);
if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= topN.getLimit() + topN.getOffset()) {
if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= topN.getLimit() + topN.getOffset()
&& !ConnectContext.get().getSessionVariable().pushTopnToAgg) {
return topN;
}
Plan topnChild = topN.child();
if (topnChild instanceof PhysicalProject) {
topnChild = topnChild.child(0);
Plan topNChild = topN.child();
if (topNChild instanceof PhysicalProject) {
topNChild = topNChild.child(0);
}
if (topnChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topnChild;
List<OrderKey> orderKeys = tryGenerateOrderKeyByGroupKeyAndTopnKey(topN, upperAgg);
if (!orderKeys.isEmpty()) {

if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getLimit() + topN.getOffset()));
if (upperAgg.child() instanceof PhysicalDistribute
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
if (topNChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topNChild;
if (isGroupKeyIdenticalToOrderKey(topN, upperAgg)) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
if (upperAgg.child() instanceof PhysicalDistribute
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
if (isGroupKeyIdenticalToOrderKey(topN, bottomAgg)) {
bottomAgg.setTopnPushInfo(new TopnPushInfo(
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
}
} else if (upperAgg.child() instanceof PhysicalHashAggregate) {
// multi-distinct plan
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child();
if (isGroupKeyIdenticalToOrderKey(topN, bottomAgg)) {
bottomAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
}
} else if (upperAgg.getAggPhase().isLocal() && upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
// one phase agg
upperAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getLimit() + topN.getOffset()));
}
}
}
return topN;
}

/**
return true, if topn order-key is prefix of agg group-key, ignore asc/desc and null_first
TODO order-key can be subset of group-key. BE does not support now.
*/
private List<OrderKey> tryGenerateOrderKeyByGroupKeyAndTopnKey(PhysicalTopN<? extends Plan> topN,
PhysicalHashAggregate<? extends Plan> agg) {
List<OrderKey> orderKeys = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
if (topN.getOrderKeys().size() > agg.getGroupByExpressions().size()) {
return orderKeys;
}
List<Expression> topnKeys = topN.getOrderKeys().stream()
.map(OrderKey::getExpr).collect(Collectors.toList());
for (int i = 0; i < topN.getOrderKeys().size(); i++) {
// prefix check
if (!topnKeys.get(i).equals(agg.getGroupByExpressions().get(i))) {
return Lists.newArrayList();
}
orderKeys.add(topN.getOrderKeys().get(i));
}
for (int i = topN.getOrderKeys().size(); i < agg.getGroupByExpressions().size(); i++) {
orderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), true, false));
}
return orderKeys;
}

@Override
public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesContext ctx) {
limit.child().accept(this, ctx);
if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= limit.getLimit() + limit.getOffset()) {
return limit;
private boolean isGroupKeyIdenticalToOrderKey(PhysicalTopN<? extends Plan> topN,
PhysicalHashAggregate<? extends Plan> agg) {
if (topN.getOrderKeys().size() != agg.getGroupByExpressions().size()) {
return false;
}
Plan limitChild = limit.child();
if (limitChild instanceof PhysicalProject) {
limitChild = limitChild.child(0);
}
if (limitChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limitChild;
if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
Plan child = upperAgg.child();
Plan grandChild = child.child(0);
if (child instanceof PhysicalDistribute
&& ((PhysicalDistribute<?>) child).getDistributionSpec() instanceof DistributionSpecGather
&& grandChild instanceof PhysicalHashAggregate) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(upperAgg),
limit.getLimit() + limit.getOffset()));
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) grandChild;
bottomAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(bottomAgg),
limit.getLimit() + limit.getOffset()));
}
} else if (upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
// 1-phase agg
upperAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(upperAgg),
limit.getLimit() + limit.getOffset()));
for (int i = 0; i < agg.getGroupByExpressions().size(); i++) {
Expression groupByKey = agg.getGroupByExpressions().get(i);
Expression orderKey = topN.getOrderKeys().get(i).getExpr();
if (!groupByKey.equals(orderKey)) {
return false;
}
}
return limit;
}

private List<OrderKey> generateOrderKeyByGroupKey(PhysicalHashAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.map(key -> new OrderKey(key, true, false))
.collect(Collectors.toList());
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public enum RuleType {
// rewrite rules
COMPRESSED_MATERIALIZE_AGG(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_SORT(RuleTypeClass.REWRITE),
SIMPLIFY_ENCODE_DECODE(RuleTypeClass.REWRITE),
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
NORMALIZE_SORT(RuleTypeClass.REWRITE),
NORMALIZE_REPEAT(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,26 @@ public List<Rule> buildRules() {

private LogicalSort<Plan> compressMaterializeSort(LogicalSort<Plan> sort) {
List<OrderKey> newOrderKeys = Lists.newArrayList();
boolean changed = false;
List<Expression> orderKeysToEncode = Lists.newArrayList();
for (OrderKey orderKey : sort.getOrderKeys()) {
Expression expr = orderKey.getExpr();
Optional<Expression> encode = getEncodeExpression(expr);
if (encode.isPresent()) {
newOrderKeys.add(new OrderKey(encode.get(),
orderKey.isAsc(),
orderKey.isNullFirst()));
changed = true;
orderKeysToEncode.add(expr);
} else {
newOrderKeys.add(orderKey);
}
}
return changed ? sort.withOrderKeys(newOrderKeys) : sort;
if (orderKeysToEncode.isEmpty()) {
return sort;
} else {
sort = sort.withOrderKeys(newOrderKeys);
return sort;
}

}

private Optional<Expression> getEncodeExpression(Expression expression) {
Expand Down
Loading

0 comments on commit a886463

Please sign in to comment.