From c59c6e8a44503b47548276447663579e58864d57 Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Tue, 9 Apr 2024 15:58:49 -0700 Subject: [PATCH] offline-phase: lowgear: triplets: Implement initial triplet gen phase This involves each party generating local shares of the triplet, encrypting their `a` value, generating a ciphertext PoK, and sending it to the counterparty. The counterparty then verifies the PoK. --- mp-spdz-rs/src/ffi.rs | 4 ++ mp-spdz-rs/src/fhe/params.rs | 5 ++ mp-spdz-rs/src/fhe/plaintext.rs | 84 ++++++++++++++++++++++++++- offline-phase/src/lib.rs | 59 ++++++++++++++++++- offline-phase/src/lowgear/mod.rs | 1 + offline-phase/src/lowgear/triplets.rs | 50 ++++++++++++++++ 6 files changed, 199 insertions(+), 4 deletions(-) create mode 100644 offline-phase/src/lowgear/triplets.rs diff --git a/mp-spdz-rs/src/ffi.rs b/mp-spdz-rs/src/ffi.rs index b8a3c3a..8dcf203 100644 --- a/mp-spdz-rs/src/ffi.rs +++ b/mp-spdz-rs/src/ffi.rs @@ -87,10 +87,12 @@ mod ffi_inner { // `PlaintextVector` type PlaintextVector; + fn new_empty_plaintext_vector() -> UniquePtr; fn new_plaintext_vector(size: usize, params: &FHE_Params) -> UniquePtr; fn new_plaintext_vector_single( plaintext: &Plaintext_mod_prime, ) -> UniquePtr; + fn random_plaintext_vector(size: usize, params: &FHE_Params) -> UniquePtr; fn get_plaintext_vector_element( vector: &PlaintextVector, index: usize, @@ -150,8 +152,10 @@ unsafe impl Send for FHE_Params {} unsafe impl Send for FHE_KeyPair {} unsafe impl Send for FHE_PK {} unsafe impl Send for Ciphertext {} +unsafe impl Send for CiphertextVector {} unsafe impl Send for CiphertextWithProof {} unsafe impl Send for Plaintext_mod_prime {} +unsafe impl Send for PlaintextVector {} #[cfg(test)] mod test { diff --git a/mp-spdz-rs/src/fhe/params.rs b/mp-spdz-rs/src/fhe/params.rs index ab7e1ae..dcb7d89 100644 --- a/mp-spdz-rs/src/fhe/params.rs +++ b/mp-spdz-rs/src/fhe/params.rs @@ -60,6 +60,11 @@ impl BGVParams { pub fn plaintext_slots(&self) -> u32 { self.as_ref().n_plaintext_slots() } + + /// Get the number of ciphertexts that may be proven together + pub fn ciphertext_pok_batch_size(&self) -> usize { + (self.plaintext_slots() as usize) * (DEFAULT_DROWN_SEC as usize) + } } impl Serialize for BGVParams { diff --git a/mp-spdz-rs/src/fhe/plaintext.rs b/mp-spdz-rs/src/fhe/plaintext.rs index 5218d21..2617c75 100644 --- a/mp-spdz-rs/src/fhe/plaintext.rs +++ b/mp-spdz-rs/src/fhe/plaintext.rs @@ -12,7 +12,11 @@ use cxx::UniquePtr; use crate::{ffi, FromBytesWithParams, ToBytes}; -use super::{ffi_bigint_to_scalar, params::BGVParams, scalar_to_ffi_bigint}; +use super::{ + ffi_bigint_to_scalar, + params::{BGVParams, DEFAULT_DROWN_SEC}, + scalar_to_ffi_bigint, +}; /// A plaintext in the BGV implementation /// @@ -142,20 +146,51 @@ impl From> for PlaintextVector impl PlaintextVector { /// Create a new `PlaintextVector` with a specified size pub fn new(size: usize, params: &BGVParams) -> Self { - let inner = crate::ffi::new_plaintext_vector(size, params.as_ref()); + let inner = ffi::new_plaintext_vector(size, params.as_ref()); + Self { inner, _phantom: PhantomData } + } + + /// Create a new empty `PlaintextVector` + pub fn empty() -> Self { + Self { inner: ffi::new_empty_plaintext_vector(), _phantom: PhantomData } + } + + /// Generate a random `PlaintextVector` with a specified size + pub fn random(size: usize, params: &BGVParams) -> Self { + let inner = ffi::random_plaintext_vector(size, params.as_ref()); Self { inner, _phantom: PhantomData } } + /// Get the total number of slots in the `PlaintextVector` + pub fn total_slots(&self) -> usize { + if self.is_empty() { + 0 + } else { + self.get(0).num_slots() as usize * self.len() + } + } + + /// Generate a random `PlaintextVector` of size equal to the batching width + /// of the plaintext PoK proof system + pub fn random_pok_batch(params: &BGVParams) -> Self { + Self::random(DEFAULT_DROWN_SEC as usize, params) + } + /// Get a pinned mutable reference to the inner `PlaintextVector` pub fn pin_mut(&mut self) -> Pin<&mut ffi::PlaintextVector> { self.inner.pin_mut() } /// Get the size of the `PlaintextVector` - pub fn size(&self) -> usize { + pub fn len(&self) -> usize { ffi::plaintext_vector_size(self.inner.as_ref().unwrap()) } + /// Whether the vector is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Add a `Plaintext` to the end of the `PlaintextVector` pub fn push(&mut self, plaintext: &Plaintext) { ffi::push_plaintext_vector(self.inner.pin_mut(), plaintext.as_ref()); @@ -178,6 +213,49 @@ impl PlaintextVector { } } +// ------------------------------- +// | Plaintext Vector Arithmetic | +// ------------------------------- + +impl Add for &PlaintextVector { + type Output = PlaintextVector; + + fn add(self, other: Self) -> Self::Output { + let mut result = PlaintextVector::empty(); + for i in 0..self.len() { + let element = &self.get(i) + &other.get(i); + result.push(&element); + } + result + } +} + +impl Sub for &PlaintextVector { + type Output = PlaintextVector; + + fn sub(self, other: Self) -> Self::Output { + let mut result = PlaintextVector::empty(); + for i in 0..self.len() { + let element = &self.get(i) - &other.get(i); + result.push(&element); + } + result + } +} + +impl Mul for &PlaintextVector { + type Output = PlaintextVector; + + fn mul(self, other: Self) -> Self::Output { + let mut result = PlaintextVector::empty(); + for i in 0..self.len() { + let element = &self.get(i) * &other.get(i); + result.push(&element); + } + result + } +} + #[cfg(test)] mod tests { use rand::thread_rng; diff --git a/offline-phase/src/lib.rs b/offline-phase/src/lib.rs index c8b7be4..a27fd8d 100644 --- a/offline-phase/src/lib.rs +++ b/offline-phase/src/lib.rs @@ -28,8 +28,12 @@ pub(crate) mod test_helpers { }; use futures::Future; use mp_spdz_rs::fhe::{ - ciphertext::Ciphertext, keys::BGVPublicKey, params::BGVParams, plaintext::Plaintext, + ciphertext::Ciphertext, + keys::{BGVKeypair, BGVPublicKey}, + params::BGVParams, + plaintext::Plaintext, }; + use rand::thread_rng; use crate::lowgear::LowGear; @@ -70,6 +74,59 @@ pub(crate) mod test_helpers { let lowgear1 = LowGear::new(net1); let lowgear2 = LowGear::new(net2); + run_mock_lowgear(f, lowgear1, lowgear2).await + } + + /// Run a two-party method with a `LowGear` instance, having run keygen and + /// setup + pub async fn mock_lowgear_with_keys(mut f: F) -> (T, T) + where + T: Send + 'static, + S: Future + Send + 'static, + F: FnMut(LowGear>) -> S, + { + let mut rng = thread_rng(); + let (stream1, stream2) = UnboundedDuplexStream::new_duplex_pair(); + let net1 = MockNetwork::new(PARTY0, stream1); + let net2 = MockNetwork::new(PARTY1, stream2); + + let mut lowgear1 = LowGear::new(net1); + let mut lowgear2 = LowGear::new(net2); + + // Setup the lowgear instances + let params = BGVParams::new_no_mults(); + let keypair1 = BGVKeypair::gen(¶ms); + let keypair2 = BGVKeypair::gen(¶ms); + + let mac_share1 = Scalar::random(&mut rng); + let mac_share2 = Scalar::random(&mut rng); + + // Set the local keypairs and mac shares + lowgear1.local_keypair = keypair1.clone(); + lowgear1.mac_share = mac_share1.clone(); + lowgear2.local_keypair = keypair2.clone(); + lowgear2.mac_share = mac_share2.clone(); + + // Set the exchanged values + lowgear1.other_pk = Some(keypair2.public_key()); + lowgear1.other_mac_enc = Some(encrypt_val(mac_share2, &keypair2.public_key(), ¶ms)); + lowgear2.other_pk = Some(keypair1.public_key()); + lowgear2.other_mac_enc = Some(encrypt_val(mac_share1, &keypair1.public_key(), ¶ms)); + + run_mock_lowgear(f, lowgear1, lowgear2).await + } + + /// Run a two-party protocol using the given `LowGear` instances + pub async fn run_mock_lowgear( + mut f: F, + lowgear1: LowGear>, + lowgear2: LowGear>, + ) -> (T, T) + where + T: Send + 'static, + S: Future + Send + 'static, + F: FnMut(LowGear>) -> S, + { let task1 = tokio::spawn(f(lowgear1)); let task2 = tokio::spawn(f(lowgear2)); let party0_out = task1.await.unwrap(); diff --git a/offline-phase/src/lowgear/mod.rs b/offline-phase/src/lowgear/mod.rs index a49fda8..829a2af 100644 --- a/offline-phase/src/lowgear/mod.rs +++ b/offline-phase/src/lowgear/mod.rs @@ -2,6 +2,7 @@ //! keys, authenticating inputs, etc pub mod setup; +pub mod triplets; use ark_ec::CurveGroup; use ark_mpc::{ diff --git a/offline-phase/src/lowgear/triplets.rs b/offline-phase/src/lowgear/triplets.rs new file mode 100644 index 0000000..957a5d9 --- /dev/null +++ b/offline-phase/src/lowgear/triplets.rs @@ -0,0 +1,50 @@ +//! Defines the logic for generating shared triples (a, b, c) which satisfy the +//! identity: +//! a * b = c +//! +//! These triples are used to define single-round multiplication in the SPDZ +//! protocol + +use ark_ec::CurveGroup; +use ark_mpc::network::MpcNetwork; +use mp_spdz_rs::fhe::{ciphertext::CiphertextPoK, plaintext::PlaintextVector}; + +use crate::error::LowGearError; + +use super::LowGear; + +impl + Unpin> LowGear { + /// Generate a single batch of shared triples + pub async fn generate_triples(&mut self) -> Result<(), LowGearError> { + // First step; generate random values a and b + let mut a = PlaintextVector::random_pok_batch(&self.params); + let b = PlaintextVector::random_pok_batch(&self.params); + + // Compute a plaintext multiplication + let c = &a * &b; + + // Encrypt `a` and send it to the counterparty + let my_proof = self.local_keypair.encrypt_and_prove_vector(&mut a); + self.send_message(my_proof).await?; + let mut other_proof: CiphertextPoK = self.receive_message().await?; + + let other_pk = self.other_pk.as_ref().expect("setup not run"); + let other_a_enc = other_pk.verify_proof(&mut other_proof); + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use crate::test_helpers::mock_lowgear_with_keys; + + /// Tests the basic triplet generation flow + #[tokio::test] + async fn test_triplet_gen() { + mock_lowgear_with_keys(|mut lowgear| async move { + lowgear.generate_triples().await.unwrap(); + }) + .await; + } +}