Skip to content

Commit

Permalink
support agg rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
seawinde committed Dec 12, 2023
1 parent 8af38fb commit a2cc509
Show file tree
Hide file tree
Showing 13 changed files with 285 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,61 +17,186 @@

package org.apache.doris.nereids.rules.exploration.mv;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import static org.apache.doris.nereids.rules.exploration.mv.StructInfo.AGGREGATE_PATTERN_CHECKER;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* AbstractMaterializedViewAggregateRule
* This is responsible for common aggregate rewriting
* */
*/
public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule {

@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping queryToViewSlotMappings,
SlotMapping queryToViewSlotMapping,
Plan tempRewritedPlan,
MaterializationContext materializationContext) {

PlanSplitContext planSplitContext = new PlanSplitContext(Sets.newHashSet(LogicalAggregate.class));
viewStructInfo.getTopPlan().accept(StructInfo.PLAN_SPLITTER, planSplitContext);
// generate aggregate in mv and mv output expression mapping
LogicalAggregate<Plan> bottomAggregate = (LogicalAggregate<Plan>) planSplitContext.getBottomPlan().get(0);
Plan topPlan = planSplitContext.getTopPlan();
ExpressionMapping aggregateToTopExpressionMapping = generateAggregateToTopMapping(bottomAggregate, topPlan);
// get view and query aggregate and top plan correspondingly
Pair<Plan, LogicalAggregate<Plan>> viewTopPlanAndAggPair = splitToTopPlanAndAggregate(viewStructInfo);
if (viewTopPlanAndAggPair == null) {
return null;
}
Pair<Plan, LogicalAggregate<Plan>> queryTopPlanAndAggPair = splitToTopPlanAndAggregate(queryStructInfo);
if (queryTopPlanAndAggPair == null) {
return null;
}

return null;
// Firstly, handle query group by expression rewrite
LogicalAggregate<Plan> queryAggregate = queryTopPlanAndAggPair.value();
Plan queryTopPlan = queryTopPlanAndAggPair.key();
// query and view have the same dimension, try to rewrite rewrittenQueryGroupExpr
LogicalAggregate<Plan> viewAggregate = viewTopPlanAndAggPair.value();
boolean needRollUp =
queryAggregate.getGroupByExpressions().size() != viewAggregate.getGroupByExpressions().size();
if (queryAggregate.getGroupByExpressions().size() == viewAggregate.getGroupByExpressions().size()) {
// todo consider alias
List<Expression> viewGroupByExpressionQueryBased = ExpressionUtils.replace(
viewAggregate.getGroupByExpressions(),
queryToViewSlotMapping.inverse().toSlotReferenceMap());
needRollUp = !queryAggregate.getGroupByExpressions().equals(viewGroupByExpressionQueryBased);
}
if (!needRollUp) {
List<? extends Expression> queryShuttledExpressions = ExpressionUtils.shuttleExpressionWithLineage(
queryTopPlan.getOutput(), queryTopPlan);
List<Expression> rewrittenQueryGroupExpr = rewriteExpression(queryShuttledExpressions,
materializationContext.getMvExprToMvScanExprMapping(),
queryToViewSlotMapping,
true);
if (rewrittenQueryGroupExpr == null) {
// can not rewrite, bail out.
return null;
}
return new LogicalProject<>(
rewrittenQueryGroupExpr.stream().map(NamedExpression.class::cast).collect(Collectors.toList()),
tempRewritedPlan);
}
// the dimension in query and view are different, try to roll up
// Split query aggregate dimension and agg function List<Expression> needPullUpExpression = new ArrayList<>();
// Firstly, find the query top output rewriteFunctionExprList which only use query aggregate function,
if (viewAggregate.getOutputExpressions().stream().anyMatch(
viewExpr -> viewExpr.anyMatch(expr -> expr instanceof AggregateFunction
&& ((AggregateFunction) expr).isDistinct())
)) {
// if mv function contains distinct, can not roll up.
return null;
}
Set<Expression> queryAggGroupSet = new HashSet<>(queryAggregate.getGroupByExpressions());
List<NamedExpression> queryAggFunctions = queryAggregate.getOutputExpressions().stream()
.filter(expr -> !queryAggGroupSet.contains(expr))
.collect(Collectors.toList());
Set<Expression> queryAggFunctionSet = new HashSet<>(queryAggFunctions);
Pair<List<? extends Expression>, List<? extends Expression>> queryGroupAndFunctionPair
= splitToGroupAndFunction(
queryTopPlanAndAggPair,
queryAggFunctionSet);
// filter the expression which use the child agg function in query top plan, only support to reference the
// aggregate function directly, will support expression later.
List<? extends Expression> queryTopPlanFunctionList = queryGroupAndFunctionPair.value();
if (queryTopPlanFunctionList.stream().anyMatch(
topAggFunc -> !(topAggFunc instanceof NamedExpression)
&& (!queryAggFunctionSet.contains(topAggFunc)
|| !queryAggFunctionSet.contains(topAggFunc.child(0))))) {
return null;
}
// Secondly, try to roll up the agg functions and add aggregate
Multimap<Expression, Expression> needRollupFunctionExprMap = HashMultimap.create();
Map<Expression, Expression> mvExprToMvScanExprQueryBased =
materializationContext.getMvExprToMvScanExprMapping().keyPermute(
queryToViewSlotMapping.inverse()).flattenMap().get(0);
for (Expression needRollUpExpr : queryTopPlanFunctionList) {
Expression needRollupShuttledExpr = ExpressionUtils.shuttleExpressionWithLineage(needRollUpExpr,
queryTopPlan);
if (!mvExprToMvScanExprQueryBased.containsKey(needRollupShuttledExpr)) {
// function can not rewrite by view
return null;
}
AggregateFunction aggregateFunction = (AggregateFunction) needRollUpExpr.firstMatch(
expr -> expr instanceof AggregateFunction);
AggregateFunction rollup = aggregateFunction.getRollup();
if (rollup == null) {
return null;
}
// key is query need roll up expr, value is mv scan based roll up expr
needRollupFunctionExprMap.put(needRollUpExpr,
rollup.withChildren(mvExprToMvScanExprQueryBased.get(needRollupShuttledExpr)));
}
// query group rewrite
Multimap<Expression, Expression> groupRewrittenExprMap = HashMultimap.create();
List<? extends Expression> queryTopPlanGroupExprList = queryGroupAndFunctionPair.key();
for (Expression needRewriteGroupExpr : queryTopPlanGroupExprList) {
Expression queryGroupShuttledExpr =
ExpressionUtils.shuttleExpressionWithLineage(needRewriteGroupExpr, queryTopPlan);
if (!mvExprToMvScanExprQueryBased.containsKey(queryGroupShuttledExpr)) {
// group expr can not rewrite by view
return null;
}
groupRewrittenExprMap.put(needRewriteGroupExpr, mvExprToMvScanExprQueryBased.get(queryGroupShuttledExpr));
}
// rewrite expression for group and function expr
List<Expression> rewriteFunctionExprList = rewriteExpression(queryTopPlanFunctionList,
new ExpressionMapping(needRollupFunctionExprMap),
queryToViewSlotMapping,
true);
if (rewriteFunctionExprList == null) {
return null;
}
List<Expression> rewriteGroupExprList = rewriteExpression(queryTopPlanGroupExprList,
new ExpressionMapping(groupRewrittenExprMap),
queryToViewSlotMapping,
true);
if (rewriteGroupExprList == null) {
return null;
}
// project rewrite
return new LogicalAggregate(rewriteGroupExprList, rewriteFunctionExprList, tempRewritedPlan);
}

private ExpressionMapping generateAggregateToTopMapping(Plan source, Plan target) {
ImmutableMultimap.Builder<Slot, Slot> expressionMappingBuilder = ImmutableMultimap.builder();
List<Slot> sourceOutput = source.getOutput();
List<Slot> targetOutputOutput = target.getOutput();
for (Slot sourceSlot : sourceOutput) {
for (Slot targetSlot : targetOutputOutput) {
if (sourceSlot.equals(targetSlot)) {
expressionMappingBuilder.put(targetSlot, sourceSlot);
}
}
private Pair<List<? extends Expression>, List<? extends Expression>> splitToGroupAndFunction(
Pair<Plan, LogicalAggregate<Plan>> topPlanAndAggPair,
Set<? extends Expression> queryAggGroupFunctionSet) {
Plan queryTopPlan = topPlanAndAggPair.key();
Map<Boolean, ? extends List<? extends Expression>> groupByAndFuncitonMap = queryTopPlan.getExpressions()
.stream()
.collect(Collectors.partitioningBy(expression -> expression.anyMatch(expr ->
expr instanceof NamedExpression && queryAggGroupFunctionSet.contains((NamedExpression) expr))));
return Pair.of(groupByAndFuncitonMap.get(false), groupByAndFuncitonMap.get(true));
}

private Pair<Plan, LogicalAggregate<Plan>> splitToTopPlanAndAggregate(StructInfo structInfo) {
Plan topPlan = structInfo.getTopPlan();
PlanSplitContext splitContext = new PlanSplitContext(Sets.newHashSet(LogicalAggregate.class));
topPlan.accept(StructInfo.PLAN_SPLITTER, splitContext);
if (!(splitContext.getBottomPlan() instanceof LogicalAggregate)) {
return null;
} else {
return Pair.of(topPlan, (LogicalAggregate<Plan>) splitContext.getBottomPlan());
}
return new ExpressionMapping(expressionMappingBuilder.build());
}

// Check Aggregate is simple or not and check join is whether valid or not.
Expand All @@ -81,20 +206,19 @@ private ExpressionMapping generateAggregateToTopMapping(Plan source, Plan target
protected boolean checkPattern(StructInfo structInfo) {

Plan topPlan = structInfo.getTopPlan();
Boolean valid = topPlan.accept(AGGREGATE_PATTERN_CHECKER, null);
Boolean valid = topPlan.accept(StructInfo.AGGREGATE_PATTERN_CHECKER, null);
if (!valid) {
return false;
}
HyperGraph hyperGraph = structInfo.getHyperGraph();
HashSet<JoinType> requiredJoinType = Sets.newHashSet(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN);
for (AbstractNode node : hyperGraph.getNodes()) {
StructInfoNode structInfoNode = (StructInfoNode) node;
if (!structInfoNode.getPlan().accept(StructInfo.JOIN_PATTERN_CHECKER,
requiredJoinType)) {
SUPPORTED_JOIN_TYPE_SET)) {
return false;
}
for (Edge edge : hyperGraph.getEdges()) {
if (!edge.getJoin().accept(StructInfo.JOIN_PATTERN_CHECKER, requiredJoinType)) {
if (!edge.getJoin().accept(StructInfo.JOIN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.Sets;

import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -40,14 +36,11 @@
* This is responsible for common join rewriting
*/
public abstract class AbstractMaterializedViewJoinRule extends AbstractMaterializedViewRule {
private static final HashSet<JoinType> SUPPORTED_JOIN_TYPE_SET =
Sets.newHashSet(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN);

@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping queryToViewSlotMappings,
SlotMapping queryToViewSlotMapping,
Plan tempRewritedPlan,
MaterializationContext materializationContext) {

Expand All @@ -57,8 +50,9 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
// Rewrite top projects, represent the query projects by view
List<Expression> expressionsRewritten = rewriteExpression(
queryShuttleExpression,
materializationContext.getViewExpressionIndexMapping(),
queryToViewSlotMappings
materializationContext.getMvExprToMvScanExprMapping(),
queryToViewSlotMapping,
true
);
// Can not rewrite, bail out
if (expressionsRewritten.isEmpty()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
Expand All @@ -52,6 +53,9 @@
*/
public abstract class AbstractMaterializedViewRule {

public static final HashSet<JoinType> SUPPORTED_JOIN_TYPE_SET =
Sets.newHashSet(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN);

/**
* The abstract template method for query rewrite, it contains the main logic and different query
* pattern should override the sub logic.
Expand Down Expand Up @@ -123,8 +127,9 @@ protected List<Plan> rewrite(Plan queryPlan, CascadesContext cascadesContext) {
// Try to rewrite compensate predicates by using mv scan
List<Expression> rewriteCompensatePredicates = rewriteExpression(
compensatePredicates.toList(),
materializationContext.getViewExpressionIndexMapping(),
queryToViewSlotMapping);
materializationContext.getMvExprToMvScanExprMapping(),
queryToViewSlotMapping,
true);
if (rewriteCompensatePredicates.isEmpty()) {
continue;
}
Expand Down Expand Up @@ -152,19 +157,24 @@ protected List<Plan> rewrite(Plan queryPlan, CascadesContext cascadesContext) {
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping queryToViewSlotMappings,
SlotMapping queryToViewSlotMapping,
Plan tempRewritedPlan,
MaterializationContext materializationContext) {
return tempRewritedPlan;
}

/**
* Use target output expression to represent the source expression
* Use target expression to represent the source expression.
* Visit the source expression, try to replace the source expression with target expression, if found then
* replace the source expression by target expression map value.
* Note: make the target expression map key to source based according to targetNeedToQueryBased,
* if targetNeedToQueryBased is true, we should not make it source based.
*/
protected List<Expression> rewriteExpression(
List<? extends Expression> sourceExpressionsToWrite,
ExpressionMapping mvExpressionToMvScanExpressionMapping,
SlotMapping sourceToTargetMapping) {
ExpressionMapping targetExpressionMapping,
SlotMapping sourceToTargetMapping,
boolean targetNeedToQueryBased) {
// Firstly, rewrite the target plan output expression using query with inverse mapping
// then try to use the mv expression to represent the query. if any of source expressions
// can not be represented by mv, return null
Expand All @@ -177,10 +187,10 @@ protected List<Expression> rewriteExpression(
// transform source to:
// project(slot 2, 1)
// target
// generate mvSql to mvScan mvExpressionToMvScanExpressionMapping, and change mv sql expression to query based
ExpressionMapping expressionMappingKeySourceBased =
mvExpressionToMvScanExpressionMapping.keyPermute(sourceToTargetMapping.inverse());
List<Map<? extends Expression, ? extends Expression>> flattenExpressionMap =
// generate mvSql to mvScan targetExpressionMapping, and change mv sql expression to query based
ExpressionMapping expressionMappingKeySourceBased = targetNeedToQueryBased
? targetExpressionMapping : targetExpressionMapping.keyPermute(sourceToTargetMapping.inverse());
List<Map<Expression, Expression>> flattenExpressionMap =
expressionMappingKeySourceBased.flattenMap();
// view to view scan expression is 1:1 so get first element
Map<? extends Expression, ? extends Expression> mvSqlToMvScanMappingQueryBased = flattenExpressionMap.get(0);
Expand Down
Loading

0 comments on commit a2cc509

Please sign in to comment.