From 9c4622e1c18ec25a881428f7fe8cd4f6c861cc62 Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Tue, 16 Apr 2024 18:46:51 -0700 Subject: [PATCH] offline-phase: structs: Implement `OfflinePhase` for `LowGearPrep` This implements the interface needed to be used in the online phase. --- mp-spdz-rs/src/ffi.rs | 2 + offline-phase/src/lib.rs | 2 +- offline-phase/src/structs.rs | 95 ++++++++++++++++++++++++++++++++ online-phase/src/offline_prep.rs | 1 + 4 files changed, 99 insertions(+), 1 deletion(-) diff --git a/mp-spdz-rs/src/ffi.rs b/mp-spdz-rs/src/ffi.rs index 02a14b6..7c6b044 100644 --- a/mp-spdz-rs/src/ffi.rs +++ b/mp-spdz-rs/src/ffi.rs @@ -168,7 +168,9 @@ mod ffi_inner { } pub use ffi_inner::*; unsafe impl Send for FHE_Params {} +unsafe impl Sync for FHE_Params {} unsafe impl Send for FHE_KeyPair {} +unsafe impl Sync for FHE_KeyPair {} unsafe impl Send for FHE_PK {} unsafe impl Sync for FHE_PK {} unsafe impl Send for Ciphertext {} diff --git a/offline-phase/src/lib.rs b/offline-phase/src/lib.rs index 1643284..f6a37d5 100644 --- a/offline-phase/src/lib.rs +++ b/offline-phase/src/lib.rs @@ -87,7 +87,7 @@ pub(crate) mod test_helpers { let mut rng = thread_rng(); let a = (0..n).map(|_| Scalar::::random(&mut rng)).collect_vec(); let b = (0..n).map(|_| Scalar::::random(&mut rng)).collect_vec(); - let c = (0..n).map(|_| Scalar::::random(&mut rng)).collect_vec(); + let c = a.iter().zip(b.iter()).map(|(a, b)| a * b).collect_vec(); (a, b, c) } diff --git a/offline-phase/src/structs.rs b/offline-phase/src/structs.rs index 6e24dba..f08a93a 100644 --- a/offline-phase/src/structs.rs +++ b/offline-phase/src/structs.rs @@ -4,6 +4,7 @@ use std::ops::{Add, Mul, Sub}; use ark_ec::CurveGroup; use ark_mpc::algebra::{Scalar, ScalarShare}; +use ark_mpc::offline_prep::OfflinePhase; use mp_spdz_rs::fhe::ciphertext::Ciphertext; use mp_spdz_rs::fhe::keys::{BGVKeypair, BGVPublicKey}; use mp_spdz_rs::fhe::params::BGVParams; @@ -92,6 +93,58 @@ impl LowGearPrep { } } +impl OfflinePhase for LowGearPrep { + fn next_shared_bit(&mut self) -> ScalarShare { + self.bits.split_off(1).into_inner()[0] + } + + fn next_shared_bit_batch(&mut self, num_values: usize) -> Vec> { + assert!(self.bits.len() >= num_values, "shared bits exhausted"); + self.bits.split_off(num_values).into_inner() + } + + fn next_shared_value(&mut self) -> ScalarShare { + self.shared_randomness.split_off(1).into_inner()[0] + } + + fn next_shared_value_batch(&mut self, num_values: usize) -> Vec> { + assert!(self.shared_randomness.len() >= num_values, "shared random values exhausted"); + self.shared_randomness.split_off(num_values).into_inner() + } + + fn next_shared_inverse_pair(&mut self) -> (ScalarShare, ScalarShare) { + let (lhs, rhs) = self.next_shared_inverse_pair_batch(1); + (lhs[0], rhs[0]) + } + + fn next_shared_inverse_pair_batch( + &mut self, + num_pairs: usize, + ) -> (Vec>, Vec>) { + assert!(self.inverse_pairs.0.len() >= num_pairs, "shared inverse pairs exhausted"); + let lhs = self.inverse_pairs.0.split_off(num_pairs); + let rhs = self.inverse_pairs.1.split_off(num_pairs); + (lhs.into_inner(), rhs.into_inner()) + } + + fn next_triplet(&mut self) -> (ScalarShare, ScalarShare, ScalarShare) { + let (a, b, c) = self.next_triplet_batch(1); + (a[0], b[0], c[0]) + } + + fn next_triplet_batch( + &mut self, + num_triplets: usize, + ) -> (Vec>, Vec>, Vec>) { + assert!(self.triplets.0.len() >= num_triplets, "shared triplets exhausted"); + let a = self.triplets.0.split_off(num_triplets); + let b = self.triplets.1.split_off(num_triplets); + let c = self.triplets.2.split_off(num_triplets); + + (a.into_inner(), b.into_inner(), c.into_inner()) + } +} + // ------------------------ // | Authenticated Shares | // ------------------------ @@ -242,3 +295,45 @@ impl Mul<&[Scalar]> for &ValueMacBatch { ValueMacBatch::new(self.inner.iter().zip(other.iter()).map(|(a, b)| a * *b).collect()) } } + +#[cfg(test)] +mod test { + use ark_mpc::{ + algebra::Scalar, test_helpers::execute_mock_mpc_with_beaver_source, PARTY0, PARTY1, + }; + use rand::thread_rng; + + use crate::test_helpers::mock_lowgear_with_triples; + + /// Tests the use of the `LowGear` type as an `OfflinePhase` implementation + #[tokio::test] + async fn test_lowgear_offline_phase() { + // Setup the mock offline phase + let (prep1, prep2) = mock_lowgear_with_triples( + 100, // num_triples + |mut lowgear| async move { lowgear.get_offline_result().unwrap() }, + ) + .await; + + // Run a mock mpc using the lowgear offline phase + let mut rng = thread_rng(); + let a = Scalar::random(&mut rng); + let b = Scalar::random(&mut rng); + let expected = a * b; + + let (res, _) = execute_mock_mpc_with_beaver_source( + |fabric| async move { + let a_shared = fabric.share_scalar(a, PARTY0); + let b_shared = fabric.share_scalar(b, PARTY1); + + let c = a_shared * b_shared; + c.open().await + }, + prep1, + prep2, + ) + .await; + + assert_eq!(res, expected); + } +} diff --git a/online-phase/src/offline_prep.rs b/online-phase/src/offline_prep.rs index c2c3d96..8c397df 100644 --- a/online-phase/src/offline_prep.rs +++ b/online-phase/src/offline_prep.rs @@ -56,6 +56,7 @@ pub trait OfflinePhase: Send + Sync { (a_vals, b_vals, c_vals) } } + /// An implementation of a beaver value source that returns /// beaver triples (0, 0, 0) for party 0 and (1, 1, 1) for party 1 #[cfg(any(feature = "test_helpers", test))]