From 971b0cf0b677312739d1a1a1053f5ebd76387e93 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Wed, 3 Apr 2024 17:57:58 +0200 Subject: [PATCH] feat(gpu): signed scalar rotate --- tfhe/benches/integer/signed_bench.rs | 28 ++ .../integers/unsigned/scalar_ops.rs | 4 +- .../gpu/server_key/radix/scalar_rotate.rs | 126 ++++--- .../gpu/server_key/radix/tests_signed/mod.rs | 1 + .../radix/tests_signed/test_scalar_rotate.rs | 46 +++ .../server_key/radix/tests_unsigned/mod.rs | 19 +- .../tests_unsigned/test_scalar_rotate.rs | 46 +++ .../radix_parallel/tests_cases_unsigned.rs | 272 --------------- .../radix_parallel/tests_signed/mod.rs | 204 +---------- .../tests_signed/test_scalar_rotate.rs | 269 +++++++++++++++ .../radix_parallel/tests_unsigned/mod.rs | 37 +- .../tests_unsigned/test_scalar_rotate.rs | 323 ++++++++++++++++++ 12 files changed, 777 insertions(+), 598 deletions(-) create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_rotate.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_rotate.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_rotate.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_rotate.rs diff --git a/tfhe/benches/integer/signed_bench.rs b/tfhe/benches/integer/signed_bench.rs index ca4b4392be..855494d0f3 100644 --- a/tfhe/benches/integer/signed_bench.rs +++ b/tfhe/benches/integer/signed_bench.rs @@ -1722,6 +1722,18 @@ mod cuda { rng_func: shift_scalar ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_rotate_right, + display_name: rotate_right, + rng_func: shift_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: unchecked_scalar_rotate_left, + display_name: rotate_left, + rng_func: shift_scalar + ); + //=========================================== // Default //=========================================== @@ -1874,6 +1886,18 @@ mod cuda { rng_func: shift_scalar ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_rotate_left, + display_name: rotate_left, + rng_func: shift_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: scalar_rotate_right, + display_name: rotate_right, + rng_func: shift_scalar + ); + criterion_group!( unchecked_cuda_ops, cuda_unchecked_add, @@ -1908,6 +1932,8 @@ mod cuda { cuda_unchecked_scalar_bitxor, cuda_unchecked_scalar_left_shift, cuda_unchecked_scalar_right_shift, + cuda_unchecked_scalar_rotate_left, + cuda_unchecked_scalar_rotate_right, ); criterion_group!( @@ -1944,6 +1970,8 @@ mod cuda { cuda_scalar_bitxor, cuda_scalar_left_shift, cuda_scalar_right_shift, + cuda_scalar_rotate_left, + cuda_scalar_rotate_right, ); } diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 075a63116a..c0bab8a68b 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -770,7 +770,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_stream(|stream| { cuda_key.key.scalar_rotate_left( - &lhs.ciphertext.on_gpu(), u64::cast_from(rhs), stream + &*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), stream ) }); RadixCiphertext::Cuda(inner_result) @@ -808,7 +808,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_stream(|stream| { cuda_key.key.scalar_rotate_right( - &lhs.ciphertext.on_gpu(), u64::cast_from(rhs), stream + &*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), stream ) }); RadixCiphertext::Cuda(inner_result) diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs index 976e47f20d..78ce186755 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs @@ -1,6 +1,6 @@ use crate::core_crypto::gpu::CudaStream; use crate::core_crypto::prelude::CastFrom; -use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; +use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; use crate::integer::gpu::server_key::CudaBootstrappingKey; use crate::integer::gpu::CudaServerKey; @@ -9,15 +9,16 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_rotate_left_async( + pub unsafe fn unchecked_scalar_rotate_left_async( &self, - ct: &CudaUnsignedRadixCiphertext, - n: T, + ct: &T, + n: Scalar, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> T where - T: CastFrom, - u32: CastFrom, + T: CudaIntegerRadixCiphertext, + Scalar: CastFrom, + u32: CastFrom, { let mut result = ct.duplicate_async(stream); self.unchecked_scalar_rotate_left_assign_async(&mut result, n, stream); @@ -28,14 +29,15 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_rotate_left_assign_async( + pub unsafe fn unchecked_scalar_rotate_left_assign_async( &self, - ct: &mut CudaUnsignedRadixCiphertext, - n: T, + ct: &mut T, + n: Scalar, stream: &CudaStream, ) where - T: CastFrom, - u32: CastFrom, + T: CudaIntegerRadixCiphertext, + Scalar: CastFrom, + u32: CastFrom, { let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count(); match &self.bootstrapping_key { @@ -89,15 +91,16 @@ impl CudaServerKey { } } - pub fn unchecked_scalar_rotate_left( + pub fn unchecked_scalar_rotate_left( &self, - ct: &CudaUnsignedRadixCiphertext, - n: T, + ct: &T, + n: Scalar, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> T where - T: CastFrom, - u32: CastFrom, + T: CudaIntegerRadixCiphertext, + Scalar: CastFrom, + u32: CastFrom, { let result = unsafe { self.unchecked_scalar_rotate_left_async(ct, n, stream) }; stream.synchronize(); @@ -108,15 +111,16 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_rotate_right_async( + pub unsafe fn unchecked_scalar_rotate_right_async( &self, - ct: &CudaUnsignedRadixCiphertext, - n: T, + ct: &T, + n: Scalar, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> T where - T: CastFrom, - u32: CastFrom, + T: CudaIntegerRadixCiphertext, + Scalar: CastFrom, + u32: CastFrom, { let mut result = ct.duplicate_async(stream); self.unchecked_scalar_rotate_right_assign_async(&mut result, n, stream); @@ -127,14 +131,15 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_scalar_rotate_right_assign_async( + pub unsafe fn unchecked_scalar_rotate_right_assign_async( &self, - ct: &mut CudaUnsignedRadixCiphertext, - n: T, + ct: &mut T, + n: Scalar, stream: &CudaStream, ) where - T: CastFrom, - u32: CastFrom, + T: CudaIntegerRadixCiphertext, + Scalar: CastFrom, + u32: CastFrom, { let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count(); match &self.bootstrapping_key { @@ -188,29 +193,27 @@ impl CudaServerKey { } } - pub fn unchecked_scalar_rotate_right( + pub fn unchecked_scalar_rotate_right( &self, - ct: &CudaUnsignedRadixCiphertext, - n: T, + ct: &T, + n: Scalar, stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + ) -> T where - T: CastFrom, - u32: CastFrom, + T: CudaIntegerRadixCiphertext, + Scalar: CastFrom, + u32: CastFrom, { let result = unsafe { self.unchecked_scalar_rotate_right_async(ct, n, stream) }; stream.synchronize(); result } - pub fn scalar_rotate_left_assign( - &self, - ct: &mut CudaUnsignedRadixCiphertext, - n: T, - stream: &CudaStream, - ) where - T: CastFrom, - u32: CastFrom, + pub fn scalar_rotate_left_assign(&self, ct: &mut T, n: Scalar, stream: &CudaStream) + where + T: CudaIntegerRadixCiphertext, + Scalar: CastFrom, + u32: CastFrom, { if !ct.block_carries_are_empty() { unsafe { @@ -222,14 +225,11 @@ impl CudaServerKey { stream.synchronize(); } - pub fn scalar_rotate_right_assign( - &self, - ct: &mut CudaUnsignedRadixCiphertext, - n: T, - stream: &CudaStream, - ) where - T: CastFrom, - u32: CastFrom, + pub fn scalar_rotate_right_assign(&self, ct: &mut T, n: Scalar, stream: &CudaStream) + where + T: CudaIntegerRadixCiphertext, + Scalar: CastFrom, + u32: CastFrom, { if !ct.block_carries_are_empty() { unsafe { @@ -241,30 +241,22 @@ impl CudaServerKey { stream.synchronize(); } - pub fn scalar_rotate_left( - &self, - ct: &CudaUnsignedRadixCiphertext, - shift: T, - stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + pub fn scalar_rotate_left(&self, ct: &T, shift: Scalar, stream: &CudaStream) -> T where - T: CastFrom, - u32: CastFrom, + T: CudaIntegerRadixCiphertext, + Scalar: CastFrom, + u32: CastFrom, { let mut result = unsafe { ct.duplicate_async(stream) }; self.scalar_rotate_left_assign(&mut result, shift, stream); result } - pub fn scalar_rotate_right( - &self, - ct: &CudaUnsignedRadixCiphertext, - shift: T, - stream: &CudaStream, - ) -> CudaUnsignedRadixCiphertext + pub fn scalar_rotate_right(&self, ct: &T, shift: Scalar, stream: &CudaStream) -> T where - T: CastFrom, - u32: CastFrom, + T: CudaIntegerRadixCiphertext, + Scalar: CastFrom, + u32: CastFrom, { let mut result = unsafe { ct.duplicate_async(stream) }; self.scalar_rotate_right_assign(&mut result, shift, stream); diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs index 14fb95f29b..778e61170a 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs @@ -7,6 +7,7 @@ pub(crate) mod test_rotate; pub(crate) mod test_scalar_add; pub(crate) mod test_scalar_bitwise_op; pub(crate) mod test_scalar_mul; +pub(crate) mod test_scalar_rotate; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_rotate.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_rotate.rs new file mode 100644 index 0000000000..24dd30655c --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_rotate.rs @@ -0,0 +1,46 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parametrized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_rotate::{ + signed_default_scalar_rotate_left_test, signed_default_scalar_rotate_right_test, + signed_unchecked_scalar_rotate_left_test, signed_unchecked_scalar_rotate_right_test, +}; +use crate::shortint::parameters::*; + +create_gpu_parametrized_test!(integer_signed_unchecked_scalar_rotate_left); +create_gpu_parametrized_test!(integer_signed_scalar_rotate_left); +create_gpu_parametrized_test!(integer_signed_unchecked_scalar_rotate_right); +create_gpu_parametrized_test!(integer_signed_scalar_rotate_right); + +fn integer_signed_unchecked_scalar_rotate_left

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_rotate_left); + signed_unchecked_scalar_rotate_left_test(param, executor); +} + +fn integer_signed_scalar_rotate_left

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_rotate_left); + signed_default_scalar_rotate_left_test(param, executor); +} + +fn integer_signed_unchecked_scalar_rotate_right

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_rotate_right); + signed_unchecked_scalar_rotate_right_test(param, executor); +} + +fn integer_signed_scalar_rotate_right

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_rotate_right); + signed_default_scalar_rotate_right_test(param, executor); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs index ea627da165..58e10a9b85 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs @@ -8,6 +8,7 @@ pub(crate) mod test_scalar_add; pub(crate) mod test_scalar_bitwise_op; pub(crate) mod test_scalar_comparison; pub(crate) mod test_scalar_mul; +pub(crate) mod test_scalar_rotate; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; @@ -85,8 +86,6 @@ impl GpuFunctionExecutor { // Unchecked operations create_gpu_parametrized_test!(integer_unchecked_if_then_else); -create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_left); -create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_right); // Default operations create_gpu_parametrized_test!(integer_if_then_else); @@ -477,22 +476,6 @@ where } } -fn integer_unchecked_scalar_rotate_left

(param: P) -where - P: Into + Copy, -{ - let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_rotate_left); - unchecked_scalar_rotate_left_test(param, executor); -} - -fn integer_unchecked_scalar_rotate_right

(param: P) -where - P: Into + Copy, -{ - let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_rotate_right); - unchecked_scalar_rotate_right_test(param, executor); -} - fn integer_if_then_else

(param: P) where P: Into + Copy, diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_rotate.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_rotate.rs new file mode 100644 index 0000000000..df19d3190d --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_rotate.rs @@ -0,0 +1,46 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parametrized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_unsigned::test_scalar_rotate::{ + default_scalar_rotate_left_test, default_scalar_rotate_right_test, + unchecked_scalar_rotate_left_test, unchecked_scalar_rotate_right_test, +}; +use crate::shortint::parameters::*; + +create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_left); +create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_right); +create_gpu_parametrized_test!(integer_scalar_rotate_left); +create_gpu_parametrized_test!(integer_scalar_rotate_right); + +fn integer_unchecked_scalar_rotate_right

(param: P) +where + P: Into + Copy, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_rotate_right); + unchecked_scalar_rotate_right_test(param, executor); +} + +fn integer_scalar_rotate_right

(param: P) +where + P: Into + Copy, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_rotate_right); + default_scalar_rotate_right_test(param, executor); +} + +fn integer_unchecked_scalar_rotate_left

(param: P) +where + P: Into + Copy, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_rotate_left); + unchecked_scalar_rotate_left_test(param, executor); +} + +fn integer_scalar_rotate_left

(param: P) +where + P: Into + Copy, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_rotate_left); + default_scalar_rotate_left_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index feb86341af..eafe2a50c1 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -877,134 +877,6 @@ where } } -pub(crate) fn unchecked_scalar_rotate_left_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let sks = Arc::new(sks); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let nb_bits = modulus.ilog2(); - let bits_per_block = cks.parameters().message_modulus().0.ilog2(); - - executor.setup(&cks, sks); - - for _ in 0..(NB_TESTS / 3).max(1) { - let clear = rng.gen::() % modulus; - let scalar = rng.gen::(); - - let ct = cks.encrypt(clear); - - // Force case where n is multiple of block size - { - let scalar = scalar - (scalar % bits_per_block); - let encrypted_result = executor.execute((&ct, scalar as u64)); - assert!(encrypted_result.block_carries_are_empty()); - let decrypted_result: u64 = cks.decrypt(&encrypted_result); - let expected = rotate_left_helper(clear, scalar, nb_bits); - assert_eq!(expected, decrypted_result); - } - - // Force case where n is not multiple of block size - { - let rest = scalar % bits_per_block; - let scalar = if rest == 0 { - scalar + (rng.gen::() % bits_per_block) - } else { - scalar - }; - let encrypted_result = executor.execute((&ct, scalar as u64)); - assert!(encrypted_result.block_carries_are_empty()); - let decrypted_result: u64 = cks.decrypt(&encrypted_result); - let expected = rotate_left_helper(clear, scalar, nb_bits); - assert_eq!(expected, decrypted_result); - } - - // Force case where - // The value is non zero - // we rotate so that at least one non zero bit, cycle/wraps around - { - let value = rng.gen_range(1..=u32::MAX); - let scalar = value.leading_zeros() + rng.gen_range(1..nb_bits); - let encrypted_result = executor.execute((&ct, scalar as u64)); - assert!(encrypted_result.block_carries_are_empty()); - let decrypted_result: u64 = cks.decrypt(&encrypted_result); - let expected = rotate_left_helper(clear, scalar, nb_bits); - assert_eq!(expected, decrypted_result); - } - } -} - -pub(crate) fn unchecked_scalar_rotate_right_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, -{ - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let sks = Arc::new(sks); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let nb_bits = modulus.ilog2(); - let bits_per_block = cks.parameters().message_modulus().0.ilog2(); - - executor.setup(&cks, sks); - - for _ in 0..(NB_TESTS / 3).max(1) { - let clear = rng.gen::() % modulus; - let scalar = rng.gen::(); - - let ct = cks.encrypt(clear); - - // Force case where n is multiple of block size - { - let scalar = scalar - (scalar % bits_per_block); - let encrypted_result = executor.execute((&ct, scalar as u64)); - assert!(encrypted_result.block_carries_are_empty()); - let decrypted_result: u64 = cks.decrypt(&encrypted_result); - let expected = rotate_right_helper(clear, scalar, nb_bits); - assert_eq!(expected, decrypted_result); - } - - // Force case where n is not multiple of block size - { - let rest = scalar % bits_per_block; - let scalar = if rest == 0 { - scalar + (rng.gen::() % bits_per_block) - } else { - scalar - }; - let encrypted_result = executor.execute((&ct, scalar as u64)); - assert!(encrypted_result.block_carries_are_empty()); - let decrypted_result: u64 = cks.decrypt(&encrypted_result); - let expected = rotate_right_helper(clear, scalar, nb_bits); - assert_eq!(expected, decrypted_result); - } - - // Force case where - // The value is non zero - // we rotate so that at least one non zero bit, cycle/wraps around - { - let value = rng.gen_range(1..=u32::MAX); - let scalar = value.trailing_zeros() + rng.gen_range(1..nb_bits); - let encrypted_result = executor.execute((&ct, scalar as u64)); - assert!(encrypted_result.block_carries_are_empty()); - let decrypted_result: u64 = cks.decrypt(&encrypted_result); - let expected = rotate_right_helper(clear, scalar, nb_bits); - assert_eq!(expected, decrypted_result); - } - } -} - //============================================================================= // Smart Tests //============================================================================= @@ -3251,150 +3123,6 @@ where } } -pub(crate) fn default_scalar_rotate_right_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - executor.setup(&cks, sks); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let nb_bits = modulus.ilog2(); - let bits_per_block = cks.parameters().message_modulus().0.ilog2(); - - for _ in 0..(NB_TESTS / 2).max(1) { - let clear = rng.gen::() % modulus; - let scalar = rng.gen::(); - - let ct = cks.encrypt(clear); - - // Force case where n is multiple of block size - { - let scalar = scalar - (scalar % bits_per_block); - let ct_res = executor.execute((&ct, scalar as u64)); - let tmp = executor.execute((&ct, scalar as u64)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - let dec_res: u64 = cks.decrypt(&ct_res); - let expected = rotate_right_helper(clear, scalar, nb_bits); - assert_eq!(expected, dec_res); - } - - // Force case where n is not multiple of block size - { - let rest = scalar % bits_per_block; - let scalar = if rest == 0 { - scalar + (rng.gen::() % bits_per_block) - } else { - scalar - }; - let ct_res = executor.execute((&ct, scalar as u64)); - let tmp = executor.execute((&ct, scalar as u64)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - let dec_res: u64 = cks.decrypt(&ct_res); - let expected = rotate_right_helper(clear, scalar, nb_bits); - assert_eq!(expected, dec_res); - } - - // Force case where - // The value is non zero - // we rotate so that at least one non zero bit, cycle/wraps around - { - let value = rng.gen_range(1..=u32::MAX); - let scalar = value.trailing_zeros() + rng.gen_range(1..nb_bits); - let ct_res = executor.execute((&ct, scalar as u64)); - let tmp = executor.execute((&ct, scalar as u64)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - let dec_res: u64 = cks.decrypt(&ct_res); - let expected = rotate_right_helper(clear, scalar, nb_bits); - assert_eq!(expected, dec_res); - } - } -} - -pub(crate) fn default_scalar_rotate_left_test(param: P, mut executor: T) -where - P: Into, - T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let nb_bits = modulus.ilog2(); - let bits_per_block = cks.parameters().message_modulus().0.ilog2(); - - executor.setup(&cks, sks); - - for _ in 0..(NB_TESTS / 3).max(1) { - let clear = rng.gen::() % modulus; - let scalar = rng.gen::(); - - let ct = cks.encrypt(clear); - - // Force case where n is multiple of block size - { - let scalar = scalar - (scalar % bits_per_block); - let ct_res = executor.execute((&ct, scalar as u64)); - let tmp = executor.execute((&ct, scalar as u64)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - let dec_res: u64 = cks.decrypt(&ct_res); - let expected = rotate_left_helper(clear, scalar, nb_bits); - assert_eq!(expected, dec_res); - } - - // Force case where n is not multiple of block size - { - let rest = scalar % bits_per_block; - let scalar = if rest == 0 { - scalar + (rng.gen::() % bits_per_block) - } else { - scalar - }; - let ct_res = executor.execute((&ct, scalar as u64)); - let tmp = executor.execute((&ct, scalar as u64)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - let dec_res: u64 = cks.decrypt(&ct_res); - let expected = rotate_left_helper(clear, scalar, nb_bits); - assert_eq!(expected, dec_res); - } - - // Force case where - // The value is non zero - // we rotate so that at least one non zero bit, cycle/wraps around - { - let value = rng.gen_range(1..=u32::MAX); - let scalar = value.leading_zeros() + rng.gen_range(1..nb_bits); - let ct_res = executor.execute((&ct, scalar as u64)); - let tmp = executor.execute((&ct, scalar as u64)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - let dec_res: u64 = cks.decrypt(&ct_res); - let expected = rotate_left_helper(clear, scalar, nb_bits); - assert_eq!(expected, dec_res); - } - } -} - pub(crate) fn default_scalar_div_rem_test(param: P, mut executor: T) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs index 12024e770d..42266495c2 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs @@ -9,6 +9,7 @@ pub(crate) mod test_scalar_add; pub(crate) mod test_scalar_bitwise_op; pub(crate) mod test_scalar_comparison; pub(crate) mod test_scalar_mul; +pub(crate) mod test_scalar_rotate; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; @@ -511,84 +512,9 @@ where //================================================================================ // Unchecked Scalar Tests //================================================================================ - -create_parametrized_test!(integer_signed_unchecked_scalar_rotate_left); -create_parametrized_test!(integer_signed_unchecked_scalar_rotate_right); create_parametrized_test!(integer_signed_unchecked_scalar_div_rem); create_parametrized_test!(integer_signed_unchecked_scalar_div_rem_floor); -fn integer_signed_unchecked_scalar_rotate_left(param: impl Into) { - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - assert!(modulus > 0); - assert!((modulus as u64).is_power_of_two()); - let nb_bits = modulus.ilog2() + 1; // We are using signed numbers - - for _ in 0..NB_TESTS { - let clear = rng.gen::() % modulus; - let clear_shift = rng.gen::(); - - let ct = cks.encrypt_signed_radix(clear, NB_CTXT); - - // case when 0 <= rotate < nb_bits - { - let clear_shift = clear_shift % nb_bits; - let ct_res = sks.unchecked_scalar_rotate_left_parallelized(&ct, clear_shift); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - let expected = rotate_left_helper(clear, clear_shift, nb_bits); - assert_eq!(expected, dec_res); - } - - // case when rotate >= nb_bits - { - let clear_shift = clear_shift.saturating_add(nb_bits); - let ct_res = sks.unchecked_scalar_rotate_left_parallelized(&ct, clear_shift); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - let expected = rotate_left_helper(clear, clear_shift, nb_bits); - assert_eq!(expected, dec_res); - } - } -} - -fn integer_signed_unchecked_scalar_rotate_right(param: impl Into) { - let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - assert!(modulus > 0); - assert!((modulus as u64).is_power_of_two()); - let nb_bits = modulus.ilog2() + 1; // We are using signed numbers - - for _ in 0..NB_TESTS { - let clear = rng.gen::() % modulus; - let clear_shift = rng.gen::(); - - let ct = cks.encrypt_signed_radix(clear, NB_CTXT); - - // case when 0 <= rotate < nb_bits - { - let clear_shift = clear_shift % nb_bits; - let ct_res = sks.unchecked_scalar_rotate_right_parallelized(&ct, clear_shift); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - let expected = rotate_right_helper(clear, clear_shift, nb_bits); - assert_eq!(expected, dec_res); - } - - // case when rotate >= nb_bits - { - let clear_shift = clear_shift.saturating_add(nb_bits); - let ct_res = sks.unchecked_scalar_rotate_right_parallelized(&ct, clear_shift); - let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); - let expected = rotate_right_helper(clear, clear_shift, nb_bits); - assert_eq!(expected, dec_res); - } - } -} - fn integer_signed_unchecked_scalar_div_rem(param: impl Into) { let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); @@ -846,8 +772,6 @@ fn integer_signed_unchecked_scalar_div_rem_floor(param: impl Into //================================================================================ create_parametrized_test!(integer_signed_default_scalar_div_rem); -create_parametrized_test!(integer_signed_default_scalar_rotate_right); -create_parametrized_test!(integer_signed_default_scalar_rotate_left); fn integer_signed_default_scalar_div_rem(param: impl Into) { let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); @@ -901,132 +825,6 @@ fn integer_signed_default_scalar_div_rem(param: impl Into) { } } -fn integer_signed_default_scalar_rotate_left

(param: P) -where - P: Into, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - assert!(modulus > 0); - assert!((modulus as u64).is_power_of_two()); - let nb_bits = modulus.ilog2() + 1; // We are using signed numbers - - for _ in 0..NB_TESTS_SMALLER { - let mut clear = rng.gen::() % modulus; - - let offset = random_non_zero_value(&mut rng, modulus); - - let mut ct = cks.encrypt_signed(clear); - sks.unchecked_scalar_add_assign(&mut ct, offset); - clear = signed_add_under_modulus(clear, offset, modulus); - - // case when 0 <= shift < nb_bits - { - let clear_shift = rng.gen::() % nb_bits; - let ct_res = sks.scalar_rotate_left_parallelized(&ct, clear_shift); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = rotate_left_helper(clear, clear_shift, nb_bits); - assert_eq!( - clear_res, dec_res, - "Invalid left shift result, for '{clear}.rotate_left({clear_shift})', \ - expected: {clear_res}, got: {dec_res}" - ); - - let ct_res2 = sks.scalar_rotate_left_parallelized(&ct, clear_shift); - assert_eq!(ct_res, ct_res2, "Failed determinism check"); - } - - // case when shift >= nb_bits - { - let clear_shift = rng.gen_range(nb_bits..=u32::MAX); - let ct_res = sks.scalar_rotate_left_parallelized(&ct, clear_shift); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - // We mimic wrapping_shl manually as we use a bigger type - // than the nb_bits we actually simulate in this test - let clear_res = rotate_left_helper(clear, clear_shift, nb_bits); - assert_eq!( - clear_res, - dec_res, - "Invalid rotate left result, for '{clear}.rotate_left({})', \ - expected: {clear_res}, got: {dec_res}", - clear_shift % nb_bits - ); - - let ct_res2 = sks.scalar_rotate_left_parallelized(&ct, clear_shift); - assert_eq!(ct_res, ct_res2, "Failed determinism check"); - } - } -} - -fn integer_signed_default_scalar_rotate_right

(param: P) -where - P: Into, -{ - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - - let mut rng = rand::thread_rng(); - - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - assert!(modulus > 0); - assert!((modulus as u64).is_power_of_two()); - let nb_bits = modulus.ilog2() + 1; // We are using signed numbers - - for _ in 0..NB_TESTS_SMALLER { - let mut clear = rng.gen::() % modulus; - - let offset = random_non_zero_value(&mut rng, modulus); - - let mut ct = cks.encrypt_signed(clear); - sks.unchecked_scalar_add_assign(&mut ct, offset); - clear = signed_add_under_modulus(clear, offset, modulus); - - // case when 0 <= shift < nb_bits - { - let clear_shift = rng.gen::() % nb_bits; - let ct_res = sks.scalar_rotate_right_parallelized(&ct, clear_shift); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - let clear_res = rotate_right_helper(clear, clear_shift, nb_bits); - assert_eq!( - clear_res, dec_res, - "Invalid right shift result, for '{clear}.rotate_right({clear_shift})', \ - expected: {clear_res}, got: {dec_res}" - ); - - let ct_res2 = sks.scalar_rotate_right_parallelized(&ct, clear_shift); - assert_eq!(ct_res, ct_res2, "Failed determinism check"); - } - - // case when shift >= nb_bits - { - let clear_shift = rng.gen_range(nb_bits..=u32::MAX); - let ct_res = sks.scalar_rotate_right_parallelized(&ct, clear_shift); - let dec_res: i64 = cks.decrypt_signed(&ct_res); - // We mimic wrapping_shl manually as we use a bigger type - // than the nb_bits we actually simulate in this test - let clear_res = rotate_right_helper(clear, clear_shift, nb_bits); - assert_eq!( - clear_res, dec_res, - "Invalid rotate right result, for '{clear}.rotate_right({clear_shift})', \ - expected: {clear_res}, got: {dec_res}" - ); - - let ct_res2 = sks.scalar_rotate_right_parallelized(&ct, clear_shift); - assert_eq!(ct_res, ct_res2, "Failed determinism check"); - } - } -} - //================================================================================ // Helper functions //================================================================================ diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_rotate.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_rotate.rs new file mode 100644 index 0000000000..0a7c375074 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_rotate.rs @@ -0,0 +1,269 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + random_non_zero_value, rotate_left_helper, rotate_right_helper, signed_add_under_modulus, + NB_CTXT, NB_TESTS, NB_TESTS_SMALLER, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::tests::create_parametrized_test; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext}; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; + +create_parametrized_test!(integer_signed_unchecked_scalar_rotate_left); +create_parametrized_test!(integer_signed_default_scalar_rotate_left); +create_parametrized_test!(integer_signed_unchecked_scalar_rotate_right); +create_parametrized_test!(integer_signed_default_scalar_rotate_right); + +fn integer_signed_unchecked_scalar_rotate_left

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_rotate_left_parallelized); + signed_unchecked_scalar_rotate_left_test(param, executor); +} + +fn integer_signed_default_scalar_rotate_left

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized); + signed_default_scalar_rotate_left_test(param, executor); +} + +fn integer_signed_unchecked_scalar_rotate_right

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_rotate_right_parallelized); + signed_unchecked_scalar_rotate_right_test(param, executor); +} + +fn integer_signed_default_scalar_rotate_right

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized); + signed_default_scalar_rotate_right_test(param, executor); +} + +pub(crate) fn signed_unchecked_scalar_rotate_left_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + assert!(modulus > 0); + assert!((modulus as u64).is_power_of_two()); + let nb_bits = modulus.ilog2() + 1; // We are using signed numbers + + for _ in 0..NB_TESTS { + let clear = rng.gen::() % modulus; + let clear_shift = rng.gen::(); + + let ct = cks.encrypt_signed(clear); + + // case when 0 <= rotate < nb_bits + { + let clear_shift = clear_shift % nb_bits; + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let expected = rotate_left_helper(clear, clear_shift, nb_bits); + assert_eq!(expected, dec_res); + } + + // case when rotate >= nb_bits + { + let clear_shift = clear_shift.saturating_add(nb_bits); + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let expected = rotate_left_helper(clear, clear_shift, nb_bits); + assert_eq!(expected, dec_res); + } + } +} + +pub(crate) fn signed_unchecked_scalar_rotate_right_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + assert!(modulus > 0); + assert!((modulus as u64).is_power_of_two()); + let nb_bits = modulus.ilog2() + 1; // We are using signed numbers + + for _ in 0..NB_TESTS { + let clear = rng.gen::() % modulus; + let clear_shift = rng.gen::(); + + let ct = cks.encrypt_signed(clear); + + // case when 0 <= shift < nb_bits + { + let clear_shift = clear_shift % nb_bits; + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let expected = rotate_right_helper(clear, clear_shift, nb_bits); + assert_eq!(expected, dec_res); + } + + // case when shift >= nb_bits + { + let clear_shift = clear_shift.saturating_add(nb_bits); + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let expected = rotate_right_helper(clear, clear_shift, nb_bits); + assert_eq!(expected, dec_res); + } + } +} + +pub(crate) fn signed_default_scalar_rotate_left_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + assert!(modulus > 0); + assert!((modulus as u64).is_power_of_two()); + let nb_bits = modulus.ilog2() + 1; // We are using signed numbers + + for _ in 0..NB_TESTS_SMALLER { + let mut clear = rng.gen::() % modulus; + + let offset = random_non_zero_value(&mut rng, modulus); + + let mut ct = cks.encrypt_signed(clear); + sks.unchecked_scalar_add_assign(&mut ct, offset); + clear = signed_add_under_modulus(clear, offset, modulus); + + // case when 0 <= shift < nb_bits + { + let clear_shift = rng.gen::() % nb_bits; + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = rotate_left_helper(clear, clear_shift, nb_bits); + assert_eq!( + clear_res, dec_res, + "Invalid left shift result, for '{clear} << {clear_shift}', \ + expected: {clear_res}, got: {dec_res}" + ); + + let ct_res2 = executor.execute((&ct, clear_shift as i64)); + assert_eq!(ct_res, ct_res2, "Failed determinism check"); + } + + // case when shift >= nb_bits + { + let clear_shift = rng.gen_range(nb_bits..=u32::MAX); + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + // We mimic wrapping_shl manually as we use a bigger type + // than the nb_bits we actually simulate in this test + let clear_res = rotate_left_helper(clear, clear_shift, nb_bits); + assert_eq!( + clear_res, dec_res, + "Invalid left shift result, for '{clear} << {clear_shift}', \ + expected: {clear_res}, got: {dec_res}" + ); + + let ct_res2 = executor.execute((&ct, clear_shift as i64)); + assert_eq!(ct_res, ct_res2, "Failed determinism check"); + } + } +} + +pub(crate) fn signed_default_scalar_rotate_right_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + assert!(modulus > 0); + assert!((modulus as u64).is_power_of_two()); + let nb_bits = modulus.ilog2() + 1; // We are using signed numbers + + for _ in 0..NB_TESTS_SMALLER { + let mut clear = rng.gen::() % modulus; + + let offset = random_non_zero_value(&mut rng, modulus); + + let mut ct = cks.encrypt_signed(clear); + sks.unchecked_scalar_add_assign(&mut ct, offset); + clear = signed_add_under_modulus(clear, offset, modulus); + + // case when 0 <= shift < nb_bits + { + let clear_shift = rng.gen::() % nb_bits; + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + let clear_res = rotate_right_helper(clear, clear_shift, nb_bits); + assert_eq!( + clear_res, dec_res, + "Invalid right shift result, for '{clear} >> {clear_shift}', \ + expected: {clear_res}, got: {dec_res}" + ); + + let ct_res2 = executor.execute((&ct, clear_shift as i64)); + assert_eq!(ct_res, ct_res2, "Failed determinism check"); + } + + // case when shift >= nb_bits + { + let clear_shift = rng.gen_range(nb_bits..=u32::MAX); + let ct_res = executor.execute((&ct, clear_shift as i64)); + let dec_res: i64 = cks.decrypt_signed(&ct_res); + // We mimic wrapping_shl manually as we use a bigger type + // than the nb_bits we actually simulate in this test + let clear_res = rotate_right_helper(clear, clear_shift, nb_bits); + assert_eq!( + clear_res, dec_res, + "Invalid right shift result, for '{clear} >> {clear_shift}', \ + expected: {clear_res}, got: {dec_res}" + ); + + let ct_res2 = executor.execute((&ct, clear_shift as i64)); + assert_eq!(ct_res, ct_res2, "Failed determinism check"); + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index ba04c05b71..b0c3b76d52 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -9,6 +9,7 @@ pub(crate) mod test_scalar_add; pub(crate) mod test_scalar_bitwise_op; pub(crate) mod test_scalar_comparison; pub(crate) mod test_scalar_mul; +pub(crate) mod test_scalar_rotate; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; @@ -498,10 +499,6 @@ create_parametrized_test!( } ); // left/right rotations -create_parametrized_test!(integer_unchecked_scalar_rotate_right); -create_parametrized_test!(integer_unchecked_scalar_rotate_left); -create_parametrized_test!(integer_default_scalar_rotate_right); -create_parametrized_test!(integer_default_scalar_rotate_left); create_parametrized_test!(integer_default_scalar_div_rem); create_parametrized_test!(integer_smart_if_then_else); create_parametrized_test!(integer_default_if_then_else); @@ -711,22 +708,6 @@ where // Unchecked Scalar Tests //============================================================================= -fn integer_unchecked_scalar_rotate_right

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_rotate_right_parallelized); - unchecked_scalar_rotate_right_test(param, executor); -} - -fn integer_unchecked_scalar_rotate_left

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_rotate_left_parallelized); - unchecked_scalar_rotate_left_test(param, executor); -} - //============================================================================= // Smart Tests //============================================================================= @@ -892,22 +873,6 @@ where default_checked_ilog2_test(param, executor); } -fn integer_default_scalar_rotate_right

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized); - default_scalar_rotate_right_test(param, executor); -} - -fn integer_default_scalar_rotate_left

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized); - default_scalar_rotate_left_test(param, executor); -} - fn integer_default_scalar_div_rem

(param: P) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_rotate.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_rotate.rs new file mode 100644 index 0000000000..fbbdbb169f --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_rotate.rs @@ -0,0 +1,323 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ + FunctionExecutor, NB_CTXT, NB_TESTS, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::{ + rotate_left_helper, rotate_right_helper, CpuFunctionExecutor, +}; +use crate::integer::tests::create_parametrized_test; +use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; + +create_parametrized_test!(integer_unchecked_scalar_rotate_left); +create_parametrized_test!(integer_default_scalar_rotate_left); +create_parametrized_test!(integer_unchecked_scalar_rotate_right); +create_parametrized_test!(integer_default_scalar_rotate_right); + +fn integer_default_scalar_rotate_left

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized); + default_scalar_rotate_left_test(param, executor); +} + +fn integer_unchecked_scalar_rotate_left

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_rotate_left_parallelized); + unchecked_scalar_rotate_left_test(param, executor); +} + +fn integer_default_scalar_rotate_right

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized); + default_scalar_rotate_right_test(param, executor); +} + +fn integer_unchecked_scalar_rotate_right

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_rotate_right_parallelized); + unchecked_scalar_rotate_right_test(param, executor); +} + +pub(crate) fn unchecked_scalar_rotate_left_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + let nb_bits = modulus.ilog2(); + let bits_per_block = cks.parameters().message_modulus().0.ilog2(); + + executor.setup(&cks, sks); + + for _ in 0..(NB_TESTS / 3).max(1) { + let clear = rng.gen::() % modulus; + let scalar = rng.gen::(); + + let ct = cks.encrypt(clear); + + // Force case where n is multiple of block size + { + let scalar = scalar - (scalar % bits_per_block); + let encrypted_result = executor.execute((&ct, scalar as u64)); + assert!(encrypted_result.block_carries_are_empty()); + let decrypted_result: u64 = cks.decrypt(&encrypted_result); + let expected = rotate_left_helper(clear, scalar, nb_bits); + assert_eq!(expected, decrypted_result); + } + + // Force case where n is not multiple of block size + { + let rest = scalar % bits_per_block; + let scalar = if rest == 0 { + scalar + (rng.gen::() % bits_per_block) + } else { + scalar + }; + let encrypted_result = executor.execute((&ct, scalar as u64)); + assert!(encrypted_result.block_carries_are_empty()); + let decrypted_result: u64 = cks.decrypt(&encrypted_result); + let expected = rotate_left_helper(clear, scalar, nb_bits); + assert_eq!(expected, decrypted_result); + } + + // Force case where + // The value is non zero + // we rotate so that at least one non zero bit, cycle/wraps around + { + let value = rng.gen_range(1..=u32::MAX); + let scalar = value.leading_zeros() + rng.gen_range(1..nb_bits); + let encrypted_result = executor.execute((&ct, scalar as u64)); + assert!(encrypted_result.block_carries_are_empty()); + let decrypted_result: u64 = cks.decrypt(&encrypted_result); + let expected = rotate_left_helper(clear, scalar, nb_bits); + assert_eq!(expected, decrypted_result); + } + } +} + +pub(crate) fn unchecked_scalar_rotate_right_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + let nb_bits = modulus.ilog2(); + let bits_per_block = cks.parameters().message_modulus().0.ilog2(); + + executor.setup(&cks, sks); + + for _ in 0..(NB_TESTS / 3).max(1) { + let clear = rng.gen::() % modulus; + let scalar = rng.gen::(); + + let ct = cks.encrypt(clear); + + // Force case where n is multiple of block size + { + let scalar = scalar - (scalar % bits_per_block); + let encrypted_result = executor.execute((&ct, scalar as u64)); + assert!(encrypted_result.block_carries_are_empty()); + let decrypted_result: u64 = cks.decrypt(&encrypted_result); + let expected = rotate_right_helper(clear, scalar, nb_bits); + assert_eq!(expected, decrypted_result); + } + + // Force case where n is not multiple of block size + { + let rest = scalar % bits_per_block; + let scalar = if rest == 0 { + scalar + (rng.gen::() % bits_per_block) + } else { + scalar + }; + let encrypted_result = executor.execute((&ct, scalar as u64)); + assert!(encrypted_result.block_carries_are_empty()); + let decrypted_result: u64 = cks.decrypt(&encrypted_result); + let expected = rotate_right_helper(clear, scalar, nb_bits); + assert_eq!(expected, decrypted_result); + } + + // Force case where + // The value is non zero + // we rotate so that at least one non zero bit, cycle/wraps around + { + let value = rng.gen_range(1..=u32::MAX); + let scalar = value.trailing_zeros() + rng.gen_range(1..nb_bits); + let encrypted_result = executor.execute((&ct, scalar as u64)); + assert!(encrypted_result.block_carries_are_empty()); + let decrypted_result: u64 = cks.decrypt(&encrypted_result); + let expected = rotate_right_helper(clear, scalar, nb_bits); + assert_eq!(expected, decrypted_result); + } + } +} + +pub(crate) fn default_scalar_rotate_right_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + let nb_bits = modulus.ilog2(); + let bits_per_block = cks.parameters().message_modulus().0.ilog2(); + + for _ in 0..(NB_TESTS / 2).max(1) { + let clear = rng.gen::() % modulus; + let scalar = rng.gen::(); + + let ct = cks.encrypt(clear); + + // Force case where n is multiple of block size + { + let scalar = scalar - (scalar % bits_per_block); + let ct_res = executor.execute((&ct, scalar as u64)); + let tmp = executor.execute((&ct, scalar as u64)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + let dec_res: u64 = cks.decrypt(&ct_res); + let expected = rotate_right_helper(clear, scalar, nb_bits); + assert_eq!(expected, dec_res); + } + + // Force case where n is not multiple of block size + { + let rest = scalar % bits_per_block; + let scalar = if rest == 0 { + scalar + (rng.gen::() % bits_per_block) + } else { + scalar + }; + let ct_res = executor.execute((&ct, scalar as u64)); + let tmp = executor.execute((&ct, scalar as u64)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + let dec_res: u64 = cks.decrypt(&ct_res); + let expected = rotate_right_helper(clear, scalar, nb_bits); + assert_eq!(expected, dec_res); + } + + // Force case where + // The value is non zero + // we rotate so that at least one non zero bit, cycle/wraps around + { + let value = rng.gen_range(1..=u32::MAX); + let scalar = value.trailing_zeros() + rng.gen_range(1..nb_bits); + let ct_res = executor.execute((&ct, scalar as u64)); + let tmp = executor.execute((&ct, scalar as u64)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + let dec_res: u64 = cks.decrypt(&ct_res); + let expected = rotate_right_helper(clear, scalar, nb_bits); + assert_eq!(expected, dec_res); + } + } +} + +pub(crate) fn default_scalar_rotate_left_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, +{ + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + let nb_bits = modulus.ilog2(); + let bits_per_block = cks.parameters().message_modulus().0.ilog2(); + + executor.setup(&cks, sks); + + for _ in 0..(NB_TESTS / 3).max(1) { + let clear = rng.gen::() % modulus; + let scalar = rng.gen::(); + + let ct = cks.encrypt(clear); + + // Force case where n is multiple of block size + { + let scalar = scalar - (scalar % bits_per_block); + let ct_res = executor.execute((&ct, scalar as u64)); + let tmp = executor.execute((&ct, scalar as u64)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + let dec_res: u64 = cks.decrypt(&ct_res); + let expected = rotate_left_helper(clear, scalar, nb_bits); + assert_eq!(expected, dec_res); + } + + // Force case where n is not multiple of block size + { + let rest = scalar % bits_per_block; + let scalar = if rest == 0 { + scalar + (rng.gen::() % bits_per_block) + } else { + scalar + }; + let ct_res = executor.execute((&ct, scalar as u64)); + let tmp = executor.execute((&ct, scalar as u64)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + let dec_res: u64 = cks.decrypt(&ct_res); + let expected = rotate_left_helper(clear, scalar, nb_bits); + assert_eq!(expected, dec_res); + } + + // Force case where + // The value is non zero + // we rotate so that at least one non zero bit, cycle/wraps around + { + let value = rng.gen_range(1..=u32::MAX); + let scalar = value.leading_zeros() + rng.gen_range(1..nb_bits); + let ct_res = executor.execute((&ct, scalar as u64)); + let tmp = executor.execute((&ct, scalar as u64)); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + let dec_res: u64 = cks.decrypt(&ct_res); + let expected = rotate_left_helper(clear, scalar, nb_bits); + assert_eq!(expected, dec_res); + } + } +}