diff --git a/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java b/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java index 0ea9d89af..6d245aa4f 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java @@ -26,7 +26,6 @@ public abstract class SqlCallTransformer { private TypeDerivationUtil typeDerivationUtil; public SqlCallTransformer() { - } public SqlCallTransformer(TypeDerivationUtil typeDerivationUtil) { diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverter.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverter.java index f29fe4636..7243efd47 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverter.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverter.java @@ -63,6 +63,10 @@ public HiveToRelConverter(Map>> localMetaStore) this.parseTreeBuilder = new ParseTreeBuilder(functionResolver); } + public HiveFunctionResolver getFunctionResolver() { + return functionResolver; + } + @Override protected SqlRexConvertletTable getConvertletTable() { return new HiveConvertletTable(); diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/HiveFunctionResolver.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/HiveFunctionResolver.java index 7e9c922e6..5159e5533 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/HiveFunctionResolver.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/HiveFunctionResolver.java @@ -203,6 +203,12 @@ public Collection tryResolveAsDaliFunction(String functionName, @Nonnu .collect(Collectors.toList()); } + public void addDynamicFunctionToTheRegistry(String funcClassName, Function function) { + if (!dynamicFunctionRegistry.contains(funcClassName)) { + dynamicFunctionRegistry.put(funcClassName, function); + } + } + private @Nonnull Collection resolveDaliFunctionDynamically(String functionName, String funcClassName, HiveTable hiveTable, int numOfOperands) { if (dynamicFunctionRegistry.contains(funcClassName)) { diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java index 1bc27d1bc..6cb29c496 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java @@ -5,7 +5,6 @@ */ package com.linkedin.coral.trino.rel2trino; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -175,34 +174,9 @@ public RexNode visitCall(RexCall call) { } } - if (operatorName.equalsIgnoreCase("concat")) { - Optional modifiedCall = visitConcat(call); - if (modifiedCall.isPresent()) { - return modifiedCall.get(); - } - } - return super.visitCall(call); } - private Optional visitConcat(RexCall call) { - // Hive supports operations like CONCAT(date, varchar) while Trino only supports CONCAT(varchar, varchar) - // So we need to cast the unsupported types to varchar - final SqlOperator op = call.getOperator(); - List convertedOperands = visitList(call.getOperands(), (boolean[]) null); - List castOperands = new ArrayList<>(); - - for (RexNode inputOperand : convertedOperands) { - if (inputOperand.getType().getSqlTypeName() != VARCHAR && inputOperand.getType().getSqlTypeName() != CHAR) { - final RexNode castOperand = rexBuilder.makeCast(typeFactory.createSqlType(VARCHAR), inputOperand); - castOperands.add(castOperand); - } else { - castOperands.add(inputOperand); - } - } - return Optional.of(rexBuilder.makeCall(op, castOperands)); - } - // Hive allows passing in a byte array or String to substr/substring, so we can make an effort to emulate the // behavior by casting non-String input to String // https://cwiki.apache.org/confluence/display/hive/languagemanual+udf diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java index 2da1962a6..c6334c953 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java @@ -5,15 +5,18 @@ */ package com.linkedin.coral.trino.rel2trino; +import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.util.SqlShuttle; -import org.apache.calcite.sql.validate.SqlValidator; import com.linkedin.coral.common.HiveMetastoreClient; +import com.linkedin.coral.common.functions.Function; import com.linkedin.coral.common.transformers.SqlCallTransformers; import com.linkedin.coral.common.utils.TypeDerivationUtil; import com.linkedin.coral.hive.hive2rel.HiveToRelConverter; +import com.linkedin.coral.hive.hive2rel.functions.VersionedSqlUserDefinedFunction; +import com.linkedin.coral.trino.rel2trino.transformers.ConcatOperatorTransformer; import com.linkedin.coral.trino.rel2trino.transformers.FromUtcTimestampOperatorTransformer; import com.linkedin.coral.trino.rel2trino.transformers.GenericProjectTransformer; import com.linkedin.coral.trino.rel2trino.transformers.NamedStructToCastTransformer; @@ -29,16 +32,34 @@ */ public class DataTypeDerivedSqlCallConverter extends SqlShuttle { private final SqlCallTransformers operatorTransformerList; + private final HiveToRelConverter toRelConverter; public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode topSqlNode) { - SqlValidator sqlValidator = new HiveToRelConverter(mscClient).getSqlValidator(); - TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(sqlValidator, topSqlNode); + toRelConverter = new HiveToRelConverter(mscClient); + topSqlNode.accept(new RegisterDynamicFunctionsForTypeDerivation()); + + TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(toRelConverter.getSqlValidator(), topSqlNode); operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil), - new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil)); + new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil), + new ConcatOperatorTransformer(typeDerivationUtil)); } @Override public SqlNode visit(final SqlCall call) { return operatorTransformerList.apply((SqlCall) super.visit(call)); } + + private class RegisterDynamicFunctionsForTypeDerivation extends SqlShuttle { + @Override + public SqlNode visit(SqlCall sqlCall) { + if (sqlCall instanceof SqlBasicCall && sqlCall.getOperator() instanceof VersionedSqlUserDefinedFunction + && sqlCall.getOperator().getName().contains(".")) { + // Register versioned SqlUserDefinedFunctions in RelConverter's dynamicFunctionRegistry. + // This enables the SqlValidator to derive RelDataType for SqlCalls that involve these operators. + Function function = new Function(sqlCall.getOperator().getName(), sqlCall.getOperator()); + toRelConverter.getFunctionResolver().addDynamicFunctionToTheRegistry(sqlCall.getOperator().getName(), function); + } + return super.visit(sqlCall); + } + } } diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java index 47b87a6e8..416c0299f 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java @@ -138,13 +138,11 @@ public Result visit(Project e) { final List selectList = new ArrayList<>(); for (RexNode ref : e.getChildExps()) { SqlNode sqlExpr = builder.context.toSql(null, ref); - // Append the CAST operator when the derived data type is NON-NULL. RelDataTypeField targetField = e.getRowType().getFieldList().get(selectList.size()); if (SqlUtil.isNullLiteral(sqlExpr, false) && !targetField.getValue().getSqlTypeName().equals(SqlTypeName.NULL)) { sqlExpr = SqlStdOperatorTable.CAST.createCall(POS, sqlExpr, dialect.getCastSpec(targetField.getType())); } - addSelect(selectList, sqlExpr, e.getRowType()); } builder.setSelect(new SqlNodeList(selectList, POS)); diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/ConcatOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/ConcatOperatorTransformer.java new file mode 100644 index 000000000..9ba38a274 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/ConcatOperatorTransformer.java @@ -0,0 +1,69 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transformers; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlBasicTypeNameSpec; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; + +import com.linkedin.coral.common.HiveTypeSystem; +import com.linkedin.coral.common.transformers.SqlCallTransformer; +import com.linkedin.coral.common.utils.TypeDerivationUtil; + +import static org.apache.calcite.rel.rel2sql.SqlImplementor.*; +import static org.apache.calcite.sql.parser.SqlParserPos.*; + + +/** + * This transformer is designed for SqlCalls that use the CONCAT operator. + * Its purpose is to convert the data types of the operands to be compatible with Trino. + * Trino only allows VARCHAR type operands for the CONCAT operator. Therefore, if there are any other data type operands present, + * an extra CAST operator is added around the operand to cast it to VARCHAR. + */ +public class ConcatOperatorTransformer extends SqlCallTransformer { + private static final int DEFAULT_VARCHAR_PRECISION = new HiveTypeSystem().getDefaultPrecision(SqlTypeName.VARCHAR); + private static final String OPERATOR_NAME = "concat"; + private static final Set OPERAND_SQL_TYPE_NAMES = + new HashSet<>(Arrays.asList(SqlTypeName.VARCHAR, SqlTypeName.CHAR)); + private static final SqlDataTypeSpec VARCHAR_SQL_DATA_TYPE_SPEC = + new SqlDataTypeSpec(new SqlBasicTypeNameSpec(SqlTypeName.VARCHAR, DEFAULT_VARCHAR_PRECISION, ZERO), ZERO); + + public ConcatOperatorTransformer(TypeDerivationUtil typeDerivationUtil) { + super(typeDerivationUtil); + } + + @Override + protected boolean condition(SqlCall sqlCall) { + return sqlCall.getOperator().getName().equalsIgnoreCase(OPERATOR_NAME); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List updatedOperands = new ArrayList<>(); + + for (SqlNode operand : sqlCall.getOperandList()) { + RelDataType type = deriveRelDatatype(operand); + if (!OPERAND_SQL_TYPE_NAMES.contains(type.getSqlTypeName())) { + SqlNode castOperand = SqlStdOperatorTable.CAST.createCall(POS, + new ArrayList<>(Arrays.asList(operand, VARCHAR_SQL_DATA_TYPE_SPEC))); + updatedOperands.add(castOperand); + } else { + updatedOperands.add(operand); + } + } + return sqlCall.getOperator().createCall(POS, updatedOperands); + } +} diff --git a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java index af23538fb..2bad16a47 100644 --- a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java +++ b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java @@ -714,6 +714,19 @@ public void testDateFormatFunction() { assertEquals(expandedSql, targetSql); } + @Test + public void testConcatWithUnionAndStar() { + RelNode relNode = TestUtils.getHiveToRelConverter().convertSql( + "SELECT * from test.tableA union all SELECT * from test.tableB where concat(current_date(), '|', tableB.a) = 'invalid'"); + RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter(); + String expandedSql = relToTrinoConverter.convert(relNode); + + String expected = "SELECT *\n" + "FROM \"test\".\"tablea\" AS \"tablea\"\n" + "UNION ALL\n" + "SELECT *\n" + + "FROM \"test\".\"tableb\" AS \"tableb\"\n" + + "WHERE \"concat\"(CAST(CURRENT_DATE AS VARCHAR(65535)), '|', CAST(\"tableb\".\"a\" AS VARCHAR(65535))) = 'invalid'"; + assertEquals(expandedSql, expected); + } + @Test public void testConcatFunction() { RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter(); @@ -771,8 +784,9 @@ public void testRegexpTransformation() { assertEquals(expandedSql, targetSql); } + @Test public void testSqlSelectAliasAppenderTransformer() { - //test.tableA(a int, b struct + // test.tableA(a int, b struct RelNode relNode = TestUtils.getHiveToRelConverter().convertSql("SELECT tableA.b.b1 FROM test.tableA where a > 5"); RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter(); String expandedSql = relToTrinoConverter.convert(relNode);