Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DON'T MERGE: experiments #84

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions gkr/examples/keccak256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn main() {
#[allow(unused_mut)]
let mut max_thread_id: usize = env::var("RAYON_NUM_THREADS")
.map(|v| str::parse::<usize>(&v).unwrap_or(1))
.unwrap();
.unwrap_or(1);

if !is_power_of_2(max_thread_id) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
Expand Down Expand Up @@ -85,15 +85,15 @@ fn main() {
let subscriber = Registry::default()
.with(
fmt::layer()
.compact()
.compact()
.with_thread_ids(false)
.with_thread_names(false),
)
.with(EnvFilter::from_default_env())
.with(flame_layer.with_threads_collapsed(true));
tracing::subscriber::set_global_default(subscriber).unwrap();

for log2_n in 0..12 {
for log2_n in 10..11 {
let Some((proof, output_mle)) =
prove_keccak256::<GoldilocksExt2>(log2_n, &circuit, (1 << log2_n).min(max_thread_id))
else {
Expand Down
58 changes: 51 additions & 7 deletions gkr/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::mem;
use std::{
mem,
ops::AddAssign,
time::{Duration, Instant},
};

use ark_std::{end_timer, start_timer};
use ff_ext::ExtensionField;
Expand Down Expand Up @@ -59,12 +63,21 @@ impl<E: ExtensionField> IOPProverState<E> {
)
});

let sumcheck_prove_dur = Instant::now();
let mut phase2_build_mle_durs = vec![];
let mut phase2_sumcheck_prove_durs = vec![];
let mut phase1_prove_sumcheck_durs = vec![];
let mut build_eq_durs = vec![];
let sumcheck_proofs = (0..circuit.layers.len() as LayerId)
.map(|layer_id| {
let timer = start_timer!(|| format!("Prove layer {}", layer_id));

prover_state.layer_id = layer_id;

let mut phase2_build_mle_dur = Duration::ZERO;
let mut phase2_sumcheck_dur = Duration::ZERO;
let mut phase1_prove_sumcheck_dur = Duration::ZERO;
let mut build_eq_dur = Duration::ZERO;
let dummy_step = SumcheckStepType::Undefined;
let proofs = circuit.layers[layer_id as usize]
.sumcheck_steps
Expand All @@ -85,8 +98,9 @@ impl<E: ExtensionField> IOPProverState<E> {
transcript,
)].to_vec()
},
(SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2, _) =>
[
(SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2, _) => {
let dur = Instant::now();
let proofs = [
prover_state
.prove_and_update_state_phase1_step1(
circuit,
Expand All @@ -99,8 +113,10 @@ impl<E: ExtensionField> IOPProverState<E> {
circuit_witness,
transcript,
),
].to_vec()
,
].to_vec();
phase1_prove_sumcheck_dur.add_assign(dur.elapsed());
proofs
},
(SumcheckStepType::Phase2Step1, step2, _) => {
let span = entered_span!("phase2_gkr");
let max_steps = match step2 {
Expand All @@ -113,9 +129,12 @@ impl<E: ExtensionField> IOPProverState<E> {
let mut layer_polys = (0..max_thread_id).map(|_| ArcDenseMultilinearExtension::default()).collect::<Vec<ArcDenseMultilinearExtension<E>>>();
let mut res = vec![];
for step in 0..max_steps {
let dur = Instant::now();
let bounded_eval_point = prover_state.to_next_step_point.clone();
eqs.push(build_eq_x_r_vec(&bounded_eval_point));
build_eq_dur.add_assign(dur.elapsed());
// build step round poly
let build_mle_dur = Instant::now();
let virtual_polys: Vec<VirtualPolynomial<E>> = (0..max_thread_id).into_par_iter().zip(layer_polys.par_iter_mut()).map(|(thread_id, layer_poly)| {
let span = entered_span!("build_poly");
let (next_layer_poly_step1, virtual_poly) = match step {
Expand Down Expand Up @@ -154,18 +173,23 @@ impl<E: ExtensionField> IOPProverState<E> {
},
_ => unimplemented!(),
};

if let Some(next_layer_poly_step1) = next_layer_poly_step1 {
let _ = mem::replace(layer_poly, next_layer_poly_step1);
}
exit_span!(span);
virtual_poly
}).collect();
let build_mle_dur = build_mle_dur.elapsed();
phase2_build_mle_dur.add_assign(build_mle_dur);

let sumcheck_dur = Instant::now();
let (sumcheck_proof, sumcheck_prover_state) = sumcheck::structs::IOPProverState::<E>::prove_batch_polys(
max_thread_id,
virtual_polys.try_into().unwrap(),
transcript,
);
phase2_sumcheck_dur.add_assign(sumcheck_dur.elapsed());

let iop_prover_step =
match step {
Expand Down Expand Up @@ -221,14 +245,34 @@ impl<E: ExtensionField> IOPProverState<E> {
})
.collect_vec();
end_timer!(timer);

println!("layer {} | phase2 build mle: {:?}", layer_id, phase2_build_mle_dur);
phase2_build_mle_durs.push(phase2_build_mle_dur);
phase2_sumcheck_prove_durs.push(phase2_sumcheck_dur);
phase1_prove_sumcheck_durs.push(phase1_prove_sumcheck_dur);
build_eq_durs.push(build_eq_dur);
proofs
})
.flatten()
.collect_vec();
end_timer!(timer);
exit_span!(span);

println!(
"phase2 build mle in total: {:?}",
phase2_build_mle_durs.iter().sum::<Duration>()
);
println!(
"phase2 prove sumcheck in total: {:?}",
phase2_sumcheck_prove_durs.iter().sum::<Duration>()
);
println!(
"phase1 prove sumcheck in total: {:?}",
phase1_prove_sumcheck_durs.iter().sum::<Duration>()
);
println!(
"build eq in total: {:?}",
build_eq_durs.iter().sum::<Duration>()
);
println!("prove sumcheck took {:?}", sumcheck_prove_dur.elapsed());
(
IOPProof { sumcheck_proofs },
GKRInputClaims {
Expand Down
7 changes: 7 additions & 0 deletions multilinear_extensions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,12 @@ goldilocks.workspace = true
rayon.workspace = true
serde.workspace = true

[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }

[[bench]]
name = "mle"
harness = false

[features]
parallel = [ ]
110 changes: 110 additions & 0 deletions multilinear_extensions/benches/mle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use ark_std::rand::thread_rng;
use criterion::*;
use ff::Field;
use goldilocks::GoldilocksExt2;
use multilinear_extensions::{
mle::DenseMultilinearExtension,
virtual_poly::{build_eq_x_r, build_eq_x_r_sequential},
};

fn fix_var(c: &mut Criterion) {
let mut rng = thread_rng();

const NUM_SAMPLES: usize = 10;
for nv in 16..24 {
let mut group = c.benchmark_group("mle");
group.sample_size(NUM_SAMPLES);

for i in 0..nv {
group.bench_function(
BenchmarkId::new("fix_var", format!("({},{})", nv, nv - i)),
|b| {
b.iter_with_setup(
|| {
let mut v =
DenseMultilinearExtension::<GoldilocksExt2>::random(nv, &mut rng);
let r = GoldilocksExt2::random(&mut rng);
for _ in 0..i {
v.fix_variables_in_place(&[r]);
}
(v, r)
},
|(mut v, r)| v.fix_variables_in_place(&[r]),
);
},
);
}
group.finish();
}
}

fn fix_var_par(c: &mut Criterion) {
let mut rng = thread_rng();

const NUM_SAMPLES: usize = 10;
for nv in 16..24 {
let mut group = c.benchmark_group("mle");
group.sample_size(NUM_SAMPLES);

for i in 0..nv {
group.bench_function(
BenchmarkId::new("fix_var_par", format!("({},{})", nv, nv - i)),
|b| {
b.iter_with_setup(
|| {
let mut v =
DenseMultilinearExtension::<GoldilocksExt2>::random(nv, &mut rng);
let r = GoldilocksExt2::random(&mut rng);
for _ in 0..i {
v.fix_variables_in_place(&[r]);
}
(v, r)
},
|(mut v, r)| v.fix_variables_in_place_parallel(&[r]),
);
},
);
}
group.finish();
}
}

fn bench_build_eq_internal(c: &mut Criterion, use_par: bool) {
const NUM_SAMPLES: usize = 10;
let mut rng = thread_rng();
let group_name = if use_par {
"build_eq_par"
} else {
"build_eq_seq"
};
let mut group = c.benchmark_group(group_name);
group.sample_size(NUM_SAMPLES);

for num_vars in 15..24 {
group.bench_function(format!("{}", num_vars), |b| {
b.iter_batched(
|| {
(0..num_vars)
.map(|_| GoldilocksExt2::random(&mut rng))
.collect::<Vec<GoldilocksExt2>>()
},
|r| {
if use_par {
build_eq_x_r(&r)
} else {
build_eq_x_r_sequential(&r)
}
},
BatchSize::SmallInput,
);
});
}
}

fn bench_build_eq(c: &mut Criterion) {
bench_build_eq_internal(c, false);
bench_build_eq_internal(c, true);
}

criterion_group!(benches, bench_build_eq, fix_var, fix_var_par,);
criterion_main!(benches);
9 changes: 4 additions & 5 deletions multilinear_extensions/src/virtual_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ pub fn eq_eval<F: PrimeField>(x: &[F], y: &[F]) -> F {
/// Evaluate
/// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
/// over r, which is
/// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
/// eq(x,r) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
pub fn build_eq_x_r_sequential<E: ExtensionField>(r: &[E]) -> ArcDenseMultilinearExtension<E> {
let evals = build_eq_x_r_vec_sequential(r);
let mle = DenseMultilinearExtension::from_evaluations_ext_vec(r.len(), evals);
Expand Down Expand Up @@ -438,8 +438,7 @@ pub fn build_eq_x_r<E: ExtensionField>(r: &[E]) -> ArcDenseMultilinearExtension<
/// Evaluate
/// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
/// over r, which is
/// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))

/// eq(x,r) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
#[tracing::instrument(skip_all, name = "multilinear_extensions::build_eq_x_r_vec")]
pub fn build_eq_x_r_vec<E: ExtensionField>(r: &[E]) -> Vec<E> {
// avoid unnecessary allocation
Expand Down Expand Up @@ -479,8 +478,8 @@ fn build_eq_x_r_helper<E: ExtensionField>(r: &[E], buf: &mut [Vec<E>; 2]) {
for (i, r) in r.iter().rev().enumerate() {
let [current, next] = buf;
let (cur_size, next_size) = (1 << i, 1 << (i + 1));
// suppose at the previous step we processed buf [0..size]
// for the current step we are populating new buf[0..2*size]
// suppose at the previous step we processed current_buf[0..size]
// for the current step we are populating new_buf[0..2*size]
// for j travese 0..size
// buf[2*j + 1] = r * buf[j]
// buf[2*j] = (1 - r) * buf[j]
Expand Down
4 changes: 4 additions & 0 deletions sumcheck/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,9 @@ criterion = { version = "0.5", features = ["html_reports"] }
name = "devirgo_sumcheck"
harness = false

[[bench]]
name = "eval_partial_poly"
harness = false

[features]
non_pow2_rayon_thread = [ ]
Loading