Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sumcheck protocol support mixed num_vars monomial form #235

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 103 additions & 2 deletions ceno_zkvm/src/virtual_polys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,28 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
#[cfg(test)]
mod tests {

use ark_std::test_rng;
use ff_ext::ExtensionField;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use multilinear_extensions::{mle::IntoMLE, virtual_poly_v2::ArcMultilinearExtension};
use multilinear_extensions::{
mle::IntoMLE,
virtual_poly::VPAuxInfo,
virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2},
};
use sumcheck::structs::{IOPProverStateV2, IOPVerifierState};
use transcript::Transcript;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
expression::{Expression, ToExpr},
virtual_polys::VirtualPolynomials,
};
use ff::Field;
type E = GoldilocksExt2;

#[test]
fn test_add_mle_list_by_expr() {
type E = GoldilocksExt2;
let mut cs = ConstraintSystem::new(|| "test_root");
let mut cb = CircuitBuilder::<E>::new(&mut cs);
let x = cb.create_witin(|| "x").unwrap();
Expand Down Expand Up @@ -218,4 +227,96 @@ mod tests {
assert!(distrinct_zerocheck_terms_set.len() == 1);
assert!(virtual_polys.degree() == 3);
}

#[test]
fn test_sumcheck_different_degree() {
let max_num_vars = 3;
let fn_eval = |fs: &[ArcMultilinearExtension<E>]| -> E {
let base_2 = <E as ExtensionField>::BaseField::from(2);

let evals = fs.iter().fold(
vec![<E as ExtensionField>::BaseField::ONE; 1 << fs[0].num_vars()],
|mut evals, f| {
evals
.iter_mut()
.zip(f.get_base_field_vec())
.for_each(|(e, v)| {
*e *= v;
});
evals
},
);

<<E as ExtensionField>::BaseField as std::convert::Into<E>>::into(
evals.iter().sum::<<E as ExtensionField>::BaseField>()
* base_2.pow([(max_num_vars - fs[0].num_vars()) as u64]),
)
};
let num_threads = 1;
let mut transcript = Transcript::new(b"test");

let mut rng = test_rng();

let f1: [ArcMultilinearExtension<E>; 2] = std::array::from_fn(|_| {
(0..1 << (max_num_vars - 2))
.map(|_| <E as ExtensionField>::BaseField::random(&mut rng))
.collect_vec()
.into_mle()
.into()
});
let f2: [ArcMultilinearExtension<E>; 1] = std::array::from_fn(|_| {
(0..1 << (max_num_vars))
.map(|_| <E as ExtensionField>::BaseField::random(&mut rng))
.collect_vec()
.into_mle()
.into()
});
let f3: [ArcMultilinearExtension<E>; 3] = std::array::from_fn(|_| {
(0..1 << (max_num_vars - 1))
.map(|_| <E as ExtensionField>::BaseField::random(&mut rng))
.collect_vec()
.into_mle()
.into()
});

let mut virtual_polys = VirtualPolynomials::<E>::new(num_threads, max_num_vars);

virtual_polys.add_mle_list(f1.iter().collect(), E::ONE);
virtual_polys.add_mle_list(f2.iter().collect(), E::ONE);
virtual_polys.add_mle_list(f3.iter().collect(), E::ONE);

let (sumcheck_proofs, _) = IOPProverStateV2::prove_batch_polys(
num_threads,
virtual_polys.get_batched_polys(),
&mut transcript,
);

let mut transcript = Transcript::new(b"test");
let subclaim = IOPVerifierState::<E>::verify(
fn_eval(&f1) + fn_eval(&f2) + fn_eval(&f3),
&sumcheck_proofs,
&VPAuxInfo {
max_degree: 3,
num_variables: max_num_vars,
phantom: std::marker::PhantomData,
},
&mut transcript,
);

let mut verifier_poly = VirtualPolynomialV2::new(max_num_vars);
verifier_poly.add_mle_list(f1.to_vec(), E::ONE);
verifier_poly.add_mle_list(f2.to_vec(), E::ONE);
verifier_poly.add_mle_list(f3.to_vec(), E::ONE);
assert!(
verifier_poly.evaluate(
subclaim
.point
.iter()
.map(|c| c.elements)
.collect::<Vec<_>>()
.as_ref()
) == subclaim.expected_evaluation,
"wrong subclaim"
);
}
}
1 change: 1 addition & 0 deletions multilinear_extensions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tracing = "0.1.40"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
tracing-flame = "0.2.0"
ff_ext = { path = "../ff_ext" }
itertools = "0.12.1"
ark-std.workspace = true
ff.workspace = true
goldilocks.workspace = true
Expand Down
10 changes: 3 additions & 7 deletions multilinear_extensions/src/mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ pub trait IntoMLEs<T>: Sized {
fn into_mles(self) -> Vec<T>;
}

impl<F: Field, E: ExtensionField<BaseField = F>> IntoMLEs<DenseMultilinearExtension<E>> for Vec<Vec<F>> {
impl<F: Field, E: ExtensionField<BaseField = F>> IntoMLEs<DenseMultilinearExtension<E>>
for Vec<Vec<F>>
{
fn into_mles(self) -> Vec<DenseMultilinearExtension<E>> {
self.into_iter().map(|v| v.into_mle()).collect()
}
Expand Down Expand Up @@ -1000,12 +1002,6 @@ macro_rules! op_mle {
match &$a.evaluations() {
$crate::mle::FieldType::Base(a) => {
let $tmp_a = if let Some((start, offset)) = $a.evaluations_range() {
println!(
"op_mle start {}, offset {}, a.len {}",
start,
offset,
a.len()
);
&a[start..][..offset]
} else {
&a[..]
Expand Down
5 changes: 5 additions & 0 deletions multilinear_extensions/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ pub fn create_uninit_vec<T: Sized>(len: usize) -> Vec<MaybeUninit<T>> {
unsafe { vec.set_len(len) };
vec
}

#[inline(always)]
pub fn largest_even_below(n: usize) -> usize {
if n % 2 == 0 { n } else { n.saturating_sub(1) }
}
53 changes: 29 additions & 24 deletions multilinear_extensions/src/virtual_poly_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
};
use ark_std::{end_timer, start_timer};
use ff_ext::ExtensionField;
use itertools::Itertools;
use serde::{Deserialize, Serialize};

pub type ArcMultilinearExtension<'a, E> =
Expand Down Expand Up @@ -55,8 +56,8 @@ pub struct VirtualPolynomialV2<'a, E: ExtensionField> {
pub struct VPAuxInfo<E> {
/// max number of multiplicands in each product
pub max_degree: usize,
/// number of variables of the polynomial
pub num_variables: usize,
/// max number of variables of the polynomial
pub max_num_variables: usize,
/// Associated field
#[doc(hidden)]
pub phantom: PhantomData<E>,
Expand All @@ -69,12 +70,12 @@ impl<E: ExtensionField> AsRef<[u8]> for VPAuxInfo<E> {
}

impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {
/// Creates an empty virtual polynomial with `num_variables`.
pub fn new(num_variables: usize) -> Self {
/// Creates an empty virtual polynomial with `max_num_variables`.
pub fn new(max_num_variables: usize) -> Self {
VirtualPolynomialV2 {
aux_info: VPAuxInfo {
max_degree: 0,
num_variables,
max_num_variables,
phantom: PhantomData,
},
products: Vec::new(),
Expand All @@ -93,7 +94,7 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {
aux_info: VPAuxInfo {
// The max degree is the max degree of any individual variable
max_degree: 1,
num_variables: mle.num_vars(),
max_num_variables: mle.num_vars(),
phantom: PhantomData,
},
// here `0` points to the first polynomial of `flattened_ml_extensions`
Expand All @@ -104,8 +105,10 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {
}

/// Add a product of list of multilinear extensions to self
/// Returns an error if the list is empty, or the MLE has a different
/// `num_vars()` from self.
/// Returns an error if the list is empty.
///
/// mle in mle_list must be in same num_vars() in same product,
/// while different product can have different num_vars()
///
/// The MLEs will be multiplied together, and then multiplied by the scalar
/// `coefficient`.
Expand All @@ -114,18 +117,20 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {
let mut indexed_product = Vec::with_capacity(mle_list.len());

assert!(!mle_list.is_empty(), "input mle_list is empty");
// sanity check: all mle in mle_list must have same num_vars()
assert!(
mle_list
.iter()
.map(|m| {
assert!(m.num_vars() <= self.aux_info.max_num_variables);
m.num_vars()
})
.all_equal()
);

self.aux_info.max_degree = max(self.aux_info.max_degree, mle_list.len());

for mle in mle_list {
assert_eq!(
mle.num_vars(),
self.aux_info.num_variables,
"product has a multiplicand with wrong number of variables {} vs {}",
mle.num_vars(),
self.aux_info.num_variables
);

let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize;
if let Some(index) = self.raw_pointers_lookup_table.get(&mle_ptr) {
indexed_product.push(*index)
Expand Down Expand Up @@ -163,10 +168,10 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {

assert_eq!(
mle.num_vars(),
self.aux_info.num_variables,
self.aux_info.max_num_variables,
"product has a multiplicand with wrong number of variables {} vs {}",
mle.num_vars(),
self.aux_info.num_variables
self.aux_info.max_num_variables
);

let mle_ptr = Arc::as_ptr(&mle) as *const () as usize;
Expand Down Expand Up @@ -200,17 +205,17 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {
let start = start_timer!(|| "evaluation");

assert_eq!(
self.aux_info.num_variables,
self.aux_info.max_num_variables,
point.len(),
"wrong number of variables {} vs {}",
self.aux_info.num_variables,
self.aux_info.max_num_variables,
point.len()
);

let evals: Vec<E> = self
.flattened_ml_extensions
.iter()
.map(|x| x.evaluate(point))
.map(|x| x.evaluate(&point[0..x.num_vars()]))
.collect();

let res = self
Expand All @@ -225,11 +230,11 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {

/// Print out the evaluation map for testing. Panic if the num_vars() > 5.
pub fn print_evals(&self) {
if self.aux_info.num_variables > 5 {
if self.aux_info.max_num_variables > 5 {
panic!("this function is used for testing only. cannot print more than 5 num_vars()")
}
for i in 0..1 << self.aux_info.num_variables {
let point = bit_decompose(i, self.aux_info.num_variables);
for i in 0..1 << self.aux_info.max_num_variables {
let point = bit_decompose(i, self.aux_info.max_num_variables);
let point_fr: Vec<E> = point.iter().map(|&x| E::from(x as u64)).collect();
println!("{} {:?}", i, self.evaluate(point_fr.as_ref()))
}
Expand Down
Loading
Loading