Skip to content

Commit

Permalink
setop
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Dec 3, 2024
1 parent 50f6fde commit 7c7375b
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.types.coercion.CharacterType;

Expand All @@ -44,6 +45,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -221,6 +223,12 @@ private void prepare() {
}
}

public void reset() {
toBePused.clear();
toBePushedToChild.clear();
replaceMap.clear();
notPushed.clear();
}
/**
* expandEncodeAliasForJoin
*/
Expand All @@ -245,15 +253,17 @@ public List<SlotReference> expandEncodeAliasForJoin(BiMap<SlotReference, SlotRef
return expandedOtherHand;
}

/**
*
* getPassThroughSlots
*/
public static Set<Slot> getPassThroughSlots(Plan plan) {
// the child of alias is a slot reference. for example: slotA as B
//
private boolean isSlotAlias(Expression expr) {
return expr instanceof Alias && expr.child(0) instanceof SlotReference;
}

private Set<Slot> getPassThroughSlots(Plan plan) {
Set<Slot> outputSlots = Sets.newHashSet(plan.getOutputSet());
Set<Slot> keySlots = Sets.newHashSet();
for (Expression e : plan.getExpressions()) {
if (!(e instanceof SlotReference)) {
if (!(e instanceof SlotReference) && !isSlotAlias(e)) {
keySlots.addAll(e.getInputSlots());
}
}
Expand Down Expand Up @@ -312,28 +322,106 @@ public Plan visit(Plan plan, PushDownContext ctx) {
return plan;
}

private Optional<Alias> findEncodeAliasByEncodeSlot(SlotReference slot, List<Alias> aliases) {
for (Alias alias : aliases) {
if (alias.child().child(0).equals(slot)) {
return Optional.of(alias);
}
}
return Optional.empty();
}

@Override
public LogicalProject<? extends Plan> visitLogicalProject(
LogicalProject<? extends Plan> project, PushDownContext ctx) {
project = (LogicalProject<? extends Plan>) visitChildren(project, ctx);
ctx.reset();
/*
* case 1
* push down "encode(v1) as v2
* project(v1, ...)
* +--->any(v1)
* =>
* project(v2, ...)
* +--->any(v1)
* and push down "encode(v1) as v2" to any(v1)
*
* case 2
* push down "encode(v1) as v2
* project(k as v1, ...)
* +--->any(k)
* =>
* project(v2, ...)
* +--->any(k)
* and push down "encode(k) as v2" to any(v1)
*
* case 3
* push down "encode(v1) as v2
* project(a + b as v1, ...)
* +--->any(a, b)
* =>
* project(encode(a+b) as v2, ...)
* +-->any(a, b)
*/
List<NamedExpression> projections = Lists.newArrayListWithCapacity(project.getProjects().size());
List<Alias> toBePushed = Lists.newArrayList();
List<Alias> notPushed = Lists.newArrayList(ctx.encodeAliases);

for (NamedExpression e : project.getProjects()) {
if (ctx.toBePused.contains(e)) {
projections.add(e.toSlot());
boolean changed = false;

if (e instanceof SlotReference) {
Optional<Alias> encodeAliasOpt = findEncodeAliasByEncodeSlot((SlotReference) e, ctx.encodeAliases);
if (encodeAliasOpt.isPresent()) {
// case 1
projections.add(encodeAliasOpt.get().toSlot());
toBePushed.add(encodeAliasOpt.get());
notPushed.remove(encodeAliasOpt.get());
changed = true;
}
} else {
// e is Alias
Expression aliasExpr = e.child(0);
if (aliasExpr instanceof SlotReference) {
//case 2
Optional<Alias> encodeAliasOpt = findEncodeAliasByEncodeSlot((SlotReference) e.toSlot(),
ctx.encodeAliases);
if (encodeAliasOpt.isPresent()) {
projections.add(encodeAliasOpt.get().toSlot());
Alias encodeAlias = encodeAliasOpt.get();
EncodeString encode = (EncodeString) encodeAlias.child();
SlotReference baseSlot = (SlotReference) aliasExpr;
Alias encodeAliasForChild = (Alias) encodeAlias.withChildren(encode.withChildren(baseSlot));
toBePushed.add(encodeAliasForChild);
notPushed.remove(encodeAlias);
changed = true;
}
} else {
// case 3
Optional<Alias> encodeAliasOpt = findEncodeAliasByEncodeSlot((SlotReference) e.toSlot(),
ctx.encodeAliases);
if (encodeAliasOpt.isPresent()) {
Alias encodeAlias = encodeAliasOpt.get();
EncodeString encode = (EncodeString) encodeAlias.child();
Alias encodeAliasForChild = (Alias) encodeAlias
.withChildren(encode.withChildren(aliasExpr));
notPushed.add(encodeAliasForChild);
changed = true;
}
}
}
if (!changed) {
projections.add(e);
}
}
return project.withProjects(projections);
projections.addAll(notPushed);

project = project.withProjects(projections);
PushDownContext childContext = new PushDownContext(project.child(), toBePushed);
Plan newChild = project.child().accept(this, childContext);
if (project.child() != newChild) {
project = (LogicalProject<? extends Plan>) project.withChildren(newChild);
}
return project;
}

@Override
Expand All @@ -359,8 +447,61 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, P
if (changed) {
join = join.withJoinConjuncts(join.getHashJoinConjuncts(), newConjuncts, join.getJoinReorderContext());
}
Plan plan = projectNotPushedAlias(join, ctx.notPushed);
return plan;
return projectNotPushedAlias(join, ctx.notPushed);
}


@Override
public Plan visitLogicalSetOperation(LogicalSetOperation op, PushDownContext ctx) {
// push down "encode(v) as x" through
// union(output[v], regular([v1],[v2]))
// -->child1(v1)
// -->child2(v2)
// rewrite union to: union(output[x], regular([x1], [x2]))
// and then push "encode(v1) as x1" to child(v1)
// push "encode(v2) as x2" to child(v2)

List<NamedExpression> newOutput = Lists.newArrayListWithCapacity(op.getOutput().size());
List<List<SlotReference>> newRegularOutputs = Lists.newArrayListWithCapacity(op.getOutput().size());
for (int cid = 0; cid < op.children().size(); cid++) {
newRegularOutputs.add(Lists.newArrayList(op.getRegularChildOutput(cid)));
}

for (int oid = 0; oid < op.getOutput().size(); oid++) {
NamedExpression e = op.getOutput().get(oid);
boolean changed = false;
for (Alias alias : ctx.encodeAliases) {
if (alias.child().child(0).equals(e)) {
newOutput.add(alias.toSlot());
changed = true;
EncodeString encode = (EncodeString) alias.child();
ctx.toBePused.add(alias);
for (int cid = 0; cid < op.children().size(); cid++) {
Plan child = op.child(cid);
ctx.toBePushedToChild.putIfAbsent(child, new ArrayList<>());
Alias aliasForChild = new Alias(
encode.withChildren(op.getRegularChildrenOutputs().get(cid).get(oid)));
ctx.toBePushedToChild.get(child).add(aliasForChild);
newRegularOutputs.get(cid).set(oid, (SlotReference) aliasForChild.toSlot());
}
break;
}
}
if (!changed) {
newOutput.add(e);
}
}
op = op.withNewOutputs(newOutput);

//rewrite children
List<Plan> newChildren = Lists.newArrayListWithCapacity(op.children().size());
for (Plan child : op.children()) {
PushDownContext childCtx = new PushDownContext(child, ctx.toBePushedToChild.get(child));
Plan newChild = child.accept(this, childCtx);
newChildren.add(newChild);
}
op = op.withChildrenAndTheirOutputs(newChildren, newRegularOutputs);
return op;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,16 @@ b 2
a 9
b 7

-- !union --
1
2
2
3
4

-- !intersect --
2

-- !except --
1

Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,35 @@

suite("pushdown_encode") {
// push down encode slot
sql """
drop table if exists t1;
CREATE TABLE t1 (
`k1` int NOT NULL,
`v1` char(5) NOT NULL
) ENGINE=OLAP
DUPLICATE KEY(`k1`)
DISTRIBUTED BY HASH(`k1`) BUCKETS 3
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
insert into t1 values (1, "a"), (2, "b");
drop table if exists t2;
CREATE TABLE t2 (
`k2` int NOT NULL,
`v2` char(5) NOT NULL
) ENGINE=OLAP
DUPLICATE KEY(`k2`)
DISTRIBUTED BY HASH(`k2`) BUCKETS 3
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
insert into t2 values (3, "c"), (4, "d"), (2, "b");
set disable_join_reorder=true;
"""
// sql """
// drop table if exists t1;
// CREATE TABLE t1 (
// `k1` int NOT NULL,
// `v1` char(5) NOT NULL
// ) ENGINE=OLAP
// DUPLICATE KEY(`k1`)
// DISTRIBUTED BY HASH(`k1`) BUCKETS 3
// PROPERTIES (
// "replication_allocation" = "tag.location.default: 1"
// );

// insert into t1 values (1, "a"), (2, "b");

// drop table if exists t2;
// CREATE TABLE t2 (
// `k2` int NOT NULL,
// `v2` char(5) NOT NULL
// ) ENGINE=OLAP
// DUPLICATE KEY(`k2`)
// DISTRIBUTED BY HASH(`k2`) BUCKETS 3
// PROPERTIES (
// "replication_allocation" = "tag.location.default: 1"
// );

// insert into t2 values (3, "c"), (4, "d"), (2, "b");

// set disable_join_reorder=true;
// """

explain{
sql """
Expand Down Expand Up @@ -154,32 +154,53 @@ suite("pushdown_encode") {
from t1 right outer join t2 on v1 < v2 and v2>"abc"
group by v1;
"""

// explain{
// // do not push down, because v is output
// sql """
// physical plan
// select k, v as x
// from (select k1 as k, v1 as v from t1) A
// union all (select k2 as k, v2 as v from t2)
// order by x
// """
// // this project is above union
// contains("""projects=[k#9, x#10, encode_as_bigint(x#10) AS `encode_as_bigint(x)`#11]""")
// }

// explain {
// sql """
// explain physical plan
// select k
// from (
// (select k1 as k, v1 as v from t1)
// union all
// (select k2 as k, v2 as v from t2)
// ) T
// order by v;
// """
// }


explain {
sql """
physical plan
select k
from (
(select k1 as k, v1 as v from t1)
union all
(select k2 as k, v2 as v from t2)
) T
order by v;
"""
contains("orderKeys=[encode_as_bigint(v)#10 asc null first]")
contains("outputs=[k#8, encode_as_bigint(v)#10], regularChildrenOutputs=[[k#4, encode_as_bigint(v)#11], [k#6, encode_as_bigint(v)#12]]")
contains("projects=[k1#0 AS `k`#4, encode_as_bigint(v1#1) AS `encode_as_bigint(v)`#11]")
contains("projects=[k2#2 AS `k`#6, encode_as_bigint(v2#3) AS `encode_as_bigint(v)`#12]")
}

order_qt_union """
select k
from (
(select k1 as k, v1 as v from t1)
union all
(select k2 as k, v2 as v from t2)
) T
order by v;
"""
order_qt_intersect """
select k
from (
(select k1 as k, v1 as v from t1)
intersect
(select k2 as k, v2 as v from t2)
) T
order by v;
"""

order_qt_except """
select k
from (
(select k1 as k, v1 as v from t1)
except
(select k2 as k, v2 as v from t2)
) T
order by v;
"""
// not pushed
// project(encode(A))
// +-->join(A=B+1)
Expand Down

0 comments on commit 7c7375b

Please sign in to comment.