Skip to content

Commit

Permalink
compute large and
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Nov 11, 2024
1 parent 2e28db9 commit e2305d2
Show file tree
Hide file tree
Showing 68 changed files with 984 additions and 937 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.coercion.RangeScalable;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.ColumnStatisticBuilder;
import org.apache.doris.statistics.StatisticRange;
Expand All @@ -60,6 +61,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Calculate selectivity of expression that produces boolean value.
Expand Down Expand Up @@ -112,42 +114,89 @@ public Statistics visit(Expression expr, EstimationContext context) {
return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT);
}

private Statistics estimateLargeAnd(List<Expression> conjuncts, EstimationContext context) {
StatisticsBuilder builder = new StatisticsBuilder(context.statistics);
if (context.statistics.getRowCount() <= 1) {
return builder.build();
}
List<Double> orderedSelections = conjuncts.stream()
.mapToDouble(conjunct ->
conjunct.accept(this, context).getBENumber() / context.statistics.getRowCount())
.boxed()
.sorted()
.collect(Collectors.toList());

double sel = 1.0;
for (int i = 0; i < orderedSelections.size(); i++) {
double value = orderedSelections.get(i);
if (value >= 1) {
break;
}
double root = Math.pow(value, 1.0 / (i + 1));
sel *= root;
}
double inputRowCount = context.statistics.getRowCount();
Statistics outputStats = builder.setRowCount(inputRowCount * sel).build();
return outputStats.normalizeColumnStatistics(inputRowCount, true);
}

private Statistics estimateBasicAnd(List<Expression> conjuncts, EstimationContext context) {
Expression leftExpr = conjuncts.get(0);
Expression rightExpr = conjuncts.get(1);
Statistics leftStats = leftExpr.accept(this, context);
leftStats.normalizeColumnStatistics(context.statistics.getRowCount(), true);
Statistics andStats = rightExpr.accept(this, new EstimationContext(leftStats));
return andStats.normalizeColumnStatistics(context.statistics.getRowCount(), true);
}

@Override
public Statistics visitCompoundPredicate(CompoundPredicate predicate, EstimationContext context) {
Expression leftExpr = predicate.child(0);
Expression rightExpr = predicate.child(1);
public Statistics visitAnd(And and, EstimationContext context) {
List<Expression> conjuncts = ExpressionUtils.extractConjunction(and);
Statistics outputStats;
if (conjuncts.size() >= LARGE_COMPOUND_PREDICATE) {
outputStats = estimateLargeAnd(conjuncts, context);
} else {
outputStats = estimateBasicAnd(conjuncts, context);
}
return outputStats;
}

@Override
public Statistics visitOr(Or or, EstimationContext context) {
Expression leftExpr = or.child(0);
Expression rightExpr = or.child(1);
Statistics leftStats = leftExpr.accept(this, context);
leftStats.normalizeColumnStatistics(context.statistics.getRowCount(), true);
Statistics andStats = rightExpr.accept(this, new EstimationContext(leftStats));
if (predicate instanceof And) {
andStats.normalizeColumnStatistics(context.statistics.getRowCount(), true);
return andStats;
} else if (predicate instanceof Or) {
Statistics rightStats = rightExpr.accept(this, context);
rightStats.normalizeColumnStatistics(context.statistics.getRowCount(), true);
double rowCount = leftStats.getRowCount() + rightStats.getRowCount() - andStats.getRowCount();
Statistics orStats = context.statistics.withRowCount(rowCount);
Set<Slot> leftInputSlots = leftExpr.getInputSlots();
Set<Slot> rightInputSlots = rightExpr.getInputSlots();
for (Slot slot : context.keyColumns) {
if (leftInputSlots.contains(slot) && rightInputSlots.contains(slot)) {
ColumnStatistic leftColStats = leftStats.findColumnStatistics(slot);
ColumnStatistic rightColStats = rightStats.findColumnStatistics(slot);
StatisticRange leftRange = StatisticRange.from(leftColStats, slot.getDataType());
StatisticRange rightRange = StatisticRange.from(rightColStats, slot.getDataType());
StatisticRange union = leftRange.union(rightRange);
ColumnStatisticBuilder colBuilder = new ColumnStatisticBuilder(
context.statistics.findColumnStatistics(slot));
colBuilder.setMinValue(union.getLow()).setMinExpr(union.getLowExpr())
.setMaxValue(union.getHigh()).setMaxExpr(union.getHighExpr())
.setNdv(union.getDistinctValues());
double maxNumNulls = Math.max(leftColStats.numNulls, rightColStats.numNulls);
colBuilder.setNumNulls(Math.min(colBuilder.getCount(), maxNumNulls));
orStats.addColumnStats(slot, colBuilder.build());
}

Statistics rightStats = rightExpr.accept(this, context);
rightStats.normalizeColumnStatistics(context.statistics.getRowCount(), true);
double rowCount = leftStats.getRowCount() + rightStats.getRowCount() - andStats.getRowCount();
Statistics orStats = context.statistics.withRowCount(rowCount);
Set<Slot> leftInputSlots = leftExpr.getInputSlots();
Set<Slot> rightInputSlots = rightExpr.getInputSlots();
for (Slot slot : context.keyColumns) {
if (leftInputSlots.contains(slot) && rightInputSlots.contains(slot)) {
ColumnStatistic leftColStats = leftStats.findColumnStatistics(slot);
ColumnStatistic rightColStats = rightStats.findColumnStatistics(slot);
StatisticRange leftRange = StatisticRange.from(leftColStats, slot.getDataType());
StatisticRange rightRange = StatisticRange.from(rightColStats, slot.getDataType());
StatisticRange union = leftRange.union(rightRange);
ColumnStatisticBuilder colBuilder = new ColumnStatisticBuilder(
context.statistics.findColumnStatistics(slot));
colBuilder.setMinValue(union.getLow()).setMinExpr(union.getLowExpr())
.setMaxValue(union.getHigh()).setMaxExpr(union.getHighExpr())
.setNdv(union.getDistinctValues());
double maxNumNulls = Math.max(leftColStats.numNulls, rightColStats.numNulls);
colBuilder.setNumNulls(Math.min(colBuilder.getCount(), maxNumNulls));
orStats.addColumnStats(slot, colBuilder.build());
}
return orStats;
}
return orStats;
}

@Override
public Statistics visitCompoundPredicate(CompoundPredicate predicate, EstimationContext context) {
// should not come here
Preconditions.checkArgument(false,
"unsupported compound operator: %s in %s",
Expand Down Expand Up @@ -577,7 +626,10 @@ public Statistics visitIsNull(IsNull isNull, EstimationContext context) {
return builder.build();
}

static class EstimationContext {
/**
* EstimationContext
*/
public static class EstimationContext {
private final Statistics statistics;

private final Set<Slot> keyColumns = Sets.newHashSet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public void normalizeColumnStatistics(double inputRowCount) {
normalizeColumnStatistics(this.rowCount, false);
}

public void normalizeColumnStatistics(double inputRowCount, boolean isNumNullsDecreaseByProportion) {
public Statistics normalizeColumnStatistics(double inputRowCount, boolean isNumNullsDecreaseByProportion) {
double factor = isNumNullsDecreaseByProportion ? rowCount / inputRowCount : 1.0;
for (Entry<Expression, ColumnStatistic> entry : expressionToColumnStats.entrySet()) {
ColumnStatistic columnStatistic = entry.getValue();
Expand All @@ -130,6 +130,7 @@ public void normalizeColumnStatistics(double inputRowCount, boolean isNumNullsDe
expressionToColumnStats.put(entry.getKey(), columnStatistic);
}
}
return this;
}

public boolean checkColumnStatsValid(ColumnStatistic columnStatistic, double rowCount) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------PhysicalDistribute[DistributionSpecGather]
--------PhysicalTopN[LOCAL_SORT]
----------PhysicalProject
------------hashJoin[INNER_JOIN shuffleBucket] hashCondition=((t_s_firstyear.customer_id = t_w_secyear.customer_id)) otherCondition=((if((year_total > 0.00), (cast(year_total as DECIMALV3(38, 8)) / year_total), 0.000000) > if((year_total > 0.00), (cast(year_total as DECIMALV3(38, 8)) / year_total), 0.000000))) build RFs:RF5 customer_id->[customer_id]
------------hashJoin[INNER_JOIN broadcast] hashCondition=((t_s_firstyear.customer_id = t_w_firstyear.customer_id)) otherCondition=((if((year_total > 0.00), (cast(year_total as DECIMALV3(38, 8)) / year_total), 0.000000) > if((year_total > 0.00), (cast(year_total as DECIMALV3(38, 8)) / year_total), 0.000000))) build RFs:RF5 customer_id->[customer_id,customer_id,customer_id]
--------------PhysicalProject
----------------filter((t_w_secyear.dyear = 2002) and (t_w_secyear.sale_type = 'w'))
------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF5
--------------PhysicalProject
----------------hashJoin[INNER_JOIN bucketShuffle] hashCondition=((t_s_firstyear.customer_id = t_w_firstyear.customer_id)) otherCondition=() build RFs:RF4 customer_id->[customer_id,customer_id]
----------------hashJoin[INNER_JOIN shuffleBucket] hashCondition=((t_s_firstyear.customer_id = t_w_secyear.customer_id)) otherCondition=() build RFs:RF4 customer_id->[customer_id]
------------------PhysicalProject
--------------------filter((t_w_secyear.dyear = 2002) and (t_w_secyear.sale_type = 'w'))
----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF4 RF5
------------------hashJoin[INNER_JOIN shuffle] hashCondition=((t_s_secyear.customer_id = t_s_firstyear.customer_id)) otherCondition=() build RFs:RF3 customer_id->[customer_id]
--------------------PhysicalProject
----------------------filter((t_s_secyear.dyear = 2002) and (t_s_secyear.sale_type = 's'))
------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4
------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF5
--------------------PhysicalProject
----------------------filter((t_s_firstyear.dyear = 2001) and (t_s_firstyear.sale_type = 's') and (t_s_firstyear.year_total > 0.00))
------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF4
------------------PhysicalProject
--------------------filter((t_w_firstyear.dyear = 2001) and (t_w_firstyear.sale_type = 'w') and (t_w_firstyear.year_total > 0.00))
----------------------PhysicalCteConsumer ( cteId=CTEId#0 )
------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF5
--------------PhysicalProject
----------------filter((t_w_firstyear.dyear = 2001) and (t_w_firstyear.sale_type = 'w') and (t_w_firstyear.year_total > 0.00))
------------------PhysicalCteConsumer ( cteId=CTEId#0 )

Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@ PhysicalResultSink
--------PhysicalProject
----------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk]
------------PhysicalProject
--------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_addr_sk = customer_address.ca_address_sk)) otherCondition=((((ca_state IN ('KS', 'MI', 'SD') AND ((store_sales.ss_net_profit >= 100.00) AND (store_sales.ss_net_profit <= 200.00))) OR (ca_state IN ('CO', 'MO', 'ND') AND ((store_sales.ss_net_profit >= 150.00) AND (store_sales.ss_net_profit <= 300.00)))) OR (ca_state IN ('NH', 'OH', 'TX') AND ((store_sales.ss_net_profit >= 50.00) AND (store_sales.ss_net_profit <= 250.00))))) build RFs:RF3 ca_address_sk->[ss_addr_sk]
--------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_addr_sk = customer_address.ca_address_sk)) otherCondition=((((ca_state IN ('KS', 'MI', 'SD') AND ((store_sales.ss_net_profit >= 100.00) AND (store_sales.ss_net_profit <= 200.00))) OR (ca_state IN ('CO', 'MO', 'ND') AND ((store_sales.ss_net_profit >= 150.00) AND (store_sales.ss_net_profit <= 300.00)))) OR (ca_state IN ('NH', 'OH', 'TX') AND ((store_sales.ss_net_profit >= 50.00) AND (store_sales.ss_net_profit <= 250.00))))) build RFs:RF3 ss_addr_sk->[ca_address_sk]
----------------PhysicalProject
------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=((((((household_demographics.hd_dep_count = 1) AND cd_marital_status IN ('M', 'S')) AND cd_education_status IN ('4 yr Degree', 'College')) AND ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) OR (((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))))) OR ((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)))) build RFs:RF2 hd_demo_sk->[ss_hdemo_sk]
------------------filter((customer_address.ca_country = 'United States') and ca_state IN ('CO', 'KS', 'MI', 'MO', 'ND', 'NH', 'OH', 'SD', 'TX'))
--------------------PhysicalOlapScan[customer_address] apply RFs: RF3
----------------PhysicalProject
------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=((((((household_demographics.hd_dep_count = 1) AND cd_marital_status IN ('M', 'S')) AND cd_education_status IN ('4 yr Degree', 'College')) AND ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) OR (((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))))) OR ((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)))) build RFs:RF2 ss_hdemo_sk->[hd_demo_sk]
--------------------PhysicalProject
----------------------filter(hd_dep_count IN (1, 3))
------------------------PhysicalOlapScan[household_demographics] apply RFs: RF2
--------------------PhysicalProject
----------------------hashJoin[INNER_JOIN broadcast] hashCondition=((customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk)) otherCondition=() build RFs:RF1 cd_demo_sk->[ss_cdemo_sk]
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store.s_store_sk = store_sales.ss_store_sk)) otherCondition=()
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store.s_store_sk = store_sales.ss_store_sk)) otherCondition=() build RFs:RF0 ss_store_sk->[s_store_sk]
----------------------------PhysicalProject
------------------------------filter((store_sales.ss_net_profit <= 300.00) and (store_sales.ss_net_profit >= 50.00) and (store_sales.ss_sales_price <= 200.00) and (store_sales.ss_sales_price >= 50.00))
--------------------------------PhysicalOlapScan[store_sales] apply RFs: RF1 RF2 RF3 RF4
------------------------------PhysicalOlapScan[store] apply RFs: RF0
----------------------------PhysicalProject
------------------------------PhysicalOlapScan[store]
------------------------------filter((store_sales.ss_net_profit <= 300.00) and (store_sales.ss_net_profit >= 50.00) and (store_sales.ss_sales_price <= 200.00) and (store_sales.ss_sales_price >= 50.00))
--------------------------------PhysicalOlapScan[store_sales] apply RFs: RF1 RF4
------------------------PhysicalProject
--------------------------filter(((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) OR ((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College'))) OR ((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree'))) and cd_education_status IN ('4 yr Degree', 'College', 'Unknown') and cd_marital_status IN ('D', 'M', 'S'))
----------------------------PhysicalOlapScan[customer_demographics]
--------------------PhysicalProject
----------------------filter(hd_dep_count IN (1, 3))
------------------------PhysicalOlapScan[household_demographics]
----------------PhysicalProject
------------------filter((customer_address.ca_country = 'United States') and ca_state IN ('CO', 'KS', 'MI', 'MO', 'ND', 'NH', 'OH', 'SD', 'TX'))
--------------------PhysicalOlapScan[customer_address]
------------PhysicalProject
--------------filter((date_dim.d_year = 2001))
----------------PhysicalOlapScan[date_dim]
Expand Down
Loading

0 comments on commit e2305d2

Please sign in to comment.