Skip to content

Commit

Permalink
compress sort key
Browse files Browse the repository at this point in the history
compressed materialize for group by
  • Loading branch information
englefly committed Oct 31, 2024
1 parent dd03546 commit 8d6d16b
Show file tree
Hide file tree
Showing 28 changed files with 769 additions and 264 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
import org.apache.doris.nereids.rules.analysis.CollectJoinConstraint;
import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias;
import org.apache.doris.nereids.rules.analysis.CompressedMaterialize;
import org.apache.doris.nereids.rules.analysis.EliminateDistinctConstant;
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint;
Expand Down Expand Up @@ -163,6 +164,7 @@ private static List<RewriteJob> buildAnalyzerJobs(Optional<CustomTableResolver>
topDown(new EliminateGroupByConstant()),

topDown(new SimplifyAggGroupBy()),
topDown(new CompressedMaterialize()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),
bottomUp(new SemiJoinCommute()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ public enum RuleType {
CHECK_DATA_TYPES(RuleTypeClass.CHECK),

// rewrite rules
COMPRESSED_MATERIALIZE_AGG(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_SORT(RuleTypeClass.REWRITE),
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
NORMALIZE_SORT(RuleTypeClass.REWRITE),
NORMALIZE_REPEAT(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsBigInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsLargeInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsSmallInt;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* select A from T group by A
* =>
* select any_value(A) from T group by encode_as_int(A)
*/
public class CompressedMaterialize implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.COMPRESSED_MATERIALIZE_AGG.build(
logicalAggregate().when(a -> ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().enableCompressMaterialize)
.then(this::compressedMaterializeAggregate)),
RuleType.COMPRESSED_MATERIALIZE_SORT.build(
logicalSort().then(this::compressMaterializeSort)
)
);
}

private LogicalSort<Plan> compressMaterializeSort(LogicalSort<Plan> sort) {
// List<Expression> orderExpressions = sort.getOrderKeys().stream()
// .map(OrderKey::getExpr).collect(Collectors.toList());
List<OrderKey> newOrderKeys = Lists.newArrayList();
boolean changed = false;
for (OrderKey orderKey : sort.getOrderKeys()) {
Expression expr = orderKey.getExpr();
Optional<Expression> encode = getEncodeExpression(expr);
if (encode.isPresent()) {
newOrderKeys.add(new OrderKey(encode.get(),
orderKey.isAsc(),
orderKey.isNullFirst()));
changed = true;
} else {
newOrderKeys.add(orderKey);
}
}
return changed ? sort.withOrderKeys(newOrderKeys) : sort;
}

private Optional<Expression> getEncodeExpression(Expression expression) {
DataType type = expression.getDataType();
Expression encodeExpr = null;
if (type instanceof CharacterType) {
CharacterType ct = (CharacterType) type;
if (ct.getLen() > 1) {
// skip column from variant, like 'L.var["L_SHIPMODE"] AS TEXT'
if (ct.getLen() < 2) {
encodeExpr = new EncodeAsSmallInt(expression);
} else if (ct.getLen() < 4) {
encodeExpr = new EncodeAsInt(expression);
} else if (ct.getLen() < 7) {
encodeExpr = new EncodeAsBigInt(expression);
} else if (ct.getLen() < 15) {
encodeExpr = new EncodeAsLargeInt(expression);
}
}
}
return Optional.ofNullable(encodeExpr);
}

/*
example:
[support] select sum(v) from t group by substring(k, 1,2)
[not support] select substring(k, 1,2), sum(v) from t group by substring(k, 1,2)
[support] select k, sum(v) from t group by k
[not support] select substring(k, 1,2), sum(v) from t group by k
[support] select A as B from T group by A
*/
private Map<Expression, Expression> getEncodeGroupByExpressions(LogicalAggregate<Plan> aggregate) {
Map<Expression, Expression> encodeGroupbyExpressions = Maps.newHashMap();
for (Expression gb : aggregate.getGroupByExpressions()) {
Optional<Expression> encodeExpr = getEncodeExpression(gb);
encodeExpr.ifPresent(expression -> encodeGroupbyExpressions.put(gb, expression));
}
return encodeGroupbyExpressions;
}

private LogicalAggregate<Plan> compressedMaterializeAggregate(LogicalAggregate<Plan> aggregate) {
Map<Expression, Expression> encodeGroupByExpressions = getEncodeGroupByExpressions(aggregate);
if (!encodeGroupByExpressions.isEmpty()) {
List<Expression> newGroupByExpressions = Lists.newArrayList();
for (Expression gp : aggregate.getGroupByExpressions()) {
newGroupByExpressions.add(encodeGroupByExpressions.getOrDefault(gp, gp));
}
List<NamedExpression> newOutputs = Lists.newArrayList();
Map<Expression, Expression> decodeMap = new HashMap<>();
for (Expression gp : encodeGroupByExpressions.keySet()) {
decodeMap.put(gp, new DecodeAsVarchar(encodeGroupByExpressions.get(gp)));
}
for (NamedExpression out : aggregate.getOutputExpressions()) {
Expression replaced = ExpressionUtils.replace(out, decodeMap);
if (out != replaced) {
if (out instanceof SlotReference) {
newOutputs.add(new Alias(out.getExprId(), replaced, out.getName()));
} else if (out instanceof Alias) {
newOutputs.add(((Alias) out).withChildren(replaced.children()));
} else {
// should not reach here
Preconditions.checkArgument(false, "output abnormal: " + aggregate);
}
} else {
newOutputs.add(out);
}
}
aggregate = aggregate.withGroupByAndOutput(newGroupByExpressions, newOutputs);
}
return aggregate;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi

// Push down exprs:
// collect group by exprs

Set<Expression> groupingByExprs = Utils.fastToImmutableSet(aggregate.getGroupByExpressions());

// collect all trivial-agg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeStrToInteger;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
Expand All @@ -30,10 +32,12 @@
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Push down filter through project.
Expand Down Expand Up @@ -81,7 +85,7 @@ private static Plan pushDownFilterThroughProject(LogicalFilter<LogicalProject<Pl
return null;
}
project = (LogicalProject<? extends Plan>) project.withChildren(new LogicalFilter<>(
ExpressionUtils.replace(splitConjuncts.second, project.getAliasToProducer()),
ExpressionUtils.replace(eliminateDecodeAndEncode(splitConjuncts.second), project.getAliasToProducer()),
project.child()));
return PlanUtils.filterOrSelf(splitConjuncts.first, project);
}
Expand All @@ -99,7 +103,7 @@ private static Plan pushDownFilterThroughLimitProject(
}
project = project.withProjectsAndChild(project.getProjects(),
new LogicalFilter<>(
ExpressionUtils.replace(splitConjuncts.second,
ExpressionUtils.replace(eliminateDecodeAndEncode(splitConjuncts.second),
project.getAliasToProducer()),
limit.withChildren(project.child())));
return PlanUtils.filterOrSelf(splitConjuncts.first, project);
Expand All @@ -119,4 +123,31 @@ private static Pair<Set<Expression>, Set<Expression>> splitConjunctsByChildOutpu
}
return Pair.of(remainPredicates, pushDownPredicates);
}

private static Set<Expression> eliminateDecodeAndEncode(Set<Expression> expressions) {
return expressions.stream()
.map(PushDownFilterThroughProject::eliminateDecodeAndEncode)
.collect(Collectors.toSet());
}

private static Expression eliminateDecodeAndEncode(Expression expression) {
if (expression instanceof DecodeAsVarchar && expression.child(0) instanceof EncodeStrToInteger) {
return expression.child(0).child(0);
}
boolean hasNewChild = false;
List<Expression> newChildren = Lists.newArrayList();
for (Expression child : expression.children()) {
Expression replace = eliminateDecodeAndEncode(child);
if (replace != child) {
hasNewChild = true;
newChildren.add(replace);
} else {
newChildren.add(child);
}
}
if (hasNewChild) {
return expression.withChildren(newChildren);
}
return expression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* ScalarFunction 'EncodeAsBigInt'.
*/
public class EncodeAsBigInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* ScalarFunction 'EncodeAsInt'.
*/
public class EncodeAsInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(IntegerType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* ScalarFunction 'EncodeAsLargeInt'.
*/
public class EncodeAsLargeInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(LargeIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* ScalarFunction 'CompressAsSmallInt'.
*/
public class EncodeAsSmallInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(SmallIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

/**
* Encode_as_XXXInt
*/
public interface EncodeStrToInteger {
}
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ public int hashCode() {
aggregateParam, maybeUsingStream, requireProperties);
}

public PhysicalHashAggregate<Plan> withGroupByExpressions(List<Expression> newGroupByExpressions) {
return new PhysicalHashAggregate<>(newGroupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(),
requireProperties, physicalProperties, statistics,
child());
}

@Override
public PhysicalHashAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2098,6 +2098,14 @@ public void setIgnoreShapePlanNodes(String ignoreShapePlanNodes) {
needForward = true, fuzzy = true)
public boolean enableSortSpill = false;

@VariableMgr.VarAttr(
name = "ENABLE_COMPRESS_MATERIALIZE",
description = {"控制是否启用compress materialize。默认为 true。",
"Controls whether to enable compress materialize. "
+ "The default value is true."},
needForward = true, fuzzy = false)
public boolean enableCompressMaterialize = true;

@VariableMgr.VarAttr(
name = ENABLE_AGG_SPILL,
description = {"控制是否启用聚合算子落盘。默认为 false。",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ public void testCTEInHavingAndSubquery() {
logicalFilter(
logicalProject(
logicalJoin(
logicalAggregate(),
logicalProject(
logicalAggregate()),
logicalProject()
)
)
Expand Down
Loading

0 comments on commit 8d6d16b

Please sign in to comment.