Skip to content

Commit

Permalink
Correctly handle single type uniontypes in Coral (#507)
Browse files Browse the repository at this point in the history
* fix single uniontypes in Coral

* remove SingleUnionFieldReferenceTransformer

* remove field reference fix to put in separate PR

* spotless

* update comments

* fix comment + add single uniontype test for RelDataTypeToHiveTypeStringConverter

* spotless

* improve Javadoc for ExtractUnionFunctionTransformer

* use html brackets in javadoc
  • Loading branch information
KevinGe00 authored Jul 31, 2024
1 parent d1d5b1e commit 74c2ca8
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 86 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 LinkedIn Corporation. All rights reserved.
* Copyright 2017-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -147,9 +147,6 @@ public static RelDataType convert(StructTypeInfo structType, final RelDataTypeFa
public static RelDataType convert(UnionTypeInfo unionType, RelDataTypeFactory dtFactory) {
List<RelDataType> fTypes = unionType.getAllUnionObjectTypeInfos().stream()
.map(typeInfo -> convert(typeInfo, dtFactory)).collect(Collectors.toList());
if (fTypes.size() == 1) {
return dtFactory.createTypeWithNullability(fTypes.get(0), true);
}
List<String> fNames = IntStream.range(0, unionType.getAllUnionObjectTypeInfos().size()).mapToObj(i -> "field" + i)
.collect(Collectors.toList());
fTypes.add(0, dtFactory.createSqlType(SqlTypeName.INTEGER));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -74,15 +74,6 @@ public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, S
if (funcType.isStruct()) {
return funcType.getField(fieldNameStripQuotes(call.operand(1)), false, false).getType();
}

// When the first operand is a SqlBasicCall with a non-struct RelDataType and the second operand is `tag_0`,
// such as `extract_union`(`product`.`value`).`tag_0` or (`extract_union`(`product`.`value`).`id`).`tag_0`,
// derived data type is first operand's RelDataType.
// This strategy ensures that RelDataType derivation remains successful for the specified sqlCalls while maintaining backward compatibility.
// Such SqlCalls are transformed {@link com.linkedin.coral.transformers.SingleUnionFieldReferenceTransformer}
if (FunctionFieldReferenceOperator.fieldNameStripQuotes(call.operand(1)).equalsIgnoreCase("tag_0")) {
return funcType;
}
}
return super.deriveType(validator, scope, call);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2022-2023 LinkedIn Corporation. All rights reserved.
* Copyright 2022-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -41,6 +41,25 @@ public class RelDataTypeToHiveTypeStringConverter {
private RelDataTypeToHiveTypeStringConverter() {
}

public RelDataTypeToHiveTypeStringConverter(boolean convertUnionTypes) {
this.convertUnionTypes = convertUnionTypes;
}

/**
* If true, Coral will convert single uniontypes back to Hive's native uniontype representation. This is necessary
* because some engines have readers that unwrap Hive single uniontypes to just the underlying data type, causing
* the loss of information that the column was originally a uniontype in Hive. This can be problematic when calling
* the `coalesce_struct` UDF on such columns, as they are expected to be treated as uniontypes. Retaining the
* original uniontype record and passing it into `coalesce_struct` ensures correct handling.
*
* Example:
* RelDataType:
* struct(tag:integer,field0:varchar)
* Hive Type String:
* uniontype&lt;string&gt;
*/
private static boolean convertUnionTypes = false;

/**
* @param relDataType a given RelDataType
* @return a syntactically and semantically correct Hive type string for relDataType
Expand Down Expand Up @@ -110,6 +129,14 @@ public static String convertRelDataType(RelDataType relDataType) {
*/
private static String buildStructDataTypeString(RelRecordType relRecordType) {
List<String> structFieldStrings = new ArrayList<>();

// Convert single uniontypes as structs back to native Hive representation
if (convertUnionTypes && relRecordType.getFieldList().size() == 2
&& relRecordType.getFieldList().get(0).getName().equals("tag")
&& relRecordType.getFieldList().get(1).getName().equals("field0")) {
return String.format("uniontype<%s>", convertRelDataType(relRecordType.getFieldList().get(1).getType()));
}

for (RelDataTypeField fieldRelDataType : relRecordType.getFieldList()) {
structFieldStrings
.add(String.format("%s:%s", fieldRelDataType.getName(), convertRelDataType(fieldRelDataType.getType())));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2019-2023 LinkedIn Corporation. All rights reserved.
* Copyright 2019-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -178,4 +178,18 @@ public void testCharRelDataType() {

assertEquals(hiveDataTypeSchemaString, expectedHiveDataTypeSchemaString);
}

@Test
public void testSingleUniontypeStructRelDataType() {
String expectedHiveDataTypeSchemaString = "uniontype<string>";

List<RelDataTypeField> fields = new ArrayList<>();
fields.add(new RelDataTypeFieldImpl("tag", 0, new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER)));
fields.add(new RelDataTypeFieldImpl("field0", 0, new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.VARCHAR)));

RelRecordType relRecordType = new RelRecordType(fields);
String hiveDataTypeSchemaString = new RelDataTypeToHiveTypeStringConverter(true).convertRelDataType(relRecordType);

assertEquals(hiveDataTypeSchemaString, expectedHiveDataTypeSchemaString);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 LinkedIn Corporation. All rights reserved.
* Copyright 2017-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand All @@ -13,7 +13,6 @@
import com.linkedin.coral.common.transformers.SqlCallTransformers;
import com.linkedin.coral.common.utils.TypeDerivationUtil;
import com.linkedin.coral.transformers.ShiftArrayIndexTransformer;
import com.linkedin.coral.transformers.SingleUnionFieldReferenceTransformer;


/**
Expand All @@ -24,8 +23,7 @@ public class HiveSqlNodeToCoralSqlNodeConverter extends SqlShuttle {

public HiveSqlNodeToCoralSqlNodeConverter(SqlValidator sqlValidator, SqlNode topSqlNode) {
TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(sqlValidator, topSqlNode);
operatorTransformerList = SqlCallTransformers.of(new ShiftArrayIndexTransformer(typeDerivationUtil),
new SingleUnionFieldReferenceTransformer(typeDerivationUtil));
operatorTransformerList = SqlCallTransformers.of(new ShiftArrayIndexTransformer(typeDerivationUtil));
}

@Override
Expand Down

This file was deleted.

29 changes: 22 additions & 7 deletions coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand All @@ -13,6 +13,7 @@
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlSelect;
Expand Down Expand Up @@ -76,7 +77,7 @@ public static CoralSpark create(RelNode irRelNode, HiveMetastoreClient hmsClient
SparkRelInfo sparkRelInfo = IRRelToSparkRelTransformer.transform(irRelNode);
Set<SparkUDFInfo> sparkUDFInfos = sparkRelInfo.getSparkUDFInfos();
RelNode sparkRelNode = sparkRelInfo.getSparkRelNode();
SqlNode sparkSqlNode = constructSparkSqlNode(sparkRelNode, sparkUDFInfos);
SqlNode sparkSqlNode = constructSparkSqlNode(sparkRelNode, sparkUDFInfos, hmsClient);
String sparkSQL = constructSparkSQL(sparkSqlNode);
List<String> baseTables = constructBaseTables(sparkRelNode);
return new CoralSpark(baseTables, ImmutableList.copyOf(sparkUDFInfos), sparkSQL, hmsClient, sparkSqlNode);
Expand All @@ -101,30 +102,44 @@ private static CoralSpark createWithAlias(RelNode irRelNode, List<String> aliase
SparkRelInfo sparkRelInfo = IRRelToSparkRelTransformer.transform(irRelNode);
Set<SparkUDFInfo> sparkUDFInfos = sparkRelInfo.getSparkUDFInfos();
RelNode sparkRelNode = sparkRelInfo.getSparkRelNode();
SqlNode sparkSqlNode = constructSparkSqlNode(sparkRelNode, sparkUDFInfos);
SqlNode sparkSqlNode = constructSparkSqlNode(sparkRelNode, sparkUDFInfos, hmsClient);
// Use a second pass visit to add explicit alias names,
// only do this when it's not a select star case,
// since for select star we don't need to add any explicit aliases
if (sparkSqlNode.getKind() == SqlKind.SELECT && ((SqlSelect) sparkSqlNode).getSelectList() != null) {
if (sparkSqlNode.getKind() == SqlKind.SELECT && !isSelectStar(sparkSqlNode)) {
sparkSqlNode = sparkSqlNode.accept(new AddExplicitAlias(aliases));
}
String sparkSQL = constructSparkSQL(sparkSqlNode);
List<String> baseTables = constructBaseTables(sparkRelNode);
return new CoralSpark(baseTables, ImmutableList.copyOf(sparkUDFInfos), sparkSQL, hmsClient, sparkSqlNode);
}

private static SqlNode constructSparkSqlNode(RelNode sparkRelNode, Set<SparkUDFInfo> sparkUDFInfos) {
private static SqlNode constructSparkSqlNode(RelNode sparkRelNode, Set<SparkUDFInfo> sparkUDFInfos,
HiveMetastoreClient hmsClient) {
CoralRelToSqlNodeConverter rel2sql = new CoralRelToSqlNodeConverter();
SqlNode coralSqlNode = rel2sql.convert(sparkRelNode);
SqlNode sparkSqlNode = coralSqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter())
.accept(new CoralToSparkSqlCallConverter(sparkUDFInfos));

SqlNode coralSqlNodeWithRelDataTypeDerivedConversions =
coralSqlNode.accept(new DataTypeDerivedSqlCallConverter(hmsClient, coralSqlNode, sparkUDFInfos));

SqlNode sparkSqlNode = coralSqlNodeWithRelDataTypeDerivedConversions
.accept(new CoralSqlNodeToSparkSqlNodeConverter()).accept(new CoralToSparkSqlCallConverter(sparkUDFInfos));
return sparkSqlNode.accept(new SparkSqlRewriter());
}

public static String constructSparkSQL(SqlNode sparkSqlNode) {
return sparkSqlNode.toSqlString(SparkSqlDialect.INSTANCE).getSql();
}

private static boolean isSelectStar(SqlNode node) {
if (node.getKind() == SqlKind.SELECT && ((SqlSelect) node).getSelectList().size() == 1
&& ((SqlSelect) node).getSelectList().get(0) instanceof SqlIdentifier) {
SqlIdentifier identifier = (SqlIdentifier) ((SqlSelect) node).getSelectList().get(0);
return identifier.isStar();
}
return false;
}

/**
* This function returns the list of base table names, in the format
* "database_name.table_name".
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Copyright 2023-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand All @@ -15,7 +15,6 @@
import com.linkedin.coral.common.transformers.OperatorRenameSqlCallTransformer;
import com.linkedin.coral.common.transformers.SqlCallTransformers;
import com.linkedin.coral.spark.containers.SparkUDFInfo;
import com.linkedin.coral.spark.transformers.ExtractUnionFunctionTransformer;
import com.linkedin.coral.spark.transformers.FallBackToLinkedInHiveUDFTransformer;
import com.linkedin.coral.spark.transformers.FuzzyUnionGenericProjectTransformer;
import com.linkedin.coral.spark.transformers.TransportUDFTransformer;
Expand Down Expand Up @@ -157,9 +156,6 @@ public CoralToSparkSqlCallConverter(Set<SparkUDFInfo> sparkUDFInfos) {
// Fall back to the original Hive UDF defined in StaticHiveFunctionRegistry after failing to apply transformers above
new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos),

// Transform `extract_union` to `coalesce_struct`
new ExtractUnionFunctionTransformer(sparkUDFInfos),

// Transform `generic_project` function
new FuzzyUnionGenericProjectTransformer(sparkUDFInfos));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/**
* Copyright 2022-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.spark;

import java.util.Set;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.util.SqlShuttle;

import com.linkedin.coral.common.HiveMetastoreClient;
import com.linkedin.coral.common.transformers.SqlCallTransformers;
import com.linkedin.coral.common.utils.TypeDerivationUtil;
import com.linkedin.coral.hive.hive2rel.HiveToRelConverter;
import com.linkedin.coral.spark.containers.SparkUDFInfo;
import com.linkedin.coral.spark.transformers.ExtractUnionFunctionTransformer;


/**
* DataTypeDerivedSqlCallConverter transforms the sqlCalls
* in the input SqlNode representation to be compatible with Spark engine.
* The transformation may involve change in operator, reordering the operands
* or even re-constructing the SqlNode.
*
* All the transformations performed as part of this shuttle require RelDataType derivation.
*/
public class DataTypeDerivedSqlCallConverter extends SqlShuttle {
private final SqlCallTransformers operatorTransformerList;
private final HiveToRelConverter toRelConverter;
TypeDerivationUtil typeDerivationUtil;

public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode topSqlNode,
Set<SparkUDFInfo> sparkUDFInfos) {
toRelConverter = new HiveToRelConverter(mscClient);
typeDerivationUtil = new TypeDerivationUtil(toRelConverter.getSqlValidator(), topSqlNode);
operatorTransformerList =
SqlCallTransformers.of(new ExtractUnionFunctionTransformer(typeDerivationUtil, sparkUDFInfos));
}

@Override
public SqlNode visit(SqlCall call) {
return operatorTransformerList.apply((SqlCall) super.visit(call));
}
}
Loading

0 comments on commit 74c2ca8

Please sign in to comment.