Skip to content

Commit

Permalink
push down encode slot
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Nov 29, 2024
1 parent 2a5e992 commit 44c7cc3
Show file tree
Hide file tree
Showing 15 changed files with 245 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOnPkFk;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownEncodeSlot;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.PushDownLimit;
import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin;
Expand Down Expand Up @@ -248,7 +249,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(
new NormalizeAggregate(),
new CountLiteralRewrite(),
new NormalizeSort()
new NormalizeSort(),
new PushDownEncodeSlot()
),
topic("Window analysis",
topDown(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,14 @@ public class LogicalProperties {
protected final Supplier<DataTrait> dataTraitSupplier;
private Integer hashCode = null;

public LogicalProperties(Supplier<List<Slot>> outputSupplier,
Supplier<DataTrait> dataTraitSupplier) {
this(outputSupplier, dataTraitSupplier, ImmutableList::of);
}

/**
* constructor of LogicalProperties.
*
* @param outputSupplier provide the output. Supplier can lazy compute output without
* throw exception for which children have UnboundRelation
*/
public LogicalProperties(Supplier<List<Slot>> outputSupplier,
Supplier<DataTrait> dataTraitSupplier,
Supplier<List<Slot>> nonUserVisibleOutputSupplier) {
Supplier<DataTrait> dataTraitSupplier) {
this.outputSupplier = Suppliers.memoize(
Objects.requireNonNull(outputSupplier, "outputSupplier can not be null")
);
Expand Down Expand Up @@ -100,6 +94,10 @@ public LogicalProperties(Supplier<List<Slot>> outputSupplier,
);
}

public LogicalProperties withOutputSupplier(Supplier<List<Slot>> outputSupplier) {
return new LogicalProperties(outputSupplier, dataTraitSupplier);
}

public List<Slot> getOutput() {
return outputSupplier.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public enum RuleType {
// rewrite rules
COMPRESSED_MATERIALIZE_AGG(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_SORT(RuleTypeClass.REWRITE),
PUSH_DOWN_ENCODE_SLOT(RuleTypeClass.REWRITE),
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
NORMALIZE_SORT(RuleTypeClass.REWRITE),
NORMALIZE_REPEAT(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,17 @@ boolean comparePlan(Plan plan1, Plan plan2) {
isEqual = false;
}
for (int i = 0; isEqual && i < plan2.getOutput().size(); i++) {
NamedExpression expr = ((LogicalProject<?>) plan1).getProjects().get(i);
NamedExpression replacedExpr = (NamedExpression)
expr.rewriteUp(e -> plan1ToPlan2.getOrDefault(e, e));
if (!replacedExpr.equals(((LogicalProject<?>) plan2).getProjects().get(i))) {
Expression expr1 = ((LogicalProject<?>) plan1).getProjects().get(i);
Expression expr2 = ((LogicalProject<?>) plan2).getProjects().get(i);
if (expr1 instanceof Alias) {
if (!(expr2 instanceof Alias)) {
return false;
}
expr1 = expr1.child(0);
expr2 = expr2.child(0);
}
Expression replacedExpr = expr1.rewriteUp(e -> plan1ToPlan2.getOrDefault(e, e));
if (!replacedExpr.equals(expr2)) {
isEqual = false;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeStr;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* push down encode_as_int(slot) down
* example:
* group by x
* -->project(encode_as_int(A) as x)
* -->Any(A)
* -->project(A)
* --> scan
* =>
* group by x
* -->project(x)
* -->Any(x)
* --> project(encode_as_int(A) as x)
* -->scan
* Note:
* do not push down encode if encode.child() is not slot,
* example
* group by encode_as_int(A + B)
* --> any(A, B)
*/
public class PushDownEncodeSlot extends OneRewriteRuleFactory {

@Override
public Rule build() {
return logicalProject()
.when(this::containsEncode)
.when(project -> !(project.child() instanceof LogicalCatalogRelation))
.then(project -> pushDownEncodeSlot(project))
.toRule(RuleType.PUSH_DOWN_ENCODE_SLOT);
}

private boolean containsEncode(LogicalProject<? extends Plan> project) {
return project.getProjects().stream()
.anyMatch(e -> e instanceof Alias && containsEncode(e.child(0)));
}

private boolean containsEncode(Expression expr) {
return expr instanceof EncodeStr && expr.child(0) instanceof SlotReference;
}

private List<Alias> collectEncodeAlias(LogicalProject<? extends Plan> project) {
List<Alias> encodeAlias = new ArrayList<>();
project.getProjects().forEach(e -> {
if (e instanceof Alias && e.child(0) instanceof EncodeStr) {
encodeAlias.add((Alias) e);
}
});
return encodeAlias;
}

private LogicalProject<? extends Plan> pushDownEncodeSlot(LogicalProject<? extends Plan> project) {
List<Alias> encodeAlias = collectEncodeAlias(project);
LogicalProject<? extends Plan> result = (LogicalProject<? extends Plan>)
project.accept(EncodeSlotPushDownVisitor.visitor, encodeAlias);
return result;
}

/**
* push down encode slot
*/
public static class EncodeSlotPushDownVisitor extends PlanVisitor<Plan, List<Alias>> {
public static EncodeSlotPushDownVisitor visitor = new EncodeSlotPushDownVisitor();

@Override
public Plan visit(Plan plan, List<Alias> encodeAlias) {
// replaceMap:
// encode_as_int(slot1) -> slot2
// slot1 -> slot2
Map<Expression, Slot> replaceMap = new HashMap<>();
List<Set<Slot>> byPassSlots = plan.children().stream()
.map(this::getByPassSlot)
.collect(Collectors.toList());
Map<Plan, List<Alias>> toBePushed = new HashMap<>();
for (Alias alias : encodeAlias) {
EncodeStr encode = (EncodeStr) alias.child();
Expression strExpr = encode.child();
if (strExpr instanceof SlotReference) {
for (int i = 0; i < byPassSlots.size(); i++) {
if (byPassSlots.get(i).contains(strExpr)) {
toBePushed.putIfAbsent(plan.child(i), new ArrayList<>());
toBePushed.get(plan.child(i)).add(alias);
replaceMap.put(alias, alias.toSlot());
replaceMap.put(alias.child().child(0), alias.toSlot());
break;
}
}
}
}
// rewrite plan according to encode expression
// for example: project(encode_as_int(slot1) as slot2)
// 1. rewrite project's expressions: project(slot2),
// 2. push encode_as_int(slot1) as slot2 down to project.child()
// rewrite expressions
plan = plan.replaceExpressions(replaceMap);
// rewrite children
ImmutableList.Builder<Plan> newChildren = ImmutableList.builderWithExpectedSize(plan.arity());
boolean hasNewChildren = false;
for (Plan child : plan.children()) {
Plan newChild;
if (toBePushed.containsKey(child)) {
if (child instanceof LogicalProject && child.child(0) instanceof LogicalCatalogRelation) {
LogicalProject project = (LogicalProject) child;
List<NamedExpression> projections = new ArrayList<>();
projections.addAll(toBePushed.get(project));
projections.addAll(project.getProjects());
newChild = project.withProjects(projections);
} else if (child instanceof LogicalCatalogRelation) {
List<NamedExpression> newProjections = new ArrayList<>();
newProjections.addAll(child.getOutput());
newProjections.addAll(toBePushed.get(child));
newChild = new LogicalProject<>(newProjections, child);
hasNewChildren = true;
} else {
newChild = child.accept(this, toBePushed.get(child));
}
if (!hasNewChildren && newChild != child) {
hasNewChildren = true;
}
} else {
newChild = child;
}
newChildren.add(newChild);
}

if (hasNewChildren) {
plan = plan.withChildren(newChildren.build());
}
return plan;
}

private Set<Slot> getByPassSlot(Plan plan) {
Set<Slot> outputSlots = Sets.newHashSet(plan.getOutput());
outputSlots.removeAll(plan.getInputSlots());
return outputSlots;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeStrToInteger;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeStr;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
Expand Down Expand Up @@ -149,7 +149,7 @@ private static Set<Expression> eliminateDecodeAndEncode(Set<Expression> expressi
}

private static Expression eliminateDecodeAndEncode(Expression expression) {
if (expression instanceof DecodeAsVarchar && expression.child(0) instanceof EncodeStrToInteger) {
if (expression instanceof DecodeAsVarchar && expression.child(0) instanceof EncodeStr) {
return expression.child(0).child(0);
}
boolean hasNewChild = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
/**
* ScalarFunction 'EncodeAsBigInt'.
*/
public class EncodeAsBigInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {
public class EncodeAsBigInt extends EncodeStr
implements ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
/**
* ScalarFunction 'EncodeAsInt'.
*/
public class EncodeAsInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {
public class EncodeAsInt extends EncodeStr
implements ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(IntegerType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
/**
* ScalarFunction 'EncodeAsLargeInt'.
*/
public class EncodeAsLargeInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {
public class EncodeAsLargeInt extends EncodeStr
implements ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(LargeIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
/**
* ScalarFunction 'CompressAsSmallInt'.
*/
public class EncodeAsSmallInt extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable, EncodeStrToInteger {
public class EncodeAsSmallInt extends EncodeStr
implements ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(SmallIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,17 @@

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;

/**
* Encode_as_XXXInt
*/
public interface EncodeStrToInteger {
public abstract class EncodeStr extends ScalarFunction implements UnaryExpression {
/**
* constructor with 1 argument.
*/
public EncodeStr(String name, Expression arg0) {
super("encode_as_int", arg0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

Expand Down Expand Up @@ -232,4 +233,8 @@ default String getGroupIdAsString() {
default String getGroupIdWithPrefix() {
return "@" + getGroupIdAsString();
}

default Plan replaceExpressions(Map<? extends Expression, ? extends Expression> replaceMap) {
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ public OlapTable getTable() {
@Override
public String toString() {
return Utils.toSqlString("LogicalOlapScan",
"qualified", qualifiedName(),
"name", table.getName(),
"indexName", getSelectedMaterializedIndexName().orElse("<index_not_selected>"),
"selectedIndexId", selectedIndexId,
"preAgg", preAggStatus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,4 +278,22 @@ public void computeFd(DataTrait.Builder builder) {
}
}
}

@Override
public Plan replaceExpressions(Map<? extends Expression, ? extends Expression> replaceMap) {
List<NamedExpression> newProjections = new ArrayList<>();
boolean changed = false;
for (NamedExpression expr : getProjects()) {
if (replaceMap.containsKey(expr) && replaceMap.get(expr) instanceof NamedExpression) {
newProjections.add((NamedExpression) replaceMap.get(expr));
changed = true;
} else {
newProjections.add(expr);
}
}
if (changed) {
return this.withProjects(newProjections);
}
return this;
}
}
Loading

0 comments on commit 44c7cc3

Please sign in to comment.