Skip to content

Commit

Permalink
aggregate tmp commit
Browse files Browse the repository at this point in the history
  • Loading branch information
seawinde committed Dec 5, 2023
1 parent 720ea15 commit 321c0e5
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.apache.doris.nereids.rules.exploration.join.PushDownProjectThroughSemiJoin;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectJoinRule;
import org.apache.doris.nereids.rules.implementation.AggregateStrategies;
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
Expand Down Expand Up @@ -223,6 +225,8 @@ public class RuleSet {

public static final List<Rule> MATERIALIZED_VIEW_RULES = planRuleFactories()
.add(MaterializedViewProjectJoinRule.INSTANCE)
.add(MaterializedViewAggregateRule.INSTANCE)
.add(MaterializedViewProjectAggregateRule.INSTANCE)
.build();

public List<Rule> getDPHypReorderRules() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,87 @@

package org.apache.doris.nereids.rules.exploration.mv;

import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;

import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Sets;
import static org.apache.doris.nereids.rules.exploration.mv.StructInfo.AGGREGATE_PATTERN_CHECKER;

import java.util.HashSet;
import java.util.List;

/**
* AbstractMaterializedViewAggregateRule
* This is responsible for common aggregate rewriting
* */
public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule {

@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping queryToViewSlotMappings,
Plan tempRewritedPlan,
MaterializationContext materializationContext) {

PlanSplitContext planSplitContext = new PlanSplitContext(Sets.newHashSet(LogicalAggregate.class));
viewStructInfo.getTopPlan().accept(StructInfo.PLAN_SPLITTER, planSplitContext);

LogicalAggregate<Plan> bottomAggregate = (LogicalAggregate<Plan>) planSplitContext.getBottomPlan().get(0);
Plan topPlan = planSplitContext.getTopPlan();
ExpressionMapping aggregateToTopExpressionMapping = generateAggregateToTopMapping(bottomAggregate, topPlan);
return null;
}

private ExpressionMapping generateAggregateToTopMapping(Plan source, Plan target) {
ImmutableMultimap.Builder<Slot, Slot> expressionMappingBuilder = ImmutableMultimap.builder();
List<Slot> sourceOutput = source.getOutput();
List<Slot> targetOutputOutput = target.getOutput();
for (Slot sourceSlot : sourceOutput) {
for (Slot targetSlot : targetOutputOutput) {
if (sourceSlot.equals(targetSlot)) {
expressionMappingBuilder.put(targetSlot, sourceSlot);
}
}
}
return new ExpressionMapping(expressionMappingBuilder.build());
}

// Check Aggregate is simple or not and check join is whether valid or not.
// Support join's input can not contain aggregate Only support project, filter, join, logical relation node and
// join condition should be slot reference equals currently
@Override
protected boolean checkPattern(StructInfo structInfo) {

Plan topPlan = structInfo.getTopPlan();
Boolean valid = topPlan.accept(AGGREGATE_PATTERN_CHECKER, null);
if (!valid) {
return false;
}
HyperGraph hyperGraph = structInfo.getHyperGraph();
HashSet<JoinType> requiredJoinType = Sets.newHashSet(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN);
for (AbstractNode node : hyperGraph.getNodes()) {
StructInfoNode structInfoNode = (StructInfoNode) node;
if (!structInfoNode.getPlan().accept(StructInfo.JOIN_PATTERN_CHECKER,
requiredJoinType)) {
return false;
}
for (Edge edge : hyperGraph.getEdges()) {
if (!edge.getJoin().accept(StructInfo.JOIN_PATTERN_CHECKER, requiredJoinType)) {
return false;
}
}
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,29 @@
package org.apache.doris.nereids.rules.exploration.mv;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RulePromise;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;

import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* This is responsible for aggregate rewriting according to different pattern
* */
public class MaterializedViewAggregateRule extends AbstractMaterializedViewAggregateRule implements RewriteRuleFactory {

public static final MaterializedViewAggregateRule INSTANCE = new MaterializedViewAggregateRule();

@Override
public List<Rule> buildRules() {
return null;
return ImmutableList.of(
logicalAggregate(any()).thenApplyMulti(ctx -> {
LogicalAggregate<Plan> root = ctx.root;
return rewrite(root, ctx.cascadesContext);
}).toRule(RuleType.MATERIALIZED_VIEW_ONLY_AGGREGATE, RulePromise.EXPLORE));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.exploration.mv;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RulePromise;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;

import com.google.common.collect.ImmutableList;

import java.util.List;

/**MaterializedViewProjectAggregateRule*/
public class MaterializedViewProjectAggregateRule extends AbstractMaterializedViewAggregateRule implements
RewriteRuleFactory {

public static final MaterializedViewProjectAggregateRule INSTANCE = new MaterializedViewProjectAggregateRule();

@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(any()).thenApplyMulti(ctx -> {
LogicalAggregate<Plan> root = ctx.root;
return rewrite(root, ctx.cascadesContext);
}).toRule(RuleType.MATERIALIZED_VIEW_PROJECT_AGGREGATE, RulePromise.EXPLORE));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ public List<Rule> buildRules() {
logicalProject(logicalJoin(any(), any())).thenApplyMulti(ctx -> {
LogicalProject<LogicalJoin<Plan, Plan>> root = ctx.root;
return rewrite(root, ctx.cascadesContext);
}).toRule(RuleType.MATERIALIZED_VIEW_ONLY_JOIN, RulePromise.EXPLORE));
}).toRule(RuleType.MATERIALIZED_VIEW_PROJECT_JOIN, RulePromise.EXPLORE));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,20 @@
import org.apache.doris.nereids.trees.plans.algebra.Filter;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.algebra.Project;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
Expand All @@ -52,8 +55,9 @@ public class StructInfo {
public static final JoinPatternChecker JOIN_PATTERN_CHECKER = new JoinPatternChecker();
private static final RelationCollector RELATION_COLLECTOR = new RelationCollector();
private static final PredicateCollector PREDICATE_COLLECTOR = new PredicateCollector();
public static final AggregatePatternChecker AGGREGATE_PATTERN_CHECKER = new AggregatePatternChecker();
// struct info splitter
private static final PlanSplitter PLAN_SPLITTER = new PlanSplitter();
public static final PlanSplitter PLAN_SPLITTER = new PlanSplitter();
// source data
private final Plan originalPlan;
private final HyperGraph hyperGraph;
Expand All @@ -79,9 +83,10 @@ private StructInfo(Plan originalPlan, @Nullable Plan topPlan, @Nullable Plan bot
private void init() {

if (topPlan == null || bottomPlan == null) {
List<Plan> topPlans = new ArrayList<>();
this.bottomPlan = originalPlan.accept(PLAN_SPLITTER, topPlans);
this.topPlan = topPlans.get(0);
PlanSplitContext planSplitContext = new PlanSplitContext(Sets.newHashSet(LogicalJoin.class));
originalPlan.accept(PLAN_SPLITTER, planSplitContext);
this.bottomPlan = planSplitContext.getBottomPlan().get(0);
this.topPlan = planSplitContext.getTopPlan();
}

this.predicates = Predicates.of();
Expand Down Expand Up @@ -142,13 +147,13 @@ private void init() {
public static List<StructInfo> of(Plan originalPlan) {
// TODO only consider the inner join currently, Should support outer join
// Split plan by the boundary which contains multi child
List<Plan> topPlans = new ArrayList<>();
Plan bottomPlan = originalPlan.accept(PLAN_SPLITTER, topPlans);
Plan topPlan = topPlans.get(0);
PlanSplitContext planSplitContext = new PlanSplitContext(Sets.newHashSet(LogicalJoin.class));
originalPlan.accept(PLAN_SPLITTER, planSplitContext);

List<HyperGraph> structInfos = HyperGraph.toStructInfo(bottomPlan);
List<HyperGraph> structInfos = HyperGraph.toStructInfo(planSplitContext.getBottomPlan().get(0));
return structInfos.stream()
.map(hyperGraph -> new StructInfo(originalPlan, topPlan, bottomPlan, hyperGraph))
.map(hyperGraph -> new StructInfo(originalPlan, planSplitContext.getTopPlan(),
planSplitContext.getBottomPlan().get(0), hyperGraph))
.collect(Collectors.toList());
}

Expand Down Expand Up @@ -189,6 +194,14 @@ public boolean isValid() {
return valid;
}

public Plan getTopPlan() {
return topPlan;
}

public Plan getBottomPlan() {
return bottomPlan;
}

public List<? extends Expression> getExpressions() {
return originalPlan instanceof LogicalProject
? ((LogicalProject<Plan>) originalPlan).getProjects() : originalPlan.getOutput();
Expand Down Expand Up @@ -224,23 +237,56 @@ public Void visit(Plan plan, Set<Expression> predicates) {
}
}

private static class PlanSplitter extends DefaultPlanRewriter<List<Plan>> {

/**
* Split the plan into bottom and up, the boundary is given by context,
* the bottom contains the boundary.
*/
public static class PlanSplitter extends DefaultPlanVisitor<Void, PlanSplitContext> {
@Override
public Plan visitLogicalRelation(LogicalRelation relation, List<Plan> topPlans) {
return relation;
public Void visit(Plan plan, PlanSplitContext context) {
if (context.getTopPlan() == null) {
context.setTopPlan(plan);
}
if (context.getClass().isAssignableFrom(plan.getClass())) {
context.setBottomPlan(plan.children());
context.setTopPlan(plan);
}
return super.visit(plan, context);
}
}

@Override
public Plan visit(Plan plan, List<Plan> topPlans) {
if (plan instanceof Join || plan instanceof SetOperation) {
return plan;
} else {
if (topPlans.isEmpty()) {
topPlans.add(plan);
public static class PlanSplitContext {
private List<Plan> bottomPlan;
private Plan topPlan;
private Set<Class<? extends Plan>> boundaryPlanClazzSet;

public PlanSplitContext(Set<Class<? extends Plan>> boundaryPlanClazzSet) {
this.boundaryPlanClazzSet = boundaryPlanClazzSet;
}

public List<Plan> getBottomPlan() {
return bottomPlan;
}

public void setBottomPlan(List<Plan> bottomPlan) {
this.bottomPlan = bottomPlan;
}

public Plan getTopPlan() {
return topPlan;
}

public void setTopPlan(Plan topPlan) {
this.topPlan = topPlan;
}

public boolean isBoundary(Plan plan) {
for (Class<? extends Plan> boundaryPlanClazz : boundaryPlanClazzSet) {
if (boundaryPlanClazz.isAssignableFrom(plan.getClass())) {
return true;
}
return plan.children().get(0).accept(this, topPlans);
}
return false;
}
}

Expand Down Expand Up @@ -269,4 +315,28 @@ public Boolean visit(Plan plan, Set<JoinType> requiredJoinType) {
return true;
}
}

/**
* AggregatePatternChecker
*/
public static class AggregatePatternChecker extends DefaultPlanVisitor<Boolean, Void> {
@Override
public Boolean visit(Plan plan, Void context) {
if (plan instanceof LogicalAggregate) {
LogicalAggregate<Plan> aggregate = (LogicalAggregate<Plan>) plan;
Optional<LogicalRepeat<?>> sourceRepeat = aggregate.getSourceRepeat();
if (sourceRepeat.isPresent()) {
return false;
}
super.visit(aggregate, context);
return true;
}
if (plan instanceof LogicalProject) {
super.visit(plan, context);
return true;
}
super.visit(plan, context);
return false;
}
}
}

0 comments on commit 321c0e5

Please sign in to comment.