diff --git a/tfhe/src/integer/server_key/comparator.rs b/tfhe/src/integer/server_key/comparator.rs index 7c7682a4e0..2f39ccd7f2 100644 --- a/tfhe/src/integer/server_key/comparator.rs +++ b/tfhe/src/integer/server_key/comparator.rs @@ -1,11 +1,8 @@ use super::ServerKey; -use crate::core_crypto::prelude::Plaintext; -use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::shortint::server_key::LookupTableOwned; use crate::shortint::Ciphertext; -use rayon::prelude::*; /// Used for compare_blocks_with_zero #[derive(Clone, Copy)] @@ -199,24 +196,6 @@ impl<'a> Comparator<'a> { self.server_key.key.unchecked_scalar_add_assign(lhs, 1); } - // lhs will be assigned - // - 0 if lhs < rhs - // - 1 if lhs == rhs - // - 2 if lhs > rhs - fn scalar_compare_block_assign(&self, lhs: &mut crate::shortint::Ciphertext, scalar: u8) { - // Same logic as compare_block_assign - // but rhs is a scalar - let delta = - (1u64 << (u64::BITS as u64 - 1)) / (lhs.carry_modulus.0 * lhs.message_modulus.0) as u64; - let plaintext = Plaintext((scalar as u64) * delta); - crate::core_crypto::algorithms::lwe_ciphertext_plaintext_sub_assign(&mut lhs.ct, plaintext); - self.server_key - .key - .apply_lookup_table_assign(lhs, &self.sign_lut); - - self.server_key.key.unchecked_scalar_add_assign(lhs, 1); - } - fn reduce_two_sign_blocks_assign( &self, msb_sign: &mut crate::shortint::Ciphertext, @@ -289,65 +268,6 @@ impl<'a> Comparator<'a> { } } - /// Reduces a vec containing shortint blocks that encrypts a sign - /// (inferior, equal, superior) to one single shortint block containing the - /// final sign - fn reduce_signs_parallelized( - &self, - mut sign_blocks: Vec, - sign_result_handler_fn: F, - ) -> crate::shortint::Ciphertext - where - F: Fn(u64) -> u64, - { - let mut sign_blocks_2 = Vec::with_capacity(sign_blocks.len() / 2); - while sign_blocks.len() > 2 { - sign_blocks - .par_chunks_exact(2) - .map(|chunk| { - let (low, high) = (&chunk[0], &chunk[1]); - let mut high = high.clone(); - self.reduce_two_sign_blocks_assign(&mut high, low); - high - }) - .collect_into_vec(&mut sign_blocks_2); - - if (sign_blocks.len() % 2) == 1 { - sign_blocks_2.push(sign_blocks[sign_blocks.len() - 1].clone()); - } - - std::mem::swap(&mut sign_blocks_2, &mut sign_blocks); - } - - if sign_blocks.len() == 2 { - let final_lut = self.server_key.key.generate_lookup_table(|x| { - let final_sign = reduce_two_orderings_function(x); - sign_result_handler_fn(final_sign) - }); - // We don't use pack_block_assign as the offset '4' does not depend on params - let mut result = self.server_key.key.unchecked_scalar_mul(&sign_blocks[1], 4); - self.server_key - .key - .unchecked_add_assign(&mut result, &sign_blocks[0]); - self.server_key - .key - .apply_lookup_table_assign(&mut result, &final_lut); - result - } else { - let final_lut = self.server_key.key.generate_lookup_table(|x| { - // sign blocks have values in the set {0, 1, 2} - // here we force apply that modulus explicitly - // so that generate_lookup_table has the correct - // degree estimation - let final_sign = x % 3; - sign_result_handler_fn(final_sign) - }); - self.server_key - .key - .apply_lookup_table(&sign_blocks[0], &final_lut) - } - } - /// returns: /// /// - 0 if lhs < rhs @@ -449,523 +369,6 @@ impl<'a> Comparator<'a> { self.reduce_signs(comparisons, sign_result_handler_fn) } - /// Expects the carry buffers to be empty - /// - /// Requires that the RadixCiphertext block have 4 bits minimum (carry + message) - /// - /// This functions takes two integer ciphertext: - /// - /// It returns a Vec of block that will contain the sign of the comparison - /// (Self::IS_INFERIOR, Self::IS_EQUAL, Self::IS_SUPERIOR) - /// - /// The output len may be shorter as blocks may be packed - fn unchecked_compare_parallelized( - &self, - lhs: &T, - rhs: &T, - sign_result_handler_fn: F, - ) -> crate::shortint::Ciphertext - where - T: IntegerRadixCiphertext, - F: Fn(u64) -> u64, - { - assert_eq!(lhs.blocks().len(), rhs.blocks().len()); - - let num_block = lhs.blocks().len(); - - // false positive as compare_blocks does not mean the same in both branches - #[allow(clippy::branches_sharing_code)] - let compare_blocks_fn = - if lhs.blocks()[0].carry_modulus.0 < lhs.blocks()[0].message_modulus.0 { - /// Compares blocks in parallel - fn compare_blocks( - comparator: &Comparator, - lhs_blocks: &[crate::shortint::Ciphertext], - rhs_blocks: &[crate::shortint::Ciphertext], - out_comparisons: &mut Vec, - ) { - lhs_blocks - .par_iter() - .zip(rhs_blocks.par_iter()) - .map(|(lhs, rhs)| { - let mut lhs = lhs.clone(); - comparator.compare_block_assign(&mut lhs, rhs); - lhs - }) - .collect_into_vec(out_comparisons); - } - - compare_blocks - } else { - /// Compares blocks in parallel, using the fact that they can be packed - fn compare_blocks( - comparator: &Comparator, - lhs_blocks: &[crate::shortint::Ciphertext], - rhs_blocks: &[crate::shortint::Ciphertext], - out_comparisons: &mut Vec, - ) { - // After packing we have to clean the noise - let identity = comparator.server_key.key.generate_lookup_table(|x| x); - lhs_blocks - .par_chunks(2) - .zip(rhs_blocks.par_chunks(2)) - .map(|(lhs_chunk, rhs_chunk)| { - let (mut packed_lhs, packed_rhs) = rayon::join( - || { - let mut block = comparator.pack_block_chunk(lhs_chunk); - comparator - .server_key - .key - .apply_lookup_table_assign(&mut block, &identity); - block - }, - || { - let mut block = comparator.pack_block_chunk(rhs_chunk); - comparator - .server_key - .key - .apply_lookup_table_assign(&mut block, &identity); - block - }, - ); - - comparator.compare_block_assign(&mut packed_lhs, &packed_rhs); - packed_lhs - }) - .collect_into_vec(out_comparisons); - } - compare_blocks - }; - - let mut comparisons = Vec::with_capacity(num_block); - - if T::IS_SIGNED { - let (lhs_last_block, lhs_ls_blocks) = lhs.blocks().split_last().unwrap(); - let (rhs_last_block, rhs_ls_blocks) = rhs.blocks().split_last().unwrap(); - let (_, last_block_cmp) = rayon::join( - || { - compare_blocks_fn(self, lhs_ls_blocks, rhs_ls_blocks, &mut comparisons); - }, - || self.compare_blocks_with_sign_bit(lhs_last_block, rhs_last_block), - ); - - comparisons.push(last_block_cmp); - } else { - compare_blocks_fn(self, lhs.blocks(), rhs.blocks(), &mut comparisons); - } - - self.reduce_signs_parallelized(comparisons, sign_result_handler_fn) - } - - /// This functions takes two slices: - /// - /// - one of encrypted blocks - /// - the other of scalar to compare to each encrypted block - /// - /// It returns a Vec of block that will contain the sign of the comparison - /// (Self::IS_INFERIOR, Self::IS_EQUAL, Self::IS_SUPERIOR) - /// - /// The output len is half the input len as blocks will be packed - fn unchecked_scalar_block_slice_compare_parallelized( - &self, - lhs_blocks: &[Ciphertext], - scalar_blocks: &[u8], - ) -> Vec { - assert_eq!(lhs_blocks.len(), scalar_blocks.len()); - let num_blocks = lhs_blocks.len(); - let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2); - - let message_modulus = self.server_key.key.message_modulus.0; - let mut signs = Vec::with_capacity(num_blocks_halved); - lhs_blocks - .par_chunks(2) - .zip(scalar_blocks.par_chunks(2)) - .map(|(lhs_chunk, scalar_chunk)| { - let packed_scalar = scalar_chunk[0] - + (scalar_chunk.get(1).copied().unwrap_or(0) * message_modulus as u8); - let mut packed_lhs = self.pack_block_chunk(lhs_chunk); - self.scalar_compare_block_assign(&mut packed_lhs, packed_scalar); - packed_lhs - }) - .collect_into_vec(&mut signs); - - signs - } - - /// Computes the sign of an integer ciphertext with a clear value - /// - /// * The ciphertext can be unsigned or signed - /// * The clear value can be positive or negative - fn unchecked_scalar_compare_parallelized( - &self, - lhs: &T, - rhs: Scalar, - sign_result_handler_fn: F, - ) -> Ciphertext - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - F: Fn(u64) -> u64, - { - if T::IS_SIGNED { - match self.server_key.is_scalar_out_of_bounds(lhs, rhs) { - Some(std::cmp::Ordering::Greater) => { - // Scalar is greater than the bounds, so ciphertext is smaller - return self - .server_key - .key - .create_trivial(sign_result_handler_fn(Self::IS_INFERIOR)); - } - Some(std::cmp::Ordering::Less) => { - // Scalar is smaller than the bounds, so ciphertext is bigger - return self - .server_key - .key - .create_trivial(sign_result_handler_fn(Self::IS_SUPERIOR)); - } - Some(std::cmp::Ordering::Equal) => unreachable!("Internal error: invalid value"), - None => { - // scalar is in range, fallthrough - } - } - - if rhs >= Scalar::ZERO { - self.signed_unchecked_scalar_compare_with_positive_scalar_parallelized( - lhs.blocks(), - rhs, - sign_result_handler_fn, - ) - } else { - let scalar_as_trivial: T = self - .server_key - .create_trivial_radix(rhs, lhs.blocks().len()); - self.unchecked_compare_parallelized(lhs, &scalar_as_trivial, sign_result_handler_fn) - } - } else { - self.unsigned_unchecked_scalar_compare_blocks_parallelized( - lhs.blocks(), - rhs, - sign_result_handler_fn, - ) - } - } - - /// This function computes the sign of a signed integer ciphertext with - /// a positive clear value - /// - /// Scalar **must** be >= 0 - /// Scalar must be <= the max value possible for lhs - fn signed_unchecked_scalar_compare_with_positive_scalar_parallelized( - &self, - lhs_blocks: &[Ciphertext], - rhs: Scalar, - sign_result_handler_fn: F, - ) -> Ciphertext - where - Scalar: DecomposableInto, - F: Fn(u64) -> u64, - { - assert!(!lhs_blocks.is_empty()); - assert!(rhs >= Scalar::ZERO); - - let message_modulus = self.server_key.key.message_modulus.0; - - let scalar_blocks = BlockDecomposer::with_early_stop_at_zero(rhs, message_modulus.ilog2()) - .iter_as::() - .map(|x| x as u8) - .take(lhs_blocks.len()) - .collect::>(); - - let (least_significant_blocks, most_significant_blocks) = - lhs_blocks.split_at(scalar_blocks.len()); - - match ( - least_significant_blocks.is_empty(), - most_significant_blocks.is_empty(), - ) { - (false, false) => { - // We have to handle both part of the work - // And the sign bit is located in the most_significant_blocks - let (lsb_sign, msb_sign) = rayon::join( - || { - let lsb_signs = self.unchecked_scalar_block_slice_compare_parallelized( - least_significant_blocks, - &scalar_blocks[..least_significant_blocks.len()], - ); - self.reduce_signs_parallelized(lsb_signs, |x| x) - }, - || { - // most_significant_blocks has the sign block, which must be handled - let are_all_msb_equal_to_zero = self.server_key.are_all_blocks_zero( - &most_significant_blocks[..most_significant_blocks.len() - 1], - ); - let sign_bit_pos = self.server_key.key.message_modulus.0.ilog2() - 1; - - // This LUT is the fusion of manu LUT. - // It defines the ordering given by the sign block (known scalar is >= 0) - // It defines the ordering given by the result of are_all_msb_blocks_zeros - // and finally reduces these two previous ordering to produce one final - // ordering for the considered blocks - let lut = self.server_key.key.generate_lookup_table_bivariate( - |sign_block, msb_are_zeros| { - let sign_bit_is_set = (sign_block >> sign_bit_pos) == 1; - let sign_block_ordering = if sign_bit_is_set { - Self::IS_INFERIOR - } else if sign_block != 0 { - Self::IS_SUPERIOR - } else { - Self::IS_EQUAL - }; - let msb_ordering = if msb_are_zeros == 1 { - Self::IS_EQUAL - } else { - Self::IS_SUPERIOR - }; - - reduce_two_orderings_function( - sign_block_ordering << 2 | msb_ordering, - ) - }, - ); - self.server_key.key.unchecked_apply_lookup_table_bivariate( - most_significant_blocks.last().as_ref().unwrap(), - &are_all_msb_equal_to_zero, - &lut, - ) - }, - ); - - // parallelized not needed as there are 2 blocks - self.reduce_signs(vec![lsb_sign, msb_sign], sign_result_handler_fn) - } - (false, true) => { - // This means lhs_blocks.len() == scalar_blocks.len() - // We have to do only the regular block comparisons. - // We again have to split in two, to handle the sign bit - let n = least_significant_blocks.len(); - let (mut signs, sign_block_sign) = rayon::join( - || { - self.unchecked_scalar_block_slice_compare_parallelized( - &least_significant_blocks[..n - 1], - &scalar_blocks[..n - 1], - ) - }, - || { - let trivial_sign_block = self - .server_key - .key - .create_trivial(scalar_blocks[n - 1] as u64); - self.compare_blocks_with_sign_bit( - &least_significant_blocks[n - 1], - &trivial_sign_block, - ) - }, - ); - - signs.push(sign_block_sign); - self.reduce_signs_parallelized(signs, sign_result_handler_fn) - } - (true, false) => { - // We only have to compare blocks with zero - // (means scalar is zero) - // This also means scalar_block.len() == 0, i.e scalar == 0 - - let n = most_significant_blocks.len(); - let are_all_msb_zeros = self - .server_key - .are_all_blocks_zero(&most_significant_blocks[..n - 1]); - - let sign_bit_pos = self.server_key.key.message_modulus.0.ilog2() - 1; - - let sign_block_ordering_with_respect_to_zero = |sign_block| { - let sign_block = sign_block % self.server_key.key.message_modulus.0 as u64; - let sign_bit_is_set = (sign_block >> sign_bit_pos) == 1; - if sign_bit_is_set { - Self::IS_INFERIOR - } else if sign_block != 0 { - Self::IS_SUPERIOR - } else { - Self::IS_EQUAL - } - }; - - // Fuse multiple LUT into one - // Use the previously defined function to get ordering of sign block - // Define ordering for comparison with zero - // reduce these two ordering to get the final one - let lut = self.server_key.key.generate_lookup_table_bivariate( - |are_all_zeros, sign_block| { - // "re-code" are_all_zeros as an ordering value - let are_all_zeros = if are_all_zeros == 1 { - Self::IS_EQUAL - } else { - Self::IS_SUPERIOR - }; - - let x = (sign_block_ordering_with_respect_to_zero(sign_block) << 2) - + are_all_zeros; - sign_result_handler_fn(reduce_two_orderings_function(x)) - }, - ); - self.server_key.key.unchecked_apply_lookup_table_bivariate( - &are_all_msb_zeros, - &most_significant_blocks[n - 1], - &lut, - ) - } - (true, true) => { - // assert should have been hit earlier - unreachable!("Empty input ciphertext") - } - } - } - - /// This function computes the sign of a unsigned integer ciphertext - /// with a clear value. - /// - /// * lhs_blocks **must** represent positive values - /// * rhs can be positive of negative - fn unsigned_unchecked_scalar_compare_blocks_parallelized( - &self, - lhs_blocks: &[Ciphertext], - rhs: Scalar, - sign_result_handler_fn: F, - ) -> Ciphertext - where - Scalar: DecomposableInto, - F: Fn(u64) -> u64, - { - assert!(!lhs_blocks.is_empty()); - - if rhs < Scalar::ZERO { - // lhs_blocks represent an unsigned (always >= 0) - return self - .server_key - .key - .create_trivial(sign_result_handler_fn(Self::IS_SUPERIOR)); - } - - let message_modulus = self.server_key.key.message_modulus.0; - - let mut scalar_blocks = - BlockDecomposer::with_early_stop_at_zero(rhs, message_modulus.ilog2()) - .iter_as::() - .map(|x| x as u8) - .collect::>(); - - // scalar is obviously bigger if it has non-zero - // blocks after lhs's last block - let is_scalar_obviously_bigger = scalar_blocks - .get(lhs_blocks.len()..) - .is_some_and(|sub_slice| sub_slice.iter().any(|&scalar_block| scalar_block != 0)); - if is_scalar_obviously_bigger { - return self - .server_key - .key - .create_trivial(sign_result_handler_fn(Self::IS_INFERIOR)); - } - // If we are sill here, that means scalar_blocks above - // num_blocks are 0s, we can remove them - // as we will handle them separately. - scalar_blocks.truncate(lhs_blocks.len()); - - let (least_significant_blocks, most_significant_blocks) = - lhs_blocks.split_at(scalar_blocks.len()); - - // Reducing the signs is the bottleneck of the comparison algorithms, - // however if the scalar case there is an improvement: - // - // The idea is to reduce the number of signs block we have to - // reduce. We can do that by splitting the comparison problem in two parts. - // - // - One part where we compute the signs block between the scalar with just enough blocks - // from the ciphertext that can represent the scalar value - // - // - The other part is to compare the ciphertext blocks not considered for the sign - // computation with zero, and create a single sign block from that. - // - // The smaller the scalar value is compared to the ciphertext num bits encrypted, - // the more the comparisons with zeros we have to do, - // and the less signs block we will have to reduce. - // - // This will create a speedup as comparing a bunch of blocks with 0 - // is faster - - match ( - least_significant_blocks.is_empty(), - most_significant_blocks.is_empty(), - ) { - (false, false) => { - // We have to handle both part of the work described above - - let (lsb_sign, are_all_msb_equal_to_zero) = rayon::join( - || { - let lsb_signs = self.unchecked_scalar_block_slice_compare_parallelized( - least_significant_blocks, - &scalar_blocks, - ); - self.reduce_signs_parallelized(lsb_signs, |x| x) - }, - || self.server_key.are_all_blocks_zero(most_significant_blocks), - ); - - // Reduce the two blocks into one final - let lut = self - .server_key - .key - .generate_lookup_table_bivariate(|lsb, msb| { - // "re-code" are_all_msb_equal_to_zero as an ordering value - let msb = if msb == 1 { - Self::IS_EQUAL - } else { - Self::IS_SUPERIOR - }; - - let x = (msb << 2) + lsb; - let final_sign = reduce_two_orderings_function(x); - sign_result_handler_fn(final_sign) - }); - - self.server_key.key.unchecked_apply_lookup_table_bivariate( - &lsb_sign, - &are_all_msb_equal_to_zero, - &lut, - ) - } - (false, true) => { - // We only have to do the regular comparison - // And not the part where we compare most significant blocks with zeros - let signs = self.unchecked_scalar_block_slice_compare_parallelized( - least_significant_blocks, - &scalar_blocks, - ); - self.reduce_signs_parallelized(signs, sign_result_handler_fn) - } - (true, false) => { - // We only have to compare blocks with zero - // means scalar is zero - let are_all_msb_equal_to_zero = - self.server_key.are_all_blocks_zero(most_significant_blocks); - let lut = self.server_key.key.generate_lookup_table(|x| { - // "re-code" x as an ordering value - let x = if x == 1 { - Self::IS_EQUAL - } else { - Self::IS_SUPERIOR - }; - sign_result_handler_fn(x) - }); - self.server_key - .key - .apply_lookup_table(&are_all_msb_equal_to_zero, &lut) - } - (true, true) => { - // assert should have been hit earlier - unreachable!("Empty input ciphertext") - } - } - } - fn smart_compare( &self, lhs: &mut T, @@ -1016,43 +419,6 @@ impl<'a> Comparator<'a> { T::from_blocks(result) } - fn unchecked_scalar_min_or_max_parallelized( - &self, - lhs: &T, - rhs: Scalar, - selector: MinMaxSelector, - ) -> T - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - let sign = self.unchecked_scalar_compare_parallelized(lhs, rhs, |x| x); - let rhs = self - .server_key - .create_trivial_radix(rhs, lhs.blocks().len()); - let do_clean_message = true; - match selector { - MinMaxSelector::Max => self - .server_key - .unchecked_programmable_if_then_else_parallelized( - &sign, - lhs, - &rhs, - |sign| sign == Self::IS_SUPERIOR, - do_clean_message, - ), - MinMaxSelector::Min => self - .server_key - .unchecked_programmable_if_then_else_parallelized( - &sign, - lhs, - &rhs, - |sign| sign == Self::IS_INFERIOR, - do_clean_message, - ), - } - } - fn smart_min_or_max(&self, lhs: &mut T, rhs: &mut T, selector: MinMaxSelector) -> T where T: IntegerRadixCiphertext, @@ -1120,52 +486,6 @@ impl<'a> Comparator<'a> { self.unchecked_min_or_max(lhs, rhs, MinMaxSelector::Min) } - //====================================== - // Unchecked Multi-Threaded operations - //====================================== - - pub fn unchecked_gt_parallelized(&self, lhs: &T, rhs: &T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.unchecked_gt_parallelized(lhs, rhs) - } - - pub fn unchecked_ge_parallelized(&self, lhs: &T, rhs: &T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.unchecked_ge_parallelized(lhs, rhs) - } - - pub fn unchecked_lt_parallelized(&self, lhs: &T, rhs: &T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.unchecked_lt_parallelized(lhs, rhs) - } - - pub fn unchecked_le_parallelized(&self, lhs: &T, rhs: &T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.unchecked_le_parallelized(lhs, rhs) - } - - pub fn unchecked_max_parallelized(&self, lhs: &T, rhs: &T) -> T - where - T: IntegerRadixCiphertext, - { - self.server_key.unchecked_max_parallelized(lhs, rhs) - } - - pub fn unchecked_min_parallelized(&self, lhs: &T, rhs: &T) -> T - where - T: IntegerRadixCiphertext, - { - self.server_key.unchecked_min_parallelized(lhs, rhs) - } - //====================================== // Smart Single-Threaded operations //====================================== @@ -1219,348 +539,4 @@ impl<'a> Comparator<'a> { { self.smart_min_or_max(lhs, rhs, MinMaxSelector::Min) } - - //====================================== - // Smart Multi-Threaded operations - //====================================== - - pub fn smart_gt_parallelized(&self, lhs: &mut T, rhs: &mut T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.smart_gt_parallelized(lhs, rhs) - } - - pub fn smart_ge_parallelized(&self, lhs: &mut T, rhs: &mut T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.smart_ge_parallelized(lhs, rhs) - } - - pub fn smart_lt_parallelized(&self, lhs: &mut T, rhs: &mut T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.smart_lt_parallelized(lhs, rhs) - } - - pub fn smart_le_parallelized(&self, lhs: &mut T, rhs: &mut T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.smart_le_parallelized(lhs, rhs) - } - - pub fn smart_max_parallelized(&self, lhs: &mut T, rhs: &mut T) -> T - where - T: IntegerRadixCiphertext, - { - self.server_key.smart_max_parallelized(lhs, rhs) - } - - pub fn smart_min_parallelized(&self, lhs: &mut T, rhs: &mut T) -> T - where - T: IntegerRadixCiphertext, - { - self.server_key.smart_min_parallelized(lhs, rhs) - } - - //====================================== - // "Default" Multi-Threaded operations - //====================================== - - pub fn gt_parallelized(&self, lhs: &T, rhs: &T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.gt_parallelized(lhs, rhs) - } - - pub fn ge_parallelized(&self, lhs: &T, rhs: &T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.ge_parallelized(lhs, rhs) - } - - pub fn lt_parallelized(&self, lhs: &T, rhs: &T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.lt_parallelized(lhs, rhs) - } - - pub fn le_parallelized(&self, lhs: &T, rhs: &T) -> BooleanBlock - where - T: IntegerRadixCiphertext, - { - self.server_key.le_parallelized(lhs, rhs) - } - - pub fn max_parallelized(&self, lhs: &T, rhs: &T) -> T - where - T: IntegerRadixCiphertext, - { - self.server_key.max_parallelized(lhs, rhs) - } - - pub fn min_parallelized(&self, lhs: &T, rhs: &T) -> T - where - T: IntegerRadixCiphertext, - { - self.server_key.min_parallelized(lhs, rhs) - } - - //=========================================== - // Unchecked Scalar Multi-Threaded operations - //=========================================== - - /// This functions calls the unchecked comparison function - /// which returns whether lhs is inferior, equal or greater than rhs, - /// and maps the result to a homomorphic bool value (0 or 1) using the provided function. - pub fn unchecked_scalar_compare_parallelized_handler( - &self, - lhs: &T, - rhs: Scalar, - sign_result_handler_fn: F, - ) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - F: Fn(u64) -> u64 + Sync, - { - let comparison = - self.unchecked_scalar_compare_parallelized(lhs, rhs, sign_result_handler_fn); - BooleanBlock::new_unchecked(comparison) - } - - pub fn unchecked_scalar_gt_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.unchecked_scalar_compare_parallelized_handler(lhs, rhs, |x| { - u64::from(x == Self::IS_SUPERIOR) - }) - } - - pub fn unchecked_scalar_ge_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.unchecked_scalar_compare_parallelized_handler(lhs, rhs, |x| { - u64::from(x == Self::IS_SUPERIOR || x == Self::IS_EQUAL) - }) - } - - pub fn unchecked_scalar_lt_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.unchecked_scalar_compare_parallelized_handler(lhs, rhs, |x| { - u64::from(x == Self::IS_INFERIOR) - }) - } - - pub fn unchecked_scalar_le_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.unchecked_scalar_compare_parallelized_handler(lhs, rhs, |x| { - u64::from(x == Self::IS_INFERIOR || x == Self::IS_EQUAL) - }) - } - - pub fn unchecked_scalar_max_parallelized(&self, lhs: &T, rhs: Scalar) -> T - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.unchecked_scalar_min_or_max_parallelized(lhs, rhs, MinMaxSelector::Max) - } - - pub fn unchecked_scalar_min_parallelized(&self, lhs: &T, rhs: Scalar) -> T - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.unchecked_scalar_min_or_max_parallelized(lhs, rhs, MinMaxSelector::Min) - } - - //======================================= - // Smart Scalar Multi-Threaded operations - //======================================= - - fn smart_scalar_compare_parallelized( - &self, - lhs: &mut T, - rhs: Scalar, - sign_result_handler_fn: F, - ) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - F: Fn(u64) -> u64 + Sync, - { - if !lhs.block_carries_are_empty() { - self.server_key.full_propagate_parallelized(lhs); - } - self.unchecked_scalar_compare_parallelized_handler(lhs, rhs, sign_result_handler_fn) - } - - pub fn smart_scalar_gt_parallelized(&self, lhs: &mut T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.smart_scalar_compare_parallelized(lhs, rhs, |x| u64::from(x == Self::IS_SUPERIOR)) - } - - pub fn smart_scalar_ge_parallelized(&self, lhs: &mut T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.smart_scalar_compare_parallelized(lhs, rhs, |x| { - u64::from(x == Self::IS_SUPERIOR || x == Self::IS_EQUAL) - }) - } - - pub fn smart_scalar_lt_parallelized(&self, lhs: &mut T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.smart_scalar_compare_parallelized(lhs, rhs, |x| u64::from(x == Self::IS_INFERIOR)) - } - - pub fn smart_scalar_le_parallelized(&self, lhs: &mut T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.smart_scalar_compare_parallelized(lhs, rhs, |x| { - u64::from(x == Self::IS_INFERIOR || x == Self::IS_EQUAL) - }) - } - - pub fn smart_scalar_max_parallelized(&self, lhs: &mut T, rhs: Scalar) -> T - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - if !lhs.block_carries_are_empty() { - self.server_key.full_propagate_parallelized(lhs); - } - self.unchecked_scalar_min_or_max_parallelized(lhs, rhs, MinMaxSelector::Max) - } - - pub fn smart_scalar_min_parallelized(&self, lhs: &mut T, rhs: Scalar) -> T - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - if !lhs.block_carries_are_empty() { - self.server_key.full_propagate_parallelized(lhs); - } - self.unchecked_scalar_min_or_max_parallelized(lhs, rhs, MinMaxSelector::Min) - } - - //====================================== - // "Default" Scalar Multi-Threaded operations - //====================================== - - fn default_scalar_compare_parallelized( - &self, - lhs: &T, - rhs: Scalar, - sign_result_handler_fn: F, - ) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - F: Fn(u64) -> u64 + Sync, - { - let mut tmp_lhs; - let lhs = if lhs.block_carries_are_empty() { - lhs - } else { - tmp_lhs = lhs.clone(); - self.server_key.full_propagate_parallelized(&mut tmp_lhs); - &tmp_lhs - }; - self.unchecked_scalar_compare_parallelized_handler(lhs, rhs, sign_result_handler_fn) - } - - pub fn scalar_gt_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.default_scalar_compare_parallelized(lhs, rhs, |x| u64::from(x == Self::IS_SUPERIOR)) - } - - pub fn scalar_ge_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.default_scalar_compare_parallelized(lhs, rhs, |x| { - u64::from(x == Self::IS_SUPERIOR || x == Self::IS_EQUAL) - }) - } - - pub fn scalar_lt_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.default_scalar_compare_parallelized(lhs, rhs, |x| u64::from(x == Self::IS_INFERIOR)) - } - - pub fn scalar_le_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - self.default_scalar_compare_parallelized(lhs, rhs, |x| { - u64::from(x == Self::IS_INFERIOR || x == Self::IS_EQUAL) - }) - } - - pub fn scalar_max_parallelized(&self, lhs: &T, rhs: Scalar) -> T - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - let mut tmp_lhs; - let lhs = if lhs.block_carries_are_empty() { - lhs - } else { - tmp_lhs = lhs.clone(); - self.server_key.full_propagate_parallelized(&mut tmp_lhs); - &tmp_lhs - }; - self.unchecked_scalar_min_or_max_parallelized(lhs, rhs, MinMaxSelector::Max) - } - - pub fn scalar_min_parallelized(&self, lhs: &T, rhs: Scalar) -> T - where - T: IntegerRadixCiphertext, - Scalar: DecomposableInto, - { - let mut tmp_lhs; - let lhs = if lhs.block_carries_are_empty() { - lhs - } else { - tmp_lhs = lhs.clone(); - self.server_key.full_propagate_parallelized(&mut tmp_lhs); - &tmp_lhs - }; - self.unchecked_scalar_min_or_max_parallelized(lhs, rhs, MinMaxSelector::Min) - } } diff --git a/tfhe/src/integer/server_key/radix_parallel/cmux.rs b/tfhe/src/integer/server_key/radix_parallel/cmux.rs index 49edb0bda7..16f49b7599 100644 --- a/tfhe/src/integer/server_key/radix_parallel/cmux.rs +++ b/tfhe/src/integer/server_key/radix_parallel/cmux.rs @@ -169,16 +169,15 @@ impl ServerKeyDefaultCMux<&BooleanBlock, &BooleanBlock> for ServerKey { }); let negated_cond = self.boolean_bitnot(condition); - let (mut lhs, rhs) = rayon::join( || { - let mut block = self.key.scalar_mul(&condition.0, 2); + let mut block = self.key.unchecked_scalar_mul(&condition.0, 2); self.key.unchecked_add_assign(&mut block, &true_ct.0); self.key.apply_lookup_table_assign(&mut block, &zero_lut); block }, || { - let mut block = self.key.scalar_mul(&negated_cond.0, 2); + let mut block = self.key.unchecked_scalar_mul(&negated_cond.0, 2); self.key.unchecked_add_assign(&mut block, &false_ct.0); self.key.apply_lookup_table_assign(&mut block, &zero_lut); block diff --git a/tfhe/src/integer/server_key/radix_parallel/comparison.rs b/tfhe/src/integer/server_key/radix_parallel/comparison.rs index b23ce60562..965f473e0a 100644 --- a/tfhe/src/integer/server_key/radix_parallel/comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/comparison.rs @@ -2,30 +2,43 @@ use super::ServerKey; use crate::core_crypto::prelude::UnsignedInteger; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::shortint::ciphertext::NoiseLevel; +use crate::integer::prelude::ServerKeyDefaultCMux; use crate::shortint::{Ciphertext, MessageModulus}; use rayon::prelude::*; #[derive(Debug, Copy, Clone)] -enum ComparisonKind { +pub(crate) enum ComparisonKind { Less, LessOrEqual, Greater, GreaterOrEqual, } -/// Given the last block of 2 _signed_ numbers x and y, and a borrow (0 or 1) +/// This blocks contains part of the information necessary to conclude, for signed ciphertext +/// it just needs some input borrow +/// +/// There are 2 possibilities: /// -/// Requires MessageModulus > 2 +/// * If a block can encrypt at least 4 bits (carry + msg) then the block contains information that +/// will allow determining the result of x < y in one PBS +/// * Otherwise the information is split in 2 blocks and a cmux will be required later +pub(crate) enum PreparedSignedCheck { + // The information could be coded on 1 block + // because a block store at least 4 bits of information + Unified(Ciphertext), + // The information had to be split + Split((Ciphertext, Ciphertext)), +} + +/// Given the last block of 2 _signed_ numbers x and y, and a borrow (0 or 1) /// /// returns whether x < y -fn is_x_less_than_y_given_input_borrow( +pub(crate) fn is_x_less_than_y_given_input_borrow( last_x_block: u64, last_y_block: u64, borrow: u64, message_modulus: MessageModulus, ) -> u64 { - assert!(message_modulus.0 > 2, "This requires MessageModulus > 2"); let last_bit_pos = message_modulus.0.ilog2() - 1; let mask = (1 << last_bit_pos) - 1; @@ -270,7 +283,7 @@ impl ServerKey { // group borrows and simulator of last block let ( - (mut group_borrows, use_sequential_algorithm_to_resolved_grouping_carries), + (group_borrows, use_sequential_algorithm_to_resolve_grouping_carries), maybe_prepared_signed_check, ) = rayon::join( || { @@ -301,205 +314,163 @@ impl ServerKey { (b1 << 1 | b0) << 2 }); - Some(self.key.apply_lookup_table_bivariate( - lhs.blocks().last().unwrap(), - rhs.blocks().last().unwrap(), - &lut, + Some(PreparedSignedCheck::Unified( + self.key.apply_lookup_table_bivariate( + lhs.blocks().last().unwrap(), + rhs.blocks().last().unwrap(), + &lut, + ), )) } else if T::IS_SIGNED { - // When we have just 2 bits (message and carry included) - // we will have to do more work. - // This step is preparing a block that will be used to compute the output borrow - // of the whole subtraction - let message_modulus = self.message_modulus().0 as u64; - let lut = self.key.generate_lookup_table_bivariate(|x, y| { - let value = x.wrapping_sub(y).wrapping_add(message_modulus); - - #[allow(clippy::comparison_chain)] - if value < message_modulus { - 2 << 1 - } else if value == message_modulus { - 1 << 1 - } else { - 0 - } - }); - - Some(self.key.apply_lookup_table_bivariate( - lhs.blocks().last().unwrap(), - rhs.blocks().last().unwrap(), - &lut, - )) + Some(PreparedSignedCheck::Split(rayon::join( + || { + let lut = self.key.generate_lookup_table_bivariate(|x, y| { + is_x_less_than_y_given_input_borrow(x, y, 1, self.message_modulus()) + }); + self.key.apply_lookup_table_bivariate( + lhs.blocks().last().unwrap(), + rhs.blocks().last().unwrap(), + &lut, + ) + }, + || { + let lut = self.key.generate_lookup_table_bivariate(|x, y| { + is_x_less_than_y_given_input_borrow(x, y, 0, self.message_modulus()) + }); + self.key.apply_lookup_table_bivariate( + lhs.blocks().last().unwrap(), + rhs.blocks().last().unwrap(), + &lut, + ) + }, + ))) } else { None } }, ); - // This blocks contains part of the information necessary to conclude, it just needs - // some input borrow - // There are 3 possibilities: - // - // * If the ciphertext is unsigned, it contains the information that will allow determining - // the output borrow - // * If the ciphertext is signed and a block can encrypt at least 4 bits (carry + msg) then - // the block contains information that will allow determining the result of x < y in one - // PBS - // * If the ciphertext is signed and a block can encrypt 2 bits (msg + carry) then the block - // contains information that will allow computing the output borrow, which will then be - // used to get the overflow flag then the final result - let mut result_block = group_borrows.pop().unwrap(); - if let Some(block) = maybe_prepared_signed_check { - self.key.unchecked_add_assign(&mut result_block, &block); - } + self.finish_comparison( + group_borrows, + grouping_size, + use_sequential_algorithm_to_resolve_grouping_carries, + maybe_prepared_signed_check, + invert_subtraction_result, + ) + } + + pub(crate) fn finish_comparison( + &self, + mut group_borrows: Vec, + grouping_size: usize, + use_sequential_algorithm_to_resolve_grouping_carries: bool, + maybe_prepared_signed_check: Option, + invert_result: bool, + ) -> BooleanBlock { + let mut last_group_borrow_state = group_borrows.pop().unwrap(); // Third step: resolving borrow propagation between the groups let resolved_borrows = if group_borrows.is_empty() { // There was only one group, and the borrow generated by this group // has already been added to the `overflow_block`, just earlier - if T::IS_SIGNED { + if maybe_prepared_signed_check.is_some() { // There is still one step to determine the result of the comparison // being done further down. // It will require an input borrow for the last group // which is 0 here because there was only one group thus, // the last group is the same as the first group, // and the input borrow of the first group is 0 - vec![self.key.create_trivial(0)] + vec![] } else { // When unsigned, the result is already known at this point - return BooleanBlock::new_unchecked(result_block); + return BooleanBlock::new_unchecked(last_group_borrow_state); } - } else if use_sequential_algorithm_to_resolved_grouping_carries { + } else if use_sequential_algorithm_to_resolve_grouping_carries { self.resolve_carries_of_groups_sequentially(group_borrows, grouping_size) } else { self.resolve_carries_of_groups_using_hillis_steele(group_borrows) }; - if T::IS_SIGNED && self.message_modulus().0 > 2 { - // For signed numbers its less direct to do lhs < rhs using subtraction - // fortunately when we have at least 4 bits we can encode all the needed information - // in one block and conclude in 1 PBS - self.key - .unchecked_add_assign(&mut result_block, resolved_borrows.last().unwrap()); - let lut = self.key.generate_lookup_table(|block| { - // If `resolved_borrows.len() == 1`, then group_borrows was empty, - // This means 2 things: - // * The overflow block already contains the borrow - // * But the position of the borrow is one less bit further - let index = if resolved_borrows.len() == 1 { 0 } else { 1 }; - let input_borrow = (block >> index) & 1; - - // Here, depending on the input borrow, we retrieve - // the bit that tells us if lhs < rhs - let r = if input_borrow == 1 { - (block >> 3) & 1 - } else { - (block >> 2) & 1 - }; - u64::from(invert_subtraction_result) ^ r - }); - - self.key.apply_lookup_table_assign(&mut result_block, &lut); - - BooleanBlock::new_unchecked(result_block) - } else if T::IS_SIGNED { - // Here, message_modulus == 2 (1 bit of message), 2 bits in a block - // Se we don't have enough bits to store all the needed stuff, thus - // we have to do a few more PBS to get the result of lhs < rhs + match maybe_prepared_signed_check { + None => { + // For unsigned numbers, if the last block borrows, then the subtraction + // overflowed, which directly means lhs < rhs + self.key.unchecked_add_assign( + &mut last_group_borrow_state, + // For unsigned, we know that if we are here, + // resolved_borrows is not empty + resolved_borrows.last().unwrap(), + ); + let lut = self.key.generate_lookup_table(|block| { + let overflowed = (block >> 1) & 1; + u64::from(invert_result) ^ overflowed + }); - let input_borrow = resolved_borrows.last().unwrap(); - let (mut shifted_output_borrow, mut new_sign_bit) = rayon::join( - || { - self.key - .unchecked_add_assign(&mut result_block, input_borrow); - if resolved_borrows.len() == 1 { - // There was one group, so the input borrow is not properly positioned - // for the next steps to work, so we add the clear value 1, this will - // push the borrow bit if there was one - self.key.unchecked_scalar_add_assign(&mut result_block, 1); - } + self.key + .apply_lookup_table_assign(&mut last_group_borrow_state, &lut); - // This exploits the fact that the padding of the input bit will be set if - // a borrow is generated, the lut always returns -1: - // If the padding bit is set: it will return -(-1) = 1 - // If it's not set: it will return -1 - // - // We then add 1, so the possible values are: - // * 2 if a borrow was generated - // * 0 otherwise - // - // We use the fact that the borrow bit is at index 1 a bit later - let lut = self.key.generate_lookup_table(|_| { - // return -1 coded on 3 bits (1 message, 1 carry, 1 padding) - 0b111 - }); - let mut shifted_output_borrow = - self.key.apply_lookup_table(&result_block, &lut); + BooleanBlock::new_unchecked(last_group_borrow_state) + } + Some(PreparedSignedCheck::Unified(ct)) => { + // For signed numbers its less direct to do lhs < rhs using subtraction + // fortunately when we have at least 4 bits we can encode all the needed information + // in one block and conclude in 1 PBS + if let Some(input_borrow) = resolved_borrows.last() { self.key - .unchecked_scalar_add_assign(&mut shifted_output_borrow, 1); - shifted_output_borrow - }, - || { - let mut sub_of_last_blocks = sub_blocks.last().cloned().unwrap(); - crate::core_crypto::prelude::lwe_ciphertext_sub_assign( - &mut sub_of_last_blocks.ct, - &input_borrow.ct, - ); - // Degree does not change as we do a subtraction, so worst case we subtract 0 - // which does not change the degree - sub_of_last_blocks - .set_noise_level(sub_of_last_blocks.noise_level + input_borrow.noise_level); - self.key.message_extract_assign(&mut sub_of_last_blocks); - sub_of_last_blocks - }, - ); - - let overflow_flag_lut = self.key.generate_lookup_table(|x| { - let output_borrow = (x >> 1) & 1; - let input_borrow = x & 1; + .unchecked_add_assign(&mut last_group_borrow_state, input_borrow); + } - input_borrow ^ output_borrow - }); - self.key - .unchecked_add_assign(&mut shifted_output_borrow, input_borrow); - self.key - .apply_lookup_table_assign(&mut shifted_output_borrow, &overflow_flag_lut); - let overflow_flag = shifted_output_borrow; // Rename - - // Since blocks have one bit of message, the new last block is also the new sign bit - let lut = self - .key - .generate_lookup_table_bivariate(|new_sign_bit, overflow_flag| { - u64::from(invert_subtraction_result) ^ (new_sign_bit ^ overflow_flag) + self.key + .unchecked_add_assign(&mut last_group_borrow_state, &ct); + let lut = self.key.generate_lookup_table(|block| { + // The overflow block already contains the borrow, + // but the position of the borrow is one less bit further + let index = if resolved_borrows.is_empty() { 0 } else { 1 }; + let input_borrow = (block >> index) & 1; + + // Here, depending on the input borrow, we retrieve + // the bit that tells us if lhs < rhs + let r = if input_borrow == 1 { + (block >> 3) & 1 + } else { + (block >> 2) & 1 + }; + u64::from(invert_result) ^ r }); - assert!(new_sign_bit.noise_level <= NoiseLevel::NOMINAL); - assert!(overflow_flag.noise_level <= NoiseLevel::NOMINAL); - self.key.unchecked_apply_lookup_table_bivariate_assign( - &mut new_sign_bit, - &overflow_flag, - &lut, - ); - - BooleanBlock::new_unchecked(new_sign_bit) - } else { - // For unsigned numbers, if the last block borrows, then the subtraction - // overflowed, which directly means lhs < rhs - self.key - .unchecked_add_assign(&mut result_block, resolved_borrows.last().unwrap()); - let lut = self.key.generate_lookup_table(|block| { - let overflowed = (block >> 1) & 1; - u64::from(invert_subtraction_result) ^ overflowed - }); + self.key + .apply_lookup_table_assign(&mut last_group_borrow_state, &lut); - self.key.apply_lookup_table_assign(&mut result_block, &lut); + BooleanBlock::new_unchecked(last_group_borrow_state) + } + Some(PreparedSignedCheck::Split((if_input_borrow_is_1, if_input_borrow_is_0))) => { + if let Some(input_borrow) = resolved_borrows.last() { + self.key + .unchecked_add_assign(&mut last_group_borrow_state, input_borrow); + let lut = self.key.generate_lookup_table(|x| (x >> 1) & 1); + self.key + .apply_lookup_table_assign(&mut last_group_borrow_state, &lut); + } - BooleanBlock::new_unchecked(result_block) + let if_input_borrow_is_1 = BooleanBlock::new_unchecked(if_input_borrow_is_1); + let if_input_borrow_is_0 = BooleanBlock::new_unchecked(if_input_borrow_is_0); + let condition = BooleanBlock::new_unchecked(last_group_borrow_state); + let result = self.if_then_else_parallelized( + &condition, + &if_input_borrow_is_1, + &if_input_borrow_is_0, + ); + if invert_result { + self.boolean_bitnot(&result) + } else { + result + } + } } } /// The invert_result boolean is only used when there is one and only one group - fn compute_group_borrow_state( + pub(crate) fn compute_group_borrow_state( &self, invert_result: bool, grouping_size: usize, diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs index 4d712f53f2..6c1a684dcc 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs @@ -1,28 +1,30 @@ use super::ServerKey; +use crate::core_crypto::prelude::{lwe_ciphertext_sub_assign, Numeric}; use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::integer::server_key::comparator::{Comparator, ZeroComparisonType}; +use crate::integer::server_key::comparator::ZeroComparisonType; +use crate::integer::server_key::radix_parallel::comparison::{ + is_x_less_than_y_given_input_borrow, ComparisonKind, PreparedSignedCheck, +}; use crate::shortint::ciphertext::Degree; use crate::shortint::server_key::LookupTableOwned; -use crate::shortint::Ciphertext; +use crate::shortint::{Ciphertext, MessageModulus}; use rayon::prelude::*; impl ServerKey { /// Returns whether the clear scalar is outside of the /// value range the ciphertext can hold. /// - /// - Returns None if the scalar is in the range of values that the ciphertext can represent - /// - /// - Returns Some(ordering) when the scalar is out of representable range of the ciphertext. - /// - Equal will never be returned - /// - Less means the scalar is less than the min value representable by the ciphertext - /// - Greater means the scalar is greater that the max value representable by the ciphertext + /// - Returns an ordering: + /// - Equal means the scalar is in the range of values that the ciphertext can represent + /// - Less means the scalar is less than the min value representable by the ciphertext + /// - Greater means the scalar is greater that the max value representable by the ciphertext pub(crate) fn is_scalar_out_of_bounds( &self, ct: &T, scalar: Scalar, - ) -> Option + ) -> std::cmp::Ordering where T: IntegerRadixCiphertext, Scalar: DecomposableInto, @@ -46,17 +48,17 @@ impl ServerKey { // it means scalar is bigger. // // This is checked in two step - // - If there a more scalar blocks than ct blocks then ct is trivially bigger + // - If there a more scalar blocks than ct blocks then scalar is trivially bigger // - If there are the same number of blocks but the "sign bit" / msb of st scalar is // set then, the scalar is trivially bigger - return Some(std::cmp::Ordering::Greater); + return std::cmp::Ordering::Greater; } else if scalar < Scalar::ZERO { // If scalar is negative, and that any bits above the ct's n-1 bits is not set // it means scalar is smaller. if ct.blocks().len() > scalar_blocks.len() { // Ciphertext has more blocks, the scalar may be in range - return None; + return std::cmp::Ordering::Equal; } // (returns false for empty iter) @@ -71,14 +73,14 @@ impl ServerKey { if at_least_one_block_is_not_full_of_1s || sign_bit_is_unset { // Scalar is smaller than lowest value of T - return Some(std::cmp::Ordering::Less); + return std::cmp::Ordering::Less; } } } else { // T is unsigned if scalar < Scalar::ZERO { // ct represent an unsigned (always >= 0) - return Some(std::cmp::Ordering::Less); + return std::cmp::Ordering::Less; } else if scalar > Scalar::ZERO { // scalar is obviously bigger if it has non-zero // blocks after lhs's last block @@ -88,12 +90,12 @@ impl ServerKey { sub_slice.iter().any(|&scalar_block| scalar_block != 0) }); if is_scalar_obviously_bigger { - return Some(std::cmp::Ordering::Greater); + return std::cmp::Ordering::Greater; } } } - None + std::cmp::Ordering::Equal } /// Takes a chunk of 2 ciphertexts and packs them together in a new ciphertext @@ -394,19 +396,16 @@ impl ServerKey { debug_assert!(lhs.block_carries_are_empty()); if T::IS_SIGNED { - match self.is_scalar_out_of_bounds(lhs, rhs) { - Some(std::cmp::Ordering::Greater | std::cmp::Ordering::Less) => { + return match self.is_scalar_out_of_bounds(lhs, rhs) { + std::cmp::Ordering::Greater | std::cmp::Ordering::Less => { // Scalar is not within bounds so it cannot be equal - return self.create_trivial_boolean_block(false); - } - Some(std::cmp::Ordering::Equal) => { - unreachable!("Internal error: is_scalar_out_of_bounds returned Ordering::Equal") + self.create_trivial_boolean_block(false) } - None => { + std::cmp::Ordering::Equal => { let trivial = self.create_trivial_radix(rhs, lhs.blocks().len()); - return self.unchecked_eq_parallelized(lhs, &trivial); + self.unchecked_eq_parallelized(lhs, &trivial) } - } + }; } // Starting From here, we know lhs (T) is an unsigned ciphertext @@ -490,17 +489,16 @@ impl ServerKey { debug_assert!(lhs.block_carries_are_empty()); if T::IS_SIGNED { - match self.is_scalar_out_of_bounds(lhs, rhs) { - Some(std::cmp::Ordering::Greater | std::cmp::Ordering::Less) => { + return match self.is_scalar_out_of_bounds(lhs, rhs) { + std::cmp::Ordering::Greater | std::cmp::Ordering::Less => { // Scalar is not within bounds so its not equal - return self.create_trivial_boolean_block(true); + self.create_trivial_boolean_block(true) } - Some(std::cmp::Ordering::Equal) => unreachable!("Internal error: invalid value"), - None => { + std::cmp::Ordering::Equal => { let trivial = self.create_trivial_radix(rhs, lhs.blocks().len()); - return self.unchecked_ne_parallelized(lhs, &trivial); + self.unchecked_ne_parallelized(lhs, &trivial) } - } + }; } if rhs < Scalar::ZERO { @@ -638,12 +636,396 @@ impl ServerKey { // Unchecked <, >, <=, >=, min, max //=========================================================== + /// This implements all comparisons (<, <=, >, >=) for both signed and unsigned + /// + /// * inputs must have the same number of blocks + /// * block carries of both inputs must be empty + /// * carry modulus == message modulus + fn scalar_compare(&self, a: &T, b: Scalar, compare: ComparisonKind) -> BooleanBlock + where + T: IntegerRadixCiphertext, + Scalar: Numeric + DecomposableInto, + { + assert!(a.block_carries_are_empty(), "Block carries must be empty"); + assert_eq!( + self.carry_modulus().0, + self.message_modulus().0, + "The carry modulus must be == to the message modulus" + ); + + if a.blocks().is_empty() { + // We interpret empty as 0 + return match compare { + ComparisonKind::Less => self.create_trivial_boolean_block(Scalar::ZERO < b), + ComparisonKind::LessOrEqual => self.create_trivial_boolean_block(Scalar::ZERO < b), + ComparisonKind::Greater => self.create_trivial_boolean_block(Scalar::ZERO > b), + ComparisonKind::GreaterOrEqual => { + self.create_trivial_boolean_block(Scalar::ZERO >= b) + } + }; + } + + match self.is_scalar_out_of_bounds(a, b) { + std::cmp::Ordering::Less => { + // We have that `b < a` trivially + return match compare { + ComparisonKind::Less | ComparisonKind::LessOrEqual => { + // So `a < b` and `a <= b` are false + self.create_trivial_boolean_block(false) + } + ComparisonKind::Greater | ComparisonKind::GreaterOrEqual => { + // So `a > b` and `a >= b` are true + self.create_trivial_boolean_block(true) + } + }; + } + std::cmp::Ordering::Greater => { + // We have that `b > a` trivially + return match compare { + ComparisonKind::Less | ComparisonKind::LessOrEqual => { + // So `a < b` and `a <= b` are true + self.create_trivial_boolean_block(true) + } + ComparisonKind::Greater | ComparisonKind::GreaterOrEqual => { + // So `a > b` and `a >= b` are false + self.create_trivial_boolean_block(false) + } + }; + } + // We have to do the homomorphic algorithm + std::cmp::Ordering::Equal => {} + } + + // Some shortcuts for comparison with zero + if T::IS_SIGNED && b == Scalar::ZERO { + match compare { + ComparisonKind::Less => { + return if self.message_modulus().0 > 2 { + let sign_bit_lut = self.key.generate_lookup_table(|last_block| { + let modulus = self.key.message_modulus.0 as u64; + (last_block % modulus) / (modulus / 2) + }); + let sign_bit = self + .key + .apply_lookup_table(a.blocks().last().unwrap(), &sign_bit_lut); + BooleanBlock::new_unchecked(sign_bit) + } else { + BooleanBlock::new_unchecked(a.blocks().last().cloned().unwrap()) + } + } + ComparisonKind::GreaterOrEqual => { + let mut sign_bit = if self.message_modulus().0 > 2 { + let sign_bit_lut = self.key.generate_lookup_table(|last_block| { + let modulus = self.key.message_modulus.0 as u64; + (last_block % modulus) / (modulus / 2) + }); + let sign_bit = self + .key + .apply_lookup_table(a.blocks().last().unwrap(), &sign_bit_lut); + BooleanBlock::new_unchecked(sign_bit) + } else { + BooleanBlock::new_unchecked(a.blocks().last().cloned().unwrap()) + }; + self.boolean_bitnot_assign(&mut sign_bit); + return sign_bit; + } + ComparisonKind::LessOrEqual | ComparisonKind::Greater => {} + } + } else if !T::IS_SIGNED && b == Scalar::ZERO { + match compare { + ComparisonKind::Less => return self.create_trivial_boolean_block(false), + ComparisonKind::GreaterOrEqual => return self.create_trivial_boolean_block(true), + ComparisonKind::LessOrEqual | ComparisonKind::Greater => {} + } + } + + let packed_modulus = (self.key.message_modulus.0 * self.key.message_modulus.0) as u64; + + // We have that `a < b` <=> `does_sub_overflows(a, b)` and we know how to do this. + // Now, to have other comparisons, we will re-express them as less than (`<`) + // with some potential boolean negation + // + // Note that for signed ciphertext it's not the overflowing sub that is used, + // but it's still something that is based on the subtraction + // + // For both signed and unsigned, a subtraction with borrow is used + // (as opposed to adding the negation) + let num_block_is_even = (a.blocks().len() & 1) == 0; + let a = a + .blocks() + .chunks(2) + .map(|chunk_of_two| self.pack_block_chunk(chunk_of_two)) + .collect::>(); + + let padding_value = (packed_modulus - 1) * u64::from(b < Scalar::ZERO); + let mut b_blocks = BlockDecomposer::new(b, packed_modulus.ilog2()) + .iter_as::() + .chain(std::iter::repeat(padding_value)) + .take(a.len()) + .collect::>(); + + if !num_block_is_even && b < Scalar::ZERO { + let last_index = b_blocks.len() - 1; + // We blindly padded with the ones, but as the num block is not even + // the last packed block high part shall be 0 not 1s (i.e. no padding) + b_blocks[last_index] %= self.message_modulus().0 as u64; + } + + let b = b_blocks; + let block_modulus = packed_modulus; + let num_bits_in_block = block_modulus.ilog2(); + let grouping_size = num_bits_in_block as usize; + + let mut first_grouping_luts = Vec::with_capacity(grouping_size); + let (invert_operands, invert_subtraction_result) = match compare { + // The easiest case, nothing changes + ComparisonKind::Less => (false, false), + // `a <= b` + // <=> `not(b < a)` + // <=> `not(does_sub_overflows(b, a))` + ComparisonKind::LessOrEqual => (true, true), + // `a > b` + // <=> `b < a` + // <=> `does_sub_overflows(b, a)` + ComparisonKind::Greater => (true, false), + // `a >= b` + // <=> `b <= a` + // <=> `not(a < b)` + // <=> `not(does_sub_overflows(a, b))` + ComparisonKind::GreaterOrEqual => (false, true), + }; + + // There is 1 packed block (i.e. there was at most 2 blocks originally) + // we can take shortcut here + if a.len() == 1 { + let lut = if T::IS_SIGNED { + let modulus = if num_block_is_even { + MessageModulus(packed_modulus as usize) + } else { + self.message_modulus() + }; + self.key.generate_lookup_table(|x| { + let (x, y) = if invert_operands { + (b[0], x) + } else { + (x, b[0]) + }; + + u64::from(invert_subtraction_result) + ^ is_x_less_than_y_given_input_borrow(x, y, 0, modulus) + }) + } else { + self.key.generate_lookup_table(|x| { + let (x, y) = if invert_operands { + (b[0], x) + } else { + (x, b[0]) + }; + let overflowed = x < y; + u64::from(invert_subtraction_result ^ overflowed) + }) + }; + let result = self.key.apply_lookup_table(&a[0], &lut); + return BooleanBlock::new_unchecked(result); + } + + // Save some values for later + let first_scalar_block = b[0]; + let last_scalar_block = b[b.len() - 1]; + + let b: Vec<_> = b + .into_iter() + .map(|v| self.key.unchecked_create_trivial(v)) + .collect(); + + let mut sub_blocks = + if invert_operands { + first_grouping_luts.push(self.key.generate_lookup_table(|first_block| { + u64::from(first_scalar_block < first_block) + })); + + b.iter() + .zip(a.iter()) + .map(|(lhs_b, rhs_b)| { + let mut result = lhs_b.clone(); + // We don't want the correcting term + lwe_ciphertext_sub_assign(&mut result.ct, &rhs_b.ct); + result + }) + .collect::>() + } else { + first_grouping_luts.push(self.key.generate_lookup_table(|first_block| { + u64::from(first_block < first_scalar_block) + })); + + a.iter() + .zip(b.iter()) + .map(|(lhs_b, rhs_b)| { + let mut result = lhs_b.clone(); + // We don't want the correcting term + lwe_ciphertext_sub_assign(&mut result.ct, &rhs_b.ct); + result + }) + .collect::>() + }; + + // The first lut, needs the encrypted block of `a`, not the subtraction + // of `a[0]` and `b[0]` + sub_blocks[0].clone_from(&a[0]); + + // We are going to group blocks and compute how each group propagates/generates a borrow + // + // Again, in unsigned representation the output borrow of the whole operation (i.e. the + // borrow generated by the last group) tells us the result of the comparison. For signed + // representation we need to XOR the overflow flag and the sign bit of the result. + let block_states = { + for i in 1..grouping_size { + let state_fn = |block| { + let r = (u64::MAX * u64::from(block != 0)) % (packed_modulus * 2); + r << (i - 1) + }; + first_grouping_luts.push(self.key.generate_lookup_table(state_fn)); + } + + let other_block_state_luts = (0..grouping_size) + .map(|i| { + let state_fn = |block| { + let r = (u64::MAX * u64::from(block != 0)) % (packed_modulus * 2); + r << i + }; + self.key.generate_lookup_table(state_fn) + }) + .collect::>(); + + let block_states = + // With unsigned ciphertexts as, overflow (i.e. does the last block needs to borrow) + // directly translates to lhs < rhs we compute the blocks states for all the blocks + // + // For signed numbers, we need to do something more specific with the last block + // thus, we don't compute the last block state + sub_blocks[..sub_blocks.len() - usize::from(T::IS_SIGNED)] + .par_iter() + .enumerate() + .map(|(index, block)| { + let grouping_index = index / grouping_size; + let is_in_first_grouping = grouping_index == 0; + let index_in_grouping = index % (grouping_size); + + let (luts, corrector) = if is_in_first_grouping { + ( + &first_grouping_luts[index_in_grouping], + if index_in_grouping == 0 { 0 } else { 1 << (index_in_grouping - 1)} + ) + } else { + (&other_block_state_luts[index_in_grouping], 1 << (index_in_grouping)) + }; + + let mut result = self.key.apply_lookup_table(block, luts); + if index > 0 { + self.key.unchecked_scalar_add_assign(&mut result, corrector); + } + result + }) + .collect::>(); + + block_states + }; + + // group borrows and simulator of last block + let ( + (group_borrows, use_sequential_algorithm_to_resolve_grouping_carries), + maybe_prepared_signed_check, + ) = rayon::join( + || { + self.compute_group_borrow_state( + // May only invert if T is not signed + // As when there is only one group, in the unsigned case since overflow + // directly translate to lhs < rhs, we can ask the LUT used to do the + // inversion for us. + // + // In signed case as it's a bit more complex, we never want to + !T::IS_SIGNED && invert_subtraction_result, + grouping_size, + block_states, + ) + }, + || { + // When the ciphertexts are signed, finding whether lhs < rhs by doing a sub + // is less direct than in unsigned where we can check for overflow. + if T::IS_SIGNED && self.message_modulus().0 > 2 { + // Luckily, when the blocks have 4 bits, we can precompute and store in a block + // the 2 possible values for `lhs < rhs` depending on whether the last block + // will be borrowed from. + let modulus = if num_block_is_even { + MessageModulus(packed_modulus as usize) + } else { + self.message_modulus() + }; + let lut = self.key.generate_lookup_table(|last_block| { + let (x, y) = if invert_operands { + (last_scalar_block, last_block) + } else { + (last_block, last_scalar_block) + }; + let b0 = is_x_less_than_y_given_input_borrow(x, y, 0, modulus); + let b1 = is_x_less_than_y_given_input_borrow(x, y, 1, modulus); + (b1 << 1 | b0) << 2 + }); + + Some(PreparedSignedCheck::Unified( + self.key.apply_lookup_table(a.last().unwrap(), &lut), + )) + } else if T::IS_SIGNED { + let modulus = if num_block_is_even { + MessageModulus(packed_modulus as usize) + } else { + self.message_modulus() + }; + Some(PreparedSignedCheck::Split(rayon::join( + || { + let lut = self.key.generate_lookup_table(|last_block| { + let (x, y) = if invert_operands { + (last_scalar_block, last_block) + } else { + (last_block, last_scalar_block) + }; + is_x_less_than_y_given_input_borrow(x, y, 1, modulus) + }); + self.key.apply_lookup_table(a.last().unwrap(), &lut) + }, + || { + let lut = self.key.generate_lookup_table(|last_block| { + let (x, y) = if invert_operands { + (last_scalar_block, last_block) + } else { + (last_block, last_scalar_block) + }; + is_x_less_than_y_given_input_borrow(x, y, 0, modulus) + }); + self.key.apply_lookup_table(a.last().unwrap(), &lut) + }, + ))) + } else { + None + } + }, + ); + + self.finish_comparison( + group_borrows, + grouping_size, + use_sequential_algorithm_to_resolve_grouping_carries, + maybe_prepared_signed_check, + invert_subtraction_result, + ) + } + pub fn unchecked_scalar_gt_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock where T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).unchecked_scalar_gt_parallelized(lhs, rhs) + self.scalar_compare(lhs, rhs, ComparisonKind::Greater) } pub fn unchecked_scalar_ge_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock @@ -651,7 +1033,7 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).unchecked_scalar_ge_parallelized(lhs, rhs) + self.scalar_compare(lhs, rhs, ComparisonKind::GreaterOrEqual) } pub fn unchecked_scalar_lt_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock @@ -659,7 +1041,7 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).unchecked_scalar_lt_parallelized(lhs, rhs) + self.scalar_compare(lhs, rhs, ComparisonKind::Less) } pub fn unchecked_scalar_le_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock @@ -667,7 +1049,7 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).unchecked_scalar_le_parallelized(lhs, rhs) + self.scalar_compare(lhs, rhs, ComparisonKind::LessOrEqual) } pub fn unchecked_scalar_max_parallelized(&self, lhs: &T, rhs: Scalar) -> T @@ -675,7 +1057,38 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).unchecked_scalar_max_parallelized(lhs, rhs) + let is_superior = self.unchecked_scalar_gt_parallelized(lhs, rhs); + let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2()) + .iter_as::() + .chain(std::iter::repeat(if rhs >= Scalar::ZERO { + 0u64 + } else { + self.message_modulus().0 as u64 - 1 + })) + .take(lhs.blocks().len()) + .map(|scalar_block| { + self.key + .generate_lookup_table_bivariate(|is_superior, block| { + if is_superior == 1 { + block + } else { + scalar_block + } + }) + }) + .collect::>(); + + let new_blocks = lhs + .blocks() + .par_iter() + .zip(luts.par_iter()) + .map(|(block, lut)| { + self.key + .unchecked_apply_lookup_table_bivariate(&is_superior.0, block, lut) + }) + .collect::>(); + + T::from(new_blocks) } pub fn unchecked_scalar_min_parallelized(&self, lhs: &T, rhs: Scalar) -> T @@ -683,7 +1096,38 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).unchecked_scalar_min_parallelized(lhs, rhs) + let is_inferior = self.unchecked_scalar_lt_parallelized(lhs, rhs); + let luts = BlockDecomposer::new(rhs, self.message_modulus().0.ilog2()) + .iter_as::() + .chain(std::iter::repeat(if rhs >= Scalar::ZERO { + 0u64 + } else { + self.message_modulus().0 as u64 - 1 + })) + .take(lhs.blocks().len()) + .map(|scalar_block| { + self.key + .generate_lookup_table_bivariate(|is_inferior, block| { + if is_inferior == 1 { + block + } else { + scalar_block + } + }) + }) + .collect::>(); + + let new_blocks = lhs + .blocks() + .par_iter() + .zip(luts.par_iter()) + .map(|(block, lut)| { + self.key + .unchecked_apply_lookup_table_bivariate(&is_inferior.0, block, lut) + }) + .collect::>(); + + T::from(new_blocks) } //=========================================================== @@ -695,7 +1139,11 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).smart_scalar_gt_parallelized(lhs, rhs) + if !lhs.block_carries_are_empty() { + self.full_propagate_parallelized(lhs); + } + + self.unchecked_scalar_gt_parallelized(lhs, rhs) } pub fn smart_scalar_ge_parallelized(&self, lhs: &mut T, rhs: Scalar) -> BooleanBlock @@ -703,7 +1151,11 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).smart_scalar_ge_parallelized(lhs, rhs) + if !lhs.block_carries_are_empty() { + self.full_propagate_parallelized(lhs); + } + + self.unchecked_scalar_ge_parallelized(lhs, rhs) } pub fn smart_scalar_lt_parallelized(&self, lhs: &mut T, rhs: Scalar) -> BooleanBlock @@ -711,7 +1163,11 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).smart_scalar_lt_parallelized(lhs, rhs) + if !lhs.block_carries_are_empty() { + self.full_propagate_parallelized(lhs); + } + + self.unchecked_scalar_lt_parallelized(lhs, rhs) } pub fn smart_scalar_le_parallelized(&self, lhs: &mut T, rhs: Scalar) -> BooleanBlock @@ -719,7 +1175,11 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).smart_scalar_le_parallelized(lhs, rhs) + if !lhs.block_carries_are_empty() { + self.full_propagate_parallelized(lhs); + } + + self.unchecked_scalar_le_parallelized(lhs, rhs) } pub fn smart_scalar_max_parallelized(&self, lhs: &mut T, rhs: Scalar) -> T @@ -727,7 +1187,11 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).smart_scalar_max_parallelized(lhs, rhs) + if !lhs.block_carries_are_empty() { + self.full_propagate_parallelized(lhs); + } + + self.unchecked_scalar_max_parallelized(lhs, rhs) } pub fn smart_scalar_min_parallelized(&self, lhs: &mut T, rhs: Scalar) -> T @@ -735,7 +1199,11 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).smart_scalar_min_parallelized(lhs, rhs) + if !lhs.block_carries_are_empty() { + self.full_propagate_parallelized(lhs); + } + + self.unchecked_scalar_min_parallelized(lhs, rhs) } //=========================================================== @@ -747,7 +1215,15 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).scalar_gt_parallelized(lhs, rhs) + let mut tmp_lhs; + let lhs = if lhs.block_carries_are_empty() { + lhs + } else { + tmp_lhs = lhs.clone(); + self.full_propagate_parallelized(&mut tmp_lhs); + &tmp_lhs + }; + self.unchecked_scalar_gt_parallelized(lhs, rhs) } pub fn scalar_ge_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock @@ -755,7 +1231,15 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).scalar_ge_parallelized(lhs, rhs) + let mut tmp_lhs; + let lhs = if lhs.block_carries_are_empty() { + lhs + } else { + tmp_lhs = lhs.clone(); + self.full_propagate_parallelized(&mut tmp_lhs); + &tmp_lhs + }; + self.unchecked_scalar_ge_parallelized(lhs, rhs) } pub fn scalar_lt_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock @@ -763,7 +1247,15 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).scalar_lt_parallelized(lhs, rhs) + let mut tmp_lhs; + let lhs = if lhs.block_carries_are_empty() { + lhs + } else { + tmp_lhs = lhs.clone(); + self.full_propagate_parallelized(&mut tmp_lhs); + &tmp_lhs + }; + self.unchecked_scalar_lt_parallelized(lhs, rhs) } pub fn scalar_le_parallelized(&self, lhs: &T, rhs: Scalar) -> BooleanBlock @@ -771,7 +1263,15 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).scalar_le_parallelized(lhs, rhs) + let mut tmp_lhs; + let lhs = if lhs.block_carries_are_empty() { + lhs + } else { + tmp_lhs = lhs.clone(); + self.full_propagate_parallelized(&mut tmp_lhs); + &tmp_lhs + }; + self.unchecked_scalar_le_parallelized(lhs, rhs) } pub fn scalar_max_parallelized(&self, lhs: &T, rhs: Scalar) -> T @@ -779,7 +1279,15 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).scalar_max_parallelized(lhs, rhs) + let mut tmp_lhs; + let lhs = if lhs.block_carries_are_empty() { + lhs + } else { + tmp_lhs = lhs.clone(); + self.full_propagate_parallelized(&mut tmp_lhs); + &tmp_lhs + }; + self.unchecked_scalar_max_parallelized(lhs, rhs) } pub fn scalar_min_parallelized(&self, lhs: &T, rhs: Scalar) -> T @@ -787,6 +1295,14 @@ impl ServerKey { T: IntegerRadixCiphertext, Scalar: DecomposableInto, { - Comparator::new(self).scalar_min_parallelized(lhs, rhs) + let mut tmp_lhs; + let lhs = if lhs.block_carries_are_empty() { + lhs + } else { + tmp_lhs = lhs.clone(); + self.full_propagate_parallelized(&mut tmp_lhs); + &tmp_lhs + }; + self.unchecked_scalar_min_parallelized(lhs, rhs) } } diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_comparison.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_comparison.rs index ebc33b5831..3da9f5c214 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_comparison.rs @@ -267,6 +267,8 @@ macro_rules! define_signed_scalar_comparison_test_functions { create_parametrized_test!([] { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, @@ -279,6 +281,8 @@ macro_rules! define_signed_scalar_comparison_test_functions { create_parametrized_test!([] { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, // We don't use PARAM_MESSAGE_3_CARRY_3_KS_PBS, // as smart test might overflow values @@ -293,6 +297,8 @@ macro_rules! define_signed_scalar_comparison_test_functions { create_parametrized_test!([] { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, // We don't use PARAM_MESSAGE_3_CARRY_3_KS_PBS, // as default test might overflow values @@ -321,19 +327,19 @@ fn integer_signed_is_scalar_out_of_bounds(param: ClassicPBSParameters) { // This one is in range let scalar = I256::from(i128::MAX); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, None); + assert_eq!(res, std::cmp::Ordering::Equal); let scalar = I256::from(i128::MAX) + I256::ONE; let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Greater)); + assert_eq!(res, std::cmp::Ordering::Greater); let scalar = I256::from(i128::MAX) + I256::from(rng.gen_range(2i128..=i128::MAX)); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Greater)); + assert_eq!(res, std::cmp::Ordering::Greater); let scalar = I256::from(i128::MAX) + I256::from(i128::MAX); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Greater)); + assert_eq!(res, std::cmp::Ordering::Greater); } // Negative scalars @@ -341,27 +347,27 @@ fn integer_signed_is_scalar_out_of_bounds(param: ClassicPBSParameters) { // This one is in range let scalar = I256::from(i128::MIN); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, None); + assert_eq!(res, std::cmp::Ordering::Equal); let scalar = I256::from(i128::MIN) - I256::ONE; let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); let scalar = I256::from(i128::MIN) + I256::from(rng.gen_range(i128::MIN..=-2)); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); let scalar = I256::from(i128::MIN) + I256::from(i128::MIN); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); let scalar = I256::from(i128::MIN) - I256::from(rng.gen_range(2..=i128::MAX)); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); let scalar = I256::from(i128::MIN) - I256::from(i128::MAX); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); } } @@ -604,16 +610,19 @@ mod no_coverage { } create_parametrized_test!(integer_signed_unchecked_scalar_max_parallelized_i128 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, PARAM_MESSAGE_4_CARRY_4_KS_PBS }); create_parametrized_test!(integer_signed_unchecked_scalar_min_parallelized_i128 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, PARAM_MESSAGE_4_CARRY_4_KS_PBS }); create_parametrized_test!(integer_signed_smart_scalar_max_parallelized_i128 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // We don't use PARAM_MESSAGE_3_CARRY_3_KS_PBS, // as default test might overflow values @@ -622,6 +631,7 @@ mod no_coverage { PARAM_MESSAGE_4_CARRY_4_KS_PBS }); create_parametrized_test!(integer_signed_smart_scalar_min_parallelized_i128 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // We don't use PARAM_MESSAGE_3_CARRY_3_KS_PBS, // as default test might overflow values @@ -630,6 +640,7 @@ mod no_coverage { PARAM_MESSAGE_4_CARRY_4_KS_PBS }); create_parametrized_test!(integer_signed_scalar_max_parallelized_i128 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // We don't use PARAM_MESSAGE_3_CARRY_3_KS_PBS, // as default test might overflow values @@ -638,6 +649,7 @@ mod no_coverage { PARAM_MESSAGE_4_CARRY_4_KS_PBS }); create_parametrized_test!(integer_signed_scalar_min_parallelized_i128 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // We don't use PARAM_MESSAGE_3_CARRY_3_KS_PBS, // as default test might overflow values @@ -654,6 +666,7 @@ mod no_coverage { define_signed_scalar_comparison_test_functions!(ge, i128); create_parametrized_test!(integer_signed_is_scalar_out_of_bounds { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // We don't use PARAM_MESSAGE_3_CARRY_3_KS_PBS, // as the test relies on the ciphertext to encrypt 128bits @@ -721,3 +734,139 @@ mod coverage { create_parametrized_test_classical_params!(integer_signed_is_scalar_out_of_bounds); } + +create_parametrized_test!(integer_extensive_trivial_signed_default_scalar_comparisons); + +fn integer_extensive_trivial_signed_default_scalar_comparisons(params: impl Into) { + let lt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized); + let le_executor = CpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized); + let gt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized); + let ge_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized); + let min_executor = CpuFunctionExecutor::new(&ServerKey::scalar_min_parallelized); + let max_executor = CpuFunctionExecutor::new(&ServerKey::scalar_max_parallelized); + + extensive_trivial_signed_default_scalar_comparisons_test( + params, + lt_executor, + le_executor, + gt_executor, + ge_executor, + min_executor, + max_executor, + ) +} + +/// Although this uses the executor pattern and could be plugged in other backends, +/// It is not recommended to do so unless the backend is extremely fast on trivial ciphertexts +/// or extremely extremely fast in general, or if its plugged just as a one time thing. +#[allow(clippy::eq_op)] +pub(crate) fn extensive_trivial_signed_default_scalar_comparisons_test( + param: P, + mut lt_executor: E1, + mut le_executor: E2, + mut gt_executor: E3, + mut ge_executor: E4, + mut min_executor: E5, + mut max_executor: E6, +) where + P: Into, + E1: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i128), BooleanBlock>, + E2: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i128), BooleanBlock>, + E3: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i128), BooleanBlock>, + E4: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i128), BooleanBlock>, + E5: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i128), SignedRadixCiphertext>, + E6: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i128), SignedRadixCiphertext>, +{ + let params = param.into(); + let (cks, mut sks) = KEY_CACHE.get_from_params(params, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, 4)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = thread_rng(); + + lt_executor.setup(&cks, sks.clone()); + le_executor.setup(&cks, sks.clone()); + gt_executor.setup(&cks, sks.clone()); + ge_executor.setup(&cks, sks.clone()); + min_executor.setup(&cks, sks.clone()); + max_executor.setup(&cks, sks.clone()); + + for num_blocks in 1..=128 { + let Some(modulus) = (params.message_modulus().0 as i128).checked_pow(num_blocks as u32) + else { + break; + }; + if modulus == 2 { + continue; + } + let modulus = modulus / 2; + for _ in 0..25 { + let clear_a = rng.gen_range(0..modulus); + let clear_b = rng.gen_range(0..modulus); + + let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_a, num_blocks); + + { + let result = lt_executor.execute((&a, clear_b)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a < clear_b, "{clear_a} < {clear_b}"); + + let result = lt_executor.execute((&a, clear_a)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a < clear_a, "{clear_a} < {clear_a}"); + } + + { + let result = le_executor.execute((&a, clear_b)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a <= clear_b, "{clear_a} <= {clear_b}"); + + let result = le_executor.execute((&a, clear_a)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a <= clear_a, "{clear_a} <= {clear_a}"); + } + + { + let result = gt_executor.execute((&a, clear_b)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a > clear_b, "{clear_a} > {clear_b}"); + + let result = gt_executor.execute((&a, clear_a)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a > clear_a, "{clear_a} > {clear_a}"); + } + + { + let result = ge_executor.execute((&a, clear_b)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a >= clear_b, "{clear_a} >= {clear_b}"); + + let result = ge_executor.execute((&a, clear_a)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a >= clear_a, "{clear_a} >= {clear_a}"); + } + + { + let result = min_executor.execute((&a, clear_b)); + let result: i128 = cks.decrypt_signed(&result); + assert_eq!(result, clear_a.min(clear_b), "{clear_a}.min({clear_b})"); + + let result = min_executor.execute((&a, clear_a)); + let result: i128 = cks.decrypt_signed(&result); + assert_eq!(result, clear_a.min(clear_a), "{clear_a}.min({clear_a})"); + } + + { + let result = max_executor.execute((&a, clear_b)); + let result: i128 = cks.decrypt_signed(&result); + assert_eq!(result, clear_a.max(clear_b), "{clear_a}.max({clear_b})"); + + let result = max_executor.execute((&a, clear_a)); + let result: i128 = cks.decrypt_signed(&result); + assert_eq!(result, clear_a.max(clear_a), "{clear_a}.max({clear_a})"); + } + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_comparison.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_comparison.rs index 733f1baa39..61667f63a0 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_comparison.rs @@ -3,7 +3,7 @@ use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom}; use crate::integer::ciphertext::RadixCiphertext; use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; -use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_unsigned::{CpuFunctionExecutor, NB_CTXT}; use crate::integer::tests::create_parametrized_test; use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey, I256, U256}; #[cfg(tarpaulin)] @@ -245,6 +245,8 @@ macro_rules! define_scalar_comparison_test_functions { create_parametrized_test!([] { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, @@ -257,6 +259,7 @@ macro_rules! define_scalar_comparison_test_functions { create_parametrized_test!([] { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // We don't use PARAM_MESSAGE_3_CARRY_3_KS_PBS, // as smart test might overflow values @@ -271,6 +274,7 @@ macro_rules! define_scalar_comparison_test_functions { create_parametrized_test!([] { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // We don't use PARAM_MESSAGE_3_CARRY_3_KS_PBS, // as default test might overflow values @@ -408,45 +412,45 @@ fn integer_is_scalar_out_of_bounds(param: ClassicPBSParameters) { // This one is in range let scalar = U256::from(u128::MAX); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, None); + assert_eq!(res, std::cmp::Ordering::Equal); let scalar = U256::from(u128::MAX) + U256::ONE; let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Greater)); + assert_eq!(res, std::cmp::Ordering::Greater); let scalar = U256::from(u128::MAX) + U256::from(rng.gen_range(2u128..=u128::MAX)); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Greater)); + assert_eq!(res, std::cmp::Ordering::Greater); let scalar = U256::from(u128::MAX) + U256::from(u128::MAX); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Greater)); + assert_eq!(res, std::cmp::Ordering::Greater); } // Negative scalars { let res = sks.is_scalar_out_of_bounds(&ct, -1i128); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); let scalar = I256::from(i128::MIN) - I256::ONE; let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); let scalar = I256::from(i128::MIN) + I256::from(rng.gen_range(i128::MIN..=-2)); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); let scalar = I256::from(i128::MIN) + I256::from(i128::MIN); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); let scalar = I256::from(i128::MIN) - I256::from(rng.gen_range(2..=i128::MAX)); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); let scalar = I256::from(i128::MIN) - I256::from(i128::MAX); let res = sks.is_scalar_out_of_bounds(&ct, scalar); - assert_eq!(res, Some(std::cmp::Ordering::Less)); + assert_eq!(res, std::cmp::Ordering::Less); } // Negative scalar @@ -456,7 +460,7 @@ fn integer_is_scalar_out_of_bounds(param: ClassicPBSParameters) { let bigger_ct = cks.encrypt_signed_radix(-1i128, num_block); let scalar = i64::MIN; let res = sks.is_scalar_out_of_bounds(&bigger_ct, scalar); - assert_eq!(res, None); + assert_eq!(res, std::cmp::Ordering::Equal); } } @@ -674,32 +678,38 @@ mod no_coverage { } create_parametrized_test!(integer_unchecked_scalar_min_parallelized_u256 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, PARAM_MESSAGE_4_CARRY_4_KS_PBS }); create_parametrized_test!(integer_unchecked_scalar_max_parallelized_u256 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, PARAM_MESSAGE_4_CARRY_4_KS_PBS }); create_parametrized_test!(integer_smart_scalar_min_parallelized_u256 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // No test for 3_3, see define_scalar_comparison_test_functions macro PARAM_MESSAGE_4_CARRY_4_KS_PBS }); create_parametrized_test!(integer_smart_scalar_max_parallelized_u256 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // No test for 3_3, see define_scalar_comparison_test_functions macro PARAM_MESSAGE_4_CARRY_4_KS_PBS }); create_parametrized_test!(integer_scalar_min_parallelized_u256 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // No test for 3_3, see define_scalar_comparison_test_functions macro PARAM_MESSAGE_4_CARRY_4_KS_PBS }); create_parametrized_test!(integer_scalar_max_parallelized_u256 { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, // No test for 3_3, see define_scalar_comparison_test_functions macro PARAM_MESSAGE_4_CARRY_4_KS_PBS @@ -713,6 +723,7 @@ mod no_coverage { define_scalar_comparison_test_functions!(ge, U256); create_parametrized_test!(integer_unchecked_scalar_comparisons_edge { + PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, PARAM_MESSAGE_4_CARRY_4_KS_PBS @@ -781,3 +792,136 @@ mod coverage { create_parametrized_test_classical_params!(integer_is_scalar_out_of_bounds); } + +create_parametrized_test!(integer_extensive_trivial_default_scalar_comparisons); + +fn integer_extensive_trivial_default_scalar_comparisons(params: impl Into) { + let lt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized); + let le_executor = CpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized); + let gt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized); + let ge_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized); + let min_executor = CpuFunctionExecutor::new(&ServerKey::scalar_min_parallelized); + let max_executor = CpuFunctionExecutor::new(&ServerKey::scalar_max_parallelized); + + extensive_trivial_default_scalar_comparisons_test( + params, + lt_executor, + le_executor, + gt_executor, + ge_executor, + min_executor, + max_executor, + ) +} + +/// Although this uses the executor pattern and could be plugged in other backends, +/// It is not recommended to do so unless the backend is extremely fast on trivial ciphertexts +/// or extremely extremely fast in general, or if its plugged just as a one time thing. +#[allow(clippy::eq_op)] +pub(crate) fn extensive_trivial_default_scalar_comparisons_test( + param: P, + mut lt_executor: E1, + mut le_executor: E2, + mut gt_executor: E3, + mut ge_executor: E4, + mut min_executor: E5, + mut max_executor: E6, +) where + P: Into, + E1: for<'a> FunctionExecutor<(&'a RadixCiphertext, u128), BooleanBlock>, + E2: for<'a> FunctionExecutor<(&'a RadixCiphertext, u128), BooleanBlock>, + E3: for<'a> FunctionExecutor<(&'a RadixCiphertext, u128), BooleanBlock>, + E4: for<'a> FunctionExecutor<(&'a RadixCiphertext, u128), BooleanBlock>, + E5: for<'a> FunctionExecutor<(&'a RadixCiphertext, u128), RadixCiphertext>, + E6: for<'a> FunctionExecutor<(&'a RadixCiphertext, u128), RadixCiphertext>, +{ + let params = param.into(); + let (cks, mut sks) = KEY_CACHE.get_from_params(params, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = thread_rng(); + + lt_executor.setup(&cks, sks.clone()); + le_executor.setup(&cks, sks.clone()); + gt_executor.setup(&cks, sks.clone()); + ge_executor.setup(&cks, sks.clone()); + min_executor.setup(&cks, sks.clone()); + max_executor.setup(&cks, sks.clone()); + + for num_blocks in 1..=128 { + println!("num_blocks: {num_blocks}"); + let Some(modulus) = (params.message_modulus().0 as u128).checked_pow(num_blocks as u32) + else { + break; + }; + for _ in 0..25 { + let clear_a = rng.gen_range(0..modulus); + let clear_b = rng.gen_range(0..modulus); + + let a: RadixCiphertext = sks.create_trivial_radix(clear_a, num_blocks); + + { + let result = lt_executor.execute((&a, clear_b)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a < clear_b, "{clear_a} < {clear_b}"); + + let result = lt_executor.execute((&a, clear_a)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a < clear_a, "{clear_a} < {clear_a}"); + } + + { + let result = le_executor.execute((&a, clear_b)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a <= clear_b, "{clear_a} <= {clear_b}"); + + let result = le_executor.execute((&a, clear_a)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a <= clear_a, "{clear_a} <= {clear_a}"); + } + + { + let result = gt_executor.execute((&a, clear_b)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a > clear_b, "{clear_a} > {clear_b}"); + + let result = gt_executor.execute((&a, clear_a)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a > clear_a, "{clear_a} > {clear_a}"); + } + + { + let result = ge_executor.execute((&a, clear_b)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a >= clear_b, "{clear_a} >= {clear_b}"); + + let result = ge_executor.execute((&a, clear_a)); + let result = cks.decrypt_bool(&result); + assert_eq!(result, clear_a >= clear_a, "{clear_a} >= {clear_a}"); + } + + { + let result = min_executor.execute((&a, clear_b)); + let result: u128 = cks.decrypt(&result); + assert_eq!(result, clear_a.min(clear_b), "{clear_a}.min({clear_b})"); + + let result = min_executor.execute((&a, clear_a)); + let result: u128 = cks.decrypt(&result); + assert_eq!(result, clear_a.min(clear_a), "{clear_a}.min({clear_a})"); + } + + { + let result = max_executor.execute((&a, clear_b)); + let result: u128 = cks.decrypt(&result); + assert_eq!(result, clear_a.max(clear_b), "{clear_a}.max({clear_b})"); + + let result = max_executor.execute((&a, clear_a)); + let result: u128 = cks.decrypt(&result); + assert_eq!(result, clear_a.max(clear_a), "{clear_a}.max({clear_a})"); + } + } + } +}