From 5dfa8908f39f55b1dcd585915235823dcc7b5145 Mon Sep 17 00:00:00 2001 From: minghong Date: Fri, 13 Sep 2024 18:38:11 +0800 Subject: [PATCH] rf targets to cte --- .../translator/PhysicalPlanTranslator.java | 3 - .../processor/post/RuntimeFilterContext.java | 18 -- .../post/RuntimeFilterGenerator.java | 273 +++++++----------- .../post/RuntimeFilterPushDownVisitor.java | 1 + 4 files changed, 112 insertions(+), 183 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 2a34bc3ca91dd20..f9f6caee7a9f1d8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -1238,9 +1238,6 @@ public PlanFragment visitPhysicalCTEProducer(PhysicalCTEProducer multiCastPlanFragment.setOutputExprs(outputs); context.getCteProduceFragments().put(cteId, multiCastPlanFragment); context.getCteProduceMap().put(cteId, cteProducer); - if (context.getRuntimeTranslator().isPresent()) { - context.getRuntimeTranslator().get().getContext().getCteProduceMap().put(cteId, cteProducer); - } context.getPlanFragments().add(multiCastPlanFragment); return child; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java index 746bb05e9fd191c..50b46848dfc9a14 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin; -import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation; import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter; @@ -118,11 +117,6 @@ public boolean equals(Object other) { private final Map effectiveSrcNodes = Maps.newHashMap(); - private final Map cteProducerMap = Maps.newLinkedHashMap(); - - // cte whose runtime filter has been extracted - private final Set processedCTE = Sets.newHashSet(); - private final SessionVariable sessionVariable; private final FilterSizeLimits limits; @@ -160,10 +154,6 @@ public RuntimeFilterContext(SessionVariable sessionVariable) { this.limits = new FilterSizeLimits(sessionVariable); } - public void setRelationsUsedByPlan(Plan plan, Set relations) { - relationsUsedByPlan.put(plan, relations); - } - /** * return true, if the relation is in the subtree */ @@ -185,14 +175,6 @@ public FilterSizeLimits getLimits() { return limits; } - public Map getCteProduceMap() { - return cteProducerMap; - } - - public Set getProcessedCTE() { - return processedCTE; - } - public void setTargetExprIdToFilter(ExprId id, RuntimeFilter filter) { Preconditions.checkArgument(filter.getTargetSlots().stream().anyMatch(expr -> expr.getExprId() == id)); this.targetExprIdToFilter.computeIfAbsent(id, k -> Lists.newArrayList()).add(filter); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java index 1192a66069716a3..1171c6551dc245f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java @@ -101,132 +101,131 @@ public class RuntimeFilterGenerator extends PlanPostProcessor { @Override public Plan processRoot(Plan plan, CascadesContext ctx) { Plan result = plan.accept(this, ctx); - // cte rf + // try to push rf inside CTEProducer + // collect cteProducers RuntimeFilterContext rfCtx = ctx.getRuntimeFilterContext(); - int cteCount = rfCtx.getProcessedCTE().size(); - if (cteCount != 0) { - Map> cteIdToConsumersWithRF = Maps.newHashMap(); - Map> cteToRFsMap = Maps.newHashMap(); - Map> consumerToRFs = Maps.newHashMap(); - Map> consumerToSrcExpression = Maps.newHashMap(); - List allRFs = rfCtx.getNereidsRuntimeFilter(); - for (RuntimeFilter rf : allRFs) { - for (PhysicalRelation rel : rf.getTargetScans()) { - if (rel instanceof PhysicalCTEConsumer) { - PhysicalCTEConsumer consumer = (PhysicalCTEConsumer) rel; - CTEId cteId = consumer.getCteId(); - cteToRFsMap.computeIfAbsent(cteId, key -> Lists.newArrayList()).add(rf); - cteIdToConsumersWithRF.computeIfAbsent(cteId, key -> Sets.newHashSet()).add(consumer); - consumerToRFs.computeIfAbsent(consumer, key -> Sets.newHashSet()).add(rf); - consumerToSrcExpression.computeIfAbsent(consumer, key -> Sets.newHashSet()) - .add(rf.getSrcExpr()); - } + Map cteProducerMap = plan.collect(PhysicalCTEProducer.class::isInstance) + .stream().collect(Collectors.toMap(p -> ((PhysicalCTEProducer) p).getCteId(), + p-> (PhysicalCTEProducer) p)); + // collect cteConsumers which are RF targets + Map> cteIdToConsumersWithRF = Maps.newHashMap(); + Map> consumerToRFs = Maps.newHashMap(); + Map> consumerToSrcExpression = Maps.newHashMap(); + List allRFs = rfCtx.getNereidsRuntimeFilter(); + for (RuntimeFilter rf : allRFs) { + for (PhysicalRelation rel : rf.getTargetScans()) { + if (rel instanceof PhysicalCTEConsumer) { + PhysicalCTEConsumer consumer = (PhysicalCTEConsumer) rel; + CTEId cteId = consumer.getCteId(); + cteIdToConsumersWithRF.computeIfAbsent(cteId, key -> Sets.newHashSet()).add(consumer); + consumerToRFs.computeIfAbsent(consumer, key -> Sets.newHashSet()).add(rf); + consumerToSrcExpression.computeIfAbsent(consumer, key -> Sets.newHashSet()) + .add(rf.getSrcExpr()); } } - for (CTEId cteId : rfCtx.getCteProduceMap().keySet()) { - // if any consumer does not have RF, RF cannot be pushed down. - // cteIdToConsumersWithRF.get(cteId).size() can not be 1, o.w. this cte will be inlined. - if (cteIdToConsumersWithRF.get(cteId) != null - && ctx.getCteIdToConsumers().get(cteId).size() == cteIdToConsumersWithRF.get(cteId).size() - && cteIdToConsumersWithRF.get(cteId).size() >= 2) { - // check if there is a common srcExpr among all the consumers - Set consumers = cteIdToConsumersWithRF.get(cteId); - PhysicalCTEConsumer consumer0 = consumers.iterator().next(); - Set candidateSrcExpressions = consumerToSrcExpression.get(consumer0); - for (PhysicalCTEConsumer currentConsumer : consumers) { - Set srcExpressionsOnCurrentConsumer = consumerToSrcExpression.get(currentConsumer); - candidateSrcExpressions.retainAll(srcExpressionsOnCurrentConsumer); - if (candidateSrcExpressions.isEmpty()) { - break; - } + } + for (CTEId cteId : cteIdToConsumersWithRF.keySet()) { + // if any consumer does not have RF, RF cannot be pushed down. + // cteIdToConsumersWithRF.get(cteId).size() can not be 1, o.w. this cte will be inlined. + if (ctx.getCteIdToConsumers().get(cteId).size() == cteIdToConsumersWithRF.get(cteId).size() + && cteIdToConsumersWithRF.get(cteId).size() >= 2) { + // check if there is a common srcExpr among all the consumers + Set consumers = cteIdToConsumersWithRF.get(cteId); + PhysicalCTEConsumer consumer0 = consumers.iterator().next(); + Set candidateSrcExpressions = consumerToSrcExpression.get(consumer0); + for (PhysicalCTEConsumer currentConsumer : consumers) { + Set srcExpressionsOnCurrentConsumer = consumerToSrcExpression.get(currentConsumer); + candidateSrcExpressions.retainAll(srcExpressionsOnCurrentConsumer); + if (candidateSrcExpressions.isEmpty()) { + break; } - if (!candidateSrcExpressions.isEmpty()) { - // find RFs to push down - for (Expression srcExpr : candidateSrcExpressions) { - List rfsToPushDown = Lists.newArrayList(); - for (PhysicalCTEConsumer consumer : cteIdToConsumersWithRF.get(cteId)) { - for (RuntimeFilter rf : consumerToRFs.get(consumer)) { - if (rf.getSrcExpr().equals(srcExpr)) { - rfsToPushDown.add(rf); - } + } + if (!candidateSrcExpressions.isEmpty()) { + // find RFs to push down + for (Expression srcExpr : candidateSrcExpressions) { + List rfsToPushDown = Lists.newArrayList(); + for (PhysicalCTEConsumer consumer : cteIdToConsumersWithRF.get(cteId)) { + for (RuntimeFilter rf : consumerToRFs.get(consumer)) { + if (rf.getSrcExpr().equals(srcExpr)) { + rfsToPushDown.add(rf); } } - if (rfsToPushDown.isEmpty()) { - break; - } + } + if (rfsToPushDown.isEmpty()) { + break; + } - // the most right deep buildNode from rfsToPushDown is used as buildNode for pushDown rf - // since the srcExpr are the same, all buildNodes of rfToPushDown are in the same tree path - // the longest ancestors means its corresponding rf build node is the most right deep one. - List rightDeepRfs = Lists.newArrayList(); - List rightDeepAncestors = rfsToPushDown.get(0).getBuilderNode().getAncestors(); - int rightDeepAncestorsSize = rightDeepAncestors.size(); - RuntimeFilter leftTop = rfsToPushDown.get(0); - int leftTopAncestorsSize = rightDeepAncestorsSize; - for (RuntimeFilter rf : rfsToPushDown) { - List ancestors = rf.getBuilderNode().getAncestors(); - int currentAncestorsSize = ancestors.size(); - if (currentAncestorsSize >= rightDeepAncestorsSize) { - if (currentAncestorsSize == rightDeepAncestorsSize) { - rightDeepRfs.add(rf); - } else { - rightDeepAncestorsSize = currentAncestorsSize; - rightDeepAncestors = ancestors; - rightDeepRfs.clear(); - rightDeepRfs.add(rf); - } - } - if (currentAncestorsSize < leftTopAncestorsSize) { - leftTopAncestorsSize = currentAncestorsSize; - leftTop = rf; + // the most right deep buildNode from rfsToPushDown is used as buildNode for pushDown rf + // since the srcExpr are the same, all buildNodes of rfToPushDown are in the same tree path + // the longest ancestors means its corresponding rf build node is the most right deep one. + List rightDeepRfs = Lists.newArrayList(); + List rightDeepAncestors = rfsToPushDown.get(0).getBuilderNode().getAncestors(); + int rightDeepAncestorsSize = rightDeepAncestors.size(); + RuntimeFilter leftTop = rfsToPushDown.get(0); + int leftTopAncestorsSize = rightDeepAncestorsSize; + for (RuntimeFilter rf : rfsToPushDown) { + List ancestors = rf.getBuilderNode().getAncestors(); + int currentAncestorsSize = ancestors.size(); + if (currentAncestorsSize >= rightDeepAncestorsSize) { + if (currentAncestorsSize == rightDeepAncestorsSize) { + rightDeepRfs.add(rf); + } else { + rightDeepAncestorsSize = currentAncestorsSize; + rightDeepAncestors = ancestors; + rightDeepRfs.clear(); + rightDeepRfs.add(rf); } } - Preconditions.checkArgument(rightDeepAncestors.contains(leftTop.getBuilderNode())); - // check nodes between right deep and left top are SPJ and not denied join and not mark join - boolean valid = true; - for (Plan cursor : rightDeepAncestors) { - if (cursor.equals(leftTop.getBuilderNode())) { - break; - } - // valid = valid && SPJ_PLAN.contains(cursor.getClass()); - if (cursor instanceof AbstractPhysicalJoin) { - AbstractPhysicalJoin cursorJoin = (AbstractPhysicalJoin) cursor; - valid = (!RuntimeFilterGenerator.DENIED_JOIN_TYPES - .contains(cursorJoin.getJoinType()) - || cursorJoin.isMarkJoin()) && valid; - } - if (!valid) { - break; - } + if (currentAncestorsSize < leftTopAncestorsSize) { + leftTopAncestorsSize = currentAncestorsSize; + leftTop = rf; + } + } + Preconditions.checkArgument(rightDeepAncestors.contains(leftTop.getBuilderNode())); + // check nodes between right deep and left top are SPJ and not denied join and not mark join + boolean valid = true; + for (Plan cursor : rightDeepAncestors) { + if (cursor.equals(leftTop.getBuilderNode())) { + break; + } + // valid = valid && SPJ_PLAN.contains(cursor.getClass()); + if (cursor instanceof AbstractPhysicalJoin) { + AbstractPhysicalJoin cursorJoin = (AbstractPhysicalJoin) cursor; + valid = (!RuntimeFilterGenerator.DENIED_JOIN_TYPES + .contains(cursorJoin.getJoinType()) + || cursorJoin.isMarkJoin()) && valid; } - if (!valid) { break; } + } + + if (!valid) { + break; + } - for (RuntimeFilter rfToPush : rightDeepRfs) { - Expression rightDeepTargetExpressionOnCTE = null; - int targetCount = rfToPush.getTargetExpressions().size(); - for (int i = 0; i < targetCount; i++) { - PhysicalRelation rel = rfToPush.getTargetScans().get(i); - if (rel instanceof PhysicalCTEConsumer - && ((PhysicalCTEConsumer) rel).getCteId().equals(cteId)) { - rightDeepTargetExpressionOnCTE = rfToPush.getTargetExpressions().get(i); - break; - } + for (RuntimeFilter rfToPush : rightDeepRfs) { + Expression rightDeepTargetExpressionOnCTE = null; + int targetCount = rfToPush.getTargetExpressions().size(); + for (int i = 0; i < targetCount; i++) { + PhysicalRelation rel = rfToPush.getTargetScans().get(i); + if (rel instanceof PhysicalCTEConsumer + && ((PhysicalCTEConsumer) rel).getCteId().equals(cteId)) { + rightDeepTargetExpressionOnCTE = rfToPush.getTargetExpressions().get(i); + break; } + } - boolean pushedDown = doPushDownIntoCTEProducerInternal( + boolean pushedDown = doPushDownIntoCTEProducerInternal( + rfToPush, + rightDeepTargetExpressionOnCTE, + rfCtx, + cteProducerMap.get(cteId) + ); + if (pushedDown) { + rfCtx.removeFilter( rfToPush, - rightDeepTargetExpressionOnCTE, - rfCtx, - rfCtx.getCteProduceMap().get(cteId) - ); - if (pushedDown) { - rfCtx.removeFilter( - rfToPush, - rightDeepTargetExpressionOnCTE.getInputSlotExprIds().iterator().next()); - } + rightDeepTargetExpressionOnCTE.getInputSlotExprIds().iterator().next()); } } } @@ -265,8 +264,8 @@ public PhysicalPlan visitPhysicalHashJoin(PhysicalHashJoin (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0) .collect(Collectors.toList()); - List hashJoinConjuncts = join.getHashJoinConjuncts().stream().collect(Collectors.toList()); - boolean buildSideContainsConsumer = hasCTEConsumerDescendant((PhysicalPlan) join.right()); + List hashJoinConjuncts = join.getHashJoinConjuncts(); + for (int i = 0; i < hashJoinConjuncts.size(); i++) { EqualPredicate equalTo = JoinUtils.swapEqualToForChildrenOrder( (EqualPredicate) hashJoinConjuncts.get(i), join.left().getOutputSet()); @@ -277,9 +276,7 @@ public PhysicalPlan visitPhysicalHashJoin(PhysicalHashJoin pair = ctx.getAliasTransferMap().get(equalTo.right()); - // CteConsumer is not allowed to generate RF in order to avoid RF cycle. - if ((pair == null && buildSideContainsConsumer) - || (pair != null && pair.first instanceof PhysicalCTEConsumer)) { + if (pair == null) { continue; } if (equalTo.left().getInputSlots().size() == 1) { @@ -306,20 +303,6 @@ public PhysicalCTEConsumer visitPhysicalCTEConsumer(PhysicalCTEConsumer scan, Ca return scan; } - @Override - public PhysicalCTEProducer visitPhysicalCTEProducer(PhysicalCTEProducer producer, - CascadesContext context) { - CTEId cteId = producer.getCteId(); - context.getRuntimeFilterContext().getCteProduceMap().put(cteId, producer); - Set processedCTE = context.getRuntimeFilterContext().getProcessedCTE(); - if (!processedCTE.contains(cteId)) { - PhysicalPlan inputPlanNode = (PhysicalPlan) producer.child(0); - inputPlanNode.accept(this, context); - processedCTE.add(cteId); - } - return producer; - } - private void generateBitMapRuntimeFilterForNLJ(PhysicalNestedLoopJoin join, RuntimeFilterContext ctx) { if (join.getJoinType() != JoinType.LEFT_SEMI_JOIN && join.getJoinType() != JoinType.CROSS_JOIN) { @@ -680,38 +663,4 @@ public static void getAllScanInfo(Plan root, Set scans) { } } } - - /** - * Check whether plan root contains cte consumer descendant. - */ - public static boolean hasCTEConsumerDescendant(PhysicalPlan root) { - if (root instanceof PhysicalCTEConsumer) { - return true; - } else if (root.children().size() == 1) { - return hasCTEConsumerDescendant((PhysicalPlan) root.child(0)); - } else { - for (Object child : root.children()) { - if (hasCTEConsumerDescendant((PhysicalPlan) child)) { - return true; - } - } - return false; - } - } - - /** - * Check whether runtime filter target is remote or local - */ - public static boolean hasRemoteTarget(AbstractPlan join, AbstractPlan scan) { - if (scan instanceof PhysicalCTEConsumer) { - return true; - } else { - Preconditions.checkArgument(join.getMutableState(AbstractPlan.FRAGMENT_ID).isPresent(), - "cannot find fragment id for Join node"); - Preconditions.checkArgument(scan.getMutableState(AbstractPlan.FRAGMENT_ID).isPresent(), - "cannot find fragment id for scan node"); - return join.getMutableState(AbstractPlan.FRAGMENT_ID).get() - != scan.getMutableState(AbstractPlan.FRAGMENT_ID).get(); - } - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPushDownVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPushDownVisitor.java index 9a5c6a1daa96923..6f22de2cc2c4c50 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPushDownVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPushDownVisitor.java @@ -33,6 +33,7 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalPlan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEConsumer; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;