diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index e13b3c06240c73a..51f481b51ac32c2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -646,10 +646,11 @@ public static boolean checkSlotConstant(Slot slot, Set predicates) { */ public static boolean isInferred(Expression expression) { return expression.accept(new DefaultExpressionVisitor() { + @Override public Boolean visit(Expression expr, Void context) { boolean inferred = expr.isInferred(); - if (expr.isInferred()) { + if (expr.isInferred() || expr.children().isEmpty()) { return inferred; } inferred = true; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index 16a212728f2c77c..33566437837c1b9 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; @@ -87,12 +88,12 @@ public void inferPredicatesTest01() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> !filter.getPredicate().isInferred() - && filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().isInferred() - && filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ); @@ -127,7 +128,8 @@ public void inferPredicatesTest03() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), logicalOlapScan() ) ) @@ -166,10 +168,12 @@ public void inferPredicatesTest05() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ), logicalFilter( logicalOlapScan() @@ -192,10 +196,12 @@ public void inferPredicatesTest06() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ), logicalFilter( logicalOlapScan() @@ -217,10 +223,12 @@ public void inferPredicatesTest07() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ); @@ -258,10 +266,12 @@ public void inferPredicatesTest09() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ); @@ -280,11 +290,13 @@ public void inferPredicatesTest10() { logicalProject( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")) + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")) ), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ); @@ -323,13 +335,15 @@ public void inferPredicatesTest12() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filer -> filer.getPredicate().toSql().contains("id > 1")), + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalProject( logicalAggregate( logicalProject( logicalFilter( logicalOlapScan() - ).when(filer -> filer.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ) @@ -351,11 +365,13 @@ public void inferPredicatesTest13() { logicalProject( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id = 1")) + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id = 1")) ), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid = 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid = 1")) ) ) ); @@ -373,11 +389,13 @@ public void inferPredicatesTest14() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalProject( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ) @@ -396,11 +414,13 @@ public void inferPredicatesTest15() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalProject( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ) @@ -461,11 +481,13 @@ public void inferPredicatesTest18() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalProject( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) )