Skip to content

Commit

Permalink
mp-spdz-rs: ffi: Add bindings for ciphertext <> plaintext ops
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Apr 2, 2024
1 parent 209b90d commit 11a5db9
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 11 deletions.
2 changes: 1 addition & 1 deletion mp-spdz-rs/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fn main() {
}

// Build cache flags
println!("cargo:rerun-if-changed=../MP-SPDZ");
println!("cargo:rerun-if-changed=src/include/MP-SPDZ");
}

/// Get the vendor of the current host
Expand Down
75 changes: 65 additions & 10 deletions mp-spdz-rs/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ mod ffi_inner {

// `Ciphertext`
type Ciphertext;
fn add_plaintext(c0: &Ciphertext, p1: &Plaintext_mod_prime) -> UniquePtr<Ciphertext>;
fn mul_plaintext(c0: &Ciphertext, p1: &Plaintext_mod_prime) -> UniquePtr<Ciphertext>;
fn add_ciphertexts(c0: &Ciphertext, c1: &Ciphertext) -> UniquePtr<Ciphertext>;
fn mul_ciphertexts(c0: &Ciphertext, c1: &Ciphertext, pk: &FHE_PK) -> UniquePtr<Ciphertext>;
}
Expand All @@ -65,18 +67,71 @@ mod test {
(params, keypair)
}

/// Create a plaintext value with the given integer in the first slot
fn plaintext_int(val: u32, params: &FHE_Params) -> UniquePtr<Plaintext_mod_prime> {
let mut plaintext = new_plaintext(params);
set_element_int(plaintext.pin_mut(), 0 /* idx */, val);

plaintext
}

/// Create a ciphertext encrypting a single integer in the zero'th slot
fn encrypt_int(
params: &FHE_Params,
keypair: &FHE_KeyPair,
value: u32,
keypair: &FHE_KeyPair,
params: &FHE_Params,
) -> UniquePtr<Ciphertext> {
let mut plaintext = new_plaintext(params);
set_element_int(plaintext.pin_mut(), 0 /* idx */, value);

let mut plaintext = plaintext_int(value, params);
encrypt(keypair, &plaintext)
}

/// Tests addition of a plaintext to a ciphertext
#[test]
fn test_plaintext_addition() {
let mut rng = thread_rng();
let (params, mut keypair) = setup_fhe(0, 254);

// Add a plaintext to a ciphertext
let val1 = rng.next_u32() / 2;
let val2 = rng.next_u32() / 2;

let plaintext = plaintext_int(val1, &params);
let ciphertext = encrypt_int(val2, keypair.as_ref().unwrap(), &params);

let sum = add_plaintext(ciphertext.as_ref().unwrap(), plaintext.as_ref().unwrap());

// Decrypt the sum
let plaintext_res = decrypt(keypair.pin_mut(), &sum);
let pt_u32 = get_element_int(&plaintext_res, 0);
let expected = val1 + val2;

assert_eq!(pt_u32, expected);
}

/// Tests multiplication of a plaintext to a ciphertext
#[test]
fn test_plaintext_multiplication() {
let mut rng = thread_rng();
let (params, mut keypair) = setup_fhe(1, 254);

// Multiply a plaintext to a ciphertext
let range = 0..(2u32.pow(16));
let val1 = rng.gen_range(range.clone());
let val2 = rng.gen_range(range.clone());

let plaintext = plaintext_int(val1, &params);
let ciphertext = encrypt_int(val2, keypair.as_ref().unwrap(), &params);

let product = mul_plaintext(ciphertext.as_ref().unwrap(), plaintext.as_ref().unwrap());

// Decrypt the product
let plaintext_res = decrypt(keypair.pin_mut(), &product);
let pt_u32 = get_element_int(&plaintext_res, 0);
let expected = val1 * val2;

assert_eq!(pt_u32, expected);
}

/// Tests addition of two encrypted values
#[test]
fn test_encrypted_addition() {
Expand All @@ -87,8 +142,8 @@ mod test {
let val1 = rng.next_u32() / 2;
let val2 = rng.next_u32() / 2;

let cipher1 = encrypt_int(&params, &keypair, val1);
let cipher2 = encrypt_int(&params, &keypair, val2);
let cipher1 = encrypt_int(val1, &keypair, &params);
let cipher2 = encrypt_int(val2, &keypair, &params);

let sum = add_ciphertexts(cipher1.as_ref().unwrap(), cipher2.as_ref().unwrap());

Expand All @@ -107,12 +162,12 @@ mod test {
let (params, mut keypair) = setup_fhe(1, 254);

// Multiply two ciphertexts; capped bit length to avoid overflow
let range = (0..(2u32.pow(16)));
let range = 0..(2u32.pow(16));
let val1 = rng.gen_range(range.clone());
let val2 = rng.gen_range(range.clone());

let cipher1 = encrypt_int(&params, &keypair, val1);
let cipher2 = encrypt_int(&params, &keypair, val2);
let cipher1 = encrypt_int(val1, &keypair, &params);
let cipher2 = encrypt_int(val2, &keypair, &params);

let pk = get_pk(&keypair);
let product = mul_ciphertexts(
Expand Down

0 comments on commit 11a5db9

Please sign in to comment.