Skip to content

Commit

Permalink
[fix](Nereids) nested type literal type coercion and insert values wi…
Browse files Browse the repository at this point in the history
…th map (apache#26669)
  • Loading branch information
morrySnow authored Nov 15, 2023
1 parent febf4bc commit 2c6d225
Show file tree
Hide file tree
Showing 13 changed files with 669 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.RelationUtil;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -226,7 +225,6 @@ public List<Rule> buildRules() {
// we skip it.
continue;
}
maybeFallbackCastUnsupportedType(expr, ctx.connectContext);
DataType inputType = expr.getDataType();
DataType targetType = DataType.fromCatalogType(table.getFullSchema().get(i).getType());
Expression castExpr = expr;
Expand Down Expand Up @@ -309,17 +307,6 @@ private List<Column> bindTargetColumns(OlapTable table, List<String> colsName, b
}).collect(ImmutableList.toImmutableList());
}

private void maybeFallbackCastUnsupportedType(Expression expression, ConnectContext ctx) {
if (expression.getDataType().isMapType()) {
try {
ctx.getSessionVariable().enableFallbackToOriginalPlannerOnce();
} catch (Exception e) {
throw new AnalysisException("failed to try to fall back to original planner");
}
throw new AnalysisException("failed to cast type when binding sink, type is: " + expression.getDataType());
}
}

private boolean isSourceAndTargetStringLikeType(DataType input, DataType target) {
return input.isStringLikeType() && target.isStringLikeType();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
Expand Down Expand Up @@ -167,39 +164,14 @@ private void registerFEFunction(ImmutableMultimap.Builder<String, FunctionInvoke
DataType returnType = DataType.convertFromString(annotation.returnType());
List<DataType> argTypes = new ArrayList<>();
for (String type : annotation.argTypes()) {
argTypes.add(replaceDecimalV3WithWildcard(DataType.convertFromString(type)));
argTypes.add(TypeCoercionUtils.replaceDecimalV3WithWildcard(DataType.convertFromString(type)));
}
FunctionSignature signature = new FunctionSignature(name,
argTypes.toArray(new DataType[0]), returnType);
mapBuilder.put(name, new FunctionInvoker(method, signature));
}
}

private DataType replaceDecimalV3WithWildcard(DataType input) {
if (input instanceof ArrayType) {
DataType item = replaceDecimalV3WithWildcard(((ArrayType) input).getItemType());
if (item == ((ArrayType) input).getItemType()) {
return input;
}
return ArrayType.of(item);
} else if (input instanceof MapType) {
DataType keyType = replaceDecimalV3WithWildcard(((MapType) input).getKeyType());
DataType valueType = replaceDecimalV3WithWildcard(((MapType) input).getValueType());
if (keyType == ((MapType) input).getKeyType() && valueType == ((MapType) input).getValueType()) {
return input;
}
return MapType.of(keyType, valueType);
} else if (input instanceof StructType) {
// TODO: support struct type
return input;
} else {
if (input instanceof DecimalV3Type) {
return DecimalV3Type.WILDCARD;
}
return input;
}
}

/**
* function invoker.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,27 @@ public class ComputeSignatureHelper {
/** implementAbstractReturnType */
public static FunctionSignature implementFollowToArgumentReturnType(
FunctionSignature signature, List<Expression> arguments) {
if (signature.returnType instanceof FollowToArgumentType) {
int argumentIndex = ((FollowToArgumentType) signature.returnType).argumentIndex;
return signature.withReturnType(arguments.get(argumentIndex).getDataType());
return signature.withReturnType(replaceFollowToArgumentReturnType(
signature.returnType, signature.argumentsTypes));
}

private static DataType replaceFollowToArgumentReturnType(DataType returnType, List<DataType> argumentTypes) {
if (returnType instanceof ArrayType) {
return ArrayType.of(replaceFollowToArgumentReturnType(
((ArrayType) returnType).getItemType(), argumentTypes));
} else if (returnType instanceof MapType) {
return MapType.of(replaceFollowToArgumentReturnType(((MapType) returnType).getKeyType(), argumentTypes),
replaceFollowToArgumentReturnType(((MapType) returnType).getValueType(), argumentTypes));
} else if (returnType instanceof StructType) {
// TODO: do not support struct type now
// throw new AnalysisException("do not support struct type now");
return returnType;
} else if (returnType instanceof FollowToArgumentType) {
int argumentIndex = ((FollowToArgumentType) returnType).argumentIndex;
return argumentTypes.get(argumentIndex);
} else {
return returnType;
}
return signature;
}

private static DataType replaceAnyDataTypeWithOutIndex(DataType sigType, DataType expressionType) {
Expand Down Expand Up @@ -308,10 +324,10 @@ public static FunctionSignature computePrecision(
if (computeSignature instanceof ComputePrecision) {
return ((ComputePrecision) computeSignature).computePrecision(signature);
}
if (signature.argumentsTypes.stream().anyMatch(DateTimeV2Type.class::isInstance)) {
if (signature.argumentsTypes.stream().anyMatch(TypeCoercionUtils::hasDateTimeV2Type)) {
signature = defaultDateTimeV2PrecisionPromotion(signature, arguments);
}
if (signature.argumentsTypes.stream().anyMatch(DecimalV3Type.class::isInstance)) {
if (signature.argumentsTypes.stream().anyMatch(TypeCoercionUtils::hasDecimalV3Type)) {
// do decimal v3 precision
signature = defaultDecimalV3PrecisionPromotion(signature, arguments);
}
Expand Down Expand Up @@ -354,30 +370,34 @@ private static FunctionSignature defaultDateTimeV2PrecisionPromotion(
} else {
targetType = signature.getArgType(i);
}
if (!(targetType instanceof DateTimeV2Type)) {
List<DataType> argTypes = extractArgumentType(DateTimeV2Type.class,
targetType, arguments.get(i).getDataType());
if (argTypes.isEmpty()) {
continue;
}
if (finalType == null) {
if (arguments.get(i) instanceof StringLikeLiteral) {
// We need to determine the scale based on the string literal.

for (DataType argType : argTypes) {
Expression arg = arguments.get(i);
DateTimeV2Type dateTimeV2Type;
if (arg instanceof StringLikeLiteral) {
StringLikeLiteral str = (StringLikeLiteral) arguments.get(i);
finalType = DateTimeV2Type.forTypeFromString(str.getStringValue());
dateTimeV2Type = DateTimeV2Type.forTypeFromString(str.getStringValue());
} else {
finalType = DateTimeV2Type.forType(arguments.get(i).getDataType());
dateTimeV2Type = DateTimeV2Type.forType(argType);
}
if (finalType == null) {
finalType = dateTimeV2Type;
} else {
finalType = DateTimeV2Type.getWiderDatetimeV2Type(finalType,
DateTimeV2Type.forType(arguments.get(i).getDataType()));
}
} else {
finalType = DateTimeV2Type.getWiderDatetimeV2Type(finalType,
DateTimeV2Type.forType(arguments.get(i).getDataType()));
}
}
DateTimeV2Type argType = finalType;
List<DataType> newArgTypes = signature.argumentsTypes.stream().map(t -> {
if (t instanceof DateTimeV2Type) {
return argType;
} else {
return t;
}
}).collect(Collectors.toList());
List<DataType> newArgTypes = signature.argumentsTypes.stream()
.map(at -> TypeCoercionUtils.replaceDateTimeV2WithTarget(at, argType))
.collect(Collectors.toList());
signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes);
signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes);
if (signature.returnType instanceof DateTimeV2Type) {
signature = signature.withReturnType(argType);
Expand All @@ -387,7 +407,7 @@ private static FunctionSignature defaultDateTimeV2PrecisionPromotion(

private static FunctionSignature defaultDecimalV3PrecisionPromotion(
FunctionSignature signature, List<Expression> arguments) {
DataType finalType = null;
DecimalV3Type finalType = null;
for (int i = 0; i < arguments.size(); i++) {
DataType targetType;
if (i >= signature.argumentsTypes.size()) {
Expand All @@ -397,37 +417,32 @@ private static FunctionSignature defaultDecimalV3PrecisionPromotion(
} else {
targetType = signature.getArgType(i);
}
if (!(targetType instanceof DecimalV3Type)) {
continue;
}
// only process wildcard decimalv3
if (((DecimalV3Type) targetType).getPrecision() > 0) {
List<DataType> argTypes = extractArgumentType(DecimalV3Type.class,
targetType, arguments.get(i).getDataType());
if (argTypes.isEmpty()) {
continue;
}
if (finalType == null) {
finalType = DecimalV3Type.forType(arguments.get(i).getDataType());
} else {

for (DataType argType : argTypes) {
Expression arg = arguments.get(i);
DecimalV3Type argType;
DecimalV3Type decimalV3Type;
if (arg.isLiteral() && arg.getDataType().isIntegralType()) {
// create decimalV3 with minimum scale enough to hold the integral literal
argType = DecimalV3Type.createDecimalV3Type(new BigDecimal(((Literal) arg).getStringValue()));
decimalV3Type = DecimalV3Type.createDecimalV3Type(new BigDecimal(((Literal) arg).getStringValue()));
} else {
decimalV3Type = DecimalV3Type.forType(argType);
}
if (finalType == null) {
finalType = decimalV3Type;
} else {
argType = DecimalV3Type.forType(arg.getDataType());
finalType = (DecimalV3Type) DecimalV3Type.widerDecimalV3Type(finalType, decimalV3Type, false);
}
finalType = DecimalV3Type.widerDecimalV3Type((DecimalV3Type) finalType, argType, true);
}
Preconditions.checkState(finalType.isDecimalV3Type(), "decimalv3 precision promotion failed.");
}
DataType argType = finalType;
List<DataType> newArgTypes = signature.argumentsTypes.stream().map(t -> {
// only process wildcard decimalv3
if (t instanceof DecimalV3Type && ((DecimalV3Type) t).getPrecision() <= 0) {
return argType;
} else {
return t;
}
}).collect(Collectors.toList());
DecimalV3Type argType = finalType;
List<DataType> newArgTypes = signature.argumentsTypes.stream()
.map(at -> TypeCoercionUtils.replaceDecimalV3WithTarget(at, argType))
.collect(Collectors.toList());
signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes);
if (signature.returnType instanceof DecimalV3Type
&& ((DecimalV3Type) signature.returnType).getPrecision() <= 0) {
Expand All @@ -436,6 +451,42 @@ private static FunctionSignature defaultDecimalV3PrecisionPromotion(
return signature;
}

private static List<DataType> extractArgumentType(Class<? extends DataType> targetType,
DataType signatureType, DataType argumentType) {
if (targetType.isAssignableFrom(signatureType.getClass())) {
return Lists.newArrayList(argumentType);
} else if (signatureType instanceof ArrayType) {
if (argumentType instanceof NullType) {
return extractArgumentType(targetType, ((ArrayType) signatureType).getItemType(), argumentType);
} else if (argumentType instanceof ArrayType) {
return extractArgumentType(targetType,
((ArrayType) signatureType).getItemType(), ((ArrayType) argumentType).getItemType());
} else {
return Lists.newArrayList();
}
} else if (signatureType instanceof MapType) {
if (argumentType instanceof NullType) {
List<DataType> ret = extractArgumentType(targetType,
((MapType) signatureType).getKeyType(), argumentType);
ret.addAll(extractArgumentType(targetType, ((MapType) signatureType).getValueType(), argumentType));
return ret;
} else if (argumentType instanceof MapType) {
List<DataType> ret = extractArgumentType(targetType,
((MapType) signatureType).getKeyType(), ((MapType) argumentType).getKeyType());
ret.addAll(extractArgumentType(targetType,
((MapType) signatureType).getValueType(), ((MapType) argumentType).getValueType()));
return ret;
} else {
return Lists.newArrayList();
}
} else if (signatureType instanceof StructType) {
// TODO: do not support struct type now
return Lists.newArrayList();
} else {
return Lists.newArrayList();
}
}

static class ComputeSignatureChain {
private final ResponsibilityChain<SignatureContext> computeChain;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.FollowToArgumentType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -79,9 +80,27 @@ public List<FunctionSignature> getSignatures() {
.map(TypeCoercionUtils::replaceCharacterToString)
.collect(Collectors.toList());
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDecimalV2Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDecimalV2WithDefault).collect(Collectors.toList()));
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDecimalV3Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDecimalV3WithWildcard).collect(Collectors.toList()));
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDateTimeV2Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDateTimeV2WithMax).collect(Collectors.toList()));
}
needTypeCoercion.addAll(partitioned.get(false));
return needTypeCoercion.stream()
.map(dataType -> FunctionSignature.ret(ArrayType.of(dataType)).varArgs(dataType))
.map(dataType -> FunctionSignature.ret(ArrayType.of(new FollowToArgumentType(0))).varArgs(dataType))
.collect(ImmutableList.toImmutableList());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import org.apache.doris.nereids.types.HllType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TinyIntType;
Expand All @@ -55,9 +57,8 @@ public class If extends ScalarFunction
implements TernaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(1)
.args(BooleanType.INSTANCE, ArrayType.of(new AnyDataType(0)),
ArrayType.of(new AnyDataType(0))),
FunctionSignature.ret(NullType.INSTANCE)
.args(BooleanType.INSTANCE, NullType.INSTANCE, NullType.INSTANCE),
FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT)
.args(BooleanType.INSTANCE, DateTimeV2Type.SYSTEM_DEFAULT, DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DateV2Type.INSTANCE)
Expand Down Expand Up @@ -88,6 +89,15 @@ public class If extends ScalarFunction
FunctionSignature.ret(BitmapType.INSTANCE)
.args(BooleanType.INSTANCE, BitmapType.INSTANCE, BitmapType.INSTANCE),
FunctionSignature.ret(HllType.INSTANCE).args(BooleanType.INSTANCE, HllType.INSTANCE, HllType.INSTANCE),
FunctionSignature.retArgType(1)
.args(BooleanType.INSTANCE, ArrayType.of(new AnyDataType(0)),
ArrayType.of(new AnyDataType(0))),
FunctionSignature.retArgType(1)
.args(BooleanType.INSTANCE, MapType.of(new AnyDataType(0), new AnyDataType(1)),
MapType.of(new AnyDataType(0), new AnyDataType(1))),
FunctionSignature.retArgType(1)
.args(BooleanType.INSTANCE, new AnyDataType(0), new AnyDataType(0)),
// NOTICE string must at least of signature list, because all complex type could implicit cast to string
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(BooleanType.INSTANCE, VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE)
Expand Down
Loading

0 comments on commit 2c6d225

Please sign in to comment.