Skip to content

Commit

Permalink
feat(hlapi): move if_then_else/cmux to FheBool
Browse files Browse the repository at this point in the history
- This makes FheBool use integer::BooleanBlock internally.
- It makes comparisons (eq, ne, le, etc) return a FheBool instead of
  FheUint/FheInt.
- It also moves the if_then_else and cmux methods to FheBool.
- Adds casting from FheBool to FheUint/FheInt (but not from
  FheUint/FheInt to FheBool as we expect users to do `a.ne(0)`
  as its matches Rust)

BREAKING CHANGE:
    - Comparisons now return FheBool
    - if_then_else/cmux are now methods of FheBool.
  • Loading branch information
tmontaigu committed Nov 15, 2023
1 parent 20cb064 commit 916bd8a
Show file tree
Hide file tree
Showing 13 changed files with 313 additions and 273 deletions.
30 changes: 15 additions & 15 deletions tfhe/docs/getting_started/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,17 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let lower_or_equal = a.le(&b);
let equal = a.eq(&b);

let dec_gt: i8 = greater.decrypt(&keys);
let dec_ge: i8 = greater_or_equal.decrypt(&keys);
let dec_lt: i8 = lower.decrypt(&keys);
let dec_le: i8 = lower_or_equal.decrypt(&keys);
let dec_eq: i8 = equal.decrypt(&keys);

assert_eq!(dec_gt, (clear_a > clear_b ) as i8);
assert_eq!(dec_ge, (clear_a >= clear_b) as i8);
assert_eq!(dec_lt, (clear_a < clear_b ) as i8);
assert_eq!(dec_le, (clear_a <= clear_b) as i8);
assert_eq!(dec_eq, (clear_a == clear_b) as i8);
let dec_gt = greater.decrypt(&keys);
let dec_ge = greater_or_equal.decrypt(&keys);
let dec_lt = lower.decrypt(&keys);
let dec_le = lower_or_equal.decrypt(&keys);
let dec_eq = equal.decrypt(&keys);

assert_eq!(dec_gt, clear_a > clear_b);
assert_eq!(dec_ge, clear_a >= clear_b);
assert_eq!(dec_lt, clear_a < clear_b);
assert_eq!(dec_le, clear_a <= clear_b);
assert_eq!(dec_eq, clear_a == clear_b);

Ok(())
}
Expand Down Expand Up @@ -292,11 +292,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

// Clear equivalent computations: 32 > -45
let encrypted_comp = &encrypted_a.gt(&encrypted_b);
let clear_res: i32 = encrypted_comp.decrypt(&client_key);
assert_eq!(clear_res, (clear_a > clear_b) as i32);
let clear_res = encrypted_comp.decrypt(&client_key);
assert_eq!(clear_res, clear_a > clear_b);

// `encrypted_comp` contains the result of the comparison, i.e.,
// a boolean value. This acts as a condition on which the
// `encrypted_comp` is a FheBool, thus it encrypts a boolean value.
// This acts as a condition on which the
// `if_then_else` function can be applied on.
// Clear equivalent computations:
// if 32 > -45 {result = 32} else {result = -45}
Expand Down
6 changes: 3 additions & 3 deletions tfhe/docs/tutorials/ascii_fhe_string.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ use tfhe::FheUint8;
pub const UP_LOW_DISTANCE: u8 = 32;

fn to_lower(c: &FheUint8) -> FheUint8 {
c + (c.gt(64) & c.lt(91)) * UP_LOW_DISTANCE
c + FheUint8::cast_from(c.gt(64) & c.lt(91)) * UP_LOW_DISTANCE
}
```

Expand All @@ -86,11 +86,11 @@ struct FheAsciiString {
}

fn to_upper(c: &FheUint8) -> FheUint8 {
c - (c.gt(96) & c.lt(123)) * UP_LOW_DISTANCE
c - FheUint8::cast_from(c.gt(96) & c.lt(123)) * UP_LOW_DISTANCE
}

fn to_lower(c: &FheUint8) -> FheUint8 {
c + (c.gt(64) & c.lt(91)) * UP_LOW_DISTANCE
c + FheUint8::cast_from(c.gt(64) & c.lt(91)) * UP_LOW_DISTANCE
}

impl FheAsciiString {
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/c_api/high_level_api/booleans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::high_level_api::prelude::*;

use std::ops::{BitAnd, BitOr, BitXor, Not};

pub struct FheBool(crate::high_level_api::FheBool);
pub struct FheBool(pub(in crate::c_api) crate::high_level_api::FheBool);

impl_destroy_on_type!(FheBool);
impl_clone_on_type!(FheBool);
Expand Down
33 changes: 19 additions & 14 deletions tfhe/src/c_api/high_level_api/integers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::ops::{
Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};

use crate::c_api::high_level_api::booleans::FheBool;
use crate::c_api::high_level_api::i128::I128;
use crate::c_api::high_level_api::i256::I256;
use crate::c_api::high_level_api::u128::U128;
Expand Down Expand Up @@ -33,18 +34,26 @@ macro_rules! impl_operations_for_integer_type {
bitand,
bitor,
bitxor,
eq,
ne,
ge,
gt,
le,
lt,
min,
max,
div,
rem,
);

// Handle comparisons separately as they return FheBool
impl_comparison_fn_on_type!(
lhs_type: $name,
rhs_type: $name,
comparison_fn_names: eq, ne, ge, gt, le, lt,
);

// Handle comparisons separately as they return FheBool
impl_scalar_comparison_fn_on_type!(
lhs_type: $name,
clear_type: $clear_scalar_type,
comparison_fn_names: eq, ne, ge, gt, le, lt,
);

// handle shift separately as they require
// rhs to be an unsigned type
impl_binary_fn_on_type!(
Expand Down Expand Up @@ -79,12 +88,6 @@ macro_rules! impl_operations_for_integer_type {
bitand,
bitor,
bitxor,
eq,
ne,
ge,
gt,
le,
lt,
min,
max,
div,
Expand Down Expand Up @@ -165,10 +168,12 @@ macro_rules! impl_operations_for_integer_type {
}
}

// Even though if_then_else/cmux is a method of FheBool, it still takes as
// integers inputs, so its easier to keep the definition here
::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<$name:snake _if_then_else>](
condition_ct: *const $name,
condition_ct: *const FheBool,
then_ct: *const $name,
else_ct: *const $name,
result: *mut *mut $name,
Expand All @@ -186,7 +191,7 @@ macro_rules! impl_operations_for_integer_type {

// map cmux to if_then_else
pub unsafe extern "C" fn [<$name:snake _cmux>](
condition_ct: *const $name,
condition_ct: *const FheBool,
then_ct: *const $name,
else_ct: *const $name,
result: *mut *mut $name,
Expand Down
65 changes: 64 additions & 1 deletion tfhe/src/c_api/high_level_api/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,76 @@ macro_rules! impl_binary_fn_on_type {
// Usual binary fn case, where lhs, rhs and result are all of the same type
($wrapper_type:ty => $($binary_fn_name:ident),* $(,)?) => {
impl_binary_fn_on_type!(
lhs_type: $wrapper_type,
lhs_type: $wrapper_type,
rhs_type: $wrapper_type,
binary_fn_names: $($binary_fn_name),*
);
};
}

// Comparisons returns FheBool so we use a specialized
// macro for them
macro_rules! impl_comparison_fn_on_type {
(
lhs_type: $lhs_type:ty,
rhs_type: $rhs_type:ty,
comparison_fn_names: $($comparison_fn_name:ident),*
$(,)?
) => {
$( // unroll comparison_fn_names
::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<$lhs_type:snake _ $comparison_fn_name>](
lhs: *const $lhs_type,
rhs: *const $rhs_type,
result: *mut *mut $crate::c_api::high_level_api::booleans::FheBool,
) -> ::std::os::raw::c_int {
$crate::c_api::utils::catch_panic(|| {
let lhs = $crate::c_api::utils::get_ref_checked(lhs).unwrap();
let rhs = $crate::c_api::utils::get_ref_checked(rhs).unwrap();

let inner = (&lhs.0).$comparison_fn_name(&rhs.0);

let inner = $crate::c_api::high_level_api::booleans::FheBool(inner);
*result = Box::into_raw(Box::new(inner));
})
}
}
)*
};
}

macro_rules! impl_scalar_comparison_fn_on_type {
(
lhs_type: $lhs_type:ty,
clear_type: $scalar_type:ty,
comparison_fn_names: $($comparison_fn_name:ident),*
$(,)?
) => {
$( // unroll comparison_fn_names
::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<$lhs_type:snake _scalar_ $comparison_fn_name>](
lhs: *const $lhs_type,
rhs: $scalar_type,
result: *mut *mut $crate::c_api::high_level_api::booleans::FheBool,
) -> ::std::os::raw::c_int {
$crate::c_api::utils::catch_panic(|| {
let lhs = $crate::c_api::utils::get_ref_checked(lhs).unwrap();
let rhs = <$scalar_type as $crate::c_api::high_level_api::utils::CApiIntegerType>::to_rust(rhs);


let inner = (&lhs.0).$comparison_fn_name(rhs);

let inner = $crate::c_api::high_level_api::booleans::FheBool(inner);
*result = Box::into_raw(Box::new(inner));
})
}
}
)*
};
}

macro_rules! impl_unary_fn_on_type {
($wrapper_type:ty => $($unary_fn_name:ident),* $(,)?) => {
$(
Expand Down
Loading

0 comments on commit 916bd8a

Please sign in to comment.