From 5321aab03a09e7f04da11bcbefb8c57392cbd3dd Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Tue, 12 Dec 2023 11:45:25 -0800 Subject: [PATCH 1/9] initial commit to make FunctionRegistry use Calcite functions --- .../common/function/FunctionRegistry.java | 146 ++++++++++-------- .../InbuiltFunctionEvaluatorTest.java | 2 +- .../calcite/jdbc/CalciteSchemaBuilder.java | 3 +- 3 files changed, 84 insertions(+), 67 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index deb1673d8bac..615137c0a3dd 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -18,16 +18,20 @@ */ package org.apache.pinot.common.function; -import com.google.common.base.Preconditions; +import com.google.common.annotations.VisibleForTesting; import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.util.HashMap; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.apache.calcite.schema.Function; -import org.apache.calcite.schema.impl.ScalarFunctionImpl; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.schema.impl.ReflectiveFunctionBase; import org.apache.calcite.util.NameMultimap; import org.apache.commons.lang3.StringUtils; import org.apache.pinot.spi.annotations.ScalarFunction; @@ -46,11 +50,7 @@ private FunctionRegistry() { private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class); - // TODO: consolidate the following 2 - // This FUNCTION_INFO_MAP is used by Pinot server to look up function by # of arguments - private static final Map> FUNCTION_INFO_MAP = new HashMap<>(); - // This FUNCTION_MAP is used by Calcite function catalog to look up function by function signature. - private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); + private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); private static final int VAR_ARG_KEY = -1; @@ -68,22 +68,15 @@ private FunctionRegistry() { } ScalarFunction scalarFunction = method.getAnnotation(ScalarFunction.class); if (scalarFunction.enabled()) { - // Annotated function names - String[] scalarFunctionNames = scalarFunction.names(); + // Parse annotated function names and alias + Set scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet()); + scalarFunctionNames.add(method.getName()); boolean nullableParameters = scalarFunction.nullableParameters(); - boolean isPlaceholder = scalarFunction.isPlaceholder(); - boolean isVarArg = scalarFunction.isVarArg(); - if (scalarFunctionNames.length > 0) { - for (String name : scalarFunctionNames) { - FunctionRegistry.registerFunction(name, method, nullableParameters, isPlaceholder, isVarArg); - } - } else { - FunctionRegistry.registerFunction(method, nullableParameters, isPlaceholder, isVarArg); - } + FunctionRegistry.registerFunction(method, scalarFunctionNames, nullableParameters); } } - LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_INFO_MAP.size(), - FUNCTION_INFO_MAP.keySet(), System.currentTimeMillis() - startTimeMs); + LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.map().size(), + FUNCTION_MAP.map().keySet(), System.currentTimeMillis() - startTimeMs); } /** @@ -97,46 +90,21 @@ public static void init() { /** * Registers a method with the name of the method. */ - public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder, - boolean isVarArg) { - registerFunction(method.getName(), method, nullableParameters, isPlaceholder, isVarArg); + @VisibleForTesting + public static void registerFunction(Method method, boolean nullableParameters) { + registerFunction(method, Collections.singleton(method.getName()), nullableParameters); } - /** - * Registers a method with the given function name. - */ - public static void registerFunction(String functionName, Method method, boolean nullableParameters, - boolean isPlaceholder, boolean isVarArg) { - if (!isPlaceholder) { - registerFunctionInfoMap(functionName, method, nullableParameters, isVarArg); - } - registerCalciteNamedFunctionMap(functionName, method, nullableParameters, isVarArg); - } - - private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters, - boolean isVarArg) { - FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters); - String canonicalName = canonicalize(functionName); - Map functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>()); - if (isVarArg) { - FunctionInfo existFunctionInfo = functionInfoMap.put(VAR_ARG_KEY, functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with variable number of parameters is already registered", functionName); - } else { - FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); - } - } - - private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters, - boolean isVarArg) { + private static void registerFunction(Method method, Set alias, boolean nullableParameters) { if (method.getAnnotation(Deprecated.class) == null) { - FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method)); +// String name = canonicalize(method.getName()); + for (String name : alias) { + FUNCTION_MAP.put(name, new PinotScalarFunction(name, alias, method, nullableParameters)); + } } } - public static Map> getRegisteredCalciteFunctionMap() { + public static Map> getRegisteredCalciteFunctionMap() { return FUNCTION_MAP.map(); } @@ -148,7 +116,7 @@ public static Set getRegisteredCalciteFunctionNames() { * Returns {@code true} if the given function name is registered, {@code false} otherwise. */ public static boolean containsFunction(String functionName) { - return FUNCTION_INFO_MAP.containsKey(canonicalize(functionName)); + return FUNCTION_MAP.containsKey(canonicalize(functionName), false); } /** @@ -158,15 +126,21 @@ public static boolean containsFunction(String functionName) { */ @Nullable public static FunctionInfo getFunctionInfo(String functionName, int numParameters) { - Map functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); - if (functionInfoMap != null) { - FunctionInfo functionInfo = functionInfoMap.get(numParameters); - if (functionInfo != null) { - return functionInfo; - } - return functionInfoMap.get(VAR_ARG_KEY); + List candidates = findByNumParameters(FUNCTION_MAP.range(functionName, false), numParameters); + if (candidates.size() <= 1) { + return candidates.size() == 1 ? candidates.get(0).getFunctionInfo() : null; + } else { + throw new IllegalArgumentException( + "Unable to lookup function: " + functionName + " by parameter count: " + numParameters + + " Found multiple candidates. Try to use argument types to resolve the correct one!"); } - return null; + } + + private static List findByNumParameters( + Collection> scalarFunctionList, int numParameters) { + return scalarFunctionList == null ? Collections.emptyList() + : scalarFunctionList.stream().filter(e -> e.getValue().getParameters().size() == numParameters) + .map(Map.Entry::getValue).collect(Collectors.toList()); } private static String canonicalize(String functionName) { @@ -199,4 +173,46 @@ public static double vectorSimilarity(float[] vector1, float[] vector2) { throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); } } + + /** + * Pinot specific implementation of the {@link org.apache.calcite.schema.ScalarFunction}. + * + * @see "{@link org.apache.calcite.schema.impl.ScalarFunctionImpl}" + */ + public static class PinotScalarFunction extends ReflectiveFunctionBase + implements org.apache.calcite.schema.ScalarFunction { + private final FunctionInfo _functionInfo; + private final String _name; + private final Set _alias; + private final Method _method; + + public PinotScalarFunction(String name, Set alias, Method method, boolean isNullableParameter) { + super(method); + _name = name; + _alias = alias; + _method = method; + _functionInfo = new FunctionInfo(method, method.getClass(), isNullableParameter); + } + + @Override + public RelDataType getReturnType(RelDataTypeFactory typeFactory) { + return typeFactory.createJavaType(method.getReturnType()); + } + + public String getName() { + return _name; + } + + public Set getAlias() { + return _alias; + } + + public Method getMethod() { + return _method; + } + + public FunctionInfo getFunctionInfo() { + return _functionInfo; + } + } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java index 07ba2e00c72d..d455faf24515 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java @@ -131,7 +131,7 @@ public void testStateSharedBetweenRowsForExecution() throws Exception { MyFunc myFunc = new MyFunc(); Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class); - FunctionRegistry.registerFunction(method, false, false, false); + FunctionRegistry.registerFunction(method, false); String expression = "appendToStringAndReturn('test ')"; InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); assertTrue(evaluator.getArguments().isEmpty()); diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java b/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java index edb2d74bf07c..efe8b56a07c0 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java @@ -54,7 +54,8 @@ private CalciteSchemaBuilder() { public static CalciteSchema asRootSchema(Schema root) { CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false, "", root); SchemaPlus schemaPlus = rootSchema.plus(); - for (Map.Entry> e : FunctionRegistry.getRegisteredCalciteFunctionMap().entrySet()) { + for (Map.Entry> e + : FunctionRegistry.getRegisteredCalciteFunctionMap().entrySet()) { for (Function f : e.getValue()) { schemaPlus.add(e.getKey(), f); } From 6e0d8ab42b4be8b757eaecb879c89a0c8d4553c2 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Thu, 14 Dec 2023 06:55:10 -0800 Subject: [PATCH 2/9] adding back fallback option to current resolution --- .../common/function/FunctionRegistry.java | 62 ++++++++++++++----- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index 615137c0a3dd..6892ea2946a0 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -19,11 +19,13 @@ package org.apache.pinot.common.function; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -42,15 +44,16 @@ /** * Registry for scalar functions. - *

TODO: Merge FunctionRegistry and FunctionDefinitionRegistry to provide one single registry for all functions. */ public class FunctionRegistry { + private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class); + // TODO: remove both when FunctionOperatorTable is in used. + private static final Map> FUNCTION_INFO_MAP = new HashMap<>(); + private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); + private FunctionRegistry() { } - private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class); - - private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); private static final int VAR_ARG_KEY = -1; @@ -70,7 +73,9 @@ private FunctionRegistry() { if (scalarFunction.enabled()) { // Parse annotated function names and alias Set scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet()); - scalarFunctionNames.add(method.getName()); + if (scalarFunctionNames.size() == 0) { + scalarFunctionNames.add(method.getName()); + } boolean nullableParameters = scalarFunction.nullableParameters(); FunctionRegistry.registerFunction(method, scalarFunctionNames, nullableParameters); } @@ -97,13 +102,26 @@ public static void registerFunction(Method method, boolean nullableParameters) { private static void registerFunction(Method method, Set alias, boolean nullableParameters) { if (method.getAnnotation(Deprecated.class) == null) { -// String name = canonicalize(method.getName()); for (String name : alias) { - FUNCTION_MAP.put(name, new PinotScalarFunction(name, alias, method, nullableParameters)); + registerFunctionInfoMap(name, method, nullableParameters); + registerCalciteNamedFunctionMap(name, method, nullableParameters); } } } + private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters) { + FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters); + String canonicalName = canonicalize(functionName); + Map functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>()); + FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); + Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), + "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); + } + + private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters) { + FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters)); + } + public static Map> getRegisteredCalciteFunctionMap() { return FUNCTION_MAP.map(); } @@ -116,7 +134,9 @@ public static Set getRegisteredCalciteFunctionNames() { * Returns {@code true} if the given function name is registered, {@code false} otherwise. */ public static boolean containsFunction(String functionName) { - return FUNCTION_MAP.containsKey(canonicalize(functionName), false); + // TODO: remove fallback to FUNCTION_INFO_MAP + return FUNCTION_MAP.containsKey(canonicalize(functionName), false) + || FUNCTION_INFO_MAP.containsKey(canonicalize(functionName)); } /** @@ -126,6 +146,22 @@ public static boolean containsFunction(String functionName) { */ @Nullable public static FunctionInfo getFunctionInfo(String functionName, int numParameters) { + // TODO: remove fallback to FUNCTION_INFO_MAP + try { + return getFunctionInfoFromCalciteNamedMap(functionName, numParameters); + } catch (IllegalArgumentException iae) { + return getFunctionInfoFromFunctionInfoMap(functionName, numParameters); + } + } + + @Nullable + private static FunctionInfo getFunctionInfoFromFunctionInfoMap(String functionName, int numParameters) { + Map functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); + return functionInfoMap != null ? functionInfoMap.get(numParameters) : null; + } + + @Nullable + private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionName, int numParameters) { List candidates = findByNumParameters(FUNCTION_MAP.range(functionName, false), numParameters); if (candidates.size() <= 1) { return candidates.size() == 1 ? candidates.get(0).getFunctionInfo() : null; @@ -183,15 +219,13 @@ public static class PinotScalarFunction extends ReflectiveFunctionBase implements org.apache.calcite.schema.ScalarFunction { private final FunctionInfo _functionInfo; private final String _name; - private final Set _alias; private final Method _method; - public PinotScalarFunction(String name, Set alias, Method method, boolean isNullableParameter) { + public PinotScalarFunction(String name, Method method, boolean isNullableParameter) { super(method); _name = name; - _alias = alias; _method = method; - _functionInfo = new FunctionInfo(method, method.getClass(), isNullableParameter); + _functionInfo = new FunctionInfo(method, method.getDeclaringClass(), isNullableParameter); } @Override @@ -203,10 +237,6 @@ public String getName() { return _name; } - public Set getAlias() { - return _alias; - } - public Method getMethod() { return _method; } From ee35d90f83c9b0ca5ef3acebd62be3c6f48bab44 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Thu, 14 Dec 2023 08:59:53 -0800 Subject: [PATCH 3/9] fix tests --- .../common/function/scalar/StringFunctions.java | 2 +- .../function/InbuiltFunctionEvaluatorTest.java | 16 ---------------- .../PostAggregationFunctionTest.java | 2 +- 3 files changed, 2 insertions(+), 18 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java index 5a49314943ba..3af1405f6cdc 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java @@ -140,7 +140,7 @@ public static String substring(String input, int beginIndex, int length) { * @param seperator * @return The two input strings joined by the seperator */ - @ScalarFunction(names = "concat_ws") + @ScalarFunction(names = {"concatWS", "concat_ws"}) public static String concatws(String seperator, String input1, String input2) { return concat(input1, input2, seperator); } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java index d455faf24515..5c6835e293f6 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java @@ -30,7 +30,6 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; -import static org.testng.Assert.fail; public class InbuiltFunctionEvaluatorTest { @@ -158,21 +157,6 @@ public void testNullReturnedByInbuiltFunctionEvaluatorThatCannotTakeNull() { } } - @Test - public void testPlaceholderFunctionShouldNotBeRegistered() - throws Exception { - GenericRow row = new GenericRow(); - row.putValue("testColumn", "testValue"); - String expression = "text_match(testColumn, 'pattern')"; - try { - InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); - evaluator.evaluate(row); - fail(); - } catch (Throwable t) { - assertTrue(t.getMessage().contains("text_match")); - } - } - public static class MyFunc { String _baseString = ""; diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java index 0c7b0e3e52dc..a69f537687a3 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java @@ -56,7 +56,7 @@ public void testPostAggregationFunction() { assertEquals(function.invoke(new Object[]{"1234567890"}), "0987654321"); // ST_AsText - function = new PostAggregationFunction("ST_AsText", new ColumnDataType[]{ColumnDataType.BYTES}); + function = new PostAggregationFunction("ST_As_Text", new ColumnDataType[]{ColumnDataType.BYTES}); assertEquals(function.getResultType(), ColumnDataType.STRING); assertEquals(function.invoke( new Object[]{GeometrySerializer.serialize(GeometryUtils.GEOMETRY_FACTORY.createPoint(new Coordinate(10, 20)))}), From 5afec12c7d588b28a1547ea9f625c715edecd4ba Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Mon, 18 Dec 2023 07:49:56 -0800 Subject: [PATCH 4/9] create new function registry mechansim 1. FunctionRegistry keeps the old FUNCTION_INFO_MAP only 2. moved Calcite Catalog-based schema.Function registry to its own package; along with a SqlOperator based PinotOperatorTable 3. both CatalogReader and OperatorTable utilizes ground truth function from PinotFunctionRegistry --> will be default once deprecate FunctionRegistry 4. PinotFunctionRegistry provides argument-type based lookup via the same method SqlValidator utilize to lookup routine (and lookup operator overload) 5. clean up multi-stage engine side accordingly --- .../common/function/FunctionRegistry.java | 147 ++++------- .../function/registry/PinotFunction.java | 29 +++ .../registry/PinotScalarFunction.java | 83 ++++++ .../sql}/PinotCalciteCatalogReader.java | 6 +- .../function/sql/PinotFunctionRegistry.java | 237 ++++++++++++++++++ .../function/sql/PinotOperatorTable.java | 101 ++++++++ .../function}/sql/PinotSqlAggFunction.java | 6 +- .../sql/PinotSqlTransformFunction.java | 5 +- .../calcite/jdbc/CalciteSchemaBuilder.java | 6 +- .../PinotAggregateExchangeNodeInsertRule.java | 2 +- .../{fun => }/PinotSqlCoalesceFunction.java | 5 +- .../calcite/sql/fun/PinotOperatorTable.java | 171 ------------- .../sql/util/PinotSqlStdOperatorTable.java | 98 ++++++++ .../apache/pinot/query/QueryEnvironment.java | 6 +- 14 files changed, 620 insertions(+), 282 deletions(-) create mode 100644 pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotFunction.java create mode 100644 pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotScalarFunction.java rename {pinot-query-planner/src/main/java/org/apache/calcite/prepare => pinot-common/src/main/java/org/apache/pinot/common/function/sql}/PinotCalciteCatalogReader.java (98%) create mode 100644 pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java create mode 100644 pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java rename {pinot-query-planner/src/main/java/org/apache/calcite => pinot-common/src/main/java/org/apache/pinot/common/function}/sql/PinotSqlAggFunction.java (91%) rename {pinot-query-planner/src/main/java/org/apache/calcite => pinot-common/src/main/java/org/apache/pinot/common/function}/sql/PinotSqlTransformFunction.java (89%) rename pinot-query-planner/src/main/java/org/apache/calcite/sql/{fun => }/PinotSqlCoalesceFunction.java (90%) delete mode 100644 pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java create mode 100644 pinot-query-planner/src/main/java/org/apache/calcite/sql/util/PinotSqlStdOperatorTable.java diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index 6892ea2946a0..e5ee9899d463 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -23,7 +23,6 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -31,11 +30,10 @@ import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.schema.impl.ReflectiveFunctionBase; -import org.apache.calcite.util.NameMultimap; import org.apache.commons.lang3.StringUtils; +import org.apache.pinot.common.function.registry.PinotFunction; +import org.apache.pinot.common.function.registry.PinotScalarFunction; +import org.apache.pinot.common.function.sql.PinotFunctionRegistry; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.PinotReflectionUtils; import org.slf4j.Logger; @@ -46,17 +44,13 @@ * Registry for scalar functions. */ public class FunctionRegistry { + public static final boolean CASE_SENSITIVITY = false; private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class); - // TODO: remove both when FunctionOperatorTable is in used. private static final Map> FUNCTION_INFO_MAP = new HashMap<>(); - private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); private FunctionRegistry() { } - - private static final int VAR_ARG_KEY = -1; - /** * Registers the scalar functions via reflection. * NOTE: In order to plugin methods using reflection, the methods should be inside a class that includes ".function." @@ -80,8 +74,8 @@ private FunctionRegistry() { FunctionRegistry.registerFunction(method, scalarFunctionNames, nullableParameters); } } - LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.map().size(), - FUNCTION_MAP.map().keySet(), System.currentTimeMillis() - startTimeMs); + LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_INFO_MAP.size(), + FUNCTION_INFO_MAP.keySet(), System.currentTimeMillis() - startTimeMs); } /** @@ -92,50 +86,29 @@ private FunctionRegistry() { public static void init() { } - /** - * Registers a method with the name of the method. - */ @VisibleForTesting public static void registerFunction(Method method, boolean nullableParameters) { registerFunction(method, Collections.singleton(method.getName()), nullableParameters); } - private static void registerFunction(Method method, Set alias, boolean nullableParameters) { - if (method.getAnnotation(Deprecated.class) == null) { - for (String name : alias) { - registerFunctionInfoMap(name, method, nullableParameters); - registerCalciteNamedFunctionMap(name, method, nullableParameters); - } - } - } - - private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters) { - FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters); - String canonicalName = canonicalize(functionName); - Map functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>()); - FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); - } - - private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters) { - FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters)); - } - - public static Map> getRegisteredCalciteFunctionMap() { - return FUNCTION_MAP.map(); + @VisibleForTesting + public static Set getRegisteredCalciteFunctionNames() { + return PinotFunctionRegistry.getFunctionMap().map().keySet(); } - public static Set getRegisteredCalciteFunctionNames() { - return FUNCTION_MAP.map().keySet(); + /** + * Returns the full list of all registered ScalarFunction to Calcite. + */ + public static Map> getRegisteredCalciteFunctionMap() { + return PinotFunctionRegistry.getFunctionMap().map(); } /** * Returns {@code true} if the given function name is registered, {@code false} otherwise. */ public static boolean containsFunction(String functionName) { - // TODO: remove fallback to FUNCTION_INFO_MAP - return FUNCTION_MAP.containsKey(canonicalize(functionName), false) + // TODO: remove deprecated FUNCTION_INFO_MAP + return PinotFunctionRegistry.getFunctionMap().containsKey(functionName, CASE_SENSITIVITY) || FUNCTION_INFO_MAP.containsKey(canonicalize(functionName)); } @@ -145,40 +118,54 @@ public static boolean containsFunction(String functionName) { * methods are already registered. */ @Nullable - public static FunctionInfo getFunctionInfo(String functionName, int numParameters) { - // TODO: remove fallback to FUNCTION_INFO_MAP + public static FunctionInfo getFunctionInfo(String functionName, int numParams) { try { - return getFunctionInfoFromCalciteNamedMap(functionName, numParameters); + return getFunctionInfoFromCalciteNamedMap(functionName, numParams); } catch (IllegalArgumentException iae) { - return getFunctionInfoFromFunctionInfoMap(functionName, numParameters); + // TODO: remove deprecated FUNCTION_INFO_MAP + return getFunctionInfoFromFunctionInfoMap(functionName, numParams); + } + } + + // TODO: remove deprecated FUNCTION_INFO_MAP + private static void registerFunction(Method method, Set alias, boolean nullableParameters) { + if (method.getAnnotation(Deprecated.class) == null) { + for (String name : alias) { + registerFunctionInfoMap(name, method, nullableParameters); + } } } + private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters) { + FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters); + String canonicalName = canonicalize(functionName); + Map functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>()); + FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); + Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), + "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); + } + @Nullable - private static FunctionInfo getFunctionInfoFromFunctionInfoMap(String functionName, int numParameters) { + private static FunctionInfo getFunctionInfoFromFunctionInfoMap(String functionName, int numParams) { Map functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); - return functionInfoMap != null ? functionInfoMap.get(numParameters) : null; + return functionInfoMap != null ? functionInfoMap.get(numParams) : null; } @Nullable - private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionName, int numParameters) { - List candidates = findByNumParameters(FUNCTION_MAP.range(functionName, false), numParameters); - if (candidates.size() <= 1) { - return candidates.size() == 1 ? candidates.get(0).getFunctionInfo() : null; + private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionName, int numParams) { + List candidates = PinotFunctionRegistry.getFunctionMap() + .range(functionName, CASE_SENSITIVITY).stream() + .filter(e -> e.getValue() instanceof PinotScalarFunction && e.getValue().getParameters().size() == numParams) + .map(e -> (PinotScalarFunction) e.getValue()).collect(Collectors.toList()); + if (candidates.size() == 1) { + return candidates.get(0).getFunctionInfo(); } else { throw new IllegalArgumentException( - "Unable to lookup function: " + functionName + " by parameter count: " + numParameters - + " Found multiple candidates. Try to use argument types to resolve the correct one!"); + "Unable to lookup function: " + functionName + " by parameter count: " + numParams + " Found " + + candidates.size() + " candidates. Try to use argument types to resolve the correct one!"); } } - private static List findByNumParameters( - Collection> scalarFunctionList, int numParameters) { - return scalarFunctionList == null ? Collections.emptyList() - : scalarFunctionList.stream().filter(e -> e.getValue().getParameters().size() == numParameters) - .map(Map.Entry::getValue).collect(Collectors.toList()); - } - private static String canonicalize(String functionName) { return StringUtils.remove(functionName, '_').toLowerCase(); } @@ -209,40 +196,4 @@ public static double vectorSimilarity(float[] vector1, float[] vector2) { throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); } } - - /** - * Pinot specific implementation of the {@link org.apache.calcite.schema.ScalarFunction}. - * - * @see "{@link org.apache.calcite.schema.impl.ScalarFunctionImpl}" - */ - public static class PinotScalarFunction extends ReflectiveFunctionBase - implements org.apache.calcite.schema.ScalarFunction { - private final FunctionInfo _functionInfo; - private final String _name; - private final Method _method; - - public PinotScalarFunction(String name, Method method, boolean isNullableParameter) { - super(method); - _name = name; - _method = method; - _functionInfo = new FunctionInfo(method, method.getDeclaringClass(), isNullableParameter); - } - - @Override - public RelDataType getReturnType(RelDataTypeFactory typeFactory) { - return typeFactory.createJavaType(method.getReturnType()); - } - - public String getName() { - return _name; - } - - public Method getMethod() { - return _method; - } - - public FunctionInfo getFunctionInfo() { - return _functionInfo; - } - } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotFunction.java new file mode 100644 index 000000000000..f0e756513739 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotFunction.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.registry; + +import org.apache.calcite.schema.Function; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; + + +public interface PinotFunction extends Function { + SqlOperandTypeChecker getOperandTypeChecker(); + SqlReturnTypeInference getReturnTypeInference(); +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotScalarFunction.java new file mode 100644 index 000000000000..c1708aab4203 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotScalarFunction.java @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.registry; + +import java.lang.reflect.Method; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.schema.ScalarFunction; +import org.apache.calcite.schema.impl.ReflectiveFunctionBase; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.pinot.common.function.FunctionInfo; + + +/** + * Pinot specific implementation of the {@link ScalarFunction}. + * + * @see "{@link org.apache.calcite.schema.impl.ScalarFunctionImpl}" + */ +public class PinotScalarFunction extends ReflectiveFunctionBase implements PinotFunction, ScalarFunction { + private final FunctionInfo _functionInfo; + private final String _name; + private final Method _method; + private final SqlOperandTypeChecker _sqlOperandTypeChecker; + private final SqlReturnTypeInference _sqlReturnTypeInference; + + public PinotScalarFunction(String name, Method method, boolean isNullableParameter) { + this(name, method, isNullableParameter, null, null); + } + + public PinotScalarFunction(String name, Method method, boolean isNullableParameter, + SqlOperandTypeChecker sqlOperandTypeChecker, SqlReturnTypeInference sqlReturnTypeInference) { + super(method); + _name = name; + _method = method; + _functionInfo = new FunctionInfo(method, method.getDeclaringClass(), isNullableParameter); + _sqlOperandTypeChecker = sqlOperandTypeChecker; + _sqlReturnTypeInference = sqlReturnTypeInference; + } + + @Override + public RelDataType getReturnType(RelDataTypeFactory typeFactory) { + return typeFactory.createJavaType(method.getReturnType()); + } + + public String getName() { + return _name; + } + + public Method getMethod() { + return _method; + } + + public FunctionInfo getFunctionInfo() { + return _functionInfo; + } + + @Override + public SqlOperandTypeChecker getOperandTypeChecker() { + return _sqlOperandTypeChecker; + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return _sqlReturnTypeInference; + } +} diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/prepare/PinotCalciteCatalogReader.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java similarity index 98% rename from pinot-query-planner/src/main/java/org/apache/calcite/prepare/PinotCalciteCatalogReader.java rename to pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java index 84c71be601f0..672689dff5c5 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/prepare/PinotCalciteCatalogReader.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.calcite.prepare; +package org.apache.pinot.common.function.sql; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; @@ -35,6 +35,8 @@ import org.apache.calcite.linq4j.function.Hints; import org.apache.calcite.model.ModelHandler; import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.prepare.RelOptTableImpl; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeFactoryImpl; @@ -310,7 +312,7 @@ public static SqlOperatorTable operatorTable(String... classNames) { } /** Converts a function to a {@link org.apache.calcite.sql.SqlOperator}. */ - private static SqlOperator toOp(SqlIdentifier name, + public static SqlOperator toOp(SqlIdentifier name, final org.apache.calcite.schema.Function function) { final Function> argTypesFactory = typeFactory -> function.getParameters() diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java new file mode 100644 index 000000000000..776a95927255 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java @@ -0,0 +1,237 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.sql; + +import com.google.common.annotations.VisibleForTesting; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.schema.Function; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlNameMatchers; +import org.apache.calcite.sql.validate.SqlUserDefinedFunction; +import org.apache.calcite.util.NameMultimap; +import org.apache.pinot.common.function.TransformFunctionType; +import org.apache.pinot.common.function.registry.PinotFunction; +import org.apache.pinot.common.function.registry.PinotScalarFunction; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.spi.annotations.ScalarFunction; +import org.apache.pinot.spi.utils.PinotReflectionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Registry for scalar functions. + */ +public class PinotFunctionRegistry { + private static final Logger LOGGER = LoggerFactory.getLogger(PinotFunctionRegistry.class); + private static final NameMultimap OPERATOR_MAP = new NameMultimap<>(); + private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); + + private PinotFunctionRegistry() { + } + + /** + * Registers the scalar functions via reflection. + * NOTE: In order to plugin methods using reflection, the methods should be inside a class that includes ".function." + * in its class path. This convention can significantly reduce the time of class scanning. + */ + static { + // REGISTER FUNCTIONS + long startTimeMs = System.currentTimeMillis(); + Set methods = PinotReflectionUtils.getMethodsThroughReflection(".*\\.function\\..*", ScalarFunction.class); + for (Method method : methods) { + if (!Modifier.isPublic(method.getModifiers())) { + continue; + } + ScalarFunction scalarFunction = method.getAnnotation(ScalarFunction.class); + if (scalarFunction.enabled()) { + // Parse annotated function names and alias + Set scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet()); + if (scalarFunctionNames.size() == 0) { + scalarFunctionNames.add(method.getName()); + } + boolean nullableParameters = scalarFunction.nullableParameters(); + PinotFunctionRegistry.registerFunction(method, scalarFunctionNames, nullableParameters); + } + } + LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.map().size(), + FUNCTION_MAP.map().keySet(), System.currentTimeMillis() - startTimeMs); + + // REGISTER OPERATORS + // Walk through all the Pinot aggregation types and + // 1. register those that are supported in multistage in addition to calcite standard opt table. + // 2. register special handling that differs from calcite standard. + for (AggregationFunctionType aggregationFunctionType : AggregationFunctionType.values()) { + if (aggregationFunctionType.getSqlKind() != null) { + // 1. Register the aggregation function with Calcite + registerAggregateFunction(aggregationFunctionType.getName(), aggregationFunctionType); + // 2. Register the aggregation function with Calcite on all alternative names + List alternativeFunctionNames = aggregationFunctionType.getAlternativeNames(); + for (String alternativeFunctionName : alternativeFunctionNames) { + registerAggregateFunction(alternativeFunctionName, aggregationFunctionType); + } + } + } + + // Walk through all the Pinot transform types and + // 1. register those that are supported in multistage in addition to calcite standard opt table. + // 2. register special handling that differs from calcite standard. + for (TransformFunctionType transformFunctionType : TransformFunctionType.values()) { + if (transformFunctionType.getSqlKind() != null) { + // 1. Register the transform function with Calcite + registerTransformFunction(transformFunctionType.getName(), transformFunctionType); + // 2. Register the transform function with Calcite on all alternative names + List alternativeFunctionNames = transformFunctionType.getAlternativeNames(); + for (String alternativeFunctionName : alternativeFunctionNames) { + registerTransformFunction(alternativeFunctionName, transformFunctionType); + } + } + } + } + + public static void init() { + } + + @VisibleForTesting + public static void registerFunction(Method method, boolean nullableParameters) { + registerFunction(method, Collections.singleton(method.getName()), nullableParameters); + } + + public static NameMultimap getFunctionMap() { + return FUNCTION_MAP; + } + + public static NameMultimap getOperatorMap() { + return OPERATOR_MAP; + } + + @Nullable + public static PinotScalarFunction getScalarFunction(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, + String functionName, List argTypes) { + List relArgTypes = convertArgumentTypes(typeFactory, argTypes); + SqlOperator sqlOperator = SqlUtil.lookupRoutine(operatorTable, typeFactory, + new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), relArgTypes, null, null, SqlSyntax.FUNCTION, + SqlKind.OTHER_FUNCTION, SqlNameMatchers.withCaseSensitive(false), true); + if (sqlOperator instanceof SqlUserDefinedFunction) { + Function function = ((SqlUserDefinedFunction) sqlOperator).getFunction(); + if (function instanceof PinotScalarFunction) { + return (PinotScalarFunction) function; + } + } + return null; + } + + private static void registerFunction(Method method, Set alias, boolean nullableParameters) { + if (method.getAnnotation(Deprecated.class) == null) { + for (String name : alias) { + registerCalciteNamedFunctionMap(name, method, nullableParameters); + } + } + } + + private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters) { + FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters)); + } + + private static List convertArgumentTypes(RelDataTypeFactory typeFactory, + List argTypes) { + return argTypes.stream().map(type -> toRelType(typeFactory, type)).collect(Collectors.toList()); + } + + private static RelDataType toRelType(RelDataTypeFactory typeFactory, DataSchema.ColumnDataType dataType) { + switch (dataType) { + case INT: + return typeFactory.createSqlType(SqlTypeName.INTEGER); + case LONG: + return typeFactory.createSqlType(SqlTypeName.BIGINT); + case FLOAT: + return typeFactory.createSqlType(SqlTypeName.REAL); + case DOUBLE: + return typeFactory.createSqlType(SqlTypeName.DOUBLE); + case BIG_DECIMAL: + return typeFactory.createSqlType(SqlTypeName.DECIMAL); + case BOOLEAN: + return typeFactory.createSqlType(SqlTypeName.BOOLEAN); + case TIMESTAMP: + return typeFactory.createSqlType(SqlTypeName.TIMESTAMP); + case JSON: + case STRING: + return typeFactory.createSqlType(SqlTypeName.VARCHAR); + case BYTES: + return typeFactory.createSqlType(SqlTypeName.VARBINARY); + case INT_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.INTEGER), -1); + case LONG_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.BIGINT), -1); + case FLOAT_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.REAL), -1); + case DOUBLE_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.DOUBLE), -1); + case BOOLEAN_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.BOOLEAN), -1); + case TIMESTAMP_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.TIMESTAMP), -1); + case STRING_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.VARCHAR), -1); + case BYTES_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.VARBINARY), -1); + case UNKNOWN: + case OBJECT: + default: + return typeFactory.createSqlType(SqlTypeName.ANY); + } + } + + private static void registerAggregateFunction(String functionName, AggregationFunctionType functionType) { + if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { + PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, + functionType.getSqlKind(), functionType.getReturnTypeInference(), null, + functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); + OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlAggFunction); + } + } + + private static void registerTransformFunction(String functionName, TransformFunctionType functionType) { + if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { + PinotSqlTransformFunction sqlTransformFunction = + new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), + functionType.getSqlKind(), functionType.getReturnTypeInference(), null, + functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); + OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlTransformFunction); + } + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java new file mode 100644 index 000000000000..ca4513a5ba73 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java @@ -0,0 +1,101 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.sql; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.validate.SqlNameMatcher; +import org.apache.pinot.common.function.FunctionRegistry; +import org.checkerframework.checker.nullness.qual.Nullable; + + +/** + * Temporary implementation of all dynamic arg/return type inference operators. + * TODO: merge this with {@link PinotCalciteCatalogReader} once we support + * 1. Return/Inference configuration in @ScalarFunction + * 2. Allow @ScalarFunction registry towards class (with multiple impl) + */ +public class PinotOperatorTable implements SqlOperatorTable { + private static final PinotOperatorTable INSTANCE = new PinotOperatorTable(); + + public static synchronized PinotOperatorTable instance() { + return INSTANCE; + } + + @Override public void lookupOperatorOverloads(SqlIdentifier opName, + @Nullable SqlFunctionCategory category, SqlSyntax syntax, + List operatorList, SqlNameMatcher nameMatcher) { + String simpleName = opName.getSimple(); + final Collection list = + lookUpOperators(simpleName); + if (list.isEmpty()) { + return; + } + for (SqlOperator op : list) { + if (op.getSyntax() == syntax) { + operatorList.add(op); + } else if (syntax == SqlSyntax.FUNCTION + && op instanceof SqlFunction) { + // this special case is needed for operators like CAST, + // which are treated as functions but have special syntax + operatorList.add(op); + } + } + + // REVIEW jvs 1-Jan-2005: why is this extra lookup required? + // Shouldn't it be covered by search above? + switch (syntax) { + case BINARY: + case PREFIX: + case POSTFIX: + for (SqlOperator extra + : lookUpOperators(simpleName)) { + // REVIEW: should only search operators added during this method? + if (extra != null && !operatorList.contains(extra)) { + operatorList.add(extra); + } + } + break; + default: + break; + } + } + + /** + * Look up operators based on case-sensitiveness. + */ + private Collection lookUpOperators(String name) { + return PinotFunctionRegistry.getOperatorMap().range(name, FunctionRegistry.CASE_SENSITIVITY).stream() + .map(Map.Entry::getValue).collect(Collectors.toSet()); + } + + @Override + public List getOperatorList() { + return PinotFunctionRegistry.getOperatorMap().map().values().stream().flatMap(List::stream) + .collect(Collectors.toList()); + } +} diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlAggFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlAggFunction.java similarity index 91% rename from pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlAggFunction.java rename to pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlAggFunction.java index 0d4146f4317d..bc8f55474e00 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlAggFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlAggFunction.java @@ -16,8 +16,12 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.calcite.sql; +package org.apache.pinot.common.function.sql; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlTransformFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlTransformFunction.java similarity index 89% rename from pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlTransformFunction.java rename to pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlTransformFunction.java index 827c9f37337b..c2e8ce8130c1 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlTransformFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlTransformFunction.java @@ -16,8 +16,11 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.calcite.sql; +package org.apache.pinot.common.function.sql; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java b/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java index efe8b56a07c0..adadcd6992b6 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java @@ -23,7 +23,8 @@ import org.apache.calcite.schema.Function; import org.apache.calcite.schema.Schema; import org.apache.calcite.schema.SchemaPlus; -import org.apache.pinot.common.function.FunctionRegistry; +import org.apache.pinot.common.function.registry.PinotFunction; +import org.apache.pinot.common.function.sql.PinotFunctionRegistry; /** @@ -54,8 +55,7 @@ private CalciteSchemaBuilder() { public static CalciteSchema asRootSchema(Schema root) { CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false, "", root); SchemaPlus schemaPlus = rootSchema.plus(); - for (Map.Entry> e - : FunctionRegistry.getRegisteredCalciteFunctionMap().entrySet()) { + for (Map.Entry> e : PinotFunctionRegistry.getFunctionMap().map().entrySet()) { for (Function f : e.getValue()) { schemaPlus.add(e.getKey(), f); } diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java index df904123d201..68e331eb5d8b 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java @@ -40,7 +40,6 @@ import org.apache.calcite.rel.logical.PinotLogicalExchange; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.PinotSqlAggFunction; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; @@ -53,6 +52,7 @@ import org.apache.calcite.util.mapping.Mapping; import org.apache.calcite.util.mapping.MappingType; import org.apache.calcite.util.mapping.Mappings; +import org.apache.pinot.common.function.sql.PinotSqlAggFunction; import org.apache.pinot.query.planner.plannode.AggregateNode.AggType; import org.apache.pinot.segment.spi.AggregationFunctionType; diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotSqlCoalesceFunction.java b/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlCoalesceFunction.java similarity index 90% rename from pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotSqlCoalesceFunction.java rename to pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlCoalesceFunction.java index 92ef85857f9a..2b2ee32f083f 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotSqlCoalesceFunction.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlCoalesceFunction.java @@ -16,10 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.calcite.sql.fun; +package org.apache.calcite.sql; -import org.apache.calcite.sql.SqlCall; -import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.fun.SqlCoalesceFunction; import org.apache.calcite.sql.validate.SqlValidator; diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java deleted file mode 100644 index 3617a7c06270..000000000000 --- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java +++ /dev/null @@ -1,171 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.calcite.sql.fun; - -import java.lang.reflect.Field; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; -import org.apache.calcite.sql.PinotSqlAggFunction; -import org.apache.calcite.sql.PinotSqlTransformFunction; -import org.apache.calcite.sql.SqlFunction; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.validate.SqlNameMatchers; -import org.apache.calcite.util.Util; -import org.apache.pinot.common.function.TransformFunctionType; -import org.apache.pinot.segment.spi.AggregationFunctionType; -import org.checkerframework.checker.nullness.qual.MonotonicNonNull; - - -/** - * {@link PinotOperatorTable} defines the {@link SqlOperator} overrides on top of the {@link SqlStdOperatorTable}. - * - *

The main purpose of this Pinot specific SQL operator table is to - *

    - *
  • Ensure that any specific SQL validation rules can apply with Pinot override entirely over Calcite's.
  • - *
  • Ability to create customer operators that are not function and cannot use - * {@link org.apache.calcite.prepare.Prepare.CatalogReader} to override
  • - *
  • Still maintain minimum customization and benefit from Calcite's original operator table setting.
  • - *
- */ -@SuppressWarnings("unused") // unused fields are accessed by reflection -public class PinotOperatorTable extends SqlStdOperatorTable { - - private static @MonotonicNonNull PinotOperatorTable _instance; - - // TODO: clean up lazy init by using Suppliers.memorized(this::computeInstance) and make getter wrapped around - // supplier instance. this should replace all lazy init static objects in the codebase - public static synchronized PinotOperatorTable instance() { - if (_instance == null) { - // Creates and initializes the standard operator table. - // Uses two-phase construction, because we can't initialize the - // table until the constructor of the sub-class has completed. - _instance = new PinotOperatorTable(); - _instance.initNoDuplicate(); - } - return _instance; - } - - /** - * Initialize without duplicate, e.g. when 2 duplicate operator is linked with the same op - * {@link org.apache.calcite.sql.SqlKind} it causes problem. - * - *

This is a direct copy of the {@link org.apache.calcite.sql.util.ReflectiveSqlOperatorTable} and can be hard to - * debug, suggest changing to a non-dynamic registration. Dynamic function support should happen via catalog. - * - * This also registers aggregation functions defined in {@link org.apache.pinot.segment.spi.AggregationFunctionType} - * which are multistage enabled. - */ - public final void initNoDuplicate() { - // Pinot supports native COALESCE function, thus no need to create CASE WHEN conversion. - register(new PinotSqlCoalesceFunction()); - // Ensure ArrayValueConstructor is registered before ArrayQueryConstructor - register(ARRAY_VALUE_CONSTRUCTOR); - - // TODO: reflection based registration is not ideal, we should use a static list of operators and register them - // Use reflection to register the expressions stored in public fields. - for (Field field : getClass().getFields()) { - try { - if (SqlFunction.class.isAssignableFrom(field.getType())) { - SqlFunction op = (SqlFunction) field.get(this); - if (op != null && notRegistered(op)) { - register(op); - } - } else if ( - SqlOperator.class.isAssignableFrom(field.getType())) { - SqlOperator op = (SqlOperator) field.get(this); - if (op != null && notRegistered(op)) { - register(op); - } - } - } catch (IllegalArgumentException | IllegalAccessException e) { - throw Util.throwAsRuntime(Util.causeOrSelf(e)); - } - } - - // Walk through all the Pinot aggregation types and - // 1. register those that are supported in multistage in addition to calcite standard opt table. - // 2. register special handling that differs from calcite standard. - for (AggregationFunctionType aggregationFunctionType : AggregationFunctionType.values()) { - if (aggregationFunctionType.getSqlKind() != null) { - // 1. Register the aggregation function with Calcite - registerAggregateFunction(aggregationFunctionType.getName(), aggregationFunctionType); - // 2. Register the aggregation function with Calcite on all alternative names - List alternativeFunctionNames = aggregationFunctionType.getAlternativeNames(); - for (String alternativeFunctionName : alternativeFunctionNames) { - registerAggregateFunction(alternativeFunctionName, aggregationFunctionType); - } - } - } - - // Walk through all the Pinot transform types and - // 1. register those that are supported in multistage in addition to calcite standard opt table. - // 2. register special handling that differs from calcite standard. - for (TransformFunctionType transformFunctionType : TransformFunctionType.values()) { - if (transformFunctionType.getSqlKind() != null) { - // 1. Register the transform function with Calcite - registerTransformFunction(transformFunctionType.getName(), transformFunctionType); - // 2. Register the transform function with Calcite on all alternative names - List alternativeFunctionNames = transformFunctionType.getAlternativeNames(); - for (String alternativeFunctionName : alternativeFunctionNames) { - registerTransformFunction(alternativeFunctionName, transformFunctionType); - } - } - } - } - - private void registerAggregateFunction(String functionName, AggregationFunctionType functionType) { - // register function behavior that's different from Calcite - if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { - PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, - functionType.getSqlKind(), functionType.getReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); - if (notRegistered(sqlAggFunction)) { - register(sqlAggFunction); - } - } - } - - private void registerTransformFunction(String functionName, TransformFunctionType functionType) { - // register function behavior that's different from Calcite - if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { - PinotSqlTransformFunction sqlTransformFunction = - new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), - functionType.getSqlKind(), functionType.getReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); - if (notRegistered(sqlTransformFunction)) { - register(sqlTransformFunction); - } - } - } - - private boolean notRegistered(SqlFunction op) { - List operatorList = new ArrayList<>(); - lookupOperatorOverloads(op.getNameAsId(), op.getFunctionType(), op.getSyntax(), operatorList, - SqlNameMatchers.withCaseSensitive(false)); - return operatorList.size() == 0; - } - - private boolean notRegistered(SqlOperator op) { - List operatorList = new ArrayList<>(); - lookupOperatorOverloads(op.getNameAsId(), null, op.getSyntax(), operatorList, - SqlNameMatchers.withCaseSensitive(false)); - return operatorList.size() == 0; - } -} diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/util/PinotSqlStdOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/calcite/sql/util/PinotSqlStdOperatorTable.java new file mode 100644 index 000000000000..e3f6d8a72214 --- /dev/null +++ b/pinot-query-planner/src/main/java/org/apache/calcite/sql/util/PinotSqlStdOperatorTable.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.calcite.sql.util; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.List; +import org.apache.calcite.sql.PinotSqlCoalesceFunction; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.validate.SqlNameMatchers; +import org.apache.calcite.util.Util; +import org.apache.pinot.common.function.FunctionRegistry; + + +public class PinotSqlStdOperatorTable extends SqlStdOperatorTable { + private static PinotSqlStdOperatorTable _instance; + + // supplier instance. this should replace all lazy init static objects in the codebase + public static synchronized PinotSqlStdOperatorTable instance() { + if (_instance == null) { + // Creates and initializes the standard operator table. + // Uses two-phase construction, because we can't initialize the + // table until the constructor of the sub-class has completed. + _instance = new PinotSqlStdOperatorTable(); + _instance.initNoDuplicate(); + } + return _instance; + } + + /** + * Initialize without duplicate, e.g. when 2 duplicate operator is linked with the same op + * {@link org.apache.calcite.sql.SqlKind} it causes problem. + * + *

This is a direct copy of the {@link org.apache.calcite.sql.util.ReflectiveSqlOperatorTable} and can be hard to + * debug, suggest changing to a non-dynamic registration. Dynamic function support should happen via catalog. + * + * This also registers aggregation functions defined in {@link org.apache.pinot.segment.spi.AggregationFunctionType} + * which are multistage enabled. + */ + public final void initNoDuplicate() { + // Pinot supports native COALESCE function, thus no need to create CASE WHEN conversion. + register(new PinotSqlCoalesceFunction()); + // Ensure ArrayValueConstructor is registered before ArrayQueryConstructor + register(ARRAY_VALUE_CONSTRUCTOR); + + // TODO: reflection based registration is not ideal, we should use a static list of operators and register them + // Use reflection to register the expressions stored in public fields. + for (Field field : getClass().getFields()) { + try { + if (SqlFunction.class.isAssignableFrom(field.getType())) { + SqlFunction op = (SqlFunction) field.get(this); + if (op != null && notRegistered(op)) { + register(op); + } + } else if (SqlOperator.class.isAssignableFrom(field.getType())) { + SqlOperator op = (SqlOperator) field.get(this); + if (op != null && notRegistered(op)) { + register(op); + } + } + } catch (IllegalArgumentException | IllegalAccessException e) { + throw Util.throwAsRuntime(Util.causeOrSelf(e)); + } + } + } + + private boolean notRegistered(SqlFunction op) { + List operatorList = new ArrayList<>(); + lookupOperatorOverloads(op.getNameAsId(), op.getFunctionType(), op.getSyntax(), operatorList, + SqlNameMatchers.withCaseSensitive(FunctionRegistry.CASE_SENSITIVITY)); + return operatorList.size() == 0; + } + + private boolean notRegistered(SqlOperator op) { + List operatorList = new ArrayList<>(); + lookupOperatorOverloads(op.getNameAsId(), null, op.getSyntax(), operatorList, + SqlNameMatchers.withCaseSensitive(FunctionRegistry.CASE_SENSITIVITY)); + return operatorList.size() == 0; + } +} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java index 769d6a607fc2..c835eb36cd93 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java @@ -35,7 +35,6 @@ import org.apache.calcite.plan.hep.HepMatchOrder; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; -import org.apache.calcite.prepare.PinotCalciteCatalogReader; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; @@ -52,8 +51,8 @@ import org.apache.calcite.sql.SqlExplainLevel; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.fun.PinotOperatorTable; import org.apache.calcite.sql.util.PinotChainedSqlOperatorTable; +import org.apache.calcite.sql.util.PinotSqlStdOperatorTable; import org.apache.calcite.sql2rel.PinotConvertletTable; import org.apache.calcite.sql2rel.RelDecorrelator; import org.apache.calcite.sql2rel.SqlToRelConverter; @@ -61,6 +60,8 @@ import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RelBuilder; import org.apache.pinot.common.config.provider.TableCache; +import org.apache.pinot.common.function.sql.PinotCalciteCatalogReader; +import org.apache.pinot.common.function.sql.PinotOperatorTable; import org.apache.pinot.query.context.PlannerContext; import org.apache.pinot.query.planner.PlannerUtils; import org.apache.pinot.query.planner.QueryPlan; @@ -113,6 +114,7 @@ public QueryEnvironment(TypeFactory typeFactory, CalciteSchema rootSchema, Worke _config = Frameworks.newConfigBuilder().traitDefs() .operatorTable(new PinotChainedSqlOperatorTable(Arrays.asList( + PinotSqlStdOperatorTable.instance(), PinotOperatorTable.instance(), _catalogReader))) .defaultSchema(_rootSchema.plus()) From 169412f1edb2ef891856b1ebe98bd15d9e9f5bb2 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Mon, 18 Dec 2023 09:23:09 -0800 Subject: [PATCH 5/9] wire up function lookup - use signature type lookup for v2 engine - deprecate usage of FunctionRegistry - allow nullable return from function lookup b/c some operators doesn't have scalar equivalent at the moment --- .../common/function/FunctionRegistry.java | 109 ++++-------------- .../InbuiltFunctionEvaluatorTest.java | 4 +- .../operator/operands/FunctionOperand.java | 16 ++- .../operands/TransformOperandFactory.java | 19 ++- 4 files changed, 56 insertions(+), 92 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index e5ee9899d463..4113da269654 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -19,23 +19,16 @@ package org.apache.pinot.common.function; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.apache.commons.lang3.StringUtils; -import org.apache.pinot.common.function.registry.PinotFunction; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlOperatorTable; import org.apache.pinot.common.function.registry.PinotScalarFunction; import org.apache.pinot.common.function.sql.PinotFunctionRegistry; +import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.spi.annotations.ScalarFunction; -import org.apache.pinot.spi.utils.PinotReflectionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,38 +39,10 @@ public class FunctionRegistry { public static final boolean CASE_SENSITIVITY = false; private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class); - private static final Map> FUNCTION_INFO_MAP = new HashMap<>(); private FunctionRegistry() { } - /** - * Registers the scalar functions via reflection. - * NOTE: In order to plugin methods using reflection, the methods should be inside a class that includes ".function." - * in its class path. This convention can significantly reduce the time of class scanning. - */ - static { - long startTimeMs = System.currentTimeMillis(); - Set methods = PinotReflectionUtils.getMethodsThroughReflection(".*\\.function\\..*", ScalarFunction.class); - for (Method method : methods) { - if (!Modifier.isPublic(method.getModifiers())) { - continue; - } - ScalarFunction scalarFunction = method.getAnnotation(ScalarFunction.class); - if (scalarFunction.enabled()) { - // Parse annotated function names and alias - Set scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet()); - if (scalarFunctionNames.size() == 0) { - scalarFunctionNames.add(method.getName()); - } - boolean nullableParameters = scalarFunction.nullableParameters(); - FunctionRegistry.registerFunction(method, scalarFunctionNames, nullableParameters); - } - } - LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_INFO_MAP.size(), - FUNCTION_INFO_MAP.keySet(), System.currentTimeMillis() - startTimeMs); - } - /** * Initializes the FunctionRegistry. * NOTE: This method itself is a NO-OP, but can be used to explicitly trigger the static block of registering the @@ -86,30 +51,16 @@ private FunctionRegistry() { public static void init() { } - @VisibleForTesting - public static void registerFunction(Method method, boolean nullableParameters) { - registerFunction(method, Collections.singleton(method.getName()), nullableParameters); - } - @VisibleForTesting public static Set getRegisteredCalciteFunctionNames() { return PinotFunctionRegistry.getFunctionMap().map().keySet(); } - /** - * Returns the full list of all registered ScalarFunction to Calcite. - */ - public static Map> getRegisteredCalciteFunctionMap() { - return PinotFunctionRegistry.getFunctionMap().map(); - } - /** * Returns {@code true} if the given function name is registered, {@code false} otherwise. */ public static boolean containsFunction(String functionName) { - // TODO: remove deprecated FUNCTION_INFO_MAP - return PinotFunctionRegistry.getFunctionMap().containsKey(functionName, CASE_SENSITIVITY) - || FUNCTION_INFO_MAP.containsKey(canonicalize(functionName)); + return PinotFunctionRegistry.getFunctionMap().containsKey(functionName, CASE_SENSITIVITY); } /** @@ -119,36 +70,25 @@ public static boolean containsFunction(String functionName) { */ @Nullable public static FunctionInfo getFunctionInfo(String functionName, int numParams) { - try { - return getFunctionInfoFromCalciteNamedMap(functionName, numParams); - } catch (IllegalArgumentException iae) { - // TODO: remove deprecated FUNCTION_INFO_MAP - return getFunctionInfoFromFunctionInfoMap(functionName, numParams); - } - } - - // TODO: remove deprecated FUNCTION_INFO_MAP - private static void registerFunction(Method method, Set alias, boolean nullableParameters) { - if (method.getAnnotation(Deprecated.class) == null) { - for (String name : alias) { - registerFunctionInfoMap(name, method, nullableParameters); - } - } - } - - private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters) { - FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters); - String canonicalName = canonicalize(functionName); - Map functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>()); - FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); + return getFunctionInfoFromCalciteNamedMap(functionName, numParams); } + /** + * Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null} + * if there is no matching method. This method should be called after the FunctionRegistry is initialized and all + * methods are already registered. + */ @Nullable - private static FunctionInfo getFunctionInfoFromFunctionInfoMap(String functionName, int numParams) { - Map functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); - return functionInfoMap != null ? functionInfoMap.get(numParams) : null; + public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, + String functionName, List argTypes) { + PinotScalarFunction scalarFunction = + PinotFunctionRegistry.getScalarFunction(operatorTable, typeFactory, functionName, argTypes); + if (scalarFunction != null) { + return scalarFunction.getFunctionInfo(); + } else { + // TODO: convert this to throw IAE when all operator has scalar equivalent. + return null; + } } @Nullable @@ -160,16 +100,11 @@ private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionNa if (candidates.size() == 1) { return candidates.get(0).getFunctionInfo(); } else { - throw new IllegalArgumentException( - "Unable to lookup function: " + functionName + " by parameter count: " + numParams + " Found " - + candidates.size() + " candidates. Try to use argument types to resolve the correct one!"); + // TODO: convert this to throw IAE when all operator has scalar equivalent. + return null; } } - private static String canonicalize(String functionName) { - return StringUtils.remove(functionName, '_').toLowerCase(); - } - /** * Placeholders for scalar function, they register and represents the signature for transform and filter predicate * so that v2 engine can understand and plan them correctly. diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java index 5c6835e293f6..8a1e418cff8d 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java @@ -20,7 +20,7 @@ import java.lang.reflect.Method; import java.util.Collections; -import org.apache.pinot.common.function.FunctionRegistry; +import org.apache.pinot.common.function.sql.PinotFunctionRegistry; import org.apache.pinot.segment.local.function.InbuiltFunctionEvaluator; import org.apache.pinot.spi.data.readers.GenericRow; import org.joda.time.DateTime; @@ -130,7 +130,7 @@ public void testStateSharedBetweenRowsForExecution() throws Exception { MyFunc myFunc = new MyFunc(); Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class); - FunctionRegistry.registerFunction(method, false); + PinotFunctionRegistry.registerFunction(method, false); String expression = "appendToStringAndReturn('test ')"; InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); assertTrue(evaluator.getArguments().isEmpty()); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java index cccc065be36d..29419dc5e4d9 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java @@ -21,7 +21,10 @@ import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; import javax.annotation.Nullable; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlOperatorTable; import org.apache.pinot.common.function.FunctionInfo; import org.apache.pinot.common.function.FunctionInvoker; import org.apache.pinot.common.function.FunctionRegistry; @@ -43,11 +46,20 @@ public class FunctionOperand implements TransformOperand { private final List _operands; private final Object[] _reusableOperandHolder; - public FunctionOperand(RexExpression.FunctionCall functionCall, String canonicalName, DataSchema dataSchema) { + public FunctionOperand(SqlOperatorTable sqlOperatorTable, RelDataTypeFactory relDataTypeFactory, + RexExpression.FunctionCall functionCall, String canonicalName, DataSchema dataSchema) { _resultType = functionCall.getDataType(); List operands = functionCall.getFunctionOperands(); int numOperands = operands.size(); - FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(canonicalName, numOperands); + List operandTypes = operands.stream().map(e -> { + if (e instanceof RexExpression.InputRef) { + return dataSchema.getColumnDataType(((RexExpression.InputRef) e).getIndex()); + } else { + return e.getDataType(); + } + }).collect(Collectors.toList()); + FunctionInfo functionInfo = + FunctionRegistry.getFunctionInfo(sqlOperatorTable, relDataTypeFactory, canonicalName, operandTypes); Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", canonicalName); _functionInvoker = new FunctionInvoker(functionInfo); if (!_functionInvoker.getMethod().isVarArgs()) { diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperandFactory.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperandFactory.java index 4a95a8b16278..5de96701f157 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperandFactory.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperandFactory.java @@ -20,12 +20,28 @@ import com.google.common.base.Preconditions; import java.util.List; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.jdbc.CalciteSchemaBuilder; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.pinot.common.function.sql.PinotCalciteCatalogReader; import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.query.catalog.PinotCatalog; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.runtime.operator.utils.OperatorUtils; +import org.apache.pinot.query.type.TypeFactory; +import org.apache.pinot.query.type.TypeSystem; public class TransformOperandFactory { + private static final RelDataTypeFactory FUNCTION_CATALOGREL_DATA_TYPE_FACTORY = new TypeFactory(new TypeSystem()); + private static final CalciteSchema FUNCTION_CATALOG_ROOT_SCHEMA = + CalciteSchemaBuilder.asRootSchema(new PinotCatalog(null)); + private static final CalciteConnectionConfig FUNCTION_CATALOG_CONFIG = CalciteConnectionConfig.DEFAULT; + private static final PinotCalciteCatalogReader FUNCTION_CATALOG_OPERATOR_TABLE = + new PinotCalciteCatalogReader(FUNCTION_CATALOG_ROOT_SCHEMA, FUNCTION_CATALOG_ROOT_SCHEMA.path(null), + FUNCTION_CATALOGREL_DATA_TYPE_FACTORY, FUNCTION_CATALOG_CONFIG); + private TransformOperandFactory() { } @@ -74,7 +90,8 @@ private static TransformOperand getTransformOperand(RexExpression.FunctionCall f case "lessThanOrEqual": return new FilterOperand.Predicate(operands, dataSchema, v -> v <= 0); default: - return new FunctionOperand(functionCall, canonicalName, dataSchema); + return new FunctionOperand(FUNCTION_CATALOG_OPERATOR_TABLE, FUNCTION_CATALOGREL_DATA_TYPE_FACTORY, + functionCall, canonicalName, dataSchema); } } } From 374616ff71097d34daa5027ea52fec417451db29 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Fri, 5 Jan 2024 10:56:17 -0800 Subject: [PATCH 6/9] clean up, refactor encapsulations - merge PinotFunctionRegistry with FunctionRegistry - renamed to match calcite.schema and calcite.sql from pinot.common.function package --- .../common/function/FunctionRegistry.java | 206 ++++++++++++++- .../{registry => schema}/PinotFunction.java | 6 +- .../PinotScalarFunction.java | 2 +- .../sql/PinotCalciteCatalogReader.java | 17 +- .../function/sql/PinotFunctionRegistry.java | 237 ------------------ .../function/sql/PinotOperatorTable.java | 4 +- .../InbuiltFunctionEvaluatorTest.java | 4 +- .../calcite/jdbc/CalciteSchemaBuilder.java | 9 +- 8 files changed, 226 insertions(+), 259 deletions(-) rename pinot-common/src/main/java/org/apache/pinot/common/function/{registry => schema}/PinotFunction.java (81%) rename pinot-common/src/main/java/org/apache/pinot/common/function/{registry => schema}/PinotScalarFunction.java (98%) delete mode 100644 pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index 4113da269654..5ea6f3806696 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -19,30 +19,112 @@ package org.apache.pinot.common.function; import com.google.common.annotations.VisibleForTesting; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.schema.Function; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlOperatorTable; -import org.apache.pinot.common.function.registry.PinotScalarFunction; -import org.apache.pinot.common.function.sql.PinotFunctionRegistry; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlNameMatchers; +import org.apache.calcite.sql.validate.SqlUserDefinedFunction; +import org.apache.calcite.util.NameMultimap; +import org.apache.pinot.common.function.schema.PinotFunction; +import org.apache.pinot.common.function.schema.PinotScalarFunction; +import org.apache.pinot.common.function.sql.PinotSqlAggFunction; +import org.apache.pinot.common.function.sql.PinotSqlTransformFunction; import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.segment.spi.AggregationFunctionType; import org.apache.pinot.spi.annotations.ScalarFunction; +import org.apache.pinot.spi.utils.PinotReflectionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * Registry for scalar functions. + * Registry for functions. */ public class FunctionRegistry { public static final boolean CASE_SENSITIVITY = false; private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class); + private static final NameMultimap OPERATOR_MAP = new NameMultimap<>(); + private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); private FunctionRegistry() { } + /** + * Registers the scalar functions via reflection. + * NOTE: In order to plugin methods using reflection, the methods should be inside a class that includes ".function." + * in its class path. This convention can significantly reduce the time of class scanning. + */ + static { + // REGISTER FUNCTIONS + long startTimeMs = System.currentTimeMillis(); + Set methods = PinotReflectionUtils.getMethodsThroughReflection(".*\\.function\\..*", ScalarFunction.class); + for (Method method : methods) { + if (!Modifier.isPublic(method.getModifiers())) { + continue; + } + ScalarFunction scalarFunction = method.getAnnotation(ScalarFunction.class); + if (scalarFunction.enabled()) { + // Parse annotated function names and alias + Set scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet()); + if (scalarFunctionNames.size() == 0) { + scalarFunctionNames.add(method.getName()); + } + boolean nullableParameters = scalarFunction.nullableParameters(); + registerFunction(method, scalarFunctionNames, nullableParameters); + } + } + LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.map().size(), + FUNCTION_MAP.map().keySet(), System.currentTimeMillis() - startTimeMs); + + // REGISTER OPERATORS + // Walk through all the Pinot aggregation types and + // 1. register those that are supported in multistage in addition to calcite standard opt table. + // 2. register special handling that differs from calcite standard. + for (AggregationFunctionType aggregationFunctionType : AggregationFunctionType.values()) { + if (aggregationFunctionType.getSqlKind() != null) { + // 1. Register the aggregation function with Calcite + registerAggregateFunction(aggregationFunctionType.getName(), aggregationFunctionType); + // 2. Register the aggregation function with Calcite on all alternative names + List alternativeFunctionNames = aggregationFunctionType.getAlternativeNames(); + for (String alternativeFunctionName : alternativeFunctionNames) { + registerAggregateFunction(alternativeFunctionName, aggregationFunctionType); + } + } + } + + // Walk through all the Pinot transform types and + // 1. register those that are supported in multistage in addition to calcite standard opt table. + // 2. register special handling that differs from calcite standard. + for (TransformFunctionType transformFunctionType : TransformFunctionType.values()) { + if (transformFunctionType.getSqlKind() != null) { + // 1. Register the transform function with Calcite + registerTransformFunction(transformFunctionType.getName(), transformFunctionType); + // 2. Register the transform function with Calcite on all alternative names + List alternativeFunctionNames = transformFunctionType.getAlternativeNames(); + for (String alternativeFunctionName : alternativeFunctionNames) { + registerTransformFunction(alternativeFunctionName, transformFunctionType); + } + } + } + } + /** * Initializes the FunctionRegistry. * NOTE: This method itself is a NO-OP, but can be used to explicitly trigger the static block of registering the @@ -51,16 +133,21 @@ private FunctionRegistry() { public static void init() { } + @VisibleForTesting + public static void registerFunction(Method method, boolean nullableParameters) { + registerFunction(method, Collections.singleton(method.getName()), nullableParameters); + } + @VisibleForTesting public static Set getRegisteredCalciteFunctionNames() { - return PinotFunctionRegistry.getFunctionMap().map().keySet(); + return getFunctionMap().map().keySet(); } /** * Returns {@code true} if the given function name is registered, {@code false} otherwise. */ public static boolean containsFunction(String functionName) { - return PinotFunctionRegistry.getFunctionMap().containsKey(functionName, CASE_SENSITIVITY); + return getFunctionMap().containsKey(functionName, CASE_SENSITIVITY); } /** @@ -81,8 +168,7 @@ public static FunctionInfo getFunctionInfo(String functionName, int numParams) { @Nullable public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, String functionName, List argTypes) { - PinotScalarFunction scalarFunction = - PinotFunctionRegistry.getScalarFunction(operatorTable, typeFactory, functionName, argTypes); + PinotScalarFunction scalarFunction = getScalarFunction(operatorTable, typeFactory, functionName, argTypes); if (scalarFunction != null) { return scalarFunction.getFunctionInfo(); } else { @@ -93,7 +179,7 @@ public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDa @Nullable private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionName, int numParams) { - List candidates = PinotFunctionRegistry.getFunctionMap() + List candidates = getFunctionMap() .range(functionName, CASE_SENSITIVITY).stream() .filter(e -> e.getValue() instanceof PinotScalarFunction && e.getValue().getParameters().size() == numParams) .map(e -> (PinotScalarFunction) e.getValue()).collect(Collectors.toList()); @@ -105,6 +191,110 @@ private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionNa } } + @Nullable + private static PinotScalarFunction getScalarFunction(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, + String functionName, List argTypes) { + List relArgTypes = convertArgumentTypes(typeFactory, argTypes); + SqlOperator sqlOperator = SqlUtil.lookupRoutine(operatorTable, typeFactory, + new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), relArgTypes, null, null, SqlSyntax.FUNCTION, + SqlKind.OTHER_FUNCTION, SqlNameMatchers.withCaseSensitive(false), true); + if (sqlOperator instanceof SqlUserDefinedFunction) { + Function function = ((SqlUserDefinedFunction) sqlOperator).getFunction(); + if (function instanceof PinotScalarFunction) { + return (PinotScalarFunction) function; + } + } + return null; + } + + public static NameMultimap getFunctionMap() { + return FUNCTION_MAP; + } + + public static NameMultimap getOperatorMap() { + return OPERATOR_MAP; + } + + private static void registerFunction(Method method, Set alias, boolean nullableParameters) { + if (method.getAnnotation(Deprecated.class) == null) { + for (String name : alias) { + registerCalciteNamedFunctionMap(name, method, nullableParameters); + } + } + } + + private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters) { + FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters)); + } + + private static void registerAggregateFunction(String functionName, AggregationFunctionType functionType) { + if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { + PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, + functionType.getSqlKind(), functionType.getReturnTypeInference(), null, + functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); + OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlAggFunction); + } + } + + private static void registerTransformFunction(String functionName, TransformFunctionType functionType) { + if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { + PinotSqlTransformFunction sqlTransformFunction = + new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), + functionType.getSqlKind(), functionType.getReturnTypeInference(), null, + functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); + OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlTransformFunction); + } + } + + private static List convertArgumentTypes(RelDataTypeFactory typeFactory, + List argTypes) { + return argTypes.stream().map(type -> toRelType(typeFactory, type)).collect(Collectors.toList()); + } + + private static RelDataType toRelType(RelDataTypeFactory typeFactory, DataSchema.ColumnDataType dataType) { + switch (dataType) { + case INT: + return typeFactory.createSqlType(SqlTypeName.INTEGER); + case LONG: + return typeFactory.createSqlType(SqlTypeName.BIGINT); + case FLOAT: + return typeFactory.createSqlType(SqlTypeName.REAL); + case DOUBLE: + return typeFactory.createSqlType(SqlTypeName.DOUBLE); + case BIG_DECIMAL: + return typeFactory.createSqlType(SqlTypeName.DECIMAL); + case BOOLEAN: + return typeFactory.createSqlType(SqlTypeName.BOOLEAN); + case TIMESTAMP: + return typeFactory.createSqlType(SqlTypeName.TIMESTAMP); + case JSON: + case STRING: + return typeFactory.createSqlType(SqlTypeName.VARCHAR); + case BYTES: + return typeFactory.createSqlType(SqlTypeName.VARBINARY); + case INT_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.INTEGER), -1); + case LONG_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.BIGINT), -1); + case FLOAT_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.REAL), -1); + case DOUBLE_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.DOUBLE), -1); + case BOOLEAN_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.BOOLEAN), -1); + case TIMESTAMP_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.TIMESTAMP), -1); + case STRING_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.VARCHAR), -1); + case BYTES_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.VARBINARY), -1); + case UNKNOWN: + case OBJECT: + default: + return typeFactory.createSqlType(SqlTypeName.ANY); + } + } + /** * Placeholders for scalar function, they register and represents the signature for transform and filter predicate * so that v2 engine can understand and plan them correctly. diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotFunction.java similarity index 81% rename from pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotFunction.java rename to pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotFunction.java index f0e756513739..95f5f5f0bf64 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotFunction.java @@ -16,13 +16,17 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.common.function.registry; +package org.apache.pinot.common.function.schema; import org.apache.calcite.schema.Function; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlReturnTypeInference; +/** + * Function Schema used to resolve function signature by {@link org.apache.pinot.common.function.FunctionRegistry} and + * {@link org.apache.pinot.common.function.sql.PinotCalciteCatalogReader}. + */ public interface PinotFunction extends Function { SqlOperandTypeChecker getOperandTypeChecker(); SqlReturnTypeInference getReturnTypeInference(); diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java similarity index 98% rename from pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotScalarFunction.java rename to pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java index c1708aab4203..5030dc47bb53 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotScalarFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.common.function.registry; +package org.apache.pinot.common.function.schema; import java.lang.reflect.Method; import org.apache.calcite.rel.type.RelDataType; diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java index 672689dff5c5..99863aeddf1e 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java @@ -81,9 +81,14 @@ /** * ============================================================================ - * THIS CLASS IS COPIED FROM Calcite's {@link org.apache.calcite.prepare.CalciteCatalogReader} and modified the - * case sensitivity of Function lookup. which is ALWAYS case-insensitive regardless of conventions on - * column/table identifier. + * THIS CLASS IS COPIED FROM Calcite's {@link org.apache.calcite.prepare.CalciteCatalogReader} and modified + *

    + *
  • the case sensitivity of Function lookup. Pinot ALWAYS resolve case-insensitive function regardless of + * case sensitivity conventions of column/table identifier.
  • + *
  • made the {@link PinotCalciteCatalogReader#toOp(SqlIdentifier, org.apache.calcite.schema.Function)}
  • method + * public access for overriding behavior for catalog function operand/return type inference. + *
+ * * ============================================================================ * * Pinot's implementation of {@link org.apache.calcite.prepare.Prepare.CatalogReader} @@ -312,8 +317,14 @@ public static SqlOperatorTable operatorTable(String... classNames) { } /** Converts a function to a {@link org.apache.calcite.sql.SqlOperator}. */ + // ==================================================================== + // LINES CHANGED BELOW + // ==================================================================== public static SqlOperator toOp(SqlIdentifier name, final org.apache.calcite.schema.Function function) { + // ==================================================================== + // LINES CHANGED ABOVE + // ==================================================================== final Function> argTypesFactory = typeFactory -> function.getParameters() .stream() diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java deleted file mode 100644 index 776a95927255..000000000000 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java +++ /dev/null @@ -1,237 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.common.function.sql; - -import com.google.common.annotations.VisibleForTesting; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Locale; -import java.util.Set; -import java.util.stream.Collectors; -import javax.annotation.Nullable; -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.schema.Function; -import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.SqlOperatorTable; -import org.apache.calcite.sql.SqlSyntax; -import org.apache.calcite.sql.SqlUtil; -import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.sql.validate.SqlNameMatchers; -import org.apache.calcite.sql.validate.SqlUserDefinedFunction; -import org.apache.calcite.util.NameMultimap; -import org.apache.pinot.common.function.TransformFunctionType; -import org.apache.pinot.common.function.registry.PinotFunction; -import org.apache.pinot.common.function.registry.PinotScalarFunction; -import org.apache.pinot.common.utils.DataSchema; -import org.apache.pinot.segment.spi.AggregationFunctionType; -import org.apache.pinot.spi.annotations.ScalarFunction; -import org.apache.pinot.spi.utils.PinotReflectionUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -/** - * Registry for scalar functions. - */ -public class PinotFunctionRegistry { - private static final Logger LOGGER = LoggerFactory.getLogger(PinotFunctionRegistry.class); - private static final NameMultimap OPERATOR_MAP = new NameMultimap<>(); - private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); - - private PinotFunctionRegistry() { - } - - /** - * Registers the scalar functions via reflection. - * NOTE: In order to plugin methods using reflection, the methods should be inside a class that includes ".function." - * in its class path. This convention can significantly reduce the time of class scanning. - */ - static { - // REGISTER FUNCTIONS - long startTimeMs = System.currentTimeMillis(); - Set methods = PinotReflectionUtils.getMethodsThroughReflection(".*\\.function\\..*", ScalarFunction.class); - for (Method method : methods) { - if (!Modifier.isPublic(method.getModifiers())) { - continue; - } - ScalarFunction scalarFunction = method.getAnnotation(ScalarFunction.class); - if (scalarFunction.enabled()) { - // Parse annotated function names and alias - Set scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet()); - if (scalarFunctionNames.size() == 0) { - scalarFunctionNames.add(method.getName()); - } - boolean nullableParameters = scalarFunction.nullableParameters(); - PinotFunctionRegistry.registerFunction(method, scalarFunctionNames, nullableParameters); - } - } - LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.map().size(), - FUNCTION_MAP.map().keySet(), System.currentTimeMillis() - startTimeMs); - - // REGISTER OPERATORS - // Walk through all the Pinot aggregation types and - // 1. register those that are supported in multistage in addition to calcite standard opt table. - // 2. register special handling that differs from calcite standard. - for (AggregationFunctionType aggregationFunctionType : AggregationFunctionType.values()) { - if (aggregationFunctionType.getSqlKind() != null) { - // 1. Register the aggregation function with Calcite - registerAggregateFunction(aggregationFunctionType.getName(), aggregationFunctionType); - // 2. Register the aggregation function with Calcite on all alternative names - List alternativeFunctionNames = aggregationFunctionType.getAlternativeNames(); - for (String alternativeFunctionName : alternativeFunctionNames) { - registerAggregateFunction(alternativeFunctionName, aggregationFunctionType); - } - } - } - - // Walk through all the Pinot transform types and - // 1. register those that are supported in multistage in addition to calcite standard opt table. - // 2. register special handling that differs from calcite standard. - for (TransformFunctionType transformFunctionType : TransformFunctionType.values()) { - if (transformFunctionType.getSqlKind() != null) { - // 1. Register the transform function with Calcite - registerTransformFunction(transformFunctionType.getName(), transformFunctionType); - // 2. Register the transform function with Calcite on all alternative names - List alternativeFunctionNames = transformFunctionType.getAlternativeNames(); - for (String alternativeFunctionName : alternativeFunctionNames) { - registerTransformFunction(alternativeFunctionName, transformFunctionType); - } - } - } - } - - public static void init() { - } - - @VisibleForTesting - public static void registerFunction(Method method, boolean nullableParameters) { - registerFunction(method, Collections.singleton(method.getName()), nullableParameters); - } - - public static NameMultimap getFunctionMap() { - return FUNCTION_MAP; - } - - public static NameMultimap getOperatorMap() { - return OPERATOR_MAP; - } - - @Nullable - public static PinotScalarFunction getScalarFunction(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, - String functionName, List argTypes) { - List relArgTypes = convertArgumentTypes(typeFactory, argTypes); - SqlOperator sqlOperator = SqlUtil.lookupRoutine(operatorTable, typeFactory, - new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), relArgTypes, null, null, SqlSyntax.FUNCTION, - SqlKind.OTHER_FUNCTION, SqlNameMatchers.withCaseSensitive(false), true); - if (sqlOperator instanceof SqlUserDefinedFunction) { - Function function = ((SqlUserDefinedFunction) sqlOperator).getFunction(); - if (function instanceof PinotScalarFunction) { - return (PinotScalarFunction) function; - } - } - return null; - } - - private static void registerFunction(Method method, Set alias, boolean nullableParameters) { - if (method.getAnnotation(Deprecated.class) == null) { - for (String name : alias) { - registerCalciteNamedFunctionMap(name, method, nullableParameters); - } - } - } - - private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters) { - FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters)); - } - - private static List convertArgumentTypes(RelDataTypeFactory typeFactory, - List argTypes) { - return argTypes.stream().map(type -> toRelType(typeFactory, type)).collect(Collectors.toList()); - } - - private static RelDataType toRelType(RelDataTypeFactory typeFactory, DataSchema.ColumnDataType dataType) { - switch (dataType) { - case INT: - return typeFactory.createSqlType(SqlTypeName.INTEGER); - case LONG: - return typeFactory.createSqlType(SqlTypeName.BIGINT); - case FLOAT: - return typeFactory.createSqlType(SqlTypeName.REAL); - case DOUBLE: - return typeFactory.createSqlType(SqlTypeName.DOUBLE); - case BIG_DECIMAL: - return typeFactory.createSqlType(SqlTypeName.DECIMAL); - case BOOLEAN: - return typeFactory.createSqlType(SqlTypeName.BOOLEAN); - case TIMESTAMP: - return typeFactory.createSqlType(SqlTypeName.TIMESTAMP); - case JSON: - case STRING: - return typeFactory.createSqlType(SqlTypeName.VARCHAR); - case BYTES: - return typeFactory.createSqlType(SqlTypeName.VARBINARY); - case INT_ARRAY: - return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.INTEGER), -1); - case LONG_ARRAY: - return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.BIGINT), -1); - case FLOAT_ARRAY: - return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.REAL), -1); - case DOUBLE_ARRAY: - return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.DOUBLE), -1); - case BOOLEAN_ARRAY: - return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.BOOLEAN), -1); - case TIMESTAMP_ARRAY: - return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.TIMESTAMP), -1); - case STRING_ARRAY: - return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.VARCHAR), -1); - case BYTES_ARRAY: - return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.VARBINARY), -1); - case UNKNOWN: - case OBJECT: - default: - return typeFactory.createSqlType(SqlTypeName.ANY); - } - } - - private static void registerAggregateFunction(String functionName, AggregationFunctionType functionType) { - if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { - PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, - functionType.getSqlKind(), functionType.getReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); - OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlAggFunction); - } - } - - private static void registerTransformFunction(String functionName, TransformFunctionType functionType) { - if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { - PinotSqlTransformFunction sqlTransformFunction = - new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), - functionType.getSqlKind(), functionType.getReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); - OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlTransformFunction); - } - } -} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java index ca4513a5ba73..3782e9010efc 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java @@ -89,13 +89,13 @@ public static synchronized PinotOperatorTable instance() { * Look up operators based on case-sensitiveness. */ private Collection lookUpOperators(String name) { - return PinotFunctionRegistry.getOperatorMap().range(name, FunctionRegistry.CASE_SENSITIVITY).stream() + return FunctionRegistry.getOperatorMap().range(name, FunctionRegistry.CASE_SENSITIVITY).stream() .map(Map.Entry::getValue).collect(Collectors.toSet()); } @Override public List getOperatorList() { - return PinotFunctionRegistry.getOperatorMap().map().values().stream().flatMap(List::stream) + return FunctionRegistry.getOperatorMap().map().values().stream().flatMap(List::stream) .collect(Collectors.toList()); } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java index 8a1e418cff8d..5c6835e293f6 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java @@ -20,7 +20,7 @@ import java.lang.reflect.Method; import java.util.Collections; -import org.apache.pinot.common.function.sql.PinotFunctionRegistry; +import org.apache.pinot.common.function.FunctionRegistry; import org.apache.pinot.segment.local.function.InbuiltFunctionEvaluator; import org.apache.pinot.spi.data.readers.GenericRow; import org.joda.time.DateTime; @@ -130,7 +130,7 @@ public void testStateSharedBetweenRowsForExecution() throws Exception { MyFunc myFunc = new MyFunc(); Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class); - PinotFunctionRegistry.registerFunction(method, false); + FunctionRegistry.registerFunction(method, false); String expression = "appendToStringAndReturn('test ')"; InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); assertTrue(evaluator.getArguments().isEmpty()); diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java b/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java index adadcd6992b6..8c87cd67d955 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java @@ -20,11 +20,10 @@ import java.util.List; import java.util.Map; -import org.apache.calcite.schema.Function; import org.apache.calcite.schema.Schema; import org.apache.calcite.schema.SchemaPlus; -import org.apache.pinot.common.function.registry.PinotFunction; -import org.apache.pinot.common.function.sql.PinotFunctionRegistry; +import org.apache.pinot.common.function.FunctionRegistry; +import org.apache.pinot.common.function.schema.PinotFunction; /** @@ -55,8 +54,8 @@ private CalciteSchemaBuilder() { public static CalciteSchema asRootSchema(Schema root) { CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false, "", root); SchemaPlus schemaPlus = rootSchema.plus(); - for (Map.Entry> e : PinotFunctionRegistry.getFunctionMap().map().entrySet()) { - for (Function f : e.getValue()) { + for (Map.Entry> e : FunctionRegistry.getFunctionMap().map().entrySet()) { + for (PinotFunction f : e.getValue()) { schemaPlus.add(e.getKey(), f); } } From b7f7b7e728843ac35bafabb70811f73c875fecb3 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Mon, 18 Dec 2023 10:07:42 -0800 Subject: [PATCH 7/9] support ScalarFunction annotation to class - allow complex transform and other dynamic operand/return inference - example added for array value constructor --- .../common/function/FunctionRegistry.java | 40 +++++++ .../function/scalar/ArrayFunctions.java | 108 ++++++++++-------- .../sql/PinotCalciteCatalogReader.java | 33 +++--- .../pinot/spi/annotations/ScalarFunction.java | 2 +- 4 files changed, 121 insertions(+), 62 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index 5ea6f3806696..ed2f50154451 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -38,6 +38,8 @@ import org.apache.calcite.sql.SqlSyntax; import org.apache.calcite.sql.SqlUtil; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlNameMatchers; import org.apache.calcite.sql.validate.SqlUserDefinedFunction; @@ -90,6 +92,22 @@ private FunctionRegistry() { registerFunction(method, scalarFunctionNames, nullableParameters); } } + Set> classes = PinotReflectionUtils.getClassesThroughReflection(".*\\.function\\..*", ScalarFunction.class); + for (Class clazz : classes) { + if (!Modifier.isPublic(clazz.getModifiers())) { + continue; + } + ScalarFunction scalarFunction = clazz.getAnnotation(ScalarFunction.class); + if (scalarFunction.enabled()) { + // Parse annotated function names and alias + Set scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet()); + if (scalarFunctionNames.size() == 0) { + scalarFunctionNames.add(clazz.getName()); + } + boolean nullableParameters = scalarFunction.nullableParameters(); + registerFunction(clazz, scalarFunctionNames, nullableParameters); + } + } LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.map().size(), FUNCTION_MAP.map().keySet(), System.currentTimeMillis() - startTimeMs); @@ -223,10 +241,32 @@ private static void registerFunction(Method method, Set alias, boolean n } } + private static void registerFunction(Class clazz, Set alias, boolean nullableParameters) { + if (clazz.getAnnotation(Deprecated.class) == null) { + for (String name : alias) { + registerCalciteNamedFunctionMap(name, clazz, nullableParameters); + } + } + } + private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters) { FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters)); } + private static void registerCalciteNamedFunctionMap(String name, Class clazz, boolean nullableParameters) { + try { + SqlReturnTypeInference returnTypeInference = (SqlReturnTypeInference) clazz.getField("RETURN_TYPE_INFERENCE").get(null); + SqlOperandTypeChecker operandTypeChecker = (SqlOperandTypeChecker) clazz.getField("OPERAND_TYPE_CHECKER").get(null); + for (Method method : clazz.getMethods()) { + if (method.getName().equals("eval")) { + FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, operandTypeChecker, returnTypeInference)); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + private static void registerAggregateFunction(String functionName, AggregationFunctionType functionType) { if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java index 32f115b51a70..aae61e94a7c8 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java @@ -25,6 +25,13 @@ import it.unimi.dsi.fastutil.objects.ObjectSet; import java.math.BigDecimal; import java.util.Arrays; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SameOperandTypeChecker; +import org.apache.calcite.sql.type.SqlOperandCountRanges; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.commons.lang3.ArrayUtils; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder; @@ -228,61 +235,66 @@ public static String arrayElementAtString(String[] arr, int idx) { return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING; } - @ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true) - public static Object arrayValueConstructor(Object... arr) { - if (arr.length == 0) { - return arr; - } - Class clazz = arr[0].getClass(); - if (clazz == Integer.class) { - int[] intArr = new int[arr.length]; - for (int i = 0; i < arr.length; i++) { - intArr[i] = (Integer) arr[i]; + @ScalarFunction(names = {"array", "arrayValueConstructor"}) + public static class ArrayValueConstructor { + public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = ReturnTypes.TO_ARRAY; + public static final SqlOperandTypeChecker OPERAND_TYPE_CHECKER = new SameOperandTypeChecker(-1); + + public static Object eval(Object... arr) { + if (arr.length == 0) { + return arr; } - return intArr; - } - if (clazz == Long.class) { - long[] longArr = new long[arr.length]; - for (int i = 0; i < arr.length; i++) { - longArr[i] = (Long) arr[i]; + Class clazz = arr[0].getClass(); + if (clazz == Integer.class) { + int[] intArr = new int[arr.length]; + for (int i = 0; i < arr.length; i++) { + intArr[i] = (Integer) arr[i]; + } + return intArr; } - return longArr; - } - if (clazz == Float.class) { - float[] floatArr = new float[arr.length]; - for (int i = 0; i < arr.length; i++) { - floatArr[i] = (Float) arr[i]; + if (clazz == Long.class) { + long[] longArr = new long[arr.length]; + for (int i = 0; i < arr.length; i++) { + longArr[i] = (Long) arr[i]; + } + return longArr; } - return floatArr; - } - if (clazz == Double.class) { - double[] doubleArr = new double[arr.length]; - for (int i = 0; i < arr.length; i++) { - doubleArr[i] = (Double) arr[i]; + if (clazz == Float.class) { + float[] floatArr = new float[arr.length]; + for (int i = 0; i < arr.length; i++) { + floatArr[i] = (Float) arr[i]; + } + return floatArr; } - return doubleArr; - } - if (clazz == Boolean.class) { - boolean[] boolArr = new boolean[arr.length]; - for (int i = 0; i < arr.length; i++) { - boolArr[i] = (Boolean) arr[i]; + if (clazz == Double.class) { + double[] doubleArr = new double[arr.length]; + for (int i = 0; i < arr.length; i++) { + doubleArr[i] = (Double) arr[i]; + } + return doubleArr; } - return boolArr; - } - if (clazz == BigDecimal.class) { - BigDecimal[] bigDecimalArr = new BigDecimal[arr.length]; - for (int i = 0; i < arr.length; i++) { - bigDecimalArr[i] = (BigDecimal) arr[i]; + if (clazz == Boolean.class) { + boolean[] boolArr = new boolean[arr.length]; + for (int i = 0; i < arr.length; i++) { + boolArr[i] = (Boolean) arr[i]; + } + return boolArr; } - return bigDecimalArr; - } - if (clazz == String.class) { - String[] strArr = new String[arr.length]; - for (int i = 0; i < arr.length; i++) { - strArr[i] = (String) arr[i]; + if (clazz == BigDecimal.class) { + BigDecimal[] bigDecimalArr = new BigDecimal[arr.length]; + for (int i = 0; i < arr.length; i++) { + bigDecimalArr[i] = (BigDecimal) arr[i]; + } + return bigDecimalArr; } - return strArr; + if (clazz == String.class) { + String[] strArr = new String[arr.length]; + for (int i = 0; i < arr.length; i++) { + strArr[i] = (String) arr[i]; + } + return strArr; + } + return arr; } - return arr; } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java index 99863aeddf1e..e3f3d2b737ef 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java @@ -352,8 +352,12 @@ public static SqlOperator toOp(SqlIdentifier name, final List typeFamilies = typeFamiliesFactory.apply(dummyTypeFactory); - final SqlOperandTypeInference operandTypeInference = - InferTypes.explicit(argTypes); + final SqlOperandTypeInference operandTypeInference; + if (function instanceof PinotScalarFunction && ((PinotScalarFunction) function).getOperandTypeChecker() != null) { + operandTypeInference = ((PinotScalarFunction) function).getOperandTypeChecker().typeInference(); + } else { + operandTypeInference = InferTypes.explicit(argTypes); + } final SqlOperandMetadata operandMetadata = OperandTypes.operandMetadata(typeFamilies, paramTypesFactory, @@ -402,17 +406,20 @@ private static SqlKind kind(org.apache.calcite.schema.Function function) { } private static SqlReturnTypeInference infer(final ScalarFunction function) { - return opBinding -> { - final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); - final RelDataType type; - if (function instanceof ScalarFunctionImpl) { - type = ((ScalarFunctionImpl) function).getReturnType(typeFactory, - opBinding); - } else { - type = function.getReturnType(typeFactory); - } - return toSql(typeFactory, type); - }; + if (function instanceof PinotScalarFunction && ((PinotScalarFunction) function).getReturnTypeInference() != null) { + return ((PinotScalarFunction) function).getReturnTypeInference(); + } else { + return opBinding -> { + final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + final RelDataType type; + if (function instanceof ScalarFunctionImpl) { + type = ((ScalarFunctionImpl) function).getReturnType(typeFactory, opBinding); + } else { + type = function.getReturnType(typeFactory); + } + return toSql(typeFactory, type); + }; + } } private static SqlReturnTypeInference infer( diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java index 0a647a879212..e91d3617fe4e 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java @@ -41,7 +41,7 @@ * - byte[] */ @Retention(RetentionPolicy.RUNTIME) -@Target(ElementType.METHOD) +@Target({ElementType.METHOD, ElementType.TYPE}) public @interface ScalarFunction { boolean enabled() default true; From 89597f3191d547ccbda696b168eea1de8723863d Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Mon, 22 Jan 2024 11:49:34 -0800 Subject: [PATCH 8/9] fix array function --- .../common/function/FunctionRegistry.java | 66 +++++++++++-------- .../function/scalar/ArrayFunctions.java | 5 +- .../function/schema/PinotScalarFunction.java | 18 ++++- .../sql/PinotCalciteCatalogReader.java | 1 + .../InbuiltFunctionEvaluatorTest.java | 2 +- .../operator/operands/FunctionOperand.java | 3 +- 6 files changed, 57 insertions(+), 38 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index ed2f50154451..fe9177ea9e01 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -89,10 +89,12 @@ private FunctionRegistry() { scalarFunctionNames.add(method.getName()); } boolean nullableParameters = scalarFunction.nullableParameters(); - registerFunction(method, scalarFunctionNames, nullableParameters); + boolean varArg = scalarFunction.isVarArg(); + registerFunction(method, scalarFunctionNames, nullableParameters, varArg); } } - Set> classes = PinotReflectionUtils.getClassesThroughReflection(".*\\.function\\..*", ScalarFunction.class); + Set> classes = + PinotReflectionUtils.getClassesThroughReflection(".*\\.function\\..*", ScalarFunction.class); for (Class clazz : classes) { if (!Modifier.isPublic(clazz.getModifiers())) { continue; @@ -105,7 +107,8 @@ private FunctionRegistry() { scalarFunctionNames.add(clazz.getName()); } boolean nullableParameters = scalarFunction.nullableParameters(); - registerFunction(clazz, scalarFunctionNames, nullableParameters); + boolean varArg = scalarFunction.isVarArg(); + registerFunction(clazz, scalarFunctionNames, nullableParameters, varArg); } } LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.map().size(), @@ -152,8 +155,8 @@ public static void init() { } @VisibleForTesting - public static void registerFunction(Method method, boolean nullableParameters) { - registerFunction(method, Collections.singleton(method.getName()), nullableParameters); + public static void registerFunction(Method method) { + registerFunction(method, Collections.singleton(method.getName()), false, false); } @VisibleForTesting @@ -197,10 +200,10 @@ public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDa @Nullable private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionName, int numParams) { - List candidates = getFunctionMap() - .range(functionName, CASE_SENSITIVITY).stream() - .filter(e -> e.getValue() instanceof PinotScalarFunction && e.getValue().getParameters().size() == numParams) - .map(e -> (PinotScalarFunction) e.getValue()).collect(Collectors.toList()); + List candidates = getFunctionMap().range(functionName, CASE_SENSITIVITY).stream().filter( + e -> e.getValue() instanceof PinotScalarFunction && (e.getValue().getParameters().size() == numParams + || ((PinotScalarFunction) e.getValue()).isVarArgs())).map(e -> (PinotScalarFunction) e.getValue()) + .collect(Collectors.toList()); if (candidates.size() == 1) { return candidates.get(0).getFunctionInfo(); } else { @@ -213,9 +216,10 @@ private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionNa private static PinotScalarFunction getScalarFunction(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, String functionName, List argTypes) { List relArgTypes = convertArgumentTypes(typeFactory, argTypes); - SqlOperator sqlOperator = SqlUtil.lookupRoutine(operatorTable, typeFactory, - new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), relArgTypes, null, null, SqlSyntax.FUNCTION, - SqlKind.OTHER_FUNCTION, SqlNameMatchers.withCaseSensitive(false), true); + SqlOperator sqlOperator = + SqlUtil.lookupRoutine(operatorTable, typeFactory, new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), + relArgTypes, null, null, SqlSyntax.FUNCTION, SqlKind.OTHER_FUNCTION, + SqlNameMatchers.withCaseSensitive(false), true); if (sqlOperator instanceof SqlUserDefinedFunction) { Function function = ((SqlUserDefinedFunction) sqlOperator).getFunction(); if (function instanceof PinotScalarFunction) { @@ -233,33 +237,38 @@ public static NameMultimap getOperatorMap() { return OPERATOR_MAP; } - private static void registerFunction(Method method, Set alias, boolean nullableParameters) { + private static void registerFunction(Method method, Set alias, boolean nullableParameters, boolean varArg) { if (method.getAnnotation(Deprecated.class) == null) { for (String name : alias) { - registerCalciteNamedFunctionMap(name, method, nullableParameters); + registerCalciteNamedFunctionMap(name, method, nullableParameters, varArg); } } } - private static void registerFunction(Class clazz, Set alias, boolean nullableParameters) { + private static void registerFunction(Class clazz, Set alias, boolean nullableParameters, boolean varArg) { if (clazz.getAnnotation(Deprecated.class) == null) { for (String name : alias) { - registerCalciteNamedFunctionMap(name, clazz, nullableParameters); + registerCalciteNamedFunctionMap(name, clazz, nullableParameters, varArg); } } } - private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters) { - FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters)); + private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters, + boolean varArg) { + FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, varArg)); } - private static void registerCalciteNamedFunctionMap(String name, Class clazz, boolean nullableParameters) { + private static void registerCalciteNamedFunctionMap(String name, Class clazz, boolean nullableParameters, + boolean varArg) { try { - SqlReturnTypeInference returnTypeInference = (SqlReturnTypeInference) clazz.getField("RETURN_TYPE_INFERENCE").get(null); - SqlOperandTypeChecker operandTypeChecker = (SqlOperandTypeChecker) clazz.getField("OPERAND_TYPE_CHECKER").get(null); + SqlReturnTypeInference returnTypeInference = + (SqlReturnTypeInference) clazz.getField("RETURN_TYPE_INFERENCE").get(null); + SqlOperandTypeChecker operandTypeChecker = + (SqlOperandTypeChecker) clazz.getField("OPERAND_TYPE_CHECKER").get(null); for (Method method : clazz.getMethods()) { if (method.getName().equals("eval")) { - FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, operandTypeChecker, returnTypeInference)); + FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, varArg, operandTypeChecker, + returnTypeInference)); } } } catch (Exception e) { @@ -269,9 +278,10 @@ private static void registerCalciteNamedFunctionMap(String name, Class clazz, private static void registerAggregateFunction(String functionName, AggregationFunctionType functionType) { if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { - PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, - functionType.getSqlKind(), functionType.getReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); + PinotSqlAggFunction sqlAggFunction = + new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, functionType.getSqlKind(), + functionType.getReturnTypeInference(), null, functionType.getOperandTypeChecker(), + functionType.getSqlFunctionCategory()); OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlAggFunction); } } @@ -279,9 +289,9 @@ private static void registerAggregateFunction(String functionName, AggregationFu private static void registerTransformFunction(String functionName, TransformFunctionType functionType) { if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { PinotSqlTransformFunction sqlTransformFunction = - new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), - functionType.getSqlKind(), functionType.getReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); + new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), functionType.getSqlKind(), + functionType.getReturnTypeInference(), null, functionType.getOperandTypeChecker(), + functionType.getSqlFunctionCategory()); OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlTransformFunction); } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java index aae61e94a7c8..ee767c1d8525 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java @@ -25,13 +25,10 @@ import it.unimi.dsi.fastutil.objects.ObjectSet; import java.math.BigDecimal; import java.util.Arrays; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SameOperandTypeChecker; -import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.commons.lang3.ArrayUtils; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder; @@ -235,7 +232,7 @@ public static String arrayElementAtString(String[] arr, int idx) { return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING; } - @ScalarFunction(names = {"array", "arrayValueConstructor"}) + @ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true) public static class ArrayValueConstructor { public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = ReturnTypes.TO_ARRAY; public static final SqlOperandTypeChecker OPERAND_TYPE_CHECKER = new SameOperandTypeChecker(-1); diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java index 5030dc47bb53..e277d958a344 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/schema/PinotScalarFunction.java @@ -39,16 +39,20 @@ public class PinotScalarFunction extends ReflectiveFunctionBase implements Pinot private final Method _method; private final SqlOperandTypeChecker _sqlOperandTypeChecker; private final SqlReturnTypeInference _sqlReturnTypeInference; + private final boolean _isNullableParameter; + private final boolean _isVarArgs; - public PinotScalarFunction(String name, Method method, boolean isNullableParameter) { - this(name, method, isNullableParameter, null, null); + public PinotScalarFunction(String name, Method method, boolean isNullableParameter, boolean isVarArg) { + this(name, method, isNullableParameter, isVarArg, null, null); } - public PinotScalarFunction(String name, Method method, boolean isNullableParameter, + public PinotScalarFunction(String name, Method method, boolean isNullableParameter, boolean isVarArgs, SqlOperandTypeChecker sqlOperandTypeChecker, SqlReturnTypeInference sqlReturnTypeInference) { super(method); _name = name; _method = method; + _isNullableParameter = isNullableParameter; + _isVarArgs = isVarArgs; _functionInfo = new FunctionInfo(method, method.getDeclaringClass(), isNullableParameter); _sqlOperandTypeChecker = sqlOperandTypeChecker; _sqlReturnTypeInference = sqlReturnTypeInference; @@ -80,4 +84,12 @@ public SqlOperandTypeChecker getOperandTypeChecker() { public SqlReturnTypeInference getReturnTypeInference() { return _sqlReturnTypeInference; } + + public boolean isNullableParameter() { + return _isNullableParameter; + } + + public boolean isVarArgs() { + return _isVarArgs; + } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java index e3f3d2b737ef..0d6a94105c42 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java @@ -76,6 +76,7 @@ import org.apache.calcite.sql.validate.SqlValidatorUtil; import org.apache.calcite.util.Optionality; import org.apache.calcite.util.Util; +import org.apache.pinot.common.function.schema.PinotScalarFunction; import org.checkerframework.checker.nullness.qual.Nullable; diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java index 5c6835e293f6..031cfd169ecd 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java @@ -130,7 +130,7 @@ public void testStateSharedBetweenRowsForExecution() throws Exception { MyFunc myFunc = new MyFunc(); Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class); - FunctionRegistry.registerFunction(method, false); + FunctionRegistry.registerFunction(method); String expression = "appendToStringAndReturn('test ')"; InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); assertTrue(evaluator.getArguments().isEmpty()); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java index 29419dc5e4d9..773f73058fb7 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java @@ -58,8 +58,7 @@ public FunctionOperand(SqlOperatorTable sqlOperatorTable, RelDataTypeFactory rel return e.getDataType(); } }).collect(Collectors.toList()); - FunctionInfo functionInfo = - FunctionRegistry.getFunctionInfo(sqlOperatorTable, relDataTypeFactory, canonicalName, operandTypes); + FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(canonicalName, operandTypes.size()); Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", canonicalName); _functionInvoker = new FunctionInvoker(functionInfo); if (!_functionInvoker.getMethod().isVarArgs()) { From ffb940b4e3a1f85f677de7e0431d804964476ea6 Mon Sep 17 00:00:00 2001 From: Rong Rong Date: Wed, 24 Jan 2024 10:48:18 -0800 Subject: [PATCH 9/9] fix lookup, using special pinot udf extension - support arrayElementAt - support arrayValuConstructor --- .../common/function/FunctionRegistry.java | 12 ++--- .../function/scalar/ArrayFunctions.java | 39 ++++++++++++++++ .../sql/PinotCalciteCatalogReader.java | 20 ++++++--- .../function/sql/PinotSqlScalarFunction.java | 44 +++++++++++++++++++ .../operator/operands/FunctionOperand.java | 6 ++- 5 files changed, 108 insertions(+), 13 deletions(-) create mode 100644 pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlScalarFunction.java diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index fe9177ea9e01..1f9242214611 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -47,6 +47,7 @@ import org.apache.pinot.common.function.schema.PinotFunction; import org.apache.pinot.common.function.schema.PinotScalarFunction; import org.apache.pinot.common.function.sql.PinotSqlAggFunction; +import org.apache.pinot.common.function.sql.PinotSqlScalarFunction; import org.apache.pinot.common.function.sql.PinotSqlTransformFunction; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.segment.spi.AggregationFunctionType; @@ -188,7 +189,7 @@ public static FunctionInfo getFunctionInfo(String functionName, int numParams) { */ @Nullable public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, - String functionName, List argTypes) { + String functionName, List argTypes) { PinotScalarFunction scalarFunction = getScalarFunction(operatorTable, typeFactory, functionName, argTypes); if (scalarFunction != null) { return scalarFunction.getFunctionInfo(); @@ -214,17 +215,18 @@ private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionNa @Nullable private static PinotScalarFunction getScalarFunction(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, - String functionName, List argTypes) { - List relArgTypes = convertArgumentTypes(typeFactory, argTypes); + String functionName, List argTypes) { SqlOperator sqlOperator = SqlUtil.lookupRoutine(operatorTable, typeFactory, new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), - relArgTypes, null, null, SqlSyntax.FUNCTION, SqlKind.OTHER_FUNCTION, + argTypes, null, null, SqlSyntax.FUNCTION, SqlKind.OTHER_FUNCTION, SqlNameMatchers.withCaseSensitive(false), true); if (sqlOperator instanceof SqlUserDefinedFunction) { Function function = ((SqlUserDefinedFunction) sqlOperator).getFunction(); if (function instanceof PinotScalarFunction) { return (PinotScalarFunction) function; } + } else if (sqlOperator instanceof PinotSqlScalarFunction) { + return ((PinotSqlScalarFunction) sqlOperator).getFunction(); } return null; } @@ -296,7 +298,7 @@ private static void registerTransformFunction(String functionName, TransformFunc } } - private static List convertArgumentTypes(RelDataTypeFactory typeFactory, + public static List convertArgumentTypes(RelDataTypeFactory typeFactory, List argTypes) { return argTypes.stream().map(type -> toRelType(typeFactory, type)).collect(Collectors.toList()); } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java index ee767c1d8525..caea2dc97345 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java @@ -18,6 +18,7 @@ */ package org.apache.pinot.common.function.scalar; +import com.google.common.base.Preconditions; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntLinkedOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; @@ -25,10 +26,16 @@ import it.unimi.dsi.fastutil.objects.ObjectSet; import java.math.BigDecimal; import java.util.Arrays; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.type.NonNullableAccessors; +import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SameOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.commons.lang3.ArrayUtils; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder; @@ -232,6 +239,38 @@ public static String arrayElementAtString(String[] arr, int idx) { return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING; } + @ScalarFunction(names = {"arrayElementAt", "array_element_at"}) + public static class ArrayElementAt { + public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = (opBinding) -> { + final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + final RelDataType operandType = opBinding.getOperandType(0); + Preconditions.checkState(operandType.getSqlTypeName() == SqlTypeName.ARRAY); + return typeFactory.createTypeWithNullability(NonNullableAccessors.getComponentTypeOrThrow(operandType), true); + }; + public static final SqlOperandTypeChecker OPERAND_TYPE_CHECKER = + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER); + + public static int eval(int[] arr, int idx) { + return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.INT; + } + + public static long eval(long[] arr, int idx) { + return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.LONG; + } + + public static float eval(float[] arr, int idx) { + return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.FLOAT; + } + + public static double eval(double[] arr, int idx) { + return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.DOUBLE; + } + + public static String eval(String[] arr, int idx) { + return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING; + } + } + @ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true) public static class ArrayValueConstructor { public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = ReturnTypes.TO_ARRAY; diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java index 0d6a94105c42..34d4079962b7 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java @@ -59,6 +59,7 @@ import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlOperandMetadata; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeFamily; @@ -323,6 +324,17 @@ public static SqlOperatorTable operatorTable(String... classNames) { // ==================================================================== public static SqlOperator toOp(SqlIdentifier name, final org.apache.calcite.schema.Function function) { + + // TODO: support AGG and TABLE function in the future + if (function instanceof PinotScalarFunction && ((PinotScalarFunction) function).getOperandTypeChecker() != null) { + final SqlOperandTypeChecker operandTypeChecker = + ((PinotScalarFunction) function).getOperandTypeChecker(); + final SqlReturnTypeInference returnTypeInference = + ((PinotScalarFunction) function).getReturnTypeInference(); + final SqlKind kind = kind(function); + return new PinotSqlScalarFunction(name.toString(), kind, returnTypeInference, null, operandTypeChecker, + SqlFunctionCategory.USER_DEFINED_FUNCTION, ((PinotScalarFunction) function)); + } // ==================================================================== // LINES CHANGED ABOVE // ==================================================================== @@ -353,13 +365,7 @@ public static SqlOperator toOp(SqlIdentifier name, final List typeFamilies = typeFamiliesFactory.apply(dummyTypeFactory); - final SqlOperandTypeInference operandTypeInference; - if (function instanceof PinotScalarFunction && ((PinotScalarFunction) function).getOperandTypeChecker() != null) { - operandTypeInference = ((PinotScalarFunction) function).getOperandTypeChecker().typeInference(); - } else { - operandTypeInference = InferTypes.explicit(argTypes); - } - + final SqlOperandTypeInference operandTypeInference = InferTypes.explicit(argTypes); final SqlOperandMetadata operandMetadata = OperandTypes.operandMetadata(typeFamilies, paramTypesFactory, i -> function.getParameters().get(i).getName(), diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlScalarFunction.java new file mode 100644 index 000000000000..1cd6d3df0370 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlScalarFunction.java @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.sql; + +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlOperandTypeInference; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.pinot.common.function.schema.PinotScalarFunction; +import org.checkerframework.checker.nullness.qual.Nullable; + + +public class PinotSqlScalarFunction extends SqlFunction { + private final PinotScalarFunction _pinotScalarFunction; + + public PinotSqlScalarFunction(String name, SqlKind kind, @Nullable SqlReturnTypeInference returnTypeInference, + @Nullable SqlOperandTypeInference operandTypeInference, @Nullable SqlOperandTypeChecker operandTypeChecker, + SqlFunctionCategory category, PinotScalarFunction pinotScalarFunction) { + super(name, kind, returnTypeInference, operandTypeInference, operandTypeChecker, category); + _pinotScalarFunction = pinotScalarFunction; + } + + public PinotScalarFunction getFunction() { + return _pinotScalarFunction; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java index 773f73058fb7..6c954caf0f2f 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.stream.Collectors; import javax.annotation.Nullable; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.sql.SqlOperatorTable; import org.apache.pinot.common.function.FunctionInfo; @@ -58,7 +59,10 @@ public FunctionOperand(SqlOperatorTable sqlOperatorTable, RelDataTypeFactory rel return e.getDataType(); } }).collect(Collectors.toList()); - FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(canonicalName, operandTypes.size()); + + List argTypes = FunctionRegistry.convertArgumentTypes(relDataTypeFactory, operandTypes); + FunctionInfo functionInfo = + FunctionRegistry.getFunctionInfo(sqlOperatorTable, relDataTypeFactory, canonicalName, argTypes); Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", canonicalName); _functionInvoker = new FunctionInvoker(functionInfo); if (!_functionInvoker.getMethod().isVarArgs()) {