From 8fb39634bdf4baebce16f48e75cc6e019584513a Mon Sep 17 00:00:00 2001 From: minghong Date: Tue, 24 Sep 2024 17:34:32 +0800 Subject: [PATCH] or-to-in --- .../expression/ExpressionOptimization.java | 26 +- .../rules/expression/rules/OrToIn.java | 236 ++++++++++++++---- .../nereids/rules/rewrite/OrToInTest.java | 63 +++-- 3 files changed, 242 insertions(+), 83 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index abf57057601dc8f..752fe156cc22590 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -41,19 +41,19 @@ public class ExpressionOptimization extends ExpressionRewrite { public static final List OPTIMIZE_REWRITE_RULES = ImmutableList.of( bottomUp( - ExtractCommonFactorRule.INSTANCE, - DistinctPredicatesRule.INSTANCE, - SimplifyComparisonPredicate.INSTANCE, - SimplifyInPredicate.INSTANCE, - SimplifyDecimalV3Comparison.INSTANCE, - OrToIn.INSTANCE, - SimplifyRange.INSTANCE, - DateFunctionRewrite.INSTANCE, - ArrayContainToArrayOverlap.INSTANCE, - CaseWhenToIf.INSTANCE, - TopnToMax.INSTANCE, - NullSafeEqualToEqual.INSTANCE, - LikeToEqualRewrite.INSTANCE + ExtractCommonFactorRule.INSTANCE, + DistinctPredicatesRule.INSTANCE, + SimplifyComparisonPredicate.INSTANCE, + SimplifyInPredicate.INSTANCE, + SimplifyDecimalV3Comparison.INSTANCE, + OrToIn.INSTANCE, + SimplifyRange.INSTANCE, + DateFunctionRewrite.INSTANCE, + ArrayContainToArrayOverlap.INSTANCE, + CaseWhenToIf.INSTANCE, + TopnToMax.INSTANCE, + NullSafeEqualToEqual.INSTANCE, + LikeToEqualRewrite.INSTANCE ) ); private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java index 83da8055037242b..4e99101b29d6628 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java @@ -22,24 +22,28 @@ import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Maps; +import com.google.common.collect.Sets; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; /** + * dependends on SimplifyRange rule + * * Used to convert multi equalTo which has same slot and compare to a literal of disjunction to a InPredicate so that * it could be push down to storage engine. * example: @@ -65,7 +69,7 @@ public class OrToIn implements ExpressionPatternRuleFactory { @Override public List> buildRules() { return ImmutableList.of( - matchesTopType(Or.class).then(OrToIn::rewrite) + matchesTopType(Or.class).then(OrToIn.INSTANCE::rewrite) ); } @@ -74,73 +78,205 @@ public Expression rewriteTree(Expression expr, ExpressionRewriteContext context) return bottomUpRewriter.rewrite(expr, context); } - private static Expression rewrite(Or or) { - // NOTICE: use linked hash map to avoid unstable order or entry. - // unstable order entry lead to dead loop since return expression always un-equals to original one. - Map> slotNameToLiteral = Maps.newLinkedHashMap(); - Map disConjunctToSlot = Maps.newLinkedHashMap(); - List expressions = ExpressionUtils.extractDisjunction(or); - for (Expression expression : expressions) { - if (expression instanceof EqualTo) { - handleEqualTo((EqualTo) expression, slotNameToLiteral, disConjunctToSlot); - } else if (expression instanceof InPredicate) { - handleInPredicate((InPredicate) expression, slotNameToLiteral, disConjunctToSlot); + private Expression rewrite(Or or) { + if (or.getMutableState("OrToIn").isPresent()) { + return or; + } + + Expression simple = SimplifyRange.rewrite(or); + if (!(simple instanceof Or)) { + simple.setMutableState("OrToIn", 1); + return simple; + } + + or = (Or) simple; + or.setMutableState("OrToIn", 1); + + List disjuncts = ExpressionUtils.extractDisjunction(or); + for (Expression disjunct : disjuncts) { + if (!hasInOrEqualChildren(disjunct)) { + return or; } } - if (disConjunctToSlot.isEmpty()) { + + Map> candidates = getCandidate(disjuncts.get(0)); + if (candidates.isEmpty()) { return or; } - List rewrittenOr = new ArrayList<>(); - for (Map.Entry> entry : slotNameToLiteral.entrySet()) { - Set literals = entry.getValue(); - if (literals.size() >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { - InPredicate inPredicate = new InPredicate(entry.getKey(), ImmutableList.copyOf(entry.getValue())); - rewrittenOr.add(inPredicate); + // verify each candidate + for (int i = 1; i < disjuncts.size(); i++) { + Map> otherCandidates = getCandidate(disjuncts.get(i)); + if (otherCandidates.isEmpty()) { + return or; + } + candidates = mergeCandidates(candidates, otherCandidates); + if (candidates.isEmpty()) { + return or; } } - for (Expression expression : expressions) { - if (disConjunctToSlot.get(expression) == null) { - rewrittenOr.add(expression); + if (!candidates.isEmpty()) { + Expression conjunct = candidatesToFinalResult(candidates); + boolean hasOtherExpr = hasOtherExpressionExceptCandidates(disjuncts, candidates.keySet()); + if (hasOtherExpr) { + return new And(conjunct, or); } else { - Set literals = slotNameToLiteral.get(disConjunctToSlot.get(expression)); - if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { - rewrittenOr.add(expression); + return conjunct; + } + } + return or; + } + + private boolean hasOtherExpressionExceptCandidates(List disjuncts, Set candidateKeys) { + for (Expression disjunct : disjuncts) { + List conjuncts = ExpressionUtils.extractConjunction(disjunct); + for (Expression conjunct : conjuncts) { + if (!containsAny(conjunct.getInputSlots(), candidateKeys)) { + return true; } } + } + return false; + } - return ExpressionUtils.or(rewrittenOr); + private boolean containsAny(Set a, Set b) { + for (Object x : a) { + if (b.contains(x)) { + return true; + } + } + return false; } - private static void handleEqualTo(EqualTo equal, Map> slotNameToLiteral, - Map disConjunctToSlot) { - Expression left = equal.left(); - Expression right = equal.right(); - if (left instanceof NamedExpression && right instanceof Literal) { - addSlotToLiteral((NamedExpression) left, (Literal) right, slotNameToLiteral); - disConjunctToSlot.put(equal, (NamedExpression) left); - } else if (right instanceof NamedExpression && left instanceof Literal) { - addSlotToLiteral((NamedExpression) right, (Literal) left, slotNameToLiteral); - disConjunctToSlot.put(equal, (NamedExpression) right); + private Map> mergeCandidates( + Map> a, + Map> b) { + Map> result = new LinkedHashMap<>(); + for (Expression expr : a.keySet()) { + Set otherLiterals = b.get(expr); + if (otherLiterals != null) { + Set literals = a.get(expr); + literals.addAll(otherLiterals); + if (!literals.isEmpty()) { + result.put(expr, literals); + } + } } + return result; } - private static void handleInPredicate(InPredicate inPredicate, Map> slotNameToLiteral, - Map disConjunctToSlot) { - // TODO a+b in (1,2,3...) is not supported now - if (inPredicate.getCompareExpr() instanceof NamedExpression - && inPredicate.getOptions().stream().allMatch(opt -> opt instanceof Literal)) { - for (Expression opt : inPredicate.getOptions()) { - addSlotToLiteral((NamedExpression) inPredicate.getCompareExpr(), (Literal) opt, slotNameToLiteral); + private Expression candidatesToFinalResult(Map> candidates) { + List conjuncts = new ArrayList<>(); + for (Expression key : candidates.keySet()) { + Set literals = candidates.get(key); + if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { + for (Literal literal : literals) { + conjuncts.add(new EqualTo(key, literal)); + } + } else { + conjuncts.add(new InPredicate(key, ImmutableList.copyOf(literals))); } - disConjunctToSlot.put(inPredicate, (NamedExpression) inPredicate.getCompareExpr()); } + return ExpressionUtils.and(conjuncts); } - private static void addSlotToLiteral(NamedExpression namedExpression, Literal literal, - Map> slotNameToLiteral) { - Set literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new LinkedHashSet<>()); - literals.add(literal); + /* + it is not necessary to rewrite "a like 'xyz' or a=1 or a=2" to "a like 'xyz' or a in (1, 2)", + because we cannot push "a in (1, 2)" into storage layer + */ + private boolean hasInOrEqualChildren(Expression disjunct) { + List conjuncts = ExpressionUtils.extractConjunction(disjunct); + for (Expression conjunct : conjuncts) { + if (conjunct instanceof EqualTo || conjunct instanceof InPredicate) { + return true; + } + } + return false; + } + + // conjuncts.get(idx) has different input slots + private boolean independentConjunct(int idx, List conjuncts) { + Expression conjunct = conjuncts.get(idx); + Set targetSlots = conjunct.getInputSlots(); + if (conjuncts.size() == 1) { + return true; + } + for (int i = 0; i < conjuncts.size(); i++) { + if (i != idx) { + Set otherInput = Sets.newHashSet(); + otherInput.addAll(conjuncts.get(i).getInputSlots()); + otherInput.retainAll(targetSlots); + if (!otherInput.isEmpty()) { + return false; + } + } + } + return true; + } + + private Map> getCandidate(Expression disjunct) { + List conjuncts = ExpressionUtils.extractConjunction(disjunct); + Map> candidates = new LinkedHashMap<>(); + // collect candidates from the first disjunction + for (int idx = 0; idx < conjuncts.size(); idx++) { + if (!independentConjunct(idx, conjuncts)) { + continue; + } + // find pattern: A=1 / A in (1, 2, 3 ...) + // candidates: A->[1] / A -> [1, 2, 3, ...] + Expression conjunct = conjuncts.get(idx); + Expression compareExpr = null; + if (conjunct instanceof EqualTo) { + EqualTo eq = (EqualTo) conjunct; + Literal literal = null; + if (!(eq.left() instanceof Literal) && eq.right() instanceof Literal) { + compareExpr = eq.left(); + literal = (Literal) eq.right(); + } else if (!(eq.right() instanceof Literal) && eq.left() instanceof Literal) { + compareExpr = eq.right(); + literal = (Literal) eq.left(); + } + if (compareExpr != null) { + Set literals = candidates.get(compareExpr); + if (literals == null) { + literals = Sets.newHashSet(); + literals.add(literal); + candidates.put(compareExpr, literals); + } else { + // pattern like (A=1 and A=2) should be processed by SimplifyRange rule + // OrToIn rule does apply to this expression + candidates.clear(); + break; + + } + } + } else if (conjunct instanceof InPredicate) { + InPredicate inPredicate = (InPredicate) conjunct; + Set literalOptions = new LinkedHashSet<>(); + boolean allLiteralOpts = true; + for (Expression opt : inPredicate.getOptions()) { + if (opt instanceof Literal) { + literalOptions.add((Literal) opt); + } else { + allLiteralOpts = false; + break; + } + } + + if (allLiteralOpts) { + Set alreadyMappedLiterals = candidates.get(inPredicate.getCompareExpr()); + if (alreadyMappedLiterals == null) { + candidates.put(inPredicate.getCompareExpr(), literalOptions); + } else { + // pattern like (A=1 and A in (1, 2)) should be processed by SimplifyRange rule + // OrToIn rule does apply to this expression + candidates.clear(); + break; + } + } + } + } + return candidates; } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java index 98eac158185c5af..3062b99610ab48d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java @@ -19,7 +19,6 @@ import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.rules.expression.rules.OrToIn; -import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -39,22 +38,8 @@ void test1() { String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)"; Expression expression = PARSER.parseExpression(expr); Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); - Set inPredicates = rewritten.collect(e -> e instanceof InPredicate); - Assertions.assertEquals(1, inPredicates.size()); - InPredicate inPredicate = inPredicates.iterator().next(); - NamedExpression namedExpression = (NamedExpression) inPredicate.getCompareExpr(); - Assertions.assertEquals("col1", namedExpression.getName()); - List options = inPredicate.getOptions(); - Assertions.assertEquals(2, options.size()); - Set opVals = ImmutableSet.of(1, 2); - for (Expression op : options) { - Literal literal = (Literal) op; - Assertions.assertTrue(opVals.contains(((Byte) literal.getValue()).intValue())); - } - Set ands = rewritten.collect(e -> e instanceof And); - Assertions.assertEquals(1, ands.size()); - And and = ands.iterator().next(); - Assertions.assertEquals("((col1 = 3) AND (col2 = 4))", and.toSql()); + Assertions.assertEquals("(col1 IN (1, 2, 3) AND (col1 IN (1, 2) OR ((col1 = 3) AND (col2 = 4))))", + rewritten.toSql()); } @Test @@ -62,7 +47,7 @@ void test2() { String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4"; Expression expression = PARSER.parseExpression(expr); Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); - Assertions.assertEquals("((((col1 = 1) AND (col1 = 3)) AND (col2 = 3)) OR (col2 = 4))", + Assertions.assertEquals("(col2 = 4)", rewritten.toSql()); } @@ -104,7 +89,7 @@ void test5() { String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))"; Expression expression = PARSER.parseExpression(expr); Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); - Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))", + Assertions.assertEquals("(col = 1)", rewritten.toSql()); } @@ -121,7 +106,7 @@ void test7() { String expr = "A = 1 or A = 2 or abs(A)=5 or A in (1, 2, 3) or B = 1 or B = 2 or B in (1, 2, 3) or B+1 in (4, 5, 7)"; Expression expression = PARSER.parseExpression(expr); Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); - Assertions.assertEquals("(((A IN (1, 2, 3) OR B IN (1, 2, 3)) OR (abs(A) = 5)) OR (B + 1) IN (4, 5, 7))", rewritten.toSql()); + Assertions.assertEquals("(((A IN (1, 2, 3) OR (abs(A) = 5)) OR B IN (1, 2, 3)) OR (B + 1) IN (4, 5, 7))", rewritten.toSql()); } @Test @@ -142,4 +127,42 @@ void testEnsureOrder() { Assertions.assertEquals("(col1 IN (1, 2) OR col2 IN (1, 2))", rewritten.toSql()); } + + @Test + void test9() { + String expr = "col1=1 and (col2=1 or col2=2)"; + Expression expression = PARSER.parseExpression(expr); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); + Assertions.assertEquals("((col1 = 1) AND col2 IN (1, 2))", + rewritten.toSql()); + } + + @Test + void test10() { + String expr = "col1=1 or (col2 = 2 and (col3=4 or col3=5))"; + Expression expression = PARSER.parseExpression(expr); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); + Assertions.assertEquals("((col1 = 1) OR ((col2 = 2) AND col3 IN (4, 5)))", + rewritten.toSql()); + } + + @Test + void test11() { + String expr = "(a=1 and b=2 and c=3) or (a=2 and b=2 and c=4)"; + Expression expression = PARSER.parseExpression(expr); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); + Assertions.assertEquals("((a IN (1, 2) AND (b = 2)) AND c IN (3, 4))", + rewritten.toSql()); + } + // recursive + // col1=1 and (col2=1 or col2=2) => col1=1 and col2 in (1, 2) + // col1=1 or (col2 = 2 and (col3=4 or col3=5)) => + // col1=1 or (col1=2 and col2=3) => X col1 in (1, 2) and (col1=1 or col2=3) + // (a=1 and b=2 and c=3) or (a=2 and b=2 and c=4) + // a in (1, 2) and a in (3, 4) + // a in (1, 2) or (b=3 and (a=4 or a=5)) + // (a =1 and a in (3, 4)) or (a =5) + // a like 'xyz%' or a=1 or a=2: no extract + // (a=1 and f(a)=2) or a=3: no extract + // x=1 or (a=1 and b=2) or (a=2 and c=3) no extract }