Skip to content

Commit

Permalink
[opt](nereids) refine operator estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongjian.xzj authored and zhongjian.xzj committed Sep 24, 2024
1 parent ace2796 commit 0f14a99
Show file tree
Hide file tree
Showing 10 changed files with 461 additions and 332 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ public Void visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join,
if (joinType.isInnerJoin() || joinType.isCrossJoin()) {
return visit(join, context);
} else if ((joinType.isLeftJoin()
|| joinType.isLefSemiJoin()
|| joinType.isLeftSemiJoin()
|| joinType.isLeftAntiJoin()) && useLeft) {
return visit(join.left(), context);
} else if ((joinType.isRightJoin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ public ColumnStatistic visitLiteral(Literal literal, Statistics context) {
return new ColumnStatisticBuilder()
.setMaxValue(literalVal)
.setMinValue(literalVal)
.setNdv(1)
.setNdv(literal.isNullLiteral() ? 0 : 1)
.setNumNulls(literal.isNullLiteral() ? 1 : 0)
.setAvgSizeByte(1)
.setAvgSizeByte(literal.getDataType().width())
.setMinExpr(literal.toLegacyLiteral())
.setMaxExpr(literal.toLegacyLiteral())
.build();
Expand Down Expand Up @@ -579,7 +579,7 @@ public ColumnStatistic visitToDate(ToDate toDate, Statistics context) {
ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(childColumnStats)
.setAvgSizeByte(toDate.getDataType().width())
.setDataSize(toDate.getDataType().width() * context.getRowCount());
if (childColumnStats.minOrMaxIsInf()) {
if (childColumnStats.isMinMaxInvalid()) {
return columnStatisticBuilder.build();
}
double minValue;
Expand Down Expand Up @@ -610,7 +610,7 @@ public ColumnStatistic visitToDays(ToDays toDays, Statistics context) {
ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(childColumnStats)
.setAvgSizeByte(toDays.getDataType().width())
.setDataSize(toDays.getDataType().width() * context.getRowCount());
if (childColumnStats.minOrMaxIsInf()) {
if (childColumnStats.isMinMaxInvalid()) {
return columnStatisticBuilder.build();
}
double minValue;
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation;
import org.apache.doris.nereids.trees.plans.algebra.Filter;
import org.apache.doris.nereids.trees.plans.algebra.Generate;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.algebra.Limit;
import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
import org.apache.doris.nereids.trees.plans.algebra.PartitionTopN;
Expand Down Expand Up @@ -253,7 +254,7 @@ public static void estimate(GroupExpression groupExpression, CascadesContext con
private void estimate() {
Plan plan = groupExpression.getPlan();
Statistics newStats = plan.accept(this, null);
newStats.enforceValid();
newStats.normalizeColumnStatistics();

// We ensure that the rowCount remains unchanged in order to make the cost of each plan comparable.
if (groupExpression.getOwnerGroup().getStatistics() == null) {
Expand Down Expand Up @@ -594,8 +595,9 @@ public Statistics visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan>

@Override
public Statistics visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) {
Statistics joinStats = JoinEstimation.estimate(groupExpression.childStatistics(0),
groupExpression.childStatistics(1), join);
Statistics joinStats = computeJoin(join);
// NOTE: physical operator visiting doesn't need the following
// logic which will ONLY be used in no-stats estimation.
joinStats = new StatisticsBuilder(joinStats).setWidthInJoinCluster(
groupExpression.childStatistics(0).getWidthInJoinCluster()
+ groupExpression.childStatistics(1).getWidthInJoinCluster()).build();
Expand Down Expand Up @@ -739,16 +741,14 @@ public Statistics visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN
@Override
public Statistics visitPhysicalHashJoin(
PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin, Void context) {
return JoinEstimation.estimate(groupExpression.childStatistics(0),
groupExpression.childStatistics(1), hashJoin);
return computeJoin(hashJoin);
}

@Override
public Statistics visitPhysicalNestedLoopJoin(
PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin,
Void context) {
return JoinEstimation.estimate(groupExpression.childStatistics(0),
groupExpression.childStatistics(1), nestedLoopJoin);
return computeJoin(nestedLoopJoin);
}

// TODO: We should subtract those pruned column, and consider the expression transformations in the node.
Expand Down Expand Up @@ -865,7 +865,7 @@ private Statistics computeFilter(Filter filter) {
}
builder.setRowCount(isNullStats.getRowCount());
stats = builder.build();
stats.enforceValid();
stats.normalizeColumnStatistics();
}
}
}
Expand Down Expand Up @@ -937,7 +937,7 @@ false, getTotalColumnStatisticMap(), false,

newStats = ((Plan) newJoin).accept(statsCalculator, null);
}
newStats.enforceValid();
newStats.normalizeColumnStatistics();

double selectivity = Statistics.getValidSelectivity(
newStats.getRowCount() / (leftRowCount * rightRowCount));
Expand Down Expand Up @@ -1087,6 +1087,11 @@ private Statistics computeCatalogRelation(CatalogRelation catalogRelation) {
return builder.setRowCount(tableRowCount).build();
}

private Statistics computeJoin(Join join) {
return JoinEstimation.estimate(groupExpression.childStatistics(0),
groupExpression.childStatistics(1), join);
}

private Statistics computeTopN(TopN topN) {
Statistics stats = groupExpression.childStatistics(0);
return stats.withRowCountAndEnforceValid(Math.min(stats.getRowCount(), topN.getLimit()));
Expand Down Expand Up @@ -1190,7 +1195,7 @@ private Statistics computeAggregate(Aggregate<? extends Plan> aggregate) {
slotToColumnStats.put(outputExpression.toSlot(), columnStat);
}
Statistics aggOutputStats = new Statistics(rowCount, 1, slotToColumnStats);
aggOutputStats.enforceValid();
aggOutputStats.normalizeColumnStatistics();
return aggOutputStats;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,9 @@ public static double divide(double a, double b) {
}
return a / nonZeroDivisor(b);
}

// TODO: add more protection at other stats estimation
public static double normalizeRowCountOrNdv(double value) {
return value >= 0 && value < 1 ? 1 : value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public final boolean isLeftAntiJoin() {
return this == LEFT_ANTI_JOIN;
}

public final boolean isLefSemiJoin() {
public final boolean isLeftSemiJoin() {
return this == LEFT_SEMI_JOIN;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ public static ColumnStatistic fromJson(String statJson) {
);
}

public boolean minOrMaxIsInf() {
public boolean isMinMaxInvalid() {
return Double.isInfinite(maxValue) || Double.isInfinite(minValue);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ public StatisticRange intersect(StatisticRange other) {
double newHigh = smallerHigh.first;
LiteralExpr newHighExpr = smallerHigh.second;
if (newLow <= newHigh) {
return new StatisticRange(newLow, newLowExpr, newHigh, newHighExpr,
overlappingDistinctValues(other), dataType);
double distinctValues = overlappingDistinctValues(other);
return new StatisticRange(newLow, newLowExpr, newHigh, newHighExpr, distinctValues, dataType);
}
return empty(dataType);
}
Expand All @@ -178,33 +178,6 @@ public Pair<Double, LiteralExpr> maxPair(double r1, LiteralExpr e1, double r2, L
return Pair.of(r2, e2);
}

public StatisticRange cover(StatisticRange other) {
StatisticRange resultRange;
Pair<Double, LiteralExpr> biggerLow = maxPair(low, lowExpr, other.low, other.lowExpr);
double newLow = biggerLow.first;
LiteralExpr newLowExpr = biggerLow.second;
Pair<Double, LiteralExpr> smallerHigh = minPair(high, highExpr, other.high, other.highExpr);
double newHigh = smallerHigh.first;
LiteralExpr newHighExpr = smallerHigh.second;

if (newLow <= newHigh) {
double overlapPercentOfLeft = overlapPercentWith(other);
double overlapDistinctValuesLeft = overlapPercentOfLeft * distinctValues;
double coveredDistinctValues = minExcludeNaN(distinctValues, overlapDistinctValuesLeft);
if (this.isBothInfinite() && other.isOneSideInfinite()) {
resultRange = new StatisticRange(newLow, newLowExpr, newHigh, newHighExpr,
distinctValues * INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR,
dataType);
} else {
resultRange = new StatisticRange(newLow, newLowExpr, newHigh, newHighExpr, coveredDistinctValues,
dataType);
}
} else {
resultRange = empty(dataType);
}
return resultRange;
}

public StatisticRange union(StatisticRange other) {
double overlapPercentThis = this.overlapPercentWith(other);
double overlapPercentOther = other.overlapPercentWith(this);
Expand All @@ -220,10 +193,27 @@ public StatisticRange union(StatisticRange other) {
}

private double overlappingDistinctValues(StatisticRange other) {
double overlapPercentOfLeft = overlapPercentWith(other);
double overlapPercentOfRight = other.overlapPercentWith(this);
double overlapDistinctValuesLeft = overlapPercentOfLeft * distinctValues;
double overlapDistinctValuesRight = overlapPercentOfRight * other.distinctValues;
double overlapDistinctValuesLeft;
double overlapDistinctValuesRight;
if (this.isInfinite() || other.isInfinite()) {
overlapDistinctValuesRight = distinctValues * INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR;
} else if (Math.abs(other.low - other.high) < 1e-6) {
// other is constant
overlapDistinctValuesRight = distinctValues;
} else {
double overlapPercentOfRight = other.overlapPercentWith(this);
overlapDistinctValuesRight = overlapPercentOfRight * other.distinctValues;
}

if (other.isInfinite() || this.isInfinite()) {
overlapDistinctValuesLeft = distinctValues * INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR;
} else if (Math.abs(this.low - this.high) < 1e-6) {
overlapDistinctValuesLeft = distinctValues;
} else {
double overlapPercentOfLeft = this.overlapPercentWith(other);
overlapDistinctValuesLeft = overlapPercentOfLeft * distinctValues;
}

return minExcludeNaN(overlapDistinctValuesLeft, overlapDistinctValuesRight);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,39 +98,53 @@ public Statistics withExpressionToColumnStats(Map<Expression, ColumnStatistic> e
*/
public Statistics withRowCountAndEnforceValid(double rowCount) {
Statistics statistics = new Statistics(rowCount, widthInJoinCluster, expressionToColumnStats);
statistics.enforceValid();
statistics.normalizeColumnStatistics();
return statistics;
}

public void enforceValid() {
// IMPORTANT: it is suggested to do this action after each estimation critical visiting,
// since statistics will have serious deviation during the partial deriving.
public void normalizeColumnStatistics() {
normalizeColumnStatistics(this.rowCount);
}

public void normalizeColumnStatistics(double inputRowCount) {
normalizeColumnStatistics(this.rowCount, false);
}

public void normalizeColumnStatistics(double inputRowCount, boolean isNumNullsDecreaseByProportion) {
for (Entry<Expression, ColumnStatistic> entry : expressionToColumnStats.entrySet()) {
ColumnStatistic columnStatistic = entry.getValue();
if (!checkColumnStatsValid(columnStatistic) && !columnStatistic.isUnKnown()) {
double ndv = Math.min(columnStatistic.ndv, rowCount);
double factor = isNumNullsDecreaseByProportion ? rowCount / inputRowCount : 1;
// the following columnStatistic.isUnKnown() judgment is loop inside since current doris
// supports partial stats deriving, i.e, allowing part of tables have stats and other parts don't,
// or part of columns have stats but other parts don't, especially join and filter estimation.
if (!checkColumnStatsValid(columnStatistic, rowCount) && !columnStatistic.isUnKnown()) {
ColumnStatisticBuilder columnStatisticBuilder = new ColumnStatisticBuilder(columnStatistic);
columnStatisticBuilder.setNdv(ndv);
columnStatisticBuilder.setNumNulls(Math.min(columnStatistic.numNulls, rowCount - ndv));
double ndv = Math.min(columnStatistic.ndv, rowCount);
double numNulls = Math.min(columnStatistic.numNulls * factor, rowCount - ndv);
columnStatisticBuilder.setNumNulls(numNulls);
columnStatisticBuilder.setNdv(Math.min(ndv, rowCount - numNulls));
columnStatistic = columnStatisticBuilder.build();
expressionToColumnStats.put(entry.getKey(), columnStatistic);
}
}
}

public boolean checkColumnStatsValid(ColumnStatistic columnStatistic) {
return columnStatistic.ndv <= rowCount
&& columnStatistic.numNulls <= rowCount - columnStatistic.ndv;
public boolean checkColumnStatsValid(ColumnStatistic columnStatistic, double rowCount) {
return columnStatistic.ndv <= rowCount && columnStatistic.numNulls <= rowCount - columnStatistic.ndv;
}

public Statistics withSel(double sel) {
return withSel(sel, 0);
}

public Statistics withSel(double sel, double numNull) {
sel = StatsMathUtil.minNonNaN(sel, 1);
public Statistics withSel(double notNullSel, double numNull) {
notNullSel = StatsMathUtil.minNonNaN(notNullSel, 1);
if (Double.isNaN(rowCount)) {
return this;
}
double newCount = rowCount * sel + numNull;
double newCount = rowCount * notNullSel + numNull;
return new Statistics(newCount, widthInJoinCluster, new HashMap<>(expressionToColumnStats));
}

Expand Down Expand Up @@ -227,8 +241,8 @@ public int getBENumber() {
return 1;
}

public static double getValidSelectivity(double nullSel) {
return nullSel < 0 ? 0 : (nullSel > 1 ? 1 : nullSel);
public static double getValidSelectivity(double selectivity) {
return selectivity < 0 ? 0 : (selectivity > 1 ? 1 : selectivity);
}

/**
Expand Down Expand Up @@ -263,24 +277,6 @@ public int getWidthInJoinCluster() {
return widthInJoinCluster;
}

public Statistics normalizeByRatio(double originRowCount) {
if (rowCount >= originRowCount || rowCount <= 0) {
return this;
}
StatisticsBuilder builder = new StatisticsBuilder(this);
double ratio = rowCount / originRowCount;
for (Entry<Expression, ColumnStatistic> entry : expressionToColumnStats.entrySet()) {
ColumnStatistic colStats = entry.getValue();
if (colStats.numNulls != 0 || colStats.ndv > rowCount) {
ColumnStatisticBuilder colStatsBuilder = new ColumnStatisticBuilder(colStats);
colStatsBuilder.setNumNulls(colStats.numNulls * ratio);
colStatsBuilder.setNdv(Math.min(rowCount - colStatsBuilder.getNumNulls(), colStats.ndv));
builder.putColumnStatistics(entry.getKey(), colStatsBuilder.build());
}
}
return builder.build();
}

public double getDeltaRowCount() {
return deltaRowCount;
}
Expand Down

0 comments on commit 0f14a99

Please sign in to comment.