Skip to content

Commit

Permalink
[improve](agg)support push down min/max on unique table
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangstar333 committed Dec 28, 2023
1 parent 065eb9a commit 0fb732e
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ public enum RuleType {
STORAGE_LAYER_AGGREGATE_WITH_PROJECT(RuleTypeClass.IMPLEMENTATION),
STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT_FOR_FILE_SCAN(RuleTypeClass.IMPLEMENTATION),
STORAGE_LAYER_AGGREGATE_WITH_PROJECT_FOR_FILE_SCAN(RuleTypeClass.IMPLEMENTATION),
MINMAX_ON_UNIQUE_TABLE(RuleTypeClass.IMPLEMENTATION),
MINMAX_ON_UNIQUE_TABLE_WITHOUT_PROJECT(RuleTypeClass.IMPLEMENTATION),
COUNT_ON_INDEX(RuleTypeClass.IMPLEMENTATION),
COUNT_ON_INDEX_WITHOUT_PROJECT(RuleTypeClass.IMPLEMENTATION),
ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT(RuleTypeClass.IMPLEMENTATION),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,15 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
Expand Down Expand Up @@ -140,6 +143,48 @@ public List<Rule> buildRules() {
return pushdownCountOnIndex(agg, project, filter, olapScan, ctx.cascadesContext);
})
),
RuleType.MINMAX_ON_UNIQUE_TABLE_WITHOUT_PROJECT.build(
logicalAggregate(
logicalFilter(
logicalOlapScan().when(this::isUniqueKeyTable))
.when(filter -> filter.getConjuncts().size() == 1))
.when(agg -> enablePushDownMinMaxOnUnique())
.when(agg -> agg.getGroupByExpressions().size() == 0)
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min) || (f instanceof Max));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
LogicalFilter<LogicalOlapScan> filter = agg.child();
LogicalOlapScan olapScan = filter.child();
return pushdownMinMaxOnUniqueTable(agg, null, filter, olapScan,
ctx.cascadesContext);
})
),
RuleType.MINMAX_ON_UNIQUE_TABLE.build(
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan().when(this::isUniqueKeyTable))
.when(filter -> filter.getConjuncts().size() == 1)))
.when(agg -> enablePushDownMinMaxOnUnique())
.when(agg -> agg.getGroupByExpressions().size() == 0)
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty()
&& funcs.stream().allMatch(f -> (f instanceof Min) || (f instanceof Max));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
LogicalFilter<LogicalOlapScan> filter = project.child();
LogicalOlapScan olapScan = filter.child();
return pushdownMinMaxOnUniqueTable(agg, project, filter, olapScan,
ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT.build(
logicalAggregate(
logicalOlapScan()
Expand Down Expand Up @@ -238,6 +283,19 @@ && couldConvertToMulti(agg))
);
}

private boolean enablePushDownMinMaxOnUnique() {
ConnectContext connectContext = ConnectContext.get();
return connectContext != null && connectContext.getSessionVariable().isEnablePushDownMinMaxOnUnique();
}

private boolean isUniqueKeyTable(LogicalOlapScan logicalScan) {
if (logicalScan != null) {
KeysType keysType = logicalScan.getTable().getKeysType();
return keysType == KeysType.UNIQUE_KEYS;
}
return false;
}

private boolean enablePushDownCountOnIndex() {
ConnectContext connectContext = ConnectContext.get();
return connectContext != null && connectContext.getSessionVariable().isEnablePushDownCountOnIndex();
Expand Down Expand Up @@ -314,6 +372,106 @@ private LogicalAggregate<? extends Plan> pushdownCountOnIndex(
}
}

//select /*+SET_VAR(enable_pushdown_minmax_on_unique=true) */min(user_id) from table_unique;
//push pushAggOp=MINMAX to scan node
private LogicalAggregate<? extends Plan> pushdownMinMaxOnUniqueTable(
LogicalAggregate<? extends Plan> aggregate,
@Nullable LogicalProject<? extends Plan> project,
LogicalFilter<? extends Plan> filter,
LogicalOlapScan olapScan,
CascadesContext cascadesContext) {
final LogicalAggregate<? extends Plan> canNotPush = aggregate;
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
if (checkWhetherPushDownMinMax(aggregateFunctions, project, olapScan.getOutput())) {
PhysicalOlapScan physicalOlapScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan()
.build()
.transform(olapScan, cascadesContext)
.get(0);
if (project != null) {
return aggregate.withChildren(ImmutableList.of(
project.withChildren(ImmutableList.of(
filter.withChildren(ImmutableList.of(
new PhysicalStorageLayerAggregate(
physicalOlapScan,
PushDownAggOp.MIN_MAX)))))));
} else {
return aggregate.withChildren(ImmutableList.of(
filter.withChildren(ImmutableList.of(
new PhysicalStorageLayerAggregate(
physicalOlapScan,
PushDownAggOp.MIN_MAX)))));
}
} else {
return canNotPush;
}
}

private boolean checkWhetherPushDownMinMax(Set<AggregateFunction> aggregateFunctions,
@Nullable LogicalProject<? extends Plan> project, List<Slot> outPutSlots) {
boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream()
.map(ExpressionTrait::getArguments)
.flatMap(List::stream)
.allMatch(argument -> {
if (argument instanceof SlotReference) {
return true;
}
if (argument instanceof Cast) {
return argument.child(0) instanceof SlotReference
&& argument.getDataType().isNumericType()
&& argument.child(0).getDataType().isNumericType();
}
return false;
});
if (!onlyContainsSlotOrNumericCastSlot) {
return false;
}
List<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(aggregateFunction -> aggregateFunction.getArguments().stream())
.collect(ImmutableList.toImmutableList());

if (project != null) {
argumentsOfAggregateFunction = Project.findProject(
argumentsOfAggregateFunction, project.getProjects())
.stream()
.map(p -> p instanceof Alias ? p.child(0) : p)
.collect(ImmutableList.toImmutableList());
}
onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
.stream()
.allMatch(argument -> {
if (argument instanceof SlotReference) {
return true;
}
if (argument instanceof Cast) {
return argument.child(0) instanceof SlotReference
&& argument.getDataType().isNumericType()
&& argument.child(0).getDataType().isNumericType();
}
return false;
});
if (!onlyContainsSlotOrNumericCastSlot) {
return false;
}
Set<SlotReference> aggUsedSlots = ExpressionUtils.collect(argumentsOfAggregateFunction,
SlotReference.class::isInstance);
List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots,
outPutSlots);
for (SlotReference slot : usedSlotInTable) {
Column column = slot.getColumn().get();
// The zone map max length of CharFamily is 512, do not
// over the length: https://github.com/apache/doris/pull/6293
PrimitiveType colType = column.getType().getPrimitiveType();
if (colType.isComplexType() || colType.isHllType() || colType.isBitmapType()
|| colType == PrimitiveType.STRING) {
return false;
}
if (colType.isCharFamily() && column.getType().getLength() > 512) {
return false;
}
}
return true;
}

/**
* sql: select count(*) from tbl
* <p>
Expand Down
15 changes: 15 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ public class SessionVariable implements Serializable, Writable {
public static final String MATERIALIZED_VIEW_REWRITE_ENABLE_CONTAIN_FOREIGN_TABLE
= "materialized_view_rewrite_enable_contain_foreign_table";

public static final String ENABLE_PUSHDOWN_MINMAX_ON_UNIQUE = "enable_pushdown_minmax_on_unique";

// When set use fix replica = true, the fixed replica maybe bad, try to use the health one if
// this session variable is set to true.
public static final String FALLBACK_OTHER_REPLICA_WHEN_FIXED_CORRUPT = "fallback_other_replica_when_fixed_corrupt";
Expand Down Expand Up @@ -1200,6 +1202,11 @@ public void setEnableLeftZigZag(boolean enableLeftZigZag) {
"是否启用count_on_index pushdown。", "Set whether to pushdown count_on_index."})
public boolean enablePushDownCountOnIndex = true;

// Whether enable pushdown minmax to scan node of unique table.
@VariableMgr.VarAttr(name = ENABLE_PUSHDOWN_MINMAX_ON_UNIQUE, needForward = true, description = {
"是否启用pushdown minmax on unique table。", "Set whether to pushdown minmax on unique table."})
public boolean enablePushDownMinMaxOnUnique = false;

// Whether drop table when create table as select insert data appear error.
@VariableMgr.VarAttr(name = DROP_TABLE_IF_CTAS_FAILED, needForward = true)
public boolean dropTableIfCtasFailed = true;
Expand Down Expand Up @@ -2417,6 +2424,14 @@ public void setDisableJoinReorder(boolean disableJoinReorder) {
this.disableJoinReorder = disableJoinReorder;
}

public boolean isEnablePushDownMinMaxOnUnique() {
return enablePushDownMinMaxOnUnique;
}

public void setEnablePushDownMinMaxOnUnique(boolean enablePushDownMinMaxOnUnique) {
this.enablePushDownMinMaxOnUnique = enablePushDownMinMaxOnUnique;
}

/**
* Nereids only support vectorized engine.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,54 @@ suite("test_pushdown_explain") {
sql("select count(cast(lo_orderkey as bigint)) from test_lineorder;")
contains "pushAggOp=COUNT"
}

sql "DROP TABLE IF EXISTS table_unique"
sql """
CREATE TABLE `table_unique` (
`user_id` LARGEINT NOT NULL COMMENT '\"用户id\"',
`username` VARCHAR(50) NOT NULL COMMENT '\"用户昵称\"'
) ENGINE=OLAP
UNIQUE KEY(`user_id`, `username`)
COMMENT 'OLAP'
DISTRIBUTED BY HASH(`user_id`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
"""

sql "set enable_pushdown_minmax_on_unique = true;"
explain {
sql("select min(user_id) from table_unique;")
contains "pushAggOp=MINMAX"
}
explain {
sql("select max(user_id) from table_unique;")
contains "pushAggOp=MINMAX"
}
explain {
sql("select min(username) from table_unique;")
contains "pushAggOp=MINMAX"
}
explain {
sql("select max(username) from table_unique;")
contains "pushAggOp=MINMAX"
}

sql "set enable_pushdown_minmax_on_unique = false;"
explain {
sql("select min(user_id) from table_unique;")
contains "pushAggOp=NONE"
}
explain {
sql("select max(user_id) from table_unique;")
contains "pushAggOp=NONE"
}
explain {
sql("select min(username) from table_unique;")
contains "pushAggOp=NONE"
}
explain {
sql("select max(username) from table_unique;")
contains "pushAggOp=NONE"
}
}

0 comments on commit 0fb732e

Please sign in to comment.