Skip to content

Commit

Permalink
add infer props test
Browse files Browse the repository at this point in the history
  • Loading branch information
seawinde committed Dec 23, 2023
1 parent 75b7127 commit 713cd0d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -646,10 +646,11 @@ public static boolean checkSlotConstant(Slot slot, Set<Expression> predicates) {
*/
public static boolean isInferred(Expression expression) {
return expression.accept(new DefaultExpressionVisitor<Boolean, Void>() {

@Override
public Boolean visit(Expression expr, Void context) {
boolean inferred = expr.isInferred();
if (expr.isInferred()) {
if (expr.isInferred() || expr.children().isEmpty()) {
return inferred;
}
inferred = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
Expand Down Expand Up @@ -87,12 +88,12 @@ public void inferPredicatesTest01() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> !filter.getPredicate().isInferred()
&& filter.getPredicate().toSql().contains("id > 1")),
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().isInferred()
&& filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
Expand Down Expand Up @@ -127,7 +128,8 @@ public void inferPredicatesTest03() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
logicalOlapScan()
)
)
Expand Down Expand Up @@ -166,10 +168,12 @@ public void inferPredicatesTest05() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid > 1"))
),
logicalFilter(
logicalOlapScan()
Expand All @@ -192,10 +196,12 @@ public void inferPredicatesTest06() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid > 1"))
),
logicalFilter(
logicalOlapScan()
Expand All @@ -217,10 +223,12 @@ public void inferPredicatesTest07() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
Expand Down Expand Up @@ -258,10 +266,12 @@ public void inferPredicatesTest09() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
Expand All @@ -280,11 +290,13 @@ public void inferPredicatesTest10() {
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1"))
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id > 1"))
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
Expand Down Expand Up @@ -323,13 +335,15 @@ public void inferPredicatesTest12() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicate().toSql().contains("id > 1")),
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id > 1")),
logicalProject(
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicate().toSql().contains("sid > 1"))
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid > 1"))
)
)
)
Expand All @@ -351,11 +365,13 @@ public void inferPredicatesTest13() {
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id = 1"))
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id = 1"))
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid = 1"))
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid = 1"))
)
)
);
Expand All @@ -373,11 +389,13 @@ public void inferPredicatesTest14() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id > 1")),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid > 1"))
)
)
)
Expand All @@ -396,11 +414,13 @@ public void inferPredicatesTest15() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id > 1")),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid > 1"))
)
)
)
Expand Down Expand Up @@ -461,11 +481,13 @@ public void inferPredicatesTest18() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id > 1")),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid > 1"))
)
)
)
Expand Down

0 comments on commit 713cd0d

Please sign in to comment.