Skip to content

Commit

Permalink
v1
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Oct 16, 2024
1 parent ec02796 commit 796b747
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
import org.apache.doris.nereids.trees.expressions.WindowFrame;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.AggPhase;
Expand Down Expand Up @@ -963,7 +964,7 @@ public PlanFragment visitPhysicalHashAggregate(
// 2. collect agg expressions and generate agg function to slot reference map
List<Slot> aggFunctionOutput = Lists.newArrayList();
List<AggregateExpression> aggregateExpressionList = outputExpressions.stream()
.filter(o -> o.anyMatch(AggregateExpression.class::isInstance))
.filter(o -> o.anyMatch(AggregateExpression.class::isInstance) || o.anyMatch(AnyValue.class::isInstance))
.peek(o -> aggFunctionOutput.add(o.toSlot()))
.map(o -> o.<AggregateExpression>collect(AggregateExpression.class::isInstance))
.flatMap(Set::stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,30 @@
import org.apache.doris.nereids.types.coercion.CharacterType;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import java.util.List;
import java.util.Map;

/**
* select A, sum(B) from T group by A
* =>
* select any_value(A) from T group by encode_as_int(A)
*/

public class CompressedMaterialization extends PlanPostProcessor{
@Override
public PhysicalHashAggregate visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> aggregate,
CascadesContext context) {
List<Expression> newGroupByExpressions = Lists.newArrayList();
List<Expression> encodedExpressions = Lists.newArrayList();
Map<Expression, Alias> encodeMap = Maps.newHashMap();
for (Expression gp : aggregate.getGroupByExpressions()) {
if (gp instanceof SlotReference && canCompress(gp)) {
newGroupByExpressions.add(new Alias(new EncodeAsInt(gp), ((SlotReference) gp).getName()));
Alias alias = new Alias(new EncodeAsInt(gp), ((SlotReference) gp).getName());
newGroupByExpressions.add(alias);
encodedExpressions.add(gp);
encodeMap.put(gp, alias);
} else {
newGroupByExpressions.add(gp);
}
Expand All @@ -41,6 +52,7 @@ public PhysicalHashAggregate visitPhysicalHashAggregate(PhysicalHashAggregate<?
for (NamedExpression ne : output) {
if (ne instanceof SlotReference && encodedExpressions.contains(ne)) {
newOutput.add(new Alias(ne.getExprId(), new AnyValue(ne), ne.getName()));
newOutput.add(encodeMap.get(ne));
hasNewOutput = true;
} else {
newOutput.add(ne);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public List<PlanPostProcessor> getProcessors() {
if (cascadesContext.getConnectContext().getSessionVariable().enableAggregateCse) {
builder.add(new ProjectAggregateExpressionsForCse());
}
builder.add(new CompressedMaterialization());
// builder.add(new CompressedMaterialization());
builder.add(new CommonSubExpressionOpt());
// DO NOT replace PLAN NODE from here
if (cascadesContext.getConnectContext().getSessionVariable().pushTopnToAgg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,26 @@
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsInt;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.ArrayList;
Expand Down Expand Up @@ -118,7 +124,59 @@ public List<Rule> buildRules() {
.toRule(RuleType.NORMALIZE_AGGREGATE));
}

private boolean canCompress(Expression expression) {
DataType type = expression.getDataType();
if (type instanceof CharacterType) {
CharacterType ct = (CharacterType) type;
if (ct.getLen() < 7) {
return true;
}
}
return false;
}

private LogicalAggregate<Plan> encode(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having) {
List<Expression> newGroupByExpressions = Lists.newArrayList();
List<Expression> encodedExpressions = Lists.newArrayList();
Map<Expression, Alias> encodeMap = Maps.newHashMap();
for (Expression gp : aggregate.getGroupByExpressions()) {
if (gp instanceof SlotReference && canCompress(gp)) {
Alias alias = new Alias(new EncodeAsInt(gp), ((SlotReference) gp).getName());
newGroupByExpressions.add(alias);
encodedExpressions.add(gp);
encodeMap.put(gp, alias);
} else {
newGroupByExpressions.add(gp);
}
}
if (!encodedExpressions.isEmpty()) {
// aggregate = aggregate.withGroupByExpressions(newGroupByExpressions);
// boolean hasNewOutput = false;
List<NamedExpression> newOutput = Lists.newArrayList();
List<NamedExpression> output = aggregate.getOutputExpressions();
for (NamedExpression ne : output) {
if (ne instanceof SlotReference && encodedExpressions.contains(ne)) {
newOutput.add(new Alias(ne.getExprId(), new AnyValue(ne), ne.getName()));
newOutput.add(encodeMap.get(ne));
// hasNewOutput = true;
} else {
newOutput.add(ne);
}
}
aggregate = aggregate.withGroupByAndOutput(newGroupByExpressions, newOutput);
// if (hasNewOutput) {
// aggregate = aggregate.withAggOutput(newOutput);
// }
}
return aggregate;
}

private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having) {
aggregate = encode(aggregate, having);
return normalizeAggInner(aggregate, having);
}

private LogicalPlan normalizeAggInner(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having) {
// The LogicalAggregate node may contain window agg functions and usual agg functions
// we call window agg functions as window-agg and usual agg functions as trivial-agg for short
// This rule simplify LogicalAggregate node by:
Expand All @@ -145,6 +203,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi

// Push down exprs:
// collect group by exprs

Set<Expression> groupingByExprs = Utils.fastToImmutableSet(aggregate.getGroupByExpressions());

// collect all trivial-agg
Expand Down

0 comments on commit 796b747

Please sign in to comment.