From b69ae32c3231a7be62a3ab3e4c010b579f3eeeba Mon Sep 17 00:00:00 2001 From: "zhongjian.xzj" Date: Wed, 18 Sep 2024 17:48:20 +0800 Subject: [PATCH] [opt](nereids) refine operator estimation --- .../doris/nereids/stats/JoinEstimation.java | 22 ++++++++++--------- .../doris/nereids/stats/StatsCalculator.java | 16 +++++++++----- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java index 05d45c01b58fb83..bc8dbf2761bc947 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java @@ -47,7 +47,7 @@ public class JoinEstimation { private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3; private static double UNKNOWN_COL_STATS_FILTER_SEL_LOWER_BOUND = 0.5; - private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equal, Statistics leftStats, + private static EqualPredicate normalizeEqualPredJoinCondition(EqualPredicate equal, Statistics leftStats, Statistics rightStats) { boolean changeOrder = equal.left().getInputSlots().stream() .anyMatch(slot -> rightStats.findColumnStatistics(slot) != null); @@ -58,7 +58,7 @@ private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equal, S } } - private static boolean hashJoinConditionContainsUnknownColumnStats(Statistics leftStats, + private static boolean joinConditionContainsUnknownColumnStats(Statistics leftStats, Statistics rightStats, Join join) { for (Expression expr : join.getEqualPredicates()) { for (Slot slot : expr.getInputSlots()) { @@ -74,7 +74,8 @@ private static boolean hashJoinConditionContainsUnknownColumnStats(Statistics le return false; } - private static Statistics estimateHashJoin(Statistics leftStats, Statistics rightStats, Join join) { + private static Statistics estimateInnerJoinWithEqualPredicate(Statistics leftStats, + Statistics rightStats, Join join) { /* * When we estimate filter A=B, * if any side of equation, A or B, is almost unique, the confidence level of estimation is high. @@ -94,7 +95,7 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ // since ndv is not accurate, if ndv/rowcount < almostUniqueThreshold, // this column is regarded as unique. double almostUniqueThreshold = 0.9; - EqualPredicate equal = normalizeHashJoinCondition(expression, leftStats, rightStats); + EqualPredicate equal = normalizeEqualPredJoinCondition(expression, leftStats, rightStats); ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats); ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats); boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold @@ -148,8 +149,9 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ return innerJoinStats; } - private static Statistics estimateNestLoopJoin(Statistics leftStats, Statistics rightStats, Join join) { - if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) { + private static Statistics estimateInnerJoinWithoutEqualPredicate(Statistics leftStats, + Statistics rightStats, Join join) { + if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) { double rowCount = (leftStats.getRowCount() + rightStats.getRowCount()); // We do more like the nested loop join with one rows than inner join if (leftStats.getRowCount() == 1 || rightStats.getRowCount() == 1) { @@ -193,7 +195,7 @@ private static double computeSelectivityForBuildSideWhenColStatsUnknown(Statisti } private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rightStats, Join join) { - if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) { + if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) { double rowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount()); rowCount = Math.max(1, rowCount); return new StatisticsBuilder() @@ -205,9 +207,9 @@ private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rig Statistics innerJoinStats; if (join.getEqualPredicates().isEmpty()) { - innerJoinStats = estimateNestLoopJoin(leftStats, rightStats, join); + innerJoinStats = estimateInnerJoinWithoutEqualPredicate(leftStats, rightStats, join); } else { - innerJoinStats = estimateHashJoin(leftStats, rightStats, join); + innerJoinStats = estimateInnerJoinWithEqualPredicate(leftStats, rightStats, join); } if (!join.getOtherJoinConjuncts().isEmpty()) { @@ -267,7 +269,7 @@ private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics leftStat } private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics rightStats, Join join) { - if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join) || join.isMarkJoin()) { + if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join) || join.isMarkJoin()) { double sel = join.isMarkJoin() ? 1.0 : computeSelectivityForBuildSideWhenColStatsUnknown(rightStats, join); Statistics result; if (join.getJoinType().isLeftSemiOrAntiJoin()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 1a661ce4cacd039..e913232d6c8e356 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -49,6 +49,7 @@ import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation; import org.apache.doris.nereids.trees.plans.algebra.Filter; import org.apache.doris.nereids.trees.plans.algebra.Generate; +import org.apache.doris.nereids.trees.plans.algebra.Join; import org.apache.doris.nereids.trees.plans.algebra.Limit; import org.apache.doris.nereids.trees.plans.algebra.OlapScan; import org.apache.doris.nereids.trees.plans.algebra.PartitionTopN; @@ -576,8 +577,8 @@ public Statistics visitLogicalPartitionTopN(LogicalPartitionTopN @Override public Statistics visitLogicalJoin(LogicalJoin join, Void context) { - Statistics joinStats = JoinEstimation.estimate(groupExpression.childStatistics(0), - groupExpression.childStatistics(1), join); + Statistics joinStats = computeJoin(join); + // TODO: make sure it is only applied here? inconsistent with physical join? joinStats = new StatisticsBuilder(joinStats).setWidthInJoinCluster( groupExpression.childStatistics(0).getWidthInJoinCluster() + groupExpression.childStatistics(1).getWidthInJoinCluster()).build(); @@ -721,16 +722,14 @@ public Statistics visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN @Override public Statistics visitPhysicalHashJoin( PhysicalHashJoin hashJoin, Void context) { - return JoinEstimation.estimate(groupExpression.childStatistics(0), - groupExpression.childStatistics(1), hashJoin); + return computeJoin(hashJoin); } @Override public Statistics visitPhysicalNestedLoopJoin( PhysicalNestedLoopJoin nestedLoopJoin, Void context) { - return JoinEstimation.estimate(groupExpression.childStatistics(0), - groupExpression.childStatistics(1), nestedLoopJoin); + return computeJoin(nestedLoopJoin); } // TODO: We should subtract those pruned column, and consider the expression transformations in the node. @@ -1070,6 +1069,11 @@ private Statistics computeCatalogRelation(CatalogRelation catalogRelation) { return builder.setRowCount(tableRowCount).build(); } + private Statistics computeJoin(Join join) { + return JoinEstimation.estimate(groupExpression.childStatistics(0), + groupExpression.childStatistics(1), join); + } + private Statistics computeTopN(TopN topN) { Statistics stats = groupExpression.childStatistics(0); return stats.withRowCountAndEnforceValid(Math.min(stats.getRowCount(), topN.getLimit()));