diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index fe9177ea9e01..1f9242214611 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -47,6 +47,7 @@ import org.apache.pinot.common.function.schema.PinotFunction; import org.apache.pinot.common.function.schema.PinotScalarFunction; import org.apache.pinot.common.function.sql.PinotSqlAggFunction; +import org.apache.pinot.common.function.sql.PinotSqlScalarFunction; import org.apache.pinot.common.function.sql.PinotSqlTransformFunction; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.segment.spi.AggregationFunctionType; @@ -188,7 +189,7 @@ public static FunctionInfo getFunctionInfo(String functionName, int numParams) { */ @Nullable public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, - String functionName, List argTypes) { + String functionName, List argTypes) { PinotScalarFunction scalarFunction = getScalarFunction(operatorTable, typeFactory, functionName, argTypes); if (scalarFunction != null) { return scalarFunction.getFunctionInfo(); @@ -214,17 +215,18 @@ private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionNa @Nullable private static PinotScalarFunction getScalarFunction(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory, - String functionName, List argTypes) { - List relArgTypes = convertArgumentTypes(typeFactory, argTypes); + String functionName, List argTypes) { SqlOperator sqlOperator = SqlUtil.lookupRoutine(operatorTable, typeFactory, new SqlIdentifier(functionName, SqlParserPos.QUOTED_ZERO), - relArgTypes, null, null, SqlSyntax.FUNCTION, SqlKind.OTHER_FUNCTION, + argTypes, null, null, SqlSyntax.FUNCTION, SqlKind.OTHER_FUNCTION, SqlNameMatchers.withCaseSensitive(false), true); if (sqlOperator instanceof SqlUserDefinedFunction) { Function function = ((SqlUserDefinedFunction) sqlOperator).getFunction(); if (function instanceof PinotScalarFunction) { return (PinotScalarFunction) function; } + } else if (sqlOperator instanceof PinotSqlScalarFunction) { + return ((PinotSqlScalarFunction) sqlOperator).getFunction(); } return null; } @@ -296,7 +298,7 @@ private static void registerTransformFunction(String functionName, TransformFunc } } - private static List convertArgumentTypes(RelDataTypeFactory typeFactory, + public static List convertArgumentTypes(RelDataTypeFactory typeFactory, List argTypes) { return argTypes.stream().map(type -> toRelType(typeFactory, type)).collect(Collectors.toList()); } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java index ee767c1d8525..caea2dc97345 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java @@ -18,6 +18,7 @@ */ package org.apache.pinot.common.function.scalar; +import com.google.common.base.Preconditions; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntLinkedOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; @@ -25,10 +26,16 @@ import it.unimi.dsi.fastutil.objects.ObjectSet; import java.math.BigDecimal; import java.util.Arrays; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.type.NonNullableAccessors; +import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SameOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.commons.lang3.ArrayUtils; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder; @@ -232,6 +239,38 @@ public static String arrayElementAtString(String[] arr, int idx) { return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING; } + @ScalarFunction(names = {"arrayElementAt", "array_element_at"}) + public static class ArrayElementAt { + public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = (opBinding) -> { + final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + final RelDataType operandType = opBinding.getOperandType(0); + Preconditions.checkState(operandType.getSqlTypeName() == SqlTypeName.ARRAY); + return typeFactory.createTypeWithNullability(NonNullableAccessors.getComponentTypeOrThrow(operandType), true); + }; + public static final SqlOperandTypeChecker OPERAND_TYPE_CHECKER = + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER); + + public static int eval(int[] arr, int idx) { + return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.INT; + } + + public static long eval(long[] arr, int idx) { + return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.LONG; + } + + public static float eval(float[] arr, int idx) { + return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.FLOAT; + } + + public static double eval(double[] arr, int idx) { + return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.DOUBLE; + } + + public static String eval(String[] arr, int idx) { + return idx > 0 && idx <= arr.length ? arr[idx - 1] : NullValuePlaceHolder.STRING; + } + } + @ScalarFunction(names = {"array", "arrayValueConstructor"}, isVarArg = true) public static class ArrayValueConstructor { public static final SqlReturnTypeInference RETURN_TYPE_INFERENCE = ReturnTypes.TO_ARRAY; diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java index 0d6a94105c42..34d4079962b7 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotCalciteCatalogReader.java @@ -59,6 +59,7 @@ import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlOperandMetadata; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeFamily; @@ -323,6 +324,17 @@ public static SqlOperatorTable operatorTable(String... classNames) { // ==================================================================== public static SqlOperator toOp(SqlIdentifier name, final org.apache.calcite.schema.Function function) { + + // TODO: support AGG and TABLE function in the future + if (function instanceof PinotScalarFunction && ((PinotScalarFunction) function).getOperandTypeChecker() != null) { + final SqlOperandTypeChecker operandTypeChecker = + ((PinotScalarFunction) function).getOperandTypeChecker(); + final SqlReturnTypeInference returnTypeInference = + ((PinotScalarFunction) function).getReturnTypeInference(); + final SqlKind kind = kind(function); + return new PinotSqlScalarFunction(name.toString(), kind, returnTypeInference, null, operandTypeChecker, + SqlFunctionCategory.USER_DEFINED_FUNCTION, ((PinotScalarFunction) function)); + } // ==================================================================== // LINES CHANGED ABOVE // ==================================================================== @@ -353,13 +365,7 @@ public static SqlOperator toOp(SqlIdentifier name, final List typeFamilies = typeFamiliesFactory.apply(dummyTypeFactory); - final SqlOperandTypeInference operandTypeInference; - if (function instanceof PinotScalarFunction && ((PinotScalarFunction) function).getOperandTypeChecker() != null) { - operandTypeInference = ((PinotScalarFunction) function).getOperandTypeChecker().typeInference(); - } else { - operandTypeInference = InferTypes.explicit(argTypes); - } - + final SqlOperandTypeInference operandTypeInference = InferTypes.explicit(argTypes); final SqlOperandMetadata operandMetadata = OperandTypes.operandMetadata(typeFamilies, paramTypesFactory, i -> function.getParameters().get(i).getName(), diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlScalarFunction.java new file mode 100644 index 000000000000..1cd6d3df0370 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlScalarFunction.java @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.sql; + +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlOperandTypeInference; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.pinot.common.function.schema.PinotScalarFunction; +import org.checkerframework.checker.nullness.qual.Nullable; + + +public class PinotSqlScalarFunction extends SqlFunction { + private final PinotScalarFunction _pinotScalarFunction; + + public PinotSqlScalarFunction(String name, SqlKind kind, @Nullable SqlReturnTypeInference returnTypeInference, + @Nullable SqlOperandTypeInference operandTypeInference, @Nullable SqlOperandTypeChecker operandTypeChecker, + SqlFunctionCategory category, PinotScalarFunction pinotScalarFunction) { + super(name, kind, returnTypeInference, operandTypeInference, operandTypeChecker, category); + _pinotScalarFunction = pinotScalarFunction; + } + + public PinotScalarFunction getFunction() { + return _pinotScalarFunction; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java index 773f73058fb7..6c954caf0f2f 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.stream.Collectors; import javax.annotation.Nullable; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.sql.SqlOperatorTable; import org.apache.pinot.common.function.FunctionInfo; @@ -58,7 +59,10 @@ public FunctionOperand(SqlOperatorTable sqlOperatorTable, RelDataTypeFactory rel return e.getDataType(); } }).collect(Collectors.toList()); - FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(canonicalName, operandTypes.size()); + + List argTypes = FunctionRegistry.convertArgumentTypes(relDataTypeFactory, operandTypes); + FunctionInfo functionInfo = + FunctionRegistry.getFunctionInfo(sqlOperatorTable, relDataTypeFactory, canonicalName, argTypes); Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", canonicalName); _functionInvoker = new FunctionInvoker(functionInfo); if (!_functionInvoker.getMethod().isVarArgs()) {