From dc3928f52c9695a4c90a330584eba2969fa15a47 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Tue, 14 May 2024 15:20:50 +0200 Subject: [PATCH] feat(c_api): quick 'n' dirty C API for some array fn --- tfhe/c_api_tests/test_high_level_array.c | 123 ++++++++++++++++++ tfhe/src/c_api/high_level_api/array.rs | 54 ++++++++ tfhe/src/c_api/high_level_api/integers.rs | 2 +- tfhe/src/c_api/high_level_api/mod.rs | 1 + tfhe/src/high_level_api/array.rs | 42 ++++++ tfhe/src/high_level_api/mod.rs | 1 + .../radix_parallel/vector_comparisons.rs | 2 +- 7 files changed, 223 insertions(+), 2 deletions(-) create mode 100644 tfhe/c_api_tests/test_high_level_array.c create mode 100644 tfhe/src/c_api/high_level_api/array.rs create mode 100644 tfhe/src/high_level_api/array.rs diff --git a/tfhe/c_api_tests/test_high_level_array.c b/tfhe/c_api_tests/test_high_level_array.c new file mode 100644 index 0000000000..bce467d4c4 --- /dev/null +++ b/tfhe/c_api_tests/test_high_level_array.c @@ -0,0 +1,123 @@ +#include "tfhe.h" + +#include +#include +#include +#include + +// Encrypts a string in a FheUint array +// No error handling is made, it asserts on all error for demo purposes +FheUint8 **encrypt_str(const char *const str, const size_t str_len, const ClientKey *ck) { + assert(str != NULL && str_len > 0); + + FheUint8 **result = malloc(sizeof(*result) * str_len); + assert(result != NULL); + + for (size_t i = 0; i < str_len; ++i) { + assert(fhe_uint8_try_encrypt_with_client_key_u8(str[i], ck, &result[i]) == 0); + } + return result; +} + +void destroy_fhe_uint8_array(FheUint8 **begin, const size_t len) { + for (size_t i = 0; i < len; ++i) { + fhe_uint8_destroy(begin[i]); + } + free(begin); +} + +int main(void) { + int ok = 0; + ConfigBuilder *builder; + Config *config; + + config_builder_default(&builder); + config_builder_build(builder, &config); + + ClientKey *client_key = NULL; + ServerKey *server_key = NULL; + + ok = generate_keys(config, &client_key, &server_key); + assert(ok == 0); + + ok = set_server_key(server_key); + assert(ok == 0); + + char const *const sentence = "The quick brown fox jumps over the lazy dog"; + char const *const pattern_1 = "wn fox "; + char const *const pattern_2 = "tfhe-rs"; + + size_t sentence_len = strlen(sentence); + size_t pattern_1_len = strlen(pattern_1); + size_t pattern_2_len = strlen(pattern_2); + + assert(pattern_1_len == pattern_2_len); // We use this later in the tests + + FheUint8 **encrypted_sentence = encrypt_str(sentence, sentence_len, client_key); + FheUint8 **encrypted_pattern_1 = encrypt_str(pattern_1, pattern_1_len, client_key); + FheUint8 **encrypted_pattern_2 = encrypt_str(pattern_2, pattern_2_len, client_key); + + // Equality + { + FheBool *result; + bool clear_result; + + // This one is trivial as the length are not the same + ok = fhe_uint8_array_eq(encrypted_sentence, sentence_len, encrypted_pattern_1, pattern_1_len, + &result); + assert(ok == 0); + ok = fhe_bool_decrypt(result, client_key, &clear_result); + assert(ok == 0 && clear_result == false); + fhe_bool_destroy(result); + + ok = fhe_uint8_array_eq(encrypted_pattern_2, pattern_2_len, encrypted_pattern_1, pattern_1_len, + &result); + assert(ok == 0); + ok = fhe_bool_decrypt(result, client_key, &clear_result); + assert(ok == 0 && clear_result == false); + fhe_bool_destroy(result); + + ok = fhe_uint8_array_eq(encrypted_sentence, sentence_len, encrypted_sentence, sentence_len, + &result); + assert(ok == 0); + ok = fhe_bool_decrypt(result, client_key, &clear_result); + assert(ok == 0 && clear_result == true); + fhe_bool_destroy(result); + } + + // contains sub slice + { + FheBool *result; + bool clear_result; + + // This one is trivial as the length are not the same + ok = fhe_uint8_array_contains_sub_slice(encrypted_sentence, sentence_len, encrypted_pattern_1, + pattern_1_len, &result); + assert(ok == 0); + ok = fhe_bool_decrypt(result, client_key, &clear_result); + assert(ok == 0 && clear_result == true); + fhe_bool_destroy(result); + + ok = fhe_uint8_array_contains_sub_slice(encrypted_sentence, sentence_len, encrypted_pattern_2, + pattern_2_len, &result); + assert(ok == 0); + ok = fhe_bool_decrypt(result, client_key, &clear_result); + assert(ok == 0 && clear_result == false); + fhe_bool_destroy(result); + + ok = fhe_uint8_array_contains_sub_slice(encrypted_sentence, sentence_len, encrypted_sentence, + sentence_len, &result); + assert(ok == 0); + ok = fhe_bool_decrypt(result, client_key, &clear_result); + assert(ok == 0 && clear_result == true); + fhe_bool_destroy(result); + } + + destroy_fhe_uint8_array(encrypted_sentence, sentence_len); + destroy_fhe_uint8_array(encrypted_pattern_1, pattern_1_len); + destroy_fhe_uint8_array(encrypted_pattern_2, pattern_2_len); + + client_key_destroy(client_key); + server_key_destroy(server_key); + return 0; +} diff --git a/tfhe/src/c_api/high_level_api/array.rs b/tfhe/src/c_api/high_level_api/array.rs new file mode 100644 index 0000000000..5a66754827 --- /dev/null +++ b/tfhe/src/c_api/high_level_api/array.rs @@ -0,0 +1,54 @@ +use crate::c_api::high_level_api::booleans::FheBool; +use crate::c_api::high_level_api::integers::{ + FheUint10, FheUint12, FheUint128, FheUint14, FheUint16, FheUint2, FheUint256, FheUint32, + FheUint4, FheUint6, FheUint64, FheUint8, +}; + +macro_rules! impl_array_fn { + ( + name: $name:ident, + inner_func: $inner_func:path, + output_type_name: $output_type_name:ty, + type_name: $($type_name:ty),* + $(,)? + ) => { + $( // type_name + ::paste::paste! { + #[no_mangle] + pub unsafe extern "C" fn [<$type_name:snake _ $name>]( + lhs: *const *mut $type_name, + lhs_len: usize, + rhs: *const *mut $type_name, + rhs_len: usize, + result: *mut *mut $output_type_name, + ) -> ::std::os::raw::c_int { + crate::c_api::utils::catch_panic(|| { + let lhs: &[*mut $type_name] = std::slice::from_raw_parts(lhs, lhs_len); + let rhs: &[*mut $type_name] = std::slice::from_raw_parts(rhs, rhs_len); + + let cloned_lhs = lhs.iter().map(|e: &*mut $type_name| e.as_ref().unwrap().0.clone()).collect::>(); + let cloned_rhs = rhs.iter().map(|e: &*mut $type_name| e.as_ref().unwrap().0.clone()).collect::>(); + + let inner = $inner_func(&cloned_lhs, &cloned_rhs); + + *result = Box::into_raw(Box::new($output_type_name(inner))); + }) + } + } + )* + }; +} + +impl_array_fn!( + name: array_eq, + inner_func: crate::high_level_api::array::fhe_uint_array_eq, + output_type_name: FheBool, + type_name: FheUint2, FheUint4, FheUint6, FheUint8, FheUint10, FheUint12, FheUint14, FheUint16, FheUint32, FheUint64, FheUint128, FheUint256, +); + +impl_array_fn!( + name: array_contains_sub_slice, + inner_func: crate::high_level_api::array::fhe_uint_array_contains_sub_slice, + output_type_name: FheBool, + type_name: FheUint2, FheUint4, FheUint6, FheUint8, FheUint10, FheUint12, FheUint14, FheUint16, FheUint32, FheUint64, FheUint128, FheUint256, +); diff --git a/tfhe/src/c_api/high_level_api/integers.rs b/tfhe/src/c_api/high_level_api/integers.rs index 9fb8696e96..c1c906b5c0 100644 --- a/tfhe/src/c_api/high_level_api/integers.rs +++ b/tfhe/src/c_api/high_level_api/integers.rs @@ -250,7 +250,7 @@ macro_rules! create_integer_wrapper_type { $(,)? ) => { - pub struct $name($crate::high_level_api::$name); + pub struct $name(pub(in $crate::c_api) $crate::high_level_api::$name); impl_destroy_on_type!($name); diff --git a/tfhe/src/c_api/high_level_api/mod.rs b/tfhe/src/c_api/high_level_api/mod.rs index 1fd5909b44..40c29188cf 100644 --- a/tfhe/src/c_api/high_level_api/mod.rs +++ b/tfhe/src/c_api/high_level_api/mod.rs @@ -1,5 +1,6 @@ #[macro_use] mod utils; +mod array; #[cfg(feature = "boolean")] pub mod booleans; pub mod config; diff --git a/tfhe/src/high_level_api/array.rs b/tfhe/src/high_level_api/array.rs new file mode 100644 index 0000000000..6bfc125ed9 --- /dev/null +++ b/tfhe/src/high_level_api/array.rs @@ -0,0 +1,42 @@ +use crate::high_level_api::global_state::with_cpu_internal_keys; +use crate::high_level_api::integers::FheUintId; +use crate::{FheBool, FheUint}; + +pub fn fhe_uint_array_eq(lhs: &[FheUint], rhs: &[FheUint]) -> FheBool { + with_cpu_internal_keys(|cpu_keys| { + let tmp_lhs = lhs + .iter() + .map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned()) + .collect::>(); + let tmp_rhs = rhs + .iter() + .map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned()) + .collect::>(); + + let result = cpu_keys + .pbs_key() + .all_eq_slices_parallelized(&tmp_lhs, &tmp_rhs); + FheBool::new(result) + }) +} + +pub fn fhe_uint_array_contains_sub_slice( + lhs: &[FheUint], + pattern: &[FheUint], +) -> FheBool { + with_cpu_internal_keys(|cpu_keys| { + let tmp_lhs = lhs + .iter() + .map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned()) + .collect::>(); + let tmp_pattern = pattern + .iter() + .map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned()) + .collect::>(); + + let result = cpu_keys + .pbs_key() + .contains_sub_slice_parallelized(&tmp_lhs, &tmp_pattern); + FheBool::new(result) + }) +} diff --git a/tfhe/src/high_level_api/mod.rs b/tfhe/src/high_level_api/mod.rs index d5e6ba88ae..3c1e825872 100644 --- a/tfhe/src/high_level_api/mod.rs +++ b/tfhe/src/high_level_api/mod.rs @@ -65,6 +65,7 @@ mod booleans; pub mod errors; mod integers; +pub mod array; pub(in crate::high_level_api) mod details; /// The tfhe prelude. pub mod prelude; diff --git a/tfhe/src/integer/server_key/radix_parallel/vector_comparisons.rs b/tfhe/src/integer/server_key/radix_parallel/vector_comparisons.rs index 44028326c9..afae55e2a0 100644 --- a/tfhe/src/integer/server_key/radix_parallel/vector_comparisons.rs +++ b/tfhe/src/integer/server_key/radix_parallel/vector_comparisons.rs @@ -146,7 +146,7 @@ impl ServerKey { } /// Returns a boolean ciphertext encrypting `true` if `lhs` contains `rhs`, `false` otherwise - pub fn contains_sub_slice_parallelized(&self, lhs: &mut [T], rhs: &mut [T]) -> BooleanBlock + pub fn contains_sub_slice_parallelized(&self, lhs: &[T], rhs: &[T]) -> BooleanBlock where T: IntegerRadixCiphertext, {