Skip to content

Commit

Permalink
fix compare with date like
Browse files Browse the repository at this point in the history
  • Loading branch information
yujun777 committed Dec 13, 2024
1 parent 09619fc commit dbe6fcd
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.base.Preconditions;
Expand All @@ -71,12 +73,6 @@
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule implements ExpressionPatternRuleFactory {
public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate();

enum AdjustType {
LOWER,
UPPER,
NONE
}

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down Expand Up @@ -116,77 +112,83 @@ public static Expression simplify(ComparisonPredicate cp) {
return cp;
}

private static Expression processComparisonPredicateDateTimeV2Literal(
private static Expression processDateTimeLikeComparisonPredicateDateTimeV2Literal(
ComparisonPredicate comparisonPredicate, Expression left, DateTimeV2Literal right) {
DateTimeV2Type leftType = (DateTimeV2Type) left.getDataType();
DataType leftType = left.getDataType();
int toScale = 0;
if (leftType instanceof DateTimeType) {
toScale = 0;
} else if (leftType instanceof DateTimeV2Type) {
toScale = ((DateTimeV2Type) leftType).getScale();
} else {
return comparisonPredicate;
}
DateTimeV2Type rightType = right.getDataType();
if (leftType.getScale() < rightType.getScale()) {
int toScale = leftType.getScale();
if (toScale < rightType.getScale()) {
if (comparisonPredicate instanceof EqualTo) {
long originValue = right.getMicroSecond();
right = right.roundCeiling(toScale);
if (right.getMicroSecond() == originValue) {
return comparisonPredicate.withChildren(left, right);
} else {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.of(false);
}
if (right.getMicroSecond() != originValue) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return ExpressionUtils.falseOrNull(left);
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
long originValue = right.getMicroSecond();
right = right.roundCeiling(toScale);
if (right.getMicroSecond() == originValue) {
return comparisonPredicate.withChildren(left, right);
} else {
if (right.getMicroSecond() != originValue) {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left, right.roundFloor(toScale));
right = right.roundFloor(toScale);
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left, right.roundCeiling(toScale));
right = right.roundCeiling(toScale);
} else {
return comparisonPredicate;
}
Expression newRight = leftType instanceof DateTimeType ? migrateToDateTime(right) : right;
return comparisonPredicate.withChildren(left, newRight);
} else {
if (leftType instanceof DateTimeType) {
return comparisonPredicate.withChildren(left, migrateToDateTime(right));
} else {
return comparisonPredicate;
}
}
return comparisonPredicate;
}

private static Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) {
if (left instanceof Cast && right instanceof DateLiteral) {
Cast cast = (Cast) left;
if (cast.child().getDataType() instanceof DateTimeType) {
if (cast.child().getDataType() instanceof DateTimeType
|| cast.child().getDataType() instanceof DateTimeV2Type) {
if (right instanceof DateTimeV2Literal) {
left = cast.child();
right = migrateToDateTime((DateTimeV2Literal) right);
}
}
if (cast.child().getDataType() instanceof DateTimeV2Type) {
if (right instanceof DateTimeV2Literal) {
left = cast.child();
return processComparisonPredicateDateTimeV2Literal(cp, left, (DateTimeV2Literal) right);
return processDateTimeLikeComparisonPredicateDateTimeV2Literal(
cp, cast.child(), (DateTimeV2Literal) right);
}
}

// datetime to datev2
if (cast.child().getDataType() instanceof DateType || cast.child().getDataType() instanceof DateV2Type) {
if (right instanceof DateTimeLiteral) {
if (cannotAdjust((DateTimeLiteral) right, cp)) {
return cp;
}
AdjustType type = AdjustType.NONE;
if (cp instanceof GreaterThanEqual || cp instanceof LessThan) {
type = AdjustType.UPPER;
} else if (cp instanceof GreaterThan || cp instanceof LessThanEqual) {
type = AdjustType.LOWER;
DateTimeLiteral dateTimeLiteral = (DateTimeLiteral) right;
right = migrateToDateV2(dateTimeLiteral);
if (dateTimeLiteral.getHour() != 0 || dateTimeLiteral.getMinute() != 0
|| dateTimeLiteral.getSecond() != 0) {
if (cp instanceof EqualTo) {
return ExpressionUtils.falseOrNull(cast.child());
} else if (cp instanceof NullSafeEqual) {
return BooleanLiteral.FALSE;
} else if (cp instanceof GreaterThanEqual || cp instanceof LessThan) {
right = ((DateV2Literal) right).plusDays(1);
}
}
right = migrateToDateV2((DateTimeLiteral) right, type);
if (cast.child().getDataType() instanceof DateV2Type) {
left = cast.child();
}
Expand Down Expand Up @@ -340,17 +342,8 @@ private static Expression migrateToDateTime(DateTimeV2Literal l) {
return new DateTimeLiteral(l.getYear(), l.getMonth(), l.getDay(), l.getHour(), l.getMinute(), l.getSecond());
}

private static boolean cannotAdjust(DateTimeLiteral l, ComparisonPredicate cp) {
return cp instanceof EqualTo && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0);
}

private static Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) {
DateV2Literal d = new DateV2Literal(l.getYear(), l.getMonth(), l.getDay());
if (type == AdjustType.UPPER && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0)) {
return d.plusDays(1);
} else {
return d;
}
private static Expression migrateToDateV2(DateTimeLiteral l) {
return new DateV2Literal(l.getYear(), l.getMonth(), l.getDay());
}

private static Expression migrateToDate(DateV2Literal l) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,22 @@ public static Expression or(Collection<Expression> expressions) {
}
}

public static Expression falseOrNull(Expression expression) {
if (expression.nullable()) {
return new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.FALSE;
}
}

public static Expression trueOrNull(Expression expression) {
if (expression.nullable()) {
return new Or(new Not(new IsNull(expression)), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.TRUE;
}
}

/**
* Use AND/OR to combine expressions together.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
Expand Down Expand Up @@ -81,11 +85,11 @@ void testSimplifyComparisonPredicateRule() {
new LessThan(dv2, dv2PlusOne));
assertRewrite(
new EqualTo(new Cast(dv2, DateTimeV2Type.SYSTEM_DEFAULT), dtv2),
new EqualTo(new Cast(dv2, DateTimeV2Type.SYSTEM_DEFAULT), dtv2));
BooleanLiteral.FALSE);

assertRewrite(
new EqualTo(new Cast(d, DateTimeV2Type.SYSTEM_DEFAULT), dtv2),
new EqualTo(new Cast(d, DateTimeV2Type.SYSTEM_DEFAULT), dtv2));
BooleanLiteral.FALSE);

// test hour, minute and second all zero
Expression dtv2AtZeroClock = new DateTimeV2Literal(1, 1, 1, 0, 0, 0, 0);
Expand Down Expand Up @@ -126,6 +130,100 @@ void testDateTimeV2CmpDateTimeV2() {
expression = new GreaterThan(left, right);
rewrittenExpression = executor.rewrite(typeCoercion(expression), context);
Assertions.assertEquals(dt.getDataType(), rewrittenExpression.child(0).getDataType());

Expression date = new SlotReference("a", DateV2Type.INSTANCE);
Expression datev1 = new SlotReference("a", DateType.INSTANCE);
Expression datetime0 = new SlotReference("a", DateTimeV2Type.of(0));
Expression datetime2 = new SlotReference("a", DateTimeV2Type.of(2));
Expression datetimev1 = new SlotReference("a", DateTimeType.INSTANCE);

// date
// cast (date as datetimev1) cmp datetimev1
assertRewrite(new EqualTo(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:00")),
new EqualTo(date, new DateV2Literal("2020-01-01")));
assertRewrite(new EqualTo(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
ExpressionUtils.falseOrNull(date));
assertRewrite(new NullSafeEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThan(date, new DateV2Literal("2020-01-01")));
assertRewrite(new GreaterThanEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThanEqual(date, new DateV2Literal("2020-01-02")));
assertRewrite(new LessThan(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThan(date, new DateV2Literal("2020-01-02")));
assertRewrite(new LessThanEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThanEqual(date, new DateV2Literal("2020-01-01")));
// cast (date as datev1) = datev1-literal
// assertRewrite(new EqualTo(new Cast(date, DateType.INSTANCE), new DateLiteral("2020-01-01")),
// new EqualTo(date, new DateV2Literal("2020-01-01")));
// assertRewrite(new GreaterThan(new Cast(date, DateType.INSTANCE), new DateLiteral("2020-01-01")),
// new GreaterThan(date, new DateV2Literal("2020-01-01")));

// cast (datev1 as datetimev1) cmp datetimev1
assertRewrite(new EqualTo(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:00")),
new EqualTo(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new EqualTo(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
ExpressionUtils.falseOrNull(datev1));
assertRewrite(new NullSafeEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThan(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new GreaterThanEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThanEqual(datev1, new DateLiteral("2020-01-02")));
assertRewrite(new LessThan(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThan(datev1, new DateLiteral("2020-01-02")));
assertRewrite(new LessThanEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThanEqual(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new EqualTo(new Cast(datev1, DateV2Type.INSTANCE), new DateV2Literal("2020-01-01")),
new EqualTo(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new GreaterThan(new Cast(datev1, DateV2Type.INSTANCE), new DateV2Literal("2020-01-01")),
new GreaterThan(datev1, new DateLiteral("2020-01-01")));

// cast (datetimev1 as datetime) cmp datetime
assertRewrite(new EqualTo(new Cast(datetimev1, DateTimeV2Type.of(0)), new DateTimeV2Literal("2020-01-01 00:00:00")),
new EqualTo(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));
assertRewrite(new GreaterThan(new Cast(datetimev1, DateTimeV2Type.of(0)), new DateTimeV2Literal("2020-01-01 00:00:00")),
new GreaterThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));
assertRewrite(new EqualTo(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
ExpressionUtils.falseOrNull(datetimev1));
assertRewrite(new NullSafeEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));
assertRewrite(new GreaterThanEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThanEqual(datetimev1, new DateTimeLiteral("2020-01-01 00:00:01")));
assertRewrite(new LessThan(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:01")));
assertRewrite(new LessThanEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThanEqual(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));

// cast (datetime0 as datetime) cmp datetime
assertRewrite(new EqualTo(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
ExpressionUtils.falseOrNull(datetime0));
assertRewrite(new NullSafeEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThan(datetime0, new DateTimeV2Literal("2020-01-01 00:00:00")));
assertRewrite(new GreaterThanEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThanEqual(datetime0, new DateTimeV2Literal("2020-01-01 00:00:01")));
assertRewrite(new LessThan(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThan(datetime0, new DateTimeV2Literal("2020-01-01 00:00:01")));
assertRewrite(new LessThanEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThanEqual(datetime0, new DateTimeV2Literal("2020-01-01 00:00:00")));

// cast (datetime2 as datetime) cmp datetime
assertRewrite(new EqualTo(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
ExpressionUtils.falseOrNull(datetime2));
assertRewrite(new NullSafeEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new GreaterThan(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.12")));
assertRewrite(new GreaterThanEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new GreaterThanEqual(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.13")));
assertRewrite(new LessThan(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new LessThan(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.13")));
assertRewrite(new LessThanEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new LessThanEqual(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.12")));
}

@Test
Expand Down

0 comments on commit dbe6fcd

Please sign in to comment.