Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into feat/impl_rv_add
Browse files Browse the repository at this point in the history
  • Loading branch information
KimiWu123 committed Jul 17, 2024
2 parents 83cd8b0 + b230551 commit f26d9ba
Show file tree
Hide file tree
Showing 69 changed files with 2,651 additions and 2,212 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn main() {
println!("cargo:rerun-if-env-changed=RAYON_NUM_THREADS");
}
1 change: 1 addition & 0 deletions gkr-graph/src/circuit_graph_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
),
};
let old_num_instances = self.witness.node_witnesses[*id].n_instances();
// TODO find way to avoid expensive clone for wit_in
let new_instances = match pred {
PredType::PredWire(_) => {
let new_size = (old_num_instances * out[0].len()) / num_instances;
Expand Down
93 changes: 68 additions & 25 deletions gkr-graph/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl<E: ExtensionField> IOPProverState<E> {
expected_max_thread_id: usize,
) -> Result<IOPProof<E>, GKRGraphError> {
assert_eq!(target_evals.0.len(), circuit.targets.len());
assert_eq!(circuit_witness.node_witnesses.len(), circuit.nodes.len());

let mut output_evals = vec![vec![]; circuit.nodes.len()];
let mut wit_out_evals = circuit
Expand All @@ -36,10 +37,42 @@ impl<E: ExtensionField> IOPProverState<E> {
let gkr_proofs = izip!(&circuit.nodes, &circuit_witness.node_witnesses)
.rev()
.map(|(node, witness)| {
// println!("expected_max_thread_id {:?}", expected_max_thread_id);
let max_thread_id = witness.n_instances().min(expected_max_thread_id);
// println!("max_thread_id {:?}", max_thread_id);
let timer = std::time::Instant::now();

// sanity check for witness poly evaluation
if cfg!(debug_assertions) {

// TODO figure out a way to do sanity check on output_evals
// it doens't work for now because output evaluation
// might only take partial range of output layer witness
// assert!(output_evals[node.id].len() <= 1);
// if !output_evals[node.id].is_empty() {
// debug_assert_eq!(
// witness
// .output_layer_witness_ref()
// .instances
// .as_slice()
// .original_mle()
// .evaluate(&point_and_eval.point),
// point_and_eval.eval,
// "node_id {} output eval failed",
// node.id,
// );
// }

for (witness_id, point_and_eval) in wit_out_evals[node.id].iter().enumerate() {
let mle = witness.witness_out_ref()[witness_id]
.instances
.as_slice()
.original_mle();
debug_assert_eq!(
mle.evaluate(&point_and_eval.point),
point_and_eval.eval,
"node_id {} output eval failed",
node.id,
);
}
}
let (proof, input_claim) = GKRProverState::prove_parallel(
&node.circuit,
witness,
Expand All @@ -48,6 +81,7 @@ impl<E: ExtensionField> IOPProverState<E> {
max_thread_id,
transcript,
);

// println!(
// "Proving node {}, label {}, num_instances:{}, took {}s",
// node.id,
Expand All @@ -56,52 +90,61 @@ impl<E: ExtensionField> IOPProverState<E> {
// timer.elapsed().as_secs_f64()
// );

izip!(&node.preds, input_claim.point_and_evals)
izip!(&node.preds, &input_claim.point_and_evals)
.enumerate()
.for_each(|(wire_id, (pred, point_and_eval))| match pred {
.for_each(|(wire_id, (pred_type, point_and_eval))| match pred_type {
PredType::Source => {
debug_assert_eq!(
witness.witness_in_ref()[wire_id as usize]
// sanity check for input poly evaluation
if cfg!(debug_assertions) {
let input_layer_poly = witness.witness_in_ref()[wire_id]
.instances
.as_slice()
.original_mle()
.evaluate(&point_and_eval.point),
point_and_eval.eval
);
.original_mle();
debug_assert_eq!(
input_layer_poly.evaluate(&point_and_eval.point),
point_and_eval.eval,
"mismatch at node.id {:?} wire_id {:?}, input_claim.point_and_evals.point {:?}, node.preds {:?}",
node.id,
wire_id,
input_claim.point_and_evals[0].point,
node.preds
);
}
}
PredType::PredWire(out) | PredType::PredWireDup(out) => {
let point = match pred {
PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => {
let point = match pred_type {
PredType::PredWire(_) => point_and_eval.point.clone(),
PredType::PredWireDup(out) => {
let node_id = match out {
let pred_node_id = match out {
NodeOutputType::OutputLayer(id) => id,
NodeOutputType::WireOut(id, _) => id,
};
// Suppose the new point is
// [single_instance_slice ||
// new_instance_index_slice]. The old point
// is [single_instance_slices ||
// new_instance_index_slices[(new_instance_num_vars
// - old_instance_num_vars)..]]
let old_instance_num_vars = circuit_witness.node_witnesses
[*node_id]
// new_instance_index_slices[(instance_num_vars
// - pred_instance_num_vars)..]]
let pred_instance_num_vars = circuit_witness.node_witnesses
[*pred_node_id]
.instance_num_vars();
let new_instance_num_vars = witness.instance_num_vars();
let num_vars =
point_and_eval.point.len() - new_instance_num_vars;
let instance_num_vars = witness.instance_num_vars();
let num_vars = point_and_eval.point.len() - instance_num_vars;
[
point_and_eval.point[..num_vars].to_vec(),
point_and_eval.point[num_vars
+ (new_instance_num_vars - old_instance_num_vars)..]
+ (instance_num_vars - pred_instance_num_vars)..]
.to_vec(),
]
.concat()
}
_ => unreachable!(),
};
match out {
NodeOutputType::OutputLayer(id) => output_evals[*id]
.push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)),
match pred_out {
NodeOutputType::OutputLayer(id) => {
output_evals[*id]
.push(PointAndEval::new_from_ref(&point, &point_and_eval.eval))
},
NodeOutputType::WireOut(id, wire_id) => {
let evals = &mut wit_out_evals[*id][*wire_id as usize];
assert!(
Expand Down
90 changes: 47 additions & 43 deletions gkr-graph/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,52 +50,56 @@ impl<E: ExtensionField> IOPVerifierState<E> {

let new_instance_num_vars = aux_info.instance_num_vars[node.id];

izip!(&node.preds, input_claim.point_and_evals).for_each(|(pred, point_and_eval)| {
match pred {
PredType::Source => {
// TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations` for later PCS open?
}
PredType::PredWire(out) | PredType::PredWireDup(out) => {
let old_point = match pred {
PredType::PredWire(_) => point_and_eval.point.clone(),
PredType::PredWireDup(out) => {
let node_id = match out {
NodeOutputType::OutputLayer(id) => *id,
NodeOutputType::WireOut(id, _) => *id,
};
// Suppose the new point is
// [single_instance_slice ||
// new_instance_index_slice]. The old point
// is [single_instance_slices ||
// new_instance_index_slices[(new_instance_num_vars
// - old_instance_num_vars)..]]
let old_instance_num_vars = aux_info.instance_num_vars[node_id];
let num_vars = point_and_eval.point.len() - new_instance_num_vars;
[
point_and_eval.point[..num_vars].to_vec(),
point_and_eval.point[num_vars
+ (new_instance_num_vars - old_instance_num_vars)..]
.to_vec(),
]
.concat()
}
_ => unreachable!(),
};
match out {
NodeOutputType::OutputLayer(id) => output_evals[*id]
.push(PointAndEval::new_from_ref(&old_point, &point_and_eval.eval)),
NodeOutputType::WireOut(id, wire_id) => {
let evals = &mut wit_out_evals[*id][*wire_id as usize];
assert!(
evals.point.is_empty() && evals.eval.is_zero_vartime(),
"unimplemented",
);
*evals = PointAndEval::new(old_point, point_and_eval.eval);
izip!(&node.preds, input_claim.point_and_evals).for_each(
|(pred_type, point_and_eval)| {
match pred_type {
PredType::Source => {
// TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations`
// for later PCS open?
}
PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => {
let point = match pred_type {
PredType::PredWire(_) => point_and_eval.point.clone(),
PredType::PredWireDup(out) => {
let node_id = match out {
NodeOutputType::OutputLayer(id) => *id,
NodeOutputType::WireOut(id, _) => *id,
};
// Suppose the new point is
// [single_instance_slice ||
// new_instance_index_slice]. The old point
// is [single_instance_slices ||
// new_instance_index_slices[(new_instance_num_vars
// - old_instance_num_vars)..]]
let old_instance_num_vars = aux_info.instance_num_vars[node_id];
let num_vars =
point_and_eval.point.len() - new_instance_num_vars;
[
point_and_eval.point[..num_vars].to_vec(),
point_and_eval.point[num_vars
+ (new_instance_num_vars - old_instance_num_vars)..]
.to_vec(),
]
.concat()
}
_ => unreachable!(),
};
match pred_out {
NodeOutputType::OutputLayer(id) => output_evals[*id]
.push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)),
NodeOutputType::WireOut(id, wire_id) => {
let evals = &mut wit_out_evals[*id][*wire_id as usize];
assert!(
evals.point.is_empty() && evals.eval.is_zero_vartime(),
"unimplemented",
);
*evals = PointAndEval::new(point, point_and_eval.eval);
}
}
}
}
}
});
},
);
}

Ok(())
Expand Down
4 changes: 1 addition & 3 deletions gkr/benches/keccak256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use gkr::gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak2
use goldilocks::GoldilocksExt2;
use sumcheck::util::is_power_of_2;

// cargo bench --bench keccak256 --features parallel --features flamegraph --package gkr -- --profile-time <secs>
cfg_if::cfg_if! {
if #[cfg(feature = "flamegraph")] {
criterion_group! {
Expand Down Expand Up @@ -48,8 +47,7 @@ fn bench_keccak256(c: &mut Criterion) {

#[cfg(feature = "non_pow2_rayon_thread")]
{
use sumcheck::local_thread_pool::create_local_pool_once;
use sumcheck::util::ceil_log2;
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS);
create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true);
max_thread_id
Expand Down
3 changes: 1 addition & 2 deletions gkr/examples/keccak256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ fn main() {

#[cfg(feature = "non_pow2_rayon_thread")]
{
use sumcheck::local_thread_pool::create_local_pool_once;
use sumcheck::util::ceil_log2;
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
max_thread_id = 1 << ceil_log2(max_thread_id);
create_local_pool_once(max_thread_id, true);
}
Expand Down
13 changes: 7 additions & 6 deletions gkr/src/circuit/circuit_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ impl<E: ExtensionField> Circuit<E> {
});
let segment = (
wire_ids_in_layer[in_cell_ids[0]],
wire_ids_in_layer[in_cell_ids[in_cell_ids.len() - 1]] + 1,
wire_ids_in_layer[in_cell_ids[in_cell_ids.len() - 1]] + 1, /* + 1 for exclusive
* last index */
);
match ty {
InType::Witness(wit_id) => {
Expand Down Expand Up @@ -258,9 +259,10 @@ impl<E: ExtensionField> Circuit<E> {
.push(output_subsets.update_wire_id(old_layer_id, old_wire_id));
}
OutType::AssertConst(constant) => {
let new_wire_id = output_subsets.update_wire_id(old_layer_id, old_wire_id);
output_assert_const.push(GateCIn {
idx_in: [],
idx_out: output_subsets.update_wire_id(old_layer_id, old_wire_id),
idx_out: new_wire_id,
scalar: ConstantType::Field(i64_to_field(constant)),
});
}
Expand Down Expand Up @@ -288,8 +290,7 @@ impl<E: ExtensionField> Circuit<E> {
} else {
let last_layer = &layers[(layer_id - 1) as usize];
if !last_layer.is_linear() || !layer.copy_to.is_empty() {
curr_sc_steps
.extend([SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2]);
curr_sc_steps.extend([SumcheckStepType::Phase1Step1]);
}
}

Expand Down Expand Up @@ -900,7 +901,7 @@ mod tests {
// Single input witness, therefore no input phase 2 steps.
assert_eq!(
circuit.layers[2].sumcheck_steps,
vec![SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2,]
vec![SumcheckStepType::Phase1Step1]
);
// There are only one incoming evals since the last layer is linear, and
// no subset evals. Therefore, there are no phase1 steps.
Expand Down Expand Up @@ -931,7 +932,7 @@ mod tests {
// Single input witness, therefore no input phase 2 steps.
assert_eq!(
circuit.layers[1].sumcheck_steps,
vec![SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2]
vec![SumcheckStepType::Phase1Step1]
);
// Output layer, single output witness, therefore no output phase 1 steps.
assert_eq!(
Expand Down
Loading

0 comments on commit f26d9ba

Please sign in to comment.