From 74c2ca82dfbdec406230001320b3a347ef96b628 Mon Sep 17 00:00:00 2001 From: Kevin Ge Date: Wed, 31 Jul 2024 16:58:50 -0400 Subject: [PATCH] Correctly handle single type uniontypes in Coral (#507) * fix single uniontypes in Coral * remove SingleUnionFieldReferenceTransformer * remove field reference fix to put in separate PR * spotless * update comments * fix comment + add single uniontype test for RelDataTypeToHiveTypeStringConverter * spotless * improve Javadoc for ExtractUnionFunctionTransformer * use html brackets in javadoc --- .../linkedin/coral/common/TypeConverter.java | 5 +- .../FunctionFieldReferenceOperator.java | 11 +-- .../RelDataTypeToHiveTypeStringConverter.java | 29 ++++++- ...TypeToHiveDataTypeStringConverterTest.java | 16 +++- .../HiveSqlNodeToCoralSqlNodeConverter.java | 6 +- .../SingleUnionFieldReferenceTransformer.java | 49 ----------- .../com/linkedin/coral/spark/CoralSpark.java | 29 +++++-- .../spark/CoralToSparkSqlCallConverter.java | 6 +- .../DataTypeDerivedSqlCallConverter.java | 47 +++++++++++ .../ExtractUnionFunctionTransformer.java | 82 ++++++++++++++++++- .../linkedin/coral/spark/CoralSparkTest.java | 26 +++++- .../com/linkedin/coral/spark/TestUtils.java | 4 +- 12 files changed, 224 insertions(+), 86 deletions(-) delete mode 100644 coral-hive/src/main/java/com/linkedin/coral/transformers/SingleUnionFieldReferenceTransformer.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/DataTypeDerivedSqlCallConverter.java diff --git a/coral-common/src/main/java/com/linkedin/coral/common/TypeConverter.java b/coral-common/src/main/java/com/linkedin/coral/common/TypeConverter.java index 183b4730b..022538eaa 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/TypeConverter.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/TypeConverter.java @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2017-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -147,9 +147,6 @@ public static RelDataType convert(StructTypeInfo structType, final RelDataTypeFa public static RelDataType convert(UnionTypeInfo unionType, RelDataTypeFactory dtFactory) { List fTypes = unionType.getAllUnionObjectTypeInfos().stream() .map(typeInfo -> convert(typeInfo, dtFactory)).collect(Collectors.toList()); - if (fTypes.size() == 1) { - return dtFactory.createTypeWithNullability(fTypes.get(0), true); - } List fNames = IntStream.range(0, unionType.getAllUnionObjectTypeInfos().size()).mapToObj(i -> "field" + i) .collect(Collectors.toList()); fTypes.add(0, dtFactory.createSqlType(SqlTypeName.INTEGER)); diff --git a/coral-common/src/main/java/com/linkedin/coral/common/functions/FunctionFieldReferenceOperator.java b/coral-common/src/main/java/com/linkedin/coral/common/functions/FunctionFieldReferenceOperator.java index ec4e125ea..f947da078 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/functions/FunctionFieldReferenceOperator.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/functions/FunctionFieldReferenceOperator.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -74,15 +74,6 @@ public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, S if (funcType.isStruct()) { return funcType.getField(fieldNameStripQuotes(call.operand(1)), false, false).getType(); } - - // When the first operand is a SqlBasicCall with a non-struct RelDataType and the second operand is `tag_0`, - // such as `extract_union`(`product`.`value`).`tag_0` or (`extract_union`(`product`.`value`).`id`).`tag_0`, - // derived data type is first operand's RelDataType. - // This strategy ensures that RelDataType derivation remains successful for the specified sqlCalls while maintaining backward compatibility. - // Such SqlCalls are transformed {@link com.linkedin.coral.transformers.SingleUnionFieldReferenceTransformer} - if (FunctionFieldReferenceOperator.fieldNameStripQuotes(call.operand(1)).equalsIgnoreCase("tag_0")) { - return funcType; - } } return super.deriveType(validator, scope, call); } diff --git a/coral-common/src/main/java/com/linkedin/coral/common/utils/RelDataTypeToHiveTypeStringConverter.java b/coral-common/src/main/java/com/linkedin/coral/common/utils/RelDataTypeToHiveTypeStringConverter.java index 1aea5a064..c5a652af2 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/utils/RelDataTypeToHiveTypeStringConverter.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/utils/RelDataTypeToHiveTypeStringConverter.java @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2022-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -41,6 +41,25 @@ public class RelDataTypeToHiveTypeStringConverter { private RelDataTypeToHiveTypeStringConverter() { } + public RelDataTypeToHiveTypeStringConverter(boolean convertUnionTypes) { + this.convertUnionTypes = convertUnionTypes; + } + + /** + * If true, Coral will convert single uniontypes back to Hive's native uniontype representation. This is necessary + * because some engines have readers that unwrap Hive single uniontypes to just the underlying data type, causing + * the loss of information that the column was originally a uniontype in Hive. This can be problematic when calling + * the `coalesce_struct` UDF on such columns, as they are expected to be treated as uniontypes. Retaining the + * original uniontype record and passing it into `coalesce_struct` ensures correct handling. + * + * Example: + * RelDataType: + * struct(tag:integer,field0:varchar) + * Hive Type String: + * uniontype<string> + */ + private static boolean convertUnionTypes = false; + /** * @param relDataType a given RelDataType * @return a syntactically and semantically correct Hive type string for relDataType @@ -110,6 +129,14 @@ public static String convertRelDataType(RelDataType relDataType) { */ private static String buildStructDataTypeString(RelRecordType relRecordType) { List structFieldStrings = new ArrayList<>(); + + // Convert single uniontypes as structs back to native Hive representation + if (convertUnionTypes && relRecordType.getFieldList().size() == 2 + && relRecordType.getFieldList().get(0).getName().equals("tag") + && relRecordType.getFieldList().get(1).getName().equals("field0")) { + return String.format("uniontype<%s>", convertRelDataType(relRecordType.getFieldList().get(1).getType())); + } + for (RelDataTypeField fieldRelDataType : relRecordType.getFieldList()) { structFieldStrings .add(String.format("%s:%s", fieldRelDataType.getName(), convertRelDataType(fieldRelDataType.getType()))); diff --git a/coral-common/src/test/java/com/linkedin/coral/common/utils/RelDataTypeToHiveDataTypeStringConverterTest.java b/coral-common/src/test/java/com/linkedin/coral/common/utils/RelDataTypeToHiveDataTypeStringConverterTest.java index e2c5ca637..0ac88bcc9 100644 --- a/coral-common/src/test/java/com/linkedin/coral/common/utils/RelDataTypeToHiveDataTypeStringConverterTest.java +++ b/coral-common/src/test/java/com/linkedin/coral/common/utils/RelDataTypeToHiveDataTypeStringConverterTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2019-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -178,4 +178,18 @@ public void testCharRelDataType() { assertEquals(hiveDataTypeSchemaString, expectedHiveDataTypeSchemaString); } + + @Test + public void testSingleUniontypeStructRelDataType() { + String expectedHiveDataTypeSchemaString = "uniontype"; + + List fields = new ArrayList<>(); + fields.add(new RelDataTypeFieldImpl("tag", 0, new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER))); + fields.add(new RelDataTypeFieldImpl("field0", 0, new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.VARCHAR))); + + RelRecordType relRecordType = new RelRecordType(fields); + String hiveDataTypeSchemaString = new RelDataTypeToHiveTypeStringConverter(true).convertRelDataType(relRecordType); + + assertEquals(hiveDataTypeSchemaString, expectedHiveDataTypeSchemaString); + } } diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveSqlNodeToCoralSqlNodeConverter.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveSqlNodeToCoralSqlNodeConverter.java index ab2fd8c32..8525a62f8 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveSqlNodeToCoralSqlNodeConverter.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveSqlNodeToCoralSqlNodeConverter.java @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2017-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -13,7 +13,6 @@ import com.linkedin.coral.common.transformers.SqlCallTransformers; import com.linkedin.coral.common.utils.TypeDerivationUtil; import com.linkedin.coral.transformers.ShiftArrayIndexTransformer; -import com.linkedin.coral.transformers.SingleUnionFieldReferenceTransformer; /** @@ -24,8 +23,7 @@ public class HiveSqlNodeToCoralSqlNodeConverter extends SqlShuttle { public HiveSqlNodeToCoralSqlNodeConverter(SqlValidator sqlValidator, SqlNode topSqlNode) { TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(sqlValidator, topSqlNode); - operatorTransformerList = SqlCallTransformers.of(new ShiftArrayIndexTransformer(typeDerivationUtil), - new SingleUnionFieldReferenceTransformer(typeDerivationUtil)); + operatorTransformerList = SqlCallTransformers.of(new ShiftArrayIndexTransformer(typeDerivationUtil)); } @Override diff --git a/coral-hive/src/main/java/com/linkedin/coral/transformers/SingleUnionFieldReferenceTransformer.java b/coral-hive/src/main/java/com/linkedin/coral/transformers/SingleUnionFieldReferenceTransformer.java deleted file mode 100644 index 32c721789..000000000 --- a/coral-hive/src/main/java/com/linkedin/coral/transformers/SingleUnionFieldReferenceTransformer.java +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2023 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.coral.transformers; - -import org.apache.calcite.sql.SqlBasicCall; -import org.apache.calcite.sql.SqlCall; - -import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator; -import com.linkedin.coral.common.transformers.SqlCallTransformer; -import com.linkedin.coral.common.utils.TypeDerivationUtil; - - -/** - * This transformer focuses on SqlCalls that involve a FunctionFieldReferenceOperator with the following characteristics: - * (1) The first operand is a SqlBasicCall with a non-struct RelDataType, and the second operand is tag_0. - * This indicates that the first operand represents a Union data type with a single data type inside. - * (2) Examples of such SqlCalls include extract_union(product.value).tag_0 or (extract_union(product.value).id).tag_0. - * (3) The transformation for such SqlCalls is to return the first operand. - */ -public class SingleUnionFieldReferenceTransformer extends SqlCallTransformer { - private static final String TAG_0_OPERAND = "tag_0"; - - public SingleUnionFieldReferenceTransformer(TypeDerivationUtil typeDerivationUtil) { - super(typeDerivationUtil); - } - - @Override - protected boolean condition(SqlCall sqlCall) { - if (FunctionFieldReferenceOperator.DOT.getName().equalsIgnoreCase(sqlCall.getOperator().getName())) { - if (sqlCall.operand(0) instanceof SqlBasicCall) { - SqlBasicCall outerSqlBasicCall = sqlCall.operand(0); - boolean isOperandStruct = deriveRelDatatype(outerSqlBasicCall).isStruct(); - - return !isOperandStruct - && FunctionFieldReferenceOperator.fieldNameStripQuotes(sqlCall.operand(1)).equalsIgnoreCase(TAG_0_OPERAND); - } - } - return false; - } - - @Override - protected SqlCall transform(SqlCall sqlCall) { - // convert x.tag_0 -> x where x is a sqlBasicCall with non-struct RelDataType - return sqlCall.operand(0); - } -} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java index 8ffba3f71..9ef1a23f2 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -13,6 +13,7 @@ import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlSelect; @@ -76,7 +77,7 @@ public static CoralSpark create(RelNode irRelNode, HiveMetastoreClient hmsClient SparkRelInfo sparkRelInfo = IRRelToSparkRelTransformer.transform(irRelNode); Set sparkUDFInfos = sparkRelInfo.getSparkUDFInfos(); RelNode sparkRelNode = sparkRelInfo.getSparkRelNode(); - SqlNode sparkSqlNode = constructSparkSqlNode(sparkRelNode, sparkUDFInfos); + SqlNode sparkSqlNode = constructSparkSqlNode(sparkRelNode, sparkUDFInfos, hmsClient); String sparkSQL = constructSparkSQL(sparkSqlNode); List baseTables = constructBaseTables(sparkRelNode); return new CoralSpark(baseTables, ImmutableList.copyOf(sparkUDFInfos), sparkSQL, hmsClient, sparkSqlNode); @@ -101,11 +102,11 @@ private static CoralSpark createWithAlias(RelNode irRelNode, List aliase SparkRelInfo sparkRelInfo = IRRelToSparkRelTransformer.transform(irRelNode); Set sparkUDFInfos = sparkRelInfo.getSparkUDFInfos(); RelNode sparkRelNode = sparkRelInfo.getSparkRelNode(); - SqlNode sparkSqlNode = constructSparkSqlNode(sparkRelNode, sparkUDFInfos); + SqlNode sparkSqlNode = constructSparkSqlNode(sparkRelNode, sparkUDFInfos, hmsClient); // Use a second pass visit to add explicit alias names, // only do this when it's not a select star case, // since for select star we don't need to add any explicit aliases - if (sparkSqlNode.getKind() == SqlKind.SELECT && ((SqlSelect) sparkSqlNode).getSelectList() != null) { + if (sparkSqlNode.getKind() == SqlKind.SELECT && !isSelectStar(sparkSqlNode)) { sparkSqlNode = sparkSqlNode.accept(new AddExplicitAlias(aliases)); } String sparkSQL = constructSparkSQL(sparkSqlNode); @@ -113,11 +114,16 @@ private static CoralSpark createWithAlias(RelNode irRelNode, List aliase return new CoralSpark(baseTables, ImmutableList.copyOf(sparkUDFInfos), sparkSQL, hmsClient, sparkSqlNode); } - private static SqlNode constructSparkSqlNode(RelNode sparkRelNode, Set sparkUDFInfos) { + private static SqlNode constructSparkSqlNode(RelNode sparkRelNode, Set sparkUDFInfos, + HiveMetastoreClient hmsClient) { CoralRelToSqlNodeConverter rel2sql = new CoralRelToSqlNodeConverter(); SqlNode coralSqlNode = rel2sql.convert(sparkRelNode); - SqlNode sparkSqlNode = coralSqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter()) - .accept(new CoralToSparkSqlCallConverter(sparkUDFInfos)); + + SqlNode coralSqlNodeWithRelDataTypeDerivedConversions = + coralSqlNode.accept(new DataTypeDerivedSqlCallConverter(hmsClient, coralSqlNode, sparkUDFInfos)); + + SqlNode sparkSqlNode = coralSqlNodeWithRelDataTypeDerivedConversions + .accept(new CoralSqlNodeToSparkSqlNodeConverter()).accept(new CoralToSparkSqlCallConverter(sparkUDFInfos)); return sparkSqlNode.accept(new SparkSqlRewriter()); } @@ -125,6 +131,15 @@ public static String constructSparkSQL(SqlNode sparkSqlNode) { return sparkSqlNode.toSqlString(SparkSqlDialect.INSTANCE).getSql(); } + private static boolean isSelectStar(SqlNode node) { + if (node.getKind() == SqlKind.SELECT && ((SqlSelect) node).getSelectList().size() == 1 + && ((SqlSelect) node).getSelectList().get(0) instanceof SqlIdentifier) { + SqlIdentifier identifier = (SqlIdentifier) ((SqlSelect) node).getSelectList().get(0); + return identifier.isStar(); + } + return false; + } + /** * This function returns the list of base table names, in the format * "database_name.table_name". diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java index c8a86d35c..5170507c8 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java @@ -1,5 +1,5 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -15,7 +15,6 @@ import com.linkedin.coral.common.transformers.OperatorRenameSqlCallTransformer; import com.linkedin.coral.common.transformers.SqlCallTransformers; import com.linkedin.coral.spark.containers.SparkUDFInfo; -import com.linkedin.coral.spark.transformers.ExtractUnionFunctionTransformer; import com.linkedin.coral.spark.transformers.FallBackToLinkedInHiveUDFTransformer; import com.linkedin.coral.spark.transformers.FuzzyUnionGenericProjectTransformer; import com.linkedin.coral.spark.transformers.TransportUDFTransformer; @@ -157,9 +156,6 @@ public CoralToSparkSqlCallConverter(Set sparkUDFInfos) { // Fall back to the original Hive UDF defined in StaticHiveFunctionRegistry after failing to apply transformers above new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos), - // Transform `extract_union` to `coalesce_struct` - new ExtractUnionFunctionTransformer(sparkUDFInfos), - // Transform `generic_project` function new FuzzyUnionGenericProjectTransformer(sparkUDFInfos)); } diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/DataTypeDerivedSqlCallConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/DataTypeDerivedSqlCallConverter.java new file mode 100644 index 000000000..079031a7f --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/DataTypeDerivedSqlCallConverter.java @@ -0,0 +1,47 @@ +/** + * Copyright 2022-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark; + +import java.util.Set; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.util.SqlShuttle; + +import com.linkedin.coral.common.HiveMetastoreClient; +import com.linkedin.coral.common.transformers.SqlCallTransformers; +import com.linkedin.coral.common.utils.TypeDerivationUtil; +import com.linkedin.coral.hive.hive2rel.HiveToRelConverter; +import com.linkedin.coral.spark.containers.SparkUDFInfo; +import com.linkedin.coral.spark.transformers.ExtractUnionFunctionTransformer; + + +/** + * DataTypeDerivedSqlCallConverter transforms the sqlCalls + * in the input SqlNode representation to be compatible with Spark engine. + * The transformation may involve change in operator, reordering the operands + * or even re-constructing the SqlNode. + * + * All the transformations performed as part of this shuttle require RelDataType derivation. + */ +public class DataTypeDerivedSqlCallConverter extends SqlShuttle { + private final SqlCallTransformers operatorTransformerList; + private final HiveToRelConverter toRelConverter; + TypeDerivationUtil typeDerivationUtil; + + public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode topSqlNode, + Set sparkUDFInfos) { + toRelConverter = new HiveToRelConverter(mscClient); + typeDerivationUtil = new TypeDerivationUtil(toRelConverter.getSqlValidator(), topSqlNode); + operatorTransformerList = + SqlCallTransformers.of(new ExtractUnionFunctionTransformer(typeDerivationUtil, sparkUDFInfos)); + } + + @Override + public SqlNode visit(SqlCall call) { + return operatorTransformerList.apply((SqlCall) super.visit(call)); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java index 27d6884b1..9572bcbbd 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java @@ -1,22 +1,31 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ package com.linkedin.coral.spark.transformers; import java.net.URI; +import java.util.ArrayList; import java.util.List; import java.util.Set; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNumericLiteral; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; import com.linkedin.coral.com.google.common.collect.ImmutableList; +import com.linkedin.coral.common.TypeConverter; import com.linkedin.coral.common.transformers.SqlCallTransformer; +import com.linkedin.coral.common.utils.RelDataTypeToHiveTypeStringConverter; +import com.linkedin.coral.common.utils.TypeDerivationUtil; import com.linkedin.coral.hive.hive2rel.functions.CoalesceStructUtility; import com.linkedin.coral.spark.containers.SparkUDFInfo; @@ -30,14 +39,34 @@ * See {@link CoalesceStructUtility#COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY} and its comments for more details. * * Check `CoralSparkTest#testUnionExtractUDF` for examples. + * + * Note that there is a Spark-specific mechanism that unwraps a single uniontype (a uniontype holding only one data type) + * to simply the single underlying data type. This behavior is specific during the Avro schema to Spark schema conversion + * in base tables. The problem with this behavior is we expect `coalesce_struct` to coalesce columns that originally contained + * single uniontypes, yet lose this information after Spark gets rid of the uniontype. To work around this, we retain information + * about the original schema and pass it to `coalesce_struct` UDF as a schema string. + * Reference: https://spark.apache.org/docs/latest/sql-data-sources-avro.html#supported-types-for-avro---spark-sql-conversion + * + * For example, if we have an input SqlNode like so, where `col` is a uniontype column holding only string type: + * "SELECT extract_union(col) FROM table" + * + * This transformer would transform the above SqlNode to: + * "SELECT coalesce_struct(col, 'uniontype<string>') FROM table" + * + * Check `CoralSparkTest#testUnionExtractUDFOnSingleTypeUnions` for more examples including examples where we have single + * uniontypes nested in a struct. + * */ public class ExtractUnionFunctionTransformer extends SqlCallTransformer { private static final String EXTRACT_UNION = "extract_union"; private static final String COALESCE_STRUCT = "coalesce_struct"; private final Set sparkUDFInfos; + private static final RelDataTypeToHiveTypeStringConverter hiveTypeStringConverter = + new RelDataTypeToHiveTypeStringConverter(true); - public ExtractUnionFunctionTransformer(Set sparkUDFInfos) { + public ExtractUnionFunctionTransformer(TypeDerivationUtil typeDerivationUtil, Set sparkUDFInfos) { + super(typeDerivationUtil); this.sparkUDFInfos = sparkUDFInfos; } @@ -56,6 +85,18 @@ protected SqlCall transform(SqlCall sqlCall) { createSqlOperator(COALESCE_STRUCT, CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY); if (operandList.size() == 1) { // one arg case: extract_union(field_name) + RelDataType operandType = deriveRelDatatype(sqlCall.operand(0)); + + if (containsSingleUnionType(operandType)) { + // Pass in schema string to keep track of the original Hive schema containing single uniontypes so coalesce_struct + // UDF knows which fields are unwrapped single uniontypes. This is needed otherwise coalesce_struct would + // not coalesce the single uniontype fields as expected. + String operandSchemaString = hiveTypeStringConverter.convertRelDataType(deriveRelDatatype(sqlCall.operand(0))); + List newOperandList = new ArrayList<>(operandList); + newOperandList.add(SqlLiteral.createCharString(operandSchemaString, SqlParserPos.ZERO)); + return coalesceStructFunction.createCall(sqlCall.getParserPosition(), newOperandList); + } + return coalesceStructFunction.createCall(sqlCall.getParserPosition(), operandList); } else if (operandList.size() == 2) { // two arg case: extract_union(field_name, ordinal) @@ -66,4 +107,41 @@ protected SqlCall transform(SqlCall sqlCall) { return sqlCall; } } + + private boolean containsSingleUnionType(RelDataType relDataType) { + if (isSingleUnionType(relDataType)) { + return true; + } + + // Recursive case: if the current type is a struct, map or collection, check its nested types + if (relDataType.isStruct()) { + for (RelDataTypeField field : relDataType.getFieldList()) { + if (containsSingleUnionType(field.getType())) { + return true; + } + } + } else if (relDataType.getKeyType() != null) { + // For map type, check both key and value types + if (containsSingleUnionType(relDataType.getKeyType()) || containsSingleUnionType(relDataType.getValueType())) { + return true; + } + } else if (relDataType.getComponentType() != null) { + // For collection type, check the component type + if (containsSingleUnionType(relDataType.getComponentType())) { + return true; + } + } + + return false; + } + + /** + * Check if the given RelDataType is a single union type in Coral IR representation, the conversion to which happens in + * {@link TypeConverter#convert(UnionTypeInfo, RelDataTypeFactory)} + */ + private boolean isSingleUnionType(RelDataType relDataType) { + return relDataType.isStruct() && relDataType.getFieldList().size() == 2 + && relDataType.getFieldList().get(0).getKey().equalsIgnoreCase("tag") + && relDataType.getFieldList().get(1).getKey().equalsIgnoreCase("field0"); + } } diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java index 666ea87e3..454ea0ac7 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -421,6 +421,30 @@ public void testUnionExtractUDF() { assertEquals(createCoralSpark(relNode2).getSparkSql(), targetSql2); } + @Test + public void testUnionExtractUDFOnSingleTypeUnions() { + RelNode relNode = TestUtils.toRelNode("SELECT extract_union(bar) from union_table"); + String targetSql = "SELECT coalesce_struct(union_table.bar, 'uniontype>')\n" + + "FROM default.union_table union_table"; + assertEquals(createCoralSpark(relNode).getSparkSql(), targetSql); + + RelNode relNode1 = TestUtils.toRelNode("SELECT extract_union(baz) from union_table"); + String targetSql1 = "SELECT coalesce_struct(union_table.baz, 'struct>>')\n" + + "FROM default.union_table union_table"; + assertEquals(createCoralSpark(relNode1).getSparkSql(), targetSql1); + + RelNode relNode2 = TestUtils.toRelNode("SELECT extract_union(bar).tag_0 from union_table"); + String targetSql2 = "SELECT coalesce_struct(union_table.bar, 'uniontype>').tag_0\n" + + "FROM default.union_table union_table"; + assertEquals(createCoralSpark(relNode2).getSparkSql(), targetSql2); + + RelNode relNode3 = TestUtils.toRelNode("SELECT extract_union(baz).single.tag_0 from union_table"); + String targetSql4 = + "SELECT (coalesce_struct(union_table.baz, 'struct>>').single).tag_0\n" + + "FROM default.union_table union_table"; + assertEquals(createCoralSpark(relNode3).getSparkSql(), targetSql4); + } + @Test public void testDateFunction() { RelNode relNode = TestUtils.toRelNode("SELECT date('2021-01-02') as a FROM foo"); diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java b/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java index 845e4ba39..7f4ae0617 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -223,7 +223,7 @@ public static void initializeViews(HiveConf conf) throws HiveException, MetaExce run(driver, String.join("\n", "", "ALTER TABLE schema_promotion CHANGE COLUMN b b array")); run(driver, - "CREATE TABLE IF NOT EXISTS union_table(foo uniontype, struct>)"); + "CREATE TABLE IF NOT EXISTS union_table(foo uniontype, struct>, bar uniontype>, baz struct>>)"); run(driver, "CREATE TABLE IF NOT EXISTS nested_union(a uniontype, b:int>>)");