From 75b71278e729f3bbb5786522ea23b7ff1e73c451 Mon Sep 17 00:00:00 2001 From: seawinde Date: Fri, 22 Dec 2023 16:51:34 +0800 Subject: [PATCH] add infer props_to_expression --- .../rules/rewrite/PredicatePropagation.java | 4 ++- .../nereids/trees/expressions/EqualTo.java | 2 +- .../nereids/trees/expressions/Expression.java | 25 +++++++++++++++++++ .../trees/expressions/GreaterThan.java | 2 +- .../trees/expressions/GreaterThanEqual.java | 2 +- .../nereids/trees/expressions/LessThan.java | 2 +- .../trees/expressions/LessThanEqual.java | 2 +- .../trees/expressions/NullSafeEqual.java | 2 +- .../doris/nereids/util/ExpressionUtils.java | 21 ++++++++++++++++ .../rules/rewrite/InferPredicatesTest.java | 6 +++-- 10 files changed, 59 insertions(+), 9 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index ecb1c5499bde777..aff8b36bc36116f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -88,6 +88,8 @@ public ComparisonInferInfo(InferType inferType, public Set infer(Set predicates) { Set inferred = Sets.newHashSet(); for (Expression predicate : predicates) { + // if we support more predicate infer, we should add .withInferred(this.isInferred()) + // to mark the predicate is from infer when call withChildren method if (!(predicate instanceof ComparisonPredicate)) { continue; } @@ -130,7 +132,7 @@ private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo pr .comparisonPredicate.withChildren(newLeft, newRight); Expression expr = SimplifyComparisonPredicate.INSTANCE .rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null); - return DateFunctionRewrite.INSTANCE.rewrite(expr, null); + return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true); } private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java index 3faccff6d99651f..90c59b8b6022d06 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java @@ -47,7 +47,7 @@ public boolean nullable() throws UnboundException { @Override public EqualTo withChildren(List children) { Preconditions.checkArgument(children.size() == 2); - return new EqualTo(children); + return new EqualTo(children).withInferred(this.isInferred()); } public R accept(ExpressionVisitor visitor, C context) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index bdd776ffe97cc51..4148b1c53d9a0cd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -59,6 +59,8 @@ public abstract class Expression extends AbstractTreeNode implements protected Optional exprName = Optional.empty(); private final int depth; private final int width; + // Mark this expression is from predicate infer or something else infer + private boolean inferred; protected Expression(Expression... children) { super(children); @@ -69,6 +71,7 @@ protected Expression(Expression... children) { .mapToInt(e -> e.width) .sum() + (children.length == 0 ? 1 : 0); checkLimit(); + this.inferred = false; } protected Expression(List children) { @@ -80,6 +83,19 @@ protected Expression(List children) { .mapToInt(e -> e.width) .sum() + (children.isEmpty() ? 1 : 0); checkLimit(); + this.inferred = false; + } + + protected Expression(List children, boolean inferred) { + super(children); + depth = children.stream() + .mapToInt(e -> e.depth) + .max().orElse(0) + 1; + width = children.stream() + .mapToInt(e -> e.width) + .sum() + (children.isEmpty() ? 1 : 0); + checkLimit(); + this.inferred = inferred; } private void checkLimit() { @@ -216,11 +232,20 @@ public int getDepth() { return depth; } + public boolean isInferred() { + return inferred; + } + @Override public Expression withChildren(List children) { throw new RuntimeException(); } + public E withInferred(boolean inferred) { + this.inferred = inferred; + return (E) this; + } + /** * Whether the expression is a constant. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThan.java index 1871781ca9cd106..b7ae656f787345e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThan.java @@ -57,7 +57,7 @@ public String toString() { @Override public GreaterThan withChildren(List children) { Preconditions.checkArgument(children.size() == 2); - return new GreaterThan(children); + return new GreaterThan(children).withInferred(this.isInferred()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThanEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThanEqual.java index e2995298e3734b7..77827c5c5628c2b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThanEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/GreaterThanEqual.java @@ -52,7 +52,7 @@ public String toString() { @Override public GreaterThanEqual withChildren(List children) { Preconditions.checkArgument(children.size() == 2); - return new GreaterThanEqual(children); + return new GreaterThanEqual(children).withInferred(this.isInferred()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThan.java index 4d34b50bb9c7b81..d1070fb06424659 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThan.java @@ -51,7 +51,7 @@ public String toString() { @Override public LessThan withChildren(List children) { Preconditions.checkArgument(children.size() == 2); - return new LessThan(children); + return new LessThan(children).withInferred(this.isInferred()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThanEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThanEqual.java index 4ac997d3ab7b4e4..420b86b44306152 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThanEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/LessThanEqual.java @@ -57,7 +57,7 @@ public String toString() { @Override public LessThanEqual withChildren(List children) { Preconditions.checkArgument(children.size() == 2); - return new LessThanEqual(children); + return new LessThanEqual(children).withInferred(this.isInferred()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java index 48d05364fa3441c..71cd911af30c314 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java @@ -51,7 +51,7 @@ public R accept(ExpressionVisitor visitor, C context) { @Override public NullSafeEqual withChildren(List children) { Preconditions.checkArgument(children.size() == 2); - return new NullSafeEqual(children); + return new NullSafeEqual(children).withInferred(this.isInferred()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 70b91dfe1028c1c..e13b3c06240c73a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -48,6 +48,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer; @@ -639,4 +640,24 @@ public static boolean checkSlotConstant(Slot slot, Set predicates) { } ); } + + /** + * isInferred + */ + public static boolean isInferred(Expression expression) { + return expression.accept(new DefaultExpressionVisitor() { + @Override + public Boolean visit(Expression expr, Void context) { + boolean inferred = expr.isInferred(); + if (expr.isInferred()) { + return inferred; + } + inferred = true; + for (Expression child : expr.children()) { + inferred = inferred && child.accept(this, context); + } + return inferred; + } + }, null); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index 1421a912a028de7..16a212728f2c77c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -87,10 +87,12 @@ public void inferPredicatesTest01() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> !filter.getPredicate().isInferred() + && filter.getPredicate().toSql().contains("id > 1")), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> filter.getPredicate().isInferred() + && filter.getPredicate().toSql().contains("sid > 1")) ) ) );