Skip to content

Commit

Permalink
add cast_nullability as an inbuilt function
Browse files Browse the repository at this point in the history
  • Loading branch information
aastha25 committed Dec 13, 2023
1 parent 6536236 commit 3a7e7af
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import org.apache.calcite.sql.type.OrdinalReturnTypeInference;


/**
* Custom implementation of {@link OrdinalReturnTypeInference} which allows inferring the return type
* based on the ordinal of a given input argument and also exposes the ordinal.
*/
public class OrdinalReturnTypeInferenceV2 extends OrdinalReturnTypeInference {
private final int ordinal;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,10 @@ public boolean isOptional(int i) {
STRING_STRING_STRING);
createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.HasMemberConsent", ReturnTypes.BOOLEAN,
family(SqlTypeFamily.STRING, SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP));
createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactFieldIf", new OrdinalReturnTypeInferenceV2(1),
createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactFieldIf", ARG1,
family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.ANY));
createAddUserDefinedFunction("cast_nullability", new OrdinalReturnTypeInferenceV2(1),
family(SqlTypeFamily.ANY, SqlTypeFamily.ANY));

createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactSecondarySchemaFieldIf", ARG1, family(
SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ARRAY, SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ static Schema relDataTypeToAvroTypeNonNullable(@Nonnull RelDataType relDataType,

private static Schema relDataTypeToAvroType(RelDataType relDataType, String recordName) {
final Schema avroSchema = relDataTypeToAvroTypeNonNullable(relDataType, recordName);
// TODO: Current logic ALWAYS sets the inner fields of RelDataType record nullable.
// Modify this to be applied only when RelDataType record was generated from a HIVE_UDF RexCall
return SchemaUtilities.makeNullable(avroSchema, false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package com.linkedin.coral.schema.avro;

import com.linkedin.coral.hive.hive2rel.functions.OrdinalReturnTypeInferenceV2;
import java.util.Deque;
import java.util.HashMap;
import java.util.LinkedList;
Expand Down Expand Up @@ -407,14 +408,7 @@ public SchemaRexShuttle(Schema inputSchema, RelNode inputNode, Queue<String> sug
@Override
public RexNode visitInputRef(RexInputRef rexInputRef) {
RexNode rexNode = super.visitInputRef(rexInputRef);

Schema.Field field = inputSchema.getFields().get(rexInputRef.getIndex());
String oldFieldName = field.name();
String suggestNewFieldName = suggestedFieldNames.poll();
String newFieldName = SchemaUtilities.getFieldName(oldFieldName, suggestNewFieldName);

SchemaUtilities.appendField(newFieldName, field, fieldAssembler);

appendRexInputRefField(rexInputRef);
return rexNode;
}

Expand Down Expand Up @@ -442,6 +436,23 @@ public RexNode visitCall(RexCall rexCall) {
* For SqlUserDefinedFunction and SqlOperator RexCall, no need to handle it recursively
* and only return type of udf or sql operator is relevant
*/

/**
* If the return type of RexCall is based on the ordinal of its input argument
* and the corresponding input argument refers to a field from the input schema,
* use the field's schema as is.
*/
if (rexCall.getOperator().getReturnTypeInference() instanceof OrdinalReturnTypeInferenceV2) {
int index = ((OrdinalReturnTypeInferenceV2) rexCall.getOperator().getReturnTypeInference())
.getOrdinal();
RexNode operand = rexCall.operands.get(index);

if (operand instanceof RexInputRef) {
appendRexInputRefField((RexInputRef) operand);
return rexCall;
}
}

RelDataType fieldType = rexCall.getType();
boolean isNullable = SchemaUtilities.isFieldNullable(rexCall, inputSchema);

Expand Down Expand Up @@ -545,6 +556,15 @@ public RexNode visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) {
return super.visitPatternFieldRef(rexPatternFieldRef);
}

private void appendRexInputRefField(RexInputRef rexInputRef) {
Schema.Field field = inputSchema.getFields().get(rexInputRef.getIndex());
String oldFieldName = field.name();
String suggestNewFieldName = suggestedFieldNames.poll();
String newFieldName = SchemaUtilities.getFieldName(oldFieldName, suggestNewFieldName);

SchemaUtilities.appendField(newFieldName, field, fieldAssembler);
}

private void appendField(RelDataType fieldType, boolean isNullable, String doc) {
String fieldName = SchemaUtilities.getFieldName("", suggestedFieldNames.poll());
SchemaUtilities.appendField(fieldName, fieldType, doc, fieldAssembler, isNullable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package com.linkedin.coral.schema.avro;

import com.linkedin.coral.hive.hive2rel.functions.OrdinalReturnTypeInferenceV2;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.*;
Expand All @@ -30,7 +29,6 @@
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OrdinalReturnTypeInference;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
Expand Down Expand Up @@ -239,19 +237,6 @@ static boolean isFieldNullable(@Nonnull RexCall rexCall, @Nonnull Schema inputSc
return rexCall.getType().isNullable();
}

if (rexCall.getOperator().getReturnTypeInference() instanceof OrdinalReturnTypeInferenceV2) {
int index = ((OrdinalReturnTypeInferenceV2) rexCall.getOperator().getReturnTypeInference())
.getOrdinal();
RexNode operand = rexCall.operands.get(index);

if (operand instanceof RexInputRef) {
Schema schema = inputSchema.getFields().get(((RexInputRef) operand).getIndex()).schema();
return isNullableType(schema);
} else if (operand instanceof RexCall) {
return isFieldNullable((RexCall) operand, inputSchema);
}
}

// the field is non-nullable only if all operands are RexInputRef
// and corresponding field schema type of RexInputRef index is not UNION
List<RexNode> operands = rexCall.getOperands();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ private static void initializeTables() {
String baseComplexUnionTypeSchema = loadSchema("base-complex-union-type.avsc");
String baseNestedUnionSchema = loadSchema("base-nested-union.avsc");
String baseComplexLowercase = loadSchema("base-complex-lowercase.avsc");
String baseComplexNonNullable = loadSchema("base-complex-non-nullable.avsc");
String baseComplexNullableWithDefaults = loadSchema("base-complex-nullable-with-defaults.avsc");
String basePrimitive = loadSchema("base-primitive.avsc");
String baseComplexNestedStructSameName = loadSchema("base-complex-nested-struct-same-name.avsc");
Expand All @@ -121,6 +122,7 @@ private static void initializeTables() {
executeCreateTableWithPartitionFieldSchemaQuery("default", "basecomplexfieldschema", baseComplexFieldSchema);
executeCreateTableWithPartitionQuery("default", "basenestedcomplex", baseNestedComplexSchema);
executeCreateTableWithPartitionQuery("default", "basecomplexnullablewithdefaults", baseComplexNullableWithDefaults);
executeCreateTableWithPartitionQuery("default", "basecomplexnonnullable", baseComplexNonNullable);

String baseComplexSchemaWithDoc = loadSchema("docTestResources/base-complex-with-doc.avsc");
String baseEnumSchemaWithDoc = loadSchema("docTestResources/base-enum-with-doc.avsc");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1102,5 +1102,16 @@ public void testDivideReturnType() {
Assert.assertEquals(actualSchema.toString(true), TestUtils.loadSchema("testDivideReturnType-expected.avsc"));
}

@Test
public void testCastNullabilityUDF() {
ViewToAvroSchemaConverter viewToAvroSchemaConverter = ViewToAvroSchemaConverter.create(hiveMetastoreClient);

Schema schemaWithUDF = viewToAvroSchemaConverter.toAvroSchema("SELECT cast_nullability(Struct_Col, Struct_Col) AS modCol FROM basecomplexnonnullable");
Schema schemaWithField = viewToAvroSchemaConverter.toAvroSchema("SELECT Struct_Col AS modCol FROM basecomplexnonnullable");

Assert.assertEquals(schemaWithUDF.toString(true), TestUtils.loadSchema("testCastNullabilityUDF-expected.avsc"));
Assert.assertEquals(schemaWithField.toString(true), TestUtils.loadSchema("testCastNullabilityUDF-expected.avsc"));
}

// TODO: add more unit tests
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"type" : "record",
"name" : "basecomplexnonnullable",
"namespace" : "coral.schema.avro.base.complex.nonnullable",
"fields" : [ {
"name" : "modCol",
"type" : {
"type" : "record",
"name" : "Struct_col",
"namespace" : "coral.schema.avro.base.complex.nonnullable.basecomplexnonnullable",
"fields" : [ {
"name" : "Bool_Field",
"type" : "boolean"
}, {
"name" : "Int_Field",
"type" : "int"
}, {
"name" : "Bigint_Field",
"type" : "long"
}, {
"name" : "Float_Field",
"type" : "float"
}, {
"name" : "Double_Field",
"type" : "double"
}, {
"name" : "Date_String_Field",
"type" : "string"
}, {
"name" : "String_Field",
"type" : "string"
}, {
"name" : "Array_Col",
"type" : {
"type" : "array",
"items" : {
"type" : "record",
"name" : "Struct_col",
"namespace" : "coral.schema.avro.base.complex.nonnullable.basecomplexnonnullable.basecomplexnonnullable",
"fields" : [ {
"name" : "key",
"type" : "string"
}, {
"name" : "value",
"type" : "string"
} ]
}
}
}, {
"name" : "Map_Col",
"type" : {
"type" : "map",
"values" : "string"
}
} ]
}
} ]
}

0 comments on commit 3a7e7af

Please sign in to comment.