Skip to content

Commit

Permalink
[stash on top] use signature type lookup for v2 engine
Browse files Browse the repository at this point in the history
  • Loading branch information
Rong Rong committed Dec 18, 2023
1 parent 5f65e91 commit e04e949
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.function.registry.PinotFunction;
import org.apache.pinot.common.function.registry.PinotScalarFunction;
import org.apache.pinot.common.function.sql.PinotCalciteCatalogReader;
import org.apache.pinot.common.function.sql.PinotFunctionRegistry;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.spi.annotations.ScalarFunction;
import org.apache.pinot.spi.utils.PinotReflectionUtils;
import org.slf4j.Logger;
Expand Down Expand Up @@ -96,13 +100,6 @@ public static Set<String> getRegisteredCalciteFunctionNames() {
return PinotFunctionRegistry.getFunctionMap().map().keySet();
}

/**
* Returns the full list of all registered ScalarFunction to Calcite.
*/
public static Map<String, List<PinotFunction>> getRegisteredCalciteFunctionMap() {
return PinotFunctionRegistry.getFunctionMap().map();
}

/**
* Returns {@code true} if the given function name is registered, {@code false} otherwise.
*/
Expand All @@ -127,6 +124,24 @@ public static FunctionInfo getFunctionInfo(String functionName, int numParams) {
}
}

/**
* Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null}
* if there is no matching method. This method should be called after the FunctionRegistry is initialized and all
* methods are already registered.
*/
@Nullable
public static FunctionInfo getFunctionInfo(SqlOperatorTable operatorTable, RelDataTypeFactory typeFactory,
String functionName, List<DataSchema.ColumnDataType> argTypes) {
PinotScalarFunction scalarFunction =
PinotFunctionRegistry.getScalarFunction(operatorTable, typeFactory, functionName, argTypes);
if (scalarFunction != null) {
return scalarFunction.getFunctionInfo();
} else {
throw new IllegalArgumentException(
"Unable to lookup function: " + functionName + " with parameter type signature: " + argTypes);
}
}

// TODO: remove deprecated FUNCTION_INFO_MAP
private static void registerFunction(Method method, Set<String> alias, boolean nullableParameters) {
if (method.getAnnotation(Deprecated.class) == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.function.FunctionInvoker;
import org.apache.pinot.common.function.FunctionRegistry;
Expand All @@ -43,11 +46,19 @@ public class FunctionOperand implements TransformOperand {
private final List<TransformOperand> _operands;
private final Object[] _reusableOperandHolder;

public FunctionOperand(RexExpression.FunctionCall functionCall, String canonicalName, DataSchema dataSchema) {
public FunctionOperand(SqlOperatorTable sqlOperatorTable, RelDataTypeFactory relDataTypeFactory,
RexExpression.FunctionCall functionCall, String canonicalName, DataSchema dataSchema) {
_resultType = functionCall.getDataType();
List<RexExpression> operands = functionCall.getFunctionOperands();
int numOperands = operands.size();
FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(canonicalName, numOperands);
List<ColumnDataType> operandTypes = operands.stream().map(e -> {
if (e instanceof RexExpression.InputRef) {
return dataSchema.getColumnDataType(((RexExpression.InputRef) e).getIndex());
} else {
return e.getDataType();
}
}).collect(Collectors.toList());
FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(sqlOperatorTable, relDataTypeFactory, canonicalName, operandTypes);
Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", canonicalName);
_functionInvoker = new FunctionInvoker(functionInfo);
Class<?>[] parameterClasses = _functionInvoker.getParameterClasses();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,31 @@

import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Properties;
import org.apache.calcite.config.CalciteConnectionConfig;
import org.apache.calcite.config.CalciteConnectionConfigImpl;
import org.apache.calcite.config.CalciteConnectionProperty;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.jdbc.CalciteSchemaBuilder;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.pinot.common.function.sql.PinotCalciteCatalogReader;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.query.catalog.PinotCatalog;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.runtime.operator.utils.OperatorUtils;
import org.apache.pinot.query.type.TypeFactory;
import org.apache.pinot.query.type.TypeSystem;


public class TransformOperandFactory {
private static final RelDataTypeFactory FUNCTION_CATALOGREL_DATA_TYPE_FACTORY = new TypeFactory(new TypeSystem());
private static final CalciteSchema FUNCTION_CATALOG_ROOT_SCHEMA =
CalciteSchemaBuilder.asRootSchema(new PinotCatalog(null));
private static final CalciteConnectionConfig FUNCTION_CATALOG_CONFIG = CalciteConnectionConfig.DEFAULT;
private static final PinotCalciteCatalogReader FUNCTION_CATALOG_OPERATOR_TABLE =
new PinotCalciteCatalogReader(FUNCTION_CATALOG_ROOT_SCHEMA, FUNCTION_CATALOG_ROOT_SCHEMA.path(null),
FUNCTION_CATALOGREL_DATA_TYPE_FACTORY, FUNCTION_CATALOG_CONFIG);

private TransformOperandFactory() {
}

Expand Down Expand Up @@ -74,7 +93,8 @@ private static TransformOperand getTransformOperand(RexExpression.FunctionCall f
case "lessThanOrEqual":
return new FilterOperand.Predicate(operands, dataSchema, v -> v <= 0);
default:
return new FunctionOperand(functionCall, canonicalName, dataSchema);
return new FunctionOperand(FUNCTION_CATALOG_OPERATOR_TABLE, FUNCTION_CATALOGREL_DATA_TYPE_FACTORY,
functionCall, canonicalName, dataSchema);
}
}
}

0 comments on commit e04e949

Please sign in to comment.