Skip to content

Commit

Permalink
join not finish
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Dec 1, 2024
1 parent 37f3182 commit 00508e8
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,37 @@

package org.apache.doris.nereids.rules.rewrite;

import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import org.apache.commons.logging.Log;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeString;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.hadoop.util.Lists;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;


import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -63,7 +74,7 @@
* --> any(A, B)
*/
public class PushDownEncodeSlot extends OneRewriteRuleFactory {

private static final Logger LOG = LogManager.getLogger(PushDownEncodeSlot.class);
@Override
public Rule build() {
return logicalProject()
Expand All @@ -82,18 +93,23 @@ private boolean containsEncode(Expression expr) {
return expr instanceof EncodeString && expr.child(0) instanceof SlotReference;
}

private List<Alias> collectEncodeAlias(LogicalProject<? extends Plan> project) {
List<Alias> encodeAlias = new ArrayList<>();
private List<Alias> collectEncodeAliases(LogicalProject<? extends Plan> project) {
List<Alias> encodeAliases = new ArrayList<>();
project.getProjects().forEach(e -> {
if (e instanceof Alias && e.child(0) instanceof EncodeString) {
encodeAlias.add((Alias) e);
if (e instanceof Alias
&& e.child(0) instanceof EncodeString
&& e.child(0).child(0) instanceof SlotReference) {
encodeAliases.add((Alias) e);
}
});
return encodeAlias;
return encodeAliases;
}

private LogicalProject<? extends Plan> rewriteRootProject(LogicalProject<? extends Plan> project,
private LogicalProject<? extends Plan> rewriteRootProject(LogicalProject<? extends Plan> project,
List<Alias> encodeAlias) {
if (encodeAlias.isEmpty()) {
return project;
}
List<NamedExpression> projections = project.getProjects().stream().map(
e -> encodeAlias.contains(e) ? e.toSlot() : e)
.collect(Collectors.toList());
Expand All @@ -102,93 +118,149 @@ private LogicalProject<? extends Plan> rewriteRootProject(LogicalProject<? exte
}

private LogicalProject<? extends Plan> pushDownEncodeSlot(LogicalProject<? extends Plan> project) {
List<Alias> encodeAlias = collectEncodeAlias(project);
project = rewriteRootProject(project, encodeAlias);
List<Alias> encodeAliases = collectEncodeAliases(project);
PushDownContext ctx = new PushDownContext(project, encodeAliases);
if (ctx.notPushed.size() == encodeAliases.size()) {
return project;
}
Plan child = project.child();
Plan newChild = child.accept(EncodeSlotPushDownVisitor.INSTANCE, encodeAlias);
if (child.equals(newChild)) {
PushDownContext childContext = new PushDownContext(child, ctx.toBePushedToChild.get(child));
Plan newChild = child.accept(EncodeSlotPushDownVisitor.INSTANCE, childContext);
List<Alias> pushed = ctx.toBePused;
if (!child.equals(newChild)) {
if (newChild instanceof LogicalProject) {
pushed.removeAll(childContext.notPushed);
newChild = ((LogicalProject<?>) newChild).child();
}
project = (LogicalProject<? extends Plan>) project.withChildren(newChild);
project = rewriteRootProject(project, pushed);
}
return project;
}

/**
* push down encode slot
*/
public static class EncodeSlotPushDownVisitor extends PlanVisitor<Plan, List<Alias>> {
public static EncodeSlotPushDownVisitor INSTANCE = new EncodeSlotPushDownVisitor();
public static class PushDownContext {
public Plan plan;

public LogicalProject<Plan> replaceProjectsEncodeSlot(LogicalProject<Plan> project,
Map<? extends Expression, ? extends Expression> replaceMap) {
List<NamedExpression> newProjections = new ArrayList<>();
boolean changed = false;
for (NamedExpression expr : project.getProjects()) {
if (replaceMap.containsKey(expr) && replaceMap.get(expr) instanceof NamedExpression) {
newProjections.add((NamedExpression) replaceMap.get(expr));
changed = true;
} else {
newProjections.add(expr);
public List<Alias> encodeAliases;
// encode_as_int(slot1) as slot2
// replaceMap:
// slot1 -> slot2
Map<Expression, SlotReference> replaceMap = new HashMap<>();
// child plan -> aliases in encodeAliases which can be pushed down to child plan
Map<Plan, List<Alias>> toBePushedToChild = new HashMap<>();
List<Alias> toBePused = new ArrayList<>();
// the aliases that cannot be pushed down to any child plan
// for example:
// encode(A+B) as x, where plan is a join, and A, B comes from join's left and right child respectively
List<Alias> notPushed = new ArrayList<>();
public PushDownContext(Plan plan, List<Alias> encodeAliases) {
this.plan = plan;
this.encodeAliases = encodeAliases;
prepare();
}

public static boolean canBothSidesEncode(EqualPredicate equal) {
return equal.left().getDataType() instanceof CharacterType
&& ((CharacterType) equal.left().getDataType()).getLen() < 15
&& ((CharacterType) equal.right().getDataType()).getLen() < 15
&& equal.left() instanceof SlotReference && equal.right() instanceof SlotReference;
}

// init replaceMap/toBePushed/notPushed
private void prepare() {
List<Set<Slot>> childrenPassThroughSlots =
plan.children().stream().map(n -> getPassThroughSlots(n)).collect(Collectors.toList());
if (plan instanceof LogicalJoin) {
LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) plan;
BiMap<SlotReference, SlotReference> equalSlots = HashBiMap.create();
for (Expression e : join.getHashJoinConjuncts()) {
EqualPredicate equal = (EqualPredicate) e;
if (canBothSidesEncode(equal)) {
equalSlots.put((SlotReference) equal.left(), (SlotReference) equal.right());
}
}
List<SlotReference> expandedOtherHands = expandEncodeAliasForJoin(equalSlots);

for (SlotReference otherHand : expandedOtherHands) {
if (join.left().getOutputSet().contains(otherHand)) {
childrenPassThroughSlots.get(0).add(otherHand);
} else {
childrenPassThroughSlots.get(1).add(otherHand);
}
}
}
if (changed) {
return project.withProjects(newProjections);
for (Alias alias : encodeAliases) {
EncodeString encode = (EncodeString) alias.child();
Expression strExpr = encode.child();
boolean pushed = false;
Preconditions.checkArgument(strExpr instanceof SlotReference,
"expect encode_as_xxx(slot), but " + alias);

for (int i = 0; i < childrenPassThroughSlots.size(); i++) {
if (childrenPassThroughSlots.get(i).contains(strExpr)) {
toBePushedToChild.putIfAbsent(plan.child(i), new ArrayList<>());
toBePushedToChild.get(plan.child(i)).add(alias);
toBePused.add(alias);
replaceMap.put(alias.child().child(0), (SlotReference) alias.toSlot());
pushed = true;
break;
}
}
if (!pushed) {
notPushed.add(alias);
}
}
return project;
}

@Override
public Plan visit(Plan plan, List<Alias> encodeAlias) {
// replaceMap:
// encode_as_int(slot1) -> slot2
// slot1 -> slot2
Map<Expression, Slot> replaceMap = new HashMap<>();
List<Set<Slot>> byPassSlots = plan.children().stream()
.map(this::getPassThroughSlots)
.collect(Collectors.toList());
Map<Plan, List<Alias>> toBePushed = new HashMap<>();
for (Alias alias : encodeAlias) {
EncodeString encode = (EncodeString) alias.child();
Expression strExpr = encode.child();
if (strExpr instanceof SlotReference) {
for (int i = 0; i < byPassSlots.size(); i++) {
if (byPassSlots.get(i).contains(strExpr)) {
toBePushed.putIfAbsent(plan.child(i), new ArrayList<>());
toBePushed.get(plan.child(i)).add(alias);
replaceMap.put(alias, alias.toSlot());
replaceMap.put(alias.child().child(0), alias.toSlot());
break;
public List<SlotReference> expandEncodeAliasForJoin(BiMap<SlotReference, SlotReference> equalSlots) {
List<SlotReference> expandedOtherHand = new ArrayList<>();
List<Alias> expanded = new ArrayList<>();
for (Alias alias : encodeAliases) {
if (alias.child().child(0) instanceof SlotReference) {
SlotReference slot = (SlotReference) alias.child().child(0);
if (equalSlots.containsKey(slot)) {
Alias encodeOtherHand = (Alias) alias.withChildren(alias.child().withChildren(equalSlots.get(slot)));
if (!encodeAliases.contains(encodeOtherHand)) {
expanded.add(encodeOtherHand);
expandedOtherHand.add(equalSlots.get(slot));
}
}
}
}
// rewrite plan according to encode expression
// for example: project(encode_as_int(slot1) as slot2)
// 1. rewrite project's expressions: project(slot2),
// 2. push "encode_as_int(slot1) as slot2" down to project.child()
// rewrite expressions
if (plan instanceof LogicalProject) {
plan = replaceProjectsEncodeSlot((LogicalProject)plan, replaceMap);
encodeAliases.addAll(expanded);
return expandedOtherHand;
}

public static Set<Slot> getPassThroughSlots(Plan plan) {
if (plan instanceof LogicalRelation) {
return new HashSet<>();
}
Set<Slot> outputSlots = Sets.newHashSet(plan.getOutputSet());
Set<Slot> keySlots = Sets.newHashSet();
for (Expression e : plan.getExpressions()) {
if (!(e instanceof SlotReference)) {
keySlots.addAll(e.getInputSlots());
}
}
// rewrite children
outputSlots.removeAll(keySlots);
return outputSlots;
}
}


/**
* push down encode slot
*/
public static class EncodeSlotPushDownVisitor extends PlanVisitor<Plan, PushDownContext> {
public static EncodeSlotPushDownVisitor INSTANCE = new EncodeSlotPushDownVisitor();

public Plan visitChildren(Plan plan, PushDownContext ctx) {
ImmutableList.Builder<Plan> newChildren = ImmutableList.builderWithExpectedSize(plan.arity());
boolean hasNewChildren = false;
for (Plan child : plan.children()) {
Plan newChild;
if (toBePushed.containsKey(child)) {
if (child instanceof LogicalProject && child.child(0) instanceof LogicalCatalogRelation) {
LogicalProject project = (LogicalProject) child;
List<NamedExpression> projections =
PlanUtils.mergeProjections(project.getProjects(), toBePushed.get(child));
newChild = project.withProjects(projections);
} else if (child instanceof LogicalCatalogRelation) {
List<NamedExpression> newProjections = new ArrayList<>();
newProjections.addAll(child.getOutput());
newProjections.addAll(toBePushed.get(child));
newChild = new LogicalProject<>(newProjections, child);
hasNewChildren = true;
} else {
newChild = child.accept(this, toBePushed.get(child));
}
if (ctx.toBePushedToChild.containsKey(child)) {
newChild = child.accept(this, new PushDownContext(child, ctx.toBePushedToChild.get(child)));
if (!hasNewChildren && newChild != child) {
hasNewChildren = true;
}
Expand All @@ -197,28 +269,81 @@ public Plan visit(Plan plan, List<Alias> encodeAlias) {
}
newChildren.add(newChild);
}

if (hasNewChildren) {
plan = plan.withChildren(newChildren.build());
}
return plan;
}

private Set<Slot> getPassThroughSlots(Plan plan) {
Set<Slot> outputSlots = Sets.newHashSet(plan.getOutputSet());
Set<Slot> keySlots = Sets.newHashSet();
if (plan instanceof LogicalProject) {
for (NamedExpression e : ((LogicalProject<?>) plan).getProjects()) {
if (!(e instanceof SlotReference)) {
keySlots.addAll(e.getInputSlots());
}
private Plan projectNotPushedAlias(Plan plan, List<Alias> notPushedAlias) {
if (!notPushedAlias.isEmpty()) {
// project encode expressions if they are not pushed down
// project(encode)
// +--> plan
List<NamedExpression> projections =
notPushedAlias.stream().map(e -> (NamedExpression) e).collect(Collectors.toList());
projections.addAll(plan.getOutput());
plan = new LogicalProject<>(projections, plan);
}
return plan;
}

@Override
public Plan visit(Plan plan, PushDownContext ctx) {
plan = visitChildren(plan, ctx);
plan = projectNotPushedAlias(plan, ctx.notPushed);
return plan;
}

@Override
public LogicalProject<? extends Plan> visitLogicalProject(
LogicalProject<? extends Plan> project, PushDownContext ctx) {
project = (LogicalProject<? extends Plan>) visitChildren(project, ctx);
/*
* push down "encode(v1) as v2
* project(v1, ...)
* +--->any(v1)
* =>
* project(v2, ...)
* +--->any(v1)
* and push down "encode(v1) as v2" to any(v1)
*/
List<NamedExpression> projections = Lists.newArrayListWithCapacity(project.getProjects().size());
for (NamedExpression e : project.getProjects()) {
if (ctx.toBePused.contains(e)) {
projections.add(e.toSlot());
} else {
projections.add(e);
}
} else {
keySlots = plan.getInputSlots();
}
outputSlots.removeAll(keySlots);
return outputSlots;
return project.withProjects(projections);
}
}

@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, PushDownContext ctx) {
join = (LogicalJoin) visitChildren(join, ctx);
// TODO: rewrite join condition
List<Expression> newConjuncts = Lists.newArrayListWithCapacity(join.getHashJoinConjuncts().size());
boolean changed = false;
for (Expression e : join.getHashJoinConjuncts()) {
EqualPredicate equal = (EqualPredicate) e;
if (PushDownContext.canBothSidesEncode(equal)) {
SlotReference newLeft = ctx.replaceMap.get(equal.left());
SlotReference newRight = ctx.replaceMap.get(equal.right());
Preconditions.checkArgument(newLeft != null,
"PushDownEncodeSlot replaceMap is not valid, " + equal.left() + " is not found" );
Preconditions.checkArgument(newRight != null,
"PushDownEncodeSlot replaceMap is not valid, " + equal.right() + " is not found" );
equal = (EqualPredicate) equal.withChildren(newLeft, newRight);
changed = true;
}
newConjuncts.add(equal);
}
if (changed) {
join = join.withJoinConjuncts(newConjuncts, join.getOtherJoinConjuncts(), join.getJoinReorderContext());
}
Plan plan = projectNotPushedAlias(join, ctx.notPushed);
return plan;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

Expand Down
Loading

0 comments on commit 00508e8

Please sign in to comment.