From 61503c6cdd6b6de7182af5f11ba16087795a2fa6 Mon Sep 17 00:00:00 2001 From: Jay White Date: Wed, 2 Oct 2024 11:51:17 -0400 Subject: [PATCH 1/2] refactor: move `Subclaim` and inline `Subclaim::create` --- .../src/proof_primitive/sumcheck/mod.rs | 3 -- .../src/proof_primitive/sumcheck/proof.rs | 34 +++++++++++--- .../src/proof_primitive/sumcheck/subclaim.rs | 45 ------------------- 3 files changed, 27 insertions(+), 55 deletions(-) delete mode 100644 crates/proof-of-sql/src/proof_primitive/sumcheck/subclaim.rs diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs index 8f7cefed2..36de38a5f 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs @@ -6,9 +6,6 @@ pub use proof::SumcheckProof; mod prover_state; use prover_state::ProverState; -mod subclaim; -pub use subclaim::Subclaim; - mod prover_round; use prover_round::prove_round; diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs index 010306e3e..116802ff9 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs @@ -1,10 +1,10 @@ use crate::{ base::{ - polynomial::{CompositePolynomial, CompositePolynomialInfo}, + polynomial::{interpolate_uni_poly, CompositePolynomial, CompositePolynomialInfo}, proof::{ProofError, Transcript}, scalar::Scalar, }, - proof_primitive::sumcheck::{prove_round, ProverState, Subclaim}, + proof_primitive::sumcheck::{prove_round, ProverState}, }; /** * Adopted from arkworks @@ -18,6 +18,10 @@ use serde::{Deserialize, Serialize}; pub struct SumcheckProof { pub(super) evaluations: Vec>, } +pub struct Subclaim { + pub evaluation_point: Vec, + pub expected_evaluation: S, +} impl SumcheckProof { #[tracing::instrument(name = "SumcheckProof::create", level = "debug", skip_all)] @@ -74,11 +78,27 @@ impl SumcheckProof { transcript.extend_scalars_as_be(&self.evaluations[round_index]); evaluation_point.push(transcript.scalar_challenge_as_be()); } - Subclaim::create( + + assert!(polynomial_info.max_multiplicands > 0); + let mut expected_evaluation = *claimed_sum; + for round_index in 0..polynomial_info.num_variables { + let round_evaluation = &self.evaluations[round_index]; + if round_evaluation.len() != polynomial_info.max_multiplicands + 1 { + return Err(ProofError::VerificationError { + error: "round evaluation length does not match max multiplicands", + }); + } + if expected_evaluation != round_evaluation[0] + round_evaluation[1] { + return Err(ProofError::VerificationError { + error: "round evaluation does not match claimed sum", + }); + } + expected_evaluation = + interpolate_uni_poly(round_evaluation, evaluation_point[round_index]); + } + Ok(Subclaim { evaluation_point, - &self.evaluations, - polynomial_info.max_multiplicands, - claimed_sum, - ) + expected_evaluation, + }) } } diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/subclaim.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/subclaim.rs deleted file mode 100644 index 47d7f15b6..000000000 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/subclaim.rs +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Adopted from arkworks - * - * See third_party/license/arkworks.LICENSE - */ -use crate::base::scalar::Scalar; -use crate::base::{polynomial::interpolate_uni_poly, proof::ProofError}; -use alloc::vec::Vec; - -pub struct Subclaim { - pub evaluation_point: Vec, - pub expected_evaluation: S, -} - -impl Subclaim { - pub(super) fn create( - evaluation_point: Vec, - evaluations: &[Vec], - max_multiplicands: usize, - claimed_sum: &S, - ) -> Result, ProofError> { - let num_vars = evaluation_point.len(); - assert!(max_multiplicands > 0); - assert_eq!(num_vars, evaluations.len()); - let mut expected_sum = *claimed_sum; - for round_index in 0..num_vars { - let round_evaluation = &evaluations[round_index]; - if round_evaluation.len() != max_multiplicands + 1 { - return Err(ProofError::VerificationError { - error: "round evaluation length does not match max multiplicands", - }); - } - if expected_sum != round_evaluation[0] + round_evaluation[1] { - return Err(ProofError::VerificationError { - error: "round evaluation does not match claimed sum", - }); - } - expected_sum = interpolate_uni_poly(round_evaluation, evaluation_point[round_index]); - } - Ok(Subclaim { - evaluation_point, - expected_evaluation: expected_sum, - }) - } -} From 4f20785dc79e2b785ddbaa7ec9e8df601c97665f Mon Sep 17 00:00:00 2001 From: Jay White Date: Wed, 2 Oct 2024 14:35:19 -0400 Subject: [PATCH 2/2] refactor!: switch sumcheck proof to return coefficients instead of evaluations --- .../src/base/polynomial/interpolate.rs | 2 +- .../src/proof_primitive/sumcheck/proof.rs | 54 +++++++++++-------- .../proof_primitive/sumcheck/proof_test.rs | 4 +- .../proof_primitive/sumcheck/test_cases.rs | 2 +- 4 files changed, 37 insertions(+), 25 deletions(-) diff --git a/crates/proof-of-sql/src/base/polynomial/interpolate.rs b/crates/proof-of-sql/src/base/polynomial/interpolate.rs index 44d046553..f5614b921 100644 --- a/crates/proof-of-sql/src/base/polynomial/interpolate.rs +++ b/crates/proof-of-sql/src/base/polynomial/interpolate.rs @@ -14,6 +14,7 @@ use num_traits::{Inv, One, Zero}; /// For any polynomial, `f(x)`, with degree less than or equal to `d`, we have that: /// `f(x) = sum_{i=0}^{d} (-1)^(d-i) * (f(i) / (i! * (d-i)! * (x-i))) * prod_{i=0}^{d} (x-i)` /// unless x is one of 0,1,...,d, in which case, f(x) is already known. +#[allow(dead_code)] pub fn interpolate_uni_poly(polynomial: &[F], x: F) -> F where F: Copy @@ -72,7 +73,6 @@ where /// Let d be the evals.len() - 1 and let f be the polynomial such that f(i) = evals[i]. /// The output of this function is the vector of coefficients of f, leading coefficient first. /// That is, `f(x) = evals[j]*x^(d-j)``. -#[allow(dead_code)] pub fn interpolate_evaluations_to_reverse_coefficients(evals: &[S]) -> Vec where S: Zero diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs index 116802ff9..7fcde240a 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs @@ -1,6 +1,9 @@ use crate::{ base::{ - polynomial::{interpolate_uni_poly, CompositePolynomial, CompositePolynomialInfo}, + polynomial::{ + interpolate_evaluations_to_reverse_coefficients, CompositePolynomial, + CompositePolynomialInfo, + }, proof::{ProofError, Transcript}, scalar::Scalar, }, @@ -16,7 +19,7 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SumcheckProof { - pub(super) evaluations: Vec>, + pub(super) coefficients: Vec, } pub struct Subclaim { pub evaluation_point: Vec, @@ -39,16 +42,18 @@ impl SumcheckProof { transcript.scalar_challenge_as_be::(); let mut r = None; let mut state = ProverState::create(polynomial); - let mut evaluations = Vec::with_capacity(polynomial.num_variables); + let mut coefficients = Vec::with_capacity(polynomial.num_variables); for scalar in evaluation_point.iter_mut().take(polynomial.num_variables) { let round_evaluations = prove_round(&mut state, &r); - transcript.extend_scalars_as_be(&round_evaluations); + let round_coefficients = + interpolate_evaluations_to_reverse_coefficients(&round_evaluations); + transcript.extend_scalars_as_be(&round_coefficients); + coefficients.extend(round_coefficients); *scalar = transcript.scalar_challenge_as_be(); - evaluations.push(round_evaluations); r = Some(*scalar); } - SumcheckProof { evaluations } + SumcheckProof { coefficients } } #[tracing::instrument( @@ -68,33 +73,40 @@ impl SumcheckProof { ]); // This challenge is in order to keep transcript messages grouped. (This simplifies the Solidity implementation.) transcript.scalar_challenge_as_be::(); - if self.evaluations.len() != polynomial_info.num_variables { + if self.coefficients.len() + != polynomial_info.num_variables * (polynomial_info.max_multiplicands + 1) + { return Err(ProofError::VerificationError { - error: "invalid number of evaluations", + error: "invalid proof size", }); } let mut evaluation_point = Vec::with_capacity(polynomial_info.num_variables); - for round_index in 0..polynomial_info.num_variables { - transcript.extend_scalars_as_be(&self.evaluations[round_index]); - evaluation_point.push(transcript.scalar_challenge_as_be()); - } - assert!(polynomial_info.max_multiplicands > 0); let mut expected_evaluation = *claimed_sum; for round_index in 0..polynomial_info.num_variables { - let round_evaluation = &self.evaluations[round_index]; - if round_evaluation.len() != polynomial_info.max_multiplicands + 1 { - return Err(ProofError::VerificationError { - error: "round evaluation length does not match max multiplicands", - }); + let start_index = round_index * (polynomial_info.max_multiplicands + 1); + transcript.extend_scalars_as_be( + &self.coefficients + [start_index..start_index + polynomial_info.max_multiplicands + 1], + ); + let round_evaluation_point = transcript.scalar_challenge_as_be(); + evaluation_point.push(round_evaluation_point); + let mut round_evaluation = self.coefficients[start_index]; + let mut actual_sum = round_evaluation + + self.coefficients[start_index + polynomial_info.max_multiplicands]; + for coefficient_index in + start_index + 1..start_index + polynomial_info.max_multiplicands + 1 + { + round_evaluation *= round_evaluation_point; + round_evaluation += self.coefficients[coefficient_index]; + actual_sum += self.coefficients[coefficient_index]; } - if expected_evaluation != round_evaluation[0] + round_evaluation[1] { + if actual_sum != expected_evaluation { return Err(ProofError::VerificationError { error: "round evaluation does not match claimed sum", }); } - expected_evaluation = - interpolate_uni_poly(round_evaluation, evaluation_point[round_index]); + expected_evaluation = round_evaluation; } Ok(Subclaim { evaluation_point, diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs index 69666d868..218cf4cfb 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs @@ -68,7 +68,7 @@ fn test_create_verify_proof() { assert!(subclaim.is_err()); // verify fails if evaluations are changed - proof.evaluations[0][1] += Curve25519Scalar::from(3u64); + proof.coefficients[0] += Curve25519Scalar::from(3u64); let subclaim = proof.verify_without_evaluation( &mut transcript, poly.info(), @@ -229,7 +229,7 @@ fn we_can_verify_many_random_test_cases() { ); let mut modified_proof = proof; - modified_proof.evaluations[0][0] += TestScalar::ONE; + modified_proof.coefficients[0] += TestScalar::ONE; let mut transcript = Transcript::new(b"sumchecktest"); assert!( modified_proof diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/test_cases.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/test_cases.rs index 863c7b9ec..dcc6f0079 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/test_cases.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/test_cases.rs @@ -48,7 +48,7 @@ pub fn sumcheck_test_cases( rng: &mut (impl ark_std::rand::Rng + ?Sized), ) -> impl Iterator> + '_ { (1..=8) - .cartesian_product(1..=5) + .cartesian_product(0..=5) .flat_map(|(num_vars, max_multiplicands)| { [ Some(vec![]),