Skip to content

Commit

Permalink
[feature](nereids) Fix data wrong using mv rewrite and ignore case wh…
Browse files Browse the repository at this point in the history
…en getting mv related partition table
  • Loading branch information
seawinde committed Dec 20, 2023
1 parent 3296487 commit 0d1e928
Show file tree
Hide file tree
Showing 20 changed files with 358 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public MTMVCache(Plan logicalPlan, List<NamedExpression> mvOutputExpressions) {

public static MTMVCache from(MTMV mtmv, ConnectContext connectContext) {
LogicalPlan unboundMvPlan = new NereidsParser().parseSingle(mtmv.getQuerySql());
// TODO: connect context set current db when create mv by use database
// this will be removed in the future when support join derivation
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES, ELIMINATE_OUTER_JOIN");
StatementContext mvSqlStatementContext = new StatementContext(connectContext,
new OriginStatement(mtmv.getQuerySql(), 0));
NereidsPlanner planner = new NereidsPlanner(mvSqlStatementContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ public <OUTPUT_TYPE extends Plan> PatternMatcher<INPUT_TYPE, OUTPUT_TYPE> thenAp
return new PatternMatcher<>(pattern, defaultPromise, matchedAction);
}

/**
* Apply rule to return multi result, catch exception to make sure no influence on other rule
*/
public <OUTPUT_TYPE extends Plan> PatternMatcher<INPUT_TYPE, OUTPUT_TYPE> thenApplyMultiNoThrow(
MatchedMultiAction<INPUT_TYPE, OUTPUT_TYPE> matchedMultiAction) {
MatchedMultiAction<INPUT_TYPE, OUTPUT_TYPE> adaptMatchedMultiAction = ctx -> {
try {
return matchedMultiAction.apply(ctx);
} catch (Exception ex) {
LOG.warn("nereids apply rule failed, because {}", ex.getMessage(), ex);
return null;
}
};
return new PatternMatcher<>(pattern, defaultPromise, adaptMatchedMultiAction);
}

public Pattern<INPUT_TYPE> getPattern() {
return pattern;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ protected List<Plan> rewrite(Plan queryPlan, CascadesContext cascadesContext) {
logger.info(currentClassName + " predicate compensate fail so continue");
continue;
}
Plan rewritedPlan;
Plan rewrittenPlan;
Plan mvScan = materializationContext.getMvScanPlan();
if (compensatePredicates.isAlwaysTrue()) {
rewritedPlan = mvScan;
rewrittenPlan = mvScan;
} else {
// Try to rewrite compensate predicates by using mv scan
List<Expression> rewriteCompensatePredicates = rewriteExpression(
Expand All @@ -175,36 +175,48 @@ protected List<Plan> rewrite(Plan queryPlan, CascadesContext cascadesContext) {
logger.info(currentClassName + " compensate predicate rewrite by view fail so continue");
continue;
}
rewritedPlan = new LogicalFilter<>(Sets.newHashSet(rewriteCompensatePredicates), mvScan);
rewrittenPlan = new LogicalFilter<>(Sets.newHashSet(rewriteCompensatePredicates), mvScan);
}
// Rewrite query by view
rewritedPlan = rewriteQueryByView(matchMode,
rewrittenPlan = rewriteQueryByView(matchMode,
queryStructInfo,
viewStructInfo,
queryToViewSlotMapping,
rewritedPlan,
rewrittenPlan,
materializationContext);
if (rewritedPlan == null) {
if (rewrittenPlan == null) {
logger.info(currentClassName + " rewrite query by view fail so continue");
continue;
}
if (!checkPartitionIsValid(queryStructInfo, materializationContext, cascadesContext)) {
logger.info(currentClassName + " check partition validation fail so continue");
continue;
}
if (!checkOutput(queryPlan, rewrittenPlan)) {
continue;
}
// run rbo job on mv rewritten plan
CascadesContext rewrittenPlanContext =
CascadesContext.initContext(cascadesContext.getStatementContext(), rewritedPlan,
CascadesContext.initContext(cascadesContext.getStatementContext(), rewrittenPlan,
cascadesContext.getCurrentJobContext().getRequiredProperties());
Rewriter.getWholeTreeRewriter(cascadesContext).execute();
rewritedPlan = rewrittenPlanContext.getRewritePlan();
rewrittenPlan = rewrittenPlanContext.getRewritePlan();
logger.info(currentClassName + "rewrite by materialized view success");
rewriteResults.add(rewritedPlan);
rewriteResults.add(rewrittenPlan);
}
}
return rewriteResults;
}

protected boolean checkOutput(Plan sourcePlan, Plan rewrittenPlan) {
if (sourcePlan.getGroupExpression().isPresent() && !rewrittenPlan.getLogicalProperties().equals(
sourcePlan.getGroupExpression().get().getOwnerGroup().getLogicalProperties())) {
logger.error("rewrittenPlan output logical properties is not same with target group");
return false;
}
return true;
}

/**
* Partition will be pruned in query then add the record the partitions to select partitions on
* catalog relation.
Expand Down Expand Up @@ -279,7 +291,7 @@ protected boolean checkPartitionIsValid(
private MTMVCache getCacheFromMTMV(MTMV mtmv, CascadesContext cascadesContext) {
MTMVCache cache;
try {
cache = mtmv.getOrGenerateCache(cascadesContext.getConnectContext());
cache = mtmv.getOrGenerateCache();
} catch (AnalysisException analysisException) {
logger.warn(this.getClass().getSimpleName() + " get mtmv cache analysisException", analysisException);
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public MaterializationContext(MTMV mtmv,

MTMVCache mtmvCache = null;
try {
mtmvCache = mtmv.getOrGenerateCache(cascadesContext.getConnectContext());
mtmvCache = mtmv.getOrGenerateCache();
} catch (AnalysisException e) {
LOG.warn("MaterializationContext init mv cache generate fail", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class MaterializedViewAggregateRule extends AbstractMaterializedViewAggre
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(any()).thenApplyMulti(ctx -> {
logicalAggregate(any()).thenApplyMultiNoThrow(ctx -> {
LogicalAggregate<Plan> root = ctx.root;
return rewrite(root, ctx.cascadesContext);
}).toRule(RuleType.MATERIALIZED_VIEW_ONLY_AGGREGATE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class MaterializedViewProjectAggregateRule extends AbstractMaterializedVi
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalProject(logicalAggregate(any())).thenApplyMulti(ctx -> {
logicalProject(logicalAggregate(any())).thenApplyMultiNoThrow(ctx -> {
LogicalProject<LogicalAggregate<Plan>> root = ctx.root;
return rewrite(root, ctx.cascadesContext);
}).toRule(RuleType.MATERIALIZED_VIEW_PROJECT_AGGREGATE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class MaterializedViewProjectJoinRule extends AbstractMaterializedViewJoi
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalProject(logicalJoin(any(), any())).thenApplyMulti(ctx -> {
logicalProject(logicalJoin(any(), any())).thenApplyMultiNoThrow(ctx -> {
LogicalProject<LogicalJoin<Plan, Plan>> root = ctx.root;
return rewrite(root, ctx.cascadesContext);
}).toRule(RuleType.MATERIALIZED_VIEW_PROJECT_JOIN));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public static Optional<RelatedTableInfo> getRelatedTableInfo(String column, Plan
Slot columnExpr = null;
// get column slot
for (Slot outputSlot : outputExpressions) {
if (outputSlot.getName().equals(column)) {
if (outputSlot.getName().equalsIgnoreCase(column)) {
columnExpr = outputSlot;
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
Expand All @@ -34,7 +35,7 @@
public class Predicates {

// Predicates that can be pulled up
private final List<Expression> pulledUpPredicates = new ArrayList<>();
private final Set<Expression> pulledUpPredicates = new HashSet<>();

private Predicates() {
}
Expand All @@ -49,7 +50,7 @@ public static Predicates of(List<? extends Expression> pulledUpPredicates) {
return predicates;
}

public List<? extends Expression> getPulledUpPredicates() {
public Set<? extends Expression> getPulledUpPredicates() {
return pulledUpPredicates;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* Split the expression to equal, range and residual predicate.
Expand All @@ -39,27 +38,26 @@
*/
public class PredicatesSplitter {

private final List<Expression> equalPredicates = new ArrayList<>();
private final List<Expression> rangePredicates = new ArrayList<>();
private final List<Expression> residualPredicates = new ArrayList<>();
private final Set<Expression> equalPredicates = new HashSet<>();
private final Set<Expression> rangePredicates = new HashSet<>();
private final Set<Expression> residualPredicates = new HashSet<>();
private final List<Expression> conjunctExpressions;

private final PredicateExtract instance = new PredicateExtract();

public PredicatesSplitter(Expression target) {
this.conjunctExpressions = ExpressionUtils.extractConjunction(target);
PredicateExtract instance = new PredicateExtract();
for (Expression expression : conjunctExpressions) {
expression.accept(instance, expression);
expression.accept(instance, null);
}
}

/**
* PredicateExtract
* extract to equal, range, residual predicate set
*/
public class PredicateExtract extends DefaultExpressionVisitor<Void, Expression> {
public class PredicateExtract extends DefaultExpressionVisitor<Void, Void> {

@Override
public Void visitComparisonPredicate(ComparisonPredicate comparisonPredicate, Expression sourceExpression) {
public Void visitComparisonPredicate(ComparisonPredicate comparisonPredicate, Void context) {
Expression leftArg = comparisonPredicate.getArgument(0);
Expression rightArg = comparisonPredicate.getArgument(1);
boolean leftArgOnlyContainsColumnRef = containOnlyColumnRef(leftArg, true);
Expand All @@ -69,7 +67,7 @@ public Void visitComparisonPredicate(ComparisonPredicate comparisonPredicate, Ex
equalPredicates.add(comparisonPredicate);
return null;
} else {
residualPredicates.add(comparisonPredicate);
rangePredicates.add(comparisonPredicate);
}
} else if ((leftArgOnlyContainsColumnRef && rightArg instanceof Literal)
|| (rightArgOnlyContainsColumnRef && leftArg instanceof Literal)) {
Expand All @@ -81,12 +79,9 @@ public Void visitComparisonPredicate(ComparisonPredicate comparisonPredicate, Ex
}

@Override
public Void visitCompoundPredicate(CompoundPredicate compoundPredicate, Expression context) {
if (compoundPredicate instanceof Or) {
residualPredicates.add(compoundPredicate);
return null;
}
return super.visitCompoundPredicate(compoundPredicate, context);
public Void visit(Expression expr, Void context) {
residualPredicates.add(expr);
return null;
}
}

Expand All @@ -98,7 +93,7 @@ public Predicates.SplitPredicate getSplitPredicate() {
}

private static boolean containOnlyColumnRef(Expression expression, boolean allowCast) {
if (expression instanceof SlotReference && ((SlotReference) expression).isColumnFromTable()) {
if (expression instanceof SlotReference && expression.isColumnFromTable()) {
return true;
}
if (allowCast && expression instanceof Cast) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ private void initPredicates() {
private void predicatesDerive() {
// construct equivalenceClass according to equals predicates
List<Expression> shuttledExpression = ExpressionUtils.shuttleExpressionWithLineage(
this.predicates.getPulledUpPredicates(), originalPlan).stream()
new ArrayList<>(this.predicates.getPulledUpPredicates()), originalPlan).stream()
.map(Expression.class::cast)
.collect(Collectors.toList());
SplitPredicate splitPredicate = Predicates.splitPredicates(ExpressionUtils.and(shuttledExpression));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions;

import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;

import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* This represents any expression, it means it equals any expression
*/
public class Any extends Expression {

public static final Any INSTANCE = new Any(ImmutableList.of());

private Any(Expression... children) {
super(children);
}

private Any(List<Expression> children) {
super(children);
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return null;
}

@Override
public boolean nullable() {
return false;
}

@Override
public boolean equals(Object o) {
return true;
}

@Override
public int hashCode() {
return 0;
}

@Override
public boolean deepEquals(TreeNode<?> that) {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.doris.mtmv.MTMVUtil;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.analyzer.UnboundResultSink;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewUtils;
Expand All @@ -54,6 +55,7 @@
import org.apache.doris.nereids.trees.plans.algebra.OneRowRelation;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
import org.apache.doris.nereids.trees.plans.visitor.NondeterministicFunctionCollector;
import org.apache.doris.nereids.trees.plans.visitor.TableCollector;
import org.apache.doris.nereids.trees.plans.visitor.TableCollector.TableCollectorContext;
Expand Down Expand Up @@ -199,7 +201,9 @@ private void analyzeProperties() {
public void analyzeQuery(ConnectContext ctx) {
// create table as select
NereidsPlanner planner = new NereidsPlanner(ctx.getStatementContext());
Plan plan = planner.plan(logicalQuery, PhysicalProperties.ANY, ExplainLevel.ALL_PLAN);
// this is for column infer
LogicalSink<Plan> logicalSink = new UnboundResultSink<>(logicalQuery);
Plan plan = planner.plan(logicalSink, PhysicalProperties.ANY, ExplainLevel.ALL_PLAN);
if (plan.anyMatch(node -> node instanceof OneRowRelation)) {
throw new AnalysisException("at least contain one table");
}
Expand Down Expand Up @@ -302,6 +306,8 @@ private void getColumns(Plan plan) {
if (!CollectionUtils.isEmpty(simpleColumnDefinitions) && simpleColumnDefinitions.size() != slots.size()) {
throw new AnalysisException("simpleColumnDefinitions size is not equal to the query's");
}
// slots = BindExpression.inferColumnNames(plan).stream()
// .map(NamedExpression::toSlot).collect(Collectors.toList());
Set<String> colNames = Sets.newHashSet();
for (int i = 0; i < slots.size(); i++) {
String colName = CollectionUtils.isEmpty(simpleColumnDefinitions) ? slots.get(i).getName()
Expand Down
Loading

0 comments on commit 0d1e928

Please sign in to comment.