diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index 97fa972bee18..deb1673d8bac 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -52,6 +52,8 @@ private FunctionRegistry() { // This FUNCTION_MAP is used by Calcite function catalog to look up function by function signature. private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); + private static final int VAR_ARG_KEY = -1; + /** * Registers the scalar functions via reflection. * NOTE: In order to plugin methods using reflection, the methods should be inside a class that includes ".function." @@ -69,12 +71,14 @@ private FunctionRegistry() { // Annotated function names String[] scalarFunctionNames = scalarFunction.names(); boolean nullableParameters = scalarFunction.nullableParameters(); + boolean isPlaceholder = scalarFunction.isPlaceholder(); + boolean isVarArg = scalarFunction.isVarArg(); if (scalarFunctionNames.length > 0) { for (String name : scalarFunctionNames) { - FunctionRegistry.registerFunction(name, method, nullableParameters, scalarFunction.isPlaceholder()); + FunctionRegistry.registerFunction(name, method, nullableParameters, isPlaceholder, isVarArg); } } else { - FunctionRegistry.registerFunction(method, nullableParameters, scalarFunction.isPlaceholder()); + FunctionRegistry.registerFunction(method, nullableParameters, isPlaceholder, isVarArg); } } } @@ -93,31 +97,40 @@ public static void init() { /** * Registers a method with the name of the method. */ - public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder) { - registerFunction(method.getName(), method, nullableParameters, isPlaceholder); + public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder, + boolean isVarArg) { + registerFunction(method.getName(), method, nullableParameters, isPlaceholder, isVarArg); } /** * Registers a method with the given function name. */ public static void registerFunction(String functionName, Method method, boolean nullableParameters, - boolean isPlaceholder) { + boolean isPlaceholder, boolean isVarArg) { if (!isPlaceholder) { - registerFunctionInfoMap(functionName, method, nullableParameters); + registerFunctionInfoMap(functionName, method, nullableParameters, isVarArg); } - registerCalciteNamedFunctionMap(functionName, method, nullableParameters); + registerCalciteNamedFunctionMap(functionName, method, nullableParameters, isVarArg); } - private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters) { + private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters, + boolean isVarArg) { FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters); String canonicalName = canonicalize(functionName); Map functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>()); - FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); + if (isVarArg) { + FunctionInfo existFunctionInfo = functionInfoMap.put(VAR_ARG_KEY, functionInfo); + Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), + "Function: %s with variable number of parameters is already registered", functionName); + } else { + FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); + Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), + "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); + } } - private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters) { + private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters, + boolean isVarArg) { if (method.getAnnotation(Deprecated.class) == null) { FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method)); } @@ -146,7 +159,14 @@ public static boolean containsFunction(String functionName) { @Nullable public static FunctionInfo getFunctionInfo(String functionName, int numParameters) { Map functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); - return functionInfoMap != null ? functionInfoMap.get(numParameters) : null; + if (functionInfoMap != null) { + FunctionInfo functionInfo = functionInfoMap.get(numParameters); + if (functionInfo != null) { + return functionInfo; + } + return functionInfoMap.get(VAR_ARG_KEY); + } + return null; } private static String canonicalize(String functionName) { diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java index d8529e9842ea..32f115b51a70 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java @@ -23,6 +23,7 @@ import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.objects.ObjectLinkedOpenHashSet; import it.unimi.dsi.fastutil.objects.ObjectSet; +import java.math.BigDecimal; import java.util.Arrays; import org.apache.commons.lang3.ArrayUtils; import org.apache.pinot.spi.annotations.ScalarFunction; @@ -226,4 +227,62 @@ public static double arrayElementAtDouble(double[] arr, int idx) { public static String arrayElementAtString(String[] arr, int idx) { return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING; } + + @ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true) + public static Object arrayValueConstructor(Object... arr) { + if (arr.length == 0) { + return arr; + } + Class clazz = arr[0].getClass(); + if (clazz == Integer.class) { + int[] intArr = new int[arr.length]; + for (int i = 0; i < arr.length; i++) { + intArr[i] = (Integer) arr[i]; + } + return intArr; + } + if (clazz == Long.class) { + long[] longArr = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + longArr[i] = (Long) arr[i]; + } + return longArr; + } + if (clazz == Float.class) { + float[] floatArr = new float[arr.length]; + for (int i = 0; i < arr.length; i++) { + floatArr[i] = (Float) arr[i]; + } + return floatArr; + } + if (clazz == Double.class) { + double[] doubleArr = new double[arr.length]; + for (int i = 0; i < arr.length; i++) { + doubleArr[i] = (Double) arr[i]; + } + return doubleArr; + } + if (clazz == Boolean.class) { + boolean[] boolArr = new boolean[arr.length]; + for (int i = 0; i < arr.length; i++) { + boolArr[i] = (Boolean) arr[i]; + } + return boolArr; + } + if (clazz == BigDecimal.class) { + BigDecimal[] bigDecimalArr = new BigDecimal[arr.length]; + for (int i = 0; i < arr.length; i++) { + bigDecimalArr[i] = (BigDecimal) arr[i]; + } + return bigDecimalArr; + } + if (clazz == String.class) { + String[] strArr = new String[arr.length]; + for (int i = 0; i < arr.length; i++) { + strArr[i] = (String) arr[i]; + } + return strArr; + } + return arr; + } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java index 802ec48a172e..5e80a4f6deb3 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java @@ -102,6 +102,26 @@ public LiteralContext(Literal literal) { _type = DataType.BYTES; _value = literal.getBinaryValue(); break; + case INT_ARRAY_VALUE: + _type = DataType.INT; + _value = literal.getIntArrayValue(); + break; + case LONG_ARRAY_VALUE: + _type = DataType.LONG; + _value = literal.getLongArrayValue(); + break; + case FLOAT_ARRAY_VALUE: + _type = DataType.FLOAT; + _value = literal.getFloatArrayValue(); + break; + case DOUBLE_ARRAY_VALUE: + _type = DataType.DOUBLE; + _value = literal.getDoubleArrayValue(); + break; + case STRING_ARRAY_VALUE: + _type = DataType.STRING; + _value = literal.getStringArrayValue(); + break; case NULL_VALUE: _type = DataType.UNKNOWN; _value = null; diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java index 2cb89d5fe8b7..92bc6bb1b019 100644 --- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java @@ -84,8 +84,13 @@ protected static Expression invokeCompileTimeFunctionExpression(@Nullable Expres } try { FunctionInvoker invoker = new FunctionInvoker(functionInfo); - invoker.convertTypes(arguments); - Object result = invoker.invoke(arguments); + Object result; + if (invoker.getMethod().isVarArgs()) { + result = invoker.invoke(new Object[] {arguments}); + } else { + invoker.convertTypes(arguments); + result = invoker.invoke(arguments); + } return RequestUtils.getLiteralExpression(result); } catch (Exception e) { throw new SqlCompilationException( diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java index 7619b53a3ea9..b2065e20d3a4 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.LiteralContext; import org.apache.pinot.core.operator.ColumnContext; import org.apache.pinot.core.operator.blocks.ValueBlock; import org.apache.pinot.core.operator.transform.TransformResultMetadata; @@ -54,6 +55,77 @@ public class ArrayLiteralTransformFunction implements TransformFunction { private double[][] _doubleArrayResult; private String[][] _stringArrayResult; + public ArrayLiteralTransformFunction(LiteralContext literalContext) { + List literalArray = (List) literalContext.getValue(); + Preconditions.checkNotNull(literalArray); + if (literalArray.isEmpty()) { + _dataType = DataType.UNKNOWN; + _intArrayLiteral = new int[0]; + _longArrayLiteral = new long[0]; + _floatArrayLiteral = new float[0]; + _doubleArrayLiteral = new double[0]; + _stringArrayLiteral = new String[0]; + return; + } + _dataType = literalContext.getType(); + switch (_dataType) { + case INT: + _intArrayLiteral = new int[literalArray.size()]; + for (int i = 0; i < _intArrayLiteral.length; i++) { + _intArrayLiteral[i] = (int) literalArray.get(i); + } + _longArrayLiteral = null; + _floatArrayLiteral = null; + _doubleArrayLiteral = null; + _stringArrayLiteral = null; + break; + case LONG: + _longArrayLiteral = new long[literalArray.size()]; + for (int i = 0; i < _longArrayLiteral.length; i++) { + _longArrayLiteral[i] = (long) literalArray.get(i); + } + _intArrayLiteral = null; + _floatArrayLiteral = null; + _doubleArrayLiteral = null; + _stringArrayLiteral = null; + break; + case FLOAT: + _floatArrayLiteral = new float[literalArray.size()]; + for (int i = 0; i < _floatArrayLiteral.length; i++) { + _floatArrayLiteral[i] = (float) literalArray.get(i); + } + _intArrayLiteral = null; + _longArrayLiteral = null; + _doubleArrayLiteral = null; + _stringArrayLiteral = null; + break; + case DOUBLE: + _doubleArrayLiteral = new double[literalArray.size()]; + for (int i = 0; i < _doubleArrayLiteral.length; i++) { + _doubleArrayLiteral[i] = (double) literalArray.get(i); + } + _intArrayLiteral = null; + _longArrayLiteral = null; + _floatArrayLiteral = null; + _stringArrayLiteral = null; + break; + case STRING: + _stringArrayLiteral = new String[literalArray.size()]; + for (int i = 0; i < _stringArrayLiteral.length; i++) { + _stringArrayLiteral[i] = (String) literalArray.get(i); + } + _intArrayLiteral = null; + _longArrayLiteral = null; + _floatArrayLiteral = null; + _doubleArrayLiteral = null; + break; + default: + throw new IllegalStateException( + "Illegal data type for ArrayLiteralTransformFunction: " + _dataType + ", literal contexts: " + + Arrays.toString(literalArray.toArray())); + } + } + public ArrayLiteralTransformFunction(List literalContexts) { Preconditions.checkNotNull(literalContexts); if (literalContexts.isEmpty()) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java index 49541841ca92..82afb6dbeb22 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java @@ -31,6 +31,7 @@ import org.apache.pinot.common.function.TransformFunctionType; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.common.request.context.FunctionContext; +import org.apache.pinot.common.request.context.LiteralContext; import org.apache.pinot.common.utils.HashUtil; import org.apache.pinot.core.geospatial.transform.function.GeoToH3Function; import org.apache.pinot.core.geospatial.transform.function.StAreaFunction; @@ -335,7 +336,12 @@ public static TransformFunction get(ExpressionContext expression, Map arguments) { if (numArguments == 2) { ExpressionContext arrayExpression = arguments.get(1); Preconditions.checkArgument( - (arrayExpression.getType() == ExpressionContext.Type.FUNCTION) && (arrayExpression.getFunction() - .getFunctionName().equals(ARRAY_CONSTRUCTOR)), + // ARRAY function + ((arrayExpression.getType() == ExpressionContext.Type.FUNCTION) + && (arrayExpression.getFunction().getFunctionName().equals(ARRAY_CONSTRUCTOR))) + || ((arrayExpression.getType() == ExpressionContext.Type.LITERAL) + && (arrayExpression.getLiteral().getValue() instanceof List)), "Please use the format of `Histogram(columnName, ARRAY[1,10,100])` to specify the bin edges"); - _bucketEdges = parseVector(arrayExpression.getFunction().getArguments()); + if (arrayExpression.getType() == ExpressionContext.Type.FUNCTION) { + _bucketEdges = parseVector(arrayExpression.getFunction().getArguments()); + } else { + _bucketEdges = parseVectorLiteral((List) arrayExpression.getLiteral().getValue()); + } _lower = _bucketEdges[0]; _upper = _bucketEdges[_bucketEdges.length - 1]; } else { _isEqualLength = true; _lower = arguments.get(1).getLiteral().getDoubleValue(); _upper = arguments.get(2).getLiteral().getDoubleValue(); - int numBins = arguments.get(3).getLiteral().getIntValue();; + int numBins = arguments.get(3).getLiteral().getIntValue(); + ; Preconditions.checkArgument(_upper > _lower, "The right most edge must be greater than left most edge, given %s and %s", _lower, _upper); Preconditions.checkArgument(numBins > 0, "The number of bins must be greater than zero, given %s", numBins); @@ -109,8 +117,23 @@ private double[] parseVector(List arrayStr) { return ret; } + private double[] parseVectorLiteral(List arrayStr) { + int len = arrayStr.size(); + Preconditions.checkArgument(len > 1, "The number of bin edges must be greater than 1"); + double[] ret = new double[len]; + for (int i = 0; i < len; i++) { + // TODO: Represent infinity as literal instead of identifier + ret[i] = Double.parseDouble(arrayStr.get(i).toString()); + if (i > 0) { + Preconditions.checkState(ret[i] > ret[i - 1], "The bin edges must be strictly increasing"); + } + } + return ret; + } + /** * Find the bin id for the input value. Use division for equal-length bins, and binary search otherwise. + * * @param val input value * @return bin id */ @@ -135,7 +158,7 @@ private int getBinId(double val) { i = mid; } } - id = i; + id = i; } return id; } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java index 82be9bcf52c3..07ba2e00c72d 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java @@ -131,7 +131,7 @@ public void testStateSharedBetweenRowsForExecution() throws Exception { MyFunc myFunc = new MyFunc(); Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class); - FunctionRegistry.registerFunction(method, false, false); + FunctionRegistry.registerFunction(method, false, false, false); String expression = "appendToStringAndReturn('test ')"; InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); assertTrue(evaluator.getArguments().isEmpty()); diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java index 9fe29ec5d09a..c400546c4f3e 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/HistogramQueriesTest.java @@ -312,7 +312,7 @@ public void testInvalidInput() { operator.nextBlock(); } catch (Exception e) { assertEquals(e.getMessage(), - "Invalid aggregation function: histogram(intColumn,arrayvalueconstructor('0')); Reason: The number of " + "Invalid aggregation function: histogram(intColumn,'[0]'); Reason: The number of " + "bin edges must be greater than 1"); } @@ -333,7 +333,7 @@ public void testInvalidInput() { operator.nextBlock(); } catch (Exception e) { assertEquals(e.getMessage(), - "Invalid aggregation function: histogram(intColumn,arrayvalueconstructor('0','0','1','2')); Reason: The " + "Invalid aggregation function: histogram(intColumn,'[0, 0, 1, 2]'); Reason: The " + "bin edges must be strictly increasing"); } diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java index 9275d3ce9073..19bf45b373f3 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java @@ -217,6 +217,22 @@ public void testIntArrayLiteral(boolean useMultiStageQueryEngine) Assert.assertEquals(row.get(0).get(2).asInt(), 3); } + @Test(dataProvider = "useBothQueryEngines") + public void testIntArrayLiteralWithoutFrom(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = "SELECT ARRAY[1,2,3] "; + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + Assert.assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + Assert.assertEquals(row.size(), 1); + Assert.assertEquals(row.get(0).size(), 3); + Assert.assertEquals(row.get(0).get(0).asInt(), 1); + Assert.assertEquals(row.get(0).get(1).asInt(), 2); + Assert.assertEquals(row.get(0).get(2).asInt(), 3); + } + @Test(dataProvider = "useBothQueryEngines") public void testLongArrayLiteral(boolean useMultiStageQueryEngine) throws Exception { @@ -236,6 +252,22 @@ public void testLongArrayLiteral(boolean useMultiStageQueryEngine) Assert.assertEquals(row.get(0).get(2).asLong(), 2147483650L); } + @Test(dataProvider = "useBothQueryEngines") + public void testLongArrayLiteralWithoutFrom(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = "SELECT ARRAY[2147483648,2147483649,2147483650]"; + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + Assert.assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + Assert.assertEquals(row.size(), 1); + Assert.assertEquals(row.get(0).size(), 3); + Assert.assertEquals(row.get(0).get(0).asLong(), 2147483648L); + Assert.assertEquals(row.get(0).get(1).asLong(), 2147483649L); + Assert.assertEquals(row.get(0).get(2).asLong(), 2147483650L); + } + @Test(dataProvider = "useBothQueryEngines") public void testFloatArrayLiteral(boolean useMultiStageQueryEngine) throws Exception { @@ -255,6 +287,22 @@ public void testFloatArrayLiteral(boolean useMultiStageQueryEngine) Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3); } + @Test(dataProvider = "useBothQueryEngines") + public void testFloatArrayLiteralWithoutFrom(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = "SELECT ARRAY[0.1, 0.2, 0.3]"; + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + Assert.assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + Assert.assertEquals(row.size(), 1); + Assert.assertEquals(row.get(0).size(), 3); + Assert.assertEquals(row.get(0).get(0).asDouble(), 0.1); + Assert.assertEquals(row.get(0).get(1).asDouble(), 0.2); + Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3); + } + @Test(dataProvider = "useBothQueryEngines") public void testDoubleArrayLiteral(boolean useMultiStageQueryEngine) throws Exception { @@ -274,6 +322,22 @@ public void testDoubleArrayLiteral(boolean useMultiStageQueryEngine) Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3); } + @Test(dataProvider = "useBothQueryEngines") + public void testDoubleArrayLiteralWithoutFrom(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = "SELECT ARRAY[CAST(0.1 AS DOUBLE), CAST(0.2 AS DOUBLE), CAST(0.3 AS DOUBLE)]"; + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + Assert.assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + Assert.assertEquals(row.size(), 1); + Assert.assertEquals(row.get(0).size(), 3); + Assert.assertEquals(row.get(0).get(0).asDouble(), 0.1); + Assert.assertEquals(row.get(0).get(1).asDouble(), 0.2); + Assert.assertEquals(row.get(0).get(2).asDouble(), 0.3); + } + @Test(dataProvider = "useBothQueryEngines") public void testStringArrayLiteral(boolean useMultiStageQueryEngine) throws Exception { @@ -293,6 +357,22 @@ public void testStringArrayLiteral(boolean useMultiStageQueryEngine) Assert.assertEquals(row.get(0).get(2).asText(), "ccc"); } + @Test(dataProvider = "useBothQueryEngines") + public void testStringArrayLiteralWithoutFrom(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = "SELECT ARRAY['a', 'bb', 'ccc']"; + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + Assert.assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + Assert.assertEquals(row.size(), 1); + Assert.assertEquals(row.get(0).size(), 3); + Assert.assertEquals(row.get(0).get(0).asText(), "a"); + Assert.assertEquals(row.get(0).get(1).asText(), "bb"); + Assert.assertEquals(row.get(0).get(2).asText(), "ccc"); + } + @Override public String getTableName() { return DEFAULT_TABLE_NAME; @@ -352,7 +432,7 @@ public File createAvroFile() fileWriter.append(recordCache.get((int) (i % (getCountStarResult() / 10)), () -> { // create avro record GenericData.Record record = new GenericData.Record(avroSchema); - record.put(BOOLEAN_COLUMN, RANDOM.nextBoolean()); + record.put(BOOLEAN_COLUMN, finalI % 4 == 0 || finalI % 4 == 1); record.put(INT_COLUMN, finalI); record.put(LONG_COLUMN, finalI); record.put(FLOAT_COLUMN, finalI + RANDOM.nextFloat()); diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java index 02d44ade9be1..ea0d531faa6f 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotEvaluateLiteralRule.java @@ -159,8 +159,12 @@ private static RexNode evaluateLiteralOnlyFunction(RexCall rexCall, RexBuilder r Object resultValue; try { FunctionInvoker invoker = new FunctionInvoker(functionInfo); - invoker.convertTypes(arguments); - resultValue = invoker.invoke(arguments); + if (functionInfo.getMethod().isVarArgs()) { + resultValue = invoker.invoke(new Object[] {arguments}); + } else { + invoker.convertTypes(arguments); + resultValue = invoker.invoke(arguments); + } if (rexNodeType.getSqlTypeName() == SqlTypeName.ARRAY) { RelDataType componentType = rexNodeType.getComponentType(); if (componentType != null) { diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java index 9e777d786780..cccc065be36d 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java @@ -50,11 +50,13 @@ public FunctionOperand(RexExpression.FunctionCall functionCall, String canonical FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(canonicalName, numOperands); Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", canonicalName); _functionInvoker = new FunctionInvoker(functionInfo); - Class[] parameterClasses = _functionInvoker.getParameterClasses(); - PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes(); - for (int i = 0; i < numOperands; i++) { - Preconditions.checkState(parameterTypes[i] != null, "Unsupported parameter class: %s for method: %s", - parameterClasses[i], functionInfo.getMethod()); + if (!_functionInvoker.getMethod().isVarArgs()) { + Class[] parameterClasses = _functionInvoker.getParameterClasses(); + PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes(); + for (int i = 0; i < numOperands; i++) { + Preconditions.checkState(parameterTypes[i] != null, "Unsupported parameter class: %s for method: %s", + parameterClasses[i], functionInfo.getMethod()); + } } ColumnDataType functionInvokerResultType = FunctionUtils.getColumnDataType(_functionInvoker.getResultClass()); // Handle unrecognized result class with STRING @@ -80,8 +82,13 @@ public Object apply(Object[] row) { _reusableOperandHolder[i] = value != null ? operand.getResultType().toExternal(value) : null; } // TODO: Optimize per record conversion - _functionInvoker.convertTypes(_reusableOperandHolder); - Object result = _functionInvoker.invoke(_reusableOperandHolder); + Object result; + if (_functionInvoker.getMethod().isVarArgs()) { + result = _functionInvoker.invoke(new Object[]{_reusableOperandHolder}); + } else { + _functionInvoker.convertTypes(_reusableOperandHolder); + result = _functionInvoker.invoke(_reusableOperandHolder); + } return result != null ? TypeUtils.convert(_functionInvokerResultType.toInternal(result), _resultType.getStoredType()) : null; } diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java index 823dd23b885d..8c3909e78482 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java @@ -233,6 +233,9 @@ public Object execute(GenericRow row) { } } } + if (_functionInvoker.getMethod().isVarArgs()) { + return _functionInvoker.invoke(new Object[]{_arguments}); + } _functionInvoker.convertTypes(_arguments); return _functionInvoker.invoke(_arguments); } catch (Exception e) { @@ -256,6 +259,9 @@ public Object execute(Object[] values) { } } } + if (_functionInvoker.getMethod().isVarArgs()) { + return _functionInvoker.invoke(new Object[]{_arguments}); + } _functionInvoker.convertTypes(_arguments); return _functionInvoker.invoke(_arguments); } catch (Exception e) { diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java index 46a743d52c79..0a647a879212 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java @@ -59,4 +59,9 @@ boolean nullableParameters() default false; boolean isPlaceholder() default false; + + /** + * Whether the scalar function takes various number of arguments. + */ + boolean isVarArg() default false; }