Skip to content

Commit

Permalink
[opt](mtmv) Support mv rewrite when mv has date_trunc but query not
Browse files Browse the repository at this point in the history
  • Loading branch information
seawinde committed Dec 9, 2024
1 parent 0a73618 commit 67befe3
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

Expand Down Expand Up @@ -126,7 +127,8 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
queryTopPlan,
materializationContext.getShuttledExprToScanExprMapping(),
viewToQuerySlotMapping,
queryStructInfo.getTableBitSet());
queryStructInfo.getTableBitSet(),
ImmutableMap.of(), cascadesContext);
boolean isRewrittenQueryExpressionValid = true;
if (!rewrittenQueryExpressions.isEmpty()) {
List<NamedExpression> projects = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableMap;

import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -49,7 +51,8 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
queryStructInfo.getTopPlan(),
materializationContext.getShuttledExprToScanExprMapping(),
targetToSourceMapping,
queryStructInfo.getTableBitSet()
queryStructInfo.getTableBitSet(),
ImmutableMap.of(), cascadesContext
);
// Can not rewrite, bail out
if (expressionsRewritten.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.rules.rewrite.MergeProjects;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand All @@ -45,6 +47,7 @@
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DateTrunc;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NonNullable;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable;
Expand Down Expand Up @@ -242,7 +245,9 @@ protected List<Plan> doRewrite(StructInfo queryStructInfo, CascadesContext casca
// Try to rewrite compensate predicates by using mv scan
List<Expression> rewriteCompensatePredicates = rewriteExpression(compensatePredicates.toList(),
queryPlan, materializationContext.getShuttledExprToScanExprMapping(),
viewToQuerySlotMapping, queryStructInfo.getTableBitSet());
viewToQuerySlotMapping, queryStructInfo.getTableBitSet(),
compensatePredicates.getRangePredicateMap(),
cascadesContext);
if (rewriteCompensatePredicates.isEmpty()) {
materializationContext.recordFailReason(queryStructInfo,
"Rewrite compensate predicate by view fail",
Expand Down Expand Up @@ -560,7 +565,8 @@ protected Plan rewriteQueryByView(MatchMode matchMode, StructInfo queryStructInf
* then use the corresponding value of mapping to replace it
*/
protected List<Expression> rewriteExpression(List<? extends Expression> sourceExpressionsToWrite, Plan sourcePlan,
ExpressionMapping targetExpressionMapping, SlotMapping targetToSourceMapping, BitSet sourcePlanBitSet) {
ExpressionMapping targetExpressionMapping, SlotMapping targetToSourceMapping, BitSet sourcePlanBitSet,
Map<Expression, Literal> shuttledQueryMap, CascadesContext cascadesContext) {
// Firstly, rewrite the target expression using source with inverse mapping
// then try to use the target expression to represent the query. if any of source expressions
// can not be represented by target expressions, return null.
Expand All @@ -579,18 +585,62 @@ protected List<Expression> rewriteExpression(List<? extends Expression> sourceEx
rewrittenExpressions.add(expressionShuttledToRewrite);
continue;
}
final Set<Object> slotsToRewrite =
final Set<Expression> slotsToRewrite =
expressionShuttledToRewrite.collectToSet(expression -> expression instanceof Slot);

final Set<SlotReference> variants =
expressionShuttledToRewrite.collectToSet(expression -> expression instanceof SlotReference
&& ((SlotReference) expression).getDataType() instanceof VariantType);
&& ((SlotReference) expression).getDataType() instanceof VariantType);
extendMappingByVariant(variants, targetToTargetReplacementMappingQueryBased);
Expression replacedExpression = ExpressionUtils.replace(expressionShuttledToRewrite,
targetToTargetReplacementMappingQueryBased);
if (replacedExpression.anyMatch(slotsToRewrite::contains)) {
// if contains any slot to rewrite, which means can not be rewritten by target, bail out
return ImmutableList.of();
Set<Expression> replacedExpressionSlotQueryUsed = replacedExpression.collect(slotsToRewrite::contains);
if (!replacedExpressionSlotQueryUsed.isEmpty()) {
// if contains any slot to rewrite, which means can not be rewritten by target,
// expressionShuttledToRewrite is slot#0 > '2024-01-01' but mv plan output is date_trunc(slot#0, 'day')
// which would try to rewrite
// paramExpressionToDateTruncMap is {slot#0 : date_trunc(slot#0, 'day')}
Map<Expression, DateTrunc> paramExpressionToDateTruncMap = new HashMap<>();
targetToTargetReplacementMappingQueryBased.keySet().forEach(expr -> {
if (expr instanceof DateTrunc) {
paramExpressionToDateTruncMap.put(expr.child(0), (DateTrunc) expr);
}
});
Expression queryExpr = expressionShuttledToRewrite.child(0);
Map<Expression, Literal> shuttledQueryParamToExpressionMap = new HashMap<>();
// TODO: 2024/12/5 optimize performance
for (Map.Entry<Expression, Literal> expressionEntry : shuttledQueryMap.entrySet()) {
Expression shuttledQueryParamExpression = ExpressionUtils.shuttleExpressionWithLineage(
expressionEntry.getKey(), sourcePlan, sourcePlanBitSet);
shuttledQueryParamToExpressionMap.put(shuttledQueryParamExpression.child(0) instanceof Literal
? shuttledQueryParamExpression.child(1) : shuttledQueryParamExpression.child(0),
expressionEntry.getValue());
}

if (paramExpressionToDateTruncMap.isEmpty() || shuttledQueryMap.isEmpty()
|| !shuttledQueryMap.containsKey(expressionShuttledToRewrite)
|| !paramExpressionToDateTruncMap.containsKey(queryExpr)) {
// mv date_trunc expression can not offer expression for query,
// can not try to rewrite by date_trunc, bail out
return ImmutableList.of();
}

Map<Expression, Expression> datetruncMap = new HashMap<>();
Literal queryLiteral = shuttledQueryMap.get(expressionShuttledToRewrite);
datetruncMap.put(queryExpr, queryLiteral);
Expression replacedWithLiteral = ExpressionUtils.replace(
paramExpressionToDateTruncMap.get(queryExpr), datetruncMap);
Expression foldedExpressionWithLiteral = FoldConstantRuleOnFE.evaluate(replacedWithLiteral,
new ExpressionRewriteContext(cascadesContext));
if (foldedExpressionWithLiteral.equals(queryLiteral)) {
// after date_trunc simplify if equals to original expression, could rewritten by mv
replacedExpression = ExpressionUtils.replace(expressionShuttledToRewrite,
targetToTargetReplacementMappingQueryBased,
paramExpressionToDateTruncMap);
}
if (replacedExpression.anyMatch(slotsToRewrite::contains)) {
return ImmutableList.of();
}
}
rewrittenExpressions.add(replacedExpression);
}
Expand Down Expand Up @@ -758,7 +808,7 @@ protected SplitPredicate predicatesCompensate(
viewToQuerySlotMapping,
comparisonResult);
// range compensate
final Set<Expression> rangeCompensatePredicates = Predicates.compensateRangePredicate(
final Map<Expression, Literal> rangeCompensatePredicates = Predicates.compensateRangePredicate(
queryStructInfo,
viewStructInfo,
viewToQuerySlotMapping,
Expand All @@ -775,15 +825,17 @@ protected SplitPredicate predicatesCompensate(
return SplitPredicate.INVALID_INSTANCE;
}
if (equalCompensateConjunctions.stream().anyMatch(expr -> expr.containsType(AggregateFunction.class))
|| rangeCompensatePredicates.stream().anyMatch(expr -> expr.containsType(AggregateFunction.class))
|| rangeCompensatePredicates.keySet().stream()
.anyMatch(expr -> expr.containsType(AggregateFunction.class))
|| residualCompensatePredicates.stream().anyMatch(expr ->
expr.containsType(AggregateFunction.class))) {
return SplitPredicate.INVALID_INSTANCE;
}
return SplitPredicate.of(equalCompensateConjunctions.isEmpty() ? BooleanLiteral.TRUE
: ExpressionUtils.and(equalCompensateConjunctions),
rangeCompensatePredicates.isEmpty() ? BooleanLiteral.TRUE
: ExpressionUtils.and(rangeCompensatePredicates),
: ExpressionUtils.and(rangeCompensatePredicates.keySet()),
rangeCompensatePredicates.isEmpty() ? ImmutableMap.of() : rangeCompensatePredicates,
residualCompensatePredicates.isEmpty() ? BooleanLiteral.TRUE
: ExpressionUtils.and(residualCompensatePredicates));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import java.util.List;
Expand All @@ -50,7 +51,8 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
queryStructInfo.getTopPlan(),
materializationContext.getShuttledExprToScanExprMapping(),
targetToSourceMapping,
queryStructInfo.getTableBitSet()
queryStructInfo.getTableBitSet(),
ImmutableMap.of(), cascadesContext
);
// Can not rewrite, bail out
if (expressionsRewritten.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,24 @@
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -139,7 +145,7 @@ public static Set<Expression> compensateEquivalence(StructInfo queryStructInfo,
/**
* compensate range predicates
*/
public static Set<Expression> compensateRangePredicate(StructInfo queryStructInfo,
public static Map<Expression, Literal> compensateRangePredicate(StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping viewToQuerySlotMapping,
ComparisonResult comparisonResult,
Expand All @@ -159,7 +165,7 @@ public static Set<Expression> compensateRangePredicate(StructInfo queryStructInf
Sets.difference(viewRangeQueryBasedSet, queryRangeSet).copyInto(differentExpressions);
// the range predicate in query and view is same, don't need to compensate
if (differentExpressions.isEmpty()) {
return differentExpressions;
return ImmutableMap.of();
}
// try to normalize the different expressions
Set<Expression> normalizedExpressions =
Expand All @@ -168,7 +174,18 @@ public static Set<Expression> compensateRangePredicate(StructInfo queryStructInf
// normalized expressions is not in query, can not compensate
return null;
}
return normalizedExpressions;
Map<Expression, Literal> normalizedExpressionsWithLiteral = new HashMap<>();
for (Expression expression : normalizedExpressions) {
Set<Literal> literalSet = expression.collect(expressionTreeNode -> expressionTreeNode instanceof Literal);
if (!(expression instanceof ComparisonPredicate)
|| (expression instanceof GreaterThan || expression instanceof LessThan)
|| literalSet.size() != 1) {
normalizedExpressionsWithLiteral.put(expression, null);
continue;
}
normalizedExpressionsWithLiteral.put(expression, literalSet.iterator().next());
}
return normalizedExpressionsWithLiteral;
}

private static Set<Expression> normalizeExpression(Expression expression, CascadesContext cascadesContext) {
Expand Down Expand Up @@ -220,14 +237,19 @@ public String toString() {
*/
public static final class SplitPredicate {
public static final SplitPredicate INVALID_INSTANCE =
SplitPredicate.of(null, null, null);
SplitPredicate.of(null, null, null, null);
private final Optional<Expression> equalPredicate;
private final Optional<Expression> rangePredicate;
private final Optional<Map<Expression, Literal>> rangePredicateMap;
private final Optional<Expression> residualPredicate;

public SplitPredicate(Expression equalPredicate, Expression rangePredicate, Expression residualPredicate) {
public SplitPredicate(Expression equalPredicate,
Expression rangePredicate,
Map<Expression, Literal> rangePredicateMap,
Expression residualPredicate) {
this.equalPredicate = Optional.ofNullable(equalPredicate);
this.rangePredicate = Optional.ofNullable(rangePredicate);
this.rangePredicateMap = Optional.ofNullable(rangePredicateMap);
this.residualPredicate = Optional.ofNullable(residualPredicate);
}

Expand All @@ -239,6 +261,10 @@ public Expression getRangePredicate() {
return rangePredicate.orElse(BooleanLiteral.TRUE);
}

public Map<Expression, Literal> getRangePredicateMap() {
return rangePredicateMap.orElse(ImmutableMap.of());
}

public Expression getResidualPredicate() {
return residualPredicate.orElse(BooleanLiteral.TRUE);
}
Expand All @@ -248,8 +274,9 @@ public Expression getResidualPredicate() {
*/
public static SplitPredicate of(Expression equalPredicates,
Expression rangePredicates,
Map<Expression, Literal> rangePredicateSet,
Expression residualPredicates) {
return new SplitPredicate(equalPredicates, rangePredicates, residualPredicates);
return new SplitPredicate(equalPredicates, rangePredicates, rangePredicateSet, residualPredicates);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ public Predicates.SplitPredicate getSplitPredicate() {
return Predicates.SplitPredicate.of(
equalPredicates.isEmpty() ? null : ExpressionUtils.and(equalPredicates),
rangePredicates.isEmpty() ? null : ExpressionUtils.and(rangePredicates),
null,
residualPredicates.isEmpty() ? null : ExpressionUtils.and(residualPredicates));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,29 @@ public static Expression replace(Expression expr, Map<? extends Expression, ? ex
});
}

/**
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
* For example.
* <pre>
* input expression: a > 1
* replaceMap: a -> b + c
*
* output:
* b + c > 1
* </pre>
*/
public static Expression replace(Expression expr, Map<? extends Expression, ? extends Expression> replaceMap,
Map<? extends Expression, ? extends Expression> transferMap) {
return expr.rewriteDownShortCircuit(e -> {
Expression replacedExpr = replaceMap.get(e);
if (replacedExpr != null) {
return replacedExpr;
}
replacedExpr = replaceMap.get(transferMap.get(e));
return replacedExpr == null ? e : replacedExpr;
});
}

public static List<Expression> replace(List<Expression> exprs,
Map<? extends Expression, ? extends Expression> replaceMap) {
ImmutableList.Builder<Expression> result = ImmutableList.builderWithExpectedSize(exprs.size());
Expand Down

0 comments on commit 67befe3

Please sign in to comment.