diff --git a/tfhe/src/strings/server_key/comp.rs b/tfhe/src/strings/server_key/comp.rs index 389fa9e0a1..4c8c759f88 100644 --- a/tfhe/src/strings/server_key/comp.rs +++ b/tfhe/src/strings/server_key/comp.rs @@ -3,7 +3,7 @@ use crate::strings::ciphertext::{FheString, GenericPatternRef}; use crate::strings::server_key::{FheStringIsEmpty, ServerKey}; impl ServerKey { - fn eq_length_checks(&self, lhs: &FheString, rhs: &FheString) -> Option { + fn string_eq_length_checks(&self, lhs: &FheString, rhs: &FheString) -> Option { // If lhs is empty, rhs must also be empty in order to be equal (the case where lhs is // empty with > 1 padding zeros is handled next) if lhs.is_empty() { @@ -67,9 +67,9 @@ impl ServerKey { pub fn string_eq(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { let early_return = match rhs { GenericPatternRef::Clear(rhs) => { - self.eq_length_checks(lhs, &FheString::trivial(self, rhs.str())) + self.string_eq_length_checks(lhs, &FheString::trivial(self, rhs.str())) } - GenericPatternRef::Enc(rhs) => self.eq_length_checks(lhs, rhs), + GenericPatternRef::Enc(rhs) => self.string_eq_length_checks(lhs, rhs), }; if let Some(val) = early_return { @@ -135,23 +135,27 @@ impl ServerKey { /// ```rust /// use tfhe::integer::{ClientKey, ServerKey}; /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; - /// use tfhe::strings::ciphertext::FheString; + /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("apple", "banana"); /// /// let enc_s1 = FheString::new(&ck, s1, None); - /// let enc_s2 = FheString::new(&ck, s2, None); + /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// - /// let result = sk.string_lt(&enc_s1, &enc_s2); + /// let result = sk.string_lt(&enc_s1, enc_s2.as_ref()); /// let is_lt = ck.decrypt_bool(&result); /// /// assert!(is_lt); // "apple" is less than "banana" /// ``` - pub fn string_lt(&self, lhs: &FheString, rhs: &FheString) -> BooleanBlock { + pub fn string_lt(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { let mut lhs_uint = lhs.to_uint(); - let mut rhs_uint = rhs.to_uint(); + + let mut rhs_uint = match rhs { + GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(), + GenericPatternRef::Enc(rhs) => rhs.to_uint(), + }; self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); @@ -167,23 +171,26 @@ impl ServerKey { /// ```rust /// use tfhe::integer::{ClientKey, ServerKey}; /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; - /// use tfhe::strings::ciphertext::FheString; + /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("banana", "apple"); /// /// let enc_s1 = FheString::new(&ck, s1, None); - /// let enc_s2 = FheString::new(&ck, s2, None); + /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// - /// let result = sk.string_gt(&enc_s1, &enc_s2); + /// let result = sk.string_gt(&enc_s1, enc_s2.as_ref()); /// let is_gt = ck.decrypt_bool(&result); /// /// assert!(is_gt); // "banana" is greater than "apple" /// ``` - pub fn string_gt(&self, lhs: &FheString, rhs: &FheString) -> BooleanBlock { + pub fn string_gt(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { let mut lhs_uint = lhs.to_uint(); - let mut rhs_uint = rhs.to_uint(); + let mut rhs_uint = match rhs { + GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(), + GenericPatternRef::Enc(rhs) => rhs.to_uint(), + }; self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); @@ -200,24 +207,26 @@ impl ServerKey { /// ```rust /// use tfhe::integer::{ClientKey, ServerKey}; /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; - /// use tfhe::strings::ciphertext::FheString; + /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("apple", "banana"); /// /// let enc_s1 = FheString::new(&ck, s1, None); - /// let enc_s2 = FheString::new(&ck, s2, None); + /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// - /// let result = sk.string_le(&enc_s1, &enc_s2); + /// let result = sk.string_le(&enc_s1, enc_s2.as_ref()); /// let is_le = ck.decrypt_bool(&result); /// /// assert!(is_le); // "apple" is less than or equal to "banana" /// ``` - pub fn string_le(&self, lhs: &FheString, rhs: &FheString) -> BooleanBlock { + pub fn string_le(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { let mut lhs_uint = lhs.to_uint(); - let mut rhs_uint = rhs.to_uint(); - + let mut rhs_uint = match rhs { + GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(), + GenericPatternRef::Enc(rhs) => rhs.to_uint(), + }; self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); self.le_parallelized(&lhs_uint, &rhs_uint) @@ -233,23 +242,26 @@ impl ServerKey { /// ```rust /// use tfhe::integer::{ClientKey, ServerKey}; /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; - /// use tfhe::strings::ciphertext::FheString; + /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; /// /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("banana", "apple"); /// /// let enc_s1 = FheString::new(&ck, s1, None); - /// let enc_s2 = FheString::new(&ck, s2, None); + /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// - /// let result = sk.string_ge(&enc_s1, &enc_s2); + /// let result = sk.string_ge(&enc_s1, enc_s2.as_ref()); /// let is_ge = ck.decrypt_bool(&result); /// /// assert!(is_ge); // "banana" is greater than or equal to "apple" /// ``` - pub fn string_ge(&self, lhs: &FheString, rhs: &FheString) -> BooleanBlock { + pub fn string_ge(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock { let mut lhs_uint = lhs.to_uint(); - let mut rhs_uint = rhs.to_uint(); + let mut rhs_uint = match rhs { + GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(), + GenericPatternRef::Enc(rhs) => rhs.to_uint(), + }; self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); diff --git a/tfhe/src/strings/test_functions/test_common.rs b/tfhe/src/strings/test_functions/test_common.rs index 6f5c9ad593..8b7ede655d 100644 --- a/tfhe/src/strings/test_functions/test_common.rs +++ b/tfhe/src/strings/test_functions/test_common.rs @@ -1,5 +1,6 @@ +use crate::integer::{BooleanBlock, ServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; -use crate::strings::ciphertext::{ClearString, GenericPattern}; +use crate::strings::ciphertext::{ClearString, FheString, GenericPattern, GenericPatternRef}; use crate::strings::server_key::{FheStringIsEmpty, FheStringLen}; use crate::strings::test::TestKind; use crate::strings::test_functions::{ @@ -303,106 +304,47 @@ impl TestKeys { let enc_rhs = GenericPattern::Enc(self.encrypt_string(rhs, rhs_pad)); let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - // Equal - let expected_eq = str == rhs; - - let start = Instant::now(); - let result_eq = self.sk.string_eq(&enc_lhs, enc_rhs.as_ref()); - let end = Instant::now(); - - let dec_eq = self.ck.decrypt_bool(&result_eq); - - println!("\n\x1b[1mEq:\x1b[0m"); - result_message_rhs(str, rhs, expected_eq, dec_eq, end.duration_since(start)); - assert_eq!(dec_eq, expected_eq); - - // Clear rhs - let start = Instant::now(); - let result_eq = self.sk.string_eq(&enc_lhs, clear_rhs.as_ref()); - let end = Instant::now(); - - let dec_eq = self.ck.decrypt_bool(&result_eq); - - println!("\n\x1b[1mEq:\x1b[0m"); - result_message_clear_rhs(str, rhs, expected_eq, dec_eq, end.duration_since(start)); - assert_eq!(dec_eq, expected_eq); - - // Not equal - let expected_ne = str != rhs; - - let start = Instant::now(); - let result_ne = self.sk.string_ne(&enc_lhs, enc_rhs.as_ref()); - let end = Instant::now(); - - let dec_ne = self.ck.decrypt_bool(&result_ne); - - println!("\n\x1b[1mNe:\x1b[0m"); - result_message_rhs(str, rhs, expected_ne, dec_ne, end.duration_since(start)); - assert_eq!(dec_ne, expected_ne); - - // Clear rhs - let start = Instant::now(); - let result_ne = self.sk.string_ne(&enc_lhs, clear_rhs.as_ref()); - let end = Instant::now(); - - let dec_ne = self.ck.decrypt_bool(&result_ne); - - println!("\n\x1b[1mNe:\x1b[0m"); - result_message_clear_rhs(str, rhs, expected_ne, dec_ne, end.duration_since(start)); - assert_eq!(dec_ne, expected_ne); - - let enc_rhs = self.encrypt_string(rhs, rhs_pad); - - // Greater or equal - let expected_ge = str >= rhs; - - let start = Instant::now(); - let result_ge = self.sk.string_ge(&enc_lhs, &enc_rhs); - let end = Instant::now(); - - let dec_ge = self.ck.decrypt_bool(&result_ge); - - println!("\n\x1b[1mGe:\x1b[0m"); - result_message_rhs(str, rhs, expected_ge, dec_ge, end.duration_since(start)); - assert_eq!(dec_ge, expected_ge); - - // Less or equal - let expected_le = str <= rhs; - - let start = Instant::now(); - let result_le = self.sk.string_le(&enc_lhs, &enc_rhs); - let end = Instant::now(); - - let dec_le = self.ck.decrypt_bool(&result_le); - - println!("\n\x1b[1mLe:\x1b[0m"); - result_message_rhs(str, rhs, expected_le, dec_le, end.duration_since(start)); - assert_eq!(dec_le, expected_le); - - // Greater than - let expected_gt = str > rhs; - - let start = Instant::now(); - let result_gt = self.sk.string_gt(&enc_lhs, &enc_rhs); - let end = Instant::now(); - - let dec_gt = self.ck.decrypt_bool(&result_gt); - - println!("\n\x1b[1mGt:\x1b[0m"); - result_message_rhs(str, rhs, expected_gt, dec_gt, end.duration_since(start)); - assert_eq!(dec_gt, expected_gt); - - // Less than - let expected_lt = str < rhs; - - let start = Instant::now(); - let result_lt = self.sk.string_lt(&enc_lhs, &enc_rhs); - let end = Instant::now(); - - let dec_lt = self.ck.decrypt_bool(&result_lt); - - println!("\n\x1b[1mLt:\x1b[0m"); - result_message_rhs(str, rhs, expected_lt, dec_lt, end.duration_since(start)); - assert_eq!(dec_lt, expected_lt); + #[allow(clippy::type_complexity)] + let ops: [( + bool, + fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> BooleanBlock, + ); 6] = [ + (str == rhs, ServerKey::string_eq), + (str != rhs, ServerKey::string_ne), + (str >= rhs, ServerKey::string_ge), + (str <= rhs, ServerKey::string_le), + (str > rhs, ServerKey::string_gt), + (str < rhs, ServerKey::string_lt), + ]; + + for (expected_result, encrypted_op) in ops { + // Encrypted rhs + let start = Instant::now(); + let result = encrypted_op(&self.sk, &enc_lhs, enc_rhs.as_ref()); + let end = Instant::now(); + + let dec_result = self.ck.decrypt_bool(&result); + + println!("\n\x1b[1mEq:\x1b[0m"); + result_message_rhs( + str, + rhs, + expected_result, + dec_result, + end.duration_since(start), + ); + assert_eq!(dec_result, expected_result); + + // Clear rhs + let start = Instant::now(); + let result_eq = encrypted_op(&self.sk, &enc_lhs, clear_rhs.as_ref()); + let end = Instant::now(); + + let dec_eq = self.ck.decrypt_bool(&result_eq); + + println!("\n\x1b[1mEq:\x1b[0m"); + result_message_clear_rhs(str, rhs, expected_result, dec_eq, end.duration_since(start)); + assert_eq!(dec_eq, expected_result); + } } }