From 1bc0c4cc309c298a0ca7fd68e0b69de26121abd3 Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Thu, 9 Jan 2025 09:15:02 -0500 Subject: [PATCH 01/12] feat: add scalar and u256 conversion functions --- Cargo.toml | 1 + crates/proof-of-sql/Cargo.toml | 1 + crates/proof-of-sql/src/base/scalar/scalar_ext.rs | 14 ++++++++++++++ 3 files changed, 16 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 11f325628..9f5bc2fee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ bit-iter = { version = "1.1.1" } bigdecimal = { version = "0.4.5", default-features = false, features = ["serde"] } blake3 = { version = "1.3.3", default-features = false } blitzar = { version = "4.0.0" } +bnum = { version = "0.3.0" } bumpalo = { version = "3.11.0" } bytemuck = {version = "1.16.3", features = ["derive"]} byte-slice-cast = { version = "1.2.1", default-features = false } diff --git a/crates/proof-of-sql/Cargo.toml b/crates/proof-of-sql/Cargo.toml index bb02f7cac..b994c2998 100644 --- a/crates/proof-of-sql/Cargo.toml +++ b/crates/proof-of-sql/Cargo.toml @@ -27,6 +27,7 @@ bit-iter = { workspace = true } bigdecimal = { workspace = true } blake3 = { workspace = true } blitzar = { workspace = true, optional = true } +bnum = { workspace = true } bumpalo = { workspace = true, features = ["collections"] } bytemuck = { workspace = true } byte-slice-cast = { workspace = true } diff --git a/crates/proof-of-sql/src/base/scalar/scalar_ext.rs b/crates/proof-of-sql/src/base/scalar/scalar_ext.rs index 3739a1cc5..02f7eea38 100644 --- a/crates/proof-of-sql/src/base/scalar/scalar_ext.rs +++ b/crates/proof-of-sql/src/base/scalar/scalar_ext.rs @@ -1,4 +1,5 @@ use super::Scalar; +use bnum::types::U256; use core::cmp::Ordering; /// Extention trait for blanket implementations for `Scalar` types. @@ -17,7 +18,20 @@ pub trait ScalarExt: Scalar { _ => Ordering::Greater, } } + + #[must_use] + /// Converts a U256 to Scalar, wrapping as needed + fn from_wrapping(value: U256) -> Self { + let value_as_limbs: [u64; 4] = value.into(); + Self::from(value_as_limbs) + } + + /// Converts a Scalar to U256. Note that any values above `MAX_SIGNED` shall remain positive, even if they are representative of negative values. + fn into_u256_wrapping(self) -> U256 { + U256::from(Into::<[u64; 4]>::into(self)) + } } + impl ScalarExt for S {} #[cfg(test)] From e468ba2e5990d6048c94a943efbccdbbaf8b4e3e Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Thu, 9 Jan 2025 09:15:23 -0500 Subject: [PATCH 02/12] test: test scalar and u256 conversion functions with test scalar --- .../src/base/scalar/test_scalar_test.rs | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs b/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs index 439e833a6..46386dcc4 100644 --- a/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs +++ b/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs @@ -1,6 +1,12 @@ +use super::ScalarExt; use crate::base::scalar::{test_scalar::TestScalar, Scalar}; +use bnum::types::U256; +use core::str::FromStr; use num_traits::Inv; +const MAX_TEST_SCALAR_SIGNED_VALUE_AS_STRING: &str = + "3618502788666131106986593281521497120428558179689953803000975469142727125494"; + #[test] fn we_can_get_test_scalar_constants_from_z_p() { assert_eq!(TestScalar::from(0), TestScalar::ZERO); @@ -10,3 +16,55 @@ fn we_can_get_test_scalar_constants_from_z_p() { assert_eq!(-TestScalar::TWO.inv().unwrap(), TestScalar::MAX_SIGNED); assert_eq!(TestScalar::from(10), TestScalar::TEN); } + +#[test] +fn we_can_convert_u256_to_test_scalar_with_wrapping() { + // ARRANGE + let u256_value = U256::TWO; + + // ACT + let test_scalar = TestScalar::from_wrapping(u256_value); + + // ASSERT + assert_eq!(test_scalar, TestScalar::TWO); +} + +#[test] +fn we_can_convert_u256_to_test_scalar_with_wrapping_of_large_value() { + // ARRANGE + let u256_value = + U256::from_str(MAX_TEST_SCALAR_SIGNED_VALUE_AS_STRING).unwrap() * U256::TWO + U256::ONE; + + // ACT + let test_scalar = TestScalar::from_wrapping(u256_value); + + // ASSERT + assert_eq!(test_scalar, TestScalar::ZERO); +} + +#[test] +fn we_can_convert_test_scalar_to_u256_with_wrapping() { + // ARRANGE + let test_scalar = TestScalar::TWO; + + // ACT + let u256_value = test_scalar.into_u256_wrapping(); + + // ASSERT + assert_eq!(u256_value, U256::TWO); +} + +#[test] +fn we_can_convert_test_scalar_to_256_with_wrapping_of_negative_value() { + // ARRANGE + let test_scalar = -TestScalar::ONE; + + // ACT + let u256: bnum::BUint<4> = test_scalar.into_u256_wrapping(); + + // ASSERT + assert_eq!( + u256, + U256::from_str(MAX_TEST_SCALAR_SIGNED_VALUE_AS_STRING).unwrap() * U256::TWO + ); +} From 05a9cfb5c70a58de9e9088c992bb979ef44b5a39 Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Thu, 9 Jan 2025 14:51:30 -0500 Subject: [PATCH 03/12] test: add random u256 test to test wrapping --- .../src/base/scalar/test_scalar_test.rs | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs b/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs index 46386dcc4..a3ed78c38 100644 --- a/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs +++ b/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs @@ -3,10 +3,18 @@ use crate::base::scalar::{test_scalar::TestScalar, Scalar}; use bnum::types::U256; use core::str::FromStr; use num_traits::Inv; +use rand::{rngs::StdRng, Rng, SeedableRng}; const MAX_TEST_SCALAR_SIGNED_VALUE_AS_STRING: &str = "3618502788666131106986593281521497120428558179689953803000975469142727125494"; +fn random_u256(seed: u64) -> U256 { + let mut rng = StdRng::seed_from_u64(seed); + let mut bytes = [0u64; 4]; + rng.fill(&mut bytes); + U256::from(bytes) +} + #[test] fn we_can_get_test_scalar_constants_from_z_p() { assert_eq!(TestScalar::from(0), TestScalar::ZERO); @@ -55,7 +63,7 @@ fn we_can_convert_test_scalar_to_u256_with_wrapping() { } #[test] -fn we_can_convert_test_scalar_to_256_with_wrapping_of_negative_value() { +fn we_can_convert_test_scalar_to_u256_with_wrapping_of_negative_value() { // ARRANGE let test_scalar = -TestScalar::ONE; @@ -68,3 +76,19 @@ fn we_can_convert_test_scalar_to_256_with_wrapping_of_negative_value() { U256::from_str(MAX_TEST_SCALAR_SIGNED_VALUE_AS_STRING).unwrap() * U256::TWO ); } + +#[test] +fn we_can_convert_u256_to_test_scalar_with_wrapping_of_random_u256() { + // ARRANGE + let random_u256 = random_u256(100); + let random_u256_after_wrapping = random_u256 + % (U256::from_str(MAX_TEST_SCALAR_SIGNED_VALUE_AS_STRING).unwrap() * U256::TWO + U256::ONE); + assert_ne!(random_u256, random_u256_after_wrapping); + + // ACT + let test_scalar = TestScalar::from_wrapping(random_u256); + let expected_scalar = TestScalar::from_wrapping(random_u256_after_wrapping); + + // ASSERT + assert_eq!(test_scalar, expected_scalar); +} From 5d7fdfa6fa81e6ac9385c2cb767086650f834bbc Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Wed, 8 Jan 2025 12:42:44 -0500 Subject: [PATCH 04/12] refactor: add vary mask getter and use where appropriate --- .../proof-of-sql/src/base/bit/bit_distribution.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/crates/proof-of-sql/src/base/bit/bit_distribution.rs b/crates/proof-of-sql/src/base/bit/bit_distribution.rs index 81db9bd60..331f51219 100644 --- a/crates/proof-of-sql/src/base/bit/bit_distribution.rs +++ b/crates/proof-of-sql/src/base/bit/bit_distribution.rs @@ -37,12 +37,15 @@ impl BitDistribution { Self { or_all, vary_mask } } + pub fn vary_mask(&self) -> U256 { + U256::from(self.vary_mask) + } + + /// # Panics + /// + /// Panics if conversion from `ExpType` to `usize` fails pub fn num_varying_bits(&self) -> usize { - let mut res = 0_usize; - for xi in &self.vary_mask { - res += xi.count_ones() as usize; - } - res + self.vary_mask().count_ones() as usize } pub fn has_varying_sign_bit(&self) -> bool { From f3b0458888e50d2f8dd4865e6b6be5bca2ba2a64 Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Thu, 9 Jan 2025 09:53:52 -0500 Subject: [PATCH 05/12] feat: add leading bit mask and getters to retrieve leading_bit_mask and leading_bit_inverse_mask --- .../proof-of-sql/src/base/bit/bit_distribution.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/crates/proof-of-sql/src/base/bit/bit_distribution.rs b/crates/proof-of-sql/src/base/bit/bit_distribution.rs index 331f51219..6f2daa097 100644 --- a/crates/proof-of-sql/src/base/bit/bit_distribution.rs +++ b/crates/proof-of-sql/src/base/bit/bit_distribution.rs @@ -1,5 +1,6 @@ use crate::base::{bit::make_abs_bit_mask, scalar::Scalar}; use bit_iter::BitIter; +use bnum::types::U256; use core::convert::Into; use serde::{Deserialize, Serialize}; @@ -15,6 +16,8 @@ pub struct BitDistribution { /// 0 otherwise pub or_all: [u64; 4], pub vary_mask: [u64; 4], + /// Identifies all columns that are the identical to the lead column. The lead bit indicates the sign of the last row of data (only relevant if the sign is constant) + pub(crate) leading_bit_mask: [u64; 4], } impl BitDistribution { @@ -41,6 +44,16 @@ impl BitDistribution { U256::from(self.vary_mask) } + /// Identifies all columns that are the identical to the lead column. + pub fn leading_bit_mask(&self) -> U256 { + U256::from(self.leading_bit_mask) | (U256::ONE << 255) + } + + /// Identifies all columns that are the identical to the inverse of the lead column. + pub fn leading_bit_inverse_mask(&self) -> U256 { + (!self.vary_mask() ^ self.leading_bit_mask()) & (U256::MAX >> 1) + } + /// # Panics /// /// Panics if conversion from `ExpType` to `usize` fails From f794f32e5fa44f2c45afbccd6c10575017648db0 Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Wed, 8 Jan 2025 13:04:28 -0500 Subject: [PATCH 06/12] refactor: redesign construction and meaning of BitDistribution --- .../src/base/bit/bit_distribution.rs | 59 ++++++++++--------- .../src/base/bit/bit_mask_utils.rs | 17 ++++++ crates/proof-of-sql/src/base/bit/mod.rs | 2 + 3 files changed, 49 insertions(+), 29 deletions(-) create mode 100644 crates/proof-of-sql/src/base/bit/bit_mask_utils.rs diff --git a/crates/proof-of-sql/src/base/bit/bit_distribution.rs b/crates/proof-of-sql/src/base/bit/bit_distribution.rs index 6f2daa097..f508c2d24 100644 --- a/crates/proof-of-sql/src/base/bit/bit_distribution.rs +++ b/crates/proof-of-sql/src/base/bit/bit_distribution.rs @@ -1,7 +1,10 @@ -use crate::base::{bit::make_abs_bit_mask, scalar::Scalar}; +use super::bit_mask_utils::{is_bit_mask_negative_representation, make_bit_mask}; +use crate::base::scalar::{Scalar, ScalarExt}; +use ark_std::iterable::Iterable; use bit_iter::BitIter; use bnum::types::U256; use core::convert::Into; +use itertools::Itertools; use serde::{Deserialize, Serialize}; /// Describe the distribution of bit values in a table column @@ -15,29 +18,37 @@ pub struct BitDistribution { /// `1` if `x_s & (1 << i) != x_t & (1 << i)` for some `s != t` /// 0 otherwise pub or_all: [u64; 4], - pub vary_mask: [u64; 4], + /// Identifies all columns that are identical to the leading column (the sign column). The lead bit indicates if the sign column is constant + pub(crate) vary_mask: [u64; 4], /// Identifies all columns that are the identical to the lead column. The lead bit indicates the sign of the last row of data (only relevant if the sign is constant) pub(crate) leading_bit_mask: [u64; 4], } impl BitDistribution { pub fn new + Clone>(data: &[T]) -> Self { - if data.is_empty() { - return Self { - or_all: [0; 4], - vary_mask: [0; 4], - }; - } - let mut or_all = make_abs_bit_mask(data[0].clone().into()); - let mut vary_mask = [0; 4]; - for x in data.iter().skip(1) { - let mask = make_abs_bit_mask((*x).clone().into()); - for i in 0..4 { - vary_mask[i] |= or_all[i] ^ mask[i]; - or_all[i] |= mask[i]; - } + let bit_masks = data.iter().cloned().map(Into::::into).map(make_bit_mask); + let (sign_mask, inverse_sign_mask) = + bit_masks + .clone() + .fold((U256::MAX, U256::MAX), |acc, bit_mask| { + let bit_mask = if is_bit_mask_negative_representation(bit_mask) { + bit_mask ^ (U256::MAX >> 1) + } else { + bit_mask + }; + (acc.0 & bit_mask, acc.1 & !bit_mask) + }); + let vary_mask_bit = U256::from( + !bit_masks + .map(is_bit_mask_negative_representation) + .all_equal(), + ) << 255; + let vary_mask: U256 = !(sign_mask | inverse_sign_mask) | vary_mask_bit; + + Self { + leading_bit_mask: sign_mask.into(), + vary_mask: vary_mask.into(), } - Self { or_all, vary_mask } } pub fn vary_mask(&self) -> U256 { @@ -75,12 +86,7 @@ impl BitDistribution { /// can be used after deserializing a [`BitDistribution`] from an untrusted /// source. pub fn is_valid(&self) -> bool { - for (m, o) in self.vary_mask.iter().zip(self.or_all) { - if m & !o != 0 { - return false; - } - } - true + (self.vary_mask() & self.leading_bit_mask()) & (U256::MAX >> 1) == U256::ZERO } /// In order to avoid cases with large numbers where there can be both a positive and negative @@ -89,17 +95,12 @@ impl BitDistribution { /// Currently this is set to be the minimal value that will include the sum of two signed 128-bit /// integers. The range will likely be expanded in the future as we support additional expressions. pub fn is_within_acceptable_range(&self) -> bool { - // handle the case of everything zero - if self.num_varying_bits() == 0 && self.constant_part() == [0; 4] { - return true; - } - // signed 128 bit numbers range from // -2^127 to 2^127-1 // the maximum absolute value of the sum of two signed 128-integers is // then // 2 * (2^127) = 2^128 - self.most_significant_abs_bit() <= 128 + (self.leading_bit_inverse_mask() >> 128) == (U256::MAX >> 129) } /// If `{b_i}` represents the non-varying 1-bits of the absolute values, return the value diff --git a/crates/proof-of-sql/src/base/bit/bit_mask_utils.rs b/crates/proof-of-sql/src/base/bit/bit_mask_utils.rs new file mode 100644 index 000000000..4fa5282c0 --- /dev/null +++ b/crates/proof-of-sql/src/base/bit/bit_mask_utils.rs @@ -0,0 +1,17 @@ +use crate::base::scalar::ScalarExt; +use bnum::types::U256; + +pub fn make_bit_mask(x: S) -> U256 { + let x_as_u256 = x.into_u256_wrapping(); + if x > S::MAX_SIGNED { + x_as_u256 - S::into_u256_wrapping(S::MAX_SIGNED) + (U256::ONE << 255) + - S::into_u256_wrapping(S::MAX_SIGNED) + - U256::ONE + } else { + x_as_u256 + (U256::ONE << 255) + } +} + +pub fn is_bit_mask_negative_representation(bit_mask: U256) -> bool { + bit_mask & (U256::ONE << 255) == U256::ZERO +} diff --git a/crates/proof-of-sql/src/base/bit/mod.rs b/crates/proof-of-sql/src/base/bit/mod.rs index 18973dc62..5daaf35e0 100644 --- a/crates/proof-of-sql/src/base/bit/mod.rs +++ b/crates/proof-of-sql/src/base/bit/mod.rs @@ -1,3 +1,5 @@ +pub(crate) mod bit_mask_utils; + mod abs_bit_mask; pub use abs_bit_mask::*; From f6facaa4ce9499660e5d50df175624086be953eb Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Thu, 9 Jan 2025 09:38:43 -0500 Subject: [PATCH 07/12] refactor: modify sign_expr and inequality_expr to reflect changes in BitDistribution --- .../src/base/bit/bit_distribution.rs | 26 ++++ .../proof-of-sql/src/base/bit/bit_matrix.rs | 28 ++-- .../src/sql/proof_exprs/inequality_expr.rs | 18 +-- .../src/sql/proof_gadgets/sign_expr.rs | 122 +++++++----------- 4 files changed, 85 insertions(+), 109 deletions(-) diff --git a/crates/proof-of-sql/src/base/bit/bit_distribution.rs b/crates/proof-of-sql/src/base/bit/bit_distribution.rs index f508c2d24..11756c119 100644 --- a/crates/proof-of-sql/src/base/bit/bit_distribution.rs +++ b/crates/proof-of-sql/src/base/bit/bit_distribution.rs @@ -72,6 +72,19 @@ impl BitDistribution { self.vary_mask().count_ones() as usize } + /// # Panics + /// + /// Panics if lead bit varies but `bit_evals` is empty + pub fn leading_bit_eval(&self, bit_evals: &[S], one_eval: S) -> S { + if U256::from(self.vary_mask) & (U256::ONE << 255) != U256::ZERO { + *bit_evals.last().expect("bit_evals should be non-empty") + } else if U256::from(self.leading_bit_mask) & (U256::ONE << 255) == U256::ZERO { + S::ZERO + } else { + one_eval + } + } + pub fn has_varying_sign_bit(&self) -> bool { self.vary_mask[3] & (1 << 63) != 0 } @@ -103,6 +116,19 @@ impl BitDistribution { (self.leading_bit_inverse_mask() >> 128) == (U256::MAX >> 129) } + /// Iterate over each varying bit + /// + /// # Panics + /// + /// The panic shouldn't be mathematically possible + pub fn vary_mask_iter(&self) -> impl Iterator + '_ { + (0..4).flat_map(|i| { + BitIter::from(self.vary_mask[i]) + .iter() + .map(move |pos| u8::try_from(i * 64 + pos).expect("index greater than 255")) + }) + } + /// If `{b_i}` represents the non-varying 1-bits of the absolute values, return the value /// `sum_i b_i 2 ^ i` pub fn constant_part(&self) -> [u64; 4] { diff --git a/crates/proof-of-sql/src/base/bit/bit_matrix.rs b/crates/proof-of-sql/src/base/bit/bit_matrix.rs index bd53b229b..f8c784be0 100644 --- a/crates/proof-of-sql/src/base/bit/bit_matrix.rs +++ b/crates/proof-of-sql/src/base/bit/bit_matrix.rs @@ -1,8 +1,7 @@ -use crate::base::{ - bit::{make_abs_bit_mask, BitDistribution}, - scalar::Scalar, -}; +use super::bit_mask_utils::make_bit_mask; +use crate::base::{bit::BitDistribution, scalar::Scalar}; use alloc::vec::Vec; +use bnum::types::U256; use bumpalo::Bump; /// Let `x1, ..., xn` denote the values of a data column. Let @@ -17,25 +16,24 @@ pub fn compute_varying_bit_matrix<'a, S: Scalar>( vals: &[S], dist: &BitDistribution, ) -> Vec<&'a [bool]> { - let n = vals.len(); + let number_of_scalars = vals.len(); let num_varying_bits = dist.num_varying_bits(); - let data: &'a mut [bool] = alloc.alloc_slice_fill_default(n * num_varying_bits); + let data: &'a mut [bool] = alloc.alloc_slice_fill_default(number_of_scalars * num_varying_bits); // decompose - for (i, val) in vals.iter().enumerate() { - let mask = make_abs_bit_mask(*val); - let mut offset = i; - dist.for_each_varying_bit(|int_index: usize, bit_index: usize| { - data[offset] = (mask[int_index] & (1u64 << bit_index)) != 0; - offset += n; - }); + for (scalar_index, val) in vals.iter().enumerate() { + let mask = make_bit_mask(*val); + for (vary_index, bit_index) in dist.vary_mask_iter().enumerate() { + data[scalar_index + vary_index * number_of_scalars] = + (mask & (U256::ONE << bit_index)) != U256::ZERO; + } } // make result let mut res = Vec::with_capacity(num_varying_bits); for bit_index in 0..num_varying_bits { - let first = n * bit_index; - let last = n * (bit_index + 1); + let first = number_of_scalars * bit_index; + let last = number_of_scalars * (bit_index + 1); res.push(&data[first..last]); } res diff --git a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs index f3f416e10..750bb4e77 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs @@ -26,20 +26,12 @@ pub struct InequalityExpr { lhs: Box, rhs: Box, is_lte: bool, - #[cfg(test)] - pub(crate) treat_column_of_zeros_as_negative: bool, } impl InequalityExpr { /// Create a new less than or equal expression pub fn new(lhs: Box, rhs: Box, is_lte: bool) -> Self { - Self { - lhs, - rhs, - is_lte, - #[cfg(test)] - treat_column_of_zeros_as_negative: false, - } + Self { lhs, rhs, is_lte } } } @@ -108,13 +100,7 @@ impl ProofExpr for InequalityExpr { let equals_zero = prover_evaluate_equals_zero(table.num_rows(), builder, alloc, diff); // sign(diff) == -1 - let sign = prover_evaluate_sign( - builder, - alloc, - diff, - #[cfg(test)] - self.treat_column_of_zeros_as_negative, - ); + let sign = prover_evaluate_sign(builder, alloc, diff); // (diff == 0) || (sign(diff) == -1) let res = Column::Boolean(prover_evaluate_or(builder, alloc, equals_zero, sign)); diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs index cf591de76..1e879db77 100644 --- a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs @@ -1,9 +1,12 @@ use super::{verify_constant_abs_decomposition, verify_constant_sign_decomposition}; use crate::{ base::{ - bit::{compute_varying_bit_matrix, BitDistribution}, + bit::{ + bit_mask_utils::{is_bit_mask_negative_representation, make_bit_mask}, + compute_varying_bit_matrix, BitDistribution, + }, proof::ProofError, - scalar::Scalar, + scalar::{Scalar, ScalarExt}, }, sql::proof::{ FinalRoundBuilder, SumcheckSubpolynomialTerm, SumcheckSubpolynomialType, @@ -11,6 +14,7 @@ use crate::{ }, }; use alloc::{boxed::Box, vec, vec::Vec}; +use bnum::types::U256; use bumpalo::Bump; /// Compute the sign bit for a column of scalars. @@ -25,23 +29,13 @@ pub fn result_evaluate_sign<'a, S: Scalar>( expr: &'a [S], ) -> &'a [bool] { assert_eq!(table_length, expr.len()); - // bit_distribution - let dist = BitDistribution::new::(expr); - - // handle the constant case - if dist.num_varying_bits() == 0 { - return alloc.alloc_slice_fill_copy(table_length, dist.sign_bit()); - } - - // prove that the bits are binary - let bits = compute_varying_bit_matrix(alloc, expr, &dist); - if !dist.has_varying_sign_bit() { - return alloc.alloc_slice_fill_copy(table_length, dist.sign_bit()); - } - - let result = bits.last().unwrap(); - assert_eq!(table_length, result.len()); - result + let signs = expr + .iter() + .map(|s| make_bit_mask(*s)) + .map(is_bit_mask_negative_representation) + .collect::>(); + assert_eq!(table_length, signs.len()); + alloc.alloc_slice_copy(&signs) } /// Prove the sign decomposition for a column of scalars. @@ -59,39 +53,25 @@ pub fn prover_evaluate_sign<'a, S: Scalar>( builder: &mut FinalRoundBuilder<'a, S>, alloc: &'a Bump, expr: &'a [S], - #[cfg(test)] treat_column_of_zeros_as_negative: bool, ) -> &'a [bool] { - let table_length = expr.len(); // bit_distribution let dist = BitDistribution::new::(expr); - #[cfg(test)] - let dist = { - let mut dist = dist; - if treat_column_of_zeros_as_negative && dist.vary_mask == [0; 4] { - dist.or_all[3] = 1 << 63; - } - dist - }; builder.produce_bit_distribution(dist.clone()); - // handle the constant case - if dist.num_varying_bits() == 0 { - return alloc.alloc_slice_fill_copy(table_length, dist.sign_bit()); - } - - // prove that the bits are binary - let bits = compute_varying_bit_matrix(alloc, expr, &dist); - prove_bits_are_binary(builder, &bits); - if !dist.has_varying_sign_bit() { - return alloc.alloc_slice_fill_copy(table_length, dist.sign_bit()); - } - - if dist.num_varying_bits() > 1 { - prove_bit_decomposition(builder, alloc, expr, &bits, &dist); + if dist.num_varying_bits() > 0 { + // prove that the bits are binary + let bits = compute_varying_bit_matrix(alloc, expr, &dist); + prove_bits_are_binary(builder, &bits); } // This might panic if `bits.last()` returns `None`. - bits.last().unwrap() + + let signs = expr + .iter() + .map(|s| make_bit_mask(*s)) + .map(is_bit_mask_negative_representation) + .collect::>(); + alloc.alloc_slice_copy(&signs) } /// Verify the sign decomposition for a column of scalars. @@ -120,19 +100,11 @@ pub fn verifier_evaluate_sign( // establish that the bits are binary verify_bits_are_binary(builder, &bit_evals)?; - // handle the special case of the sign bit being constant - if !dist.has_varying_sign_bit() { - return verifier_const_sign_evaluate(&dist, eval, one_eval, &bit_evals); - } - - // handle the special case of the absolute part being constant - if dist.num_varying_bits() == 1 { - verify_constant_abs_decomposition(&dist, eval, one_eval, bit_evals[0])?; - } else { - verify_bit_decomposition(builder, eval, one_eval, &bit_evals, &dist)?; - } - - Ok(*bit_evals.last().unwrap()) + verify_bit_decomposition(eval, one_eval, &bit_evals, &dist) + .then(|| one_eval - dist.leading_bit_eval(&bit_evals, one_eval)) + .ok_or(ProofError::VerificationError { + error: "invalid bit_decomposition", + }) } fn verifier_const_sign_evaluate( @@ -220,29 +192,23 @@ fn prove_bit_decomposition<'a, S: Scalar>( /// Panics if `bit_evals.last()` returns `None`. /// /// This function checks the consistency of the bit evaluations with the expression evaluation. -fn verify_bit_decomposition( - builder: &mut VerificationBuilder, +fn verify_bit_decomposition( expr_eval: S, one_eval: S, bit_evals: &[S], dist: &BitDistribution, -) -> Result<(), ProofError> { - let mut eval = expr_eval; - let sign_eval = bit_evals.last().unwrap(); - let sign_eval = one_eval - S::TWO * *sign_eval; - let mut vary_index = 0; - eval -= sign_eval * S::from(dist.constant_part()); - dist.for_each_abs_varying_bit(|int_index: usize, bit_index: usize| { - let mut mult = [0u64; 4]; - mult[int_index] = 1u64 << bit_index; - let bit_eval = bit_evals[vary_index]; - eval -= S::from(mult) * sign_eval * bit_eval; - vary_index += 1; - }); - builder.try_produce_sumcheck_subpolynomial_evaluation( - SumcheckSubpolynomialType::Identity, - eval, - 2, - )?; - Ok(()) +) -> bool { + let sign_eval = dist.leading_bit_eval(bit_evals, one_eval); + let mut rhs = sign_eval * S::from_wrapping(dist.leading_bit_mask()) + + (one_eval - sign_eval) * S::from_wrapping(dist.leading_bit_inverse_mask()) + - one_eval * S::from_wrapping(U256::ONE << 255); + + for (vary_index, bit_index) in dist.vary_mask_iter().enumerate() { + if bit_index != 255 { + let mult = U256::ONE << bit_index; + let bit_eval = bit_evals[vary_index]; + rhs += S::from_wrapping(mult) * bit_eval; + } + } + rhs == expr_eval } From 3af99978b07249fe91bd4850dafca0436d614f72 Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Wed, 8 Jan 2025 13:32:57 -0500 Subject: [PATCH 08/12] test: modify tests as appropriate --- .../src/base/bit/bit_distribution_test.rs | 287 ++++++------------ .../src/base/bit/bit_matrix_test.rs | 14 +- .../src/sql/proof/query_proof_test.rs | 4 +- .../sql/proof_exprs/inequality_expr_test.rs | 9 +- .../src/sql/proof_gadgets/sign_expr_test.rs | 13 +- 5 files changed, 111 insertions(+), 216 deletions(-) diff --git a/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs b/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs index f0fc693c2..44900dcc8 100644 --- a/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs +++ b/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs @@ -1,34 +1,27 @@ use super::*; -use crate::base::scalar::test_scalar::TestScalar; -use num_traits::{One, Zero}; +use crate::base::scalar::{test_scalar::TestScalar, ScalarExt}; +use bnum::types::U256; #[test] fn we_can_compute_the_bit_distribution_of_an_empty_slice() { let data: Vec = vec![]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 0); - assert!(!dist.has_varying_sign_bit()); - assert!(!dist.sign_bit()); assert!(dist.is_valid()); - assert_eq!(TestScalar::from(dist.constant_part()), TestScalar::zero()); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping(U256::MAX) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from(0) + ); + assert_eq!( + TestScalar::from_wrapping(dist.vary_mask()), + TestScalar::from(0) + ); - let mut cnt = 0; - dist.for_each_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!(dist.vary_mask_iter().count(), 0); } #[test] @@ -37,34 +30,21 @@ fn we_can_compute_the_bit_distribution_of_a_slice_with_a_single_element() { let data: Vec = vec![val]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 0); - assert!(!dist.has_varying_sign_bit()); - assert!(!dist.sign_bit()); assert!(dist.is_valid()); assert_eq!( - TestScalar::from(dist.constant_part()), - TestScalar::from(val) + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping((U256::ONE << 2) | (U256::ONE << 10) | (U256::ONE << 255)) + ); + assert_eq!( + TestScalar::from_wrapping(dist.vary_mask()), + TestScalar::from(0) + ); + assert_eq!( + dist.leading_bit_inverse_mask(), + ((U256::ONE << 2) | (U256::ONE << 10) | (U256::ONE << 255)) ^ U256::MAX ); - assert_eq!(dist.most_significant_abs_bit(), 10); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert!(pos == 2 || pos == 10); - cnt += 1; - }); - assert_eq!(cnt, 2); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - let mut cnt = 0; - dist.for_each_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!(dist.vary_mask_iter().count(), 0); } #[test] @@ -72,38 +52,19 @@ fn we_can_compute_the_bit_distribution_of_a_slice_with_one_varying_bits() { let data: Vec = vec![(1 << 2) | (1 << 10), (1 << 2) | (1 << 10) | (1 << 21)]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 1); - assert!(!dist.has_varying_sign_bit()); - assert!(!dist.sign_bit()); assert!(dist.is_valid()); assert_eq!( - TestScalar::from(dist.constant_part()), - TestScalar::from((1 << 10) | (1 << 2)) + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping((U256::ONE << 2) | (U256::ONE << 10) | (U256::ONE << 255)) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping( + (U256::FOUR | (U256::ONE << 10) | (U256::ONE << 21) | (U256::ONE << 255)) ^ U256::MAX + ) ); - assert_eq!(dist.most_significant_abs_bit(), 21); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert!(pos == 2 || pos == 10); - cnt += 1; - }); - assert_eq!(cnt, 2); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert_eq!(pos, 21); - cnt += 1; - }); - assert_eq!(cnt, 1); - let mut cnt = 0; - dist.for_each_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert_eq!(pos, 21); - cnt += 1; - }); - assert_eq!(cnt, 1); + assert_eq!(dist.vary_mask_iter().count(), 1); } #[test] @@ -115,38 +76,29 @@ fn we_can_compute_the_bit_distribution_of_a_slice_with_multiple_varying_bits() { ]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 4); - assert!(!dist.has_varying_sign_bit()); - assert!(!dist.sign_bit()); assert!(dist.is_valid()); + assert_eq!( - TestScalar::from(dist.constant_part()), - TestScalar::from(1 << 10) + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping((U256::ONE << 10) | (U256::ONE << 255)) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping( + (U256::FOUR + | U256::EIGHT + | (U256::ONE << 10) + | (U256::ONE << 21) + | (U256::ONE << 50) + | (U256::ONE << 255)) + ^ U256::MAX + ) ); - assert_eq!(dist.most_significant_abs_bit(), 50); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert_eq!(pos, 10); - cnt += 1; - }); - assert_eq!(cnt, 1); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert!(pos == 2 || pos == 3 || pos == 21 || pos == 50); - cnt += 1; - }); - assert_eq!(cnt, 4); - let mut cnt = 0; - dist.for_each_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert!(pos == 2 || pos == 3 || pos == 21 || pos == 50); - cnt += 1; - }); - assert_eq!(cnt, 4); + for i in dist.vary_mask_iter() { + assert!(i == 2 || i == 3 || i == 21 || i == 50); + } + assert_eq!(dist.vary_mask_iter().count(), 4); } #[test] @@ -154,63 +106,34 @@ fn we_can_compute_the_bit_distribution_of_negative_values() { let data: Vec = vec![-1]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 0); - assert!(!dist.has_varying_sign_bit()); - assert!(dist.sign_bit()); assert!(dist.is_valid()); - assert_eq!(TestScalar::from(dist.constant_part()), TestScalar::one()); - assert_eq!(dist.most_significant_abs_bit(), 0); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert_eq!(pos, 0); - cnt += 1; - }); - assert_eq!(cnt, 1); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping(U256::ONE << 255) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping(U256::MAX ^ (U256::ONE << 255)) + ); - let mut cnt = 0; - dist.for_each_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!(dist.vary_mask_iter().count(), 0); } #[test] fn we_can_compute_the_bit_distribution_of_values_with_different_signs() { let data: Vec = vec![-1, 1]; let dist = BitDistribution::new::(&data); - assert_eq!(dist.num_varying_bits(), 1); - assert!(dist.has_varying_sign_bit()); - assert_eq!(TestScalar::from(dist.constant_part()), TestScalar::one()); - assert_eq!(dist.most_significant_abs_bit(), 0); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert_eq!(pos, 0); - cnt += 1; - }); - assert_eq!(cnt, 1); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!(dist.num_varying_bits(), 2); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping(U256::ONE << 255) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping(U256::MAX ^ (U256::ONE | (U256::ONE << 255))) + ); - let mut cnt = 0; - dist.for_each_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 3); - assert_eq!(pos, 63); - cnt += 1; - }); - assert_eq!(cnt, 1); + assert_eq!(dist.vary_mask_iter().count(), 2); } #[test] @@ -218,31 +141,17 @@ fn we_can_compute_the_bit_distribution_of_values_with_different_signs_and_values let data: Vec = vec![4, -1, 1]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 3); - assert!(dist.has_varying_sign_bit()); assert!(dist.is_valid()); - assert_eq!(TestScalar::from(dist.constant_part()), TestScalar::zero()); - assert_eq!(dist.most_significant_abs_bit(), 2); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|i: usize, pos: usize| { - assert_eq!(i, 0); - assert!(pos == 0 || pos == 2); - cnt += 1; - }); - assert_eq!(cnt, 2); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping(U256::ONE << 255) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping(U256::MAX ^ (U256::FIVE | (U256::ONE << 255))) + ); - let mut cnt = 0; - dist.for_each_varying_bit(|i: usize, pos: usize| { - assert!((i == 0 && (pos == 0 || pos == 2)) || (i == 3 && pos == 63)); - cnt += 1; - }); - assert_eq!(cnt, 3); + assert_eq!(dist.vary_mask_iter().count(), 3); } #[test] @@ -252,40 +161,24 @@ fn we_can_compute_the_bit_distribution_of_values_larger_than_64_bit_integers() { let data: Vec = vec![TestScalar::from_bigint(val)]; let dist = BitDistribution::new::(&data); assert_eq!(dist.num_varying_bits(), 0); - assert!(!dist.has_varying_sign_bit()); assert!(dist.is_valid()); assert_eq!( - TestScalar::from(dist.constant_part()), - TestScalar::from_bigint(val) + TestScalar::from_wrapping(dist.leading_bit_mask()), + TestScalar::from_wrapping((U256::ONE << 203) | (U256::ONE << 255)) + ); + assert_eq!( + TestScalar::from_wrapping(dist.leading_bit_inverse_mask()), + TestScalar::from_wrapping(U256::MAX ^ ((U256::ONE << 203) | (U256::ONE << 255))) ); - assert_eq!(dist.most_significant_abs_bit(), 64 * 3 + 11); - - let mut cnt = 0; - dist.for_each_abs_constant_bit(|i: usize, pos: usize| { - assert_eq!(i, 3); - assert_eq!(pos, 11); - cnt += 1; - }); - assert_eq!(cnt, 1); - - let mut cnt = 0; - dist.for_each_abs_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); - let mut cnt = 0; - dist.for_each_varying_bit(|_i: usize, _pos: usize| { - cnt += 1; - }); - assert_eq!(cnt, 0); + assert_eq!(dist.vary_mask_iter().count(), 0); } #[test] fn we_can_detect_invalid_bit_distributions() { let dist = BitDistribution { - or_all: [0, 0, 0, 0], vary_mask: [1, 0, 0, 0], + leading_bit_mask: [1, 0, 0, 0], }; assert!(!dist.is_valid()); } diff --git a/crates/proof-of-sql/src/base/bit/bit_matrix_test.rs b/crates/proof-of-sql/src/base/bit/bit_matrix_test.rs index 25d0c0d37..511a92b66 100644 --- a/crates/proof-of-sql/src/base/bit/bit_matrix_test.rs +++ b/crates/proof-of-sql/src/base/bit/bit_matrix_test.rs @@ -38,9 +38,11 @@ fn we_can_compute_the_bit_matrix_for_data_with_a_varying_sign_bit() { let dist = BitDistribution::new::(&data); let alloc = Bump::new(); let matrix = compute_varying_bit_matrix(&alloc, &data, &dist); - assert_eq!(matrix.len(), 1); - let slice1 = vec![false, true]; + assert_eq!(matrix.len(), 2); + let slice1 = vec![true, true]; + let slice2 = vec![true, false]; assert_eq!(matrix[0], slice1); + assert_eq!(matrix[1], slice2); } #[test] @@ -62,11 +64,13 @@ fn we_can_compute_the_bit_matrix_for_data_with_varying_bits_and_constant_bits() let dist = BitDistribution::new::(&data); let alloc = Bump::new(); let matrix = compute_varying_bit_matrix(&alloc, &data, &dist); - assert_eq!(matrix.len(), 2); - let slice1 = vec![true, false]; - let slice2 = vec![false, true]; + assert_eq!(matrix.len(), 3); + let slice1 = vec![true, true]; + let slice2 = vec![true, true]; + let slice3 = vec![true, false]; assert_eq!(matrix[0], slice1); assert_eq!(matrix[1], slice2); + assert_eq!(matrix[2], slice3); } #[test] diff --git a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs index 878a5915b..856811d82 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs @@ -39,7 +39,7 @@ impl Default for TrivialTestProofPlan { evaluation: 0, produce_length: true, bit_distribution: Some(BitDistribution { - or_all: [0; 4], + leading_bit_mask: [0; 4], vary_mask: [0; 4], }), } @@ -229,7 +229,7 @@ fn verify_fails_if_the_number_of_bit_distributions_is_not_enough() { fn verify_fails_if_a_bit_distribution_is_invalid() { let expr = TrivialTestProofPlan { bit_distribution: Some(BitDistribution { - or_all: [1; 4], + leading_bit_mask: [1; 4], vary_mask: [1; 4], }), ..Default::default() diff --git a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs index fa83572d3..2ef62a83d 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs @@ -11,7 +11,7 @@ use crate::{ parse::ConversionError, proof::{exercise_verification, VerifiableQueryResult}, proof_exprs::{test_utility::*, DynProofExpr, ProofExpr}, - proof_plans::{test_utility::*, DynProofPlan}, + proof_plans::test_utility::*, }, }; use bumpalo::Bump; @@ -480,16 +480,11 @@ fn the_sign_can_be_0_or_1_for_a_constant_column_of_zeros() { let data = owned_table([bigint("a", [0_i64, 0, 0]), bigint("b", [1_i64, 2, 3])]); let t = "sxt.t".parse().unwrap(); let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); - let mut ast = filter( + let ast = filter( cols_expr_plan(t, &["b"], &accessor), tab(t), lte(column(t, "a", &accessor), const_bigint(0)), ); - if let DynProofPlan::Filter(filter) = &mut ast { - if let DynProofExpr::Inequality(lte) = &mut filter.where_clause { - lte.treat_column_of_zeros_as_negative = true; - } - } let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); exercise_verification(&verifiable_res, &ast, &accessor, t); let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr_test.rs b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr_test.rs index 73d7ccc4e..2ee86e44b 100644 --- a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr_test.rs @@ -1,12 +1,15 @@ use super::{prover_evaluate_sign, result_evaluate_sign, verifier_evaluate_sign}; use crate::{ - base::{bit::BitDistribution, polynomial::MultilinearExtension, scalar::Curve25519Scalar}, + base::{ + bit::BitDistribution, + polynomial::MultilinearExtension, + scalar::{Curve25519Scalar, Scalar}, + }, sql::proof::{ FinalRoundBuilder, SumcheckMleEvaluations, SumcheckRandomScalars, VerificationBuilder, }, }; use bumpalo::Bump; -use num_traits::Zero; #[test] fn prover_evaluation_generates_the_bit_distribution_of_a_constant_column() { @@ -15,7 +18,7 @@ fn prover_evaluation_generates_the_bit_distribution_of_a_constant_column() { let alloc = Bump::new(); let data: Vec = data.into_iter().map(Curve25519Scalar::from).collect(); let mut builder = FinalRoundBuilder::new(2, Vec::new()); - let sign = prover_evaluate_sign(&mut builder, &alloc, &data, false); + let sign = prover_evaluate_sign(&mut builder, &alloc, &data); assert_eq!(sign, [false; 3]); assert_eq!(builder.bit_distributions(), [dist]); } @@ -27,7 +30,7 @@ fn prover_evaluation_generates_the_bit_distribution_of_a_negative_constant_colum let alloc = Bump::new(); let data: Vec = data.into_iter().map(Curve25519Scalar::from).collect(); let mut builder = FinalRoundBuilder::new(2, Vec::new()); - let sign = prover_evaluate_sign(&mut builder, &alloc, &data, false); + let sign = prover_evaluate_sign(&mut builder, &alloc, &data); assert_eq!(sign, [true; 3]); assert_eq!(builder.bit_distributions(), [dist]); } @@ -62,7 +65,7 @@ fn we_can_verify_a_constant_decomposition() { ); let data_eval = (&data).evaluate_at_point(&evaluation_point); let eval = verifier_evaluate_sign(&mut builder, data_eval, *one_eval).unwrap(); - assert_eq!(eval, Curve25519Scalar::zero()); + assert_eq!(eval, Curve25519Scalar::ZERO); } #[test] From 0daa2ed849026c9edcea3bcb2ee19ee8c22fe122 Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Thu, 9 Jan 2025 10:01:19 -0500 Subject: [PATCH 09/12] refactor!: drop dead code --- .../proof-of-sql/src/base/bit/abs_bit_mask.rs | 8 -- .../src/base/bit/bit_distribution.rs | 94 -------------- crates/proof-of-sql/src/base/bit/mod.rs | 3 - .../sql/proof_gadgets/bitwise_verification.rs | 64 --------- .../bitwise_verification_test.rs | 122 ------------------ .../proof-of-sql/src/sql/proof_gadgets/mod.rs | 4 - .../src/sql/proof_gadgets/sign_expr.rs | 56 +------- 7 files changed, 1 insertion(+), 350 deletions(-) delete mode 100644 crates/proof-of-sql/src/base/bit/abs_bit_mask.rs delete mode 100644 crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification.rs delete mode 100644 crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification_test.rs diff --git a/crates/proof-of-sql/src/base/bit/abs_bit_mask.rs b/crates/proof-of-sql/src/base/bit/abs_bit_mask.rs deleted file mode 100644 index 205e563bc..000000000 --- a/crates/proof-of-sql/src/base/bit/abs_bit_mask.rs +++ /dev/null @@ -1,8 +0,0 @@ -use crate::base::scalar::Scalar; - -pub fn make_abs_bit_mask(x: S) -> [u64; 4] { - let (sign, x) = if S::MAX_SIGNED < x { (1, -x) } else { (0, x) }; - let mut res: [u64; 4] = x.into(); - res[3] |= sign << 63; - res -} diff --git a/crates/proof-of-sql/src/base/bit/bit_distribution.rs b/crates/proof-of-sql/src/base/bit/bit_distribution.rs index 11756c119..5939ecd12 100644 --- a/crates/proof-of-sql/src/base/bit/bit_distribution.rs +++ b/crates/proof-of-sql/src/base/bit/bit_distribution.rs @@ -10,14 +10,6 @@ use serde::{Deserialize, Serialize}; /// Describe the distribution of bit values in a table column #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct BitDistribution { - /// We use two arrays to track which bits vary - /// and the constant bit values. If - /// `{x_1, ..., x_n}` represents the values [`BitDistribution`] describes, then: - /// `or_all = abs(x_1) | abs(x_2) | ... | abs(x_n)` - /// `vary_mask & (1 << i) =` - /// `1` if `x_s & (1 << i) != x_t & (1 << i)` for some `s != t` - /// 0 otherwise - pub or_all: [u64; 4], /// Identifies all columns that are identical to the leading column (the sign column). The lead bit indicates if the sign column is constant pub(crate) vary_mask: [u64; 4], /// Identifies all columns that are the identical to the lead column. The lead bit indicates the sign of the last row of data (only relevant if the sign is constant) @@ -85,16 +77,6 @@ impl BitDistribution { } } - pub fn has_varying_sign_bit(&self) -> bool { - self.vary_mask[3] & (1 << 63) != 0 - } - - #[allow(clippy::missing_panics_doc)] - pub fn sign_bit(&self) -> bool { - assert!(!self.has_varying_sign_bit()); - self.or_all[3] & (1 << 63) != 0 - } - /// Check if this instance represents a valid bit distribution. `is_valid` /// can be used after deserializing a [`BitDistribution`] from an untrusted /// source. @@ -128,80 +110,4 @@ impl BitDistribution { .map(move |pos| u8::try_from(i * 64 + pos).expect("index greater than 255")) }) } - - /// If `{b_i}` represents the non-varying 1-bits of the absolute values, return the value - /// `sum_i b_i 2 ^ i` - pub fn constant_part(&self) -> [u64; 4] { - let mut val = [0; 4]; - self.for_each_abs_constant_bit(|i: usize, bit: usize| { - val[i] |= 1u64 << bit; - }); - val - } - - /// Iterate over each constant 1-bit for the absolute values - pub fn for_each_abs_constant_bit(&self, mut f: F) - where - F: FnMut(usize, usize), - { - for i in 0..4 { - let bitset = if i == 3 { - !(self.vary_mask[i] | (1 << 63)) - } else { - !self.vary_mask[i] - }; - let bitset = bitset & self.or_all[i]; - for pos in BitIter::from(bitset) { - f(i, pos); - } - } - } - - /// Iterate over each varying bit for the absolute values - pub fn for_each_abs_varying_bit(&self, mut f: F) - where - F: FnMut(usize, usize), - { - for i in 0..4 { - let bitset = if i == 3 { - self.vary_mask[i] & !(1 << 63) - } else { - self.vary_mask[i] - }; - for pos in BitIter::from(bitset) { - f(i, pos); - } - } - } - - /// Iterate over each varying bit for the absolute values and the sign bit - /// if it varies - pub fn for_each_varying_bit(&self, mut f: F) - where - F: FnMut(usize, usize), - { - for i in 0..4 { - let bitset = self.vary_mask[i]; - for pos in BitIter::from(bitset) { - f(i, pos); - } - } - } - - /// Return the position of the most significant bit of the absolute values - /// # Panics - /// Panics if no bits are set to 1 in the bit representation of `or_all`. - pub fn most_significant_abs_bit(&self) -> usize { - let mask = self.or_all[3] & !(1 << 63); - if mask != 0 { - return 64 - (mask.leading_zeros() as usize) - 1 + 3 * 64; - } - for i in (0..3).rev() { - let mask = self.or_all[i]; - if mask != 0 { - return 64 - (mask.leading_zeros() as usize) - 1 + 64 * i; - } - } - panic!("no bits are set"); - } } diff --git a/crates/proof-of-sql/src/base/bit/mod.rs b/crates/proof-of-sql/src/base/bit/mod.rs index 5daaf35e0..f6cc941a3 100644 --- a/crates/proof-of-sql/src/base/bit/mod.rs +++ b/crates/proof-of-sql/src/base/bit/mod.rs @@ -1,8 +1,5 @@ pub(crate) mod bit_mask_utils; -mod abs_bit_mask; -pub use abs_bit_mask::*; - mod bit_distribution; pub use bit_distribution::*; #[cfg(test)] diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification.rs b/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification.rs deleted file mode 100644 index fe4b772e1..000000000 --- a/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification.rs +++ /dev/null @@ -1,64 +0,0 @@ -use crate::base::{bit::BitDistribution, proof::ProofError, scalar::Scalar}; - -#[allow( - clippy::missing_panics_doc, - reason = "All assertions check for validity within the context, ensuring no panic can occur" -)] -/// Given a bit distribution for a column of data with a constant sign, the evaluation of a column -/// of ones, the constant column's evaluation, and the evaluation of varying absolute bits, verify -/// that the bit distribution is correct. -pub fn verify_constant_sign_decomposition( - dist: &BitDistribution, - eval: S, - one_eval: S, - bit_evals: &[S], -) -> Result<(), ProofError> { - assert!( - dist.is_valid() - && dist.is_within_acceptable_range() - && dist.num_varying_bits() == bit_evals.len() - && !dist.has_varying_sign_bit() - ); - let lhs = if dist.sign_bit() { -eval } else { eval }; - let mut rhs = S::from(dist.constant_part()) * one_eval; - let mut vary_index = 0; - dist.for_each_abs_varying_bit(|int_index: usize, bit_index: usize| { - let mut mult = [0u64; 4]; - mult[int_index] = 1u64 << bit_index; - rhs += S::from(mult) * bit_evals[vary_index]; - vary_index += 1; - }); - if lhs == rhs { - Ok(()) - } else { - Err(ProofError::VerificationError { - error: "constant sign bitwise decomposition is invalid", - }) - } -} - -#[allow( - clippy::missing_panics_doc, - reason = "The assertion checks ensure that conditions are valid, preventing panics" -)] -pub fn verify_constant_abs_decomposition( - dist: &BitDistribution, - eval: S, - one_eval: S, - sign_eval: S, -) -> Result<(), ProofError> { - assert!( - dist.is_valid() - && dist.is_within_acceptable_range() - && dist.num_varying_bits() == 1 - && dist.has_varying_sign_bit() - ); - let t = one_eval - S::TWO * sign_eval; - if S::from(dist.constant_part()) * t == eval { - Ok(()) - } else { - Err(ProofError::VerificationError { - error: "constant absolute bitwise decomposition is invalid", - }) - } -} diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification_test.rs b/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification_test.rs deleted file mode 100644 index 4d0c99ef3..000000000 --- a/crates/proof-of-sql/src/sql/proof_gadgets/bitwise_verification_test.rs +++ /dev/null @@ -1,122 +0,0 @@ -use super::{verify_constant_abs_decomposition, verify_constant_sign_decomposition}; -use crate::base::{ - bit::BitDistribution, - scalar::Curve25519Scalar, - slice_ops::{inner_product, slice_cast}, -}; -use ark_std::UniformRand; -use core::iter::repeat_with; - -fn rand_eval_vec(len: usize) -> Vec { - let rng = &mut ark_std::test_rng(); - repeat_with(|| Curve25519Scalar::rand(rng)) - .take(len) - .collect() -} - -#[test] -fn we_can_verify_the_decomposition_of_a_constant_column() { - let data: Vec = - vec![Curve25519Scalar::from(1234), Curve25519Scalar::from(1234)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &[]).is_ok()); -} - -#[test] -fn we_can_verify_the_decomposition_of_a_column_with_constant_sign() { - let data: Vec = - vec![Curve25519Scalar::from(123), Curve25519Scalar::from(122)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - let bits = [inner_product(&slice_cast(&[1, 0]), &eval_vec)]; - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &bits).is_ok()); -} - -#[test] -fn we_can_verify_the_decomposition_of_a_constant_column_with_negative_values() { - let data: Vec = - vec![Curve25519Scalar::from(-1234), Curve25519Scalar::from(-1234)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &[]).is_ok()); -} - -#[test] -fn constant_verification_fails_if_the_commitment_doesnt_match() { - let data: Vec = - vec![Curve25519Scalar::from(1234), Curve25519Scalar::from(1234)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data: Vec = - vec![Curve25519Scalar::from(1235), Curve25519Scalar::from(1234)]; - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &[]).is_err()); -} - -#[test] -fn constant_verification_fails_if_the_sign_bit_doesnt_match() { - let data: Vec = - vec![Curve25519Scalar::from(1234), Curve25519Scalar::from(1234)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data: Vec = - vec![Curve25519Scalar::from(-1234), Curve25519Scalar::from(-1234)]; - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &[]).is_err()); -} - -#[test] -fn constant_verification_fails_if_a_varying_bit_doesnt_match() { - let data: Vec = - vec![Curve25519Scalar::from(1234), Curve25519Scalar::from(1234)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data: Vec = - vec![Curve25519Scalar::from(234), Curve25519Scalar::from(1234)]; - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - assert!(verify_constant_sign_decomposition(&dist, data_eval, one_eval, &[]).is_err()); -} - -#[test] -fn we_can_verify_a_decomposition_with_only_a_varying_sign() { - let data: Vec = vec![Curve25519Scalar::from(-1), Curve25519Scalar::from(1)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - let sign_eval = inner_product(&slice_cast(&[1, 0]), &eval_vec); - assert!(verify_constant_abs_decomposition(&dist, data_eval, one_eval, sign_eval).is_ok()); -} - -#[test] -fn constant_abs_verification_fails_if_the_sign_and_data_dont_match() { - let data: Vec = vec![Curve25519Scalar::from(-1), Curve25519Scalar::from(1)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - let sign_eval = inner_product(&slice_cast(&[0, 1]), &eval_vec); - assert!(verify_constant_abs_decomposition(&dist, data_eval, one_eval, sign_eval).is_err()); -} - -#[test] -fn we_can_verify_a_decomposition_with_only_a_varying_sign_and_magnitude_greater_than_one() { - let data: Vec = - vec![Curve25519Scalar::from(-100), Curve25519Scalar::from(100)]; - let eval_vec = rand_eval_vec(data.len()); - let dist = BitDistribution::new::(&data); - let data_eval = inner_product(&data, &eval_vec); - let one_eval = eval_vec.iter().sum(); - let sign_eval = inner_product(&slice_cast(&[1, 0]), &eval_vec); - assert!(verify_constant_abs_decomposition(&dist, data_eval, one_eval, sign_eval).is_ok()); -} diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs b/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs index 0ec348767..4da2ca170 100644 --- a/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs +++ b/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs @@ -1,8 +1,4 @@ //! This module contains shared proof logic for multiple `ProofExpr` / `ProofPlan` implementations. -mod bitwise_verification; -use bitwise_verification::{verify_constant_abs_decomposition, verify_constant_sign_decomposition}; -#[cfg(test)] -mod bitwise_verification_test; mod sign_expr; pub(crate) use sign_expr::{prover_evaluate_sign, result_evaluate_sign, verifier_evaluate_sign}; pub mod range_check; diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs index 1e879db77..8996d9ffb 100644 --- a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs @@ -8,10 +8,7 @@ use crate::{ proof::ProofError, scalar::{Scalar, ScalarExt}, }, - sql::proof::{ - FinalRoundBuilder, SumcheckSubpolynomialTerm, SumcheckSubpolynomialType, - VerificationBuilder, - }, + sql::proof::{FinalRoundBuilder, SumcheckSubpolynomialType, VerificationBuilder}, }; use alloc::{boxed::Box, vec, vec::Vec}; use bnum::types::U256; @@ -107,20 +104,6 @@ pub fn verifier_evaluate_sign( }) } -fn verifier_const_sign_evaluate( - dist: &BitDistribution, - eval: S, - one_eval: S, - bit_evals: &[S], -) -> Result { - verify_constant_sign_decomposition(dist, eval, one_eval, bit_evals)?; - if dist.sign_bit() { - Ok(one_eval) - } else { - Ok(S::zero()) - } -} - fn prove_bits_are_binary<'a, S: Scalar>( builder: &mut FinalRoundBuilder<'a, S>, bits: &[&'a [bool]], @@ -151,43 +134,6 @@ fn verify_bits_are_binary( Ok(()) } -/// # Panics -/// Panics if `bits.last()` returns `None`. -/// -/// This function generates subpolynomial terms for sumcheck, involving the scalar expression and its bit decomposition. -fn prove_bit_decomposition<'a, S: Scalar>( - builder: &mut FinalRoundBuilder<'a, S>, - alloc: &'a Bump, - expr: &'a [S], - bits: &[&'a [bool]], - dist: &BitDistribution, -) { - let sign_mle = bits.last().unwrap(); - let sign_mle: &[_] = - alloc.alloc_slice_fill_with(sign_mle.len(), |i| 1 - 2 * i32::from(sign_mle[i])); - let mut terms: Vec> = Vec::new(); - - // expr - terms.push((S::one(), vec![Box::new(expr)])); - - // expr bit decomposition - let const_part = S::from(dist.constant_part()); - if !const_part.is_zero() { - terms.push((-const_part, vec![Box::new(sign_mle)])); - } - let mut vary_index = 0; - dist.for_each_abs_varying_bit(|int_index: usize, bit_index: usize| { - let mut mult = [0u64; 4]; - mult[int_index] = 1u64 << bit_index; - terms.push(( - -S::from(mult), - vec![Box::new(sign_mle), Box::new(bits[vary_index])], - )); - vary_index += 1; - }); - builder.produce_sumcheck_subpolynomial(SumcheckSubpolynomialType::Identity, terms); -} - /// # Panics /// Panics if `bit_evals.last()` returns `None`. /// From 7a91333b077e739a98039e0caf03f95d03613db3 Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Thu, 9 Jan 2025 10:03:49 -0500 Subject: [PATCH 10/12] test: add new tests to test sign_expr and bit_distribution --- .../src/base/bit/bit_distribution_test.rs | 204 +++++++++++++++++- .../src/base/bit/bit_mask_utils_test.rs | 56 +++++ crates/proof-of-sql/src/base/bit/mod.rs | 2 + .../src/sql/proof_gadgets/sign_expr.rs | 65 +++++- 4 files changed, 325 insertions(+), 2 deletions(-) create mode 100644 crates/proof-of-sql/src/base/bit/bit_mask_utils_test.rs diff --git a/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs b/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs index 44900dcc8..b63fb4716 100644 --- a/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs +++ b/crates/proof-of-sql/src/base/bit/bit_distribution_test.rs @@ -1,7 +1,209 @@ use super::*; -use crate::base::scalar::{test_scalar::TestScalar, ScalarExt}; +use crate::base::scalar::{test_scalar::TestScalar, Scalar, ScalarExt}; use bnum::types::U256; +// vary_mask function start + +#[test] +fn we_can_get_u256_version_of_vary_mask_for_zero() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0; 4], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_vary_mask = bit_distribution.vary_mask(); + + // ASSERT + assert_eq!(u256_vary_mask, U256::ZERO); +} + +#[test] +fn we_can_get_u256_version_of_vary_mask_for_one() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [1, 0, 0, 0], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_vary_mask = bit_distribution.vary_mask(); + + // ASSERT + assert_eq!(u256_vary_mask, U256::ONE); +} + +#[test] +fn we_can_get_u256_version_of_vary_mask_for_large_number() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [256, 0, 0, 256], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_vary_mask = bit_distribution.vary_mask(); + + // ASSERT + assert_eq!(u256_vary_mask, (U256::ONE << 8) + (U256::ONE << 200)); +} + +// vary_mask function end + +// leading_bit_mask function start + +#[test] +fn we_can_get_u256_version_of_leading_bit_mask_for_zero() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0; 4], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_leading_bit_mask = bit_distribution.leading_bit_mask(); + + // ASSERT + assert_eq!(u256_leading_bit_mask, U256::ONE << 255); +} + +#[test] +fn we_can_get_u256_version_of_leading_bit_mask_for_one() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0; 4], + leading_bit_mask: [1, 0, 0, 0], + }; + + // ACT + let u256_leading_bit_mask = bit_distribution.leading_bit_mask(); + + // ASSERT + assert_eq!(u256_leading_bit_mask, U256::ONE | (U256::ONE << 255)); +} + +#[test] +fn we_can_get_u256_version_of_leading_bit_mask_for_large_number() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0; 4], + leading_bit_mask: [256, 0, 0, 256], + }; + + // ACT + let u256_leading_bit_mask = bit_distribution.leading_bit_mask(); + + // ASSERT + assert_eq!( + u256_leading_bit_mask, + ((U256::ONE << 8) + (U256::ONE << 200)) | (U256::ONE << 255) + ); +} + +// leading_bit_mask function end + +// leading_bit_inverse_mask function start + +#[test] +fn we_can_get_u256_version_of_leading_bit_inverse_mask_for_zero() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [u64::MAX; 4], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_leading_bit_inverse_mask = bit_distribution.leading_bit_inverse_mask(); + + // ASSERT + assert_eq!(u256_leading_bit_inverse_mask, U256::ZERO); +} + +#[test] +fn we_can_get_u256_version_of_leading_bit_inverse_mask_for_one() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [u64::MAX - 1, u64::MAX, u64::MAX, u64::MAX], + leading_bit_mask: [0, 0, 0, 0], + }; + + // ACT + let u256_leading_bit_inverse_mask = bit_distribution.leading_bit_inverse_mask(); + + // ASSERT + assert_eq!(u256_leading_bit_inverse_mask, U256::ONE); +} + +#[test] +fn we_can_get_u256_version_of_leading_bit_inverse_mask_for_large_number() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [u64::MAX - 255, u64::MAX, u64::MAX, u64::MAX - 255], + leading_bit_mask: [0; 4], + }; + + // ACT + let u256_leading_bit_inverse_mask = bit_distribution.leading_bit_inverse_mask(); + + // ASSERT + assert_eq!( + u256_leading_bit_inverse_mask, + ((U256::ONE << 8) - U256::ONE) + (((U256::ONE << 8) - U256::ONE) << 192) + ); +} + +// leading_bit_inverse_mask function end + +#[test] +fn we_can_get_leading_bit_eval_while_varying() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0, 0, 0, 1 << 63], + leading_bit_mask: [0; 4], + }; + + // ACT + let bit_eval = bit_distribution.leading_bit_eval(&[TestScalar::ONE], TestScalar::TWO); + + // ASSERT + assert_eq!(bit_eval, TestScalar::ONE); +} + +#[test] +fn we_can_get_leading_bit_eval_while_constant_and_zero() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0, 0, 0, 0], + leading_bit_mask: [0; 4], + }; + + // ACT + let bit_eval = bit_distribution.leading_bit_eval(&[TestScalar::ONE], TestScalar::TWO); + + // ASSERT + assert_eq!(bit_eval, TestScalar::ZERO); +} + +#[test] +fn we_can_get_leading_bit_eval_while_constant_and_non_zero() { + // ARRANGE + let bit_distribution = BitDistribution { + vary_mask: [0, 0, 0, 0], + leading_bit_mask: [0, 0, 0, 1 << 63], + }; + + // ACT + let bit_eval = bit_distribution.leading_bit_eval(&[TestScalar::ONE], TestScalar::TWO); + + // ASSERT + assert_eq!(bit_eval, TestScalar::TWO); +} + +// leading_bit_eval functions start + +// leading_bit_eval functions end + #[test] fn we_can_compute_the_bit_distribution_of_an_empty_slice() { let data: Vec = vec![]; diff --git a/crates/proof-of-sql/src/base/bit/bit_mask_utils_test.rs b/crates/proof-of-sql/src/base/bit/bit_mask_utils_test.rs new file mode 100644 index 000000000..42f0f0ec6 --- /dev/null +++ b/crates/proof-of-sql/src/base/bit/bit_mask_utils_test.rs @@ -0,0 +1,56 @@ +use super::bit_mask_utils::make_bit_mask; +use crate::base::{ + bit::bit_mask_utils::is_bit_mask_negative_representation, + scalar::{test_scalar::TestScalar, Scalar}, +}; +use bnum::types::U256; + +#[test] +fn we_can_make_positive_bit_mask() { + // ARRANGE + let positive_scalar = TestScalar::TWO; + + // ACT + let bit_mask = make_bit_mask(positive_scalar); + + // ASSERT + assert_eq!(bit_mask, (U256::ONE << 255) + U256::TWO); +} + +#[test] +fn we_can_make_negative_bit_mask() { + // ARRANGE + let negative_scalar = -TestScalar::TWO; + + // ACT + let bit_mask = make_bit_mask(negative_scalar); + + // ASSERT + assert_eq!(bit_mask, (U256::ONE << 255) - U256::TWO); +} + +#[test] +fn we_can_verify_positive_bit_mask_is_positive_representation() { + // ARRANGE + let positive_scalar = TestScalar::TWO; + let bit_mask = make_bit_mask(positive_scalar); + + // ACT + let is_positive = !is_bit_mask_negative_representation(bit_mask); + + // ASSERT + assert!(is_positive); +} + +#[test] +fn we_can_verify_negative_bit_mask_is_negative_representation() { + // ARRANGE + let negative_scalar = -TestScalar::TWO; + let bit_mask = make_bit_mask(negative_scalar); + + // ACT + let is_negative = is_bit_mask_negative_representation(bit_mask); + + // ASSERT + assert!(is_negative); +} diff --git a/crates/proof-of-sql/src/base/bit/mod.rs b/crates/proof-of-sql/src/base/bit/mod.rs index f6cc941a3..5360b0c43 100644 --- a/crates/proof-of-sql/src/base/bit/mod.rs +++ b/crates/proof-of-sql/src/base/bit/mod.rs @@ -1,4 +1,6 @@ pub(crate) mod bit_mask_utils; +#[cfg(test)] +mod bit_mask_utils_test; mod bit_distribution; pub use bit_distribution::*; diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs index 8996d9ffb..554bcea82 100644 --- a/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_gadgets/sign_expr.rs @@ -1,4 +1,3 @@ -use super::{verify_constant_abs_decomposition, verify_constant_sign_decomposition}; use crate::{ base::{ bit::{ @@ -158,3 +157,67 @@ fn verify_bit_decomposition( } rhs == expr_eval } + +#[cfg(test)] +mod tests { + use crate::{ + base::{ + bit::BitDistribution, + scalar::{test_scalar::TestScalar, Scalar}, + }, + sql::proof_gadgets::sign_expr::verify_bit_decomposition, + }; + + #[test] + fn we_can_verify_bit_decomposition() { + let dist = BitDistribution { + vary_mask: [629, 0, 0, 0], + leading_bit_mask: [2, 0, 0, 9_223_372_036_854_775_808], + }; + let one_eval = TestScalar::ONE; + let bit_evals = [0, 0, 1, 1, 0, 1].map(TestScalar::from); + let expr_eval = TestScalar::from(562); + assert!(verify_bit_decomposition( + expr_eval, one_eval, &bit_evals, &dist, + )); + } + + #[test] + fn we_can_verify_bit_decomposition_constant_sign() { + let dist = BitDistribution { + vary_mask: [629, 0, 0, 0], + leading_bit_mask: [2, 0, 0, 9_223_372_036_854_775_808], + }; + let a = TestScalar::ONE; + let b = TestScalar::ONE; + let expr_eval = TestScalar::from(118) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(562) * a * (TestScalar::ONE - b) + + TestScalar::from(3) * (TestScalar::ONE - a) * b; + let one_eval = TestScalar::from(1) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(1) * a * (TestScalar::ONE - b) + + TestScalar::from(1) * (TestScalar::ONE - a) * b; + let bit_evals = [ + TestScalar::from(0) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(0) * a * (TestScalar::ONE - b) + + TestScalar::from(1) * (TestScalar::ONE - a) * b, + TestScalar::from(1) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(0) * a * (TestScalar::ONE - b) + + TestScalar::from(0) * (TestScalar::ONE - a) * b, + TestScalar::from(1) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(1) * a * (TestScalar::ONE - b) + + TestScalar::from(0) * (TestScalar::ONE - a) * b, + TestScalar::from(1) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(1) * a * (TestScalar::ONE - b) + + TestScalar::from(0) * (TestScalar::ONE - a) * b, + TestScalar::from(1) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(0) * a * (TestScalar::ONE - b) + + TestScalar::from(0) * (TestScalar::ONE - a) * b, + TestScalar::from(0) * (TestScalar::ONE - a) * (TestScalar::ONE - b) + + TestScalar::from(1) * a * (TestScalar::ONE - b) + + TestScalar::from(0) * (TestScalar::ONE - a) * b, + ]; + assert!(verify_bit_decomposition( + expr_eval, one_eval, &bit_evals, &dist, + )); + } +} From c259f426bc2e649920674092d865e89200ecfdcf Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Fri, 10 Jan 2025 08:36:28 -0500 Subject: [PATCH 11/12] one commit --- .../src/intermediate_ast.rs | 8 +- .../src/intermediate_ast_tests.rs | 4 +- crates/proof-of-sql-parser/src/sql.lalrpop | 16 ++-- crates/proof-of-sql-parser/src/sqlparser.rs | 4 +- crates/proof-of-sql-parser/src/utility.rs | 24 +++++- .../database/column_comparison_operation.rs | 36 ++++++--- .../base/database/expression_evaluation.rs | 4 +- crates/proof-of-sql/src/base/database/mod.rs | 4 +- .../base/database/owned_column_operation.rs | 80 +++++++++---------- .../src/sql/parse/dyn_proof_expr_builder.rs | 4 +- .../src/sql/parse/query_context_builder.rs | 6 +- .../src/sql/parse/where_expr_builder_tests.rs | 38 +++++---- .../src/sql/proof_exprs/comparison_util.rs | 4 +- .../src/sql/proof_exprs/dyn_proof_expr.rs | 6 +- .../src/sql/proof_exprs/inequality_expr.rs | 48 +++-------- .../proof-of-sql/src/sql/proof_exprs/mod.rs | 5 +- .../src/sql/proof_exprs/test_utility.rs | 4 +- 17 files changed, 153 insertions(+), 142 deletions(-) diff --git a/crates/proof-of-sql-parser/src/intermediate_ast.rs b/crates/proof-of-sql-parser/src/intermediate_ast.rs index d89696654..08dcd02c7 100644 --- a/crates/proof-of-sql-parser/src/intermediate_ast.rs +++ b/crates/proof-of-sql-parser/src/intermediate_ast.rs @@ -106,11 +106,11 @@ pub enum BinaryOperator { /// Comparison = Equal, - /// Comparison <= - LessThanOrEqual, + /// Comparison < + LessThan, - /// Comparison >= - GreaterThanOrEqual, + /// Comparison > + GreaterThan, } /// Possible unary operators for simple expressions diff --git a/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs b/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs index 4574de18b..418a9c1c4 100644 --- a/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs +++ b/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs @@ -987,7 +987,7 @@ fn we_can_parse_a_query_with_filter_lt() { query( cols_res(&["a"]), tab(None, "tab"), - not(ge(col("b"), lit(4))), + lt(col("b"), lit(4)), vec![], ), vec![], @@ -1023,7 +1023,7 @@ fn we_can_parse_a_query_with_filter_gt() { query( cols_res(&["a"]), tab(None, "tab"), - not(le(col("b"), lit(4))), + gt(col("b"), lit(4)), vec![], ), vec![], diff --git a/crates/proof-of-sql-parser/src/sql.lalrpop b/crates/proof-of-sql-parser/src/sql.lalrpop index ce6dca007..e838af055 100644 --- a/crates/proof-of-sql-parser/src/sql.lalrpop +++ b/crates/proof-of-sql-parser/src/sql.lalrpop @@ -238,35 +238,35 @@ Expression: Box = { }), #[precedence(level="4")] #[assoc(side="left")] - ">=" => + ">" => Box::new(intermediate_ast::Expression::Binary { - op: intermediate_ast::BinaryOperator::GreaterThanOrEqual, + op: intermediate_ast::BinaryOperator::GreaterThan, left, right, }), - "<=" => + "<" => Box::new(intermediate_ast::Expression::Binary { - op: intermediate_ast::BinaryOperator::LessThanOrEqual, + op: intermediate_ast::BinaryOperator::LessThan, left, right, }), - ">" => + ">=" => Box::new(intermediate_ast::Expression::Unary { op: intermediate_ast::UnaryOperator::Not, expr: Box::new(intermediate_ast::Expression::Binary { - op: intermediate_ast::BinaryOperator::LessThanOrEqual, + op: intermediate_ast::BinaryOperator::LessThan, left, right, }), }), - "<" => + "<=" => Box::new(intermediate_ast::Expression::Unary { op: intermediate_ast::UnaryOperator::Not, expr: Box::new(intermediate_ast::Expression::Binary { - op: intermediate_ast::BinaryOperator::GreaterThanOrEqual, + op: intermediate_ast::BinaryOperator::GreaterThan, left, right, }), diff --git a/crates/proof-of-sql-parser/src/sqlparser.rs b/crates/proof-of-sql-parser/src/sqlparser.rs index 72643b568..9a54bf3d9 100644 --- a/crates/proof-of-sql-parser/src/sqlparser.rs +++ b/crates/proof-of-sql-parser/src/sqlparser.rs @@ -90,8 +90,8 @@ impl From for BinaryOperator { PoSqlBinaryOperator::And => BinaryOperator::And, PoSqlBinaryOperator::Or => BinaryOperator::Or, PoSqlBinaryOperator::Equal => BinaryOperator::Eq, - PoSqlBinaryOperator::LessThanOrEqual => BinaryOperator::LtEq, - PoSqlBinaryOperator::GreaterThanOrEqual => BinaryOperator::GtEq, + PoSqlBinaryOperator::LessThan => BinaryOperator::Lt, + PoSqlBinaryOperator::GreaterThan => BinaryOperator::Gt, PoSqlBinaryOperator::Add => BinaryOperator::Plus, PoSqlBinaryOperator::Subtract => BinaryOperator::Minus, PoSqlBinaryOperator::Multiply => BinaryOperator::Multiply, diff --git a/crates/proof-of-sql-parser/src/utility.rs b/crates/proof-of-sql-parser/src/utility.rs index cf9d7781e..60188cf59 100644 --- a/crates/proof-of-sql-parser/src/utility.rs +++ b/crates/proof-of-sql-parser/src/utility.rs @@ -30,8 +30,18 @@ pub fn equal(left: Box, right: Box) -> Box { /// Construct a new boxed `Expression` A >= B #[must_use] pub fn ge(left: Box, right: Box) -> Box { + not(Box::new(Expression::Binary { + op: BinaryOperator::LessThan, + left, + right, + })) +} + +/// Construct a new boxed `Expression` A > B +#[must_use] +pub fn gt(left: Box, right: Box) -> Box { Box::new(Expression::Binary { - op: BinaryOperator::GreaterThanOrEqual, + op: BinaryOperator::GreaterThan, left, right, }) @@ -40,8 +50,18 @@ pub fn ge(left: Box, right: Box) -> Box { /// Construct a new boxed `Expression` A <= B #[must_use] pub fn le(left: Box, right: Box) -> Box { + not(Box::new(Expression::Binary { + op: BinaryOperator::GreaterThan, + left, + right, + })) +} + +/// Construct a new boxed `Expression` A < B +#[must_use] +pub fn lt(left: Box, right: Box) -> Box { Box::new(Expression::Binary { - op: BinaryOperator::LessThanOrEqual, + op: BinaryOperator::LessThan, left, right, }) diff --git a/crates/proof-of-sql/src/base/database/column_comparison_operation.rs b/crates/proof-of-sql/src/base/database/column_comparison_operation.rs index 1f5c32833..ac995c6be 100644 --- a/crates/proof-of-sql/src/base/database/column_comparison_operation.rs +++ b/crates/proof-of-sql/src/base/database/column_comparison_operation.rs @@ -280,13 +280,13 @@ impl ComparisonOp for EqualOp { } } -pub struct GreaterThanOrEqualOp {} -impl ComparisonOp for GreaterThanOrEqualOp { +pub struct GreaterThanOp {} +impl ComparisonOp for GreaterThanOp { fn op(l: &T, r: &T) -> bool where T: Debug + Ord, { - l >= r + l > r } fn decimal_op_left_upcast( @@ -299,7 +299,10 @@ impl ComparisonOp for GreaterThanOrEqualOp { S: Scalar, T: Copy + Debug + Ord + Zero + Into, { - ge_decimal_columns(lhs, rhs, left_column_type, right_column_type) + le_decimal_columns(lhs, rhs, left_column_type, right_column_type) + .iter() + .map(|b| !b) + .collect() } fn decimal_op_right_upcast( @@ -312,25 +315,28 @@ impl ComparisonOp for GreaterThanOrEqualOp { S: Scalar, T: Copy + Debug + Ord + Zero + Into, { - le_decimal_columns(rhs, lhs, right_column_type, left_column_type) + ge_decimal_columns(rhs, lhs, right_column_type, left_column_type) + .iter() + .map(|b| !b) + .collect() } fn string_op(_lhs: &[String], _rhs: &[String]) -> ColumnOperationResult> { Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: ">=".to_string(), + operator: ">".to_string(), left_type: ColumnType::VarChar, right_type: ColumnType::VarChar, }) } } -pub struct LessThanOrEqualOp {} -impl ComparisonOp for LessThanOrEqualOp { +pub struct LessThanOp {} +impl ComparisonOp for LessThanOp { fn op(l: &T, r: &T) -> bool where T: Debug + Ord, { - l <= r + l < r } fn decimal_op_left_upcast( @@ -343,7 +349,10 @@ impl ComparisonOp for LessThanOrEqualOp { S: Scalar, T: Copy + Debug + Ord + Zero + Into, { - le_decimal_columns(lhs, rhs, left_column_type, right_column_type) + ge_decimal_columns(lhs, rhs, left_column_type, right_column_type) + .iter() + .map(|b| !b) + .collect() } fn decimal_op_right_upcast( @@ -356,12 +365,15 @@ impl ComparisonOp for LessThanOrEqualOp { S: Scalar, T: Copy + Debug + Ord + Zero + Into, { - ge_decimal_columns(rhs, lhs, right_column_type, left_column_type) + le_decimal_columns(rhs, lhs, right_column_type, left_column_type) + .iter() + .map(|b| !b) + .collect() } fn string_op(_lhs: &[String], _rhs: &[String]) -> ColumnOperationResult> { Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: "<=".to_string(), + operator: "<".to_string(), left_type: ColumnType::VarChar, right_type: ColumnType::VarChar, }) diff --git a/crates/proof-of-sql/src/base/database/expression_evaluation.rs b/crates/proof-of-sql/src/base/database/expression_evaluation.rs index d9df43097..adab9d60b 100644 --- a/crates/proof-of-sql/src/base/database/expression_evaluation.rs +++ b/crates/proof-of-sql/src/base/database/expression_evaluation.rs @@ -90,8 +90,8 @@ impl OwnedTable { BinaryOperator::And => Ok(left.element_wise_and(&right)?), BinaryOperator::Or => Ok(left.element_wise_or(&right)?), BinaryOperator::Eq => Ok(left.element_wise_eq(&right)?), - BinaryOperator::GtEq => Ok(left.element_wise_ge(&right)?), - BinaryOperator::LtEq => Ok(left.element_wise_le(&right)?), + BinaryOperator::Gt => Ok(left.element_wise_gt(&right)?), + BinaryOperator::Lt => Ok(left.element_wise_lt(&right)?), BinaryOperator::Plus => Ok(left.element_wise_add(&right)?), BinaryOperator::Minus => Ok(left.element_wise_sub(&right)?), BinaryOperator::Multiply => Ok(left.element_wise_mul(&right)?), diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index 03f80f23c..a17f4836d 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -21,9 +21,7 @@ mod column_arithmetic_operation; pub(super) use column_arithmetic_operation::{AddOp, ArithmeticOp, DivOp, MulOp, SubOp}; mod column_comparison_operation; -pub(super) use column_comparison_operation::{ - ComparisonOp, EqualOp, GreaterThanOrEqualOp, LessThanOrEqualOp, -}; +pub(super) use column_comparison_operation::{ComparisonOp, EqualOp, GreaterThanOp, LessThanOp}; mod column_index_operation; pub(super) use column_index_operation::apply_column_to_indexes; diff --git a/crates/proof-of-sql/src/base/database/owned_column_operation.rs b/crates/proof-of-sql/src/base/database/owned_column_operation.rs index 48eca7027..369360f38 100644 --- a/crates/proof-of-sql/src/base/database/owned_column_operation.rs +++ b/crates/proof-of-sql/src/base/database/owned_column_operation.rs @@ -1,6 +1,6 @@ use super::{ AddOp, ArithmeticOp, ColumnOperationError, ColumnOperationResult, ComparisonOp, DivOp, EqualOp, - GreaterThanOrEqualOp, LessThanOrEqualOp, MulOp, SubOp, + GreaterThanOp, LessThanOp, MulOp, SubOp, }; use crate::base::{ database::{ @@ -65,13 +65,13 @@ impl OwnedColumn { } /// Element-wise less than or equal to check for two columns - pub fn element_wise_le(&self, rhs: &Self) -> ColumnOperationResult { - LessThanOrEqualOp::owned_column_element_wise_comparison(self, rhs) + pub fn element_wise_lt(&self, rhs: &Self) -> ColumnOperationResult { + LessThanOp::owned_column_element_wise_comparison(self, rhs) } /// Element-wise greater than or equal to check for two columns - pub fn element_wise_ge(&self, rhs: &Self) -> ColumnOperationResult { - GreaterThanOrEqualOp::owned_column_element_wise_comparison(self, rhs) + pub fn element_wise_gt(&self, rhs: &Self) -> ColumnOperationResult { + GreaterThanOp::owned_column_element_wise_comparison(self, rhs) } /// Element-wise addition for two columns @@ -118,13 +118,13 @@ mod test { Err(ColumnOperationError::DifferentColumnLength { .. }) )); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::DifferentColumnLength { .. }) )); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert!(matches!( result, Err(ColumnOperationError::DifferentColumnLength { .. }) @@ -316,31 +316,31 @@ mod test { } #[test] - fn we_can_do_le_operation_on_numeric_and_boolean_columns() { + fn we_can_do_lt_operation_on_numeric_and_boolean_columns() { // Booleans let lhs = OwnedColumn::::Boolean(vec![true, false, true]); let rhs = OwnedColumn::::Boolean(vec![true, true, false]); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![false, true, false])) ); // Integers let lhs = OwnedColumn::::SmallInt(vec![1, 3, 2]); let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, false, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); let lhs = OwnedColumn::::Int(vec![1, 3, 2]); let rhs = OwnedColumn::::SmallInt(vec![1, 2, 3]); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, false, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); // Decimals @@ -348,29 +348,29 @@ mod test { let rhs_scalars = [1, 24, -3].iter().map(TestScalar::from).collect(); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 3, lhs_scalars); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![false, true, false])) ); // Decimals and integers let lhs_scalars = [10, -2, -30].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::TinyInt(vec![1, -20, 3]); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), -1, lhs_scalars); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![false, true, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); let lhs_scalars = [10, -2, -30].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Int(vec![1, -20, 3]); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), -1, lhs_scalars); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![false, true, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); } @@ -379,27 +379,27 @@ mod test { // Booleans let lhs = OwnedColumn::::Boolean(vec![true, false, true]); let rhs = OwnedColumn::::Boolean(vec![true, true, false]); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, false, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); // Integers let lhs = OwnedColumn::::SmallInt(vec![1, 3, 2]); let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![false, true, false])) ); let lhs = OwnedColumn::::Int(vec![1, 3, 2]); let rhs = OwnedColumn::::SmallInt(vec![1, 2, 3]); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![false, true, false])) ); // Decimals @@ -407,29 +407,29 @@ mod test { let rhs_scalars = [1, 24, -3].iter().map(TestScalar::from).collect(); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 3, lhs_scalars); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, false, true])) + Ok(OwnedColumn::::Boolean(vec![false, false, true])) ); // Decimals and integers let lhs_scalars = [10, -2, -30].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::TinyInt(vec![1_i8, -20, 3]); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), -1, lhs_scalars); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![true, false, false])) ); let lhs_scalars = [10, -2, -30].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::BigInt(vec![1_i64, -20, 3]); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), -1, lhs_scalars); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert_eq!( result, - Ok(OwnedColumn::::Boolean(vec![true, true, false])) + Ok(OwnedColumn::::Boolean(vec![true, false, false])) ); } @@ -443,7 +443,7 @@ mod test { .map(ToString::to_string) .collect(), ); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) @@ -456,19 +456,19 @@ mod test { .map(ToString::to_string) .collect(), ); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) @@ -477,7 +477,7 @@ mod test { // Booleans can't be compared with other types let lhs = OwnedColumn::::Boolean(vec![true, false, true]); let rhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) @@ -485,7 +485,7 @@ mod test { let lhs = OwnedColumn::::Boolean(vec![true, false, true]); let rhs = OwnedColumn::::Int(vec![1, 2, 3]); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) @@ -504,13 +504,13 @@ mod test { .map(ToString::to_string) .collect(), ); - let result = lhs.element_wise_le(&rhs); + let result = lhs.element_wise_lt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); - let result = lhs.element_wise_ge(&rhs); + let result = lhs.element_wise_gt(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) diff --git a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs index a0e76a031..b712cdd5e 100644 --- a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs @@ -168,12 +168,12 @@ impl DynProofExprBuilder<'_> { let right = self.visit_expr(right); DynProofExpr::try_new_equals(left?, right?) } - BinaryOperator::GtEq => { + BinaryOperator::Gt => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_inequality(left?, right?, false) } - BinaryOperator::LtEq => { + BinaryOperator::Lt => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_inequality(left?, right?, true) diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index 708cb8236..e7aa30151 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -171,8 +171,8 @@ impl<'a> QueryContextBuilder<'a> { BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Eq - | BinaryOperator::GtEq - | BinaryOperator::LtEq => Ok(ColumnType::Boolean), + | BinaryOperator::Gt + | BinaryOperator::Lt => Ok(ColumnType::Boolean), BinaryOperator::Multiply | BinaryOperator::Divide | BinaryOperator::Minus @@ -309,7 +309,7 @@ pub(crate) fn type_check_binary_operation( | (ColumnType::Scalar, _) ) || (left_dtype.is_numeric() && right_dtype.is_numeric()) } - BinaryOperator::GtEq | BinaryOperator::LtEq => { + BinaryOperator::Gt | BinaryOperator::Lt => { if left_dtype == ColumnType::VarChar || right_dtype == ColumnType::VarChar { return false; } diff --git a/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs b/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs index 560362df4..1df7f16c5 100644 --- a/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs @@ -143,14 +143,17 @@ fn we_can_directly_check_whether_bigint_columns_ge_int128() { .build(Some(expr_integer_to_integer)) .unwrap() .unwrap(); - let expected = DynProofExpr::try_new_inequality( - DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( - "sxt.sxt_tab".parse().unwrap(), - "bigint_column".into(), - ColumnType::BigInt, - ))), - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))), - false, + let expected = DynProofExpr::try_new_not( + DynProofExpr::try_new_inequality( + DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( + "sxt.sxt_tab".parse().unwrap(), + "bigint_column".into(), + ColumnType::BigInt, + ))), + DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))), + true, + ) + .unwrap(), ) .unwrap(); assert_eq!(actual, expected); @@ -165,14 +168,17 @@ fn we_can_directly_check_whether_bigint_columns_le_int128() { .build(Some(expr_integer_to_integer)) .unwrap() .unwrap(); - let expected = DynProofExpr::try_new_inequality( - DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( - "sxt.sxt_tab".parse().unwrap(), - "bigint_column".into(), - ColumnType::BigInt, - ))), - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))), - true, + let expected = DynProofExpr::try_new_not( + DynProofExpr::try_new_inequality( + DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( + "sxt.sxt_tab".parse().unwrap(), + "bigint_column".into(), + ColumnType::BigInt, + ))), + DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))), + false, + ) + .unwrap(), ) .unwrap(); assert_eq!(actual, expected); diff --git a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs index 92601f512..a3b28e35e 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs @@ -31,7 +31,7 @@ pub fn scale_and_subtract_literal( let operator = if is_equal { BinaryOperator::Eq } else { - BinaryOperator::LtEq + BinaryOperator::Lt }; if !type_check_binary_operation(lhs_type, rhs_type, &operator) { return Err(ConversionError::DataTypeMismatch { @@ -104,7 +104,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( let operator = if is_equal { BinaryOperator::Eq } else { - BinaryOperator::LtEq + BinaryOperator::Lt }; if !type_check_binary_operation(lhs_type, rhs_type, &operator) { return Err(ConversionError::DataTypeMismatch { diff --git a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs index efe540b52..40dd63627 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs @@ -89,15 +89,15 @@ impl DynProofExpr { pub fn try_new_inequality( lhs: DynProofExpr, rhs: DynProofExpr, - is_lte: bool, + is_lt: bool, ) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::LtEq) { + if type_check_binary_operation(lhs_datatype, rhs_datatype, &BinaryOperator::Lt) { Ok(Self::Inequality(InequalityExpr::new( Box::new(lhs), Box::new(rhs), - is_lte, + is_lt, ))) } else { Err(ConversionError::DataTypeMismatch { diff --git a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs index 750bb4e77..4a0ac56d9 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs @@ -1,8 +1,4 @@ -use super::{ - prover_evaluate_equals_zero, prover_evaluate_or, result_evaluate_equals_zero, - result_evaluate_or, scale_and_add_subtract_eval, scale_and_subtract, - verifier_evaluate_equals_zero, verifier_evaluate_or, DynProofExpr, ProofExpr, -}; +use super::{scale_and_add_subtract_eval, scale_and_subtract, DynProofExpr, ProofExpr}; use crate::{ base::{ database::{Column, ColumnRef, ColumnType, Table}, @@ -25,13 +21,13 @@ use serde::{Deserialize, Serialize}; pub struct InequalityExpr { lhs: Box, rhs: Box, - is_lte: bool, + is_lt: bool, } impl InequalityExpr { - /// Create a new less than or equal expression - pub fn new(lhs: Box, rhs: Box, is_lte: bool) -> Self { - Self { lhs, rhs, is_lte } + /// Create a new less than or equal + pub fn new(lhs: Box, rhs: Box, is_lt: bool) -> Self { + Self { lhs, rhs, is_lt } } } @@ -53,7 +49,7 @@ impl ProofExpr for InequalityExpr { let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); let table_length = table.num_rows(); - let diff = if self.is_lte { + let diff = if self.is_lt { scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, false) .expect("Failed to scale and subtract") } else { @@ -61,14 +57,8 @@ impl ProofExpr for InequalityExpr { .expect("Failed to scale and subtract") }; - // diff == 0 - let equals_zero = result_evaluate_equals_zero(table_length, alloc, diff); - - // sign(diff) == -1 - let sign = result_evaluate_sign(table_length, alloc, diff); - - // (diff == 0) || (sign(diff) == -1) - let res = Column::Boolean(result_evaluate_or(table_length, alloc, equals_zero, sign)); + // (sign(diff) == -1) + let res = Column::Boolean(result_evaluate_sign(table_length, alloc, diff)); log::log_memory_usage("End"); @@ -88,7 +78,7 @@ impl ProofExpr for InequalityExpr { let rhs_column = self.rhs.prover_evaluate(builder, alloc, table); let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); - let diff = if self.is_lte { + let diff = if self.is_lt { scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, false) .expect("Failed to scale and subtract") } else { @@ -96,14 +86,8 @@ impl ProofExpr for InequalityExpr { .expect("Failed to scale and subtract") }; - // diff == 0 - let equals_zero = prover_evaluate_equals_zero(table.num_rows(), builder, alloc, diff); - - // sign(diff) == -1 - let sign = prover_evaluate_sign(builder, alloc, diff); - - // (diff == 0) || (sign(diff) == -1) - let res = Column::Boolean(prover_evaluate_or(builder, alloc, equals_zero, sign)); + // (sign(diff) == -1) + let res = Column::Boolean(prover_evaluate_sign(builder, alloc, diff)); log::log_memory_usage("End"); @@ -120,20 +104,14 @@ impl ProofExpr for InequalityExpr { let rhs_eval = self.rhs.verifier_evaluate(builder, accessor, one_eval)?; let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); - let diff_eval = if self.is_lte { + let diff_eval = if self.is_lt { scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, true) } else { scale_and_add_subtract_eval(rhs_eval, lhs_eval, rhs_scale, lhs_scale, true) }; - // diff == 0 - let equals_zero = verifier_evaluate_equals_zero(builder, diff_eval, one_eval)?; - // sign(diff) == -1 - let sign = verifier_evaluate_sign(builder, diff_eval, one_eval)?; - - // (diff == 0) || (sign(diff) == -1) - verifier_evaluate_or(builder, &equals_zero, &sign) + verifier_evaluate_sign(builder, diff_eval, one_eval) } fn get_column_references(&self, columns: &mut IndexSet) { diff --git a/crates/proof-of-sql/src/sql/proof_exprs/mod.rs b/crates/proof-of-sql/src/sql/proof_exprs/mod.rs index 298ad945d..a5717ee9b 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/mod.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/mod.rs @@ -39,7 +39,7 @@ use inequality_expr::InequalityExpr; mod inequality_expr_test; mod or_expr; -use or_expr::{prover_evaluate_or, result_evaluate_or, verifier_evaluate_or, OrExpr}; +use or_expr::OrExpr; #[cfg(all(test, feature = "blitzar"))] mod or_expr_test; @@ -58,9 +58,6 @@ pub(crate) use numerical_util::{ mod equals_expr; pub(crate) use equals_expr::EqualsExpr; -use equals_expr::{ - prover_evaluate_equals_zero, result_evaluate_equals_zero, verifier_evaluate_equals_zero, -}; #[cfg(all(test, feature = "blitzar"))] mod equals_expr_test; diff --git a/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs b/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs index baa8b44e5..e6d3c9b42 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs @@ -33,14 +33,14 @@ pub fn equal(left: DynProofExpr, right: DynProofExpr) -> DynProofExpr { /// Panics if: /// - `DynProofExpr::try_new_inequality()` returns an error. pub fn lte(left: DynProofExpr, right: DynProofExpr) -> DynProofExpr { - DynProofExpr::try_new_inequality(left, right, true).unwrap() + not(DynProofExpr::try_new_inequality(left, right, false).unwrap()) } /// # Panics /// Panics if: /// - `DynProofExpr::try_new_inequality()` returns an error. pub fn gte(left: DynProofExpr, right: DynProofExpr) -> DynProofExpr { - DynProofExpr::try_new_inequality(left, right, false).unwrap() + not(DynProofExpr::try_new_inequality(left, right, true).unwrap()) } /// # Panics From 0e1b2cb5a05eb917112c806ca4c130875b0ffd6d Mon Sep 17 00:00:00 2001 From: stuarttimwhite Date: Fri, 10 Jan 2025 10:27:59 -0500 Subject: [PATCH 12/12] test: change sqlparser test to use strict inequality --- crates/proof-of-sql-parser/src/sqlparser.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/proof-of-sql-parser/src/sqlparser.rs b/crates/proof-of-sql-parser/src/sqlparser.rs index 9a54bf3d9..efae449f5 100644 --- a/crates/proof-of-sql-parser/src/sqlparser.rs +++ b/crates/proof-of-sql-parser/src/sqlparser.rs @@ -291,7 +291,7 @@ mod test { ); check_posql_intermediate_ast_to_sqlparser_equality("select 1 as a, 'Meow' as d, b as b from namespace.table where c = 4 order by a desc limit 10 offset 0;"); check_posql_intermediate_ast_to_sqlparser_equality( - "select true as cons, a and b or c >= 4 as comp from tab where d = 'Space and Time';", + "select true as cons, a and b or c > 4 as comp from tab where d = 'Space and Time';", ); check_posql_intermediate_ast_to_sqlparser_equality( "select cat as cat, true as cons, max(meow) as max_meow from tab where d = 'Space and Time' group by cat;",