diff --git a/tfhe/src/error.rs b/tfhe/src/error.rs index 51667eb481..58d7ceaf93 100644 --- a/tfhe/src/error.rs +++ b/tfhe/src/error.rs @@ -65,3 +65,12 @@ impl From for Error { unreachable!() } } + +/// Error returned when the provided range for a slice is invalid +#[derive(Debug)] +pub enum InvalidRangeError { + /// The upper bound of the range is greater than the size of the integer + SliceTooBig, + /// The upper gound is smaller than the lower bound + WrongOrder, +} diff --git a/tfhe/src/integer/server_key/radix/mod.rs b/tfhe/src/integer/server_key/radix/mod.rs index 1a25bd7659..6e5348a585 100644 --- a/tfhe/src/integer/server_key/radix/mod.rs +++ b/tfhe/src/integer/server_key/radix/mod.rs @@ -7,6 +7,7 @@ mod scalar_add; pub(super) mod scalar_mul; pub(super) mod scalar_sub; mod shift; +pub(super) mod slice; mod sub; use super::ServerKey; diff --git a/tfhe/src/integer/server_key/radix/slice.rs b/tfhe/src/integer/server_key/radix/slice.rs new file mode 100644 index 0000000000..c5f244dd04 --- /dev/null +++ b/tfhe/src/integer/server_key/radix/slice.rs @@ -0,0 +1,533 @@ +use std::ops::{Bound, RangeBounds}; + +use crate::error::InvalidRangeError; +use crate::integer::{RadixCiphertext, ServerKey}; +use crate::prelude::{CastFrom, CastInto}; +use crate::shortint; + +/// Normalize a rust bound object, and check that it is valid for the source integer +pub(crate) fn parse_bounds( + range: &R, + nb_bits: usize, +) -> Result<(usize, usize), InvalidRangeError> +where + R: RangeBounds, + B: CastFrom + CastInto + Copy, +{ + let start = match range.start_bound() { + Bound::Included(inc) => (*inc).cast_into(), + Bound::Excluded(excl) => (*excl).cast_into() - 1, + Bound::Unbounded => 0, + }; + + let end = match range.end_bound() { + Bound::Included(inc) => (*inc).cast_into() + 1, + Bound::Excluded(excl) => (*excl).cast_into(), + Bound::Unbounded => nb_bits, + }; + + if end > nb_bits { + Err(InvalidRangeError::SliceTooBig) + } else if start > end { + Err(InvalidRangeError::WrongOrder) + } else { + Ok((start, end)) + } +} + +/// This is the operation to extract a non-aligned block, on the clear. +/// For example, with a 2x4bits integer: |abcd|efgh|, extracting the block +/// at offset 2 will return |cdef|. This function should be used inside a LUT. +pub(in crate::integer) fn slice_oneblock_clear_unaligned( + cur_block: u64, + next_block: u64, + offset: usize, + block_size: usize, +) -> u64 { + cur_block >> (offset) | ((next_block << (block_size - offset)) % (1 << block_size)) +} + +impl ServerKey { + /// Extract a slice of blocks from a ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_block = 1; + /// let end_block = 2; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.blockslice(&ct, start_block, end_block); + /// + /// let blocksize = cks.parameters().message_modulus().0.ilog2() as u64; + /// let start_bit = (start_block as u64) * blocksize; + /// let end_bit = (end_block as u64) * blocksize; + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn blockslice( + &self, + ctxt: &RadixCiphertext, + start_block: usize, + end_block: usize, + ) -> RadixCiphertext { + let limit = end_block - start_block; + + let mut result: RadixCiphertext = self.create_trivial_zero_radix(limit); + + for (res_i, c_i) in result.blocks[..limit] + .iter_mut() + .zip(ctxt.blocks[start_block..].iter()) + { + res_i.clone_from(c_i); + } + + result + } + + /// Extract a slice of blocks from a ciphertext. + /// + /// The result is assigned in the input ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_block = 1; + /// let end_block = 2; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.blockslice_assign(&mut ct, start_block, end_block); + /// + /// let blocksize = cks.parameters().message_modulus().0.ilog2() as u64; + /// let start_bit = (start_block as u64) * blocksize; + /// let end_bit = (end_block as u64) * blocksize; + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn blockslice_assign( + &self, + ctxt: &mut RadixCiphertext, + start_block: usize, + end_block: usize, + ) { + *ctxt = self.blockslice(ctxt, start_block, end_block); + } + + /// Return the unaligned remainder of a slice after all the unaligned full blocks have been + /// extracted. This is similar to what [`slice_interblock`] does on each block except that the + /// remainder is not a full block, so it will be truncated to `count` bits. + pub(in crate::integer) fn bitslice_remainder_unaligned( + &self, + ctxt: &RadixCiphertext, + block_idx: usize, + offset: usize, + count: usize, + ) -> shortint::Ciphertext { + let lut = self + .key + .generate_lookup_table_bivariate(|current_block, next_block| { + slice_oneblock_clear_unaligned( + current_block, + next_block, + offset, + self.message_modulus().0.ilog2() as usize, + ) % (1 << count) + }); + + self.key.apply_lookup_table_bivariate( + &ctxt.blocks[block_idx], + &ctxt + .blocks + .get(block_idx + 1) + .cloned() + .unwrap_or_else(|| self.key.create_trivial(0)), + &lut, + ) + } + + /// Returnsthe remainder of a slice after all the full blocks have been extracted. This will + /// simply truncate the block value to `count` bits. + pub(in crate::integer) fn bitslice_remainder( + &self, + ctxt: &RadixCiphertext, + block_idx: usize, + count: usize, + ) -> shortint::Ciphertext { + let lut = self.key.generate_lookup_table(|block| block % (1 << count)); + + self.key.apply_lookup_table(&ctxt.blocks[block_idx], &lut) + } + + /// Extract a slice from a ciphertext. The size of the slice is a multiple of the block + /// size but it is not aligned on block boundaries, so we need to mix block n and (n+1) toG + /// create a new block, using the lut function `slice_oneblock_clear_unaligned`. + fn blockslice_unaligned( + &self, + ctxt: &RadixCiphertext, + start_block: usize, + block_count: usize, + offset: usize, + ) -> RadixCiphertext { + let mut blocks = Vec::new(); + + let lut = self + .key + .generate_lookup_table_bivariate(|current_block, next_block| { + slice_oneblock_clear_unaligned( + current_block, + next_block, + offset, + self.message_modulus().0.ilog2() as usize, + ) + }); + + for idx in 0..block_count { + let block = self.key.apply_lookup_table_bivariate( + &ctxt.blocks[idx + start_block], + &ctxt.blocks[idx + start_block + 1], + &lut, + ); + + blocks.push(block); + } + + RadixCiphertext::from(blocks) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .unchecked_scalar_bitslice(&ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn unchecked_scalar_bitslice( + &self, + ctxt: &RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + let block_width = self.message_modulus().0.ilog2() as usize; + let (start, end) = parse_bounds(&range, block_width * ctxt.blocks.len())?; + + let slice_width = end - start; + + // If the starting bit is block aligned, we can do most of the slicing with block copies. + // If it's not we must extract the bits with PBS. In either cases, we must extract the last + // bits with a PBS if the slice size is not a multiple of the block size. + let mut sliced = if start % block_width != 0 { + let mut sliced = self.blockslice_unaligned( + ctxt, + start / block_width, + slice_width / block_width, + start % block_width, + ); + + if slice_width % block_width != 0 { + let last_block = self.bitslice_remainder_unaligned( + ctxt, + start / block_width + slice_width / block_width, + start % block_width, + slice_width % block_width, + ); + sliced.blocks.push(last_block); + } + + sliced + } else { + let mut sliced = self.blockslice(ctxt, start / block_width, end / block_width); + if slice_width % block_width != 0 { + let last_block = + self.bitslice_remainder(ctxt, end / block_width, slice_width % block_width); + sliced.blocks.push(last_block); + } + + sliced + }; + + // Extend with trivial zeroes to return an integer of the same size as the input one. + self.extend_radix_with_trivial_zero_blocks_msb_assign(&mut sliced, ctxt.blocks.len()); + Ok(sliced) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.unchecked_scalar_bitslice_assign(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn unchecked_scalar_bitslice_assign( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + *ctxt = self.unchecked_scalar_bitslice(ctxt, range)?; + Ok(()) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.scalar_bitslice(&ct, start_bit..end_bit).unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn scalar_bitslice( + &self, + ctxt: &RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if ctxt.block_carries_are_empty() { + self.unchecked_scalar_bitslice(ctxt, range) + } else { + let mut ctxt = ctxt.clone(); + self.full_propagate(&mut ctxt); + self.unchecked_scalar_bitslice(&ctxt, range) + } + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.scalar_bitslice(&ct, start_bit..end_bit).unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn scalar_bitslice_assign( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + } + + self.unchecked_scalar_bitslice_assign(ctxt, range) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .smart_scalar_bitslice(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn smart_scalar_bitslice( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + } + + self.unchecked_scalar_bitslice(ctxt, range) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.smart_scalar_bitslice_assign(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn smart_scalar_bitslice_assign( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + } + + self.unchecked_scalar_bitslice_assign(ctxt, range) + } +} diff --git a/tfhe/src/integer/server_key/radix/tests.rs b/tfhe/src/integer/server_key/radix/tests.rs index 9192f3c6e1..64dc9dcb04 100644 --- a/tfhe/src/integer/server_key/radix/tests.rs +++ b/tfhe/src/integer/server_key/radix/tests.rs @@ -2,6 +2,11 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::*; use crate::integer::server_key::radix_parallel::tests_unsigned::test_add::smart_add_test; use crate::integer::server_key::radix_parallel::tests_unsigned::test_neg::smart_neg_test; +use crate::integer::server_key::radix_parallel::tests_unsigned::test_slice::{ + blockslice_assign_test, blockslice_test, default_scalar_bitslice_assign_test, + default_scalar_bitslice_test, smart_scalar_bitslice_assign_test, smart_scalar_bitslice_test, + unchecked_scalar_bitslice_assign_test, unchecked_scalar_bitslice_test, +}; use crate::integer::server_key::radix_parallel::tests_unsigned::test_sub::{ default_overflowing_sub_test, smart_sub_test, }; @@ -103,6 +108,14 @@ create_parametrized_test!( create_parametrized_test_classical_params!(integer_create_trivial_min_max); create_parametrized_test_classical_params!(integer_signed_decryption_correctly_sign_extend); +create_parametrized_test_classical_params!(integer_blockslice); +create_parametrized_test_classical_params!(integer_blockslice_assign); +create_parametrized_test_classical_params!(integer_unchecked_scalar_slice); +create_parametrized_test_classical_params!(integer_unchecked_scalar_slice_assign); +create_parametrized_test_classical_params!(integer_default_scalar_slice); +create_parametrized_test_classical_params!(integer_default_scalar_slice_assign); +create_parametrized_test_classical_params!(integer_smart_scalar_slice); +create_parametrized_test_classical_params!(integer_smart_scalar_slice_assign); fn integer_encrypt_decrypt(param: ClassicPBSParameters) { let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); @@ -1068,3 +1081,43 @@ fn integer_signed_decryption_correctly_sign_extend(param: impl Into().unwrap(), value as i128); } + +fn integer_blockslice(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::blockslice); + blockslice_test(param, executor); +} + +fn integer_blockslice_assign(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::blockslice_assign); + blockslice_assign_test(param, executor); +} + +fn integer_unchecked_scalar_slice(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_bitslice); + unchecked_scalar_bitslice_test(param, executor); +} + +fn integer_unchecked_scalar_slice_assign(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_bitslice_assign); + unchecked_scalar_bitslice_assign_test(param, executor); +} + +fn integer_default_scalar_slice(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_bitslice); + default_scalar_bitslice_test(param, executor); +} + +fn integer_default_scalar_slice_assign(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_bitslice_assign); + default_scalar_bitslice_assign_test(param, executor); +} + +fn integer_smart_scalar_slice(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_bitslice); + smart_scalar_bitslice_test(param, executor); +} + +fn integer_smart_scalar_slice_assign(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_bitslice_assign); + smart_scalar_bitslice_assign_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs index 621a21fbba..2540b1868a 100644 --- a/tfhe/src/integer/server_key/radix_parallel/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -23,6 +23,7 @@ mod sum; mod ilog2; mod reverse_bits; +mod slice; #[cfg(test)] pub(crate) mod tests_cases_unsigned; #[cfg(test)] diff --git a/tfhe/src/integer/server_key/radix_parallel/slice.rs b/tfhe/src/integer/server_key/radix_parallel/slice.rs new file mode 100644 index 0000000000..2d07199d90 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/slice.rs @@ -0,0 +1,361 @@ +use std::ops::RangeBounds; + +use rayon::prelude::*; + +use crate::error::InvalidRangeError; +use crate::integer::server_key::radix::slice::{parse_bounds, slice_oneblock_clear_unaligned}; +use crate::integer::{RadixCiphertext, ServerKey}; +use crate::prelude::{CastFrom, CastInto}; + +impl ServerKey { + /// Extract a slice from a ciphertext. The size of the slice is a multiple of the block + /// size but it is not aligned on block boundaries, so we need to mix block n and (n+1) to + /// create a new block, using the lut function `slice_oneblock_clear_unaligned`. + fn blockslice_unaligned_parallelized( + &self, + ctxt: &RadixCiphertext, + start_block: usize, + block_count: usize, + offset: usize, + ) -> RadixCiphertext { + let mut out: RadixCiphertext = self.create_trivial_zero_radix(block_count); + + let lut = self + .key + .generate_lookup_table_bivariate(|current_block, next_block| { + slice_oneblock_clear_unaligned( + current_block, + next_block, + offset, + self.message_modulus().0.ilog2() as usize, + ) + }); + + out.blocks + .par_iter_mut() + .enumerate() + .for_each(|(idx, block)| { + *block = self.key.apply_lookup_table_bivariate( + &ctxt.blocks[idx + start_block], + &ctxt.blocks[idx + start_block + 1], + &lut, + ); + }); + + out + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .unchecked_scalar_bitslice_parallelized(&ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn unchecked_scalar_bitslice_parallelized( + &self, + ctxt: &RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + let block_width = self.message_modulus().0.ilog2() as usize; + let (start, end) = parse_bounds(&range, block_width * ctxt.blocks.len())?; + + let slice_width = end - start; + + // If the starting bit is block aligned, we can do most of the slicing with block copies. + // If it's not we must extract the bits with PBS. In either cases, we must extract the last + // bits with a PBS if the slice size is not a multiple of the block size. + let mut sliced = if start % block_width != 0 { + let mut sliced = self.blockslice_unaligned_parallelized( + ctxt, + start / block_width, + slice_width / block_width, + start % block_width, + ); + + if slice_width % block_width != 0 { + let last_block = self.bitslice_remainder_unaligned( + ctxt, + start / block_width + slice_width / block_width, + start % block_width, + slice_width % block_width, + ); + sliced.blocks.push(last_block); + } + + sliced + } else { + let mut sliced = self.blockslice(ctxt, start / block_width, end / block_width); + if slice_width % block_width != 0 { + let last_block = + self.bitslice_remainder(ctxt, end / block_width, slice_width % block_width); + sliced.blocks.push(last_block); + } + + sliced + }; + + // Extend with trivial zeroes to return an integer of the same size as the input one. + self.extend_radix_with_trivial_zero_blocks_msb_assign(&mut sliced, ctxt.blocks.len()); + Ok(sliced) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.unchecked_scalar_bitslice_assign_parallelized(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn unchecked_scalar_bitslice_assign_parallelized( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + *ctxt = self.unchecked_scalar_bitslice_parallelized(ctxt, range)?; + Ok(()) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .scalar_bitslice_parallelized(&ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn scalar_bitslice_parallelized( + &self, + ctxt: &RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if ctxt.block_carries_are_empty() { + self.unchecked_scalar_bitslice_parallelized(ctxt, range) + } else { + let mut ctxt = ctxt.clone(); + self.full_propagate(&mut ctxt); + self.unchecked_scalar_bitslice_parallelized(&ctxt, range) + } + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.scalar_bitslice_assign_parallelized(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn scalar_bitslice_assign_parallelized( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + } + + self.unchecked_scalar_bitslice_assign_parallelized(ctxt, range) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .smart_scalar_bitslice_parallelized(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn smart_scalar_bitslice_parallelized( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + } + + self.unchecked_scalar_bitslice_parallelized(ctxt, range) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.smart_scalar_bitslice_assign(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn smart_scalar_bitslice_assign_parallelized( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + } + + self.unchecked_scalar_bitslice_assign_parallelized(ctxt, range) + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index ed23a5580d..5a4bd65e61 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -16,6 +16,7 @@ pub(crate) mod test_scalar_rotate; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; +pub(crate) mod test_slice; pub(crate) mod test_sub; pub(crate) mod test_sum; pub(crate) mod test_vector_comparisons; diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_slice.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_slice.rs new file mode 100644 index 0000000000..74b3c3a90d --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_slice.rs @@ -0,0 +1,443 @@ +use std::ops::{Range, RangeBounds}; +use std::sync::Arc; + +use rand::prelude::*; + +use crate::error::InvalidRangeError; +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix::slice::parse_bounds; +use crate::integer::tests::create_parametrized_test; +use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; +use crate::prelude::CastFrom; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; + +use super::{nb_tests_for_params, CpuFunctionExecutor, FunctionExecutor, PBSParameters, NB_CTXT}; + +create_parametrized_test!(integer_unchecked_scalar_slice); +create_parametrized_test!(integer_unchecked_scalar_slice_assign); +create_parametrized_test!(integer_default_scalar_slice); +create_parametrized_test!(integer_default_scalar_slice_assign); +create_parametrized_test!(integer_smart_scalar_slice); +create_parametrized_test!(integer_smart_scalar_slice_assign); + +// Reference implementation of the slice using a conversion into a string of 0/1 to do the slicing. +fn slice_reference_impl(value: u64, bounds: B, modulus: u64) -> u64 +where + B: RangeBounds, + T: CastFrom + Copy, + usize: CastFrom, +{ + let (start, end) = parse_bounds(&bounds, modulus as usize).unwrap(); + + let bin: String = format!("{value:064b}").chars().rev().collect(); + + let out_bin: String = bin[start..end].chars().rev().collect(); + u64::from_str_radix(&out_bin, 2).unwrap_or_default() +} + +//============================================================================= +// Unchecked Tests +//============================================================================= + +pub(crate) fn blockslice_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a RadixCiphertext, usize, usize), RadixCiphertext>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % (NB_CTXT as u32); + let range_b = rng.gen::() % (NB_CTXT as u32); + + let (block_start, block_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let bit_start = block_start * param.message_modulus().0.ilog2(); + let bit_end = block_end * param.message_modulus().0.ilog2(); + + let ct = cks.encrypt(clear); + + let ct_res = executor.execute((&ct, block_start as usize, block_end as usize)); + let dec_res: u64 = cks.decrypt(&ct_res); + assert_eq!( + slice_reference_impl(clear, bit_start..bit_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn blockslice_assign_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, usize, usize), ()>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % (NB_CTXT as u32); + let range_b = rng.gen::() % (NB_CTXT as u32); + + let (block_start, block_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let bit_start = block_start * param.message_modulus().0.ilog2(); + let bit_end = block_end * param.message_modulus().0.ilog2(); + + let mut ct = cks.encrypt(clear); + + executor.execute((&mut ct, block_start as usize, block_end as usize)); + let dec_res: u64 = cks.decrypt(&ct); + assert_eq!( + slice_reference_impl(clear, bit_start..bit_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn unchecked_scalar_bitslice_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a RadixCiphertext, Range), + Result, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let ct = cks.encrypt(clear); + + let ct_res = executor.execute((&ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct_res); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn unchecked_scalar_bitslice_assign_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a mut RadixCiphertext, Range), + Result<(), InvalidRangeError>, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let mut ct = cks.encrypt(clear); + + executor.execute((&mut ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn default_scalar_bitslice_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a RadixCiphertext, Range), + Result, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let ct = cks.encrypt(clear); + + let ct_res = executor.execute((&ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct_res); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn default_scalar_bitslice_assign_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a mut RadixCiphertext, Range), + Result<(), InvalidRangeError>, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let mut ct = cks.encrypt(clear); + + executor.execute((&mut ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn smart_scalar_bitslice_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a mut RadixCiphertext, Range), + Result, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let mut ct = cks.encrypt(clear); + + let ct_res = executor.execute((&mut ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct_res); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn smart_scalar_bitslice_assign_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a mut RadixCiphertext, Range), + Result<(), InvalidRangeError>, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let mut ct = cks.encrypt(clear); + + executor.execute((&mut ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +fn integer_unchecked_scalar_slice

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_bitslice_parallelized); + unchecked_scalar_bitslice_test(param, executor); +} + +fn integer_unchecked_scalar_slice_assign

(param: P) +where + P: Into, +{ + let executor = + CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_bitslice_assign_parallelized); + unchecked_scalar_bitslice_assign_test(param, executor); +} + +fn integer_default_scalar_slice

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_bitslice_parallelized); + default_scalar_bitslice_test(param, executor); +} + +fn integer_default_scalar_slice_assign

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_bitslice_assign_parallelized); + default_scalar_bitslice_assign_test(param, executor); +} + +fn integer_smart_scalar_slice

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_bitslice_parallelized); + smart_scalar_bitslice_test(param, executor); +} + +fn integer_smart_scalar_slice_assign

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_bitslice_assign_parallelized); + smart_scalar_bitslice_assign_test(param, executor); +}