Skip to content

Commit

Permalink
query rewrite support filter valid partition filter
Browse files Browse the repository at this point in the history
  • Loading branch information
seawinde committed Dec 17, 2023
1 parent a0f85f8 commit 8dcbdb4
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ private static long getTableMinVisibleVersionTime(OlapTable table) {
* @param relatedTable
* @return mv.partitionId ==> relatedTable.partitionId
*/
private static Map<Long, Set<Long>> getMvToBasePartitions(MTMV mtmv, OlapTable relatedTable)
public static Map<Long, Set<Long>> getMvToBasePartitions(MTMV mtmv, OlapTable relatedTable)
throws AnalysisException {
HashMap<Long, Set<Long>> res = Maps.newHashMap();
Map<Long, PartitionItem> relatedTableItems = relatedTable.getPartitionInfo().getIdToItem(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ public boolean equals(Object o) {
return false;
}
PlaceholderExpression that = (PlaceholderExpression) o;
return position == that.position && Objects.equals(delegateClazz, that.delegateClazz);
return position == that.position
&& Objects.equals(delegateClazz, that.delegateClazz)
&& distinct == that.distinct;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapCount;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
Expand Down Expand Up @@ -69,9 +68,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate
static {
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
PlaceholderExpression.of(Count.class, 0, true),
new PlaceholderExpression(
ImmutableList.of(PlaceholderExpression.of(ToBitmap.class, 0)),
BitmapUnion.class, 0));
new PlaceholderExpression(ImmutableList.of(), BitmapUnion.class, 0));
}

@Override
Expand Down Expand Up @@ -153,25 +150,19 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
for (Expression topExpression : queryTopPlan.getExpressions()) {
// is agg function, try to roll up and rewrite
if (queryTopPlanFunctionSet.contains(topExpression)) {
Expression needRollupShuttledExpr = ExpressionUtils.shuttleExpressionWithLineage(
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
topExpression,
queryTopPlan);

if (!mvExprToMvScanExprQueryBased.containsKey(needRollupShuttledExpr)) {
// function can not rewrite by view
return null;
}

// try to roll up
AggregateFunction needRollupAggFunction = (AggregateFunction) topExpression.firstMatch(
AggregateFunction queryFunction = (AggregateFunction) topExpression.firstMatch(
expr -> expr instanceof AggregateFunction);
Function rollupAggregateFunction = rollup(needRollupAggFunction,
mvExprToMvScanExprQueryBased.get(needRollupShuttledExpr));
Function rollupAggregateFunction = rollup(queryFunction,
queryFunctionShuttled, mvExprToMvScanExprQueryBased);
if (rollupAggregateFunction == null) {
return null;
}
// key is query need roll up expr, value is mv scan based roll up expr
needRollupExprMap.put(needRollupShuttledExpr, rollupAggregateFunction);
needRollupExprMap.put(queryFunctionShuttled, rollupAggregateFunction);
// rewrite query function expression by mv expression
Expression rewrittenFunctionExpression = rewriteExpression(topExpression,
queryTopPlan,
Expand Down Expand Up @@ -249,23 +240,44 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
}

// only support sum roll up, support other agg functions later.
private Function rollup(AggregateFunction originFunction,
Expression mappedExpression) {
Class<? extends Function> rollupAggregateFunction = originFunction.getRollup();
private Function rollup(AggregateFunction queryFunction,
Expression queryFunctionShuttled,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
Expression rollupParam = null;
if (mvExprToMvScanExprQueryBased.containsKey(queryFunctionShuttled)) {
// function can not rewrite by view
rollupParam = mvExprToMvScanExprQueryBased.get(queryFunctionShuttled);
} else {
// try to use complex roll up param
// eg: query is count(distinct param), mv sql is bitmap_union(to_bitmap(param))
for (Expression mvExprShuttled : mvExprToMvScanExprQueryBased.keySet()) {
if (!(mvExprShuttled instanceof Function)) {
continue;
}
if (isAggregateFunctionEquivalent(queryFunction, (Function) mvExprShuttled)) {
rollupParam = mvExprToMvScanExprQueryBased.get(mvExprShuttled);
}
}
}
if (rollupParam == null) {
return null;
}
// do roll up
Class<? extends Function> rollupAggregateFunction = queryFunction.getRollup();
if (rollupAggregateFunction == null) {
return null;
}
if (Sum.class.isAssignableFrom(rollupAggregateFunction)) {
return new Sum(originFunction.isDistinct(), mappedExpression);
return new Sum(queryFunction.isDistinct(), rollupParam);
}
if (Max.class.isAssignableFrom(rollupAggregateFunction)) {
return new Max(originFunction.isDistinct(), mappedExpression);
return new Max(queryFunction.isDistinct(), rollupParam);
}
if (Min.class.isAssignableFrom(rollupAggregateFunction)) {
return new Min(originFunction.isDistinct(), mappedExpression);
return new Min(queryFunction.isDistinct(), rollupParam);
}
if (BitmapCount.class.isAssignableFrom(rollupAggregateFunction)) {
return new BitmapCount(mappedExpression);
if (BitmapUnionCount.class.isAssignableFrom(rollupAggregateFunction)) {
return new BitmapUnionCount(rollupParam);
}
// can rollup return null
return null;
Expand Down Expand Up @@ -345,6 +357,7 @@ private boolean isAggregateFunctionEquivalent(Function queryFunction, Function v
if (queryClazz.isAssignableFrom(viewClazz)) {
return true;
}
// bitmap roll up
boolean isDistinct = queryFunction instanceof AggregateFunction
&& ((AggregateFunction) queryFunction).isDistinct();
PlaceholderExpression equivalentFunction = AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.get(
Expand All @@ -357,7 +370,7 @@ private boolean isAggregateFunctionEquivalent(Function queryFunction, Function v
if (!viewFunction.getClass().isAssignableFrom(equivalentFunction.getDelegateClazz())) {
return false;
}
if (!viewFunction.children().isEmpty()) {
if (!viewFunction.children().isEmpty() && !equivalentFunction.children().isEmpty()) {
// children compare, just compare two level, support more later
List<Expression> equivalentFunctions = equivalentFunction.children();
if (viewFunction.children().size() != equivalentFunctions.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
package org.apache.doris.nereids.rules.exploration.mv;

import org.apache.doris.catalog.MTMV;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Partition;
import org.apache.doris.catalog.PartitionInfo;
import org.apache.doris.catalog.PartitionItem;
import org.apache.doris.catalog.PartitionType;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.mtmv.BaseTableInfo;
import org.apache.doris.mtmv.MTMVCache;
import org.apache.doris.mtmv.MTMVPartitionInfo;
import org.apache.doris.mtmv.MTMVUtil;
Expand All @@ -33,11 +35,6 @@
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand All @@ -49,10 +46,9 @@
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.commands.UpdateMvByPartitionCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
Expand All @@ -61,11 +57,11 @@

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -102,11 +98,12 @@ protected List<Plan> rewrite(Plan queryPlan, CascadesContext cascadesContext) {
queryPlan.getGroupExpression().get().getOwnerGroup().getGroupId())) {
continue;
}
Plan mvPlan = handleValidPartition(materializationContext.getMtmv(), cascadesContext);
if (mvPlan == null) {
continue;
MTMV mtmv = materializationContext.getMtmv();
MTMVCache mtmvCache = getCacheFromMTMV(mtmv);
if (mtmvCache == null) {
return null;
}
List<StructInfo> viewStructInfos = extractStructInfo(mvPlan, cascadesContext);
List<StructInfo> viewStructInfos = extractStructInfo(mtmvCache.getLogicalPlan(), cascadesContext);
if (viewStructInfos.size() > 1) {
// view struct info should only have one
return rewriteResults;
Expand Down Expand Up @@ -176,76 +173,88 @@ protected List<Plan> rewrite(Plan queryPlan, CascadesContext cascadesContext) {
if (rewritedPlan == null) {
continue;
}
if (!checkPartitionIsValid(queryStructInfo, materializationContext, cascadesContext)) {
continue;
}
rewriteResults.add(rewritedPlan);
}
}
return rewriteResults;
}

private MTMVCache getCacheFromMTMV(MTMV mtmv) {
MTMVCache cache;
try {
cache = mtmv.getOrGenerateCache();
} catch (AnalysisException analysisException) {
logger.warn("get mtmv cache analysisException", analysisException);
return null;
protected boolean checkPartitionIsValid(
StructInfo queryInfo,
MaterializationContext materializationContext,
CascadesContext cascadesContext) {
// check partition is valid or not
MTMV mtmv = materializationContext.getMtmv();
PartitionInfo mvPartitionInfo = mtmv.getPartitionInfo();
if (PartitionType.UNPARTITIONED.equals(mvPartitionInfo.getType())) {
// if not partition, if rewrite success, it means mv is available
return true;
}
return cache;
}

// return the plan with filter if some partition is valid
private Plan handleValidPartition(MTMV mtmv, CascadesContext cascadesContext) {
PartitionInfo partitionInfo = mtmv.getPartitionInfo();
PartitionType partitionType = partitionInfo.getType();
MTMVCache mtmvCache = getCacheFromMTMV(mtmv);
if (mtmvCache == null) {
return null;
// check mv related table partition is valid or not
MTMVPartitionInfo mvCustomPartitionInfo = mtmv.getMvPartitionInfo();
BaseTableInfo relatedPartitionTable = mvCustomPartitionInfo.getRelatedTable();
if (relatedPartitionTable == null) {
return true;
}
Optional<LogicalOlapScan> relatedTableRelation = queryInfo.getRelations().stream()
.filter(relation -> relatedPartitionTable.equals(new BaseTableInfo(relation.getTable()))
&& relation instanceof LogicalOlapScan)
.map(relation -> (LogicalOlapScan) relation)
.findFirst();
if (!relatedTableRelation.isPresent()) {
logger.warn("mv is partition update, but related table relation is null");
return false;
}
if (PartitionType.UNPARTITIONED.equals(partitionType)) {
// not handle un partition table
return mtmvCache.getLogicalPlan();
OlapTable relatedTable = relatedTableRelation.get().getTable();
Map<Long, Set<Long>> mvToBasePartitionMap;
try {
mvToBasePartitionMap = MTMVUtil.getMvToBasePartitions(mtmv, relatedTable);
} catch (AnalysisException e) {
logger.error("mvRewriteSuccess getMvToBasePartitions fail", e);
return false;
}
Map<Long, PartitionItem> allPartitions = partitionInfo.getAllPartitions();
Collection<Partition> dataValidPartitions = MTMVUtil.getMTMVCanRewritePartitions(mtmv,
// get mv valid partitions
Collection<Partition> mvDataValidPartitions = MTMVUtil.getMTMVCanRewritePartitions(mtmv,
cascadesContext.getConnectContext());
if (!allPartitions.isEmpty() && dataValidPartitions.isEmpty()) {
Map<Long, PartitionItem> allPartitions = mvPartitionInfo.getAllPartitions();
if (!allPartitions.isEmpty() && mvDataValidPartitions.isEmpty()) {
// do not have valid partition
return null;
return false;
}
if (allPartitions.size() == dataValidPartitions.size()) {
// todo deep equals check,all partition is valid just return the plan
return mtmvCache.getLogicalPlan();
// get mv related table valid partitions
Set<Long> relatedTalbeValidSet = mvDataValidPartitions.stream()
.map(partition -> {
Set<Long> relatedBaseTablePartitions = mvToBasePartitionMap.get(partition.getId());
if (relatedBaseTablePartitions == null || relatedBaseTablePartitions.isEmpty()) {
return ImmutableList.of();
} else {
return relatedBaseTablePartitions;
}
})
.flatMap(Collection::stream)
.map(Long.class::cast)
.collect(Collectors.toSet());
// get query selected partitions to make the partitions is valid or not
Set<Long> relatedTableSelectedPartitionToCheck =
new HashSet<>(relatedTableRelation.get().getSelectedPartitionIds());
if (relatedTableSelectedPartitionToCheck.isEmpty()) {
relatedTableSelectedPartitionToCheck.addAll(relatedTable.getPartitionIds());
}
// handle the scene when some partition is valid
Set<Expression> disjunctions = new HashSet<>();
Set<Long> allPartitionIdSet = allPartitions.keySet();
Plan logicalPlan = mtmvCache.getLogicalPlan();
// get mv partition column name
Map<String, Slot> mvPlanOutputNameMap = new HashMap<>();
logicalPlan.getOutput().forEach(slot -> mvPlanOutputNameMap.putIfAbsent(slot.getName(), slot));
MTMVPartitionInfo mvPartitionInfo = mtmv.getMvPartitionInfo();
Slot partitionColumnSlot = mvPlanOutputNameMap.get(mvPartitionInfo.getPartitionCol());
if (partitionColumnSlot == null) {
return relatedTalbeValidSet.containsAll(relatedTableSelectedPartitionToCheck);
}

private MTMVCache getCacheFromMTMV(MTMV mtmv) {
MTMVCache cache;
try {
cache = mtmv.getOrGenerateCache();
} catch (AnalysisException analysisException) {
logger.warn("get mtmv cache analysisException", analysisException);
return null;
}
for (Partition validPartition : dataValidPartitions) {
if (!allPartitionIdSet.contains(validPartition.getId())) {
return null;
}
disjunctions.add(UpdateMvByPartitionCommand.convertPartitionItemToPredicate(
allPartitions.get(validPartition.getId()),
partitionColumnSlot
));
}

// filter condition optimization
ExpressionOptimization expressionOptimization = new ExpressionOptimization();
ExpressionNormalization expressionNormalization = new ExpressionNormalization();
ExpressionRewriteContext expressionRewriteContext = new ExpressionRewriteContext(cascadesContext);
Expression optimizedExpression = expressionOptimization.rewrite(ExpressionUtils.or(disjunctions),
expressionRewriteContext);
optimizedExpression = expressionNormalization.rewrite(optimizedExpression, expressionRewriteContext);
return new LogicalFilter<>(ExpressionUtils.extractConjunctionToSet(optimizedExpression), mtmvCache.getLogicalPlan());
return cache;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public List<FunctionSignature> getSignatures() {
@Override
public Class<? extends Function> getRollup() {
if (this.isDistinct()) {
return BitmapCount.class;
return BitmapUnionCount.class;
} else {
return Sum.class;
}
Expand Down

0 comments on commit 8dcbdb4

Please sign in to comment.