diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java index 356941694988704..7698d881c661aa7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java @@ -42,7 +42,9 @@ public abstract class Edge { // added by the graph simplifier. private final long leftRequiredNodes; private final long rightRequiredNodes; + // The nodes needed which to prevent wrong association or l-association private long leftExtendedNodes; + // The nodes needed which to prevent wrong association or r-association private long rightExtendedNodes; // record the left child edges and right child edges in origin plan tree @@ -53,8 +55,11 @@ public abstract class Edge { private final BitSet curOperatorEdges = new BitSet(); // record all sub nodes behind in this operator. It's T function in paper private final long subTreeNodes; - + // The edges which prevents association or l-association when join edge + // and prevents push down or pull up when filter edge in the left of edge private final Set leftRejectEdges; + // The edges which prevents association or r-association + // and prevents push down or pull up when filter edge in the right of edge private final Set rightRejectEdges; /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java index 868f97949c07054..d4594583c314c23 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java @@ -35,12 +35,16 @@ import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; import com.google.common.collect.Sets; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -411,25 +415,72 @@ private Map constructQueryToViewJoinMapWithExpr() { return edgeMap; } + // Such as the filter as following, their expression is same, but should be different filter edge + // Only construct edge that can mapping, the edges which can not mapping would be handled by buildComparisonRes + // LogicalJoin[569] + // |--LogicalProject[567] + // | +--LogicalFilter[566] ( predicates=(l_orderkey#10 IS NULL OR ( not (l_orderkey#10 = 1))) ) + // | +--LogicalJoin[565] + // | |--LogicalProject[562] + // | | +--LogicalOlapScan + // | +--LogicalProject[564] + // | +--LogicalFilter[563] ( predicates=(l_orderkey#10 IS NULL OR ( not (l_orderkey#10 = 1)))) + // | +--LogicalOlapScan + // +--LogicalProject[568] + // +--LogicalOlapScan private Map constructQueryToViewFilterMapWithExpr() { - Map viewExprToEdge = getViewFilterEdges().stream() + Multimap viewExprToEdge = HashMultimap.create(); + getViewFilterEdges().stream() .flatMap(e -> e.getExpressions().stream().map(expr -> Pair.of(expr, e))) - .collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second)); - Map queryExprToEdge = getQueryFilterEdges().stream() + .forEach(pair -> viewExprToEdge.put(pair.key(), pair.value())); + + Multimap queryExprToEdge = HashMultimap.create(); + getQueryFilterEdges().stream() .flatMap(e -> e.getExpressions().stream().map(expr -> Pair.of(expr, e))) - .collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second)); + .forEach(pair -> queryExprToEdge.put(pair.key(), pair.value())); - HashMap edgeMap = new HashMap<>(); - for (Entry entry : queryExprToEdge.entrySet()) { - if (edgeMap.containsKey(entry.getValue())) { + HashMap queryToViewEdgeMap = new HashMap<>(); + for (Entry> entry : queryExprToEdge.asMap().entrySet()) { + Expression queryExprViewBased = logicalCompatibilityContext.getViewFilterExprFromQuery(entry.getKey()); + if (queryExprViewBased == null) { continue; } - Expression viewExpr = logicalCompatibilityContext.getViewFilterExprFromQuery(entry.getKey()); - if (viewExprToEdge.containsKey(viewExpr)) { - edgeMap.put(entry.getValue(), Objects.requireNonNull(viewExprToEdge.get(viewExpr))); + Collection viewEdges = viewExprToEdge.get(queryExprViewBased); + if (viewEdges.isEmpty()) { + continue; + } + for (Edge queryEdge : entry.getValue()) { + for (Edge viewEdge : viewEdges) { + if (!isSubTreeNodesEquals(queryEdge, viewEdge, logicalCompatibilityContext)) { + // Such as query filter edge is <{1} --FILTER-- {}> but view filter edge is + // <{0, 1} --FILTER-- {}>, though they are all + // l_orderkey#10 IS NULL OR ( not (l_orderkey#10 = 1)) but they are different actually + continue; + } + queryToViewEdgeMap.put(queryEdge, viewEdge); + } } } - return edgeMap; + return queryToViewEdgeMap; + } + + private static boolean isSubTreeNodesEquals(Edge queryEdge, Edge viewEdge, + LogicalCompatibilityContext logicalCompatibilityContext) { + if (!(queryEdge instanceof FilterEdge) || !(viewEdge instanceof FilterEdge)) { + return false; + } + // subTreeNodes should be equal + BiMap queryToViewNodeIdMapping = + logicalCompatibilityContext.getQueryToViewNodeIDMapping(); + List queryNodeIndexViewBasedList = new ArrayList<>(); + for (int queryNodeIndex : LongBitmap.getIterator(queryEdge.getSubTreeNodes())) { + Integer queryNodeIndexViewBased = queryToViewNodeIdMapping.get(queryNodeIndex); + if (queryNodeIndexViewBased == null) { + return false; + } + queryNodeIndexViewBasedList.add(queryNodeIndexViewBased); + } + return LongBitmap.newBitmap(queryNodeIndexViewBasedList) == viewEdge.getSubTreeNodes(); } private void refreshViewEdges() { diff --git a/regression-test/data/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.out b/regression-test/data/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.out index 3b9c3a1219ea15c..5c9df6b7f92256a 100644 --- a/regression-test/data/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.out +++ b/regression-test/data/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.out @@ -315,3 +315,11 @@ a 3 3 a,a,a 4.0 yy 3 1 a 4 2 a,a 4.0 yy 2 1 c 3 6 c,c,c 5.333333333333333 mi 3 2 +-- !query28_0_before -- +1 2023-12-09 1 yy 2 2 2 4 3 \N 2 3 \N \N 8 8 1 +1 2023-12-09 1 yy 2 2 2 4 3 \N 2 3 1 2 8 8 1 + +-- !query28_0_after -- +1 2023-12-09 1 yy 2 2 2 4 3 \N 2 3 \N \N 8 8 1 +1 2023-12-09 1 yy 2 2 2 4 3 \N 2 3 1 2 8 8 1 + diff --git a/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy b/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy index 9d60280503cf748..13ed0ca7dd5d804 100644 --- a/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy +++ b/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy @@ -36,7 +36,8 @@ suite("aggregate_without_roll_up") { o_orderpriority CHAR(15) NOT NULL, o_clerk CHAR(15) NOT NULL, o_shippriority INTEGER NOT NULL, - o_comment VARCHAR(79) NOT NULL + o_comment VARCHAR(79) NOT NULL, + public_col INT NULL ) DUPLICATE KEY(o_orderkey, o_custkey) DISTRIBUTED BY HASH(o_orderkey) BUCKETS 3 @@ -66,7 +67,8 @@ suite("aggregate_without_roll_up") { l_receiptdate DATE NOT NULL, l_shipinstruct CHAR(25) NOT NULL, l_shipmode CHAR(10) NOT NULL, - l_comment VARCHAR(44) NOT NULL + l_comment VARCHAR(44) NOT NULL, + public_col INT NULL ) DUPLICATE KEY(l_orderkey, l_partkey, l_suppkey, l_linenumber) DISTRIBUTED BY HASH(l_orderkey) BUCKETS 3 @@ -85,7 +87,8 @@ suite("aggregate_without_roll_up") { ps_suppkey INTEGER NOT NULL, ps_availqty INTEGER NOT NULL, ps_supplycost DECIMALV3(15,2) NOT NULL, - ps_comment VARCHAR(199) NOT NULL + ps_comment VARCHAR(199) NOT NULL, + public_col INT NULL ) DUPLICATE KEY(ps_partkey, ps_suppkey) DISTRIBUTED BY HASH(ps_partkey) BUCKETS 3 @@ -96,29 +99,29 @@ suite("aggregate_without_roll_up") { sql """ insert into lineitem values - (1, 2, 3, 4, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-12-08', '2023-12-09', '2023-12-10', 'a', 'b', 'yyyyyyyyy'), - (2, 4, 3, 4, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-12-09', '2023-12-09', '2023-12-10', 'a', 'b', 'yyyyyyyyy'), - (3, 2, 4, 4, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-12-10', '2023-12-09', '2023-12-10', 'a', 'b', 'yyyyyyyyy'), - (4, 3, 3, 4, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-12-11', '2023-12-09', '2023-12-10', 'a', 'b', 'yyyyyyyyy'), - (5, 2, 3, 6, 7.5, 8.5, 9.5, 10.5, 'k', 'o', '2023-12-12', '2023-12-12', '2023-12-13', 'c', 'd', 'xxxxxxxxx'); + (1, 2, 3, 4, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-12-08', '2023-12-09', '2023-12-10', 'a', 'b', 'yyyyyyyyy', 1), + (2, 4, 3, 4, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-12-09', '2023-12-09', '2023-12-10', 'a', 'b', 'yyyyyyyyy', null), + (3, 2, 4, 4, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-12-10', '2023-12-09', '2023-12-10', 'a', 'b', 'yyyyyyyyy', 2), + (4, 3, 3, 4, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-12-11', '2023-12-09', '2023-12-10', 'a', 'b', 'yyyyyyyyy', null), + (5, 2, 3, 6, 7.5, 8.5, 9.5, 10.5, 'k', 'o', '2023-12-12', '2023-12-12', '2023-12-13', 'c', 'd', 'xxxxxxxxx', 3); """ sql """ insert into orders values - (1, 1, 'o', 9.5, '2023-12-08', 'a', 'b', 1, 'yy'), - (1, 1, 'o', 10.5, '2023-12-08', 'a', 'b', 1, 'yy'), - (2, 1, 'o', 11.5, '2023-12-09', 'a', 'b', 1, 'yy'), - (3, 1, 'o', 12.5, '2023-12-10', 'a', 'b', 1, 'yy'), - (3, 1, 'o', 33.5, '2023-12-10', 'a', 'b', 1, 'yy'), - (4, 2, 'o', 43.2, '2023-12-11', 'c','d',2, 'mm'), - (5, 2, 'o', 56.2, '2023-12-12', 'c','d',2, 'mi'), - (5, 2, 'o', 1.2, '2023-12-12', 'c','d',2, 'mi'); + (1, 1, 'o', 9.5, '2023-12-08', 'a', 'b', 1, 'yy', 1), + (1, 1, 'o', 10.5, '2023-12-08', 'a', 'b', 1, 'yy', null), + (2, 1, 'o', 11.5, '2023-12-09', 'a', 'b', 1, 'yy', 2), + (3, 1, 'o', 12.5, '2023-12-10', 'a', 'b', 1, 'yy', null), + (3, 1, 'o', 33.5, '2023-12-10', 'a', 'b', 1, 'yy', 3), + (4, 2, 'o', 43.2, '2023-12-11', 'c','d',2, 'mm', null), + (5, 2, 'o', 56.2, '2023-12-12', 'c','d',2, 'mi', 4), + (5, 2, 'o', 1.2, '2023-12-12', 'c','d',2, 'mi', null); """ sql """ insert into partsupp values - (2, 3, 9, 10.01, 'supply1'), - (2, 3, 10, 11.01, 'supply2'); + (2, 3, 9, 10.01, 'supply1', 1), + (2, 3, 10, 11.01, 'supply2', null); """ // single table @@ -1356,4 +1359,156 @@ suite("aggregate_without_roll_up") { """ async_mv_rewrite_fail(db, mv27_0, query27_0, "mv27_0") sql """ DROP MATERIALIZED VIEW IF EXISTS mv27_0""" + + + // query and mv has the same filter but position is different, should rewrite successfully + def mv28_0 = """ + select + o_custkey, + o_orderdate, + o_shippriority, + o_comment, + o_orderkey, + orders.public_col as col1, + l_orderkey, + l_partkey, + l_suppkey, + lineitem.public_col as col2, + ps_partkey, + ps_suppkey, + partsupp.public_col as col3, + partsupp.public_col * 2 as col4, + o_orderkey + l_orderkey + ps_partkey * 2, + sum( + o_orderkey + l_orderkey + ps_partkey * 2 + ), + count() as count_all + from + ( + select + o_custkey, + o_orderdate, + o_shippriority, + o_comment, + o_orderkey, + orders.public_col as public_col + from + orders + ) orders + left join ( + select + l_orderkey, + l_partkey, + l_suppkey, + lineitem.public_col as public_col + from + lineitem + where + l_orderkey is null + or l_orderkey <> 8 + ) lineitem on l_orderkey = o_orderkey + inner join ( + select + ps_partkey, + ps_suppkey, + partsupp.public_col as public_col + from + partsupp + ) partsupp on ps_partkey = o_orderkey + where + l_orderkey is null + or l_orderkey <> 8 + group by + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14; + """ + def query28_0 = """ + select + o_custkey, + o_orderdate, + o_shippriority, + o_comment, + o_orderkey, + orders.public_col as col1, + l_orderkey, + l_partkey, + l_suppkey, + lineitem.public_col as col2, + ps_partkey, + ps_suppkey, + partsupp.public_col as col3, + partsupp.public_col * 2 as col4, + o_orderkey + l_orderkey + ps_partkey * 2, + sum( + o_orderkey + l_orderkey + ps_partkey * 2 + ), + count() as count_all + from + ( + select + o_custkey, + o_orderdate, + o_shippriority, + o_comment, + o_orderkey, + orders.public_col as public_col + from + orders + ) orders + left join ( + select + l_orderkey, + l_partkey, + l_suppkey, + lineitem.public_col as public_col + from + lineitem + where + l_orderkey is null + or l_orderkey <> 8 + ) lineitem on l_orderkey = o_orderkey + inner join ( + select + ps_partkey, + ps_suppkey, + partsupp.public_col as public_col + from + partsupp + ) partsupp on ps_partkey = o_orderkey + where + l_orderkey is null + or l_orderkey <> 8 + group by + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14; + """ + order_qt_query28_0_before "${query28_0}" + async_mv_rewrite_success(db, mv28_0, query28_0, "mv28_0") + order_qt_query28_0_after "${query28_0}" + sql """ DROP MATERIALIZED VIEW IF EXISTS mv28_0""" + }