Skip to content

Commit

Permalink
algebra: scalar: Add interface to specify domain in FFT computation
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Oct 25, 2023
1 parent 2b49548 commit c29f786
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 19 deletions.
35 changes: 26 additions & 9 deletions src/algebra/scalar/authenticated_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1137,25 +1137,42 @@ where
C::ScalarField: FftField,
{
/// Compute the FFT of a vector of `AuthenticatedScalarResult`s
pub fn fft<D: EvaluationDomain<C::ScalarField>>(
pub fn fft<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
x: &[AuthenticatedScalarResult<C>],
) -> Vec<AuthenticatedScalarResult<C>> {
Self::fft_helper::<D>(x, true /* is_forward */)
Self::fft_with_domain::<D>(x, D::new(x.len()).unwrap())
}

/// Compute the FFT of a vector of `AuthenticatedScalarResult`s with a given domain
pub fn fft_with_domain<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
x: &[AuthenticatedScalarResult<C>],
domain: D,
) -> Vec<AuthenticatedScalarResult<C>> {
Self::fft_helper::<D>(x, true /* is_forward */, domain)
}

/// Compute the inverse FFT of a vector of `AuthenticatedScalarResult`s
pub fn ifft<D: EvaluationDomain<C::ScalarField>>(
pub fn ifft<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
x: &[AuthenticatedScalarResult<C>],
) -> Vec<AuthenticatedScalarResult<C>> {
Self::fft_helper::<D>(x, false /* is_forward */, D::new(x.len()).unwrap())
}

/// Compute the inverse FFT of a vector of `AuthenticatedScalarResult`s with a given domain
pub fn ifft_with_domain<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
x: &[AuthenticatedScalarResult<C>],
domain: D,
) -> Vec<AuthenticatedScalarResult<C>> {
Self::fft_helper::<D>(x, false /* is_forward */)
Self::fft_helper::<D>(x, false /* is_forward */, domain)
}

/// An FFT/IFFT helper that encapsulates the setup and restructuring of an FFT regardless of direction
///
/// If `is_forward` is set, an FFT is performed. Otherwise, an IFFT is performed
fn fft_helper<D: EvaluationDomain<C::ScalarField>>(
fn fft_helper<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
x: &[AuthenticatedScalarResult<C>],
is_forward: bool,
domain: D,
) -> Vec<AuthenticatedScalarResult<C>> {
assert!(!x.is_empty(), "Cannot compute FFT of empty vector");
let fabric = x[0].fabric();
Expand All @@ -1172,13 +1189,13 @@ where

let (share_fft, mac_fft) = if is_forward {
(
ScalarResult::fft::<D>(&shares),
ScalarResult::fft::<D>(&macs),
ScalarResult::fft_with_domain::<D>(&shares, domain),
ScalarResult::fft_with_domain::<D>(&macs, domain),
)
} else {
(
ScalarResult::ifft::<D>(&shares),
ScalarResult::ifft::<D>(&macs),
ScalarResult::ifft_with_domain::<D>(&shares, domain),
ScalarResult::ifft_with_domain::<D>(&macs, domain),
)
};

Expand Down
36 changes: 26 additions & 10 deletions src/algebra/scalar/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,17 @@ where
C::ScalarField: FftField,
{
/// Compute the fft of a sequence of `ScalarResult`s
pub fn fft<D: EvaluationDomain<C::ScalarField>>(x: &[ScalarResult<C>]) -> Vec<ScalarResult<C>> {
pub fn fft<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
x: &[ScalarResult<C>],
) -> Vec<ScalarResult<C>> {
Self::fft_with_domain(x, D::new(x.len()).unwrap())
}

/// Compute the fft of a sequence of `ScalarResult`s with the given domain
pub fn fft_with_domain<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
x: &[ScalarResult<C>],
domain: D,
) -> Vec<ScalarResult<C>> {
assert!(!x.is_empty(), "Cannot compute fft of empty sequence");
let n = x.len().next_power_of_two();

Expand All @@ -449,18 +459,25 @@ where
.map(|x| x.0)
.collect_vec();

let domain = D::new(n).unwrap();
let res = domain.fft(&scalars);

res.into_iter()
domain
.fft(&scalars)
.into_iter()
.map(|x| ResultValue::Scalar(Scalar::new(x)))
.collect_vec()
})
}

/// Compute the ifft of a sequence of `ScalarResult`s
pub fn ifft<D: EvaluationDomain<C::ScalarField>>(
pub fn ifft<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
x: &[ScalarResult<C>],
) -> Vec<ScalarResult<C>> {
Self::ifft_with_domain(x, D::new(x.len()).unwrap())
}

/// Compute the ifft of a sequence of `ScalarResult`s with the given domain
pub fn ifft_with_domain<D: 'static + EvaluationDomain<C::ScalarField> + Send>(
x: &[ScalarResult<C>],
domain: D,
) -> Vec<ScalarResult<C>> {
assert!(!x.is_empty(), "Cannot compute fft of empty sequence");
let n = x.len().next_power_of_two();
Expand All @@ -475,10 +492,9 @@ where
.map(|x| x.0)
.collect_vec();

let domain = D::new(n).unwrap();
let res = domain.ifft(&scalars);

res.into_iter()
domain
.ifft(&scalars)
.into_iter()
.map(|x| ResultValue::Scalar(Scalar::new(x)))
.collect_vec()
})
Expand Down

0 comments on commit c29f786

Please sign in to comment.