diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index 16fe1353facfb6e..c0ab8aa6d53d7a0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -31,6 +31,7 @@ import org.apache.doris.common.util.DebugUtil; import org.apache.doris.mysql.FieldInfo; import org.apache.doris.nereids.CascadesContext.Lock; +import org.apache.doris.nereids.commonCTE.CteExtractor; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.glue.LogicalPlanAdapter; import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator; @@ -59,6 +60,7 @@ import org.apache.doris.nereids.trees.plans.distribute.DistributePlanner; import org.apache.doris.nereids.trees.plans.distribute.DistributedPlan; import org.apache.doris.nereids.trees.plans.distribute.FragmentIdMapping; +import org.apache.doris.nereids.trees.plans.logical.AbstractLogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -284,6 +286,8 @@ protected Plan planWithoutLock( setRuntimeFilterWaitTimeByTableRowCountAndType(); + extractCommonCTE(); + optimize(); if (statementContext.getConnectContext().getExecutor() != null) { statementContext.getConnectContext().getExecutor().getSummaryProfile().setNereidsOptimizeTime(); @@ -320,6 +324,11 @@ protected LogicalPlan preprocess(LogicalPlan logicalPlan) { return new PlanPreprocessors(statementContext).process(logicalPlan); } + private void extractCommonCTE() { + CteExtractor commonCTE = new CteExtractor((AbstractLogicalPlan) cascadesContext.getRewritePlan()); + commonCTE.execute(); + } + /** * compute rf wait time according to max table row count, if wait time is not default value * olap: diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/commonCTE/CteExtractor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/commonCTE/CteExtractor.java new file mode 100644 index 000000000000000..ab907b456267844 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/commonCTE/CteExtractor.java @@ -0,0 +1,44 @@ +package org.apache.doris.nereids.commonCTE; + +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.AbstractLogicalPlan; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class CteExtractor { + private AbstractLogicalPlan plan; + private Map anchorToSignature = new HashMap<>(); + private Map> signatureToAnchorList = new HashMap<>(); + + public CteExtractor(AbstractLogicalPlan plan) { + this.plan = plan; + } + + public AbstractLogicalPlan execute() { + sign(); + return plan; + } + + private void sign() { + SignatureVisitor visitor = new SignatureVisitor(); + visitor.visit(plan, anchorToSignature); + extract(); + } + + private void extract() { + List a = anchorToSignature.values().stream().collect(Collectors.toList()); + + + for (Plan plan : anchorToSignature.keySet()) { + TableSignature signature = anchorToSignature.get(plan); + List plans = signatureToAnchorList.computeIfAbsent(signature, key -> new ArrayList<>()); + plans.add(plan); + } + } + + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/commonCTE/SignatureVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/commonCTE/SignatureVisitor.java new file mode 100644 index 000000000000000..38ec46b78226589 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/commonCTE/SignatureVisitor.java @@ -0,0 +1,88 @@ +package org.apache.doris.nereids.commonCTE; + +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor; + +import com.google.common.collect.ImmutableSet; + +import java.util.Map; + +public class SignatureVisitor extends DefaultPlanVisitor >{ + + @Override + public TableSignature visitLogicalCatalogRelation(LogicalCatalogRelation relation, + Map signatureMap) { + TableSignature signature = new TableSignature(true, false, + ImmutableSet.of(relation.getTable().getId())); + signatureMap.put(relation, signature); + return signature; + } + + @Override + public TableSignature visitLogicalFilter(LogicalFilter filter, + Map signatureMap) { + TableSignature childSignature = filter.child().accept(this, signatureMap); + if (filter.child() instanceof LogicalAggregate) { + return TableSignature.EMPTY; + } + signatureMap.put(filter, childSignature); + return childSignature; + } + + @Override + public TableSignature visitLogicalJoin(LogicalJoin join, + Map signatureMap) { + TableSignature signature = TableSignature.EMPTY; + TableSignature leftSignature = join.left().accept(this, signatureMap); + + TableSignature rightSignature = join.right().accept(this, signatureMap); + + if (leftSignature != TableSignature.EMPTY && rightSignature != TableSignature.EMPTY) { + signature = new TableSignature(true, + leftSignature.isContainsAggregation() || rightSignature.isContainsAggregation(), + new ImmutableSet.Builder() + .addAll(leftSignature.getTableIds()) + .addAll(rightSignature.getTableIds()) + .build()); + signatureMap.put(join, signature); + } + return signature; + } + + @Override + public TableSignature visitLogicalAggregate(LogicalAggregate aggregate, + Map signatureMap) { + TableSignature signature = TableSignature.EMPTY; + TableSignature childSignature = aggregate.child().accept(this, signatureMap); + if (childSignature != TableSignature.EMPTY) { + signature = childSignature.withContainsAggregation(true); + signatureMap.put(aggregate, signature); + } + return signature; + } + + @Override + public TableSignature visitLogicalProject(LogicalProject project, + Map signatureMap) { + TableSignature childSignature = project.child().accept(this, signatureMap); + if (childSignature != TableSignature.EMPTY) { + signatureMap.put(project, childSignature); + } + return childSignature; + } + + @Override + public TableSignature visit(Plan plan, Map signatureMap) { + for (Plan child : plan.children()) { + child.accept(this, signatureMap); + } + return TableSignature.EMPTY; + } +} + + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/commonCTE/TableSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/commonCTE/TableSignature.java new file mode 100644 index 000000000000000..d91edca2898380d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/commonCTE/TableSignature.java @@ -0,0 +1,70 @@ +package org.apache.doris.nereids.commonCTE; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Set; + +public class TableSignature { + private final boolean isSPJG; + private final boolean containsAggregation; + private final Set tableIds; + + public static TableSignature EMPTY = new TableSignature(false, false, ImmutableSet.of()); + + public TableSignature(boolean isSPJG, boolean containsAggregation , Set tableIds) { + this.isSPJG = isSPJG; + this.tableIds = ImmutableSet.copyOf(tableIds); + this.containsAggregation = containsAggregation; + } + + public boolean isSPJG() { + return isSPJG; + } + + public boolean isContainsAggregation() { + return containsAggregation; + } + + public Set getTableIds() { + return tableIds; + } + + public TableSignature withContainsAggregation(boolean containsAggregation) { + return new TableSignature(isSPJG, containsAggregation, tableIds); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + if (isSPJG) { + sb.append("SPJG "); + } + if (containsAggregation) { + sb.append("AGG "); + } + if (tableIds != null && !tableIds.isEmpty()) { + sb.append(tableIds); + } + return sb.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TableSignature that = (TableSignature) o; + return isSPJG == that.isSPJG && containsAggregation == that.containsAggregation + && tableIds.equals(that.tableIds); + } + + @Override + public int hashCode() { + return getClass().hashCode() + tableIds.hashCode(); + } +}