Skip to content

Commit

Permalink
chore: simplify and homogenize proving code for Nova and SuperNova pr…
Browse files Browse the repository at this point in the history
…oofs (#1121)

* Remove code duplication in Nova proving
* Factor out and reuse debugging code that was triggered only for Nova proving without "parallel steps"
* Remove the cloning of recursive snarks on SuperNova
* Drop `crossbeam` dependency because `crossbeam::thread::scope` has been soft-deprecated
  in favor of `std::thread::scope` since Rust 1.63 due to performance reasons
  • Loading branch information
arthurpaulino authored Feb 13, 2024
1 parent 4eed503 commit 1d4d308
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 109 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ stable_deref_trait = "1.2.0"
thiserror = { workspace = true }
abomonation = { workspace = true }
abomonation_derive = { version = "0.1.0", package = "abomonation_derive_ng" }
crossbeam = "0.8.2"
byteorder = "1.4.3"
circom-scotia = { git = "https://github.com/lurk-lab/circom-scotia", branch = "dev" }
sha2 = { version = "0.10.2" }
Expand Down
128 changes: 58 additions & 70 deletions src/proof/nova.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use bellpepper_core::{num::AllocatedNum, ConstraintSystem};
use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError};
use halo2curves::bn256::Fr as Bn256Scalar;
use nova::{
errors::NovaError,
Expand All @@ -18,14 +18,15 @@ use std::{
marker::PhantomData,
sync::{Arc, Mutex},
};
use tracing::info;

use crate::{
config::lurk_config,
coprocessor::Coprocessor,
error::ProofError,
eval::lang::Lang,
field::LurkField,
lem::{interpreter::Frame, pointers::Ptr, store::Store},
lem::{interpreter::Frame, multiframe::MultiFrame, pointers::Ptr, store::Store},
proof::{supernova::FoldingConfig, FrameLike, Prover},
};

Expand Down Expand Up @@ -223,6 +224,29 @@ pub fn circuits<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a>(
)
}

/// For debugging purposes, synthesize the circuit and check that the constraint
/// system is satisfied
#[inline]
pub(crate) fn debug_step<F: LurkField, C: Coprocessor<F>>(
circuit: &MultiFrame<'_, F, C>,
store: &Store<F>,
) -> Result<(), SynthesisError> {
use bellpepper_core::test_cs::TestConstraintSystem;
let mut cs = TestConstraintSystem::<F>::new();

let zi = store.to_scalar_vector(circuit.input());
let zi_allocated: Vec<_> = zi
.iter()
.enumerate()
.map(|(i, x)| AllocatedNum::alloc(cs.namespace(|| format!("z{i}_1")), || Ok(*x)))
.collect::<Result<_, _>>()?;

circuit.synthesize(&mut cs, zi_allocated.as_slice())?;

assert!(cs.is_satisfied());
Ok(())
}

impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> RecursiveSNARKTrait<F, C1LEM<'a, F, C>>
for Proof<F, C1LEM<'a, F, C>>
{
Expand All @@ -237,31 +261,42 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> RecursiveSNARKTrait<F, C1LEM<
steps: Vec<C1LEM<'a, F, C>>,
store: &Store<F>,
) -> Result<Self, ProofError> {
assert!(!steps.is_empty());
assert_eq!(steps[0].arity(), z0.len());
let debug = false;
let z0_primary = z0;
let z0_secondary = Self::z0_secondary();
assert_eq!(steps[0].arity(), z0.len());

let circuit_secondary = TrivialCircuit::default();
let secondary_circuit = TrivialCircuit::default();

let num_steps = steps.len();
tracing::debug!("steps.len: {num_steps}");
info!("proving {num_steps} steps");

let mut recursive_snark_option: Option<RecursiveSNARK<E1<F>>> = None;

// produce a recursive SNARK
let mut recursive_snark: Option<RecursiveSNARK<E1<F>>> = None;
let prove_step =
|i: usize, step: &C1LEM<'a, F, C>, rs: &mut Option<RecursiveSNARK<E1<F>>>| {
if debug {
debug_step(step, store).unwrap();
}
let mut recursive_snark = rs.take().unwrap_or_else(|| {
RecursiveSNARK::new(&pp.pp, step, &secondary_circuit, z0, &Self::z0_secondary())
.expect("failed to construct initial recursive SNARK")
});
info!("prove_step {i}");
recursive_snark
.prove_step(&pp.pp, step, &secondary_circuit)
.unwrap();
*rs = Some(recursive_snark);
};

// the shadowing here is voluntary
let recursive_snark = if lurk_config(None, None)
recursive_snark_option = if lurk_config(None, None)
.perf
.parallelism
.recursive_steps
.is_parallel()
{
let cc = steps.into_iter().map(Mutex::new).collect::<Vec<_>>();

crossbeam::thread::scope(|s| {
s.spawn(|_| {
std::thread::scope(|s| {
s.spawn(|| {
// Skip the very first circuit's witness, so `prove_step` can begin immediately.
// That circuit's witness will not be cached and will just be computed on-demand.
cc.iter().skip(1).for_each(|mf| {
Expand All @@ -272,69 +307,22 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> RecursiveSNARKTrait<F, C1LEM<
});
});

for circuit_primary in cc.iter() {
let mut circuit_primary = circuit_primary.lock().unwrap();

let mut r_snark = recursive_snark.unwrap_or_else(|| {
RecursiveSNARK::new(
&pp.pp,
&*circuit_primary,
&circuit_secondary,
z0_primary,
&z0_secondary,
)
.expect("Failed to construct initial recursive snark")
});
r_snark
.prove_step(&pp.pp, &*circuit_primary, &circuit_secondary)
.expect("failure to prove Nova step");
circuit_primary.clear_cached_witness();
recursive_snark = Some(r_snark);
for (i, step) in cc.iter().enumerate() {
let mut step = step.lock().unwrap();
prove_step(i, &step, &mut recursive_snark_option);
step.clear_cached_witness();
}
recursive_snark
recursive_snark_option
})
.unwrap()
} else {
for circuit_primary in steps.iter() {
if debug {
// For debugging purposes, synthesize the circuit and check that the constraint system is satisfied.
use bellpepper_core::test_cs::TestConstraintSystem;
let mut cs = TestConstraintSystem::<F>::new();

let zi = store.to_scalar_vector(circuit_primary.input());
let zi_allocated: Vec<_> = zi
.iter()
.enumerate()
.map(|(i, x)| {
AllocatedNum::alloc(cs.namespace(|| format!("z{i}_1")), || Ok(*x))
})
.collect::<Result<_, _>>()?;

circuit_primary.synthesize(&mut cs, zi_allocated.as_slice())?;

assert!(cs.is_satisfied());
}

let mut r_snark = recursive_snark.unwrap_or_else(|| {
RecursiveSNARK::new(
&pp.pp,
circuit_primary,
&circuit_secondary,
z0_primary,
&z0_secondary,
)
.expect("Failed to construct initial recursive snark")
});
r_snark
.prove_step(&pp.pp, circuit_primary, &circuit_secondary)
.expect("failure to prove Nova step");
recursive_snark = Some(r_snark);
for (i, step) in steps.iter().enumerate() {
prove_step(i, step, &mut recursive_snark_option);
}
recursive_snark
recursive_snark_option
};

Ok(Self::Recursive(
Box::new(recursive_snark.unwrap()),
Box::new(recursive_snark_option.expect("RecursiveSNARK missing")),
num_steps,
PhantomData,
))
Expand Down
74 changes: 36 additions & 38 deletions src/proof/supernova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
field::LurkField,
lem::{interpreter::Frame, pointers::Ptr, store::Store},
proof::{
nova::{CurveCycleEquipped, Dual, NovaCircuitShape, E1},
nova::{debug_step, CurveCycleEquipped, Dual, NovaCircuitShape, E1},
Prover, RecursiveSNARKTrait,
},
};
Expand Down Expand Up @@ -202,39 +202,37 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> RecursiveSNARKTrait<F, C1LEM<
steps: Vec<C1LEM<'a, F, C>>,
store: &Store<F>,
) -> Result<Self, ProofError> {
let mut recursive_snark_option: Option<RecursiveSNARK<E1<F>>> = None;

let z0_primary = z0;
let z0_secondary = Self::z0_secondary();

let mut prove_step = |i: usize, step: &C1LEM<'a, F, C>| {
info!("prove_recursively, step {i}");

let secondary_circuit = step.secondary_circuit();

let mut recursive_snark = recursive_snark_option.clone().unwrap_or_else(|| {
info!("RecursiveSnark::new {i}");
RecursiveSNARK::new(
&pp.pp,
step,
step,
&secondary_circuit,
z0_primary,
&z0_secondary,
)
.unwrap()
});

info!("prove_step {i}");
let debug = false;

recursive_snark
.prove_step(&pp.pp, step, &secondary_circuit)
.unwrap();
info!("proving {} steps", steps.len());

recursive_snark_option = Some(recursive_snark);
};
let mut recursive_snark_option: Option<RecursiveSNARK<E1<F>>> = None;

if lurk_config(None, None)
let prove_step =
|i: usize, step: &C1LEM<'a, F, C>, rs: &mut Option<RecursiveSNARK<E1<F>>>| {
if debug {
debug_step(step, store).unwrap();
}
let secondary_circuit = step.secondary_circuit();
let mut recursive_snark = rs.take().unwrap_or_else(|| {
RecursiveSNARK::new(
&pp.pp,
step,
step,
&secondary_circuit,
z0,
&Self::z0_secondary(),
)
.expect("failed to construct initial recursive SNARK")
});
info!("prove_step {i}");
recursive_snark
.prove_step(&pp.pp, step, &secondary_circuit)
.unwrap();
*rs = Some(recursive_snark);
};

recursive_snark_option = if lurk_config(None, None)
.perf
.parallelism
.recursive_steps
Expand All @@ -245,8 +243,8 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> RecursiveSNARKTrait<F, C1LEM<
.map(|mf| (mf.program_counter() == 0, Mutex::new(mf)))
.collect::<Vec<_>>();

crossbeam::thread::scope(|s| {
s.spawn(|_| {
std::thread::scope(|s| {
s.spawn(|| {
// Skip the very first circuit's witness, so `prove_step` can begin immediately.
// That circuit's witness will not be cached and will just be computed on-demand.

Expand Down Expand Up @@ -280,18 +278,18 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> RecursiveSNARKTrait<F, C1LEM<

for (i, (_, step)) in cc.iter().enumerate() {
let mut step = step.lock().unwrap();
prove_step(i, &step);
prove_step(i, &step, &mut recursive_snark_option);
step.clear_cached_witness();
}
recursive_snark_option
})
.unwrap()
} else {
for (i, step) in steps.iter().enumerate() {
prove_step(i, step);
prove_step(i, step, &mut recursive_snark_option);
}
}
recursive_snark_option
};

// This probably should be made unnecessary.
Ok(Self::Recursive(
Box::new(recursive_snark_option.expect("RecursiveSNARK missing")),
PhantomData,
Expand Down

1 comment on commit 1d4d308

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Fibonacci GPU benchmark.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
32 vCPUs
125 GB RAM
Workflow run: https://github.com/lurk-lab/lurk-rs/actions/runs/7892310674

Benchmark Results

LEM Fibonacci Prove - rc = 100

ref=4eed503a8f44e739c592bcb1ef07301e44437d86 ref=1d4d308e2bc12f5ab431ea210c0b722f9eb31825
num-100 1.45 s (✅ 1.00x) 1.45 s (✅ 1.00x faster)
num-200 2.77 s (✅ 1.00x) 2.78 s (✅ 1.00x slower)

LEM Fibonacci Prove - rc = 600

ref=4eed503a8f44e739c592bcb1ef07301e44437d86 ref=1d4d308e2bc12f5ab431ea210c0b722f9eb31825
num-100 1.83 s (✅ 1.00x) 1.84 s (✅ 1.00x slower)
num-200 3.04 s (✅ 1.00x) 3.03 s (✅ 1.00x faster)

Made with criterion-table

Please sign in to comment.