Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce OrdinalReturnTypeInferenceV2 to infer RexCall's return type #481

Merged
merged 6 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.hive.hive2rel.functions;

import org.apache.calcite.sql.type.OrdinalReturnTypeInference;


/**
* Custom implementation of {@link OrdinalReturnTypeInference} which allows inferring the return type
* based on the ordinal of a given input argument and also exposes the ordinal.
*/
public class OrdinalReturnTypeInferenceV2 extends OrdinalReturnTypeInference {
private final int ordinal;

public OrdinalReturnTypeInferenceV2(int ordinal) {
super(ordinal);
this.ordinal = ordinal;
}

public int getOrdinal() {
return ordinal;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,8 @@ public boolean isOptional(int i) {
family(SqlTypeFamily.STRING, SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP));
createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactFieldIf", ARG1,
family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.ANY));
createAddUserDefinedFunction("li_groot_cast_nullability", new OrdinalReturnTypeInferenceV2(1),
ljfgem marked this conversation as resolved.
Show resolved Hide resolved
family(SqlTypeFamily.ANY, SqlTypeFamily.ANY));
ljfgem marked this conversation as resolved.
Show resolved Hide resolved

createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactSecondarySchemaFieldIf", ARG1, family(
SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ARRAY, SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ static Schema relDataTypeToAvroTypeNonNullable(@Nonnull RelDataType relDataType,

private static Schema relDataTypeToAvroType(RelDataType relDataType, String recordName) {
final Schema avroSchema = relDataTypeToAvroTypeNonNullable(relDataType, recordName);
// TODO: Current logic ALWAYS sets the inner fields of RelDataType record nullable.
// Modify this to be applied only when RelDataType record was generated from a HIVE_UDF RexCall
return SchemaUtilities.makeNullable(avroSchema, false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import com.linkedin.coral.com.google.common.base.Preconditions;
import com.linkedin.coral.common.HiveMetastoreClient;
import com.linkedin.coral.common.HiveUncollect;
import com.linkedin.coral.hive.hive2rel.functions.OrdinalReturnTypeInferenceV2;


/**
Expand Down Expand Up @@ -407,14 +408,7 @@ public SchemaRexShuttle(Schema inputSchema, RelNode inputNode, Queue<String> sug
@Override
public RexNode visitInputRef(RexInputRef rexInputRef) {
RexNode rexNode = super.visitInputRef(rexInputRef);

Schema.Field field = inputSchema.getFields().get(rexInputRef.getIndex());
String oldFieldName = field.name();
String suggestNewFieldName = suggestedFieldNames.poll();
String newFieldName = SchemaUtilities.getFieldName(oldFieldName, suggestNewFieldName);

SchemaUtilities.appendField(newFieldName, field, fieldAssembler);

appendRexInputRefField(rexInputRef);
return rexNode;
}

Expand Down Expand Up @@ -442,6 +436,22 @@ public RexNode visitCall(RexCall rexCall) {
* For SqlUserDefinedFunction and SqlOperator RexCall, no need to handle it recursively
* and only return type of udf or sql operator is relevant
*/

/**
* If the return type of RexCall is based on the ordinal of its input argument
* and the corresponding input argument refers to a field from the input schema,
* use the field's schema as is.
*/
if (rexCall.getOperator().getReturnTypeInference() instanceof OrdinalReturnTypeInferenceV2) {
int index = ((OrdinalReturnTypeInferenceV2) rexCall.getOperator().getReturnTypeInference()).getOrdinal();
RexNode operand = rexCall.operands.get(index);

if (operand instanceof RexInputRef) {
appendRexInputRefField((RexInputRef) operand);
return rexCall;
}
}

RelDataType fieldType = rexCall.getType();
boolean isNullable = SchemaUtilities.isFieldNullable(rexCall, inputSchema);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the reason OrdinalReturnTypeInference does not work out of the box?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[OrdinalReturnTypeInference](https://github.com/linkedin/linkedin-calcite/blob/li-1.21.0/core/src/main/java/org/apache/calcite/sql/type/OrdinalReturnTypeInference.java#L25) cannot be directly used because the private field 'ordinal' has no public accessor method. This class supports type derivation via method RelDataType inferReturnType(SqlOperatorBinding opBinding). To leverage this API, we need to make coral-schema complaint with Coral IR and enable type derivation in coral-schema, similar to our ongoing work in coral-spark & coral-trino

Copy link
Contributor

@wmoustafa wmoustafa Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So once we append a RexCall field, what gets used at the end to infer its type if not inferReturnType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we append the field AFTER extracting the ordinal here and deriving its field type here

Since this is a rexShuttle, the field types from the inputSchema are available in the RelNode representation.

SchemaUtilities.isFieldNullable(rexCall, inputSchema) derives incorrect nullability (always nullable) for the top level fields and SchemaUtilities.makeNullable here derives incorrect nullability (always nullable) for the inner fields.

By introducing the if condition above in line 445, we avoid those codepaths


Expand Down Expand Up @@ -545,6 +555,15 @@ public RexNode visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) {
return super.visitPatternFieldRef(rexPatternFieldRef);
}

private void appendRexInputRefField(RexInputRef rexInputRef) {
Schema.Field field = inputSchema.getFields().get(rexInputRef.getIndex());
String oldFieldName = field.name();
String suggestNewFieldName = suggestedFieldNames.poll();
String newFieldName = SchemaUtilities.getFieldName(oldFieldName, suggestNewFieldName);

SchemaUtilities.appendField(newFieldName, field, fieldAssembler);
}

private void appendField(RelDataType fieldType, boolean isNullable, String doc) {
String fieldName = SchemaUtilities.getFieldName("", suggestedFieldNames.poll());
SchemaUtilities.appendField(fieldName, fieldType, doc, fieldAssembler, isNullable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ private static void initializeTables() {
String baseComplexUnionTypeSchema = loadSchema("base-complex-union-type.avsc");
String baseNestedUnionSchema = loadSchema("base-nested-union.avsc");
String baseComplexLowercase = loadSchema("base-complex-lowercase.avsc");
String baseComplexNonNullable = loadSchema("base-complex-non-nullable.avsc");
String baseComplexNullableWithDefaults = loadSchema("base-complex-nullable-with-defaults.avsc");
String basePrimitive = loadSchema("base-primitive.avsc");
String baseComplexNestedStructSameName = loadSchema("base-complex-nested-struct-same-name.avsc");
Expand All @@ -121,6 +122,7 @@ private static void initializeTables() {
executeCreateTableWithPartitionFieldSchemaQuery("default", "basecomplexfieldschema", baseComplexFieldSchema);
executeCreateTableWithPartitionQuery("default", "basenestedcomplex", baseNestedComplexSchema);
executeCreateTableWithPartitionQuery("default", "basecomplexnullablewithdefaults", baseComplexNullableWithDefaults);
executeCreateTableWithPartitionQuery("default", "basecomplexnonnullable", baseComplexNonNullable);

String baseComplexSchemaWithDoc = loadSchema("docTestResources/base-complex-with-doc.avsc");
String baseEnumSchemaWithDoc = loadSchema("docTestResources/base-enum-with-doc.avsc");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1102,5 +1102,19 @@ public void testDivideReturnType() {
Assert.assertEquals(actualSchema.toString(true), TestUtils.loadSchema("testDivideReturnType-expected.avsc"));
}

@Test
public void testLiGrootCastNullability() {
ViewToAvroSchemaConverter viewToAvroSchemaConverter = ViewToAvroSchemaConverter.create(hiveMetastoreClient);

Schema schemaWithUDF = viewToAvroSchemaConverter
.toAvroSchema("SELECT li_groot_cast_nullability(Struct_Col, Struct_Col) AS modCol FROM basecomplexnonnullable");
Schema schemaWithField =
viewToAvroSchemaConverter.toAvroSchema("SELECT Struct_Col AS modCol FROM basecomplexnonnullable");

Assert.assertEquals(schemaWithUDF.toString(true), TestUtils.loadSchema("testLiGrootCastNullability-expected.avsc"));
Assert.assertEquals(schemaWithField.toString(true),
TestUtils.loadSchema("testLiGrootCastNullability-expected.avsc"));
}

// TODO: add more unit tests
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"type" : "record",
"name" : "basecomplexnonnullable",
"namespace" : "coral.schema.avro.base.complex.nonnullable",
"fields" : [ {
"name" : "modCol",
"type" : {
"type" : "record",
"name" : "Struct_col",
"namespace" : "coral.schema.avro.base.complex.nonnullable.basecomplexnonnullable",
"fields" : [ {
"name" : "Bool_Field",
"type" : "boolean"
}, {
"name" : "Int_Field",
"type" : "int"
}, {
"name" : "Bigint_Field",
"type" : "long"
}, {
"name" : "Float_Field",
"type" : "float"
}, {
"name" : "Double_Field",
"type" : "double"
}, {
"name" : "Date_String_Field",
"type" : "string"
}, {
"name" : "String_Field",
"type" : "string"
}, {
"name" : "Array_Col",
"type" : {
"type" : "array",
"items" : {
"type" : "record",
"name" : "Struct_col",
"namespace" : "coral.schema.avro.base.complex.nonnullable.basecomplexnonnullable.basecomplexnonnullable",
"fields" : [ {
"name" : "key",
"type" : "string"
}, {
"name" : "value",
"type" : "string"
} ]
}
}
}, {
"name" : "Map_Col",
"type" : {
"type" : "map",
"values" : "string"
}
} ]
}
} ]
}
Loading