diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index 33566437837c1b9..942ca01e143bcf3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -148,7 +148,8 @@ public void inferPredicatesTest04() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), logicalOlapScan() ) ) @@ -247,7 +248,8 @@ public void inferPredicatesTest08() { logicalOlapScan(), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ); @@ -317,7 +319,8 @@ public void inferPredicatesTest11() { ), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ); @@ -441,7 +444,8 @@ public void inferPredicatesTest16() { logicalProject( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ) @@ -462,7 +466,8 @@ public void inferPredicatesTest17() { logicalProject( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ) @@ -528,19 +533,22 @@ public void inferPredicatesTest19() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("k1 = 3")), + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("k1 = 3")), logicalProject( logicalJoin( logicalJoin( logicalProject( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("k3 = 3")) + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("k3 = 3")) ), logicalProject( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("k1 = 3")) + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("k1 = 3")) ) ), logicalAggregate( @@ -568,10 +576,12 @@ public void inferPredicatesTest20() { innerLogicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ), logicalFilter( logicalOlapScan() @@ -594,10 +604,12 @@ public void inferPredicatesTest21() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ), logicalFilter( logicalOlapScan() @@ -622,11 +634,13 @@ public void inferPredicatesTest22() { logicalJoin( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("id > 1")), + ).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("id > 1")), logicalProject( logicalFilter( logicalOlapScan() - ).when(filter -> filter.getPredicate().toSql().contains("sid > 1")) + ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid > 1")) ) ) ) @@ -650,6 +664,7 @@ public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() { logicalFilter( logicalOlapScan() ).when(filter -> filter.getConjuncts().size() == 1 + && !ExpressionUtils.isInferred(filter.getPredicate()) && filter.getPredicate().toSql().contains("id = 2")), any() ).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN)