Skip to content

Commit

Permalink
[enhancement](nereids) convert string literal to commontype in in-exp…
Browse files Browse the repository at this point in the history
…r and cass-when-expr (apache#17200)
  • Loading branch information
morrySnow authored Mar 2, 2023
1 parent 93d2d46 commit 3eeeff0
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 80 deletions.
1 change: 1 addition & 0 deletions .licenserc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ header:
- "**/*.log"
- "**/*.sql"
- "**/*.lock"
- "**/*.out"
- "tsan_suppressions"
- "docs/.markdownlintignore"
- "fe/fe-core/src/test/resources/data/net_snmp_normal"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,26 @@
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonArray;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonObject;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -218,9 +223,6 @@ public Expression visitCaseWhen(CaseWhen caseWhen, CascadesContext context) {
.map(e -> e.accept(this, context)).collect(Collectors.toList());
CaseWhen newCaseWhen = caseWhen.withChildren(rewrittenChildren);

// check
newCaseWhen.checkLegalityBeforeTypeCoercion();

// type coercion
List<DataType> dataTypesForCoercion = newCaseWhen.dataTypesForCoercion();
if (dataTypesForCoercion.size() <= 1) {
Expand All @@ -230,20 +232,37 @@ public Expression visitCaseWhen(CaseWhen caseWhen, CascadesContext context) {
if (dataTypesForCoercion.stream().allMatch(dataType -> dataType.equals(first))) {
return newCaseWhen;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(dataTypesForCoercion);
return optionalCommonType
.map(commonType -> {
List<Expression> newChildren
= newCaseWhen.getWhenClauses().stream()
.map(wc -> wc.withChildren(wc.getOperand(),
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)))
.collect(Collectors.toList());
newCaseWhen.getDefaultValue()
.map(dv -> TypeCoercionUtils.castIfNotMatchType(dv, commonType))
.ifPresent(newChildren::add);
return newCaseWhen.withChildren(newChildren);
})
.orElse(newCaseWhen);

Map<Boolean, List<Expression>> filteredStringLiteral = newCaseWhen.expressionForCoercion()
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));

Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
.stream().map(Expression::getDataType).collect(Collectors.toList()));

if (!optionalCommonType.isPresent()) {
return newCaseWhen;
}
DataType commonType = optionalCommonType.get();

// process character literal
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
Literal literal = (Literal) stringLikeLiteral;
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
literal.getStringValue(), commonType).isPresent()) {
commonType = StringType.INSTANCE;
break;
}
}

List<Expression> newChildren = Lists.newArrayList();
for (WhenClause wc : newCaseWhen.getWhenClauses()) {
newChildren.add(wc.withChildren(wc.getOperand(),
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)));
}
if (newCaseWhen.getDefaultValue().isPresent()) {
newChildren.add(TypeCoercionUtils.castIfNotMatchType(newCaseWhen.getDefaultValue().get(), commonType));
}
return newCaseWhen.withChildren(newChildren);
}

@Override
Expand All @@ -257,17 +276,32 @@ public Expression visitInPredicate(InPredicate inPredicate, CascadesContext cont
.allMatch(dt -> dt.equals(newInPredicate.getCompareExpr().getDataType()))) {
return newInPredicate;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(newInPredicate.children()

Map<Boolean, List<Expression>> filteredStringLiteral = newInPredicate.children()
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
.stream().map(Expression::getDataType).collect(Collectors.toList()));

return optionalCommonType
.map(commonType -> {
List<Expression> newChildren = newInPredicate.children().stream()
.map(e -> TypeCoercionUtils.castIfNotMatchType(e, commonType))
.collect(Collectors.toList());
return newInPredicate.withChildren(newChildren);
})
.orElse(newInPredicate);
if (!optionalCommonType.isPresent()) {
return newInPredicate;
}
DataType commonType = optionalCommonType.get();

// process character literal
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
Literal literal = (Literal) stringLikeLiteral;
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
literal.getStringValue(), commonType).isPresent()) {
commonType = StringType.INSTANCE;
break;
}
}

List<Expression> newChildren = Lists.newArrayList();
for (Expression child : newInPredicate.children()) {
newChildren.add(TypeCoercionUtils.castIfNotMatchType(child, commonType));
}
return newInPredicate.withChildren(newChildren);
}

private Expression visitImplicitCastInputTypes(Expression expr, List<AbstractDataType> expectedInputTypes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,19 @@
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.Lists;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -152,20 +158,37 @@ public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext cont
if (dataTypesForCoercion.stream().allMatch(dataType -> dataType.equals(first))) {
return newCaseWhen;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(dataTypesForCoercion);
return optionalCommonType
.map(commonType -> {
List<Expression> newChildren
= newCaseWhen.getWhenClauses().stream()
.map(wc -> wc.withChildren(wc.getOperand(),
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)))
.collect(Collectors.toList());
newCaseWhen.getDefaultValue()
.map(dv -> TypeCoercionUtils.castIfNotMatchType(dv, commonType))
.ifPresent(newChildren::add);
return newCaseWhen.withChildren(newChildren);
})
.orElse(newCaseWhen);

Map<Boolean, List<Expression>> filteredStringLiteral = newCaseWhen.expressionForCoercion()
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));

Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
.stream().map(Expression::getDataType).collect(Collectors.toList()));

if (!optionalCommonType.isPresent()) {
return newCaseWhen;
}
DataType commonType = optionalCommonType.get();

// process character literal
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
Literal literal = (Literal) stringLikeLiteral;
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
literal.getStringValue(), commonType).isPresent()) {
commonType = StringType.INSTANCE;
break;
}
}

List<Expression> newChildren = Lists.newArrayList();
for (WhenClause wc : newCaseWhen.getWhenClauses()) {
newChildren.add(wc.withChildren(wc.getOperand(),
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)));
}
if (newCaseWhen.getDefaultValue().isPresent()) {
newChildren.add(TypeCoercionUtils.castIfNotMatchType(newCaseWhen.getDefaultValue().get(), commonType));
}
return newCaseWhen.withChildren(newChildren);
}

@Override
Expand All @@ -178,17 +201,32 @@ public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteCon
.allMatch(dt -> dt.equals(newInPredicate.getCompareExpr().getDataType()))) {
return newInPredicate;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(newInPredicate.children()

Map<Boolean, List<Expression>> filteredStringLiteral = newInPredicate.children()
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
.stream().map(Expression::getDataType).collect(Collectors.toList()));

return optionalCommonType
.map(commonType -> {
List<Expression> newChildren = newInPredicate.children().stream()
.map(e -> TypeCoercionUtils.castIfNotMatchType(e, commonType))
.collect(Collectors.toList());
return newInPredicate.withChildren(newChildren);
})
.orElse(newInPredicate);
if (!optionalCommonType.isPresent()) {
return newInPredicate;
}
DataType commonType = optionalCommonType.get();

// process character literal
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
Literal literal = (Literal) stringLikeLiteral;
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
literal.getStringValue(), commonType).isPresent()) {
commonType = StringType.INSTANCE;
break;
}
}

List<Expression> newChildren = Lists.newArrayList();
for (Expression child : newInPredicate.children()) {
newChildren.add(TypeCoercionUtils.castIfNotMatchType(child, commonType));
}
return newInPredicate.withChildren(newChildren);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
Expand Down Expand Up @@ -71,6 +72,12 @@ public List<DataType> dataTypesForCoercion() {
.collect(ImmutableList.toImmutableList());
}

public List<Expression> expressionForCoercion() {
List<Expression> ret = whenClauses.stream().map(WhenClause::getResult).collect(Collectors.toList());
defaultValue.ifPresent(ret::add);
return ret;
}

public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitCaseWhen(this, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !test_compare_expression --
0

-- !test_compare_expression_2 --
0

-- !test_compare_expression_3 --
true

-- !test_compare_expression_4 --
\N

-- !test_compare_expression_5 --
false

-- !test_compare_expression_6 --
\N

-- !test_compare_expression_7 --
true

-- !test_compare_expression_8 --
\N

-- !test_compare_expression_9 --
false

-- !test_compare_expression_10 --
\N

-- !test_compare_expression_11 --
true

-- !test_compare_expression_12 --
2008-08-08T00:00

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ select 1 not in (null, 1);
select 1 not in (null, 2);


select timestamp '2008-08-08 00:00:00' in ('2008-08-08');
select case when true then timestamp '2008-08-08 00:00:00' else '2008-08-08' end;

0 comments on commit 3eeeff0

Please sign in to comment.