From c29f786649f6a8e647fce6534a7ada094170d0cd Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Wed, 25 Oct 2023 16:27:01 -0700 Subject: [PATCH] algebra: scalar: Add interface to specify domain in FFT computation --- src/algebra/scalar/authenticated_scalar.rs | 35 +++++++++++++++------ src/algebra/scalar/scalar.rs | 36 ++++++++++++++++------ 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/src/algebra/scalar/authenticated_scalar.rs b/src/algebra/scalar/authenticated_scalar.rs index 05f22b7..d42cc34 100644 --- a/src/algebra/scalar/authenticated_scalar.rs +++ b/src/algebra/scalar/authenticated_scalar.rs @@ -1137,25 +1137,42 @@ where C::ScalarField: FftField, { /// Compute the FFT of a vector of `AuthenticatedScalarResult`s - pub fn fft>( + pub fn fft + Send>( x: &[AuthenticatedScalarResult], ) -> Vec> { - Self::fft_helper::(x, true /* is_forward */) + Self::fft_with_domain::(x, D::new(x.len()).unwrap()) + } + + /// Compute the FFT of a vector of `AuthenticatedScalarResult`s with a given domain + pub fn fft_with_domain + Send>( + x: &[AuthenticatedScalarResult], + domain: D, + ) -> Vec> { + Self::fft_helper::(x, true /* is_forward */, domain) } /// Compute the inverse FFT of a vector of `AuthenticatedScalarResult`s - pub fn ifft>( + pub fn ifft + Send>( + x: &[AuthenticatedScalarResult], + ) -> Vec> { + Self::fft_helper::(x, false /* is_forward */, D::new(x.len()).unwrap()) + } + + /// Compute the inverse FFT of a vector of `AuthenticatedScalarResult`s with a given domain + pub fn ifft_with_domain + Send>( x: &[AuthenticatedScalarResult], + domain: D, ) -> Vec> { - Self::fft_helper::(x, false /* is_forward */) + Self::fft_helper::(x, false /* is_forward */, domain) } /// An FFT/IFFT helper that encapsulates the setup and restructuring of an FFT regardless of direction /// /// If `is_forward` is set, an FFT is performed. Otherwise, an IFFT is performed - fn fft_helper>( + fn fft_helper + Send>( x: &[AuthenticatedScalarResult], is_forward: bool, + domain: D, ) -> Vec> { assert!(!x.is_empty(), "Cannot compute FFT of empty vector"); let fabric = x[0].fabric(); @@ -1172,13 +1189,13 @@ where let (share_fft, mac_fft) = if is_forward { ( - ScalarResult::fft::(&shares), - ScalarResult::fft::(&macs), + ScalarResult::fft_with_domain::(&shares, domain), + ScalarResult::fft_with_domain::(&macs, domain), ) } else { ( - ScalarResult::ifft::(&shares), - ScalarResult::ifft::(&macs), + ScalarResult::ifft_with_domain::(&shares, domain), + ScalarResult::ifft_with_domain::(&macs, domain), ) }; diff --git a/src/algebra/scalar/scalar.rs b/src/algebra/scalar/scalar.rs index 00c0c59..2542322 100644 --- a/src/algebra/scalar/scalar.rs +++ b/src/algebra/scalar/scalar.rs @@ -435,7 +435,17 @@ where C::ScalarField: FftField, { /// Compute the fft of a sequence of `ScalarResult`s - pub fn fft>(x: &[ScalarResult]) -> Vec> { + pub fn fft + Send>( + x: &[ScalarResult], + ) -> Vec> { + Self::fft_with_domain(x, D::new(x.len()).unwrap()) + } + + /// Compute the fft of a sequence of `ScalarResult`s with the given domain + pub fn fft_with_domain + Send>( + x: &[ScalarResult], + domain: D, + ) -> Vec> { assert!(!x.is_empty(), "Cannot compute fft of empty sequence"); let n = x.len().next_power_of_two(); @@ -449,18 +459,25 @@ where .map(|x| x.0) .collect_vec(); - let domain = D::new(n).unwrap(); - let res = domain.fft(&scalars); - - res.into_iter() + domain + .fft(&scalars) + .into_iter() .map(|x| ResultValue::Scalar(Scalar::new(x))) .collect_vec() }) } /// Compute the ifft of a sequence of `ScalarResult`s - pub fn ifft>( + pub fn ifft + Send>( x: &[ScalarResult], + ) -> Vec> { + Self::ifft_with_domain(x, D::new(x.len()).unwrap()) + } + + /// Compute the ifft of a sequence of `ScalarResult`s with the given domain + pub fn ifft_with_domain + Send>( + x: &[ScalarResult], + domain: D, ) -> Vec> { assert!(!x.is_empty(), "Cannot compute fft of empty sequence"); let n = x.len().next_power_of_two(); @@ -475,10 +492,9 @@ where .map(|x| x.0) .collect_vec(); - let domain = D::new(n).unwrap(); - let res = domain.ifft(&scalars); - - res.into_iter() + domain + .ifft(&scalars) + .into_iter() .map(|x| ResultValue::Scalar(Scalar::new(x))) .collect_vec() })