Skip to content

Commit

Permalink
feat(gpu): implement vector comparisons gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
guillermo-oyarzun committed Dec 20, 2024
1 parent 70ff0f7 commit b25a448
Show file tree
Hide file tree
Showing 9 changed files with 664 additions and 34 deletions.
7 changes: 4 additions & 3 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ void cuda_apply_many_univariate_lut_kb_64(

void scratch_cuda_apply_bivariate_lut_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
int8_t **mem_ptr, void *input_lut, uint32_t lwe_dimension,
int8_t **mem_ptr, void const *input_lut, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t ks_level,
uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t grouping_factor, uint32_t num_radix_blocks,
Expand All @@ -272,8 +272,9 @@ void scratch_cuda_apply_bivariate_lut_kb_64(

scratch_cuda_apply_bivariate_lut_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
(int_radix_lut<uint64_t> **)mem_ptr, static_cast<uint64_t *>(input_lut),
num_radix_blocks, params, allocate_gpu_memory);
(int_radix_lut<uint64_t> **)mem_ptr,
static_cast<const uint64_t *>(input_lut), num_radix_blocks, params,
allocate_gpu_memory);
}

void cuda_apply_bivariate_lut_kb_64(
Expand Down
97 changes: 68 additions & 29 deletions tfhe/src/high_level_api/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ pub(in crate::high_level_api) mod traits;

use crate::array::traits::TensorSlice;
use crate::high_level_api::array::traits::HasClear;
use crate::high_level_api::global_state::with_cpu_internal_keys;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::integers::FheUintId;
use crate::high_level_api::keys::InternalServerKey;
#[cfg(feature = "gpu")]
use crate::{FheBool, FheId, FheUint};
use std::ops::RangeBounds;
use traits::{ArrayBackend, BackendDataContainer, BackendDataContainerMut};
Expand Down Expand Up @@ -345,40 +349,75 @@ declare_concrete_array_types!(
);

pub fn fhe_uint_array_eq<Id: FheUintId>(lhs: &[FheUint<Id>], rhs: &[FheUint<Id>]) -> FheBool {
with_cpu_internal_keys(|cpu_keys| {
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();
let tmp_rhs = rhs
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();

let result = cpu_keys
.pbs_key()
.all_eq_slices_parallelized(&tmp_lhs, &tmp_rhs);
FheBool::new(result, cpu_keys.tag.clone())
global_state::with_internal_keys(|sks| match sks {
InternalServerKey::Cpu(cpu_key) => {
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();
let tmp_rhs = rhs
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();

let result = cpu_key
.pbs_key()
.all_eq_slices_parallelized(&tmp_lhs, &tmp_rhs);
FheBool::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(gpu_key) => with_thread_local_cuda_streams(|streams| {
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu())
.collect::<Vec<_>>();
let tmp_rhs = rhs
.iter()
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu())
.collect::<Vec<_>>();

let result = gpu_key.key.key.all_eq_slices(&tmp_lhs, &tmp_rhs, streams);
FheBool::new(result, gpu_key.tag.clone())
}),
})
}

pub fn fhe_uint_array_contains_sub_slice<Id: FheUintId>(
lhs: &[FheUint<Id>],
pattern: &[FheUint<Id>],
) -> FheBool {
with_cpu_internal_keys(|cpu_keys| {
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();
let tmp_pattern = pattern
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();

let result = cpu_keys
.pbs_key()
.contains_sub_slice_parallelized(&tmp_lhs, &tmp_pattern);
FheBool::new(result, cpu_keys.tag.clone())
global_state::with_internal_keys(|sks| match sks {
InternalServerKey::Cpu(cpu_key) => {
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();
let tmp_pattern = pattern
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();

let result = cpu_key
.pbs_key()
.contains_sub_slice_parallelized(&tmp_lhs, &tmp_pattern);
FheBool::new(result, cpu_key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(gpu_key) => with_thread_local_cuda_streams(|streams| {
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu())
.collect::<Vec<_>>();
let tmp_pattern = pattern
.iter()
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu())
.collect::<Vec<_>>();

let result = gpu_key
.key
.key
.contains_sub_slice(&tmp_lhs, &tmp_pattern, streams);
FheBool::new(result, gpu_key.tag.clone())
}),
})
}
4 changes: 2 additions & 2 deletions tfhe/src/integer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2713,8 +2713,8 @@ pub unsafe fn apply_many_univariate_lut_kb_async<T: UnsignedInteger, B: Numeric>
pub unsafe fn apply_bivariate_lut_kb_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
radix_lwe_output: &mut CudaSliceMut<T>,
radix_lwe_input_1: &CudaSlice<T>,
radix_lwe_input_2: &CudaSlice<T>,
radix_lwe_input_1: &CudaVec<T>,
radix_lwe_input_2: &CudaVec<T>,
input_lut: &[T],
bootstrapping_key: &CudaVec<B>,
keyswitch_key: &CudaVec<T>,
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/integer/gpu/server_key/radix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ mod scalar_shift;
mod scalar_sub;
mod shift;
mod sub;
mod vector_comparisons;
mod vector_find;

#[cfg(all(test, feature = "__long_run_tests"))]
Expand Down
44 changes: 44 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub(crate) mod test_scalar_shift;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_shift;
pub(crate) mod test_sub;
pub(crate) mod test_vector_comparisons;

use crate::core_crypto::gpu::CudaStreams;
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
Expand Down Expand Up @@ -565,3 +566,46 @@ where
)
}
}
impl<'a, F>
FunctionExecutor<(&'a [SignedRadixCiphertext], &'a [SignedRadixCiphertext]), BooleanBlock>
for GpuFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&[CudaSignedRadixCiphertext],
&[CudaSignedRadixCiphertext],
&CudaStreams,
) -> CudaBooleanBlock,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}

fn execute(
&mut self,
input: (&'a [SignedRadixCiphertext], &'a [SignedRadixCiphertext]),
) -> BooleanBlock {
let context = self
.context
.as_ref()
.expect("setup was not properly called");

let mut d_ctxs1 = Vec::<CudaSignedRadixCiphertext>::with_capacity(input.0.len());
for ctx in input.0 {
d_ctxs1.push(CudaSignedRadixCiphertext::from_signed_radix_ciphertext(
ctx,
&context.streams,
));
}
let mut d_ctxs2 = Vec::<CudaSignedRadixCiphertext>::with_capacity(input.0.len());
for ctx in input.1 {
d_ctxs2.push(CudaSignedRadixCiphertext::from_signed_radix_ciphertext(
ctx,
&context.streams,
));
}

let d_block = (self.func)(&context.sks, &d_ctxs1, &d_ctxs2, &context.streams);
d_block.to_boolean_block(&context.streams)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parameterized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_signed::test_vector_comparisons::{
default_all_eq_slices_test_case, unchecked_all_eq_slices_test_case,
};
use crate::shortint::parameters::*;

create_gpu_parameterized_test!(integer_signed_unchecked_all_eq_slices_test_case);
create_gpu_parameterized_test!(integer_signed_default_all_eq_slices_test_case);

fn integer_signed_unchecked_all_eq_slices_test_case<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_all_eq_slices);
unchecked_all_eq_slices_test_case(param, executor);
}

fn integer_signed_default_all_eq_slices_test_case<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::all_eq_slices);
default_all_eq_slices_test_case(param, executor);
}
42 changes: 42 additions & 0 deletions tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub(crate) mod test_scalar_shift;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_shift;
pub(crate) mod test_sub;
pub(crate) mod test_vector_comparisons;
pub(crate) mod test_vector_find;

use crate::core_crypto::gpu::CudaStreams;
Expand Down Expand Up @@ -865,3 +866,44 @@ where
(res, block)
}
}

impl<'a, F> FunctionExecutor<(&'a [RadixCiphertext], &'a [RadixCiphertext]), BooleanBlock>
for GpuFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&[CudaUnsignedRadixCiphertext],
&[CudaUnsignedRadixCiphertext],
&CudaStreams,
) -> CudaBooleanBlock,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}

fn execute(&mut self, input: (&'a [RadixCiphertext], &'a [RadixCiphertext])) -> BooleanBlock {
let context = self
.context
.as_ref()
.expect("setup was not properly called");

let mut d_ctxs1 = Vec::<CudaUnsignedRadixCiphertext>::with_capacity(input.0.len());
for ctx in input.0 {
d_ctxs1.push(CudaUnsignedRadixCiphertext::from_radix_ciphertext(
ctx,
&context.streams,
));
}
let mut d_ctxs2 = Vec::<CudaUnsignedRadixCiphertext>::with_capacity(input.0.len());
for ctx in input.1 {
d_ctxs2.push(CudaUnsignedRadixCiphertext::from_radix_ciphertext(
ctx,
&context.streams,
));
}

let d_block = (self.func)(&context.sks, &d_ctxs1, &d_ctxs2, &context.streams);
d_block.to_boolean_block(&context.streams)
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parameterized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::shortint::parameters::*;

use crate::integer::server_key::radix_parallel::tests_unsigned::test_vector_comparisons::{
default_all_eq_slices_test_case, unchecked_all_eq_slices_test_case,
unchecked_slice_contains_test_case,
};

create_gpu_parameterized_test!(integer_unchecked_all_eq_slices_test_case);
create_gpu_parameterized_test!(integer_default_all_eq_slices_test_case);
create_gpu_parameterized_test!(integer_unchecked_contains_slice_test_case);

fn integer_unchecked_all_eq_slices_test_case<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_all_eq_slices);
unchecked_all_eq_slices_test_case(param, executor);
}

fn integer_default_all_eq_slices_test_case<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::all_eq_slices);
default_all_eq_slices_test_case(param, executor);
}

fn integer_unchecked_contains_slice_test_case<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_contains_sub_slice);
unchecked_slice_contains_test_case(param, executor);
}
Loading

0 comments on commit b25a448

Please sign in to comment.