Skip to content

Commit

Permalink
compress_materialize for aggregate and sort
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Nov 5, 2024
1 parent 800f5c6 commit 916bd8d
Show file tree
Hide file tree
Showing 14 changed files with 496 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
import org.apache.doris.nereids.rules.analysis.CollectJoinConstraint;
import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias;
import org.apache.doris.nereids.rules.analysis.CompressedMaterialize;
import org.apache.doris.nereids.rules.analysis.EliminateDistinctConstant;
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint;
Expand Down Expand Up @@ -166,6 +167,7 @@ private static List<RewriteJob> buildAnalyzerJobs(Optional<CustomTableResolver>
topDown(new EliminateGroupByConstant()),

topDown(new SimplifyAggGroupBy()),
topDown(new CompressedMaterialize()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),
topDown(new QualifyToFilter()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ public enum RuleType {
CHECK_DATA_TYPES(RuleTypeClass.CHECK),

// rewrite rules
COMPRESSED_MATERIALIZE_AGG(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_SORT(RuleTypeClass.REWRITE),
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
NORMALIZE_SORT(RuleTypeClass.REWRITE),
NORMALIZE_REPEAT(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsBigInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsLargeInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsSmallInt;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
* select A from T group by A
* =>
* select any_value(A) from T group by encode_as_int(A)
*/
public class CompressedMaterialize implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.COMPRESSED_MATERIALIZE_AGG.build(
logicalAggregate().when(a -> ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().enableCompressMaterialize)
.then(this::compressedMaterializeAggregate)),
RuleType.COMPRESSED_MATERIALIZE_SORT.build(
logicalSort().when(a -> ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().enableCompressMaterialize)
.then(this::compressMaterializeSort)
)
);
}

private LogicalSort<Plan> compressMaterializeSort(LogicalSort<Plan> sort) {
// List<Expression> orderExpressions = sort.getOrderKeys().stream()
// .map(OrderKey::getExpr).collect(Collectors.toList());
List<OrderKey> newOrderKeys = Lists.newArrayList();
boolean changed = false;
for (OrderKey orderKey : sort.getOrderKeys()) {
Expression expr = orderKey.getExpr();
Optional<Expression> encode = getEncodeExpression(expr);
if (encode.isPresent()) {
newOrderKeys.add(new OrderKey(encode.get(),
orderKey.isAsc(),
orderKey.isNullFirst()));
changed = true;
} else {
newOrderKeys.add(orderKey);
}
}
return changed ? sort.withOrderKeys(newOrderKeys) : sort;
}

private Optional<Expression> getEncodeExpression(Expression expression) {
DataType type = expression.getDataType();
Expression encodeExpr = null;
if (type instanceof CharacterType) {
CharacterType ct = (CharacterType) type;
if (ct.getLen() > 0) {
// skip column from variant, like 'L.var["L_SHIPMODE"] AS TEXT'
if (ct.getLen() < 2) {
encodeExpr = new EncodeAsSmallInt(expression);
} else if (ct.getLen() < 4) {
encodeExpr = new EncodeAsInt(expression);
} else if (ct.getLen() < 7) {
encodeExpr = new EncodeAsBigInt(expression);
} else if (ct.getLen() < 15) {
encodeExpr = new EncodeAsLargeInt(expression);
}
}
}
return Optional.ofNullable(encodeExpr);
}

/*
example:
[support] select sum(v) from t group by substring(k, 1,2)
[not support] select substring(k, 1,2), sum(v) from t group by substring(k, 1,2)
[support] select k, sum(v) from t group by k
[not support] select substring(k, 1,2), sum(v) from t group by k
[support] select A as B from T group by A
*/
private Map<Expression, Expression> getEncodeGroupByExpressions(LogicalAggregate<Plan> aggregate) {
Map<Expression, Expression> encodeGroupbyExpressions = Maps.newHashMap();
for (Expression gb : aggregate.getGroupByExpressions()) {
Optional<Expression> encodeExpr = getEncodeExpression(gb);
encodeExpr.ifPresent(expression -> encodeGroupbyExpressions.put(gb, expression));
}
return encodeGroupbyExpressions;
}

private LogicalAggregate<Plan> compressedMaterializeAggregate(LogicalAggregate<Plan> aggregate) {
Map<Expression, Expression> encodeGroupByExpressions = getEncodeGroupByExpressions(aggregate);
if (!encodeGroupByExpressions.isEmpty()) {
List<Expression> newGroupByExpressions = Lists.newArrayList();
for (Expression gp : aggregate.getGroupByExpressions()) {
newGroupByExpressions.add(encodeGroupByExpressions.getOrDefault(gp, gp));
}
List<NamedExpression> newOutputs = Lists.newArrayList();
Map<Expression, Expression> decodeMap = new HashMap<>();
for (Expression gp : encodeGroupByExpressions.keySet()) {
decodeMap.put(gp, new DecodeAsVarchar(encodeGroupByExpressions.get(gp)));
}
for (NamedExpression out : aggregate.getOutputExpressions()) {
Expression replaced = ExpressionUtils.replace(out, decodeMap);
if (out != replaced) {
if (out instanceof SlotReference) {
newOutputs.add(new Alias(out.getExprId(), replaced, out.getName()));
} else if (out instanceof Alias) {
newOutputs.add(((Alias) out).withChildren(replaced.children()));
} else {
// should not reach here
Preconditions.checkArgument(false, "output abnormal: " + aggregate);
}
} else {
newOutputs.add(out);
}
}
aggregate = aggregate.withGroupByAndOutput(newGroupByExpressions, newOutputs);
}
return aggregate;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeStrToInteger;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
Expand All @@ -30,10 +32,12 @@
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Push down filter through project.
Expand Down Expand Up @@ -81,7 +85,7 @@ private static Plan pushDownFilterThroughProject(LogicalFilter<LogicalProject<Pl
return null;
}
project = (LogicalProject<? extends Plan>) project.withChildren(new LogicalFilter<>(
ExpressionUtils.replace(splitConjuncts.second, project.getAliasToProducer()),
ExpressionUtils.replace(eliminateDecodeAndEncode(splitConjuncts.second), project.getAliasToProducer()),
project.child()));
return PlanUtils.filterOrSelf(splitConjuncts.first, project);
}
Expand All @@ -99,7 +103,7 @@ private static Plan pushDownFilterThroughLimitProject(
}
project = project.withProjectsAndChild(project.getProjects(),
new LogicalFilter<>(
ExpressionUtils.replace(splitConjuncts.second,
ExpressionUtils.replace(eliminateDecodeAndEncode(splitConjuncts.second),
project.getAliasToProducer()),
limit.withChildren(project.child())));
return PlanUtils.filterOrSelf(splitConjuncts.first, project);
Expand All @@ -119,4 +123,31 @@ private static Pair<Set<Expression>, Set<Expression>> splitConjunctsByChildOutpu
}
return Pair.of(remainPredicates, pushDownPredicates);
}

private static Set<Expression> eliminateDecodeAndEncode(Set<Expression> expressions) {
return expressions.stream()
.map(PushDownFilterThroughProject::eliminateDecodeAndEncode)
.collect(Collectors.toSet());
}

private static Expression eliminateDecodeAndEncode(Expression expression) {
if (expression instanceof DecodeAsVarchar && expression.child(0) instanceof EncodeStrToInteger) {
return expression.child(0).child(0);
}
boolean hasNewChild = false;
List<Expression> newChildren = Lists.newArrayList();
for (Expression child : expression.children()) {
Expression replace = eliminateDecodeAndEncode(child);
if (replace != child) {
hasNewChild = true;
newChildren.add(replace);
} else {
newChildren.add(child);
}
}
if (hasNewChild) {
return expression.withChildren(newChildren);
}
return expression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* ScalarFunction 'EncodeAsBigInt'.
*/
public class EncodeAsBigInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* ScalarFunction 'EncodeAsInt'.
*/
public class EncodeAsInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(IntegerType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* ScalarFunction 'EncodeAsLargeInt'.
*/
public class EncodeAsLargeInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(LargeIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* ScalarFunction 'CompressAsSmallInt'.
*/
public class EncodeAsSmallInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(SmallIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

/**
* Encode_as_XXXInt
*/
public interface EncodeStrToInteger {
}
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ public int hashCode() {
aggregateParam, maybeUsingStream, requireProperties);
}

public PhysicalHashAggregate<Plan> withGroupByExpressions(List<Expression> newGroupByExpressions) {
return new PhysicalHashAggregate<>(newGroupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(),
requireProperties, physicalProperties, statistics,
child());
}

@Override
public PhysicalHashAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,15 @@ public void setIgnoreShapePlanNodes(String ignoreShapePlanNodes) {
needForward = true, fuzzy = true)
public boolean enableSortSpill = false;

@VariableMgr.VarAttr(
name = "ENABLE_COMPRESS_MATERIALIZE",
description = {"控制是否启用compress materialize。",
"enable compress-materialize. "},
needForward = false, fuzzy = false,
varType = VariableAnnotation.EXPERIMENTAL
)
public boolean enableCompressMaterialize = false;

@VariableMgr.VarAttr(
name = ENABLE_AGG_SPILL,
description = {"控制是否启用聚合算子落盘。默认为 false。",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ public void testCTEInHavingAndSubquery() {
logicalFilter(
logicalProject(
logicalJoin(
logicalAggregate(),
logicalProject(
logicalAggregate()),
logicalProject()
)
)
Expand Down
Loading

0 comments on commit 916bd8d

Please sign in to comment.