From 209b90de4229ba74a4a694de5ec4600b68533cea Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Tue, 2 Apr 2024 14:03:10 -0700 Subject: [PATCH] mp-spdz-rs: ffi: Add bindings for ciphertext <> ciphertext operations --- mp-spdz-rs/Cargo.toml | 3 + mp-spdz-rs/build.rs | 6 +- mp-spdz-rs/src/ffi.rs | 131 ++++++++++++++++++++++++++++++++++++++++++ mp-spdz-rs/src/lib.rs | 32 +---------- 4 files changed, 139 insertions(+), 33 deletions(-) create mode 100644 mp-spdz-rs/src/ffi.rs diff --git a/mp-spdz-rs/Cargo.toml b/mp-spdz-rs/Cargo.toml index b6a15cb..93dd495 100644 --- a/mp-spdz-rs/Cargo.toml +++ b/mp-spdz-rs/Cargo.toml @@ -10,3 +10,6 @@ cxx = "1.0" cxx-build = "1.0" itertools = "0.12.0" pkg-config = "0.3" + +[dev-dependencies] +rand = "0.8.4" diff --git a/mp-spdz-rs/build.rs b/mp-spdz-rs/build.rs index 7dfd5e8..8f6603d 100644 --- a/mp-spdz-rs/build.rs +++ b/mp-spdz-rs/build.rs @@ -12,7 +12,7 @@ fn main() { ]; // Build the c++ bridge - cxx_build::bridge("src/lib.rs") + cxx_build::bridge("src/ffi.rs") .files(get_source_files("src/include/MP-SPDZ/FHE")) .files(get_source_files("src/include/MP-SPDZ/FHEOffline")) .files(get_source_files("src/include/MP-SPDZ/Math")) @@ -40,6 +40,9 @@ fn main() { add_link_path("ntl"); link_lib("ntl"); + add_link_path("libsodium"); + link_lib("sodium"); + add_link_path("gmp"); link_lib("gmp"); link_lib("gmpxx"); @@ -60,7 +63,6 @@ fn main() { } // Build cache flags - println!("cargo:rerun-if-changed=src/lib.rs"); println!("cargo:rerun-if-changed=../MP-SPDZ"); } diff --git a/mp-spdz-rs/src/ffi.rs b/mp-spdz-rs/src/ffi.rs new file mode 100644 index 0000000..e29ae4d --- /dev/null +++ b/mp-spdz-rs/src/ffi.rs @@ -0,0 +1,131 @@ +//! The FFI bindings for the MP-SPDZ library + +#[cxx::bridge] +mod ffi_inner { + unsafe extern "C++" { + include!("FHE/FHE_Params.h"); + include!("FHE/FHE_Keys.h"); + include!("FHE/Plaintext.h"); + include!("Math/bigint.h"); + + // `bigint` + type bigint; + fn print(self: &bigint); + + // `FHE_Params` + type FHE_Params; + fn new_fhe_params(n_mults: i32, drown_sec: i32) -> UniquePtr; + fn basic_generation_mod_prime(self: Pin<&mut FHE_Params>, plaintext_length: i32); + fn get_plaintext_mod(params: &FHE_Params) -> UniquePtr; + + // `FHE Keys` + type FHE_KeyPair; + type FHE_PK; + type FHE_SK; + fn new_keypair(params: &FHE_Params) -> UniquePtr; + fn get_pk(keypair: &FHE_KeyPair) -> UniquePtr; + fn get_sk(keypair: &FHE_KeyPair) -> UniquePtr; + fn encrypt(keypair: &FHE_KeyPair, plaintext: &Plaintext_mod_prime) + -> UniquePtr; + fn decrypt( + keypair: Pin<&mut FHE_KeyPair>, + ciphertext: &Ciphertext, + ) -> UniquePtr; + + // `Plaintext` + type Plaintext_mod_prime; + fn new_plaintext(params: &FHE_Params) -> UniquePtr; + fn set_element_int(plaintext: Pin<&mut Plaintext_mod_prime>, idx: usize, value: u32); + fn get_element_int(plaintext: &Plaintext_mod_prime, idx: usize) -> u32; + + // `Ciphertext` + type Ciphertext; + fn add_ciphertexts(c0: &Ciphertext, c1: &Ciphertext) -> UniquePtr; + fn mul_ciphertexts(c0: &Ciphertext, c1: &Ciphertext, pk: &FHE_PK) -> UniquePtr; + } +} +pub use ffi_inner::*; + +#[cfg(test)] +mod test { + use cxx::UniquePtr; + use rand::{thread_rng, Rng, RngCore}; + + use super::*; + + /// Generate a new set of FHE parameters and keypair + fn setup_fhe( + n_mults: i32, + plaintext_length: i32, + ) -> (UniquePtr, UniquePtr) { + let mut params = new_fhe_params(n_mults, 128 /* sec */); + params.pin_mut().basic_generation_mod_prime(plaintext_length); + + let keypair = new_keypair(¶ms); + (params, keypair) + } + + /// Create a ciphertext encrypting a single integer in the zero'th slot + fn encrypt_int( + params: &FHE_Params, + keypair: &FHE_KeyPair, + value: u32, + ) -> UniquePtr { + let mut plaintext = new_plaintext(params); + set_element_int(plaintext.pin_mut(), 0 /* idx */, value); + + encrypt(keypair, &plaintext) + } + + /// Tests addition of two encrypted values + #[test] + fn test_encrypted_addition() { + let mut rng = thread_rng(); + let (params, mut keypair) = setup_fhe(0, 254); + + // Add two ciphertexts, divide by two to avoid overflow + let val1 = rng.next_u32() / 2; + let val2 = rng.next_u32() / 2; + + let cipher1 = encrypt_int(¶ms, &keypair, val1); + let cipher2 = encrypt_int(¶ms, &keypair, val2); + + let sum = add_ciphertexts(cipher1.as_ref().unwrap(), cipher2.as_ref().unwrap()); + + // Decrypt the sum + let plaintext_res = decrypt(keypair.pin_mut(), &sum); + let pt_u32 = get_element_int(&plaintext_res, 0); + let expected = val1 + val2; + + assert_eq!(pt_u32, expected); + } + + /// Tests multiplication of two encrypted values + #[test] + fn test_encrypted_multiplication() { + let mut rng = thread_rng(); + let (params, mut keypair) = setup_fhe(1, 254); + + // Multiply two ciphertexts; capped bit length to avoid overflow + let range = (0..(2u32.pow(16))); + let val1 = rng.gen_range(range.clone()); + let val2 = rng.gen_range(range.clone()); + + let cipher1 = encrypt_int(¶ms, &keypair, val1); + let cipher2 = encrypt_int(¶ms, &keypair, val2); + + let pk = get_pk(&keypair); + let product = mul_ciphertexts( + cipher1.as_ref().unwrap(), + cipher2.as_ref().unwrap(), + pk.as_ref().unwrap(), + ); + + // Decrypt the product + let plaintext_res = decrypt(keypair.pin_mut(), &product); + let pt_u32 = get_element_int(&plaintext_res, 0); + let expected = val1 * val2; + + assert_eq!(pt_u32, expected); + } +} diff --git a/mp-spdz-rs/src/lib.rs b/mp-spdz-rs/src/lib.rs index 6373b83..8ebb0c3 100644 --- a/mp-spdz-rs/src/lib.rs +++ b/mp-spdz-rs/src/lib.rs @@ -4,34 +4,4 @@ //! This library is intended to be a thin wrapper around the MP-SPDZ library, //! and to internalize build and link procedure with the foreign ABI -#[cxx::bridge] -pub mod ffi { - unsafe extern "C++" { - include!("FHE/FHE_Params.h"); - include!("Math/bigint.h"); - - // `bigint` - type bigint; - fn print(self: &bigint); - - // `FHE_Params` - type FHE_Params; - fn new_fhe_params(n_mults: i32, drown_sec: i32) -> UniquePtr; - fn basic_generation_mod_prime(self: Pin<&mut FHE_Params>, plaintext_length: i32); - fn get_plaintext_mod(params: &FHE_Params) -> UniquePtr; - } -} - -#[cfg(test)] -mod test { - use super::ffi::*; - - #[test] - fn test_dummy() { - let mut params = new_fhe_params(0 /* mults */, 128 /* sec */); - params.pin_mut().basic_generation_mod_prime(255 /* bitlength */); - - let plaintext_modulus = get_plaintext_mod(¶ms); - plaintext_modulus.print(); - } -} +pub mod ffi;