diff --git a/Cargo.toml b/Cargo.toml index 11f325628..9f5bc2fee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ bit-iter = { version = "1.1.1" } bigdecimal = { version = "0.4.5", default-features = false, features = ["serde"] } blake3 = { version = "1.3.3", default-features = false } blitzar = { version = "4.0.0" } +bnum = { version = "0.3.0" } bumpalo = { version = "3.11.0" } bytemuck = {version = "1.16.3", features = ["derive"]} byte-slice-cast = { version = "1.2.1", default-features = false } diff --git a/crates/proof-of-sql-parser/src/intermediate_ast.rs b/crates/proof-of-sql-parser/src/intermediate_ast.rs index d89696654..08dcd02c7 100644 --- a/crates/proof-of-sql-parser/src/intermediate_ast.rs +++ b/crates/proof-of-sql-parser/src/intermediate_ast.rs @@ -106,11 +106,11 @@ pub enum BinaryOperator { /// Comparison = Equal, - /// Comparison <= - LessThanOrEqual, + /// Comparison < + LessThan, - /// Comparison >= - GreaterThanOrEqual, + /// Comparison > + GreaterThan, } /// Possible unary operators for simple expressions diff --git a/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs b/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs index 4574de18b..418a9c1c4 100644 --- a/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs +++ b/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs @@ -987,7 +987,7 @@ fn we_can_parse_a_query_with_filter_lt() { query( cols_res(&["a"]), tab(None, "tab"), - not(ge(col("b"), lit(4))), + lt(col("b"), lit(4)), vec![], ), vec![], @@ -1023,7 +1023,7 @@ fn we_can_parse_a_query_with_filter_gt() { query( cols_res(&["a"]), tab(None, "tab"), - not(le(col("b"), lit(4))), + gt(col("b"), lit(4)), vec![], ), vec![], diff --git a/crates/proof-of-sql-parser/src/sql.lalrpop b/crates/proof-of-sql-parser/src/sql.lalrpop index ce6dca007..e838af055 100644 --- a/crates/proof-of-sql-parser/src/sql.lalrpop +++ b/crates/proof-of-sql-parser/src/sql.lalrpop @@ -238,35 +238,35 @@ Expression: Box = { }), #[precedence(level="4")] #[assoc(side="left")] - ">=" => + ">" => Box::new(intermediate_ast::Expression::Binary { - op: intermediate_ast::BinaryOperator::GreaterThanOrEqual, + op: intermediate_ast::BinaryOperator::GreaterThan, left, right, }), - "<=" => + "<" => Box::new(intermediate_ast::Expression::Binary { - op: intermediate_ast::BinaryOperator::LessThanOrEqual, + op: intermediate_ast::BinaryOperator::LessThan, left, right, }), - ">" => + ">=" => Box::new(intermediate_ast::Expression::Unary { op: intermediate_ast::UnaryOperator::Not, expr: Box::new(intermediate_ast::Expression::Binary { - op: intermediate_ast::BinaryOperator::LessThanOrEqual, + op: intermediate_ast::BinaryOperator::LessThan, left, right, }), }), - "<" => + "<=" => Box::new(intermediate_ast::Expression::Unary { op: intermediate_ast::UnaryOperator::Not, expr: Box::new(intermediate_ast::Expression::Binary { - op: intermediate_ast::BinaryOperator::GreaterThanOrEqual, + op: intermediate_ast::BinaryOperator::GreaterThan, left, right, }), diff --git a/crates/proof-of-sql-parser/src/sqlparser.rs b/crates/proof-of-sql-parser/src/sqlparser.rs index 72643b568..efae449f5 100644 --- a/crates/proof-of-sql-parser/src/sqlparser.rs +++ b/crates/proof-of-sql-parser/src/sqlparser.rs @@ -90,8 +90,8 @@ impl From for BinaryOperator { PoSqlBinaryOperator::And => BinaryOperator::And, PoSqlBinaryOperator::Or => BinaryOperator::Or, PoSqlBinaryOperator::Equal => BinaryOperator::Eq, - PoSqlBinaryOperator::LessThanOrEqual => BinaryOperator::LtEq, - PoSqlBinaryOperator::GreaterThanOrEqual => BinaryOperator::GtEq, + PoSqlBinaryOperator::LessThan => BinaryOperator::Lt, + PoSqlBinaryOperator::GreaterThan => BinaryOperator::Gt, PoSqlBinaryOperator::Add => BinaryOperator::Plus, PoSqlBinaryOperator::Subtract => BinaryOperator::Minus, PoSqlBinaryOperator::Multiply => BinaryOperator::Multiply, @@ -291,7 +291,7 @@ mod test { ); check_posql_intermediate_ast_to_sqlparser_equality("select 1 as a, 'Meow' as d, b as b from namespace.table where c = 4 order by a desc limit 10 offset 0;"); check_posql_intermediate_ast_to_sqlparser_equality( - "select true as cons, a and b or c >= 4 as comp from tab where d = 'Space and Time';", + "select true as cons, a and b or c > 4 as comp from tab where d = 'Space and Time';", ); check_posql_intermediate_ast_to_sqlparser_equality( "select cat as cat, true as cons, max(meow) as max_meow from tab where d = 'Space and Time' group by cat;", diff --git a/crates/proof-of-sql-parser/src/utility.rs b/crates/proof-of-sql-parser/src/utility.rs index cf9d7781e..60188cf59 100644 --- a/crates/proof-of-sql-parser/src/utility.rs +++ b/crates/proof-of-sql-parser/src/utility.rs @@ -30,8 +30,18 @@ pub fn equal(left: Box, right: Box) -> Box { /// Construct a new boxed `Expression` A >= B #[must_use] pub fn ge(left: Box, right: Box) -> Box { + not(Box::new(Expression::Binary { + op: BinaryOperator::LessThan, + left, + right, + })) +} + +/// Construct a new boxed `Expression` A > B +#[must_use] +pub fn gt(left: Box, right: Box) -> Box { Box::new(Expression::Binary { - op: BinaryOperator::GreaterThanOrEqual, + op: BinaryOperator::GreaterThan, left, right, }) @@ -40,8 +50,18 @@ pub fn ge(left: Box, right: Box) -> Box { /// Construct a new boxed `Expression` A <= B #[must_use] pub fn le(left: Box, right: Box) -> Box { + not(Box::new(Expression::Binary { + op: BinaryOperator::GreaterThan, + left, + right, + })) +} + +/// Construct a new boxed `Expression` A < B +#[must_use] +pub fn lt(left: Box, right: Box) -> Box { Box::new(Expression::Binary { - op: BinaryOperator::LessThanOrEqual, + op: BinaryOperator::LessThan, left, right, }) diff --git a/crates/proof-of-sql/Cargo.toml b/crates/proof-of-sql/Cargo.toml index bb02f7cac..b994c2998 100644 --- a/crates/proof-of-sql/Cargo.toml +++ b/crates/proof-of-sql/Cargo.toml @@ -27,6 +27,7 @@ bit-iter = { workspace = true } bigdecimal = { workspace = true } blake3 = { workspace = true } blitzar = { workspace = true, optional = true } +bnum = { workspace = true } bumpalo = { workspace = true, features = ["collections"] } bytemuck = { workspace = true } byte-slice-cast = { workspace = true } diff --git a/crates/proof-of-sql/src/base/bit/abs_bit_mask.rs b/crates/proof-of-sql/src/base/bit/abs_bit_mask.rs deleted file mode 100644 index 205e563bc..000000000 --- a/crates/proof-of-sql/src/base/bit/abs_bit_mask.rs +++ /dev/null @@ -1,8 +0,0 @@ -use crate::base::scalar::Scalar; - -pub fn make_abs_bit_mask(x: S) -> [u64; 4] { - let (sign, x) = if S::MAX_SIGNED < x { (1, -x) } else { (0, x) }; - let mut res: [u64; 4] = x.into(); - res[3] |= sign << 63; - res -} diff --git a/crates/proof-of-sql/src/base/bit/bit_distribution.rs b/crates/proof-of-sql/src/base/bit/bit_distribution.rs index 81db9bd60..5939ecd12 100644 --- a/crates/proof-of-sql/src/base/bit/bit_distribution.rs +++ b/crates/proof-of-sql/src/base/bit/bit_distribution.rs @@ -1,70 +1,87 @@ -use crate::base::{bit::make_abs_bit_mask, scalar::Scalar}; +use super::bit_mask_utils::{is_bit_mask_negative_representation, make_bit_mask}; +use crate::base::scalar::{Scalar, ScalarExt}; +use ark_std::iterable::Iterable; use bit_iter::BitIter; +use bnum::types::U256; use core::convert::Into; +use itertools::Itertools; use serde::{Deserialize, Serialize}; /// Describe the distribution of bit values in a table column #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct BitDistribution { - /// We use two arrays to track which bits vary - /// and the constant bit values. If - /// `{x_1, ..., x_n}` represents the values [`BitDistribution`] describes, then: - /// `or_all = abs(x_1) | abs(x_2) | ... | abs(x_n)` - /// `vary_mask & (1 << i) =` - /// `1` if `x_s & (1 << i) != x_t & (1 << i)` for some `s != t` - /// 0 otherwise - pub or_all: [u64; 4], - pub vary_mask: [u64; 4], + /// Identifies all columns that are identical to the leading column (the sign column). The lead bit indicates if the sign column is constant + pub(crate) vary_mask: [u64; 4], + /// Identifies all columns that are the identical to the lead column. The lead bit indicates the sign of the last row of data (only relevant if the sign is constant) + pub(crate) leading_bit_mask: [u64; 4], } impl BitDistribution { pub fn new + Clone>(data: &[T]) -> Self { - if data.is_empty() { - return Self { - or_all: [0; 4], - vary_mask: [0; 4], - }; - } - let mut or_all = make_abs_bit_mask(data[0].clone().into()); - let mut vary_mask = [0; 4]; - for x in data.iter().skip(1) { - let mask = make_abs_bit_mask((*x).clone().into()); - for i in 0..4 { - vary_mask[i] |= or_all[i] ^ mask[i]; - or_all[i] |= mask[i]; - } + let bit_masks = data.iter().cloned().map(Into::::into).map(make_bit_mask); + let (sign_mask, inverse_sign_mask) = + bit_masks + .clone() + .fold((U256::MAX, U256::MAX), |acc, bit_mask| { + let bit_mask = if is_bit_mask_negative_representation(bit_mask) { + bit_mask ^ (U256::MAX >> 1) + } else { + bit_mask + }; + (acc.0 & bit_mask, acc.1 & !bit_mask) + }); + let vary_mask_bit = U256::from( + !bit_masks + .map(is_bit_mask_negative_representation) + .all_equal(), + ) << 255; + let vary_mask: U256 = !(sign_mask | inverse_sign_mask) | vary_mask_bit; + + Self { + leading_bit_mask: sign_mask.into(), + vary_mask: vary_mask.into(), } - Self { or_all, vary_mask } } - pub fn num_varying_bits(&self) -> usize { - let mut res = 0_usize; - for xi in &self.vary_mask { - res += xi.count_ones() as usize; - } - res + pub fn vary_mask(&self) -> U256 { + U256::from(self.vary_mask) + } + + /// Identifies all columns that are the identical to the lead column. + pub fn leading_bit_mask(&self) -> U256 { + U256::from(self.leading_bit_mask) | (U256::ONE << 255) + } + + /// Identifies all columns that are the identical to the inverse of the lead column. + pub fn leading_bit_inverse_mask(&self) -> U256 { + (!self.vary_mask() ^ self.leading_bit_mask()) & (U256::MAX >> 1) } - pub fn has_varying_sign_bit(&self) -> bool { - self.vary_mask[3] & (1 << 63) != 0 + /// # Panics + /// + /// Panics if conversion from `ExpType` to `usize` fails + pub fn num_varying_bits(&self) -> usize { + self.vary_mask().count_ones() as usize } - #[allow(clippy::missing_panics_doc)] - pub fn sign_bit(&self) -> bool { - assert!(!self.has_varying_sign_bit()); - self.or_all[3] & (1 << 63) != 0 + /// # Panics + /// + /// Panics if lead bit varies but `bit_evals` is empty + pub fn leading_bit_eval(&self, bit_evals: &[S], one_eval: S) -> S { + if U256::from(self.vary_mask) & (U256::ONE << 255) != U256::ZERO { + *bit_evals.last().expect("bit_evals should be non-empty") + } else if U256::from(self.leading_bit_mask) & (U256::ONE << 255) == U256::ZERO { + S::ZERO + } else { + one_eval + } } /// Check if this instance represents a valid bit distribution. `is_valid` /// can be used after deserializing a [`BitDistribution`] from an untrusted /// source. pub fn is_valid(&self) -> bool { - for (m, o) in self.vary_mask.iter().zip(self.or_all) { - if m & !o != 0 { - return false; - } - } - true + (self.vary_mask() & self.leading_bit_mask()) & (U256::MAX >> 1) == U256::ZERO } /// In order to avoid cases with large numbers where there can be both a positive and negative @@ -73,92 +90,24 @@ impl BitDistribution { /// Currently this is set to be the minimal value that will include the sum of two signed 128-bit /// integers. The range will likely be expanded in the future as we support additional expressions. pub fn is_within_acceptable_range(&self) -> bool { - // handle the case of everything zero - if self.num_varying_bits() == 0 && self.constant_part() == [0; 4] { - return true; - } - // signed 128 bit numbers range from // -2^127 to 2^127-1 // the maximum absolute value of the sum of two signed 128-integers is // then // 2 * (2^127) = 2^128 - self.most_significant_abs_bit() <= 128 - } - - /// If `{b_i}` represents the non-varying 1-bits of the absolute values, return the value - /// `sum_i b_i 2 ^ i` - pub fn constant_part(&self) -> [u64; 4] { - let mut val = [0; 4]; - self.for_each_abs_constant_bit(|i: usize, bit: usize| { - val[i] |= 1u64 << bit; - }); - val - } - - /// Iterate over each constant 1-bit for the absolute values - pub fn for_each_abs_constant_bit(&self, mut f: F) - where - F: FnMut(usize, usize), - { - for i in 0..4 { - let bitset = if i == 3 { - !(self.vary_mask[i] | (1 << 63)) - } else { - !self.vary_mask[i] - }; - let bitset = bitset & self.or_all[i]; - for pos in BitIter::from(bitset) { - f(i, pos); - } - } - } - - /// Iterate over each varying bit for the absolute values - pub fn for_each_abs_varying_bit(&self, mut f: F) - where - F: FnMut(usize, usize), - { - for i in 0..4 { - let bitset = if i == 3 { - self.vary_mask[i] & !(1 << 63) - } else { - self.vary_mask[i] - }; - for pos in BitIter::from(bitset) { - f(i, pos); - } - } - } - - /// Iterate over each varying bit for the absolute values and the sign bit - /// if it varies - pub fn for_each_varying_bit(&self, mut f: F) - where - F: FnMut(usize, usize), - { - for i in 0..4 { - let bitset = self.vary_mask[i]; - for pos in BitIter::from(bitset) { - f(i, pos); - } - } + (self.leading_bit_inverse_mask() >> 128) == (U256::MAX >> 129) } - /// Return the position of the most significant bit of the absolute values + /// Iterate over each varying bit + /// /// # Panics - /// Panics if no bits are set to 1 in the bit representation of `or_all`. - pub fn most_significant_abs_bit(&self) -> usize { - let mask = self.or_all[3] & !(1 << 63); - if mask != 0 { - return 64 - (mask.leading_zeros() as usize) - 1 + 3 * 64; - } - for i in (0..3).rev() { - let mask = self.or_all[i]; - if mask != 0 { - return 64 - (mask.leading_zeros() as usize) - 1 + 64 * i; - } - } - panic!("no bits are set"); + /// + /// The panic shouldn't be mathematically possible + pub fn vary_mask_iter(&self) -> impl Iterator + '_ { + (0..4).flat_map(|i| { + BitIter::from(self.vary_mask[i]) + .iter() + .map(move |pos| u8::try_from(i * 64 + pos).expect("index greater than 255")) + }) } } diff --git a/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs b/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs index f0fc693c2..b63fb4716 100644 --- a/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs +++ b/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs @@ -1,34 +1,229 @@ use super::*; -use crate::base::scalar::test_scalar::TestScalar; -use num_traits::{One, Zero}; +use crate::base::scalar::{test_scalar::TestScalar, Scalar, ScalarExt}; +use bnum::types::U256; + +// vary_mask function start + +#[test] +fn we_can_get_u256_version_of_vary_mask_for_zero() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0; 4], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_vary_mask = bit_distribution.vary_mask(); + + // ASSERT + assert_eq!(u256_vary_mask, U256::ZERO); +} + +#[test] +fn we_can_get_u256_version_of_vary_mask_for_one() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [1, 0, 0, 0], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_vary_mask = bit_distribution.vary_mask(); + + // ASSERT + assert_eq!(u256_vary_mask, U256::ONE); +} + +#[test] +fn we_can_get_u256_version_of_vary_mask_for_large_number() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [256, 0, 0, 256], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_vary_mask = bit_distribution.vary_mask(); + + // ASSERT + assert_eq!(u256_vary_mask, (U256::ONE << 8) + (U256::ONE << 200)); +} + +// vary_mask function end + +// leading_bit_mask function start + +#[test] +fn we_can_get_u256_version_of_leading_bit_mask_for_zero() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0; 4], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_leading_bit_mask = bit_distribution.leading_bit_mask(); + + // ASSERT + assert_eq!(u256_leading_bit_mask, U256::ONE << 255); +} + +#[test] +fn we_can_get_u256_version_of_leading_bit_mask_for_one() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0; 4], + leading_bit_mask: [1, 0, 0, 0], + }; + + // ACT + let u256_leading_bit_mask = bit_distribution.leading_bit_mask(); + + // ASSERT + assert_eq!(u256_leading_bit_mask, U256::ONE | (U256::ONE << 255)); +} + +#[test] +fn we_can_get_u256_version_of_leading_bit_mask_for_large_number() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0; 4], + leading_bit_mask: [256, 0, 0, 256], + }; + + // ACT + let u256_leading_bit_mask = bit_distribution.leading_bit_mask(); + + // ASSERT + assert_eq!( + u256_leading_bit_mask, + ((U256::ONE << 8) + (U256::ONE << 200)) | (U256::ONE << 255) + ); +} + +// leading_bit_mask function end + +// leading_bit_inverse_mask function start + +#[test] +fn we_can_get_u256_version_of_leading_bit_inverse_mask_for_zero() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [u64::MAX; 4], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_leading_bit_inverse_mask = bit_distribution.leading_bit_inverse_mask(); + + // ASSERT + assert_eq!(u256_leading_bit_inverse_mask, U256::ZERO); +} + +#[test] +fn we_can_get_u256_version_of_leading_bit_inverse_mask_for_one() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [u64::MAX - 1, u64::MAX, u64::MAX, u64::MAX], + leading_bit_mask: [0, 0, 0, 0], + }; + + // ACT + let u256_leading_bit_inverse_mask = bit_distribution.leading_bit_inverse_mask(); + + // ASSERT + assert_eq!(u256_leading_bit_inverse_mask, U256::ONE); +} + +#[test] +fn we_can_get_u256_version_of_leading_bit_inverse_mask_for_large_number() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [u64::MAX - 255, u64::MAX, u64::MAX, u64::MAX - 255], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_leading_bit_inverse_mask = bit_distribution.leading_bit_inverse_mask(); + + // ASSERT + assert_eq!( + u256_leading_bit_inverse_mask, + ((U256::ONE << 8) - U256::ONE) + (((U256::ONE << 8) - U256::ONE) << 192) + ); +} + +// leading_bit_inverse_mask function end + +#[test] +fn we_can_get_leading_bit_eval_while_varying() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0, 0, 0, 1 << 63], + leading_bit_mask: [0; 4], + }; + + // ACT + let bit_eval = bit_distribution.leading_bit_eval(&[TestScalar::ONE], TestScalar::TWO); + + // ASSERT + assert_eq!(bit_eval, TestScalar::ONE); +} + +#[test] +fn we_can_get_leading_bit_eval_while_constant_and_zero() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0, 0, 0, 0], + leading_bit_mask: [0; 4], + }; + + // ACT + let bit_eval = bit_distribution.leading_bit_eval(&[TestScalar::ONE], TestScalar::TWO); + + // ASSERT + assert_eq!(bit_eval, TestScalar::ZERO); +} + +#[test] +fn we_can_get_leading_bit_eval_while_constant_and_non_zero() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0, 0, 0, 0], + leading_bit_mask: [0, 0, 0, 1 << 63], + }; + + // ACT + let bit_eval = bit_distribution.leading_bit_eval(&[TestScalar::ONE], TestScalar::TWO); + + // ASSERT + assert_eq!(bit_eval, TestScalar::TWO); +} + +// leading_bit_eval functions start + +// leading_bit_eval functions end #[test] fn we_can_compute_the_bit_distribution_of_an_empty_slice() { let data: Vec = vec![]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 0); - assert!(!dist.has_varying_sign_bit()); - assert!(!dist.sign_bit()); assert!(dist.is_valid()); - assert_eq!(TestScalar::from(dist.constant_part()), TestScalar::zero()); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - - let mut cnt = 0; - dist.for_each_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping(U256::MAX) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from(0) + ); + assert_eq!( + TestScalar::from_wrapping(dist.vary_mask()), + TestScalar::from(0) + ); + + assert_eq!(dist.vary_mask_iter().count(), 0); } #[test] @@ -37,34 +232,21 @@ fn we_can_compute_the_bit_distribution_of_a_slice_with_a_single_element() { let data: Vec = vec![val]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 0); - assert!(!dist.has_varying_sign_bit()); - assert!(!dist.sign_bit()); assert!(dist.is_valid()); assert_eq!( - TestScalar::from(dist.constant_part()), - TestScalar::from(val) + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping((U256::ONE << 2) | (U256::ONE << 10) | (U256::ONE << 255)) ); - assert_eq!(dist.most_significant_abs_bit(), 10); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert!(pos == 2 || pos == 10); - cnt += 1; - }); - assert_eq!(cnt, 2); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - - let mut cnt = 0; - dist.for_each_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!( + TestScalar::from_wrapping(dist.vary_mask()), + TestScalar::from(0) + ); + assert_eq!( + dist.leading_bit_inverse_mask(), + ((U256::ONE << 2) | (U256::ONE << 10) | (U256::ONE << 255)) ^ U256::MAX + ); + + assert_eq!(dist.vary_mask_iter().count(), 0); } #[test] @@ -72,38 +254,19 @@ fn we_can_compute_the_bit_distribution_of_a_slice_with_one_varying_bits() { let data: Vec = vec![(1 << 2) | (1 << 10), (1 << 2) | (1 << 10) | (1 << 21)]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 1); - assert!(!dist.has_varying_sign_bit()); - assert!(!dist.sign_bit()); assert!(dist.is_valid()); assert_eq!( - TestScalar::from(dist.constant_part()), - TestScalar::from((1 << 10) | (1 << 2)) + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping((U256::ONE << 2) | (U256::ONE << 10) | (U256::ONE << 255)) ); - assert_eq!(dist.most_significant_abs_bit(), 21); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert!(pos == 2 || pos == 10); - cnt += 1; - }); - assert_eq!(cnt, 2); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert_eq!(pos, 21); - cnt += 1; - }); - assert_eq!(cnt, 1); - - let mut cnt = 0; - dist.for_each_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert_eq!(pos, 21); - cnt += 1; - }); - assert_eq!(cnt, 1); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping( + (U256::FOUR | (U256::ONE << 10) | (U256::ONE << 21) | (U256::ONE << 255)) ^ U256::MAX + ) + ); + + assert_eq!(dist.vary_mask_iter().count(), 1); } #[test] @@ -115,38 +278,29 @@ fn we_can_compute_the_bit_distribution_of_a_slice_with_multiple_varying_bits() { ]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 4); - assert!(!dist.has_varying_sign_bit()); - assert!(!dist.sign_bit()); assert!(dist.is_valid()); + assert_eq!( - TestScalar::from(dist.constant_part()), - TestScalar::from(1 << 10) + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping((U256::ONE << 10) | (U256::ONE << 255)) ); - assert_eq!(dist.most_significant_abs_bit(), 50); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert_eq!(pos, 10); - cnt += 1; - }); - assert_eq!(cnt, 1); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert!(pos == 2 || pos == 3 || pos == 21 || pos == 50); - cnt += 1; - }); - assert_eq!(cnt, 4); - - let mut cnt = 0; - dist.for_each_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert!(pos == 2 || pos == 3 || pos == 21 || pos == 50); - cnt += 1; - }); - assert_eq!(cnt, 4); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping( + (U256::FOUR + | U256::EIGHT + | (U256::ONE << 10) + | (U256::ONE << 21) + | (U256::ONE << 50) + | (U256::ONE << 255)) + ^ U256::MAX + ) + ); + + for i in dist.vary_mask_iter() { + assert!(i == 2 || i == 3 || i == 21 || i == 50); + } + assert_eq!(dist.vary_mask_iter().count(), 4); } #[test] @@ -154,63 +308,34 @@ fn we_can_compute_the_bit_distribution_of_negative_values() { let data: Vec = vec![-1]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 0); - assert!(!dist.has_varying_sign_bit()); - assert!(dist.sign_bit()); assert!(dist.is_valid()); - assert_eq!(TestScalar::from(dist.constant_part()), TestScalar::one()); - assert_eq!(dist.most_significant_abs_bit(), 0); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert_eq!(pos, 0); - cnt += 1; - }); - assert_eq!(cnt, 1); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - - let mut cnt = 0; - dist.for_each_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping(U256::ONE << 255) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping(U256::MAX ^ (U256::ONE << 255)) + ); + + assert_eq!(dist.vary_mask_iter().count(), 0); } #[test] fn we_can_compute_the_bit_distribution_of_values_with_different_signs() { let data: Vec = vec![-1, 1]; let dist = BitDistribution::new::(&data); - assert_eq!(dist.num_varying_bits(), 1); - assert!(dist.has_varying_sign_bit()); - assert_eq!(TestScalar::from(dist.constant_part()), TestScalar::one()); - assert_eq!(dist.most_significant_abs_bit(), 0); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert_eq!(pos, 0); - cnt += 1; - }); - assert_eq!(cnt, 1); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - - let mut cnt = 0; - dist.for_each_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 3); - assert_eq!(pos, 63); - cnt += 1; - }); - assert_eq!(cnt, 1); + assert_eq!(dist.num_varying_bits(), 2); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping(U256::ONE << 255) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping(U256::MAX ^ (U256::ONE | (U256::ONE << 255))) + ); + + assert_eq!(dist.vary_mask_iter().count(), 2); } #[test] @@ -218,31 +343,17 @@ fn we_can_compute_the_bit_distribution_of_values_with_different_signs_and_values let data: Vec = vec![4, -1, 1]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 3); - assert!(dist.has_varying_sign_bit()); assert!(dist.is_valid()); - assert_eq!(TestScalar::from(dist.constant_part()), TestScalar::zero()); - assert_eq!(dist.most_significant_abs_bit(), 2); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert!(pos == 0 || pos == 2); - cnt += 1; - }); - assert_eq!(cnt, 2); - - let mut cnt = 0; - dist.for_each_varying_bit(|i: usize, pos: usize| { - assert!((i == 0 && (pos == 0 || pos == 2)) || (i == 3 && pos == 63)); - cnt += 1; - }); - assert_eq!(cnt, 3); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping(U256::ONE << 255) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping(U256::MAX ^ (U256::FIVE | (U256::ONE << 255))) + ); + + assert_eq!(dist.vary_mask_iter().count(), 3); } #[test] @@ -252,40 +363,24 @@ fn we_can_compute_the_bit_distribution_of_values_larger_than_64_bit_integers() { let data: Vec = vec![TestScalar::from_bigint(val)]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 0); - assert!(!dist.has_varying_sign_bit()); assert!(dist.is_valid()); assert_eq!( - TestScalar::from(dist.constant_part()), - TestScalar::from_bigint(val) + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping((U256::ONE << 203) | (U256::ONE << 255)) ); - assert_eq!(dist.most_significant_abs_bit(), 64 * 3 + 11); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 3); - assert_eq!(pos, 11); - cnt += 1; - }); - assert_eq!(cnt, 1); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - - let mut cnt = 0; - dist.for_each_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping(U256::MAX ^ ((U256::ONE << 203) | (U256::ONE << 255))) + ); + + assert_eq!(dist.vary_mask_iter().count(), 0); } #[test] fn we_can_detect_invalid_bit_distributions() { let dist = BitDistribution { - or_all: [0, 0, 0, 0], vary_mask: [1, 0, 0, 0], + leading_bit_mask: [1, 0, 0, 0], }; assert!(!dist.is_valid()); } diff --git a/crates/proof-of-sql/src/base/bit/bit_mask_utils.rs b/crates/proof-of-sql/src/base/bit/bit_mask_utils.rs new file mode 100644 index 000000000..4fa5282c0 --- /dev/null +++ b/crates/proof-of-sql/src/base/bit/bit_mask_utils.rs @@ -0,0 +1,17 @@ +use crate::base::scalar::ScalarExt; +use bnum::types::U256; + +pub fn make_bit_mask(x: S) -> U256 { + let x_as_u256 = x.into_u256_wrapping(); + if x > S::MAX_SIGNED { + x_as_u256 - S::into_u256_wrapping(S::MAX_SIGNED) + (U256::ONE << 255) + - S::into_u256_wrapping(S::MAX_SIGNED) + - U256::ONE + } else { + x_as_u256 + (U256::ONE << 255) + } +} + +pub fn is_bit_mask_negative_representation(bit_mask: U256) -> bool { + bit_mask & (U256::ONE << 255) == U256::ZERO +} diff --git a/crates/proof-of-sql/src/base/bit/bit_mask_utils_test.rs b/crates/proof-of-sql/src/base/bit/bit_mask_utils_test.rs new file mode 100644 index 000000000..42f0f0ec6 --- /dev/null +++ b/crates/proof-of-sql/src/base/bit/bit_mask_utils_test.rs @@ -0,0 +1,56 @@ +use super::bit_mask_utils::make_bit_mask; +use crate::base::{ + bit::bit_mask_utils::is_bit_mask_negative_representation, + scalar::{test_scalar::TestScalar, Scalar}, +}; +use bnum::types::U256; + +#[test] +fn we_can_make_positive_bit_mask() { + // ARRANGE + let positive_scalar = TestScalar::TWO; + + // ACT + let bit_mask = make_bit_mask(positive_scalar); + + // ASSERT + assert_eq!(bit_mask, (U256::ONE << 255) + U256::TWO); +} + +#[test] +fn we_can_make_negative_bit_mask() { + // ARRANGE + let negative_scalar = -TestScalar::TWO; + + // ACT + let bit_mask = make_bit_mask(negative_scalar); + + // ASSERT + assert_eq!(bit_mask, (U256::ONE << 255) - U256::TWO); +} + +#[test] +fn we_can_verify_positive_bit_mask_is_positive_representation() { + // ARRANGE + let positive_scalar = TestScalar::TWO; + let bit_mask = make_bit_mask(positive_scalar); + + // ACT + let is_positive = !is_bit_mask_negative_representation(bit_mask); + + // ASSERT + assert!(is_positive); +} + +#[test] +fn we_can_verify_negative_bit_mask_is_negative_representation() { + // ARRANGE + let negative_scalar = -TestScalar::TWO; + let bit_mask = make_bit_mask(negative_scalar); + + // ACT + let is_negative = is_bit_mask_negative_representation(bit_mask); + + // ASSERT + assert!(is_negative); +} diff --git a/crates/proof-of-sql/src/base/bit/bit_matrix.rs b/crates/proof-of-sql/src/base/bit/bit_matrix.rs index bd53b229b..f8c784be0 100644 --- a/crates/proof-of-sql/src/base/bit/bit_matrix.rs +++ b/crates/proof-of-sql/src/base/bit/bit_matrix.rs @@ -1,8 +1,7 @@ -use crate::base::{ - bit::{make_abs_bit_mask, BitDistribution}, - scalar::Scalar, -}; +use super::bit_mask_utils::make_bit_mask; +use crate::base::{bit::BitDistribution, scalar::Scalar}; use alloc::vec::Vec; +use bnum::types::U256; use bumpalo::Bump; /// Let `x1, ..., xn` denote the values of a data column. Let @@ -17,25 +16,24 @@ pub fn compute_varying_bit_matrix<'a, S: Scalar>( vals: &[S], dist: &BitDistribution, ) -> Vec<&'a [bool]> { - let n = vals.len(); + let number_of_scalars = vals.len(); let num_varying_bits = dist.num_varying_bits(); - let data: &'a mut [bool] = alloc.alloc_slice_fill_default(n * num_varying_bits); + let data: &'a mut [bool] = alloc.alloc_slice_fill_default(number_of_scalars * num_varying_bits); // decompose - for (i, val) in vals.iter().enumerate() { - let mask = make_abs_bit_mask(*val); - let mut offset = i; - dist.for_each_varying_bit(|int_index: usize, bit_index: usize| { - data[offset] = (mask[int_index] & (1u64 << bit_index)) != 0; - offset += n; - }); + for (scalar_index, val) in vals.iter().enumerate() { + let mask = make_bit_mask(*val); + for (vary_index, bit_index) in dist.vary_mask_iter().enumerate() { + data[scalar_index + vary_index * number_of_scalars] = + (mask & (U256::ONE << bit_index)) != U256::ZERO; + } } // make result let mut res = Vec::with_capacity(num_varying_bits); for bit_index in 0..num_varying_bits { - let first = n * bit_index; - let last = n * (bit_index + 1); + let first = number_of_scalars * bit_index; + let last = number_of_scalars * (bit_index + 1); res.push(&data[first..last]); } res diff --git a/crates/proof-of-sql/src/base/bit/bit_matrix_test.rs b/crates/proof-of-sql/src/base/bit/bit_matrix_test.rs index 25d0c0d37..511a92b66 100644 --- a/crates/proof-of-sql/src/base/bit/bit_matrix_test.rs +++ b/crates/proof-of-sql/src/base/bit/bit_matrix_test.rs @@ -38,9 +38,11 @@ fn we_can_compute_the_bit_matrix_for_data_with_a_varying_sign_bit() { let dist = BitDistribution::new::(&data); let alloc = Bump::new(); let matrix = compute_varying_bit_matrix(&alloc, &data, &dist); - assert_eq!(matrix.len(), 1); - let slice1 = vec![false, true]; + assert_eq!(matrix.len(), 2); + let slice1 = vec![true, true]; + let slice2 = vec![true, false]; assert_eq!(matrix[0], slice1); + assert_eq!(matrix[1], slice2); } #[test] @@ -62,11 +64,13 @@ fn we_can_compute_the_bit_matrix_for_data_with_varying_bits_and_constant_bits() let dist = BitDistribution::new::(&data); let alloc = Bump::new(); let matrix = compute_varying_bit_matrix(&alloc, &data, &dist); - assert_eq!(matrix.len(), 2); - let slice1 = vec![true, false]; - let slice2 = vec![false, true]; + assert_eq!(matrix.len(), 3); + let slice1 = vec![true, true]; + let slice2 = vec![true, true]; + let slice3 = vec![true, false]; assert_eq!(matrix[0], slice1); assert_eq!(matrix[1], slice2); + assert_eq!(matrix[2], slice3); } #[test] diff --git a/crates/proof-of-sql/src/base/bit/mod.rs b/crates/proof-of-sql/src/base/bit/mod.rs index 18973dc62..5360b0c43 100644 --- a/crates/proof-of-sql/src/base/bit/mod.rs +++ b/crates/proof-of-sql/src/base/bit/mod.rs @@ -1,5 +1,6 @@ -mod abs_bit_mask; -pub use abs_bit_mask::*; +pub(crate) mod bit_mask_utils; +#[cfg(test)] +mod bit_mask_utils_test; mod bit_distribution; pub use bit_distribution::*; diff --git a/crates/proof-of-sql/src/base/database/column_comparison_operation.rs b/crates/proof-of-sql/src/base/database/column_comparison_operation.rs index 1f5c32833..ac995c6be 100644 --- a/crates/proof-of-sql/src/base/database/column_comparison_operation.rs +++ b/crates/proof-of-sql/src/base/database/column_comparison_operation.rs @@ -280,13 +280,13 @@ impl ComparisonOp for EqualOp { } } -pub struct GreaterThanOrEqualOp {} -impl ComparisonOp for GreaterThanOrEqualOp { +pub struct GreaterThanOp {} +impl ComparisonOp for GreaterThanOp { fn op(l: &T, r: &T) -> bool where T: Debug + Ord, { - l >= r + l > r } fn decimal_op_left_upcast( @@ -299,7 +299,10 @@ impl ComparisonOp for GreaterThanOrEqualOp { S: Scalar, T: Copy + Debug + Ord + Zero + Into, { - ge_decimal_columns(lhs, rhs, left_column_type, right_column_type) + le_decimal_columns(lhs, rhs, left_column_type, right_column_type) + .iter() + .map(|b| !b) + .collect() } fn decimal_op_right_upcast( @@ -312,25 +315,28 @@ impl ComparisonOp for GreaterThanOrEqualOp { S: Scalar, T: Copy + Debug + Ord + Zero + Into, { - le_decimal_columns(rhs, lhs, right_column_type, left_column_type) + ge_decimal_columns(rhs, lhs, right_column_type, left_column_type) + .iter() + .map(|b| !b) + .collect() } fn string_op(_lhs: &[String], _rhs: &[String]) -> ColumnOperationResult> { Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: ">=".to_string(), + operator: ">".to_string(), left_type: ColumnType::VarChar, right_type: ColumnType::VarChar, }) } } -pub struct LessThanOrEqualOp {} -impl ComparisonOp for LessThanOrEqualOp { +pub struct LessThanOp {} +impl ComparisonOp for LessThanOp { fn op(l: &T, r: &T) -> bool where T: Debug + Ord, { - l <= r + l < r } fn decimal_op_left_upcast( @@ -343,7 +349,10 @@ impl ComparisonOp for LessThanOrEqualOp { S: Scalar, T: Copy + Debug + Ord + Zero + Into, { - le_decimal_columns(lhs, rhs, left_column_type, right_column_type) + ge_decimal_columns(lhs, rhs, left_column_type, right_column_type) + .iter() + .map(|b| !b) + .collect() } fn decimal_op_right_upcast( @@ -356,12 +365,15 @@ impl ComparisonOp for LessThanOrEqualOp { S: Scalar, T: Copy + Debug + Ord + Zero + Into, { - ge_decimal_columns(rhs, lhs, right_column_type, left_column_type) + le_decimal_columns(rhs, lhs, right_column_type, left_column_type) + .iter() + .map(|b| !b) + .collect() } fn string_op(_lhs: &[String], _rhs: &[String]) -> ColumnOperationResult> { Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: "<=".to_string(), + operator: "<".to_string(), left_type: ColumnType::VarChar, right_type: ColumnType::VarChar, }) diff --git a/crates/proof-of-sql/src/base/database/expression_evaluation.rs b/crates/proof-of-sql/src/base/database/expression_evaluation.rs index d9df43097..adab9d60b 100644 --- a/crates/proof-of-sql/src/base/database/expression_evaluation.rs +++ b/crates/proof-of-sql/src/base/database/expression_evaluation.rs @@ -90,8 +90,8 @@ impl OwnedTable { BinaryOperator::And => Ok(left.element_wise_and(&right)?), BinaryOperator::Or => Ok(left.element_wise_or(&right)?), BinaryOperator::Eq => Ok(left.element_wise_eq(&right)?), - BinaryOperator::GtEq => Ok(left.element_wise_ge(&right)?), - BinaryOperator::LtEq => Ok(left.element_wise_le(&right)?), + BinaryOperator::Gt => Ok(left.element_wise_gt(&right)?), + BinaryOperator::Lt => Ok(left.element_wise_lt(&right)?), BinaryOperator::Plus => Ok(left.element_wise_add(&right)?), BinaryOperator::Minus => Ok(left.element_wise_sub(&right)?), BinaryOperator::Multiply => Ok(left.element_wise_mul(&right)?), diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index 03f80f23c..a17f4836d 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -21,9 +21,7 @@ mod column_arithmetic_operation; pub(super) use column_arithmetic_operation::{AddOp, ArithmeticOp, DivOp, MulOp, SubOp}; mod column_comparison_operation; -pub(super) use column_comparison_operation::{ - ComparisonOp, EqualOp, GreaterThanOrEqualOp, LessThanOrEqualOp, -}; +pub(super) use column_comparison_operation::{ComparisonOp, EqualOp, GreaterThanOp, LessThanOp}; mod column_index_operation; pub(super) use column_index_operation::apply_column_to_indexes; diff --git a/crates/proof-of-sql/src/base/database/owned_column_operation.rs b/crates/proof-of-sql/src/base/database/owned_column_operation.rs index 48eca7027..369360f38 100644 --- a/crates/proof-of-sql/src/base/database/owned_column_operation.rs +++ b/crates/proof-of-sql/src/base/database/owned_column_operation.rs @@ -1,6 +1,6 @@ use super::{ AddOp, ArithmeticOp, ColumnOperationError, ColumnOperationResult, ComparisonOp, DivOp, EqualOp, - GreaterThanOrEqualOp, LessThanOrEqualOp, MulOp, SubOp, + GreaterThanOp, LessThanOp, MulOp, SubOp, }; use crate::base::{ database::{ @@ -65,13 +65,13 @@ impl OwnedColumn { } /// Element-wise less than or equal to check for two columns - pub fn element_wise_le(&self, rhs: &Self) -> ColumnOperationResult { - LessThanOrEqualOp::owned_column_element_wise_comparison(self, rhs) + pub fn element_wise_lt(&self, rhs: &Self) -> ColumnOperationResult { + LessThanOp::owned_column_element_wise_comparison(self, rhs) } /// Element-wise greater than or equal to check for two columns - pub fn element_wise_ge(&self, rhs: &Self) -> ColumnOperationResult { - GreaterThanOrEqualOp::owned_column_element_wise_comparison(self, rhs) + pub fn element_wise_gt(&self, rhs: &Self) -> ColumnOperationResult { + GreaterThanOp::owned_column_element_wise_comparison(self, rhs) } /// Element-wise addition for two columns @@ -118,13 +118,13 @@ mod test { Err(ColumnOperationError::DifferentColumnLength { .. }) )); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::DifferentColumnLength { .. }) )); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert!(matches!( result, Err(ColumnOperationError::DifferentColumnLength { .. }) @@ -316,31 +316,31 @@ mod test { } #[test] - fn we_can_do_le_operation_on_numeric_and_boolean_columns() { + fn we_can_do_lt_operation_on_numeric_and_boolean_columns() { // Booleans let lhs = OwnedColumn::::Boolean(vec![true, false, true]); let rhs = OwnedColumn::::Boolean(vec![true, true, false]); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![false, true, false])) ); // Integers let lhs = OwnedColumn::::SmallInt(vec![1, 3, 2]); let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, false, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); let lhs = OwnedColumn::::Int(vec![1, 3, 2]); let rhs = OwnedColumn::::SmallInt(vec![1, 2, 3]); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, false, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); // Decimals @@ -348,29 +348,29 @@ mod test { let rhs_scalars = [1, 24, -3].iter().map(TestScalar::from).collect(); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 3, lhs_scalars); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![false, true, false])) ); // Decimals and integers let lhs_scalars = [10, -2, -30].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::TinyInt(vec![1, -20, 3]); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), -1, lhs_scalars); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![false, true, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); let lhs_scalars = [10, -2, -30].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Int(vec![1, -20, 3]); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), -1, lhs_scalars); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![false, true, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); } @@ -379,27 +379,27 @@ mod test { // Booleans let lhs = OwnedColumn::::Boolean(vec![true, false, true]); let rhs = OwnedColumn::::Boolean(vec![true, true, false]); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, false, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); // Integers let lhs = OwnedColumn::::SmallInt(vec![1, 3, 2]); let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![false, true, false])) ); let lhs = OwnedColumn::::Int(vec![1, 3, 2]); let rhs = OwnedColumn::::SmallInt(vec![1, 2, 3]); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![false, true, false])) ); // Decimals @@ -407,29 +407,29 @@ mod test { let rhs_scalars = [1, 24, -3].iter().map(TestScalar::from).collect(); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 3, lhs_scalars); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, false, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); // Decimals and integers let lhs_scalars = [10, -2, -30].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::TinyInt(vec![1_i8, -20, 3]); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), -1, lhs_scalars); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![true, false, false])) ); let lhs_scalars = [10, -2, -30].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::BigInt(vec![1_i64, -20, 3]); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), -1, lhs_scalars); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![true, false, false])) ); } @@ -443,7 +443,7 @@ mod test { .map(ToString::to_string) .collect(), ); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) @@ -456,19 +456,19 @@ mod test { .map(ToString::to_string) .collect(), ); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) @@ -477,7 +477,7 @@ mod test { // Booleans can't be compared with other types let lhs = OwnedColumn::::Boolean(vec![true, false, true]); let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) @@ -485,7 +485,7 @@ mod test { let lhs = OwnedColumn::::Boolean(vec![true, false, true]); let rhs = OwnedColumn::::Int(vec![1, 2, 3]); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) @@ -504,13 +504,13 @@ mod test { .map(ToString::to_string) .collect(), ); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) diff --git a/crates/proof-of-sql/src/base/scalar/scalar_ext.rs b/crates/proof-of-sql/src/base/scalar/scalar_ext.rs index 3739a1cc5..02f7eea38 100644 --- a/crates/proof-of-sql/src/base/scalar/scalar_ext.rs +++ b/crates/proof-of-sql/src/base/scalar/scalar_ext.rs @@ -1,4 +1,5 @@ use super::Scalar; +use bnum::types::U256; use core::cmp::Ordering; /// Extention trait for blanket implementations for `Scalar` types. @@ -17,7 +18,20 @@ pub trait ScalarExt: Scalar { _ => Ordering::Greater, } } + + #[must_use] + /// Converts a U256 to Scalar, wrapping as needed + fn from_wrapping(value: U256) -> Self { + let value_as_limbs: [u64; 4] = value.into(); + Self::from(value_as_limbs) + } + + /// Converts a Scalar to U256. Note that any values above `MAX_SIGNED` shall remain positive, even if they are representative of negative values. + fn into_u256_wrapping(self) -> U256 { + U256::from(Into::<[u64; 4]>::into(self)) + } } + impl ScalarExt for S {} #[cfg(test)] diff --git a/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs b/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs index 439e833a6..a3ed78c38 100644 --- a/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs +++ b/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs @@ -1,5 +1,19 @@ +use super::ScalarExt; use crate::base::scalar::{test_scalar::TestScalar, Scalar}; +use bnum::types::U256; +use core::str::FromStr; use num_traits::Inv; +use rand::{rngs::StdRng, Rng, SeedableRng}; + +const MAX_TEST_SCALAR_SIGNED_VALUE_AS_STRING: &str = + "3618502788666131106986593281521497120428558179689953803000975469142727125494"; + +fn random_u256(seed: u64) -> U256 { + let mut rng = StdRng::seed_from_u64(seed); + let mut bytes = [0u64; 4]; + rng.fill(&mut bytes); + U256::from(bytes) +} #[test] fn we_can_get_test_scalar_constants_from_z_p() { @@ -10,3 +24,71 @@ fn we_can_get_test_scalar_constants_from_z_p() { assert_eq!(-TestScalar::TWO.inv().unwrap(), TestScalar::MAX_SIGNED); assert_eq!(TestScalar::from(10), TestScalar::TEN); } + +#[test] +fn we_can_convert_u256_to_test_scalar_with_wrapping() { + // ARRANGE + let u256_value = U256::TWO; + + // ACT + let test_scalar = TestScalar::from_wrapping(u256_value); + + // ASSERT + assert_eq!(test_scalar, TestScalar::TWO); +} + +#[test] +fn we_can_convert_u256_to_test_scalar_with_wrapping_of_large_value() { + // ARRANGE + let u256_value = + U256::from_str(MAX_TEST_SCALAR_SIGNED_VALUE_AS_STRING).unwrap() * U256::TWO + U256::ONE; + + // ACT + let test_scalar = TestScalar::from_wrapping(u256_value); + + // ASSERT + assert_eq!(test_scalar, TestScalar::ZERO); +} + +#[test] +fn we_can_convert_test_scalar_to_u256_with_wrapping() { + // ARRANGE + let test_scalar = TestScalar::TWO; + + // ACT + let u256_value = test_scalar.into_u256_wrapping(); + + // ASSERT + assert_eq!(u256_value, U256::TWO); +} + +#[test] +fn we_can_convert_test_scalar_to_u256_with_wrapping_of_negative_value() { + // ARRANGE + let test_scalar = -TestScalar::ONE; + + // ACT + let u256: bnum::BUint<4> = test_scalar.into_u256_wrapping(); + + // ASSERT + assert_eq!( + u256, + U256::from_str(MAX_TEST_SCALAR_SIGNED_VALUE_AS_STRING).unwrap() * U256::TWO + ); +} + +#[test] +fn we_can_convert_u256_to_test_scalar_with_wrapping_of_random_u256() { + // ARRANGE + let random_u256 = random_u256(100); + let random_u256_after_wrapping = random_u256 + % (U256::from_str(MAX_TEST_SCALAR_SIGNED_VALUE_AS_STRING).unwrap() * U256::TWO + U256::ONE); + assert_ne!(random_u256, random_u256_after_wrapping); + + // ACT + let test_scalar = TestScalar::from_wrapping(random_u256); + let expected_scalar = TestScalar::from_wrapping(random_u256_after_wrapping); + + // ASSERT + assert_eq!(test_scalar, expected_scalar); +} diff --git a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs index a0e76a031..b712cdd5e 100644 --- a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs @@ -168,12 +168,12 @@ impl DynProofExprBuilder<'_> { let right = self.visit_expr(right); DynProofExpr::try_new_equals(left?, right?) } - BinaryOperator::GtEq => { + BinaryOperator::Gt => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_inequality(left?, right?, false) } - BinaryOperator::LtEq => { + BinaryOperator::Lt => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_inequality(left?, right?, true) diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index 708cb8236..e7aa30151 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -171,8 +171,8 @@ impl<'a> QueryContextBuilder<'a> { BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Eq - | BinaryOperator::GtEq - | BinaryOperator::LtEq => Ok(ColumnType::Boolean), + | BinaryOperator::Gt + | BinaryOperator::Lt => Ok(ColumnType::Boolean), BinaryOperator::Multiply | BinaryOperator::Divide | BinaryOperator::Minus @@ -309,7 +309,7 @@ pub(crate) fn type_check_binary_operation( | (ColumnType::Scalar, _) ) || (left_dtype.is_numeric() && right_dtype.is_numeric()) } - BinaryOperator::GtEq | BinaryOperator::LtEq => { + BinaryOperator::Gt | BinaryOperator::Lt => { if left_dtype == ColumnType::VarChar || right_dtype == ColumnType::VarChar { return false; } diff --git a/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs b/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs index 560362df4..1df7f16c5 100644 --- a/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs @@ -143,14 +143,17 @@ fn we_can_directly_check_whether_bigint_columns_ge_int128() { .build(Some(expr_integer_to_integer)) .unwrap() .unwrap(); - let expected = DynProofExpr::try_new_inequality( - DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( - "sxt.sxt_tab".parse().unwrap(), - "bigint_column".into(), - ColumnType::BigInt, - ))), - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))), - false, + let expected = DynProofExpr::try_new_not( + DynProofExpr::try_new_inequality( + DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( + "sxt.sxt_tab".parse().unwrap(), + "bigint_column".into(), + ColumnType::BigInt, + ))), + DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))), + true, + ) + .unwrap(), ) .unwrap(); assert_eq!(actual, expected); @@ -165,14 +168,17 @@ fn we_can_directly_check_whether_bigint_columns_le_int128() { .build(Some(expr_integer_to_integer)) .unwrap() .unwrap(); - let expected = DynProofExpr::try_new_inequality( - DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( - "sxt.sxt_tab".parse().unwrap(), - "bigint_column".into(), - ColumnType::BigInt, - ))), - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))), - true, + let expected = DynProofExpr::try_new_not( + DynProofExpr::try_new_inequality( + DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( + "sxt.sxt_tab".parse().unwrap(), + "bigint_column".into(), + ColumnType::BigInt, + ))), + DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))), + false, + ) + .unwrap(), ) .unwrap(); assert_eq!(actual, expected); diff --git a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs index 878a5915b..856811d82 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs @@ -39,7 +39,7 @@ impl Default for TrivialTestProofPlan { evaluation: 0, produce_length: true, bit_distribution: Some(BitDistribution { - or_all: [0; 4], + leading_bit_mask: [0; 4], vary_mask: [0; 4], }), } @@ -229,7 +229,7 @@ fn verify_fails_if_the_number_of_bit_distributions_is_not_enough() { fn verify_fails_if_a_bit_distribution_is_invalid() { let expr = TrivialTestProofPlan { bit_distribution: Some(BitDistribution { - or_all: [1; 4], + leading_bit_mask: [1; 4], vary_mask: [1; 4], }), ..Default::default() diff --git a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs index 92601f512..a3b28e35e 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs @@ -31,7 +31,7 @@ pub fn scale_and_subtract_literal( let operator = if is_equal { BinaryOperator::Eq } else { - BinaryOperator::LtEq + BinaryOperator::Lt }; if !type_check_binary_operation(lhs_type, rhs_type, &operator) { return Err(ConversionError::DataTypeMismatch { @@ -104,7 +104,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( let operator = if is_equal { BinaryOperator::Eq } else { - BinaryOperator::LtEq + BinaryOperator::Lt }; if !type_check_binary_operation(lhs_type, rhs_type, &operator) { return Err(ConversionError::DataTypeMismatch { diff --git a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs index efe540b52..40dd63627 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs @@ -89,15 +89,15 @@ impl DynProofExpr { pub fn try_new_inequality( lhs: DynProofExpr, rhs: DynProofExpr, - is_lte: bool, + is_lt: bool, ) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::LtEq) { + if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Lt) { Ok(Self::Inequality(InequalityExpr::new( Box::new(lhs), Box::new(rhs), - is_lte, + is_lt, ))) } else { Err(ConversionError::DataTypeMismatch { diff --git a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs index f3f416e10..4a0ac56d9 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs @@ -1,8 +1,4 @@ -use super::{ - prover_evaluate_equals_zero, prover_evaluate_or, result_evaluate_equals_zero, - result_evaluate_or, scale_and_add_subtract_eval, scale_and_subtract, - verifier_evaluate_equals_zero, verifier_evaluate_or, DynProofExpr, ProofExpr, -}; +use super::{scale_and_add_subtract_eval, scale_and_subtract, DynProofExpr, ProofExpr}; use crate::{ base::{ database::{Column, ColumnRef, ColumnType, Table}, @@ -25,21 +21,13 @@ use serde::{Deserialize, Serialize}; pub struct InequalityExpr { lhs: Box, rhs: Box, - is_lte: bool, - #[cfg(test)] - pub(crate) treat_column_of_zeros_as_negative: bool, + is_lt: bool, } impl InequalityExpr { - /// Create a new less than or equal expression - pub fn new(lhs: Box, rhs: Box, is_lte: bool) -> Self { - Self { - lhs, - rhs, - is_lte, - #[cfg(test)] - treat_column_of_zeros_as_negative: false, - } + /// Create a new less than or equal + pub fn new(lhs: Box, rhs: Box, is_lt: bool) -> Self { + Self { lhs, rhs, is_lt } } } @@ -61,7 +49,7 @@ impl ProofExpr for InequalityExpr { let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); let table_length = table.num_rows(); - let diff = if self.is_lte { + let diff = if self.is_lt { scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, false) .expect("Failed to scale and subtract") } else { @@ -69,14 +57,8 @@ impl ProofExpr for InequalityExpr { .expect("Failed to scale and subtract") }; - // diff == 0 - let equals_zero = result_evaluate_equals_zero(table_length, alloc, diff); - - // sign(diff) == -1 - let sign = result_evaluate_sign(table_length, alloc, diff); - - // (diff == 0) || (sign(diff) == -1) - let res = Column::Boolean(result_evaluate_or(table_length, alloc, equals_zero, sign)); + // (sign(diff) == -1) + let res = Column::Boolean(result_evaluate_sign(table_length, alloc, diff)); log::log_memory_usage("End"); @@ -96,7 +78,7 @@ impl ProofExpr for InequalityExpr { let rhs_column = self.rhs.prover_evaluate(builder, alloc, table); let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); - let diff = if self.is_lte { + let diff = if self.is_lt { scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, false) .expect("Failed to scale and subtract") } else { @@ -104,20 +86,8 @@ impl ProofExpr for InequalityExpr { .expect("Failed to scale and subtract") }; - // diff == 0 - let equals_zero = prover_evaluate_equals_zero(table.num_rows(), builder, alloc, diff); - - // sign(diff) == -1 - let sign = prover_evaluate_sign( - builder, - alloc, - diff, - #[cfg(test)] - self.treat_column_of_zeros_as_negative, - ); - - // (diff == 0) || (sign(diff) == -1) - let res = Column::Boolean(prover_evaluate_or(builder, alloc, equals_zero, sign)); + // (sign(diff) == -1) + let res = Column::Boolean(prover_evaluate_sign(builder, alloc, diff)); log::log_memory_usage("End"); @@ -134,20 +104,14 @@ impl ProofExpr for InequalityExpr { let rhs_eval = self.rhs.verifier_evaluate(builder, accessor, one_eval)?; let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); - let diff_eval = if self.is_lte { + let diff_eval = if self.is_lt { scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, true) } else { scale_and_add_subtract_eval(rhs_eval, lhs_eval, rhs_scale, lhs_scale, true) }; - // diff == 0 - let equals_zero = verifier_evaluate_equals_zero(builder, diff_eval, one_eval)?; - // sign(diff) == -1 - let sign = verifier_evaluate_sign(builder, diff_eval, one_eval)?; - - // (diff == 0) || (sign(diff) == -1) - verifier_evaluate_or(builder, &equals_zero, &sign) + verifier_evaluate_sign(builder, diff_eval, one_eval) } fn get_column_references(&self, columns: &mut IndexSet) { diff --git a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs index fa83572d3..2ef62a83d 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs @@ -11,7 +11,7 @@ use crate::{ parse::ConversionError, proof::{exercise_verification, VerifiableQueryResult}, proof_exprs::{test_utility::*, DynProofExpr, ProofExpr}, - proof_plans::{test_utility::*, DynProofPlan}, + proof_plans::test_utility::*, }, }; use bumpalo::Bump; @@ -480,16 +480,11 @@ fn the_sign_can_be_0_or_1_for_a_constant_column_of_zeros() { let data = owned_table([bigint("a", [0_i64, 0, 0]), bigint("b", [1_i64, 2, 3])]); let t = "sxt.t".parse().unwrap(); let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); - let mut ast = filter( + let ast = filter( cols_expr_plan(t, &["b"], &accessor), tab(t), lte(column(t, "a", &accessor), const_bigint(0)), ); - if let DynProofPlan::Filter(filter) = &mut ast { - if let DynProofExpr::Inequality(lte) = &mut filter.where_clause { - lte.treat_column_of_zeros_as_negative = true; - } - } let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); exercise_verification(&verifiable_res, &ast, &accessor, t); let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; diff --git a/crates/proof-of-sql/src/sql/proof_exprs/mod.rs b/crates/proof-of-sql/src/sql/proof_exprs/mod.rs index 298ad945d..a5717ee9b 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/mod.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/mod.rs @@ -39,7 +39,7 @@ use inequality_expr::InequalityExpr; mod inequality_expr_test; mod or_expr; -use or_expr::{prover_evaluate_or, result_evaluate_or, verifier_evaluate_or, OrExpr}; +use or_expr::OrExpr; #[cfg(all(test, feature = "blitzar"))] mod or_expr_test; @@ -58,9 +58,6 @@ pub(crate) use numerical_util::{ mod equals_expr; pub(crate) use equals_expr::EqualsExpr; -use equals_expr::{ - prover_evaluate_equals_zero, result_evaluate_equals_zero, verifier_evaluate_equals_zero, -}; #[cfg(all(test, feature = "blitzar"))] mod equals_expr_test; diff --git a/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs b/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs index baa8b44e5..e6d3c9b42 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs @@ -33,14 +33,14 @@ pub fn equal(left: DynProofExpr, right: DynProofExpr) -> DynProofExpr { /// Panics if: /// - `DynProofExpr::try_new_inequality()` returns an error. pub fn lte(left: DynProofExpr, right: DynProofExpr) -> DynProofExpr { - DynProofExpr::try_new_inequality(left, right, true).unwrap() + not(DynProofExpr::try_new_inequality(left, right, false).unwrap()) } /// # Panics /// Panics if: /// - `DynProofExpr::try_new_inequality()` returns an error. pub fn gte(left: DynProofExpr, right: DynProofExpr) -> DynProofExpr { - DynProofExpr::try_new_inequality(left, right, false).unwrap() + not(DynProofExpr::try_new_inequality(left, right, true).unwrap()) } /// # Panics diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification.rs b/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification.rs deleted file mode 100644 index fe4b772e1..000000000 --- a/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification.rs +++ /dev/null @@ -1,64 +0,0 @@ -use crate::base::{bit::BitDistribution, proof::ProofError, scalar::Scalar}; - -#[allow( - clippy::missing_panics_doc, - reason = "All assertions check for validity within the context, ensuring no panic can occur" -)] -/// Given a bit distribution for a column of data with a constant sign, the evaluation of a column -/// of ones, the constant column's evaluation, and the evaluation of varying absolute bits, verify -/// that the bit distribution is correct. -pub fn verify_constant_sign_decomposition( - dist: &BitDistribution, - eval: S, - one_eval: S, - bit_evals: &[S], -) -> Result<(), ProofError> { - assert!( - dist.is_valid() - && dist.is_within_acceptable_range() - && dist.num_varying_bits() == bit_evals.len() - && !dist.has_varying_sign_bit() - ); - let lhs = if dist.sign_bit() { -eval } else { eval }; - let mut rhs = S::from(dist.constant_part()) * one_eval; - let mut vary_index = 0; - dist.for_each_abs_varying_bit(|int_index: usize, bit_index: usize| { - let mut mult = [0u64; 4]; - mult[int_index] = 1u64 << bit_index; - rhs += S::from(mult) * bit_evals[vary_index]; - vary_index += 1; - }); - if lhs == rhs { - Ok(()) - } else { - Err(ProofError::VerificationError { - error: "constant sign bitwise decomposition is invalid", - }) - } -} - -#[allow( - clippy::missing_panics_doc, - reason = "The assertion checks ensure that conditions are valid, preventing panics" -)] -pub fn verify_constant_abs_decomposition( - dist: &BitDistribution, - eval: S, - one_eval: S, - sign_eval: S, -) -> Result<(), ProofError> { - assert!( - dist.is_valid() - && dist.is_within_acceptable_range() - && dist.num_varying_bits() == 1 - && dist.has_varying_sign_bit() - ); - let t = one_eval - S::TWO * sign_eval; - if S::from(dist.constant_part()) * t == eval { - Ok(()) - } else { - Err(ProofError::VerificationError { - error: "constant absolute bitwise decomposition is invalid", - }) - } -} diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification_test.rs b/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification_test.rs deleted file mode 100644 index 4d0c99ef3..000000000 --- a/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification_test.rs +++ /dev/null @@ -1,122 +0,0 @@ -use super::{verify_constant_abs_decomposition, verify_constant_sign_decomposition}; -use crate::base::{ - bit::BitDistribution, - scalar::Curve25519Scalar, - slice_ops::{inner_product, slice_cast}, -}; -use ark_std::UniformRand; -use core::iter::repeat_with; - -fn rand_eval_vec(len: usize) -> Vec { - let rng = &mut ark_std::test_rng(); - repeat_with(|| Curve25519Scalar::rand(rng)) - .take(len) - .collect() -} - -#[test] -fn we_can_verify_the_decomposition_of_a_constant_column() { - let data: Vec = - vec![Curve25519Scalar::from(1234), Curve25519Scalar::from(1234)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &[]).is_ok()); -} - -#[test] -fn we_can_verify_the_decomposition_of_a_column_with_constant_sign() { - let data: Vec = - vec![Curve25519Scalar::from(123), Curve25519Scalar::from(122)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - let bits = [inner_product(&slice_cast(&[1, 0]), &eval_vec)]; - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &bits).is_ok()); -} - -#[test] -fn we_can_verify_the_decomposition_of_a_constant_column_with_negative_values() { - let data: Vec = - vec![Curve25519Scalar::from(-1234), Curve25519Scalar::from(-1234)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &[]).is_ok()); -} - -#[test] -fn constant_verification_fails_if_the_commitment_doesnt_match() { - let data: Vec = - vec![Curve25519Scalar::from(1234), Curve25519Scalar::from(1234)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data: Vec = - vec![Curve25519Scalar::from(1235), Curve25519Scalar::from(1234)]; - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &[]).is_err()); -} - -#[test] -fn constant_verification_fails_if_the_sign_bit_doesnt_match() { - let data: Vec = - vec![Curve25519Scalar::from(1234), Curve25519Scalar::from(1234)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data: Vec = - vec![Curve25519Scalar::from(-1234), Curve25519Scalar::from(-1234)]; - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &[]).is_err()); -} - -#[test] -fn constant_verification_fails_if_a_varying_bit_doesnt_match() { - let data: Vec = - vec![Curve25519Scalar::from(1234), Curve25519Scalar::from(1234)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data: Vec = - vec![Curve25519Scalar::from(234), Curve25519Scalar::from(1234)]; - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &[]).is_err()); -} - -#[test] -fn we_can_verify_a_decomposition_with_only_a_varying_sign() { - let data: Vec = vec![Curve25519Scalar::from(-1), Curve25519Scalar::from(1)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - let sign_eval = inner_product(&slice_cast(&[1, 0]), &eval_vec); - assert!(verify_constant_abs_decomposition(&dist, data_eval, one_eval, sign_eval).is_ok()); -} - -#[test] -fn constant_abs_verification_fails_if_the_sign_and_data_dont_match() { - let data: Vec = vec![Curve25519Scalar::from(-1), Curve25519Scalar::from(1)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - let sign_eval = inner_product(&slice_cast(&[0, 1]), &eval_vec); - assert!(verify_constant_abs_decomposition(&dist, data_eval, one_eval, sign_eval).is_err()); -} - -#[test] -fn we_can_verify_a_decomposition_with_only_a_varying_sign_and_magnitude_greater_than_one() { - let data: Vec = - vec![Curve25519Scalar::from(-100), Curve25519Scalar::from(100)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - let sign_eval = inner_product(&slice_cast(&[1, 0]), &eval_vec); - assert!(verify_constant_abs_decomposition(&dist, data_eval, one_eval, sign_eval).is_ok()); -} diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs b/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs index 0ec348767..4da2ca170 100644 --- a/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs +++ b/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs @@ -1,8 +1,4 @@ //! This module contains shared proof logic for multiple `ProofExpr` / `ProofPlan` implementations. -mod bitwise_verification; -use bitwise_verification::{verify_constant_abs_decomposition, verify_constant_sign_decomposition}; -#[cfg(test)] -mod bitwise_verification_test; mod sign_expr; pub(crate) use sign_expr::{prover_evaluate_sign, result_evaluate_sign, verifier_evaluate_sign}; pub mod range_check; diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs index cf591de76..554bcea82 100644 --- a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs @@ -1,16 +1,16 @@ -use super::{verify_constant_abs_decomposition, verify_constant_sign_decomposition}; use crate::{ base::{ - bit::{compute_varying_bit_matrix, BitDistribution}, + bit::{ + bit_mask_utils::{is_bit_mask_negative_representation, make_bit_mask}, + compute_varying_bit_matrix, BitDistribution, + }, proof::ProofError, - scalar::Scalar, - }, - sql::proof::{ - FinalRoundBuilder, SumcheckSubpolynomialTerm, SumcheckSubpolynomialType, - VerificationBuilder, + scalar::{Scalar, ScalarExt}, }, + sql::proof::{FinalRoundBuilder, SumcheckSubpolynomialType, VerificationBuilder}, }; use alloc::{boxed::Box, vec, vec::Vec}; +use bnum::types::U256; use bumpalo::Bump; /// Compute the sign bit for a column of scalars. @@ -25,23 +25,13 @@ pub fn result_evaluate_sign<'a, S: Scalar>( expr: &'a [S], ) -> &'a [bool] { assert_eq!(table_length, expr.len()); - // bit_distribution - let dist = BitDistribution::new::(expr); - - // handle the constant case - if dist.num_varying_bits() == 0 { - return alloc.alloc_slice_fill_copy(table_length, dist.sign_bit()); - } - - // prove that the bits are binary - let bits = compute_varying_bit_matrix(alloc, expr, &dist); - if !dist.has_varying_sign_bit() { - return alloc.alloc_slice_fill_copy(table_length, dist.sign_bit()); - } - - let result = bits.last().unwrap(); - assert_eq!(table_length, result.len()); - result + let signs = expr + .iter() + .map(|s| make_bit_mask(*s)) + .map(is_bit_mask_negative_representation) + .collect::>(); + assert_eq!(table_length, signs.len()); + alloc.alloc_slice_copy(&signs) } /// Prove the sign decomposition for a column of scalars. @@ -59,39 +49,25 @@ pub fn prover_evaluate_sign<'a, S: Scalar>( builder: &mut FinalRoundBuilder<'a, S>, alloc: &'a Bump, expr: &'a [S], - #[cfg(test)] treat_column_of_zeros_as_negative: bool, ) -> &'a [bool] { - let table_length = expr.len(); // bit_distribution let dist = BitDistribution::new::(expr); - #[cfg(test)] - let dist = { - let mut dist = dist; - if treat_column_of_zeros_as_negative && dist.vary_mask == [0; 4] { - dist.or_all[3] = 1 << 63; - } - dist - }; builder.produce_bit_distribution(dist.clone()); - // handle the constant case - if dist.num_varying_bits() == 0 { - return alloc.alloc_slice_fill_copy(table_length, dist.sign_bit()); - } - - // prove that the bits are binary - let bits = compute_varying_bit_matrix(alloc, expr, &dist); - prove_bits_are_binary(builder, &bits); - if !dist.has_varying_sign_bit() { - return alloc.alloc_slice_fill_copy(table_length, dist.sign_bit()); - } - - if dist.num_varying_bits() > 1 { - prove_bit_decomposition(builder, alloc, expr, &bits, &dist); + if dist.num_varying_bits() > 0 { + // prove that the bits are binary + let bits = compute_varying_bit_matrix(alloc, expr, &dist); + prove_bits_are_binary(builder, &bits); } // This might panic if `bits.last()` returns `None`. - bits.last().unwrap() + + let signs = expr + .iter() + .map(|s| make_bit_mask(*s)) + .map(is_bit_mask_negative_representation) + .collect::>(); + alloc.alloc_slice_copy(&signs) } /// Verify the sign decomposition for a column of scalars. @@ -120,33 +96,11 @@ pub fn verifier_evaluate_sign( // establish that the bits are binary verify_bits_are_binary(builder, &bit_evals)?; - // handle the special case of the sign bit being constant - if !dist.has_varying_sign_bit() { - return verifier_const_sign_evaluate(&dist, eval, one_eval, &bit_evals); - } - - // handle the special case of the absolute part being constant - if dist.num_varying_bits() == 1 { - verify_constant_abs_decomposition(&dist, eval, one_eval, bit_evals[0])?; - } else { - verify_bit_decomposition(builder, eval, one_eval, &bit_evals, &dist)?; - } - - Ok(*bit_evals.last().unwrap()) -} - -fn verifier_const_sign_evaluate( - dist: &BitDistribution, - eval: S, - one_eval: S, - bit_evals: &[S], -) -> Result { - verify_constant_sign_decomposition(dist, eval, one_eval, bit_evals)?; - if dist.sign_bit() { - Ok(one_eval) - } else { - Ok(S::zero()) - } + verify_bit_decomposition(eval, one_eval, &bit_evals, &dist) + .then(|| one_eval - dist.leading_bit_eval(&bit_evals, one_eval)) + .ok_or(ProofError::VerificationError { + error: "invalid bit_decomposition", + }) } fn prove_bits_are_binary<'a, S: Scalar>( @@ -179,70 +133,91 @@ fn verify_bits_are_binary( Ok(()) } -/// # Panics -/// Panics if `bits.last()` returns `None`. -/// -/// This function generates subpolynomial terms for sumcheck, involving the scalar expression and its bit decomposition. -fn prove_bit_decomposition<'a, S: Scalar>( - builder: &mut FinalRoundBuilder<'a, S>, - alloc: &'a Bump, - expr: &'a [S], - bits: &[&'a [bool]], - dist: &BitDistribution, -) { - let sign_mle = bits.last().unwrap(); - let sign_mle: &[_] = - alloc.alloc_slice_fill_with(sign_mle.len(), |i| 1 - 2 * i32::from(sign_mle[i])); - let mut terms: Vec> = Vec::new(); - - // expr - terms.push((S::one(), vec![Box::new(expr)])); - - // expr bit decomposition - let const_part = S::from(dist.constant_part()); - if !const_part.is_zero() { - terms.push((-const_part, vec![Box::new(sign_mle)])); - } - let mut vary_index = 0; - dist.for_each_abs_varying_bit(|int_index: usize, bit_index: usize| { - let mut mult = [0u64; 4]; - mult[int_index] = 1u64 << bit_index; - terms.push(( - -S::from(mult), - vec![Box::new(sign_mle), Box::new(bits[vary_index])], - )); - vary_index += 1; - }); - builder.produce_sumcheck_subpolynomial(SumcheckSubpolynomialType::Identity, terms); -} - /// # Panics /// Panics if `bit_evals.last()` returns `None`. /// /// This function checks the consistency of the bit evaluations with the expression evaluation. -fn verify_bit_decomposition( - builder: &mut VerificationBuilder, +fn verify_bit_decomposition( expr_eval: S, one_eval: S, bit_evals: &[S], dist: &BitDistribution, -) -> Result<(), ProofError> { - let mut eval = expr_eval; - let sign_eval = bit_evals.last().unwrap(); - let sign_eval = one_eval - S::TWO * *sign_eval; - let mut vary_index = 0; - eval -= sign_eval * S::from(dist.constant_part()); - dist.for_each_abs_varying_bit(|int_index: usize, bit_index: usize| { - let mut mult = [0u64; 4]; - mult[int_index] = 1u64 << bit_index; - let bit_eval = bit_evals[vary_index]; - eval -= S::from(mult) * sign_eval * bit_eval; - vary_index += 1; - }); - builder.try_produce_sumcheck_subpolynomial_evaluation( - SumcheckSubpolynomialType::Identity, - eval, - 2, - )?; - Ok(()) +) -> bool { + let sign_eval = dist.leading_bit_eval(bit_evals, one_eval); + let mut rhs = sign_eval * S::from_wrapping(dist.leading_bit_mask()) + + (one_eval - sign_eval) * S::from_wrapping(dist.leading_bit_inverse_mask()) + - one_eval * S::from_wrapping(U256::ONE << 255); + + for (vary_index, bit_index) in dist.vary_mask_iter().enumerate() { + if bit_index != 255 { + let mult = U256::ONE << bit_index; + let bit_eval = bit_evals[vary_index]; + rhs += S::from_wrapping(mult) * bit_eval; + } + } + rhs == expr_eval +} + +#[cfg(test)] +mod tests { + use crate::{ + base::{ + bit::BitDistribution, + scalar::{test_scalar::TestScalar, Scalar}, + }, + sql::proof_gadgets::sign_expr::verify_bit_decomposition, + }; + + #[test] + fn we_can_verify_bit_decomposition() { + let dist = BitDistribution { + vary_mask: [629, 0, 0, 0], + leading_bit_mask: [2, 0, 0, 9_223_372_036_854_775_808], + }; + let one_eval = TestScalar::ONE; + let bit_evals = [0, 0, 1, 1, 0, 1].map(TestScalar::from); + let expr_eval = TestScalar::from(562); + assert!(verify_bit_decomposition( + expr_eval, one_eval, &bit_evals, &dist, + )); + } + + #[test] + fn we_can_verify_bit_decomposition_constant_sign() { + let dist = BitDistribution { + vary_mask: [629, 0, 0, 0], + leading_bit_mask: [2, 0, 0, 9_223_372_036_854_775_808], + }; + let a = TestScalar::ONE; + let b = TestScalar::ONE; + let expr_eval = TestScalar::from(118) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(562) * a * (TestScalar::ONE - b) + + TestScalar::from(3) * (TestScalar::ONE - a) * b; + let one_eval = TestScalar::from(1) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(1) * a * (TestScalar::ONE - b) + + TestScalar::from(1) * (TestScalar::ONE - a) * b; + let bit_evals = [ + TestScalar::from(0) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(0) * a * (TestScalar::ONE - b) + + TestScalar::from(1) * (TestScalar::ONE - a) * b, + TestScalar::from(1) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(0) * a * (TestScalar::ONE - b) + + TestScalar::from(0) * (TestScalar::ONE - a) * b, + TestScalar::from(1) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(1) * a * (TestScalar::ONE - b) + + TestScalar::from(0) * (TestScalar::ONE - a) * b, + TestScalar::from(1) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(1) * a * (TestScalar::ONE - b) + + TestScalar::from(0) * (TestScalar::ONE - a) * b, + TestScalar::from(1) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(0) * a * (TestScalar::ONE - b) + + TestScalar::from(0) * (TestScalar::ONE - a) * b, + TestScalar::from(0) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(1) * a * (TestScalar::ONE - b) + + TestScalar::from(0) * (TestScalar::ONE - a) * b, + ]; + assert!(verify_bit_decomposition( + expr_eval, one_eval, &bit_evals, &dist, + )); + } } diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr_test.rs b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr_test.rs index 73d7ccc4e..2ee86e44b 100644 --- a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr_test.rs @@ -1,12 +1,15 @@ use super::{prover_evaluate_sign, result_evaluate_sign, verifier_evaluate_sign}; use crate::{ - base::{bit::BitDistribution, polynomial::MultilinearExtension, scalar::Curve25519Scalar}, + base::{ + bit::BitDistribution, + polynomial::MultilinearExtension, + scalar::{Curve25519Scalar, Scalar}, + }, sql::proof::{ FinalRoundBuilder, SumcheckMleEvaluations, SumcheckRandomScalars, VerificationBuilder, }, }; use bumpalo::Bump; -use num_traits::Zero; #[test] fn prover_evaluation_generates_the_bit_distribution_of_a_constant_column() { @@ -15,7 +18,7 @@ fn prover_evaluation_generates_the_bit_distribution_of_a_constant_column() { let alloc = Bump::new(); let data: Vec = data.into_iter().map(Curve25519Scalar::from).collect(); let mut builder = FinalRoundBuilder::new(2, Vec::new()); - let sign = prover_evaluate_sign(&mut builder, &alloc, &data, false); + let sign = prover_evaluate_sign(&mut builder, &alloc, &data); assert_eq!(sign, [false; 3]); assert_eq!(builder.bit_distributions(), [dist]); } @@ -27,7 +30,7 @@ fn prover_evaluation_generates_the_bit_distribution_of_a_negative_constant_colum let alloc = Bump::new(); let data: Vec = data.into_iter().map(Curve25519Scalar::from).collect(); let mut builder = FinalRoundBuilder::new(2, Vec::new()); - let sign = prover_evaluate_sign(&mut builder, &alloc, &data, false); + let sign = prover_evaluate_sign(&mut builder, &alloc, &data); assert_eq!(sign, [true; 3]); assert_eq!(builder.bit_distributions(), [dist]); } @@ -62,7 +65,7 @@ fn we_can_verify_a_constant_decomposition() { ); let data_eval = (&data).evaluate_at_point(&evaluation_point); let eval = verifier_evaluate_sign(&mut builder, data_eval, *one_eval).unwrap(); - assert_eq!(eval, Curve25519Scalar::zero()); + assert_eq!(eval, Curve25519Scalar::ZERO); } #[test]