Skip to content

Commit

Permalink
all
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Dec 4, 2024
1 parent 00d9d42 commit c7fc744
Show file tree
Hide file tree
Showing 296 changed files with 9,541 additions and 9,052 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,15 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(
new NormalizeAggregate(),
new CountLiteralRewrite(),
new NormalizeSort(),
new NormalizeSort()
),

topDown(// must behind NormalizeAggregate/NormalizeSort
new MergeProjects(),
new PushDownEncodeSlot(),
new DecoupleEncodeDecode()
),

topic("Window analysis",
topDown(
new ExtractAndNormalizeWindowExpression(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public enum RuleType {
// rewrite rules
COMPRESSED_MATERIALIZE_AGG(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_SORT(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_REPEAT(RuleTypeClass.REWRITE),
PUSH_DOWN_ENCODE_SLOT(RuleTypeClass.REWRITE),
DECOUPLE_DECODE_ENCODE_SLOT(RuleTypeClass.REWRITE),
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsSmallInt;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.CharacterType;
Expand Down Expand Up @@ -72,6 +73,11 @@ public List<Rule> buildRules() {
logicalSort().when(a -> ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().enableCompressMaterialize)
.then(this::compressMaterializeSort)
// ),
// RuleType.COMPRESSED_MATERIALIZE_REPEAT.build(
// logicalRepeat().when(r -> ConnectContext.get() != null
// && ConnectContext.get().getSessionVariable().enableCompressMaterialize)
// .then(this::compressMaterializeRepeat)
)
);
}
Expand Down Expand Up @@ -163,4 +169,52 @@ private LogicalAggregate<Plan> compressedMaterializeAggregate(LogicalAggregate<P
}
return aggregate;
}

private Map<Expression, Expression> getEncodeGroupingSets(LogicalRepeat<Plan> repeat) {
Map<Expression, Expression> encode = Maps.newHashMap();
// the first grouping set contains all group by keys
for (Expression gb : repeat.getGroupingSets().get(0)) {
Optional<Expression> encodeExpr = getEncodeExpression(gb);
encodeExpr.ifPresent(expression -> encode.put(gb, expression));
}
return encode;
}

private LogicalRepeat<Plan> compressMaterializeRepeat(LogicalRepeat<Plan> repeat) {
Map<Expression, Expression> encode = getEncodeGroupingSets(repeat);
if (encode.isEmpty()) {
return repeat;
}
List<List<Expression>> newGroupingSets = Lists.newArrayList();
for (int i = 0; i < repeat.getGroupingSets().size(); i++) {
List<Expression> grouping = Lists.newArrayList();
for (int j = 0; j < repeat.getGroupingSets().get(i).size(); j++) {
Expression groupingExpr = repeat.getGroupingSets().get(i).get(j);
grouping.add(encode.getOrDefault(groupingExpr, groupingExpr));
}
newGroupingSets.add(grouping);
}
List<NamedExpression> newOutputs = Lists.newArrayList();
Map<Expression, Expression> decodeMap = new HashMap<>();
for (Expression gp : encode.keySet()) {
decodeMap.put(gp, new DecodeAsVarchar(encode.get(gp)));
}
for (NamedExpression out : repeat.getOutputExpressions()) {
Expression replaced = ExpressionUtils.replace(out, decodeMap);
if (out != replaced) {
if (out instanceof SlotReference) {
newOutputs.add(new Alias(out.getExprId(), replaced, out.getName()));
} else if (out instanceof Alias) {
newOutputs.add(((Alias) out).withChildren(replaced.children()));
} else {
// should not reach here
Preconditions.checkArgument(false, "output abnormal: " + repeat);
}
} else {
newOutputs.add(out);
}
}
repeat = repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs);
return repeat;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeString;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.types.coercion.CharacterType;
Expand Down Expand Up @@ -73,47 +76,86 @@ public class PushDownEncodeSlot extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject()
.when(this::containsEncode)
.whenNot(project -> project.child() instanceof LogicalRepeat)
.when(project -> !(project.child() instanceof LogicalCatalogRelation))
.then(project -> pushDownEncodeSlot(project))
.toRule(RuleType.PUSH_DOWN_ENCODE_SLOT);
}

private boolean containsEncode(LogicalProject<? extends Plan> project) {
return project.getProjects().stream()
.anyMatch(e -> e instanceof Alias && containsEncode(e.child(0)));
}

private boolean containsEncode(Expression expr) {
return expr instanceof EncodeString && expr.child(0) instanceof SlotReference;
}

private List<Alias> collectEncodeAliases(LogicalProject<? extends Plan> project) {
List<Alias> encodeAliases = new ArrayList<>();
project.getProjects().forEach(e -> {
if (e instanceof Alias
&& e.child(0) instanceof EncodeString
for (NamedExpression e : project.getProjects()) {
if (e instanceof Alias && e.child(0) instanceof EncodeString
&& e.child(0).child(0) instanceof SlotReference) {
encodeAliases.add((Alias) e);
}
});
}
return encodeAliases;
}

/**
* case 1
* project(encode(A) as B)
* --> any(A)
* =>
* project(B)
* -->any(A): push "encode(A) as B"
*
* case 2
* project(A, encode(A) as B)
* -->any(A)
* =>
* project(decode(B) as A, B)
* -->any(A): push "encode(A) as B"
*
* case 3
* project(A as C, encode(A) as B)
* -->any(A)
* =>
* project(decode(B) as C, B)
* -->any(A): push "encode(A) as B"
*/
private LogicalProject<? extends Plan> rewriteRootProject(LogicalProject<? extends Plan> project,
List<Alias> encodeAlias) {
if (encodeAlias.isEmpty()) {
List<Alias> pushedEncodeAlias) {
if (pushedEncodeAlias.isEmpty()) {
return project;
}
List<NamedExpression> projections = project.getProjects().stream().map(
e -> encodeAlias.contains(e) ? e.toSlot() : e)
.collect(Collectors.toList());
Map<Expression, Alias> encodeBodyToEncodeAlias = new HashMap<>();
for (Alias alias : pushedEncodeAlias) {
Expression encodeBody = alias.child().child(0);
encodeBodyToEncodeAlias.put(encodeBody, alias);
}
List<NamedExpression> projections = Lists.newArrayListWithCapacity(project.getProjects().size());
for (NamedExpression e : project.getProjects()) {
if (pushedEncodeAlias.contains(e)) {
// case 1
projections.add(e.toSlot());
} else if (encodeBodyToEncodeAlias.containsKey(e)) {
// case 2
ExprId id = e.getExprId();
DecodeAsVarchar decode = new DecodeAsVarchar(encodeBodyToEncodeAlias.get(e).toSlot());
Alias alias = new Alias(id, decode, decode.toSql());
projections.add(alias);
} else if (e instanceof Alias && encodeBodyToEncodeAlias.containsKey(e.child(0))) {
// case 3
Alias alias = (Alias) e;
DecodeAsVarchar decode = new DecodeAsVarchar(encodeBodyToEncodeAlias.get(e.child(0)).toSlot());
Alias newAlias = (Alias) alias.withChildren(decode);
projections.add(newAlias);
} else {
projections.add(e);
}
}
return project.withProjects(projections);

}

private LogicalProject<? extends Plan> pushDownEncodeSlot(LogicalProject<? extends Plan> project) {
List<Alias> encodeAliases = collectEncodeAliases(project);
if (encodeAliases.isEmpty()) {
return project;
}

PushDownContext ctx = new PushDownContext(project, encodeAliases);
ctx.prepare();
if (ctx.notPushed.size() == encodeAliases.size()) {
Expand Down Expand Up @@ -158,30 +200,6 @@ public PushDownContext(Plan plan, List<Alias> encodeAliases) {
this.encodeAliases = encodeAliases;
}

// public static boolean canBothSidesEncode(ComparisonPredicate compare) {
// return compare.left().getDataType() instanceof CharacterType
// && ((CharacterType) compare.left().getDataType()).getLen() < 15
// && ((CharacterType) compare.right().getDataType()).getLen() < 15
// && compare.left() instanceof SlotReference && compare.right() instanceof SlotReference;
// }
//
// private BiMap<SlotReference, SlotReference> getCompareSlotsFromJoinCondition(LogicalJoin<?, ?> join) {
// BiMap<SlotReference, SlotReference> compareSlots = HashBiMap.create();
// List<Expression> conditions = new ArrayList<>();
// conditions.addAll(join.getHashJoinConjuncts());
// conditions.addAll(join.getOtherJoinConjuncts());
// for (Expression e : conditions) {
// if (e instanceof ComparisonPredicate) {
// ComparisonPredicate compare = (ComparisonPredicate) e;
// if (canBothSidesEncode(compare)) {
// compareSlots.put((SlotReference) compare.left(), (SlotReference) compare.right());
// }
// }
// }
//
// return compareSlots;
// }

// init replaceMap/toBePushed/notPushed
private void prepare() {
List<Set<Slot>> childrenPassThroughSlots =
Expand All @@ -195,12 +213,6 @@ private void prepare() {
childrenPassThroughSlots.get(i).addAll(compareSlots.keySet());
}
}
//
// if (plan instanceof LogicalJoin) {
// LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) plan;
// BiMap<SlotReference, SlotReference> compareSlots = getCompareSlotsFromJoinCondition(join);
// expandEncodeAliasForJoin(compareSlots);
// }
for (Alias alias : encodeAliases) {
EncodeString encode = (EncodeString) alias.child();
Expression strExpr = encode.child();
Expand Down Expand Up @@ -315,6 +327,12 @@ public Plan visit(Plan plan, PushDownContext ctx) {
return plan;
}

@Override
public Plan visitLogicalRepeat(LogicalRepeat repeat, PushDownContext ctx) {
Plan plan = projectNotPushedAlias(repeat, ctx.encodeAliases);
return plan;
}

private Optional<Alias> findEncodeAliasByEncodeSlot(SlotReference slot, List<Alias> aliases) {
for (Alias alias : aliases) {
if (alias.child().child(0).equals(slot)) {
Expand Down Expand Up @@ -347,12 +365,23 @@ public LogicalProject<? extends Plan> visitLogicalProject(
* and push down "encode(k) as v2" to any(v1)
*
* case 3
* push down "encode(v44) as v307"
* project(decode(v305) as v44)
* +-->agg(v305, groupBy[v305])
* +--->project(encode(v44) as v305)
* =>
* project(v305 as v307)
* +-->agg
*
* case 4
* push down "encode(v1) as v2
* project(a + b as v1, ...)
* +--->any(a, b)
* =>
* project(encode(a+b) as v2, ...)
* +-->any(a, b)
*
*/
List<NamedExpression> projections = Lists.newArrayListWithCapacity(project.getProjects().size());
List<Alias> toBePushed = Lists.newArrayList();
Expand Down Expand Up @@ -388,16 +417,30 @@ public LogicalProject<? extends Plan> visitLogicalProject(
changed = true;
}
} else {
// case 3
Optional<Alias> encodeAliasOpt = findEncodeAliasByEncodeSlot((SlotReference) e.toSlot(),
ctx.encodeAliases);
if (encodeAliasOpt.isPresent()) {
Alias encodeAlias = encodeAliasOpt.get();
EncodeString encode = (EncodeString) encodeAlias.child();
Alias encodeAliasForChild = (Alias) encodeAlias
.withChildren(encode.withChildren(aliasExpr));
notPushed.add(encodeAliasForChild);
changed = true;
if (aliasExpr instanceof DecodeAsVarchar) {
// case 3
// push down "encode(v44) as v307"
// project(decode(v305) as v44)
// +-->agg(v305, groupBy[v305])
// +--->project(encode(v44) as v305)
Expression decodeBody = aliasExpr.child(0);
Alias aliasForProject = (Alias) encodeAlias.withChildren(decodeBody);
projections.add(aliasForProject);
notPushed.remove(encodeAlias);
changed = true;
} else {
// case 4
EncodeString encode = (EncodeString) encodeAlias.child();
Alias encodeAliasForProject = (Alias) encodeAlias
.withChildren(encode.withChildren(aliasExpr));
projections.add(encodeAliasForProject);
notPushed.remove(encodeAlias);
changed = true;
}
}
}
}
Expand All @@ -408,10 +451,12 @@ public LogicalProject<? extends Plan> visitLogicalProject(
projections.addAll(notPushed);

project = project.withProjects(projections);
PushDownContext childContext = new PushDownContext(project.child(), toBePushed);
Plan newChild = project.child().accept(this, childContext);
if (project.child() != newChild) {
project = (LogicalProject<? extends Plan>) project.withChildren(newChild);
if (!toBePushed.isEmpty()) {
PushDownContext childContext = new PushDownContext(project.child(), toBePushed);
Plan newChild = project.child().accept(this, childContext);
if (project.child() != newChild) {
project = (LogicalProject<? extends Plan>) project.withChildren(newChild);
}
}
return project;
}
Expand Down Expand Up @@ -491,12 +536,12 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, P
if (canBothSidesEncode(compare)) {
SlotReference newLeft = replaceMap.get(compare.left());
SlotReference newRight = replaceMap.get(compare.right());
Preconditions.checkArgument(newLeft != null,
"PushDownEncodeSlot replaceMap is not valid, " + compare.left() + " is not found");
Preconditions.checkArgument(newRight != null,
"PushDownEncodeSlot replaceMap is not valid, " + compare.right() + " is not found");
compare = (ComparisonPredicate) compare.withChildren(newLeft, newRight);
changed = true;
if (newLeft != null && newRight != null) {
compare = (ComparisonPredicate) compare.withChildren(newLeft, newRight);
changed = true;
}
Preconditions.checkArgument((newLeft == null) == (newRight == null),
"PushDownEncodeSlot replaceMap is not valid, " + compare);
}
newConjuncts.add(compare);
}
Expand Down Expand Up @@ -564,9 +609,13 @@ public Plan visitLogicalSetOperation(LogicalSetOperation op, PushDownContext ctx
//rewrite children
List<Plan> newChildren = Lists.newArrayListWithCapacity(op.children().size());
for (Plan child : op.children()) {
PushDownContext childCtx = new PushDownContext(child, ctx.toBePushedToChild.get(child));
Plan newChild = child.accept(this, childCtx);
newChildren.add(newChild);
if (!ctx.toBePushedToChild.get(child).isEmpty()) {
PushDownContext childCtx = new PushDownContext(child, ctx.toBePushedToChild.get(child));
Plan newChild = child.accept(this, childCtx);
newChildren.add(newChild);
} else {
newChildren.add(child);
}
}
op = op.withChildrenAndTheirOutputs(newChildren, newRegularOutputs);
return op;
Expand Down
Loading

0 comments on commit c7fc744

Please sign in to comment.