Skip to content

Commit

Permalink
decode-encode
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Dec 3, 2024
1 parent 7c7375b commit 06337a7
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite;
import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite;
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
import org.apache.doris.nereids.rules.rewrite.DecoupleEncodeDecode;
import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult;
import org.apache.doris.nereids.rules.rewrite.EliminateAggCaseWhen;
import org.apache.doris.nereids.rules.rewrite.EliminateAggregate;
Expand Down Expand Up @@ -251,7 +252,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
new CountLiteralRewrite(),
new NormalizeSort(),
new MergeProjects(),
new PushDownEncodeSlot()
new PushDownEncodeSlot(),
new DecoupleEncodeDecode()
),
topic("Window analysis",
topDown(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ public enum RuleType {
COMPRESSED_MATERIALIZE_AGG(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_SORT(RuleTypeClass.REWRITE),
PUSH_DOWN_ENCODE_SLOT(RuleTypeClass.REWRITE),
DECOUPLE_DECODE_ENCODE_SLOT(RuleTypeClass.REWRITE),
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
NORMALIZE_SORT(RuleTypeClass.REWRITE),
NORMALIZE_REPEAT(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeString;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.Lists;

import java.util.List;

/**
* in project:
* decode_as_varchar(encode_as_xxx(v)) => v
*/
public class DecoupleEncodeDecode extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject().then(this::rewrite)
.toRule(RuleType.DECOUPLE_DECODE_ENCODE_SLOT);
}

private LogicalProject<?> rewrite(LogicalProject<?> project) {
List<NamedExpression> newProjections = Lists.newArrayList();
boolean hasNewProjections = false;
for (NamedExpression e : project.getProjects()) {
boolean changed = false;
if (e instanceof Alias) {
Alias alias = (Alias) e;
Expression body = alias.child();
if (body instanceof DecodeAsVarchar && body.child(0) instanceof EncodeString) {
Expression encodeBody = body.child(0).child(0);
newProjections.add((NamedExpression) alias.withChildren(encodeBody));
changed = true;
} else if (body instanceof EncodeString && body.child(0) instanceof DecodeAsVarchar) {
Expression decodeBody = body.child(0).child(0);
newProjections.add((NamedExpression) alias.withChildren(decodeBody));
changed = true;
}
}
if (!changed) {
newProjections.add(e);
hasNewProjections = true;
}
}
if (hasNewProjections) {
project = project.withProjects(newProjections);
}
return project;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ private LogicalProject<? extends Plan> rewriteRootProject(LogicalProject<? exten
private LogicalProject<? extends Plan> pushDownEncodeSlot(LogicalProject<? extends Plan> project) {
List<Alias> encodeAliases = collectEncodeAliases(project);
PushDownContext ctx = new PushDownContext(project, encodeAliases);
ctx.prepare();
if (ctx.notPushed.size() == encodeAliases.size()) {
return project;
}
Expand Down Expand Up @@ -155,32 +156,31 @@ public static class PushDownContext {
public PushDownContext(Plan plan, List<Alias> encodeAliases) {
this.plan = plan;
this.encodeAliases = encodeAliases;
prepare();
}

public static boolean canBothSidesEncode(ComparisonPredicate compare) {
return compare.left().getDataType() instanceof CharacterType
&& ((CharacterType) compare.left().getDataType()).getLen() < 15
&& ((CharacterType) compare.right().getDataType()).getLen() < 15
&& compare.left() instanceof SlotReference && compare.right() instanceof SlotReference;
}

private BiMap<SlotReference, SlotReference> getCompareSlotsFromJoinCondition(LogicalJoin<?, ?> join) {
BiMap<SlotReference, SlotReference> compareSlots = HashBiMap.create();
List<Expression> conditions = new ArrayList<>();
conditions.addAll(join.getHashJoinConjuncts());
conditions.addAll(join.getOtherJoinConjuncts());
for (Expression e : conditions) {
if (e instanceof ComparisonPredicate) {
ComparisonPredicate compare = (ComparisonPredicate) e;
if (canBothSidesEncode(compare)) {
compareSlots.put((SlotReference) compare.left(), (SlotReference) compare.right());
}
}
}

return compareSlots;
}
// public static boolean canBothSidesEncode(ComparisonPredicate compare) {
// return compare.left().getDataType() instanceof CharacterType
// && ((CharacterType) compare.left().getDataType()).getLen() < 15
// && ((CharacterType) compare.right().getDataType()).getLen() < 15
// && compare.left() instanceof SlotReference && compare.right() instanceof SlotReference;
// }
//
// private BiMap<SlotReference, SlotReference> getCompareSlotsFromJoinCondition(LogicalJoin<?, ?> join) {
// BiMap<SlotReference, SlotReference> compareSlots = HashBiMap.create();
// List<Expression> conditions = new ArrayList<>();
// conditions.addAll(join.getHashJoinConjuncts());
// conditions.addAll(join.getOtherJoinConjuncts());
// for (Expression e : conditions) {
// if (e instanceof ComparisonPredicate) {
// ComparisonPredicate compare = (ComparisonPredicate) e;
// if (canBothSidesEncode(compare)) {
// compareSlots.put((SlotReference) compare.left(), (SlotReference) compare.right());
// }
// }
// }
//
// return compareSlots;
// }

// init replaceMap/toBePushed/notPushed
private void prepare() {
Expand All @@ -190,16 +190,17 @@ private void prepare() {
Plan child = plan.children().get(i);
if (child instanceof LogicalJoin) {
LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) child;
BiMap<SlotReference, SlotReference> compareSlots = getCompareSlotsFromJoinCondition(join);
BiMap<SlotReference, SlotReference> compareSlots =
EncodeSlotPushDownVisitor.getEncodeCandidateSlotsFromJoinCondition(join);
childrenPassThroughSlots.get(i).addAll(compareSlots.keySet());
}
}

if (plan instanceof LogicalJoin) {
LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) plan;
BiMap<SlotReference, SlotReference> compareSlots = getCompareSlotsFromJoinCondition(join);
expandEncodeAliasForJoin(compareSlots);
}
//
// if (plan instanceof LogicalJoin) {
// LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) plan;
// BiMap<SlotReference, SlotReference> compareSlots = getCompareSlotsFromJoinCondition(join);
// expandEncodeAliasForJoin(compareSlots);
// }
for (Alias alias : encodeAliases) {
EncodeString encode = (EncodeString) alias.child();
Expression strExpr = encode.child();
Expand All @@ -223,17 +224,10 @@ private void prepare() {
}
}

public void reset() {
toBePused.clear();
toBePushedToChild.clear();
replaceMap.clear();
notPushed.clear();
}
/**
* expandEncodeAliasForJoin
*/
public List<SlotReference> expandEncodeAliasForJoin(BiMap<SlotReference, SlotReference> equalSlots) {
List<SlotReference> expandedOtherHand = new ArrayList<>();
public void expandEncodeAliasForJoin(BiMap<SlotReference, SlotReference> equalSlots) {
List<Alias> expanded = new ArrayList<>();
for (Alias alias : encodeAliases) {
if (alias.child().child(0) instanceof SlotReference) {
Expand All @@ -244,13 +238,11 @@ public List<SlotReference> expandEncodeAliasForJoin(BiMap<SlotReference, SlotRef
Alias encodeOtherHandAlias = new Alias(encodeOtherHand, encodeOtherHand.toSql());
if (!encodeAliases.contains(encodeOtherHandAlias)) {
expanded.add(encodeOtherHandAlias);
expandedOtherHand.add(equalSlots.get(slot));
}
}
}
}
encodeAliases.addAll(expanded);
return expandedOtherHand;
}

// the child of alias is a slot reference. for example: slotA as B
Expand Down Expand Up @@ -317,6 +309,7 @@ private Plan projectNotPushedAlias(Plan plan, List<Alias> notPushedAlias) {

@Override
public Plan visit(Plan plan, PushDownContext ctx) {
ctx.prepare();
plan = visitChildren(plan, ctx);
plan = projectNotPushedAlias(plan, ctx.notPushed);
return plan;
Expand All @@ -334,7 +327,6 @@ private Optional<Alias> findEncodeAliasByEncodeSlot(SlotReference slot, List<Ali
@Override
public LogicalProject<? extends Plan> visitLogicalProject(
LogicalProject<? extends Plan> project, PushDownContext ctx) {
ctx.reset();
/*
* case 1
* push down "encode(v1) as v2
Expand Down Expand Up @@ -424,17 +416,76 @@ public LogicalProject<? extends Plan> visitLogicalProject(
return project;
}

private static boolean canBothSidesEncode(ComparisonPredicate compare) {
return compare.left().getDataType() instanceof CharacterType
&& ((CharacterType) compare.left().getDataType()).getLen() < 15
&& ((CharacterType) compare.right().getDataType()).getLen() < 15
&& compare.left() instanceof SlotReference && compare.right() instanceof SlotReference;
}

/**
* getEncodeCandidateSlotsFromJoinCondition
*
*/
public static BiMap<SlotReference, SlotReference> getEncodeCandidateSlotsFromJoinCondition(
LogicalJoin<?, ?> join) {
// T1 join T2 on v1=v2 => v1/v2 can be encoded
// T1 join T2 on v1=v2 and fun(v1) => v1/v2 can not be encoded
BiMap<SlotReference, SlotReference> compareSlots = HashBiMap.create();
List<Expression> conditions = new ArrayList<>();
conditions.addAll(join.getHashJoinConjuncts());
conditions.addAll(join.getOtherJoinConjuncts());
Set<Slot> shouldNotPushSlots = Sets.newHashSet();
for (Expression e : conditions) {
boolean canPush = false;
if (e instanceof ComparisonPredicate) {
ComparisonPredicate compare = (ComparisonPredicate) e;
if (canBothSidesEncode(compare)) {
compareSlots.put((SlotReference) compare.left(), (SlotReference) compare.right());
canPush = true;
}
}
if (!canPush) {
shouldNotPushSlots.addAll(e.getInputSlots());
}
}
for (Slot notPushSlot : shouldNotPushSlots) {
compareSlots.remove((SlotReference) notPushSlot);
}
return compareSlots;
}

@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, PushDownContext ctx) {
join = (LogicalJoin) visitChildren(join, ctx);
// TODO: rewrite join condition
List<Alias> pushLeft = new ArrayList<>();
List<Alias> pushRight = new ArrayList<>();
BiMap<SlotReference, SlotReference> encodeCandidateSlots = getEncodeCandidateSlotsFromJoinCondition(join);
Set<Slot> leftOutputSlots = join.left().getOutputSet();
Map<SlotReference, SlotReference> replaceMap = new HashMap<>();

for (Alias encodeAlias : ctx.encodeAliases) {
SlotReference encodeSlot = (SlotReference) encodeAlias.child().child(0);
if (encodeCandidateSlots.containsKey(encodeSlot)) {
SlotReference otherHand = encodeCandidateSlots.get(encodeSlot);
Alias otherHandAlias = new Alias(encodeAlias.child().withChildren(otherHand));
if (leftOutputSlots.contains(encodeSlot)) {
pushLeft.add(encodeAlias);
pushRight.add(otherHandAlias);
} else {
pushRight.add(encodeAlias);
pushLeft.add(otherHandAlias);
}
replaceMap.put(encodeSlot, (SlotReference) encodeAlias.toSlot());
replaceMap.put(otherHand, (SlotReference) otherHandAlias.toSlot());
}
}
List<Expression> newConjuncts = Lists.newArrayListWithCapacity(join.getOtherJoinConjuncts().size());
boolean changed = false;
for (Expression e : join.getOtherJoinConjuncts()) {
ComparisonPredicate compare = (ComparisonPredicate) e;
if (PushDownContext.canBothSidesEncode(compare)) {
SlotReference newLeft = ctx.replaceMap.get(compare.left());
SlotReference newRight = ctx.replaceMap.get(compare.right());
if (canBothSidesEncode(compare)) {
SlotReference newLeft = replaceMap.get(compare.left());
SlotReference newRight = replaceMap.get(compare.right());
Preconditions.checkArgument(newLeft != null,
"PushDownEncodeSlot replaceMap is not valid, " + compare.left() + " is not found");
Preconditions.checkArgument(newRight != null,
Expand All @@ -447,10 +498,22 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, P
if (changed) {
join = join.withJoinConjuncts(join.getHashJoinConjuncts(), newConjuncts, join.getJoinReorderContext());
}
return projectNotPushedAlias(join, ctx.notPushed);
Plan newLeft;
if (pushLeft.isEmpty()) {
newLeft = join.left();
} else {
newLeft = join.left().accept(this, new PushDownContext(join.left(), pushLeft));
}
Plan newRight;
if (pushRight.isEmpty()) {
newRight = join.right();
} else {
newRight = join.right().accept(this, new PushDownContext(join.right(), pushRight));
}
join = (LogicalJoin<? extends Plan, ? extends Plan>) join.withChildren(newLeft, newRight);
return join;
}


@Override
public Plan visitLogicalSetOperation(LogicalSetOperation op, PushDownContext ctx) {
// push down "encode(v) as x" through
Expand Down
Loading

0 comments on commit 06337a7

Please sign in to comment.