diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 6e140fff13f87a..149505b5785e13 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -99,6 +99,7 @@ import org.apache.doris.nereids.trees.expressions.WindowFrame; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; +import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue; import org.apache.doris.nereids.trees.plans.AbstractPlan; import org.apache.doris.nereids.trees.plans.AggMode; import org.apache.doris.nereids.trees.plans.AggPhase; @@ -963,7 +964,7 @@ public PlanFragment visitPhysicalHashAggregate( // 2. collect agg expressions and generate agg function to slot reference map List aggFunctionOutput = Lists.newArrayList(); List aggregateExpressionList = outputExpressions.stream() - .filter(o -> o.anyMatch(AggregateExpression.class::isInstance)) + .filter(o -> o.anyMatch(AggregateExpression.class::isInstance) || o.anyMatch(AnyValue.class::isInstance)) .peek(o -> aggFunctionOutput.add(o.toSlot())) .map(o -> o.collect(AggregateExpression.class::isInstance)) .flatMap(Set::stream) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CompressedMaterialization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CompressedMaterialization.java index e7557e9d562f57..a8b32d407c38da 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CompressedMaterialization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CompressedMaterialization.java @@ -16,8 +16,16 @@ import org.apache.doris.nereids.types.coercion.CharacterType; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import java.util.List; +import java.util.Map; + +/** + * select A, sum(B) from T group by A + * => + * select any_value(A) from T group by encode_as_int(A) + */ public class CompressedMaterialization extends PlanPostProcessor{ @Override @@ -25,10 +33,13 @@ public PhysicalHashAggregate visitPhysicalHashAggregate(PhysicalHashAggregate newGroupByExpressions = Lists.newArrayList(); List encodedExpressions = Lists.newArrayList(); + Map encodeMap = Maps.newHashMap(); for (Expression gp : aggregate.getGroupByExpressions()) { if (gp instanceof SlotReference && canCompress(gp)) { - newGroupByExpressions.add(new Alias(new EncodeAsInt(gp), ((SlotReference) gp).getName())); + Alias alias = new Alias(new EncodeAsInt(gp), ((SlotReference) gp).getName()); + newGroupByExpressions.add(alias); encodedExpressions.add(gp); + encodeMap.put(gp, alias); } else { newGroupByExpressions.add(gp); } @@ -41,6 +52,7 @@ public PhysicalHashAggregate visitPhysicalHashAggregate(PhysicalHashAggregate getProcessors() { if (cascadesContext.getConnectContext().getSessionVariable().enableAggregateCse) { builder.add(new ProjectAggregateExpressionsForCse()); } - builder.add(new CompressedMaterialization()); + // builder.add(new CompressedMaterialization()); builder.add(new CommonSubExpressionOpt()); // DO NOT replace PLAN NODE from here if (cascadesContext.getConnectContext().getSessionVariable().pushTopnToAgg) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index e5ebee120a310c..a18ff819d36676 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -33,13 +33,17 @@ import org.apache.doris.nereids.trees.expressions.SubqueryExpr; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue; import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction; +import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsInt; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs; import org.apache.doris.nereids.util.Utils; @@ -47,6 +51,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import com.google.common.collect.Sets; import java.util.ArrayList; @@ -118,7 +124,59 @@ public List buildRules() { .toRule(RuleType.NORMALIZE_AGGREGATE)); } + private boolean canCompress(Expression expression) { + DataType type = expression.getDataType(); + if (type instanceof CharacterType) { + CharacterType ct = (CharacterType) type; + if (ct.getLen() < 7) { + return true; + } + } + return false; + } + + private LogicalAggregate encode(LogicalAggregate aggregate, Optional> having) { + List newGroupByExpressions = Lists.newArrayList(); + List encodedExpressions = Lists.newArrayList(); + Map encodeMap = Maps.newHashMap(); + for (Expression gp : aggregate.getGroupByExpressions()) { + if (gp instanceof SlotReference && canCompress(gp)) { + Alias alias = new Alias(new EncodeAsInt(gp), ((SlotReference) gp).getName()); + newGroupByExpressions.add(alias); + encodedExpressions.add(gp); + encodeMap.put(gp, alias); + } else { + newGroupByExpressions.add(gp); + } + } + if (!encodedExpressions.isEmpty()) { + // aggregate = aggregate.withGroupByExpressions(newGroupByExpressions); + // boolean hasNewOutput = false; + List newOutput = Lists.newArrayList(); + List output = aggregate.getOutputExpressions(); + for (NamedExpression ne : output) { + if (ne instanceof SlotReference && encodedExpressions.contains(ne)) { + newOutput.add(new Alias(ne.getExprId(), new AnyValue(ne), ne.getName())); + newOutput.add(encodeMap.get(ne)); + // hasNewOutput = true; + } else { + newOutput.add(ne); + } + } + aggregate = aggregate.withGroupByAndOutput(newGroupByExpressions, newOutput); + // if (hasNewOutput) { + // aggregate = aggregate.withAggOutput(newOutput); + // } + } + return aggregate; + } + private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional> having) { + aggregate = encode(aggregate, having); + return normalizeAggInner(aggregate, having); + } + + private LogicalPlan normalizeAggInner(LogicalAggregate aggregate, Optional> having) { // The LogicalAggregate node may contain window agg functions and usual agg functions // we call window agg functions as window-agg and usual agg functions as trivial-agg for short // This rule simplify LogicalAggregate node by: @@ -145,6 +203,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional groupingByExprs = Utils.fastToImmutableSet(aggregate.getGroupByExpressions()); // collect all trivial-agg