diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/MTMV.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/MTMV.java index 0c59f56db5c8317..af0691d94fbecc6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/MTMV.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/MTMV.java @@ -30,6 +30,7 @@ import org.apache.doris.mtmv.MTMVRefreshInfo; import org.apache.doris.mtmv.MTMVRelation; import org.apache.doris.mtmv.MTMVStatus; +import org.apache.doris.mtmv.MVCache; import org.apache.doris.persist.gson.GsonUtils; import com.google.gson.annotations.SerializedName; @@ -61,6 +62,8 @@ public class MTMV extends OlapTable { private Map mvProperties; @SerializedName("r") private MTMVRelation relation; + // Should update after every fresh + private MVCache mvCache; // For deserialization public MTMV() { @@ -116,6 +119,14 @@ public MTMVRelation getRelation() { return relation; } + public MVCache getMvCache() { + return mvCache; + } + + public void setMvCache(MVCache mvCache) { + this.mvCache = mvCache; + } + public MTMVRefreshInfo alterRefreshInfo(MTMVRefreshInfo newRefreshInfo) { return refreshInfo.updateNotNull(newRefreshInfo); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVCacheManager.java b/fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVCacheManager.java index 847f85b90749958..ab70d3a962bcc0b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVCacheManager.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVCacheManager.java @@ -48,6 +48,7 @@ import org.apache.doris.persist.AlterMTMV; import org.apache.doris.qe.ConnectContext; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import org.apache.commons.collections.CollectionUtils; @@ -70,6 +71,11 @@ public Set getMtmvsByBaseTable(BaseTableInfo table) { return tableMTMVs.get(table); } + // TODO Implement the method which getting materialized view by tables + public List getAvailableMaterializedView(List tables){ + return ImmutableList.of(); + } + public boolean isAvailableMTMV(MTMV mtmv, ConnectContext ctx) throws AnalysisException, DdlException { // check session variable if enable rewrite if (!ctx.getSessionVariable().isEnableMvRewrite()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/mtmv/MVCache.java b/fe/fe-core/src/main/java/org/apache/doris/mtmv/MVCache.java new file mode 100644 index 000000000000000..6a74c0c9562f437 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/mtmv/MVCache.java @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.mtmv; + +import org.apache.doris.catalog.MTMV; +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.plans.Plan; + +import java.util.List; + +/**The cache for materialized view cache */ +public class MVCache { + + // the materialized view plan which should be optimized by the same rules to query + private final Plan logicalPlan; + // this should be shuttle expression with lineage + private final List mvOutputExpressions; + // the context when parse, analyze, optimize the mv logical plan + private final CascadesContext context; + + public MVCache(MTMV materializedView, Plan logicalPlan, List mvOutputExpressions, + CascadesContext context) { + this.logicalPlan = logicalPlan; + this.mvOutputExpressions = mvOutputExpressions; + this.context = context; + } + + public Plan getLogicalPlan() { + return logicalPlan; + } + + public List getMvOutputExpressions() { + return mvOutputExpressions; + } + + public CascadesContext getContext() { + return context; + } + + public static MVCache from(MTMV mtmv) { + // TODO Init the MVCache + return new MVCache(mtmv, null, null, null); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java index 899ba216e9965a3..7eeb5e7b288961e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java @@ -18,9 +18,12 @@ package org.apache.doris.nereids.rules.exploration.mv; import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping; +import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; import java.util.List; @@ -34,15 +37,18 @@ public abstract class AbstractMaterializedViewJoinRule extends AbstractMateriali protected Plan rewriteQueryByView(MatchMode matchMode, StructInfo queryStructInfo, StructInfo viewStructInfo, - RelationMapping queryToViewTableMappings, - Plan tempRewritedPlan) { + SlotMapping queryToViewSlotMappings, + Plan tempRewritedPlan, + MaterializationContext materializationContext) { + List queryShuttleExpression = ExpressionUtils.shuttleExpressionWithLineage( + queryStructInfo.getExpressions(), + queryStructInfo.getOriginalPlan()); // Rewrite top projects, represent the query projects by view List expressions = rewriteExpression( - queryStructInfo.getExpressions(), - queryStructInfo, - viewStructInfo, - queryToViewTableMappings, + queryShuttleExpression, + materializationContext.getViewExpressionIndexMapping(), + queryToViewSlotMappings, tempRewritedPlan ); // Can not rewrite, bail out @@ -58,6 +64,6 @@ protected Plan rewriteQueryByView(MatchMode matchMode, @Override protected boolean checkPattern(StructInfo structInfo) { // TODO Should get struct info from hyper graph and check - return false; + return true; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java index ea5e86f38f2af72..6b0045b6706c9b4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java @@ -22,12 +22,13 @@ import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.rules.exploration.mv.Predicates.SplitPredicate; import org.apache.doris.nereids.rules.exploration.mv.mapping.EquivalenceClassSetMapping; -import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionIndexMapping; +import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping; import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping; import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.plans.Plan; @@ -35,11 +36,13 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.util.ExpressionUtils; -import com.clearspring.analytics.util.Lists; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Multimap; import com.google.common.collect.Sets; import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -69,7 +72,7 @@ protected List rewrite(Plan queryPlan, CascadesContext cascadesContext) { } for (MaterializationContext materializationContext : materializationContexts) { - Plan mvPlan = materializationContext.getMvPlan(); + Plan mvPlan = materializationContext.getMtmv().getMvCache().getLogicalPlan(); StructInfo viewStructInfo = extractStructInfo(mvPlan, cascadesContext); if (!checkPattern(viewStructInfo)) { continue; @@ -92,16 +95,15 @@ protected List rewrite(Plan queryPlan, CascadesContext cascadesContext) { continue; } Plan rewritedPlan; - Plan mvScan = materializationContext.getScanPlan(); + Plan mvScan = materializationContext.getMvScanPlan(); if (compensatePredicates.isAlwaysTrue()) { rewritedPlan = mvScan; } else { // Try to rewrite compensate predicates by using mv scan List rewriteCompensatePredicates = rewriteExpression( compensatePredicates.toList(), - queryStructInfo, - viewStructInfo, - queryToViewTableMapping, + materializationContext.getViewExpressionIndexMapping(), + queryToViewSlotMapping, mvScan); if (rewriteCompensatePredicates.isEmpty()) { continue; @@ -109,8 +111,12 @@ protected List rewrite(Plan queryPlan, CascadesContext cascadesContext) { rewritedPlan = new LogicalFilter<>(Sets.newHashSet(rewriteCompensatePredicates), mvScan); } // Rewrite query by view - rewritedPlan = rewriteQueryByView(matchMode, queryStructInfo, viewStructInfo, - queryToViewTableMapping, rewritedPlan); + rewritedPlan = rewriteQueryByView(matchMode, + queryStructInfo, + viewStructInfo, + queryToViewSlotMapping, + rewritedPlan, + materializationContext); if (rewritedPlan == null) { continue; } @@ -124,18 +130,19 @@ protected List rewrite(Plan queryPlan, CascadesContext cascadesContext) { protected Plan rewriteQueryByView(MatchMode matchMode, StructInfo queryStructInfo, StructInfo viewStructInfo, - RelationMapping queryToViewTableMappings, - Plan tempRewritedPlan) { + SlotMapping queryToViewSlotMappings, + Plan tempRewritedPlan, + MaterializationContext materializationContext) { return tempRewritedPlan; } - /**Use target output expression to represent the source expression*/ - protected List rewriteExpression(List sourceExpressions, - StructInfo sourceStructInfo, - StructInfo targetStructInfo, - RelationMapping sourceToTargetMapping, + /**Use target output expression to represent the source expression + * */ + protected List rewriteExpression( + List sourceExpressions, + ExpressionMapping expressionMapping, + SlotMapping sourceToTargetMapping, Plan targetScanNode) { - // TODO represent the sourceExpressions by using target scan node // Firstly, rewrite the target plan output expression using query with inverse mapping // then try to use the mv expression to represent the query. if any of source expressions // can not be represented by mv, return null @@ -148,18 +155,21 @@ protected List rewriteExpression(List sou // transform source to: // project(slot 2, 1) // target - List targetTopExpressions = targetStructInfo.getExpressions(); - List shuttledTargetExpressions = ExpressionUtils.shuttleExpressionWithLineage( - targetTopExpressions, targetStructInfo.getOriginalPlan(), Sets.newHashSet(), Sets.newHashSet()); - SlotMapping sourceToTargetSlotMapping = SlotMapping.generate(sourceToTargetMapping); - // mv sql plan expressions transform to query based - List queryBasedExpressions = ExpressionUtils.replace( - shuttledTargetExpressions.stream().map(Expression.class::cast).collect(Collectors.toList()), - (Map)sourceToTargetSlotMapping.inverse().toMappedSlotMap()); - // mv sql query based expression and index mapping - ExpressionIndexMapping.generate(queryBasedExpressions); - // TODO visit source expression and replace the expression with expressionIndexMapping - return ImmutableList.of(); + List> maps = expressionMapping.flattenMap(); + // view to view scan expression is 1:1 so get first element + Map expressionMap = maps.get(0); + + List result = new ArrayList<>(); + for (Expression expressionToRewrite : sourceExpressions) { + final Set slotSet = expressionToRewrite.collectToSet(expression -> expression instanceof Slot); + Expression replacedExpression = ExpressionUtils.replace(expressionToRewrite, expressionMap, true); + if (replacedExpression.anyMatch(slotSet::contains)) { + // can not rewrite + return null; + } + result.add((NamedExpression) replacedExpression); + } + return result; } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializationContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializationContext.java index b0de1ccfa469bb7..cf80b9b746f65c7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializationContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializationContext.java @@ -17,12 +17,15 @@ package org.apache.doris.nereids.rules.exploration.mv; +import org.apache.doris.catalog.MTMV; import org.apache.doris.catalog.Table; -import org.apache.doris.catalog.View; -import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.mtmv.MVCache; import org.apache.doris.nereids.memo.GroupId; +import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping; import org.apache.doris.nereids.trees.plans.Plan; +import com.google.common.collect.ImmutableList; + import java.util.HashSet; import java.util.List; import java.util.Set; @@ -32,22 +35,32 @@ */ public class MaterializationContext { - // TODO add MaterializedView class - private final Plan mvPlan; - private final CascadesContext context; + private MTMV mtmv; + // Should use stmt id generator in query context + private final Plan mvScanPlan; private final List baseTables; - private final List baseViews; + private final List
baseViews; // Group ids that are rewritten by this mv to reduce rewrite times private final Set matchedGroups = new HashSet<>(); - private final Plan scanPlan; + // generate form mv scan plan + private ExpressionMapping viewExpressionMapping; - public MaterializationContext(Plan mvPlan, CascadesContext context, - List
baseTables, List baseViews, Plan scanPlan) { - this.mvPlan = mvPlan; - this.context = context; + public MaterializationContext(MTMV mtmv, Plan mvScanPlan, + List
baseTables, + List
baseViews) { + this.mtmv = mtmv; + this.mvScanPlan = mvScanPlan; this.baseTables = baseTables; this.baseViews = baseViews; - this.scanPlan = scanPlan; + MVCache mvCache = mtmv.getMvCache(); + if (mvCache == null) { + // Laze init + mvCache = MVCache.from(mtmv); + mtmv.setMvCache(mvCache); + } + this.viewExpressionMapping = ExpressionMapping.generate( + mvCache.getMvOutputExpressions(), + mvScanPlan.getExpressions()); } public Set getMatchedGroups() { @@ -58,11 +71,32 @@ public void addMatchedGroup(GroupId groupId) { matchedGroups.add(groupId); } - public Plan getMvPlan() { - return mvPlan; + public MTMV getMtmv() { + return mtmv; + } + + public Plan getMvScanPlan() { + return mvScanPlan; + } + + public List
getBaseTables() { + return baseTables; + } + + public List
getBaseViews() { + return baseViews; + } + + public ExpressionMapping getViewExpressionIndexMapping() { + return viewExpressionMapping; } - public Plan getScanPlan() { - return scanPlan; + public static MaterializationContext fromMaterializedView(MTMV materializedView, + Plan mvScanPlan){ + return new MaterializationContext( + materializedView, + mvScanPlan, + ImmutableList.of(), + ImmutableList.of()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/mapping/ExpressionIndexMapping.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/mapping/ExpressionIndexMapping.java deleted file mode 100644 index f63017633a82aab..000000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/mapping/ExpressionIndexMapping.java +++ /dev/null @@ -1,48 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.exploration.mv.mapping; - -import org.apache.doris.nereids.trees.expressions.Expression; - -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.Multimap; - -import java.util.List; - -/** - * Expression and it's index mapping - */ -public class ExpressionIndexMapping extends Mapping { - private final Multimap expressionIndexMapping; - - public ExpressionIndexMapping(Multimap expressionIndexMapping) { - this.expressionIndexMapping = expressionIndexMapping; - } - - public Multimap getExpressionIndexMapping() { - return expressionIndexMapping; - } - - public static ExpressionIndexMapping generate(List expressions) { - Multimap expressionIndexMapping = ArrayListMultimap.create(); - for (int i = 0; i < expressions.size(); i++) { - expressionIndexMapping.put(expressions.get(i), i); - } - return new ExpressionIndexMapping(expressionIndexMapping); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/mapping/ExpressionMapping.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/mapping/ExpressionMapping.java new file mode 100644 index 000000000000000..52f751878e033a9 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/mapping/ExpressionMapping.java @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.mv.mapping; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.Lists; +import com.google.common.collect.Multimap; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Expression and it's index mapping + */ +public class ExpressionMapping extends Mapping { + private final Multimap expressionMapping; + + public ExpressionMapping(Multimap expressionMapping) { + this.expressionMapping = expressionMapping; + } + + public Multimap getExpressionMapping() { + return expressionMapping; + } + + public List> flattenMap() { + List>> tmpExpressionPairs = new ArrayList<>(this.expressionMapping.size()); + Map> map = expressionMapping.asMap(); + for (Map.Entry> entry : map.entrySet()) { + List> valueList= new ArrayList<>(entry.getValue().size()); + for (Expression valueExpression : entry.getValue()) { + valueList.add(Pair.of(entry.getKey(), valueExpression)); + } + tmpExpressionPairs.add(valueList); + } + List>> cartesianExpressionMap = Lists.cartesianProduct(tmpExpressionPairs); + + final List> flattenedMap = new ArrayList<>(); + for (List> listPair : cartesianExpressionMap) { + final Map expressionMap = new HashMap<>(); + listPair.forEach(pair -> expressionMap.put(pair.key(), pair.value())); + flattenedMap.add(expressionMap); + } + return flattenedMap; + } + + public static ExpressionMapping generate( + List sourceExpressions, + List targetExpressions) { + final Multimap expressionMultiMap = + ArrayListMultimap.create(); + for (int i = 0; i< sourceExpressions.size(); i++) { + expressionMultiMap.put(sourceExpressions.get(i), targetExpressions.get(i)); + } + return new ExpressionMapping(expressionMultiMap); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index 6f937cc96d29535..694f0611567ec81 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -24,6 +24,7 @@ import java.util.ArrayDeque; import java.util.Deque; import java.util.List; +import java.util.Set; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; @@ -210,6 +211,19 @@ default List collectToList(Predicate> predicate) { return (List) result.build(); } + /** + * Collect the nodes that satisfied the predicate to set. + */ + default Set collectToSet(Predicate> predicate) { + ImmutableSet.Builder> result = ImmutableSet.builder(); + foreach(node -> { + if (predicate.test(node)) { + result.add(node); + } + }); + return (Set) result.build(); + } + /** * iterate top down and test predicate if contains any instance of the classes * @param types classes array diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index b24f782b5da93a2..611f6e3c28b328b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; import org.apache.doris.nereids.trees.TreeNode; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; @@ -44,6 +45,7 @@ import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer; +import org.apache.doris.nereids.util.ExpressionUtils.ExpressionReplacer.ExpressionReplacerContext; import com.google.common.base.Preconditions; import com.google.common.base.Predicate; @@ -303,7 +305,24 @@ public static Optional extractSlotOrCastOnSlot(Expression expr) { * */ public static Expression replace(Expression expr, Map replaceMap) { - return expr.accept(ExpressionReplacer.INSTANCE, replaceMap); + return expr.accept(ExpressionReplacer.INSTANCE, ExpressionReplacerContext.of(replaceMap, false)); + } + + /** + * Replace expression node in the expression tree by `replaceMap` in top-down manner. + * if replaced, create alias + * For example. + *
+     * input expression: a > 1
+     * replaceMap: a -> b + c
+     *
+     * output:
+     * (b + c) as a > 1
+     * 
+ */ + public static Expression replace(Expression expr, Map replaceMap, + boolean withAlias) { + return expr.accept(ExpressionReplacer.INSTANCE, ExpressionReplacerContext.of(replaceMap, true)); } public static List replace(List exprs, @@ -328,18 +347,48 @@ public static List rewriteDownShortCircuit( } private static class ExpressionReplacer - extends DefaultExpressionRewriter> { + extends DefaultExpressionRewriter { public static final ExpressionReplacer INSTANCE = new ExpressionReplacer(); private ExpressionReplacer() { } @Override - public Expression visit(Expression expr, Map replaceMap) { + public Expression visit(Expression expr, ExpressionReplacerContext replacerContext) { + Map replaceMap = replacerContext.getReplaceMap(); + Expression replacedExpression = expr; if (replaceMap.containsKey(expr)) { - return replaceMap.get(expr); + replacedExpression = replaceMap.get(expr); + return replacerContext.isWithAlias() && expr instanceof NamedExpression + ? new Alias(((NamedExpression) expr).getExprId(), + replacedExpression, + expr.getExpressionName()) : replacedExpression; } - return super.visit(expr, replaceMap); + return super.visit(expr, replacerContext); + } + } + + private static class ExpressionReplacerContext { + private final Map replaceMap; + private final boolean withAlias; + + private ExpressionReplacerContext(Map replaceMap, + boolean withAlias) { + this.replaceMap = replaceMap; + this.withAlias = withAlias; + } + + public static ExpressionReplacerContext of(Map replaceMap, + boolean withAlias) { + return new ExpressionReplacerContext(replaceMap, withAlias); + } + + public Map getReplaceMap() { + return replaceMap; + } + + public boolean isWithAlias() { + return withAlias; } }