diff --git a/mp-spdz-rs/src/ffi.rs b/mp-spdz-rs/src/ffi.rs index 8dcf203..952883b 100644 --- a/mp-spdz-rs/src/ffi.rs +++ b/mp-spdz-rs/src/ffi.rs @@ -59,6 +59,7 @@ mod ffi_inner { // `Plaintext` type Plaintext_mod_prime; fn new_plaintext(params: &FHE_Params) -> UniquePtr; + fn randomize_plaintext(plaintext: Pin<&mut Plaintext_mod_prime>); fn clone(self: &Plaintext_mod_prime) -> UniquePtr; fn to_rust_bytes(self: &Plaintext_mod_prime) -> Vec; fn plaintext_from_rust_bytes( @@ -97,6 +98,11 @@ mod ffi_inner { vector: &PlaintextVector, index: usize, ) -> UniquePtr; + fn set_plaintext_vector_element( + vector: Pin<&mut PlaintextVector>, + index: usize, + plaintext: &Plaintext_mod_prime, + ); fn randomize_plaintext_vector(vector: Pin<&mut PlaintextVector>); fn push_plaintext_vector( vector: Pin<&mut PlaintextVector>, @@ -109,6 +115,7 @@ mod ffi_inner { type Ciphertext; fn clone(self: &Ciphertext) -> UniquePtr; fn to_rust_bytes(self: &Ciphertext) -> Vec; + fn rerandomize(self: Pin<&mut Ciphertext>, pk: &FHE_PK); fn ciphertext_from_rust_bytes(data: &[u8], params: &FHE_Params) -> UniquePtr; fn add_plaintext(c0: &Ciphertext, p1: &Plaintext_mod_prime) -> UniquePtr; @@ -120,10 +127,21 @@ mod ffi_inner { type CiphertextVector; fn new_ciphertext_vector(size: usize, params: &FHE_Params) -> UniquePtr; fn new_ciphertext_vector_single(ciphertext: &Ciphertext) -> UniquePtr; + fn ciphertext_vector_to_rust_bytes(vector: &CiphertextVector) -> Vec; + fn ciphertext_vector_from_rust_bytes( + data: &[u8], + params: &FHE_Params, + ) -> UniquePtr; + fn get_ciphertext_vector_element( vector: &CiphertextVector, index: usize, ) -> UniquePtr; + fn set_ciphertext_vector_element( + vector: Pin<&mut CiphertextVector>, + index: usize, + ciphertext: &Ciphertext, + ); fn push_ciphertext_vector(vector: Pin<&mut CiphertextVector>, ciphertext: &Ciphertext); fn pop_ciphertext_vector(vector: Pin<&mut CiphertextVector>); fn ciphertext_vector_size(vector: &CiphertextVector) -> usize; @@ -156,6 +174,7 @@ unsafe impl Send for CiphertextVector {} unsafe impl Send for CiphertextWithProof {} unsafe impl Send for Plaintext_mod_prime {} unsafe impl Send for PlaintextVector {} +unsafe impl Sync for PlaintextVector {} #[cfg(test)] mod test { diff --git a/mp-spdz-rs/src/fhe/ciphertext.rs b/mp-spdz-rs/src/fhe/ciphertext.rs index c8d6148..3146e7f 100644 --- a/mp-spdz-rs/src/fhe/ciphertext.rs +++ b/mp-spdz-rs/src/fhe/ciphertext.rs @@ -27,6 +27,11 @@ pub struct Ciphertext { } impl Ciphertext { + /// Rerandomize the ciphertext + pub fn rerandomize(&mut self, pk: &BGVPublicKey) { + self.inner.pin_mut().rerandomize(pk.as_ref()); + } + /// Multiply two ciphertexts pub fn mul_ciphertext(&self, other: &Self, pk: &BGVPublicKey) -> Self { ffi::mul_ciphertexts(self.as_ref(), other.as_ref(), pk.as_ref()).into() @@ -111,6 +116,24 @@ impl From> for CiphertextVector< } } +impl AsRef for CiphertextVector { + fn as_ref(&self) -> &ffi::CiphertextVector { + self.inner.as_ref().unwrap() + } +} + +impl ToBytes for CiphertextVector { + fn to_bytes(&self) -> Vec { + ffi::ciphertext_vector_to_rust_bytes(self.as_ref()) + } +} + +impl FromBytesWithParams for CiphertextVector { + fn from_bytes(data: &[u8], params: &BGVParams) -> Self { + ffi::ciphertext_vector_from_rust_bytes(data, params.as_ref()).into() + } +} + impl CiphertextVector { /// Create a new `CiphertextVector` with a specified size pub fn new(size: usize, params: &BGVParams) -> Self { @@ -124,10 +147,15 @@ impl CiphertextVector { } /// Get the size of the `CiphertextVector` - pub fn size(&self) -> usize { + pub fn len(&self) -> usize { ffi::ciphertext_vector_size(self.inner.as_ref().unwrap()) } + /// Whether the vector is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Add a `Ciphertext` to the end of the `CiphertextVector` pub fn push(&mut self, ciphertext: &Ciphertext) { ffi::push_ciphertext_vector(self.inner.pin_mut(), ciphertext.as_ref()); @@ -143,6 +171,11 @@ impl CiphertextVector { let ciphertext = ffi::get_ciphertext_vector_element(self.inner.as_ref().unwrap(), index); Ciphertext::from(ciphertext) } + + /// Set a `Ciphertext` at a specific index in the `CiphertextVector` + pub fn set(&mut self, index: usize, ciphertext: &Ciphertext) { + ffi::set_ciphertext_vector_element(self.inner.pin_mut(), index, ciphertext.as_ref()); + } } // ----------------- diff --git a/mp-spdz-rs/src/fhe/plaintext.rs b/mp-spdz-rs/src/fhe/plaintext.rs index 2617c75..2c70460 100644 --- a/mp-spdz-rs/src/fhe/plaintext.rs +++ b/mp-spdz-rs/src/fhe/plaintext.rs @@ -66,6 +66,11 @@ impl Plaintext { Self { inner, _phantom: PhantomData } } + /// Randomize the plaintext + pub fn randomize(&mut self) { + ffi::randomize_plaintext(self.inner.pin_mut()); + } + /// Get the number of slots in the plaintext pub fn num_slots(&self) -> u32 { self.inner.num_slots() @@ -211,6 +216,11 @@ impl PlaintextVector { let plaintext = ffi::get_plaintext_vector_element(self.inner.as_ref().unwrap(), index); Plaintext::from(plaintext) } + + /// Set a `Plaintext` at a specific index in the `PlaintextVector` + pub fn set(&mut self, index: usize, plaintext: &Plaintext) { + ffi::set_plaintext_vector_element(self.inner.pin_mut(), index, plaintext.as_ref()); + } } // ------------------------------- diff --git a/offline-phase/src/lib.rs b/offline-phase/src/lib.rs index a27fd8d..b7ea548 100644 --- a/offline-phase/src/lib.rs +++ b/offline-phase/src/lib.rs @@ -61,7 +61,7 @@ pub(crate) mod test_helpers { } /// Run a two-party method with a `LowGear` instance setup and in scope - pub async fn mock_lowgear(mut f: F) -> (T, T) + pub async fn mock_lowgear(f: F) -> (T, T) where T: Send + 'static, S: Future + Send + 'static, @@ -77,9 +77,8 @@ pub(crate) mod test_helpers { run_mock_lowgear(f, lowgear1, lowgear2).await } - /// Run a two-party method with a `LowGear` instance, having run keygen and - /// setup - pub async fn mock_lowgear_with_keys(mut f: F) -> (T, T) + /// Run a two-party method with a `LowGear` instance, mocking keygen setup + pub async fn mock_lowgear_with_keys(f: F) -> (T, T) where T: Send + 'static, S: Future + Send + 'static, @@ -103,9 +102,9 @@ pub(crate) mod test_helpers { // Set the local keypairs and mac shares lowgear1.local_keypair = keypair1.clone(); - lowgear1.mac_share = mac_share1.clone(); + lowgear1.mac_share = mac_share1; lowgear2.local_keypair = keypair2.clone(); - lowgear2.mac_share = mac_share2.clone(); + lowgear2.mac_share = mac_share2; // Set the exchanged values lowgear1.other_pk = Some(keypair2.public_key()); diff --git a/offline-phase/src/lowgear/triplets.rs b/offline-phase/src/lowgear/triplets.rs index 957a5d9..5684e85 100644 --- a/offline-phase/src/lowgear/triplets.rs +++ b/offline-phase/src/lowgear/triplets.rs @@ -4,10 +4,15 @@ //! //! These triples are used to define single-round multiplication in the SPDZ //! protocol +//! +//! Follows the protocol detailed in https://eprint.iacr.org/2017/1230.pdf (Figure 7) use ark_ec::CurveGroup; use ark_mpc::network::MpcNetwork; -use mp_spdz_rs::fhe::{ciphertext::CiphertextPoK, plaintext::PlaintextVector}; +use mp_spdz_rs::fhe::{ + ciphertext::{CiphertextPoK, CiphertextVector}, + plaintext::{Plaintext, PlaintextVector}, +}; use crate::error::LowGearError; @@ -24,13 +29,112 @@ impl + Unpin> LowGear { let c = &a * &b; // Encrypt `a` and send it to the counterparty - let my_proof = self.local_keypair.encrypt_and_prove_vector(&mut a); + let other_a_enc = self.exchange_a_values(&mut a).await?; + + // Generate shares of the product and exchange + let c_shares = self.share_product(other_a_enc, &b, c).await?; + + Ok(()) + } + + /// Exchange encryptions of the `a` value + /// + /// Returns the counterparty's encryption of `a` + async fn exchange_a_values( + &mut self, + a: &mut PlaintextVector, + ) -> Result, LowGearError> { + // Encrypt `a` and send it to the counterparty + let my_proof = self.local_keypair.encrypt_and_prove_vector(a); self.send_message(my_proof).await?; let mut other_proof: CiphertextPoK = self.receive_message().await?; let other_pk = self.other_pk.as_ref().expect("setup not run"); let other_a_enc = other_pk.verify_proof(&mut other_proof); + Ok(other_a_enc) + } + + /// Create shares of the product `c` by exchanging homomorphically evaluated + /// encryptions of `my_b * other_a` + async fn share_product( + &mut self, + other_enc_a: CiphertextVector, + my_b_share: &PlaintextVector, + my_c_share: PlaintextVector, + ) -> Result, LowGearError> { + let mut c_res = my_c_share; + + // Compute the cross products then share them with the counterparty and compute + // local shares of `c` + let cross_products = + self.compute_triplet_cross_products(&other_enc_a, my_b_share, &mut c_res); + self.exchange_cross_products(cross_products, &mut c_res).await?; + + Ok(c_res) + } + + /// Compute the cross products in the triplet generation + fn compute_triplet_cross_products( + &mut self, + other_a: &CiphertextVector, + my_b: &PlaintextVector, + my_c: &mut PlaintextVector, + ) -> CiphertextVector { + let n = other_a.len(); + let mut cross_products = CiphertextVector::new(n, &self.params); + + // Compute the cross products of the local party's `b` share and the encryption + // of the counterparty's `a` share + for i in 0..n { + let a_enc = other_a.get(i); + let b = my_b.get(i); + let c = my_c.get(i); + + // Compute the product of `my_b` and `other_enc_a` + let mut product = &a_enc * &b; + + // Rerandomize the product to add drowning noise and mask it with a random value + product.rerandomize(self.other_pk.as_ref().unwrap()); + let mut mask = Plaintext::new(&self.params); + mask.randomize(); + + let masked_product = &product + &mask; + + // Subtract the masked product from our share + let my_share = &c - &mask; + my_c.set(i, &my_share); + cross_products.set(i, &masked_product); + } + + cross_products + } + + /// Exchange cross products and compute final shares of `c` + async fn exchange_cross_products( + &mut self, + cross_products: CiphertextVector, + my_c_share: &mut PlaintextVector, + ) -> Result<(), LowGearError> { + let n = cross_products.len(); + + // Send and receive cross products to/from the counterparty + self.send_message(cross_products).await?; + let other_cross_products: CiphertextVector = self.receive_message().await?; + + // Add each cross product to the local party's share of `c` + for i in 0..n { + let cross_product = other_cross_products.get(i); + let c = my_c_share.get(i); + + // Decrypt the term + let cross_product = self.local_keypair.decrypt(&cross_product); + + // Add the cross product to the local party's share of `c` + let my_share = &c + &cross_product; + my_c_share.set(i, &my_share); + } + Ok(()) } }