Skip to content

Commit

Permalink
Support array gen in literal evaluation (apache#12278)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu0 authored Jan 18, 2024
1 parent 4ad36c3 commit 6bb387a
Show file tree
Hide file tree
Showing 14 changed files with 341 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ private FunctionRegistry() {
// This FUNCTION_MAP is used by Calcite function catalog to look up function by function signature.
private static final NameMultimap<Function> FUNCTION_MAP = new NameMultimap<>();

private static final int VAR_ARG_KEY = -1;

/**
* Registers the scalar functions via reflection.
* NOTE: In order to plugin methods using reflection, the methods should be inside a class that includes ".function."
Expand All @@ -69,12 +71,14 @@ private FunctionRegistry() {
// Annotated function names
String[] scalarFunctionNames = scalarFunction.names();
boolean nullableParameters = scalarFunction.nullableParameters();
boolean isPlaceholder = scalarFunction.isPlaceholder();
boolean isVarArg = scalarFunction.isVarArg();
if (scalarFunctionNames.length > 0) {
for (String name : scalarFunctionNames) {
FunctionRegistry.registerFunction(name, method, nullableParameters, scalarFunction.isPlaceholder());
FunctionRegistry.registerFunction(name, method, nullableParameters, isPlaceholder, isVarArg);
}
} else {
FunctionRegistry.registerFunction(method, nullableParameters, scalarFunction.isPlaceholder());
FunctionRegistry.registerFunction(method, nullableParameters, isPlaceholder, isVarArg);
}
}
}
Expand All @@ -93,31 +97,40 @@ public static void init() {
/**
* Registers a method with the name of the method.
*/
public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder) {
registerFunction(method.getName(), method, nullableParameters, isPlaceholder);
public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder,
boolean isVarArg) {
registerFunction(method.getName(), method, nullableParameters, isPlaceholder, isVarArg);
}

/**
* Registers a method with the given function name.
*/
public static void registerFunction(String functionName, Method method, boolean nullableParameters,
boolean isPlaceholder) {
boolean isPlaceholder, boolean isVarArg) {
if (!isPlaceholder) {
registerFunctionInfoMap(functionName, method, nullableParameters);
registerFunctionInfoMap(functionName, method, nullableParameters, isVarArg);
}
registerCalciteNamedFunctionMap(functionName, method, nullableParameters);
registerCalciteNamedFunctionMap(functionName, method, nullableParameters, isVarArg);
}

private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters) {
private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters,
boolean isVarArg) {
FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters);
String canonicalName = canonicalize(functionName);
Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>());
FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo);
Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(),
"Function: %s with %s parameters is already registered", functionName, method.getParameterCount());
if (isVarArg) {
FunctionInfo existFunctionInfo = functionInfoMap.put(VAR_ARG_KEY, functionInfo);
Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(),
"Function: %s with variable number of parameters is already registered", functionName);
} else {
FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo);
Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(),
"Function: %s with %s parameters is already registered", functionName, method.getParameterCount());
}
}

private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters) {
private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters,
boolean isVarArg) {
if (method.getAnnotation(Deprecated.class) == null) {
FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method));
}
Expand Down Expand Up @@ -146,7 +159,14 @@ public static boolean containsFunction(String functionName) {
@Nullable
public static FunctionInfo getFunctionInfo(String functionName, int numParameters) {
Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName));
return functionInfoMap != null ? functionInfoMap.get(numParameters) : null;
if (functionInfoMap != null) {
FunctionInfo functionInfo = functionInfoMap.get(numParameters);
if (functionInfo != null) {
return functionInfo;
}
return functionInfoMap.get(VAR_ARG_KEY);
}
return null;
}

private static String canonicalize(String functionName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.objects.ObjectLinkedOpenHashSet;
import it.unimi.dsi.fastutil.objects.ObjectSet;
import java.math.BigDecimal;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.pinot.spi.annotations.ScalarFunction;
Expand Down Expand Up @@ -226,4 +227,62 @@ public static double arrayElementAtDouble(double[] arr, int idx) {
public static String arrayElementAtString(String[] arr, int idx) {
return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING;
}

@ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true)
public static Object arrayValueConstructor(Object... arr) {
if (arr.length == 0) {
return arr;
}
Class<?> clazz = arr[0].getClass();
if (clazz == Integer.class) {
int[] intArr = new int[arr.length];
for (int i = 0; i < arr.length; i++) {
intArr[i] = (Integer) arr[i];
}
return intArr;
}
if (clazz == Long.class) {
long[] longArr = new long[arr.length];
for (int i = 0; i < arr.length; i++) {
longArr[i] = (Long) arr[i];
}
return longArr;
}
if (clazz == Float.class) {
float[] floatArr = new float[arr.length];
for (int i = 0; i < arr.length; i++) {
floatArr[i] = (Float) arr[i];
}
return floatArr;
}
if (clazz == Double.class) {
double[] doubleArr = new double[arr.length];
for (int i = 0; i < arr.length; i++) {
doubleArr[i] = (Double) arr[i];
}
return doubleArr;
}
if (clazz == Boolean.class) {
boolean[] boolArr = new boolean[arr.length];
for (int i = 0; i < arr.length; i++) {
boolArr[i] = (Boolean) arr[i];
}
return boolArr;
}
if (clazz == BigDecimal.class) {
BigDecimal[] bigDecimalArr = new BigDecimal[arr.length];
for (int i = 0; i < arr.length; i++) {
bigDecimalArr[i] = (BigDecimal) arr[i];
}
return bigDecimalArr;
}
if (clazz == String.class) {
String[] strArr = new String[arr.length];
for (int i = 0; i < arr.length; i++) {
strArr[i] = (String) arr[i];
}
return strArr;
}
return arr;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ public LiteralContext(Literal literal) {
_type = DataType.BYTES;
_value = literal.getBinaryValue();
break;
case INT_ARRAY_VALUE:
_type = DataType.INT;
_value = literal.getIntArrayValue();
break;
case LONG_ARRAY_VALUE:
_type = DataType.LONG;
_value = literal.getLongArrayValue();
break;
case FLOAT_ARRAY_VALUE:
_type = DataType.FLOAT;
_value = literal.getFloatArrayValue();
break;
case DOUBLE_ARRAY_VALUE:
_type = DataType.DOUBLE;
_value = literal.getDoubleArrayValue();
break;
case STRING_ARRAY_VALUE:
_type = DataType.STRING;
_value = literal.getStringArrayValue();
break;
case NULL_VALUE:
_type = DataType.UNKNOWN;
_value = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@ protected static Expression invokeCompileTimeFunctionExpression(@Nullable Expres
}
try {
FunctionInvoker invoker = new FunctionInvoker(functionInfo);
invoker.convertTypes(arguments);
Object result = invoker.invoke(arguments);
Object result;
if (invoker.getMethod().isVarArgs()) {
result = invoker.invoke(new Object[] {arguments});
} else {
invoker.convertTypes(arguments);
result = invoker.invoke(arguments);
}
return RequestUtils.getLiteralExpression(result);
} catch (Exception e) {
throw new SqlCompilationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.LiteralContext;
import org.apache.pinot.core.operator.ColumnContext;
import org.apache.pinot.core.operator.blocks.ValueBlock;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
Expand Down Expand Up @@ -54,6 +55,77 @@ public class ArrayLiteralTransformFunction implements TransformFunction {
private double[][] _doubleArrayResult;
private String[][] _stringArrayResult;

public ArrayLiteralTransformFunction(LiteralContext literalContext) {
List literalArray = (List) literalContext.getValue();
Preconditions.checkNotNull(literalArray);
if (literalArray.isEmpty()) {
_dataType = DataType.UNKNOWN;
_intArrayLiteral = new int[0];
_longArrayLiteral = new long[0];
_floatArrayLiteral = new float[0];
_doubleArrayLiteral = new double[0];
_stringArrayLiteral = new String[0];
return;
}
_dataType = literalContext.getType();
switch (_dataType) {
case INT:
_intArrayLiteral = new int[literalArray.size()];
for (int i = 0; i < _intArrayLiteral.length; i++) {
_intArrayLiteral[i] = (int) literalArray.get(i);
}
_longArrayLiteral = null;
_floatArrayLiteral = null;
_doubleArrayLiteral = null;
_stringArrayLiteral = null;
break;
case LONG:
_longArrayLiteral = new long[literalArray.size()];
for (int i = 0; i < _longArrayLiteral.length; i++) {
_longArrayLiteral[i] = (long) literalArray.get(i);
}
_intArrayLiteral = null;
_floatArrayLiteral = null;
_doubleArrayLiteral = null;
_stringArrayLiteral = null;
break;
case FLOAT:
_floatArrayLiteral = new float[literalArray.size()];
for (int i = 0; i < _floatArrayLiteral.length; i++) {
_floatArrayLiteral[i] = (float) literalArray.get(i);
}
_intArrayLiteral = null;
_longArrayLiteral = null;
_doubleArrayLiteral = null;
_stringArrayLiteral = null;
break;
case DOUBLE:
_doubleArrayLiteral = new double[literalArray.size()];
for (int i = 0; i < _doubleArrayLiteral.length; i++) {
_doubleArrayLiteral[i] = (double) literalArray.get(i);
}
_intArrayLiteral = null;
_longArrayLiteral = null;
_floatArrayLiteral = null;
_stringArrayLiteral = null;
break;
case STRING:
_stringArrayLiteral = new String[literalArray.size()];
for (int i = 0; i < _stringArrayLiteral.length; i++) {
_stringArrayLiteral[i] = (String) literalArray.get(i);
}
_intArrayLiteral = null;
_longArrayLiteral = null;
_floatArrayLiteral = null;
_doubleArrayLiteral = null;
break;
default:
throw new IllegalStateException(
"Illegal data type for ArrayLiteralTransformFunction: " + _dataType + ", literal contexts: "
+ Arrays.toString(literalArray.toArray()));
}
}

public ArrayLiteralTransformFunction(List<ExpressionContext> literalContexts) {
Preconditions.checkNotNull(literalContexts);
if (literalContexts.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.request.context.LiteralContext;
import org.apache.pinot.common.utils.HashUtil;
import org.apache.pinot.core.geospatial.transform.function.GeoToH3Function;
import org.apache.pinot.core.geospatial.transform.function.StAreaFunction;
Expand Down Expand Up @@ -335,7 +336,12 @@ public static TransformFunction get(ExpressionContext expression, Map<String, Co
String columnName = expression.getIdentifier();
return new IdentifierTransformFunction(columnName, columnContextMap.get(columnName));
case LITERAL:
return queryContext.getOrComputeSharedValue(LiteralTransformFunction.class, expression.getLiteral(),
LiteralContext literal = expression.getLiteral();
if (literal.getValue() != null && literal.getValue() instanceof ArrayList) {
return queryContext.getOrComputeSharedValue(ArrayLiteralTransformFunction.class, literal,
ArrayLiteralTransformFunction::new);
}
return queryContext.getOrComputeSharedValue(LiteralTransformFunction.class, literal,
LiteralTransformFunction::new);
default:
throw new IllegalStateException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,25 @@ public HistogramAggregationFunction(List<ExpressionContext> arguments) {
if (numArguments == 2) {
ExpressionContext arrayExpression = arguments.get(1);
Preconditions.checkArgument(
(arrayExpression.getType() == ExpressionContext.Type.FUNCTION) && (arrayExpression.getFunction()
.getFunctionName().equals(ARRAY_CONSTRUCTOR)),
// ARRAY function
((arrayExpression.getType() == ExpressionContext.Type.FUNCTION)
&& (arrayExpression.getFunction().getFunctionName().equals(ARRAY_CONSTRUCTOR)))
|| ((arrayExpression.getType() == ExpressionContext.Type.LITERAL)
&& (arrayExpression.getLiteral().getValue() instanceof List)),
"Please use the format of `Histogram(columnName, ARRAY[1,10,100])` to specify the bin edges");
_bucketEdges = parseVector(arrayExpression.getFunction().getArguments());
if (arrayExpression.getType() == ExpressionContext.Type.FUNCTION) {
_bucketEdges = parseVector(arrayExpression.getFunction().getArguments());
} else {
_bucketEdges = parseVectorLiteral((List) arrayExpression.getLiteral().getValue());
}
_lower = _bucketEdges[0];
_upper = _bucketEdges[_bucketEdges.length - 1];
} else {
_isEqualLength = true;
_lower = arguments.get(1).getLiteral().getDoubleValue();
_upper = arguments.get(2).getLiteral().getDoubleValue();
int numBins = arguments.get(3).getLiteral().getIntValue();;
int numBins = arguments.get(3).getLiteral().getIntValue();
;
Preconditions.checkArgument(_upper > _lower,
"The right most edge must be greater than left most edge, given %s and %s", _lower, _upper);
Preconditions.checkArgument(numBins > 0, "The number of bins must be greater than zero, given %s", numBins);
Expand Down Expand Up @@ -109,8 +117,23 @@ private double[] parseVector(List<ExpressionContext> arrayStr) {
return ret;
}

private double[] parseVectorLiteral(List arrayStr) {
int len = arrayStr.size();
Preconditions.checkArgument(len > 1, "The number of bin edges must be greater than 1");
double[] ret = new double[len];
for (int i = 0; i < len; i++) {
// TODO: Represent infinity as literal instead of identifier
ret[i] = Double.parseDouble(arrayStr.get(i).toString());
if (i > 0) {
Preconditions.checkState(ret[i] > ret[i - 1], "The bin edges must be strictly increasing");
}
}
return ret;
}

/**
* Find the bin id for the input value. Use division for equal-length bins, and binary search otherwise.
*
* @param val input value
* @return bin id
*/
Expand All @@ -135,7 +158,7 @@ private int getBinId(double val) {
i = mid;
}
}
id = i;
id = i;
}
return id;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public void testStateSharedBetweenRowsForExecution()
throws Exception {
MyFunc myFunc = new MyFunc();
Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class);
FunctionRegistry.registerFunction(method, false, false);
FunctionRegistry.registerFunction(method, false, false, false);
String expression = "appendToStringAndReturn('test ')";
InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression);
assertTrue(evaluator.getArguments().isEmpty());
Expand Down
Loading

0 comments on commit 6bb387a

Please sign in to comment.