Skip to content

Commit

Permalink
offline-phase: lowgear: multiplication: Implement subprotocol
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Apr 13, 2024
1 parent 3dd547b commit 0bc09c2
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 19 deletions.
3 changes: 1 addition & 2 deletions mp-spdz-rs/src/fhe/plaintext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ impl<C: CurveGroup> PlaintextVector<C> {
/// Create a plaintext vector from a vector of scalars, packing them into
/// slots
pub fn from_scalars(scalars: &[Scalar<C>], params: &BGVParams<C>) -> Self {
let n_plaintexts = scalars.len() / params.plaintext_slots() + 1;
let mut pt = Self::new(n_plaintexts, params);
let mut pt = Self::empty();

for chunk in scalars.chunks(params.plaintext_slots()) {
let mut plaintext = Plaintext::new(params);
Expand Down
24 changes: 22 additions & 2 deletions offline-phase/src/beaver_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ pub struct LowGearParams<C: CurveGroup> {
#[derive(Default, Copy, Clone)]
pub struct ValueMac<C: CurveGroup> {
/// The value
value: Scalar<C>,
pub(crate) value: Scalar<C>,
/// The mac
mac: Scalar<C>,
pub(crate) mac: Scalar<C>,
}

impl<C: CurveGroup> ValueMac<C> {
Expand Down Expand Up @@ -87,6 +87,21 @@ impl<C: CurveGroup> ValueMacBatch<C> {
Self { inner }
}

/// Get the length of the batch
pub fn len(&self) -> usize {
self.inner.len()
}

/// Check if the batch is empty
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}

/// Get the inner vector
pub fn into_inner(self) -> Vec<ValueMac<C>> {
self.inner
}

/// Get all values
pub fn values(&self) -> Vec<Scalar<C>> {
self.inner.iter().map(|vm| vm.value).collect()
Expand All @@ -102,6 +117,11 @@ impl<C: CurveGroup> ValueMacBatch<C> {
self.inner.iter()
}

/// Get a mutable iterator over the vector
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, ValueMac<C>> {
self.inner.iter_mut()
}

/// Create a new ValueMacBatch from a batch of values and macs
pub fn from_parts(values: &[Scalar<C>], macs: &[Scalar<C>]) -> Self {
assert_eq!(values.len(), macs.len());
Expand Down
2 changes: 1 addition & 1 deletion offline-phase/src/lowgear/inverse_tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin + Send> LowGear<C, N> {
let random_values = self.get_authenticated_randomness_vec(2 * n).await?;

// Split into halves that we will multiply using the Beaver trick
let (random_values1, random_values2) = random_values.split_at(n);
let (random_values1, random_values2) = random_values.into_inner().split_at(n);

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion offline-phase/src/lowgear/mac_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin + Send> LowGear<C, N> {
/// Returns the opened values
pub async fn open_and_check_macs(
&mut self,
x: ValueMacBatch<C>,
x: &ValueMacBatch<C>,
) -> Result<Vec<Scalar<C>>, LowGearError> {
// Open and reconstruct
let recovered_values = self.open_batch(&x.values()).await?;
Expand Down
179 changes: 179 additions & 0 deletions offline-phase/src/lowgear/multiplication.rs
Original file line number Diff line number Diff line change
@@ -1 +1,180 @@
//! Multiplication sub-protocol using the Beaver trick
use ark_ec::CurveGroup;
use ark_mpc::{algebra::Scalar, network::MpcNetwork, PARTY0};
use itertools::Itertools;

use crate::{beaver_source::ValueMacBatch, error::LowGearError};

use super::LowGear;

impl<C: CurveGroup, N: MpcNetwork<C> + Unpin + Send> LowGear<C, N> {
/// Multiply two batches of values using the Beaver trick
pub async fn beaver_mul(
&mut self,
lhs: &ValueMacBatch<C>,
rhs: &ValueMacBatch<C>,
) -> Result<ValueMacBatch<C>, LowGearError> {
let n = lhs.len();
assert_eq!(rhs.len(), n, "Batch sizes must match");
assert!(self.triples.len() >= n, "Not enough triples for batch size");

// Get triples for the beaver trick
let (a, b, c) = self.consume_triples(n);

// Open d = lhs - a and e = rhs - b
let d = self.open_and_check_macs(&(lhs - &a)).await?;
let e = self.open_and_check_macs(&(rhs - &b)).await?;

// Identity: [x * y] = de + d[b] + e[a] + [c]
let de = d.iter().zip(e.iter()).map(|(d, e)| d * e).collect_vec();
let db = &b * d.as_slice();
let ea = &a * e.as_slice();
let mut shared_sum = &(&db + &ea) + &c;

// Only the first party adds the public term to their shares
self.add_public_value(&de, &mut shared_sum);

Ok(shared_sum)
}

/// Get the next `n` triples from the beaver source
fn consume_triples(
&mut self,
n: usize,
) -> (ValueMacBatch<C>, ValueMacBatch<C>, ValueMacBatch<C>) {
let triples = self.triples.split_off(n);

let mut a_res = Vec::with_capacity(n);
let mut b_res = Vec::with_capacity(n);
let mut c_res = Vec::with_capacity(n);
for (a, b, c) in triples.iter() {
a_res.push(*a);
b_res.push(*b);
c_res.push(*c);
}

(ValueMacBatch::new(a_res), ValueMacBatch::new(b_res), ValueMacBatch::new(c_res))
}

/// Add a batch of public values to a batch of shared values
///
/// Only the first party adds the public term to their shares, both parties
/// add the corresponding mac term
fn add_public_value(&mut self, public: &[Scalar<C>], batch: &mut ValueMacBatch<C>) {
let is_party0 = self.party_id() == PARTY0;
for (val, public) in batch.iter_mut().zip(public.iter()) {
val.mac += self.mac_share * public;
if is_party0 {
val.value += *public;
}
}
}
}

#[cfg(test)]
mod tests {
use ark_mpc::{algebra::Scalar, PARTY0};
use itertools::{izip, Itertools};
use rand::thread_rng;

use crate::{
beaver_source::ValueMacBatch,
test_helpers::{encrypt_all, mock_lowgear_with_keys, TestCurve},
};

/// Generate random mock triples for the Beaver trick
#[allow(clippy::type_complexity)]
fn generate_triples(
n: usize,
) -> (Vec<Scalar<TestCurve>>, Vec<Scalar<TestCurve>>, Vec<Scalar<TestCurve>>) {
let mut rng = thread_rng();
let a = (0..n).map(|_| Scalar::<TestCurve>::random(&mut rng)).collect_vec();
let b = (0..n).map(|_| Scalar::<TestCurve>::random(&mut rng)).collect_vec();
let c = (0..n).map(|_| Scalar::<TestCurve>::random(&mut rng)).collect_vec();

(a, b, c)
}

/// Generate authenticated secret shares of a given set of values
fn generate_authenticated_secret_shares(
values: &[Scalar<TestCurve>],
mac_key: Scalar<TestCurve>,
) -> (ValueMacBatch<TestCurve>, ValueMacBatch<TestCurve>) {
let (shares1, shares2) = generate_secret_shares(values);
let macs = values.iter().map(|value| *value * mac_key).collect_vec();
let (macs1, macs2) = generate_secret_shares(&macs);

(ValueMacBatch::from_parts(&shares1, &macs1), ValueMacBatch::from_parts(&shares2, &macs2))
}

/// Generate secret shares of a set of values
fn generate_secret_shares(
values: &[Scalar<TestCurve>],
) -> (Vec<Scalar<TestCurve>>, Vec<Scalar<TestCurve>>) {
let mut rng = thread_rng();
let mut shares1 = Vec::with_capacity(values.len());
let mut shares2 = Vec::with_capacity(values.len());
for value in values {
let share1 = Scalar::<TestCurve>::random(&mut rng);
let share2 = value - share1;
shares1.push(share1);
shares2.push(share2);
}

(shares1, shares2)
}

#[tokio::test]
async fn test_beaver_mul() {
const N: usize = 1;
let mut rng = thread_rng();

// Setup mock keys and triplets
let mac_key = Scalar::<TestCurve>::random(&mut rng);
let mac_key1 = Scalar::<TestCurve>::random(&mut rng);
let mac_key2 = mac_key - mac_key1;

let (a, b, c) = generate_triples(N);
let (a1, a2) = generate_authenticated_secret_shares(&a, mac_key);
let (b1, b2) = generate_authenticated_secret_shares(&b, mac_key);
let (c1, c2) = generate_authenticated_secret_shares(&c, mac_key);

mock_lowgear_with_keys(|mut lowgear| {
// Setup the mac shares and counterparty mac share encryptions
let is_party0 = lowgear.party_id() == PARTY0;
lowgear.mac_share = if is_party0 { mac_key1 } else { mac_key2 };

let other_pk = lowgear.other_pk.as_ref().unwrap();
let other_share = if is_party0 { mac_key2 } else { mac_key1 };
lowgear.other_mac_enc = Some(encrypt_all(other_share, other_pk, &lowgear.params));

// Setup the mock triplets
let (my_a, my_b, my_c) = if is_party0 { (&a1, &b1, &c1) } else { (&a2, &b2, &c2) };
lowgear.triples = izip!(
my_a.clone().into_inner(),
my_b.clone().into_inner(),
my_c.clone().into_inner()
)
.collect_vec();

// Test the multiplication sub-protocol
async move {
let lhs = lowgear.get_authenticated_randomness_vec(N).await.unwrap();
let rhs = lowgear.get_authenticated_randomness_vec(N).await.unwrap();
let res = lowgear.beaver_mul(&lhs, &rhs).await.unwrap();

// Open all values
let lhs_open = lowgear.open_and_check_macs(&lhs).await.unwrap();
let rhs_open = lowgear.open_and_check_macs(&rhs).await.unwrap();
let res_open = lowgear.open_and_check_macs(&res).await.unwrap();

// Assert that the result is equal to the expected value
for (l, r, re) in izip!(lhs_open, rhs_open, res_open) {
assert_eq!(re, l * r);
}
}
})
.await;
}
}
16 changes: 5 additions & 11 deletions offline-phase/src/lowgear/shared_random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use itertools::Itertools;
use mp_spdz_rs::fhe::plaintext::PlaintextVector;
use rand::rngs::OsRng;

use crate::{beaver_source::ValueMac, error::LowGearError};
use crate::{beaver_source::ValueMacBatch, error::LowGearError};

use super::LowGear;

Expand Down Expand Up @@ -44,7 +44,7 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin + Send> LowGear<C, N> {
pub async fn get_authenticated_randomness_vec(
&mut self,
n: usize,
) -> Result<Vec<ValueMac<C>>, LowGearError> {
) -> Result<ValueMacBatch<C>, LowGearError> {
// Each party generates shares locally with the represented value implicitly
// defined as the sum of the shares
let mut rng = OsRng;
Expand All @@ -55,19 +55,13 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin + Send> LowGear<C, N> {

// Recombine into ValueMac pairs
macs.truncate(n);
let res =
my_shares.into_iter().zip(macs.into_iter()).map(|(v, m)| ValueMac::new(v, m)).collect();

Ok(res)
Ok(ValueMacBatch::from_parts(&my_shares, &macs))
}
}

#[cfg(test)]
mod tests {
use crate::{
beaver_source::ValueMacBatch,
test_helpers::{mock_lowgear, mock_lowgear_with_keys},
};
use crate::test_helpers::{mock_lowgear, mock_lowgear_with_keys};

use super::*;

Expand Down Expand Up @@ -99,7 +93,7 @@ mod tests {
assert_eq!(shares.len(), N);

// Check the macs on the shares
lowgear.open_and_check_macs(ValueMacBatch::new(shares)).await.unwrap();
lowgear.open_and_check_macs(&shares).await.unwrap();
})
.await;
}
Expand Down
4 changes: 2 additions & 2 deletions offline-phase/src/lowgear/triplets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ impl<C: CurveGroup, N: MpcNetwork<C> + Unpin> LowGear<C, N> {

// Open r * b - b'
let my_rho = &(b * r) - b_prime;
let rho = self.open_and_check_macs(my_rho).await?;
let rho = self.open_and_check_macs(&my_rho).await?;

// Compute the expected rhs of the sacrifice identity
let rho_a = a * rho.as_slice();
let c_diff = &(c * r) - c_prime;
let my_tau = &c_diff - &rho_a;

// Open tau and check that all values are zero
let tau = self.open_and_check_macs(my_tau).await?;
let tau = self.open_and_check_macs(&my_tau).await?;

let zero = Scalar::zero();
if !tau.into_iter().all(|s| s == zero) {
Expand Down

0 comments on commit 0bc09c2

Please sign in to comment.