Skip to content

Commit

Permalink
[Coral-Trino] Cast char fields, if necessary, to varchar type in the …
Browse files Browse the repository at this point in the history
…set operation

In case of dealing with Hive views which make use of set operation (e.g. UNION)
ensure that the `char` fields from the inner SELECT statements have the same type
as the output field types of the set operation.

Due to wrong coercion between `varchar` and `char` in Trino, as described in
trinodb/trino#9031
, a work-around needs to be applied in case of translating Hive views which
contain a UNION dealing with char and varchar types.
The work-around consists in the explicit cast of the field
having char type towards varchar type corresponding of the set operation
output type.
  • Loading branch information
findinpath committed Sep 19, 2023
1 parent db56069 commit 37f39d3
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
*/
package com.linkedin.coral.common.transformers;

import java.util.List;

import com.google.common.collect.ImmutableList;

import org.apache.calcite.rel.type.RelDataType;
Expand Down Expand Up @@ -66,6 +68,10 @@ protected RelDataType deriveRelDatatype(SqlNode sqlNode) {
return typeDerivationUtil.getRelDataType(sqlNode);
}

protected RelDataType leastRestrictive(List<RelDataType> types) {
return typeDerivationUtil.leastRestrictive(types);
}

/**
* This function creates a {@link SqlOperator} for a function with the function name and return type inference.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ public RelDataType getRelDataType(SqlNode sqlNode) {
sqlNode, topSelectNodes.get(0)));
}

public RelDataType leastRestrictive(List<RelDataType> types) {
return sqlValidator.getTypeFactory().leastRestrictive(types);
}

private class SqlNodePreprocessorForTypeDerivation extends SqlShuttle {
@Override
public SqlNode visit(SqlCall sqlCall) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.linkedin.coral.trino.rel2trino.transformers.GenericProjectTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.NamedStructToCastTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.SubstrOperatorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.UnionSqlCallTransformer;


/**
Expand All @@ -42,7 +43,8 @@ public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode to
TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(toRelConverter.getSqlValidator(), topSqlNode);
operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil),
new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil),
new ConcatOperatorTransformer(typeDerivationUtil), new SubstrOperatorTransformer(typeDerivationUtil));
new ConcatOperatorTransformer(typeDerivationUtil), new SubstrOperatorTransformer(typeDerivationUtil),
new UnionSqlCallTransformer(typeDerivationUtil));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlTypeNameSpec;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;

import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.common.utils.TypeDerivationUtil;

import static org.apache.calcite.rel.rel2sql.SqlImplementor.POS;
import static org.apache.calcite.sql.parser.SqlParserPos.ZERO;


/**
* This transformer class is used to adapt expressions in
* set statements (e.g.: `UNION`, `INTERSECT`, `MINUS`)
* in case that the branches of the set statement contain fields in
* char family which have different types.
* The `char` fields which are differing from the expected `varchar` output
* of the set statement will be adapted through an explicit `CAST` to the `varchar` type.
*
* @see <a href="https://github.com/trinodb/trino/issues/9031">Change char/varchar coercion in Trino</a>
*/
public class UnionSqlCallTransformer extends SqlCallTransformer {

public UnionSqlCallTransformer(TypeDerivationUtil typeDerivationUtil) {
super(typeDerivationUtil);
}

@Override
protected boolean condition(SqlCall sqlCall) {
return sqlCall.getOperator().kind == SqlKind.UNION || sqlCall.getOperator().kind == SqlKind.INTERSECT
|| sqlCall.getOperator().kind == SqlKind.MINUS;
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
List<SqlNode> operandsList = sqlCall.getOperandList();
List<List<Optional<RelDataType>>> columnsTypesLists = new ArrayList<>();
for (SqlNode operand : operandsList) {
if (operand.getKind() == SqlKind.SELECT && ((SqlSelect) operand).getSelectList() != null) {
SqlSelect select = (SqlSelect) operand;
List<SqlNode> selectNodes = select.getSelectList().getList();
if (columnsTypesLists.isEmpty()) {
for (int i = 0; i < selectNodes.size(); i++) {
columnsTypesLists.add(new ArrayList<>());
}
}

for (int i = 0; i < selectNodes.size(); i++) {
SqlNode sqlNode = selectNodes.get(i);
if (sqlNode.getKind() == SqlKind.IDENTIFIER && ((SqlIdentifier) sqlNode).isStar()) {
// The type derivation of * on the SqlNode layer is not supported now.
return sqlCall;
}

Optional<RelDataType> selectNodeDataType;
try {
selectNodeDataType = Optional.of(deriveRelDatatype(sqlNode));
} catch (RuntimeException e) {
// The type derivation may fail for complex expressions
selectNodeDataType = Optional.empty();
}
columnsTypesLists.get(i).add(selectNodeDataType);
}
} else {
return sqlCall;
}
}

List<SqlNode> updatedOperands = new ArrayList<>(sqlCall.getOperandList().size());
for (SqlNode operand : operandsList) {
SqlSelect select = (SqlSelect) operand;

List<SqlNode> selectNodes = select.getSelectList().getList();

List<SqlNode> rewrittenSelectNodes = new ArrayList<>();
boolean useRewrittenSelectNodes = false;
for (int i = 0; i < selectNodes.size(); i++) {

if (columnsTypesLists.get(i).stream().anyMatch(columnType -> !columnType.isPresent())) {
// Couldn't determine the type for all the expressions
continue;
}
RelDataType inferredColumnType =
leastRestrictive(columnsTypesLists.get(i).stream().map(Optional::get).collect(Collectors.toList()));

SqlNode selectNode = selectNodes.get(i);
RelDataType selectNodeDataType = deriveRelDatatype(selectNodes.get(i));
if (!selectNodeDataType.equals(inferredColumnType) && selectNodeDataType.getSqlTypeName() == SqlTypeName.CHAR
&& inferredColumnType.getSqlTypeName() == SqlTypeName.VARCHAR) {
// Work-around for the Trino limitation in dealing UNION statements between `char` and `varchar`.
// See https://github.com/trinodb/trino/issues/9031
SqlNode rewrittenSelectNode = castNode(selectNode, inferredColumnType);
if (!useRewrittenSelectNodes) {
rewrittenSelectNodes.addAll(selectNodes.subList(0, i));
useRewrittenSelectNodes = true;
}
rewrittenSelectNodes.add(rewrittenSelectNode);
} else if (useRewrittenSelectNodes) {
rewrittenSelectNodes.add(selectNode);
}
}

if (useRewrittenSelectNodes) {
select.setSelectList(new SqlNodeList(rewrittenSelectNodes, SqlParserPos.ZERO));
}
updatedOperands.add(operand);
}
return sqlCall.getOperator().createCall(POS, updatedOperands);
}

private SqlNode castNode(SqlNode node, RelDataType type) {
if (node.getKind() == SqlKind.AS) {
SqlNode expression = ((SqlCall) node).getOperandList().get(0);
SqlIdentifier identifier = (SqlIdentifier) ((SqlCall) node).getOperandList().get(1);
return SqlStdOperatorTable.AS.createCall(POS, new SqlCastFunction() {
@Override
public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
SqlCallBinding opBinding = new SqlCallBinding(validator, scope, call);
return inferReturnType(opBinding);
}
}.createCall(ZERO, expression, getSqlDataTypeSpecForCasting(type)), identifier);
} else {
// If there's no existing alias, just do the cast
return new SqlCastFunction() {
@Override
public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
SqlCallBinding opBinding = new SqlCallBinding(validator, scope, call);
return inferReturnType(opBinding);
}
}.createCall(ZERO, node, getSqlDataTypeSpecForCasting(type));
}
}

private static SqlDataTypeSpec getSqlDataTypeSpecForCasting(RelDataType relDataType) {
final SqlTypeNameSpec typeNameSpec = new SqlBasicTypeNameSpec(relDataType.getSqlTypeName(),
relDataType.getPrecision(), relDataType.getScale(), null, ZERO);
return new SqlDataTypeSpec(typeNameSpec, ZERO);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,50 @@ public Object[][] viewTestCasesProvider() {
+ "FROM \"test\".\"duplicate_column_name_a\" AS \"duplicate_column_name_a\"\n"
+ "LEFT JOIN (SELECT TRIM(\"duplicate_column_name_b\".\"some_id\") AS \"SOME_ID\", CAST(TRIM(\"duplicate_column_name_b\".\"some_id\") AS VARCHAR(65536)) AS \"$f1\"\n"
+ "FROM \"test\".\"duplicate_column_name_b\" AS \"duplicate_column_name_b\") AS \"t\" ON \"duplicate_column_name_a\".\"some_id\" = \"t\".\"$f1\") AS \"t0\"\n"
+ "WHERE \"t0\".\"some_id\" <> ''" } };
+ "WHERE \"t0\".\"some_id\" <> ''" },

{ "test", "view_char_different_size_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(255)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT CAST(\"table_with_mixed_columns0\".\"a_char255\" AS VARCHAR(255)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" },

{ "test", "view_cast_char_to_varchar", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"" },

{ "test", "view_cast_char_to_varchar_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT CAST(CASE WHEN \"table_with_mixed_columns0\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns0\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" },

{ "test", "view_cast_char_to_varchar_in_union_flipped", "SELECT CAST(CASE WHEN \"table_with_mixed_columns\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT CAST(\"table_with_mixed_columns0\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" },

{ "test", "view_cast_char_to_varchar_with_other_fields_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"text\", \"table_with_mixed_columns\".\"a_boolean\" AS \"a_boolean\", \"table_with_mixed_columns\".\"a_smallint\" AS \"a_number\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT CAST(CASE WHEN \"table_with_mixed_columns0\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns0\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"text\", \"table_with_mixed_columns0\".\"a_boolean\" AS \"a_boolean\", \"table_with_mixed_columns0\".\"a_integer\" AS \"a_number\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" },

{ "test", "view_char_and_null_in_union", "SELECT \"table_with_mixed_columns\".\"a_char1\" AS \"text\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT NULL AS \"text\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" },

{ "test", "view_different_numerical_types_in_union", "SELECT *\n" + "FROM (SELECT *\n"
+ "FROM (SELECT \"table_with_mixed_columns\".\"a_tinyint\" AS \"a_number\", \"table_with_mixed_columns\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT \"table_with_mixed_columns0\".\"a_smallint\" AS \"a_number\", \"table_with_mixed_columns0\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\") AS \"t1\"\n" + "UNION ALL\n"
+ "SELECT \"table_with_mixed_columns1\".\"a_integer\" AS \"a_number\", \"table_with_mixed_columns1\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns1\") AS \"t3\"\n" + "UNION ALL\n"
+ "SELECT \"table_with_mixed_columns2\".\"a_bigint\" AS \"a_number\", \"table_with_mixed_columns2\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns2\"" },

{ "test", "view_union_no_casting", "SELECT \"table_with_mixed_columns\".\"a_tinyint\" AS \"a_tinyint\", \"table_with_mixed_columns\".\"a_smallint\" AS \"a_smallint\", \"table_with_mixed_columns\".\"a_integer\" AS \"a_integer\", \"table_with_mixed_columns\".\"a_bigint\" AS \"a_bigint\", \"table_with_mixed_columns\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT \"table_with_mixed_columns0\".\"a_tinyint\" AS \"a_tinyint\", \"table_with_mixed_columns0\".\"a_smallint\" AS \"a_smallint\", \"table_with_mixed_columns0\".\"a_integer\" AS \"a_integer\", \"table_with_mixed_columns0\".\"a_bigint\" AS \"a_bigint\", \"table_with_mixed_columns0\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" } };
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,42 @@ public static void initializeTablesAndViews(HiveConf conf) throws HiveException,

run(driver, "CREATE TABLE test.table_with_binary_column (b binary)");

run(driver,
"CREATE TABLE test.table_with_mixed_columns (a_char1 char(1), a_char255 char(255), a_string string, a_tinyint tinyint, a_smallint smallint, a_integer int, a_bigint bigint, a_float float, a_double double, a_boolean boolean)");
run(driver, "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar AS \n"
+ "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_char_different_size_in_union AS \n"
+ "SELECT a_char1 AS col FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT a_char255 AS col FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_in_union AS \n"
+ "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT COALESCE(a_char1, 'N') AS col FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_in_union_flipped AS \n"
+ "SELECT COALESCE(a_char1, 'N') as col FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns");
run(driver, "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_with_other_fields_in_union AS \n"
+ "SELECT CAST(a_char1 AS VARCHAR(65535)) AS text , a_boolean, a_smallint as a_number FROM test.table_with_mixed_columns \n"
+ "UNION ALL\n"
+ "SELECT COALESCE(a_char1, 'N') as text, a_boolean, a_integer as a_number FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_char_and_null_in_union AS \n"
+ "SELECT a_char1 as text FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT NULL text FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_different_numerical_types_in_union AS \n"
+ "SELECT a_tinyint AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT a_smallint AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT a_integer AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT a_bigint AS a_number, a_float FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_union_no_casting AS \n"
+ "SELECT a_tinyint, a_smallint, a_integer, a_bigint, a_float FROM test.table_with_mixed_columns \n"
+ "UNION ALL\n"
+ "SELECT a_tinyint, a_smallint, a_integer, a_bigint, a_float FROM test.table_with_mixed_columns");

// Tables used in RelToTrinoConverterTest
run(driver,
"CREATE TABLE IF NOT EXISTS test.tableOne(icol int, dcol double, scol string, tcol timestamp, acol array<string>)");
Expand Down

0 comments on commit 37f39d3

Please sign in to comment.