Skip to content

Commit

Permalink
[opt](nereids) refine operator estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongjian.xzj authored and zhongjian.xzj committed Sep 18, 2024
1 parent 7f9370b commit b69ae32
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class JoinEstimation {
private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3;
private static double UNKNOWN_COL_STATS_FILTER_SEL_LOWER_BOUND = 0.5;

private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equal, Statistics leftStats,
private static EqualPredicate normalizeEqualPredJoinCondition(EqualPredicate equal, Statistics leftStats,
Statistics rightStats) {
boolean changeOrder = equal.left().getInputSlots().stream()
.anyMatch(slot -> rightStats.findColumnStatistics(slot) != null);
Expand All @@ -58,7 +58,7 @@ private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equal, S
}
}

private static boolean hashJoinConditionContainsUnknownColumnStats(Statistics leftStats,
private static boolean joinConditionContainsUnknownColumnStats(Statistics leftStats,
Statistics rightStats, Join join) {
for (Expression expr : join.getEqualPredicates()) {
for (Slot slot : expr.getInputSlots()) {
Expand All @@ -74,7 +74,8 @@ private static boolean hashJoinConditionContainsUnknownColumnStats(Statistics le
return false;
}

private static Statistics estimateHashJoin(Statistics leftStats, Statistics rightStats, Join join) {
private static Statistics estimateInnerJoinWithEqualPredicate(Statistics leftStats,
Statistics rightStats, Join join) {
/*
* When we estimate filter A=B,
* if any side of equation, A or B, is almost unique, the confidence level of estimation is high.
Expand All @@ -94,7 +95,7 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ
// since ndv is not accurate, if ndv/rowcount < almostUniqueThreshold,
// this column is regarded as unique.
double almostUniqueThreshold = 0.9;
EqualPredicate equal = normalizeHashJoinCondition(expression, leftStats, rightStats);
EqualPredicate equal = normalizeEqualPredJoinCondition(expression, leftStats, rightStats);
ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats);
ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats);
boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold
Expand Down Expand Up @@ -148,8 +149,9 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ
return innerJoinStats;
}

private static Statistics estimateNestLoopJoin(Statistics leftStats, Statistics rightStats, Join join) {
if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
private static Statistics estimateInnerJoinWithoutEqualPredicate(Statistics leftStats,
Statistics rightStats, Join join) {
if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
double rowCount = (leftStats.getRowCount() + rightStats.getRowCount());
// We do more like the nested loop join with one rows than inner join
if (leftStats.getRowCount() == 1 || rightStats.getRowCount() == 1) {
Expand Down Expand Up @@ -193,7 +195,7 @@ private static double computeSelectivityForBuildSideWhenColStatsUnknown(Statisti
}

private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rightStats, Join join) {
if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
double rowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount());
rowCount = Math.max(1, rowCount);
return new StatisticsBuilder()
Expand All @@ -205,9 +207,9 @@ private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rig

Statistics innerJoinStats;
if (join.getEqualPredicates().isEmpty()) {
innerJoinStats = estimateNestLoopJoin(leftStats, rightStats, join);
innerJoinStats = estimateInnerJoinWithoutEqualPredicate(leftStats, rightStats, join);
} else {
innerJoinStats = estimateHashJoin(leftStats, rightStats, join);
innerJoinStats = estimateInnerJoinWithEqualPredicate(leftStats, rightStats, join);
}

if (!join.getOtherJoinConjuncts().isEmpty()) {
Expand Down Expand Up @@ -267,7 +269,7 @@ private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics leftStat
}

private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics rightStats, Join join) {
if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join) || join.isMarkJoin()) {
if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join) || join.isMarkJoin()) {
double sel = join.isMarkJoin() ? 1.0 : computeSelectivityForBuildSideWhenColStatsUnknown(rightStats, join);
Statistics result;
if (join.getJoinType().isLeftSemiOrAntiJoin()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation;
import org.apache.doris.nereids.trees.plans.algebra.Filter;
import org.apache.doris.nereids.trees.plans.algebra.Generate;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.algebra.Limit;
import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
import org.apache.doris.nereids.trees.plans.algebra.PartitionTopN;
Expand Down Expand Up @@ -576,8 +577,8 @@ public Statistics visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan>

@Override
public Statistics visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) {
Statistics joinStats = JoinEstimation.estimate(groupExpression.childStatistics(0),
groupExpression.childStatistics(1), join);
Statistics joinStats = computeJoin(join);
// TODO: make sure it is only applied here? inconsistent with physical join?
joinStats = new StatisticsBuilder(joinStats).setWidthInJoinCluster(
groupExpression.childStatistics(0).getWidthInJoinCluster()
+ groupExpression.childStatistics(1).getWidthInJoinCluster()).build();
Expand Down Expand Up @@ -721,16 +722,14 @@ public Statistics visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN
@Override
public Statistics visitPhysicalHashJoin(
PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin, Void context) {
return JoinEstimation.estimate(groupExpression.childStatistics(0),
groupExpression.childStatistics(1), hashJoin);
return computeJoin(hashJoin);
}

@Override
public Statistics visitPhysicalNestedLoopJoin(
PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin,
Void context) {
return JoinEstimation.estimate(groupExpression.childStatistics(0),
groupExpression.childStatistics(1), nestedLoopJoin);
return computeJoin(nestedLoopJoin);
}

// TODO: We should subtract those pruned column, and consider the expression transformations in the node.
Expand Down Expand Up @@ -1070,6 +1069,11 @@ private Statistics computeCatalogRelation(CatalogRelation catalogRelation) {
return builder.setRowCount(tableRowCount).build();
}

private Statistics computeJoin(Join join) {
return JoinEstimation.estimate(groupExpression.childStatistics(0),
groupExpression.childStatistics(1), join);
}

private Statistics computeTopN(TopN topN) {
Statistics stats = groupExpression.childStatistics(0);
return stats.withRowCountAndEnforceValid(Math.min(stats.getRowCount(), topN.getLimit()));
Expand Down

0 comments on commit b69ae32

Please sign in to comment.