From 437c41093111b586474ba15c84377c88572a0ed3 Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Wed, 18 Oct 2023 18:27:21 -0700 Subject: [PATCH] algebra: authenticated-scalar: Implement FFT and IFFT on shared values --- Cargo.toml | 3 +- src/algebra/mod.rs | 2 - src/algebra/scalar/authenticated_scalar.rs | 153 ++++++++++++++++++++- src/algebra/scalar/scalar.rs | 125 ++++++++++++++++- 4 files changed, 273 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a51ed71..5d9b3fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,6 @@ path = "src/lib.rs" benchmarks = [] stats = ["benchmarks"] test_helpers = ["ark-bn254"] -poly = ["ark-poly"] [[test]] name = "integration" @@ -77,7 +76,7 @@ tokio = { version = "1.12", features = ["macros", "rt-multi-thread"] } ark-bn254 = { version = "0.4", optional = true } ark-ec = { version = "0.4", features = ["parallel"] } ark-ff = "0.4" -ark-poly = { version = "0.4", optional = true, features = ["std", "parallel"] } +ark-poly = { version = "0.4", features = ["std", "parallel"] } ark-serialize = "0.4" ark-std = "0.4" digest = "0.10" diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs index 89780b7..5328c97 100644 --- a/src/algebra/mod.rs +++ b/src/algebra/mod.rs @@ -4,9 +4,7 @@ mod curve; mod macros; mod scalar; -#[cfg(feature = "poly")] mod poly; -#[cfg(feature = "poly")] pub use poly::*; pub use curve::*; diff --git a/src/algebra/scalar/authenticated_scalar.rs b/src/algebra/scalar/authenticated_scalar.rs index 519db72..efc0ab8 100644 --- a/src/algebra/scalar/authenticated_scalar.rs +++ b/src/algebra/scalar/authenticated_scalar.rs @@ -2,13 +2,15 @@ use std::{ fmt::Debug, - iter::Sum, + iter::{self, Sum}, ops::{Add, Div, Mul, Neg, Sub}, pin::Pin, task::{Context, Poll}, }; use ark_ec::CurveGroup; +use ark_ff::FftField; +use ark_poly::EvaluationDomain; use futures::{Future, FutureExt}; use itertools::{izip, Itertools}; @@ -139,6 +141,11 @@ impl AuthenticatedScalarResult { self.share.to_scalar() } + /// Get the raw share of the MAC as a `ScalarResult` + pub fn mac_share(&self) -> ScalarResult { + self.mac.to_scalar() + } + /// Get a reference to the underlying MPC fabric pub fn fabric(&self) -> &MpcFabric { self.share.fabric() @@ -1088,6 +1095,79 @@ impl Mul<&AuthenticatedScalarResult> for &CurvePointResult impl_borrow_variants!(CurvePointResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedPointResult, C: CurveGroup); impl_commutative!(CurvePointResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedPointResult, C: CurveGroup); +// === FFT and IFFT === // +impl AuthenticatedScalarResult +where + C::ScalarField: FftField, +{ + /// Compute the FFT of a vector of `AuthenticatedScalarResult`s + pub fn fft>( + x: &[AuthenticatedScalarResult], + ) -> Vec> { + Self::fft_helper::(x, true /* is_forward */) + } + + /// Compute the inverse FFT of a vector of `AuthenticatedScalarResult`s + pub fn ifft>( + x: &[AuthenticatedScalarResult], + ) -> Vec> { + Self::fft_helper::(x, false /* is_forward */) + } + + /// An FFT/IFFT helper that encapsulates the setup and restructuring of an FFT regardless of direction + /// + /// If `is_forward` is set, an FFT is performed. Otherwise, an IFFT is performed + fn fft_helper>( + x: &[AuthenticatedScalarResult], + is_forward: bool, + ) -> Vec> { + assert!(!x.is_empty(), "Cannot compute FFT of empty vector"); + let fabric = x[0].fabric(); + + // Extend to the next power of two + let n = x.len(); + let padding_length = n.next_power_of_two() - n; + let pad = fabric.zeros_authenticated(padding_length); + let padded_input = [x, &pad].concat(); + + // Take the FFT of the shares and the macs separately + let shares = padded_input.iter().map(|v| v.share()).collect_vec(); + let macs = padded_input.iter().map(|v| v.mac_share()).collect_vec(); + + let (share_fft, mac_fft) = if is_forward { + ( + ScalarResult::fft::(&shares), + ScalarResult::fft::(&macs), + ) + } else { + ( + ScalarResult::ifft::(&shares), + ScalarResult::ifft::(&macs), + ) + }; + + // No public values are added in an FFT, so the public modifier remain unchanged + let n = mac_fft.len(); + let modifiers = padded_input + .iter() + .map(|v| v.public_modifier.clone()) + .chain(iter::repeat(fabric.one())) + .take(n) + .collect_vec(); + + let mut res = Vec::with_capacity(n); + for (share, mac, modifier) in izip!(share_fft, mac_fft, modifiers) { + res.push(AuthenticatedScalarResult { + share: MpcScalarResult::new_shared(share), + mac: MpcScalarResult::new_shared(mac), + public_modifier: modifier, + }) + } + + res + } +} + // ---------------- // | Test Helpers | // ---------------- @@ -1126,12 +1206,13 @@ pub mod test_helpers { #[cfg(test)] mod tests { + use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; use futures::future; use itertools::Itertools; - use rand::thread_rng; + use rand::{thread_rng, Rng}; use crate::{ - algebra::{scalar::Scalar, AuthenticatedScalarResult}, + algebra::{poly_test_helpers::TestPolyField, scalar::Scalar, AuthenticatedScalarResult}, test_helpers::{execute_mock_mpc, TestCurve}, PARTY0, }; @@ -1258,4 +1339,70 @@ mod tests { assert_eq!(res.unwrap(), expected_res) } + + #[tokio::test] + async fn test_fft() { + let mut rng = thread_rng(); + let n: usize = rng.gen_range(0..100); + + let values = (0..n) + .map(|_| Scalar::::random(&mut rng)) + .collect_vec(); + + let domain = Radix2EvaluationDomain::::new(n).unwrap(); + let fft_res = domain.fft(&values.iter().map(Scalar::inner).collect_vec()); + let expected_res = fft_res.into_iter().map(Scalar::new).collect_vec(); + + let (res, _) = execute_mock_mpc(|fabric| { + let values = values.clone(); + async move { + let shared_values = fabric.batch_share_scalar(values, PARTY0 /* sender */); + let fft = AuthenticatedScalarResult::fft::>( + &shared_values, + ); + + let opening = AuthenticatedScalarResult::open_authenticated_batch(&fft); + future::join_all(opening.into_iter()) + .await + .into_iter() + .collect::, _>>() + } + }) + .await; + + assert_eq!(res.unwrap(), expected_res) + } + + #[tokio::test] + async fn test_ifft() { + let mut rng = thread_rng(); + let n: usize = rng.gen_range(0..100); + + let values = (0..n) + .map(|_| Scalar::::random(&mut rng)) + .collect_vec(); + + let domain = Radix2EvaluationDomain::::new(n).unwrap(); + let ifft_res = domain.ifft(&values.iter().map(Scalar::inner).collect_vec()); + let expected_res = ifft_res.into_iter().map(Scalar::new).collect_vec(); + + let (res, _) = execute_mock_mpc(|fabric| { + let values = values.clone(); + async move { + let shared_values = fabric.batch_share_scalar(values, PARTY0 /* sender */); + let ifft = AuthenticatedScalarResult::ifft::>( + &shared_values, + ); + + let opening = AuthenticatedScalarResult::open_authenticated_batch(&ifft); + future::join_all(opening.into_iter()) + .await + .into_iter() + .collect::, _>>() + } + }) + .await; + + assert_eq!(res.unwrap(), expected_res) + } } diff --git a/src/algebra/scalar/scalar.rs b/src/algebra/scalar/scalar.rs index 315e6ef..62111d9 100644 --- a/src/algebra/scalar/scalar.rs +++ b/src/algebra/scalar/scalar.rs @@ -11,7 +11,8 @@ use std::{ }; use ark_ec::CurveGroup; -use ark_ff::{batch_inversion, Field, PrimeField}; +use ark_ff::{batch_inversion, FftField, Field, PrimeField}; +use ark_poly::EvaluationDomain; use ark_std::UniformRand; use itertools::Itertools; use num_bigint::BigUint; @@ -412,6 +413,62 @@ impl MulAssign for Scalar { } } +// === FFT and IFFT === // +impl ScalarResult +where + C::ScalarField: FftField, +{ + /// Compute the fft of a sequence of `ScalarResult`s + pub fn fft>(x: &[ScalarResult]) -> Vec> { + assert!(!x.is_empty(), "Cannot compute fft of empty sequence"); + let n = x.len().next_power_of_two(); + + let fabric = x[0].fabric(); + let ids = x.iter().map(|v| v.id).collect_vec(); + + fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| { + let scalars = args + .into_iter() + .map(Scalar::from) + .map(|x| x.0) + .collect_vec(); + + let domain = D::new(n).unwrap(); + let res = domain.fft(&scalars); + + res.into_iter() + .map(|x| ResultValue::Scalar(Scalar::new(x))) + .collect_vec() + }) + } + + /// Compute the ifft of a sequence of `ScalarResult`s + pub fn ifft>( + x: &[ScalarResult], + ) -> Vec> { + assert!(!x.is_empty(), "Cannot compute fft of empty sequence"); + let n = x.len().next_power_of_two(); + + let fabric = x[0].fabric(); + let ids = x.iter().map(|v| v.id).collect_vec(); + + fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| { + let scalars = args + .into_iter() + .map(Scalar::from) + .map(|x| x.0) + .collect_vec(); + + let domain = D::new(n).unwrap(); + let res = domain.ifft(&scalars); + + res.into_iter() + .map(|x| ResultValue::Scalar(Scalar::new(x))) + .collect_vec() + }) + } +} + // --------------- // | Conversions | // --------------- @@ -476,8 +533,14 @@ impl Product for Scalar { #[cfg(test)] mod test { - use crate::{algebra::scalar::Scalar, test_helpers::mock_fabric}; - use rand::thread_rng; + use crate::{ + algebra::{poly_test_helpers::TestPolyField, scalar::Scalar, ScalarResult}, + test_helpers::{execute_mock_mpc, mock_fabric, TestCurve}, + }; + use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; + use futures::future; + use itertools::Itertools; + use rand::{thread_rng, Rng}; /// Tests addition of raw scalars in a circuit #[tokio::test] @@ -560,4 +623,60 @@ mod test { assert_eq!(res_final, expected_res); fabric.shutdown(); } + + /// Tests fft of scalars allocated in a circuit + #[tokio::test] + async fn test_circuit_fft() { + let mut rng = thread_rng(); + let n: usize = rng.gen_range(1..=100); + + let seq = (0..n) + .map(|_| Scalar::::random(&mut rng)) + .collect_vec(); + + let domain = Radix2EvaluationDomain::::new(n).unwrap(); + let fft_res = domain.fft(&seq.iter().map(|s| s.inner()).collect_vec()); + let expected_res = fft_res.into_iter().map(Scalar::new).collect_vec(); + + let (res, _) = execute_mock_mpc(|fabric| { + let seq = seq.clone(); + async move { + let seq_alloc = seq.iter().map(|x| fabric.allocate_scalar(*x)).collect_vec(); + + let res = ScalarResult::fft::>(&seq_alloc); + future::join_all(res.into_iter()).await + } + }) + .await; + + assert_eq!(res, expected_res); + } + + /// Tests the ifft of scalars allocated in a circuit + #[tokio::test] + async fn test_circuit_ifft() { + let mut rng = thread_rng(); + let n: usize = rng.gen_range(1..=100); + + let seq = (0..n) + .map(|_| Scalar::::random(&mut rng)) + .collect_vec(); + + let domain = Radix2EvaluationDomain::::new(n).unwrap(); + let ifft_res = domain.ifft(&seq.iter().map(|s| s.inner()).collect_vec()); + let expected_res = ifft_res.into_iter().map(Scalar::new).collect_vec(); + + let (res, _) = execute_mock_mpc(|fabric| { + let seq = seq.clone(); + async move { + let seq_alloc = seq.iter().map(|x| fabric.allocate_scalar(*x)).collect_vec(); + + let res = ScalarResult::ifft::>(&seq_alloc); + future::join_all(res.into_iter()).await + } + }) + .await; + + assert_eq!(res, expected_res); + } }