Skip to content

Commit

Permalink
[opt](nereids) refine operator estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongjian.xzj authored and zhongjian.xzj committed Sep 18, 2024
1 parent b69ae32 commit d6c2925
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ public Void visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join,
if (joinType.isInnerJoin() || joinType.isCrossJoin()) {
return visit(join, context);
} else if ((joinType.isLeftJoin()
|| joinType.isLefSemiJoin()
|| joinType.isLeftSemiJoin()
|| joinType.isLeftAntiJoin()) && useLeft) {
return visit(join.left(), context);
} else if ((joinType.isRightJoin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@
public class JoinEstimation {
private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3;
private static double UNKNOWN_COL_STATS_FILTER_SEL_LOWER_BOUND = 0.5;
private static double TRUSTABLE_CONDITION_SELECTIVITY_POW_FACTOR = 2.0;
private static double UNTRUSTABLE_CONDITION_SELECTIVITY_LINEAR_FACTOR = 0.9;
private static double TRUSTABLE_UNIQ_THRESHOLD = 0.9;

private static EqualPredicate normalizeEqualPredJoinCondition(EqualPredicate equal, Statistics leftStats,
Statistics rightStats) {
private static EqualPredicate normalizeEqualPredJoinCondition(EqualPredicate equal, Statistics rightStats) {
boolean changeOrder = equal.left().getInputSlots().stream()
.anyMatch(slot -> rightStats.findColumnStatistics(slot) != null);
if (changeOrder) {
Expand Down Expand Up @@ -92,14 +94,13 @@ private static Statistics estimateInnerJoinWithEqualPredicate(Statistics leftSta
.map(expression -> (EqualPredicate) expression)
.filter(
expression -> {
// since ndv is not accurate, if ndv/rowcount < almostUniqueThreshold,
// since ndv is not accurate, if ndv/rowcount < TRUSTABLE_UNIQ_THRESHOLD,
// this column is regarded as unique.
double almostUniqueThreshold = 0.9;
EqualPredicate equal = normalizeEqualPredJoinCondition(expression, leftStats, rightStats);
EqualPredicate equal = normalizeEqualPredJoinCondition(expression, rightStats);
ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats);
ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats);
boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold
|| eqLeftColStats.ndv / leftStatsRowCount > almostUniqueThreshold;
boolean trustable = eqRightColStats.ndv / rightStatsRowCount > TRUSTABLE_UNIQ_THRESHOLD
|| eqLeftColStats.ndv / leftStatsRowCount > TRUSTABLE_UNIQ_THRESHOLD;
if (!trustable) {
double rNdv = StatsMathUtil.nonZeroDivisor(eqRightColStats.ndv);
double lNdv = StatsMathUtil.nonZeroDivisor(eqLeftColStats.ndv);
Expand All @@ -125,6 +126,8 @@ private static Statistics estimateInnerJoinWithEqualPredicate(Statistics leftSta

double outputRowCount;
if (!trustableConditions.isEmpty()) {
// TODO: trustable condition should use pk-fk like non-expanding(use one side row cnt directly) estimation,
// for the others, use the pow-down way to estimation.
List<Double> joinConditionSels = trustableConditions.stream()
.map(expression -> estimateJoinConditionSel(crossJoinStats, expression))
.sorted()
Expand All @@ -134,11 +137,13 @@ private static Statistics estimateInnerJoinWithEqualPredicate(Statistics leftSta
double denominator = 1.0;
for (Double joinConditionSel : joinConditionSels) {
sel *= Math.pow(joinConditionSel, 1 / denominator);
denominator *= 2;
denominator *= TRUSTABLE_CONDITION_SELECTIVITY_POW_FACTOR;
}
outputRowCount = Math.max(1, crossJoinStats.getRowCount() * sel);
outputRowCount = outputRowCount * Math.pow(0.9, unTrustableCondition.size());
outputRowCount = outputRowCount * Math.pow(UNTRUSTABLE_CONDITION_SELECTIVITY_LINEAR_FACTOR,
unTrustableCondition.size());
} else {
// TODO: refine this common path estimation
outputRowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount());
Optional<Double> ratio = unTrustEqualRatio.stream().min(Double::compareTo);
if (ratio.isPresent()) {
Expand Down Expand Up @@ -268,7 +273,8 @@ private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics leftStat
return Math.max(1, rowCount);
}

private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics rightStats, Join join) {
private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics rightStats,
Statistics innerJoinStats, Join join) {
if (joinConditionContainsUnknownColumnStats(leftStats, rightStats, join) || join.isMarkJoin()) {
double sel = join.isMarkJoin() ? 1.0 : computeSelectivityForBuildSideWhenColStatsUnknown(rightStats, join);
Statistics result;
Expand All @@ -287,6 +293,7 @@ private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics ri
result.normalizeColumnStatistics();
return result;
}
// TODO: replace with the else branch estimation
double rowCount = Double.POSITIVE_INFINITY;
for (Expression conjunct : join.getEqualPredicates()) {
double eqRowCount = estimateSemiOrAntiRowCountBySlotsEqual(leftStats, rightStats,
Expand All @@ -297,21 +304,37 @@ private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics ri
}
if (Double.isInfinite(rowCount)) {
//slotsEqual estimation failed, fall back to original algorithm
Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join);
double baseRowCount =
join.getJoinType().isLeftSemiOrAntiJoin() ? leftStats.getRowCount() : rightStats.getRowCount();
rowCount = Math.min(innerJoinStats.getRowCount(), baseRowCount);
return innerJoinStats.withRowCountAndEnforceValid(rowCount);
} else {
double crossRowCount = Math.max(1, leftStats.getRowCount()) * Math.max(1, rightStats.getRowCount());
double selectivity = innerJoinStats.getRowCount() / crossRowCount;
selectivity = Statistics.getValidSelectivity(selectivity);
double outputRowCount;
StatisticsBuilder builder;

if (join.getJoinType().isLeftSemiOrAntiJoin()) {
outputRowCount = leftStats.getRowCount();
builder = new StatisticsBuilder(leftStats);
builder.setRowCount(rowCount);
} else {
//right semi or anti
outputRowCount = rightStats.getRowCount();
builder = new StatisticsBuilder(rightStats);
builder.setRowCount(rowCount);
}
if (join.getJoinType().isLeftSemiJoin() || join.getJoinType().isRightSemiJoin()) {
outputRowCount *= selectivity;
} else {
outputRowCount *= 1 - selectivity;
if (join.getJoinType().isLeftAntiJoin() && rightStats.getRowCount() < 1) {
outputRowCount = leftStats.getRowCount();
} else if (join.getJoinType().isRightAntiJoin() && leftStats.getRowCount() < 1) {
outputRowCount = rightStats.getRowCount();
} else {
outputRowCount = StatsMathUtil.normalizeRowCountOrNdv(outputRowCount);
}
}
builder.setRowCount(outputRowCount);
Statistics outputStats = builder.build();
outputStats.normalizeColumnStatistics();
return outputStats;
Expand All @@ -328,26 +351,22 @@ public static Statistics estimate(Statistics leftStats, Statistics rightStats, J
.putColumnStatistics(leftStats.columnStatistics())
.putColumnStatistics(rightStats.columnStatistics())
.build();
Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join);
if (joinType.isSemiOrAntiJoin()) {
return estimateSemiOrAnti(leftStats, rightStats, join);
return estimateSemiOrAnti(leftStats, rightStats, innerJoinStats, join);
} else if (joinType == JoinType.INNER_JOIN) {
Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join);
innerJoinStats = updateJoinResultStatsByHashJoinCondition(innerJoinStats, join);
return innerJoinStats;
// TODO: why only inner but not for other outer join?
return updateJoinConditionColumnStatistics(innerJoinStats, join);
} else if (joinType == JoinType.LEFT_OUTER_JOIN) {
Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join);
double rowCount = Math.max(leftStats.getRowCount(), innerJoinStats.getRowCount());
rowCount = Math.max(leftStats.getRowCount(), rowCount);
return crossJoinStats.withRowCountAndEnforceValid(rowCount);
} else if (joinType == JoinType.RIGHT_OUTER_JOIN) {
Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join);
double rowCount = Math.max(rightStats.getRowCount(), innerJoinStats.getRowCount());
rowCount = Math.max(rowCount, rightStats.getRowCount());
return crossJoinStats.withRowCountAndEnforceValid(rowCount);
} else if (joinType == JoinType.FULL_OUTER_JOIN) {
Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join);
return crossJoinStats.withRowCountAndEnforceValid(leftStats.getRowCount()
+ rightStats.getRowCount() + innerJoinStats.getRowCount());
double rowCount = Math.max(leftStats.getRowCount(), innerJoinStats.getRowCount());
rowCount = Math.max(rightStats.getRowCount(), rowCount);
return crossJoinStats.withRowCountAndEnforceValid(rowCount);
} else if (joinType == JoinType.CROSS_JOIN) {
return new StatisticsBuilder()
.setRowCount(leftStats.getRowCount() * rightStats.getRowCount())
Expand All @@ -362,7 +381,7 @@ public static Statistics estimate(Statistics leftStats, Statistics rightStats, J
* L join R on a = b
* after join, a.ndv and b.ndv should be equal to min(a.ndv, b.ndv)
*/
private static Statistics updateJoinResultStatsByHashJoinCondition(Statistics innerStats, Join join) {
private static Statistics updateJoinConditionColumnStatistics(Statistics innerStats, Join join) {
Map<Expression, ColumnStatistic> updatedCols = new HashMap<>();
for (Expression expr : join.getEqualPredicates()) {
EqualPredicate equalTo = (EqualPredicate) expr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,9 @@ public static double divide(double a, double b) {
}
return a / nonZeroDivisor(b);
}

// TODO: add more protection at other stats estimation
public static double normalizeRowCountOrNdv(double value) {
return value >= 0 && value < 1 ? 1 : value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public final boolean isLeftAntiJoin() {
return this == LEFT_ANTI_JOIN;
}

public final boolean isLefSemiJoin() {
public final boolean isLeftSemiJoin() {
return this == LEFT_SEMI_JOIN;
}

Expand Down

0 comments on commit d6c2925

Please sign in to comment.