From 93d6615ac0ca7d486ce87be409837b7f07e3c43e Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Wed, 17 Apr 2024 16:36:53 -0700 Subject: [PATCH] online-phase: fabric: Authenticate inputs with input masks --- README.md | 2 +- online-phase/src/fabric.rs | 301 +++++++++++------------------- online-phase/src/fabric/result.rs | 41 ++++ online-phase/src/lib.rs | 4 +- online-phase/src/offline_prep.rs | 69 +++++-- 5 files changed, 213 insertions(+), 204 deletions(-) diff --git a/README.md b/README.md index 19022ce..1329521 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ `ark-mpc` provides a malicious secure [SPDZ](https://eprint.iacr.org/2011/535.pdf) style framework for two party secure computation. The circuit is constructed on the fly, by overloading arithmetic operators of MPC types, see the example below in which each of the parties shares a value and together they compute the product: ```rust use ark_mpc::{ - algebra::scalar::Scalar, beaver::OfflinePhase, network::QuicTwoPartyNet, MpcFabric, + algebra::scalar::Scalar, beaver::PreprocessingPhase, network::QuicTwoPartyNet, MpcFabric, PARTY0, PARTY1, }; use ark_curve25519::EdwardsProjective as Curve25519Projective; diff --git a/online-phase/src/fabric.rs b/online-phase/src/fabric.rs index 2ccdb1f..d2e5e51 100644 --- a/online-phase/src/fabric.rs +++ b/online-phase/src/fabric.rs @@ -4,7 +4,7 @@ //! MpcFabric creates and manages dependencies needed to allocate network //! values. This provides a cleaner interface for consumers of the library; i.e. //! clients do not have to hold onto references of the network layer or the -//! beaver sources to allocate values. +//! offline phase implementation to allocate values. mod executor; mod network_sender; @@ -16,7 +16,6 @@ pub use executor::ExecutorSizeHints; use executor::{single_threaded::SerialExecutor, ExecutorMessage}; #[cfg(feature = "benchmarks")] pub use executor::{single_threaded::SerialExecutor, ExecutorMessage, GrowableBuffer}; -use rand::thread_rng; pub use result::{ResultHandle, ResultId, ResultValue}; use futures::executor::block_on; @@ -37,12 +36,11 @@ use itertools::Itertools; use crate::{ algebra::{ - AuthenticatedPointResult, AuthenticatedScalarResult, BatchCurvePointResult, - BatchScalarResult, CurvePoint, CurvePointResult, PointShare, Scalar, ScalarResult, - ScalarShare, + AuthenticatedPointResult, AuthenticatedScalarResult, CurvePoint, CurvePointResult, + PointShare, Scalar, ScalarResult, ScalarShare, }, network::{MpcNetwork, NetworkOutbound, NetworkPayload, PartyId}, - offline_prep::OfflinePhase, + offline_prep::PreprocessingPhase, PARTY0, }; @@ -173,8 +171,6 @@ pub struct MpcFabric { /// The local party's share of the global MAC key /// /// The parties collectively hold an additive sharing of the global key - /// - /// We wrap in a reference counting structure to avoid recursive type issues #[cfg(not(feature = "benchmarks"))] mac_key: Scalar, /// The MAC key, accessible publicly for benchmark mocking @@ -209,7 +205,7 @@ pub struct FabricInner { /// The underlying queue to the network outbound_queue: KanalSender>, /// The underlying shared randomness source - beaver_source: Arc>>>, + offline_phase: Arc>>>, } impl Debug for FabricInner { @@ -220,12 +216,12 @@ impl Debug for FabricInner { impl FabricInner { /// Constructor - pub fn new>( + pub fn new>( party_id: u64, mac_key: Scalar, execution_queue: ExecutorJobQueue, outbound_queue: KanalSender>, - beaver_source: S, + offline_phase: S, ) -> Self { // Allocate a zero and a one as well as the curve identity in the fabric to // begin, for convenience @@ -264,7 +260,7 @@ impl FabricInner { next_op_id, execution_queue, outbound_queue, - beaver_source: Arc::new(Mutex::new(Box::new(beaver_source))), + offline_phase: Arc::new(Mutex::new(Box::new(offline_phase))), } } @@ -357,26 +353,6 @@ impl FabricInner { ids } - /// Allocate a secret shared value in the network - pub(crate) fn allocate_shared_value( - &self, - my_share: ResultValue, - their_share: ResultValue, - ) -> ResultId { - // Forward the local party's share to the executor - let id = self.new_result_id(); - self.execution_queue.push(ExecutorMessage::Result(OpResult { id, value: my_share })); - - // Send the counterparty their share - if let Err(e) = - self.outbound_queue.send(NetworkOutbound { result_id: id, payload: their_share.into() }) - { - log::error!("error sending share to counterparty: {e:?}"); - } - - id - } - /// Receive a value from a network operation initiated by a peer /// /// The peer will already send the value with the corresponding ID, so all @@ -402,6 +378,7 @@ impl FabricInner { } // Allocate IDs for the results + assert!(output_arity > 0, "output arity must be greater than 0"); let ids = self.new_result_id_batch(output_arity); // Build the operation @@ -422,24 +399,24 @@ impl FabricInner { impl MpcFabric { /// Constructor - pub fn new, S: 'static + OfflinePhase>( + pub fn new, S: 'static + PreprocessingPhase>( network: N, - beaver_source: S, + offline_phase: S, ) -> Self { - Self::new_with_size_hint(ExecutorSizeHints::default(), network, beaver_source) + Self::new_with_size_hint(ExecutorSizeHints::default(), network, offline_phase) } /// Constructor that takes an additional size hint, indicating how much /// buffer space the fabric should allocate for results. The size is /// given in number of gates - pub fn new_with_size_hint, S: 'static + OfflinePhase>( + pub fn new_with_size_hint, S: 'static + PreprocessingPhase>( size_hints: ExecutorSizeHints, network: N, - beaver_source: S, + offline_phase: S, ) -> Self { // Build an executor queue and a fabric around it let executor_queue = Arc::new(SegQueue::new()); - let self_ = Self::new_with_executor(network, beaver_source, executor_queue.clone()); + let self_ = Self::new_with_executor(network, offline_phase, executor_queue.clone()); // Spawn the executor let outbound_queue = self_.inner.outbound_queue.clone(); @@ -455,9 +432,9 @@ impl MpcFabric { /// Constructor that takes an additional size hint as well as a queue for /// the executor - pub fn new_with_executor, S: 'static + OfflinePhase>( + pub fn new_with_executor, S: 'static + PreprocessingPhase>( network: N, - beaver_source: S, + offline_phase: S, executor_queue: ExecutorJobQueue, ) -> Self { // Build communication primitives @@ -466,14 +443,13 @@ impl MpcFabric { // Build a fabric let party_id = network.party_id(); - // TODO: Use offline phase params - let mac_key = Scalar::from(party_id); + let mac_key = offline_phase.get_mac_key_share(); let fabric = FabricInner::new( party_id, mac_key, executor_queue.clone(), outbound_sender.to_sync(), - beaver_source, + offline_phase, ); // Start a network sender and operator executor @@ -573,43 +549,29 @@ impl MpcFabric { // | Wire Allocation | // ------------------- - /// Allocate a shared value in the fabric - fn allocate_shared_value>>( - &self, - my_share: ResultValue, - their_share: ResultValue, - ) -> ResultHandle { - let id = self.inner.allocate_shared_value(my_share, their_share); - ResultHandle::new(id, self.clone()) - } - /// Share a `Scalar` value with the counterparty - /// - /// TODO: Input authentication pub fn share_scalar>>( &self, val: T, sender: PartyId, ) -> AuthenticatedScalarResult { - let scalar: ScalarResult = if self.party_id() == sender { - let scalar_val = val.into(); - let mut rng = thread_rng(); - let random = Scalar::random(&mut rng); - - let (my_share, their_share) = (scalar_val - random, random); - self.allocate_shared_value( - ResultValue::Scalar(my_share), - ResultValue::Scalar(their_share), - ) + // Sample an input mask from the offline phase + let mut offline = self.inner.offline_phase.lock().expect("offline phase poisoned"); + let (masked_val, shared_mask) = if self.party_id() == sender { + let (mask, mask_share) = offline.next_local_input_mask(); + let masked = Into::>::into(val) - mask; + let masked_val = self.share_plaintext(masked, sender); + + (masked_val, mask_share) } else { - self.receive_value() + let mask_share = offline.next_counterparty_input_mask(); + let masked_val = self.share_plaintext(Scalar::zero(), sender); + + (masked_val, mask_share) }; - // TODO: Proper input authentication - self.new_gate_op(vec![scalar.id()], |mut args| { - let share: Scalar = args.next().unwrap().into(); - ResultValue::ScalarShare(ScalarShare::new(share, share /* mac */)) - }) + // Unmask the value in the MPC circuit + self.allocate_scalar_share(shared_mask) + masked_val } /// Share a batch of `Scalar` values with the counterparty @@ -619,62 +581,41 @@ impl MpcFabric { sender: PartyId, ) -> Vec> { let n = vals.len(); - let shares: BatchScalarResult = if self.party_id() == sender { - let vals = vals.into_iter().map(|val| val.into()).collect_vec(); - let mut rng = thread_rng(); - - let peer_shares = (0..vals.len()).map(|_| Scalar::random(&mut rng)).collect_vec(); - let my_shares = - vals.iter().zip(peer_shares.iter()).map(|(val, share)| val - share).collect_vec(); - - self.allocate_shared_value( - ResultValue::ScalarBatch(my_shares), - ResultValue::ScalarBatch(peer_shares), - ) - } else { - self.receive_value() - }; + let mut offline = self.inner.offline_phase.lock().expect("offline phase poisoned"); + let (masked_vals, mask_shares) = if self.party_id() == sender { + let (masks, mask_shares) = offline.next_local_input_mask_batch(n); + let masked = vals.into_iter().zip(masks).map(|(val, mask)| val.into() - mask).collect(); + let masked_vals = self.batch_share_plaintext(masked, sender); - // TODO: Proper input authentication - self.new_batch_gate_op(vec![shares.id()], n, |mut args| { - let shares: Vec> = args.next().unwrap().into(); - let mut res = Vec::with_capacity(shares.len()); + (masked_vals, mask_shares) + } else { + let mask_shares = offline.next_counterparty_input_mask_batch(n); + let masked_vals = self.batch_share_plaintext(vec![Scalar::zero(); n], sender); - for share in shares.into_iter() { - res.push(ResultValue::ScalarShare(ScalarShare::new(share, share /* mac */))); - } + (masked_vals, mask_shares) + }; - res - }) + let shares = self.allocate_scalar_shares(mask_shares); + AuthenticatedScalarResult::batch_add_public(&shares, &masked_vals) } /// Share a `CurvePoint` value with the counterparty pub fn share_point(&self, val: CurvePoint, sender: PartyId) -> AuthenticatedPointResult { - let point: CurvePointResult = if self.party_id() == sender { - // As mentioned in https://eprint.iacr.org/2009/226.pdf - // it is okay to sample a random point by sampling a random `Scalar` and - // multiplying by the generator in the case that the discrete log of - // the output may be leaked with respect to the generator. Leaking - // the discrete log (i.e. the random `Scalar`) is okay when it is - // used to generate secret shares - let mut rng = thread_rng(); - let random = Scalar::random(&mut rng); - let random_point = random * CurvePoint::generator(); - - let (my_share, their_share) = (val - random_point, random_point); - self.allocate_shared_value( - ResultValue::Point(my_share), - ResultValue::Point(their_share), - ) + let mut offline = self.inner.offline_phase.lock().expect("offline phase poisoned"); + let (masked_point, mask_share) = if self.party_id() == sender { + let (mask, mask_share) = offline.next_local_input_mask(); + let masked = val - mask * CurvePoint::generator(); + let masked_point = self.share_plaintext(masked, sender); + + (masked_point, mask_share) } else { - self.receive_value() + let mask_share = offline.next_counterparty_input_mask(); + let masked_point = self.share_plaintext(CurvePoint::generator(), sender); + + (masked_point, mask_share) }; - // TODO: Proper input authentication - self.new_gate_op(vec![point.id()], |mut args| { - let share: CurvePoint = args.next().unwrap().into(); - ResultValue::PointShare(PointShare::new(share, share /* mac */)) - }) + self.allocate_scalar_share(mask_share) * CurvePoint::generator() + masked_point } /// Share a batch of `CurvePoint`s with the counterparty @@ -684,37 +625,27 @@ impl MpcFabric { sender: PartyId, ) -> Vec> { let n = vals.len(); - let shares: BatchCurvePointResult = if self.party_id() == sender { - let mut rng = thread_rng(); - let generator = CurvePoint::generator(); - let peer_shares = (0..vals.len()) - .map(|_| { - let discrete_log = Scalar::random(&mut rng); - discrete_log * generator - }) - .collect_vec(); - let my_shares = - vals.iter().zip(peer_shares.iter()).map(|(val, share)| val - share).collect_vec(); - - self.allocate_shared_value( - ResultValue::PointBatch(my_shares), - ResultValue::PointBatch(peer_shares), - ) + let mut offline = self.inner.offline_phase.lock().expect("offline phase poisoned"); + let (masked_vals, mask_shares) = if self.party_id() == sender { + let (masks, mask_shares) = offline.next_local_input_mask_batch(n); + let mask_times_gen = + masks.into_iter().map(|mask| mask * CurvePoint::generator()).collect_vec(); + let masked = + vals.into_iter().zip(mask_times_gen).map(|(val, mask)| val - mask).collect(); + let masked_vals = self.batch_share_plaintext(masked, sender); + + (masked_vals, mask_shares) } else { - self.receive_value() - }; + let mask_shares = offline.next_counterparty_input_mask_batch(n); + let masked_vals = self.batch_share_plaintext(vec![CurvePoint::generator(); n], sender); - // TODO: Proper input authentication - self.new_batch_gate_op(vec![shares.id()], n, |mut args| { - let shares: Vec> = args.next().unwrap().into(); - let mut res = Vec::with_capacity(shares.len()); + (masked_vals, mask_shares) + }; - for share in shares.into_iter() { - res.push(ResultValue::PointShare(PointShare::new(share, share /* mac */))); - } + let shares = self.allocate_scalar_shares(mask_shares); + let masks = AuthenticatedPointResult::batch_mul_generator(&shares); - res - }) + AuthenticatedPointResult::batch_add_public(&masks, &masked_vals) } /// Allocate a public value in the fabric @@ -754,38 +685,23 @@ impl MpcFabric { .collect_vec() } - /// Allocate a scalar as a secret share of an already shared value - pub fn allocate_preshared_scalar>>( - &self, - value: T, - ) -> AuthenticatedScalarResult { - let allocated = self.allocate_scalar(value); - - // TODO: Proper input authentication - self.new_gate_op(vec![allocated.id()], |mut args| { - let share: Scalar = args.next().unwrap().into(); - ResultValue::ScalarShare(ScalarShare::new(share, share /* mac */)) - }) + /// Allocate a point secret share in the fabric + pub fn allocate_point_share(&self, share: PointShare) -> AuthenticatedPointResult { + let id = self.inner.allocate_value(ResultValue::PointShare(share)); + ResultHandle::new(id, self.clone()) } - /// Allocate a batch of scalars as secret shares of already shared values - pub fn batch_allocate_preshared_scalar>>( + /// Allocate a batch of point secret shares in the fabric + pub fn allocate_point_shares( &self, - values: Vec, - ) -> Vec> { - let values = self.allocate_scalars(values); - - // TODO: Proper input authentication - let ids = values.iter().map(|v| v.id()).collect_vec(); - self.new_batch_gate_op(ids, values.len(), |args| { - let shares: Vec> = args.into_iter().map(|s| s.into()).collect_vec(); - let mut res = Vec::new(); - for share in shares.into_iter() { - res.push(ResultValue::ScalarShare(ScalarShare::new(share, share /* mac */))); - } - - res - }) + shares: Vec>, + ) -> Vec> { + let result_values = shares.into_iter().map(ResultValue::PointShare).collect_vec(); + self.inner + .allocate_values(result_values) + .into_iter() + .map(|id| ResultHandle::new(id, self.clone())) + .collect_vec() } /// Allocate a public curve point in the fabric @@ -882,12 +798,19 @@ impl MpcFabric { &self, values: Vec, sender: PartyId, - ) -> ResultHandle> + ) -> Vec> where - T: 'static + From> + Send + Sync, + T: 'static + From> + Into> + Send + Sync, Vec: Into> + From>, { - self.share_plaintext(values, sender) + let n = values.len(); + let res = self.share_plaintext(values, sender); + + // Split the vec into a result of values + self.new_batch_gate_op(vec![res.id()], n, |mut args| { + let values: Vec = args.next().unwrap().into(); + values.into_iter().map(Into::into).collect_vec() + }) } // ------------------- @@ -947,7 +870,7 @@ impl MpcFabric { } // ----------------- - // | Beaver Source | + // | Offline Phase | // ----------------- /// Sample the next beaver triplet with MACs from the beaver source @@ -956,7 +879,7 @@ impl MpcFabric { ) -> (AuthenticatedScalarResult, AuthenticatedScalarResult, AuthenticatedScalarResult) { let (a, b, c) = - self.inner.beaver_source.lock().expect("beaver source poisoned").next_triplet(); + self.inner.offline_phase.lock().expect("beaver source poisoned").next_triplet(); let mut abc = self.allocate_scalar_shares(vec![a, b, c]); let c_val = abc.pop().unwrap(); @@ -977,7 +900,7 @@ impl MpcFabric { Vec>, ) { let (a_vals, b_vals, c_vals) = - self.inner.beaver_source.lock().expect("beaver source poisoned").next_triplet_batch(n); + self.inner.offline_phase.lock().expect("beaver source poisoned").next_triplet_batch(n); // Concatenate and allocate all the values let vals = a_vals.into_iter().chain(b_vals).chain(c_vals).collect_vec(); @@ -991,14 +914,14 @@ impl MpcFabric { (a_vals, b_vals, c_vals) } - /// Sample a batch of random shared values from the beaver source and + /// Sample a batch of random shared values from the offline phase and /// allocate them as `AuthenticatedScalars` pub fn random_shared_scalars(&self, n: usize) -> Vec> { let values_raw = self .inner - .beaver_source + .offline_phase .lock() - .expect("beaver source poisoned") + .expect("offline phase poisoned") .next_shared_value_batch(n); self.allocate_scalar_shares(values_raw) @@ -1008,7 +931,7 @@ impl MpcFabric { pub fn random_inverse_pair( &self, ) -> (AuthenticatedScalarResult, AuthenticatedScalarResult) { - let (l, r) = self.inner.beaver_source.lock().unwrap().next_shared_inverse_pair(); + let (l, r) = self.inner.offline_phase.lock().unwrap().next_shared_inverse_pair(); let mut lr = self.allocate_scalar_shares(vec![l, r]); let r = lr.pop().unwrap(); let l = lr.pop().unwrap(); @@ -1022,7 +945,7 @@ impl MpcFabric { n: usize, ) -> (Vec>, Vec>) { let (left, right) = - self.inner.beaver_source.lock().unwrap().next_shared_inverse_pair_batch(n); + self.inner.offline_phase.lock().unwrap().next_shared_inverse_pair_batch(n); let left_right = left.into_iter().chain(right).collect_vec(); let mut allocated_left_right = self.allocate_scalar_shares(left_right); @@ -1034,21 +957,21 @@ impl MpcFabric { (left, right) } - /// Sample a random shared bit from the beaver source + /// Sample a random shared bit from the offline phase pub fn random_shared_bit(&self) -> AuthenticatedScalarResult { let bit = - self.inner.beaver_source.lock().expect("beaver source poisoned").next_shared_bit(); + self.inner.offline_phase.lock().expect("offline phase poisoned").next_shared_bit(); self.allocate_scalar_share(bit) } - /// Sample a batch of random shared bits from the beaver source + /// Sample a batch of random shared bits from the offline phase pub fn random_shared_bits(&self, n: usize) -> Vec> { let bits = self .inner - .beaver_source + .offline_phase .lock() - .expect("beaver source poisoned") + .expect("offline phase poisoned") .next_shared_bit_batch(n); self.allocate_scalar_shares(bits) diff --git a/online-phase/src/fabric/result.rs b/online-phase/src/fabric/result.rs index 247543c..c89a866 100644 --- a/online-phase/src/fabric/result.rs +++ b/online-phase/src/fabric/result.rs @@ -118,6 +118,11 @@ impl From> for Vec { } } } +impl From> for ResultValue { + fn from(value: Vec) -> Self { + ResultValue::Bytes(value) + } +} impl From> for Scalar { fn from(value: ResultValue) -> Self { @@ -128,6 +133,12 @@ impl From> for Scalar { } } +impl From> for ResultValue { + fn from(value: Scalar) -> Self { + ResultValue::Scalar(value) + } +} + impl From<&ResultValue> for Scalar { fn from(value: &ResultValue) -> Self { match value { @@ -146,6 +157,12 @@ impl From> for Vec> { } } +impl From>> for ResultValue { + fn from(value: Vec>) -> Self { + ResultValue::ScalarBatch(value) + } +} + impl From> for ScalarShare { fn from(value: ResultValue) -> Self { match value { @@ -155,6 +172,12 @@ impl From> for ScalarShare { } } +impl From> for ResultValue { + fn from(value: ScalarShare) -> Self { + ResultValue::ScalarShare(value) + } +} + impl From> for CurvePoint { fn from(value: ResultValue) -> Self { match value { @@ -164,6 +187,12 @@ impl From> for CurvePoint { } } +impl From> for ResultValue { + fn from(value: CurvePoint) -> Self { + ResultValue::Point(value) + } +} + impl From<&ResultValue> for CurvePoint { fn from(value: &ResultValue) -> Self { match value { @@ -182,6 +211,12 @@ impl From> for Vec> { } } +impl From>> for ResultValue { + fn from(value: Vec>) -> Self { + ResultValue::PointBatch(value) + } +} + impl From> for PointShare { fn from(value: ResultValue) -> Self { match value { @@ -191,6 +226,12 @@ impl From> for PointShare { } } +impl From> for ResultValue { + fn from(value: PointShare) -> Self { + ResultValue::PointShare(value) + } +} + // --------------- // | Handle Type | // --------------- diff --git a/online-phase/src/lib.rs b/online-phase/src/lib.rs index a4347f4..714f9f1 100644 --- a/online-phase/src/lib.rs +++ b/online-phase/src/lib.rs @@ -57,7 +57,7 @@ pub mod test_helpers { algebra::{AuthenticatedPointResult, AuthenticatedScalarResult, CurvePoint, Scalar}, fabric::ExecutorSizeHints, network::{MockNetwork, NoRecvNetwork, UnboundedDuplexStream}, - offline_prep::{OfflinePhase, PartyIDBeaverSource}, + offline_prep::{PartyIDBeaverSource, PreprocessingPhase}, MpcFabric, PARTY0, PARTY1, }; @@ -149,7 +149,7 @@ pub mod test_helpers { party1_beaver: B, ) -> (T, T) where - B: 'static + OfflinePhase, + B: 'static + PreprocessingPhase, T: Send + 'static, S: Future + Send + 'static, F: FnMut(MpcFabric) -> S, diff --git a/online-phase/src/offline_prep.rs b/online-phase/src/offline_prep.rs index 8c397df..ae1346c 100644 --- a/online-phase/src/offline_prep.rs +++ b/online-phase/src/offline_prep.rs @@ -6,12 +6,36 @@ use itertools::Itertools; use crate::algebra::{Scalar, ScalarShare}; -/// OfflinePhase implements both the functionality for: -/// 1. Single additively shared values [x] where party 1 holds x_1 and party -/// 2 holds x_2 such that x_1 + x_2 = x -/// 2. Beaver triplets; additively shared values [a], [b], [c] such that a * -/// b = c -pub trait OfflinePhase: Send + Sync { +/// PreprocessingPhase implements both the functionality for: +/// 1. Input authentication and sharing +/// 2. Shared values from the pre-processing phase +pub trait PreprocessingPhase: Send + Sync { + // === Input Authentication === // + /// Get the local party's share of the mac key + fn get_mac_key_share(&self) -> Scalar; + /// Get an input mask value for the local party + /// + /// That is, a cleartext random value and the local party's share of the + /// value + fn next_local_input_mask(&mut self) -> (Scalar, ScalarShare); + /// Get a batch of input mask values for the local party + fn next_local_input_mask_batch( + &mut self, + num_values: usize, + ) -> (Vec>, Vec>) { + (0..num_values).map(|_| self.next_local_input_mask()).unzip() + } + /// Get an input mask share for the counterparty + /// + /// That is, a share of a random value for which the counterparty holds the + /// cleartext + fn next_counterparty_input_mask(&mut self) -> ScalarShare; + /// Get a batch of input mask shares for the counterparty + fn next_counterparty_input_mask_batch(&mut self, num_values: usize) -> Vec> { + (0..num_values).map(|_| self.next_counterparty_input_mask()).collect_vec() + } + + // === Shared Values === // /// Fetch the next shared single bit fn next_shared_bit(&mut self) -> ScalarShare; /// Fetch the next shared batch of bits @@ -70,6 +94,7 @@ pub struct PartyIDBeaverSource { impl PartyIDBeaverSource { /// Create a new beaver source given the local party_id pub fn new(party_id: u64) -> Self { + assert!(party_id == 0 || party_id == 1); Self { party_id } } } @@ -81,10 +106,30 @@ impl PartyIDBeaverSource { /// We also assume the MAC key is a secret sharing of 1 with each party holding /// their own party id as a mac key share #[cfg(any(feature = "test_helpers", test))] -impl OfflinePhase for PartyIDBeaverSource { +impl PreprocessingPhase for PartyIDBeaverSource { + fn get_mac_key_share(&self) -> Scalar { + Scalar::from(self.party_id) + } + + fn next_local_input_mask(&mut self) -> (Scalar, ScalarShare) { + let party = Scalar::from(self.party_id); + let value = Scalar::from(3u8); + let share = party * value; + let mac = party * value; + + (value, ScalarShare::new(share, mac)) + } + + fn next_counterparty_input_mask(&mut self) -> ScalarShare { + let party = Scalar::from(self.party_id); + let value = Scalar::from(3u8) * party; + let mac = party * value; + + ScalarShare::new(value, mac) + } + fn next_shared_bit(&mut self) -> ScalarShare { // Simply output partyID, assume partyID \in {0, 1} - assert!(self.party_id == 0 || self.party_id == 1); let value = Scalar::from(self.party_id); ScalarShare::new(value, value) } @@ -94,10 +139,10 @@ impl OfflinePhase for PartyIDBeaverSource { let b = Scalar::from(3u8); let c = Scalar::from(6u8); - let party_id = Scalar::from(self.party_id); - let a_mac = party_id * a; - let b_mac = party_id * b; - let c_mac = party_id * c; + let key = self.get_mac_key_share(); + let a_mac = key * a; + let b_mac = key * b; + let c_mac = key * c; let (a_share, b_share, c_share) = if self.party_id == 0 { (Scalar::from(1u64), Scalar::from(3u64), Scalar::from(2u64))