Skip to content

Commit

Permalink
[Coral-Trino] Cast char fields, if necessary, to varchar type in the …
Browse files Browse the repository at this point in the history
…set operation (#442)

In case of dealing with Hive views which make use of set operation (e.g. UNION)
ensure that the `char` fields from the inner SELECT statements have the same type
as the output field types of the set operation.

Due to wrong coercion between `varchar` and `char` in Trino, as described in
trinodb/trino#9031
, a work-around needs to be applied in case of translating Hive views which
contain a UNION dealing with char and varchar types.
The work-around consists in the explicit cast of the field
having char type towards varchar type corresponding of the set operation
output type.
  • Loading branch information
findinpath authored Oct 5, 2023
1 parent a886a95 commit 4bb38a1
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
*/
package com.linkedin.coral.common.transformers;

import java.util.List;

import com.google.common.collect.ImmutableList;

import org.apache.calcite.rel.type.RelDataType;
Expand Down Expand Up @@ -66,6 +68,10 @@ protected RelDataType deriveRelDatatype(SqlNode sqlNode) {
return typeDerivationUtil.getRelDataType(sqlNode);
}

protected RelDataType leastRestrictive(List<RelDataType> types) {
return typeDerivationUtil.leastRestrictive(types);
}

/**
* This function creates a {@link SqlOperator} for a function with the function name and return type inference.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ public RelDataType getRelDataType(SqlNode sqlNode) {
sqlNode, topSelectNodes.get(0)));
}

public RelDataType leastRestrictive(List<RelDataType> types) {
return sqlValidator.getTypeFactory().leastRestrictive(types);
}

private class SqlNodePreprocessorForTypeDerivation extends SqlShuttle {
@Override
public SqlNode visit(SqlCall sqlCall) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.linkedin.coral.trino.rel2trino.transformers.GenericProjectTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.NamedStructToCastTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.SubstrOperatorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.UnionSqlCallTransformer;


/**
Expand All @@ -42,7 +43,8 @@ public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode to
TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(toRelConverter.getSqlValidator(), topSqlNode);
operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil),
new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil),
new ConcatOperatorTransformer(typeDerivationUtil), new SubstrOperatorTransformer(typeDerivationUtil));
new ConcatOperatorTransformer(typeDerivationUtil), new SubstrOperatorTransformer(typeDerivationUtil),
new UnionSqlCallTransformer(typeDerivationUtil));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlTypeNameSpec;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;

import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.common.utils.TypeDerivationUtil;

import static org.apache.calcite.rel.rel2sql.SqlImplementor.POS;
import static org.apache.calcite.sql.parser.SqlParserPos.ZERO;


/**
* This transformer class is used to adapt expressions in
* set statements (e.g.: `UNION`, `INTERSECT`, `MINUS`)
* in case that the branches of the set statement contain fields in
* char family which have different types.
* The `char` fields which are differing from the expected `varchar` output
* of the set statement will be adapted through an explicit `CAST` to the `varchar` type.
*
* @see <a href="https://github.com/trinodb/trino/issues/9031">Change char/varchar coercion in Trino</a>
*/
public class UnionSqlCallTransformer extends SqlCallTransformer {

public UnionSqlCallTransformer(TypeDerivationUtil typeDerivationUtil) {
super(typeDerivationUtil);
}

@Override
protected boolean condition(SqlCall sqlCall) {
return sqlCall.getOperator().kind == SqlKind.UNION || sqlCall.getOperator().kind == SqlKind.INTERSECT
|| sqlCall.getOperator().kind == SqlKind.MINUS;
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
List<SqlNode> operandsList = sqlCall.getOperandList();
if (sqlCall.getOperandList().isEmpty()) {
return sqlCall;
}
Integer selectListSize = null;
// Ensure all the operand are SELECT nodes and they have the same number of expressions projected
for (SqlNode operand : operandsList) {
if (operand.getKind() == SqlKind.SELECT && ((SqlSelect) operand).getSelectList() != null) {
SqlSelect select = (SqlSelect) operand;
List<SqlNode> selectList = select.getSelectList().getList();
if (selectListSize == null) {
selectListSize = selectList.size();
} else if (selectListSize != selectList.size()) {
return sqlCall;
}
} else {
return sqlCall;
}
}

List<Optional<RelDataType>> leastRestrictiveSelectItemTypes = new ArrayList<>(selectListSize);
for (int i = 0; i < selectListSize; i++) {
List<RelDataType> selectItemTypes = new ArrayList<>();
boolean selectItemTypesDerived = true;
for (SqlNode operand : operandsList) {
SqlSelect select = (SqlSelect) operand;
List<SqlNode> selectList = select.getSelectList().getList();
SqlNode selectItem = selectList.get(i);
if (selectItem.getKind() == SqlKind.IDENTIFIER && ((SqlIdentifier) selectItem).isStar()) {
// The type derivation of * on the SqlNode layer is not supported now.
return sqlCall;
}

try {
selectItemTypes.add(deriveRelDatatype(selectItem));
} catch (RuntimeException e) {
// The type derivation may fail for complex expressions
selectItemTypesDerived = false;
break;
}
}

Optional<RelDataType> leastRestrictiveSelectItemType =
selectItemTypesDerived ? Optional.ofNullable(leastRestrictive(selectItemTypes)) : Optional.empty();
leastRestrictiveSelectItemTypes.add(leastRestrictiveSelectItemType);
}

boolean operandsUpdated = false;
for (SqlNode operand : operandsList) {
SqlSelect select = (SqlSelect) operand;
List<SqlNode> selectList = select.getSelectList().getList();
List<SqlNode> rewrittenSelectList = null;
for (int i = 0; i < selectList.size(); i++) {
SqlNode selectItem = selectList.get(i);
if (!leastRestrictiveSelectItemTypes.get(i).isPresent()) {
// Couldn't determine the type for all the expressions corresponding to the selection index
if (rewrittenSelectList != null) {
rewrittenSelectList.add(selectItem);
}
continue;
}
RelDataType leastRestrictiveSelectItemType = leastRestrictiveSelectItemTypes.get(i).get();
RelDataType selectItemType = deriveRelDatatype(selectItem);
if (selectItemType.getSqlTypeName() == SqlTypeName.CHAR
&& leastRestrictiveSelectItemType.getSqlTypeName() == SqlTypeName.VARCHAR) {
// Work-around for the Trino limitation in dealing UNION statements between `char` and `varchar`.
// See https://github.com/trinodb/trino/issues/9031
SqlNode rewrittenSelectItem = castNode(selectItem, leastRestrictiveSelectItemType);
if (rewrittenSelectList == null) {
rewrittenSelectList = new ArrayList<>(selectListSize);
rewrittenSelectList.addAll(selectList.subList(0, i));
operandsUpdated = true;
}
rewrittenSelectList.add(rewrittenSelectItem);
} else if (rewrittenSelectList != null) {
rewrittenSelectList.add(selectItem);
}
}

if (rewrittenSelectList != null) {
select.setSelectList(new SqlNodeList(rewrittenSelectList, SqlParserPos.ZERO));
}
}

if (operandsUpdated) {
return sqlCall.getOperator().createCall(POS, operandsList);
}

return sqlCall;
}

private SqlNode castNode(SqlNode node, RelDataType type) {
if (node.getKind() == SqlKind.AS) {
SqlNode expression = ((SqlCall) node).getOperandList().get(0);
SqlIdentifier identifier = (SqlIdentifier) ((SqlCall) node).getOperandList().get(1);
return SqlStdOperatorTable.AS.createCall(POS, new SqlCastFunction() {
@Override
public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
SqlCallBinding opBinding = new SqlCallBinding(validator, scope, call);
return inferReturnType(opBinding);
}
}.createCall(ZERO, expression, getSqlDataTypeSpec(type)), identifier);
} else {
// If there's no existing alias, just do the cast
return new SqlCastFunction() {
@Override
public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
SqlCallBinding opBinding = new SqlCallBinding(validator, scope, call);
return inferReturnType(opBinding);
}
}.createCall(ZERO, node, getSqlDataTypeSpec(type));
}
}

private static SqlDataTypeSpec getSqlDataTypeSpec(RelDataType relDataType) {
final SqlTypeNameSpec typeNameSpec = new SqlBasicTypeNameSpec(relDataType.getSqlTypeName(),
relDataType.getPrecision(), relDataType.getScale(), null, ZERO);
return new SqlDataTypeSpec(typeNameSpec, ZERO);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,50 @@ public Object[][] viewTestCasesProvider() {
+ "FROM \"test\".\"duplicate_column_name_a\" AS \"duplicate_column_name_a\"\n"
+ "LEFT JOIN (SELECT TRIM(\"duplicate_column_name_b\".\"some_id\") AS \"SOME_ID\", CAST(TRIM(\"duplicate_column_name_b\".\"some_id\") AS VARCHAR(65536)) AS \"$f1\"\n"
+ "FROM \"test\".\"duplicate_column_name_b\" AS \"duplicate_column_name_b\") AS \"t\" ON \"duplicate_column_name_a\".\"some_id\" = \"t\".\"$f1\") AS \"t0\"\n"
+ "WHERE \"t0\".\"some_id\" <> ''" } };
+ "WHERE \"t0\".\"some_id\" <> ''" },

{ "test", "view_char_different_size_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(255)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT CAST(\"table_with_mixed_columns0\".\"a_char255\" AS VARCHAR(255)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" },

{ "test", "view_cast_char_to_varchar", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"" },

{ "test", "view_cast_char_to_varchar_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT CAST(CASE WHEN \"table_with_mixed_columns0\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns0\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" },

{ "test", "view_cast_char_to_varchar_in_union_flipped", "SELECT CAST(CASE WHEN \"table_with_mixed_columns\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT CAST(\"table_with_mixed_columns0\".\"a_char1\" AS VARCHAR(65535)) AS \"col\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" },

{ "test", "view_cast_char_to_varchar_with_other_fields_in_union", "SELECT CAST(\"table_with_mixed_columns\".\"a_char1\" AS VARCHAR(65535)) AS \"text\", \"table_with_mixed_columns\".\"a_boolean\" AS \"a_boolean\", \"table_with_mixed_columns\".\"a_smallint\" AS \"a_number\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT CAST(CASE WHEN \"table_with_mixed_columns0\".\"a_char1\" IS NOT NULL THEN \"table_with_mixed_columns0\".\"a_char1\" ELSE 'N' END AS VARCHAR(65535)) AS \"text\", \"table_with_mixed_columns0\".\"a_boolean\" AS \"a_boolean\", \"table_with_mixed_columns0\".\"a_integer\" AS \"a_number\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" },

{ "test", "view_char_and_null_in_union", "SELECT \"table_with_mixed_columns\".\"a_char1\" AS \"text\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT NULL AS \"text\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" },

{ "test", "view_different_numerical_types_in_union", "SELECT *\n" + "FROM (SELECT *\n"
+ "FROM (SELECT \"table_with_mixed_columns\".\"a_tinyint\" AS \"a_number\", \"table_with_mixed_columns\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT \"table_with_mixed_columns0\".\"a_smallint\" AS \"a_number\", \"table_with_mixed_columns0\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\") AS \"t1\"\n" + "UNION ALL\n"
+ "SELECT \"table_with_mixed_columns1\".\"a_integer\" AS \"a_number\", \"table_with_mixed_columns1\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns1\") AS \"t3\"\n" + "UNION ALL\n"
+ "SELECT \"table_with_mixed_columns2\".\"a_bigint\" AS \"a_number\", \"table_with_mixed_columns2\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns2\"" },

{ "test", "view_union_no_casting", "SELECT \"table_with_mixed_columns\".\"a_tinyint\" AS \"a_tinyint\", \"table_with_mixed_columns\".\"a_smallint\" AS \"a_smallint\", \"table_with_mixed_columns\".\"a_integer\" AS \"a_integer\", \"table_with_mixed_columns\".\"a_bigint\" AS \"a_bigint\", \"table_with_mixed_columns\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns\"\n" + "UNION ALL\n"
+ "SELECT \"table_with_mixed_columns0\".\"a_tinyint\" AS \"a_tinyint\", \"table_with_mixed_columns0\".\"a_smallint\" AS \"a_smallint\", \"table_with_mixed_columns0\".\"a_integer\" AS \"a_integer\", \"table_with_mixed_columns0\".\"a_bigint\" AS \"a_bigint\", \"table_with_mixed_columns0\".\"a_float\" AS \"a_float\"\n"
+ "FROM \"test\".\"table_with_mixed_columns\" AS \"table_with_mixed_columns0\"" } };
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,42 @@ public static void initializeTablesAndViews(HiveConf conf) throws HiveException,

run(driver, "CREATE TABLE test.table_with_binary_column (b binary)");

run(driver,
"CREATE TABLE test.table_with_mixed_columns (a_char1 char(1), a_char255 char(255), a_string string, a_tinyint tinyint, a_smallint smallint, a_integer int, a_bigint bigint, a_float float, a_double double, a_boolean boolean)");
run(driver, "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar AS \n"
+ "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_char_different_size_in_union AS \n"
+ "SELECT a_char1 AS col FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT a_char255 AS col FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_in_union AS \n"
+ "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT COALESCE(a_char1, 'N') AS col FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_in_union_flipped AS \n"
+ "SELECT COALESCE(a_char1, 'N') as col FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT CAST(a_char1 AS VARCHAR(65535)) AS col FROM test.table_with_mixed_columns");
run(driver, "CREATE VIEW IF NOT EXISTS test.view_cast_char_to_varchar_with_other_fields_in_union AS \n"
+ "SELECT CAST(a_char1 AS VARCHAR(65535)) AS text , a_boolean, a_smallint as a_number FROM test.table_with_mixed_columns \n"
+ "UNION ALL\n"
+ "SELECT COALESCE(a_char1, 'N') as text, a_boolean, a_integer as a_number FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_char_and_null_in_union AS \n"
+ "SELECT a_char1 as text FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT NULL text FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_different_numerical_types_in_union AS \n"
+ "SELECT a_tinyint AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT a_smallint AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT a_integer AS a_number, a_float FROM test.table_with_mixed_columns \n" + "UNION ALL\n"
+ "SELECT a_bigint AS a_number, a_float FROM test.table_with_mixed_columns");
run(driver,
"CREATE VIEW IF NOT EXISTS test.view_union_no_casting AS \n"
+ "SELECT a_tinyint, a_smallint, a_integer, a_bigint, a_float FROM test.table_with_mixed_columns \n"
+ "UNION ALL\n"
+ "SELECT a_tinyint, a_smallint, a_integer, a_bigint, a_float FROM test.table_with_mixed_columns");

// Tables used in RelToTrinoConverterTest
run(driver,
"CREATE TABLE IF NOT EXISTS test.tableOne(icol int, dcol double, scol string, tcol timestamp, acol array<string>)");
Expand Down

0 comments on commit 4bb38a1

Please sign in to comment.