Skip to content

Commit

Permalink
[refactor](nereids) Remove SlotBinder and FunctionBinder (apache#36872)
Browse files Browse the repository at this point in the history
The new ExpressionAnalyzer can do both bind slot and bind function, so I remove SlotBinder and FunctionBinder.
  • Loading branch information
924060929 authored Jun 27, 2024
1 parent 3493a10 commit 8f2d075
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 776 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.doris.datasource.iceberg.IcebergExternalDatabase;
import org.apache.doris.datasource.iceberg.IcebergExternalTable;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundHiveTableSink;
import org.apache.doris.nereids.analyzer.UnboundIcebergTableSink;
import org.apache.doris.nereids.analyzer.UnboundSlot;
Expand All @@ -43,7 +44,6 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FunctionBinder;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.DefaultValueSlot;
Expand All @@ -54,7 +54,6 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Substring;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.commands.info.DMLCommandType;
import org.apache.doris.nereids.trees.plans.logical.LogicalHiveTableSink;
Expand Down Expand Up @@ -204,8 +203,8 @@ private Plan bindOlapTableSink(MatchingContext<UnboundTableSink<Plan>> ctx) {
throw new AnalysisException(e.getMessage(), e.getCause());
}

Map<String, NamedExpression> columnToOutput = getColumnToOutput(ctx, table, isPartialUpdate,
boundSink, child);
Map<String, NamedExpression> columnToOutput = getColumnToOutput(
ctx, table, isPartialUpdate, boundSink, child);
LogicalProject<?> fullOutputProject = getOutputProjectByCoercion(table.getFullSchema(), child, columnToOutput);
return boundSink.withChildAndUpdateOutput(fullOutputProject);
}
Expand Down Expand Up @@ -323,12 +322,15 @@ private static Map<String, NamedExpression> getColumnToOutput(
// update the value of the column to the current timestamp whenever there
// is an update on the row
if (column.hasOnUpdateDefaultValue()) {
Expression defualtValueExpression = FunctionBinder.INSTANCE.rewrite(
new NereidsParser().parseExpression(
column.getOnUpdateDefaultValueExpr().toSqlWithoutTbl()),
new ExpressionRewriteContext(ctx.cascadesContext));
Expression unboundFunctionDefaultValue = new NereidsParser().parseExpression(
column.getOnUpdateDefaultValueExpr().toSqlWithoutTbl()
);
Expression defualtValueExpression = ExpressionAnalyzer.analyzeFunction(
boundSink, ctx.cascadesContext, unboundFunctionDefaultValue
);
columnToOutput.put(column.getName(),
new Alias(defualtValueExpression, column.getName()));
new Alias(defualtValueExpression, column.getName())
);
} else {
continue;
}
Expand Down Expand Up @@ -356,10 +358,10 @@ private static Map<String, NamedExpression> getColumnToOutput(
.checkedCastTo(DataType.fromCatalogType(column.getType())),
column.getName()));
} else {
Expression defualtValueExpression = FunctionBinder.INSTANCE.rewrite(
new NereidsParser().parseExpression(
column.getDefaultValueExpr().toSqlWithoutTbl()),
new ExpressionRewriteContext(ctx.cascadesContext));
Expression unboundDefaultValue = new NereidsParser().parseExpression(
column.getDefaultValueExpr().toSqlWithoutTbl());
Expression defualtValueExpression = ExpressionAnalyzer.analyzeFunction(
boundSink, ctx.cascadesContext, unboundDefaultValue);
if (defualtValueExpression instanceof Alias) {
defualtValueExpression = ((Alias) defualtValueExpression).child();
}
Expand All @@ -374,13 +376,12 @@ private static Map<String, NamedExpression> getColumnToOutput(
}
// the generated columns can use all ordinary columns,
// if processed in upper for loop, will lead to not found slot error
//It's the same reason for moving the processing of materialized columns down.
// It's the same reason for moving the processing of materialized columns down.
for (Column column : generatedColumns) {
GeneratedColumnInfo info = column.getGeneratedColumnInfo();
Expression parsedExpression = new NereidsParser().parseExpression(info.getExpr().toSqlWithoutTbl());
Expression boundSlotExpression = SlotReplacer.INSTANCE.replace(parsedExpression, columnToOutput);
Expression boundExpression = FunctionBinder.INSTANCE.rewrite(boundSlotExpression,
new ExpressionRewriteContext(ctx.cascadesContext));
Expression boundExpression = new CustomExpressionAnalyzer(boundSink, ctx.cascadesContext, columnToOutput)
.analyze(parsedExpression);
if (boundExpression instanceof Alias) {
boundExpression = ((Alias) boundExpression).child();
}
Expand All @@ -395,13 +396,11 @@ private static Map<String, NamedExpression> getColumnToOutput(
"mv column %s 's ref column cannot be null", column);
Expression parsedExpression = expressionParser.parseExpression(
column.getDefineExpr().toSqlWithoutTbl());
Expression boundSlotExpression = SlotReplacer.INSTANCE
.replace(parsedExpression, columnToOutput);
// the boundSlotExpression is an expression whose slots are bound but function
// may not be bound, we have to bind it again.
// for example: to_bitmap.
Expression boundExpression = FunctionBinder.INSTANCE.rewrite(
boundSlotExpression, new ExpressionRewriteContext(ctx.cascadesContext));
Expression boundExpression = new CustomExpressionAnalyzer(
boundSink, ctx.cascadesContext, columnToOutput).analyze(parsedExpression);
if (boundExpression instanceof Alias) {
boundExpression = ((Alias) boundExpression).child();
}
Expand Down Expand Up @@ -599,19 +598,21 @@ private boolean validColumn(Column column, boolean isNeedSequenceCol) {
&& !column.isMaterializedViewColumn();
}

private static class SlotReplacer extends DefaultExpressionRewriter<Map<String, NamedExpression>> {
public static final SlotReplacer INSTANCE = new SlotReplacer();
private static class CustomExpressionAnalyzer extends ExpressionAnalyzer {
private Map<String, NamedExpression> slotBinder;

public Expression replace(Expression e, Map<String, NamedExpression> replaceMap) {
return e.accept(this, replaceMap);
public CustomExpressionAnalyzer(
Plan currentPlan, CascadesContext cascadesContext, Map<String, NamedExpression> slotBinder) {
super(currentPlan, new Scope(ImmutableList.of()), cascadesContext, false, false);
this.slotBinder = slotBinder;
}

@Override
public Expression visitUnboundSlot(UnboundSlot unboundSlot, Map<String, NamedExpression> replaceMap) {
if (!replaceMap.containsKey(unboundSlot.getName())) {
public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteContext context) {
if (!slotBinder.containsKey(unboundSlot.getName())) {
throw new AnalysisException("cannot find column from target table " + unboundSlot.getNameParts());
}
return replaceMap.get(unboundSlot.getName());
return slotBinder.get(unboundSlot.getName());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.doris.nereids.analyzer.UnboundVariable;
import org.apache.doris.nereids.analyzer.UnboundVariable.VariableType;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.trees.expressions.Alias;
Expand Down Expand Up @@ -78,6 +79,7 @@
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
Expand All @@ -91,6 +93,7 @@
import org.apache.doris.qe.VariableVarConverters;
import org.apache.doris.qe.cache.CacheAnalyzer;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
Expand All @@ -102,9 +105,19 @@
import java.util.Locale;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

/** ExpressionAnalyzer */
public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext> {
@VisibleForTesting
public static final AbstractExpressionRewriteRule FUNCTION_ANALYZER_RULE = new AbstractExpressionRewriteRule() {
@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
return new ExpressionAnalyzer(
null, new Scope(ImmutableList.of()), null, false, false
).analyze(expr, ctx);
}
};

private final Plan currentPlan;
/*
Expand All @@ -124,14 +137,30 @@ public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext
private boolean hasNondeterministic;

/** ExpressionAnalyzer */
public ExpressionAnalyzer(Plan currentPlan, Scope scope, CascadesContext cascadesContext,
boolean enableExactMatch, boolean bindSlotInOuterScope) {
public ExpressionAnalyzer(Plan currentPlan, Scope scope,
@Nullable CascadesContext cascadesContext, boolean enableExactMatch, boolean bindSlotInOuterScope) {
super(scope, cascadesContext);
this.currentPlan = currentPlan;
this.enableExactMatch = enableExactMatch;
this.bindSlotInOuterScope = bindSlotInOuterScope;
this.wantToParseSqlFromSqlCache = CacheAnalyzer.canUseSqlCache(
cascadesContext.getConnectContext().getSessionVariable());
this.wantToParseSqlFromSqlCache = cascadesContext != null
&& CacheAnalyzer.canUseSqlCache(cascadesContext.getConnectContext().getSessionVariable());
}

/** analyzeFunction */
public static Expression analyzeFunction(
@Nullable LogicalPlan plan, @Nullable CascadesContext cascadesContext, Expression expression) {
ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plan, new Scope(ImmutableList.of()),
cascadesContext, false, false);
return analyzer.analyze(
expression,
cascadesContext == null ? null : new ExpressionRewriteContext(cascadesContext)
);
}

public Expression analyze(Expression expression) {
CascadesContext cascadesContext = getCascadesContext();
return analyze(expression, cascadesContext == null ? null : new ExpressionRewriteContext(cascadesContext));
}

/** analyze */
Expand Down Expand Up @@ -258,10 +287,7 @@ public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteCon
if (tableName.isEmpty()) {
tableName = "table list";
}
throw new AnalysisException("Unknown column '"
+ unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1)
+ "' in '" + tableName + "' in "
+ currentPlan.getType().toString().substring("LOGICAL_".length()) + " clause");
couldNotFoundColumn(unboundSlot, tableName);
}
return unboundSlot;
case 1:
Expand Down Expand Up @@ -300,6 +326,16 @@ public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteCon
}
}

protected void couldNotFoundColumn(UnboundSlot unboundSlot, String tableName) {
String message = "Unknown column '"
+ unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1)
+ "' in '" + tableName;
if (currentPlan != null) {
message += "' in " + currentPlan.getType().toString().substring("LOGICAL_".length()) + " clause";
}
throw new AnalysisException(message);
}

@Override
public Expression visitUnboundStar(UnboundStar unboundStar, ExpressionRewriteContext context) {
List<String> qualifier = unboundStar.getQualifier();
Expand Down Expand Up @@ -356,15 +392,18 @@ public Expression visitUnboundFunction(UnboundFunction unboundFunction, Expressi
FunctionBuilder builder = functionRegistry.findFunctionBuilder(
unboundFunction.getDbName(), functionName, arguments);
Pair<? extends Expression, ? extends BoundFunction> buildResult = builder.build(functionName, arguments);
StatementContext statementContext = context.cascadesContext.getStatementContext();
if (buildResult.second instanceof Nondeterministic) {
hasNondeterministic = true;
}
Optional<SqlCacheContext> sqlCacheContext = statementContext.getSqlCacheContext();
if (builder instanceof AliasUdfBuilder
|| buildResult.second instanceof JavaUdf || buildResult.second instanceof JavaUdaf) {
if (sqlCacheContext.isPresent()) {
sqlCacheContext.get().setCannotProcessExpression(true);
Optional<SqlCacheContext> sqlCacheContext = Optional.empty();
if (wantToParseSqlFromSqlCache) {
StatementContext statementContext = context.cascadesContext.getStatementContext();
if (buildResult.second instanceof Nondeterministic) {
hasNondeterministic = true;
}
sqlCacheContext = statementContext.getSqlCacheContext();
if (builder instanceof AliasUdfBuilder
|| buildResult.second instanceof JavaUdf || buildResult.second instanceof JavaUdaf) {
if (sqlCacheContext.isPresent()) {
sqlCacheContext.get().setCannotProcessExpression(true);
}
}
}
if (builder instanceof AliasUdfBuilder) {
Expand All @@ -376,9 +415,9 @@ public Expression visitUnboundFunction(UnboundFunction unboundFunction, Expressi
} else {
Expression castFunction = TypeCoercionUtils.processBoundFunction((BoundFunction) buildResult.first);
if (castFunction instanceof Count
&& context != null
&& context.cascadesContext.getOuterScope().isPresent()
&& !context.cascadesContext.getOuterScope().get().getCorrelatedSlots()
.isEmpty()) {
&& !context.cascadesContext.getOuterScope().get().getCorrelatedSlots().isEmpty()) {
// consider sql: SELECT * FROM t1 WHERE t1.a <= (SELECT COUNT(t2.a) FROM t2 WHERE (t1.b = t2.b));
// when unnest correlated subquery, we create a left join node.
// outer query is left table and subquery is right one
Expand Down Expand Up @@ -488,6 +527,9 @@ public Expression visitNot(Not not, ExpressionRewriteContext context) {

@Override
public Expression visitPlaceholder(Placeholder placeholder, ExpressionRewriteContext context) {
if (context == null) {
return super.visitPlaceholder(placeholder, context);
}
Expression realExpr = context.cascadesContext.getStatementContext()
.getIdToPlaceholderRealExpr().get(placeholder.getPlaceholderId());
return visit(realExpr, context);
Expand Down Expand Up @@ -714,17 +756,6 @@ public static boolean sameTableName(String boundSlot, String unboundSlot) {
}
}

private void checkBoundLambda(Expression lambdaFunction, List<String> argumentNames) {
lambdaFunction.foreachUp(e -> {
if (e instanceof UnboundSlot) {
UnboundSlot unboundSlot = (UnboundSlot) e;
throw new AnalysisException("Unknown lambda slot '"
+ unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1)
+ " in lambda arguments" + argumentNames);
}
});
}

private UnboundFunction bindHighOrderFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) {
int childrenSize = unboundFunction.children().size();
List<Expression> subChildren = new ArrayList<>();
Expand All @@ -737,16 +768,20 @@ private UnboundFunction bindHighOrderFunction(UnboundFunction unboundFunction, E
Expression lambdaFunction = lambda.getLambdaFunction();
List<ArrayItemReference> arrayItemReferences = lambda.makeArguments(subChildren);

// 1.bindSlot
List<Slot> boundedSlots = arrayItemReferences.stream()
.map(ArrayItemReference::toSlot)
.collect(ImmutableList.toImmutableList());
lambdaFunction = new SlotBinder(new Scope(boundedSlots), context.cascadesContext,
true, false).bind(lambdaFunction);
checkBoundLambda(lambdaFunction, lambda.getLambdaArgumentNames());

// 2.bindFunction
lambdaFunction = lambdaFunction.accept(this, context);
ExpressionAnalyzer lambdaAnalyzer = new ExpressionAnalyzer(currentPlan, new Scope(boundedSlots),
context == null ? null : context.cascadesContext, true, false) {
@Override
protected void couldNotFoundColumn(UnboundSlot unboundSlot, String tableName) {
throw new AnalysisException("Unknown lambda slot '"
+ unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1)
+ " in lambda arguments" + lambda.getLambdaArgumentNames());
}
};
lambdaFunction = lambdaAnalyzer.analyze(lambdaFunction, context);

Lambda lambdaClosure = lambda.withLambdaFunctionArguments(lambdaFunction, arrayItemReferences);

Expand Down
Loading

0 comments on commit 8f2d075

Please sign in to comment.