Skip to content

Commit

Permalink
or-to-in
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Sep 24, 2024
1 parent c2e4d97 commit 7f06720
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@
public class ExpressionOptimization extends ExpressionRewrite {
public static final List<ExpressionRewriteRule> OPTIMIZE_REWRITE_RULES = ImmutableList.of(
bottomUp(
ExtractCommonFactorRule.INSTANCE,
DistinctPredicatesRule.INSTANCE,
SimplifyComparisonPredicate.INSTANCE,
SimplifyInPredicate.INSTANCE,
SimplifyDecimalV3Comparison.INSTANCE,
OrToIn.INSTANCE,
SimplifyRange.INSTANCE,
DateFunctionRewrite.INSTANCE,
ArrayContainToArrayOverlap.INSTANCE,
CaseWhenToIf.INSTANCE,
TopnToMax.INSTANCE,
NullSafeEqualToEqual.INSTANCE,
LikeToEqualRewrite.INSTANCE
ExtractCommonFactorRule.INSTANCE,
DistinctPredicatesRule.INSTANCE,
SimplifyComparisonPredicate.INSTANCE,
SimplifyInPredicate.INSTANCE,
SimplifyDecimalV3Comparison.INSTANCE,
OrToIn.INSTANCE,
SimplifyRange.INSTANCE,
DateFunctionRewrite.INSTANCE,
ArrayContainToArrayOverlap.INSTANCE,
CaseWhenToIf.INSTANCE,
TopnToMax.INSTANCE,
NullSafeEqualToEqual.INSTANCE,
LikeToEqualRewrite.INSTANCE
)
);
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,29 @@
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* dependends on SimplifyRange rule
*
* Used to convert multi equalTo which has same slot and compare to a literal of disjunction to a InPredicate so that
* it could be push down to storage engine.
* example:
Expand All @@ -65,7 +70,7 @@ public class OrToIn implements ExpressionPatternRuleFactory {
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesTopType(Or.class).then(OrToIn::rewrite)
matchesTopType(Or.class).then(OrToIn.INSTANCE::rewrite)
);
}

Expand All @@ -74,73 +79,202 @@ public Expression rewriteTree(Expression expr, ExpressionRewriteContext context)
return bottomUpRewriter.rewrite(expr, context);
}

private static Expression rewrite(Or or) {
// NOTICE: use linked hash map to avoid unstable order or entry.
// unstable order entry lead to dead loop since return expression always un-equals to original one.
Map<NamedExpression, Set<Literal>> slotNameToLiteral = Maps.newLinkedHashMap();
Map<Expression, NamedExpression> disConjunctToSlot = Maps.newLinkedHashMap();
List<Expression> expressions = ExpressionUtils.extractDisjunction(or);
for (Expression expression : expressions) {
if (expression instanceof EqualTo) {
handleEqualTo((EqualTo) expression, slotNameToLiteral, disConjunctToSlot);
} else if (expression instanceof InPredicate) {
handleInPredicate((InPredicate) expression, slotNameToLiteral, disConjunctToSlot);
private Expression rewrite(Or or) {
if (or.getMutableState("OrToIn").isPresent()) {
return or;
}

Expression simple = SimplifyRange.rewrite(or);
if (!(simple instanceof Or)) {
simple.setMutableState("OrToIn", 1);
return simple;
}

or = (Or) simple;
or.setMutableState("OrToIn", 1);

List<Expression> disjuncts = ExpressionUtils.extractDisjunction(or);
for (Expression disjunct : disjuncts) {
if (!hasInOrEqualChildren(disjunct)) {
return or;
}
}
if (disConjunctToSlot.isEmpty()) {

Map<Expression, Set<Literal>> candidates = getCandidate(disjuncts.get(0));
if (candidates.isEmpty()) {
return or;
}

List<Expression> rewrittenOr = new ArrayList<>();
for (Map.Entry<NamedExpression, Set<Literal>> entry : slotNameToLiteral.entrySet()) {
Set<Literal> literals = entry.getValue();
if (literals.size() >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
InPredicate inPredicate = new InPredicate(entry.getKey(), ImmutableList.copyOf(entry.getValue()));
rewrittenOr.add(inPredicate);
// verify each candidate
for (int i = 1; i < disjuncts.size(); i++) {
Map<Expression, Set<Literal>> otherCandidates = getCandidate(disjuncts.get(i));
if (otherCandidates.isEmpty()) {
return or;
}
candidates = mergeCandidates(candidates, otherCandidates);
if (candidates.isEmpty()) {
return or;
}
}
for (Expression expression : expressions) {
if (disConjunctToSlot.get(expression) == null) {
rewrittenOr.add(expression);
if (!candidates.isEmpty()) {
Expression conjunct = candidatesToFinalResult(candidates);
List<Expression> cleanedDisjuncts = cleanDisjunctsByCandidates(disjuncts, candidates.keySet());
if (cleanedDisjuncts.isEmpty()) {
return conjunct;
} else {
Set<Literal> literals = slotNameToLiteral.get(disConjunctToSlot.get(expression));
if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
rewrittenOr.add(expression);
return new And(conjunct, or);
}
}
return or;
}

private List<Expression> cleanDisjunctsByCandidates(List<Expression> disjuncts, Set<Expression> candidateKeys) {
List<Expression> cleanedDisjuncts = Lists.newArrayList();
for (Expression disjunct : disjuncts) {
List<Expression> conjuncts = ExpressionUtils.extractConjunction(disjunct);
List<Expression> cleanedConjuncts = Lists.newArrayList();
for (Expression conjunct : conjuncts) {
if (!containsAny(conjunct.getInputSlots(), candidateKeys)) {
cleanedConjuncts.add(conjunct);
}
}
if (!cleanedConjuncts.isEmpty()) {
cleanedDisjuncts.add(ExpressionUtils.and(cleanedConjuncts));
}
}
return cleanedDisjuncts;
}

return ExpressionUtils.or(rewrittenOr);
private boolean containsAny(Set a, Set b) {
for (Object x : a) {
if (b.contains(x)) {
return true;
}
}
return false;
}

private static void handleEqualTo(EqualTo equal, Map<NamedExpression, Set<Literal>> slotNameToLiteral,
Map<Expression, NamedExpression> disConjunctToSlot) {
Expression left = equal.left();
Expression right = equal.right();
if (left instanceof NamedExpression && right instanceof Literal) {
addSlotToLiteral((NamedExpression) left, (Literal) right, slotNameToLiteral);
disConjunctToSlot.put(equal, (NamedExpression) left);
} else if (right instanceof NamedExpression && left instanceof Literal) {
addSlotToLiteral((NamedExpression) right, (Literal) left, slotNameToLiteral);
disConjunctToSlot.put(equal, (NamedExpression) right);
private Map<Expression, Set<Literal>> mergeCandidates(
Map<Expression, Set<Literal>> a,
Map<Expression, Set<Literal>> b) {
Map<Expression, Set<Literal>> result = new LinkedHashMap<>();
for (Expression expr : a.keySet()) {
Set<Literal> otherLiterals = b.get(expr);
if (otherLiterals != null) {
Set<Literal> literals = a.get(expr);
literals.addAll(otherLiterals);
if (!literals.isEmpty()) {
result.put(expr, literals);
}
}
}
return result;
}

private static void handleInPredicate(InPredicate inPredicate, Map<NamedExpression, Set<Literal>> slotNameToLiteral,
Map<Expression, NamedExpression> disConjunctToSlot) {
// TODO a+b in (1,2,3...) is not supported now
if (inPredicate.getCompareExpr() instanceof NamedExpression
&& inPredicate.getOptions().stream().allMatch(opt -> opt instanceof Literal)) {
for (Expression opt : inPredicate.getOptions()) {
addSlotToLiteral((NamedExpression) inPredicate.getCompareExpr(), (Literal) opt, slotNameToLiteral);
private Expression candidatesToFinalResult(Map<Expression, Set<Literal>> candidates) {
List<Expression> conjuncts = new ArrayList<>();
for (Expression key : candidates.keySet()) {
Set<Literal> literals = candidates.get(key);
if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
for (Literal literal : literals) {
conjuncts.add(new EqualTo(key, literal));
}
} else {
conjuncts.add(new InPredicate(key, ImmutableList.copyOf(literals)));
}
}
return ExpressionUtils.and(conjuncts);
}

/*
it is not necessary to rewrite "a like 'xyz' or a=1 or a=2" to "a like 'xyz' or a in (1, 2)",
because we cannot push "a in (1, 2)" into storage layer
*/
private boolean hasInOrEqualChildren(Expression disjunct) {
List<Expression> conjuncts = ExpressionUtils.extractConjunction(disjunct);
for (Expression conjunct : conjuncts) {
if (conjunct instanceof EqualTo || conjunct instanceof InPredicate) {
return true;
}
}
return false;
}

// conjuncts.get(idx) has different input slots
private boolean independentConjunct(int idx, List<Expression> conjuncts) {
Expression conjunct = conjuncts.get(idx);
Set<Slot> targetSlots = conjunct.getInputSlots();
if (conjuncts.size() == 1) {
return true;
}
for (int i = 0; i < conjuncts.size(); i++) {
if (i != idx) {
Set<Slot> otherInput = Sets.newHashSet();
otherInput.addAll(conjuncts.get(i).getInputSlots());
otherInput.retainAll(targetSlots);
if (!otherInput.isEmpty()) {
return false;
}
}
disConjunctToSlot.put(inPredicate, (NamedExpression) inPredicate.getCompareExpr());
}
return true;
}

private static void addSlotToLiteral(NamedExpression namedExpression, Literal literal,
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new LinkedHashSet<>());
literals.add(literal);
private Map<Expression, Set<Literal>> getCandidate(Expression disjunct) {
List<Expression> conjuncts = ExpressionUtils.extractConjunction(disjunct);
Map<Expression, Set<Literal>> candidates = new LinkedHashMap<>();
// collect candidates from the first disjunction
for (int idx = 0; idx < conjuncts.size(); idx++) {
if (!independentConjunct(idx, conjuncts)) {
continue;
}
// find pattern: A=1 / A in (1, 2, 3 ...)
// candidates: A->[1] / A -> [1, 2, 3, ...]
Expression conjunct = conjuncts.get(idx);
Expression compareExpr = null;
if (conjunct instanceof EqualTo) {
EqualTo eq = (EqualTo) conjunct;
Literal literal = null;
if (!(eq.left() instanceof Literal) && eq.right() instanceof Literal) {
compareExpr = eq.left();
literal = (Literal) eq.right();
} else if (!(eq.right() instanceof Literal) && eq.left() instanceof Literal) {
compareExpr = eq.right();
literal = (Literal) eq.left();
}
if (compareExpr != null) {
Set<Literal> literals = candidates.get(compareExpr);
if (literals == null) {
literals = Sets.newHashSet();
literals.add(literal);
candidates.put(compareExpr, literals);
} else {
// pattern like (A=1 and A=2) should be processed by SimplifyRange rule
// OrToIn rule does apply to this expression
candidates.clear();
break;

}
}
} else if (conjunct instanceof InPredicate) {
InPredicate inPredicate = (InPredicate) conjunct;
Set<Literal> literalOptions = inPredicate.getOptions().stream()
.filter(Literal.class::isInstance)
.map(e -> (Literal) e)
.collect(Collectors.toSet());
if (literalOptions.size() == inPredicate.getOptions().size()) {
Set<Literal> alreadyMappedLiterals = candidates.get(inPredicate.getCompareExpr());
if (alreadyMappedLiterals == null) {
candidates.put(inPredicate.getCompareExpr(), literalOptions);
} else {
// pattern like (A=1 and A in (1, 2)) should be processed by SimplifyRange rule
// OrToIn rule does apply to this expression
candidates.clear();
break;
}
}
}
}
return candidates;
}
}
Loading

0 comments on commit 7f06720

Please sign in to comment.