Skip to content

Commit

Permalink
Make CommitmentSchemeProver::prove_values take ownership (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored Nov 24, 2024
1 parent cd8b37b commit 3e5a81d
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 89 deletions.
2 changes: 1 addition & 1 deletion crates/prover/src/core/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ impl<B: FriOps + MerkleOps<H>, H: MerkleHasher> FriLayerProver<B, H> {
let commitment = self.merkle_tree.root();
// TODO(andrew): Use _evals.
let (_evals, decommitment) = self.merkle_tree.decommit(
[(self.evaluation.len().ilog2(), decommit_positions)]
&[(self.evaluation.len().ilog2(), decommit_positions)]
.into_iter()
.collect(),
self.evaluation.values.columns.iter().collect_vec(),
Expand Down
8 changes: 5 additions & 3 deletions crates/prover/src/core/pcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
}

pub fn prove_values(
&self,
self,
sampled_points: TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>,
channel: &mut MC::C,
) -> CommitmentSchemeProof<MC::H> {
Expand Down Expand Up @@ -133,13 +133,14 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
.iter()
.map(|(&log_size, domain)| (log_size, domain.flatten()))
.collect();
tree.decommit(queries)
tree.decommit(&queries)
});

let queried_values = decommitment_results.as_ref().map(|(v, _)| v.clone());
let decommitments = decommitment_results.map(|(_, d)| d);

CommitmentSchemeProof {
commitments: self.roots(),
sampled_values,
decommitments,
queried_values,
Expand All @@ -151,6 +152,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,

#[derive(Debug, Serialize, Deserialize)]
pub struct CommitmentSchemeProof<H: MerkleHasher> {
pub commitments: TreeVec<H::Hash>,
pub sampled_values: TreeVec<ColumnVec<Vec<SecureField>>>,
pub decommitments: TreeVec<MerkleDecommitment<H>>,
pub queried_values: TreeVec<ColumnVec<Vec<BaseField>>>,
Expand Down Expand Up @@ -231,7 +233,7 @@ impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
/// positions on each column of that size.
fn decommit(
&self,
queries: BTreeMap<u32, Vec<usize>>,
queries: &BTreeMap<u32, Vec<usize>>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<MC::H>) {
let eval_vec = self
.evaluations
Expand Down
103 changes: 51 additions & 52 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::ops::Deref;
use std::{array, mem};

use serde::{Deserialize, Serialize};
Expand All @@ -9,7 +10,7 @@ use super::backend::BackendForChannel;
use super::channel::MerkleChannel;
use super::fields::secure_column::SECURE_EXTENSION_DEGREE;
use super::fri::FriVerificationError;
use super::pcs::{CommitmentSchemeProof, TreeVec};
use super::pcs::CommitmentSchemeProof;
use super::vcs::ops::MerkleHasher;
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::channel::Channel;
Expand All @@ -22,17 +23,11 @@ use crate::core::vcs::hash::Hash;
use crate::core::vcs::prover::MerkleDecommitment;
use crate::core::vcs::verifier::MerkleVerificationError;

#[derive(Debug, Serialize, Deserialize)]
pub struct StarkProof<H: MerkleHasher> {
pub commitments: TreeVec<H::Hash>,
pub commitment_scheme_proof: CommitmentSchemeProof<H>,
}

#[instrument(skip_all)]
pub fn prove<B: BackendForChannel<MC>, MC: MerkleChannel>(
components: &[&dyn ComponentProver<B>],
channel: &mut MC::C,
commitment_scheme: &mut CommitmentSchemeProver<'_, B, MC>,
mut commitment_scheme: CommitmentSchemeProver<'_, B, MC>,
) -> Result<StarkProof<MC::H>, ProvingError> {
let n_preprocessed_columns = commitment_scheme.trees[PREPROCESSED_TRACE_IDX]
.polynomials
Expand Down Expand Up @@ -67,25 +62,19 @@ pub fn prove<B: BackendForChannel<MC>, MC: MerkleChannel>(

// Prove the trace and composition OODS values, and retrieve them.
let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel);

let sampled_oods_values = &commitment_scheme_proof.sampled_values;
let composition_oods_eval = extract_composition_eval(sampled_oods_values).unwrap();
let proof = StarkProof(commitment_scheme_proof);
info!(proof_size_estimate = proof.size_estimate());

// Evaluate composition polynomial at OODS point and check that it matches the trace OODS
// values. This is a sanity check.
if composition_oods_eval
if proof.extract_composition_oods_eval().unwrap()
!= component_provers
.components()
.eval_composition_polynomial_at_point(oods_point, sampled_oods_values, random_coeff)
.eval_composition_polynomial_at_point(oods_point, &proof.sampled_values, random_coeff)
{
return Err(ProvingError::ConstraintsNotSatisfied);
}

let proof = StarkProof {
commitments: commitment_scheme.roots(),
commitment_scheme_proof,
};
info!(proof_size_estimate = proof.size_estimate());
Ok(proof)
}

Expand Down Expand Up @@ -120,42 +109,21 @@ pub fn verify<MC: MerkleChannel>(
// Add the composition polynomial mask points.
sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]);

let sampled_oods_values = &proof.commitment_scheme_proof.sampled_values;
let composition_oods_eval = extract_composition_eval(sampled_oods_values).map_err(|_| {
let composition_oods_eval = proof.extract_composition_oods_eval().map_err(|_| {
VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string())
})?;

if composition_oods_eval
!= components.eval_composition_polynomial_at_point(
oods_point,
sampled_oods_values,
&proof.sampled_values,
random_coeff,
)
{
return Err(VerificationError::OodsNotMatching);
}

commitment_scheme.verify_values(sample_points, proof.commitment_scheme_proof, channel)
}

/// Extracts the composition trace evaluation from the mask.
fn extract_composition_eval(
mask: &TreeVec<Vec<Vec<SecureField>>>,
) -> Result<SecureField, InvalidOodsSampleStructure> {
let mut composition_cols = mask.last().into_iter().flatten();

let coordinate_evals = array::try_from_fn(|_| {
let col = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?;
let [eval] = col.try_into().map_err(|_| InvalidOodsSampleStructure)?;
Ok(eval)
})?;

// Too many columns.
if composition_cols.next().is_some() {
return Err(InvalidOodsSampleStructure);
}

Ok(SecureField::from_partial_evals(coordinate_evals))
commitment_scheme.verify_values(sample_points, proof.0, channel)
}

/// Error when the sampled values have an invalid structure.
Expand Down Expand Up @@ -187,20 +155,44 @@ pub enum VerificationError {
ProofOfWork,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct StarkProof<H: MerkleHasher>(pub CommitmentSchemeProof<H>);

impl<H: MerkleHasher> StarkProof<H> {
/// Extracts the composition trace Out-Of-Domain-Sample evaluation from the mask.
fn extract_composition_oods_eval(&self) -> Result<SecureField, InvalidOodsSampleStructure> {
// TODO(andrew): `[.., composition_mask, _quotients_mask]` when add quotients commitment.
let [.., composition_mask] = &**self.sampled_values else {
return Err(InvalidOodsSampleStructure);
};

let mut composition_cols = composition_mask.iter();

let coordinate_evals = array::try_from_fn(|_| {
let col = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?;
let [eval] = col.try_into().map_err(|_| InvalidOodsSampleStructure)?;
Ok(eval)
})?;

// Too many columns.
if composition_cols.next().is_some() {
return Err(InvalidOodsSampleStructure);
}

Ok(SecureField::from_partial_evals(coordinate_evals))
}

/// Returns the estimate size (in bytes) of the proof.
pub fn size_estimate(&self) -> usize {
SizeEstimate::size_estimate(self)
}

/// Returns size estimates (in bytes) for different parts of the proof.
pub fn size_breakdown_estimate(&self) -> StarkProofSizeBreakdown {
let Self {
commitments,
commitment_scheme_proof,
} = self;
let Self(commitment_scheme_proof) = self;

let CommitmentSchemeProof {
commitments,
sampled_values,
decommitments,
queried_values,
Expand Down Expand Up @@ -236,6 +228,14 @@ impl<H: MerkleHasher> StarkProof<H> {
}
}

impl<H: MerkleHasher> Deref for StarkProof<H> {
type Target = CommitmentSchemeProof<H>;

fn deref(&self) -> &CommitmentSchemeProof<H> {
&self.0
}
}

/// Size estimate (in bytes) for different parts of the proof.
pub struct StarkProofSizeBreakdown {
pub oods_samples: usize,
Expand Down Expand Up @@ -313,13 +313,15 @@ impl<H: MerkleHasher> SizeEstimate for FriProof<H> {
impl<H: MerkleHasher> SizeEstimate for CommitmentSchemeProof<H> {
fn size_estimate(&self) -> usize {
let Self {
commitments,
sampled_values,
decommitments,
queried_values,
proof_of_work,
fri_proof,
} = self;
sampled_values.size_estimate()
commitments.size_estimate()
+ sampled_values.size_estimate()
+ decommitments.size_estimate()
+ queried_values.size_estimate()
+ mem::size_of_val(proof_of_work)
Expand All @@ -329,11 +331,8 @@ impl<H: MerkleHasher> SizeEstimate for CommitmentSchemeProof<H> {

impl<H: MerkleHasher> SizeEstimate for StarkProof<H> {
fn size_estimate(&self) -> usize {
let Self {
commitments,
commitment_scheme_proof,
} = self;
commitments.size_estimate() + commitment_scheme_proof.size_estimate()
let Self(commitment_scheme_proof) = self;
commitment_scheme_proof.size_estimate()
}
}

Expand Down
11 changes: 1 addition & 10 deletions crates/prover/src/core/vcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,9 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
/// * A `MerkleDecommitment` containing the hash and column witnesses.
pub fn decommit(
&self,
queries_per_log_size: BTreeMap<u32, Vec<usize>>,
queries_per_log_size: &BTreeMap<u32, Vec<usize>>,
columns: Vec<&Col<B, BaseField>>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<H>) {
// Check that queries are sorted and deduped.
// TODO(andrew): Consider using a Queries struct to prevent this.
for queries in queries_per_log_size.values() {
assert!(
queries.windows(2).all(|w| w[0] < w[1]),
"Queries are not sorted."
);
}

// Prepare output buffers.
let mut queried_values_by_layer = vec![];
let mut decommitment = MerkleDecommitment::empty();
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/vcs/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ where
queries.insert(log_size, layer_queries);
}

let (values, decommitment) = merkle.decommit(queries.clone(), cols.iter().collect_vec());
let (values, decommitment) = merkle.decommit(&queries, cols.iter().collect_vec());

let verifier = MerkleVerifier {
root: merkle.root(),
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ where

// Setup protocol.
let channel = &mut MC::C::default();
let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles);
let mut commitment_scheme = CommitmentSchemeProver::new(config, &twiddles);

// Preprocessed trace.
// TODO(ShaharS): share is_first column between components when constant columns support this.
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ pub fn prove_fibonacci_plonk(

// Setup protocol.
let channel = &mut Blake2sChannel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);
let mut commitment_scheme =
CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

// Preprocessed trace.
let span = span!(Level::INFO, "Constant").entered();
Expand Down Expand Up @@ -298,7 +298,7 @@ mod tests {
// Retrieve the expected column sizes in each commitment interaction, from the AIR.
let sizes = component.trace_log_degree_bounds();

// Constant columns.
// Preprocessed columns.
commitment_scheme.commit(proof.commitments[0], &sizes[0], channel);

// Trace columns.
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ pub fn prove_poseidon(

// Setup protocol.
let channel = &mut Blake2sChannel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);
let mut commitment_scheme =
CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

// Preprocessed trace.
let span = span!(Level::INFO, "Constant").entered();
Expand Down Expand Up @@ -512,7 +512,7 @@ mod tests {
// Retrieve the expected column sizes in each commitment interaction, from the AIR.
let sizes = component.trace_log_degree_bounds();

// Constant columns.
// Preprocessed columns.
commitment_scheme.commit(proof.commitments[0], &sizes[0], channel);
// Trace columns.
commitment_scheme.commit(proof.commitments[1], &sizes[1], channel);
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ pub fn prove_state_machine(
);

// Setup protocol.
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);
let mut commitment_scheme =
CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

// Preprocessed trace.
let mut tree_builder = commitment_scheme.tree_builder();
Expand Down Expand Up @@ -142,7 +142,7 @@ pub fn verify_state_machine(
// Retrieve the expected column sizes in each commitment interaction, from the AIR.
let sizes = proof.stmt0.log_sizes();

// Constant columns.
// Preprocessed columns.
commitment_scheme.commit(proof.stark_proof.commitments[0], &sizes[0], channel);
// Trace columns.
proof.stmt0.mix_into(channel);
Expand Down
12 changes: 4 additions & 8 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,8 @@ mod tests {

// Setup protocol.
let prover_channel = &mut Blake2sChannel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<SimdBackend, Blake2sMerkleChannel>::new(
config, &twiddles,
);
let mut commitment_scheme =
CommitmentSchemeProver::<SimdBackend, Blake2sMerkleChannel>::new(config, &twiddles);

// Preprocessed trace
let mut tree_builder = commitment_scheme.tree_builder();
Expand Down Expand Up @@ -242,10 +240,8 @@ mod tests {

// Setup protocol.
let prover_channel = &mut Poseidon252Channel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<SimdBackend, Poseidon252MerkleChannel>::new(
config, &twiddles,
);
let mut commitment_scheme =
CommitmentSchemeProver::<SimdBackend, Poseidon252MerkleChannel>::new(config, &twiddles);

// TODO(ilya): remove the following once preproccessed columns are not mandatory.
// Preprocessed trace
Expand Down
Loading

1 comment on commit 3e5a81d

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 3e5a81d Previous: cd8b37b Ratio
merkle throughput/simd merkle 29986729 ns/iter (± 519529) 13712527 ns/iter (± 579195) 2.19

This comment was automatically generated by workflow using github-action-benchmark.

CC: @shaharsamocha7

Please sign in to comment.