diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/OrdinalReturnTypeInferenceV2.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/OrdinalReturnTypeInferenceV2.java index 55a5da4da..897dc83a1 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/OrdinalReturnTypeInferenceV2.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/OrdinalReturnTypeInferenceV2.java @@ -2,7 +2,10 @@ import org.apache.calcite.sql.type.OrdinalReturnTypeInference; - +/** + * Custom implementation of {@link OrdinalReturnTypeInference} which allows inferring the return type + * based on the ordinal of a given input argument and also exposes the ordinal. + */ public class OrdinalReturnTypeInferenceV2 extends OrdinalReturnTypeInference { private final int ordinal; diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java index c94b37d70..98c5fbe5d 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java @@ -665,8 +665,10 @@ public boolean isOptional(int i) { STRING_STRING_STRING); createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.HasMemberConsent", ReturnTypes.BOOLEAN, family(SqlTypeFamily.STRING, SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP)); - createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactFieldIf", new OrdinalReturnTypeInferenceV2(1), + createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactFieldIf", ARG1, family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.ANY)); + createAddUserDefinedFunction("cast_nullability", new OrdinalReturnTypeInferenceV2(1), + family(SqlTypeFamily.ANY, SqlTypeFamily.ANY)); createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactSecondarySchemaFieldIf", ARG1, family( SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ARRAY, SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY)); diff --git a/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelDataTypeToAvroType.java b/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelDataTypeToAvroType.java index 5b2f04406..53963c321 100644 --- a/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelDataTypeToAvroType.java +++ b/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelDataTypeToAvroType.java @@ -84,6 +84,8 @@ static Schema relDataTypeToAvroTypeNonNullable(@Nonnull RelDataType relDataType, private static Schema relDataTypeToAvroType(RelDataType relDataType, String recordName) { final Schema avroSchema = relDataTypeToAvroTypeNonNullable(relDataType, recordName); + // TODO: Current logic ALWAYS sets the inner fields of RelDataType record nullable. + // Modify this to be applied only when RelDataType record was generated from a HIVE_UDF RexCall return SchemaUtilities.makeNullable(avroSchema, false); } diff --git a/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelToAvroSchemaConverter.java b/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelToAvroSchemaConverter.java index 6c94ac56c..f3b791394 100644 --- a/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelToAvroSchemaConverter.java +++ b/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelToAvroSchemaConverter.java @@ -5,6 +5,7 @@ */ package com.linkedin.coral.schema.avro; +import com.linkedin.coral.hive.hive2rel.functions.OrdinalReturnTypeInferenceV2; import java.util.Deque; import java.util.HashMap; import java.util.LinkedList; @@ -407,14 +408,7 @@ public SchemaRexShuttle(Schema inputSchema, RelNode inputNode, Queue sug @Override public RexNode visitInputRef(RexInputRef rexInputRef) { RexNode rexNode = super.visitInputRef(rexInputRef); - - Schema.Field field = inputSchema.getFields().get(rexInputRef.getIndex()); - String oldFieldName = field.name(); - String suggestNewFieldName = suggestedFieldNames.poll(); - String newFieldName = SchemaUtilities.getFieldName(oldFieldName, suggestNewFieldName); - - SchemaUtilities.appendField(newFieldName, field, fieldAssembler); - + appendRexInputRefField(rexInputRef); return rexNode; } @@ -442,6 +436,23 @@ public RexNode visitCall(RexCall rexCall) { * For SqlUserDefinedFunction and SqlOperator RexCall, no need to handle it recursively * and only return type of udf or sql operator is relevant */ + + /** + * If the return type of RexCall is based on the ordinal of its input argument + * and the corresponding input argument refers to a field from the input schema, + * use the field's schema as is. + */ + if (rexCall.getOperator().getReturnTypeInference() instanceof OrdinalReturnTypeInferenceV2) { + int index = ((OrdinalReturnTypeInferenceV2) rexCall.getOperator().getReturnTypeInference()) + .getOrdinal(); + RexNode operand = rexCall.operands.get(index); + + if (operand instanceof RexInputRef) { + appendRexInputRefField((RexInputRef) operand); + return rexCall; + } + } + RelDataType fieldType = rexCall.getType(); boolean isNullable = SchemaUtilities.isFieldNullable(rexCall, inputSchema); @@ -545,6 +556,15 @@ public RexNode visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) { return super.visitPatternFieldRef(rexPatternFieldRef); } + private void appendRexInputRefField(RexInputRef rexInputRef) { + Schema.Field field = inputSchema.getFields().get(rexInputRef.getIndex()); + String oldFieldName = field.name(); + String suggestNewFieldName = suggestedFieldNames.poll(); + String newFieldName = SchemaUtilities.getFieldName(oldFieldName, suggestNewFieldName); + + SchemaUtilities.appendField(newFieldName, field, fieldAssembler); + } + private void appendField(RelDataType fieldType, boolean isNullable, String doc) { String fieldName = SchemaUtilities.getFieldName("", suggestedFieldNames.poll()); SchemaUtilities.appendField(fieldName, fieldType, doc, fieldAssembler, isNullable); diff --git a/coral-schema/src/main/java/com/linkedin/coral/schema/avro/SchemaUtilities.java b/coral-schema/src/main/java/com/linkedin/coral/schema/avro/SchemaUtilities.java index b4f49c833..cde0a992d 100644 --- a/coral-schema/src/main/java/com/linkedin/coral/schema/avro/SchemaUtilities.java +++ b/coral-schema/src/main/java/com/linkedin/coral/schema/avro/SchemaUtilities.java @@ -5,7 +5,6 @@ */ package com.linkedin.coral.schema.avro; -import com.linkedin.coral.hive.hive2rel.functions.OrdinalReturnTypeInferenceV2; import java.io.PrintWriter; import java.io.StringWriter; import java.util.*; @@ -30,7 +29,6 @@ import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.OrdinalReturnTypeInference; import org.apache.calcite.sql.validate.SqlUserDefinedFunction; import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.hive.metastore.api.FieldSchema; @@ -239,19 +237,6 @@ static boolean isFieldNullable(@Nonnull RexCall rexCall, @Nonnull Schema inputSc return rexCall.getType().isNullable(); } - if (rexCall.getOperator().getReturnTypeInference() instanceof OrdinalReturnTypeInferenceV2) { - int index = ((OrdinalReturnTypeInferenceV2) rexCall.getOperator().getReturnTypeInference()) - .getOrdinal(); - RexNode operand = rexCall.operands.get(index); - - if (operand instanceof RexInputRef) { - Schema schema = inputSchema.getFields().get(((RexInputRef) operand).getIndex()).schema(); - return isNullableType(schema); - } else if (operand instanceof RexCall) { - return isFieldNullable((RexCall) operand, inputSchema); - } - } - // the field is non-nullable only if all operands are RexInputRef // and corresponding field schema type of RexInputRef index is not UNION List operands = rexCall.getOperands(); diff --git a/coral-schema/src/test/java/com/linkedin/coral/schema/avro/TestUtils.java b/coral-schema/src/test/java/com/linkedin/coral/schema/avro/TestUtils.java index 73340ec98..cb79495f4 100644 --- a/coral-schema/src/test/java/com/linkedin/coral/schema/avro/TestUtils.java +++ b/coral-schema/src/test/java/com/linkedin/coral/schema/avro/TestUtils.java @@ -99,6 +99,7 @@ private static void initializeTables() { String baseComplexUnionTypeSchema = loadSchema("base-complex-union-type.avsc"); String baseNestedUnionSchema = loadSchema("base-nested-union.avsc"); String baseComplexLowercase = loadSchema("base-complex-lowercase.avsc"); + String baseComplexNonNullable = loadSchema("base-complex-non-nullable.avsc"); String baseComplexNullableWithDefaults = loadSchema("base-complex-nullable-with-defaults.avsc"); String basePrimitive = loadSchema("base-primitive.avsc"); String baseComplexNestedStructSameName = loadSchema("base-complex-nested-struct-same-name.avsc"); @@ -121,6 +122,7 @@ private static void initializeTables() { executeCreateTableWithPartitionFieldSchemaQuery("default", "basecomplexfieldschema", baseComplexFieldSchema); executeCreateTableWithPartitionQuery("default", "basenestedcomplex", baseNestedComplexSchema); executeCreateTableWithPartitionQuery("default", "basecomplexnullablewithdefaults", baseComplexNullableWithDefaults); + executeCreateTableWithPartitionQuery("default", "basecomplexnonnullable", baseComplexNonNullable); String baseComplexSchemaWithDoc = loadSchema("docTestResources/base-complex-with-doc.avsc"); String baseEnumSchemaWithDoc = loadSchema("docTestResources/base-enum-with-doc.avsc"); diff --git a/coral-schema/src/test/java/com/linkedin/coral/schema/avro/ViewToAvroSchemaConverterTests.java b/coral-schema/src/test/java/com/linkedin/coral/schema/avro/ViewToAvroSchemaConverterTests.java index 69d28eef3..7b9d121dc 100644 --- a/coral-schema/src/test/java/com/linkedin/coral/schema/avro/ViewToAvroSchemaConverterTests.java +++ b/coral-schema/src/test/java/com/linkedin/coral/schema/avro/ViewToAvroSchemaConverterTests.java @@ -1102,5 +1102,16 @@ public void testDivideReturnType() { Assert.assertEquals(actualSchema.toString(true), TestUtils.loadSchema("testDivideReturnType-expected.avsc")); } + @Test + public void testCastNullabilityUDF() { + ViewToAvroSchemaConverter viewToAvroSchemaConverter = ViewToAvroSchemaConverter.create(hiveMetastoreClient); + + Schema schemaWithUDF = viewToAvroSchemaConverter.toAvroSchema("SELECT cast_nullability(Struct_Col, Struct_Col) AS modCol FROM basecomplexnonnullable"); + Schema schemaWithField = viewToAvroSchemaConverter.toAvroSchema("SELECT Struct_Col AS modCol FROM basecomplexnonnullable"); + + Assert.assertEquals(schemaWithUDF.toString(true), TestUtils.loadSchema("testCastNullabilityUDF-expected.avsc")); + Assert.assertEquals(schemaWithField.toString(true), TestUtils.loadSchema("testCastNullabilityUDF-expected.avsc")); + } + // TODO: add more unit tests } diff --git a/coral-schema/src/test/resources/testCastNullabilityUDF-expected.avsc b/coral-schema/src/test/resources/testCastNullabilityUDF-expected.avsc new file mode 100644 index 000000000..c5b5f18ca --- /dev/null +++ b/coral-schema/src/test/resources/testCastNullabilityUDF-expected.avsc @@ -0,0 +1,58 @@ +{ + "type" : "record", + "name" : "basecomplexnonnullable", + "namespace" : "coral.schema.avro.base.complex.nonnullable", + "fields" : [ { + "name" : "modCol", + "type" : { + "type" : "record", + "name" : "Struct_col", + "namespace" : "coral.schema.avro.base.complex.nonnullable.basecomplexnonnullable", + "fields" : [ { + "name" : "Bool_Field", + "type" : "boolean" + }, { + "name" : "Int_Field", + "type" : "int" + }, { + "name" : "Bigint_Field", + "type" : "long" + }, { + "name" : "Float_Field", + "type" : "float" + }, { + "name" : "Double_Field", + "type" : "double" + }, { + "name" : "Date_String_Field", + "type" : "string" + }, { + "name" : "String_Field", + "type" : "string" + }, { + "name" : "Array_Col", + "type" : { + "type" : "array", + "items" : { + "type" : "record", + "name" : "Struct_col", + "namespace" : "coral.schema.avro.base.complex.nonnullable.basecomplexnonnullable.basecomplexnonnullable", + "fields" : [ { + "name" : "key", + "type" : "string" + }, { + "name" : "value", + "type" : "string" + } ] + } + } + }, { + "name" : "Map_Col", + "type" : { + "type" : "map", + "values" : "string" + } + } ] + } + } ] +} \ No newline at end of file