Skip to content

Commit

Permalink
push encode
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Dec 24, 2024
1 parent 9b0d962 commit d436870
Show file tree
Hide file tree
Showing 40 changed files with 1,377 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,9 @@ public TSortInfo toThrift() {
}
return sortInfo;
}

@Override
public String toString() {
return orderingExprs.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ public static ConnectContext createMTMVContext(MTMV mtmv) {
ctx.getSessionVariable().allowModifyMaterializedViewData = true;
// Disable add default limit rule to avoid refresh data wrong
ctx.getSessionVariable().setDisableNereidsRules(
String.join(",", ImmutableSet.of(RuleType.ADD_DEFAULT_LIMIT.name())));
String.join(",", ImmutableSet.of(
"COMPRESSED_MATERIALIZE_AGG", "COMPRESSED_MATERIALIZE_SORT",
RuleType.ADD_DEFAULT_LIMIT.name())));
Optional<String> workloadGroup = mtmv.getWorkloadGroup();
if (workloadGroup.isPresent()) {
ctx.getSessionVariable().setWorkloadGroup(workloadGroup.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite;
import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite;
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
import org.apache.doris.nereids.rules.rewrite.DecoupleEncodeDecode;
import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult;
import org.apache.doris.nereids.rules.rewrite.EliminateAggCaseWhen;
import org.apache.doris.nereids.rules.rewrite.EliminateAggregate;
Expand Down Expand Up @@ -115,6 +116,7 @@
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownAggWithDistinctThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownEncodeSlot;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.PushDownLimit;
import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin;
Expand Down Expand Up @@ -253,6 +255,13 @@ public class Rewriter extends AbstractBatchJobExecutor {
new CountLiteralRewrite(),
new NormalizeSort()
),

topDown(// must behind NormalizeAggregate/NormalizeSort
new MergeProjects(),
new PushDownEncodeSlot(),
new DecoupleEncodeDecode()
),

topic("Window analysis",
topDown(
new ExtractAndNormalizeWindowExpression(),
Expand Down Expand Up @@ -372,9 +381,11 @@ public class Rewriter extends AbstractBatchJobExecutor {
// generate one PhysicalLimit if current distribution is gather or two
// PhysicalLimits with gather exchange
topDown(new LimitSortToTopN()),
topDown(new SimplifyEncodeDecode()),
topDown(new LimitAggToTopNAgg()),
topDown(new MergeTopNs()),
topDown(new SimplifyEncodeDecode(),
new MergeProjects()
),
topDown(new LimitAggToTopNAgg()),
topDown(new SplitLimit()),
topDown(
new PushDownLimit(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,14 @@ public class LogicalProperties {
protected final Supplier<DataTrait> dataTraitSupplier;
private Integer hashCode = null;

public LogicalProperties(Supplier<List<Slot>> outputSupplier,
Supplier<DataTrait> dataTraitSupplier) {
this(outputSupplier, dataTraitSupplier, ImmutableList::of);
}

/**
* constructor of LogicalProperties.
*
* @param outputSupplier provide the output. Supplier can lazy compute output without
* throw exception for which children have UnboundRelation
*/
public LogicalProperties(Supplier<List<Slot>> outputSupplier,
Supplier<DataTrait> dataTraitSupplier,
Supplier<List<Slot>> nonUserVisibleOutputSupplier) {
Supplier<DataTrait> dataTraitSupplier) {
this.outputSupplier = Suppliers.memoize(
Objects.requireNonNull(outputSupplier, "outputSupplier can not be null")
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ public enum RuleType {
// rewrite rules
COMPRESSED_MATERIALIZE_AGG(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_SORT(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_REPEAT(RuleTypeClass.REWRITE),
PUSH_DOWN_ENCODE_SLOT(RuleTypeClass.REWRITE),
DECOUPLE_DECODE_ENCODE_SLOT(RuleTypeClass.REWRITE),
SIMPLIFY_ENCODE_DECODE(RuleTypeClass.REWRITE),
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
NORMALIZE_SORT(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsSmallInt;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.CharacterType;
Expand Down Expand Up @@ -72,6 +73,11 @@ public List<Rule> buildRules() {
logicalSort().when(a -> ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().enableCompressMaterialize)
.then(this::compressMaterializeSort)
// ),
// RuleType.COMPRESSED_MATERIALIZE_REPEAT.build(
// logicalRepeat().when(r -> ConnectContext.get() != null
// && ConnectContext.get().getSessionVariable().enableCompressMaterialize)
// .then(this::compressMaterializeRepeat)
)
);
}
Expand Down Expand Up @@ -101,6 +107,9 @@ private LogicalSort<Plan> compressMaterializeSort(LogicalSort<Plan> sort) {
}

private Optional<Expression> getEncodeExpression(Expression expression) {
if (expression.isConstant()) {
return Optional.empty();
}
DataType type = expression.getDataType();
Expression encodeExpr = null;
if (type instanceof CharacterType) {
Expand Down Expand Up @@ -169,4 +178,52 @@ private LogicalAggregate<Plan> compressedMaterializeAggregate(LogicalAggregate<P
}
return aggregate;
}

private Map<Expression, Expression> getEncodeGroupingSets(LogicalRepeat<Plan> repeat) {
Map<Expression, Expression> encode = Maps.newHashMap();
// the first grouping set contains all group by keys
for (Expression gb : repeat.getGroupingSets().get(0)) {
Optional<Expression> encodeExpr = getEncodeExpression(gb);
encodeExpr.ifPresent(expression -> encode.put(gb, expression));
}
return encode;
}

private LogicalRepeat<Plan> compressMaterializeRepeat(LogicalRepeat<Plan> repeat) {
Map<Expression, Expression> encode = getEncodeGroupingSets(repeat);
if (encode.isEmpty()) {
return repeat;
}
List<List<Expression>> newGroupingSets = Lists.newArrayList();
for (int i = 0; i < repeat.getGroupingSets().size(); i++) {
List<Expression> grouping = Lists.newArrayList();
for (int j = 0; j < repeat.getGroupingSets().get(i).size(); j++) {
Expression groupingExpr = repeat.getGroupingSets().get(i).get(j);
grouping.add(encode.getOrDefault(groupingExpr, groupingExpr));
}
newGroupingSets.add(grouping);
}
List<NamedExpression> newOutputs = Lists.newArrayList();
Map<Expression, Expression> decodeMap = new HashMap<>();
for (Expression gp : encode.keySet()) {
decodeMap.put(gp, new DecodeAsVarchar(encode.get(gp)));
}
for (NamedExpression out : repeat.getOutputExpressions()) {
Expression replaced = ExpressionUtils.replace(out, decodeMap);
if (out != replaced) {
if (out instanceof SlotReference) {
newOutputs.add(new Alias(out.getExprId(), replaced, out.getName()));
} else if (out instanceof Alias) {
newOutputs.add(((Alias) out).withChildren(replaced.children()));
} else {
// should not reach here
Preconditions.checkArgument(false, "output abnormal: " + repeat);
}
} else {
newOutputs.add(out);
}
}
repeat = repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs);
return repeat;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeString;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.Lists;

import java.util.List;

/**
* in project:
* decode_as_varchar(encode_as_xxx(v)) => v
*/
public class DecoupleEncodeDecode extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject().then(this::rewrite)
.toRule(RuleType.DECOUPLE_DECODE_ENCODE_SLOT);
}

private LogicalProject<?> rewrite(LogicalProject<?> project) {
List<NamedExpression> newProjections = Lists.newArrayList();
boolean hasNewProjections = false;
for (NamedExpression e : project.getProjects()) {
boolean changed = false;
if (e instanceof Alias) {
Alias alias = (Alias) e;
Expression body = alias.child();
if (body instanceof DecodeAsVarchar && body.child(0) instanceof EncodeString) {
Expression encodeBody = body.child(0).child(0);
newProjections.add((NamedExpression) alias.withChildren(encodeBody));
changed = true;
} else if (body instanceof EncodeString && body.child(0) instanceof DecodeAsVarchar) {
Expression decodeBody = body.child(0).child(0);
newProjections.add((NamedExpression) alias.withChildren(decodeBody));
changed = true;
}
}
if (!changed) {
newProjections.add(e);
hasNewProjections = true;
}
}
if (hasNewProjections) {
project = project.withProjects(newProjections);
}
return project;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
Expand All @@ -35,6 +34,8 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.HashMap;
import java.util.List;
Expand All @@ -50,6 +51,8 @@
* 2. push limit to local agg
*/
public class LimitAggToTopNAgg implements RewriteRuleFactory {
public static final Logger LOG = LogManager.getLogger(LimitAggToTopNAgg.class);

@Override
public List<Rule> buildRules() {
return ImmutableList.of(
Expand Down Expand Up @@ -122,6 +125,8 @@ public List<Rule> buildRules() {
LogicalTopN originTopn = topn;
LogicalProject<? extends Plan> project = topn.child();
LogicalAggregate<? extends Plan> agg = (LogicalAggregate) project.child();
StringBuilder builder = new StringBuilder();
builder.append("@@@@@###");
if (!project.isAllSlots()) {
/*
topn(orderKey=[a])
Expand All @@ -139,6 +144,9 @@ public List<Rule> buildRules() {
keyAsKey.put((SlotReference) e.toSlot(), (SlotReference) e.child(0));
}
}
builder.append(topn);
builder.append(project);

List<OrderKey> projectOrderKeys = Lists.newArrayList();
boolean hasNew = false;
for (OrderKey orderKey : topn.getOrderKeys()) {
Expand All @@ -157,36 +165,24 @@ public List<Rule> buildRules() {
supplementOrderKeyByGroupKeyIfCompatible(topn, agg);
Plan result;
if (pair == null) {
builder.append("|not compatible");
result = originTopn;
} else {
builder.append("|compatible");
agg = agg.withGroupBy(pair.second);
topn = (LogicalTopN) topn.withOrderKeys(pair.first);
if (isOrderKeysInProject(topn, project)) {
project = (LogicalProject<? extends Plan>) project.withChildren(agg);
topn = (LogicalTopN<LogicalProject<LogicalAggregate<Plan>>>)
topn.withChildren(project);
result = topn;
} else {
topn = (LogicalTopN) topn.withChildren(agg);
project = (LogicalProject<? extends Plan>) project.withChildren(topn);
result = project;
}
topn = (LogicalTopN) topn.withChildren(agg);
project = (LogicalProject<? extends Plan>) project.withChildren(topn);
result = project;
}
LOG.warn(builder.toString());
LOG.warn("@@@@@###originTopn " + originTopn.treeString());
LOG.warn("@@@@@###result " + result.treeString());
return result;
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG)
);
}

private boolean isOrderKeysInProject(LogicalTopN<? extends Plan> topn, LogicalProject project) {
Set<Slot> projectSlots = project.getOutputSet();
for (OrderKey orderKey : topn.getOrderKeys()) {
if (!projectSlots.contains(orderKey.getExpr())) {
return false;
}
}
return true;
}

private List<OrderKey> generateOrderKeyByGroupKey(LogicalAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.map(key -> new OrderKey(key, true, false))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,17 @@ boolean comparePlan(Plan plan1, Plan plan2) {
isEqual = false;
}
for (int i = 0; isEqual && i < plan2.getOutput().size(); i++) {
NamedExpression expr = ((LogicalProject<?>) plan1).getProjects().get(i);
NamedExpression replacedExpr = (NamedExpression)
expr.rewriteUp(e -> plan1ToPlan2.getOrDefault(e, e));
if (!replacedExpr.equals(((LogicalProject<?>) plan2).getProjects().get(i))) {
Expression expr1 = ((LogicalProject<?>) plan1).getProjects().get(i);
Expression expr2 = ((LogicalProject<?>) plan2).getProjects().get(i);
if (expr1 instanceof Alias) {
if (!(expr2 instanceof Alias)) {
return false;
}
expr1 = expr1.child(0);
expr2 = expr2.child(0);
}
Expression replacedExpr = expr1.rewriteUp(e -> plan1ToPlan2.getOrDefault(e, e));
if (!replacedExpr.equals(expr2)) {
isEqual = false;
}
}
Expand Down
Loading

0 comments on commit d436870

Please sign in to comment.