Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] function reg full use arg type lookup #92

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,26 @@
*/
package org.apache.pinot.common.function;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.schema.Function;
import org.apache.calcite.schema.impl.ScalarFunctionImpl;
import org.apache.calcite.util.NameMultimap;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.function.registry.PinotFunction;
import org.apache.pinot.common.function.registry.PinotScalarFunction;
import org.apache.pinot.common.function.sql.PinotCalciteCatalogReader;
import org.apache.pinot.common.function.sql.PinotFunctionRegistry;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.spi.annotations.ScalarFunction;
import org.apache.pinot.spi.utils.PinotReflectionUtils;
import org.slf4j.Logger;
Expand All @@ -38,19 +46,14 @@

/**
* Registry for scalar functions.
* <p>TODO: Merge FunctionRegistry and FunctionDefinitionRegistry to provide one single registry for all functions.
*/
public class FunctionRegistry {
private FunctionRegistry() {
}

public static final boolean CASE_SENSITIVITY = false;
private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class);

// TODO: consolidate the following 2
// This FUNCTION_INFO_MAP is used by Pinot server to look up function by # of arguments
private static final Map<String, Map<Integer, FunctionInfo>> FUNCTION_INFO_MAP = new HashMap<>();
// This FUNCTION_MAP is used by Calcite function catalog to look up function by function signature.
private static final NameMultimap<Function> FUNCTION_MAP = new NameMultimap<>();

private FunctionRegistry() {
}

/**
* Registers the scalar functions via reflection.
Expand All @@ -66,16 +69,13 @@ private FunctionRegistry() {
}
ScalarFunction scalarFunction = method.getAnnotation(ScalarFunction.class);
if (scalarFunction.enabled()) {
// Annotated function names
String[] scalarFunctionNames = scalarFunction.names();
boolean nullableParameters = scalarFunction.nullableParameters();
if (scalarFunctionNames.length > 0) {
for (String name : scalarFunctionNames) {
FunctionRegistry.registerFunction(name, method, nullableParameters, scalarFunction.isPlaceholder());
}
} else {
FunctionRegistry.registerFunction(method, nullableParameters, scalarFunction.isPlaceholder());
// Parse annotated function names and alias
Set<String> scalarFunctionNames = Arrays.stream(scalarFunction.names()).collect(Collectors.toSet());
if (scalarFunctionNames.size() == 0) {
scalarFunctionNames.add(method.getName());
}
boolean nullableParameters = scalarFunction.nullableParameters();
FunctionRegistry.registerFunction(method, scalarFunctionNames, nullableParameters);
}
}
LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_INFO_MAP.size(),
Expand All @@ -90,22 +90,65 @@ private FunctionRegistry() {
public static void init() {
}

@VisibleForTesting
public static void registerFunction(Method method, boolean nullableParameters) {
registerFunction(method, Collections.singleton(method.getName()), nullableParameters);
}

@VisibleForTesting
public static Set<String> getRegisteredCalciteFunctionNames() {
return PinotFunctionRegistry.getFunctionMap().map().keySet();
}

/**
* Registers a method with the name of the method.
* Returns {@code true} if the given function name is registered, {@code false} otherwise.
*/
public static boolean containsFunction(String functionName) {
// TODO: remove deprecated FUNCTION_INFO_MAP
return PinotFunctionRegistry.getFunctionMap().containsKey(functionName, CASE_SENSITIVITY)
|| FUNCTION_INFO_MAP.containsKey(canonicalize(functionName));
}

/**
* Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null}
* if there is no matching method. This method should be called after the FunctionRegistry is initialized and all
* methods are already registered.
*/
public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder) {
registerFunction(method.getName(), method, nullableParameters, isPlaceholder);
@Nullable
public static FunctionInfo getFunctionInfo(String functionName, int numParams) {
try {
return getFunctionInfoFromCalciteNamedMap(functionName, numParams);
} catch (IllegalArgumentException iae) {
// TODO: remove deprecated FUNCTION_INFO_MAP
return getFunctionInfoFromFunctionInfoMap(functionName, numParams);
}
}

/**
* Registers a method with the given function name.
* Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null}
* if there is no matching method. This method should be called after the FunctionRegistry is initialized and all
* methods are already registered.
*/
public static void registerFunction(String functionName, Method method, boolean nullableParameters,
boolean isPlaceholder) {
if (!isPlaceholder) {
registerFunctionInfoMap(functionName, method, nullableParameters);
@Nullable
public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory,
String functionName, List<DataSchema.ColumnDataType> argTypes) {
PinotScalarFunction scalarFunction =
PinotFunctionRegistry.getScalarFunction(operatorTable, typeFactory, functionName, argTypes);
if (scalarFunction != null) {
return scalarFunction.getFunctionInfo();
} else {
throw new IllegalArgumentException(
"Unable to lookup function: " + functionName + " with parameter type signature: " + argTypes);
}
}

// TODO: remove deprecated FUNCTION_INFO_MAP
private static void registerFunction(Method method, Set<String> alias, boolean nullableParameters) {
if (method.getAnnotation(Deprecated.class) == null) {
for (String name : alias) {
registerFunctionInfoMap(name, method, nullableParameters);
}
}
registerCalciteNamedFunctionMap(functionName, method, nullableParameters);
}

private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters) {
Expand All @@ -117,36 +160,25 @@ private static void registerFunctionInfoMap(String functionName, Method method,
"Function: %s with %s parameters is already registered", functionName, method.getParameterCount());
}

private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters) {
if (method.getAnnotation(Deprecated.class) == null) {
FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method));
}
}

public static Map<String, List<Function>> getRegisteredCalciteFunctionMap() {
return FUNCTION_MAP.map();
}

public static Set<String> getRegisteredCalciteFunctionNames() {
return FUNCTION_MAP.map().keySet();
}

/**
* Returns {@code true} if the given function name is registered, {@code false} otherwise.
*/
public static boolean containsFunction(String functionName) {
return FUNCTION_INFO_MAP.containsKey(canonicalize(functionName));
@Nullable
private static FunctionInfo getFunctionInfoFromFunctionInfoMap(String functionName, int numParams) {
Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName));
return functionInfoMap != null ? functionInfoMap.get(numParams) : null;
}

/**
* Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null}
* if there is no matching method. This method should be called after the FunctionRegistry is initialized and all
* methods are already registered.
*/
@Nullable
public static FunctionInfo getFunctionInfo(String functionName, int numParameters) {
Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName));
return functionInfoMap != null ? functionInfoMap.get(numParameters) : null;
private static FunctionInfo getFunctionInfoFromCalciteNamedMap(String functionName, int numParams) {
List<PinotScalarFunction> candidates = PinotFunctionRegistry.getFunctionMap()
.range(functionName, CASE_SENSITIVITY).stream()
.filter(e -> e.getValue() instanceof PinotScalarFunction && e.getValue().getParameters().size() == numParams)
.map(e -> (PinotScalarFunction) e.getValue()).collect(Collectors.toList());
if (candidates.size() == 1) {
return candidates.get(0).getFunctionInfo();
} else {
throw new IllegalArgumentException(
"Unable to lookup function: " + functionName + " by parameter count: " + numParams + " Found "
+ candidates.size() + " candidates. Try to use argument types to resolve the correct one!");
}
}

private static String canonicalize(String functionName) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.common.function.registry;

import org.apache.calcite.schema.Function;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlReturnTypeInference;


public interface PinotFunction extends Function {
SqlOperandTypeChecker getOperandTypeChecker();
SqlReturnTypeInference getReturnTypeInference();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.common.function.registry;

import java.lang.reflect.Method;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.schema.ScalarFunction;
import org.apache.calcite.schema.impl.ReflectiveFunctionBase;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.pinot.common.function.FunctionInfo;


/**
* Pinot specific implementation of the {@link ScalarFunction}.
*
* @see "{@link org.apache.calcite.schema.impl.ScalarFunctionImpl}"
*/
public class PinotScalarFunction extends ReflectiveFunctionBase implements PinotFunction, ScalarFunction {
private final FunctionInfo _functionInfo;
private final String _name;
private final Method _method;
private final SqlOperandTypeChecker _sqlOperandTypeChecker;
private final SqlReturnTypeInference _sqlReturnTypeInference;

public PinotScalarFunction(String name, Method method, boolean isNullableParameter) {
this(name, method, isNullableParameter, null, null);
}

public PinotScalarFunction(String name, Method method, boolean isNullableParameter,
SqlOperandTypeChecker sqlOperandTypeChecker, SqlReturnTypeInference sqlReturnTypeInference) {
super(method);
_name = name;
_method = method;
_functionInfo = new FunctionInfo(method, method.getDeclaringClass(), isNullableParameter);
_sqlOperandTypeChecker = sqlOperandTypeChecker;
_sqlReturnTypeInference = sqlReturnTypeInference;
}

@Override
public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
return typeFactory.createJavaType(method.getReturnType());
}

public String getName() {
return _name;
}

public Method getMethod() {
return _method;
}

public FunctionInfo getFunctionInfo() {
return _functionInfo;
}

@Override
public SqlOperandTypeChecker getOperandTypeChecker() {
return _sqlOperandTypeChecker;
}

@Override
public SqlReturnTypeInference getReturnTypeInference() {
return _sqlReturnTypeInference;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public static String substring(String input, int beginIndex, int length) {
* @param seperator
* @return The two input strings joined by the seperator
*/
@ScalarFunction(names = "concat_ws")
@ScalarFunction(names = {"concatWS", "concat_ws"})
public static String concatws(String seperator, String input1, String input2) {
return concat(input1, input2, seperator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.calcite.prepare;
package org.apache.pinot.common.function.sql;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
Expand All @@ -35,6 +35,8 @@
import org.apache.calcite.linq4j.function.Hints;
import org.apache.calcite.model.ModelHandler;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.prepare.RelOptTableImpl;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeFactoryImpl;
Expand Down Expand Up @@ -310,7 +312,7 @@ public static SqlOperatorTable operatorTable(String... classNames) {
}

/** Converts a function to a {@link org.apache.calcite.sql.SqlOperator}. */
private static SqlOperator toOp(SqlIdentifier name,
public static SqlOperator toOp(SqlIdentifier name,
final org.apache.calcite.schema.Function function) {
final Function<RelDataTypeFactory, List<RelDataType>> argTypesFactory =
typeFactory -> function.getParameters()
Expand Down
Loading