diff --git a/mp-spdz-rs/src/fhe/plaintext.rs b/mp-spdz-rs/src/fhe/plaintext.rs index dd65da6..14512b8 100644 --- a/mp-spdz-rs/src/fhe/plaintext.rs +++ b/mp-spdz-rs/src/fhe/plaintext.rs @@ -76,6 +76,18 @@ impl Plaintext { self.inner.num_slots() as usize } + /// Create a plaintext given a batch of scalars + pub fn from_scalars(scalars: &[Scalar], params: &BGVParams) -> Self { + assert!(scalars.len() < params.plaintext_slots(), "not enough plaintext slots"); + + let mut pt = Self::new(params); + for (i, scalar) in scalars.iter().enumerate() { + pt.set_element(i, *scalar); + } + + pt + } + /// Get a vector of scalars from the plaintext slots pub fn to_scalars(&self) -> Vec> { let mut scalars = Vec::with_capacity(self.num_slots()); diff --git a/offline-phase/src/lowgear/input_masks.rs b/offline-phase/src/lowgear/input_masks.rs new file mode 100644 index 0000000..985304e --- /dev/null +++ b/offline-phase/src/lowgear/input_masks.rs @@ -0,0 +1,92 @@ +//! Generates input masks for a party + +use ark_ec::CurveGroup; +use ark_mpc::{algebra::Scalar, network::MpcNetwork}; +use itertools::Itertools; +use mp_spdz_rs::fhe::{ciphertext::Ciphertext, plaintext::Plaintext}; +use rand::rngs::OsRng; + +use crate::{error::LowGearError, structs::ValueMacBatch}; + +use super::LowGear; + +impl + Unpin + Send> LowGear { + /// Generate input masks for the given party + pub async fn generate_input_masks(&mut self, n: usize) -> Result<(), LowGearError> { + assert!( + n <= self.params.plaintext_slots(), + "can only generate input masks for {} slots", + self.params.plaintext_slots() + ); + + // Each party generates their values, shares, and mac shares + let mut rng = OsRng; + let my_values = (0..n).map(|_| Scalar::::random(&mut rng)).collect_vec(); + let my_share = (0..n).map(|_| Scalar::::random(&mut rng)).collect_vec(); + + let mut mac_mask = Plaintext::new(&self.params); + mac_mask.randomize(); + let my_key = self.mac_share; + let my_mac_shares = + my_values.iter().zip(mac_mask.to_scalars()).map(|(x, y)| my_key * x - y).collect_vec(); + + let my_values_shares = ValueMacBatch::from_parts(&my_share, &my_mac_shares); + self.input_masks.add_local_masks(my_values.clone(), my_values_shares.into_inner()); + + // Compute the counterparty's shares and mac shares of my values + let their_share = my_values.iter().zip(my_share.iter()).map(|(x, y)| x - y).collect_vec(); + let other_key_enc = self.other_mac_enc.as_ref().unwrap(); + let values_plaintext = Plaintext::from_scalars(&my_values, &self.params); + let mut mac_product = other_key_enc * &values_plaintext; + mac_product.rerandomize(self.other_pk.as_ref().unwrap()); + + let their_mac = &mac_product + &mac_mask; + + // Exchange shares and macs + self.send_network_payload(their_share).await?; + let my_shares: Vec> = self.receive_network_payload().await?; + + self.send_message(&their_mac).await?; + let my_counterparty_macs: Ciphertext = self.receive_message().await?; + let mut my_macs = self.local_keypair.decrypt(&my_counterparty_macs).to_scalars(); + my_macs.truncate(n); + + let my_counterparty_shares = ValueMacBatch::from_parts(&my_shares, &my_macs); + self.input_masks.add_counterparty_masks(my_counterparty_shares); + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use ark_mpc::PARTY0; + + use crate::test_helpers::mock_lowgear_with_keys; + + /// Tests generating input masks + #[tokio::test] + async fn test_generate_input_masks() { + const N: usize = 100; + let (party0_res, _) = mock_lowgear_with_keys(|mut lowgear| async move { + lowgear.generate_input_masks(N).await.unwrap(); + + // Open the first party's input masks, verify that they're the same as party 0's + // cleartext values + if lowgear.party_id() == PARTY0 { + let (cleartext, shares) = lowgear.input_masks.get_local_mask_batch(N); + let opened = lowgear.open_and_check_macs(&shares).await.unwrap(); + + cleartext == opened + } else { + let shares = lowgear.input_masks.get_counterparty_mask_batch(N); + lowgear.open_and_check_macs(&shares).await.unwrap(); + + true + } + }) + .await; + + assert!(party0_res); + } +} diff --git a/offline-phase/src/lowgear/mod.rs b/offline-phase/src/lowgear/mod.rs index ee6f9f4..5c25c78 100644 --- a/offline-phase/src/lowgear/mod.rs +++ b/offline-phase/src/lowgear/mod.rs @@ -2,6 +2,7 @@ //! keys, authenticating inputs, etc pub mod commit_reveal; +pub mod input_masks; pub mod inverse_tuples; pub mod mac_check; pub mod multiplication; @@ -29,7 +30,7 @@ use rand::thread_rng; use crate::{ error::LowGearError, - structs::{LowGearParams, LowGearPrep, ValueMacBatch}, + structs::{InputMasks, LowGearParams, LowGearPrep, ValueMacBatch}, }; /// A type implementing Lowgear protocol logic @@ -52,6 +53,12 @@ pub struct LowGear> { pub shared_bits: ValueMacBatch, /// The shared random values generated during the offline phase pub shared_randomness: ValueMacBatch, + /// The input masks generated during the offline phase + /// + /// An input mask is party specific, that is, each party has a set of input + /// values wherein it holds a random value in the cleartext + /// and the parties collectively hold a sharing of the value + pub input_masks: InputMasks, /// A reference to the underlying network connection pub network: N, } @@ -75,6 +82,7 @@ impl + Unpin> LowGear { inverse_tuples: Default::default(), shared_bits: Default::default(), shared_randomness: Default::default(), + input_masks: Default::default(), network, } } diff --git a/offline-phase/src/structs.rs b/offline-phase/src/structs.rs index f08a93a..5c967be 100644 --- a/offline-phase/src/structs.rs +++ b/offline-phase/src/structs.rs @@ -4,7 +4,7 @@ use std::ops::{Add, Mul, Sub}; use ark_ec::CurveGroup; use ark_mpc::algebra::{Scalar, ScalarShare}; -use ark_mpc::offline_prep::OfflinePhase; +use ark_mpc::offline_prep::PreprocessingPhase; use mp_spdz_rs::fhe::ciphertext::Ciphertext; use mp_spdz_rs::fhe::keys::{BGVKeypair, BGVPublicKey}; use mp_spdz_rs::fhe::params::BGVParams; @@ -93,7 +93,7 @@ impl LowGearPrep { } } -impl OfflinePhase for LowGearPrep { +impl PreprocessingPhase for LowGearPrep { fn next_shared_bit(&mut self) -> ScalarShare { self.bits.split_off(1).into_inner()[0] } @@ -172,6 +172,11 @@ impl ValueMacBatch { self.inner.is_empty() } + /// Pop the last value and mac from the batch + pub fn pop(&mut self) -> Option> { + self.inner.pop() + } + /// Append the given batch to this one pub fn append(&mut self, other: &mut Self) { self.inner.append(&mut other.inner); @@ -296,6 +301,71 @@ impl Mul<&[Scalar]> for &ValueMacBatch { } } +// --------------- +// | Input Masks | +// --------------- + +/// The input mask values held by the local party +/// +/// Each party holds a set of random cleartext values used to mask inputs to the +/// MPC. The other parties collectively hold a sharing of the values +/// +/// So, this struct holds the local party's cleartext values and the local +/// party's shares of their own and others' cleartext masks +#[derive(Clone, Default)] +pub struct InputMasks { + /// The local party's cleartext mask values + pub my_masks: Vec>, + /// The local party's shares of their own mask values + pub my_mask_shares: ValueMacBatch, + /// The shares of the cleartext values + /// + /// Index `i` is a set of shares for party i's masks + pub their_masks: ValueMacBatch, +} + +impl InputMasks { + /// Append values to `my_masks` + pub fn add_local_masks(&mut self, values: Vec>, masks: Vec>) { + assert_eq!(values.len(), masks.len()); + self.my_masks.extend(values); + self.my_mask_shares.append(&mut ValueMacBatch::new(masks)); + } + + /// Add values to `their_masks` + pub fn add_counterparty_masks(&mut self, mut masks: ValueMacBatch) { + self.their_masks.append(&mut masks); + } + + /// Get the local party's next mask and share of the mask + pub fn get_local_mask(&mut self) -> (Scalar, ScalarShare) { + assert!(!self.my_masks.is_empty(), "no local masks left"); + let mask = self.my_masks.pop().unwrap(); + let mask_share = self.my_mask_shares.pop().unwrap(); + + (mask, mask_share) + } + + /// Get a batch of local masks and shares of the masks + pub fn get_local_mask_batch(&mut self, num_masks: usize) -> (Vec>, ValueMacBatch) { + let split_idx = self.my_masks.len() - num_masks; + let masks = self.my_masks.split_off(split_idx); + let mask_shares = self.my_mask_shares.split_off(num_masks); + + (masks, mask_shares) + } + + /// Get the local party's share of the counterparty's next mask + pub fn get_counterparty_mask(&mut self) -> ScalarShare { + self.their_masks.split_off(1).into_inner()[0] + } + + /// Get a batch of the local party's shares of the counterparty's masks + pub fn get_counterparty_mask_batch(&mut self, num_masks: usize) -> ValueMacBatch { + self.their_masks.split_off(num_masks) + } +} + #[cfg(test)] mod test { use ark_mpc::{ @@ -305,7 +375,8 @@ mod test { use crate::test_helpers::mock_lowgear_with_triples; - /// Tests the use of the `LowGear` type as an `OfflinePhase` implementation + /// Tests the use of the `LowGear` type as an `PreprocessingPhase` + /// implementation #[tokio::test] async fn test_lowgear_offline_phase() { // Setup the mock offline phase