Skip to content

Commit

Permalink
Merge pull request #1409 from akoshelev/lagrange-perf
Browse files Browse the repository at this point in the history
Lagrange evaluation performance improvements
  • Loading branch information
akoshelev authored Nov 12, 2024
2 parents 3d5ba63 + a0c6a78 commit 0a34110
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
50 changes: 28 additions & 22 deletions ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{borrow::Borrow, fmt::Debug};
use std::fmt::Debug;

use typenum::Unsigned;

Expand Down Expand Up @@ -79,8 +79,7 @@ pub struct LagrangeTable<F: Field, const N: usize, const M: usize> {

impl<F, const N: usize> LagrangeTable<F, N, 1>
where
F: Field + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
F: PrimeField,
{
/// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point
/// The "x coordinate" of the output point is `x_output`.
Expand All @@ -95,31 +94,16 @@ where

impl<F, const N: usize, const M: usize> LagrangeTable<F, N, M>
where
F: Field,
F: PrimeField,
{
/// This function uses the `LagrangeTable` to evaluate `polynomial` on the _output_ "x coordinates"
/// that were used to generate this table.
/// It is assumed that the `y_coordinates` provided to this function correspond the values of the _input_ "x coordinates"
/// that were used to generate this table.
pub fn eval<I>(&self, y_coordinates: I) -> [F; M]
where
I: IntoIterator + Copy,
I::IntoIter: ExactSizeIterator,
I::Item: Borrow<F>,
{
debug_assert_eq!(y_coordinates.into_iter().len(), N);

pub fn eval(&self, y_coordinates: &[F; N]) -> [F; M] {
self.table
.iter()
.map(|table_row| {
table_row
.iter()
.zip(y_coordinates)
.fold(F::ZERO, |acc, (&base, y)| acc + base * (*y.borrow()))
})
.collect::<Vec<F>>()
.try_into()
.unwrap()
.each_ref()
.map(|row| dot_product(row, y_coordinates))
}

/// helper function to compute a single row of `LagrangeTable`
Expand Down Expand Up @@ -176,6 +160,28 @@ where
}
}

/// Computes the dot product of two arrays of the same size.
/// It is isolated from Lagrange because there could be potential SIMD optimizations used
fn dot_product<F: PrimeField, const N: usize>(a: &[F; N], b: &[F; N]) -> F {
// Staying in integers allows rustc to optimize this code properly, but puts a restriction
// on how large the prime field can be
debug_assert!(
2 * F::BITS + N.next_power_of_two().ilog2() <= 128,
"The prime field {} is too large for this dot product implementation",
F::PRIME.into()
);

let mut sum = 0;

// I am cautious about using zip in hot code
// https://github.com/rust-lang/rust/issues/103555
for i in 0..N {
sum += a[i].as_u128() * b[i].as_u128();
}

F::truncate_from(sum)
}

#[cfg(all(test, unit_test))]
mod test {
use std::{borrow::Borrow, fmt::Debug};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ where
last_array[1..last_u_or_v_values.len()].copy_from_slice(&last_u_or_v_values[1..]);

// compute and output p_or_q
tables.last().unwrap().eval(last_array)[0]
tables.last().unwrap().eval(&last_array)[0]
}

#[cfg(all(test, unit_test))]
Expand Down

0 comments on commit 0a34110

Please sign in to comment.