Skip to content

Commit

Permalink
cleanup debug log and test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Sep 18, 2024
1 parent 7494bc8 commit ce2a5b1
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 90 deletions.
16 changes: 15 additions & 1 deletion ceno_zkvm/src/virtual_polys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,27 +254,40 @@ mod tests {
vec![<E as ExtensionField>::BaseField::ONE; 1 << (max_num_vars)]
.into_mle()
.into();

let f3: ArcMultilinearExtension<E> =
vec![<E as ExtensionField>::BaseField::ONE; 1 << (max_num_vars - 1)]
.into_mle()
.into();
let mut virtual_polys = VirtualPolynomials::<E>::new(num_threads, max_num_vars);

virtual_polys.add_mle_list(vec![&f1], E::ONE);
virtual_polys.add_mle_list(vec![&f2], E::ONE);
virtual_polys.add_mle_list(vec![&f3], E::ONE);

let (sumcheck_proofs, _) = IOPProverStateV2::prove_batch_polys(
num_threads,
virtual_polys.get_batched_polys(),
&mut transcript,
);

let base_2 = <E as ExtensionField>::BaseField::from(2);
let mut transcript = Transcript::new(b"test");
let subclaim = IOPVerifierState::<E>::verify(
E::ONE
* (f1
.get_base_field_vec()
.iter()
.sum::<<E as ExtensionField>::BaseField>()
* base_2.pow([(max_num_vars - f1.num_vars()) as u64])
+ f2.get_base_field_vec()
.iter()
.sum::<<E as ExtensionField>::BaseField>()),
.sum::<<E as ExtensionField>::BaseField>())
* base_2.pow([(max_num_vars - f2.num_vars()) as u64])
+ f3.get_base_field_vec()
.iter()
.sum::<<E as ExtensionField>::BaseField>()
* base_2.pow([(max_num_vars - f3.num_vars()) as u64]),
&sumcheck_proofs,
&VPAuxInfo {
max_degree: 1,
Expand All @@ -287,6 +300,7 @@ mod tests {
let mut verifier_poly = VirtualPolynomialV2::new(max_num_vars);
verifier_poly.add_mle_list(vec![f1.clone()], E::ONE);
verifier_poly.add_mle_list(vec![f2.clone()], E::ONE);
verifier_poly.add_mle_list(vec![f3.clone()], E::ONE);
assert!(
verifier_poly.evaluate(
subclaim
Expand Down
164 changes: 96 additions & 68 deletions sumcheck/src/prover_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ macro_rules! pad_and_chunk_2 {
($f:ident $(,$fs:ident)* $(,)?) => {
{
$f.iter().chain(if $f.len() == 1 {
println!("go padding");
vec![$f.last().unwrap()]
} else {
vec![]
Expand Down Expand Up @@ -216,7 +215,6 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
exit_span!(span);
// pushing the last challenge point to the state
if let Some(p) = challenge {
println!("prover challenge {:?}", challenge);
prover_state.challenges.push(p);
// fix last challenge to collect final evaluation
prover_state
Expand Down Expand Up @@ -398,7 +396,6 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
&mut self,
challenge: &Option<Challenge<E>>,
) -> IOPProverMessage<E> {
println!("prover challenge {:?}", challenge);
let start =
start_timer!(|| format!("sum check prove {}-th round and update state", self.round));

Expand Down Expand Up @@ -436,9 +433,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
if self.challenges.len() == 1 {
self.poly.flattened_ml_extensions.iter_mut().for_each(|f| {
if f.num_vars() > 0 {
println!("round {} fix var", self.round);
*f = Arc::new(f.fix_variables(&[r.elements]));
println!("round {} after fix var {:?}", self.round, f.evaluations())
} else {
panic!("calling sumcheck on constant")
}
Expand All @@ -454,11 +449,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
.for_each(|f| {
if let Some(f) = f {
if f.num_vars() > 0 {
println!("round {} fix var", self.round);
f.fix_variables_in_place(&[r.elements]);
println!("round {} after fix var {:?}", self.round, f.evaluations())
} else {
println!("round {} skip!", self.round)
}
}
});
Expand All @@ -482,14 +473,17 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
let f = &self.poly.flattened_ml_extensions[products[0]];
op_mle! {
|f| {
let a = pad_and_chunk_2!(f).fold(AdditiveArray::<E, 2>(array::from_fn(|_| 0.into())), |mut acc, mut b| {
println!("inside");
acc.0[0] += b.next().unwrap();
acc.0[1] += b.next().unwrap();
let res = pad_and_chunk_2!(f).fold(AdditiveArray::<E, 2>(array::from_fn(|_| 0.into())), |mut acc, mut b| {
acc.0[0] += *b.next().unwrap();
acc.0[1] += *b.next().unwrap();
acc
});
println!("a {:?}", a);
a
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
},
|sum| AdditiveArray(sum.0.map(E::from))
}
Expand All @@ -501,18 +495,25 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
&self.poly.flattened_ml_extensions[products[1]],
);
commutative_op_mle_pair!(
|f, g| pad_and_chunk_2!(f, g).fold(
AdditiveArray::<E, 3>(array::from_fn(|_| 0.into())),
|mut acc, (mut f, mut g)| {
let (f0, f1) = (f.next().unwrap(), f.next().unwrap()) ;
let (g0, g1) = (g.next().unwrap(), g.next().unwrap()) ;
acc.0[0] += *f0 * g0;
acc.0[1] += *f1 * g1;
acc.0[2] +=
(*f1 + f1 - f0) * (*g1 + g1 - g0);
acc
|f, g| {
let res = pad_and_chunk_2!(f, g).fold(
AdditiveArray::<E, 3>(array::from_fn(|_| 0.into())),
|mut acc, (mut f, mut g)| {
let (f0, f1) = (f.next().unwrap(), f.next().unwrap()) ;
let (g0, g1) = (g.next().unwrap(), g.next().unwrap()) ;
acc.0[0] += *f0 * g0;
acc.0[1] += *f1 * g1;
acc.0[2] +=
(*f1 + f1 - f0) * (*g1 + g1 - g0);
acc
});
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
),
},
|sum| AdditiveArray(sum.0.map(E::from))
)
.to_vec()
Expand All @@ -524,22 +525,29 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
&self.poly.flattened_ml_extensions[products[2]],
);
op_mle_3!(
|f1, f2, f3|
pad_and_chunk_2!(f1, f2, f3).fold(
AdditiveArray::<E, 4>(array::from_fn(|_| 0.into())), |mut acc, ((mut f1,mut f2),mut f3)| {
let (f10, f11) = (f1.next().unwrap(), f1.next().unwrap()) ;
let (f20, f21) = (f2.next().unwrap(), f2.next().unwrap()) ;
let (f30, f31) = (f3.next().unwrap(), f3.next().unwrap()) ;

let c1 = *f11 - f10;
let c2 = *f21 - f20;
let c3 = *f31 - f30;
acc.0[0] += *f10 * (*f20 * f30);
acc.0[1] += *f11 * (*f21 * f31);
acc.0[2] += (c1 + f11) * ((c2 + f21) * (c3 + f31));
acc.0[3] += (c1 + c1 + f11) * ((c2 + c2 + f21) * (c3 + c3 + f31));
acc
}),
|f1, f2, f3| {
let res = pad_and_chunk_2!(f1, f2, f3).fold(
AdditiveArray::<E, 4>(array::from_fn(|_| 0.into())), |mut acc, ((mut f1,mut f2),mut f3)| {
let (f10, f11) = (f1.next().unwrap(), f1.next().unwrap()) ;
let (f20, f21) = (f2.next().unwrap(), f2.next().unwrap()) ;
let (f30, f31) = (f3.next().unwrap(), f3.next().unwrap()) ;

let c1 = *f11 - f10;
let c2 = *f21 - f20;
let c3 = *f31 - f30;
acc.0[0] += *f10 * (*f20 * f30);
acc.0[1] += *f11 * (*f21 * f31);
acc.0[2] += (c1 + f11) * ((c2 + f21) * (c3 + f31));
acc.0[3] += (c1 + c1 + f11) * ((c2 + c2 + f21) * (c3 + c3 + f31));
acc
});
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
},
|sum| AdditiveArray(sum.0.map(E::from))
)
.to_vec()
Expand Down Expand Up @@ -570,7 +578,6 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {

end_timer!(start);

println!("round {} final product sum {:?}", self.round, products_sum);
IOPProverMessage {
evaluations: products_sum,
}
Expand Down Expand Up @@ -800,12 +807,18 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
let f = &self.poly.flattened_ml_extensions[products[0]];
op_mle! {
|f| {
par_pad_and_chunk_2!(f).fold_with(AdditiveArray::<E, 2>(array::from_fn(|_| 0.into())), |mut acc, b| {
let res = par_pad_and_chunk_2!(f).fold_with(AdditiveArray::<E, 2>(array::from_fn(|_| 0.into())), |mut acc, b| {
acc.0[0] += b[0];
acc.0[1] += b[1];
acc
}).reduce_with(|acc, item| acc + item)
.unwrap()
.unwrap();
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
},
|sum| AdditiveArray(sum.0.map(E::from))
}
Expand All @@ -817,17 +830,25 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
&self.poly.flattened_ml_extensions[products[1]],
);
commutative_op_mle_pair!(
|f, g| par_pad_and_chunk_2!(f, g).fold_with(
AdditiveArray::<E, 3>(array::from_fn(|_| 0.into())),
|mut acc, (f, g)| {
acc.0[0] += *f[0] * g[0];
acc.0[1] += *f[1] * g[1];
acc.0[2] +=
(*f[1] + f[1] - f[0]) * (*g[1] + g[1] - g[0]);
acc
|f, g| {
let res = par_pad_and_chunk_2!(f, g).fold_with(
AdditiveArray::<E, 3>(array::from_fn(|_| 0.into())),
|mut acc, (f, g)| {
acc.0[0] += *f[0] * g[0];
acc.0[1] += *f[1] * g[1];
acc.0[2] +=
(*f[1] + f[1] - f[0]) * (*g[1] + g[1] - g[0]);
acc
}
).reduce_with(|acc, item| acc + item)
.unwrap();
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
).reduce_with(|acc, item| acc + item)
.unwrap(),
},
|sum| AdditiveArray(sum.0.map(E::from))
)
.to_vec()
Expand All @@ -839,19 +860,26 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
&self.poly.flattened_ml_extensions[products[2]],
);
op_mle_3!(
|f1, f2, f3|
par_pad_and_chunk_2!(f1, f2, f3).fold_with(
AdditiveArray::<E, 4>(array::from_fn(|_| 0.into())), |mut acc, ((f1, f2), f3)| {
let c1 = *f1[1] - f1[0];
let c2 = *f2[1] - f2[0];
let c3 = *f3[1] - f3[0];
acc.0[0] += *f1[0] * (*f2[0] * f3[0]);
acc.0[1] += *f1[1] * (*f2[1] * *f3[1]);
acc.0[2] += (c1 + f1[1]) * ((c2 + f2[1]) * (c3 + *f3[1]));
acc.0[3] += (c1 + c1 + f1[1]) * ((c2 + c2 + f2[1]) * (c3 + c3 + *f3[1]));
acc
}).reduce_with(|acc, item| acc + item)
.unwrap(),
|f1, f2, f3| {
let res = par_pad_and_chunk_2!(f1, f2, f3).fold_with(
AdditiveArray::<E, 4>(array::from_fn(|_| 0.into())), |mut acc, ((f1, f2), f3)| {
let c1 = *f1[1] - f1[0];
let c2 = *f2[1] - f2[0];
let c3 = *f3[1] - f3[0];
acc.0[0] += *f1[0] * (*f2[0] * f3[0]);
acc.0[1] += *f1[1] * (*f2[1] * *f3[1]);
acc.0[2] += (c1 + f1[1]) * ((c2 + f2[1]) * (c3 + *f3[1]));
acc.0[3] += (c1 + c1 + f1[1]) * ((c2 + c2 + f2[1]) * (c3 + c3 + *f3[1]));
acc
}).reduce_with(|acc, item| acc + item)
.unwrap();
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity)))
} else {
res
}
},
|sum| AdditiveArray(sum.0.map(E::from))
)
.to_vec()
Expand Down
Loading

0 comments on commit ce2a5b1

Please sign in to comment.