Skip to content

Commit

Permalink
feat(c_api): quick 'n' dirty C API for some array fn
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed May 16, 2024
1 parent 0e1f24f commit dc3928f
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 2 deletions.
123 changes: 123 additions & 0 deletions tfhe/c_api_tests/test_high_level_array.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include "tfhe.h"

#include <assert.h>
#include <inttypes.h>
#include <stdio.h>
#include <string.h>

// Encrypts a string in a FheUint array
// No error handling is made, it asserts on all error for demo purposes
FheUint8 **encrypt_str(const char *const str, const size_t str_len, const ClientKey *ck) {
assert(str != NULL && str_len > 0);

FheUint8 **result = malloc(sizeof(*result) * str_len);
assert(result != NULL);

for (size_t i = 0; i < str_len; ++i) {
assert(fhe_uint8_try_encrypt_with_client_key_u8(str[i], ck, &result[i]) == 0);
}
return result;
}

void destroy_fhe_uint8_array(FheUint8 **begin, const size_t len) {
for (size_t i = 0; i < len; ++i) {
fhe_uint8_destroy(begin[i]);
}
free(begin);
}

int main(void) {
int ok = 0;
ConfigBuilder *builder;
Config *config;

config_builder_default(&builder);
config_builder_build(builder, &config);

ClientKey *client_key = NULL;
ServerKey *server_key = NULL;

ok = generate_keys(config, &client_key, &server_key);
assert(ok == 0);

ok = set_server_key(server_key);
assert(ok == 0);

char const *const sentence = "The quick brown fox jumps over the lazy dog";
char const *const pattern_1 = "wn fox ";
char const *const pattern_2 = "tfhe-rs";

size_t sentence_len = strlen(sentence);
size_t pattern_1_len = strlen(pattern_1);
size_t pattern_2_len = strlen(pattern_2);

assert(pattern_1_len == pattern_2_len); // We use this later in the tests

FheUint8 **encrypted_sentence = encrypt_str(sentence, sentence_len, client_key);
FheUint8 **encrypted_pattern_1 = encrypt_str(pattern_1, pattern_1_len, client_key);
FheUint8 **encrypted_pattern_2 = encrypt_str(pattern_2, pattern_2_len, client_key);

// Equality
{
FheBool *result;
bool clear_result;

// This one is trivial as the length are not the same
ok = fhe_uint8_array_eq(encrypted_sentence, sentence_len, encrypted_pattern_1, pattern_1_len,
&result);
assert(ok == 0);
ok = fhe_bool_decrypt(result, client_key, &clear_result);
assert(ok == 0 && clear_result == false);
fhe_bool_destroy(result);

ok = fhe_uint8_array_eq(encrypted_pattern_2, pattern_2_len, encrypted_pattern_1, pattern_1_len,
&result);
assert(ok == 0);
ok = fhe_bool_decrypt(result, client_key, &clear_result);
assert(ok == 0 && clear_result == false);
fhe_bool_destroy(result);

ok = fhe_uint8_array_eq(encrypted_sentence, sentence_len, encrypted_sentence, sentence_len,
&result);
assert(ok == 0);
ok = fhe_bool_decrypt(result, client_key, &clear_result);
assert(ok == 0 && clear_result == true);
fhe_bool_destroy(result);
}

// contains sub slice
{
FheBool *result;
bool clear_result;

// This one is trivial as the length are not the same
ok = fhe_uint8_array_contains_sub_slice(encrypted_sentence, sentence_len, encrypted_pattern_1,
pattern_1_len, &result);
assert(ok == 0);
ok = fhe_bool_decrypt(result, client_key, &clear_result);
assert(ok == 0 && clear_result == true);
fhe_bool_destroy(result);

ok = fhe_uint8_array_contains_sub_slice(encrypted_sentence, sentence_len, encrypted_pattern_2,
pattern_2_len, &result);
assert(ok == 0);
ok = fhe_bool_decrypt(result, client_key, &clear_result);
assert(ok == 0 && clear_result == false);
fhe_bool_destroy(result);

ok = fhe_uint8_array_contains_sub_slice(encrypted_sentence, sentence_len, encrypted_sentence,
sentence_len, &result);
assert(ok == 0);
ok = fhe_bool_decrypt(result, client_key, &clear_result);
assert(ok == 0 && clear_result == true);
fhe_bool_destroy(result);
}

destroy_fhe_uint8_array(encrypted_sentence, sentence_len);
destroy_fhe_uint8_array(encrypted_pattern_1, pattern_1_len);
destroy_fhe_uint8_array(encrypted_pattern_2, pattern_2_len);

client_key_destroy(client_key);
server_key_destroy(server_key);
return 0;
}
54 changes: 54 additions & 0 deletions tfhe/src/c_api/high_level_api/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use crate::c_api::high_level_api::booleans::FheBool;
use crate::c_api::high_level_api::integers::{
FheUint10, FheUint12, FheUint128, FheUint14, FheUint16, FheUint2, FheUint256, FheUint32,
FheUint4, FheUint6, FheUint64, FheUint8,
};

macro_rules! impl_array_fn {
(
name: $name:ident,
inner_func: $inner_func:path,
output_type_name: $output_type_name:ty,
type_name: $($type_name:ty),*
$(,)?
) => {
$( // type_name
::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<$type_name:snake _ $name>](
lhs: *const *mut $type_name,
lhs_len: usize,
rhs: *const *mut $type_name,
rhs_len: usize,
result: *mut *mut $output_type_name,
) -> ::std::os::raw::c_int {
crate::c_api::utils::catch_panic(|| {
let lhs: &[*mut $type_name] = std::slice::from_raw_parts(lhs, lhs_len);
let rhs: &[*mut $type_name] = std::slice::from_raw_parts(rhs, rhs_len);

let cloned_lhs = lhs.iter().map(|e: &*mut $type_name| e.as_ref().unwrap().0.clone()).collect::<Vec<_>>();
let cloned_rhs = rhs.iter().map(|e: &*mut $type_name| e.as_ref().unwrap().0.clone()).collect::<Vec<_>>();

let inner = $inner_func(&cloned_lhs, &cloned_rhs);

*result = Box::into_raw(Box::new($output_type_name(inner)));
})
}
}
)*
};
}

impl_array_fn!(
name: array_eq,
inner_func: crate::high_level_api::array::fhe_uint_array_eq,
output_type_name: FheBool,
type_name: FheUint2, FheUint4, FheUint6, FheUint8, FheUint10, FheUint12, FheUint14, FheUint16, FheUint32, FheUint64, FheUint128, FheUint256,
);

impl_array_fn!(
name: array_contains_sub_slice,
inner_func: crate::high_level_api::array::fhe_uint_array_contains_sub_slice,
output_type_name: FheBool,
type_name: FheUint2, FheUint4, FheUint6, FheUint8, FheUint10, FheUint12, FheUint14, FheUint16, FheUint32, FheUint64, FheUint128, FheUint256,
);
2 changes: 1 addition & 1 deletion tfhe/src/c_api/high_level_api/integers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ macro_rules! create_integer_wrapper_type {
$(,)?
) => {

pub struct $name($crate::high_level_api::$name);
pub struct $name(pub(in $crate::c_api) $crate::high_level_api::$name);

impl_destroy_on_type!($name);

Expand Down
1 change: 1 addition & 0 deletions tfhe/src/c_api/high_level_api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#[macro_use]
mod utils;
mod array;
#[cfg(feature = "boolean")]
pub mod booleans;
pub mod config;
Expand Down
42 changes: 42 additions & 0 deletions tfhe/src/high_level_api/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use crate::high_level_api::global_state::with_cpu_internal_keys;
use crate::high_level_api::integers::FheUintId;
use crate::{FheBool, FheUint};

pub fn fhe_uint_array_eq<Id: FheUintId>(lhs: &[FheUint<Id>], rhs: &[FheUint<Id>]) -> FheBool {
with_cpu_internal_keys(|cpu_keys| {
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();
let tmp_rhs = rhs
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();

let result = cpu_keys
.pbs_key()
.all_eq_slices_parallelized(&tmp_lhs, &tmp_rhs);
FheBool::new(result)
})
}

pub fn fhe_uint_array_contains_sub_slice<Id: FheUintId>(
lhs: &[FheUint<Id>],
pattern: &[FheUint<Id>],
) -> FheBool {
with_cpu_internal_keys(|cpu_keys| {
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();
let tmp_pattern = pattern
.iter()
.map(|fhe_uint| fhe_uint.ciphertext.on_cpu().to_owned())
.collect::<Vec<_>>();

let result = cpu_keys
.pbs_key()
.contains_sub_slice_parallelized(&tmp_lhs, &tmp_pattern);
FheBool::new(result)
})
}
1 change: 1 addition & 0 deletions tfhe/src/high_level_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ mod booleans;
pub mod errors;
mod integers;

pub mod array;
pub(in crate::high_level_api) mod details;
/// The tfhe prelude.
pub mod prelude;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl ServerKey {
}

/// Returns a boolean ciphertext encrypting `true` if `lhs` contains `rhs`, `false` otherwise
pub fn contains_sub_slice_parallelized<T>(&self, lhs: &mut [T], rhs: &mut [T]) -> BooleanBlock
pub fn contains_sub_slice_parallelized<T>(&self, lhs: &[T], rhs: &[T]) -> BooleanBlock
where
T: IntegerRadixCiphertext,
{
Expand Down

0 comments on commit dc3928f

Please sign in to comment.