Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat](nereids) push down encode slot #44748

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
The diff you're trying to view is too large. We only load the first 3000 changed files.
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,9 @@ public TSortInfo toThrift() {
}
return sortInfo;
}

@Override
public String toString() {
return orderingExprs.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ public static ConnectContext createMTMVContext(MTMV mtmv) {
ctx.getSessionVariable().allowModifyMaterializedViewData = true;
// Disable add default limit rule to avoid refresh data wrong
ctx.getSessionVariable().setDisableNereidsRules(
String.join(",", ImmutableSet.of(RuleType.ADD_DEFAULT_LIMIT.name())));
String.join(",", ImmutableSet.of(
"COMPRESSED_MATERIALIZE_AGG", "COMPRESSED_MATERIALIZE_SORT",
RuleType.ADD_DEFAULT_LIMIT.name())));
Optional<String> workloadGroup = mtmv.getWorkloadGroup();
if (workloadGroup.isPresent()) {
ctx.getSessionVariable().setWorkloadGroup(workloadGroup.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType;
import org.apache.doris.nereids.rules.rewrite.AdjustNullable;
import org.apache.doris.nereids.rules.rewrite.AdjustPreAggStatus;
import org.apache.doris.nereids.rules.rewrite.AdjustTopNProject;
import org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction;
import org.apache.doris.nereids.rules.rewrite.BuildAggForUnion;
import org.apache.doris.nereids.rules.rewrite.CTEInline;
Expand All @@ -55,6 +56,7 @@
import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite;
import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite;
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
import org.apache.doris.nereids.rules.rewrite.DecoupleEncodeDecode;
import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult;
import org.apache.doris.nereids.rules.rewrite.EliminateAggCaseWhen;
import org.apache.doris.nereids.rules.rewrite.EliminateAggregate;
Expand Down Expand Up @@ -115,6 +117,7 @@
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownAggWithDistinctThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownEncodeSlot;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.PushDownLimit;
import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin;
Expand Down Expand Up @@ -253,6 +256,13 @@ public class Rewriter extends AbstractBatchJobExecutor {
new CountLiteralRewrite(),
new NormalizeSort()
),

topDown(// must behind NormalizeAggregate/NormalizeSort
new MergeProjects(),
new PushDownEncodeSlot(),
new DecoupleEncodeDecode()
),

topic("Window analysis",
topDown(
new ExtractAndNormalizeWindowExpression(),
Expand Down Expand Up @@ -372,9 +382,11 @@ public class Rewriter extends AbstractBatchJobExecutor {
// generate one PhysicalLimit if current distribution is gather or two
// PhysicalLimits with gather exchange
topDown(new LimitSortToTopN()),
topDown(new SimplifyEncodeDecode()),
topDown(new LimitAggToTopNAgg()),
topDown(new MergeTopNs()),
topDown(new SimplifyEncodeDecode(),
new MergeProjects()
),
topDown(new LimitAggToTopNAgg()),
topDown(new SplitLimit()),
topDown(
new PushDownLimit(),
Expand Down Expand Up @@ -466,6 +478,9 @@ public class Rewriter extends AbstractBatchJobExecutor {
custom(RuleType.ADD_PROJECT_FOR_JOIN, AddProjectForJoin::new),
topDown(new MergeProjects())
),
topic("Adjust topN project",
topDown(new MergeProjects(),
new AdjustTopNProject())),
// this rule batch must keep at the end of rewrite to do some plan check
topic("Final rewrite and check",
custom(RuleType.CHECK_DATA_TYPES, CheckDataTypes::new),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,14 @@ public class LogicalProperties {
protected final Supplier<DataTrait> dataTraitSupplier;
private Integer hashCode = null;

public LogicalProperties(Supplier<List<Slot>> outputSupplier,
Supplier<DataTrait> dataTraitSupplier) {
this(outputSupplier, dataTraitSupplier, ImmutableList::of);
}

/**
* constructor of LogicalProperties.
*
* @param outputSupplier provide the output. Supplier can lazy compute output without
* throw exception for which children have UnboundRelation
*/
public LogicalProperties(Supplier<List<Slot>> outputSupplier,
Supplier<DataTrait> dataTraitSupplier,
Supplier<List<Slot>> nonUserVisibleOutputSupplier) {
Supplier<DataTrait> dataTraitSupplier) {
this.outputSupplier = Suppliers.memoize(
Objects.requireNonNull(outputSupplier, "outputSupplier can not be null")
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ public enum RuleType {
// rewrite rules
COMPRESSED_MATERIALIZE_AGG(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_SORT(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_REPEAT(RuleTypeClass.REWRITE),
PUSH_DOWN_ENCODE_SLOT(RuleTypeClass.REWRITE),
ADJUST_TOPN_PROJECT(RuleTypeClass.REWRITE),
DECOUPLE_DECODE_ENCODE_SLOT(RuleTypeClass.REWRITE),
SIMPLIFY_ENCODE_DECODE(RuleTypeClass.REWRITE),
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
NORMALIZE_SORT(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsSmallInt;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.CharacterType;
Expand Down Expand Up @@ -72,6 +73,11 @@ public List<Rule> buildRules() {
logicalSort().when(a -> ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().enableCompressMaterialize)
.then(this::compressMaterializeSort)
// ),
// RuleType.COMPRESSED_MATERIALIZE_REPEAT.build(
// logicalRepeat().when(r -> ConnectContext.get() != null
// && ConnectContext.get().getSessionVariable().enableCompressMaterialize)
// .then(this::compressMaterializeRepeat)
)
);
}
Expand Down Expand Up @@ -101,6 +107,9 @@ private LogicalSort<Plan> compressMaterializeSort(LogicalSort<Plan> sort) {
}

private Optional<Expression> getEncodeExpression(Expression expression) {
if (expression.isConstant()) {
return Optional.empty();
}
DataType type = expression.getDataType();
Expression encodeExpr = null;
if (type instanceof CharacterType) {
Expand Down Expand Up @@ -169,4 +178,52 @@ private LogicalAggregate<Plan> compressedMaterializeAggregate(LogicalAggregate<P
}
return aggregate;
}

private Map<Expression, Expression> getEncodeGroupingSets(LogicalRepeat<Plan> repeat) {
Map<Expression, Expression> encode = Maps.newHashMap();
// the first grouping set contains all group by keys
for (Expression gb : repeat.getGroupingSets().get(0)) {
Optional<Expression> encodeExpr = getEncodeExpression(gb);
encodeExpr.ifPresent(expression -> encode.put(gb, expression));
}
return encode;
}

private LogicalRepeat<Plan> compressMaterializeRepeat(LogicalRepeat<Plan> repeat) {
Map<Expression, Expression> encode = getEncodeGroupingSets(repeat);
if (encode.isEmpty()) {
return repeat;
}
List<List<Expression>> newGroupingSets = Lists.newArrayList();
for (int i = 0; i < repeat.getGroupingSets().size(); i++) {
List<Expression> grouping = Lists.newArrayList();
for (int j = 0; j < repeat.getGroupingSets().get(i).size(); j++) {
Expression groupingExpr = repeat.getGroupingSets().get(i).get(j);
grouping.add(encode.getOrDefault(groupingExpr, groupingExpr));
}
newGroupingSets.add(grouping);
}
List<NamedExpression> newOutputs = Lists.newArrayList();
Map<Expression, Expression> decodeMap = new HashMap<>();
for (Expression gp : encode.keySet()) {
decodeMap.put(gp, new DecodeAsVarchar(encode.get(gp)));
}
for (NamedExpression out : repeat.getOutputExpressions()) {
Expression replaced = ExpressionUtils.replace(out, decodeMap);
if (out != replaced) {
if (out instanceof SlotReference) {
newOutputs.add(new Alias(out.getExprId(), replaced, out.getName()));
} else if (out instanceof Alias) {
newOutputs.add(((Alias) out).withChildren(replaced.children()));
} else {
// should not reach here
Preconditions.checkArgument(false, "output abnormal: " + repeat);
}
} else {
newOutputs.add(out);
}
}
repeat = repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs);
return repeat;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
*
* try to reduce shuffle cost of topN operator
*
* topn(orderKey=[a])
* --> project(a+1 as x, a+2 as y, a)
* --> any(output(a))
* =>
* project(a+1 as x, a+2 as y, a)
* --> topn(orderKey=[a])
* --> any(output(a))
*
*/
public class AdjustTopNProject extends OneRewriteRuleFactory {
public static final Logger LOG = LogManager.getLogger(AdjustTopNProject.class);
@Override
public Rule build() {
return logicalTopN(logicalProject(logicalAggregate()))
.then(topN -> adjust(topN)).toRule(RuleType.ADJUST_TOPN_PROJECT);
}

private Plan adjust(LogicalTopN<? extends Plan> topN) {
LogicalProject<Plan> project = (LogicalProject<Plan>) topN.child();
Set<Slot> projectInputSlots = project.getInputSlots();
Map<SlotReference, SlotReference> keyAsKey = new HashMap<>();
for (NamedExpression proj : project.getProjects()) {
if (proj instanceof Alias && ((Alias) proj).child(0) instanceof SlotReference) {
keyAsKey.put((SlotReference) ((Alias) proj).toSlot(), (SlotReference) ((Alias) proj).child());
}
}
boolean match = true;
List<OrderKey> newOrderKeys = new ArrayList<>();
for (OrderKey orderKey : topN.getOrderKeys()) {
Expression orderExpr = orderKey.getExpr();
if (orderExpr instanceof SlotReference) {
if (projectInputSlots.contains(orderExpr)) {
newOrderKeys.add(orderKey);
} else if (keyAsKey.containsKey(orderExpr)) {
newOrderKeys.add(orderKey.withExpression(keyAsKey.get(orderExpr)));
} else {
match = false;
break;
}
} else {
match = false;
break;
}
}
if (match) {
if (project.getProjects().size() >= project.getInputSlots().size()) {
LOG.info("$$$$ before: project.getProjects() = " + project.getProjects());
LOG.info("$$$$ before: project.getInputSlots() = " + project.getInputSlots());
LOG.info("$$$$ before: " + topN.treeString());
topN = topN.withChildren(project.children()).withOrderKeys(newOrderKeys);
project = (LogicalProject<Plan>) project.withChildren(topN);
LOG.info("$$$$ after:" + project.treeString());
return project;
}
}
return topN;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeString;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.Lists;

import java.util.List;

/**
* in project:
* decode_as_varchar(encode_as_xxx(v)) => v
*/
public class DecoupleEncodeDecode extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject().then(this::rewrite)
.toRule(RuleType.DECOUPLE_DECODE_ENCODE_SLOT);
}

private LogicalProject<?> rewrite(LogicalProject<?> project) {
List<NamedExpression> newProjections = Lists.newArrayList();
boolean hasNewProjections = false;
for (NamedExpression e : project.getProjects()) {
boolean changed = false;
if (e instanceof Alias) {
Alias alias = (Alias) e;
Expression body = alias.child();
if (body instanceof DecodeAsVarchar && body.child(0) instanceof EncodeString) {
Expression encodeBody = body.child(0).child(0);
newProjections.add((NamedExpression) alias.withChildren(encodeBody));
changed = true;
} else if (body instanceof EncodeString && body.child(0) instanceof DecodeAsVarchar) {
Expression decodeBody = body.child(0).child(0);
newProjections.add((NamedExpression) alias.withChildren(decodeBody));
changed = true;
}
}
if (!changed) {
newProjections.add(e);
hasNewProjections = true;
}
}
if (hasNewProjections) {
project = project.withProjects(newProjections);
}
return project;
}

}
Loading
Loading