diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index ed2f50154451..fe9177ea9e01 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -89,10 +89,12 @@ private FunctionRegistry() { scalarFunctionNames.add(method.getName()); } boolean nullableParameters = scalarFunction.nullableParameters(); - registerFunction(method, scalarFunctionNames, nullableParameters); + boolean varArg = scalarFunction.isVarArg(); + registerFunction(method, scalarFunctionNames, nullableParameters, varArg); } } - Set> classes = PinotReflectionUtils.getClassesThroughReflection(".*\\.function\\..*", ScalarFunction.class); + Set> classes = + PinotReflectionUtils.getClassesThroughReflection(".*\\.function\\..*", ScalarFunction.class); for (Class clazz : classes) { if (!Modifier.isPublic(clazz.getModifiers())) { continue; @@ -105,7 +107,8 @@ private FunctionRegistry() { scalarFunctionNames.add(clazz.getName()); } boolean nullableParameters = scalarFunction.nullableParameters(); - registerFunction(clazz, scalarFunctionNames, nullableParameters); + boolean varArg = scalarFunction.isVarArg(); + registerFunction(clazz, scalarFunctionNames, nullableParameters, varArg); } } LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.map().size(), @@ -152,8 +155,8 @@ public static void init() { } @VisibleForTesting - public static void registerFunction(Method method, boolean nullableParameters) { - registerFunction(method, Collections.singleton(method.getName()), nullableParameters); + public static void registerFunction(Method method) { + registerFunction(method, Collections.singleton(method.getName()), false, false); } @VisibleForTesting @@ -197,10 +200,10 @@ public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDa @Nullable private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionName, int numParams) { - List candidates = getFunctionMap() - .range(functionName, CASE_SENSITIVITY).stream() - .filter(e -> e.getValue() instanceof PinotScalarFunction && e.getValue().getParameters().size() == numParams) - .map(e -> (PinotScalarFunction) e.getValue()).collect(Collectors.toList()); + List candidates = getFunctionMap().range(functionName, CASE_SENSITIVITY).stream().filter( + e -> e.getValue() instanceof PinotScalarFunction && (e.getValue().getParameters().size() == numParams + || ((PinotScalarFunction) e.getValue()).isVarArgs())).map(e -> (PinotScalarFunction) e.getValue()) + .collect(Collectors.toList()); if (candidates.size() == 1) { return candidates.get(0).getFunctionInfo(); } else { @@ -213,9 +216,10 @@ private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionNa private static PinotScalarFunction getScalarFunction(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, String functionName, List argTypes) { List relArgTypes = convertArgumentTypes(typeFactory, argTypes); - SqlOperator sqlOperator = SqlUtil.lookupRoutine(operatorTable, typeFactory, - new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), relArgTypes, null, null, SqlSyntax.FUNCTION, - SqlKind.OTHER_FUNCTION, SqlNameMatchers.withCaseSensitive(false), true); + SqlOperator sqlOperator = + SqlUtil.lookupRoutine(operatorTable, typeFactory, new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), + relArgTypes, null, null, SqlSyntax.FUNCTION, SqlKind.OTHER_FUNCTION, + SqlNameMatchers.withCaseSensitive(false), true); if (sqlOperator instanceof SqlUserDefinedFunction) { Function function = ((SqlUserDefinedFunction) sqlOperator).getFunction(); if (function instanceof PinotScalarFunction) { @@ -233,33 +237,38 @@ public static NameMultimap getOperatorMap() { return OPERATOR_MAP; } - private static void registerFunction(Method method, Set alias, boolean nullableParameters) { + private static void registerFunction(Method method, Set alias, boolean nullableParameters, boolean varArg) { if (method.getAnnotation(Deprecated.class) == null) { for (String name : alias) { - registerCalciteNamedFunctionMap(name, method, nullableParameters); + registerCalciteNamedFunctionMap(name, method, nullableParameters, varArg); } } } - private static void registerFunction(Class clazz, Set alias, boolean nullableParameters) { + private static void registerFunction(Class clazz, Set alias, boolean nullableParameters, boolean varArg) { if (clazz.getAnnotation(Deprecated.class) == null) { for (String name : alias) { - registerCalciteNamedFunctionMap(name, clazz, nullableParameters); + registerCalciteNamedFunctionMap(name, clazz, nullableParameters, varArg); } } } - private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters) { - FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters)); + private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters, + boolean varArg) { + FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, varArg)); } - private static void registerCalciteNamedFunctionMap(String name, Class clazz, boolean nullableParameters) { + private static void registerCalciteNamedFunctionMap(String name, Class clazz, boolean nullableParameters, + boolean varArg) { try { - SqlReturnTypeInference returnTypeInference = (SqlReturnTypeInference) clazz.getField("RETURN_TYPE_INFERENCE").get(null); - SqlOperandTypeChecker operandTypeChecker = (SqlOperandTypeChecker) clazz.getField("OPERAND_TYPE_CHECKER").get(null); + SqlReturnTypeInference returnTypeInference = + (SqlReturnTypeInference) clazz.getField("RETURN_TYPE_INFERENCE").get(null); + SqlOperandTypeChecker operandTypeChecker = + (SqlOperandTypeChecker) clazz.getField("OPERAND_TYPE_CHECKER").get(null); for (Method method : clazz.getMethods()) { if (method.getName().equals("eval")) { - FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, operandTypeChecker, returnTypeInference)); + FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, varArg, operandTypeChecker, + returnTypeInference)); } } } catch (Exception e) { @@ -269,9 +278,10 @@ private static void registerCalciteNamedFunctionMap(String name, Class clazz, private static void registerAggregateFunction(String functionName, AggregationFunctionType functionType) { if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { - PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, - functionType.getSqlKind(), functionType.getReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); + PinotSqlAggFunction sqlAggFunction = + new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, functionType.getSqlKind(), + functionType.getReturnTypeInference(), null, functionType.getOperandTypeChecker(), + functionType.getSqlFunctionCategory()); OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlAggFunction); } } @@ -279,9 +289,9 @@ private static void registerAggregateFunction(String functionName, AggregationFu private static void registerTransformFunction(String functionName, TransformFunctionType functionType) { if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { PinotSqlTransformFunction sqlTransformFunction = - new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), - functionType.getSqlKind(), functionType.getReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); + new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), functionType.getSqlKind(), + functionType.getReturnTypeInference(), null, functionType.getOperandTypeChecker(), + functionType.getSqlFunctionCategory()); OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlTransformFunction); } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java index aae61e94a7c8..ee767c1d8525 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java @@ -25,13 +25,10 @@ import it.unimi.dsi.fastutil.objects.ObjectSet; import java.math.BigDecimal; import java.util.Arrays; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SameOperandTypeChecker; -import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.commons.lang3.ArrayUtils; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder; @@ -235,7 +232,7 @@ public static String arrayElementAtString(String[] arr, int idx) { return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING; } - @ScalarFunction(names = {"array", "arrayValueConstructor"}) + @ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true) public static class ArrayValueConstructor { public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = ReturnTypes.TO_ARRAY; public static final SqlOperandTypeChecker OPERAND_TYPE_CHECKER = new SameOperandTypeChecker(-1); diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java index 5030dc47bb53..e277d958a344 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java @@ -39,16 +39,20 @@ public class PinotScalarFunction extends ReflectiveFunctionBase implements Pinot private final Method _method; private final SqlOperandTypeChecker _sqlOperandTypeChecker; private final SqlReturnTypeInference _sqlReturnTypeInference; + private final boolean _isNullableParameter; + private final boolean _isVarArgs; - public PinotScalarFunction(String name, Method method, boolean isNullableParameter) { - this(name, method, isNullableParameter, null, null); + public PinotScalarFunction(String name, Method method, boolean isNullableParameter, boolean isVarArg) { + this(name, method, isNullableParameter, isVarArg, null, null); } - public PinotScalarFunction(String name, Method method, boolean isNullableParameter, + public PinotScalarFunction(String name, Method method, boolean isNullableParameter, boolean isVarArgs, SqlOperandTypeChecker sqlOperandTypeChecker, SqlReturnTypeInference sqlReturnTypeInference) { super(method); _name = name; _method = method; + _isNullableParameter = isNullableParameter; + _isVarArgs = isVarArgs; _functionInfo = new FunctionInfo(method, method.getDeclaringClass(), isNullableParameter); _sqlOperandTypeChecker = sqlOperandTypeChecker; _sqlReturnTypeInference = sqlReturnTypeInference; @@ -80,4 +84,12 @@ public SqlOperandTypeChecker getOperandTypeChecker() { public SqlReturnTypeInference getReturnTypeInference() { return _sqlReturnTypeInference; } + + public boolean isNullableParameter() { + return _isNullableParameter; + } + + public boolean isVarArgs() { + return _isVarArgs; + } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java index e3f3d2b737ef..0d6a94105c42 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java @@ -76,6 +76,7 @@ import org.apache.calcite.sql.validate.SqlValidatorUtil; import org.apache.calcite.util.Optionality; import org.apache.calcite.util.Util; +import org.apache.pinot.common.function.schema.PinotScalarFunction; import org.checkerframework.checker.nullness.qual.Nullable; diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java index 5c6835e293f6..031cfd169ecd 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java @@ -130,7 +130,7 @@ public void testStateSharedBetweenRowsForExecution() throws Exception { MyFunc myFunc = new MyFunc(); Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class); - FunctionRegistry.registerFunction(method, false); + FunctionRegistry.registerFunction(method); String expression = "appendToStringAndReturn('test ')"; InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); assertTrue(evaluator.getArguments().isEmpty()); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java index 29419dc5e4d9..773f73058fb7 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java @@ -58,8 +58,7 @@ public FunctionOperand(SqlOperatorTable sqlOperatorTable, RelDataTypeFactory rel return e.getDataType(); } }).collect(Collectors.toList()); - FunctionInfo functionInfo = - FunctionRegistry.getFunctionInfo(sqlOperatorTable, relDataTypeFactory, canonicalName, operandTypes); + FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(canonicalName, operandTypes.size()); Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", canonicalName); _functionInvoker = new FunctionInvoker(functionInfo); if (!_functionInvoker.getMethod().isVarArgs()) {