Skip to content

Commit

Permalink
fix array function
Browse files Browse the repository at this point in the history
  • Loading branch information
Rong Rong committed Jan 26, 2024
1 parent b7f7b7e commit 89597f3
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ private FunctionRegistry() {
scalarFunctionNames.add(method.getName());
}
boolean nullableParameters = scalarFunction.nullableParameters();
registerFunction(method, scalarFunctionNames, nullableParameters);
boolean varArg = scalarFunction.isVarArg();
registerFunction(method, scalarFunctionNames, nullableParameters, varArg);
}
}
Set<Class<?>> classes = PinotReflectionUtils.getClassesThroughReflection(".*\\.function\\..*", ScalarFunction.class);
Set<Class<?>> classes =
PinotReflectionUtils.getClassesThroughReflection(".*\\.function\\..*", ScalarFunction.class);
for (Class<?> clazz : classes) {
if (!Modifier.isPublic(clazz.getModifiers())) {
continue;
Expand All @@ -105,7 +107,8 @@ private FunctionRegistry() {
scalarFunctionNames.add(clazz.getName());
}
boolean nullableParameters = scalarFunction.nullableParameters();
registerFunction(clazz, scalarFunctionNames, nullableParameters);
boolean varArg = scalarFunction.isVarArg();
registerFunction(clazz, scalarFunctionNames, nullableParameters, varArg);
}
}
LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.map().size(),
Expand Down Expand Up @@ -152,8 +155,8 @@ public static void init() {
}

@VisibleForTesting
public static void registerFunction(Method method, boolean nullableParameters) {
registerFunction(method, Collections.singleton(method.getName()), nullableParameters);
public static void registerFunction(Method method) {
registerFunction(method, Collections.singleton(method.getName()), false, false);
}

@VisibleForTesting
Expand Down Expand Up @@ -197,10 +200,10 @@ public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDa

@Nullable
private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionName, int numParams) {
List<PinotScalarFunction> candidates = getFunctionMap()
.range(functionName, CASE_SENSITIVITY).stream()
.filter(e -> e.getValue() instanceof PinotScalarFunction && e.getValue().getParameters().size() == numParams)
.map(e -> (PinotScalarFunction) e.getValue()).collect(Collectors.toList());
List<PinotScalarFunction> candidates = getFunctionMap().range(functionName, CASE_SENSITIVITY).stream().filter(
e -> e.getValue() instanceof PinotScalarFunction && (e.getValue().getParameters().size() == numParams
|| ((PinotScalarFunction) e.getValue()).isVarArgs())).map(e -> (PinotScalarFunction) e.getValue())
.collect(Collectors.toList());
if (candidates.size() == 1) {
return candidates.get(0).getFunctionInfo();
} else {
Expand All @@ -213,9 +216,10 @@ private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionNa
private static PinotScalarFunction getScalarFunction(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory,
String functionName, List<DataSchema.ColumnDataType> argTypes) {
List<RelDataType> relArgTypes = convertArgumentTypes(typeFactory, argTypes);
SqlOperator sqlOperator = SqlUtil.lookupRoutine(operatorTable, typeFactory,
new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), relArgTypes, null, null, SqlSyntax.FUNCTION,
SqlKind.OTHER_FUNCTION, SqlNameMatchers.withCaseSensitive(false), true);
SqlOperator sqlOperator =
SqlUtil.lookupRoutine(operatorTable, typeFactory, new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO),
relArgTypes, null, null, SqlSyntax.FUNCTION, SqlKind.OTHER_FUNCTION,
SqlNameMatchers.withCaseSensitive(false), true);
if (sqlOperator instanceof SqlUserDefinedFunction) {
Function function = ((SqlUserDefinedFunction) sqlOperator).getFunction();
if (function instanceof PinotScalarFunction) {
Expand All @@ -233,33 +237,38 @@ public static NameMultimap<SqlOperator> getOperatorMap() {
return OPERATOR_MAP;
}

private static void registerFunction(Method method, Set<String> alias, boolean nullableParameters) {
private static void registerFunction(Method method, Set<String> alias, boolean nullableParameters, boolean varArg) {
if (method.getAnnotation(Deprecated.class) == null) {
for (String name : alias) {
registerCalciteNamedFunctionMap(name, method, nullableParameters);
registerCalciteNamedFunctionMap(name, method, nullableParameters, varArg);
}
}
}

private static void registerFunction(Class<?> clazz, Set<String> alias, boolean nullableParameters) {
private static void registerFunction(Class<?> clazz, Set<String> alias, boolean nullableParameters, boolean varArg) {
if (clazz.getAnnotation(Deprecated.class) == null) {
for (String name : alias) {
registerCalciteNamedFunctionMap(name, clazz, nullableParameters);
registerCalciteNamedFunctionMap(name, clazz, nullableParameters, varArg);
}
}
}

private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters) {
FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters));
private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters,
boolean varArg) {
FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, varArg));
}

private static void registerCalciteNamedFunctionMap(String name, Class<?> clazz, boolean nullableParameters) {
private static void registerCalciteNamedFunctionMap(String name, Class<?> clazz, boolean nullableParameters,
boolean varArg) {
try {
SqlReturnTypeInference returnTypeInference = (SqlReturnTypeInference) clazz.getField("RETURN_TYPE_INFERENCE").get(null);
SqlOperandTypeChecker operandTypeChecker = (SqlOperandTypeChecker) clazz.getField("OPERAND_TYPE_CHECKER").get(null);
SqlReturnTypeInference returnTypeInference =
(SqlReturnTypeInference) clazz.getField("RETURN_TYPE_INFERENCE").get(null);
SqlOperandTypeChecker operandTypeChecker =
(SqlOperandTypeChecker) clazz.getField("OPERAND_TYPE_CHECKER").get(null);
for (Method method : clazz.getMethods()) {
if (method.getName().equals("eval")) {
FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, operandTypeChecker, returnTypeInference));
FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, varArg, operandTypeChecker,
returnTypeInference));
}
}
} catch (Exception e) {
Expand All @@ -269,19 +278,20 @@ private static void registerCalciteNamedFunctionMap(String name, Class<?> clazz,

private static void registerAggregateFunction(String functionName, AggregationFunctionType functionType) {
if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) {
PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
functionType.getSqlKind(), functionType.getReturnTypeInference(), null,
functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory());
PinotSqlAggFunction sqlAggFunction =
new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, functionType.getSqlKind(),
functionType.getReturnTypeInference(), null, functionType.getOperandTypeChecker(),
functionType.getSqlFunctionCategory());
OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlAggFunction);
}
}

private static void registerTransformFunction(String functionName, TransformFunctionType functionType) {
if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) {
PinotSqlTransformFunction sqlTransformFunction =
new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT),
functionType.getSqlKind(), functionType.getReturnTypeInference(), null,
functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory());
new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), functionType.getSqlKind(),
functionType.getReturnTypeInference(), null, functionType.getOperandTypeChecker(),
functionType.getSqlFunctionCategory());
OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlTransformFunction);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,10 @@
import it.unimi.dsi.fastutil.objects.ObjectSet;
import java.math.BigDecimal;
import java.util.Arrays;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SameOperandTypeChecker;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.pinot.spi.annotations.ScalarFunction;
import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder;
Expand Down Expand Up @@ -235,7 +232,7 @@ public static String arrayElementAtString(String[] arr, int idx) {
return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING;
}

@ScalarFunction(names = {"array", "arrayValueConstructor"})
@ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true)
public static class ArrayValueConstructor {
public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = ReturnTypes.TO_ARRAY;
public static final SqlOperandTypeChecker OPERAND_TYPE_CHECKER = new SameOperandTypeChecker(-1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,20 @@ public class PinotScalarFunction extends ReflectiveFunctionBase implements Pinot
private final Method _method;
private final SqlOperandTypeChecker _sqlOperandTypeChecker;
private final SqlReturnTypeInference _sqlReturnTypeInference;
private final boolean _isNullableParameter;
private final boolean _isVarArgs;

public PinotScalarFunction(String name, Method method, boolean isNullableParameter) {
this(name, method, isNullableParameter, null, null);
public PinotScalarFunction(String name, Method method, boolean isNullableParameter, boolean isVarArg) {
this(name, method, isNullableParameter, isVarArg, null, null);
}

public PinotScalarFunction(String name, Method method, boolean isNullableParameter,
public PinotScalarFunction(String name, Method method, boolean isNullableParameter, boolean isVarArgs,
SqlOperandTypeChecker sqlOperandTypeChecker, SqlReturnTypeInference sqlReturnTypeInference) {
super(method);
_name = name;
_method = method;
_isNullableParameter = isNullableParameter;
_isVarArgs = isVarArgs;
_functionInfo = new FunctionInfo(method, method.getDeclaringClass(), isNullableParameter);
_sqlOperandTypeChecker = sqlOperandTypeChecker;
_sqlReturnTypeInference = sqlReturnTypeInference;
Expand Down Expand Up @@ -80,4 +84,12 @@ public SqlOperandTypeChecker getOperandTypeChecker() {
public SqlReturnTypeInference getReturnTypeInference() {
return _sqlReturnTypeInference;
}

public boolean isNullableParameter() {
return _isNullableParameter;
}

public boolean isVarArgs() {
return _isVarArgs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.util.Optionality;
import org.apache.calcite.util.Util;
import org.apache.pinot.common.function.schema.PinotScalarFunction;
import org.checkerframework.checker.nullness.qual.Nullable;


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public void testStateSharedBetweenRowsForExecution()
throws Exception {
MyFunc myFunc = new MyFunc();
Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class);
FunctionRegistry.registerFunction(method, false);
FunctionRegistry.registerFunction(method);
String expression = "appendToStringAndReturn('test ')";
InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression);
assertTrue(evaluator.getArguments().isEmpty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ public FunctionOperand(SqlOperatorTable sqlOperatorTable, RelDataTypeFactory rel
return e.getDataType();
}
}).collect(Collectors.toList());
FunctionInfo functionInfo =
FunctionRegistry.getFunctionInfo(sqlOperatorTable, relDataTypeFactory, canonicalName, operandTypes);
FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(canonicalName, operandTypes.size());
Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", canonicalName);
_functionInvoker = new FunctionInvoker(functionInfo);
if (!_functionInvoker.getMethod().isVarArgs()) {
Expand Down

0 comments on commit 89597f3

Please sign in to comment.