From 262e773cfdfd736d51717f0f4d89f39be35880f4 Mon Sep 17 00:00:00 2001 From: Guillermo Oyarzun Date: Mon, 16 Dec 2024 19:31:25 +0100 Subject: [PATCH] feat(gpu): implement subarray search --- .../cuda/include/integer/integer.h | 36 + .../cuda/src/integer/comparison.cu | 88 + .../cuda/src/integer/comparison.cuh | 35 +- backends/tfhe-cuda-backend/src/bindings.rs | 86 + tfhe/src/integer/gpu/ciphertext/info.rs | 37 + tfhe/src/integer/gpu/mod.rs | 210 +++ tfhe/src/integer/gpu/server_key/radix/mod.rs | 1 + .../gpu/server_key/radix/scalar_comparison.rs | 160 ++ .../server_key/radix/tests_unsigned/mod.rs | 265 ++- .../radix/tests_unsigned/test_vector_find.rs | 220 +++ .../gpu/server_key/radix/vector_find.rs | 1668 +++++++++++++++++ .../radix_parallel/tests_cases_unsigned.rs | 15 + .../server_key/radix_parallel/vector_find.rs | 4 + 13 files changed, 2822 insertions(+), 3 deletions(-) create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_vector_find.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/vector_find.rs diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h index 325891b860..062d30755a 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h @@ -440,5 +440,41 @@ void cleanup_cuda_integer_abs_inplace(void *const *streams, uint32_t gpu_count, int8_t **mem_ptr_void); +void scratch_cuda_integer_are_all_comparisons_block_true_kb_64( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, + uint32_t big_lwe_dimension, uint32_t small_lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t num_radix_blocks, + uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, + bool allocate_gpu_memory); + +void cuda_integer_are_all_comparisons_block_true_kb_64( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + void *lwe_array_out, void const *lwe_array_in, int8_t *mem_ptr, + void *const *bsks, void *const *ksks, uint32_t num_radix_blocks); + +void cleanup_cuda_integer_are_all_comparisons_block_true( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr_void); + +void scratch_cuda_integer_is_at_least_one_comparisons_block_true_kb_64( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, + uint32_t big_lwe_dimension, uint32_t small_lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t num_radix_blocks, + uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, + bool allocate_gpu_memory); + +void cuda_integer_is_at_least_one_comparisons_block_true_kb_64( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + void *lwe_array_out, void const *lwe_array_in, int8_t *mem_ptr, + void *const *bsks, void *const *ksks, uint32_t num_radix_blocks); + +void cleanup_cuda_integer_is_at_least_one_comparisons_block_true( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr_void); + } // extern C #endif // CUDA_INTEGER_H diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu index 3e5c7fb683..0071eedae7 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu @@ -94,3 +94,91 @@ void cleanup_cuda_integer_comparison(void *const *streams, (int_comparison_buffer *)(*mem_ptr_void); mem_ptr->release((cudaStream_t *)(streams), gpu_indexes, gpu_count); } + +void scratch_cuda_integer_are_all_comparisons_block_true_kb_64( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, + uint32_t big_lwe_dimension, uint32_t small_lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t num_radix_blocks, + uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, + bool allocate_gpu_memory) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus); + + scratch_cuda_integer_radix_comparison_check_kb( + (cudaStream_t *)(streams), gpu_indexes, gpu_count, + (int_comparison_buffer **)mem_ptr, num_radix_blocks, params, EQ, + false, allocate_gpu_memory); +} + +void cuda_integer_are_all_comparisons_block_true_kb_64( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + void *lwe_array_out, void const *lwe_array_in, int8_t *mem_ptr, + void *const *bsks, void *const *ksks, uint32_t num_radix_blocks) { + + int_comparison_buffer *buffer = + (int_comparison_buffer *)mem_ptr; + + host_integer_are_all_comparisons_block_true_kb( + (cudaStream_t *)(streams), gpu_indexes, gpu_count, + static_cast(lwe_array_out), + static_cast(lwe_array_in), buffer, bsks, + (uint64_t **)(ksks), num_radix_blocks); +} + +void cleanup_cuda_integer_are_all_comparisons_block_true( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr_void) { + + int_comparison_buffer *mem_ptr = + (int_comparison_buffer *)(*mem_ptr_void); + mem_ptr->release((cudaStream_t *)(streams), gpu_indexes, gpu_count); +} + +void scratch_cuda_integer_is_at_least_one_comparisons_block_true_kb_64( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, + uint32_t big_lwe_dimension, uint32_t small_lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t num_radix_blocks, + uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, + bool allocate_gpu_memory) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus); + + scratch_cuda_integer_radix_comparison_check_kb( + (cudaStream_t *)(streams), gpu_indexes, gpu_count, + (int_comparison_buffer **)mem_ptr, num_radix_blocks, params, EQ, + false, allocate_gpu_memory); +} + +void cuda_integer_is_at_least_one_comparisons_block_true_kb_64( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + void *lwe_array_out, void const *lwe_array_in, int8_t *mem_ptr, + void *const *bsks, void *const *ksks, uint32_t num_radix_blocks) { + + int_comparison_buffer *buffer = + (int_comparison_buffer *)mem_ptr; + + host_integer_is_at_least_one_comparisons_block_true_kb( + (cudaStream_t *)(streams), gpu_indexes, gpu_count, + static_cast(lwe_array_out), + static_cast(lwe_array_in), buffer, bsks, + (uint64_t **)(ksks), num_radix_blocks); +} + +void cleanup_cuda_integer_is_at_least_one_comparisons_block_true( + void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr_void) { + + int_comparison_buffer *mem_ptr = + (int_comparison_buffer *)(*mem_ptr_void); + mem_ptr->release((cudaStream_t *)(streams), gpu_indexes, gpu_count); +} diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh index f53096b5fc..3f82be99d6 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh @@ -58,7 +58,7 @@ __host__ void accumulate_all_blocks(cudaStream_t stream, uint32_t gpu_index, template __host__ void are_all_comparisons_block_true( cudaStream_t const *streams, uint32_t const *gpu_indexes, - uint32_t gpu_count, Torus *lwe_array_out, Torus *lwe_array_in, + uint32_t gpu_count, Torus *lwe_array_out, Torus const *lwe_array_in, int_comparison_buffer *mem_ptr, void *const *bsks, Torus *const *ksks, uint32_t num_radix_blocks) { @@ -167,7 +167,7 @@ __host__ void are_all_comparisons_block_true( template __host__ void is_at_least_one_comparisons_block_true( cudaStream_t const *streams, uint32_t const *gpu_indexes, - uint32_t gpu_count, Torus *lwe_array_out, Torus *lwe_array_in, + uint32_t gpu_count, Torus *lwe_array_out, Torus const *lwe_array_in, int_comparison_buffer *mem_ptr, void *const *bsks, Torus *const *ksks, uint32_t num_radix_blocks) { @@ -626,4 +626,35 @@ __host__ void host_integer_radix_maxmin_kb( mem_ptr->cmux_buffer, bsks, ksks, total_num_radix_blocks); } +template +__host__ void host_integer_are_all_comparisons_block_true_kb( + cudaStream_t const *streams, uint32_t const *gpu_indexes, + uint32_t gpu_count, Torus *lwe_array_out, Torus const *lwe_array_in, + int_comparison_buffer *mem_ptr, void *const *bsks, + Torus *const *ksks, uint32_t num_radix_blocks) { + + auto eq_buffer = mem_ptr->eq_buffer; + + // It returns a block encrypting 1 if all input blocks are 1 + // otherwise the block encrypts 0 + are_all_comparisons_block_true(streams, gpu_indexes, gpu_count, + lwe_array_out, lwe_array_in, mem_ptr, + bsks, ksks, num_radix_blocks); +} + +template +__host__ void host_integer_is_at_least_one_comparisons_block_true_kb( + cudaStream_t const *streams, uint32_t const *gpu_indexes, + uint32_t gpu_count, Torus *lwe_array_out, Torus const *lwe_array_in, + int_comparison_buffer *mem_ptr, void *const *bsks, + Torus *const *ksks, uint32_t num_radix_blocks) { + + auto eq_buffer = mem_ptr->eq_buffer; + + // It returns a block encrypting 1 if all input blocks are 1 + // otherwise the block encrypts 0 + is_at_least_one_comparisons_block_true( + streams, gpu_indexes, gpu_count, lwe_array_out, lwe_array_in, mem_ptr, + bsks, ksks, num_radix_blocks); +} #endif diff --git a/backends/tfhe-cuda-backend/src/bindings.rs b/backends/tfhe-cuda-backend/src/bindings.rs index d6bf96fe14..3484d33bbf 100644 --- a/backends/tfhe-cuda-backend/src/bindings.rs +++ b/backends/tfhe-cuda-backend/src/bindings.rs @@ -1083,6 +1083,92 @@ extern "C" { mem_ptr_void: *mut *mut i8, ); } +extern "C" { + pub fn scratch_cuda_integer_are_all_comparisons_block_true_kb_64( + streams: *const *mut ffi::c_void, + gpu_indexes: *const u32, + gpu_count: u32, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_radix_blocks: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + ); +} +extern "C" { + pub fn cuda_integer_are_all_comparisons_block_true_kb_64( + streams: *const *mut ffi::c_void, + gpu_indexes: *const u32, + gpu_count: u32, + lwe_array_out: *mut ffi::c_void, + lwe_array_in: *const ffi::c_void, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + num_radix_blocks: u32, + ); +} +extern "C" { + pub fn cleanup_cuda_integer_are_all_comparisons_block_true( + streams: *const *mut ffi::c_void, + gpu_indexes: *const u32, + gpu_count: u32, + mem_ptr_void: *mut *mut i8, + ); +} +extern "C" { + pub fn scratch_cuda_integer_is_at_least_one_comparisons_block_true_kb_64( + streams: *const *mut ffi::c_void, + gpu_indexes: *const u32, + gpu_count: u32, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_radix_blocks: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + ); +} +extern "C" { + pub fn cuda_integer_is_at_least_one_comparisons_block_true_kb_64( + streams: *const *mut ffi::c_void, + gpu_indexes: *const u32, + gpu_count: u32, + lwe_array_out: *mut ffi::c_void, + lwe_array_in: *const ffi::c_void, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + num_radix_blocks: u32, + ); +} +extern "C" { + pub fn cleanup_cuda_integer_is_at_least_one_comparisons_block_true( + streams: *const *mut ffi::c_void, + gpu_indexes: *const u32, + gpu_count: u32, + mem_ptr_void: *mut *mut i8, + ); +} extern "C" { pub fn cuda_keyswitch_lwe_ciphertext_vector_32( stream: *mut ffi::c_void, diff --git a/tfhe/src/integer/gpu/ciphertext/info.rs b/tfhe/src/integer/gpu/ciphertext/info.rs index 260dfcffd5..0a6316a02c 100644 --- a/tfhe/src/integer/gpu/ciphertext/info.rs +++ b/tfhe/src/integer/gpu/ciphertext/info.rs @@ -525,6 +525,43 @@ impl CudaRadixCiphertextInfo { } } + pub(crate) fn after_block_comparisons(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .enumerate() + .map(|(i, block)| CudaBlockInfo { + degree: if i == 0 { + Degree::new(1) + } else { + Degree::new(0) + }, + message_modulus: block.message_modulus, + carry_modulus: block.carry_modulus, + pbs_order: block.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + + pub(crate) fn after_aggregate_one_hot_vector(&self) -> Self { + Self { + blocks: self + .blocks + .iter() + .map(|left| CudaBlockInfo { + degree: Degree::new(left.message_modulus.0 - 1), + message_modulus: left.message_modulus, + carry_modulus: left.carry_modulus, + pbs_order: left.pbs_order, + noise_level: NoiseLevel::NOMINAL, + }) + .collect(), + } + } + pub(crate) fn after_ne(&self) -> Self { Self { blocks: self diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 6b43204968..fc5f68c122 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -3223,3 +3223,213 @@ pub unsafe fn unchecked_signed_abs_radix_kb_assign_async( + streams: &CudaStreams, + radix_lwe_out: &mut CudaVec, + radix_lwe_in: &CudaVec, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + big_lwe_dimension: LweDimension, + small_lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + num_blocks: u32, + pbs_type: PBSType, + grouping_factor: LweBskGroupingFactor, +) { + assert_eq!( + streams.gpu_indexes[0], + radix_lwe_out.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + radix_lwe_in.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + bootstrapping_key.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + keyswitch_key.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + scratch_cuda_integer_is_at_least_one_comparisons_block_true_kb_64( + streams.ptr.as_ptr(), + streams + .gpu_indexes + .iter() + .map(|i| i.0) + .collect::>() + .as_ptr(), + streams.len() as u32, + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + ); + + cuda_integer_is_at_least_one_comparisons_block_true_kb_64( + streams.ptr.as_ptr(), + streams + .gpu_indexes + .iter() + .map(|i| i.0) + .collect::>() + .as_ptr(), + streams.len() as u32, + radix_lwe_out.as_mut_c_ptr(0), + radix_lwe_in.as_c_ptr(0), + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + num_blocks, + ); + + cleanup_cuda_integer_is_at_least_one_comparisons_block_true( + streams.ptr.as_ptr(), + streams + .gpu_indexes + .iter() + .map(|i| i.0) + .collect::>() + .as_ptr(), + streams.len() as u32, + std::ptr::addr_of_mut!(mem_ptr), + ); +} + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization +/// is required +pub unsafe fn unchecked_are_all_comparisons_block_true_integer_radix_kb_async< + T: UnsignedInteger, + B: Numeric, +>( + streams: &CudaStreams, + radix_lwe_out: &mut CudaVec, + radix_lwe_in: &CudaVec, + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + big_lwe_dimension: LweDimension, + small_lwe_dimension: LweDimension, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + num_blocks: u32, + pbs_type: PBSType, + grouping_factor: LweBskGroupingFactor, +) { + assert_eq!( + streams.gpu_indexes[0], + radix_lwe_out.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + radix_lwe_in.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + bootstrapping_key.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + keyswitch_key.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + scratch_cuda_integer_are_all_comparisons_block_true_kb_64( + streams.ptr.as_ptr(), + streams + .gpu_indexes + .iter() + .map(|i| i.0) + .collect::>() + .as_ptr(), + streams.len() as u32, + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + ); + + cuda_integer_are_all_comparisons_block_true_kb_64( + streams.ptr.as_ptr(), + streams + .gpu_indexes + .iter() + .map(|i| i.0) + .collect::>() + .as_ptr(), + streams.len() as u32, + radix_lwe_out.as_mut_c_ptr(0), + radix_lwe_in.as_c_ptr(0), + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + num_blocks, + ); + + cleanup_cuda_integer_are_all_comparisons_block_true( + streams.ptr.as_ptr(), + streams + .gpu_indexes + .iter() + .map(|i| i.0) + .collect::>() + .as_ptr(), + streams.len() as u32, + std::ptr::addr_of_mut!(mem_ptr), + ); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index 0ea12c791d..9af0c24feb 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -47,6 +47,7 @@ mod scalar_shift; mod scalar_sub; mod shift; mod sub; +mod vector_find; #[cfg(all(test, feature = "__long_run_tests"))] mod tests_long_run; diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs index ff6049cbad..aebd1adacb 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs @@ -8,6 +8,8 @@ use crate::integer::gpu::ciphertext::info::CudaRadixCiphertextInfo; use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaRadixCiphertext}; use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; use crate::integer::gpu::{ + unchecked_are_all_comparisons_block_true_integer_radix_kb_async, + unchecked_is_at_least_one_comparisons_block_true_integer_radix_kb_async, unchecked_scalar_comparison_integer_radix_kb_async, ComparisonType, PBSType, }; use crate::shortint::ciphertext::Degree; @@ -398,6 +400,164 @@ impl CudaServerKey { result.as_mut().info = ct.as_ref().info.after_min_max(); result } + /// # Safety + /// + /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until streams is synchronised + pub unsafe fn unchecked_are_all_comparisons_block_true( + &self, + ct: &T, + streams: &CudaStreams, + ) -> CudaBooleanBlock + where + T: CudaIntegerRadixCiphertext, + { + let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count(); + + let ct_res: T = self.create_trivial_radix(0, 1, streams); + let mut boolean_res = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner()); + + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + unchecked_are_all_comparisons_block_true_integer_radix_kb_async( + streams, + &mut boolean_res.as_mut().ciphertext.d_blocks.0.d_vec, + &ct.as_ref().d_blocks.0.d_vec, + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + unchecked_are_all_comparisons_block_true_integer_radix_kb_async( + streams, + &mut boolean_res.as_mut().ciphertext.d_blocks.0.d_vec, + &ct.as_ref().d_blocks.0.d_vec, + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + ); + } + } + boolean_res.as_mut().ciphertext.info = boolean_res + .as_ref() + .ciphertext + .info + .after_block_comparisons(); + streams.synchronize(); + boolean_res + } + /// # Safety + /// + /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until streams is synchronised + pub unsafe fn unchecked_is_at_least_one_comparisons_block_true( + &self, + ct: &T, + streams: &CudaStreams, + ) -> CudaBooleanBlock + where + T: CudaIntegerRadixCiphertext, + { + let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count(); + + let ct_res: T = self.create_trivial_radix(0, 1, streams); + let mut boolean_res = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_res.into_inner()); + + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + unchecked_is_at_least_one_comparisons_block_true_integer_radix_kb_async( + streams, + &mut boolean_res.as_mut().ciphertext.d_blocks.0.d_vec, + &ct.as_ref().d_blocks.0.d_vec, + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + unchecked_is_at_least_one_comparisons_block_true_integer_radix_kb_async( + streams, + &mut boolean_res.as_mut().ciphertext.d_blocks.0.d_vec, + &ct.as_ref().d_blocks.0.d_vec, + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key + .input_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + ); + } + } + boolean_res.as_mut().ciphertext.info = boolean_res + .as_ref() + .ciphertext + .info + .after_block_comparisons(); + streams.synchronize(); + boolean_res + } /// # Safety /// diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs index d2ed604fbb..eaee727179 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs @@ -17,15 +17,16 @@ pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; pub(crate) mod test_sub; +pub(crate) mod test_vector_find; use crate::core_crypto::gpu::CudaStreams; use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; use crate::integer::gpu::CudaServerKey; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::*; +pub use crate::integer::server_key::radix_parallel::MatchValues; use crate::integer::{BooleanBlock, RadixCiphertext, RadixClientKey, ServerKey, U256}; use std::sync::Arc; - // Macro to generate tests for all parameter sets macro_rules! create_gpu_parameterized_test{ ($name:ident { $($param:ident),* $(,)? }) => { @@ -602,3 +603,265 @@ where d_res.to_radix_ciphertext(&context.streams) } } + +impl<'a, F> + FunctionExecutor<(&'a RadixCiphertext, &'a MatchValues), (RadixCiphertext, BooleanBlock)> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaUnsignedRadixCiphertext, + &MatchValues, + &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock), +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute( + &mut self, + input: (&'a RadixCiphertext, &'a MatchValues), + ) -> (RadixCiphertext, BooleanBlock) { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt: CudaUnsignedRadixCiphertext = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams); + + let (d_res, d_block) = (self.func)(&context.sks, &d_ctxt, input.1, &context.streams); + + let res = d_res.to_radix_ciphertext(&context.streams); + let block = d_block.to_boolean_block(&context.streams); + (res, block) + } +} + +impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a MatchValues, u64), RadixCiphertext> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaUnsignedRadixCiphertext, + &MatchValues, + u64, + &CudaStreams, + ) -> CudaUnsignedRadixCiphertext, +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute( + &mut self, + input: (&'a RadixCiphertext, &'a MatchValues, u64), + ) -> RadixCiphertext { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt: CudaUnsignedRadixCiphertext = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams); + + let d_res = (self.func)(&context.sks, &d_ctxt, input.1, input.2, &context.streams); + + d_res.to_radix_ciphertext(&context.streams) + } +} + +impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a [u64]), BooleanBlock> + for GpuFunctionExecutor +where + F: Fn(&CudaServerKey, &CudaUnsignedRadixCiphertext, &[u64], &CudaStreams) -> CudaBooleanBlock, +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute(&mut self, input: (&'a RadixCiphertext, &'a [u64])) -> BooleanBlock { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt: CudaUnsignedRadixCiphertext = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams); + + let d_block = (self.func)(&context.sks, &d_ctxt, input.1, &context.streams); + d_block.to_boolean_block(&context.streams) + } +} + +impl<'a, F> FunctionExecutor<(&'a [RadixCiphertext], &'a RadixCiphertext), BooleanBlock> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &[CudaUnsignedRadixCiphertext], + &CudaUnsignedRadixCiphertext, + &CudaStreams, + ) -> CudaBooleanBlock, +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute(&mut self, input: (&'a [RadixCiphertext], &'a RadixCiphertext)) -> BooleanBlock { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let mut d_ctxs = Vec::::with_capacity(input.0.len()); + for ctx in input.0 { + d_ctxs.push(CudaUnsignedRadixCiphertext::from_radix_ciphertext( + ctx, + &context.streams, + )); + } + let d_ctxt2: CudaUnsignedRadixCiphertext = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams); + + let d_block = (self.func)(&context.sks, &d_ctxs, &d_ctxt2, &context.streams); + d_block.to_boolean_block(&context.streams) + } +} + +impl<'a, F> FunctionExecutor<(&'a [RadixCiphertext], u64), BooleanBlock> for GpuFunctionExecutor +where + F: Fn(&CudaServerKey, &[CudaUnsignedRadixCiphertext], u64, &CudaStreams) -> CudaBooleanBlock, +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute(&mut self, input: (&'a [RadixCiphertext], u64)) -> BooleanBlock { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let mut d_ctxs = Vec::::with_capacity(input.0.len()); + for ctx in input.0 { + d_ctxs.push(CudaUnsignedRadixCiphertext::from_radix_ciphertext( + ctx, + &context.streams, + )); + } + + let d_block = (self.func)(&context.sks, &d_ctxs, input.1, &context.streams); + d_block.to_boolean_block(&context.streams) + } +} + +impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a [u64]), (RadixCiphertext, BooleanBlock)> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaUnsignedRadixCiphertext, + &[u64], + &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock), +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute( + &mut self, + input: (&'a RadixCiphertext, &'a [u64]), + ) -> (RadixCiphertext, BooleanBlock) { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt: CudaUnsignedRadixCiphertext = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams); + + let (d_res, d_block) = (self.func)(&context.sks, &d_ctxt, input.1, &context.streams); + let res = d_res.to_radix_ciphertext(&context.streams); + let block = d_block.to_boolean_block(&context.streams); + (res, block) + } +} + +impl<'a, F> + FunctionExecutor<(&'a [RadixCiphertext], &'a RadixCiphertext), (RadixCiphertext, BooleanBlock)> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &[CudaUnsignedRadixCiphertext], + &CudaUnsignedRadixCiphertext, + &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock), +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute( + &mut self, + input: (&'a [RadixCiphertext], &'a RadixCiphertext), + ) -> (RadixCiphertext, BooleanBlock) { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let mut d_ctxs = Vec::::with_capacity(input.0.len()); + for ctx in input.0 { + d_ctxs.push(CudaUnsignedRadixCiphertext::from_radix_ciphertext( + ctx, + &context.streams, + )); + } + let d_ctxt2: CudaUnsignedRadixCiphertext = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams); + + let (d_res, d_block) = (self.func)(&context.sks, &d_ctxs, &d_ctxt2, &context.streams); + let res = d_res.to_radix_ciphertext(&context.streams); + let block = d_block.to_boolean_block(&context.streams); + (res, block) + } +} + +impl<'a, F> FunctionExecutor<(&'a [RadixCiphertext], u64), (RadixCiphertext, BooleanBlock)> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &[CudaUnsignedRadixCiphertext], + u64, + &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock), +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute(&mut self, input: (&'a [RadixCiphertext], u64)) -> (RadixCiphertext, BooleanBlock) { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let mut d_ctxs = Vec::::with_capacity(input.0.len()); + for ctx in input.0 { + d_ctxs.push(CudaUnsignedRadixCiphertext::from_radix_ciphertext( + ctx, + &context.streams, + )); + } + + let (d_res, d_block) = (self.func)(&context.sks, &d_ctxs, input.1, &context.streams); + let res = d_res.to_radix_ciphertext(&context.streams); + let block = d_block.to_boolean_block(&context.streams); + (res, block) + } +} diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_vector_find.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_vector_find.rs new file mode 100644 index 0000000000..ae862a15b7 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_vector_find.rs @@ -0,0 +1,220 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parameterized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ + default_contains_clear_test_case, default_contains_test_case, + default_first_index_in_clears_test_case, default_first_index_of_clear_test_case, + default_first_index_of_test_case, default_index_in_clears_test_case, + default_index_of_clear_test_case, default_index_of_test_case, default_is_in_clears_test_case, + default_match_value_or_test_case, default_match_value_test_case, + unchecked_contains_clear_test_case, unchecked_contains_test_case, + unchecked_first_index_in_clears_test_case, unchecked_first_index_of_clear_test_case, + unchecked_first_index_of_test_case, unchecked_index_in_clears_test_case, + unchecked_index_of_clear_test_case, unchecked_index_of_test_case, + unchecked_is_in_clears_test_case, unchecked_match_value_or_test_case, + unchecked_match_value_test_case, +}; + +use crate::shortint::parameters::*; + +create_gpu_parameterized_test!(integer_unchecked_match_value); +create_gpu_parameterized_test!(integer_unchecked_match_value_or); +create_gpu_parameterized_test!(integer_unchecked_contains); +create_gpu_parameterized_test!(integer_unchecked_contains_clear); +create_gpu_parameterized_test!(integer_unchecked_is_in_clears); +create_gpu_parameterized_test!(integer_unchecked_index_in_clears); +create_gpu_parameterized_test!(integer_unchecked_first_index_in_clears); +create_gpu_parameterized_test!(integer_unchecked_index_of); +create_gpu_parameterized_test!(integer_unchecked_index_of_clear); +create_gpu_parameterized_test!(integer_unchecked_first_index_of); +create_gpu_parameterized_test!(integer_unchecked_first_index_of_clear); + +create_gpu_parameterized_test!(integer_default_match_value); +create_gpu_parameterized_test!(integer_default_match_value_or); +create_gpu_parameterized_test!(integer_default_contains); +create_gpu_parameterized_test!(integer_default_contains_clear); +create_gpu_parameterized_test!(integer_default_is_in_clears); +create_gpu_parameterized_test!(integer_default_index_in_clears); +create_gpu_parameterized_test!(integer_default_first_index_in_clears); +create_gpu_parameterized_test!(integer_default_index_of); +create_gpu_parameterized_test!(integer_default_index_of_clear); +create_gpu_parameterized_test!(integer_default_first_index_of); +create_gpu_parameterized_test!(integer_default_first_index_of_clear); + +fn integer_unchecked_match_value

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_match_value); + unchecked_match_value_test_case(param, executor); +} + +fn integer_unchecked_match_value_or

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_match_value_or); + unchecked_match_value_or_test_case(param, executor); +} + +fn integer_unchecked_contains

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_contains); + unchecked_contains_test_case(param, executor); +} + +fn integer_unchecked_contains_clear

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_contains_clear); + unchecked_contains_clear_test_case(param, executor); +} + +fn integer_unchecked_is_in_clears

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_is_in_clears); + unchecked_is_in_clears_test_case(param, executor); +} + +fn integer_unchecked_index_in_clears

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_index_in_clears); + unchecked_index_in_clears_test_case(param, executor); +} + +fn integer_unchecked_first_index_in_clears

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_first_index_in_clears); + unchecked_first_index_in_clears_test_case(param, executor); +} +fn integer_unchecked_index_of

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_index_of); + unchecked_index_of_test_case(param, executor); +} + +fn integer_unchecked_index_of_clear

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_index_of_clear); + unchecked_index_of_clear_test_case(param, executor); +} + +fn integer_unchecked_first_index_of_clear

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_first_index_of_clear); + unchecked_first_index_of_clear_test_case(param, executor); +} + +fn integer_unchecked_first_index_of

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_first_index_of); + unchecked_first_index_of_test_case(param, executor); +} + +// Default tests + +fn integer_default_match_value

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::match_value); + default_match_value_test_case(param, executor); +} + +fn integer_default_match_value_or

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::match_value_or); + default_match_value_or_test_case(param, executor); +} + +fn integer_default_contains

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::contains); + default_contains_test_case(param, executor); +} + +fn integer_default_contains_clear

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::contains_clear); + default_contains_clear_test_case(param, executor); +} + +fn integer_default_is_in_clears

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::is_in_clears); + default_is_in_clears_test_case(param, executor); +} + +fn integer_default_index_in_clears

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::index_in_clears); + default_index_in_clears_test_case(param, executor); +} + +fn integer_default_first_index_in_clears

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::first_index_in_clears); + default_first_index_in_clears_test_case(param, executor); +} + +fn integer_default_index_of

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::index_of); + default_index_of_test_case(param, executor); +} + +fn integer_default_index_of_clear

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::index_of_clear); + default_index_of_clear_test_case(param, executor); +} + +fn integer_default_first_index_of

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::first_index_of); + default_first_index_of_test_case(param, executor); +} + +fn integer_default_first_index_of_clear

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::first_index_of_clear); + default_first_index_of_clear_test_case(param, executor); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/vector_find.rs b/tfhe/src/integer/gpu/server_key/radix/vector_find.rs new file mode 100644 index 0000000000..7213dfd74f --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/vector_find.rs @@ -0,0 +1,1668 @@ +use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; +use crate::core_crypto::gpu::CudaStreams; +use crate::core_crypto::prelude::{LweBskGroupingFactor, UnsignedInteger}; +use crate::integer::block_decomposition::{BlockDecomposer, Decomposable, DecomposableInto}; +use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; +use crate::integer::gpu::server_key::radix::CudaRadixCiphertext; +use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; +use crate::integer::gpu::{ + apply_univariate_lut_kb_async, compute_prefix_sum_hillis_steele_async, PBSType, +}; +pub use crate::integer::server_key::radix_parallel::MatchValues; +use crate::prelude::CastInto; +use itertools::Itertools; +use rayon::prelude::*; +use std::hash::Hash; +// /// MatchValues for the `match_value_parallelized` family of function + +impl CudaServerKey { + pub fn num_bits_to_represent_unsigned_value(&self, clear: Clear) -> usize + where + Clear: UnsignedInteger, + { + if clear == Clear::MAX { + Clear::BITS + } else { + (clear + Clear::ONE).ceil_ilog2() as usize + } + } + + pub fn convert_selectors_to_unsigned_radix_ciphertext( + &self, + selectors: &[CudaBooleanBlock], + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + let packed_list = CudaLweCiphertextList::from_vec_cuda_lwe_ciphertexts_list( + selectors + .iter() + .map(|ciphertext| &ciphertext.0.ciphertext.d_blocks), + streams, + ); + + let blocks_ct: CudaUnsignedRadixCiphertext = CudaUnsignedRadixCiphertext { + ciphertext: CudaRadixCiphertext { + d_blocks: packed_list, + info: selectors[0].0.ciphertext.info.clone(), + }, + }; + blocks_ct + } + + pub fn convert_radixes_vec_to_single_radix_ciphertext( + &self, + selectors: &[CudaRadixCiphertext], + streams: &CudaStreams, + ) -> T + where + T: CudaIntegerRadixCiphertext, + { + let packed_list = CudaLweCiphertextList::from_vec_cuda_lwe_ciphertexts_list( + selectors.iter().map(|ciphertext| &ciphertext.d_blocks), + streams, + ); + + CudaIntegerRadixCiphertext::from(CudaRadixCiphertext { + d_blocks: packed_list, + info: selectors[0].info.clone(), + }) + } + + pub fn convert_unsigned_radix_ciphertext_to_selectors( + &self, + ct: &mut CudaUnsignedRadixCiphertext, + streams: &CudaStreams, + ) -> Vec { + let num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0; + let lwe_size = ct.as_ref().d_blocks.lwe_dimension().to_lwe_size().0; + let mut unpacked_selectors = Vec::::with_capacity(num_blocks); + for i in 0..num_blocks { + let mut radix_ct: CudaUnsignedRadixCiphertext = + self.create_trivial_radix(0, 1, streams); + let slice_in = ct + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(i * lwe_size..(i + 1) * lwe_size, 0) + .unwrap(); + let mut slice_out = radix_ct + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size, 0) + .unwrap(); + unsafe { slice_out.copy_from_gpu_async(&slice_in, streams, 0) }; + let boolean_block = CudaBooleanBlock::from_cuda_radix_ciphertext(radix_ct.into_inner()); + + unpacked_selectors.push(boolean_block); + } + unpacked_selectors + } + + /// Returns how many blocks a radix ciphertext should have to + /// be able to represent the given unsigned integer + pub fn num_blocks_to_represent_unsigned_value(&self, clear: Clear) -> usize + where + Clear: UnsignedInteger, + { + let num_bits_to_represent_output_value = self.num_bits_to_represent_unsigned_value(clear); + let num_bits_in_message = self.message_modulus.0.ilog2(); + num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize) + } + + /// `match` an input value to an output value + /// + /// - Input values are not required to span all possible values that `ct` could hold. + /// + /// - The output radix has a number of blocks that depends on the maximum possible output value + /// from the `MatchValues` + /// + /// Returns a boolean block that encrypts `true` if the input `ct` + /// matched one of the possible inputs + pub fn unchecked_match_value( + &self, + ct: &CudaUnsignedRadixCiphertext, + matches: &MatchValues, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + Clear: UnsignedInteger + DecomposableInto + CastInto, + { + if matches.get_values().is_empty() { + let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams); + let trivial_bool = CudaBooleanBlock::from_cuda_radix_ciphertext( + trivial_ct.duplicate(streams).into_inner(), + ); + return (trivial_ct, trivial_bool); + } + + let selectors = self.compute_equality_selectors( + ct, + matches + .get_values() + .par_iter() + .map(|(input, _output)| *input), + streams, + ); + + let max_output_value = matches + .get_values() + .iter() + .copied() + .max_by(|(_, outputl), (_, outputr)| outputl.cmp(outputr)) + .expect("luts is not empty at this point") + .1; + + let num_blocks_to_represent_values = + self.num_blocks_to_represent_unsigned_value(max_output_value); + + let blocks_ct = self.convert_selectors_to_unsigned_radix_ciphertext(&selectors, streams); + + let possible_results_to_be_aggregated = self.create_possible_results( + num_blocks_to_represent_values, + selectors.into_par_iter().zip( + matches + .get_values() + .par_iter() + .map(|(_input, output)| *output), + ), + streams, + ); + + if max_output_value == Clear::ZERO { + // If the max output value is zero, it means 0 is the only output possible + // and in the case where none of the input matches the ct, the returned value is 0 + // + // Thus in that case, the returned value is always 0 regardless of ct's value, + // but we still have to see if the input matched something + let zero_ct: CudaUnsignedRadixCiphertext = unsafe { + self.create_trivial_zero_radix_async(num_blocks_to_represent_values, streams) + }; + let out_block = unsafe { + self.unchecked_is_at_least_one_comparisons_block_true(&blocks_ct, streams) + }; + return (zero_ct, out_block); + } + let result = self.aggregate_one_hot_vector(&possible_results_to_be_aggregated, streams); + let out_ct = + unsafe { self.unchecked_is_at_least_one_comparisons_block_true(&blocks_ct, streams) }; + (result, out_ct) + } + + /// `match` an input value to an output value + /// + /// - Input values are not required to span all possible values that `ct` could hold. + /// + /// - The output radix has a number of blocks that depends on the maximum possible output value + /// from the `MatchValues` + /// + /// Returns a boolean block that encrypts `true` if the input `ct` + /// matched one of the possible inputs + pub fn match_value( + &self, + ct: &CudaUnsignedRadixCiphertext, + matches: &MatchValues, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + Clear: UnsignedInteger + DecomposableInto + CastInto, + { + if ct.block_carries_are_empty() { + self.unchecked_match_value(ct, matches, streams) + } else { + let mut clone = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut clone, streams); + } + self.unchecked_match_value(&clone, matches, streams) + } + } + + /// `match` an input value to an output value + /// + /// - Input values are not required to span all possible values that `ct` could hold. + /// + /// - The output radix has a number of blocks that depends on the maximum possible output value + /// from the `MatchValues` + /// + /// + /// If none of the input matched the `ct` then, `ct` will encrypt the + /// value given to `or_value` + pub fn unchecked_match_value_or( + &self, + ct: &CudaUnsignedRadixCiphertext, + matches: &MatchValues, + or_value: Clear, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext + where + Clear: UnsignedInteger + DecomposableInto + CastInto, + { + if matches.get_values().is_empty() { + let ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix( + or_value, + self.num_blocks_to_represent_unsigned_value(or_value), + streams, + ); + return ct; + } + let (result, selected) = self.unchecked_match_value(ct, matches, streams); + + // The result must have as many block to represent either the result of the match or the + // or_value + let num_blocks_to_represent_or_value = + self.num_blocks_to_represent_unsigned_value(or_value); + let num_blocks = (result.as_ref().d_blocks.lwe_ciphertext_count().0) + .max(num_blocks_to_represent_or_value); + let or_value: CudaUnsignedRadixCiphertext = + self.create_trivial_radix(or_value, num_blocks, streams); + + // Note, this could be slightly faster when we have scalar if then_else + self.unchecked_if_then_else(&selected, &result, &or_value, streams) + } + + /// `match` an input value to an output value + /// + /// - Input values are not required to span all possible values that `ct` could hold. + /// + /// - The output radix has a number of blocks that depends on the maximum possible output value + /// from the `MatchValues` + /// + /// If none of the input matched the `ct` then, `ct` will encrypt the + /// value given to `or_value` + pub fn match_value_or( + &self, + ct: &CudaUnsignedRadixCiphertext, + matches: &MatchValues, + or_value: Clear, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext + where + Clear: UnsignedInteger + DecomposableInto + CastInto, + { + if ct.block_carries_are_empty() { + self.unchecked_match_value_or(ct, matches, or_value, streams) + } else { + let mut clone = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut clone, streams); + } + self.unchecked_match_value_or(&clone, matches, or_value, streams) + } + } + + // /// Returns an encrypted `true` if the encrypted `value` is found in the encrypted slice + pub fn unchecked_contains( + &self, + cts: &[T], + value: &T, + streams: &CudaStreams, + ) -> CudaBooleanBlock + where + T: CudaIntegerRadixCiphertext, + { + if cts.is_empty() { + let d_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams); + return CudaBooleanBlock::from_cuda_radix_ciphertext(d_ct.ciphertext); + } + //Here It would be better to launch them in parallel maybe using different streams or + // packed them in a vector + let selectors = cts + .iter() + .map(|ct| self.eq(ct, value, streams)) + .collect::>(); + + let packed_ct = self.convert_selectors_to_unsigned_radix_ciphertext(&selectors, streams); + unsafe { self.unchecked_is_at_least_one_comparisons_block_true(&packed_ct, streams) } + } + + /// Returns an encrypted `true` if the encrypted `value` is found in the encrypted slice + pub fn contains(&self, cts: &[T], value: &T, streams: &CudaStreams) -> CudaBooleanBlock + where + T: CudaIntegerRadixCiphertext, + { + let mut tmp_cts = Vec::::with_capacity(cts.len()); + let mut tmp_value; + + let cts = if cts.iter().any(|ct| !ct.block_carries_are_empty()) { + // Need a way to parallelize this step + for ct in cts.iter() { + let mut temp_ct = unsafe { ct.duplicate_async(streams) }; + if !temp_ct.block_carries_are_empty() { + unsafe { + self.full_propagate_assign_async(&mut temp_ct, streams); + } + } + tmp_cts.push(temp_ct); + } + + &tmp_cts + } else { + cts + }; + + let value = if value.block_carries_are_empty() { + value + } else { + tmp_value = value.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp_value, streams); + } + &tmp_value + }; + + self.unchecked_contains(cts, value, streams) + } + + /// Returns an encrypted `true` if the clear `value` is found in the encrypted slice + pub fn unchecked_contains_clear( + &self, + cts: &[T], + clear: Clear, + streams: &CudaStreams, + ) -> CudaBooleanBlock + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto, + { + if cts.is_empty() { + let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams); + let trivial_bool = CudaBooleanBlock::from_cuda_radix_ciphertext( + trivial_ct.duplicate(streams).into_inner(), + ); + return trivial_bool; + } + let selectors = cts + .iter() + .map(|ct| self.scalar_eq(ct, clear, streams)) + .collect::>(); + + let packed_ct = self.convert_selectors_to_unsigned_radix_ciphertext(&selectors, streams); + unsafe { self.unchecked_is_at_least_one_comparisons_block_true(&packed_ct, streams) } + } + + /// Returns an encrypted `true` if the clear `value` is found in the encrypted slice + pub fn contains_clear( + &self, + cts: &[T], + clear: Clear, + streams: &CudaStreams, + ) -> CudaBooleanBlock + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto, + { + let mut tmp_cts = Vec::::with_capacity(cts.len()); + let cts = if cts.iter().any(|ct| !ct.block_carries_are_empty()) { + // Need a way to parallelize this step + for ct in cts.iter() { + let mut temp_ct = unsafe { ct.duplicate_async(streams) }; + if !temp_ct.block_carries_are_empty() { + unsafe { + self.full_propagate_assign_async(&mut temp_ct, streams); + } + } + tmp_cts.push(temp_ct); + } + &tmp_cts + } else { + cts + }; + + self.unchecked_contains_clear(cts, clear, streams) + } + + // /// Returns an encrypted `true` if the encrypted `value` is found in the clear slice + pub fn unchecked_is_in_clears( + &self, + ct: &T, + clears: &[Clear], + streams: &CudaStreams, + ) -> CudaBooleanBlock + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto, + { + if clears.is_empty() { + let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams); + let trivial_bool = CudaBooleanBlock::from_cuda_radix_ciphertext( + trivial_ct.duplicate(streams).into_inner(), + ); + return trivial_bool; + } + let selectors = self.compute_equality_selectors(ct, clears.par_iter().copied(), streams); + + let blocks_ct = self.convert_selectors_to_unsigned_radix_ciphertext(&selectors, streams); + unsafe { self.unchecked_is_at_least_one_comparisons_block_true(&blocks_ct, streams) } + } + + /// Returns an encrypted `true` if the encrypted `value` is found in the clear slice + pub fn is_in_clears( + &self, + ct: &T, + clears: &[Clear], + streams: &CudaStreams, + ) -> CudaBooleanBlock + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto, + { + let mut tmp_ct; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp_ct = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp_ct, streams); + } + &tmp_ct + }; + self.unchecked_is_in_clears(ct, clears, streams) + } + + /// Returns the encrypted index of the encrypted `value` in the clear slice + /// also returns an encrypted boolean that is `true` if the encrypted value was found. + /// + /// # Notes + /// + /// - clear values in the slice must be unique (otherwise use + /// [Self::unchecked_first_index_in_clears]) + /// - If the encrypted value is not in the clear slice, the returned index is 0 + pub fn unchecked_index_in_clears( + &self, + ct: &T, + clears: &[Clear], + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto, + { + if clears.is_empty() { + let trivial_ct2: CudaUnsignedRadixCiphertext = self.create_trivial_radix( + 0, + ct.as_ref().d_blocks.lwe_ciphertext_count().0, + streams, + ); + let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams); + let trivial_bool = CudaBooleanBlock::from_cuda_radix_ciphertext( + trivial_ct.duplicate(streams).into_inner(), + ); + return (trivial_ct2, trivial_bool); + } + let selectors = self.compute_equality_selectors(ct, clears.par_iter().copied(), streams); + self.compute_final_index_from_selectors(selectors, streams) + } + + /// Returns the encrypted index of the encrypted `value` in the clear slice + /// also returns an encrypted boolean that is `true` if the encrypted value was found. + /// + /// # Notes + /// + /// - clear values in the slice must be unique (otherwise use [Self::index_in_clears]) + /// - If the encrypted value is not in the clear slice, the returned index is 0 + pub fn index_in_clears( + &self, + ct: &T, + clears: &[Clear], + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto, + { + let mut tmp_ct; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp_ct = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp_ct, streams); + } + streams.synchronize(); + &tmp_ct + }; + + self.unchecked_index_in_clears(ct, clears, streams) + } + + /// Returns the encrypted index of the _first_ occurrence of encrypted `value` in the clear + /// slice also, it returns an encrypted boolean that is `true` if the encrypted value was + /// found. + /// + /// # Notes + /// + /// - If the encrypted value is not in the clear slice, the returned index is 0 + pub fn unchecked_first_index_in_clears( + &self, + ct: &T, + clears: &[Clear], + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto + Hash, + { + if clears.is_empty() { + let trivial_ct2: CudaUnsignedRadixCiphertext = self.create_trivial_radix( + 0, + ct.as_ref().d_blocks.lwe_ciphertext_count().0, + streams, + ); + let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams); + let trivial_bool = CudaBooleanBlock::from_cuda_radix_ciphertext( + trivial_ct.duplicate(streams).into_inner(), + ); + return (trivial_ct2, trivial_bool); + } + let unique_clears = clears + .iter() + .copied() + .enumerate() + .unique_by(|&(_, value)| value) + .collect::>(); + let selectors = self.compute_equality_selectors( + ct, + unique_clears.par_iter().copied().map(|(_, value)| value), + streams, + ); + + let selectors2 = self.convert_selectors_to_unsigned_radix_ciphertext(&selectors, streams); + let num_blocks_result = + (clears.len().ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize; + + let possible_values = self.create_possible_results( + num_blocks_result, + selectors + .into_par_iter() + .zip(unique_clears.into_par_iter().map(|(index, _)| index as u64)), + streams, + ); + + let out_ct = self.aggregate_one_hot_vector(&possible_values, streams); + + let block = + unsafe { self.unchecked_is_at_least_one_comparisons_block_true(&selectors2, streams) }; + (out_ct, block) + } + + /// Returns the encrypted index of the _first_ occurrence of encrypted `value` in the clear + /// slice also, it returns an encrypted boolean that is `true` if the encrypted value was + /// found. + /// + /// # Notes + /// + /// - If the encrypted value is not in the clear slice, the returned index is 0 + pub fn first_index_in_clears( + &self, + ct: &T, + clears: &[Clear], + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto + Hash, + { + let mut tmp_ct; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp_ct = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp_ct, streams); + } + streams.synchronize(); + &tmp_ct + }; + + self.unchecked_first_index_in_clears(ct, clears, streams) + } + + /// Returns the encrypted index of the of encrypted `value` in the ciphertext slice + /// also, it returns an encrypted boolean that is `true` if the encrypted value was found. + /// + /// # Notes + /// + /// - clear values in the slice must be unique (otherwise use [Self::unchecked_first_index_of]) + /// - If the encrypted value is not in the encrypted slice, the returned index is 0 + pub fn unchecked_index_of( + &self, + cts: &[T], + value: &T, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + { + if cts.is_empty() { + let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams); + let trivial_bool = CudaBooleanBlock::from_cuda_radix_ciphertext( + trivial_ct.duplicate(streams).into_inner(), + ); + return (trivial_ct, trivial_bool); + } + let selectors = cts + .iter() + .map(|ct| self.eq(ct, value, streams)) + .collect::>(); + + self.compute_final_index_from_selectors(selectors, streams) + } + + /// Returns the encrypted index of the of encrypted `value` in the ciphertext slice + /// also, it returns an encrypted boolean that is `true` if the encrypted value was found. + /// + /// # Notes + /// + /// - clear values in the slice must be unique (otherwise use [Self::first_index_of]) + /// - If the encrypted value is not in the encrypted slice, the returned index is 0 + pub fn index_of( + &self, + cts: &[T], + value: &T, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + { + let mut tmp_cts = Vec::::with_capacity(cts.len()); + let mut tmp_value; + + let cts = if cts.iter().any(|ct| !ct.block_carries_are_empty()) { + // Need a way to parallelize this step + for ct in cts.iter() { + let mut temp_ct = unsafe { ct.duplicate_async(streams) }; + if !temp_ct.block_carries_are_empty() { + unsafe { + self.full_propagate_assign_async(&mut temp_ct, streams); + } + } + tmp_cts.push(temp_ct); + } + + &tmp_cts + } else { + cts + }; + + let value = if value.block_carries_are_empty() { + value + } else { + tmp_value = value.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp_value, streams); + } + &tmp_value + }; + self.unchecked_index_of(cts, value, streams) + } + + /// Returns the encrypted index of the of clear `value` in the ciphertext slice + /// also, it returns an encrypted boolean that is `true` if the encrypted value was found. + /// + /// # Notes + /// + /// - clear values in the slice must be unique (otherwise use + /// [Self::unchecked_first_index_of_clear]) + /// - If the clear value is not in the encrypted slice, the returned index is 0 + pub fn unchecked_index_of_clear( + &self, + cts: &[T], + clear: Clear, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto, + { + if cts.is_empty() { + let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams); + let trivial_bool = CudaBooleanBlock::from_cuda_radix_ciphertext( + trivial_ct.duplicate(streams).into_inner(), + ); + return (trivial_ct, trivial_bool); + } + let selectors = cts + .iter() + .map(|ct| self.scalar_eq(ct, clear, streams)) + .collect::>(); + + self.compute_final_index_from_selectors(selectors, streams) + } + + /// Returns the encrypted index of the of clear `value` in the ciphertext slice + /// also, it returns an encrypted boolean that is `true` if the encrypted value was found. + /// + /// # Notes + /// + /// - clear values in the slice must be unique (otherwise use [Self::first_index_of_clear]) + /// - If the clear value is not in the encrypted slice, the returned index is 0 + pub fn index_of_clear( + &self, + cts: &[T], + clear: Clear, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto, + { + let mut tmp_cts = Vec::::with_capacity(cts.len()); + + let cts = if cts.iter().any(|ct| !ct.block_carries_are_empty()) { + // Need a way to parallelize this step + for ct in cts.iter() { + let mut temp_ct = unsafe { ct.duplicate_async(streams) }; + if !temp_ct.block_carries_are_empty() { + unsafe { + self.full_propagate_assign_async(&mut temp_ct, streams); + } + } + tmp_cts.push(temp_ct); + } + + &tmp_cts + } else { + cts + }; + self.unchecked_index_of_clear(cts, clear, streams) + } + + /// Returns the encrypted index of the _first_ occurrence of clear `value` in the ciphertext + /// slice also, it returns an encrypted boolean that is `true` if the encrypted value was + /// found. + /// + /// # Notes + /// + /// - If the clear value is not in the clear slice, the returned index is 0 + pub fn unchecked_first_index_of_clear( + &self, + cts: &[T], + clear: Clear, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto, + { + if cts.is_empty() { + let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams); + let trivial_bool = CudaBooleanBlock::from_cuda_radix_ciphertext( + trivial_ct.duplicate(streams).into_inner(), + ); + return (trivial_ct, trivial_bool); + } + let num_blocks_result = + (cts.len().ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize; + + let selectors = cts + .iter() + .map(|ct| self.scalar_eq(ct, clear, streams)) + .collect::>(); + + let packed_selectors = + self.convert_selectors_to_unsigned_radix_ciphertext(&selectors, streams); + let mut only_first_selectors = + unsafe { self.only_keep_first_true(packed_selectors, streams) }; + + let unpacked_selectors = + self.convert_unsigned_radix_ciphertext_to_selectors(&mut only_first_selectors, streams); + + let possible_values = self.create_possible_results( + num_blocks_result, + unpacked_selectors + .into_par_iter() + .enumerate() + .map(|(i, v)| (v, i as u64)), + streams, + ); + let out_ct = self.aggregate_one_hot_vector(&possible_values, streams); + + let block = unsafe { + self.unchecked_is_at_least_one_comparisons_block_true(&only_first_selectors, streams) + }; + (out_ct, block) + } + + /// Returns the encrypted index of the _first_ occurrence of clear `value` in the ciphertext + /// slice also, it returns an encrypted boolean that is `true` if the encrypted value was + /// found. + /// + /// # Notes + /// + /// - If the clear value is not in the clear slice, the returned index is 0 + pub fn first_index_of_clear( + &self, + cts: &[T], + clear: Clear, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + Clear: DecomposableInto + CastInto, + { + let mut tmp_cts = Vec::::with_capacity(cts.len()); + + let cts = if cts.iter().any(|ct| !ct.block_carries_are_empty()) { + // Need a way to parallelize this step + for ct in cts.iter() { + let mut temp_ct = unsafe { ct.duplicate_async(streams) }; + if !temp_ct.block_carries_are_empty() { + unsafe { + self.full_propagate_assign_async(&mut temp_ct, streams); + } + } + tmp_cts.push(temp_ct); + } + + &tmp_cts + } else { + cts + }; + self.unchecked_first_index_of_clear(cts, clear, streams) + } + + /// Returns the encrypted index of the _first_ occurrence of encrypted `value` in the ciphertext + /// slice also, it returns an encrypted boolean that is `true` if the encrypted value was found. + /// + /// # Notes + /// + /// - If the encrypted value is not in the clear slice, the returned index is 0 + pub fn unchecked_first_index_of( + &self, + cts: &[T], + value: &T, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + { + if cts.is_empty() { + let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(0, 1, streams); + + let trivial_bool = CudaBooleanBlock::from_cuda_radix_ciphertext( + trivial_ct.duplicate(streams).into_inner(), + ); + return (trivial_ct, trivial_bool); + } + + let num_blocks_result = + (cts.len().ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize; + + let selectors = cts + .iter() + .map(|ct| self.eq(ct, value, streams)) + .collect::>(); + + let packed_selectors = + self.convert_selectors_to_unsigned_radix_ciphertext(&selectors, streams); + + let mut only_first_selectors = + unsafe { self.only_keep_first_true(packed_selectors, streams) }; + + let unpacked_selectors = + self.convert_unsigned_radix_ciphertext_to_selectors(&mut only_first_selectors, streams); + + let possible_values = self.create_possible_results( + num_blocks_result, + unpacked_selectors + .into_par_iter() + .enumerate() + .map(|(i, v)| (v, i as u64)), + streams, + ); + let out_ct = self.aggregate_one_hot_vector(&possible_values, streams); + + let block = unsafe { + self.unchecked_is_at_least_one_comparisons_block_true(&only_first_selectors, streams) + }; + (out_ct, block) + } + + /// Returns the encrypted index of the _first_ occurrence of encrypted `value` in the ciphertext + /// slice also, it returns an encrypted boolean that is `true` if the encrypted value was found. + /// + /// # Notes + /// + /// - If the encrypted value is not in the clear slice, the returned index is 0 + pub fn first_index_of( + &self, + cts: &[T], + value: &T, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + { + let mut tmp_cts = Vec::::with_capacity(cts.len()); + let mut tmp_value; + + let cts = if cts.iter().any(|ct| !ct.block_carries_are_empty()) { + // Need a way to parallelize this step + for ct in cts.iter() { + let mut temp_ct = unsafe { ct.duplicate_async(streams) }; + if !temp_ct.block_carries_are_empty() { + unsafe { + self.full_propagate_assign_async(&mut temp_ct, streams); + } + } + tmp_cts.push(temp_ct); + } + + &tmp_cts + } else { + cts + }; + + let value = if value.block_carries_are_empty() { + value + } else { + tmp_value = value.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp_value, streams); + } + &tmp_value + }; + self.unchecked_first_index_of(cts, value, streams) + } + + fn compute_final_index_from_selectors( + &self, + selectors: Vec, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) { + let num_blocks_result = + (selectors.len().ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize; + + let selectors2 = self.convert_selectors_to_unsigned_radix_ciphertext(&selectors, streams); + let possible_values = self.create_possible_results( + num_blocks_result, + selectors + .into_par_iter() + .enumerate() + .map(|(i, v)| (v, i as u64)), + streams, + ); + let one_hot_vector = self.aggregate_one_hot_vector(&possible_values, streams); + + let block = + unsafe { self.unchecked_is_at_least_one_comparisons_block_true(&selectors2, streams) }; + + (one_hot_vector, block) + } + + /// Computes the vector of selectors from an input iterator of clear values and an encrypted + /// value + /// + /// Given an iterator of clear values, and an encrypted radix ciphertext, + /// this method will return a vector of encrypted boolean values where + /// each value is either 1 if the ct is equal to the corresponding clear in the iterator + /// otherwise it will be 0. + /// On the GPU after applying many luts the result is stored differently than on the CPU. + /// If we have 4 many luts result is stored contiguosly in memory as follows: + /// [result many lut 1][result many lut 2][result many lut 3][result many lut 4] + /// In this case we need to jump between the results of the many luts to build the final result + /// + /// Requires ct to have empty carries + fn compute_equality_selectors( + &self, + ct: &T, + possible_input_values: Iter, + streams: &CudaStreams, + ) -> Vec + where + T: CudaIntegerRadixCiphertext, + Iter: ParallelIterator, + Clear: Decomposable + CastInto, + { + assert!( + ct.block_carries_are_empty(), + "internal error: ciphertext carries must be empty" + ); + assert!( + self.carry_modulus.0 >= self.message_modulus.0, + "This function uses many LUTs in a way that requires to have at least as much carry \ + space as message space ({:?} vs {:?})", + self.carry_modulus, + self.message_modulus + ); + // Contains the LUTs used to compare a block with scalar block values + // in many LUTs format for efficiency + let luts = { + let scalar_block_cmp_fns = (0..self.message_modulus.0) + .map(|msg_value| move |block: u64| u64::from(block == msg_value)) + .collect::>(); + + let fns = scalar_block_cmp_fns + .iter() + .map(|func| func as &dyn Fn(u64) -> u64) + .collect::>(); + + self.generate_many_lookup_table(fns.as_slice()) + }; + + let blocks_cmps = + unsafe { self.apply_many_lookup_table_async(ct.as_ref(), &luts, streams) }; + + let num_bits_in_message = self.message_modulus.0.ilog2(); + let num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0; + let lwe_dimension = ct.as_ref().d_blocks.lwe_dimension(); + let lwe_size = lwe_dimension.to_lwe_size().0; + possible_input_values + .map(|input_value| { + let cmps: Vec = BlockDecomposer::new(input_value, num_bits_in_message) + .take(num_blocks) + .map(|block_value| block_value.cast_into()) + .collect::>(); + + let mut d_ct_res: CudaUnsignedRadixCiphertext = + unsafe { self.create_trivial_zero_radix_async(num_blocks, streams) }; + // Here we jump between the results of the many luts to build the final result + for (block_index, block_value) in cmps.iter().enumerate() { + let mut dest_slice = d_ct_res + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(block_index * lwe_size..lwe_size * (block_index + 1), 0) + .unwrap(); + // block_value gives us the index of the many lut we need to use for each block + let mut copy_ct = blocks_cmps[*block_value].duplicate(streams); + let src_slice = copy_ct + .d_blocks + .0 + .d_vec + .as_mut_slice(block_index * lwe_size..lwe_size * (block_index + 1), 0) + .unwrap(); + unsafe { dest_slice.copy_from_gpu_async(&src_slice, streams, 0) }; + } + unsafe { self.unchecked_are_all_comparisons_block_true(&d_ct_res, streams) } + }) + .collect::>() + } + + /// Creates a vector of radix ciphertext from an iterator that associates encrypted boolean + /// values to clear values. + /// + /// The elements of the resulting vector are zero if the corresponding BooleanBlock encrypted 0, + /// otherwise it encrypts the associated clear value. + /// + /// This is only really useful if only one of the boolean block is known to be non-zero. + /// + /// `num_blocks`: number of blocks (unpacked) needed to represent the biggest clear value + /// + /// - Resulting radix ciphertexts have their block packed, thus they will have ceil (numb_blocks + /// / 2) elements + fn create_possible_results( + &self, + num_blocks: usize, + possible_outputs: Iter, + streams: &CudaStreams, + ) -> Vec + where + T: CudaIntegerRadixCiphertext, + Iter: ParallelIterator, + Clear: Decomposable + CastInto, + { + assert!( + self.carry_modulus.0 >= self.message_modulus.0, + "As this function packs blocks, it requires to have at least as much carry \ + space as message space ({:?} vs {:?})", + self.carry_modulus, + self.message_modulus + ); + // Vector of functions that returns function, that will be used to create LUTs later + let scalar_block_cmp_fns = (0..(self.message_modulus.0 * self.message_modulus.0)) + .map(|packed_block_value| { + move |is_selected: u64| { + if is_selected == 1 { + packed_block_value + } else { + 0 + } + } + }) + .collect::>(); + + // How "many LUTs" we can apply, since we are going to apply luts on boolean values + // (Degree(1), Modulus(2)) + // Equivalent to (2^(msg_bits + carry_bits - 1) + let max_num_many_luts = ((self.message_modulus.0 * self.carry_modulus.0) / 2) as usize; + + let num_bits_in_message = self.message_modulus.0.ilog2(); + let vec_cts = possible_outputs + .map(|(selector, output_value)| { + let decomposed_value = BlockDecomposer::new(output_value, 2 * num_bits_in_message) + .take(num_blocks.div_ceil(2)) + .collect::>(); + + // Since there is a limit in the number of how many lut we can apply in one PBS + // we pre-chunk LUTs according to that amount + let blocks = decomposed_value + .par_chunks(max_num_many_luts) + .flat_map(|chunk_of_packed_value| { + let fns = chunk_of_packed_value + .iter() + .map(|packed_value| { + &(scalar_block_cmp_fns[(*packed_value).cast_into()]) + as &dyn Fn(u64) -> u64 + }) + .collect::>(); + let luts = self.generate_many_lookup_table(fns.as_slice()); + unsafe { + self.apply_many_lookup_table_async( + &selector.0.ciphertext, + &luts, + streams, + ) + } + }) + .collect::>(); + //Ideally in the previous step we would have operated all blocks at once, but since + // we didn't, we have this The result here will be Vec + //To do all of them at once, we need to create an apply many lut vector interface, + // and give a vector of many luts This is not implemented yet, so we + // will just unpack the blocks here They are already in order in the + // Vec but we want to have them in a Vec + blocks + }) + .collect::>(); + + //Brute force way to wrap the blocks in a single CudaIntegerCiphertext + let mut outputs = Vec::::with_capacity(vec_cts.len()); + for ct_vec in vec_cts.iter() { + let ct: T = self.convert_radixes_vec_to_single_radix_ciphertext(ct_vec, streams); + outputs.push(ct); + } + outputs + } + + /// Aggregate/combines a vec of one-hot vector of radix ciphertexts + /// (i.e. at most one of the vector element is non-zero) into single ciphertext + /// containing the non-zero value. + /// + /// The elements in the one hot vector have their block packed. + /// + /// The returned result has non packed blocks + fn aggregate_one_hot_vector(&self, one_hot_vector: &[T], streams: &CudaStreams) -> T + where + T: CudaIntegerRadixCiphertext, + { + // Used to clean the noise + let identity_lut = self.generate_lookup_table(|x| x); + + let total_modulus = (self.message_modulus.0 * self.carry_modulus.0) as usize; + let chunk_size = (total_modulus - 1) / (self.message_modulus.0 as usize - 1); + + let num_chunks = one_hot_vector.len().div_ceil(chunk_size); + let num_ct_blocks = one_hot_vector[0].as_ref().d_blocks.lwe_ciphertext_count().0; + let lwe_size = one_hot_vector[0] + .as_ref() + .d_blocks + .0 + .lwe_dimension + .to_lwe_size() + .0; + let mut aggregated_vector: T = + unsafe { self.create_trivial_zero_radix_async(num_ct_blocks, streams) }; + + //iterate over num_chunks + for chunk_idx in 0..(num_chunks - 1) { + for ct_idx in 0..chunk_size { + let one_hot_idx = chunk_idx * chunk_size + ct_idx; + self.unchecked_add_assign( + &mut aggregated_vector, + &one_hot_vector[one_hot_idx], + streams, + ); + } + let mut temp = unsafe { aggregated_vector.duplicate_async(streams) }; + let mut aggregated_mut_slice = aggregated_vector + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size * num_ct_blocks, 0) + .unwrap(); + + unsafe { + let aggregated_slice = temp + .as_mut() + .d_blocks + .0 + .d_vec + .as_slice(0..lwe_size * num_ct_blocks, 0) + .unwrap(); + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut aggregated_mut_slice, + &aggregated_slice, + identity_lut.acc.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut aggregated_mut_slice, + &aggregated_slice, + identity_lut.acc.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + ); + } + } + } + } + let last_chunk_size = one_hot_vector.len() - (num_chunks - 1) * chunk_size; + for ct_idx in 0..last_chunk_size { + let one_hot_idx = (num_chunks - 1) * chunk_size + ct_idx; + unsafe { + self.unchecked_add_assign_async( + &mut aggregated_vector, + &one_hot_vector[one_hot_idx], + streams, + ); + } + } + + let message_extract_lut = + self.generate_lookup_table(|x| (x % self.message_modulus.0) % self.message_modulus.0); + let carry_extract_lut = self.generate_lookup_table(|x| (x / self.message_modulus.0)); + let mut message_ct: T = + unsafe { self.create_trivial_zero_radix_async(num_ct_blocks, streams) }; + let mut message_mut_slice = message_ct + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size * num_ct_blocks, 0) + .unwrap(); + + let mut carry_ct: T = + unsafe { self.create_trivial_zero_radix_async(num_ct_blocks, streams) }; + + let mut carry_mut_slice = carry_ct + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size * num_ct_blocks, 0) + .unwrap(); + unsafe { + let mut temp = aggregated_vector.duplicate_async(streams); + let aggregated_slice = temp + .as_mut() + .d_blocks + .0 + .d_vec + .as_slice(0..lwe_size * num_ct_blocks, 0) + .unwrap(); + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut carry_mut_slice, + &aggregated_slice, + carry_extract_lut.acc.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut carry_mut_slice, + &aggregated_slice, + carry_extract_lut.acc.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + ); + } + } + } + unsafe { + let mut temp = aggregated_vector.duplicate_async(streams); + let aggregated_slice = temp + .as_mut() + .d_blocks + .0 + .d_vec + .as_slice(0..lwe_size * num_ct_blocks, 0) + .unwrap(); + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut message_mut_slice, + &aggregated_slice, + message_extract_lut.acc.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut message_mut_slice, + &aggregated_slice, + message_extract_lut.acc.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + ); + } + } + } + + let mut output_ct: T = + unsafe { self.create_trivial_zero_radix_async(num_ct_blocks * 2, streams) }; + + // unpacked_blocks + for index in 0..num_ct_blocks { + let mut output_mut_slice1 = output_ct + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(2 * index * lwe_size..(2 * (index) * lwe_size + lwe_size), 0) + .unwrap(); + + let message_mut_slice = message_ct + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(index * lwe_size..(index + 1) * lwe_size, 0) + .unwrap(); + + unsafe { output_mut_slice1.copy_from_gpu_async(&message_mut_slice, streams, 0) }; + } + for index in 0..num_ct_blocks { + let mut output_mut_slice2 = output_ct + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice( + (2 * index * lwe_size + lwe_size)..(2 * (index + 1) * lwe_size), + 0, + ) + .unwrap(); + + let carry_mut_slice = carry_ct + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(index * lwe_size..(index + 1) * lwe_size, 0) + .unwrap(); + + unsafe { output_mut_slice2.copy_from_gpu_async(&carry_mut_slice, streams, 0) }; + } + output_ct.as_mut().info = output_ct.as_ref().info.after_aggregate_one_hot_vector(); + output_ct + } + + /// Only keeps at most one Ciphertext that encrypts 1 + /// + /// Given a Vec of Ciphertexts where each Ciphertext encrypts 0 or 1 + /// This function will return a Vec of Ciphertext where at most one encryption of 1 is present + /// The first encryption of one is kept + unsafe fn only_keep_first_true(&self, values: T, streams: &CudaStreams) -> T + where + T: CudaIntegerRadixCiphertext, + { + let num_ct_blocks = values.as_ref().d_blocks.lwe_ciphertext_count().0; + if num_ct_blocks <= 1 { + return values; + } + const ALREADY_SEEN: u64 = 2; + let lut_fn = self.generate_lookup_table_bivariate(|current, previous| { + if previous == 1 || previous == ALREADY_SEEN { + ALREADY_SEEN + } else { + current + } + }); + + let lwe_size = values.as_ref().d_blocks.0.lwe_dimension.to_lwe_size().0; + let mut first_true: T = self.create_trivial_zero_radix_async(num_ct_blocks, streams); + + let mut clone_ct = values.duplicate_async(streams); + let mut slice_in = clone_ct + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size * num_ct_blocks, 0) + .unwrap(); + { + let mut slice_out = first_true + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size * num_ct_blocks, 0) + .unwrap(); + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + compute_prefix_sum_hillis_steele_async( + streams, + &mut slice_out, + &mut slice_in, + lut_fn.acc.acc.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + 0u32, + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + compute_prefix_sum_hillis_steele_async( + streams, + &mut slice_out, + &mut slice_in, + lut_fn.acc.acc.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + 0u32, + ); + } + } + } + + let lut = self.generate_lookup_table(|x| { + let x = x % self.message_modulus.0; + if x == ALREADY_SEEN { + 0 + } else { + x + } + }); + + let cloned_ct = first_true.duplicate_async(streams); + let slice_in_final = cloned_ct + .as_ref() + .d_blocks + .0 + .d_vec + .as_slice(0..lwe_size * num_ct_blocks, 0) + .unwrap(); + let mut slice_out = first_true + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size * num_ct_blocks, 0) + .unwrap(); + + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut slice_out, + &slice_in_final, + lut.acc.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut slice_out, + &slice_in_final, + lut.acc.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + ); + } + } + streams.synchronize(); + first_true + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index 55d52b80dc..2095584d59 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -50,7 +50,22 @@ pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_ pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_sub::unchecked_sub_test; #[cfg(feature = "gpu")] pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_sum::default_sum_ciphertexts_vec_test; +#[cfg(feature = "gpu")] +pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_vector_find::unchecked_match_value_test_case; +pub(crate) use crate::integer::server_key::radix_parallel::tests_unsigned::test_vector_find::{ + default_contains_clear_test_case, default_contains_test_case, + default_first_index_in_clears_test_case, default_first_index_of_clear_test_case, + default_first_index_of_test_case, default_index_in_clears_test_case, + default_index_of_clear_test_case, default_index_of_test_case, default_is_in_clears_test_case, + default_match_value_or_test_case, default_match_value_test_case, + unchecked_contains_clear_test_case, unchecked_contains_test_case, + unchecked_first_index_in_clears_test_case, unchecked_first_index_of_clear_test_case, + unchecked_first_index_of_test_case, unchecked_index_in_clears_test_case, + unchecked_index_of_clear_test_case, unchecked_index_of_test_case, + unchecked_is_in_clears_test_case, unchecked_match_value_or_test_case, +}; use crate::shortint::server_key::CiphertextNoiseDegree; + //============================================================================= // Unchecked Tests //============================================================================= diff --git a/tfhe/src/integer/server_key/radix_parallel/vector_find.rs b/tfhe/src/integer/server_key/radix_parallel/vector_find.rs index f54f66093f..0b34f7448e 100644 --- a/tfhe/src/integer/server_key/radix_parallel/vector_find.rs +++ b/tfhe/src/integer/server_key/radix_parallel/vector_find.rs @@ -50,6 +50,10 @@ impl MatchValues { let matches = range.map(|input| (input, func(input))).collect(); Self(matches) } + // Public method to access the private field + pub fn get_values(&self) -> &Vec<(Clear, Clear)> { + &self.0 + } } impl ServerKey {