From a6e9482c04d29593e6d251981fd593964bb849b4 Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Thu, 11 Apr 2024 15:01:22 -0700 Subject: [PATCH] offline-phase: lowgear: triplets: Implement triplet authentication --- mp-spdz-rs/src/ffi.rs | 4 + mp-spdz-rs/src/fhe/plaintext.rs | 3 + mp-spdz-rs/src/lib.rs | 2 +- offline-phase/src/lib.rs | 22 ++- offline-phase/src/lowgear/mod.rs | 44 +++++- offline-phase/src/lowgear/setup.rs | 8 +- offline-phase/src/lowgear/triplets.rs | 202 ++++++++++++++++++++++++-- 7 files changed, 259 insertions(+), 26 deletions(-) diff --git a/mp-spdz-rs/src/ffi.rs b/mp-spdz-rs/src/ffi.rs index 952883b..eac5cca 100644 --- a/mp-spdz-rs/src/ffi.rs +++ b/mp-spdz-rs/src/ffi.rs @@ -169,9 +169,13 @@ pub use ffi_inner::*; unsafe impl Send for FHE_Params {} unsafe impl Send for FHE_KeyPair {} unsafe impl Send for FHE_PK {} +unsafe impl Sync for FHE_PK {} unsafe impl Send for Ciphertext {} +unsafe impl Sync for Ciphertext {} unsafe impl Send for CiphertextVector {} +unsafe impl Sync for CiphertextVector {} unsafe impl Send for CiphertextWithProof {} +unsafe impl Sync for CiphertextWithProof {} unsafe impl Send for Plaintext_mod_prime {} unsafe impl Send for PlaintextVector {} unsafe impl Sync for PlaintextVector {} diff --git a/mp-spdz-rs/src/fhe/plaintext.rs b/mp-spdz-rs/src/fhe/plaintext.rs index 2c70460..08f9be7 100644 --- a/mp-spdz-rs/src/fhe/plaintext.rs +++ b/mp-spdz-rs/src/fhe/plaintext.rs @@ -231,6 +231,7 @@ impl Add for &PlaintextVector { type Output = PlaintextVector; fn add(self, other: Self) -> Self::Output { + assert_eq!(self.len(), other.len(), "Vectors must be the same length"); let mut result = PlaintextVector::empty(); for i in 0..self.len() { let element = &self.get(i) + &other.get(i); @@ -244,6 +245,7 @@ impl Sub for &PlaintextVector { type Output = PlaintextVector; fn sub(self, other: Self) -> Self::Output { + assert_eq!(self.len(), other.len(), "Vectors must be the same length"); let mut result = PlaintextVector::empty(); for i in 0..self.len() { let element = &self.get(i) - &other.get(i); @@ -257,6 +259,7 @@ impl Mul for &PlaintextVector { type Output = PlaintextVector; fn mul(self, other: Self) -> Self::Output { + assert_eq!(self.len(), other.len(), "Vectors must be the same length"); let mut result = PlaintextVector::empty(); for i in 0..self.len() { let element = &self.get(i) * &other.get(i); diff --git a/mp-spdz-rs/src/lib.rs b/mp-spdz-rs/src/lib.rs index eedbfe1..1f4d7b3 100644 --- a/mp-spdz-rs/src/lib.rs +++ b/mp-spdz-rs/src/lib.rs @@ -10,7 +10,7 @@ #![feature(inherent_associated_types)] #![feature(stmt_expr_attributes)] -pub mod ffi; +mod ffi; pub mod fhe; /// A trait for serializing to bytes diff --git a/offline-phase/src/lib.rs b/offline-phase/src/lib.rs index b7ea548..69fd9ff 100644 --- a/offline-phase/src/lib.rs +++ b/offline-phase/src/lib.rs @@ -48,6 +48,14 @@ pub(crate) mod test_helpers { pt } + /// Get a plaintext with a single value in all slots + pub fn plaintext_all(val: Scalar, params: &BGVParams) -> Plaintext { + let mut pt = Plaintext::new(params); + pt.set_all(val); + + pt + } + /// Encrypt a single value using the BGV cryptosystem /// /// Places the element in the zeroth slot of the plaintext @@ -60,6 +68,16 @@ pub(crate) mod test_helpers { key.encrypt(&pt) } + /// Encrypt a single value in all slots of a plaintext + pub fn encrypt_all( + val: Scalar, + key: &BGVPublicKey, + params: &BGVParams, + ) -> Ciphertext { + let pt = plaintext_all(val, params); + key.encrypt(&pt) + } + /// Run a two-party method with a `LowGear` instance setup and in scope pub async fn mock_lowgear(f: F) -> (T, T) where @@ -108,9 +126,9 @@ pub(crate) mod test_helpers { // Set the exchanged values lowgear1.other_pk = Some(keypair2.public_key()); - lowgear1.other_mac_enc = Some(encrypt_val(mac_share2, &keypair2.public_key(), ¶ms)); + lowgear1.other_mac_enc = Some(encrypt_all(mac_share2, &keypair2.public_key(), ¶ms)); lowgear2.other_pk = Some(keypair1.public_key()); - lowgear2.other_mac_enc = Some(encrypt_val(mac_share1, &keypair1.public_key(), ¶ms)); + lowgear2.other_mac_enc = Some(encrypt_all(mac_share1, &keypair1.public_key(), ¶ms)); run_mock_lowgear(f, lowgear1, lowgear2).await } diff --git a/offline-phase/src/lowgear/mod.rs b/offline-phase/src/lowgear/mod.rs index ac5c9ea..bfdc84c 100644 --- a/offline-phase/src/lowgear/mod.rs +++ b/offline-phase/src/lowgear/mod.rs @@ -13,9 +13,10 @@ use ark_mpc::{ use futures::{SinkExt, StreamExt}; use mp_spdz_rs::{ fhe::{ - ciphertext::Ciphertext, + ciphertext::{Ciphertext, CiphertextVector}, keys::{BGVKeypair, BGVPublicKey}, params::BGVParams, + plaintext::{Plaintext, PlaintextVector}, }, FromBytesWithParams, ToBytes, }; @@ -37,6 +38,8 @@ pub struct LowGear> { pub other_mac_enc: Option>, /// The Beaver triples generated during the offline phase pub triples: Vec<(Scalar, Scalar, Scalar)>, + /// The mac values for the triples generated during the offline phase + pub triple_macs: Vec<(Scalar, Scalar, Scalar)>, /// A reference to the underlying network connection pub network: N, } @@ -57,6 +60,7 @@ impl + Unpin> LowGear { other_pk: None, other_mac_enc: None, triples: vec![], + triple_macs: vec![], network, } } @@ -72,8 +76,40 @@ impl + Unpin> LowGear { }) } + /// Get a plaintext with the local mac share in all slots + pub fn get_mac_plaintext(&self) -> Plaintext { + let mut pt = Plaintext::new(&self.params); + pt.set_all(self.mac_share); + + pt + } + + /// Get a plaintext vector wherein each plaintext created with the local mac + /// share in all slots + pub fn get_mac_plaintext_vector(&self, n: usize) -> PlaintextVector { + let mut vec = PlaintextVector::new(n, &self.params); + let pt = self.get_mac_plaintext(); + for i in 0..n { + vec.set(i, &pt); + } + + vec + } + + /// Get a ciphertext vector wherein each ciphertext is an encryption of the + /// counterparty's mac share + pub fn get_other_mac_enc(&self, n: usize) -> CiphertextVector { + let mut vec = CiphertextVector::new(n, &self.params); + let ct = self.other_mac_enc.as_ref().unwrap(); + for i in 0..n { + vec.set(i, ct); + } + + vec + } + /// Send a message to the counterparty - pub async fn send_message(&mut self, message: T) -> Result<(), LowGearError> { + pub async fn send_message(&mut self, message: &T) -> Result<(), LowGearError> { let payload = NetworkPayload::::Bytes(message.to_bytes()); let msg = NetworkOutbound { result_id: 0, payload }; @@ -145,12 +181,12 @@ mod test { let party = lowgear.network.party_id(); if party == PARTY0 { let msg = TestMessage(MSG1.to_string()); - lowgear.send_message(msg).await.unwrap(); + lowgear.send_message(&msg).await.unwrap(); lowgear.receive_message::().await.unwrap() } else { let msg = TestMessage(MSG2.to_string()); let recv = lowgear.receive_message::().await.unwrap(); - lowgear.send_message(msg).await.unwrap(); + lowgear.send_message(&msg).await.unwrap(); recv } diff --git a/offline-phase/src/lowgear/setup.rs b/offline-phase/src/lowgear/setup.rs index f727c83..62960f4 100644 --- a/offline-phase/src/lowgear/setup.rs +++ b/offline-phase/src/lowgear/setup.rs @@ -12,7 +12,7 @@ impl + Unpin> LowGear { /// Exchange BGV public keys and mac shares with the counterparty pub async fn run_key_exchange(&mut self) -> Result<(), LowGearError> { // First, share the public key - self.send_message(self.local_keypair.public_key()).await?; + self.send_message(&self.local_keypair.public_key()).await?; let counterparty_pk: BGVPublicKey = self.receive_message().await?; // Encrypt my mac share under my public key @@ -21,7 +21,7 @@ impl + Unpin> LowGear { let ct = self.local_keypair.encrypt_and_prove(&pt); // Send and receive - self.send_message(ct).await?; + self.send_message(&ct).await?; let mut counterparty_mac_pok: CiphertextPoK = self.receive_message().await?; let counterparty_mac_enc = counterparty_pk.verify_proof(&mut counterparty_mac_pok); @@ -60,7 +60,7 @@ mod test { let encrypted_val = encrypt_val(my_val, lowgear.other_pk.as_ref().unwrap(), &lowgear.params); - lowgear.send_message(encrypted_val).await.unwrap(); + lowgear.send_message(&encrypted_val).await.unwrap(); let received_val: Ciphertext = lowgear.receive_message().await.unwrap(); let decrypted_val = lowgear.local_keypair.decrypt(&received_val); @@ -72,7 +72,7 @@ mod test { let ct = lowgear.other_mac_enc.as_ref().unwrap() * &pt; // Send the result to the other party - lowgear.send_message(ct).await.unwrap(); + lowgear.send_message(&ct).await.unwrap(); let received_val: Ciphertext = lowgear.receive_message().await.unwrap(); let decrypted_val = lowgear.local_keypair.decrypt(&received_val); diff --git a/offline-phase/src/lowgear/triplets.rs b/offline-phase/src/lowgear/triplets.rs index 2afa193..1685e7f 100644 --- a/offline-phase/src/lowgear/triplets.rs +++ b/offline-phase/src/lowgear/triplets.rs @@ -32,7 +32,10 @@ impl + Unpin> LowGear { let other_a_enc = self.exchange_a_values(&mut a).await?; // Generate shares of the product and exchange - let c_shares = self.share_product(other_a_enc, &b, c).await?; + let c_shares = self.share_product(&other_a_enc, &b, c).await?; + + // Authenticate the triplets + let (a_mac, b_mac, c_mac) = self.authenticate_triplets(&a, &b, &c_shares).await?; // Increase the size of self.triples by self.params.ciphertext_pok_batch_size self.triples.reserve(self.params.ciphertext_pok_batch_size()); @@ -40,13 +43,20 @@ impl + Unpin> LowGear { let plaintext_a = a.get(pt_idx); let plaintext_b = b.get(pt_idx); let plaintext_c = c_shares.get(pt_idx); + let mac_pt_a = a_mac.get(pt_idx); + let mac_pt_b = b_mac.get(pt_idx); + let mac_pt_c = c_mac.get(pt_idx); for slot_idx in 0..plaintext_a.num_slots() as usize { let a = plaintext_a.get_element(slot_idx); let b = plaintext_b.get_element(slot_idx); let c = plaintext_c.get_element(slot_idx); + let mac_a = mac_pt_a.get_element(slot_idx); + let mac_b = mac_pt_b.get_element(slot_idx); + let mac_c = mac_pt_c.get_element(slot_idx); self.triples.push((a, b, c)); + self.triple_macs.push((mac_a, mac_b, mac_c)); } } @@ -62,7 +72,7 @@ impl + Unpin> LowGear { ) -> Result, LowGearError> { // Encrypt `a` and send it to the counterparty let my_proof = self.local_keypair.encrypt_and_prove_vector(a); - self.send_message(my_proof).await?; + self.send_message(&my_proof).await?; let mut other_proof: CiphertextPoK = self.receive_message().await?; let other_pk = self.other_pk.as_ref().expect("setup not run"); @@ -71,11 +81,37 @@ impl + Unpin> LowGear { Ok(other_a_enc) } + /// Authenticate triplets with the counterparty + /// + /// Returns the mac shares for each triplet + pub(crate) async fn authenticate_triplets( + &mut self, + a: &PlaintextVector, + b: &PlaintextVector, + c: &PlaintextVector, + ) -> Result<(PlaintextVector, PlaintextVector, PlaintextVector), LowGearError> { + let n = a.len(); + + // Multiply my share of `\alpha` with each share + let mac_vec = self.get_mac_plaintext_vector(n); + let a_macs = &mac_vec * a; + let b_macs = &mac_vec * b; + let c_macs = &mac_vec * c; + + // Compute cross terms, share them, then sum into the result + let other_mac_enc = self.get_other_mac_enc(n); + let a_macs = self.share_product(&other_mac_enc, a, a_macs).await?; + let b_macs = self.share_product(&other_mac_enc, b, b_macs).await?; + let c_macs = self.share_product(&other_mac_enc, c, c_macs).await?; + + Ok((a_macs, b_macs, c_macs)) + } + /// Create shares of the product `c` by exchanging homomorphically evaluated /// encryptions of `my_b * other_a` async fn share_product( &mut self, - other_enc_a: CiphertextVector, + other_enc_a: &CiphertextVector, my_b_share: &PlaintextVector, my_c_share: PlaintextVector, ) -> Result, LowGearError> { @@ -83,15 +119,14 @@ impl + Unpin> LowGear { // Compute the cross products then share them with the counterparty and compute // local shares of `c` - let cross_products = - self.compute_triplet_cross_products(&other_enc_a, my_b_share, &mut c_res); + let cross_products = self.compute_cross_products(other_enc_a, my_b_share, &mut c_res); self.exchange_cross_products(cross_products, &mut c_res).await?; Ok(c_res) } /// Compute the cross products in the triplet generation - fn compute_triplet_cross_products( + fn compute_cross_products( &mut self, other_a: &CiphertextVector, my_b: &PlaintextVector, @@ -135,7 +170,7 @@ impl + Unpin> LowGear { let n = cross_products.len(); // Send and receive cross products to/from the counterparty - self.send_message(cross_products).await?; + self.send_message(&cross_products).await?; let other_cross_products: CiphertextVector = self.receive_message().await?; // Add each cross product to the local party's share of `c` @@ -157,10 +192,107 @@ impl + Unpin> LowGear { #[cfg(test)] mod test { - use ark_mpc::algebra::Scalar; + use ark_mpc::{ + algebra::Scalar, + network::{MpcNetwork, NetworkPayload}, + }; use itertools::izip; + use mp_spdz_rs::fhe::{ + params::BGVParams, + plaintext::{Plaintext, PlaintextVector}, + }; + use rand::rngs::OsRng; + + use crate::{ + error::LowGearError, + lowgear::LowGear, + test_helpers::{mock_lowgear_with_keys, TestCurve}, + }; + + // ----------- + // | Helpers | + // ----------- + + /// Generate a vector of random scalar values + fn random_scalars(n: usize) -> Vec> { + (0..n).map(|_| Scalar::random(&mut OsRng)).collect() + } - use crate::test_helpers::{mock_lowgear_with_keys, TestCurve}; + /// Generate a plaintext vector with a single element from a vector of + /// scalars + fn scalars_to_plaintext_vec( + scalars: &[Scalar], + params: &BGVParams, + ) -> PlaintextVector { + let mut pt = Plaintext::new(params); + for (i, s) in scalars.iter().enumerate() { + pt.set_element(i, *s); + } + + PlaintextVector::from(&pt) + } + + /// Get a vector of scalars from a plaintext vector + fn plaintext_vec_to_scalars(pt_vec: &PlaintextVector) -> Vec> { + if pt_vec.is_empty() { + return vec![]; + } + + let n = pt_vec.len(); + let slots = pt_vec.get(0).num_slots() as usize; + let mut vec = Vec::with_capacity(n * slots); + + for i in 0..n { + let pt = pt_vec.get(i); + for j in 0..slots { + vec.push(pt.get_element(j)); + } + } + + vec + } + + /// Send and receive a payload between two `LowGear` instances + async fn send_receive_payload( + my_val: T, + lowgear: &mut LowGear, + ) -> Result + where + T: Into> + From> + Send + 'static, + N: MpcNetwork + Unpin + Send, + { + lowgear.send_network_payload(my_val).await?; + let their_val: T = lowgear.receive_network_payload().await?; + + Ok(their_val) + } + + /// Verify the macs on a set of values given the opened shares from both + /// parties + fn verify_macs( + my_share: &[Scalar], + their_share: &[Scalar], + my_mac: &[Scalar], + their_mac: &[Scalar], + mac_key: Scalar, + ) { + let n = my_share.len(); + assert_eq!(their_share.len(), n); + assert_eq!(my_mac.len(), n); + assert_eq!(their_mac.len(), n); + + for (a1, a2, mac1, mac2) in izip!(my_share, their_share, my_mac, their_mac) { + let val = a1 + a2; + let expected = mac_key * val; + let actual = mac1 + mac2; + + assert_eq!(expected, actual); + } + } + + // --------- + // | Tests | + // --------- /// Tests the basic triplet generation flow #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -179,12 +311,9 @@ mod test { my_c.push(*c); } - lowgear.send_network_payload(my_a.clone()).await.unwrap(); - lowgear.send_network_payload(my_b.clone()).await.unwrap(); - lowgear.send_network_payload(my_c.clone()).await.unwrap(); - let their_a: Vec> = lowgear.receive_network_payload().await.unwrap(); - let their_b: Vec> = lowgear.receive_network_payload().await.unwrap(); - let their_c: Vec> = lowgear.receive_network_payload().await.unwrap(); + let their_a = send_receive_payload(my_a.clone(), &mut lowgear).await.unwrap(); + let their_b = send_receive_payload(my_b.clone(), &mut lowgear).await.unwrap(); + let their_c = send_receive_payload(my_c.clone(), &mut lowgear).await.unwrap(); // Add together all the shares to get the final values for (a_1, a_2, b_1, b_2, c_1, c_2) in izip!( @@ -204,4 +333,47 @@ mod test { }) .await; } + + /// Tests authenticating the triples in a batch + #[tokio::test] + async fn test_triplet_auth() { + // The number of plaintext vectors to test + mock_lowgear_with_keys(|mut lowgear| async move { + // Generate values for the triplets + let n_slots = lowgear.params.plaintext_slots() as usize; + let my_a = random_scalars(n_slots); + let my_b = random_scalars(n_slots); + let my_c = random_scalars(n_slots); + + // Convert to plaintexts + let a = scalars_to_plaintext_vec(&my_a, &lowgear.params); + let b = scalars_to_plaintext_vec(&my_b, &lowgear.params); + let c = scalars_to_plaintext_vec(&my_c, &lowgear.params); + + // Authenticate the triplets + let (a_mac, b_mac, c_mac) = lowgear.authenticate_triplets(&a, &b, &c).await.unwrap(); + let a_mac = plaintext_vec_to_scalars(&a_mac); + let b_mac = plaintext_vec_to_scalars(&b_mac); + let c_mac = plaintext_vec_to_scalars(&c_mac); + + // Share the scalars, macs, and mac keys with the counterparty then verify + let their_a = send_receive_payload(my_a.clone(), &mut lowgear).await.unwrap(); + let their_b = send_receive_payload(my_b.clone(), &mut lowgear).await.unwrap(); + let their_c = send_receive_payload(my_c.clone(), &mut lowgear).await.unwrap(); + + let their_a_mac = &send_receive_payload(a_mac.clone(), &mut lowgear).await.unwrap(); + let their_b_mac = &send_receive_payload(b_mac.clone(), &mut lowgear).await.unwrap(); + let their_c_mac = &send_receive_payload(c_mac.clone(), &mut lowgear).await.unwrap(); + + let their_mac_key = + send_receive_payload(lowgear.mac_share, &mut lowgear).await.unwrap(); + let mac_key = lowgear.mac_share + their_mac_key; + + // Verify the macs + verify_macs(&my_a, &their_a, &a_mac, their_a_mac, mac_key); + verify_macs(&my_b, &their_b, &b_mac, their_b_mac, mac_key); + verify_macs(&my_c, &their_c, &c_mac, their_c_mac, mac_key); + }) + .await; + } }