diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index cb61795865239bb..ca0490f8f06d2e9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -48,12 +48,14 @@ import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateTimeV2Type; import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.types.DateV2Type; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.coercion.DateLikeType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.base.Preconditions; @@ -71,12 +73,6 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule implements ExpressionPatternRuleFactory { public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate(); - enum AdjustType { - LOWER, - UPPER, - NONE - } - @Override public List> buildRules() { return ImmutableList.of( @@ -116,77 +112,83 @@ public static Expression simplify(ComparisonPredicate cp) { return cp; } - private static Expression processComparisonPredicateDateTimeV2Literal( + private static Expression processDateTimeLikeComparisonPredicateDateTimeV2Literal( ComparisonPredicate comparisonPredicate, Expression left, DateTimeV2Literal right) { - DateTimeV2Type leftType = (DateTimeV2Type) left.getDataType(); + DataType leftType = left.getDataType(); + int toScale = 0; + if (leftType instanceof DateTimeType) { + toScale = 0; + } else if (leftType instanceof DateTimeV2Type) { + toScale = ((DateTimeV2Type) leftType).getScale(); + } else { + return comparisonPredicate; + } DateTimeV2Type rightType = right.getDataType(); - if (leftType.getScale() < rightType.getScale()) { - int toScale = leftType.getScale(); + if (toScale < rightType.getScale()) { if (comparisonPredicate instanceof EqualTo) { long originValue = right.getMicroSecond(); right = right.roundCeiling(toScale); - if (right.getMicroSecond() == originValue) { - return comparisonPredicate.withChildren(left, right); - } else { - if (left.nullable()) { - // TODO: the ideal way is to return an If expr like: - // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), - // BooleanLiteral.of(false)); - // but current fold constant rule can't handle such complex expr with null literal - // before supporting complex conjuncts with null literal folding rules, - // we use a trick way like this: - return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE)); - } else { - return BooleanLiteral.of(false); - } + if (right.getMicroSecond() != originValue) { + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return ExpressionUtils.falseOrNull(left); } } else if (comparisonPredicate instanceof NullSafeEqual) { long originValue = right.getMicroSecond(); right = right.roundCeiling(toScale); - if (right.getMicroSecond() == originValue) { - return comparisonPredicate.withChildren(left, right); - } else { + if (right.getMicroSecond() != originValue) { return BooleanLiteral.of(false); } } else if (comparisonPredicate instanceof GreaterThan || comparisonPredicate instanceof LessThanEqual) { - return comparisonPredicate.withChildren(left, right.roundFloor(toScale)); + right = right.roundFloor(toScale); } else if (comparisonPredicate instanceof LessThan || comparisonPredicate instanceof GreaterThanEqual) { - return comparisonPredicate.withChildren(left, right.roundCeiling(toScale)); + right = right.roundCeiling(toScale); + } else { + return comparisonPredicate; + } + Expression newRight = leftType instanceof DateTimeType ? migrateToDateTime(right) : right; + return comparisonPredicate.withChildren(left, newRight); + } else { + if (leftType instanceof DateTimeType) { + return comparisonPredicate.withChildren(left, migrateToDateTime(right)); + } else { + return comparisonPredicate; } } - return comparisonPredicate; } private static Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) { if (left instanceof Cast && right instanceof DateLiteral) { Cast cast = (Cast) left; - if (cast.child().getDataType() instanceof DateTimeType) { + if (cast.child().getDataType() instanceof DateTimeType + || cast.child().getDataType() instanceof DateTimeV2Type) { if (right instanceof DateTimeV2Literal) { - left = cast.child(); - right = migrateToDateTime((DateTimeV2Literal) right); - } - } - if (cast.child().getDataType() instanceof DateTimeV2Type) { - if (right instanceof DateTimeV2Literal) { - left = cast.child(); - return processComparisonPredicateDateTimeV2Literal(cp, left, (DateTimeV2Literal) right); + return processDateTimeLikeComparisonPredicateDateTimeV2Literal( + cp, cast.child(), (DateTimeV2Literal) right); } } + // datetime to datev2 if (cast.child().getDataType() instanceof DateType || cast.child().getDataType() instanceof DateV2Type) { if (right instanceof DateTimeLiteral) { - if (cannotAdjust((DateTimeLiteral) right, cp)) { - return cp; - } - AdjustType type = AdjustType.NONE; - if (cp instanceof GreaterThanEqual || cp instanceof LessThan) { - type = AdjustType.UPPER; - } else if (cp instanceof GreaterThan || cp instanceof LessThanEqual) { - type = AdjustType.LOWER; + DateTimeLiteral dateTimeLiteral = (DateTimeLiteral) right; + right = migrateToDateV2(dateTimeLiteral); + if (dateTimeLiteral.getHour() != 0 || dateTimeLiteral.getMinute() != 0 + || dateTimeLiteral.getSecond() != 0) { + if (cp instanceof EqualTo) { + return ExpressionUtils.falseOrNull(cast.child()); + } else if (cp instanceof NullSafeEqual) { + return BooleanLiteral.FALSE; + } else if (cp instanceof GreaterThanEqual || cp instanceof LessThan) { + right = ((DateV2Literal) right).plusDays(1); + } } - right = migrateToDateV2((DateTimeLiteral) right, type); if (cast.child().getDataType() instanceof DateV2Type) { left = cast.child(); } @@ -340,17 +342,8 @@ private static Expression migrateToDateTime(DateTimeV2Literal l) { return new DateTimeLiteral(l.getYear(), l.getMonth(), l.getDay(), l.getHour(), l.getMinute(), l.getSecond()); } - private static boolean cannotAdjust(DateTimeLiteral l, ComparisonPredicate cp) { - return cp instanceof EqualTo && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0); - } - - private static Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) { - DateV2Literal d = new DateV2Literal(l.getYear(), l.getMonth(), l.getDay()); - if (type == AdjustType.UPPER && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0)) { - return d.plusDays(1); - } else { - return d; - } + private static Expression migrateToDateV2(DateTimeLiteral l) { + return new DateV2Literal(l.getYear(), l.getMonth(), l.getDay()); } private static Expression migrateToDate(DateV2Literal l) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 22b681a6246d920..25637d1b8166568 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -260,6 +260,22 @@ public static Expression or(Collection expressions) { } } + public static Expression falseOrNull(Expression expression) { + if (expression.nullable()) { + return new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE)); + } else { + return BooleanLiteral.FALSE; + } + } + + public static Expression trueOrNull(Expression expression) { + if (expression.nullable()) { + return new Or(new Not(new IsNull(expression)), new NullLiteral(BooleanType.INSTANCE)); + } else { + return BooleanLiteral.TRUE; + } + } + /** * Use AND/OR to combine expressions together. */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 84ebd7c72501984..db95b705b0dfe37 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -40,9 +40,13 @@ import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateTimeV2Type; +import org.apache.doris.nereids.types.DateType; +import org.apache.doris.nereids.types.DateV2Type; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Assertions; @@ -81,11 +85,11 @@ void testSimplifyComparisonPredicateRule() { new LessThan(dv2, dv2PlusOne)); assertRewrite( new EqualTo(new Cast(dv2, DateTimeV2Type.SYSTEM_DEFAULT), dtv2), - new EqualTo(new Cast(dv2, DateTimeV2Type.SYSTEM_DEFAULT), dtv2)); + BooleanLiteral.FALSE); assertRewrite( new EqualTo(new Cast(d, DateTimeV2Type.SYSTEM_DEFAULT), dtv2), - new EqualTo(new Cast(d, DateTimeV2Type.SYSTEM_DEFAULT), dtv2)); + BooleanLiteral.FALSE); // test hour, minute and second all zero Expression dtv2AtZeroClock = new DateTimeV2Literal(1, 1, 1, 0, 0, 0, 0); @@ -126,6 +130,100 @@ void testDateTimeV2CmpDateTimeV2() { expression = new GreaterThan(left, right); rewrittenExpression = executor.rewrite(typeCoercion(expression), context); Assertions.assertEquals(dt.getDataType(), rewrittenExpression.child(0).getDataType()); + + Expression date = new SlotReference("a", DateV2Type.INSTANCE); + Expression datev1 = new SlotReference("a", DateType.INSTANCE); + Expression datetime0 = new SlotReference("a", DateTimeV2Type.of(0)); + Expression datetime2 = new SlotReference("a", DateTimeV2Type.of(2)); + Expression datetimev1 = new SlotReference("a", DateTimeType.INSTANCE); + + // date + // cast (date as datetimev1) cmp datetimev1 + assertRewrite(new EqualTo(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:00")), + new EqualTo(date, new DateV2Literal("2020-01-01"))); + assertRewrite(new EqualTo(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + ExpressionUtils.falseOrNull(date)); + assertRewrite(new NullSafeEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + new GreaterThan(date, new DateV2Literal("2020-01-01"))); + assertRewrite(new GreaterThanEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + new GreaterThanEqual(date, new DateV2Literal("2020-01-02"))); + assertRewrite(new LessThan(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + new LessThan(date, new DateV2Literal("2020-01-02"))); + assertRewrite(new LessThanEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + new LessThanEqual(date, new DateV2Literal("2020-01-01"))); + // cast (date as datev1) = datev1-literal + // assertRewrite(new EqualTo(new Cast(date, DateType.INSTANCE), new DateLiteral("2020-01-01")), + // new EqualTo(date, new DateV2Literal("2020-01-01"))); + // assertRewrite(new GreaterThan(new Cast(date, DateType.INSTANCE), new DateLiteral("2020-01-01")), + // new GreaterThan(date, new DateV2Literal("2020-01-01"))); + + // cast (datev1 as datetimev1) cmp datetimev1 + assertRewrite(new EqualTo(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:00")), + new EqualTo(datev1, new DateLiteral("2020-01-01"))); + assertRewrite(new EqualTo(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + ExpressionUtils.falseOrNull(datev1)); + assertRewrite(new NullSafeEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + new GreaterThan(datev1, new DateLiteral("2020-01-01"))); + assertRewrite(new GreaterThanEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + new GreaterThanEqual(datev1, new DateLiteral("2020-01-02"))); + assertRewrite(new LessThan(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + new LessThan(datev1, new DateLiteral("2020-01-02"))); + assertRewrite(new LessThanEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")), + new LessThanEqual(datev1, new DateLiteral("2020-01-01"))); + assertRewrite(new EqualTo(new Cast(datev1, DateV2Type.INSTANCE), new DateV2Literal("2020-01-01")), + new EqualTo(datev1, new DateLiteral("2020-01-01"))); + assertRewrite(new GreaterThan(new Cast(datev1, DateV2Type.INSTANCE), new DateV2Literal("2020-01-01")), + new GreaterThan(datev1, new DateLiteral("2020-01-01"))); + + // cast (datetimev1 as datetime) cmp datetime + assertRewrite(new EqualTo(new Cast(datetimev1, DateTimeV2Type.of(0)), new DateTimeV2Literal("2020-01-01 00:00:00")), + new EqualTo(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00"))); + assertRewrite(new GreaterThan(new Cast(datetimev1, DateTimeV2Type.of(0)), new DateTimeV2Literal("2020-01-01 00:00:00")), + new GreaterThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00"))); + assertRewrite(new EqualTo(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + ExpressionUtils.falseOrNull(datetimev1)); + assertRewrite(new NullSafeEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + new GreaterThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00"))); + assertRewrite(new GreaterThanEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + new GreaterThanEqual(datetimev1, new DateTimeLiteral("2020-01-01 00:00:01"))); + assertRewrite(new LessThan(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + new LessThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:01"))); + assertRewrite(new LessThanEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + new LessThanEqual(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00"))); + + // cast (datetime0 as datetime) cmp datetime + assertRewrite(new EqualTo(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + ExpressionUtils.falseOrNull(datetime0)); + assertRewrite(new NullSafeEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + new GreaterThan(datetime0, new DateTimeV2Literal("2020-01-01 00:00:00"))); + assertRewrite(new GreaterThanEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + new GreaterThanEqual(datetime0, new DateTimeV2Literal("2020-01-01 00:00:01"))); + assertRewrite(new LessThan(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + new LessThan(datetime0, new DateTimeV2Literal("2020-01-01 00:00:01"))); + assertRewrite(new LessThanEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")), + new LessThanEqual(datetime0, new DateTimeV2Literal("2020-01-01 00:00:00"))); + + // cast (datetime2 as datetime) cmp datetime + assertRewrite(new EqualTo(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")), + ExpressionUtils.falseOrNull(datetime2)); + assertRewrite(new NullSafeEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")), + new GreaterThan(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.12"))); + assertRewrite(new GreaterThanEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")), + new GreaterThanEqual(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.13"))); + assertRewrite(new LessThan(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")), + new LessThan(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.13"))); + assertRewrite(new LessThanEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")), + new LessThanEqual(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.12"))); } @Test