Skip to content

Commit

Permalink
fix lookup, using special pinot udf extension
Browse files Browse the repository at this point in the history
- support arrayElementAt
- support arrayValuConstructor
  • Loading branch information
Rong Rong committed Jan 26, 2024
1 parent 89597f3 commit ffb940b
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.apache.pinot.common.function.schema.PinotFunction;
import org.apache.pinot.common.function.schema.PinotScalarFunction;
import org.apache.pinot.common.function.sql.PinotSqlAggFunction;
import org.apache.pinot.common.function.sql.PinotSqlScalarFunction;
import org.apache.pinot.common.function.sql.PinotSqlTransformFunction;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.segment.spi.AggregationFunctionType;
Expand Down Expand Up @@ -188,7 +189,7 @@ public static FunctionInfo getFunctionInfo(String functionName, int numParams) {
*/
@Nullable
public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory,
String functionName, List<DataSchema.ColumnDataType> argTypes) {
String functionName, List<RelDataType> argTypes) {
PinotScalarFunction scalarFunction = getScalarFunction(operatorTable, typeFactory, functionName, argTypes);
if (scalarFunction != null) {
return scalarFunction.getFunctionInfo();
Expand All @@ -214,17 +215,18 @@ private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionNa

@Nullable
private static PinotScalarFunction getScalarFunction(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory,
String functionName, List<DataSchema.ColumnDataType> argTypes) {
List<RelDataType> relArgTypes = convertArgumentTypes(typeFactory, argTypes);
String functionName, List<RelDataType> argTypes) {
SqlOperator sqlOperator =
SqlUtil.lookupRoutine(operatorTable, typeFactory, new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO),
relArgTypes, null, null, SqlSyntax.FUNCTION, SqlKind.OTHER_FUNCTION,
argTypes, null, null, SqlSyntax.FUNCTION, SqlKind.OTHER_FUNCTION,
SqlNameMatchers.withCaseSensitive(false), true);
if (sqlOperator instanceof SqlUserDefinedFunction) {
Function function = ((SqlUserDefinedFunction) sqlOperator).getFunction();
if (function instanceof PinotScalarFunction) {
return (PinotScalarFunction) function;
}
} else if (sqlOperator instanceof PinotSqlScalarFunction) {
return ((PinotSqlScalarFunction) sqlOperator).getFunction();
}
return null;
}
Expand Down Expand Up @@ -296,7 +298,7 @@ private static void registerTransformFunction(String functionName, TransformFunc
}
}

private static List<RelDataType> convertArgumentTypes(RelDataTypeFactory typeFactory,
public static List<RelDataType> convertArgumentTypes(RelDataTypeFactory typeFactory,
List<DataSchema.ColumnDataType> argTypes) {
return argTypes.stream().map(type -> toRelType(typeFactory, type)).collect(Collectors.toList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,24 @@
*/
package org.apache.pinot.common.function.scalar;

import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntLinkedOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.objects.ObjectLinkedOpenHashSet;
import it.unimi.dsi.fastutil.objects.ObjectSet;
import java.math.BigDecimal;
import java.util.Arrays;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.type.NonNullableAccessors;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SameOperandTypeChecker;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.pinot.spi.annotations.ScalarFunction;
import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder;
Expand Down Expand Up @@ -232,6 +239,38 @@ public static String arrayElementAtString(String[] arr, int idx) {
return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING;
}

@ScalarFunction(names = {"arrayElementAt", "array_element_at"})
public static class ArrayElementAt {
public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = (opBinding) -> {
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
final RelDataType operandType = opBinding.getOperandType(0);
Preconditions.checkState(operandType.getSqlTypeName() == SqlTypeName.ARRAY);
return typeFactory.createTypeWithNullability(NonNullableAccessors.getComponentTypeOrThrow(operandType), true);
};
public static final SqlOperandTypeChecker OPERAND_TYPE_CHECKER =
OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER);

public static int eval(int[] arr, int idx) {
return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.INT;
}

public static long eval(long[] arr, int idx) {
return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.LONG;
}

public static float eval(float[] arr, int idx) {
return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.FLOAT;
}

public static double eval(double[] arr, int idx) {
return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.DOUBLE;
}

public static String eval(String[] arr, int idx) {
return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING;
}
}

@ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true)
public static class ArrayValueConstructor {
public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = ReturnTypes.TO_ARRAY;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlOperandTypeInference;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
Expand Down Expand Up @@ -323,6 +324,17 @@ public static SqlOperatorTable operatorTable(String... classNames) {
// ====================================================================
public static SqlOperator toOp(SqlIdentifier name,
final org.apache.calcite.schema.Function function) {

// TODO: support AGG and TABLE function in the future
if (function instanceof PinotScalarFunction && ((PinotScalarFunction) function).getOperandTypeChecker() != null) {
final SqlOperandTypeChecker operandTypeChecker =
((PinotScalarFunction) function).getOperandTypeChecker();
final SqlReturnTypeInference returnTypeInference =
((PinotScalarFunction) function).getReturnTypeInference();
final SqlKind kind = kind(function);
return new PinotSqlScalarFunction(name.toString(), kind, returnTypeInference, null, operandTypeChecker,
SqlFunctionCategory.USER_DEFINED_FUNCTION, ((PinotScalarFunction) function));
}
// ====================================================================
// LINES CHANGED ABOVE
// ====================================================================
Expand Down Expand Up @@ -353,13 +365,7 @@ public static SqlOperator toOp(SqlIdentifier name,
final List<SqlTypeFamily> typeFamilies =
typeFamiliesFactory.apply(dummyTypeFactory);

final SqlOperandTypeInference operandTypeInference;
if (function instanceof PinotScalarFunction && ((PinotScalarFunction) function).getOperandTypeChecker() != null) {
operandTypeInference = ((PinotScalarFunction) function).getOperandTypeChecker().typeInference();
} else {
operandTypeInference = InferTypes.explicit(argTypes);
}

final SqlOperandTypeInference operandTypeInference = InferTypes.explicit(argTypes);
final SqlOperandMetadata operandMetadata =
OperandTypes.operandMetadata(typeFamilies, paramTypesFactory,
i -> function.getParameters().get(i).getName(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.common.function.sql;

import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlOperandTypeInference;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.pinot.common.function.schema.PinotScalarFunction;
import org.checkerframework.checker.nullness.qual.Nullable;


public class PinotSqlScalarFunction extends SqlFunction {
private final PinotScalarFunction _pinotScalarFunction;

public PinotSqlScalarFunction(String name, SqlKind kind, @Nullable SqlReturnTypeInference returnTypeInference,
@Nullable SqlOperandTypeInference operandTypeInference, @Nullable SqlOperandTypeChecker operandTypeChecker,
SqlFunctionCategory category, PinotScalarFunction pinotScalarFunction) {
super(name, kind, returnTypeInference, operandTypeInference, operandTypeChecker, category);
_pinotScalarFunction = pinotScalarFunction;
}

public PinotScalarFunction getFunction() {
return _pinotScalarFunction;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.pinot.common.function.FunctionInfo;
Expand Down Expand Up @@ -58,7 +59,10 @@ public FunctionOperand(SqlOperatorTable sqlOperatorTable, RelDataTypeFactory rel
return e.getDataType();
}
}).collect(Collectors.toList());
FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(canonicalName, operandTypes.size());

List<RelDataType> argTypes = FunctionRegistry.convertArgumentTypes(relDataTypeFactory, operandTypes);
FunctionInfo functionInfo =
FunctionRegistry.getFunctionInfo(sqlOperatorTable, relDataTypeFactory, canonicalName, argTypes);
Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", canonicalName);
_functionInvoker = new FunctionInvoker(functionInfo);
if (!_functionInvoker.getMethod().isVarArgs()) {
Expand Down

0 comments on commit ffb940b

Please sign in to comment.