diff --git a/libraries/math/src/pool_math.sw b/libraries/math/src/pool_math.sw index 14de7ef..7b5f314 100644 --- a/libraries/math/src/pool_math.sw +++ b/libraries/math/src/pool_math.sw @@ -92,9 +92,9 @@ pub fn get_amount_in( xy, reserve_in_adjusted, ) - reserve_in_adjusted; - y * pow_decimals_in / ONE_E_18 + rounding_up_division(y * pow_decimals_in, ONE_E_18) } else { - output_amount * reserve_in / (reserve_out - output_amount) + 1 + rounding_up_division(output_amount * reserve_in, (reserve_out - output_amount)) } } @@ -291,23 +291,15 @@ fn get_y(x_0: u256, xy: u256, y: u256) -> u256 { fn calculate_fee_to_subtract(amount: u64, feeBP: u64) -> u64 { let nominator = amount.as_u256() * feeBP.as_u256(); - let fee = u64::try_from(nominator / BASIS_POINTS_DENOMINATOR).unwrap(); - if nominator % BASIS_POINTS_DENOMINATOR != 0 { - fee + 1 - } else { - fee - } + let fee = rounding_up_division(nominator, BASIS_POINTS_DENOMINATOR); + u64::try_from(fee).unwrap() } fn calculate_fee_to_add(amount: u64, feeBP: u64) -> u64 { let nominator = amount.as_u256() * feeBP.as_u256(); let denominator = BASIS_POINTS_DENOMINATOR - feeBP.as_u256(); - let fee = u64::try_from(nominator / denominator).unwrap(); - if nominator % denominator != 0 { - fee + 1 - } else { - fee - } + let fee = rounding_up_division(nominator, denominator); + u64::try_from(fee).unwrap() } fn subtract_fee(amount: u64, fee: u64) -> u64 { @@ -318,6 +310,15 @@ fn add_fee(amount: u64, fee: u64) -> u64 { amount + calculate_fee_to_add(amount, fee) } +fn rounding_up_division(nominator: u256, denominator: u256) -> u256 { + let rounding_down_division_result = nominator / denominator; + if nominator % denominator == 0 { + rounding_down_division_result + } else { + rounding_down_division_result + 1 + } +} + // Tests #[test] fn test_pow_decimals() { @@ -379,3 +380,23 @@ fn test_calculate_fee_to_add() { i = i + 1; } } + +#[test] +fn test_rounding_up_division() { + assert_eq(rounding_up_division(1000, 1000), 1); + assert_eq(rounding_up_division(1000, 1), 1000); + assert_eq(rounding_up_division(1000, 5), 200); + assert_eq(rounding_up_division(1000, 2000), 1); + assert_eq(rounding_up_division(9, 3), 3); + assert_eq(rounding_up_division(10, 3), 4); + assert_eq(rounding_up_division(11, 3), 4); + assert_eq(rounding_up_division(12, 3), 4); + assert_eq( + rounding_up_division(pow_decimals(72), pow_decimals(12)), + pow_decimals(60), + ); + assert_eq( + rounding_up_division(pow_decimals(72) + 1, pow_decimals(12)), + pow_decimals(60) + 1, + ); +}