Skip to content

Commit

Permalink
TODO distinct
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Dec 15, 2024
1 parent 02e87e1 commit 0294963
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,47 +21,27 @@
package org.apache.doris.nereids.processor.post;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.rewrite.LimitAggToTopNAgg;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate.TopnPushInfo;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
import org.apache.doris.qe.ConnectContext;

import org.apache.hadoop.util.Lists;

import java.util.List;
import java.util.stream.Collectors;

/**
* Add SortInfo to Agg. This SortInfo is used as boundary, not used to sort elements.
* Add TopNInfo to Agg. This TopNInfo is used as boundary, not used to sort elements.
* example
* sql: select count(*) from orders group by o_clerk order by o_clerk limit 1;
* plan: topn(1) -> aggGlobal -> shuffle -> aggLocal -> scan
* optimization: aggLocal and aggGlobal only need to generate the smallest row with respect to o_clerk.
*
* Attention: the following case is error-prone
* sql: select sum(o_shippriority) from orders group by o_clerk limit 1;
* plan: limit -> aggGlobal -> shuffle -> aggLocal -> scan
* aggGlobal may receive partial aggregate results, and hence is not supported now
* instance1: input (key=2, v=1) => localAgg => (2, 1) => aggGlobal inst1 => (2, 1)
* instance2: input (key=1, v=1), (key=2, v=2) => localAgg inst2 => (1, 1)
* (2,1),(1,1) => limit => may output (2, 1), which is not complete, missing (2, 2) in instance2
*
*TOPN:
* Pattern 2-phase agg:
* topn -> aggGlobal -> distribute -> aggLocal
* =>
* topn(n) -> aggGlobal(topNInfo) -> distribute -> aggLocal(topNInfo)
* Pattern 1-phase agg:
* topn->agg->Any(not agg) -> topn -> agg(topNInfo) -> any
* This rule only applies to the pattern that
* 1. aggregate is the child of topN (there is no project between topN and aggregate).
* 2. aggregate is not scalar agg, and there is no distinct arguments
* Refer to LimitAggToTopNAgg rule.
*/
public class PushTopnToAgg extends PlanPostProcessor {
@Override
Expand All @@ -71,95 +51,47 @@ public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext
return topN;
}
Plan topnChild = topN.child();
if (topnChild instanceof PhysicalProject) {
topnChild = topnChild.child(0);
}
if (topnChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topnChild;
List<OrderKey> orderKeys = generateOrderKeyByGroupKeyAndTopNKey(topN, upperAgg);
if (!orderKeys.isEmpty()) {
// TODO detect distinct
if (isGroupKeyIdenticalToOrderKey(topN, upperAgg)
&& LimitAggToTopNAgg.isSortableAggregate(upperAgg)) {
if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getLimit() + topN.getOffset()));
if (upperAgg.child() instanceof PhysicalDistribute
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
bottomAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
}
} else if (upperAgg.getAggPhase().isLocal() && upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
// one phase agg
upperAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
}
}
}
return topN;
}

private List<OrderKey> generateOrderKeyByGroupKeyAndTopNKey(PhysicalTopN<? extends Plan> topN,
private boolean isGroupKeyIdenticalToOrderKey(PhysicalTopN<? extends Plan> topN,
PhysicalHashAggregate<? extends Plan> agg) {
List<OrderKey> orderKeys = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
if (topN.getOrderKeys().size() < agg.getGroupByExpressions().size()) {
return Lists.newArrayList();
if (topN.getOrderKeys().size() != agg.getGroupByExpressions().size()) {
return false;
}
for (int i = 0; i < agg.getGroupByExpressions().size(); i++) {
Expression groupByKey = agg.getGroupByExpressions().get(i);
Expression orderKey = topN.getOrderKeys().get(i).getExpr();
if (groupByKey.equals(orderKey)) {
orderKeys.add(topN.getOrderKeys().get(i));
} else {
orderKeys.clear();
break;
}
}
return orderKeys;
}

@Override
public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesContext ctx) {
limit.child().accept(this, ctx);
if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= limit.getLimit() + limit.getOffset()) {
return limit;
}
Plan limitChild = limit.child();
if (limitChild instanceof PhysicalProject) {
limitChild = limitChild.child(0);
}
if (limitChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limitChild;
if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
Plan child = upperAgg.child();
Plan grandChild = child.child(0);
if (child instanceof PhysicalDistribute
&& ((PhysicalDistribute<?>) child).getDistributionSpec() instanceof DistributionSpecGather
&& grandChild instanceof PhysicalHashAggregate) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(upperAgg),
limit.getLimit() + limit.getOffset()));
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) grandChild;
bottomAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(bottomAgg),
limit.getLimit() + limit.getOffset()));
}
} else if (upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
// 1-phase agg
upperAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(upperAgg),
limit.getLimit() + limit.getOffset()));
if (!groupByKey.equals(orderKey)) {
return false;
}
}
return limit;
return true;
}

private List<OrderKey> generateOrderKeyByGroupKey(PhysicalHashAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.map(key -> new OrderKey(key, true, false))
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
Expand All @@ -32,7 +32,6 @@
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/**
Expand All @@ -52,6 +51,10 @@ public List<Rule> buildRules() {
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= limit.getLimit() + limit.getOffset())
.when(limit -> {
LogicalAggregate<? extends Plan> agg = limit.child();
return isSortableAggregate(agg);
})
.then(limit -> {
LogicalAggregate<? extends Plan> agg = limit.child();
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
Expand All @@ -64,10 +67,13 @@ public List<Rule> buildRules() {
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= limit.getLimit() + limit.getOffset())
.when(limit -> limit.child().isAllSlots())
.when(limit -> {
LogicalAggregate<? extends Plan> agg = limit.child().child();
return isSortableAggregate(agg);
})
.then(limit -> {
LogicalProject<? extends Plan> project = limit.child();
LogicalAggregate<? extends Plan> agg
= (LogicalAggregate<? extends Plan>) project.child();
LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) project.child();
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
LogicalTopN topn = new LogicalTopN<>(orderKeys, limit.getLimit(),
limit.getOffset(), agg);
Expand All @@ -80,22 +86,18 @@ public List<Rule> buildRules() {
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= topn.getLimit() + topn.getOffset())
.when(topn -> {
LogicalAggregate<? extends Plan> agg = topn.child();
return isSortableAggregate(agg);
})
.then(topn -> {
LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) topn.child();
List<OrderKey> newOrders = Lists.newArrayList(topn.getOrderKeys());
Set<Expression> orderExprs = topn.getOrderKeys().stream()
.map(orderKey -> orderKey.getExpr()).collect(Collectors.toSet());
boolean orderKeyChanged = false;
for (Expression expr : agg.getGroupByExpressions()) {
if (!orderExprs.contains(expr)) {
// after NormalizeAggregate, expr should be SlotReference
if (expr instanceof SlotReference) {
orderKeyChanged = true;
newOrders.add(new OrderKey(expr, true, true));
}
}
LogicalAggregate<? extends Plan> agg = topn.child();
List<OrderKey> newOrderKyes = supplementOrderKeyByGroupKeyIfCompatible(topn, agg);
if (newOrderKyes.isEmpty()) {
return topn;
} else {
return topn.withOrderKeys(newOrderKyes);
}
return orderKeyChanged ? topn.withOrderKeys(newOrders) : topn;
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
//topn -> project ->agg: add all group key to sort key, and prune column
logicalTopN(logicalProject(logicalAggregate()))
Expand All @@ -104,38 +106,65 @@ public List<Rule> buildRules() {
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= topn.getLimit() + topn.getOffset())
.when(topn -> topn.child().isAllSlots())
.when(topn -> {
LogicalAggregate<? extends Plan> agg = topn.child().child();
return isSortableAggregate(agg);
})
.then(topn -> {
LogicalProject project = topn.child();
LogicalProject<? extends Plan> project = topn.child();
LogicalAggregate<? extends Plan> agg = (LogicalAggregate) project.child();
List<OrderKey> newOrders = Lists.newArrayList(topn.getOrderKeys());
Set<Expression> orderExprs = topn.getOrderKeys().stream()
.map(orderKey -> orderKey.getExpr()).collect(Collectors.toSet());
boolean orderKeyChanged = false;
for (Expression expr : agg.getGroupByExpressions()) {
if (!orderExprs.contains(expr)) {
// after NormalizeAggregate, expr should be SlotReference
if (expr instanceof SlotReference) {
orderKeyChanged = true;
newOrders.add(new OrderKey(expr, true, true));
}
}
}
Plan result;
if (orderKeyChanged) {
topn = (LogicalTopN) topn.withChildren(agg);
topn.withOrderKeys(newOrders);
result = (Plan) project.withChildren(topn);
List<OrderKey> newOrders = supplementOrderKeyByGroupKeyIfCompatible(topn, agg);
if (newOrders.isEmpty()) {
return topn;
} else {
result = topn;
topn = (LogicalTopN) topn.withChildren(agg);
topn = (LogicalTopN) topn.withOrderKeys(newOrders);
project = (LogicalProject) project.withChildren(topn);
return project;
}
return result;
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG)
);
}

/**
* not scalar agg
* no distinct
*/
public static boolean isSortableAggregate(Aggregate agg) {
return !agg.getGroupByExpressions().isEmpty() && agg.getDistinctArguments().isEmpty();
}

private List<OrderKey> generateOrderKeyByGroupKey(LogicalAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.map(key -> new OrderKey(key, true, false))
.collect(Collectors.toList());
}

private List<OrderKey> supplementOrderKeyByGroupKeyIfCompatible(LogicalTopN<? extends Plan> topn,
LogicalAggregate<? extends Plan> agg) {
int groupKeyCount = agg.getGroupByExpressions().size();
int orderKeyCount = topn.getOrderKeys().size();
if (orderKeyCount <= groupKeyCount) {
boolean canAppendOrderKey = true;
for (int i = 0; i < orderKeyCount; i++) {
Expression groupKey = agg.getGroupByExpressions().get(i);
Expression orderKey = topn.getOrderKeys().get(i).getExpr();
if (!groupKey.equals(orderKey)) {
canAppendOrderKey = false;
break;
}
}
if (canAppendOrderKey && orderKeyCount < groupKeyCount) {
List<OrderKey> newOrderKeys = Lists.newArrayList(topn.getOrderKeys());
for (int i = orderKeyCount; i < groupKeyCount; i++) {
newOrderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), true, false));
}
return newOrderKeys;
} else {
return Lists.newArrayList();
}
} else {
return Lists.newArrayList();
}
}
}

0 comments on commit 0294963

Please sign in to comment.