Skip to content

Commit

Permalink
refactor(strings): comparisons take a GenericPattern
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Nov 20, 2024
1 parent 11a0fe2 commit 46cf465
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 126 deletions.
60 changes: 36 additions & 24 deletions tfhe/src/strings/server_key/comp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::strings::ciphertext::{FheString, GenericPatternRef};
use crate::strings::server_key::{FheStringIsEmpty, ServerKey};

impl ServerKey {
fn eq_length_checks(&self, lhs: &FheString, rhs: &FheString) -> Option<BooleanBlock> {
fn string_eq_length_checks(&self, lhs: &FheString, rhs: &FheString) -> Option<BooleanBlock> {
// If lhs is empty, rhs must also be empty in order to be equal (the case where lhs is
// empty with > 1 padding zeros is handled next)
if lhs.is_empty() {
Expand Down Expand Up @@ -67,9 +67,9 @@ impl ServerKey {
pub fn string_eq(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock {
let early_return = match rhs {
GenericPatternRef::Clear(rhs) => {
self.eq_length_checks(lhs, &FheString::trivial(self, rhs.str()))
self.string_eq_length_checks(lhs, &FheString::trivial(self, rhs.str()))
}
GenericPatternRef::Enc(rhs) => self.eq_length_checks(lhs, rhs),
GenericPatternRef::Enc(rhs) => self.string_eq_length_checks(lhs, rhs),
};

if let Some(val) = early_return {
Expand Down Expand Up @@ -135,23 +135,27 @@ impl ServerKey {
/// ```rust
/// use tfhe::integer::{ClientKey, ServerKey};
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::strings::ciphertext::FheString;
/// use tfhe::strings::ciphertext::{FheString, GenericPattern};
///
/// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let sk = ServerKey::new_radix_server_key(&ck);
/// let (s1, s2) = ("apple", "banana");
///
/// let enc_s1 = FheString::new(&ck, s1, None);
/// let enc_s2 = FheString::new(&ck, s2, None);
/// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None));
///
/// let result = sk.string_lt(&enc_s1, &enc_s2);
/// let result = sk.string_lt(&enc_s1, enc_s2.as_ref());
/// let is_lt = ck.decrypt_bool(&result);
///
/// assert!(is_lt); // "apple" is less than "banana"
/// ```
pub fn string_lt(&self, lhs: &FheString, rhs: &FheString) -> BooleanBlock {
pub fn string_lt(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock {
let mut lhs_uint = lhs.to_uint();
let mut rhs_uint = rhs.to_uint();

let mut rhs_uint = match rhs {
GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(),
GenericPatternRef::Enc(rhs) => rhs.to_uint(),
};

self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint);

Expand All @@ -167,23 +171,26 @@ impl ServerKey {
/// ```rust
/// use tfhe::integer::{ClientKey, ServerKey};
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::strings::ciphertext::FheString;
/// use tfhe::strings::ciphertext::{FheString, GenericPattern};
///
/// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let sk = ServerKey::new_radix_server_key(&ck);
/// let (s1, s2) = ("banana", "apple");
///
/// let enc_s1 = FheString::new(&ck, s1, None);
/// let enc_s2 = FheString::new(&ck, s2, None);
/// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None));
///
/// let result = sk.string_gt(&enc_s1, &enc_s2);
/// let result = sk.string_gt(&enc_s1, enc_s2.as_ref());
/// let is_gt = ck.decrypt_bool(&result);
///
/// assert!(is_gt); // "banana" is greater than "apple"
/// ```
pub fn string_gt(&self, lhs: &FheString, rhs: &FheString) -> BooleanBlock {
pub fn string_gt(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock {
let mut lhs_uint = lhs.to_uint();
let mut rhs_uint = rhs.to_uint();
let mut rhs_uint = match rhs {
GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(),
GenericPatternRef::Enc(rhs) => rhs.to_uint(),
};

self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint);

Expand All @@ -200,24 +207,26 @@ impl ServerKey {
/// ```rust
/// use tfhe::integer::{ClientKey, ServerKey};
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::strings::ciphertext::FheString;
/// use tfhe::strings::ciphertext::{FheString, GenericPattern};
///
/// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let sk = ServerKey::new_radix_server_key(&ck);
/// let (s1, s2) = ("apple", "banana");
///
/// let enc_s1 = FheString::new(&ck, s1, None);
/// let enc_s2 = FheString::new(&ck, s2, None);
/// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None));
///
/// let result = sk.string_le(&enc_s1, &enc_s2);
/// let result = sk.string_le(&enc_s1, enc_s2.as_ref());
/// let is_le = ck.decrypt_bool(&result);
///
/// assert!(is_le); // "apple" is less than or equal to "banana"
/// ```
pub fn string_le(&self, lhs: &FheString, rhs: &FheString) -> BooleanBlock {
pub fn string_le(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock {
let mut lhs_uint = lhs.to_uint();
let mut rhs_uint = rhs.to_uint();

let mut rhs_uint = match rhs {
GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(),
GenericPatternRef::Enc(rhs) => rhs.to_uint(),
};
self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint);

self.le_parallelized(&lhs_uint, &rhs_uint)
Expand All @@ -233,23 +242,26 @@ impl ServerKey {
/// ```rust
/// use tfhe::integer::{ClientKey, ServerKey};
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::strings::ciphertext::FheString;
/// use tfhe::strings::ciphertext::{FheString, GenericPattern};
///
/// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let sk = ServerKey::new_radix_server_key(&ck);
/// let (s1, s2) = ("banana", "apple");
///
/// let enc_s1 = FheString::new(&ck, s1, None);
/// let enc_s2 = FheString::new(&ck, s2, None);
/// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None));
///
/// let result = sk.string_ge(&enc_s1, &enc_s2);
/// let result = sk.string_ge(&enc_s1, enc_s2.as_ref());
/// let is_ge = ck.decrypt_bool(&result);
///
/// assert!(is_ge); // "banana" is greater than or equal to "apple"
/// ```
pub fn string_ge(&self, lhs: &FheString, rhs: &FheString) -> BooleanBlock {
pub fn string_ge(&self, lhs: &FheString, rhs: GenericPatternRef<'_>) -> BooleanBlock {
let mut lhs_uint = lhs.to_uint();
let mut rhs_uint = rhs.to_uint();
let mut rhs_uint = match rhs {
GenericPatternRef::Clear(rhs) => FheString::trivial(self, rhs.str()).to_uint(),
GenericPatternRef::Enc(rhs) => rhs.to_uint(),
};

self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint);

Expand Down
146 changes: 44 additions & 102 deletions tfhe/src/strings/test_functions/test_common.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::integer::{BooleanBlock, ServerKey};
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::strings::ciphertext::{ClearString, GenericPattern};
use crate::strings::ciphertext::{ClearString, FheString, GenericPattern, GenericPatternRef};
use crate::strings::server_key::{FheStringIsEmpty, FheStringLen};
use crate::strings::test::TestKind;
use crate::strings::test_functions::{
Expand Down Expand Up @@ -303,106 +304,47 @@ impl TestKeys {
let enc_rhs = GenericPattern::Enc(self.encrypt_string(rhs, rhs_pad));
let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string()));

// Equal
let expected_eq = str == rhs;

let start = Instant::now();
let result_eq = self.sk.string_eq(&enc_lhs, enc_rhs.as_ref());
let end = Instant::now();

let dec_eq = self.ck.decrypt_bool(&result_eq);

println!("\n\x1b[1mEq:\x1b[0m");
result_message_rhs(str, rhs, expected_eq, dec_eq, end.duration_since(start));
assert_eq!(dec_eq, expected_eq);

// Clear rhs
let start = Instant::now();
let result_eq = self.sk.string_eq(&enc_lhs, clear_rhs.as_ref());
let end = Instant::now();

let dec_eq = self.ck.decrypt_bool(&result_eq);

println!("\n\x1b[1mEq:\x1b[0m");
result_message_clear_rhs(str, rhs, expected_eq, dec_eq, end.duration_since(start));
assert_eq!(dec_eq, expected_eq);

// Not equal
let expected_ne = str != rhs;

let start = Instant::now();
let result_ne = self.sk.string_ne(&enc_lhs, enc_rhs.as_ref());
let end = Instant::now();

let dec_ne = self.ck.decrypt_bool(&result_ne);

println!("\n\x1b[1mNe:\x1b[0m");
result_message_rhs(str, rhs, expected_ne, dec_ne, end.duration_since(start));
assert_eq!(dec_ne, expected_ne);

// Clear rhs
let start = Instant::now();
let result_ne = self.sk.string_ne(&enc_lhs, clear_rhs.as_ref());
let end = Instant::now();

let dec_ne = self.ck.decrypt_bool(&result_ne);

println!("\n\x1b[1mNe:\x1b[0m");
result_message_clear_rhs(str, rhs, expected_ne, dec_ne, end.duration_since(start));
assert_eq!(dec_ne, expected_ne);

let enc_rhs = self.encrypt_string(rhs, rhs_pad);

// Greater or equal
let expected_ge = str >= rhs;

let start = Instant::now();
let result_ge = self.sk.string_ge(&enc_lhs, &enc_rhs);
let end = Instant::now();

let dec_ge = self.ck.decrypt_bool(&result_ge);

println!("\n\x1b[1mGe:\x1b[0m");
result_message_rhs(str, rhs, expected_ge, dec_ge, end.duration_since(start));
assert_eq!(dec_ge, expected_ge);

// Less or equal
let expected_le = str <= rhs;

let start = Instant::now();
let result_le = self.sk.string_le(&enc_lhs, &enc_rhs);
let end = Instant::now();

let dec_le = self.ck.decrypt_bool(&result_le);

println!("\n\x1b[1mLe:\x1b[0m");
result_message_rhs(str, rhs, expected_le, dec_le, end.duration_since(start));
assert_eq!(dec_le, expected_le);

// Greater than
let expected_gt = str > rhs;

let start = Instant::now();
let result_gt = self.sk.string_gt(&enc_lhs, &enc_rhs);
let end = Instant::now();

let dec_gt = self.ck.decrypt_bool(&result_gt);

println!("\n\x1b[1mGt:\x1b[0m");
result_message_rhs(str, rhs, expected_gt, dec_gt, end.duration_since(start));
assert_eq!(dec_gt, expected_gt);

// Less than
let expected_lt = str < rhs;

let start = Instant::now();
let result_lt = self.sk.string_lt(&enc_lhs, &enc_rhs);
let end = Instant::now();

let dec_lt = self.ck.decrypt_bool(&result_lt);

println!("\n\x1b[1mLt:\x1b[0m");
result_message_rhs(str, rhs, expected_lt, dec_lt, end.duration_since(start));
assert_eq!(dec_lt, expected_lt);
#[allow(clippy::type_complexity)]
let ops: [(
bool,
fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> BooleanBlock,
); 6] = [
(str == rhs, ServerKey::string_eq),
(str != rhs, ServerKey::string_ne),
(str >= rhs, ServerKey::string_ge),
(str <= rhs, ServerKey::string_le),
(str > rhs, ServerKey::string_gt),
(str < rhs, ServerKey::string_lt),
];

for (expected_result, encrypted_op) in ops {
// Encrypted rhs
let start = Instant::now();
let result = encrypted_op(&self.sk, &enc_lhs, enc_rhs.as_ref());
let end = Instant::now();

let dec_result = self.ck.decrypt_bool(&result);

println!("\n\x1b[1mEq:\x1b[0m");
result_message_rhs(
str,
rhs,
expected_result,
dec_result,
end.duration_since(start),
);
assert_eq!(dec_result, expected_result);

// Clear rhs
let start = Instant::now();
let result_eq = encrypted_op(&self.sk, &enc_lhs, clear_rhs.as_ref());
let end = Instant::now();

let dec_eq = self.ck.decrypt_bool(&result_eq);

println!("\n\x1b[1mEq:\x1b[0m");
result_message_clear_rhs(str, rhs, expected_result, dec_eq, end.duration_since(start));
assert_eq!(dec_eq, expected_result);
}
}
}

0 comments on commit 46cf465

Please sign in to comment.