Skip to content

Commit

Permalink
[improve](fold) support complex type for constant folding
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangstar333 committed Mar 26, 2024
1 parent c8825a4 commit 4bebd11
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 49 deletions.
3 changes: 3 additions & 0 deletions be/src/runtime/fold_constant_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ Status FoldConstantExecutor::fold_constant_vexpr(const TFoldConstantParams& para
ctx->root()->type(), column_ptr, column_type, result));
}

PTypeDesc* p_type = expr_result.mutable_p_type();
res_type.to_protobuf(p_type);
expr_result.set_content(std::move(result));
//maybe could remove this field, all of version use type_desc field
expr_result.mutable_type()->set_type(t_type);
expr_result.mutable_type()->set_scale(res_type.scale);
expr_result.mutable_type()->set_precision(res_type.precision);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.Pair;
import org.apache.doris.common.UserException;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.common.util.TimeUtils;
Expand All @@ -35,15 +36,17 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sleep;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NumericLiteral;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.StructField;
import org.apache.doris.nereids.types.StructType;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.InternalService.PConstantExprResult;
import org.apache.doris.proto.Types.PScalarType;
import org.apache.doris.proto.Types.PStructField;
import org.apache.doris.proto.Types.PTypeDesc;
import org.apache.doris.proto.Types.PTypeNode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rpc.BackendServiceProxy;
import org.apache.doris.system.Backend;
Expand All @@ -54,7 +57,6 @@
import org.apache.doris.thrift.TQueryGlobals;
import org.apache.doris.thrift.TQueryOptions;

import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -66,7 +68,6 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -216,43 +217,11 @@ private Map<String, Expression> evalOnBE(Map<String, Map<String, TExpr>> paramMa
if (result.getStatus().getStatusCode() == 0) {
for (Entry<String, InternalService.PExprResultMap> e : result.getExprResultMapMap().entrySet()) {
for (Entry<String, InternalService.PExprResult> e1 : e.getValue().getMapMap().entrySet()) {
PScalarType pScalarType = e1.getValue().getType();
TPrimitiveType tPrimitiveType = TPrimitiveType.findByValue(pScalarType.getType());
PrimitiveType primitiveType = PrimitiveType.fromThrift(Objects.requireNonNull(tPrimitiveType));
Expression ret;
if (e1.getValue().getSuccess()) {
DataType type;
if (PrimitiveType.ARRAY == primitiveType
|| PrimitiveType.MAP == primitiveType
|| PrimitiveType.STRUCT == primitiveType
|| PrimitiveType.AGG_STATE == primitiveType) {
ret = constMap.get(e1.getKey());
} else {
if (primitiveType == PrimitiveType.CHAR) {
Preconditions.checkState(pScalarType.hasLen(),
"be return char type without len");
type = CharType.createCharType(pScalarType.getLen());
} else if (primitiveType == PrimitiveType.VARCHAR) {
Preconditions.checkState(pScalarType.hasLen(),
"be return varchar type without len");
type = VarcharType.createVarcharType(pScalarType.getLen());
} else if (primitiveType == PrimitiveType.DECIMALV2) {
type = DecimalV2Type.createDecimalV2Type(
pScalarType.getPrecision(), pScalarType.getScale());
} else if (primitiveType == PrimitiveType.DATETIMEV2) {
type = DateTimeV2Type.of(pScalarType.getScale());
} else if (primitiveType == PrimitiveType.DECIMAL32
|| primitiveType == PrimitiveType.DECIMAL64
|| primitiveType == PrimitiveType.DECIMAL128
|| primitiveType == PrimitiveType.DECIMAL256) {
type = DecimalV3Type.createDecimalV3TypeLooseCheck(
pScalarType.getPrecision(), pScalarType.getScale());
} else {
type = DataType.fromCatalogType(ScalarType.createType(
PrimitiveType.fromThrift(tPrimitiveType)));
}
ret = Literal.of(e1.getValue().getContent()).castTo(type);
}
PTypeDesc pTypeDesc = e1.getValue().getPType();
DataType type = convertToNereidsType(pTypeDesc.getTypesList(), 0).key();
ret = Literal.of(e1.getValue().getContent()).castTo(type);
} else {
ret = constMap.get(e1.getKey());
}
Expand All @@ -262,7 +231,6 @@ private Map<String, Expression> evalOnBE(Map<String, Map<String, TExpr>> paramMa
resultMap.put(e1.getKey(), ret);
}
}

} else {
LOG.warn("query {} failed to get const expr value from be: {}",
DebugUtil.printId(context.queryId()), result.getStatus().getErrorMsgsList());
Expand All @@ -273,4 +241,38 @@ private Map<String, Expression> evalOnBE(Map<String, Map<String, TExpr>> paramMa
}
return resultMap;
}

private Pair<DataType, Integer> convertToNereidsType(List<PTypeNode> typeNodes, int start) {
PScalarType pScalarType = typeNodes.get(start).getScalarType();
TPrimitiveType tPrimitiveType = TPrimitiveType.findByValue(pScalarType.getType());
DataType type;
int parsedNodes;
if (tPrimitiveType == TPrimitiveType.ARRAY) {
Pair<DataType, Integer> itemType = convertToNereidsType(typeNodes, start + 1);
type = ArrayType.of(itemType.key(), true);
parsedNodes = 1 + itemType.value();
} else if (tPrimitiveType == TPrimitiveType.MAP) {
Pair<DataType, Integer> keyType = convertToNereidsType(typeNodes, start + 1);
Pair<DataType, Integer> valueType = convertToNereidsType(typeNodes, start + 1 + keyType.value());
type = MapType.of(keyType.key(), valueType.key());
parsedNodes = 1 + keyType.value() + valueType.value();
} else if (tPrimitiveType == TPrimitiveType.STRUCT) {
parsedNodes = 1;
ArrayList<StructField> fields = new ArrayList<>();
for (int i = 0; i < typeNodes.get(start).getStructFieldsCount(); ++i) {
Pair<DataType, Integer> fieldType = convertToNereidsType(typeNodes, start + parsedNodes);
PStructField structField = typeNodes.get(start).getStructFields(i);
fields.add(new StructField(structField.getName(), fieldType.key(),
structField.getContainsNull(),
structField.getComment() == null ? "" : structField.getComment()));
parsedNodes += fieldType.value();
}
type = new StructType(fields);
} else {
type = DataType.fromCatalogType(ScalarType.createType(PrimitiveType.fromThrift(tPrimitiveType),
pScalarType.getLen(), pScalarType.getPrecision(), pScalarType.getScale()));
parsedNodes = 1;
}
return Pair.of(type, parsedNodes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.google.common.collect.ImmutableList;
import org.springframework.util.CollectionUtils;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -57,6 +58,27 @@ public ArrayLiteral(List<Literal> items, DataType dataType) {
this.items = ImmutableList.copyOf(Objects.requireNonNull(items, "items should not null"));
}

/**
* eg: "[10, 11, 12, 13, 14, 15, 16]" array_int
* eg: "['world', 'c++']" array_string
*/
public ArrayLiteral(String str, DataType dataType) {
super(dataType);
Preconditions.checkArgument(dataType instanceof ArrayType,
"dataType should be ArrayType, but we meet %s, and str is %s", dataType, str);
DataType nestedType = ((ArrayType) dataType).getItemType();
this.items = new ArrayList<>();
//str maybe empty need handle
String[] parts = str.substring(1, str.length() - 1).split(", ");
for (String s : parts) {
if (nestedType.isStringLikeType()) {
s = s.substring(1, s.length() - 1); // 'world'----> world
}
StringLiteral strLiteral = new StringLiteral(s);
this.items.add((Literal) strLiteral.uncheckedCastTo(nestedType));
}
}

@Override
public List<Literal> getValue() {
return items;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ protected Expression uncheckedCastTo(DataType targetType) throws AnalysisExcepti
return new IPv4Literal(desc);
} else if (targetType.isIPv6Type()) {
return new IPv6Literal(desc);
} else if (targetType.isArrayType()) {
return new ArrayLiteral(desc, targetType);
} else if (targetType.isMapType()) {
return new MapLiteral(desc, targetType);
} else if (targetType.isStructType()) {
return new StructLiteral(desc, targetType);
}
throw new AnalysisException("cannot cast " + desc + " from type " + this.dataType + " to type " + targetType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -58,6 +59,36 @@ private MapLiteral(List<Literal> keys, List<Literal> values, DataType dataType)
"key size %s is not equal to value size %s", keys.size(), values.size());
}

/**
* eg: {"k11":1000, "k22":2000} string: int
*/
public MapLiteral(String str, DataType dataType) {
super(dataType);
Preconditions.checkArgument(dataType instanceof MapType,
"dataType should be MapType, but we meet %s", dataType);
DataType nestedKeyType = ((MapType) dataType).getKeyType();
DataType nestedValueType = ((MapType) dataType).getValueType();
this.keys = new ArrayList<>();
this.values = new ArrayList<>();
String[] elements = str.substring(1, str.length() - 1).split(", ");
for (String element : elements) {
String[] keyValue = element.split(":");
Preconditions.checkArgument(keyValue.length == 2, "error key value map is %s", keyValue.toString());
String key = keyValue[0];
String value = keyValue[1];
if (nestedKeyType.isStringLikeType()) {
key = key.substring(1, key.length() - 1); // "k11"----> k11
}
if (nestedValueType.isStringLikeType()) {
value = value.substring(1, value.length() - 1);
}
StringLiteral keyLiteral = new StringLiteral(key);
StringLiteral valueLiteral = new StringLiteral(value);
this.keys.add((Literal) keyLiteral.uncheckedCastTo(nestedKeyType));
this.values.add((Literal) valueLiteral.uncheckedCastTo(nestedValueType));
}
}

@Override
public List<List<Literal>> getValue() {
return ImmutableList.of(keys, values);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
Expand All @@ -49,6 +50,26 @@ public StructLiteral(List<Literal> fields) {
this(fields, computeDataType(fields));
}

/**
* eg: {1, a, abc} int, string, string
*/
public StructLiteral(String str, DataType dataType) {
super(dataType);
Preconditions.checkArgument(dataType instanceof StructType,
"dataType should be StructType, but we meet %s", dataType);
this.fields = new ArrayList<>();
StructType structType = (StructType) dataType;
String[] parts = str.substring(1, str.length() - 1).split(", ");
Preconditions.checkArgument(parts.length == structType.getFields().size(),
"parts length is not same with structType size. %s vs %s",
parts.length, structType.getFields().size());
for (int i = 0; i < parts.length; i++) {
DataType fieldDataType = structType.getFields().get(i).getDataType();
StringLiteral strLiteral = new StringLiteral(parts[i]);
this.fields.add((Literal) strLiteral.uncheckedCastTo(fieldDataType));
}
}

private StructLiteral(List<Literal> fields, DataType dataType) {
super(dataType);
this.fields = ImmutableList.copyOf(Objects.requireNonNull(fields, "fields should not be null"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1802,12 +1802,6 @@ public void initFuzzyModeVariables() {
default:
break;
}
randomInt = random.nextInt(2);
if (randomInt % 2 == 0) {
this.enableFoldConstantByBe = false;
} else {
this.enableFoldConstantByBe = true;
}

switch (Config.pull_request_id % 3) {
case 0:
Expand Down Expand Up @@ -1845,8 +1839,10 @@ public void initFuzzyModeVariables() {
if (Config.pull_request_id > 0) {
if (Config.pull_request_id % 2 == 1) {
this.batchSize = 4064;
this.enableFoldConstantByBe = true;
} else {
this.batchSize = 50;
this.enableFoldConstantByBe = false;
}
}
}
Expand Down
1 change: 1 addition & 0 deletions gensrc/proto/internal_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ message PExprResult {
required PScalarType type = 1;
required string content = 2;
required bool success = 3;
optional PTypeDesc p_type = 4; // want convert complex types: array map struct
};

message PExprResultMap {
Expand Down

0 comments on commit 4bebd11

Please sign in to comment.