From 0f14a992a9d9500916d0bc10afc6c8e521943ff6 Mon Sep 17 00:00:00 2001 From: "zhongjian.xzj" Date: Tue, 24 Sep 2024 15:33:10 +0800 Subject: [PATCH] [opt](nereids) refine operator estimation --- .../exploration/mv/MaterializedViewUtils.java | 2 +- .../nereids/stats/ExpressionEstimation.java | 8 +- .../doris/nereids/stats/FilterEstimation.java | 464 ++++++++++-------- .../doris/nereids/stats/JoinEstimation.java | 169 +++++-- .../doris/nereids/stats/StatsCalculator.java | 25 +- .../doris/nereids/stats/StatsMathUtil.java | 5 + .../doris/nereids/trees/plans/JoinType.java | 2 +- .../doris/statistics/ColumnStatistic.java | 2 +- .../doris/statistics/StatisticRange.java | 56 +-- .../apache/doris/statistics/Statistics.java | 60 ++- 10 files changed, 461 insertions(+), 332 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtils.java index 6af72b1e81db3f4..342c88ff677ff09 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtils.java @@ -414,7 +414,7 @@ public Void visitLogicalJoin(LogicalJoin join, if (joinType.isInnerJoin() || joinType.isCrossJoin()) { return visit(join, context); } else if ((joinType.isLeftJoin() - || joinType.isLefSemiJoin() + || joinType.isLeftSemiJoin() || joinType.isLeftAntiJoin()) && useLeft) { return visit(join.left(), context); } else if ((joinType.isRightJoin() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java index 126e90417213125..0a2bda874c1c090 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java @@ -242,9 +242,9 @@ public ColumnStatistic visitLiteral(Literal literal, Statistics context) { return new ColumnStatisticBuilder() .setMaxValue(literalVal) .setMinValue(literalVal) - .setNdv(1) + .setNdv(literal.isNullLiteral() ? 0 : 1) .setNumNulls(literal.isNullLiteral() ? 1 : 0) - .setAvgSizeByte(1) + .setAvgSizeByte(literal.getDataType().width()) .setMinExpr(literal.toLegacyLiteral()) .setMaxExpr(literal.toLegacyLiteral()) .build(); @@ -579,7 +579,7 @@ public ColumnStatistic visitToDate(ToDate toDate, Statistics context) { ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(childColumnStats) .setAvgSizeByte(toDate.getDataType().width()) .setDataSize(toDate.getDataType().width() * context.getRowCount()); - if (childColumnStats.minOrMaxIsInf()) { + if (childColumnStats.isMinMaxInvalid()) { return columnStatisticBuilder.build(); } double minValue; @@ -610,7 +610,7 @@ public ColumnStatistic visitToDays(ToDays toDays, Statistics context) { ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(childColumnStats) .setAvgSizeByte(toDays.getDataType().width()) .setDataSize(toDays.getDataType().width() * context.getRowCount()); - if (childColumnStats.minOrMaxIsInf()) { + if (childColumnStats.isMinMaxInvalid()) { return columnStatisticBuilder.build(); } double minValue; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java index b3576a0e58e61e6..20617deef83a924 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java @@ -103,7 +103,7 @@ public Statistics estimate(Expression expression, Statistics inputStats) { } outputStats = expression.accept(this, new EstimationContext(deltaStats.build())); } - outputStats.enforceValid(); + outputStats.normalizeColumnStatistics(); return outputStats; } @@ -117,13 +117,15 @@ public Statistics visitCompoundPredicate(CompoundPredicate predicate, Estimation Expression leftExpr = predicate.child(0); Expression rightExpr = predicate.child(1); Statistics leftStats = leftExpr.accept(this, context); - leftStats = leftStats.normalizeByRatio(context.statistics.getRowCount()); - Statistics andStats = rightExpr.accept(this, - new EstimationContext(leftStats)); + leftStats.normalizeColumnStatistics(context.statistics.getRowCount(), true); + Statistics andStats = rightExpr.accept(this, new EstimationContext(leftStats)); if (predicate instanceof And) { + // TODO: this will cause estimation change + //andStats.normalizeColumnStatistics(); return andStats; } else if (predicate instanceof Or) { Statistics rightStats = rightExpr.accept(this, context); + rightStats.normalizeColumnStatistics(context.statistics.getRowCount(), true); double rowCount = leftStats.getRowCount() + rightStats.getRowCount() - andStats.getRowCount(); Statistics orStats = context.statistics.withRowCount(rowCount); Set leftInputSlots = leftExpr.getInputSlots(); @@ -171,42 +173,32 @@ public Statistics visitComparisonPredicate(ComparisonPredicate cp, EstimationCon ColumnStatistic statsForLeft = ExpressionEstimation.estimate(left, context.statistics); ColumnStatistic statsForRight = ExpressionEstimation.estimate(right, context.statistics); if (!left.isConstant() && !right.isConstant()) { - return calculateWhenBothColumn(cp, context, statsForLeft, statsForRight); + return estimateColumnToColumn(cp, context, statsForLeft, statsForRight); } else { - // For literal, it's max min is same value. - return calculateWhenLiteralRight(cp, - statsForLeft, - statsForRight, - context); + return estimateColumnToConstant(cp, statsForLeft, statsForRight, context); } } - private Statistics updateLessThanLiteral(Expression leftExpr, DataType dataType, ColumnStatistic statsForLeft, - ColumnStatistic statsForRight, EstimationContext context) { - StatisticRange rightRange = new StatisticRange(statsForLeft.minValue, statsForLeft.minExpr, - statsForRight.maxValue, statsForRight.maxExpr, - statsForLeft.ndv, dataType); - return estimateBinaryComparisonFilter(leftExpr, dataType, - statsForLeft, - rightRange, context); + private Statistics estimateColumnLessThanConstant(Expression leftExpr, DataType dataType, + ColumnStatistic statsForLeft, ColumnStatistic statsForRight, EstimationContext context) { + StatisticRange constantRange = new StatisticRange(statsForLeft.minValue, statsForLeft.minExpr, + statsForRight.maxValue, statsForRight.maxExpr, statsForLeft.ndv, dataType); + return estimateColumnToConstantRange(leftExpr, dataType, statsForLeft, constantRange, context); } - private Statistics updateGreaterThanLiteral(Expression leftExpr, DataType dataType, ColumnStatistic statsForLeft, - ColumnStatistic statsForRight, EstimationContext context) { - StatisticRange rightRange = new StatisticRange(statsForRight.minValue, statsForRight.minExpr, - statsForLeft.maxValue, statsForLeft.maxExpr, - statsForLeft.ndv, dataType); - return estimateBinaryComparisonFilter(leftExpr, dataType, statsForLeft, rightRange, context); + private Statistics estimateColumnGreaterThanConstant(Expression leftExpr, DataType dataType, + ColumnStatistic statsForLeft, ColumnStatistic statsForRight, EstimationContext context) { + StatisticRange constantRange = new StatisticRange(statsForRight.minValue, statsForRight.minExpr, + statsForLeft.maxValue, statsForLeft.maxExpr, statsForLeft.ndv, dataType); + return estimateColumnToConstantRange(leftExpr, dataType, statsForLeft, constantRange, context); } - private Statistics calculateWhenLiteralRight(ComparisonPredicate cp, + private Statistics estimateColumnToConstant(ComparisonPredicate cp, ColumnStatistic statsForLeft, ColumnStatistic statsForRight, EstimationContext context) { if (statsForLeft.isUnKnown) { return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT); - } - - if (cp instanceof EqualPredicate) { - return estimateEqualTo(cp, statsForLeft, statsForRight, context); + } else if (cp instanceof EqualPredicate) { + return estimateColumnEqualToConstant(cp, statsForLeft, statsForRight, context); } else { // literal Map used to covert dateLiteral back to stringLiteral Map literalMap = new HashMap<>(); @@ -229,12 +221,13 @@ private Statistics calculateWhenLiteralRight(ComparisonPredicate cp, statsForLeftMayConverted = statsForLeftMayConvertedOpt.get(); statsForRightMayConverted = statsForRightMayConvertedOpt.get(); } - Statistics result = null; + + Statistics result; if (cp instanceof LessThan || cp instanceof LessThanEqual) { - result = updateLessThanLiteral(cp.left(), compareType, statsForLeftMayConverted, + result = estimateColumnLessThanConstant(cp.left(), compareType, statsForLeftMayConverted, statsForRightMayConverted, context); } else if (cp instanceof GreaterThan || cp instanceof GreaterThanEqual) { - result = updateGreaterThanLiteral(cp.left(), compareType, statsForLeftMayConverted, + result = estimateColumnGreaterThanConstant(cp.left(), compareType, statsForLeftMayConverted, statsForRightMayConverted, context); } else { throw new RuntimeException(String.format("Unexpected expression : %s", cp.toSql())); @@ -315,7 +308,7 @@ private Optional tryConvertStrLiteralToDateLiteral(LiteralExpr lite return dt == null ? Optional.empty() : Optional.of(dt); } - private Statistics estimateEqualTo(ComparisonPredicate cp, ColumnStatistic statsForLeft, + private Statistics estimateColumnEqualToConstant(ComparisonPredicate cp, ColumnStatistic statsForLeft, ColumnStatistic statsForRight, EstimationContext context) { double selectivity; @@ -334,6 +327,8 @@ private Statistics estimateEqualTo(ComparisonPredicate cp, ColumnStatistic stats } else { double val = statsForRight.maxValue; if (val > statsForLeft.maxValue || val < statsForLeft.minValue) { + // FIXME: make sure left's stats is RangeScalable whose min/max is trustable + // The equal to constant which don't rely on the range, so maybe safe here. selectivity = 0.0; } else if (ndv >= 1.0) { selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv); @@ -350,113 +345,132 @@ private Statistics estimateEqualTo(ComparisonPredicate cp, ColumnStatistic stats if (!(left instanceof SlotReference)) { left.accept(new ColumnStatsAdjustVisitor(), equalStats); } + // TODO: normalizeColumnStatistics() will have problem after ColumnStatsAdjustVisitor + //equalStats.normalizeColumnStatistics(); return equalStats; } - private Statistics calculateWhenBothColumn(ComparisonPredicate cp, EstimationContext context, + private Statistics estimateColumnToColumn(ComparisonPredicate cp, EstimationContext context, ColumnStatistic statsForLeft, ColumnStatistic statsForRight) { Expression left = cp.left(); Expression right = cp.right(); if (cp instanceof EqualPredicate) { return estimateColumnEqualToColumn(left, statsForLeft, right, statsForRight, cp instanceof NullSafeEqual, context); - } - if (cp instanceof GreaterThan || cp instanceof GreaterThanEqual) { + } else if (cp instanceof GreaterThan || cp instanceof GreaterThanEqual) { return estimateColumnLessThanColumn(right, statsForRight, left, statsForLeft, context); - } - if (cp instanceof LessThan || cp instanceof LessThanEqual) { + } else if (cp instanceof LessThan || cp instanceof LessThanEqual) { return estimateColumnLessThanColumn(left, statsForLeft, right, statsForRight, context); + } else { + return context.statistics; } - return context.statistics; } - @Override - public Statistics visitInPredicate(InPredicate inPredicate, EstimationContext context) { - Expression compareExpr = inPredicate.getCompareExpr(); - ColumnStatistic compareExprStats = ExpressionEstimation.estimate(compareExpr, context.statistics); - if (compareExprStats.isUnKnown || compareExpr instanceof Function) { - return context.statistics.withSel(DEFAULT_IN_COEFFICIENT); - } + private ColumnStatistic updateInPredicateColumnStatistics(InPredicate inPredicate, EstimationContext context, + ColumnStatistic compareExprStats) { List options = inPredicate.getOptions(); - // init minOption and maxOption by compareExpr.max and compareExpr.min respectively, - // and then adjust min/max by options - double minOptionValue = compareExprStats.maxValue; - double maxOptionValue = compareExprStats.minValue; - LiteralExpr minOptionLiteral = compareExprStats.maxExpr; - LiteralExpr maxOptionLiteral = compareExprStats.minExpr; - /* suppose A.(min, max) = (0, 10), A.ndv=10 - A in ( 1, 2, 5, 100): - validInOptCount = 3, that is (1, 2, 5) - table selectivity = 3/10 - A.min = 1, A.max=5 - A.selectivity = 3/5 - A.ndv = 3 - A not in (1, 2, 3, 100): - validInOptCount = 10 - 3 - we assume that 1, 2, 3 exist in A - A.ndv = 10 - 3 = 7 - table selectivity = 7/10 - A.(min, max) not changed - A.selectivity = 7/10 - */ - int validInOptCount = 0; - double selectivity = 1.0; ColumnStatisticBuilder compareExprStatsBuilder = new ColumnStatisticBuilder(compareExprStats); - int nonLiteralOptionCount = 0; - for (Expression option : options) { - ColumnStatistic optionStats = ExpressionEstimation.estimate(option, context.statistics); - if (option instanceof Literal) { - // remove the options which is out of compareExpr.range - if (compareExprStats.minValue <= optionStats.maxValue - && optionStats.minValue <= compareExprStats.maxValue) { - validInOptCount++; - LiteralExpr optionLiteralExpr = ((Literal) option).toLegacyLiteral(); - if (maxOptionLiteral == null || optionLiteralExpr.compareTo(maxOptionLiteral) >= 0) { - maxOptionLiteral = optionLiteralExpr; - maxOptionValue = optionStats.maxValue; - } - if (minOptionLiteral == null || optionLiteralExpr.compareTo(minOptionLiteral) <= 0) { - minOptionLiteral = optionLiteralExpr; - minOptionValue = optionStats.minValue; + if (!compareExprStats.isMinMaxInvalid()) { + // init minOption and maxOption by compareExpr.max and compareExpr.min respectively, + // and then adjust min/max by options + double minOptionValue = compareExprStats.maxValue; + double maxOptionValue = compareExprStats.minValue; + LiteralExpr minOptionLiteral = compareExprStats.maxExpr; + LiteralExpr maxOptionLiteral = compareExprStats.minExpr; + /* suppose A.(min, max) = (0, 10), A.ndv=10 + A in ( 1, 2, 5, 100): + validInOptCount = 3, that is (1, 2, 5) + table selectivity = 3/10 + A.min = 1, A.max=5 + A.selectivity = 3/5 + A.ndv = 3 + A not in (1, 2, 3, 100): + validInOptCount = 10 - 3 + we assume that 1, 2, 3 exist in A + A.ndv = 10 - 3 = 7 + table selectivity = 7/10 + A.(min, max) not changed + A.selectivity = 7/10 + */ + int validInOptCount = 0; + int nonLiteralOptionCount = 0; + for (Expression option : options) { + ColumnStatistic optionStats = ExpressionEstimation.estimate(option, context.statistics); + if (option instanceof Literal) { + // remove the options which is out of compareExpr.range + Preconditions.checkState(Math.abs(optionStats.maxValue - optionStats.minValue) < 1e-06, + "literal's min/max doesn't equal"); + double constValue = optionStats.maxValue; + if (compareExprStats.minValue <= constValue && compareExprStats.maxValue >= constValue) { + validInOptCount++; + LiteralExpr optionLiteralExpr = ((Literal) option).toLegacyLiteral(); + if (maxOptionLiteral == null || optionLiteralExpr.compareTo(maxOptionLiteral) >= 0) { + maxOptionLiteral = optionLiteralExpr; + maxOptionValue = constValue; + } + + if (minOptionLiteral == null || optionLiteralExpr.compareTo(minOptionLiteral) <= 0) { + minOptionLiteral = optionLiteralExpr; + minOptionValue = constValue; + } } + } else { + nonLiteralOptionCount++; } - } else { - nonLiteralOptionCount++; } - } - if (nonLiteralOptionCount > 0) { - // A in (x+1, ...) - // "x+1" is not literal, and if const-fold can not handle it, it blocks estimation of min/max value. - // and hence, we do not adjust compareExpr.stats.range. - int newNdv = nonLiteralOptionCount + validInOptCount; - if (newNdv < compareExprStats.ndv) { - compareExprStatsBuilder.setNdv(newNdv); - selectivity = StatsMathUtil.divide(newNdv, compareExprStats.ndv); + if (nonLiteralOptionCount > 0) { + // A in (x+1, ...) + // "x+1" is not literal, and if const-fold can not handle it, it blocks estimation of min/max value. + // and hence, we do not adjust compareExpr.stats.range. + int newNdv = nonLiteralOptionCount + validInOptCount; + if (newNdv < compareExprStats.ndv) { + compareExprStatsBuilder.setNdv(newNdv); + } } else { - selectivity = 1.0; + maxOptionValue = Math.min(maxOptionValue, compareExprStats.maxValue); + minOptionValue = Math.max(minOptionValue, compareExprStats.minValue); + compareExprStatsBuilder.setMaxValue(maxOptionValue); + compareExprStatsBuilder.setMaxExpr(maxOptionLiteral); + compareExprStatsBuilder.setMinValue(minOptionValue); + compareExprStatsBuilder.setMinExpr(minOptionLiteral); + if (validInOptCount < compareExprStats.ndv) { + compareExprStatsBuilder.setNdv(validInOptCount); + } } } else { - maxOptionValue = Math.min(maxOptionValue, compareExprStats.maxValue); - minOptionValue = Math.max(minOptionValue, compareExprStats.minValue); - compareExprStatsBuilder.setMaxValue(maxOptionValue); - compareExprStatsBuilder.setMaxExpr(maxOptionLiteral); - compareExprStatsBuilder.setMinValue(minOptionValue); - compareExprStatsBuilder.setMinExpr(minOptionLiteral); - if (validInOptCount < compareExprStats.ndv) { - compareExprStatsBuilder.setNdv(validInOptCount); - selectivity = StatsMathUtil.divide(validInOptCount, compareExprStats.ndv); - } else { - selectivity = 1.0; - } + // other types, such as string type, using option's size to estimate + // min/max will not be updated + compareExprStatsBuilder.setNdv(Math.min(options.size(), compareExprStats.getOriginalNdv())); } compareExprStatsBuilder.setNumNulls(0); + return compareExprStatsBuilder.build(); + } + + @Override + public Statistics visitInPredicate(InPredicate inPredicate, EstimationContext context) { + Expression compareExpr = inPredicate.getCompareExpr(); + ColumnStatistic compareExprStats = ExpressionEstimation.estimate(compareExpr, context.statistics); + if (compareExprStats.isUnKnown || compareExpr instanceof Function) { + return context.statistics.withSel(DEFAULT_IN_COEFFICIENT); + } + + List options = inPredicate.getOptions(); + ColumnStatistic newCompareExprStats = updateInPredicateColumnStatistics(inPredicate, context, compareExprStats); + double selectivity; + if (!newCompareExprStats.isMinMaxInvalid()) { + selectivity = Statistics.getValidSelectivity( + Math.min(StatsMathUtil.divide(newCompareExprStats.ndv, compareExprStats.ndv), 1)); + } else { + selectivity = Statistics.getValidSelectivity( + Math.min(options.size() / compareExprStats.getOriginalNdv(), 1)); + } + Statistics estimated = new StatisticsBuilder(context.statistics).build(); - ColumnStatistic stats = compareExprStatsBuilder.build(); - selectivity = getNotNullSelectivity(compareExprStats.numNulls, estimated.getRowCount(), - compareExprStats.ndv, selectivity); + selectivity = getNotNullSelectivity(newCompareExprStats.numNulls, estimated.getRowCount(), + newCompareExprStats.ndv, selectivity); estimated = estimated.withSel(selectivity); - estimated.addColumnStats(compareExpr, stats); + estimated.addColumnStats(compareExpr, newCompareExprStats); context.addKeyIfSlot(compareExpr); return estimated; } @@ -473,6 +487,7 @@ public Statistics visitNot(Not not, EstimationContext context) { } Expression child = not.child(); Statistics childStats = child.accept(this, context); + childStats.normalizeColumnStatistics(); //if estimated rowCount is 0, adjust to 1 to make upper join reorder reasonable. double rowCount = Math.max(context.statistics.getRowCount() - childStats.getRowCount(), 1); StatisticsBuilder statisticsBuilder = new StatisticsBuilder(context.statistics).setRowCount(rowCount); @@ -496,7 +511,6 @@ public Statistics visitNot(Not not, EstimationContext context) { || child instanceof Match, "Not-predicate meet unexpected child: %s", child.toSql()); if (child instanceof Like) { - rowCount = context.statistics.getRowCount() - childStats.getRowCount(); colBuilder.setNdv(Math.max(1.0, originColStats.ndv - childColStats.ndv)); } else if (child instanceof InPredicate) { colBuilder.setNdv(Math.max(1.0, originColStats.ndv - childColStats.ndv)); @@ -517,7 +531,6 @@ public Statistics visitNot(Not not, EstimationContext context) { .setMaxValue(originColStats.maxValue) .setMaxExpr(originColStats.maxExpr); } else if (child instanceof Match) { - rowCount = context.statistics.getRowCount() - childStats.getRowCount(); colBuilder.setNdv(Math.max(1.0, originColStats.ndv - childColStats.ndv)); } if (not.child().getInputSlots().size() == 1 && !(child instanceof IsNull)) { @@ -539,15 +552,18 @@ public Statistics visitIsNull(IsNull isNull, EstimationContext context) { double row = context.statistics.getRowCount() * DEFAULT_ISNULL_SELECTIVITY; return new StatisticsBuilder(context.statistics).setRowCount(row).build(); } - double outputRowCount = childColStats.numNulls; + double childOutputRowCount = context.statistics.getRowCount(); + double outputRowCount = Math.min(childColStats.numNulls, childOutputRowCount); if (!isOnBaseTable) { // for is null on base table, use the numNulls, otherwise // nulls will be generated such as outer join and then we do a protection Expression child = isNull.child(); Statistics childStats = child.accept(this, context); + childStats.normalizeColumnStatistics(); outputRowCount = Math.max(childStats.getRowCount() * DEFAULT_ISNULL_SELECTIVITY, outputRowCount); outputRowCount = Math.max(outputRowCount, 1); } + ColumnStatisticBuilder colBuilder = new ColumnStatisticBuilder(childColStats); colBuilder.setNumNulls(outputRowCount) .setMaxValue(Double.POSITIVE_INFINITY) @@ -583,17 +599,18 @@ public boolean isKeySlot(Expression expr) { } } - private Statistics estimateBinaryComparisonFilter(Expression leftExpr, DataType dataType, ColumnStatistic leftStats, + private Statistics estimateColumnToConstantRange(Expression leftExpr, DataType dataType, ColumnStatistic leftStats, StatisticRange rightRange, EstimationContext context) { - StatisticRange leftRange = - new StatisticRange(leftStats.minValue, leftStats.minExpr, leftStats.maxValue, leftStats.maxExpr, - leftStats.ndv, dataType); - StatisticRange intersectRange = leftRange.cover(rightRange); - + StatisticRange leftRange = new StatisticRange(leftStats.minValue, leftStats.minExpr, + leftStats.maxValue, leftStats.maxExpr, leftStats.ndv, dataType); ColumnStatisticBuilder leftColumnStatisticBuilder; Statistics updatedStatistics; + + StatisticRange intersectRange = leftRange.intersect(rightRange); + double sel = leftRange.getDistinctValues() == 0 + ? 1.0 + : intersectRange.getDistinctValues() / leftRange.getDistinctValues(); if (intersectRange.isEmpty()) { - updatedStatistics = context.statistics.withRowCount(0); leftColumnStatisticBuilder = new ColumnStatisticBuilder(leftStats) .setMinValue(Double.NEGATIVE_INFINITY) .setMinExpr(null) @@ -601,7 +618,8 @@ private Statistics estimateBinaryComparisonFilter(Expression leftExpr, DataType .setMaxExpr(null) .setNdv(0) .setNumNulls(0); - } else { + updatedStatistics = context.statistics.withRowCount(0); + } else if (dataType instanceof RangeScalable || sel == 0 || sel == 1) { leftColumnStatisticBuilder = new ColumnStatisticBuilder(leftStats) .setMinValue(intersectRange.getLow()) .setMinExpr(intersectRange.getLowExpr()) @@ -609,42 +627,63 @@ private Statistics estimateBinaryComparisonFilter(Expression leftExpr, DataType .setMaxExpr(intersectRange.getHighExpr()) .setNdv(intersectRange.getDistinctValues()) .setNumNulls(0); - double sel = leftRange.getDistinctValues() == 0 - ? 1.0 - : intersectRange.getDistinctValues() / leftRange.getDistinctValues(); - if (!(dataType instanceof RangeScalable) && (sel != 0.0 && sel != 1.0)) { - sel = DEFAULT_INEQUALITY_COEFFICIENT; - } else { - sel = Math.max(sel, RANGE_SELECTIVITY_THRESHOLD); - } + sel = Math.max(sel, RANGE_SELECTIVITY_THRESHOLD); sel = getNotNullSelectivity(leftStats.numNulls, context.statistics.getRowCount(), leftStats.ndv, sel); updatedStatistics = context.statistics.withSel(sel); + } else { + sel = DEFAULT_INEQUALITY_COEFFICIENT; + sel = getNotNullSelectivity(leftStats.numNulls, context.statistics.getRowCount(), leftStats.ndv, sel); + leftColumnStatisticBuilder = new ColumnStatisticBuilder(leftStats) + .setMinValue(intersectRange.getLow()) + .setMinExpr(intersectRange.getLowExpr()) + .setMaxValue(intersectRange.getHigh()) + .setMaxExpr(intersectRange.getHighExpr()) + .setNdv(Math.max(1, Math.min(leftStats.ndv * sel, intersectRange.getDistinctValues()))) + .setNumNulls(0); + updatedStatistics = context.statistics.withSel(sel); } updatedStatistics.addColumnStats(leftExpr, leftColumnStatisticBuilder.build()); context.addKeyIfSlot(leftExpr); leftExpr.accept(new ColumnStatsAdjustVisitor(), updatedStatistics); + // TODO: normalizeColumnStatistics() will have problem after ColumnStatsAdjustVisitor + //updatedStatistics.normalizeColumnStatistics(); + return updatedStatistics; } private Statistics estimateColumnEqualToColumn(Expression leftExpr, ColumnStatistic leftStats, Expression rightExpr, ColumnStatistic rightStats, boolean keepNull, EstimationContext context) { + ColumnStatisticBuilder intersectBuilder = new ColumnStatisticBuilder(leftStats); StatisticRange leftRange = StatisticRange.from(leftStats, leftExpr.getDataType()); StatisticRange rightRange = StatisticRange.from(rightStats, rightExpr.getDataType()); - StatisticRange leftIntersectRight = leftRange.intersect(rightRange); - StatisticRange intersect = rightRange.intersect(leftIntersectRight); - ColumnStatisticBuilder intersectBuilder = new ColumnStatisticBuilder(leftStats); - intersectBuilder.setNdv(intersect.getDistinctValues()); + StatisticRange intersect = leftRange.intersect(rightRange); intersectBuilder.setMinValue(intersect.getLow()); intersectBuilder.setMaxValue(intersect.getHigh()); - double numNull = 0; - if (keepNull) { - numNull = Math.min(leftStats.numNulls, rightStats.numNulls); + + if (leftExpr.getDataType() instanceof RangeScalable && rightExpr.getDataType() instanceof RangeScalable + && !leftStats.isMinMaxInvalid() && !rightStats.isMinMaxInvalid()) { + intersectBuilder.setNdv(intersect.getDistinctValues()); + } else { + // intersect ndv uses min ndv but selectivity computing use the max + intersectBuilder.setNdv(Math.min(leftStats.ndv, rightStats.ndv)); } + double numNull = keepNull ? Math.min(leftStats.numNulls, rightStats.numNulls) : 0; intersectBuilder.setNumNulls(numNull); - double sel = 1 / StatsMathUtil.nonZeroDivisor(Math.max(leftStats.ndv, rightStats.ndv)); - Statistics updatedStatistics = context.statistics.withSel(sel, numNull); - updatedStatistics.addColumnStats(leftExpr, intersectBuilder.build()); - updatedStatistics.addColumnStats(rightExpr, intersectBuilder.build()); + + // TODO: consider notNullSelectivity + //double origRowCount = context.statistics.getRowCount(); + double leftNotNullSel = 1.0; //Statistics.getValidSelectivity(1 - (leftStats.numNulls / origRowCount)); + double rightNotNullSel = 1.0; //Statistics.getValidSelectivity(1 - (rightStats.numNulls / origRowCount)); + double notNullSel = 1 / StatsMathUtil.nonZeroDivisor(Math.max(leftStats.ndv, rightStats.ndv)) + * (keepNull ? 1 : leftNotNullSel * rightNotNullSel); + + Statistics updatedStatistics = context.statistics.withSel(notNullSel, numNull); + ColumnStatistic newLeftStatistics = intersectBuilder + .setAvgSizeByte(leftStats.avgSizeByte).build(); + ColumnStatistic newRightStatistics = intersectBuilder + .setAvgSizeByte(rightStats.avgSizeByte).build(); + updatedStatistics.addColumnStats(leftExpr, newLeftStatistics); + updatedStatistics.addColumnStats(rightExpr, newRightStatistics); context.addKeyIfSlot(leftExpr); context.addKeyIfSlot(rightExpr); @@ -655,67 +694,90 @@ private Statistics estimateColumnLessThanColumn(Expression leftExpr, ColumnStati Expression rightExpr, ColumnStatistic rightStats, EstimationContext context) { StatisticRange leftRange = StatisticRange.from(leftStats, leftExpr.getDataType()); StatisticRange rightRange = StatisticRange.from(rightStats, rightExpr.getDataType()); - Statistics statistics = null; - // Left always less than Right - if (leftRange.getHigh() < rightRange.getLow()) { - statistics = - context.statistics.withRowCount(Math.min(context.statistics.getRowCount() - leftStats.numNulls, - context.statistics.getRowCount() - rightStats.numNulls)); - statistics.addColumnStats(leftExpr, new ColumnStatisticBuilder(leftStats).setNumNulls(0.0).build()); - statistics.addColumnStats(rightExpr, new ColumnStatisticBuilder(rightStats).setNumNulls(0.0).build()); - context.addKeyIfSlot(leftExpr); - context.addKeyIfSlot(rightExpr); - return statistics; - } - if (leftRange.isInfinite() || rightRange.isInfinite()) { - return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT); - } + StatisticRange intersect = leftRange.intersect(rightRange); + + if (leftExpr.getDataType() instanceof RangeScalable && rightExpr.getDataType() instanceof RangeScalable + && !leftStats.isMinMaxInvalid() && !rightStats.isMinMaxInvalid()) { + // TODO: use intersect interface to refine this to avoid this kind of left-dominating style + Statistics statistics; + // Left always less than Right + if (leftRange.getHigh() < rightRange.getLow()) { + statistics = + context.statistics.withRowCount(Math.min(context.statistics.getRowCount() - leftStats.numNulls, + context.statistics.getRowCount() - rightStats.numNulls)); + statistics.addColumnStats(leftExpr, new ColumnStatisticBuilder(leftStats).setNumNulls(0.0).build()); + statistics.addColumnStats(rightExpr, new ColumnStatisticBuilder(rightStats).setNumNulls(0.0).build()); + context.addKeyIfSlot(leftExpr); + context.addKeyIfSlot(rightExpr); + return statistics; + } - double leftOverlapPercent = leftRange.overlapPercentWith(rightRange); + double leftOverlapPercent = leftRange.overlapPercentWith(rightRange); - if (leftOverlapPercent == 0.0) { - // Left always greater than right - return context.statistics.withRowCount(0.0); - } - StatisticRange leftAlwaysLessThanRightRange = new StatisticRange(leftStats.minValue, leftStats.minExpr, - rightStats.minValue, rightStats.minExpr, Double.NaN, leftExpr.getDataType()); - double leftAlwaysLessThanRightPercent = 0; - if (leftRange.getLow() < rightRange.getLow()) { - leftAlwaysLessThanRightPercent = leftRange.overlapPercentWith(leftAlwaysLessThanRightRange); - } - ColumnStatistic leftColumnStatistic = new ColumnStatisticBuilder(leftStats) - .setMaxValue(Math.min(leftRange.getHigh(), rightRange.getHigh())) - .setMinValue(leftRange.getLow()) - .setNdv(leftStats.ndv * (leftAlwaysLessThanRightPercent + leftOverlapPercent)) - .setNumNulls(0) - .build(); - double rightOverlappingRangeFraction = rightRange.overlapPercentWith(leftRange); - double rightAlwaysGreaterRangeFraction = 0; - if (leftRange.getHigh() < rightRange.getHigh()) { - rightAlwaysGreaterRangeFraction = rightRange.overlapPercentWith(new StatisticRange( - leftRange.getHigh(), leftRange.getHighExpr(), - rightRange.getHigh(), rightRange.getHighExpr(), - Double.NaN, rightExpr.getDataType())); - } - ColumnStatistic rightColumnStatistic = new ColumnStatisticBuilder(rightStats) - .setMinValue(Math.max(leftRange.getLow(), rightRange.getLow())) - .setMaxValue(rightRange.getHigh()) - .setNdv(rightStats.ndv * (rightAlwaysGreaterRangeFraction + rightOverlappingRangeFraction)) - .setNumNulls(0) - .build(); - double sel = DEFAULT_INEQUALITY_COEFFICIENT; - if (leftExpr.getDataType() instanceof RangeScalable) { - sel = leftAlwaysLessThanRightPercent - + leftOverlapPercent * rightOverlappingRangeFraction * DEFAULT_INEQUALITY_COEFFICIENT - + leftOverlapPercent * rightAlwaysGreaterRangeFraction; - } else if (leftOverlapPercent == 1.0) { - sel = 1.0; + if (leftOverlapPercent == 0.0) { + // Left always greater than right + return context.statistics.withRowCount(0.0); + } + StatisticRange leftAlwaysLessThanRightRange = new StatisticRange(leftStats.minValue, leftStats.minExpr, + rightStats.minValue, rightStats.minExpr, Double.NaN, leftExpr.getDataType()); + double leftAlwaysLessThanRightPercent = 0; + if (leftRange.getLow() < rightRange.getLow()) { + leftAlwaysLessThanRightPercent = leftRange.overlapPercentWith(leftAlwaysLessThanRightRange); + } + ColumnStatistic leftColumnStatistic = new ColumnStatisticBuilder(leftStats) + .setMaxValue(Math.min(leftRange.getHigh(), rightRange.getHigh())) + .setMinValue(leftRange.getLow()) + .setNdv(leftStats.ndv * (leftAlwaysLessThanRightPercent + leftOverlapPercent)) + .setNumNulls(0) + .build(); + double rightOverlappingRangeFraction = rightRange.overlapPercentWith(leftRange); + double rightAlwaysGreaterRangeFraction = 0; + if (leftRange.getHigh() < rightRange.getHigh()) { + rightAlwaysGreaterRangeFraction = rightRange.overlapPercentWith(new StatisticRange( + leftRange.getHigh(), leftRange.getHighExpr(), + rightRange.getHigh(), rightRange.getHighExpr(), + Double.NaN, rightExpr.getDataType())); + } + ColumnStatistic rightColumnStatistic = new ColumnStatisticBuilder(rightStats) + .setMinValue(Math.max(leftRange.getLow(), rightRange.getLow())) + .setMaxValue(rightRange.getHigh()) + .setNdv(rightStats.ndv * (rightAlwaysGreaterRangeFraction + rightOverlappingRangeFraction)) + .setNumNulls(0) + .build(); + double sel; + if (leftExpr.getDataType() instanceof RangeScalable) { + sel = leftAlwaysLessThanRightPercent + + leftOverlapPercent * rightOverlappingRangeFraction * DEFAULT_INEQUALITY_COEFFICIENT + + leftOverlapPercent * rightAlwaysGreaterRangeFraction; + } else if (leftOverlapPercent == 1.0) { + sel = 1.0; + } else { + sel = DEFAULT_INEQUALITY_COEFFICIENT; + } + context.addKeyIfSlot(leftExpr); + context.addKeyIfSlot(rightExpr); + return context.statistics.withSel(sel) + .addColumnStats(leftExpr, leftColumnStatistic) + .addColumnStats(rightExpr, rightColumnStatistic); + } else { + ColumnStatistic leftColumnStatistic = new ColumnStatisticBuilder(leftStats) + .setMaxValue(intersect.getHigh()) + .setMinValue(intersect.getLow()) + .setNumNulls(0) + .setNdv(Math.max(leftStats.ndv * DEFAULT_INEQUALITY_COEFFICIENT, 1)) + .build(); + ColumnStatistic rightColumnStatistic = new ColumnStatisticBuilder(rightStats) + .setMaxValue(intersect.getHigh()) + .setMinValue(intersect.getLow()) + .setNumNulls(0) + .setNdv(Math.max(rightStats.ndv * DEFAULT_INEQUALITY_COEFFICIENT, 1)) + .build(); + context.addKeyIfSlot(leftExpr); + context.addKeyIfSlot(rightExpr); + return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT) + .addColumnStats(leftExpr, leftColumnStatistic) + .addColumnStats(rightExpr, rightColumnStatistic); } - context.addKeyIfSlot(leftExpr); - context.addKeyIfSlot(rightExpr); - return context.statistics.withSel(sel) - .addColumnStats(leftExpr, leftColumnStatistic) - .addColumnStats(rightExpr, rightColumnStatistic); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java index f8298871f0d632a..99080566e848ca2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java @@ -46,9 +46,11 @@ public class JoinEstimation { private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3; private static double UNKNOWN_COL_STATS_FILTER_SEL_LOWER_BOUND = 0.5; + private static double TRUSTABLE_CONDITION_SELECTIVITY_POW_FACTOR = 2.0; + private static double UNTRUSTABLE_CONDITION_SELECTIVITY_LINEAR_FACTOR = 0.9; + private static double TRUSTABLE_UNIQ_THRESHOLD = 0.9; - private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equal, Statistics leftStats, - Statistics rightStats) { + private static EqualPredicate normalizeEqualPredJoinCondition(EqualPredicate equal, Statistics rightStats) { boolean changeOrder = equal.left().getInputSlots().stream() .anyMatch(slot -> rightStats.findColumnStatistics(slot) != null); if (changeOrder) { @@ -58,7 +60,7 @@ private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equal, S } } - private static boolean hashJoinConditionContainsUnknownColumnStats(Statistics leftStats, + private static boolean joinConditionContainsUnknownColumnStats(Statistics leftStats, Statistics rightStats, Join join) { for (Expression expr : join.getEqualPredicates()) { for (Slot slot : expr.getInputSlots()) { @@ -74,7 +76,8 @@ private static boolean hashJoinConditionContainsUnknownColumnStats(Statistics le return false; } - private static Statistics estimateHashJoin(Statistics leftStats, Statistics rightStats, Join join) { + private static Statistics estimateInnerJoinWithEqualPredicate(Statistics leftStats, + Statistics rightStats, Join join) { /* * When we estimate filter A=B, * if any side of equation, A or B, is almost unique, the confidence level of estimation is high. @@ -91,14 +94,13 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ .map(expression -> (EqualPredicate) expression) .filter( expression -> { - // since ndv is not accurate, if ndv/rowcount < almostUniqueThreshold, + // since ndv is not accurate, if ndv/rowcount < TRUSTABLE_UNIQ_THRESHOLD, // this column is regarded as unique. - double almostUniqueThreshold = 0.9; - EqualPredicate equal = normalizeHashJoinCondition(expression, leftStats, rightStats); + EqualPredicate equal = normalizeEqualPredJoinCondition(expression, rightStats); ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats); ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats); - boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold - || eqLeftColStats.ndv / leftStatsRowCount > almostUniqueThreshold; + boolean trustable = eqRightColStats.ndv / rightStatsRowCount > TRUSTABLE_UNIQ_THRESHOLD + || eqLeftColStats.ndv / leftStatsRowCount > TRUSTABLE_UNIQ_THRESHOLD; if (!trustable) { double rNdv = StatsMathUtil.nonZeroDivisor(eqRightColStats.ndv); double lNdv = StatsMathUtil.nonZeroDivisor(eqLeftColStats.ndv); @@ -124,6 +126,8 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ double outputRowCount; if (!trustableConditions.isEmpty()) { + // TODO: strict pk-fk can use one-side stats instead of crossJoinStats + // in estimateJoinConditionSel, to get more accurate estimation. List joinConditionSels = trustableConditions.stream() .map(expression -> estimateJoinConditionSel(crossJoinStats, expression)) .sorted() @@ -133,10 +137,11 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ double denominator = 1.0; for (Double joinConditionSel : joinConditionSels) { sel *= Math.pow(joinConditionSel, 1 / denominator); - denominator *= 2; + denominator *= TRUSTABLE_CONDITION_SELECTIVITY_POW_FACTOR; } outputRowCount = Math.max(1, crossJoinStats.getRowCount() * sel); - outputRowCount = outputRowCount * Math.pow(0.9, unTrustableCondition.size()); + outputRowCount = outputRowCount * Math.pow(UNTRUSTABLE_CONDITION_SELECTIVITY_LINEAR_FACTOR, + unTrustableCondition.size()); } else { outputRowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount()); Optional ratio = unTrustEqualRatio.stream().min(Double::compareTo); @@ -148,8 +153,9 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ return innerJoinStats; } - private static Statistics estimateNestLoopJoin(Statistics leftStats, Statistics rightStats, Join join) { - if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) { + private static Statistics estimateInnerJoinWithoutEqualPredicate(Statistics leftStats, + Statistics rightStats, Join join) { + if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) { double rowCount = (leftStats.getRowCount() + rightStats.getRowCount()); // We do more like the nested loop join with one rows than inner join if (leftStats.getRowCount() == 1 || rightStats.getRowCount() == 1) { @@ -193,7 +199,7 @@ private static double computeSelectivityForBuildSideWhenColStatsUnknown(Statisti } private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rightStats, Join join) { - if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) { + if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) { double rowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount()); rowCount = Math.max(1, rowCount); return new StatisticsBuilder() @@ -205,9 +211,9 @@ private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rig Statistics innerJoinStats; if (join.getEqualPredicates().isEmpty()) { - innerJoinStats = estimateNestLoopJoin(leftStats, rightStats, join); + innerJoinStats = estimateInnerJoinWithoutEqualPredicate(leftStats, rightStats, join); } else { - innerJoinStats = estimateHashJoin(leftStats, rightStats, join); + innerJoinStats = estimateInnerJoinWithEqualPredicate(leftStats, rightStats, join); } if (!join.getOtherJoinConjuncts().isEmpty()) { @@ -266,21 +272,25 @@ private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics leftStat return Math.max(1, rowCount); } - private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics rightStats, Join join) { - if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join) || join.isMarkJoin()) { + private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics rightStats, + Statistics innerJoinStats, Join join) { + if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join) || join.isMarkJoin()) { double sel = join.isMarkJoin() ? 1.0 : computeSelectivityForBuildSideWhenColStatsUnknown(rightStats, join); + Statistics result; if (join.getJoinType().isLeftSemiOrAntiJoin()) { - return new StatisticsBuilder().setRowCount(leftStats.getRowCount() * sel) + result = new StatisticsBuilder().setRowCount(leftStats.getRowCount() * sel) .putColumnStatistics(leftStats.columnStatistics()) .putColumnStatistics(rightStats.columnStatistics()) .build(); } else { //right semi or anti - return new StatisticsBuilder().setRowCount(rightStats.getRowCount() * sel) + result = new StatisticsBuilder().setRowCount(rightStats.getRowCount() * sel) .putColumnStatistics(leftStats.columnStatistics()) .putColumnStatistics(rightStats.columnStatistics()) .build(); } + result.normalizeColumnStatistics(); + return result; } double rowCount = Double.POSITIVE_INFINITY; for (Expression conjunct : join.getEqualPredicates()) { @@ -292,12 +302,40 @@ private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics ri } if (Double.isInfinite(rowCount)) { //slotsEqual estimation failed, fall back to original algorithm - Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join); double baseRowCount = join.getJoinType().isLeftSemiOrAntiJoin() ? leftStats.getRowCount() : rightStats.getRowCount(); rowCount = Math.min(innerJoinStats.getRowCount(), baseRowCount); return innerJoinStats.withRowCountAndEnforceValid(rowCount); } else { + /*double crossRowCount = Math.max(1, leftStats.getRowCount()) * Math.max(1, rightStats.getRowCount()); + double selectivity = innerJoinStats.getRowCount() / crossRowCount; + selectivity = Statistics.getValidSelectivity(selectivity); + double outputRowCount; + StatisticsBuilder builder; + + if (join.getJoinType().isLeftSemiOrAntiJoin()) { + outputRowCount = leftStats.getRowCount(); + builder = new StatisticsBuilder(leftStats); + } else { + outputRowCount = rightStats.getRowCount(); + builder = new StatisticsBuilder(rightStats); + } + if (join.getJoinType().isLeftSemiJoin() || join.getJoinType().isRightSemiJoin()) { + outputRowCount *= selectivity; + } else { + outputRowCount *= 1 - selectivity; + if (join.getJoinType().isLeftAntiJoin() && rightStats.getRowCount() < 1) { + outputRowCount = leftStats.getRowCount(); + } else if (join.getJoinType().isRightAntiJoin() && leftStats.getRowCount() < 1) { + outputRowCount = rightStats.getRowCount(); + } else { + outputRowCount = StatsMathUtil.normalizeRowCountOrNdv(outputRowCount); + } + } + builder.setRowCount(outputRowCount); + Statistics outputStats = builder.build(); + outputStats.normalizeColumnStatistics(); + return outputStats;*/ StatisticsBuilder builder; if (join.getJoinType().isLeftSemiOrAntiJoin()) { builder = new StatisticsBuilder(leftStats); @@ -308,7 +346,7 @@ private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics ri builder.setRowCount(rowCount); } Statistics outputStats = builder.build(); - outputStats.enforceValid(); + outputStats.normalizeColumnStatistics(); return outputStats; } } @@ -323,49 +361,48 @@ public static Statistics estimate(Statistics leftStats, Statistics rightStats, J .putColumnStatistics(leftStats.columnStatistics()) .putColumnStatistics(rightStats.columnStatistics()) .build(); + Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join); if (joinType.isSemiOrAntiJoin()) { - return estimateSemiOrAnti(leftStats, rightStats, join); + Statistics outputStats = estimateSemiOrAnti(leftStats, rightStats, innerJoinStats, join); + updateJoinConditionColumnStatistics(outputStats, join); + return outputStats; } else if (joinType == JoinType.INNER_JOIN) { - Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join); - innerJoinStats = updateJoinResultStatsByHashJoinCondition(innerJoinStats, join); + updateJoinConditionColumnStatistics(innerJoinStats, join); return innerJoinStats; } else if (joinType == JoinType.LEFT_OUTER_JOIN) { - Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join); double rowCount = Math.max(leftStats.getRowCount(), innerJoinStats.getRowCount()); - rowCount = Math.max(leftStats.getRowCount(), rowCount); + updateJoinConditionColumnStatistics(crossJoinStats, join); return crossJoinStats.withRowCountAndEnforceValid(rowCount); } else if (joinType == JoinType.RIGHT_OUTER_JOIN) { - Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join); double rowCount = Math.max(rightStats.getRowCount(), innerJoinStats.getRowCount()); - rowCount = Math.max(rowCount, rightStats.getRowCount()); + updateJoinConditionColumnStatistics(crossJoinStats, join); return crossJoinStats.withRowCountAndEnforceValid(rowCount); } else if (joinType == JoinType.FULL_OUTER_JOIN) { - Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join); - return crossJoinStats.withRowCountAndEnforceValid(leftStats.getRowCount() - + rightStats.getRowCount() + innerJoinStats.getRowCount()); + double rowCount = Math.max(leftStats.getRowCount(), innerJoinStats.getRowCount()); + rowCount = Math.max(rightStats.getRowCount(), rowCount); + updateJoinConditionColumnStatistics(crossJoinStats, join); + return crossJoinStats.withRowCountAndEnforceValid(rowCount); } else if (joinType == JoinType.CROSS_JOIN) { - return new StatisticsBuilder() - .setRowCount(leftStats.getRowCount() * rightStats.getRowCount()) - .putColumnStatistics(leftStats.columnStatistics()) - .putColumnStatistics(rightStats.columnStatistics()) - .build(); + updateJoinConditionColumnStatistics(crossJoinStats, join); + return crossJoinStats; } throw new AnalysisException("join type not supported: " + join.getJoinType()); } /** - * L join R on a = b - * after join, a.ndv and b.ndv should be equal to min(a.ndv, b.ndv) + * update join condition columns' ColumnStatistics, based on different join type. */ - private static Statistics updateJoinResultStatsByHashJoinCondition(Statistics innerStats, Join join) { + private static void updateJoinConditionColumnStatistics(Statistics inputStats, Join join) { Map updatedCols = new HashMap<>(); + JoinType joinType = join.getJoinType(); for (Expression expr : join.getEqualPredicates()) { EqualPredicate equalTo = (EqualPredicate) expr; - ColumnStatistic leftColStats = ExpressionEstimation.estimate(equalTo.left(), innerStats); - ColumnStatistic rightColStats = ExpressionEstimation.estimate(equalTo.right(), innerStats); - double minNdv = Math.min(leftColStats.ndv, rightColStats.ndv); - leftColStats = new ColumnStatisticBuilder(leftColStats).setNdv(minNdv).build(); - rightColStats = new ColumnStatisticBuilder(rightColStats).setNdv(minNdv).build(); + ColumnStatistic leftColStats = ExpressionEstimation.estimate(equalTo.left(), inputStats); + ColumnStatistic rightColStats = ExpressionEstimation.estimate(equalTo.right(), inputStats); + double leftNdv = 1.0; + double rightNdv = 1.0; + boolean updateLeft = false; + boolean updateRight = false; Expression eqLeft = equalTo.left(); if (eqLeft instanceof Cast) { eqLeft = eqLeft.child(0); @@ -374,13 +411,47 @@ private static Statistics updateJoinResultStatsByHashJoinCondition(Statistics in if (eqRight instanceof Cast) { eqRight = eqRight.child(0); } - updatedCols.put(eqLeft, leftColStats); - updatedCols.put(eqRight, rightColStats); + if (joinType == JoinType.INNER_JOIN) { + leftNdv = Math.min(leftColStats.ndv, rightColStats.ndv); + rightNdv = Math.min(leftColStats.ndv, rightColStats.ndv); + updateLeft = true; + updateRight = true; + } else if (joinType == JoinType.LEFT_OUTER_JOIN) { + leftNdv = leftColStats.ndv; + rightNdv = Math.min(leftColStats.ndv, rightColStats.ndv); + updateLeft = true; + updateRight = true; + } else if (joinType == JoinType.LEFT_SEMI_JOIN + || joinType == JoinType.LEFT_ANTI_JOIN + || joinType == JoinType.NULL_AWARE_LEFT_ANTI_JOIN) { + leftNdv = Math.min(leftColStats.ndv, rightColStats.ndv); + updateLeft = true; + } else if (joinType == JoinType.RIGHT_OUTER_JOIN) { + leftNdv = Math.min(leftColStats.ndv, rightColStats.ndv); + rightNdv = rightColStats.ndv; + } else if (joinType == JoinType.RIGHT_SEMI_JOIN + || joinType == JoinType.RIGHT_ANTI_JOIN) { + rightNdv = Math.min(leftColStats.ndv, rightColStats.ndv); + updateRight = true; + } else if (joinType == JoinType.FULL_OUTER_JOIN || joinType == JoinType.CROSS_JOIN) { + leftNdv = leftColStats.ndv; + rightNdv = rightColStats.ndv; + updateLeft = true; + updateRight = true; + } + + if (updateLeft) { + leftColStats = new ColumnStatisticBuilder(leftColStats).setNdv(leftNdv).build(); + updatedCols.put(eqLeft, leftColStats); + } + if (updateRight) { + rightColStats = new ColumnStatisticBuilder(rightColStats).setNdv(rightNdv).build(); + updatedCols.put(eqRight, rightColStats); + } } updatedCols.entrySet().stream().forEach( - entry -> innerStats.addColumnStats(entry.getKey(), entry.getValue()) + entry -> inputStats.addColumnStats(entry.getKey(), entry.getValue()) ); - return innerStats; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index bac66f34ae665f8..7394aaa2f7530f8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -49,6 +49,7 @@ import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation; import org.apache.doris.nereids.trees.plans.algebra.Filter; import org.apache.doris.nereids.trees.plans.algebra.Generate; +import org.apache.doris.nereids.trees.plans.algebra.Join; import org.apache.doris.nereids.trees.plans.algebra.Limit; import org.apache.doris.nereids.trees.plans.algebra.OlapScan; import org.apache.doris.nereids.trees.plans.algebra.PartitionTopN; @@ -253,7 +254,7 @@ public static void estimate(GroupExpression groupExpression, CascadesContext con private void estimate() { Plan plan = groupExpression.getPlan(); Statistics newStats = plan.accept(this, null); - newStats.enforceValid(); + newStats.normalizeColumnStatistics(); // We ensure that the rowCount remains unchanged in order to make the cost of each plan comparable. if (groupExpression.getOwnerGroup().getStatistics() == null) { @@ -594,8 +595,9 @@ public Statistics visitLogicalPartitionTopN(LogicalPartitionTopN @Override public Statistics visitLogicalJoin(LogicalJoin join, Void context) { - Statistics joinStats = JoinEstimation.estimate(groupExpression.childStatistics(0), - groupExpression.childStatistics(1), join); + Statistics joinStats = computeJoin(join); + // NOTE: physical operator visiting doesn't need the following + // logic which will ONLY be used in no-stats estimation. joinStats = new StatisticsBuilder(joinStats).setWidthInJoinCluster( groupExpression.childStatistics(0).getWidthInJoinCluster() + groupExpression.childStatistics(1).getWidthInJoinCluster()).build(); @@ -739,16 +741,14 @@ public Statistics visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN @Override public Statistics visitPhysicalHashJoin( PhysicalHashJoin hashJoin, Void context) { - return JoinEstimation.estimate(groupExpression.childStatistics(0), - groupExpression.childStatistics(1), hashJoin); + return computeJoin(hashJoin); } @Override public Statistics visitPhysicalNestedLoopJoin( PhysicalNestedLoopJoin nestedLoopJoin, Void context) { - return JoinEstimation.estimate(groupExpression.childStatistics(0), - groupExpression.childStatistics(1), nestedLoopJoin); + return computeJoin(nestedLoopJoin); } // TODO: We should subtract those pruned column, and consider the expression transformations in the node. @@ -865,7 +865,7 @@ private Statistics computeFilter(Filter filter) { } builder.setRowCount(isNullStats.getRowCount()); stats = builder.build(); - stats.enforceValid(); + stats.normalizeColumnStatistics(); } } } @@ -937,7 +937,7 @@ false, getTotalColumnStatisticMap(), false, newStats = ((Plan) newJoin).accept(statsCalculator, null); } - newStats.enforceValid(); + newStats.normalizeColumnStatistics(); double selectivity = Statistics.getValidSelectivity( newStats.getRowCount() / (leftRowCount * rightRowCount)); @@ -1087,6 +1087,11 @@ private Statistics computeCatalogRelation(CatalogRelation catalogRelation) { return builder.setRowCount(tableRowCount).build(); } + private Statistics computeJoin(Join join) { + return JoinEstimation.estimate(groupExpression.childStatistics(0), + groupExpression.childStatistics(1), join); + } + private Statistics computeTopN(TopN topN) { Statistics stats = groupExpression.childStatistics(0); return stats.withRowCountAndEnforceValid(Math.min(stats.getRowCount(), topN.getLimit())); @@ -1190,7 +1195,7 @@ private Statistics computeAggregate(Aggregate aggregate) { slotToColumnStats.put(outputExpression.toSlot(), columnStat); } Statistics aggOutputStats = new Statistics(rowCount, 1, slotToColumnStats); - aggOutputStats.enforceValid(); + aggOutputStats.normalizeColumnStatistics(); return aggOutputStats; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsMathUtil.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsMathUtil.java index c56437f53bcb7a5..49cc466b780ec33 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsMathUtil.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsMathUtil.java @@ -58,4 +58,9 @@ public static double divide(double a, double b) { } return a / nonZeroDivisor(b); } + + // TODO: add more protection at other stats estimation + public static double normalizeRowCountOrNdv(double value) { + return value >= 0 && value < 1 ? 1 : value; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java index 3423b13168b4289..112c1d98a98a1f2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java @@ -164,7 +164,7 @@ public final boolean isLeftAntiJoin() { return this == LEFT_ANTI_JOIN; } - public final boolean isLefSemiJoin() { + public final boolean isLeftSemiJoin() { return this == LEFT_SEMI_JOIN; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStatistic.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStatistic.java index 3edc14577d9efd5..7a64f2031345967 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStatistic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/ColumnStatistic.java @@ -333,7 +333,7 @@ public static ColumnStatistic fromJson(String statJson) { ); } - public boolean minOrMaxIsInf() { + public boolean isMinMaxInvalid() { return Double.isInfinite(maxValue) || Double.isInfinite(minValue); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticRange.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticRange.java index ca9735b56654b12..4e1d0ac491f1590 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticRange.java @@ -158,8 +158,8 @@ public StatisticRange intersect(StatisticRange other) { double newHigh = smallerHigh.first; LiteralExpr newHighExpr = smallerHigh.second; if (newLow <= newHigh) { - return new StatisticRange(newLow, newLowExpr, newHigh, newHighExpr, - overlappingDistinctValues(other), dataType); + double distinctValues = overlappingDistinctValues(other); + return new StatisticRange(newLow, newLowExpr, newHigh, newHighExpr, distinctValues, dataType); } return empty(dataType); } @@ -178,33 +178,6 @@ public Pair maxPair(double r1, LiteralExpr e1, double r2, L return Pair.of(r2, e2); } - public StatisticRange cover(StatisticRange other) { - StatisticRange resultRange; - Pair biggerLow = maxPair(low, lowExpr, other.low, other.lowExpr); - double newLow = biggerLow.first; - LiteralExpr newLowExpr = biggerLow.second; - Pair smallerHigh = minPair(high, highExpr, other.high, other.highExpr); - double newHigh = smallerHigh.first; - LiteralExpr newHighExpr = smallerHigh.second; - - if (newLow <= newHigh) { - double overlapPercentOfLeft = overlapPercentWith(other); - double overlapDistinctValuesLeft = overlapPercentOfLeft * distinctValues; - double coveredDistinctValues = minExcludeNaN(distinctValues, overlapDistinctValuesLeft); - if (this.isBothInfinite() && other.isOneSideInfinite()) { - resultRange = new StatisticRange(newLow, newLowExpr, newHigh, newHighExpr, - distinctValues * INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR, - dataType); - } else { - resultRange = new StatisticRange(newLow, newLowExpr, newHigh, newHighExpr, coveredDistinctValues, - dataType); - } - } else { - resultRange = empty(dataType); - } - return resultRange; - } - public StatisticRange union(StatisticRange other) { double overlapPercentThis = this.overlapPercentWith(other); double overlapPercentOther = other.overlapPercentWith(this); @@ -220,10 +193,27 @@ public StatisticRange union(StatisticRange other) { } private double overlappingDistinctValues(StatisticRange other) { - double overlapPercentOfLeft = overlapPercentWith(other); - double overlapPercentOfRight = other.overlapPercentWith(this); - double overlapDistinctValuesLeft = overlapPercentOfLeft * distinctValues; - double overlapDistinctValuesRight = overlapPercentOfRight * other.distinctValues; + double overlapDistinctValuesLeft; + double overlapDistinctValuesRight; + if (this.isInfinite() || other.isInfinite()) { + overlapDistinctValuesRight = distinctValues * INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR; + } else if (Math.abs(other.low - other.high) < 1e-6) { + // other is constant + overlapDistinctValuesRight = distinctValues; + } else { + double overlapPercentOfRight = other.overlapPercentWith(this); + overlapDistinctValuesRight = overlapPercentOfRight * other.distinctValues; + } + + if (other.isInfinite() || this.isInfinite()) { + overlapDistinctValuesLeft = distinctValues * INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR; + } else if (Math.abs(this.low - this.high) < 1e-6) { + overlapDistinctValuesLeft = distinctValues; + } else { + double overlapPercentOfLeft = this.overlapPercentWith(other); + overlapDistinctValuesLeft = overlapPercentOfLeft * distinctValues; + } + return minExcludeNaN(overlapDistinctValuesLeft, overlapDistinctValuesRight); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java index e18dc09792054e5..4236993977aaa10 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java @@ -98,39 +98,53 @@ public Statistics withExpressionToColumnStats(Map e */ public Statistics withRowCountAndEnforceValid(double rowCount) { Statistics statistics = new Statistics(rowCount, widthInJoinCluster, expressionToColumnStats); - statistics.enforceValid(); + statistics.normalizeColumnStatistics(); return statistics; } - public void enforceValid() { + // IMPORTANT: it is suggested to do this action after each estimation critical visiting, + // since statistics will have serious deviation during the partial deriving. + public void normalizeColumnStatistics() { + normalizeColumnStatistics(this.rowCount); + } + + public void normalizeColumnStatistics(double inputRowCount) { + normalizeColumnStatistics(this.rowCount, false); + } + + public void normalizeColumnStatistics(double inputRowCount, boolean isNumNullsDecreaseByProportion) { for (Entry entry : expressionToColumnStats.entrySet()) { ColumnStatistic columnStatistic = entry.getValue(); - if (!checkColumnStatsValid(columnStatistic) && !columnStatistic.isUnKnown()) { - double ndv = Math.min(columnStatistic.ndv, rowCount); + double factor = isNumNullsDecreaseByProportion ? rowCount / inputRowCount : 1; + // the following columnStatistic.isUnKnown() judgment is loop inside since current doris + // supports partial stats deriving, i.e, allowing part of tables have stats and other parts don't, + // or part of columns have stats but other parts don't, especially join and filter estimation. + if (!checkColumnStatsValid(columnStatistic, rowCount) && !columnStatistic.isUnKnown()) { ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(columnStatistic); - columnStatisticBuilder.setNdv(ndv); - columnStatisticBuilder.setNumNulls(Math.min(columnStatistic.numNulls, rowCount - ndv)); + double ndv = Math.min(columnStatistic.ndv, rowCount); + double numNulls = Math.min(columnStatistic.numNulls * factor, rowCount - ndv); + columnStatisticBuilder.setNumNulls(numNulls); + columnStatisticBuilder.setNdv(Math.min(ndv, rowCount - numNulls)); columnStatistic = columnStatisticBuilder.build(); expressionToColumnStats.put(entry.getKey(), columnStatistic); } } } - public boolean checkColumnStatsValid(ColumnStatistic columnStatistic) { - return columnStatistic.ndv <= rowCount - && columnStatistic.numNulls <= rowCount - columnStatistic.ndv; + public boolean checkColumnStatsValid(ColumnStatistic columnStatistic, double rowCount) { + return columnStatistic.ndv <= rowCount && columnStatistic.numNulls <= rowCount - columnStatistic.ndv; } public Statistics withSel(double sel) { return withSel(sel, 0); } - public Statistics withSel(double sel, double numNull) { - sel = StatsMathUtil.minNonNaN(sel, 1); + public Statistics withSel(double notNullSel, double numNull) { + notNullSel = StatsMathUtil.minNonNaN(notNullSel, 1); if (Double.isNaN(rowCount)) { return this; } - double newCount = rowCount * sel + numNull; + double newCount = rowCount * notNullSel + numNull; return new Statistics(newCount, widthInJoinCluster, new HashMap<>(expressionToColumnStats)); } @@ -227,8 +241,8 @@ public int getBENumber() { return 1; } - public static double getValidSelectivity(double nullSel) { - return nullSel < 0 ? 0 : (nullSel > 1 ? 1 : nullSel); + public static double getValidSelectivity(double selectivity) { + return selectivity < 0 ? 0 : (selectivity > 1 ? 1 : selectivity); } /** @@ -263,24 +277,6 @@ public int getWidthInJoinCluster() { return widthInJoinCluster; } - public Statistics normalizeByRatio(double originRowCount) { - if (rowCount >= originRowCount || rowCount <= 0) { - return this; - } - StatisticsBuilder builder = new StatisticsBuilder(this); - double ratio = rowCount / originRowCount; - for (Entry entry : expressionToColumnStats.entrySet()) { - ColumnStatistic colStats = entry.getValue(); - if (colStats.numNulls != 0 || colStats.ndv > rowCount) { - ColumnStatisticBuilder colStatsBuilder = new ColumnStatisticBuilder(colStats); - colStatsBuilder.setNumNulls(colStats.numNulls * ratio); - colStatsBuilder.setNdv(Math.min(rowCount - colStatsBuilder.getNumNulls(), colStats.ndv)); - builder.putColumnStatistics(entry.getKey(), colStatsBuilder.build()); - } - } - return builder.build(); - } - public double getDeltaRowCount() { return deltaRowCount; }