Skip to content

Commit

Permalink
[feature](Nereids): use session variable to enable rule (apache#27036)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener authored Nov 20, 2023
1 parent 20d7ab0 commit fec94b7
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -71,7 +72,7 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(innerLogicalJoin())
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
.when(agg -> {
Expand All @@ -80,11 +81,19 @@ public List<Rule> buildRules() {
.allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot));
})
.then(agg -> pushCount(agg, agg.child(), ImmutableList.of()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
return pushCount(agg, agg.child(), ImmutableList.of());
})
.toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
.when(agg -> {
Expand All @@ -93,7 +102,15 @@ public List<Rule> buildRules() {
.allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot));
})
.then(agg -> pushCount(agg, agg.child().child(), agg.child().getProjects()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
return pushCount(agg, agg.child().child(), agg.child().getProjects());
})
.toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Relation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
Expand All @@ -29,6 +30,7 @@

import com.google.common.collect.ImmutableList;

import java.util.Set;
import java.util.function.Function;

/**
Expand All @@ -37,6 +39,11 @@
public class PushdownDistinctThroughJoin extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
@Override
public Plan rewriteRoot(Plan plan, JobContext context) {
Set<Integer> enableNereidsRules = context.getCascadesContext().getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_DISTINCT_THROUGH_JOIN.type())) {
return null;
}
return plan.accept(this, context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -65,19 +66,27 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(innerLogicalJoin())
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
0) instanceof Slot);
})
.then(agg -> pushMinMax(agg, agg.child(), ImmutableList.of()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
return pushMinMax(agg, agg.child(), ImmutableList.of());
})
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
Expand All @@ -86,7 +95,15 @@ public List<Rule> buildRules() {
f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
0) instanceof Slot);
})
.then(agg -> pushMinMax(agg, agg.child().child(), agg.child().getProjects()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
return pushMinMax(agg, agg.child().child(), agg.child().getProjects());
})
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
Expand Down Expand Up @@ -65,25 +66,41 @@ public class PushdownSumThroughJoin implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(innerLogicalJoin())
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
})
.then(agg -> pushSum(agg, agg.child(), ImmutableList.of()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
return pushSum(agg, agg.child(), ImmutableList.of());
})
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
})
.then(agg -> pushSum(agg, agg.child().child(), agg.child().getProjects()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
return pushSum(agg, agg.child().child(), agg.child().getProjects());
})
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN)
);
}
Expand Down
11 changes: 11 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,9 @@ public void setEnableLeftZigZag(boolean enableLeftZigZag) {
@VariableMgr.VarAttr(name = DISABLE_NEREIDS_RULES, needForward = true)
private String disableNereidsRules = "";

@VariableMgr.VarAttr(name = "ENABLE_NEREIDS_RULES", needForward = true)
public String enableNereidsRules = "";

@VariableMgr.VarAttr(name = ENABLE_NEW_COST_MODEL, needForward = true)
private boolean enableNewCostModel = false;

Expand Down Expand Up @@ -2285,6 +2288,14 @@ public Set<Integer> getDisableNereidsRules() {
.collect(ImmutableSet.toImmutableSet());
}

public Set<Integer> getEnableNereidsRules() {
return Arrays.stream(enableNereidsRules.split(",[\\s]*"))
.filter(rule -> !rule.isEmpty())
.map(rule -> rule.toUpperCase(Locale.ROOT))
.map(rule -> RuleType.valueOf(rule).type())
.collect(ImmutableSet.toImmutableSet());
}

public void setEnableNewCostModel(boolean enable) {
this.enableNewCostModel = enable;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.plans.JoinType;
Expand All @@ -28,16 +29,28 @@
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.SessionVariable;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import mockit.Mock;
import mockit.MockUp;
import org.junit.jupiter.api.Test;

import java.util.Set;

class PushdownCountThroughJoinTest implements MemoPatternMatchSupported {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

@Test
void testSingleCount() {
new MockUp<SessionVariable>() {
@Mock
public Set<Integer> getEnableNereidsRules() {
return ImmutableSet.of(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type());
}
};
Alias count = new Count(scan1.getOutput().get(0)).alias("count");
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
Expand All @@ -46,11 +59,24 @@ void testSingleCount() {

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownCountThroughJoin())
.printlnTree();
.matches(
logicalAggregate(
logicalJoin(
logicalAggregate(),
logicalAggregate()
)
)
);
}

@Test
void testMultiCount() {
new MockUp<SessionVariable>() {
@Mock
public Set<Integer> getEnableNereidsRules() {
return ImmutableSet.of(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type());
}
};
Alias leftCnt1 = new Count(scan1.getOutput().get(0)).alias("leftCnt1");
Alias leftCnt2 = new Count(scan1.getOutput().get(1)).alias("leftCnt2");
Alias rightCnt1 = new Count(scan2.getOutput().get(1)).alias("rightCnt1");
Expand All @@ -62,11 +88,24 @@ void testMultiCount() {

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownCountThroughJoin())
.printlnTree();
.matches(
logicalAggregate(
logicalJoin(
logicalAggregate(),
logicalAggregate()
)
)
);
}

@Test
void testSingleCountStar() {
new MockUp<SessionVariable>() {
@Mock
public Set<Integer> getEnableNereidsRules() {
return ImmutableSet.of(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type());
}
};
Alias count = new Count().alias("countStar");
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
Expand All @@ -75,11 +114,24 @@ void testSingleCountStar() {

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownCountThroughJoin())
.printlnTree();
.matches(
logicalAggregate(
logicalJoin(
logicalAggregate(),
logicalAggregate()
)
)
);
}

@Test
void testSingleCountStarEmptyGroupBy() {
new MockUp<SessionVariable>() {
@Mock
public Set<Integer> getEnableNereidsRules() {
return ImmutableSet.of(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type());
}
};
Alias count = new Count().alias("countStar");
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
Expand All @@ -89,11 +141,24 @@ void testSingleCountStarEmptyGroupBy() {
// shouldn't rewrite.
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownCountThroughJoin())
.printlnTree();
.matches(
logicalAggregate(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
)
);
}

@Test
void testBothSideCountAndCountStar() {
new MockUp<SessionVariable>() {
@Mock
public Set<Integer> getEnableNereidsRules() {
return ImmutableSet.of(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type());
}
};
Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt");
Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt");
Alias countStar = new Count().alias("countStar");
Expand All @@ -105,6 +170,13 @@ void testBothSideCountAndCountStar() {

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownCountThroughJoin())
.printlnTree();
.matches(
logicalAggregate(
logicalJoin(
logicalAggregate(),
logicalAggregate()
)
)
);
}
}
Loading

0 comments on commit fec94b7

Please sign in to comment.