diff --git a/tfhe/docs/getting_started/operations.md b/tfhe/docs/getting_started/operations.md index 6ff281be76..9cdf466051 100644 --- a/tfhe/docs/getting_started/operations.md +++ b/tfhe/docs/getting_started/operations.md @@ -204,17 +204,17 @@ fn main() -> Result<(), Box> { let lower_or_equal = a.le(&b); let equal = a.eq(&b); - let dec_gt: i8 = greater.decrypt(&keys); - let dec_ge: i8 = greater_or_equal.decrypt(&keys); - let dec_lt: i8 = lower.decrypt(&keys); - let dec_le: i8 = lower_or_equal.decrypt(&keys); - let dec_eq: i8 = equal.decrypt(&keys); - - assert_eq!(dec_gt, (clear_a > clear_b ) as i8); - assert_eq!(dec_ge, (clear_a >= clear_b) as i8); - assert_eq!(dec_lt, (clear_a < clear_b ) as i8); - assert_eq!(dec_le, (clear_a <= clear_b) as i8); - assert_eq!(dec_eq, (clear_a == clear_b) as i8); + let dec_gt = greater.decrypt(&keys); + let dec_ge = greater_or_equal.decrypt(&keys); + let dec_lt = lower.decrypt(&keys); + let dec_le = lower_or_equal.decrypt(&keys); + let dec_eq = equal.decrypt(&keys); + + assert_eq!(dec_gt, clear_a > clear_b); + assert_eq!(dec_ge, clear_a >= clear_b); + assert_eq!(dec_lt, clear_a < clear_b); + assert_eq!(dec_le, clear_a <= clear_b); + assert_eq!(dec_eq, clear_a == clear_b); Ok(()) } @@ -292,11 +292,11 @@ fn main() -> Result<(), Box> { // Clear equivalent computations: 32 > -45 let encrypted_comp = &encrypted_a.gt(&encrypted_b); - let clear_res: i32 = encrypted_comp.decrypt(&client_key); - assert_eq!(clear_res, (clear_a > clear_b) as i32); + let clear_res = encrypted_comp.decrypt(&client_key); + assert_eq!(clear_res, clear_a > clear_b); - // `encrypted_comp` contains the result of the comparison, i.e., - // a boolean value. This acts as a condition on which the + // `encrypted_comp` is a FheBool, thus it encrypts a boolean value. + // This acts as a condition on which the // `if_then_else` function can be applied on. // Clear equivalent computations: // if 32 > -45 {result = 32} else {result = -45} diff --git a/tfhe/docs/tutorials/ascii_fhe_string.md b/tfhe/docs/tutorials/ascii_fhe_string.md index d1ab0c8e60..b08f5c9423 100644 --- a/tfhe/docs/tutorials/ascii_fhe_string.md +++ b/tfhe/docs/tutorials/ascii_fhe_string.md @@ -69,7 +69,7 @@ use tfhe::FheUint8; pub const UP_LOW_DISTANCE: u8 = 32; fn to_lower(c: &FheUint8) -> FheUint8 { - c + (c.gt(64) & c.lt(91)) * UP_LOW_DISTANCE + c + FheUint8::cast_from(c.gt(64) & c.lt(91)) * UP_LOW_DISTANCE } ``` @@ -86,11 +86,11 @@ struct FheAsciiString { } fn to_upper(c: &FheUint8) -> FheUint8 { - c - (c.gt(96) & c.lt(123)) * UP_LOW_DISTANCE + c - FheUint8::cast_from(c.gt(96) & c.lt(123)) * UP_LOW_DISTANCE } fn to_lower(c: &FheUint8) -> FheUint8 { - c + (c.gt(64) & c.lt(91)) * UP_LOW_DISTANCE + c + FheUint8::cast_from(c.gt(64) & c.lt(91)) * UP_LOW_DISTANCE } impl FheAsciiString { diff --git a/tfhe/src/c_api/high_level_api/booleans.rs b/tfhe/src/c_api/high_level_api/booleans.rs index 76c7ee83ce..94f76e41b2 100644 --- a/tfhe/src/c_api/high_level_api/booleans.rs +++ b/tfhe/src/c_api/high_level_api/booleans.rs @@ -2,7 +2,7 @@ use crate::high_level_api::prelude::*; use std::ops::{BitAnd, BitOr, BitXor, Not}; -pub struct FheBool(crate::high_level_api::FheBool); +pub struct FheBool(pub(in crate::c_api) crate::high_level_api::FheBool); impl_destroy_on_type!(FheBool); impl_clone_on_type!(FheBool); diff --git a/tfhe/src/c_api/high_level_api/integers.rs b/tfhe/src/c_api/high_level_api/integers.rs index 00dc67ff25..c3b68bc389 100644 --- a/tfhe/src/c_api/high_level_api/integers.rs +++ b/tfhe/src/c_api/high_level_api/integers.rs @@ -5,6 +5,7 @@ use std::ops::{ Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; +use crate::c_api::high_level_api::booleans::FheBool; use crate::c_api::high_level_api::i128::I128; use crate::c_api::high_level_api::i256::I256; use crate::c_api::high_level_api::u128::U128; @@ -33,18 +34,26 @@ macro_rules! impl_operations_for_integer_type { bitand, bitor, bitxor, - eq, - ne, - ge, - gt, - le, - lt, min, max, div, rem, ); + // Handle comparisons separately as they return FheBool + impl_comparison_fn_on_type!( + lhs_type: $name, + rhs_type: $name, + comparison_fn_names: eq, ne, ge, gt, le, lt, + ); + + // Handle comparisons separately as they return FheBool + impl_scalar_comparison_fn_on_type!( + lhs_type: $name, + clear_type: $clear_scalar_type, + comparison_fn_names: eq, ne, ge, gt, le, lt, + ); + // handle shift separately as they require // rhs to be an unsigned type impl_binary_fn_on_type!( @@ -79,12 +88,6 @@ macro_rules! impl_operations_for_integer_type { bitand, bitor, bitxor, - eq, - ne, - ge, - gt, - le, - lt, min, max, div, @@ -165,10 +168,12 @@ macro_rules! impl_operations_for_integer_type { } } + // Even though if_then_else/cmux is a method of FheBool, it still takes as + // integers inputs, so its easier to keep the definition here ::paste::paste! { #[no_mangle] pub unsafe extern "C" fn [<$name:snake _if_then_else>]( - condition_ct: *const $name, + condition_ct: *const FheBool, then_ct: *const $name, else_ct: *const $name, result: *mut *mut $name, @@ -186,7 +191,7 @@ macro_rules! impl_operations_for_integer_type { // map cmux to if_then_else pub unsafe extern "C" fn [<$name:snake _cmux>]( - condition_ct: *const $name, + condition_ct: *const FheBool, then_ct: *const $name, else_ct: *const $name, result: *mut *mut $name, diff --git a/tfhe/src/c_api/high_level_api/utils.rs b/tfhe/src/c_api/high_level_api/utils.rs index 7c3b85aacf..5bd3b8f624 100644 --- a/tfhe/src/c_api/high_level_api/utils.rs +++ b/tfhe/src/c_api/high_level_api/utils.rs @@ -366,13 +366,76 @@ macro_rules! impl_binary_fn_on_type { // Usual binary fn case, where lhs, rhs and result are all of the same type ($wrapper_type:ty => $($binary_fn_name:ident),* $(,)?) => { impl_binary_fn_on_type!( - lhs_type: $wrapper_type, + lhs_type: $wrapper_type, rhs_type: $wrapper_type, binary_fn_names: $($binary_fn_name),* ); }; } +// Comparisons returns FheBool so we use a specialized +// macro for them +macro_rules! impl_comparison_fn_on_type { + ( + lhs_type: $lhs_type:ty, + rhs_type: $rhs_type:ty, + comparison_fn_names: $($comparison_fn_name:ident),* + $(,)? + ) => { + $( // unroll comparison_fn_names + ::paste::paste! { + #[no_mangle] + pub unsafe extern "C" fn [<$lhs_type:snake _ $comparison_fn_name>]( + lhs: *const $lhs_type, + rhs: *const $rhs_type, + result: *mut *mut $crate::c_api::high_level_api::booleans::FheBool, + ) -> ::std::os::raw::c_int { + $crate::c_api::utils::catch_panic(|| { + let lhs = $crate::c_api::utils::get_ref_checked(lhs).unwrap(); + let rhs = $crate::c_api::utils::get_ref_checked(rhs).unwrap(); + + let inner = (&lhs.0).$comparison_fn_name(&rhs.0); + + let inner = $crate::c_api::high_level_api::booleans::FheBool(inner); + *result = Box::into_raw(Box::new(inner)); + }) + } + } + )* + }; +} + +macro_rules! impl_scalar_comparison_fn_on_type { + ( + lhs_type: $lhs_type:ty, + clear_type: $scalar_type:ty, + comparison_fn_names: $($comparison_fn_name:ident),* + $(,)? + ) => { + $( // unroll comparison_fn_names + ::paste::paste! { + #[no_mangle] + pub unsafe extern "C" fn [<$lhs_type:snake _scalar_ $comparison_fn_name>]( + lhs: *const $lhs_type, + rhs: $scalar_type, + result: *mut *mut $crate::c_api::high_level_api::booleans::FheBool, + ) -> ::std::os::raw::c_int { + $crate::c_api::utils::catch_panic(|| { + let lhs = $crate::c_api::utils::get_ref_checked(lhs).unwrap(); + let rhs = <$scalar_type as $crate::c_api::high_level_api::utils::CApiIntegerType>::to_rust(rhs); + + + let inner = (&lhs.0).$comparison_fn_name(rhs); + + let inner = $crate::c_api::high_level_api::booleans::FheBool(inner); + *result = Box::into_raw(Box::new(inner)); + }) + } + } + )* + }; +} + macro_rules! impl_unary_fn_on_type { ($wrapper_type:ty => $($unary_fn_name:ident),* $(,)?) => { $( diff --git a/tfhe/src/high_level_api/booleans/mod.rs b/tfhe/src/high_level_api/booleans/mod.rs index bf55db9c14..2bdb2428e9 100644 --- a/tfhe/src/high_level_api/booleans/mod.rs +++ b/tfhe/src/high_level_api/booleans/mod.rs @@ -3,11 +3,13 @@ use std::ops::{BitAnd, BitOr, BitXor}; use crate::errors::Type; use crate::high_level_api::global_state::WithGlobalKey; +use crate::high_level_api::integers::{GenericInteger, IntegerId}; use crate::high_level_api::internal_traits::TypeIdentifier; use crate::high_level_api::keys::{ClientKey, PublicKey}; use crate::high_level_api::traits::{ FheDecrypt, FheEq, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, }; +use crate::integer::BooleanBlock; use crate::shortint::{Ciphertext, CompressedCiphertext}; use crate::CompressedPublicKey; use serde::{Deserialize, Serialize}; @@ -60,46 +62,80 @@ impl WithGlobalKey for FheBoolId { /// ``` #[derive(Clone, Serialize, Deserialize)] pub struct FheBool { - pub(in crate::high_level_api::booleans) ciphertext: Ciphertext, + pub(in crate::high_level_api) ciphertext: BooleanBlock, id: FheBoolId, } #[derive(Clone, Serialize, Deserialize)] pub struct CompressedFheBool { - pub(in crate::high_level_api::booleans) ciphertext: CompressedCiphertext, + pub(in crate::high_level_api) ciphertext: CompressedCiphertext, } impl FheBool { - pub(in crate::high_level_api::booleans) fn new(ciphertext: Ciphertext) -> Self { + pub(in crate::high_level_api) fn new(ciphertext: BooleanBlock) -> Self { Self { ciphertext, id: FheBoolId, } } + + /// Conditional selection. + /// + /// The output value returned depends on the value of `self`. + /// + /// `self` has to encrypt 0 or 1. + /// + /// - if `self` is true (1), the output will have the value of `ct_then` + /// - if `self` is false (0), the output will have the value of `ct_else` + pub fn if_then_else( + &self, + ct_then: &GenericInteger, + ct_else: &GenericInteger, + ) -> GenericInteger { + let ct_condition = self; + let new_ct = ct_condition.id.with_unwrapped_global(|integer_key| { + integer_key.pbs_key().if_then_else_parallelized( + &ct_condition.ciphertext, + &ct_then.ciphertext, + &ct_else.ciphertext, + ) + }); + + GenericInteger::new(new_ct, Id::default()) + } + + /// Conditional selection. + /// + /// cmux is another name for (if_then_else)[Self::if_then_else] + pub fn cmux( + &self, + ct_then: &GenericInteger, + ct_else: &GenericInteger, + ) -> GenericInteger { + self.if_then_else(ct_then, ct_else) + } } impl FheEq for FheBool where B: Borrow, { - type Output = Self; - fn eq(&self, other: B) -> Self { let ciphertext = self.id.with_unwrapped_global(|key| { key.pbs_key() .key - .equal(&self.ciphertext, &other.borrow().ciphertext) + .equal(self.ciphertext.as_ref(), other.borrow().ciphertext.as_ref()) }); - Self::new(ciphertext) + Self::new(BooleanBlock::new_unchecked(ciphertext)) } fn ne(&self, other: B) -> Self { let ciphertext = self.id.with_unwrapped_global(|key| { key.pbs_key() .key - .not_equal(&self.ciphertext, &other.borrow().ciphertext) + .not_equal(self.ciphertext.as_ref(), other.borrow().ciphertext.as_ref()) }); - Self::new(ciphertext) + Self::new(BooleanBlock::new_unchecked(ciphertext)) } } @@ -112,7 +148,7 @@ impl CompressedFheBool { impl From for FheBool { fn from(value: CompressedFheBool) -> Self { let block: Ciphertext = value.ciphertext.into(); - Self::new(block) + Self::new(BooleanBlock::new_unchecked(block)) } } @@ -131,7 +167,7 @@ impl FheTryEncrypt for FheBool { fn try_encrypt(value: bool, key: &ClientKey) -> Result { let integer_client_key = &key.key.key; - let ciphertext = integer_client_key.encrypt_one_block(u64::from(value)); + let ciphertext = integer_client_key.encrypt_bool(value); Ok(Self::new(ciphertext)) } } @@ -141,7 +177,7 @@ impl FheTryTrivialEncrypt for FheBool { fn try_encrypt_trivial(value: bool) -> Result { let ciphertext = FheBoolId - .with_unwrapped_global(|key| key.pbs_key().key.create_trivial(u64::from(value))); + .with_unwrapped_global(|key| key.pbs_key().create_trivial_boolean_block(value)); Ok(Self::new(ciphertext)) } } @@ -158,7 +194,7 @@ impl FheTryEncrypt for crate::FheBool { fn try_encrypt(value: bool, key: &CompressedPublicKey) -> Result { let key = &key.key; - let ciphertext = key.key.encrypt(u64::from(value)); + let ciphertext = key.encrypt_bool(value); Ok(Self::new(ciphertext)) } } @@ -168,7 +204,7 @@ impl FheTryEncrypt for FheBool { fn try_encrypt(value: bool, key: &PublicKey) -> Result { let key = &key.key; - let ciphertext = key.key.encrypt(u64::from(value)); + let ciphertext = key.encrypt_bool(value); Ok(Self::new(ciphertext)) } } @@ -176,7 +212,7 @@ impl FheTryEncrypt for FheBool { impl FheDecrypt for FheBool { fn decrypt(&self, key: &ClientKey) -> bool { let integer_client_key = &key.key.key; - integer_client_key.decrypt_one_block(&self.ciphertext) != 0 + integer_client_key.decrypt_bool(&self.ciphertext) } } @@ -199,9 +235,9 @@ macro_rules! fhe_bool_impl_operation( fn $trait_method(self, rhs: B) -> Self::Output { let ciphertext = self.id.with_unwrapped_global(|key| { - key.pbs_key().key.$key_method(&self.ciphertext, &rhs.borrow().ciphertext) + key.pbs_key().key.$key_method(self.ciphertext.as_ref(), rhs.borrow().ciphertext.as_ref()) }); - FheBool::new(ciphertext) + FheBool::new(BooleanBlock::new_unchecked(ciphertext)) } } }; @@ -215,10 +251,10 @@ impl ::std::ops::Not for FheBool { type Output = Self; fn not(self) -> Self::Output { - let ciphertext = self - .id - .with_unwrapped_global(|key| key.pbs_key().key.scalar_bitxor(&self.ciphertext, 1)); - Self::new(ciphertext) + let ciphertext = self.id.with_unwrapped_global(|key| { + key.pbs_key().key.scalar_bitxor(self.ciphertext.as_ref(), 1) + }); + Self::new(BooleanBlock::new_unchecked(ciphertext)) } } @@ -226,9 +262,9 @@ impl ::std::ops::Not for &FheBool { type Output = FheBool; fn not(self) -> Self::Output { - let ciphertext = self - .id - .with_unwrapped_global(|key| key.pbs_key().key.scalar_bitxor(&self.ciphertext, 1)); - FheBool::new(ciphertext) + let ciphertext = self.id.with_unwrapped_global(|key| { + key.pbs_key().key.scalar_bitxor(self.ciphertext.as_ref(), 1) + }); + FheBool::new(BooleanBlock::new_unchecked(ciphertext)) } } diff --git a/tfhe/src/high_level_api/integers/mod.rs b/tfhe/src/high_level_api/integers/mod.rs index 7ef83ac29a..20da82fba8 100644 --- a/tfhe/src/high_level_api/integers/mod.rs +++ b/tfhe/src/high_level_api/integers/mod.rs @@ -10,6 +10,9 @@ pub(in crate::high_level_api) use keys::{ IntegerCompressedServerKey, IntegerConfig, IntegerServerKey, }; +pub(in crate::high_level_api) use parameters::IntegerId; +pub(in crate::high_level_api) use types::GenericInteger; + mod client_key; mod keys; mod parameters; diff --git a/tfhe/src/high_level_api/integers/tests_signed.rs b/tfhe/src/high_level_api/integers/tests_signed.rs index d810585bab..59007e145c 100644 --- a/tfhe/src/high_level_api/integers/tests_signed.rs +++ b/tfhe/src/high_level_api/integers/tests_signed.rs @@ -52,86 +52,86 @@ fn test_int32_compare() { // Test comparing encrypted with encrypted { let result = &a.eq(&b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a == clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a == clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.eq(&a); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a == clear_a); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a == clear_a; assert_eq!(decrypted_result, clear_result); let result = &a.ne(&b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a != clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a != clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.ne(&a); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a != clear_a); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a != clear_a; assert_eq!(decrypted_result, clear_result); let result = &a.le(&b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a <= clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a <= clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.lt(&b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a < clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a < clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.ge(&b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a >= clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a >= clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.gt(&b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a > clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a > clear_b; assert_eq!(decrypted_result, clear_result); } // Test comparing encrypted with clear { let result = &a.eq(clear_b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a == clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a == clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.eq(clear_a); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a == clear_a); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a == clear_a; assert_eq!(decrypted_result, clear_result); let result = &a.ne(clear_b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a != clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a != clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.ne(clear_a); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a != clear_a); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a != clear_a; assert_eq!(decrypted_result, clear_result); let result = &a.le(clear_b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a <= clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a <= clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.lt(clear_b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a < clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a < clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.ge(clear_b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a >= clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a >= clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.gt(clear_b); - let decrypted_result: i32 = result.decrypt(&client_key); - let clear_result = i32::from(clear_a > clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a > clear_b; assert_eq!(decrypted_result, clear_result); } } diff --git a/tfhe/src/high_level_api/integers/tests_unsigned.rs b/tfhe/src/high_level_api/integers/tests_unsigned.rs index 83e2dda1e0..a933470dfa 100644 --- a/tfhe/src/high_level_api/integers/tests_unsigned.rs +++ b/tfhe/src/high_level_api/integers/tests_unsigned.rs @@ -49,86 +49,86 @@ fn test_uint8_compare() { // Test comparing encrypted with encrypted { let result = &a.eq(&b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a == clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a == clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.eq(&a); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a == clear_a); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a == clear_a; assert_eq!(decrypted_result, clear_result); let result = &a.ne(&b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a != clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a != clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.ne(&a); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a != clear_a); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a != clear_a; assert_eq!(decrypted_result, clear_result); let result = &a.le(&b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a <= clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a <= clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.lt(&b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a < clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a < clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.ge(&b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a >= clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a >= clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.gt(&b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a > clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a > clear_b; assert_eq!(decrypted_result, clear_result); } // Test comparing encrypted with clear { let result = &a.eq(clear_b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a == clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a == clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.eq(clear_a); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a == clear_a); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a == clear_a; assert_eq!(decrypted_result, clear_result); let result = &a.ne(clear_b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a != clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a != clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.ne(clear_a); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a != clear_a); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a != clear_a; assert_eq!(decrypted_result, clear_result); let result = &a.le(clear_b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a <= clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a <= clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.lt(clear_b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a < clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a < clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.ge(clear_b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a >= clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a >= clear_b; assert_eq!(decrypted_result, clear_result); let result = &a.gt(clear_b); - let decrypted_result: u8 = result.decrypt(&client_key); - let clear_result = u8::from(clear_a > clear_b); + let decrypted_result = result.decrypt(&client_key); + let clear_result = clear_a > clear_b; assert_eq!(decrypted_result, clear_result); } } diff --git a/tfhe/src/high_level_api/integers/types/base.rs b/tfhe/src/high_level_api/integers/types/base.rs index 38041c132e..285b35ec13 100644 --- a/tfhe/src/high_level_api/integers/types/base.rs +++ b/tfhe/src/high_level_api/integers/types/base.rs @@ -5,6 +5,7 @@ use std::ops::{ }; use crate::conformance::ParameterSetConformant; +use crate::core_crypto::prelude::CastFrom; use crate::high_level_api::global_state::WithGlobalKey; use crate::high_level_api::integers::parameters::IntegerId; use crate::high_level_api::integers::IntegerServerKey; @@ -17,12 +18,11 @@ use crate::high_level_api::traits::{ }; use crate::high_level_api::{ClientKey, PublicKey}; use crate::integer::block_decomposition::DecomposableInto; -use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::{IntegerRadixCiphertext, RadixCiphertext}; use crate::integer::parameters::RadixCiphertextConformanceParams; use crate::integer::{IntegerCiphertext, SignedRadixCiphertext, I256, U256}; use crate::named::Named; -use crate::CompactPublicKey; +use crate::{CompactPublicKey, FheBool}; #[derive(Debug)] pub enum GenericIntegerBlockError { @@ -80,7 +80,7 @@ impl std::fmt::Display for GenericIntegerBlockError { #[cfg_attr(all(doc, not(doctest)), doc(cfg(feature = "integer")))] #[derive(Clone, serde::Deserialize, serde::Serialize)] pub struct GenericInteger { - pub(in crate::high_level_api::integers) ciphertext: Id::InnerCiphertext, + pub(in crate::high_level_api) ciphertext: Id::InnerCiphertext, pub(in crate::high_level_api::integers) id: Id, } @@ -102,10 +102,7 @@ impl GenericInteger where Id: IntegerId, { - pub(in crate::high_level_api::integers) fn new( - ciphertext: Id::InnerCiphertext, - id: Id, - ) -> Self { + pub(in crate::high_level_api) fn new(ciphertext: Id::InnerCiphertext, id: Id) -> Self { Self { ciphertext, id } } @@ -181,37 +178,18 @@ where } } -impl GenericInteger +impl CastFrom for GenericInteger where - Id: IntegerId + WithGlobalKey, + Id: IntegerId, { - /// Conditional selection. - /// - /// The output value returned depends on the value of `self`. - /// - /// `self` has to encrypt 0 or 1. - /// - /// - if `self` is true (1), the output will have the value of `ct_then` - /// - if `self` is false (0), the output will have the value of `ct_else` - pub fn if_then_else(&self, ct_then: &Self, ct_else: &Self) -> Self { - let ct_condition = self; - let new_ct = ct_condition.id.with_unwrapped_global(|integer_key| { - integer_key.pbs_key().if_then_else_parallelized( - &BooleanBlock::try_new(&ct_condition.ciphertext) - .expect("if_then_else requires a boolean value"), - &ct_then.ciphertext, - &ct_else.ciphertext, - ) + fn cast_from(input: FheBool) -> Self { + let ciphertext = crate::high_level_api::global_state::with_internal_keys(|keys| { + input + .ciphertext + .into_radix(Id::num_blocks(), keys.integer_key.pbs_key()) }); - Self::new(new_ct, ct_condition.id) - } - - /// Conditional selection. - /// - /// cmux is another name for (if_then_else)[Self::if_then_else] - pub fn cmux(&self, ct_then: &Self, ct_else: &Self) -> Self { - self.if_then_else(ct_then, ct_else) + Self::new(ciphertext, Id::default()) } } @@ -460,26 +438,20 @@ where Id: IntegerId + WithGlobalKey, Self: Clone, { - type Output = Self; - - fn eq(&self, rhs: Self) -> Self::Output { + fn eq(&self, rhs: Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .eq_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.eq_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn ne(&self, rhs: Self) -> Self::Output { + fn ne(&self, rhs: Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .ne_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.ne_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } } @@ -488,26 +460,20 @@ where Id: IntegerId + WithGlobalKey, Self: Clone, { - type Output = Self; - - fn eq(&self, rhs: &Self) -> Self::Output { + fn eq(&self, rhs: &Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .eq_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.eq_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn ne(&self, rhs: &Self) -> Self::Output { + fn ne(&self, rhs: &Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .ne_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.ne_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } } @@ -517,26 +483,20 @@ where Id: IntegerId + WithGlobalKey, Self: Clone, { - type Output = Self; - - fn eq(&self, rhs: Clear) -> Self::Output { + fn eq(&self, rhs: Clear) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .scalar_eq_parallelized(&self.ciphertext, rhs) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.scalar_eq_parallelized(&self.ciphertext, rhs) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn ne(&self, rhs: Clear) -> Self::Output { + fn ne(&self, rhs: Clear) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .scalar_ne_parallelized(&self.ciphertext, rhs) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.scalar_ne_parallelized(&self.ciphertext, rhs) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } } @@ -545,46 +505,36 @@ where Id: IntegerId + WithGlobalKey, Self: Clone, { - type Output = Self; - - fn lt(&self, rhs: Self) -> Self::Output { + fn lt(&self, rhs: Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .lt_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.lt_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn le(&self, rhs: Self) -> Self::Output { + fn le(&self, rhs: Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .le_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.le_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn gt(&self, rhs: Self) -> Self::Output { + fn gt(&self, rhs: Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .gt_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.gt_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn ge(&self, rhs: Self) -> Self::Output { + fn ge(&self, rhs: Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .ge_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.ge_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } } @@ -593,46 +543,36 @@ where Id: IntegerId + WithGlobalKey, Self: Clone, { - type Output = Self; - - fn lt(&self, rhs: &Self) -> Self::Output { + fn lt(&self, rhs: &Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .lt_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.lt_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn le(&self, rhs: &Self) -> Self::Output { + fn le(&self, rhs: &Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .le_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.le_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn gt(&self, rhs: &Self) -> Self::Output { + fn gt(&self, rhs: &Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .gt_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.gt_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn ge(&self, rhs: &Self) -> Self::Output { + fn ge(&self, rhs: &Self) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .ge_parallelized(&self.ciphertext, &rhs.ciphertext) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.ge_parallelized(&self.ciphertext, &rhs.ciphertext) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } } @@ -642,46 +582,36 @@ where Clear: DecomposableInto, Self: Clone, { - type Output = Self; - - fn lt(&self, rhs: Clear) -> Self::Output { + fn lt(&self, rhs: Clear) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .scalar_lt_parallelized(&self.ciphertext, rhs) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.scalar_lt_parallelized(&self.ciphertext, rhs) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn le(&self, rhs: Clear) -> Self::Output { + fn le(&self, rhs: Clear) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .scalar_le_parallelized(&self.ciphertext, rhs) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.scalar_le_parallelized(&self.ciphertext, rhs) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn gt(&self, rhs: Clear) -> Self::Output { + fn gt(&self, rhs: Clear) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .scalar_gt_parallelized(&self.ciphertext, rhs) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.scalar_gt_parallelized(&self.ciphertext, rhs) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } - fn ge(&self, rhs: Clear) -> Self::Output { + fn ge(&self, rhs: Clear) -> FheBool { let inner_result = self.id.with_unwrapped_global(|integer_key| { let pbs_key = integer_key.pbs_key(); - pbs_key - .scalar_ge_parallelized(&self.ciphertext, rhs) - .into_radix(Id::num_blocks(), pbs_key) + pbs_key.scalar_ge_parallelized(&self.ciphertext, rhs) }); - Self::new(inner_result, self.id) + FheBool::new(inner_result) } } diff --git a/tfhe/src/high_level_api/traits.rs b/tfhe/src/high_level_api/traits.rs index 10709dd66a..9b04a4d352 100644 --- a/tfhe/src/high_level_api/traits.rs +++ b/tfhe/src/high_level_api/traits.rs @@ -1,4 +1,5 @@ use crate::high_level_api::ClientKey; +use crate::FheBool; /// Trait used to have a generic way of creating a value of a FHE type /// from a native value. @@ -87,11 +88,9 @@ pub trait FheDecrypt { /// for equality, one cannot use the standard operator `==` but rather, use /// the function directly. pub trait FheEq { - type Output; - - fn eq(&self, other: Rhs) -> Self::Output; + fn eq(&self, other: Rhs) -> FheBool; - fn ne(&self, other: Rhs) -> Self::Output; + fn ne(&self, other: Rhs) -> FheBool; } /// Trait for fully homomorphic comparisons. @@ -103,12 +102,10 @@ pub trait FheEq { /// one cannot use the standard operators (`>`, `<`, etc) and must use /// the functions directly. pub trait FheOrd { - type Output; - - fn lt(&self, other: Rhs) -> Self::Output; - fn le(&self, other: Rhs) -> Self::Output; - fn gt(&self, other: Rhs) -> Self::Output; - fn ge(&self, other: Rhs) -> Self::Output; + fn lt(&self, other: Rhs) -> FheBool; + fn le(&self, other: Rhs) -> FheBool; + fn gt(&self, other: Rhs) -> FheBool; + fn ge(&self, other: Rhs) -> FheBool; } pub trait FheMin { diff --git a/tfhe/src/integer/public_key/compressed.rs b/tfhe/src/integer/public_key/compressed.rs index 9a1c08a499..740f5a5f07 100644 --- a/tfhe/src/integer/public_key/compressed.rs +++ b/tfhe/src/integer/public_key/compressed.rs @@ -2,7 +2,7 @@ use crate::integer::block_decomposition::DecomposableInto; use crate::integer::ciphertext::{CrtCiphertext, RadixCiphertext}; use crate::integer::client_key::ClientKey; use crate::integer::encryption::{encrypt_crt, encrypt_words_radix_impl}; -use crate::integer::SignedRadixCiphertext; +use crate::integer::{BooleanBlock, SignedRadixCiphertext}; use crate::shortint::parameters::MessageModulus; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -46,9 +46,7 @@ impl CompressedPublicKey { { encrypt_crt(&self.key, message, base_vec, encrypt_block) } -} -impl CompressedPublicKey { pub fn parameters(&self) -> crate::shortint::PBSParameters { self.key.parameters.pbs_parameters().unwrap() } @@ -77,6 +75,10 @@ impl CompressedPublicKey { ) } + pub fn encrypt_bool(&self, message: bool) -> BooleanBlock { + BooleanBlock::new_unchecked(self.key.encrypt(u64::from(message))) + } + pub fn encrypt_radix_without_padding( &self, message: u64, diff --git a/tfhe/src/integer/public_key/standard.rs b/tfhe/src/integer/public_key/standard.rs index f25da886c8..5b913227ec 100644 --- a/tfhe/src/integer/public_key/standard.rs +++ b/tfhe/src/integer/public_key/standard.rs @@ -4,7 +4,7 @@ use crate::integer::ciphertext::{CrtCiphertext, RadixCiphertext}; use crate::integer::client_key::ClientKey; use crate::integer::encryption::{encrypt_crt, encrypt_words_radix_impl}; use crate::integer::public_key::compressed::CompressedPublicKey; -use crate::integer::SignedRadixCiphertext; +use crate::integer::{BooleanBlock, SignedRadixCiphertext}; use crate::shortint::parameters::MessageModulus; use crate::shortint::PublicKey as ShortintPublicKey; @@ -68,6 +68,10 @@ impl PublicKey { encrypt_words_radix_impl(&self.key, message, num_blocks, ShortintPublicKey::encrypt) } + pub fn encrypt_bool(&self, message: bool) -> BooleanBlock { + BooleanBlock::new_unchecked(self.key.encrypt(u64::from(message))) + } + pub fn encrypt_radix_without_padding( &self, message: u64,