Skip to content

Commit

Permalink
Implement MultilinearExtension for DensePolynomialPqx
Browse files Browse the repository at this point in the history
  • Loading branch information
darth-cy committed Jan 3, 2025
1 parent 901a56c commit 928be2e
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 18 deletions.
218 changes: 205 additions & 13 deletions spartan_parallel/src/custom_dense_mlpoly.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#![allow(clippy::too_many_arguments)]
use core::{unimplemented, unreachable};
use std::cmp::min;

use crate::dense_mlpoly::DensePolynomial;
use crate::{dense_mlpoly::DensePolynomial, instance};
use ff::Field;
use ff_ext::ExtensionField;
use multilinear_extensions::mle::DenseMultilinearExtension;
use multilinear_extensions::mle::{FieldType, MultilinearExtension, DenseMultilinearExtension, RangedMultilinearExtension};
use std::{any::TypeId, borrow::Cow, mem, sync::Arc};

use super::math::Math;

Expand All @@ -29,6 +32,7 @@ pub struct DensePolynomialPqx<E: ExtensionField> {
// Let Q_max = max_num_proofs, assume that for a given P, num_proofs[P] = Q_i, then let STEP = Q_max / Q_i,
// Z(P, y, .) is only non-zero if y is a multiple of STEP, so Z[P][j][.] actually stores Z(P, j*STEP, .)
// The same applies to X
pub dense_multilinear: Option<DenseMultilinearExtension<E>>,
}

// Reverse the bits in q or x
Expand All @@ -50,15 +54,18 @@ impl<E: ExtensionField> DensePolynomialPqx<E> {
) -> Self {
let num_instances = z_mat.len().next_power_of_two();
let num_witness_secs = z_mat[0][0].len().next_power_of_two();
DensePolynomialPqx {
let mut inst = DensePolynomialPqx {
num_instances,
num_proofs,
max_num_proofs,
num_witness_secs,
num_inputs,
max_num_inputs,
Z: z_mat,
}
dense_multilinear: None,
};
inst.fill_dense_Z_poly();
inst
}

// Assume z_mat is in its standard form of (p, q, x)
Expand Down Expand Up @@ -101,15 +108,18 @@ impl<E: ExtensionField> DensePolynomialPqx<E> {
}
}
}
DensePolynomialPqx {
let mut inst = DensePolynomialPqx {
num_instances: num_instances.next_power_of_two(),
num_proofs,
max_num_proofs,
num_witness_secs: num_witness_secs.next_power_of_two(),
num_inputs,
max_num_inputs,
Z,
}
dense_multilinear: None,
};
inst.fill_dense_Z_poly();
inst
}

pub fn len(&self) -> usize {
Expand Down Expand Up @@ -319,6 +329,14 @@ impl<E: ExtensionField> DensePolynomialPqx<E> {
}
}

pub fn flattened_len(&self) -> usize {
self.num_instances * self.max_num_proofs * self.num_witness_secs * self.max_num_inputs
}

pub fn num_flattened_vars(&self) -> usize {
self.flattened_len().log_2()
}

pub fn evaluate(&self, r_p: &Vec<E>, r_q: &Vec<E>, r_w: &Vec<E>, r_x: &Vec<E>) -> E {
let mut cl = self.clone();
cl.bound_poly_vars_rx(r_x);
Expand All @@ -328,11 +346,11 @@ impl<E: ExtensionField> DensePolynomialPqx<E> {
cl.index(0, 0, 0, 0)
}

fn to_dense_Z_poly(&self) -> Vec<E> {
fn fill_dense_Z_poly(&mut self) {
let mut Z_poly =
vec![
E::ZERO;
self.num_instances * self.max_num_proofs * self.num_witness_secs * self.max_num_inputs
self.flattened_len()
];
for p in 0..min(self.num_instances, self.Z.len()) {
let step_q = self.max_num_proofs / self.num_proofs[p];
Expand All @@ -351,17 +369,191 @@ impl<E: ExtensionField> DensePolynomialPqx<E> {
}
}

Z_poly
self.dense_multilinear = Some(DenseMultilinearExtension::from_evaluations_ext_vec(Z_poly.len().log_2(), Z_poly));
}

// Convert to a (p, q_rev, x_rev) regular dense poly of form (p, q, x)
// Convert to Ceno prover compatible multilinear poly
pub fn to_dense_poly(&self) -> DensePolynomial<E> {
DensePolynomial::new(self.to_dense_Z_poly())
match self.evaluations() {
FieldType::Ext(v) => DensePolynomial::new(v.to_vec()),
_ => { unreachable!() }
}
}

// Convert to Ceno prover compatible multilinear poly
pub fn to_ceno_multilinear(&self) -> DenseMultilinearExtension<E> {
let Z_poly = self.to_dense_Z_poly();
DenseMultilinearExtension::from_evaluations_ext_vec(Z_poly.len().log_2(), Z_poly)
match self.evaluations() {
FieldType::Ext(v) => DenseMultilinearExtension::from_evaluations_ext_vec(v.len().log_2(), v.to_vec()),
_ => { unreachable!() }
}
}
}

impl<E: ExtensionField> MultilinearExtension<E> for DensePolynomialPqx<E> {
type Output = DenseMultilinearExtension<E>;
/// Reduce the number of variables of `self` by fixing the
/// `partial_point.len()` variables at `partial_point`.
fn fix_variables(&self, partial_point: &[E]) -> Self::Output {
// TODO: return error.
assert!(
partial_point.len() <= self.num_vars(),
"invalid size of partial point"
);

let mut poly = self.clone();

for point in partial_point.iter() {
poly.fix_variables_in_place(&[*point])
}
assert!(poly.num_flattened_vars() == self.num_flattened_vars() - partial_point.len(),);
poly.to_ceno_multilinear()
}

/// Reduce the number of variables of `self` by fixing the
/// `partial_point.len()` variables at `partial_point` in place
fn fix_variables_in_place(&mut self, partial_point: &[E]) {
// TODO: return error.
assert!(
partial_point.len() <= self.num_flattened_vars(),
"partial point len {} >= num_vars {}",
partial_point.len(),
self.num_flattened_vars()
);

let mut instance_vars = self.num_instances.log_2();
let mut proofs_vars = self.max_num_proofs.log_2();
let mut witness_secs_vars = self.num_witness_secs.log_2();
let mut input_vars = self.max_num_inputs.log_2();

for point in partial_point.iter() {
if input_vars > 0 {
self.bound_poly_vars_rx(&vec![*point]);
input_vars /= 2;
} else if witness_secs_vars > 0 {
self.bound_poly_vars_rw(&vec![*point]);
witness_secs_vars /= 2;
} else if proofs_vars > 0 {
self.bound_poly_vars_rq(&vec![*point]);
proofs_vars /= 2;
} else {
self.bound_poly_vars_rp(&vec![*point]);
instance_vars /= 2;
}
}
}

/// Reduce the number of variables of `self` by fixing the
/// `partial_point.len()` variables at `partial_point` from high position
fn fix_high_variables(&self, _partial_point: &[E]) -> Self::Output {
unimplemented!()
}

/// Reduce the number of variables of `self` by fixing the
/// `partial_point.len()` variables at `partial_point` from high position in place
fn fix_high_variables_in_place(&mut self, _partial_point: &[E]) {
unimplemented!()
}

/// Evaluate the MLE at a give point.
/// Returns an error if the MLE length does not match the point.
fn evaluate(&self, point: &[E]) -> E {
// TODO: return error.
assert_eq!(
self.num_vars(),
point.len(),
"MLE size does not match the point"
);
let mle = self.fix_variables_parallel(point);

if let Some(f) = &self.dense_multilinear {
match &f.evaluations {
FieldType::Ext(v) => v[0],
_ => unreachable!()
}
} else {
unreachable!()
}
}

fn num_vars(&self) -> usize {
self.num_flattened_vars()
}

/// Reduce the number of variables of `self` by fixing the
/// `partial_point.len()` variables at `partial_point`.
fn fix_variables_parallel(&self, partial_point: &[E]) -> Self::Output {
self.fix_variables(partial_point)
}

/// Reduce the number of variables of `self` by fixing the
/// `partial_point.len()` variables at `partial_point` in place
fn fix_variables_in_place_parallel(&mut self, partial_point: &[E]) {
self.fix_variables_in_place(partial_point);
}

fn evaluations(&self) -> &FieldType<E> {
&self.dense_multilinear.as_ref().unwrap().evaluations
}

fn evaluations_to_owned(self) -> FieldType<E> {
unimplemented!()
}

fn evaluations_range(&self) -> Option<(usize, usize)> {
None
}

fn name(&self) -> &'static str {
"DensePolynomialPqx"
}

/// assert and get base field vector
/// panic if not the case
fn get_base_field_vec(&self) -> &[E::BaseField] {
if let Some(f) = &self.dense_multilinear {
match &f.evaluations {
FieldType::Base(evaluations) => &evaluations[..],
_ => unreachable!(),
}
} else {
unreachable!()
}
}

fn merge(&mut self, _rhs: DenseMultilinearExtension<E>) {
unimplemented!()
}

/// get ranged multiliear extention
fn get_ranged_mle(
&self,
num_range: usize,
range_index: usize,
) -> RangedMultilinearExtension<'_, E> {
assert!(num_range > 0);
// ranged_mle is exclusively used in multi-thread parallelism
// The number of ranges must be a power of 2
assert!(num_range.next_power_of_two() == num_range);
let offset = self.evaluations().len() / num_range;
let start = offset * range_index;
RangedMultilinearExtension::new(self.dense_multilinear.as_ref().unwrap(), start, offset)
}

/// resize to new size (num_instances * new_size_per_instance / num_range)
/// and selected by range_index
/// only support resize base fields, otherwise panic
fn resize_ranged(
&self,
_num_instances: usize,
_new_size_per_instance: usize,
_num_range: usize,
_range_index: usize,
) -> Self::Output {
unimplemented!()
}

/// dup to new size 1 << (self.num_vars + ceil_log2(num_dups))
fn dup(&self, _num_instances: usize, _num_dups: usize) -> Self::Output {
unimplemented!()
}
}
13 changes: 8 additions & 5 deletions spartan_parallel/src/r1csproof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use serde::Serialize;
use std::cmp::min;
use std::iter::zip;
use std::sync::Arc;
use std::cmp::max;
use multilinear_extensions::{
mle::{IntoMLE, MultilinearExtension, DenseMultilinearExtension},
virtual_poly::VPAuxInfo,
Expand Down Expand Up @@ -245,13 +246,13 @@ impl<'a, E: ExtensionField + Send + Sync> R1CSProof<E> {
timer_tmp.stop();

// == test: ceno_verifier_bench ==
let max_num_vars = poly_tau.get_num_vars();
let num_threads = 32;
let max_num_vars = poly_tau.get_num_vars();

let arc_A: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_tau.to_ceno_multilinear());
let arc_B: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Az.to_ceno_multilinear());
let arc_C: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Bz.to_ceno_multilinear());
let arc_D: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Cz.to_ceno_multilinear());
let arc_B: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Az);
let arc_C: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Bz);
let arc_D: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(poly_Cz);

let mut virtual_polys =
VirtualPolynomials::new(num_threads, max_num_vars);
Expand Down Expand Up @@ -430,7 +431,6 @@ impl<'a, E: ExtensionField + Send + Sync> R1CSProof<E> {
let mut eq_p_rp_poly = DensePolynomial::new(
tmp_rp_poly.into_iter().map(|i| vec![i; scale]).collect::<Vec<Vec<E>>>().concat()
);
let max_num_vars_phase2 = ABC_poly.get_num_vars();

let mut claimed_sum = E::ZERO;
let mut claimed_partial_sum = E::ZERO;
Expand All @@ -445,6 +445,9 @@ impl<'a, E: ExtensionField + Send + Sync> R1CSProof<E> {
c_sum += c.clone();
}

// debug_ceno_prover
let max_num_vars_phase2 = max(ABC_poly.get_num_vars(), max(Z_poly.get_num_vars(), eq_p_rp_poly.get_num_vars()));

let arc_A: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(ABC_poly.to_ceno_multilinear());
let arc_B: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(Z_poly.to_ceno_multilinear());
let arc_C: Arc<dyn MultilinearExtension<_, Output = DenseMultilinearExtension<E>>> = Arc::new(eq_p_rp_poly.to_ceno_multilinear());
Expand Down

0 comments on commit 928be2e

Please sign in to comment.