diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java index df0dd87ed9718a..e55cfe343d78b9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java @@ -43,6 +43,7 @@ import org.apache.doris.rewrite.FoldConstantsRule; import org.apache.doris.rewrite.NormalizeBinaryPredicatesRule; import org.apache.doris.rewrite.RewriteAliasFunctionRule; +import org.apache.doris.rewrite.RewriteBinaryPredicatesRule; import org.apache.doris.rewrite.RewriteEncryptKeyRule; import org.apache.doris.rewrite.RewriteFromUnixTimeRule; import org.apache.doris.rewrite.RewriteLikePredicateRule; @@ -268,6 +269,8 @@ public GlobalState(Catalog catalog, ConnectContext context) { // Binary predicates must be rewritten to a canonical form for both predicate // pushdown and Parquet row group pruning based on min/max statistics. rules.add(NormalizeBinaryPredicatesRule.INSTANCE); + // Put it after NormalizeBinaryPredicatesRule, make sure slotRef is on the left and Literal is on the right. + rules.add(RewriteBinaryPredicatesRule.INSTANCE); rules.add(FoldConstantsRule.INSTANCE); rules.add(RewriteFromUnixTimeRule.INSTANCE); rules.add(CompoundPredicateWriteRule.INSTANCE); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java index d64ea00ed2aba5..1cf3fcd7de3010 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java @@ -35,6 +35,7 @@ import java.io.DataOutput; import java.io.IOException; import java.math.BigDecimal; +import java.math.RoundingMode; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Objects; @@ -231,6 +232,14 @@ public int getFracValue() { return fracPart.intValue(); } + public void roundCeiling() { + value = value.setScale(0, RoundingMode.CEILING); + } + + public void roundFloor() { + value = value.setScale(0, RoundingMode.FLOOR); + } + @Override protected Expr uncheckedCastTo(Type targetType) throws AnalysisException { if (targetType.isDecimalV2()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java new file mode 100644 index 00000000000000..39d4180efc08b7 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.rewrite; + +import org.apache.doris.analysis.Analyzer; +import org.apache.doris.analysis.BinaryPredicate; +import org.apache.doris.analysis.BoolLiteral; +import org.apache.doris.analysis.CastExpr; +import org.apache.doris.analysis.DecimalLiteral; +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.SlotRef; +import org.apache.doris.catalog.Type; +import org.apache.doris.common.AnalysisException; + +/** + * Rewrite binary predicate. + */ +public class RewriteBinaryPredicatesRule implements ExprRewriteRule { + public static ExprRewriteRule INSTANCE = new RewriteBinaryPredicatesRule(); + + /** + * Convert the binary predicate of the form > to the binary + * predicate of , thereby allowing the binary predicate + * The predicate pushes down and completes the bucket clipped. + * + * Examples & background + * For query "select * from T where t1 = 2.0", when the ResultType of column t1 is equal to BIGINT, in the binary + * predicate analyze, the type will be unified to DECIMALV2, so the binary predicate will be converted to + * > , because Cast wraps the t1 column, it cannot be pushed down, + * resulting in poor performance. + * We convert it to the equivalent query "select * from T where t1 = 2" to push down and improve performance. + * + * Applicable scene: + * The performance and results of the following scenarios are equivalent. + * 1) "select * from T where t1 = 2.0" is converted to "select * from T where t1 = 2" + * 2) "select * from T where t1 = 2.1" is converted to "select * from T where 2 = 2.1" (`EMPTY`) + * 3) "select * from T where t1 != 2.0" is converted to "select * from T where t1 != 2" + * 4) "select * from T where t1 != 2.1" is converted to "select * from T" + * 5) "select * from T where t1 <= 2.0" is converted to "select * from T where t1 <= 2" + * 6) "select * from T where t1 <= 2.1" is converted to "select * from T where t1 <3" + * 7) "select * from T where t1 >= 2.0" is converted to "select * from T where t1 >= 2" + * 8) "select * from T where t1 >= 2.1" is converted to "select * from T where t1> 2" + * 9) "select * from T where t1 <2.0" is converted to "select * from T where t1 <2" + * 10) "select * from T where t1 <2.1" is converted to "select * from T where t1 <3" + * 11) "select * from T where t1> 2.0" is converted to "select * from T where t1> 2" + * 12) "select * from T where t1> 2.1" is converted to "select * from T where t1> 2" + */ + private Expr rewriteBigintSlotRefCompareDecimalLiteral(Expr expr0, Expr expr1, BinaryPredicate.Operator op) + throws AnalysisException { + if (((DecimalLiteral) expr1).getDoubleValue() % (int) (((DecimalLiteral) expr1).getDoubleValue()) != 0) { + if (op == BinaryPredicate.Operator.EQ || op == BinaryPredicate.Operator.EQ_FOR_NULL) { + return new BoolLiteral(false); + } else if (op == BinaryPredicate.Operator.NE) { + return new BoolLiteral(true); + } else if (op == BinaryPredicate.Operator.LE) { + ((DecimalLiteral) expr1).roundCeiling(); + op = BinaryPredicate.Operator.LT; + } else if (op == BinaryPredicate.Operator.GE) { + ((DecimalLiteral) expr1).roundFloor(); + op = BinaryPredicate.Operator.GT; + } else if (op == BinaryPredicate.Operator.LT) { + ((DecimalLiteral) expr1).roundCeiling(); + } else if (op == BinaryPredicate.Operator.GT) { + ((DecimalLiteral) expr1).roundFloor(); + } + } + expr0 = expr0.getChild(0); + expr1 = expr1.castTo(Type.BIGINT); + return new BinaryPredicate(op, expr0, expr1); + } + + @Override + public Expr apply(Expr expr, Analyzer analyzer) throws AnalysisException { + if (!(expr instanceof BinaryPredicate)) return expr; + BinaryPredicate.Operator op = ((BinaryPredicate) expr).getOp(); + Expr expr0 = expr.getChild(0); + Expr expr1 = expr.getChild(1); + if (expr0 instanceof CastExpr && expr0.getType() == Type.DECIMALV2 && expr0.getChild(0) instanceof SlotRef + && expr0.getChild(0).getType().getResultType() == Type.BIGINT && expr1 instanceof DecimalLiteral) { + return rewriteBigintSlotRefCompareDecimalLiteral(expr0, expr1, op); + } + return expr; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java index 084a536d2bb076..ff29dc29649ffd 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java @@ -394,4 +394,44 @@ public void testAnalyticSortNodeLeftJoin() throws Exception { } + @Test + public void testBigintSlotRefCompareDecimalLiteral() { + java.util.function.BiConsumer compare = (sql1, sql2) -> { + StmtExecutor stmtExecutor1 = new StmtExecutor(ctx, sql1); + try { + stmtExecutor1.execute(); + } catch (Exception e) { + e.printStackTrace(); + } + Planner planner1 = stmtExecutor1.planner(); + List fragments1 = planner1.getFragments(); + String plan1 = planner1.getExplainString(fragments1, new ExplainOptions(false, false)); + + StmtExecutor stmtExecutor2 = new StmtExecutor(ctx, sql2); + try { + stmtExecutor2.execute(); + } catch (Exception e) { + e.printStackTrace(); + } + Planner planner2 = stmtExecutor2.planner(); + List fragments2 = planner2.getFragments(); + String plan2 = planner2.getExplainString(fragments2, new ExplainOptions(false, false)); + + Assert.assertEquals(plan1, plan2); + }; + + compare.accept("select * from db1.tbl2 where k1 = 2.0", "select * from db1.tbl2 where k1 = 2"); + compare.accept("select * from db1.tbl2 where k1 = 2.1", "select * from db1.tbl2 where 2 = 2.1"); + compare.accept("select * from db1.tbl2 where k1 != 2.0", "select * from db1.tbl2 where k1 != 2"); + compare.accept("select * from db1.tbl2 where k1 != 2.1", "select * from db1.tbl2"); + compare.accept("select * from db1.tbl2 where k1 <= 2.0", "select * from db1.tbl2 where k1 <= 2"); + compare.accept("select * from db1.tbl2 where k1 <= 2.1", "select * from db1.tbl2 where k1 < 3"); + compare.accept("select * from db1.tbl2 where k1 >= 2.0", "select * from db1.tbl2 where k1 >= 2"); + compare.accept("select * from db1.tbl2 where k1 >= 2.1", "select * from db1.tbl2 where k1 > 2"); + compare.accept("select * from db1.tbl2 where k1 < 2.0", "select * from db1.tbl2 where k1 < 2"); + compare.accept("select * from db1.tbl2 where k1 < 2.1", "select * from db1.tbl2 where k1 < 3"); + compare.accept("select * from db1.tbl2 where k1 > 2.0", "select * from db1.tbl2 where k1 > 2"); + compare.accept("select * from db1.tbl2 where k1 > 2.1", "select * from db1.tbl2 where k1 > 2"); + } + }