From 4775c127d93b9ada42589999f2786d6c426f3026 Mon Sep 17 00:00:00 2001 From: seawinde Date: Fri, 15 Dec 2023 10:22:11 +0800 Subject: [PATCH] bitmap roll up develop --- .../analyzer/PlaceholderExpression.java | 20 ++- ...AbstractMaterializedViewAggregateRule.java | 72 ++++++++- .../functions/agg/AggregateFunction.java | 3 +- .../expressions/functions/agg/Count.java | 11 ++ .../trees/expressions/functions/agg/Max.java | 6 + .../trees/expressions/functions/agg/Min.java | 6 + .../mv/MaterializedViewUtilsTest.java | 140 +++++++++++++----- 7 files changed, 217 insertions(+), 41 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java index 8f069b25694f955..2f55ca87965118e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java @@ -39,6 +39,7 @@ public class PlaceholderExpression extends Expression implements AlwaysNotNullab * 1 based */ private final int position; + protected boolean distinct; public PlaceholderExpression(List children, Class delegateClazz, int position) { super(children); @@ -46,10 +47,23 @@ public PlaceholderExpression(List children, Class children, Class delegateClazz, int position, + boolean distinct) { + super(children); + this.delegateClazz = Objects.requireNonNull(delegateClazz, "delegateClazz should not be null"); + this.position = position; + this.distinct = distinct; + } + public static PlaceholderExpression of(Class delegateClazz, int position) { return new PlaceholderExpression(ImmutableList.of(), delegateClazz, position); } + public static PlaceholderExpression of(Class delegateClazz, int position, + boolean distinct) { + return new PlaceholderExpression(ImmutableList.of(), delegateClazz, position, distinct); + } + @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visit(this, context); @@ -63,6 +77,10 @@ public int getPosition() { return position; } + public boolean isDistinct() { + return distinct; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -80,6 +98,6 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(super.hashCode(), delegateClazz, position); + return Objects.hash(super.hashCode(), delegateClazz, position, distinct); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java index 0a5d3f0948c873e..86e2e664e45cfb1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.exploration.mv; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.analyzer.PlaceholderExpression; import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode; @@ -29,18 +30,27 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.Function; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapCount; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -53,6 +63,17 @@ */ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule { + protected static final Map + AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = new HashMap<>(); + + static { + AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put( + PlaceholderExpression.of(Count.class, 0, true), + new PlaceholderExpression( + ImmutableList.of(PlaceholderExpression.of(ToBitmap.class, 0)), + BitmapUnion.class, 0)); + } + @Override protected Plan rewriteQueryByView(MatchMode matchMode, StructInfo queryStructInfo, @@ -135,14 +156,16 @@ protected Plan rewriteQueryByView(MatchMode matchMode, Expression needRollupShuttledExpr = ExpressionUtils.shuttleExpressionWithLineage( topExpression, queryTopPlan); + if (!mvExprToMvScanExprQueryBased.containsKey(needRollupShuttledExpr)) { // function can not rewrite by view return null; } + // try to roll up AggregateFunction needRollupAggFunction = (AggregateFunction) topExpression.firstMatch( expr -> expr instanceof AggregateFunction); - AggregateFunction rollupAggregateFunction = rollup(needRollupAggFunction, + Function rollupAggregateFunction = rollup(needRollupAggFunction, mvExprToMvScanExprQueryBased.get(needRollupShuttledExpr)); if (rollupAggregateFunction == null) { return null; @@ -226,15 +249,24 @@ protected Plan rewriteQueryByView(MatchMode matchMode, } // only support sum roll up, support other agg functions later. - private AggregateFunction rollup(AggregateFunction originFunction, + private Function rollup(AggregateFunction originFunction, Expression mappedExpression) { - Class rollupAggregateFunction = originFunction.getRollup(); + Class rollupAggregateFunction = originFunction.getRollup(); if (rollupAggregateFunction == null) { return null; } if (Sum.class.isAssignableFrom(rollupAggregateFunction)) { return new Sum(originFunction.isDistinct(), mappedExpression); } + if (Max.class.isAssignableFrom(rollupAggregateFunction)) { + return new Max(originFunction.isDistinct(), mappedExpression); + } + if (Min.class.isAssignableFrom(rollupAggregateFunction)) { + return new Min(originFunction.isDistinct(), mappedExpression); + } + if (BitmapCount.class.isAssignableFrom(rollupAggregateFunction)) { + return new BitmapCount(mappedExpression); + } // can rollup return null return null; } @@ -306,4 +338,38 @@ protected boolean checkPattern(StructInfo structInfo) { } return true; } + + private boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) { + Class queryClazz = queryFunction.getClass(); + Class viewClazz = viewFunction.getClass(); + if (queryClazz.isAssignableFrom(viewClazz)) { + return true; + } + boolean isDistinct = queryFunction instanceof AggregateFunction + && ((AggregateFunction) queryFunction).isDistinct(); + PlaceholderExpression equivalentFunction = AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.get( + PlaceholderExpression.of(queryFunction.getClass(), 0, isDistinct)); + // check is have equivalent function or not + if (equivalentFunction == null){ + return false; + } + // current compare + if (!viewFunction.getClass().isAssignableFrom(equivalentFunction.getDelegateClazz())) { + return false; + } + if (!viewFunction.children().isEmpty()) { + // children compare, just compare two level, support more later + List equivalentFunctions = equivalentFunction.children(); + if (viewFunction.children().size() != equivalentFunctions.size()) { + return false; + } + for (int i = 0; i < viewFunction.children().size(); i++) { + if (!viewFunction.child(i).getClass().equals( + ((PlaceholderExpression)equivalentFunctions.get(i)).getDelegateClazz())) { + return false; + } + } + } + return true; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java index 61a589daba27819..ce80dec27415fe2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.Function; import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; @@ -77,7 +78,7 @@ public boolean isDistinct() { return distinct; } - public Class getRollup() { + public Class getRollup() { return null; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java index a8a3fdd033ddc49..7176369cfca0a16 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java @@ -22,6 +22,8 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapCount; import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; @@ -142,4 +144,13 @@ public R accept(ExpressionVisitor visitor, C context) { public List getSignatures() { return SIGNATURES; } + + @Override + public Class getRollup() { + if (this.isDistinct()) { + return BitmapCount.class; + } else { + return Sum.class; + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java index 19cd0190bb6e37f..b32345b46518a1e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.CustomSignature; +import org.apache.doris.nereids.trees.expressions.functions.Function; import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; @@ -80,4 +81,9 @@ public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) { public R accept(ExpressionVisitor visitor, C context) { return visitor.visitMax(this, context); } + + @Override + public Class getRollup() { + return Max.class; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java index 72b2162eb51f6ea..097b1246d1fd0f6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.CustomSignature; +import org.apache.doris.nereids.trees.expressions.functions.Function; import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; @@ -81,4 +82,9 @@ public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) { public R accept(ExpressionVisitor visitor, C context) { return visitor.visitMin(this, context); } + + @Override + public Class getRollup() { + return Min.class; + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtilsTest.java index 2e402cd5c7aa386..4d164ba4cf8e91c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtilsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtilsTest.java @@ -41,46 +41,90 @@ protected void runBeforeAll() throws Exception { useDatabase("mv_util_test"); createTable("CREATE TABLE IF NOT EXISTS lineitem (\n" - + " L_ORDERKEY INTEGER NOT NULL,\n" - + " L_PARTKEY INTEGER NOT NULL,\n" - + " L_SUPPKEY INTEGER NOT NULL,\n" - + " L_LINENUMBER INTEGER NOT NULL,\n" - + " L_QUANTITY DECIMALV3(15,2) NOT NULL,\n" - + " L_EXTENDEDPRICE DECIMALV3(15,2) NOT NULL,\n" - + " L_DISCOUNT DECIMALV3(15,2) NOT NULL,\n" - + " L_TAX DECIMALV3(15,2) NOT NULL,\n" - + " L_RETURNFLAG CHAR(1) NOT NULL,\n" - + " L_LINESTATUS CHAR(1) NOT NULL,\n" - + " L_SHIPDATE DATE NOT NULL,\n" - + " L_COMMITDATE DATE NOT NULL,\n" - + " L_RECEIPTDATE DATE NOT NULL,\n" - + " L_SHIPINSTRUCT CHAR(25) NOT NULL,\n" - + " L_SHIPMODE CHAR(10) NOT NULL,\n" - + " L_COMMENT VARCHAR(44) NOT NULL\n" + + " l_orderkey integer not null,\n" + + " l_partkey integer not null,\n" + + " l_suppkey integer not null,\n" + + " l_linenumber integer not null,\n" + + " l_quantity decimalv3(15,2) not null,\n" + + " l_extendedprice decimalv3(15,2) not null,\n" + + " l_discount decimalv3(15,2) not null,\n" + + " l_tax decimalv3(15,2) not null,\n" + + " l_returnflag char(1) not null,\n" + + " l_linestatus char(1) not null,\n" + + " l_shipdate date not null,\n" + + " l_commitdate date not null,\n" + + " l_receiptdate date not null,\n" + + " l_shipinstruct char(25) not null,\n" + + " l_shipmode char(10) not null,\n" + + " l_comment varchar(44) not null\n" + ")\n" - + "DUPLICATE KEY(L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER)\n" - + "PARTITION BY RANGE(L_SHIPDATE) (PARTITION `day_1` VALUES LESS THAN ('2017-02-01'))\n" - + "DISTRIBUTED BY HASH(L_ORDERKEY) BUCKETS 3\n" + + "DUPLICATE KEY(l_orderkey, l_partkey, l_suppkey, l_linenumber)\n" + + "PARTITION BY RANGE(l_shipdate) \n" + + "(FROM ('2023-10-17') TO ('2023-10-20') INTERVAL 1 DAY)\n" + + "DISTRIBUTED BY HASH(l_orderkey) BUCKETS 3\n" + "PROPERTIES (\n" + " \"replication_num\" = \"1\"\n" - + ")"); + + ");"); + // createTable("CREATE TABLE IF NOT EXISTS lineitem (\n" + // + " L_ORDERKEY INTEGER NOT NULL,\n" + // + " L_PARTKEY INTEGER NOT NULL,\n" + // + " L_SUPPKEY INTEGER NOT NULL,\n" + // + " L_LINENUMBER INTEGER NOT NULL,\n" + // + " L_QUANTITY DECIMALV3(15,2) NOT NULL,\n" + // + " L_EXTENDEDPRICE DECIMALV3(15,2) NOT NULL,\n" + // + " L_DISCOUNT DECIMALV3(15,2) NOT NULL,\n" + // + " L_TAX DECIMALV3(15,2) NOT NULL,\n" + // + " L_RETURNFLAG CHAR(1) NOT NULL,\n" + // + " L_LINESTATUS CHAR(1) NOT NULL,\n" + // + " L_SHIPDATE DATE NOT NULL,\n" + // + " L_COMMITDATE DATE NOT NULL,\n" + // + " L_RECEIPTDATE DATE NOT NULL,\n" + // + " L_SHIPINSTRUCT CHAR(25) NOT NULL,\n" + // + " L_SHIPMODE CHAR(10) NOT NULL,\n" + // + " L_COMMENT VARCHAR(44) NOT NULL\n" + // + ")\n" + // + "DUPLICATE KEY(L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER)\n" + // + "PARTITION BY RANGE(L_SHIPDATE) (FROM ('2023-10-17') TO ('2023-10-20') INTERVAL 1 DAY)\n" + // + "DISTRIBUTED BY HASH(L_ORDERKEY) BUCKETS 3\n" + // + "PROPERTIES (\n" + // + " \"replication_num\" = \"1\"\n" + // + ")"); + // createTable("CREATE TABLE IF NOT EXISTS orders (\n" + // + " O_ORDERKEY INTEGER NOT NULL,\n" + // + " O_CUSTKEY INTEGER NOT NULL,\n" + // + " O_ORDERSTATUS CHAR(1) NOT NULL,\n" + // + " O_TOTALPRICE DECIMALV3(15,2) NOT NULL,\n" + // + " O_ORDERDATE DATE NOT NULL,\n" + // + " O_ORDERPRIORITY CHAR(15) NOT NULL, \n" + // + " O_CLERK CHAR(15) NOT NULL, \n" + // + " O_SHIPPRIORITY INTEGER NOT NULL,\n" + // + " O_COMMENT VARCHAR(79) NOT NULL\n" + // + ")\n" + // + "DUPLICATE KEY(O_ORDERKEY, O_CUSTKEY)\n" + // + "PARTITION BY RANGE(O_ORDERDATE) (FROM ('2023-10-17') TO ('2023-10-20') INTERVAL 1 DAY)\n" + // + "DISTRIBUTED BY HASH(O_ORDERKEY) BUCKETS 3\n" + // + "PROPERTIES (\n" + // + " \"replication_num\" = \"1\"\n" + // + ")"); createTable("CREATE TABLE IF NOT EXISTS orders (\n" - + " O_ORDERKEY INTEGER NOT NULL,\n" - + " O_CUSTKEY INTEGER NOT NULL,\n" - + " O_ORDERSTATUS CHAR(1) NOT NULL,\n" - + " O_TOTALPRICE DECIMALV3(15,2) NOT NULL,\n" - + " O_ORDERDATE DATE NOT NULL,\n" - + " O_ORDERPRIORITY CHAR(15) NOT NULL, \n" - + " O_CLERK CHAR(15) NOT NULL, \n" - + " O_SHIPPRIORITY INTEGER NOT NULL,\n" - + " O_COMMENT VARCHAR(79) NOT NULL\n" - + ")\n" - + "DUPLICATE KEY(O_ORDERKEY, O_CUSTKEY)\n" - + "PARTITION BY RANGE(O_ORDERDATE) (PARTITION `day_2` VALUES LESS THAN ('2017-03-01'))\n" - + "DISTRIBUTED BY HASH(O_ORDERKEY) BUCKETS 3\n" - + "PROPERTIES (\n" - + " \"replication_num\" = \"1\"\n" - + ")"); + + " o_orderkey integer not null,\n" + + " o_custkey integer not null,\n" + + " o_orderstatus char(1) not null,\n" + + " o_totalprice decimalv3(15,2) not null,\n" + + " o_orderdate date not null,\n" + + " o_orderpriority char(15) not null, \n" + + " o_clerk char(15) not null, \n" + + " o_shippriority integer not null,\n" + + " o_comment varchar(79) not null\n" + + " )\n" + + " DUPLICATE KEY(o_orderkey, o_custkey)\n" + + " PARTITION BY RANGE(o_orderdate)(\n" + + " FROM ('2023-10-17') TO ('2023-10-20') INTERVAL 1 DAY\n" + + " )\n" + + " DISTRIBUTED BY HASH(o_orderkey) BUCKETS 3\n" + + " PROPERTIES (\n" + + " \"replication_num\" = \"1\"\n" + + " );"); createTable("CREATE TABLE IF NOT EXISTS partsupp (\n" + " PS_PARTKEY INTEGER NOT NULL,\n" + " PS_SUPPKEY INTEGER NOT NULL,\n" @@ -93,6 +137,8 @@ protected void runBeforeAll() throws Exception { + "PROPERTIES (\n" + " \"replication_num\" = \"1\"\n" + ")"); + + connectContext.getSessionVariable().enableNereidsTimeout = false; } @Test @@ -243,6 +289,28 @@ public void getRelatedTableInfoTestWithWindowButNotPartitionTest() { }); } + @Test + public void getRelatedTableInfoWithLeftJoinTest() { + PlanChecker.from(connectContext) + .checkExplain("select l_shipdate, o_orderdate, l_partkey, l_suppkey, sum(o_totalprice) as sum_total\n" + + " from lineitem\n" + + " left join orders on lineitem.l_orderkey = orders.o_orderkey and l_shipdate = o_orderdate\n" + + " group by\n" + + " l_shipdate,\n" + + " o_orderdate,\n" + + " l_partkey,\n" + + " l_suppkey;", + nereidsPlanner -> { + Plan rewrittenPlan = nereidsPlanner.getRewrittenPlan(); + Optional relatedTableInfo = + MaterializedViewUtils.getRelatedTableInfo("o_orderdate", rewrittenPlan); + checkRelatedTableInfo(relatedTableInfo, + "orders", + "o_orderdate", + true); + }); + } + @Test public void containTableQueryOperatorWithTabletTest() { PlanChecker.from(connectContext)