Skip to content

Commit

Permalink
add infer props_to_expression
Browse files Browse the repository at this point in the history
  • Loading branch information
seawinde committed Dec 22, 2023
1 parent 9b67c86 commit 75b7127
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public ComparisonInferInfo(InferType inferType,
public Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
// if we support more predicate infer, we should add .withInferred(this.isInferred())
// to mark the predicate is from infer when call withChildren method
if (!(predicate instanceof ComparisonPredicate)) {
continue;
}
Expand Down Expand Up @@ -130,7 +132,7 @@ private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo pr
.comparisonPredicate.withChildren(newLeft, newRight);
Expression expr = SimplifyComparisonPredicate.INSTANCE
.rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null);
return DateFunctionRewrite.INSTANCE.rewrite(expr, null);
return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true);
}

private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public boolean nullable() throws UnboundException {
@Override
public EqualTo withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new EqualTo(children);
return new EqualTo(children).withInferred(this.isInferred());
}

public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
protected Optional<String> exprName = Optional.empty();
private final int depth;
private final int width;
// Mark this expression is from predicate infer or something else infer
private boolean inferred;

protected Expression(Expression... children) {
super(children);
Expand All @@ -69,6 +71,7 @@ protected Expression(Expression... children) {
.mapToInt(e -> e.width)
.sum() + (children.length == 0 ? 1 : 0);
checkLimit();
this.inferred = false;
}

protected Expression(List<Expression> children) {
Expand All @@ -80,6 +83,19 @@ protected Expression(List<Expression> children) {
.mapToInt(e -> e.width)
.sum() + (children.isEmpty() ? 1 : 0);
checkLimit();
this.inferred = false;
}

protected Expression(List<Expression> children, boolean inferred) {
super(children);
depth = children.stream()
.mapToInt(e -> e.depth)
.max().orElse(0) + 1;
width = children.stream()
.mapToInt(e -> e.width)
.sum() + (children.isEmpty() ? 1 : 0);
checkLimit();
this.inferred = inferred;
}

private void checkLimit() {
Expand Down Expand Up @@ -216,11 +232,20 @@ public int getDepth() {
return depth;
}

public boolean isInferred() {
return inferred;
}

@Override
public Expression withChildren(List<Expression> children) {
throw new RuntimeException();
}

public <E> E withInferred(boolean inferred) {
this.inferred = inferred;
return (E) this;
}

/**
* Whether the expression is a constant.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public String toString() {
@Override
public GreaterThan withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new GreaterThan(children);
return new GreaterThan(children).withInferred(this.isInferred());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public String toString() {
@Override
public GreaterThanEqual withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new GreaterThanEqual(children);
return new GreaterThanEqual(children).withInferred(this.isInferred());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public String toString() {
@Override
public LessThan withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new LessThan(children);
return new LessThan(children).withInferred(this.isInferred());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public String toString() {
@Override
public LessThanEqual withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new LessThanEqual(children);
return new LessThanEqual(children).withInferred(this.isInferred());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
@Override
public NullSafeEqual withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new NullSafeEqual(children);
return new NullSafeEqual(children).withInferred(this.isInferred());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer;

Expand Down Expand Up @@ -639,4 +640,24 @@ public static boolean checkSlotConstant(Slot slot, Set<Expression> predicates) {
}
);
}

/**
* isInferred
*/
public static boolean isInferred(Expression expression) {
return expression.accept(new DefaultExpressionVisitor<Boolean, Void>() {
@Override
public Boolean visit(Expression expr, Void context) {
boolean inferred = expr.isInferred();
if (expr.isInferred()) {
return inferred;
}
inferred = true;
for (Expression child : expr.children()) {
inferred = inferred && child.accept(this, context);
}
return inferred;
}
}, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ public void inferPredicatesTest01() {
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
).when(filter -> !filter.getPredicate().isInferred()
&& filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> filter.getPredicate().isInferred()
&& filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
Expand Down

0 comments on commit 75b7127

Please sign in to comment.