From 8ad393ee4a610556cb6d6d8266649d40e6b06b0a Mon Sep 17 00:00:00 2001 From: Pengfei Zhan Date: Thu, 7 Mar 2024 15:51:13 +0800 Subject: [PATCH] [CALCITE-5843] Constant expression with nested casts causes a compiler crash --- .../calcite/linq4j/tree/Expressions.java | 10 ++- .../apache/calcite/linq4j/tree/Primitive.java | 71 +++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java index 5effc26d5b2..ac2359f40fa 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java @@ -569,7 +569,15 @@ public static ConstantExpression constant(@Nullable Object value, Type type) { value = new BigInteger(stringValue); } if (primitive != null) { - value = primitive.parse(stringValue); + if (value instanceof Number) { + Number valueNumber = (Number) value; + value = primitive.numberValue(valueNumber); + if (value == null) { + value = primitive.parse(stringValue); + } + } else { + value = primitive.parse(stringValue); + } } } } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java index 0e89bdeebed..fd7042b4550 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java @@ -22,6 +22,8 @@ import java.lang.reflect.Array; import java.lang.reflect.Field; import java.lang.reflect.Type; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.sql.ResultSet; import java.sql.SQLException; import java.util.AbstractList; @@ -366,6 +368,75 @@ public static List asList(double[] elements) { return (List) asList((Object) elements); } + /** + * Check if a value after rounding falls within a specified range. + * + * @param value Value to compare. + * @param min Minimum value allowed. + * @param max Maximum value allowed. + */ + static void checkRoundedRange(Number value, double min, double max) { + double dbl = value.doubleValue(); + // The equivalent of DOWN rounding for BigDecimal + dbl = dbl > 0 ? Math.floor(dbl) : Math.ceil(dbl); + if (dbl < min || dbl > max) { + throw new ArithmeticException("Value " + value + " out of range"); + } + } + + /** + * Converts a number into a value of the type specified by this primitive + * using the SQL CAST rules. If the value conversion causes loss of significant digits, + * an exception is thrown. + * + * @param value Value to convert. + * @return The converted value, or null if the type of the result is not a number. + */ + public @Nullable Object numberValue(Number value) { + switch (this) { + case BYTE: + checkRoundedRange(value, Byte.MIN_VALUE, Byte.MAX_VALUE); + return value.byteValue(); + case CHAR: + // No overflow checks for char values. + // For example, Postgres has this behavior. + return (char) value.intValue(); + case SHORT: + checkRoundedRange(value, Short.MIN_VALUE, Short.MAX_VALUE); + return value.shortValue(); + case INT: + checkRoundedRange(value, Integer.MIN_VALUE, Integer.MAX_VALUE); + return value.intValue(); + case LONG: + if (value instanceof Byte + || value instanceof Short + || value instanceof Integer + || value instanceof Long) { + return value.longValue(); + } + if (value instanceof Float + || value instanceof Double) { + // The value Long.MAX_VALUE cannot be represented exactly as a double, + // so we cannot use checkRoundedRange. + BigDecimal decimal = BigDecimal.valueOf(value.doubleValue()) + // Round to an integer + .setScale(0, RoundingMode.DOWN); + // longValueExact will throw ArithmeticException if out of range + return decimal.longValueExact(); + } + throw new AssertionError("Unexpected Number type " + + value.getClass().getSimpleName()); + case FLOAT: + // out of range values will be represented as infinities + return value.floatValue(); + case DOUBLE: + // out of range values will be represented as infinities + return value.doubleValue(); + default: + return null; + } + } + /** * Converts a collection of boxed primitives into an array of primitives. *