Skip to content

Commit

Permalink
[enhancement](Nereids) avoiding broadcast join heuristically and prun…
Browse files Browse the repository at this point in the history
…ing more in CostAndEnforceJob (apache#25137)

When the rowCount exceeds a certain threshold, refrain from generating a broadcast join.
Only enforce the best expression in CostAndEnforce Job, rather than enforcing every expression.
Remove lower bound group pruning
  • Loading branch information
keanji-x authored Oct 10, 2023
1 parent 181c58c commit 7276665
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,6 @@ public Cost visitPhysicalDistribute(

// replicate
if (spec instanceof DistributionSpecReplicated) {
double dataSize = childStatistics.computeSize();
double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
//if build side is big, avoid use broadcast join
double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
if (dataSize > memLimit * brMemlimit
|| childStatistics.getRowCount() > rowsLimit) {
return CostV1.of(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE);
}
// estimate broadcast cost by an experience formula: beNumber^0.5 * rowCount
// - sender number and receiver number is not available at RBO stage now, so we use beNumber
// - senders and receivers work in parallel, that why we use square of beNumber
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,17 @@ public void execute() {
curNodeCost,
lowestCostExpr.getCostValueByProperties(requestChildProperty),
curChildIndex);
if (curTotalCost.getValue() > context.getCostUpperBound()) {
curTotalCost = Cost.infinite();
}

// Not performing lower bound group pruning here is to avoid redundant optimization of children.
// For example:
// Group1 : betterExpr, currentExpr(child: Group2), otherExpr(child: Group)
// steps
// 1. CostAndEnforce(currentExpr) with upperBound betterExpr.cost
// 2. OptimzeGroup(Group2) with upperBound bestExpr.cost - currentExpr.nodeCost
// 3. CostAndEnforce(Expr in Group2) trigger here and exit
// ...
// n. CostAndEnforce(otherExpr) can trigger optimize group2 again for the same requireProp

// the request child properties will be covered by the output properties
// that corresponding to the request properties. so if we run a costAndEnforceJob of the same
// group expression, that request child properties will be different of this.
Expand Down Expand Up @@ -275,6 +283,23 @@ private void enforce(PhysicalProperties outputProperty, List<PhysicalProperties>
}
return;
}

if (context.getRequiredProperties().isDistributionOnlyProperties()) {
// For properties without an orderSpec, enforceMissingPropertiesHelper always adds a distributor
// above this group expression. The cost of the distributor is equal to the cost of the groupExpression
// plus the cost of the distributor. The distributor remains unchanged for different groupExpressions.
// Therefore, if there is a better groupExpr, it is preferable to enforce the better groupExpr.
// Consequently, we can avoid this enforcement.
Optional<Pair<Cost, GroupExpression>> bestExpr = groupExpression.getOwnerGroup()
.getLowestCostPlan(context.getRequiredProperties());
double bestCost = bestExpr
.map(costGroupExpressionPair -> costGroupExpressionPair.first.getValue())
.orElse(Double.POSITIVE_INFINITY);
if (curTotalCost.getValue() > bestCost) {
return;
}
}

EnforceMissingPropertiesHelper enforceMissingPropertiesHelper
= new EnforceMissingPropertiesHelper(context, groupExpression, curTotalCost);
PhysicalProperties addEnforcedProperty = enforceMissingPropertiesHelper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ public DistributionSpec getDistributionSpec() {
return distributionSpec;
}

public boolean isDistributionOnlyProperties() {
return orderSpec.getOrderKeys().isEmpty();
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,14 @@ public Void visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? extends Pla
if (JoinUtils.couldShuffle(hashJoin)) {
addShuffleJoinRequestProperty(hashJoin);
}

// for broadcast join
if (JoinUtils.couldBroadcast(hashJoin)) {
double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
double datasize = hashJoin.getGroupExpression().get().child(1).getStatistics().computeSize();
double rowCount = hashJoin.getGroupExpression().get().child(1).getStatistics().getRowCount();
if (JoinUtils.couldBroadcast(hashJoin) && rowCount <= rowsLimit && datasize <= memLimit * brMemlimit) {
addBroadcastJoinRequestProperty();
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
Expand All @@ -60,6 +61,12 @@ class RequestPropertyDeriverTest {
@Mocked
LogicalProperties logicalProperties;

@Mocked
ConnectContext connectContext;

@Injectable
Group group;

@Injectable
JobContext jobContext;

Expand Down Expand Up @@ -105,7 +112,7 @@ Pair<List<ExprId>, List<ExprId>> getHashConjunctsExprIds() {
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(),
logicalProperties,
groupPlan, groupPlan);
GroupExpression groupExpression = new GroupExpression(join);
GroupExpression groupExpression = new GroupExpression(join, Lists.newArrayList(group, group));
new Group(null, groupExpression, null);

RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext);
Expand All @@ -130,11 +137,18 @@ Pair<List<ExprId>, List<ExprId>> getHashConjunctsExprIds() {
}
};

new MockUp<ConnectContext>() {
@Mock
ConnectContext get() {
return connectContext;
}
};

PhysicalHashJoin<GroupPlan, GroupPlan> join = new PhysicalHashJoin<>(JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(),
logicalProperties,
groupPlan, groupPlan);
GroupExpression groupExpression = new GroupExpression(join);
GroupExpression groupExpression = new GroupExpression(join, Lists.newArrayList(group, group));
new Group(null, groupExpression, null);

RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext);
Expand Down

0 comments on commit 7276665

Please sign in to comment.