Skip to content

Commit

Permalink
offline-phase: lowgear: input-masks: Generate input masks in offline
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Apr 17, 2024
1 parent 9c4622e commit 3641783
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 4 deletions.
12 changes: 12 additions & 0 deletions mp-spdz-rs/src/fhe/plaintext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ impl<C: CurveGroup> Plaintext<C> {
self.inner.num_slots() as usize
}

/// Create a plaintext given a batch of scalars
pub fn from_scalars(scalars: &[Scalar<C>], params: &BGVParams<C>) -> Self {
assert!(scalars.len() < params.plaintext_slots(), "not enough plaintext slots");

let mut pt = Self::new(params);
for (i, scalar) in scalars.iter().enumerate() {
pt.set_element(i, *scalar);
}

pt
}

/// Get a vector of scalars from the plaintext slots
pub fn to_scalars(&self) -> Vec<Scalar<C>> {
let mut scalars = Vec::with_capacity(self.num_slots());
Expand Down
92 changes: 92 additions & 0 deletions offline-phase/src/lowgear/input_masks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//! Generates input masks for a party
use ark_ec::CurveGroup;
use ark_mpc::{algebra::Scalar, network::MpcNetwork};
use itertools::Itertools;
use mp_spdz_rs::fhe::{ciphertext::Ciphertext, plaintext::Plaintext};
use rand::rngs::OsRng;

use crate::{error::LowGearError, structs::ValueMacBatch};

use super::LowGear;

impl<C: CurveGroup, N: MpcNetwork<C> + Unpin + Send> LowGear<C, N> {
/// Generate input masks for the given party
pub async fn generate_input_masks(&mut self, n: usize) -> Result<(), LowGearError> {
assert!(
n <= self.params.plaintext_slots(),
"can only generate input masks for {} slots",
self.params.plaintext_slots()
);

// Each party generates their values, shares, and mac shares
let mut rng = OsRng;
let my_values = (0..n).map(|_| Scalar::<C>::random(&mut rng)).collect_vec();
let my_share = (0..n).map(|_| Scalar::<C>::random(&mut rng)).collect_vec();

let mut mac_mask = Plaintext::new(&self.params);
mac_mask.randomize();
let my_key = self.mac_share;
let my_mac_shares =
my_values.iter().zip(mac_mask.to_scalars()).map(|(x, y)| my_key * x - y).collect_vec();

let my_values_shares = ValueMacBatch::from_parts(&my_share, &my_mac_shares);
self.input_masks.add_local_masks(my_values.clone(), my_values_shares.into_inner());

// Compute the counterparty's shares and mac shares of my values
let their_share = my_values.iter().zip(my_share.iter()).map(|(x, y)| x - y).collect_vec();
let other_key_enc = self.other_mac_enc.as_ref().unwrap();
let values_plaintext = Plaintext::from_scalars(&my_values, &self.params);
let mut mac_product = other_key_enc * &values_plaintext;
mac_product.rerandomize(self.other_pk.as_ref().unwrap());

let their_mac = &mac_product + &mac_mask;

// Exchange shares and macs
self.send_network_payload(their_share).await?;
let my_shares: Vec<Scalar<C>> = self.receive_network_payload().await?;

self.send_message(&their_mac).await?;
let my_counterparty_macs: Ciphertext<C> = self.receive_message().await?;
let mut my_macs = self.local_keypair.decrypt(&my_counterparty_macs).to_scalars();
my_macs.truncate(n);

let my_counterparty_shares = ValueMacBatch::from_parts(&my_shares, &my_macs);
self.input_masks.add_counterparty_masks(my_counterparty_shares);

Ok(())
}
}

#[cfg(test)]
mod test {
use ark_mpc::PARTY0;

use crate::test_helpers::mock_lowgear_with_keys;

/// Tests generating input masks
#[tokio::test]
async fn test_generate_input_masks() {
const N: usize = 100;
let (party0_res, _) = mock_lowgear_with_keys(|mut lowgear| async move {
lowgear.generate_input_masks(N).await.unwrap();

// Open the first party's input masks, verify that they're the same as party 0's
// cleartext values
if lowgear.party_id() == PARTY0 {
let (cleartext, shares) = lowgear.input_masks.get_local_mask_batch(N);
let opened = lowgear.open_and_check_macs(&shares).await.unwrap();

cleartext == opened
} else {
let shares = lowgear.input_masks.get_counterparty_mask_batch(N);
lowgear.open_and_check_macs(&shares).await.unwrap();

true
}
})
.await;

assert!(party0_res);
}
}
10 changes: 9 additions & 1 deletion offline-phase/src/lowgear/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//! keys, authenticating inputs, etc
pub mod commit_reveal;
pub mod input_masks;
pub mod inverse_tuples;
pub mod mac_check;
pub mod multiplication;
Expand Down Expand Up @@ -29,7 +30,7 @@ use rand::thread_rng;

use crate::{
error::LowGearError,
structs::{LowGearParams, LowGearPrep, ValueMacBatch},
structs::{InputMasks, LowGearParams, LowGearPrep, ValueMacBatch},
};

/// A type implementing Lowgear protocol logic
Expand All @@ -52,6 +53,12 @@ pub struct LowGear<C: CurveGroup, N: MpcNetwork<C>> {
pub shared_bits: ValueMacBatch<C>,
/// The shared random values generated during the offline phase
pub shared_randomness: ValueMacBatch<C>,
/// The input masks generated during the offline phase
///
/// An input mask is party specific, that is, each party has a set of input
/// values wherein it holds a random value in the cleartext
/// and the parties collectively hold a sharing of the value
pub input_masks: InputMasks<C>,
/// A reference to the underlying network connection
pub network: N,
}
Expand All @@ -75,6 +82,7 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {
inverse_tuples: Default::default(),
shared_bits: Default::default(),
shared_randomness: Default::default(),
input_masks: Default::default(),
network,
}
}
Expand Down
77 changes: 74 additions & 3 deletions offline-phase/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::ops::{Add, Mul, Sub};

use ark_ec::CurveGroup;
use ark_mpc::algebra::{Scalar, ScalarShare};
use ark_mpc::offline_prep::OfflinePhase;
use ark_mpc::offline_prep::PreprocessingPhase;
use mp_spdz_rs::fhe::ciphertext::Ciphertext;
use mp_spdz_rs::fhe::keys::{BGVKeypair, BGVPublicKey};
use mp_spdz_rs::fhe::params::BGVParams;
Expand Down Expand Up @@ -93,7 +93,7 @@ impl<C: CurveGroup> LowGearPrep<C> {
}
}

impl<C: CurveGroup> OfflinePhase<C> for LowGearPrep<C> {
impl<C: CurveGroup> PreprocessingPhase<C> for LowGearPrep<C> {
fn next_shared_bit(&mut self) -> ScalarShare<C> {
self.bits.split_off(1).into_inner()[0]
}
Expand Down Expand Up @@ -172,6 +172,11 @@ impl<C: CurveGroup> ValueMacBatch<C> {
self.inner.is_empty()
}

/// Pop the last value and mac from the batch
pub fn pop(&mut self) -> Option<ScalarShare<C>> {
self.inner.pop()
}

/// Append the given batch to this one
pub fn append(&mut self, other: &mut Self) {
self.inner.append(&mut other.inner);
Expand Down Expand Up @@ -296,6 +301,71 @@ impl<C: CurveGroup> Mul<&[Scalar<C>]> for &ValueMacBatch<C> {
}
}

// ---------------
// | Input Masks |
// ---------------

/// The input mask values held by the local party
///
/// Each party holds a set of random cleartext values used to mask inputs to the
/// MPC. The other parties collectively hold a sharing of the values
///
/// So, this struct holds the local party's cleartext values and the local
/// party's shares of their own and others' cleartext masks
#[derive(Clone, Default)]
pub struct InputMasks<C: CurveGroup> {
/// The local party's cleartext mask values
pub my_masks: Vec<Scalar<C>>,
/// The local party's shares of their own mask values
pub my_mask_shares: ValueMacBatch<C>,
/// The shares of the cleartext values
///
/// Index `i` is a set of shares for party i's masks
pub their_masks: ValueMacBatch<C>,
}

impl<C: CurveGroup> InputMasks<C> {
/// Append values to `my_masks`
pub fn add_local_masks(&mut self, values: Vec<Scalar<C>>, masks: Vec<ScalarShare<C>>) {
assert_eq!(values.len(), masks.len());
self.my_masks.extend(values);
self.my_mask_shares.append(&mut ValueMacBatch::new(masks));
}

/// Add values to `their_masks`
pub fn add_counterparty_masks(&mut self, mut masks: ValueMacBatch<C>) {
self.their_masks.append(&mut masks);
}

/// Get the local party's next mask and share of the mask
pub fn get_local_mask(&mut self) -> (Scalar<C>, ScalarShare<C>) {
assert!(!self.my_masks.is_empty(), "no local masks left");
let mask = self.my_masks.pop().unwrap();
let mask_share = self.my_mask_shares.pop().unwrap();

(mask, mask_share)
}

/// Get a batch of local masks and shares of the masks
pub fn get_local_mask_batch(&mut self, num_masks: usize) -> (Vec<Scalar<C>>, ValueMacBatch<C>) {
let split_idx = self.my_masks.len() - num_masks;
let masks = self.my_masks.split_off(split_idx);
let mask_shares = self.my_mask_shares.split_off(num_masks);

(masks, mask_shares)
}

/// Get the local party's share of the counterparty's next mask
pub fn get_counterparty_mask(&mut self) -> ScalarShare<C> {
self.their_masks.split_off(1).into_inner()[0]
}

/// Get a batch of the local party's shares of the counterparty's masks
pub fn get_counterparty_mask_batch(&mut self, num_masks: usize) -> ValueMacBatch<C> {
self.their_masks.split_off(num_masks)
}
}

#[cfg(test)]
mod test {
use ark_mpc::{
Expand All @@ -305,7 +375,8 @@ mod test {

use crate::test_helpers::mock_lowgear_with_triples;

/// Tests the use of the `LowGear` type as an `OfflinePhase` implementation
/// Tests the use of the `LowGear` type as an `PreprocessingPhase`
/// implementation
#[tokio::test]
async fn test_lowgear_offline_phase() {
// Setup the mock offline phase
Expand Down

0 comments on commit 3641783

Please sign in to comment.