From 4bb38a120051a58ab8b06b7865e279bfb05da209 Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Thu, 5 Oct 2023 03:35:03 +0200 Subject: [PATCH] [Coral-Trino] Cast char fields, if necessary, to varchar type in the set operation (#442) In case of dealing with Hive views which make use of set operation (e.g. UNION) ensure that the `char` fields from the inner SELECT statements have the same type as the output field types of the set operation. Due to wrong coercion between `varchar` and `char` in Trino, as described in https://github.com/trinodb/trino/issues/9031 , a work-around needs to be applied in case of translating Hive views which contain a UNION dealing with char and varchar types. The work-around consists in the explicit cast of the field having char type towards varchar type corresponding of the set operation output type. --- .../transformers/SqlCallTransformer.java | 6 + .../common/utils/TypeDerivationUtil.java | 4 + .../DataTypeDerivedSqlCallConverter.java | 4 +- .../transformers/UnionSqlCallTransformer.java | 180 ++++++++++++++++++ .../rel2trino/HiveToTrinoConverterTest.java | 45 ++++- .../coral/trino/rel2trino/TestUtils.java | 36 ++++ 6 files changed, 273 insertions(+), 2 deletions(-) create mode 100644 coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/UnionSqlCallTransformer.java diff --git a/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java b/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java index 6d245aa4f..e309313e8 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java @@ -5,6 +5,8 @@ */ package com.linkedin.coral.common.transformers; +import java.util.List; + import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.type.RelDataType; @@ -66,6 +68,10 @@ protected RelDataType deriveRelDatatype(SqlNode sqlNode) { return typeDerivationUtil.getRelDataType(sqlNode); } + protected RelDataType leastRestrictive(List types) { + return typeDerivationUtil.leastRestrictive(types); + } + /** * This function creates a {@link SqlOperator} for a function with the function name and return type inference. */ diff --git a/coral-common/src/main/java/com/linkedin/coral/common/utils/TypeDerivationUtil.java b/coral-common/src/main/java/com/linkedin/coral/common/utils/TypeDerivationUtil.java index c9b181198..e0ba27c51 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/utils/TypeDerivationUtil.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/utils/TypeDerivationUtil.java @@ -88,6 +88,10 @@ public RelDataType getRelDataType(SqlNode sqlNode) { sqlNode, topSelectNodes.get(0))); } + public RelDataType leastRestrictive(List types) { + return sqlValidator.getTypeFactory().leastRestrictive(types); + } + private class SqlNodePreprocessorForTypeDerivation extends SqlShuttle { @Override public SqlNode visit(SqlCall sqlCall) { diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java index 546edec86..056120e91 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java @@ -21,6 +21,7 @@ import com.linkedin.coral.trino.rel2trino.transformers.GenericProjectTransformer; import com.linkedin.coral.trino.rel2trino.transformers.NamedStructToCastTransformer; import com.linkedin.coral.trino.rel2trino.transformers.SubstrOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transformers.UnionSqlCallTransformer; /** @@ -42,7 +43,8 @@ public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode to TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(toRelConverter.getSqlValidator(), topSqlNode); operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil), new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil), - new ConcatOperatorTransformer(typeDerivationUtil), new SubstrOperatorTransformer(typeDerivationUtil)); + new ConcatOperatorTransformer(typeDerivationUtil), new SubstrOperatorTransformer(typeDerivationUtil), + new UnionSqlCallTransformer(typeDerivationUtil)); } @Override diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/UnionSqlCallTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/UnionSqlCallTransformer.java new file mode 100644 index 000000000..61bcd3427 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/UnionSqlCallTransformer.java @@ -0,0 +1,180 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transformers; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlBasicTypeNameSpec; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlTypeNameSpec; +import org.apache.calcite.sql.fun.SqlCastFunction; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; + +import com.linkedin.coral.common.transformers.SqlCallTransformer; +import com.linkedin.coral.common.utils.TypeDerivationUtil; + +import static org.apache.calcite.rel.rel2sql.SqlImplementor.POS; +import static org.apache.calcite.sql.parser.SqlParserPos.ZERO; + + +/** + * This transformer class is used to adapt expressions in + * set statements (e.g.: `UNION`, `INTERSECT`, `MINUS`) + * in case that the branches of the set statement contain fields in + * char family which have different types. + * The `char` fields which are differing from the expected `varchar` output + * of the set statement will be adapted through an explicit `CAST` to the `varchar` type. + * + * @see Change char/varchar coercion in Trino + */ +public class UnionSqlCallTransformer extends SqlCallTransformer { + + public UnionSqlCallTransformer(TypeDerivationUtil typeDerivationUtil) { + super(typeDerivationUtil); + } + + @Override + protected boolean condition(SqlCall sqlCall) { + return sqlCall.getOperator().kind == SqlKind.UNION || sqlCall.getOperator().kind == SqlKind.INTERSECT + || sqlCall.getOperator().kind == SqlKind.MINUS; + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List operandsList = sqlCall.getOperandList(); + if (sqlCall.getOperandList().isEmpty()) { + return sqlCall; + } + Integer selectListSize = null; + // Ensure all the operand are SELECT nodes and they have the same number of expressions projected + for (SqlNode operand : operandsList) { + if (operand.getKind() == SqlKind.SELECT && ((SqlSelect) operand).getSelectList() != null) { + SqlSelect select = (SqlSelect) operand; + List selectList = select.getSelectList().getList(); + if (selectListSize == null) { + selectListSize = selectList.size(); + } else if (selectListSize != selectList.size()) { + return sqlCall; + } + } else { + return sqlCall; + } + } + + List> leastRestrictiveSelectItemTypes = new ArrayList<>(selectListSize); + for (int i = 0; i < selectListSize; i++) { + List selectItemTypes = new ArrayList<>(); + boolean selectItemTypesDerived = true; + for (SqlNode operand : operandsList) { + SqlSelect select = (SqlSelect) operand; + List selectList = select.getSelectList().getList(); + SqlNode selectItem = selectList.get(i); + if (selectItem.getKind() == SqlKind.IDENTIFIER && ((SqlIdentifier) selectItem).isStar()) { + // The type derivation of * on the SqlNode layer is not supported now. + return sqlCall; + } + + try { + selectItemTypes.add(deriveRelDatatype(selectItem)); + } catch (RuntimeException e) { + // The type derivation may fail for complex expressions + selectItemTypesDerived = false; + break; + } + } + + Optional leastRestrictiveSelectItemType = + selectItemTypesDerived ? Optional.ofNullable(leastRestrictive(selectItemTypes)) : Optional.empty(); + leastRestrictiveSelectItemTypes.add(leastRestrictiveSelectItemType); + } + + boolean operandsUpdated = false; + for (SqlNode operand : operandsList) { + SqlSelect select = (SqlSelect) operand; + List selectList = select.getSelectList().getList(); + List rewrittenSelectList = null; + for (int i = 0; i < selectList.size(); i++) { + SqlNode selectItem = selectList.get(i); + if (!leastRestrictiveSelectItemTypes.get(i).isPresent()) { + // Couldn't determine the type for all the expressions corresponding to the selection index + if (rewrittenSelectList != null) { + rewrittenSelectList.add(selectItem); + } + continue; + } + RelDataType leastRestrictiveSelectItemType = leastRestrictiveSelectItemTypes.get(i).get(); + RelDataType selectItemType = deriveRelDatatype(selectItem); + if (selectItemType.getSqlTypeName() == SqlTypeName.CHAR + && leastRestrictiveSelectItemType.getSqlTypeName() == SqlTypeName.VARCHAR) { + // Work-around for the Trino limitation in dealing UNION statements between `char` and `varchar`. + // See https://github.com/trinodb/trino/issues/9031 + SqlNode rewrittenSelectItem = castNode(selectItem, leastRestrictiveSelectItemType); + if (rewrittenSelectList == null) { + rewrittenSelectList = new ArrayList<>(selectListSize); + rewrittenSelectList.addAll(selectList.subList(0, i)); + operandsUpdated = true; + } + rewrittenSelectList.add(rewrittenSelectItem); + } else if (rewrittenSelectList != null) { + rewrittenSelectList.add(selectItem); + } + } + + if (rewrittenSelectList != null) { + select.setSelectList(new SqlNodeList(rewrittenSelectList, SqlParserPos.ZERO)); + } + } + + if (operandsUpdated) { + return sqlCall.getOperator().createCall(POS, operandsList); + } + + return sqlCall; + } + + private SqlNode castNode(SqlNode node, RelDataType type) { + if (node.getKind() == SqlKind.AS) { + SqlNode expression = ((SqlCall) node).getOperandList().get(0); + SqlIdentifier identifier = (SqlIdentifier) ((SqlCall) node).getOperandList().get(1); + return SqlStdOperatorTable.AS.createCall(POS, new SqlCastFunction() { + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + SqlCallBinding opBinding = new SqlCallBinding(validator, scope, call); + return inferReturnType(opBinding); + } + }.createCall(ZERO, expression, getSqlDataTypeSpec(type)), identifier); + } else { + // If there's no existing alias, just do the cast + return new SqlCastFunction() { + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + SqlCallBinding opBinding = new SqlCallBinding(validator, scope, call); + return inferReturnType(opBinding); + } + }.createCall(ZERO, node, getSqlDataTypeSpec(type)); + } + } + + private static SqlDataTypeSpec getSqlDataTypeSpec(RelDataType relDataType) { + final SqlTypeNameSpec typeNameSpec = new SqlBasicTypeNameSpec(relDataType.getSqlTypeName(), + relDataType.getPrecision(), relDataType.getScale(), null, ZERO); + return new SqlDataTypeSpec(typeNameSpec, ZERO); + } +} diff --git a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java index 3a1bddc3c..816df7bc0 100644 --- a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java +++ b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java @@ -193,7 +193,50 @@ public Object[][] viewTestCasesProvider() { + "FROM \"test\".\"duplicate_column_name_a\" AS \"duplicate_column_name_a\"\n" + "LEFT JOIN (SELECT TRIM(\"duplicate_column_name_b\".\"some_id\") AS \"SOME_ID\", CAST(TRIM(\"duplicate_column_name_b\".\"some_id\") AS VARCHAR(65536)) AS \"$f1\"\n" + "FROM \"test\".\"duplicate_column_name_b\" AS \"duplicate_column_name_b\") AS \"t\" ON \"duplicate_column_name_a\".\"some_id\" = \"t\".\"$f1\") AS \"t0\"\n" - + "WHERE \"t0\".\"some_id\" <> ''" } }; + + "WHERE \"t0\".\"some_id\" <> ''" }, + + { "test", "view_char_different_size_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(255)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT CAST(\"table_with_mixed_columns0\".\"a_char255\" AS VARCHAR(255)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" }, + + { "test", "view_cast_char_to_varchar", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"" }, + + { "test", "view_cast_char_to_varchar_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT CAST(CASE WHEN \"table_with_mixed_columns0\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns0\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" }, + + { "test", "view_cast_char_to_varchar_in_union_flipped", "SELECT CAST(CASE WHEN \"table_with_mixed_columns\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT CAST(\"table_with_mixed_columns0\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" }, + + { "test", "view_cast_char_to_varchar_with_other_fields_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"text\", \"table_with_mixed_columns\".\"a_boolean\" AS \"a_boolean\", \"table_with_mixed_columns\".\"a_smallint\" AS \"a_number\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT CAST(CASE WHEN \"table_with_mixed_columns0\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns0\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"text\", \"table_with_mixed_columns0\".\"a_boolean\" AS \"a_boolean\", \"table_with_mixed_columns0\".\"a_integer\" AS \"a_number\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" }, + + { "test", "view_char_and_null_in_union", "SELECT \"table_with_mixed_columns\".\"a_char1\" AS \"text\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT NULL AS \"text\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" }, + + { "test", "view_different_numerical_types_in_union", "SELECT *\n" + "FROM (SELECT *\n" + + "FROM (SELECT \"table_with_mixed_columns\".\"a_tinyint\" AS \"a_number\", \"table_with_mixed_columns\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT \"table_with_mixed_columns0\".\"a_smallint\" AS \"a_number\", \"table_with_mixed_columns0\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\") AS \"t1\"\n" + "UNION ALL\n" + + "SELECT \"table_with_mixed_columns1\".\"a_integer\" AS \"a_number\", \"table_with_mixed_columns1\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns1\") AS \"t3\"\n" + "UNION ALL\n" + + "SELECT \"table_with_mixed_columns2\".\"a_bigint\" AS \"a_number\", \"table_with_mixed_columns2\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns2\"" }, + + { "test", "view_union_no_casting", "SELECT \"table_with_mixed_columns\".\"a_tinyint\" AS \"a_tinyint\", \"table_with_mixed_columns\".\"a_smallint\" AS \"a_smallint\", \"table_with_mixed_columns\".\"a_integer\" AS \"a_integer\", \"table_with_mixed_columns\".\"a_bigint\" AS \"a_bigint\", \"table_with_mixed_columns\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n" + + "SELECT \"table_with_mixed_columns0\".\"a_tinyint\" AS \"a_tinyint\", \"table_with_mixed_columns0\".\"a_smallint\" AS \"a_smallint\", \"table_with_mixed_columns0\".\"a_integer\" AS \"a_integer\", \"table_with_mixed_columns0\".\"a_bigint\" AS \"a_bigint\", \"table_with_mixed_columns0\".\"a_float\" AS \"a_float\"\n" + + "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" } }; } @Test diff --git a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/TestUtils.java b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/TestUtils.java index 2c6bb0f97..0409333ff 100644 --- a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/TestUtils.java +++ b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/TestUtils.java @@ -378,6 +378,42 @@ public static void initializeTablesAndViews(HiveConf conf) throws HiveException, run(driver, "CREATE TABLE test.table_with_binary_column (b binary)"); + run(driver, + "CREATE TABLE test.table_with_mixed_columns (a_char1 char(1), a_char255 char(255), a_string string, a_tinyint tinyint, a_smallint smallint, a_integer int, a_bigint bigint, a_float float, a_double double, a_boolean boolean)"); + run(driver, "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar AS \n" + + "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_char_different_size_in_union AS \n" + + "SELECT a_char1 AS col FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT a_char255 AS col FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_in_union AS \n" + + "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT COALESCE(a_char1, 'N') AS col FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_in_union_flipped AS \n" + + "SELECT COALESCE(a_char1, 'N') as col FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns"); + run(driver, "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_with_other_fields_in_union AS \n" + + "SELECT CAST(a_char1 AS VARCHAR(65535)) AS text , a_boolean, a_smallint as a_number FROM test.table_with_mixed_columns \n" + + "UNION ALL\n" + + "SELECT COALESCE(a_char1, 'N') as text, a_boolean, a_integer as a_number FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_char_and_null_in_union AS \n" + + "SELECT a_char1 as text FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT NULL text FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_different_numerical_types_in_union AS \n" + + "SELECT a_tinyint AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT a_smallint AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT a_integer AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n" + + "SELECT a_bigint AS a_number, a_float FROM test.table_with_mixed_columns"); + run(driver, + "CREATE VIEW IF NOT EXISTS test.view_union_no_casting AS \n" + + "SELECT a_tinyint, a_smallint, a_integer, a_bigint, a_float FROM test.table_with_mixed_columns \n" + + "UNION ALL\n" + + "SELECT a_tinyint, a_smallint, a_integer, a_bigint, a_float FROM test.table_with_mixed_columns"); + // Tables used in RelToTrinoConverterTest run(driver, "CREATE TABLE IF NOT EXISTS test.tableOne(icol int, dcol double, scol string, tcol timestamp, acol array)");