diff --git a/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java b/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java index 6d245aa4f..e309313e8 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java @@ -5,6 +5,8 @@ */ package com.linkedin.coral.common.transformers; +import java.util.List; + import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.type.RelDataType; @@ -66,6 +68,10 @@ protected RelDataType deriveRelDatatype(SqlNode sqlNode) { return typeDerivationUtil.getRelDataType(sqlNode); } + protected RelDataType leastRestrictive(List types) { + return typeDerivationUtil.leastRestrictive(types); + } + /** * This function creates a {@link SqlOperator} for a function with the function name and return type inference. */ diff --git a/coral-common/src/main/java/com/linkedin/coral/common/utils/TypeDerivationUtil.java b/coral-common/src/main/java/com/linkedin/coral/common/utils/TypeDerivationUtil.java index c9b181198..e0ba27c51 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/utils/TypeDerivationUtil.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/utils/TypeDerivationUtil.java @@ -88,6 +88,10 @@ public RelDataType getRelDataType(SqlNode sqlNode) { sqlNode, topSelectNodes.get(0))); } + public RelDataType leastRestrictive(List types) { + return sqlValidator.getTypeFactory().leastRestrictive(types); + } + private class SqlNodePreprocessorForTypeDerivation extends SqlShuttle { @Override public SqlNode visit(SqlCall sqlCall) { diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java index 546edec86..056120e91 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java @@ -21,6 +21,7 @@ import com.linkedin.coral.trino.rel2trino.transformers.GenericProjectTransformer; import com.linkedin.coral.trino.rel2trino.transformers.NamedStructToCastTransformer; import com.linkedin.coral.trino.rel2trino.transformers.SubstrOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transformers.UnionSqlCallTransformer; /** @@ -42,7 +43,8 @@ public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode to TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(toRelConverter.getSqlValidator(), topSqlNode); operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil), new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil), - new ConcatOperatorTransformer(typeDerivationUtil), new SubstrOperatorTransformer(typeDerivationUtil)); + new ConcatOperatorTransformer(typeDerivationUtil), new SubstrOperatorTransformer(typeDerivationUtil), + new UnionSqlCallTransformer(typeDerivationUtil)); } @Override diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/UnionSqlCallTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/UnionSqlCallTransformer.java new file mode 100644 index 000000000..32b5e9383 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/UnionSqlCallTransformer.java @@ -0,0 +1,165 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transformers; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlBasicTypeNameSpec; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlTypeNameSpec; +import org.apache.calcite.sql.fun.SqlCastFunction; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; + +import com.linkedin.coral.common.transformers.SqlCallTransformer; +import com.linkedin.coral.common.utils.TypeDerivationUtil; + +import static org.apache.calcite.rel.rel2sql.SqlImplementor.POS; +import static org.apache.calcite.sql.parser.SqlParserPos.ZERO; + + +/** + * This transformer class is used to adapt expressions in + * set statements (e.g.: `UNION`, `INTERSECT`, `MINUS`) + * in case that the branches of the set statement contain fields in + * char family which have different types. + * The `char` fields which are differing from the expected `varchar` output + * of the set statement will be adapted through an explicit `CAST` to the `varchar` type. + * + * @see Change char/varchar coercion in Trino + */ +public class UnionSqlCallTransformer extends SqlCallTransformer { + + public UnionSqlCallTransformer(TypeDerivationUtil typeDerivationUtil) { + super(typeDerivationUtil); + } + + @Override + protected boolean condition(SqlCall sqlCall) { + return sqlCall.getOperator().kind == SqlKind.UNION || sqlCall.getOperator().kind == SqlKind.INTERSECT + || sqlCall.getOperator().kind == SqlKind.MINUS; + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List operandsList = sqlCall.getOperandList(); + List>> columnsTypesLists = new ArrayList<>(); + for (SqlNode operand : operandsList) { + if (operand.getKind() == SqlKind.SELECT && ((SqlSelect) operand).getSelectList() != null) { + SqlSelect select = (SqlSelect) operand; + List selectNodes = select.getSelectList().getList(); + if (columnsTypesLists.isEmpty()) { + for (int i = 0; i < selectNodes.size(); i++) { + columnsTypesLists.add(new ArrayList<>()); + } + } + + for (int i = 0; i < selectNodes.size(); i++) { + SqlNode sqlNode = selectNodes.get(i); + if (sqlNode.getKind() == SqlKind.IDENTIFIER && ((SqlIdentifier) sqlNode).isStar()) { + // The type derivation of * on the SqlNode layer is not supported now. + return sqlCall; + } + + Optional selectNodeDataType; + try { + selectNodeDataType = Optional.of(deriveRelDatatype(sqlNode)); + } catch (RuntimeException e) { + // The type derivation may fail for complex expressions + selectNodeDataType = Optional.empty(); + } + columnsTypesLists.get(i).add(selectNodeDataType); + } + } else { + return sqlCall; + } + } + + List updatedOperands = new ArrayList<>(sqlCall.getOperandList().size()); + for (SqlNode operand : operandsList) { + SqlSelect select = (SqlSelect) operand; + + List selectNodes = select.getSelectList().getList(); + + List rewrittenSelectNodes = new ArrayList<>(); + boolean useRewrittenSelectNodes = false; + for (int i = 0; i < selectNodes.size(); i++) { + + if (columnsTypesLists.get(i).stream().anyMatch(columnType -> !columnType.isPresent())) { + // Couldn't determine the type for all the expressions + continue; + } + RelDataType inferredColumnType = + leastRestrictive(columnsTypesLists.get(i).stream().map(Optional::get).collect(Collectors.toList())); + + SqlNode selectNode = selectNodes.get(i); + RelDataType selectNodeDataType = deriveRelDatatype(selectNodes.get(i)); + if (!selectNodeDataType.equals(inferredColumnType) && selectNodeDataType.getSqlTypeName() == SqlTypeName.CHAR + && inferredColumnType.getSqlTypeName() == SqlTypeName.VARCHAR) { + // Work-around for the Trino limitation in dealing UNION statements between `char` and `varchar`. + // See https://github.com/trinodb/trino/issues/9031 + SqlNode rewrittenSelectNode = castNode(selectNode, inferredColumnType); + if (!useRewrittenSelectNodes) { + rewrittenSelectNodes.addAll(selectNodes.subList(0, i)); + useRewrittenSelectNodes = true; + } + rewrittenSelectNodes.add(rewrittenSelectNode); + } else if (useRewrittenSelectNodes) { + rewrittenSelectNodes.add(selectNode); + } + } + + if (useRewrittenSelectNodes) { + select.setSelectList(new SqlNodeList(rewrittenSelectNodes, SqlParserPos.ZERO)); + } + updatedOperands.add(operand); + } + return sqlCall.getOperator().createCall(POS, updatedOperands); + } + + private SqlNode castNode(SqlNode node, RelDataType type) { + if (node.getKind() == SqlKind.AS) { + SqlNode expression = ((SqlCall) node).getOperandList().get(0); + SqlIdentifier identifier = (SqlIdentifier) ((SqlCall) node).getOperandList().get(1); + return SqlStdOperatorTable.AS.createCall(POS, new SqlCastFunction() { + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + SqlCallBinding opBinding = new SqlCallBinding(validator, scope, call); + return inferReturnType(opBinding); + } + }.createCall(ZERO, expression, getSqlDataTypeSpecForCasting(type)), identifier); + } else { + // If there's no existing alias, just do the cast + return new SqlCastFunction() { + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + SqlCallBinding opBinding = new SqlCallBinding(validator, scope, call); + return inferReturnType(opBinding); + } + }.createCall(ZERO, node, getSqlDataTypeSpecForCasting(type)); + } + } + + private static SqlDataTypeSpec getSqlDataTypeSpecForCasting(RelDataType relDataType) { + final SqlTypeNameSpec typeNameSpec = new SqlBasicTypeNameSpec(relDataType.getSqlTypeName(), + relDataType.getPrecision(), relDataType.getScale(), null, ZERO); + return new SqlDataTypeSpec(typeNameSpec, ZERO); + } +} diff --git a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java index 4d37a65ef..c49eb6f7b 100644 --- a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java +++ b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java @@ -193,7 +193,50 @@ public Object[][] viewTestCasesProvider() { + "FROM \"test\".\"duplicate_column_name_a\" AS \"duplicate_column_name_a\"\n" + "LEFT JOIN (SELECT TRIM(\"duplicate_column_name_b\".\"some_id\") AS \"SOME_ID\", CAST(TRIM(\"duplicate_column_name_b\".\"some_id\") AS VARCHAR(65536)) AS \"$f1\"\n" + "FROM \"test\".\"duplicate_column_name_b\" AS \"duplicate_column_name_b\") AS \"t\" ON \"duplicate_column_name_a\".\"some_id\" = \"t\".\"$f1\") AS \"t0\"\n" - + "WHERE \"t0\".\"some_id\" <> ''" } }; + + "WHERE \"t0\".\"some_id\" <> ''" }, + + { "test", "view_char_different_size_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(255)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT CAST(\"table_with_mixed_columns0\".\"a_char255\" AS VARCHAR(255)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" }, + + { "test", "view_cast_char_to_varchar", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"" }, + + { "test", "view_cast_char_to_varchar_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT CAST(CASE WHEN \"table_with_mixed_columns0\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns0\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" }, + + { "test", "view_cast_char_to_varchar_in_union_flipped", "SELECT CAST(CASE WHEN \"table_with_mixed_columns\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT CAST(\"table_with_mixed_columns0\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" }, + + { "test", "view_cast_char_to_varchar_with_other_fields_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"text\", \"table_with_mixed_columns\".\"a_boolean\" AS \"a_boolean\", \"table_with_mixed_columns\".\"a_smallint\" AS \"a_number\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT CAST(CASE WHEN \"table_with_mixed_columns0\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns0\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"text\", \"table_with_mixed_columns0\".\"a_boolean\" AS \"a_boolean\", \"table_with_mixed_columns0\".\"a_integer\" AS \"a_number\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" }, + + { "test", "view_char_and_null_in_union", "SELECT \"table_with_mixed_columns\".\"a_char1\" AS \"text\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT NULL AS \"text\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" }, + + { "test", "view_different_numerical_types_in_union", "SELECT *\n" + "FROM (SELECT *\n" + + "FROM (SELECT \"table_with_mixed_columns\".\"a_tinyint\" AS \"a_number\", \"table_with_mixed_columns\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT \"table_with_mixed_columns0\".\"a_smallint\" AS \"a_number\", \"table_with_mixed_columns0\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\") AS \"t1\"\n" + "UNION ALL\n" + + "SELECT \"table_with_mixed_columns1\".\"a_integer\" AS \"a_number\", \"table_with_mixed_columns1\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns1\") AS \"t3\"\n" + "UNION ALL\n" + + "SELECT \"table_with_mixed_columns2\".\"a_bigint\" AS \"a_number\", \"table_with_mixed_columns2\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns2\"" }, + + { "test", "view_union_no_casting", "SELECT \"table_with_mixed_columns\".\"a_tinyint\" AS \"a_tinyint\", \"table_with_mixed_columns\".\"a_smallint\" AS \"a_smallint\", \"table_with_mixed_columns\".\"a_integer\" AS \"a_integer\", \"table_with_mixed_columns\".\"a_bigint\" AS \"a_bigint\", \"table_with_mixed_columns\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT \"table_with_mixed_columns0\".\"a_tinyint\" AS \"a_tinyint\", \"table_with_mixed_columns0\".\"a_smallint\" AS \"a_smallint\", \"table_with_mixed_columns0\".\"a_integer\" AS \"a_integer\", \"table_with_mixed_columns0\".\"a_bigint\" AS \"a_bigint\", \"table_with_mixed_columns0\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" } }; } @Test diff --git a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/TestUtils.java b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/TestUtils.java index 2c6bb0f97..0409333ff 100644 --- a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/TestUtils.java +++ b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/TestUtils.java @@ -378,6 +378,42 @@ public static void initializeTablesAndViews(HiveConf conf) throws HiveException, run(driver, "CREATE TABLE test.table_with_binary_column (b binary)"); + run(driver, + "CREATE TABLE test.table_with_mixed_columns (a_char1 char(1), a_char255 char(255), a_string string, a_tinyint tinyint, a_smallint smallint, a_integer int, a_bigint bigint, a_float float, a_double double, a_boolean boolean)"); + run(driver, "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar AS \n" + + "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_char_different_size_in_union AS \n" + + "SELECT a_char1 AS col FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT a_char255 AS col FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_in_union AS \n" + + "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT COALESCE(a_char1, 'N') AS col FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_in_union_flipped AS \n" + + "SELECT COALESCE(a_char1, 'N') as col FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns"); + run(driver, "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_with_other_fields_in_union AS \n" + + "SELECT CAST(a_char1 AS VARCHAR(65535)) AS text , a_boolean, a_smallint as a_number FROM test.table_with_mixed_columns \n" + + "UNION ALL\n" + + "SELECT COALESCE(a_char1, 'N') as text, a_boolean, a_integer as a_number FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_char_and_null_in_union AS \n" + + "SELECT a_char1 as text FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT NULL text FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_different_numerical_types_in_union AS \n" + + "SELECT a_tinyint AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT a_smallint AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT a_integer AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT a_bigint AS a_number, a_float FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_union_no_casting AS \n" + + "SELECT a_tinyint, a_smallint, a_integer, a_bigint, a_float FROM test.table_with_mixed_columns \n" + + "UNION ALL\n" + + "SELECT a_tinyint, a_smallint, a_integer, a_bigint, a_float FROM test.table_with_mixed_columns"); + // Tables used in RelToTrinoConverterTest run(driver, "CREATE TABLE IF NOT EXISTS test.tableOne(icol int, dcol double, scol string, tcol timestamp, acol array)");