diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index 97fa972bee18..e5ee9899d463 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -18,18 +18,22 @@ */ package org.apache.pinot.common.function; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.apache.calcite.schema.Function; -import org.apache.calcite.schema.impl.ScalarFunctionImpl; -import org.apache.calcite.util.NameMultimap; import org.apache.commons.lang3.StringUtils; +import org.apache.pinot.common.function.registry.PinotFunction; +import org.apache.pinot.common.function.registry.PinotScalarFunction; +import org.apache.pinot.common.function.sql.PinotFunctionRegistry; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.PinotReflectionUtils; import org.slf4j.Logger; @@ -38,19 +42,14 @@ /** * Registry for scalar functions. - *

TODO: Merge FunctionRegistry and FunctionDefinitionRegistry to provide one single registry for all functions. */ public class FunctionRegistry { - private FunctionRegistry() { - } - + public static final boolean CASE_SENSITIVITY = false; private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class); - - // TODO: consolidate the following 2 - // This FUNCTION_INFO_MAP is used by Pinot server to look up function by # of arguments private static final Map> FUNCTION_INFO_MAP = new HashMap<>(); - // This FUNCTION_MAP is used by Calcite function catalog to look up function by function signature. - private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); + + private FunctionRegistry() { + } /** * Registers the scalar functions via reflection. @@ -66,16 +65,13 @@ private FunctionRegistry() { } ScalarFunction scalarFunction = method.getAnnotation(ScalarFunction.class); if (scalarFunction.enabled()) { - // Annotated function names - String[] scalarFunctionNames = scalarFunction.names(); - boolean nullableParameters = scalarFunction.nullableParameters(); - if (scalarFunctionNames.length > 0) { - for (String name : scalarFunctionNames) { - FunctionRegistry.registerFunction(name, method, nullableParameters, scalarFunction.isPlaceholder()); - } - } else { - FunctionRegistry.registerFunction(method, nullableParameters, scalarFunction.isPlaceholder()); + // Parse annotated function names and alias + Set scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet()); + if (scalarFunctionNames.size() == 0) { + scalarFunctionNames.add(method.getName()); } + boolean nullableParameters = scalarFunction.nullableParameters(); + FunctionRegistry.registerFunction(method, scalarFunctionNames, nullableParameters); } } LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_INFO_MAP.size(), @@ -90,22 +86,54 @@ private FunctionRegistry() { public static void init() { } + @VisibleForTesting + public static void registerFunction(Method method, boolean nullableParameters) { + registerFunction(method, Collections.singleton(method.getName()), nullableParameters); + } + + @VisibleForTesting + public static Set getRegisteredCalciteFunctionNames() { + return PinotFunctionRegistry.getFunctionMap().map().keySet(); + } + /** - * Registers a method with the name of the method. + * Returns the full list of all registered ScalarFunction to Calcite. */ - public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder) { - registerFunction(method.getName(), method, nullableParameters, isPlaceholder); + public static Map> getRegisteredCalciteFunctionMap() { + return PinotFunctionRegistry.getFunctionMap().map(); } /** - * Registers a method with the given function name. + * Returns {@code true} if the given function name is registered, {@code false} otherwise. */ - public static void registerFunction(String functionName, Method method, boolean nullableParameters, - boolean isPlaceholder) { - if (!isPlaceholder) { - registerFunctionInfoMap(functionName, method, nullableParameters); + public static boolean containsFunction(String functionName) { + // TODO: remove deprecated FUNCTION_INFO_MAP + return PinotFunctionRegistry.getFunctionMap().containsKey(functionName, CASE_SENSITIVITY) + || FUNCTION_INFO_MAP.containsKey(canonicalize(functionName)); + } + + /** + * Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null} + * if there is no matching method. This method should be called after the FunctionRegistry is initialized and all + * methods are already registered. + */ + @Nullable + public static FunctionInfo getFunctionInfo(String functionName, int numParams) { + try { + return getFunctionInfoFromCalciteNamedMap(functionName, numParams); + } catch (IllegalArgumentException iae) { + // TODO: remove deprecated FUNCTION_INFO_MAP + return getFunctionInfoFromFunctionInfoMap(functionName, numParams); + } + } + + // TODO: remove deprecated FUNCTION_INFO_MAP + private static void registerFunction(Method method, Set alias, boolean nullableParameters) { + if (method.getAnnotation(Deprecated.class) == null) { + for (String name : alias) { + registerFunctionInfoMap(name, method, nullableParameters); + } } - registerCalciteNamedFunctionMap(functionName, method, nullableParameters); } private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters) { @@ -117,36 +145,25 @@ private static void registerFunctionInfoMap(String functionName, Method method, "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); } - private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters) { - if (method.getAnnotation(Deprecated.class) == null) { - FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method)); - } - } - - public static Map> getRegisteredCalciteFunctionMap() { - return FUNCTION_MAP.map(); - } - - public static Set getRegisteredCalciteFunctionNames() { - return FUNCTION_MAP.map().keySet(); - } - - /** - * Returns {@code true} if the given function name is registered, {@code false} otherwise. - */ - public static boolean containsFunction(String functionName) { - return FUNCTION_INFO_MAP.containsKey(canonicalize(functionName)); + @Nullable + private static FunctionInfo getFunctionInfoFromFunctionInfoMap(String functionName, int numParams) { + Map functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); + return functionInfoMap != null ? functionInfoMap.get(numParams) : null; } - /** - * Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null} - * if there is no matching method. This method should be called after the FunctionRegistry is initialized and all - * methods are already registered. - */ @Nullable - public static FunctionInfo getFunctionInfo(String functionName, int numParameters) { - Map functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); - return functionInfoMap != null ? functionInfoMap.get(numParameters) : null; + private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionName, int numParams) { + List candidates = PinotFunctionRegistry.getFunctionMap() + .range(functionName, CASE_SENSITIVITY).stream() + .filter(e -> e.getValue() instanceof PinotScalarFunction && e.getValue().getParameters().size() == numParams) + .map(e -> (PinotScalarFunction) e.getValue()).collect(Collectors.toList()); + if (candidates.size() == 1) { + return candidates.get(0).getFunctionInfo(); + } else { + throw new IllegalArgumentException( + "Unable to lookup function: " + functionName + " by parameter count: " + numParams + " Found " + + candidates.size() + " candidates. Try to use argument types to resolve the correct one!"); + } } private static String canonicalize(String functionName) { diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java index c2d7d3e24047..069503fcd493 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java @@ -138,12 +138,7 @@ public enum TransformFunctionType { ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)), "date_time_convert_window_hop"), - DATE_TRUNC("dateTrunc", - ReturnTypes.BIGINT_FORCE_NULLABLE, - OperandTypes.family( - ImmutableList.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, - SqlTypeFamily.CHARACTER), - ordinal -> ordinal > 1), "date_trunc"), + DATE_TRUNC("dateTrunc","date_trunc"), FROM_DATE_TIME("fromDateTime", ReturnTypes.TIMESTAMP_NULLABLE, OperandTypes.family(ImmutableList.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER), diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotFunction.java new file mode 100644 index 000000000000..f0e756513739 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotFunction.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.registry; + +import org.apache.calcite.schema.Function; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; + + +public interface PinotFunction extends Function { + SqlOperandTypeChecker getOperandTypeChecker(); + SqlReturnTypeInference getReturnTypeInference(); +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotScalarFunction.java new file mode 100644 index 000000000000..c1708aab4203 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/registry/PinotScalarFunction.java @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.registry; + +import java.lang.reflect.Method; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.schema.ScalarFunction; +import org.apache.calcite.schema.impl.ReflectiveFunctionBase; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.pinot.common.function.FunctionInfo; + + +/** + * Pinot specific implementation of the {@link ScalarFunction}. + * + * @see "{@link org.apache.calcite.schema.impl.ScalarFunctionImpl}" + */ +public class PinotScalarFunction extends ReflectiveFunctionBase implements PinotFunction, ScalarFunction { + private final FunctionInfo _functionInfo; + private final String _name; + private final Method _method; + private final SqlOperandTypeChecker _sqlOperandTypeChecker; + private final SqlReturnTypeInference _sqlReturnTypeInference; + + public PinotScalarFunction(String name, Method method, boolean isNullableParameter) { + this(name, method, isNullableParameter, null, null); + } + + public PinotScalarFunction(String name, Method method, boolean isNullableParameter, + SqlOperandTypeChecker sqlOperandTypeChecker, SqlReturnTypeInference sqlReturnTypeInference) { + super(method); + _name = name; + _method = method; + _functionInfo = new FunctionInfo(method, method.getDeclaringClass(), isNullableParameter); + _sqlOperandTypeChecker = sqlOperandTypeChecker; + _sqlReturnTypeInference = sqlReturnTypeInference; + } + + @Override + public RelDataType getReturnType(RelDataTypeFactory typeFactory) { + return typeFactory.createJavaType(method.getReturnType()); + } + + public String getName() { + return _name; + } + + public Method getMethod() { + return _method; + } + + public FunctionInfo getFunctionInfo() { + return _functionInfo; + } + + @Override + public SqlOperandTypeChecker getOperandTypeChecker() { + return _sqlOperandTypeChecker; + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return _sqlReturnTypeInference; + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java index 16dfad75c7ab..5f9807d809cc 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java @@ -18,9 +18,16 @@ */ package org.apache.pinot.common.function.scalar; +import com.google.common.collect.ImmutableList; import java.sql.Timestamp; import java.time.Duration; import java.util.concurrent.TimeUnit; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.pinot.common.function.DateTimePatternHandler; import org.apache.pinot.common.function.DateTimeUtils; import org.apache.pinot.common.function.TimeZoneKey; @@ -1058,20 +1065,64 @@ public static int[] millisecondMV(long[] millis, String timezoneId) { return results; } + /** * The sql compatible date_trunc function for epoch time. - * - * @param unit truncate to unit (millisecond, second, minute, hour, day, week, month, quarter, year) - * @param timeValue value to truncate - * @return truncated timeValue in TimeUnit.MILLISECONDS */ @ScalarFunction(names = {"dateTrunc", "date_trunc"}) + public static class dateTruncScalarFunctions { + public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = ReturnTypes.BIGINT_FORCE_NULLABLE; + public static final SqlOperandTypeChecker OPERAND_TYPE_CHECKER = OperandTypes.family( + ImmutableList.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER), + ordinal -> ordinal > 1); + + public static long eval(String unit, long timeValue) { + return dateTrunc(unit, timeValue, TimeUnit.MILLISECONDS.name(), ISOChronology.getInstanceUTC(), + TimeUnit.MILLISECONDS.name()); + } + + public static long eval(String unit, long timeValue, String inputTimeUnit) { + return dateTrunc(unit, timeValue, inputTimeUnit, ISOChronology.getInstanceUTC(), inputTimeUnit); + } + public static long eval(String unit, long timeValue, String inputTimeUnit, String timeZone) { + return dateTrunc(unit, timeValue, inputTimeUnit, DateTimeUtils.getChronology(TimeZoneKey.getTimeZoneKey(timeZone)), + inputTimeUnit); + } + public static long eval(String unit, long timeValue, String inputTimeUnit, String timeZone, + String outputTimeUnit) { + return dateTrunc(unit, timeValue, inputTimeUnit, DateTimeUtils.getChronology(TimeZoneKey.getTimeZoneKey(timeZone)), + outputTimeUnit); + } + } + + + @ScalarFunction(names = {"dateTruncMV", "date_trunc_mv"}) + public static class dateTruncMvScalarFunction { + public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = ReturnTypes.BIGINT_FORCE_NULLABLE.andThen(SqlTypeTransforms.TO_ARRAY); + public static final SqlOperandTypeChecker OPERAND_TYPE_CHECKER = OperandTypes.family( + ImmutableList.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.ARRAY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER), ordinal -> ordinal > 1); + public static long[] eval(String unit, long[] timeValue) { + return dateTruncMV(unit, timeValue); + } + public static long[] eval(String unit, long[] timeValue, String inputTimeUnit) { + return dateTruncMV(unit, timeValue, inputTimeUnit); + } + public static long[] eval(String unit, long[] timeValue, String inputTimeUnit, String timeZone) { + return dateTruncMV(unit, timeValue, inputTimeUnit, timeZone); + } + public static long[] eval(String unit, long[] timeValue, String inputTimeUnit, String timeZone, + String outputTimeUnit) { + return dateTruncMV(unit, timeValue, inputTimeUnit, timeZone, outputTimeUnit); + } + } + public static long dateTrunc(String unit, long timeValue) { return dateTrunc(unit, timeValue, TimeUnit.MILLISECONDS.name(), ISOChronology.getInstanceUTC(), TimeUnit.MILLISECONDS.name()); } - @ScalarFunction(names = {"dateTruncMV", "date_trunc_mv"}) public static long[] dateTruncMV(String unit, long[] timeValue) { long[] results = new long[timeValue.length]; for (int i = 0; i < timeValue.length; i++) { @@ -1080,20 +1131,10 @@ public static long[] dateTruncMV(String unit, long[] timeValue) { return results; } - /** - * The sql compatible date_trunc function for epoch time. - * - * @param unit truncate to unit (millisecond, second, minute, hour, day, week, month, quarter, year) - * @param timeValue value to truncate - * @param inputTimeUnit TimeUnit of value, expressed in Java's joda TimeUnit - * @return truncated timeValue in same TimeUnit as the input - */ - @ScalarFunction(names = {"dateTrunc", "date_trunc"}) public static long dateTrunc(String unit, long timeValue, String inputTimeUnit) { return dateTrunc(unit, timeValue, inputTimeUnit, ISOChronology.getInstanceUTC(), inputTimeUnit); } - @ScalarFunction(names = {"dateTruncMV", "date_trunc_mv"}) public static long[] dateTruncMV(String unit, long[] timeValue, String inputTimeUnit) { long[] results = new long[timeValue.length]; for (int i = 0; i < timeValue.length; i++) { @@ -1102,22 +1143,11 @@ public static long[] dateTruncMV(String unit, long[] timeValue, String inputTime return results; } - /** - * The sql compatible date_trunc function for epoch time. - * - * @param unit truncate to unit (millisecond, second, minute, hour, day, week, month, quarter, year) - * @param timeValue value to truncate - * @param inputTimeUnit TimeUnit of value, expressed in Java's joda TimeUnit - * @param timeZone timezone of the input - * @return truncated timeValue in same TimeUnit as the input - */ - @ScalarFunction(names = {"dateTrunc", "date_trunc"}) public static long dateTrunc(String unit, long timeValue, String inputTimeUnit, String timeZone) { return dateTrunc(unit, timeValue, inputTimeUnit, DateTimeUtils.getChronology(TimeZoneKey.getTimeZoneKey(timeZone)), inputTimeUnit); } - @ScalarFunction(names = {"dateTruncMV", "date_trunc_mv"}) public static long[] dateTruncMV(String unit, long[] timeValue, String inputTimeUnit, String timeZone) { long[] results = new long[timeValue.length]; for (int i = 0; i < timeValue.length; i++) { @@ -1126,25 +1156,12 @@ public static long[] dateTruncMV(String unit, long[] timeValue, String inputTime return results; } - /** - * The sql compatible date_trunc function for epoch time. - * - * @param unit truncate to unit (millisecond, second, minute, hour, day, week, month, quarter, year) - * @param timeValue value to truncate - * @param inputTimeUnit TimeUnit of value, expressed in Java's joda TimeUnit - * @param timeZone timezone of the input - * @param outputTimeUnit TimeUnit to convert the output to - * @return truncated timeValue - * - */ - @ScalarFunction(names = {"dateTrunc", "date_trunc"}) public static long dateTrunc(String unit, long timeValue, String inputTimeUnit, String timeZone, String outputTimeUnit) { return dateTrunc(unit, timeValue, inputTimeUnit, DateTimeUtils.getChronology(TimeZoneKey.getTimeZoneKey(timeZone)), outputTimeUnit); } - @ScalarFunction(names = {"dateTruncMV", "date_trunc_mv"}) public static long[] dateTruncMV(String unit, long[] timeValue, String inputTimeUnit, String timeZone, String outputTimeUnit) { long[] results = new long[timeValue.length]; @@ -1154,7 +1171,7 @@ public static long[] dateTruncMV(String unit, long[] timeValue, String inputTime return results; } - private static long dateTrunc(String unit, long timeValue, String inputTimeUnit, ISOChronology chronology, + public static long dateTrunc(String unit, long timeValue, String inputTimeUnit, ISOChronology chronology, String outputTimeUnit) { return TimeUnit.valueOf(outputTimeUnit.toUpperCase()).convert(DateTimeUtils.getTimestampField(chronology, unit) .roundFloor(TimeUnit.MILLISECONDS.convert(timeValue, TimeUnit.valueOf(inputTimeUnit.toUpperCase()))), diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java index 5a49314943ba..3af1405f6cdc 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java @@ -140,7 +140,7 @@ public static String substring(String input, int beginIndex, int length) { * @param seperator * @return The two input strings joined by the seperator */ - @ScalarFunction(names = "concat_ws") + @ScalarFunction(names = {"concatWS", "concat_ws"}) public static String concatws(String seperator, String input1, String input2) { return concat(input1, input2, seperator); } diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/prepare/PinotCalciteCatalogReader.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java similarity index 93% rename from pinot-query-planner/src/main/java/org/apache/calcite/prepare/PinotCalciteCatalogReader.java rename to pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java index 84c71be601f0..e536a7bdfa70 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/prepare/PinotCalciteCatalogReader.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.calcite.prepare; +package org.apache.pinot.common.function.sql; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; @@ -35,6 +35,8 @@ import org.apache.calcite.linq4j.function.Hints; import org.apache.calcite.model.ModelHandler; import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.prepare.RelOptTableImpl; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeFactoryImpl; @@ -74,6 +76,7 @@ import org.apache.calcite.sql.validate.SqlValidatorUtil; import org.apache.calcite.util.Optionality; import org.apache.calcite.util.Util; +import org.apache.pinot.common.function.registry.PinotScalarFunction; import org.checkerframework.checker.nullness.qual.Nullable; @@ -310,7 +313,7 @@ public static SqlOperatorTable operatorTable(String... classNames) { } /** Converts a function to a {@link org.apache.calcite.sql.SqlOperator}. */ - private static SqlOperator toOp(SqlIdentifier name, + public static SqlOperator toOp(SqlIdentifier name, final org.apache.calcite.schema.Function function) { final Function> argTypesFactory = typeFactory -> function.getParameters() @@ -339,8 +342,12 @@ private static SqlOperator toOp(SqlIdentifier name, final List typeFamilies = typeFamiliesFactory.apply(dummyTypeFactory); - final SqlOperandTypeInference operandTypeInference = - InferTypes.explicit(argTypes); + final SqlOperandTypeInference operandTypeInference; + if (function instanceof PinotScalarFunction && ((PinotScalarFunction) function).getOperandTypeChecker() != null) { + operandTypeInference = ((PinotScalarFunction) function).getOperandTypeChecker().typeInference(); + } else { + operandTypeInference = InferTypes.explicit(argTypes); + } final SqlOperandMetadata operandMetadata = OperandTypes.operandMetadata(typeFamilies, paramTypesFactory, @@ -389,17 +396,20 @@ private static SqlKind kind(org.apache.calcite.schema.Function function) { } private static SqlReturnTypeInference infer(final ScalarFunction function) { - return opBinding -> { - final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); - final RelDataType type; - if (function instanceof ScalarFunctionImpl) { - type = ((ScalarFunctionImpl) function).getReturnType(typeFactory, - opBinding); - } else { - type = function.getReturnType(typeFactory); - } - return toSql(typeFactory, type); - }; + if (function instanceof PinotScalarFunction && ((PinotScalarFunction) function).getReturnTypeInference() != null) { + return ((PinotScalarFunction) function).getReturnTypeInference(); + } else { + return opBinding -> { + final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + final RelDataType type; + if (function instanceof ScalarFunctionImpl) { + type = ((ScalarFunctionImpl) function).getReturnType(typeFactory, opBinding); + } else { + type = function.getReturnType(typeFactory); + } + return toSql(typeFactory, type); + }; + } } private static SqlReturnTypeInference infer( diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java new file mode 100644 index 000000000000..d21e3dd261d0 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotFunctionRegistry.java @@ -0,0 +1,277 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.sql; + +import com.google.common.annotations.VisibleForTesting; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.schema.Function; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlNameMatchers; +import org.apache.calcite.sql.validate.SqlUserDefinedFunction; +import org.apache.calcite.util.NameMultimap; +import org.apache.pinot.common.function.TransformFunctionType; +import org.apache.pinot.common.function.registry.PinotFunction; +import org.apache.pinot.common.function.registry.PinotScalarFunction; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.spi.annotations.ScalarFunction; +import org.apache.pinot.spi.utils.PinotReflectionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Registry for scalar functions. + */ +public class PinotFunctionRegistry { + private static final Logger LOGGER = LoggerFactory.getLogger(PinotFunctionRegistry.class); + private static final NameMultimap OPERATOR_MAP = new NameMultimap<>(); + private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); + + private PinotFunctionRegistry() { + } + + /** + * Registers the scalar functions via reflection. + * NOTE: In order to plugin methods using reflection, the methods should be inside a class that includes ".function." + * in its class path. This convention can significantly reduce the time of class scanning. + */ + static { + // REGISTER FUNCTIONS + long startTimeMs = System.currentTimeMillis(); + Set methods = PinotReflectionUtils.getMethodsThroughReflection(".*\\.function\\..*", ScalarFunction.class); + for (Method method : methods) { + if (!Modifier.isPublic(method.getModifiers())) { + continue; + } + ScalarFunction scalarFunction = method.getAnnotation(ScalarFunction.class); + if (scalarFunction.enabled()) { + // Parse annotated function names and alias + Set scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet()); + if (scalarFunctionNames.size() == 0) { + scalarFunctionNames.add(method.getName()); + } + boolean nullableParameters = scalarFunction.nullableParameters(); + PinotFunctionRegistry.registerFunction(method, scalarFunctionNames, nullableParameters); + } + } + Set> classes = PinotReflectionUtils.getClassesThroughReflection(".*\\.function\\..*", ScalarFunction.class); + for (Class clazz : classes) { + if (!Modifier.isPublic(clazz.getModifiers())) { + continue; + } + ScalarFunction scalarFunction = clazz.getAnnotation(ScalarFunction.class); + if (scalarFunction.enabled()) { + // Parse annotated function names and alias + Set scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet()); + if (scalarFunctionNames.size() == 0) { + scalarFunctionNames.add(clazz.getName()); + } + boolean nullableParameters = scalarFunction.nullableParameters(); + PinotFunctionRegistry.registerFunction(clazz, scalarFunctionNames, nullableParameters); + } + } + LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.map().size(), + FUNCTION_MAP.map().keySet(), System.currentTimeMillis() - startTimeMs); + + // REGISTER OPERATORS + // Walk through all the Pinot aggregation types and + // 1. register those that are supported in multistage in addition to calcite standard opt table. + // 2. register special handling that differs from calcite standard. + for (AggregationFunctionType aggregationFunctionType : AggregationFunctionType.values()) { + if (aggregationFunctionType.getSqlKind() != null) { + // 1. Register the aggregation function with Calcite + registerAggregateFunction(aggregationFunctionType.getName(), aggregationFunctionType); + // 2. Register the aggregation function with Calcite on all alternative names + List alternativeFunctionNames = aggregationFunctionType.getAlternativeNames(); + for (String alternativeFunctionName : alternativeFunctionNames) { + registerAggregateFunction(alternativeFunctionName, aggregationFunctionType); + } + } + } + + // Walk through all the Pinot transform types and + // 1. register those that are supported in multistage in addition to calcite standard opt table. + // 2. register special handling that differs from calcite standard. + for (TransformFunctionType transformFunctionType : TransformFunctionType.values()) { + if (transformFunctionType.getSqlKind() != null) { + // 1. Register the transform function with Calcite + registerTransformFunction(transformFunctionType.getName(), transformFunctionType); + // 2. Register the transform function with Calcite on all alternative names + List alternativeFunctionNames = transformFunctionType.getAlternativeNames(); + for (String alternativeFunctionName : alternativeFunctionNames) { + registerTransformFunction(alternativeFunctionName, transformFunctionType); + } + } + } + } + + public static void init() { + } + + @VisibleForTesting + public static void registerFunction(Method method, boolean nullableParameters) { + registerFunction(method, Collections.singleton(method.getName()), nullableParameters); + } + + public static NameMultimap getFunctionMap() { + return FUNCTION_MAP; + } + + public static NameMultimap getOperatorMap() { + return OPERATOR_MAP; + } + + @Nullable + public static PinotScalarFunction getScalarFunction(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, + String functionName, List argTypes) { + List relArgTypes = convertArgumentTypes(typeFactory, argTypes); + SqlOperator sqlOperator = SqlUtil.lookupRoutine(operatorTable, typeFactory, + new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), relArgTypes, null, null, SqlSyntax.FUNCTION, + SqlKind.OTHER_FUNCTION, SqlNameMatchers.withCaseSensitive(false), true); + if (sqlOperator instanceof SqlUserDefinedFunction) { + Function function = ((SqlUserDefinedFunction) sqlOperator).getFunction(); + if (function instanceof PinotScalarFunction) { + return (PinotScalarFunction) function; + } + } + return null; + } + + private static void registerFunction(Method method, Set alias, boolean nullableParameters) { + if (method.getAnnotation(Deprecated.class) == null) { + for (String name : alias) { + registerCalciteNamedFunctionMap(name, method, nullableParameters); + } + } + } + + private static void registerFunction(Class clazz, Set alias, boolean nullableParameters) { + if (clazz.getAnnotation(Deprecated.class) == null) { + for (String name : alias) { + registerCalciteNamedFunctionMap(name, clazz, nullableParameters); + } + } + } + + private static void registerCalciteNamedFunctionMap(String name, Class clazz, boolean nullableParameters) { + try { + SqlReturnTypeInference returnTypeInference = (SqlReturnTypeInference) clazz.getField("RETURN_TYPE_INFERENCE").get(null); + SqlOperandTypeChecker operandTypeChecker = (SqlOperandTypeChecker) clazz.getField("OPERAND_TYPE_CHECKER").get(null); + for (Method method : clazz.getMethods()) { + if (method.getName().equals("eval")) { + FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters, operandTypeChecker, returnTypeInference)); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static void registerCalciteNamedFunctionMap(String name, Method method, boolean nullableParameters) { + FUNCTION_MAP.put(name, new PinotScalarFunction(name, method, nullableParameters)); + } + + private static List convertArgumentTypes(RelDataTypeFactory typeFactory, + List argTypes) { + return argTypes.stream().map(type -> toRelType(typeFactory, type)).collect(Collectors.toList()); + } + + private static RelDataType toRelType(RelDataTypeFactory typeFactory, DataSchema.ColumnDataType dataType) { + switch (dataType) { + case INT: + return typeFactory.createSqlType(SqlTypeName.INTEGER); + case LONG: + return typeFactory.createSqlType(SqlTypeName.BIGINT); + case FLOAT: + return typeFactory.createSqlType(SqlTypeName.REAL); + case DOUBLE: + return typeFactory.createSqlType(SqlTypeName.DOUBLE); + case BIG_DECIMAL: + return typeFactory.createSqlType(SqlTypeName.DECIMAL); + case BOOLEAN: + return typeFactory.createSqlType(SqlTypeName.BOOLEAN); + case TIMESTAMP: + return typeFactory.createSqlType(SqlTypeName.TIMESTAMP); + case JSON: + case STRING: + return typeFactory.createSqlType(SqlTypeName.VARCHAR); + case BYTES: + return typeFactory.createSqlType(SqlTypeName.VARBINARY); + case INT_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.INTEGER), -1); + case LONG_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.BIGINT), -1); + case FLOAT_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.REAL), -1); + case DOUBLE_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.DOUBLE), -1); + case BOOLEAN_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.BOOLEAN), -1); + case TIMESTAMP_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.TIMESTAMP), -1); + case STRING_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.VARCHAR), -1); + case BYTES_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.VARBINARY), -1); + case UNKNOWN: + case OBJECT: + default: + return typeFactory.createSqlType(SqlTypeName.ANY); + } + } + + private static void registerAggregateFunction(String functionName, AggregationFunctionType functionType) { + if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { + PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, + functionType.getSqlKind(), functionType.getReturnTypeInference(), null, + functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); + OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlAggFunction); + } + } + + private static void registerTransformFunction(String functionName, TransformFunctionType functionType) { + if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { + PinotSqlTransformFunction sqlTransformFunction = + new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), + functionType.getSqlKind(), functionType.getReturnTypeInference(), null, + functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); + OPERATOR_MAP.put(functionName.toUpperCase(Locale.ROOT), sqlTransformFunction); + } + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java new file mode 100644 index 000000000000..ca4513a5ba73 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotOperatorTable.java @@ -0,0 +1,101 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.sql; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.validate.SqlNameMatcher; +import org.apache.pinot.common.function.FunctionRegistry; +import org.checkerframework.checker.nullness.qual.Nullable; + + +/** + * Temporary implementation of all dynamic arg/return type inference operators. + * TODO: merge this with {@link PinotCalciteCatalogReader} once we support + * 1. Return/Inference configuration in @ScalarFunction + * 2. Allow @ScalarFunction registry towards class (with multiple impl) + */ +public class PinotOperatorTable implements SqlOperatorTable { + private static final PinotOperatorTable INSTANCE = new PinotOperatorTable(); + + public static synchronized PinotOperatorTable instance() { + return INSTANCE; + } + + @Override public void lookupOperatorOverloads(SqlIdentifier opName, + @Nullable SqlFunctionCategory category, SqlSyntax syntax, + List operatorList, SqlNameMatcher nameMatcher) { + String simpleName = opName.getSimple(); + final Collection list = + lookUpOperators(simpleName); + if (list.isEmpty()) { + return; + } + for (SqlOperator op : list) { + if (op.getSyntax() == syntax) { + operatorList.add(op); + } else if (syntax == SqlSyntax.FUNCTION + && op instanceof SqlFunction) { + // this special case is needed for operators like CAST, + // which are treated as functions but have special syntax + operatorList.add(op); + } + } + + // REVIEW jvs 1-Jan-2005: why is this extra lookup required? + // Shouldn't it be covered by search above? + switch (syntax) { + case BINARY: + case PREFIX: + case POSTFIX: + for (SqlOperator extra + : lookUpOperators(simpleName)) { + // REVIEW: should only search operators added during this method? + if (extra != null && !operatorList.contains(extra)) { + operatorList.add(extra); + } + } + break; + default: + break; + } + } + + /** + * Look up operators based on case-sensitiveness. + */ + private Collection lookUpOperators(String name) { + return PinotFunctionRegistry.getOperatorMap().range(name, FunctionRegistry.CASE_SENSITIVITY).stream() + .map(Map.Entry::getValue).collect(Collectors.toSet()); + } + + @Override + public List getOperatorList() { + return PinotFunctionRegistry.getOperatorMap().map().values().stream().flatMap(List::stream) + .collect(Collectors.toList()); + } +} diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlAggFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlAggFunction.java similarity index 91% rename from pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlAggFunction.java rename to pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlAggFunction.java index 0d4146f4317d..bc8f55474e00 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlAggFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlAggFunction.java @@ -16,8 +16,12 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.calcite.sql; +package org.apache.pinot.common.function.sql; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlTransformFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlTransformFunction.java similarity index 89% rename from pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlTransformFunction.java rename to pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlTransformFunction.java index 827c9f37337b..c2e8ce8130c1 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlTransformFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlTransformFunction.java @@ -16,8 +16,11 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.calcite.sql; +package org.apache.pinot.common.function.sql; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java index 82be9bcf52c3..5c6835e293f6 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java @@ -30,7 +30,6 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; -import static org.testng.Assert.fail; public class InbuiltFunctionEvaluatorTest { @@ -131,7 +130,7 @@ public void testStateSharedBetweenRowsForExecution() throws Exception { MyFunc myFunc = new MyFunc(); Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class); - FunctionRegistry.registerFunction(method, false, false); + FunctionRegistry.registerFunction(method, false); String expression = "appendToStringAndReturn('test ')"; InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); assertTrue(evaluator.getArguments().isEmpty()); @@ -158,21 +157,6 @@ public void testNullReturnedByInbuiltFunctionEvaluatorThatCannotTakeNull() { } } - @Test - public void testPlaceholderFunctionShouldNotBeRegistered() - throws Exception { - GenericRow row = new GenericRow(); - row.putValue("testColumn", "testValue"); - String expression = "text_match(testColumn, 'pattern')"; - try { - InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); - evaluator.evaluate(row); - fail(); - } catch (Throwable t) { - assertTrue(t.getMessage().contains("text_match")); - } - } - public static class MyFunc { String _baseString = ""; diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java index 0c7b0e3e52dc..a69f537687a3 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java @@ -56,7 +56,7 @@ public void testPostAggregationFunction() { assertEquals(function.invoke(new Object[]{"1234567890"}), "0987654321"); // ST_AsText - function = new PostAggregationFunction("ST_AsText", new ColumnDataType[]{ColumnDataType.BYTES}); + function = new PostAggregationFunction("ST_As_Text", new ColumnDataType[]{ColumnDataType.BYTES}); assertEquals(function.getResultType(), ColumnDataType.STRING); assertEquals(function.invoke( new Object[]{GeometrySerializer.serialize(GeometryUtils.GEOMETRY_FACTORY.createPoint(new Coordinate(10, 20)))}), diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java b/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java index edb2d74bf07c..adadcd6992b6 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java @@ -23,7 +23,8 @@ import org.apache.calcite.schema.Function; import org.apache.calcite.schema.Schema; import org.apache.calcite.schema.SchemaPlus; -import org.apache.pinot.common.function.FunctionRegistry; +import org.apache.pinot.common.function.registry.PinotFunction; +import org.apache.pinot.common.function.sql.PinotFunctionRegistry; /** @@ -54,7 +55,7 @@ private CalciteSchemaBuilder() { public static CalciteSchema asRootSchema(Schema root) { CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false, "", root); SchemaPlus schemaPlus = rootSchema.plus(); - for (Map.Entry> e : FunctionRegistry.getRegisteredCalciteFunctionMap().entrySet()) { + for (Map.Entry> e : PinotFunctionRegistry.getFunctionMap().map().entrySet()) { for (Function f : e.getValue()) { schemaPlus.add(e.getKey(), f); } diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java index df904123d201..68e331eb5d8b 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java @@ -40,7 +40,6 @@ import org.apache.calcite.rel.logical.PinotLogicalExchange; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.PinotSqlAggFunction; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; @@ -53,6 +52,7 @@ import org.apache.calcite.util.mapping.Mapping; import org.apache.calcite.util.mapping.MappingType; import org.apache.calcite.util.mapping.Mappings; +import org.apache.pinot.common.function.sql.PinotSqlAggFunction; import org.apache.pinot.query.planner.plannode.AggregateNode.AggType; import org.apache.pinot.segment.spi.AggregationFunctionType; diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotSqlCoalesceFunction.java b/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlCoalesceFunction.java similarity index 90% rename from pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotSqlCoalesceFunction.java rename to pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlCoalesceFunction.java index 92ef85857f9a..2b2ee32f083f 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotSqlCoalesceFunction.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/sql/PinotSqlCoalesceFunction.java @@ -16,10 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.calcite.sql.fun; +package org.apache.calcite.sql; -import org.apache.calcite.sql.SqlCall; -import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.fun.SqlCoalesceFunction; import org.apache.calcite.sql.validate.SqlValidator; diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java deleted file mode 100644 index 3617a7c06270..000000000000 --- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java +++ /dev/null @@ -1,171 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.calcite.sql.fun; - -import java.lang.reflect.Field; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; -import org.apache.calcite.sql.PinotSqlAggFunction; -import org.apache.calcite.sql.PinotSqlTransformFunction; -import org.apache.calcite.sql.SqlFunction; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.validate.SqlNameMatchers; -import org.apache.calcite.util.Util; -import org.apache.pinot.common.function.TransformFunctionType; -import org.apache.pinot.segment.spi.AggregationFunctionType; -import org.checkerframework.checker.nullness.qual.MonotonicNonNull; - - -/** - * {@link PinotOperatorTable} defines the {@link SqlOperator} overrides on top of the {@link SqlStdOperatorTable}. - * - *

The main purpose of this Pinot specific SQL operator table is to - *

- */ -@SuppressWarnings("unused") // unused fields are accessed by reflection -public class PinotOperatorTable extends SqlStdOperatorTable { - - private static @MonotonicNonNull PinotOperatorTable _instance; - - // TODO: clean up lazy init by using Suppliers.memorized(this::computeInstance) and make getter wrapped around - // supplier instance. this should replace all lazy init static objects in the codebase - public static synchronized PinotOperatorTable instance() { - if (_instance == null) { - // Creates and initializes the standard operator table. - // Uses two-phase construction, because we can't initialize the - // table until the constructor of the sub-class has completed. - _instance = new PinotOperatorTable(); - _instance.initNoDuplicate(); - } - return _instance; - } - - /** - * Initialize without duplicate, e.g. when 2 duplicate operator is linked with the same op - * {@link org.apache.calcite.sql.SqlKind} it causes problem. - * - *

This is a direct copy of the {@link org.apache.calcite.sql.util.ReflectiveSqlOperatorTable} and can be hard to - * debug, suggest changing to a non-dynamic registration. Dynamic function support should happen via catalog. - * - * This also registers aggregation functions defined in {@link org.apache.pinot.segment.spi.AggregationFunctionType} - * which are multistage enabled. - */ - public final void initNoDuplicate() { - // Pinot supports native COALESCE function, thus no need to create CASE WHEN conversion. - register(new PinotSqlCoalesceFunction()); - // Ensure ArrayValueConstructor is registered before ArrayQueryConstructor - register(ARRAY_VALUE_CONSTRUCTOR); - - // TODO: reflection based registration is not ideal, we should use a static list of operators and register them - // Use reflection to register the expressions stored in public fields. - for (Field field : getClass().getFields()) { - try { - if (SqlFunction.class.isAssignableFrom(field.getType())) { - SqlFunction op = (SqlFunction) field.get(this); - if (op != null && notRegistered(op)) { - register(op); - } - } else if ( - SqlOperator.class.isAssignableFrom(field.getType())) { - SqlOperator op = (SqlOperator) field.get(this); - if (op != null && notRegistered(op)) { - register(op); - } - } - } catch (IllegalArgumentException | IllegalAccessException e) { - throw Util.throwAsRuntime(Util.causeOrSelf(e)); - } - } - - // Walk through all the Pinot aggregation types and - // 1. register those that are supported in multistage in addition to calcite standard opt table. - // 2. register special handling that differs from calcite standard. - for (AggregationFunctionType aggregationFunctionType : AggregationFunctionType.values()) { - if (aggregationFunctionType.getSqlKind() != null) { - // 1. Register the aggregation function with Calcite - registerAggregateFunction(aggregationFunctionType.getName(), aggregationFunctionType); - // 2. Register the aggregation function with Calcite on all alternative names - List alternativeFunctionNames = aggregationFunctionType.getAlternativeNames(); - for (String alternativeFunctionName : alternativeFunctionNames) { - registerAggregateFunction(alternativeFunctionName, aggregationFunctionType); - } - } - } - - // Walk through all the Pinot transform types and - // 1. register those that are supported in multistage in addition to calcite standard opt table. - // 2. register special handling that differs from calcite standard. - for (TransformFunctionType transformFunctionType : TransformFunctionType.values()) { - if (transformFunctionType.getSqlKind() != null) { - // 1. Register the transform function with Calcite - registerTransformFunction(transformFunctionType.getName(), transformFunctionType); - // 2. Register the transform function with Calcite on all alternative names - List alternativeFunctionNames = transformFunctionType.getAlternativeNames(); - for (String alternativeFunctionName : alternativeFunctionNames) { - registerTransformFunction(alternativeFunctionName, transformFunctionType); - } - } - } - } - - private void registerAggregateFunction(String functionName, AggregationFunctionType functionType) { - // register function behavior that's different from Calcite - if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { - PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, - functionType.getSqlKind(), functionType.getReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); - if (notRegistered(sqlAggFunction)) { - register(sqlAggFunction); - } - } - } - - private void registerTransformFunction(String functionName, TransformFunctionType functionType) { - // register function behavior that's different from Calcite - if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { - PinotSqlTransformFunction sqlTransformFunction = - new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), - functionType.getSqlKind(), functionType.getReturnTypeInference(), null, - functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory()); - if (notRegistered(sqlTransformFunction)) { - register(sqlTransformFunction); - } - } - } - - private boolean notRegistered(SqlFunction op) { - List operatorList = new ArrayList<>(); - lookupOperatorOverloads(op.getNameAsId(), op.getFunctionType(), op.getSyntax(), operatorList, - SqlNameMatchers.withCaseSensitive(false)); - return operatorList.size() == 0; - } - - private boolean notRegistered(SqlOperator op) { - List operatorList = new ArrayList<>(); - lookupOperatorOverloads(op.getNameAsId(), null, op.getSyntax(), operatorList, - SqlNameMatchers.withCaseSensitive(false)); - return operatorList.size() == 0; - } -} diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/util/PinotSqlStdOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/calcite/sql/util/PinotSqlStdOperatorTable.java new file mode 100644 index 000000000000..e3f6d8a72214 --- /dev/null +++ b/pinot-query-planner/src/main/java/org/apache/calcite/sql/util/PinotSqlStdOperatorTable.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.calcite.sql.util; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.List; +import org.apache.calcite.sql.PinotSqlCoalesceFunction; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.validate.SqlNameMatchers; +import org.apache.calcite.util.Util; +import org.apache.pinot.common.function.FunctionRegistry; + + +public class PinotSqlStdOperatorTable extends SqlStdOperatorTable { + private static PinotSqlStdOperatorTable _instance; + + // supplier instance. this should replace all lazy init static objects in the codebase + public static synchronized PinotSqlStdOperatorTable instance() { + if (_instance == null) { + // Creates and initializes the standard operator table. + // Uses two-phase construction, because we can't initialize the + // table until the constructor of the sub-class has completed. + _instance = new PinotSqlStdOperatorTable(); + _instance.initNoDuplicate(); + } + return _instance; + } + + /** + * Initialize without duplicate, e.g. when 2 duplicate operator is linked with the same op + * {@link org.apache.calcite.sql.SqlKind} it causes problem. + * + *

This is a direct copy of the {@link org.apache.calcite.sql.util.ReflectiveSqlOperatorTable} and can be hard to + * debug, suggest changing to a non-dynamic registration. Dynamic function support should happen via catalog. + * + * This also registers aggregation functions defined in {@link org.apache.pinot.segment.spi.AggregationFunctionType} + * which are multistage enabled. + */ + public final void initNoDuplicate() { + // Pinot supports native COALESCE function, thus no need to create CASE WHEN conversion. + register(new PinotSqlCoalesceFunction()); + // Ensure ArrayValueConstructor is registered before ArrayQueryConstructor + register(ARRAY_VALUE_CONSTRUCTOR); + + // TODO: reflection based registration is not ideal, we should use a static list of operators and register them + // Use reflection to register the expressions stored in public fields. + for (Field field : getClass().getFields()) { + try { + if (SqlFunction.class.isAssignableFrom(field.getType())) { + SqlFunction op = (SqlFunction) field.get(this); + if (op != null && notRegistered(op)) { + register(op); + } + } else if (SqlOperator.class.isAssignableFrom(field.getType())) { + SqlOperator op = (SqlOperator) field.get(this); + if (op != null && notRegistered(op)) { + register(op); + } + } + } catch (IllegalArgumentException | IllegalAccessException e) { + throw Util.throwAsRuntime(Util.causeOrSelf(e)); + } + } + } + + private boolean notRegistered(SqlFunction op) { + List operatorList = new ArrayList<>(); + lookupOperatorOverloads(op.getNameAsId(), op.getFunctionType(), op.getSyntax(), operatorList, + SqlNameMatchers.withCaseSensitive(FunctionRegistry.CASE_SENSITIVITY)); + return operatorList.size() == 0; + } + + private boolean notRegistered(SqlOperator op) { + List operatorList = new ArrayList<>(); + lookupOperatorOverloads(op.getNameAsId(), null, op.getSyntax(), operatorList, + SqlNameMatchers.withCaseSensitive(FunctionRegistry.CASE_SENSITIVITY)); + return operatorList.size() == 0; + } +} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java index 769d6a607fc2..c835eb36cd93 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java @@ -35,7 +35,6 @@ import org.apache.calcite.plan.hep.HepMatchOrder; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; -import org.apache.calcite.prepare.PinotCalciteCatalogReader; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; @@ -52,8 +51,8 @@ import org.apache.calcite.sql.SqlExplainLevel; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.fun.PinotOperatorTable; import org.apache.calcite.sql.util.PinotChainedSqlOperatorTable; +import org.apache.calcite.sql.util.PinotSqlStdOperatorTable; import org.apache.calcite.sql2rel.PinotConvertletTable; import org.apache.calcite.sql2rel.RelDecorrelator; import org.apache.calcite.sql2rel.SqlToRelConverter; @@ -61,6 +60,8 @@ import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RelBuilder; import org.apache.pinot.common.config.provider.TableCache; +import org.apache.pinot.common.function.sql.PinotCalciteCatalogReader; +import org.apache.pinot.common.function.sql.PinotOperatorTable; import org.apache.pinot.query.context.PlannerContext; import org.apache.pinot.query.planner.PlannerUtils; import org.apache.pinot.query.planner.QueryPlan; @@ -113,6 +114,7 @@ public QueryEnvironment(TypeFactory typeFactory, CalciteSchema rootSchema, Worke _config = Frameworks.newConfigBuilder().traitDefs() .operatorTable(new PinotChainedSqlOperatorTable(Arrays.asList( + PinotSqlStdOperatorTable.instance(), PinotOperatorTable.instance(), _catalogReader))) .defaultSchema(_rootSchema.plus()) diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java index 46a743d52c79..0f7ce1119841 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java @@ -41,7 +41,7 @@ * - byte[] */ @Retention(RetentionPolicy.RUNTIME) -@Target(ElementType.METHOD) +@Target({ElementType.METHOD, ElementType.TYPE}) public @interface ScalarFunction { boolean enabled() default true;