Skip to content

Commit

Permalink
algebra: authenticated-scalar: Implement FFT and IFFT on shared values
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Oct 19, 2023
1 parent fd9c75a commit f1f4d45
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 10 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ path = "src/lib.rs"
benchmarks = []
stats = ["benchmarks"]
test_helpers = ["ark-bn254"]
poly = ["ark-poly"]

[[test]]
name = "integration"
Expand Down Expand Up @@ -77,7 +76,7 @@ tokio = { version = "1.12", features = ["macros", "rt-multi-thread"] }
ark-bn254 = { version = "0.4", optional = true }
ark-ec = { version = "0.4", features = ["parallel"] }
ark-ff = "0.4"
ark-poly = { version = "0.4", optional = true, features = ["std", "parallel"] }
ark-poly = { version = "0.4", features = ["std", "parallel"] }
ark-serialize = "0.4"
ark-std = "0.4"
digest = "0.10"
Expand Down
2 changes: 0 additions & 2 deletions src/algebra/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ mod curve;
mod macros;
mod scalar;

#[cfg(feature = "poly")]
mod poly;
#[cfg(feature = "poly")]
pub use poly::*;

pub use curve::*;
Expand Down
153 changes: 150 additions & 3 deletions src/algebra/scalar/authenticated_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
use std::{
fmt::Debug,
iter::Sum,
iter::{self, Sum},
ops::{Add, Div, Mul, Neg, Sub},
pin::Pin,
task::{Context, Poll},
};

use ark_ec::CurveGroup;
use ark_ff::FftField;
use ark_poly::EvaluationDomain;
use futures::{Future, FutureExt};
use itertools::{izip, Itertools};

Expand Down Expand Up @@ -139,6 +141,11 @@ impl<C: CurveGroup> AuthenticatedScalarResult<C> {
self.share.to_scalar()
}

/// Get the raw share of the MAC as a `ScalarResult`
pub fn mac_share(&self) -> ScalarResult<C> {
self.mac.to_scalar()
}

/// Get a reference to the underlying MPC fabric
pub fn fabric(&self) -> &MpcFabric<C> {
self.share.fabric()
Expand Down Expand Up @@ -1088,6 +1095,79 @@ impl<C: CurveGroup> Mul<&AuthenticatedScalarResult<C>> for &CurvePointResult<C>
impl_borrow_variants!(CurvePointResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, Output=AuthenticatedPointResult<C>, C: CurveGroup);
impl_commutative!(CurvePointResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, Output=AuthenticatedPointResult<C>, C: CurveGroup);

// === FFT and IFFT === //
impl<C: CurveGroup> AuthenticatedScalarResult<C>
where
C::ScalarField: FftField,
{
/// Compute the FFT of a vector of `AuthenticatedScalarResult`s
pub fn fft<D: EvaluationDomain<C::ScalarField>>(
x: &[AuthenticatedScalarResult<C>],
) -> Vec<AuthenticatedScalarResult<C>> {
Self::fft_helper::<D>(x, true /* is_forward */)
}

/// Compute the inverse FFT of a vector of `AuthenticatedScalarResult`s
pub fn ifft<D: EvaluationDomain<C::ScalarField>>(
x: &[AuthenticatedScalarResult<C>],
) -> Vec<AuthenticatedScalarResult<C>> {
Self::fft_helper::<D>(x, false /* is_forward */)
}

/// An FFT/IFFT helper that encapsulates the setup and restructuring of an FFT regardless of direction
///
/// If `is_forward` is set, an FFT is performed. Otherwise, an IFFT is performed
fn fft_helper<D: EvaluationDomain<C::ScalarField>>(
x: &[AuthenticatedScalarResult<C>],
is_forward: bool,
) -> Vec<AuthenticatedScalarResult<C>> {
assert!(!x.is_empty(), "Cannot compute FFT of empty vector");
let fabric = x[0].fabric();

// Extend to the next power of two
let n = x.len();
let padding_length = n.next_power_of_two() - n;
let pad = fabric.zeros_authenticated(padding_length);
let padded_input = [x, &pad].concat();

// Take the FFT of the shares and the macs separately
let shares = padded_input.iter().map(|v| v.share()).collect_vec();
let macs = padded_input.iter().map(|v| v.mac_share()).collect_vec();

let (share_fft, mac_fft) = if is_forward {
(
ScalarResult::fft::<D>(&shares),
ScalarResult::fft::<D>(&macs),
)
} else {
(
ScalarResult::ifft::<D>(&shares),
ScalarResult::ifft::<D>(&macs),
)
};

// No public values are added in an FFT, so the public modifier remain unchanged
let n = mac_fft.len();
let modifiers = padded_input
.iter()
.map(|v| v.public_modifier.clone())
.chain(iter::repeat(fabric.one()))
.take(n)
.collect_vec();

let mut res = Vec::with_capacity(n);
for (share, mac, modifier) in izip!(share_fft, mac_fft, modifiers) {
res.push(AuthenticatedScalarResult {
share: MpcScalarResult::new_shared(share),
mac: MpcScalarResult::new_shared(mac),
public_modifier: modifier,
})
}

res
}
}

// ----------------
// | Test Helpers |
// ----------------
Expand Down Expand Up @@ -1126,12 +1206,13 @@ pub mod test_helpers {

#[cfg(test)]
mod tests {
use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
use futures::future;
use itertools::Itertools;
use rand::thread_rng;
use rand::{thread_rng, Rng};

use crate::{
algebra::{scalar::Scalar, AuthenticatedScalarResult},
algebra::{poly_test_helpers::TestPolyField, scalar::Scalar, AuthenticatedScalarResult},
test_helpers::{execute_mock_mpc, TestCurve},
PARTY0,
};
Expand Down Expand Up @@ -1258,4 +1339,70 @@ mod tests {

assert_eq!(res.unwrap(), expected_res)
}

#[tokio::test]
async fn test_fft() {
let mut rng = thread_rng();
let n: usize = rng.gen_range(0..100);

let values = (0..n)
.map(|_| Scalar::<TestCurve>::random(&mut rng))
.collect_vec();

let domain = Radix2EvaluationDomain::<TestPolyField>::new(n).unwrap();
let fft_res = domain.fft(&values.iter().map(Scalar::inner).collect_vec());
let expected_res = fft_res.into_iter().map(Scalar::new).collect_vec();

let (res, _) = execute_mock_mpc(|fabric| {
let values = values.clone();
async move {
let shared_values = fabric.batch_share_scalar(values, PARTY0 /* sender */);
let fft = AuthenticatedScalarResult::fft::<Radix2EvaluationDomain<TestPolyField>>(
&shared_values,
);

let opening = AuthenticatedScalarResult::open_authenticated_batch(&fft);
future::join_all(opening.into_iter())
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
}
})
.await;

assert_eq!(res.unwrap(), expected_res)
}

#[tokio::test]
async fn test_ifft() {
let mut rng = thread_rng();
let n: usize = rng.gen_range(0..100);

let values = (0..n)
.map(|_| Scalar::<TestCurve>::random(&mut rng))
.collect_vec();

let domain = Radix2EvaluationDomain::<TestPolyField>::new(n).unwrap();
let ifft_res = domain.ifft(&values.iter().map(Scalar::inner).collect_vec());
let expected_res = ifft_res.into_iter().map(Scalar::new).collect_vec();

let (res, _) = execute_mock_mpc(|fabric| {
let values = values.clone();
async move {
let shared_values = fabric.batch_share_scalar(values, PARTY0 /* sender */);
let ifft = AuthenticatedScalarResult::ifft::<Radix2EvaluationDomain<TestPolyField>>(
&shared_values,
);

let opening = AuthenticatedScalarResult::open_authenticated_batch(&ifft);
future::join_all(opening.into_iter())
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
}
})
.await;

assert_eq!(res.unwrap(), expected_res)
}
}
125 changes: 122 additions & 3 deletions src/algebra/scalar/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use std::{
};

use ark_ec::CurveGroup;
use ark_ff::{batch_inversion, Field, PrimeField};
use ark_ff::{batch_inversion, FftField, Field, PrimeField};
use ark_poly::EvaluationDomain;
use ark_std::UniformRand;
use itertools::Itertools;
use num_bigint::BigUint;
Expand Down Expand Up @@ -412,6 +413,62 @@ impl<C: CurveGroup> MulAssign for Scalar<C> {
}
}

// === FFT and IFFT === //
impl<C: CurveGroup> ScalarResult<C>
where
C::ScalarField: FftField,
{
/// Compute the fft of a sequence of `ScalarResult`s
pub fn fft<D: EvaluationDomain<C::ScalarField>>(x: &[ScalarResult<C>]) -> Vec<ScalarResult<C>> {
assert!(!x.is_empty(), "Cannot compute fft of empty sequence");
let n = x.len().next_power_of_two();

let fabric = x[0].fabric();
let ids = x.iter().map(|v| v.id).collect_vec();

fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
let scalars = args
.into_iter()
.map(Scalar::from)
.map(|x| x.0)
.collect_vec();

let domain = D::new(n).unwrap();
let res = domain.fft(&scalars);

res.into_iter()
.map(|x| ResultValue::Scalar(Scalar::new(x)))
.collect_vec()
})
}

/// Compute the ifft of a sequence of `ScalarResult`s
pub fn ifft<D: EvaluationDomain<C::ScalarField>>(
x: &[ScalarResult<C>],
) -> Vec<ScalarResult<C>> {
assert!(!x.is_empty(), "Cannot compute fft of empty sequence");
let n = x.len().next_power_of_two();

let fabric = x[0].fabric();
let ids = x.iter().map(|v| v.id).collect_vec();

fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| {
let scalars = args
.into_iter()
.map(Scalar::from)
.map(|x| x.0)
.collect_vec();

let domain = D::new(n).unwrap();
let res = domain.ifft(&scalars);

res.into_iter()
.map(|x| ResultValue::Scalar(Scalar::new(x)))
.collect_vec()
})
}
}

// ---------------
// | Conversions |
// ---------------
Expand Down Expand Up @@ -476,8 +533,14 @@ impl<C: CurveGroup> Product for Scalar<C> {

#[cfg(test)]
mod test {
use crate::{algebra::scalar::Scalar, test_helpers::mock_fabric};
use rand::thread_rng;
use crate::{
algebra::{poly_test_helpers::TestPolyField, scalar::Scalar, ScalarResult},
test_helpers::{execute_mock_mpc, mock_fabric, TestCurve},
};
use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
use futures::future;
use itertools::Itertools;
use rand::{thread_rng, Rng};

/// Tests addition of raw scalars in a circuit
#[tokio::test]
Expand Down Expand Up @@ -560,4 +623,60 @@ mod test {
assert_eq!(res_final, expected_res);
fabric.shutdown();
}

/// Tests fft of scalars allocated in a circuit
#[tokio::test]
async fn test_circuit_fft() {
let mut rng = thread_rng();
let n: usize = rng.gen_range(1..=100);

let seq = (0..n)
.map(|_| Scalar::<TestCurve>::random(&mut rng))
.collect_vec();

let domain = Radix2EvaluationDomain::<TestPolyField>::new(n).unwrap();
let fft_res = domain.fft(&seq.iter().map(|s| s.inner()).collect_vec());
let expected_res = fft_res.into_iter().map(Scalar::new).collect_vec();

let (res, _) = execute_mock_mpc(|fabric| {
let seq = seq.clone();
async move {
let seq_alloc = seq.iter().map(|x| fabric.allocate_scalar(*x)).collect_vec();

let res = ScalarResult::fft::<Radix2EvaluationDomain<TestPolyField>>(&seq_alloc);
future::join_all(res.into_iter()).await
}
})
.await;

assert_eq!(res, expected_res);
}

/// Tests the ifft of scalars allocated in a circuit
#[tokio::test]
async fn test_circuit_ifft() {
let mut rng = thread_rng();
let n: usize = rng.gen_range(1..=100);

let seq = (0..n)
.map(|_| Scalar::<TestCurve>::random(&mut rng))
.collect_vec();

let domain = Radix2EvaluationDomain::<TestPolyField>::new(n).unwrap();
let ifft_res = domain.ifft(&seq.iter().map(|s| s.inner()).collect_vec());
let expected_res = ifft_res.into_iter().map(Scalar::new).collect_vec();

let (res, _) = execute_mock_mpc(|fabric| {
let seq = seq.clone();
async move {
let seq_alloc = seq.iter().map(|x| fabric.allocate_scalar(*x)).collect_vec();

let res = ScalarResult::ifft::<Radix2EvaluationDomain<TestPolyField>>(&seq_alloc);
future::join_all(res.into_iter()).await
}
})
.await;

assert_eq!(res, expected_res);
}
}

0 comments on commit f1f4d45

Please sign in to comment.