Skip to content

Commit

Permalink
bitmap roll up develop
Browse files Browse the repository at this point in the history
  • Loading branch information
seawinde committed Dec 15, 2023
1 parent dccba70 commit 4775c12
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,31 @@ public class PlaceholderExpression extends Expression implements AlwaysNotNullab
* 1 based
*/
private final int position;
protected boolean distinct;

public PlaceholderExpression(List<Expression> children, Class<? extends Expression> delegateClazz, int position) {
super(children);
this.delegateClazz = Objects.requireNonNull(delegateClazz, "delegateClazz should not be null");
this.position = position;
}

public PlaceholderExpression(List<Expression> children, Class<? extends Expression> delegateClazz, int position,
boolean distinct) {
super(children);
this.delegateClazz = Objects.requireNonNull(delegateClazz, "delegateClazz should not be null");
this.position = position;
this.distinct = distinct;
}

public static PlaceholderExpression of(Class<? extends Expression> delegateClazz, int position) {
return new PlaceholderExpression(ImmutableList.of(), delegateClazz, position);
}

public static PlaceholderExpression of(Class<? extends Expression> delegateClazz, int position,
boolean distinct) {
return new PlaceholderExpression(ImmutableList.of(), delegateClazz, position, distinct);
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visit(this, context);
Expand All @@ -63,6 +77,10 @@ public int getPosition() {
return position;
}

public boolean isDistinct() {
return distinct;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -80,6 +98,6 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), delegateClazz, position);
return Objects.hash(super.hashCode(), delegateClazz, position, distinct);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.exploration.mv;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.analyzer.PlaceholderExpression;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
Expand All @@ -29,18 +30,27 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapCount;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -53,6 +63,17 @@
*/
public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule {

protected static final Map<PlaceholderExpression, PlaceholderExpression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = new HashMap<>();

static {
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
PlaceholderExpression.of(Count.class, 0, true),
new PlaceholderExpression(
ImmutableList.of(PlaceholderExpression.of(ToBitmap.class, 0)),
BitmapUnion.class, 0));
}

@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
Expand Down Expand Up @@ -135,14 +156,16 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
Expression needRollupShuttledExpr = ExpressionUtils.shuttleExpressionWithLineage(
topExpression,
queryTopPlan);

if (!mvExprToMvScanExprQueryBased.containsKey(needRollupShuttledExpr)) {
// function can not rewrite by view
return null;
}

// try to roll up
AggregateFunction needRollupAggFunction = (AggregateFunction) topExpression.firstMatch(
expr -> expr instanceof AggregateFunction);
AggregateFunction rollupAggregateFunction = rollup(needRollupAggFunction,
Function rollupAggregateFunction = rollup(needRollupAggFunction,
mvExprToMvScanExprQueryBased.get(needRollupShuttledExpr));
if (rollupAggregateFunction == null) {
return null;
Expand Down Expand Up @@ -226,15 +249,24 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
}

// only support sum roll up, support other agg functions later.
private AggregateFunction rollup(AggregateFunction originFunction,
private Function rollup(AggregateFunction originFunction,
Expression mappedExpression) {
Class<? extends AggregateFunction> rollupAggregateFunction = originFunction.getRollup();
Class<? extends Function> rollupAggregateFunction = originFunction.getRollup();
if (rollupAggregateFunction == null) {
return null;
}
if (Sum.class.isAssignableFrom(rollupAggregateFunction)) {
return new Sum(originFunction.isDistinct(), mappedExpression);
}
if (Max.class.isAssignableFrom(rollupAggregateFunction)) {
return new Max(originFunction.isDistinct(), mappedExpression);
}
if (Min.class.isAssignableFrom(rollupAggregateFunction)) {
return new Min(originFunction.isDistinct(), mappedExpression);
}
if (BitmapCount.class.isAssignableFrom(rollupAggregateFunction)) {
return new BitmapCount(mappedExpression);
}
// can rollup return null
return null;
}
Expand Down Expand Up @@ -306,4 +338,38 @@ protected boolean checkPattern(StructInfo structInfo) {
}
return true;
}

private boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
Class<? extends Function> queryClazz = queryFunction.getClass();
Class<? extends Function> viewClazz = viewFunction.getClass();
if (queryClazz.isAssignableFrom(viewClazz)) {
return true;
}
boolean isDistinct = queryFunction instanceof AggregateFunction
&& ((AggregateFunction) queryFunction).isDistinct();
PlaceholderExpression equivalentFunction = AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.get(
PlaceholderExpression.of(queryFunction.getClass(), 0, isDistinct));
// check is have equivalent function or not
if (equivalentFunction == null){
return false;
}
// current compare
if (!viewFunction.getClass().isAssignableFrom(equivalentFunction.getDelegateClazz())) {
return false;
}
if (!viewFunction.children().isEmpty()) {
// children compare, just compare two level, support more later
List<Expression> equivalentFunctions = equivalentFunction.children();
if (viewFunction.children().size() != equivalentFunctions.size()) {
return false;
}
for (int i = 0; i < viewFunction.children().size(); i++) {
if (!viewFunction.child(i).getClass().equals(
((PlaceholderExpression)equivalentFunctions.get(i)).getDelegateClazz())) {
return false;
}
}
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.typecoercion.ExpectsInputTypes;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
Expand Down Expand Up @@ -77,7 +78,7 @@ public boolean isDistinct() {
return distinct;
}

public Class<? extends AggregateFunction> getRollup() {
public Class<? extends Function> getRollup() {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapCount;
import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
Expand Down Expand Up @@ -142,4 +144,13 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public Class<? extends Function> getRollup() {
if (this.isDistinct()) {
return BitmapCount.class;
} else {
return Sum.class;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
Expand Down Expand Up @@ -80,4 +81,9 @@ public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMax(this, context);
}

@Override
public Class<? extends Function> getRollup() {
return Max.class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
Expand Down Expand Up @@ -81,4 +82,9 @@ public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMin(this, context);
}

@Override
public Class<? extends Function> getRollup() {
return Min.class;
}
}
Loading

0 comments on commit 4775c12

Please sign in to comment.