diff --git a/ipa-core/src/ff/boolean.rs b/ipa-core/src/ff/boolean.rs index ec6d9f6add..b503082bf2 100644 --- a/ipa-core/src/ff/boolean.rs +++ b/ipa-core/src/ff/boolean.rs @@ -5,7 +5,10 @@ use super::Gf32Bit; use crate::{ ff::{Field, Serializable}, protocol::prss::FromRandomU128, - secret_sharing::{replicated::malicious::ExtendableField, Block, SharedValue}, + secret_sharing::{ + replicated::malicious::ExtendableField, Block, FieldVectorizable, SharedValue, StdArray, + Vectorizable, + }, }; impl Block for bool { @@ -30,6 +33,14 @@ impl SharedValue for Boolean { const ZERO: Self = Self(false); } +impl Vectorizable<1> for Boolean { + type Array = StdArray; +} + +impl FieldVectorizable<1> for Boolean { + type T = StdArray; +} + ///conversion to Scalar struct of `curve25519_dalek` impl From for bool { fn from(s: Boolean) -> Self { diff --git a/ipa-core/src/ff/boolean_array.rs b/ipa-core/src/ff/boolean_array.rs index ea93cabb93..b2fffcd6cd 100644 --- a/ipa-core/src/ff/boolean_array.rs +++ b/ipa-core/src/ff/boolean_array.rs @@ -8,7 +8,7 @@ use typenum::{U14, U2, U32, U8}; use crate::{ ff::{boolean::Boolean, ArrayAccess, Field, Serializable}, protocol::prss::{FromRandom, FromRandomU128}, - secret_sharing::{Block, SharedValue}, + secret_sharing::{Block, FieldVectorizable, SharedValue, StdArray, Vectorizable}, }; /// The implementation below cannot be constrained without breaking Rust's @@ -152,6 +152,10 @@ macro_rules! boolean_array_impl_small { Field::truncate_from(src) } } + + impl FieldVectorizable<1> for $name { + type T = StdArray<$name, 1>; + } }; } @@ -272,6 +276,10 @@ macro_rules! boolean_array_impl { } } + impl Vectorizable<1> for $name { + type Array = StdArray<$name, 1>; + } + impl std::ops::Mul for $name { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { diff --git a/ipa-core/src/ff/curve_points.rs b/ipa-core/src/ff/curve_points.rs index 3e5a97b691..614e27029c 100644 --- a/ipa-core/src/ff/curve_points.rs +++ b/ipa-core/src/ff/curve_points.rs @@ -7,7 +7,7 @@ use typenum::U32; use crate::{ ff::{ec_prime_field::Fp25519, Serializable}, - secret_sharing::{Block, SharedValue}, + secret_sharing::{Block, SharedValue, StdArray, Vectorizable}, }; impl Block for CompressedRistretto { @@ -35,6 +35,10 @@ impl SharedValue for RP25519 { const ZERO: Self = Self(CompressedRistretto([0_u8; 32])); } +impl Vectorizable<1> for RP25519 { + type Array = StdArray; +} + impl Serializable for RP25519 { type Size = <::Storage as Block>::Size; diff --git a/ipa-core/src/ff/ec_prime_field.rs b/ipa-core/src/ff/ec_prime_field.rs index 0077f8720b..d63929fa97 100644 --- a/ipa-core/src/ff/ec_prime_field.rs +++ b/ipa-core/src/ff/ec_prime_field.rs @@ -7,7 +7,7 @@ use typenum::U32; use crate::{ ff::{boolean_array::BA256, Field, Serializable}, protocol::prss::FromRandomU128, - secret_sharing::{Block, SharedValue}, + secret_sharing::{Block, FieldVectorizable, SharedValue, StdArray, Vectorizable}, }; impl Block for Scalar { @@ -172,6 +172,14 @@ macro_rules! sc_hash_impl { #[cfg(test)] sc_hash_impl!(u64); +impl Vectorizable<1> for Fp25519 { + type Array = StdArray; +} + +impl FieldVectorizable<1> for Fp25519 { + type T = StdArray; +} + ///implement Field because required by PRSS impl Field for Fp25519 { const ONE: Fp25519 = Fp25519::ONE; diff --git a/ipa-core/src/ff/field.rs b/ipa-core/src/ff/field.rs index 5535ed8333..58ac91b030 100644 --- a/ipa-core/src/ff/field.rs +++ b/ipa-core/src/ff/field.rs @@ -8,7 +8,7 @@ use typenum::{U1, U4}; use crate::{ error, protocol::prss::FromRandomU128, - secret_sharing::{Block, SharedValue}, + secret_sharing::{Block, FieldVectorizable, SharedValue, Vectorizable}, }; impl Block for u8 { @@ -29,6 +29,8 @@ pub trait Field: + FromRandomU128 + TryFrom + Into + + Vectorizable<1> + + FieldVectorizable<1, T = >::Array> { /// Multiplicative identity element const ONE: Self; diff --git a/ipa-core/src/ff/galois_field.rs b/ipa-core/src/ff/galois_field.rs index 0808e307a5..c068569ebb 100644 --- a/ipa-core/src/ff/galois_field.rs +++ b/ipa-core/src/ff/galois_field.rs @@ -14,7 +14,7 @@ use super::ArrayAccess; use crate::{ ff::{Field, Serializable}, protocol::prss::FromRandomU128, - secret_sharing::{Block, SharedValue}, + secret_sharing::{Block, FieldVectorizable, SharedValue, Vectorizable}, }; /// Trait for data types storing arbitrary number of bits. @@ -168,6 +168,14 @@ macro_rules! bit_array_impl { const ZERO: Self = Self(<$store>::ZERO); } + impl Vectorizable<1> for $name { + type Array = crate::secret_sharing::StdArray<$name, 1>; + } + + impl FieldVectorizable<1> for $name { + type T = crate::secret_sharing::StdArray<$name, 1>; + } + impl Field for $name { const ONE: Self = Self($one); @@ -685,5 +693,31 @@ bit_array_impl!( v } } + + impl From for bool { + fn from(value: Gf2) -> Self { + value != Gf2::ZERO + } + } + + impl From for Gf2 { + fn from(value: crate::ff::boolean::Boolean) -> Self { + bool::from(value).into() + } + } + + impl From for crate::ff::boolean::Boolean { + fn from(value: Gf2) -> Self { + bool::from(value).into() + } + } + + impl std::ops::Not for Gf2 { + type Output = Self; + + fn not(self) -> Self { + (!bool::from(self)).into() + } + } } ); diff --git a/ipa-core/src/ff/prime_field.rs b/ipa-core/src/ff/prime_field.rs index d528ac9294..a67a2309f4 100644 --- a/ipa-core/src/ff/prime_field.rs +++ b/ipa-core/src/ff/prime_field.rs @@ -4,7 +4,7 @@ use super::Field; use crate::{ ff::Serializable, protocol::prss::FromRandomU128, - secret_sharing::{Block, SharedValue}, + secret_sharing::{Block, FieldVectorizable, SharedValue, StdArray, Vectorizable}, }; pub trait PrimeField: Field { @@ -43,6 +43,14 @@ macro_rules! field_impl { const ZERO: Self = $field(0); } + impl Vectorizable<1> for $field { + type Array = StdArray<$field, 1>; + } + + impl FieldVectorizable<1> for $field { + type T = StdArray<$field, 1>; + } + impl Field for $field { const ONE: Self = $field(1); @@ -270,6 +278,14 @@ mod fp31 { mod fp32bit { field_impl! { Fp32BitPrime, u32, 32, 4_294_967_291 } + impl Vectorizable<32> for Fp32BitPrime { + type Array = StdArray; + } + + impl FieldVectorizable<32> for Fp32BitPrime { + type T = StdArray; + } + #[cfg(all(test, unit_test))] mod specialized_tests { use super::*; diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index ed98e9d0b2..b8924343a7 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -7,9 +7,12 @@ use crate::{ context::{Context, UpgradedMaliciousContext}, RecordId, }, - secret_sharing::replicated::{ - malicious::{AdditiveShare as MaliciousReplicated, ExtendableField}, - semi_honest::AdditiveShare as Replicated, + secret_sharing::{ + replicated::{ + malicious::{AdditiveShare as MaliciousReplicated, ExtendableField}, + semi_honest::AdditiveShare as Replicated, + }, + FieldSimd, }, }; @@ -52,7 +55,11 @@ use {malicious::multiply as malicious_mul, semi_honest::multiply as semi_honest_ /// Implement secure multiplication for semi-honest contexts with replicated secret sharing. #[async_trait] -impl SecureMul for Replicated { +impl SecureMul for Replicated +where + C: Context, + F: Field + FieldSimd, +{ async fn multiply_sparse<'fut>( &self, rhs: &Self, diff --git a/ipa-core/src/protocol/basics/mul/semi_honest.rs b/ipa-core/src/protocol/basics/mul/semi_honest.rs index 25de869460..da4753350f 100644 --- a/ipa-core/src/protocol/basics/mul/semi_honest.rs +++ b/ipa-core/src/protocol/basics/mul/semi_honest.rs @@ -8,8 +8,9 @@ use crate::{ prss::SharedRandomness, RecordId, }, - secret_sharing::replicated::{ - semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing, + secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, FieldSimd, SharedValueArray, + Vectorizable, }, }; @@ -26,16 +27,16 @@ use crate::{ /// ## Errors /// Lots of things may go wrong here, from timeouts to bad output. They will be signalled /// back via the error response -pub async fn multiply( +pub async fn multiply( ctx: C, record_id: RecordId, - a: &Replicated, - b: &Replicated, + a: &Replicated, + b: &Replicated, zeros: MultiplyZeroPositions, -) -> Result, Error> +) -> Result, Error> where C: Context, - F: Field, + F: Field + FieldSimd, { let role = ctx.role(); let [need_to_recv, need_to_send, need_random_right] = zeros.work_for(role); @@ -43,19 +44,27 @@ where zeros.1.check(role, "b", b); // Shared randomness used to mask the values that are sent. - let (s0, s1) = ctx.prss().generate(record_id); + let (s0, s1) = ctx + .prss() + .generate::<(>::Array, _), _>(record_id); + + let mut rhs = a.right_arr().clone() * b.right_arr(); - let mut rhs = a.right() * b.right(); if need_to_send { // Compute the value (d_i) we want to send to the right helper (i+1). - let right_d = a.left() * b.right() + a.right() * b.left() - s0; + let right_d = + a.left_arr().clone() * b.right_arr() + a.right_arr().clone() * b.left_arr() - &s0; + // TODO: can we make `send` take a reference to the message? ctx.send_channel(role.peer(Direction::Right)) - .send(record_id, right_d) + .send(record_id, right_d.clone()) .await?; rhs += right_d; } else { - debug_assert_eq!(a.left() * b.right() + a.right() * b.left(), F::ZERO); + debug_assert_eq!( + a.left_arr().clone() * b.right_arr() + a.right_arr().clone() * b.left_arr(), + <>::Array as SharedValueArray>::ZERO + ); } // Add randomness to this value whether we sent or not, depending on whether the // peer to the right needed to send. If they send, they subtract randomness, @@ -65,9 +74,9 @@ where } // Sleep until helper on the left sends us their (d_i-1) value. - let mut lhs = a.left() * b.left(); + let mut lhs = a.left_arr().clone() * b.left_arr(); if need_to_recv { - let left_d = ctx + let left_d: >::Array = ctx .recv_channel(role.peer(Direction::Left)) .receive(record_id) .await?; @@ -78,21 +87,32 @@ where lhs += s0; } - Ok(Replicated::new(lhs, rhs)) + Ok(Replicated::new_arr(lhs, rhs)) } #[cfg(all(test, unit_test))] mod test { - use std::iter::{repeat, zip}; + use std::{ + array, + iter::{repeat, zip}, + time::Instant, + }; use rand::distributions::{Distribution, Standard}; + use super::multiply; use crate::{ - ff::{Field, Fp31}, - protocol::{basics::SecureMul, context::Context, RecordId}, + ff::{Field, Fp31, Fp32BitPrime}, + helpers::TotalRecords, + protocol::{ + basics::{SecureMul, ZeroPositions}, + context::Context, + RecordId, + }, rand::{thread_rng, Rng}, + secret_sharing::replicated::semi_honest::AdditiveShare, seq_join::SeqJoin, - test_fixture::{Reconstruct, Runner, TestWorld}, + test_fixture::{Reconstruct, ReconstructArr, Runner, TestWorld}, }; #[tokio::test] @@ -182,4 +202,136 @@ mod test { result.reconstruct().as_u128() } + + const MANYMULT_ITERS: usize = 16384; + const MANYMULT_WIDTH: usize = 32; + + #[tokio::test] + pub async fn wide_mul() { + const COUNT: usize = 32; + let world = TestWorld::default(); + + let mut rng = thread_rng(); + let a: [Fp32BitPrime; COUNT] = (0..COUNT) + .map(|_| rng.gen::()) + .collect::>() + .try_into() + .unwrap(); + let b: [Fp32BitPrime; COUNT] = (0..COUNT) + .map(|_| rng.gen::()) + .collect::>() + .try_into() + .unwrap(); + let expected: [Fp32BitPrime; COUNT] = zip(a.iter(), b.iter()) + .map(|(&a, &b)| a * b) + .collect::>() + .try_into() + .unwrap(); + let results = world + .semi_honest((a, b), |ctx, (a_shares, b_shares)| async move { + multiply( + ctx.set_total_records(1), + RecordId::from(0), + &a_shares, + &b_shares, + ZeroPositions::NONE, + ) + .await + .unwrap() + }) + .await; + assert_eq!(expected, results.reconstruct_arr()); + } + + #[tokio::test] + pub async fn manymult_novec() { + let world = TestWorld::default(); + let mut rng = thread_rng(); + let mut inputs = Vec::>::new(); + for _ in 0..MANYMULT_ITERS { + inputs.push( + (0..MANYMULT_WIDTH) + .map(|_| Fp32BitPrime::try_from(rng.gen_range(0u32..100) as u128).unwrap()) + .collect::>(), + ); + } + let expected = inputs + .iter() + .fold(None, |acc: Option>, b| match acc { + Some(a) => Some(a.iter().zip(b.iter()).map(|(&a, &b)| a * b).collect()), + None => Some(b.to_vec()), + }) + .unwrap(); + + let begin = Instant::now(); + let result = world + .semi_honest( + inputs.into_iter().map(IntoIterator::into_iter), + |ctx, share: Vec>>| async move { + let ctx = ctx.set_total_records(MANYMULT_ITERS * MANYMULT_WIDTH); + let mut iter = share.iter(); + let mut val = iter.next().unwrap().clone(); + for i in 1..MANYMULT_ITERS.try_into().unwrap() { + let cur = iter.next().unwrap(); + let mut res = Vec::with_capacity(MANYMULT_WIDTH); + for j in 0..MANYMULT_WIDTH { + //res.push(ctx.clone().multiply(RecordId::from(MANYMULT_WIDTH * (i - 1) + j), &val[j], &cur[j])); + res.push(val[j].multiply( + &cur[j], + ctx.clone(), + RecordId::from(MANYMULT_WIDTH * (i - 1) + j), + )); + } + val = ctx.parallel_join(res).await.unwrap(); + } + val + }, + ) + .await; + tracing::info!("Protocol execution time: {:?}", begin.elapsed()); + assert_eq!(expected, result.reconstruct()); + } + + #[tokio::test] + pub async fn manymult_vec() { + let world = TestWorld::default(); + let mut rng = thread_rng(); + let mut inputs = Vec::<[Fp32BitPrime; MANYMULT_WIDTH]>::new(); + for _ in 0..MANYMULT_ITERS { + inputs.push(array::from_fn(|_| rng.gen())); + } + let expected = inputs + .iter() + .fold(None, |acc: Option>, b| match acc { + Some(a) => Some(a.iter().zip(b.iter()).map(|(&a, &b)| a * b).collect()), + None => Some(b.to_vec()), + }) + .unwrap(); + + let begin = Instant::now(); + let result = world + .semi_honest( + inputs.into_iter(), + |ctx, share: Vec>| async move { + let ctx = ctx.set_total_records(TotalRecords::Indeterminate); + let mut iter = share.iter(); + let mut val = iter.next().unwrap().clone(); + for i in 1..MANYMULT_ITERS.try_into().unwrap() { + val = multiply( + ctx.clone(), + RecordId::from(i - 1), + &val, + iter.next().unwrap(), + ZeroPositions::NONE, + ) + .await + .unwrap(); + } + val + }, + ) + .await; + tracing::info!("Protocol execution time: {:?}", begin.elapsed()); + assert_eq!(expected, result.reconstruct_arr()); + } } diff --git a/ipa-core/src/protocol/basics/mul/sparse.rs b/ipa-core/src/protocol/basics/mul/sparse.rs index 9f1ad99431..878199b237 100644 --- a/ipa-core/src/protocol/basics/mul/sparse.rs +++ b/ipa-core/src/protocol/basics/mul/sparse.rs @@ -1,5 +1,8 @@ +#[cfg_attr(not(debug_assertions), allow(unused_variables))] +use crate::secret_sharing::Vectorizable; use crate::{ - ff::Field, helpers::Role, secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, + helpers::Role, + secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, }; /// A description of a replicated secret sharing, with zero values at known positions. @@ -105,25 +108,28 @@ impl ZeroPositions { /// # Panics /// When the input value includes a non-zero value in a position marked as having a zero. #[cfg_attr(not(debug_assertions), allow(unused_variables))] - pub fn check(self, role: Role, which: &str, v: &Replicated) { + pub fn check, const N: usize>( + self, + role: Role, + which: &str, + v: &Replicated, + ) { #[cfg(debug_assertions)] { - use crate::{ - helpers::Direction::Right, secret_sharing::replicated::ReplicatedSecretSharing, - }; + use crate::{helpers::Direction::Right, secret_sharing::SharedValueArray}; let flags = <[bool; 3]>::from(self); if flags[role as usize] { assert_eq!( - F::ZERO, - v.left(), + &>::Array::ZERO, + v.left_arr(), "expected a zero on the left for input {which}" ); } if flags[role.peer(Right) as usize] { assert_eq!( - F::ZERO, - v.right(), + &>::Array::ZERO, + v.right_arr(), "expected a zero on the right for input {which}" ); } diff --git a/ipa-core/src/protocol/basics/share_known_value.rs b/ipa-core/src/protocol/basics/share_known_value.rs index bd811a3dfc..cd36f33e95 100644 --- a/ipa-core/src/protocol/basics/share_known_value.rs +++ b/ipa-core/src/protocol/basics/share_known_value.rs @@ -12,6 +12,10 @@ use crate::{ }, }; +/// Produce a share of some pre-determined constant. +/// +/// The context is only used to determine the helper role. It is not used for communication or PRSS, +/// and it is not necessary to use a uniquely narrowed context. pub trait ShareKnownValue { fn share_known_value(ctx: &C, value: V) -> Self; } diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/comparison_and_subtraction_sequential.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/comparison_and_subtraction_sequential.rs index 54815ff13f..2b54d294fc 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/comparison_and_subtraction_sequential.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/comparison_and_subtraction_sequential.rs @@ -6,7 +6,12 @@ use crate::ff::Expand; use crate::{ error::Error, ff::{ArrayAccess, CustomArray, Field}, - protocol::{basics::SecureMul, context::Context, step::BitOpStep, RecordId}, + protocol::{ + basics::{SecureMul, ShareKnownValue}, + context::Context, + step::BitOpStep, + RecordId, + }, secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue}, }; @@ -32,11 +37,11 @@ where C: Context, YS: SharedValue + CustomArray, XS: SharedValue + CustomArray + Field, - XS::Element: Field + std::ops::Not, + XS::Element: Field, + AdditiveShare: std::ops::Not>, { // we need to initialize carry to 1 for x>=y, - // since there are three shares 1+1+1 = 1 mod 2, so setting left = 1 and right = 1 works - let mut carry = AdditiveShare(XS::Element::ONE, XS::Element::ONE); + let mut carry = AdditiveShare::::share_known_value(&ctx, XS::Element::ONE); // we don't care about the subtraction, we just want the carry let _ = subtraction_circuit(ctx, record_id, x, y, &mut carry).await; Ok(carry) @@ -56,7 +61,8 @@ where C: Context, YS: SharedValue + CustomArray, XS: SharedValue + CustomArray + Field, - XS::Element: Field + std::ops::Not, + XS::Element: Field, + AdditiveShare: std::ops::Not>, { // we need to initialize carry to 0 for x>y let mut carry = AdditiveShare::::ZERO; @@ -80,10 +86,11 @@ where C: Context, YS: SharedValue + CustomArray, XS: SharedValue + CustomArray + Field, - XS::Element: Field + std::ops::Not, + XS::Element: Field, + AdditiveShare: std::ops::Not>, { // we need to initialize carry to 1 for a subtraction - let mut carry = AdditiveShare(XS::Element::ONE, XS::Element::ONE); + let mut carry = AdditiveShare::::share_known_value(&ctx, XS::Element::ONE); subtraction_circuit(ctx, record_id, x, y, &mut carry).await } @@ -102,9 +109,10 @@ pub async fn integer_sat_sub( where C: Context, S: CustomArray + Field, - S::Element: Field + std::ops::Not, + S::Element: Field, + AdditiveShare: std::ops::Not>, { - let mut carry = AdditiveShare(S::Element::ONE, S::Element::ONE); + let mut carry = AdditiveShare::::share_known_value(&ctx, S::Element::ONE); let result = subtraction_circuit( ctx.narrow(&Step::SaturatedSubtraction), record_id, @@ -139,7 +147,8 @@ where C: Context, XS: SharedValue + CustomArray, YS: SharedValue + CustomArray, - XS::Element: Field + std::ops::Not, + XS::Element: Field, + AdditiveShare: std::ops::Not>, { let mut result = AdditiveShare::::ZERO; for (i, v) in x.iter().enumerate() { @@ -182,7 +191,8 @@ async fn bit_subtractor( ) -> Result, Error> where C: Context, - S: Field + std::ops::Not, + S: Field, + AdditiveShare: std::ops::Not>, { let output = x + !(y.unwrap_or(&AdditiveShare::::ZERO) + &*carry); @@ -228,27 +238,27 @@ mod test { assert_eq!(::ONE, !(::ZERO)); assert_eq!(::ZERO, !(::ONE)); assert_eq!( - AdditiveShare(::ZERO, ::ZERO), - !AdditiveShare(::ONE, ::ONE) + AdditiveShare::([::ZERO].into(), [::ZERO].into()), + !AdditiveShare([::ONE].into(), [::ONE].into()) ); assert_eq!( - AdditiveShare( - ::expand(&::ZERO), - ::expand(&::ZERO) + AdditiveShare::( + [::expand(&::ZERO)].into(), + [::expand(&::ZERO)].into(), ), - !AdditiveShare( - ::expand(&::ONE), - ::expand(&::ONE) + !AdditiveShare::( + [::expand(&::ONE)].into(), + [::expand(&::ONE)].into(), ) ); assert_eq!( - !AdditiveShare( - ::expand(&::ZERO), - ::expand(&::ZERO) + !AdditiveShare::( + [::expand(&::ZERO)].into(), + [::expand(&::ZERO)].into(), ), - AdditiveShare( - ::expand(&::ONE), - ::expand(&::ONE) + AdditiveShare::( + [::expand(&::ONE)].into(), + [::expand(&::ONE)].into(), ) ); } diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs index b5f06147d6..ae30efa231 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs @@ -15,7 +15,7 @@ use crate::{ }, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, - SharedValue, + SharedValue, SharedValueArray, StdArray, }, }; @@ -122,16 +122,16 @@ where // sh_s: H1: (r1,0), H2: (0,0), H3: (0, r1) match ctx.role() { Role::H1 => ( - AdditiveShare(::ZERO, ::ZERO), - AdditiveShare(r.0, ::ZERO), + AdditiveShare(>::ZERO, >::ZERO), + AdditiveShare(r.0.into(), >::ZERO), ), Role::H2 => ( - AdditiveShare(::ZERO, r.1), - AdditiveShare(::ZERO, ::ZERO), + AdditiveShare(>::ZERO, r.1.into()), + AdditiveShare(>::ZERO, >::ZERO), ), Role::H3 => ( - AdditiveShare(r.0, ::ZERO), - AdditiveShare(::ZERO, r.1), + AdditiveShare(r.0.into(), >::ZERO), + AdditiveShare(>::ZERO, r.1.into()), ), } }; @@ -319,7 +319,8 @@ mod tests { let a = rng.gen::(); - let shared_a = AdditiveShare::(rng.gen::(), rng.gen::()); + let shared_a = + AdditiveShare::([rng.gen::()].into(), [rng.gen::()].into()); let b = expand_array::<_, BA256>(&a, None); diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index ccaa667814..606907eea0 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -35,7 +35,7 @@ use crate::{ malicious::ExtendableField, semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing, }, - BitDecomposed, Linear as LinearSecretSharing, SharedValue, + BitDecomposed, Linear as LinearSecretSharing, SharedValue, SharedValueArray, }, seq_join::{seq_join, SeqJoin}, }; @@ -74,8 +74,14 @@ where ); let mut offset = BA7::BITS as usize; - self.sort_key.0.set(offset, self.is_trigger_bit.left()); - self.sort_key.1.set(offset, self.is_trigger_bit.right()); + self.sort_key + .0 + .index(0) + .set(offset, self.is_trigger_bit.left()); + self.sort_key + .1 + .index(0) + .set(offset, self.is_trigger_bit.right()); offset += 1; expand_shared_array_in_place(&mut self.sort_key, &self.timestamp, offset); @@ -262,15 +268,35 @@ impl< if i < bk_bits { BitConversionTriple::new( role, - self.attributed_breakdown_key_bits.0.get(i).unwrap() == Boolean::ONE, - self.attributed_breakdown_key_bits.1.get(i).unwrap() == Boolean::ONE, + self.attributed_breakdown_key_bits + .0 + .index(0) + .get(i) + .unwrap() + == Boolean::ONE, + self.attributed_breakdown_key_bits + .1 + .index(0) + .get(i) + .unwrap() + == Boolean::ONE, ) } else { let i = i - bk_bits; BitConversionTriple::new( role, - self.capped_attributed_trigger_value.0.get(i).unwrap() == Boolean::ONE, - self.capped_attributed_trigger_value.1.get(i).unwrap() == Boolean::ONE, + self.capped_attributed_trigger_value + .0 + .index(0) + .get(i) + .unwrap() + == Boolean::ONE, + self.capped_attributed_trigger_value + .1 + .index(0) + .get(i) + .unwrap() + == Boolean::ONE, ) } } diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index f9432db6ae..805fad3883 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -11,7 +11,7 @@ use crate::{ report::OprfReport, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, - SharedValue, + SharedValue, SharedValueArray, }, }; @@ -57,8 +57,8 @@ where let mut offset = BA64::BITS as usize; - y.0.set(offset, input.is_trigger.left()); - y.1.set(offset, input.is_trigger.right()); + y.0.index(0).set(offset, input.is_trigger.left()); + y.1.index(0).set(offset, input.is_trigger.right()); offset += 1; diff --git a/ipa-core/src/secret_sharing/array.rs b/ipa-core/src/secret_sharing/array.rs new file mode 100644 index 0000000000..18c9b7d04f --- /dev/null +++ b/ipa-core/src/secret_sharing/array.rs @@ -0,0 +1,283 @@ +use std::{ + array, + fmt::Debug, + ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}, +}; + +use generic_array::{ArrayLength, GenericArray}; +use typenum::{U1, U32, U8}; + +use crate::{ + ff::{boolean::Boolean, boolean_array::BA64, Field, Fp32BitPrime, Serializable}, + helpers::Message, + protocol::prss::FromRandom, + secret_sharing::{FieldArray, SharedValue, SharedValueArray}, +}; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct StdArray([V; N]); + +impl From> for [V; N] { + fn from(value: StdArray) -> Self { + value.0 + } +} + +impl From<[V; N]> for StdArray { + fn from(value: [V; N]) -> Self { + Self(value) + } +} + +impl SharedValueArray for StdArray { + const ZERO: Self = Self([V::ZERO; N]); + + fn index(&self, index: usize) -> V { + self.0[index] + } + + fn from_item(item: V) -> Self { + let mut res = Self::ZERO; + res.0[0] = item; + res + } +} + +impl FieldArray for StdArray {} + +impl TryFrom> for StdArray { + type Error = (); + fn try_from(value: Vec) -> Result { + value.try_into().map(Self).map_err(|_| ()) + } +} + +impl<'a, 'b, V: SharedValue, const N: usize> Add<&'b StdArray> for &'a StdArray { + type Output = StdArray; + + fn add(self, rhs: &'b StdArray) -> Self::Output { + StdArray(array::from_fn(|i| self.0[i] + rhs.0[i])) + } +} + +impl Add for StdArray { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Add::add(&self, &rhs) + } +} + +// add(owned, ref) should be preferred over this. +impl Add> for &StdArray { + type Output = StdArray; + + fn add(self, rhs: StdArray) -> Self::Output { + Add::add(self, &rhs) + } +} + +impl Add<&StdArray> for StdArray { + type Output = Self; + + fn add(self, rhs: &Self) -> Self::Output { + Add::add(&self, rhs) + } +} + +impl AddAssign<&Self> for StdArray { + fn add_assign(&mut self, rhs: &Self) { + for (a, b) in self.0.iter_mut().zip(rhs.0.iter()) { + *a += *b; + } + } +} + +impl AddAssign for StdArray { + fn add_assign(&mut self, rhs: Self) { + AddAssign::add_assign(self, &rhs); + } +} + +impl Neg for &StdArray { + type Output = StdArray; + + fn neg(self) -> Self::Output { + StdArray(array::from_fn(|i| -self.0[i])) + } +} + +impl Neg for StdArray { + type Output = Self; + + fn neg(self) -> Self::Output { + Neg::neg(&self) + } +} + +impl Sub for &StdArray { + type Output = StdArray; + + fn sub(self, rhs: Self) -> Self::Output { + StdArray(array::from_fn(|i| self.0[i] - rhs.0[i])) + } +} + +impl Sub for StdArray { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Sub::sub(&self, &rhs) + } +} + +impl Sub<&Self> for StdArray { + type Output = Self; + + fn sub(self, rhs: &Self) -> Self::Output { + Sub::sub(&self, rhs) + } +} + +impl Sub> for &StdArray { + type Output = StdArray; + + fn sub(self, rhs: StdArray) -> Self::Output { + Sub::sub(self, &rhs) + } +} + +impl SubAssign<&Self> for StdArray { + fn sub_assign(&mut self, rhs: &Self) { + for (a, b) in self.0.iter_mut().zip(rhs.0.iter()) { + *a -= *b; + } + } +} + +impl SubAssign for StdArray { + fn sub_assign(&mut self, rhs: Self) { + SubAssign::sub_assign(self, &rhs); + } +} + +impl<'a, 'b, F: Field, const N: usize> Mul<&'b F> for &'a StdArray { + type Output = StdArray; + + fn mul(self, rhs: &'b F) -> Self::Output { + StdArray(array::from_fn(|i| self.0[i] * *rhs)) + } +} + +impl Mul for StdArray { + type Output = Self; + + fn mul(self, rhs: F) -> Self::Output { + Mul::mul(&self, &rhs) + } +} + +impl Mul<&F> for StdArray { + type Output = Self; + + fn mul(self, rhs: &F) -> Self::Output { + Mul::mul(&self, rhs) + } +} + +impl Mul for &StdArray { + type Output = StdArray; + + fn mul(self, rhs: F) -> Self::Output { + Mul::mul(self, &rhs) + } +} + +impl<'a, F: Field, const N: usize> Mul<&'a StdArray> for StdArray { + type Output = StdArray; + + fn mul(self, rhs: &'a StdArray) -> Self::Output { + StdArray(array::from_fn(|i| self.0[i] * rhs.0[i])) + } +} + +impl std::ops::Not for StdArray { + type Output = StdArray; + + fn not(self) -> Self::Output { + StdArray(array::from_fn(|i| !self.0[i])) + } +} + +impl std::ops::Not for StdArray { + type Output = StdArray; + + fn not(self) -> Self::Output { + StdArray([!self.0[0]]) + } +} + +impl> FromRandom for StdArray { + type SourceLength = U1; + fn from_random(src: GenericArray) -> Self { + Self([F::from_random(src)]) + } +} + +impl FromRandom for StdArray { + type SourceLength = U8; + + fn from_random(src: GenericArray) -> Self { + // TODO: reduce mod p + const WORDS_PER_U128: u32 = 4; + const WORDS: usize = 32; + let mut res = Vec::with_capacity(WORDS); + for word in src { + for j in 0..WORDS_PER_U128 { + res.push(Fp32BitPrime::truncate_from::( + (word >> (j * Fp32BitPrime::BITS)) & u128::from(u32::MAX), + )); + } + } + res.try_into().unwrap() + } +} + +impl Serializable for StdArray { + type Size = ::Size; + + fn serialize(&self, buf: &mut GenericArray) { + self.0[0].serialize(buf); + } + + fn deserialize(buf: &GenericArray) -> Self { + StdArray([V::deserialize(buf)]) + } +} + +impl Serializable for StdArray +where + V: SharedValue, + ::Size: Mul, + <::Size as Mul>::Output: ArrayLength, +{ + type Size = <::Size as Mul>::Output; + + fn serialize(&self, buf: &mut GenericArray) { + let sz: usize = (::BITS / 8).try_into().unwrap(); + for i in 0..32 { + self.0[i].serialize( + &mut GenericArray::try_from_mut_slice(&mut buf[sz * i..sz * (i + 1)]).unwrap(), + ); + } + } + + fn deserialize(buf: &GenericArray) -> Self { + let sz: usize = (::BITS / 8).try_into().unwrap(); + Self(array::from_fn(|i| { + V::deserialize(&GenericArray::from_slice(&buf[sz * i..sz * (i + 1)])) + })) + } +} + +impl Message for StdArray where Self: Serializable {} diff --git a/ipa-core/src/secret_sharing/mod.rs b/ipa-core/src/secret_sharing/mod.rs index 42a62ca555..2c6157cce1 100644 --- a/ipa-core/src/secret_sharing/mod.rs +++ b/ipa-core/src/secret_sharing/mod.rs @@ -1,14 +1,16 @@ pub mod replicated; +mod array; mod decomposed; mod into_shares; mod scheme; use std::{ fmt::Debug, - ops::{Mul, MulAssign, Neg}, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +pub use array::StdArray; pub use decomposed::BitDecomposed; use generic_array::ArrayLength; pub use into_shares::IntoShares; @@ -21,7 +23,11 @@ use rand::{ use replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}; pub use scheme::{Bitwise, Linear, LinearRefOps, SecretSharing}; -use crate::ff::{AddSub, AddSubAssign, Serializable}; +use crate::{ + ff::{boolean::Boolean, AddSub, AddSubAssign, Field, Fp32BitPrime, Gf2, Serializable}, + helpers::Message, + protocol::prss::FromRandom, +}; /// Operations supported for weak shared values. pub trait Additive: @@ -57,7 +63,17 @@ pub trait Block: Sized + Copy + Debug { /// (capable of supporting addition and multiplication) is desired, the `Field` trait extends /// `SharedValue` to require multiplication. pub trait SharedValue: - Clone + Copy + Eq + Debug + Send + Sync + Sized + Additive + Serializable + 'static + Clone + + Copy + + Eq + + Debug + + Send + + Sync + + Sized + + Additive + + Serializable + + Vectorizable<1> + + 'static { type Storage: Block; @@ -66,6 +82,162 @@ pub trait SharedValue: const ZERO: Self; } +// # Vectorization +// +// Vectorization refers to adapting an implementation that previously operated on one value at a +// time, to instead operate on `N` values at a time. Vectorization improves performance in two ways: +// +// 1. Vectorized code can make use of special CPU instructions (Intel AVX, ARM NEON) that operate +// on multiple values at a time. This reduces the CPU time required to perform computations. +// We also use vectorization to refer to "bit packing" of boolean values, i.e., packing +// 64 boolean values into a single u64 rather than using a byte (or even a word) for each +// value. +// 2. Aside from the core arithmetic operations that are involved in our MPC, a substantial +// amount of other code is needed to send values between helpers, schedule futures for +// execution, etc. Vectorization can result in a greater amount of arithmetic work being +// performed for a given amount of overhead work, thus increasing the efficiency of the +// implementation. +// +// ## Vectorization traits +// +// There are two sets of traits related to vectorization. +// +// If you are writing protocols, the traits of interest are `SharedValueSimd` and `FieldSimd`. +// These can be specified in a trait bound, something like `F: Field + FieldSimd`. +// +// The other traits are `Vectorizable` (for `SharedValue`s) and `FieldVectorizable`. These traits +// are needed to work around a limitation in the rust type system. +// +// ## Adding a new supported vectorization +// +// Currently, each vectorization configuration (combination of data type being vectorized and +// vectorization width) must be explicitly implemented. The primary reason this is necessary +// is that Rust doesn't yet support evaluating expressions involving const parameters at compile +// time. +// +// 1. Add `FieldSimd` impl (secret_sharing/mod.rs) +// 2. Add `FromRandom` impl (array.rs or gf2_array.rs) +// 3. Add `Serializable` impl (array.rs or gf2_array.rs) +// 4. Add `Into<[Gf2; N]>` impl (array.rs or gf2_array.rs) +// 4. Add `Vectorizable` and `FieldVectorizable` impl (primitive type def e.g. galois_field.rs) + +/// Trait for `SharedValue`s supporting operations on `N`-wide vectors. +pub trait Vectorizable: Sized { + // There are two (three?) kinds of bounds here: + // 1. Bounds that apply to the array type for vectorized operation, but not universally to + // `SharedValue::Array`. + // 2. Bounds that apply universally to `SharedValue::Array`, but are replicated here due + // to a compiler limitation. + // 3. Field vs. SharedValue + // https://github.com/rust-lang/rust/issues/41118 + type Array: Message + SharedValueArray + Clone + Eq + Send + Sync; +} + +// TODO: Question: What to do with this? +// When SharedValue had the Array associated type, both Vectorizable and FieldVectorizable +// had an associated type T, which was only used to impose further trait bounds on the array. +// Now, Vectorizable::Array is the canonical array type. +/// Trait for `Field`s supporting operations on `N`-wide vectors. +pub trait FieldVectorizable: SharedValue { + // There are two (three?) kinds of bounds here: + // 1. Bounds that apply to the array type for vectorized operation, but not universally to + // `SharedValue::Array`. + // 2. Bounds that apply universally to `SharedValue::Array`, but are replicated here due + // to a compiler limitation. + // 3. Field vs. SharedValue + // https://github.com/rust-lang/rust/issues/41118 + type T: Message + FromRandom + FieldArray + Into<[Self; N]> + Clone + Eq + Send + Sync; + // TODO: do we really want the Into bound here? +} + +// The purpose of this trait is to avoid placing a `Message` trait bound on `SharedValueArray`, or +// similar. Doing so would require either (1) a generic impl of `Serializable` for any `N`, which +// is hard to write, or (2) additional trait bounds of something like `F::Array<1>: Message` +// throughout many of our protocols. +// +// Writing `impl Vectorized<1> for F` means that the compiler will always see that it +// is available anywhere an `F: Field` trait bound is effective. + +pub trait SharedValueSimd: SharedValue {} + +pub trait FieldSimd: + Field + + SharedValueSimd + + Vectorizable>::T> + + FieldVectorizable +{ +} + +// Portions of the implementation treat non-vectorized operations as a vector with `N = 1`. +// These blanket impls are important in allowing code that writes `F: Field` to continue +// working without modification. + +impl SharedValueSimd for F {} + +impl + FieldVectorizable<1, T = >::Array>> + FieldSimd<1> for F +{ +} + +// Supported vectorizations + +impl FieldSimd<32> for Fp32BitPrime {} + +/* +impl FieldSimd<64> for Gf2 { } + +impl FieldSimd<256> for Gf2 { } + +impl FieldSimd<1024> for Gf2 { } + +impl FieldSimd<4096> for Gf2 { } +*/ + +pub trait SharedValueArray: + Clone + + Eq + + Debug + + Send + + Sync + + Sized + + TryFrom, Error = ()> + + Add + + for<'a> Add<&'a Self, Output = Self> + + AddAssign + + for<'a> AddAssign<&'a Self> + + Neg + + Sub + + for<'a> Sub<&'a Self, Output = Self> + + SubAssign + + for<'a> SubAssign<&'a Self> +{ + const ZERO: Self; + + fn index(&self, index: usize) -> V; + + fn from_item(item: V) -> Self; +} + +impl SharedValueArray for T +where + T: SharedValueArray + TryFrom, Error = ()>, +{ + const ZERO: Self = >::ZERO; + + fn index(&self, index: usize) -> Boolean { + >::index(self, index).into() + } + + fn from_item(item: Boolean) -> Self { + >::from_item(item.into()) + } +} + +pub trait FieldArray: + SharedValueArray + for<'a> Mul<&'a F, Output = Self> + for<'a> Mul<&'a Self, Output = Self> +{ +} + #[cfg(any(test, feature = "test-fixture", feature = "cli"))] impl IntoShares> for V where @@ -85,6 +257,29 @@ where } } +#[cfg(any(test, feature = "test-fixture", feature = "cli"))] +impl IntoShares> for [V; N] +where + V: SharedValue + Vectorizable, + Standard: Distribution, +{ + fn share_with(self, rng: &mut R) -> [AdditiveShare; 3] { + // For arrays large enough that the compiler doesn't just unroll everything, it might be + // more efficient to avoid the intermediate vector by implementing this as a specialized + // hybrid of the impls for `F as IntoShares>` and ` as + // IntoShares>`. Not bothering since this is test-support functionality. + let [v1, v2, v3] = self.into_iter().share_with(rng); + let (v1l, v1r): (Vec, Vec) = v1.iter().map(AdditiveShare::as_tuple).unzip(); + let (v2l, v2r): (Vec, Vec) = v2.iter().map(AdditiveShare::as_tuple).unzip(); + let (v3l, v3r): (Vec, Vec) = v3.iter().map(AdditiveShare::as_tuple).unzip(); + [ + AdditiveShare::new_arr(v1l.try_into().unwrap(), v1r.try_into().unwrap()), + AdditiveShare::new_arr(v2l.try_into().unwrap(), v2r.try_into().unwrap()), + AdditiveShare::new_arr(v3l.try_into().unwrap(), v3r.try_into().unwrap()), + ] + } +} + #[cfg(all(test, unit_test))] mod tests { use crate::{ diff --git a/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs b/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs index dff9428ab9..c77f5057f3 100644 --- a/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs +++ b/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs @@ -7,25 +7,36 @@ use generic_array::{ArrayLength, GenericArray}; use typenum::Unsigned; use crate::{ - ff::{ArrayAccess, Expand, Field, Serializable}, + ff::{ArrayAccess, CustomArray, Expand, Field, GaloisField, Gf2, Serializable}, secret_sharing::{ - replicated::ReplicatedSecretSharing, Linear as LinearSecretSharing, SecretSharing, - SharedValue, + replicated::ReplicatedSecretSharing, FieldSimd, Linear as LinearSecretSharing, + SecretSharing, SharedValue, SharedValueArray, Vectorizable, }, }; +/// Additive secret sharing. +/// +/// `AdditiveShare` holds two out of three shares of an additive secret sharing, either of a single +/// value with type `V`, or a vector of such values. #[derive(Clone, PartialEq, Eq)] -pub struct AdditiveShare(pub V, pub V); +pub struct AdditiveShare, const N: usize = 1>( + pub >::Array, + pub >::Array, +); #[derive(Clone, PartialEq, Eq)] pub struct ASIterator(pub T, pub T); -impl SecretSharing for AdditiveShare { - const ZERO: Self = AdditiveShare::ZERO; +impl, const N: usize> SecretSharing for AdditiveShare { + const ZERO: Self = Self( + >::Array::ZERO, + >::Array::ZERO, + ); } -impl LinearSecretSharing for AdditiveShare {} -impl Debug for AdditiveShare { +impl LinearSecretSharing for AdditiveShare where F: Field + FieldSimd {} + +impl + Debug, const N: usize> Debug for AdditiveShare { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "({:?}, {:?})", self.0, self.1) } @@ -37,26 +48,45 @@ impl Default for AdditiveShare { } } -impl AdditiveShare { - /// Replicated secret share where both left and right values are `F::ZERO` - pub const ZERO: Self = Self(V::ZERO, V::ZERO); +impl, const N: usize> AdditiveShare { + /// Replicated secret share where both left and right values are `V::ZERO` + pub const ZERO: Self = Self( + >::Array::ZERO, + >::Array::ZERO, + ); +} +impl AdditiveShare { pub fn as_tuple(&self) -> (V, V) { - (self.0, self.1) + (self.0.index(0), self.1.index(0)) } } impl ReplicatedSecretSharing for AdditiveShare { fn new(a: V, b: V) -> Self { - Self(a, b) + Self(V::Array::from_item(a), V::Array::from_item(b)) } fn left(&self) -> V { - self.0 + self.0.index(0) } fn right(&self) -> V { - self.1 + self.1.index(0) + } +} + +impl, const N: usize> AdditiveShare { + pub fn new_arr(a: >::Array, b: >::Array) -> Self { + Self(a, b) + } + + pub fn left_arr(&self) -> &>::Array { + &self.0 + } + + pub fn right_arr(&self) -> &>::Array { + &self.1 } } @@ -75,15 +105,20 @@ where } } -impl<'a, 'b, V: SharedValue> Add<&'b AdditiveShare> for &'a AdditiveShare { - type Output = AdditiveShare; +impl<'a, 'b, V: SharedValue + Vectorizable, const N: usize> Add<&'b AdditiveShare> + for &'a AdditiveShare +{ + type Output = AdditiveShare; - fn add(self, rhs: &'b AdditiveShare) -> Self::Output { - AdditiveShare(self.0 + rhs.0, self.1 + rhs.1) + fn add(self, rhs: &'b AdditiveShare) -> Self::Output { + AdditiveShare( + Add::add(self.0.clone(), &rhs.0), + Add::add(self.1.clone(), &rhs.1), + ) } } -impl Add for AdditiveShare { +impl, const N: usize> Add for AdditiveShare { type Output = Self; fn add(self, rhs: Self) -> Self::Output { @@ -91,15 +126,19 @@ impl Add for AdditiveShare { } } -impl Add> for &AdditiveShare { - type Output = AdditiveShare; +impl, const N: usize> Add> + for &AdditiveShare +{ + type Output = AdditiveShare; - fn add(self, rhs: AdditiveShare) -> Self::Output { + fn add(self, rhs: AdditiveShare) -> Self::Output { Add::add(self, &rhs) } } -impl Add<&AdditiveShare> for AdditiveShare { +impl, const N: usize> Add<&AdditiveShare> + for AdditiveShare +{ type Output = Self; fn add(self, rhs: &Self) -> Self::Output { @@ -107,28 +146,28 @@ impl Add<&AdditiveShare> for AdditiveShare { } } -impl AddAssign<&Self> for AdditiveShare { +impl, const N: usize> AddAssign<&Self> for AdditiveShare { fn add_assign(&mut self, rhs: &Self) { - self.0 += rhs.0; - self.1 += rhs.1; + self.0 += &rhs.0; + self.1 += &rhs.1; } } -impl AddAssign for AdditiveShare { +impl, const N: usize> AddAssign for AdditiveShare { fn add_assign(&mut self, rhs: Self) { AddAssign::add_assign(self, &rhs); } } -impl Neg for &AdditiveShare { - type Output = AdditiveShare; +impl, const N: usize> Neg for &AdditiveShare { + type Output = AdditiveShare; fn neg(self) -> Self::Output { - AdditiveShare(-self.0, -self.1) + AdditiveShare(-self.0.clone(), -self.1.clone()) } } -impl Neg for AdditiveShare { +impl, const N: usize> Neg for AdditiveShare { type Output = Self; fn neg(self) -> Self::Output { @@ -136,15 +175,18 @@ impl Neg for AdditiveShare { } } -impl Sub for &AdditiveShare { - type Output = AdditiveShare; +impl, const N: usize> Sub for &AdditiveShare { + type Output = AdditiveShare; fn sub(self, rhs: Self) -> Self::Output { - AdditiveShare(self.0 - rhs.0, self.1 - rhs.1) + AdditiveShare( + Sub::sub(self.0.clone(), &rhs.0), + Sub::sub(self.1.clone(), &rhs.1), + ) } } -impl Sub for AdditiveShare { +impl, const N: usize> Sub for AdditiveShare { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { @@ -152,7 +194,7 @@ impl Sub for AdditiveShare { } } -impl Sub<&Self> for AdditiveShare { +impl, const N: usize> Sub<&Self> for AdditiveShare { type Output = Self; fn sub(self, rhs: &Self) -> Self::Output { @@ -160,56 +202,67 @@ impl Sub<&Self> for AdditiveShare { } } -impl Sub> for &AdditiveShare { - type Output = AdditiveShare; +impl, const N: usize> Sub> + for &AdditiveShare +{ + type Output = AdditiveShare; - fn sub(self, rhs: AdditiveShare) -> Self::Output { + fn sub(self, rhs: AdditiveShare) -> Self::Output { Sub::sub(self, &rhs) } } -impl SubAssign<&Self> for AdditiveShare { +impl, const N: usize> SubAssign<&Self> for AdditiveShare { fn sub_assign(&mut self, rhs: &Self) { - self.0 -= rhs.0; - self.1 -= rhs.1; + self.0 -= &rhs.0; + self.1 -= &rhs.1; } } -impl SubAssign for AdditiveShare { +impl, const N: usize> SubAssign for AdditiveShare { fn sub_assign(&mut self, rhs: Self) { SubAssign::sub_assign(self, &rhs); } } -impl<'a, 'b, F: Field> Mul<&'b F> for &'a AdditiveShare { - type Output = AdditiveShare; +impl<'a, 'b, F, const N: usize> Mul<&'b F> for &'a AdditiveShare +where + F: Field + FieldSimd, +{ + type Output = AdditiveShare; fn mul(self, rhs: &'b F) -> Self::Output { - AdditiveShare(self.0 * *rhs, self.1 * *rhs) + AdditiveShare(self.0.clone() * rhs, self.1.clone() * rhs) } } -impl Mul for AdditiveShare { +impl Mul for AdditiveShare +where + F: Field + FieldSimd, +{ type Output = Self; fn mul(self, rhs: F) -> Self::Output { - Mul::mul(&self, &rhs) + Mul::mul(&self, rhs) } } -impl Mul<&F> for AdditiveShare { +impl<'a, F: Field + FieldSimd, const N: usize> Mul<&'a F> for AdditiveShare { type Output = Self; fn mul(self, rhs: &F) -> Self::Output { - Mul::mul(&self, rhs) + Mul::mul(&self, *rhs) } } -impl Mul for &AdditiveShare { - type Output = AdditiveShare; +impl Mul for &AdditiveShare +where + F: Field + FieldSimd, +{ + type Output = AdditiveShare; fn mul(self, rhs: F) -> Self::Output { - Mul::mul(self, &rhs) + Mul::mul(self, rhs) } } @@ -219,11 +272,15 @@ impl From<(V, V)> for AdditiveShare { } } -impl + SharedValue> std::ops::Not for AdditiveShare { +impl std::ops::Not for AdditiveShare +where + V: SharedValue + Vectorizable, + >::Array: std::ops::Not>::Array>, +{ type Output = Self; fn not(self) -> Self::Output { - AdditiveShare(!(self.0), !(self.1)) + AdditiveShare(!self.0, !self.1) } } @@ -249,53 +306,62 @@ where } /// Implement `ArrayAccess` for `AdditiveShare` over `SharedValue` that implements `ArrayAccess` -impl ArrayAccess for AdditiveShare +impl ArrayAccess for AdditiveShare where - S: ArrayAccess + SharedValue, - ::Output: SharedValue, + S: ArrayAccess + SharedValue, + T: SharedValue + Vectorizable<1, Array = A>, + A: SharedValueArray, { - type Output = AdditiveShare<::Output>; + type Output = AdditiveShare; type Iter<'a> = ASIterator>; fn get(&self, index: usize) -> Option { self.0 + .index(0) .get(index) - .zip(self.1.get(index)) - .map(|v| AdditiveShare(v.0, v.1)) + .zip(self.1.index(0).get(index)) + .map(|v| AdditiveShare(A::from_item(v.0), A::from_item(v.0))) } fn set(&mut self, index: usize, e: Self::Output) { - self.0.set(index, e.0); - self.1.set(index, e.1); + self.0.index(0).set(index, e.0.index(0)); + self.1.index(0).set(index, e.1.index(0)); } fn iter(&self) -> Self::Iter<'_> { - ASIterator(self.0.iter(), self.1.iter()) + ASIterator(self.0.index(0).iter(), self.1.index(0).iter()) } } -impl Expand for AdditiveShare +impl Expand for AdditiveShare where - S: Expand + SharedValue, + S: Expand + SharedValue + Vectorizable<1, Array = A>, ::Input: SharedValue, + A: SharedValueArray, { type Input = AdditiveShare<::Input>; fn expand(v: &Self::Input) -> Self { - AdditiveShare(S::expand(&v.0), S::expand(&v.1)) + AdditiveShare( + A::from_item(S::expand(&v.0.index(0))), + A::from_item(S::expand(&v.1.index(0))), + ) } } -impl Iterator for ASIterator +impl Iterator for ASIterator where T: Iterator, - T::Item: SharedValue, + T::Item: SharedValue + Vectorizable<1, Array = A>, + A: SharedValueArray, { type Item = AdditiveShare; fn next(&mut self) -> Option { match (self.0.next(), self.1.next()) { - (Some(left), Some(right)) => Some(AdditiveShare(left, right)), + (Some(left), Some(right)) => { + Some(AdditiveShare(A::from_item(left), A::from_item(right))) + } _ => None, } } @@ -358,8 +424,14 @@ mod tests { a3: &AdditiveShare, expected_value: u128, ) { - assert_eq!(a1.0 + a2.0 + a3.0, Fp31::truncate_from(expected_value)); - assert_eq!(a1.1 + a2.1 + a3.1, Fp31::truncate_from(expected_value)); + assert_eq!( + a1.left() + a2.left() + a3.left(), + Fp31::truncate_from(expected_value) + ); + assert_eq!( + a1.right() + a2.right() + a3.right(), + Fp31::truncate_from(expected_value) + ); } fn addition_test_case(a: (u8, u8, u8), b: (u8, u8, u8), expected_output: u128) { diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index acfb8f853f..e383d4db00 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -23,7 +23,7 @@ pub use event_gen::{Config as EventGeneratorConfig, EventGenerator}; use futures::TryFuture; use rand::{distributions::Standard, prelude::Distribution, rngs::mock::StepRng}; use rand_core::{CryptoRng, RngCore}; -pub use sharing::{get_bits, into_bits, Reconstruct}; +pub use sharing::{get_bits, into_bits, Reconstruct, ReconstructArr}; #[cfg(feature = "in-memory-infra")] pub use world::{Runner, TestWorld, TestWorldConfig}; diff --git a/ipa-core/src/test_fixture/sharing.rs b/ipa-core/src/test_fixture/sharing.rs index a9ac85cf30..5e64331567 100644 --- a/ipa-core/src/test_fixture/sharing.rs +++ b/ipa-core/src/test_fixture/sharing.rs @@ -9,7 +9,7 @@ use crate::{ semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing, }, - BitDecomposed, SecretSharing, + BitDecomposed, FieldSimd, SecretSharing, }, }; @@ -20,7 +20,7 @@ pub fn into_bits(v: F) -> BitDecomposed { }) } -/// Deconstructs a value into N values, one for each bi3t. +/// Deconstructs a value into N values, one for each bit. /// # Panics /// It won't #[must_use] @@ -37,22 +37,39 @@ pub trait Reconstruct { fn reconstruct(&self) -> T; } -impl Reconstruct for [&Replicated; 3] { - fn reconstruct(&self) -> F { - let s0 = &self[0]; - let s1 = &self[1]; - let s2 = &self[2]; +/// Alternate version of `Reconstruct` for vectors. +/// +/// There is no difference in the traits, but this avoids having to add +/// type annotations everywhere to disambiguate whether a single-bit +/// result should be reconstructed as `F` or `[F; 1]`. +pub trait ReconstructArr { + /// Validates correctness of the secret sharing scheme. + /// + /// # Panics + /// Panics if the given input is not a valid replicated secret share. + fn reconstruct_arr(&self) -> T; +} - assert_eq!( - s0.left() + s1.left() + s2.left(), - s0.right() + s1.right() + s2.right(), - ); +fn raw_reconstruct(s0l: F, s0r: F, s1l: F, s1r: F, s2l: F, s2r: F) -> F { + assert_eq!(s0l + s1l + s2l, s0r + s1r + s2r,); - assert_eq!(s0.right(), s1.left()); - assert_eq!(s1.right(), s2.left()); - assert_eq!(s2.right(), s0.left()); + assert_eq!(s0r, s1l); + assert_eq!(s1r, s2l); + assert_eq!(s2r, s0l); + + s0l + s1l + s2l +} - s0.left() + s1.left() + s2.left() +impl Reconstruct for [&Replicated; 3] { + fn reconstruct(&self) -> F { + raw_reconstruct( + self[0].left(), + self[0].right(), + self[1].left(), + self[1].right(), + self[2].left(), + self[2].right(), + ) } } @@ -62,6 +79,25 @@ impl Reconstruct for [Replicated; 3] { } } +impl, const N: usize> ReconstructArr<[F; N]> for [Replicated; 3] { + fn reconstruct_arr(&self) -> [F; N] { + let s0l = self[0].left_arr(); + let s0r = self[0].right_arr(); + let s1l = self[1].left_arr(); + let s1r = self[1].right_arr(); + let s2l = self[2].left_arr(); + let s2r = self[2].right_arr(); + + assert_eq!(s0l.clone() + s1l + s2l, s0r.clone() + s1r + s2r); + + assert_eq!(s0r, s1l); + assert_eq!(s1r, s2l); + assert_eq!(s2r, s0l); + + (s0l.clone() + s1l + s2l).into() + } +} + impl Reconstruct<(V, W)> for [(T, U); 3] where for<'t> [&'t T; 3]: Reconstruct,