From 88f3d40603396a48a043b14e79135a6236c69dea Mon Sep 17 00:00:00 2001 From: Dustin Ray <40841027+drcapybara@users.noreply.github.com> Date: Tue, 25 Jun 2024 13:11:03 -0700 Subject: [PATCH] refactor: conversion error decimal variant (#26) # Rationale for this change The ```ConversionError``` error type has grown quite large with the addition of new features and types. This PR categorizes some timestamp and decimal errors into different modules and reduces the size and complexity of ```ConversionError```. # What changes are included in this PR? IntermediateDecimal, Decimal errors are moved to respective modules. Native thiserr ```#from``` is used in place of manual ```From``` conversions. # Are these changes tested? yes --- .../src/intermediate_decimal.rs | 43 +++++----- crates/proof-of-sql/src/base/math/decimal.rs | 80 ++++++++++++++----- .../src/base/scalar/mont_scalar.rs | 12 ++- .../src/sql/ast/comparison_util.rs | 19 ++++- .../src/sql/ast/numerical_util.rs | 12 ++- crates/proof-of-sql/src/sql/parse/error.rs | 43 ++-------- .../sql/parse/provable_expr_plan_builder.rs | 12 ++- 7 files changed, 126 insertions(+), 95 deletions(-) diff --git a/crates/proof-of-sql-parser/src/intermediate_decimal.rs b/crates/proof-of-sql-parser/src/intermediate_decimal.rs index 850637e16..bb613da0b 100644 --- a/crates/proof-of-sql-parser/src/intermediate_decimal.rs +++ b/crates/proof-of-sql-parser/src/intermediate_decimal.rs @@ -5,6 +5,7 @@ //! //! A decimal must have a decimal point. The lexer does not route //! whole integers to this contructor. +use crate::intermediate_decimal::IntermediateDecimalError::{LossyCast, OutOfRange, ParseError}; use bigdecimal::{num_bigint::BigInt, BigDecimal, ParseBigDecimalError, ToPrimitive}; use serde::{Deserialize, Serialize}; use std::{fmt, str::FromStr}; @@ -12,7 +13,7 @@ use thiserror::Error; /// Errors related to the processing of decimal values in proof-of-sql #[derive(Error, Debug, PartialEq)] -pub enum DecimalError { +pub enum IntermediateDecimalError { /// Represents an error encountered during the parsing of a decimal string. #[error(transparent)] ParseError(#[from] ParseBigDecimalError), @@ -27,6 +28,8 @@ pub enum DecimalError { ConversionFailure, } +impl Eq for IntermediateDecimalError {} + /// An intermediate placeholder for a decimal #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] pub struct IntermediateDecimal { @@ -55,10 +58,10 @@ impl IntermediateDecimal { &self, precision: u8, scale: i8, - ) -> Result { + ) -> Result { let scaled_decimal = self.value.with_scale(scale.into()); if scaled_decimal.digits() > precision.into() { - return Err(DecimalError::LossyCast); + return Err(LossyCast); } let (d, _) = scaled_decimal.into_bigint_and_exponent(); Ok(d) @@ -72,14 +75,14 @@ impl fmt::Display for IntermediateDecimal { } impl FromStr for IntermediateDecimal { - type Err = DecimalError; + type Err = IntermediateDecimalError; fn from_str(decimal_string: &str) -> Result { BigDecimal::from_str(decimal_string) .map(|value| IntermediateDecimal { value: value.normalized(), }) - .map_err(DecimalError::ParseError) + .map_err(ParseError) } } @@ -100,7 +103,7 @@ impl From for IntermediateDecimal { } impl TryFrom<&str> for IntermediateDecimal { - type Error = DecimalError; + type Error = IntermediateDecimalError; fn try_from(s: &str) -> Result { IntermediateDecimal::from_str(s) @@ -108,7 +111,7 @@ impl TryFrom<&str> for IntermediateDecimal { } impl TryFrom for IntermediateDecimal { - type Error = DecimalError; + type Error = IntermediateDecimalError; fn try_from(s: String) -> Result { IntermediateDecimal::from_str(&s) @@ -116,31 +119,31 @@ impl TryFrom for IntermediateDecimal { } impl TryFrom for i128 { - type Error = DecimalError; + type Error = IntermediateDecimalError; fn try_from(decimal: IntermediateDecimal) -> Result { if !decimal.value.is_integer() { - return Err(DecimalError::LossyCast); + return Err(LossyCast); } match decimal.value.to_i128() { Some(value) if (i128::MIN..=i128::MAX).contains(&value) => Ok(value), - _ => Err(DecimalError::OutOfRange), + _ => Err(OutOfRange), } } } impl TryFrom for i64 { - type Error = DecimalError; + type Error = IntermediateDecimalError; fn try_from(decimal: IntermediateDecimal) -> Result { if !decimal.value.is_integer() { - return Err(DecimalError::LossyCast); + return Err(LossyCast); } match decimal.value.to_i64() { Some(value) if (i64::MIN..=i64::MAX).contains(&value) => Ok(value), - _ => Err(DecimalError::OutOfRange), + _ => Err(OutOfRange), } } } @@ -195,10 +198,7 @@ mod tests { let overflow_decimal = IntermediateDecimal { value: BigDecimal::from_str("170141183460469231731687303715884105728").unwrap(), }; - assert_eq!( - i128::try_from(overflow_decimal), - Err(DecimalError::OutOfRange) - ); + assert_eq!(i128::try_from(overflow_decimal), Err(OutOfRange)); let valid_decimal_negative = IntermediateDecimal { value: BigDecimal::from_str("-170141183460469231731687303715884105728").unwrap(), @@ -211,7 +211,7 @@ mod tests { let non_integer = IntermediateDecimal { value: BigDecimal::from_str("100.5").unwrap(), }; - assert_eq!(i128::try_from(non_integer), Err(DecimalError::LossyCast)); + assert_eq!(i128::try_from(non_integer), Err(LossyCast)); } #[test] @@ -229,10 +229,7 @@ mod tests { let overflow_decimal = IntermediateDecimal { value: BigDecimal::from_str("9223372036854775808").unwrap(), }; - assert_eq!( - i64::try_from(overflow_decimal), - Err(DecimalError::OutOfRange) - ); + assert_eq!(i64::try_from(overflow_decimal), Err(OutOfRange)); let valid_decimal_negative = IntermediateDecimal { value: BigDecimal::from_str("-9223372036854775808").unwrap(), @@ -245,6 +242,6 @@ mod tests { let non_integer = IntermediateDecimal { value: BigDecimal::from_str("100.5").unwrap(), }; - assert_eq!(i64::try_from(non_integer), Err(DecimalError::LossyCast)); + assert_eq!(i64::try_from(non_integer), Err(LossyCast)); } } diff --git a/crates/proof-of-sql/src/base/math/decimal.rs b/crates/proof-of-sql/src/base/math/decimal.rs index 38f19013b..363cef87d 100644 --- a/crates/proof-of-sql/src/base/math/decimal.rs +++ b/crates/proof-of-sql/src/base/math/decimal.rs @@ -1,10 +1,51 @@ //! Module for parsing an `IntermediateDecimal` into a `Decimal75`. use crate::{ - base::scalar::Scalar, - sql::parse::{ConversionError, ConversionResult}, + base::{ + math::decimal::DecimalError::{ + IntermediateDecimalConversionError, InvalidPrecision, RoundingError, + }, + scalar::Scalar, + }, + sql::parse::{ + ConversionError::{self, DecimalConversionError}, + ConversionResult, + }, }; -use proof_of_sql_parser::intermediate_decimal::IntermediateDecimal; +use proof_of_sql_parser::intermediate_decimal::{IntermediateDecimal, IntermediateDecimalError}; use serde::{Deserialize, Deserializer, Serialize}; +use thiserror::Error; + +/// Errors related to decimal operations. +#[derive(Error, Debug, Eq, PartialEq)] +pub enum DecimalError { + #[error("Invalid decimal format or value: {0}")] + /// Error when a decimal format or value is incorrect, + /// the string isn't even a decimal e.g. "notastring", + /// "-21.233.122" etc aka InvalidDecimal + InvalidDecimal(String), + + #[error("Decimal precision is not valid: {0}")] + /// Decimal precision exceeds the allowed limit, + /// e.g. precision above 75/76/whatever set by Scalar + /// or non-positive aka InvalidPrecision + InvalidPrecision(String), + + #[error("Unsupported operation: cannot round decimal: {0}")] + /// This error occurs when attempting to scale a + /// decimal in such a way that a loss of precision occurs. + RoundingError(String), + + /// Errors that may occur when parsing an intermediate decimal + /// into a posql decimal + #[error("Intermediate decimal conversion error: {0}")] + IntermediateDecimalConversionError(IntermediateDecimalError), +} + +impl From for ConversionError { + fn from(err: IntermediateDecimalError) -> ConversionError { + DecimalConversionError(IntermediateDecimalConversionError(err)) + } +} #[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Copy)] /// limit-enforced precision @@ -15,10 +56,10 @@ impl Precision { /// Constructor for creating a Precision instance pub fn new(value: u8) -> Result { if value > MAX_SUPPORTED_PRECISION || value == 0 { - Err(ConversionError::PrecisionParseError(format!( + Err(DecimalConversionError(InvalidPrecision(format!( "Failed to parse precision. Value of {} exceeds max supported precision of {}", value, MAX_SUPPORTED_PRECISION - ))) + )))) } else { Ok(Precision(value)) } @@ -73,9 +114,9 @@ impl Decimal { ) -> ConversionResult> { let scale_factor = new_scale - self.scale; if scale_factor < 0 || new_precision.value() < self.precision.value() + scale_factor as u8 { - return Err(ConversionError::DecimalRoundingError( + return Err(DecimalConversionError(RoundingError( "Scale factor must be non-negative".to_string(), - )); + ))); } let scaled_value = scale_scalar(self.value, scale_factor)?; Ok(Decimal::new(scaled_value, new_precision, new_scale)) @@ -86,14 +127,14 @@ impl Decimal { const MINIMAL_PRECISION: u8 = 19; let raw_precision = precision.value(); if raw_precision < MINIMAL_PRECISION { - return Err(ConversionError::DecimalRoundingError( + return Err(DecimalConversionError(RoundingError( "Precision must be at least 19".to_string(), - )); + ))); } if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 { - return Err(ConversionError::DecimalRoundingError( + return Err(DecimalConversionError(RoundingError( "Can not scale down a decimal".to_string(), - )); + ))); } let scaled_value = scale_scalar(S::from(&value), scale)?; Ok(Decimal::new(scaled_value, precision, scale)) @@ -104,14 +145,14 @@ impl Decimal { const MINIMAL_PRECISION: u8 = 39; let raw_precision = precision.value(); if raw_precision < MINIMAL_PRECISION { - return Err(ConversionError::DecimalRoundingError( + return Err(DecimalConversionError(RoundingError( "Precision must be at least 19".to_string(), - )); + ))); } if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 { - return Err(ConversionError::DecimalRoundingError( + return Err(DecimalConversionError(RoundingError( "Can not scale down a decimal".to_string(), - )); + ))); } let scaled_value = scale_scalar(S::from(&value), scale)?; Ok(Decimal::new(scaled_value, precision, scale)) @@ -132,8 +173,9 @@ impl Decimal { /// * `target_scale` - The scale (number of decimal places) to use in the scalar. /// /// ## Errors -/// Returns `ConversionError::PrecisionParseError` if the number of digits in -/// the decimal exceeds the `target_precision` after adjusting for `target_scale`. +/// Returns `InvalidPrecision` error if the number of digits in +/// the decimal exceeds the `target_precision` before or after adjusting for +/// `target_scale`, or if the target precision is zero. pub(crate) fn try_into_to_scalar( d: &IntermediateDecimal, target_precision: Precision, @@ -147,9 +189,9 @@ pub(crate) fn try_into_to_scalar( /// Note that we do not check for overflow. pub(crate) fn scale_scalar(s: S, scale: i8) -> ConversionResult { if scale < 0 { - return Err(ConversionError::DecimalRoundingError( + return Err(DecimalConversionError(RoundingError( "Scale factor must be non-negative".to_string(), - )); + ))); } let ten = S::from(10); let mut res = s; diff --git a/crates/proof-of-sql/src/base/scalar/mont_scalar.rs b/crates/proof-of-sql/src/base/scalar/mont_scalar.rs index 2ef60e34c..7ac833139 100644 --- a/crates/proof-of-sql/src/base/scalar/mont_scalar.rs +++ b/crates/proof-of-sql/src/base/scalar/mont_scalar.rs @@ -1,5 +1,11 @@ use super::{scalar_conversion_to_int, Scalar, ScalarConversionError}; -use crate::{base::math::decimal::MAX_SUPPORTED_PRECISION, sql::parse::ConversionError}; +use crate::{ + base::{ + math::decimal::{DecimalError, MAX_SUPPORTED_PRECISION}, + scalar::mont_scalar::DecimalError::InvalidDecimal, + }, + sql::parse::{ConversionError, ConversionError::DecimalConversionError}, +}; use ark_ff::{BigInteger, Field, Fp, Fp256, MontBackend, MontConfig, PrimeField}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use bytemuck::TransparentWrapper; @@ -163,11 +169,11 @@ impl> TryFrom for MontScalar { // Check if the number of digits exceeds the maximum precision allowed if digits.len() > MAX_SUPPORTED_PRECISION.into() { - return Err(ConversionError::InvalidDecimal(format!( + return Err(DecimalConversionError(InvalidDecimal(format!( "Attempted to parse a number with {} digits, which exceeds the max supported precision of {}", digits.len(), MAX_SUPPORTED_PRECISION - ))); + )))); } // Continue with the previous logic diff --git a/crates/proof-of-sql/src/sql/ast/comparison_util.rs b/crates/proof-of-sql/src/sql/ast/comparison_util.rs index 4b1d28461..1911ac28a 100644 --- a/crates/proof-of-sql/src/sql/ast/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/ast/comparison_util.rs @@ -1,6 +1,16 @@ use crate::{ - base::{database::Column, math::decimal::Precision, scalar::Scalar}, - sql::parse::{type_check_binary_operation, ConversionError, ConversionResult}, + base::{ + database::Column, + math::decimal::{DecimalError, Precision}, + scalar::Scalar, + }, + sql::{ + ast::comparison_util::DecimalError::InvalidPrecision, + parse::{ + type_check_binary_operation, ConversionError, ConversionError::DecimalConversionError, + ConversionResult, + }, + }, }; use bumpalo::Bump; use proof_of_sql_parser::intermediate_ast::BinaryOperator; @@ -67,8 +77,9 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( rhs_precision_value + (max_scale - rhs_scale) as u8, ); // Check if the precision is valid - let _max_precision = Precision::new(max_precision_value) - .map_err(|_| ConversionError::InvalidPrecision(max_precision_value as i16))?; + let _max_precision = Precision::new(max_precision_value).map_err(|_| { + DecimalConversionError(InvalidPrecision(max_precision_value.to_string())) + })?; } unchecked_subtract_impl( alloc, diff --git a/crates/proof-of-sql/src/sql/ast/numerical_util.rs b/crates/proof-of-sql/src/sql/ast/numerical_util.rs index c87186425..42fa8eeec 100644 --- a/crates/proof-of-sql/src/sql/ast/numerical_util.rs +++ b/crates/proof-of-sql/src/sql/ast/numerical_util.rs @@ -1,10 +1,13 @@ use crate::{ base::{ database::{Column, ColumnType}, - math::decimal::{scale_scalar, Precision}, + math::decimal::{scale_scalar, DecimalError, Precision}, scalar::Scalar, }, - sql::parse::{ConversionError, ConversionResult}, + sql::{ + ast::numerical_util::DecimalError::InvalidPrecision, + parse::{ConversionError, ConversionError::DecimalConversionError, ConversionResult}, + }, }; use bumpalo::Bump; @@ -41,9 +44,10 @@ pub(crate) fn try_add_subtract_column_types( .max(right_precision_value - right_scale as i16) + 1_i16; let precision = u8::try_from(precision_value) - .map_err(|_| ConversionError::InvalidPrecision(precision_value)) + .map_err(|_| DecimalConversionError(InvalidPrecision(precision_value.to_string()))) .and_then(|p| { - Precision::new(p).map_err(|_| ConversionError::InvalidPrecision(p as i16)) + Precision::new(p) + .map_err(|_| DecimalConversionError(InvalidPrecision(p.to_string()))) })?; Ok(ColumnType::Decimal75(precision, scale)) } diff --git a/crates/proof-of-sql/src/sql/parse/error.rs b/crates/proof-of-sql/src/sql/parse/error.rs index 255522b3e..8d2892b51 100644 --- a/crates/proof-of-sql/src/sql/parse/error.rs +++ b/crates/proof-of-sql/src/sql/parse/error.rs @@ -1,5 +1,5 @@ -use crate::base::database::ColumnType; -use proof_of_sql_parser::{intermediate_decimal::DecimalError, Identifier, ResourceId}; +use crate::base::{database::ColumnType, math::decimal::DecimalError}; +use proof_of_sql_parser::{Identifier, ResourceId}; use thiserror::Error; /// Errors from converting an intermediate AST into a provable AST. @@ -50,50 +50,17 @@ pub enum ConversionError { /// General error for invalid expressions InvalidExpression(String), - #[error("Unsupported operation: cannot round decimal: {0}")] - /// Decimal rounding is not supported - DecimalRoundingError(String), - - #[error("Error while parsing precision from query: {0}")] - /// Error in parsing precision in a query - PrecisionParseError(String), - - #[error("Decimal precision is not valid: {0}")] - /// Decimal precision is an integer but exceeds the allowed limit. We use i16 here to include all kinds of invalid precision values. - InvalidPrecision(i16), - #[error("Encountered parsing error: {0}")] /// General parsing error ParseError(String), - #[error("Unsupported operation: cannot round literal: {0}")] - /// Error when a rounding operation is not supported - LiteralRoundDownError(String), + #[error(transparent)] + /// Errors related to decimal operations + DecimalConversionError(#[from] DecimalError), #[error("Query not provable because: {0}")] /// Query requires unprovable feature Unprovable(String), - - #[error("Invalid decimal format or value: {0}")] - /// Error when a decimal format or value is incorrect - InvalidDecimal(String), -} - -impl From for ConversionError { - fn from(error: DecimalError) -> Self { - match error { - DecimalError::ParseError(e) => ConversionError::ParseError(e.to_string()), - DecimalError::OutOfRange => ConversionError::ParseError( - "Intermediate decimal cannot be cast to primitive".into(), - ), - DecimalError::LossyCast => ConversionError::ParseError( - "Intermediate decimal has non-zero fractional part".into(), - ), - DecimalError::ConversionFailure => { - ConversionError::ParseError("Could not cast into intermediate decimal.".into()) - } - } - } } impl From for ConversionError { diff --git a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs index d33637533..d18069ecd 100644 --- a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs @@ -3,9 +3,12 @@ use crate::{ base::{ commitment::Commitment, database::{ColumnRef, LiteralValue}, - math::decimal::{try_into_to_scalar, Precision}, + math::decimal::{try_into_to_scalar, DecimalError::InvalidPrecision, Precision}, + }, + sql::{ + ast::{ColumnExpr, ProvableExprPlan}, + parse::ConversionError::DecimalConversionError, }, - sql::ast::{ColumnExpr, ProvableExprPlan}, }; use proof_of_sql_parser::{ intermediate_ast::{BinaryOperator, Expression, Literal, UnaryOperator}, @@ -72,8 +75,9 @@ impl ProvableExprPlanBuilder<'_> { Literal::Int128(i) => Ok(ProvableExprPlan::new_literal(LiteralValue::Int128(*i))), Literal::Decimal(d) => { let scale = d.scale(); - let precision = Precision::new(d.precision()) - .map_err(|_| ConversionError::InvalidPrecision(d.precision() as i16))?; + let precision = Precision::new(d.precision()).map_err(|_| { + DecimalConversionError(InvalidPrecision(d.precision().to_string())) + })?; Ok(ProvableExprPlan::new_literal(LiteralValue::Decimal75( precision, scale,