Skip to content

Commit

Permalink
[feature](nereids) Support basic aggregate rewrite and function roll …
Browse files Browse the repository at this point in the history
…up using materialized view
  • Loading branch information
seawinde committed Dec 12, 2023
1 parent a2cc509 commit fae6000
Show file tree
Hide file tree
Showing 19 changed files with 553 additions and 177 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import java.util.List;
import java.util.stream.Collectors;
Expand All @@ -43,13 +42,10 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
SlotMapping queryToViewSlotMapping,
Plan tempRewritedPlan,
MaterializationContext materializationContext) {

List<? extends Expression> queryShuttleExpression = ExpressionUtils.shuttleExpressionWithLineage(
queryStructInfo.getExpressions(),
queryStructInfo.getOriginalPlan());
// Rewrite top projects, represent the query projects by view
List<Expression> expressionsRewritten = rewriteExpression(
queryShuttleExpression,
queryStructInfo.getExpressions(),
queryStructInfo.getOriginalPlan(),
materializationContext.getMvExprToMvScanExprMapping(),
queryToViewSlotMapping,
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
Expand Down Expand Up @@ -127,6 +128,7 @@ protected List<Plan> rewrite(Plan queryPlan, CascadesContext cascadesContext) {
// Try to rewrite compensate predicates by using mv scan
List<Expression> rewriteCompensatePredicates = rewriteExpression(
compensatePredicates.toList(),
queryPlan,
materializationContext.getMvExprToMvScanExprMapping(),
queryToViewSlotMapping,
true);
Expand Down Expand Up @@ -164,20 +166,23 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
}

/**
* Use target expression to represent the source expression.
* Visit the source expression, try to replace the source expression with target expression, if found then
* replace the source expression by target expression map value.
* Note: make the target expression map key to source based according to targetNeedToQueryBased,
* if targetNeedToQueryBased is true, we should not make it source based.
* Use target expression to represent the source expression. Visit the source expression,
* try to replace the source expression with target expression in targetExpressionMapping, if found then
* replace the source expression by target expression mapping value.
* Note: make the target expression map key to source based according to targetExpressionNeedSourceBased,
* if targetExpressionNeedSourceBased is true, we should make it source based.
* the key expression in targetExpressionMapping should be shuttled. with the method
* ExpressionUtils.shuttleExpressionWithLineage.
*/
protected List<Expression> rewriteExpression(
List<? extends Expression> sourceExpressionsToWrite,
Plan sourcePlan,
ExpressionMapping targetExpressionMapping,
SlotMapping sourceToTargetMapping,
boolean targetNeedToQueryBased) {
// Firstly, rewrite the target plan output expression using query with inverse mapping
// then try to use the mv expression to represent the query. if any of source expressions
// can not be represented by mv, return null
boolean targetExpressionNeedSourceBased) {
// Firstly, rewrite the target expression using source with inverse mapping
// then try to use the target expression to represent the query. if any of source expressions
// can not be represented by target expressions, return null.
//
// example as following:
// source target
Expand All @@ -187,35 +192,58 @@ protected List<Expression> rewriteExpression(
// transform source to:
// project(slot 2, 1)
// target
// generate mvSql to mvScan targetExpressionMapping, and change mv sql expression to query based
ExpressionMapping expressionMappingKeySourceBased = targetNeedToQueryBased
? targetExpressionMapping : targetExpressionMapping.keyPermute(sourceToTargetMapping.inverse());
// generate target to target replacement expression mapping, and change target expression to source based
List<? extends Expression> sourceShuttledExpressions =
ExpressionUtils.shuttleExpressionWithLineage(sourceExpressionsToWrite, sourcePlan);
ExpressionMapping expressionMappingKeySourceBased = targetExpressionNeedSourceBased
? targetExpressionMapping.keyPermute(sourceToTargetMapping.inverse()) : targetExpressionMapping;
// target to target replacement expression mapping, because mv is 1:1 so get first element
List<Map<Expression, Expression>> flattenExpressionMap =
expressionMappingKeySourceBased.flattenMap();
// view to view scan expression is 1:1 so get first element
Map<? extends Expression, ? extends Expression> mvSqlToMvScanMappingQueryBased = flattenExpressionMap.get(0);
Map<? extends Expression, ? extends Expression> targetToTargetReplacementMapping = flattenExpressionMap.get(0);

List<Expression> rewrittenExpressions = new ArrayList<>();
for (Expression expressionToRewrite : sourceExpressionsToWrite) {
for (int index = 0; index < sourceShuttledExpressions.size(); index++) {
Expression expressionToRewrite = sourceShuttledExpressions.get(index);
if (expressionToRewrite instanceof Literal) {
rewrittenExpressions.add(expressionToRewrite);
continue;
}
final Set<Object> slotsToRewrite =
expressionToRewrite.collectToSet(expression -> expression instanceof Slot);
boolean wiAlias = expressionToRewrite instanceof NamedExpression;
Expression replacedExpression = ExpressionUtils.replace(expressionToRewrite,
mvSqlToMvScanMappingQueryBased,
wiAlias);
targetToTargetReplacementMapping);
if (replacedExpression.anyMatch(slotsToRewrite::contains)) {
// if contains any slot to rewrite, which means can not be rewritten by target, bail out
return ImmutableList.of();
}
Expression sourceExpression = sourceExpressionsToWrite.get(index);
if (sourceExpression instanceof NamedExpression) {
NamedExpression sourceNamedExpression = (NamedExpression) sourceExpression;
replacedExpression = new Alias(sourceNamedExpression.getExprId(), replacedExpression,
sourceNamedExpression.getName());
}
rewrittenExpressions.add(replacedExpression);
}
return rewrittenExpressions;
}

protected Expression rewriteExpression(
Expression sourceExpressionsToWrite,
Plan sourcePlan,
ExpressionMapping targetExpressionMapping,
SlotMapping sourceToTargetMapping,
boolean targetExpressionNeedSourceBased) {
List<Expression> expressionToRewrite = new ArrayList<>();
expressionToRewrite.add(sourceExpressionsToWrite);
List<Expression> rewrittenExpressions = rewriteExpression(expressionToRewrite, sourcePlan,
targetExpressionMapping, sourceToTargetMapping, targetExpressionNeedSourceBased);
if (rewrittenExpressions.isEmpty()) {
return null;
}
return rewrittenExpressions.get(0);
}

/**
* Compensate mv predicates by query predicates, compensate predicate result is query based.
* Such as a > 5 in mv, and a > 10 in query, the compensatory predicate is a > 10.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.doris.nereids.PlannerHook;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.TableCollector;
Expand Down Expand Up @@ -90,19 +91,22 @@ private void initMaterializationContext(CascadesContext cascadesContext) {
.getDbOrMetaException(mvBaseTableInfo.getDbId())
.getTableOrMetaException(mvBaseTableInfo.getTableId(), TableType.MATERIALIZED_VIEW);

String qualifiedName = materializedView.getQualifiedName();
// generate outside, maybe add partition filter in the future
Plan mvScan = new LogicalOlapScan(cascadesContext.getStatementContext().getNextRelationId(),
LogicalOlapScan mvScan = new LogicalOlapScan(
cascadesContext.getStatementContext().getNextRelationId(),
(OlapTable) materializedView,
ImmutableList.of(qualifiedName),
Lists.newArrayList(materializedView.getId()),
ImmutableList.of(materializedView.getQualifiedDbName()),
// this must be empty, or it will be used to sample
Lists.newArrayList(),
Lists.newArrayList(),
Optional.empty());
mvScan = mvScan.withMaterializedIndexSelected(PreAggStatus.on(), materializedView.getBaseIndexId());
List<NamedExpression> mvProjects = mvScan.getOutput().stream().map(NamedExpression.class::cast)
.collect(Collectors.toList());
mvScan = new LogicalProject<Plan>(mvProjects, mvScan);
// todo should force keep consistency to mv sql plan output
Plan projectScan = new LogicalProject<Plan>(mvProjects, mvScan);
cascadesContext.addMaterializationContext(
MaterializationContext.fromMaterializedView(materializedView, mvScan, cascadesContext));
MaterializationContext.fromMaterializedView(materializedView, projectScan, cascadesContext));
} catch (MetaNotFoundException metaNotFoundException) {
LOG.error(mvBaseTableInfo.toString() + " can not find corresponding materialized view.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.GroupId;
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.ExpressionUtils;

Expand All @@ -32,7 +31,6 @@
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Maintain the context for query rewrite by materialized view
Expand Down Expand Up @@ -67,14 +65,11 @@ public MaterializationContext(MTMV mtmv,
mvCache = MVCache.from(mtmv, cascadesContext.getConnectContext());
mtmv.setMvCache(mvCache);
}
List<NamedExpression> mvOutputExpressions = mvCache.getMvOutputExpressions();
// mv output expression shuttle, this will be used to expression rewrite
mvOutputExpressions =
ExpressionUtils.shuttleExpressionWithLineage(mvOutputExpressions, mvCache.getLogicalPlan()).stream()
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
this.mvExprToMvScanExprMapping = ExpressionMapping.generate(
mvOutputExpressions,
ExpressionUtils.shuttleExpressionWithLineage(
mvCache.getMvOutputExpressions(),
mvCache.getLogicalPlan()),
mvScanPlan.getExpressions());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableList;

Expand All @@ -37,8 +38,8 @@ public class MaterializedViewProjectAggregateRule extends AbstractMaterializedVi
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(any()).thenApplyMulti(ctx -> {
LogicalAggregate<Plan> root = ctx.root;
logicalProject(logicalAggregate(any())).thenApplyMulti(ctx -> {
LogicalProject<LogicalAggregate<Plan>> root = ctx.root;
return rewrite(root, ctx.cascadesContext);
}).toRule(RuleType.MATERIALIZED_VIEW_PROJECT_AGGREGATE, RulePromise.EXPLORE));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,15 @@ default boolean anyMatch(Predicate<TreeNode<NODE_TYPE>> predicate) {
/**
* iterate top down and test predicate if any matched. Top-down traverse implicitly.
* @param predicate predicate
* @return true if all predicate return true
* @return the first node which match the predicate
*/
default TreeNode<NODE_TYPE> firstMatch(Predicate<TreeNode<NODE_TYPE>> predicate) {
if (!predicate.test(this)) {
if (predicate.test(this)) {
return this;
}
for (NODE_TYPE child : children()) {
if (!child.anyMatch(predicate)) {
return this;
if (child.anyMatch(predicate)) {
return child;
}
}
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public boolean isDistinct() {
return distinct;
}

public AggregateFunction getRollup() {
public Class<? extends AggregateFunction> getRollup() {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public List<FunctionSignature> getSignatures() {
}

@Override
public AggregateFunction getRollup() {
return this;
public Class<? extends AggregateFunction> getRollup() {
return Sum.class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ private void getColumns(Plan plan) {
colNames.add(colName);
}
columns.add(new ColumnDefinition(
colName, slots.get(i).getDataType(), true,
colName, slots.get(i).getDataType(), slots.get(i).nullable(),
CollectionUtils.isEmpty(simpleColumnDefinitions) ? null
: simpleColumnDefinitions.get(i).getComment()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
@Override
public List<? extends Expression> getExpressions() {
return new ImmutableList.Builder<Expression>()
.addAll(groupByExpressions)
.addAll(outputExpressions)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,10 @@ public ExpressionReplaceContext(List<Expression> targetExpressions,
this.tableIdentifiers = tableIdentifiers;
// collect the named expressions used in target expression and will be replaced later
this.exprIdExpressionMap = targetExpressions.stream()
.map(each -> {
// if Alias, shuttle form the child of alias to retain the alias
if (each instanceof Alias && !each.children().isEmpty()) {
return each.child(0).collectToList(NamedExpression.class::isInstance);
}
return each.collectToList(NamedExpression.class::isInstance);
})
.map(each -> each.collectToList(NamedExpression.class::isInstance))
.flatMap(Collection::stream)
.map(NamedExpression.class::cast)
.distinct()
.collect(Collectors.toMap(NamedExpression::getExprId, expr -> expr));
}

Expand Down
Loading

0 comments on commit fae6000

Please sign in to comment.