Skip to content

Commit

Permalink
optimize sumcheck algo
Browse files Browse the repository at this point in the history
circuit witness: direct witness on mle

devirgo style on phase1_output
  • Loading branch information
hero78119 committed Jul 15, 2024
1 parent ad4a6d8 commit 4226a3e
Show file tree
Hide file tree
Showing 65 changed files with 3,961 additions and 2,329 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 10 additions & 12 deletions gkr-graph/examples/series_connection_alt.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use ff::Field;
use ff_ext::ExtensionField;
use gkr::{
structs::{Circuit, LayerWitness, PointAndEval},
utils::MultilinearExtensionFromVectors,
structs::{Circuit, PointAndEval},
util::ceil_log2,
};
use gkr_graph::{
error::GKRGraphError,
Expand All @@ -12,6 +12,7 @@ use gkr_graph::{
},
};
use goldilocks::{Goldilocks, GoldilocksExt2};
use multilinear_extensions::mle::DenseMultilinearExtension;
use simple_frontend::structs::{ChallengeId, CircuitBuilder, MixedCell};
use std::sync::Arc;
use transcript::Transcript;
Expand Down Expand Up @@ -153,7 +154,7 @@ fn main() -> Result<(), GKRGraphError> {
circuit: &Arc<Circuit<_>>,
preds: Vec<PredType>,
challenges: Vec<_>,
sources: Vec<LayerWitness<_>>,
sources: Vec<DenseMultilinearExtension<_>>,
num_instances: usize|
-> Result<usize, GKRGraphError> {
let prover_node_id = prover_graph_builder.add_node_with_witness(
Expand All @@ -174,10 +175,10 @@ fn main() -> Result<(), GKRGraphError> {
&input_circuit,
vec![PredType::Source],
challenge,
// input_circuit_wires_in.clone()
vec![LayerWitness {
instances: vec![input_circuit_wires_in.clone()],
}],
vec![DenseMultilinearExtension::from_evaluations_vec(
ceil_log2(input_circuit_wires_in.len()),
input_circuit_wires_in.clone(),
)],
1,
)?;
let selector = add_node_and_witness("selector", &prefix_selector, vec![], vec![], vec![], 1)?;
Expand All @@ -191,7 +192,7 @@ fn main() -> Result<(), GKRGraphError> {
PredType::PredWire(NodeOutputType::OutputLayer(selector)),
],
vec![],
vec![LayerWitness::default(); 2],
vec![DenseMultilinearExtension::default(); 2],
round_input_size >> 1,
)?;
round_input_size >>= 1;
Expand All @@ -203,7 +204,7 @@ fn main() -> Result<(), GKRGraphError> {
&frac_sum_circuit,
vec![PredType::PredWire(frac_sum_input)],
vec![],
vec![LayerWitness::default(); 1],
vec![DenseMultilinearExtension::default(); 1],
round_input_size >> 1,
)?,
0,
Expand Down Expand Up @@ -237,9 +238,6 @@ fn main() -> Result<(), GKRGraphError> {
.last()
.unwrap()
.output_layer_witness_ref()
.instances
.as_slice()
.original_mle()
.evaluate(&output_point);
let proof = IOPProverState::prove(
&prover_graph,
Expand Down
25 changes: 9 additions & 16 deletions gkr-graph/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
use ff_ext::ExtensionField;
use gkr::{
structs::{Point, PointAndEval},
utils::MultilinearExtensionFromVectors,
};
use gkr::structs::{Point, PointAndEval};
use itertools::Itertools;

use crate::structs::{CircuitGraph, CircuitGraphWitness, NodeOutputType, TargetEvaluations};

impl<E: ExtensionField> CircuitGraph<E> {
pub fn target_evals(
&self,
witness: &CircuitGraphWitness<E::BaseField>,
witness: &CircuitGraphWitness<E>,
point: &Point<E>,
) -> TargetEvaluations<E> {
// println!("targets: {:?}, point: {:?}", self.targets, point);
Expand All @@ -19,19 +16,15 @@ impl<E: ExtensionField> CircuitGraph<E> {
.iter()
.map(|target| {
let poly = match target {
NodeOutputType::OutputLayer(node_id) => witness.node_witnesses[*node_id]
.output_layer_witness_ref()
.instances
.as_slice()
.original_mle(),
NodeOutputType::WireOut(node_id, wit_id) => witness.node_witnesses[*node_id]
.witness_out_ref()[*wit_id as usize]
.instances
.as_slice()
.original_mle(),
NodeOutputType::OutputLayer(node_id) => {
witness.node_witnesses[*node_id].output_layer_witness_ref()
}
NodeOutputType::WireOut(node_id, wit_id) => {
&witness.node_witnesses[*node_id].witness_out_ref()[*wit_id as usize]
}
};
// println!("target: {:?}, poly.num_vars: {:?}", target, poly.num_vars);
let p = point[..poly.num_vars].to_vec();
let p = point[..poly.num_vars()].to_vec();
PointAndEval::new_from_ref(&p, &poly.evaluate(&p))
})
.collect_vec();
Expand Down
76 changes: 28 additions & 48 deletions gkr-graph/src/circuit_graph_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ use std::{collections::BTreeSet, sync::Arc};

use ark_std::Zero;
use ff_ext::ExtensionField;
use gkr::structs::{Circuit, CircuitWitness, LayerWitness};
use gkr::structs::{Circuit, CircuitWitness};
use itertools::{chain, izip, Itertools};
use multilinear_extensions::{
mle::DenseMultilinearExtension, virtual_poly_v2::ArcMultilinearExtension,
};
use simple_frontend::structs::WitnessId;

use crate::{
Expand All @@ -14,7 +17,7 @@ use crate::{
},
};

impl<E: ExtensionField> CircuitGraphBuilder<E> {
impl<'a, E: ExtensionField> CircuitGraphBuilder<'a, E> {
pub fn new() -> Self {
Self {
graph: Default::default(),
Expand All @@ -32,7 +35,7 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
circuit: &Arc<Circuit<E>>,
preds: Vec<PredType>,
challenges: Vec<E>,
sources: Vec<LayerWitness<E::BaseField>>,
sources: Vec<DenseMultilinearExtension<E>>,
num_instances: usize,
) -> Result<usize, GKRGraphError> {
let id = self.graph.nodes.len();
Expand All @@ -45,82 +48,61 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
assert!(num_instances.is_power_of_two());
assert_eq!(sources.len(), circuit.n_witness_in);
assert!(
!sources.iter().any(
|source| source.instances.len() != 0 && source.instances.len() != num_instances
),
sources
.iter()
.all(|source| source.evaluations.len() % num_instances == 0),
"node_id: {}, num_instances: {}, sources_num_instances: {:?}",
id,
num_instances,
sources
.iter()
.map(|source| source.instances.len())
.map(|source| source.evaluations.len())
.collect_vec()
);

let mut witness = CircuitWitness::new(circuit, challenges);
let wits_in = izip!(preds.iter(), sources.into_iter())
.map(|(pred, source)| match pred {
PredType::Source => source,
PredType::Source => source.into(),
PredType::PredWire(out) | PredType::PredWireDup(out) => {
let (id, out) = &match out {
let (id, out) = match out {
NodeOutputType::OutputLayer(id) => (
*id,
&self.witness.node_witnesses[*id]
self.witness.node_witnesses[*id]
.output_layer_witness_ref()
.instances,
.clone(),
),
NodeOutputType::WireOut(id, wit_id) => (
*id,
&self.witness.node_witnesses[*id].witness_out_ref()[*wit_id as usize]
.instances,
self.witness.node_witnesses[*id].witness_out_ref()[*wit_id as usize]
.clone(),
),
};
let old_num_instances = self.witness.node_witnesses[*id].n_instances();
// TODO find way to avoid expensive clone for wit_in
let new_instances = match pred {
PredType::PredWire(_) => {
let new_size = (old_num_instances * out[0].len()) / num_instances;
out.iter()
.cloned()
.flatten()
.chunks(new_size)
.into_iter()
.map(|c| c.collect_vec())
.collect_vec()
}
let old_num_instances = self.witness.node_witnesses[id].n_instances();
let new_instances: ArcMultilinearExtension<'a, E> = match pred {
PredType::PredWire(_) => out,
PredType::PredWireDup(_) => {
let num_dups = num_instances / old_num_instances;
let old_size = out[0].len();
out.iter()
.cloned()
.flat_map(|single_instance| {
single_instance
.into_iter()
.cycle()
.take(num_dups * old_size)
})
.chunks(old_size)
.into_iter()
.map(|c| c.collect_vec())
.collect_vec()
let new: ArcMultilinearExtension<E> =
out.dup(old_num_instances, num_dups).into();
new
}
_ => unreachable!(),
};
LayerWitness {
instances: new_instances,
}
new_instances
}
})
.collect_vec();
witness.add_instances(circuit, wits_in, num_instances);

witness.set_instances(circuit, wits_in, num_instances);
self.witness.node_witnesses.push(Arc::new(witness));

self.graph.nodes.push(CircuitNode {
id,
label,
circuit: circuit.clone(),
preds,
});
self.witness.node_witnesses.push(witness);

Ok(id)
}
Expand All @@ -146,9 +128,7 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
}

/// Collect the information of `self.sources` and `self.targets`.
pub fn finalize_graph_and_witness(
mut self,
) -> (CircuitGraph<E>, CircuitGraphWitness<E::BaseField>) {
pub fn finalize_graph_and_witness(mut self) -> (CircuitGraph<E>, CircuitGraphWitness<'a, E>) {
// Generate all possible graph output
let outs = self
.graph
Expand Down Expand Up @@ -203,7 +183,7 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
pub fn finalize_graph_and_witness_with_targets(
mut self,
targets: &[NodeOutputType],
) -> (CircuitGraph<E>, CircuitGraphWitness<E::BaseField>) {
) -> (CircuitGraph<E>, CircuitGraphWitness<'a, E>) {
// Generate all possible graph output
let outs = self
.graph
Expand Down
27 changes: 11 additions & 16 deletions gkr-graph/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
use ff_ext::ExtensionField;
use gkr::{structs::PointAndEval, utils::MultilinearExtensionFromVectors};
use itertools::{izip, Itertools};
use std::mem;
use transcript::Transcript;

use crate::{
error::GKRGraphError,
structs::{
CircuitGraph, CircuitGraphWitness, GKRProverState, IOPProof, IOPProverState,
NodeOutputType, PredType, TargetEvaluations,
},
};
use ff_ext::ExtensionField;
use gkr::structs::PointAndEval;
use itertools::{izip, Itertools};
use std::mem;
use transcript::Transcript;

impl<E: ExtensionField> IOPProverState<E> {
pub fn prove(
circuit: &CircuitGraph<E>,
circuit_witness: &CircuitGraphWitness<E::BaseField>,
circuit_witness: &CircuitGraphWitness<E>,
target_evals: &TargetEvaluations<E>,
transcript: &mut Transcript<E>,
expected_max_thread_id: usize,
Expand All @@ -31,7 +30,9 @@ impl<E: ExtensionField> IOPProverState<E> {
.collect_vec();
izip!(&circuit.targets, &target_evals.0).for_each(|(target, eval)| match target {
NodeOutputType::OutputLayer(id) => output_evals[*id].push(eval.clone()),
NodeOutputType::WireOut(id, _) => wit_out_evals[*id].push(eval.clone()),
NodeOutputType::WireOut(id, wire_out_id) => {
wit_out_evals[*id][*wire_out_id as usize] = eval.clone()
}
});

let gkr_proofs = izip!(&circuit.nodes, &circuit_witness.node_witnesses)
Expand Down Expand Up @@ -61,10 +62,7 @@ impl<E: ExtensionField> IOPProverState<E> {
// }

for (witness_id, point_and_eval) in wit_out_evals[node.id].iter().enumerate() {
let mle = witness.witness_out_ref()[witness_id]
.instances
.as_slice()
.original_mle();
let mle = &witness.witness_out_ref()[witness_id];
debug_assert_eq!(
mle.evaluate(&point_and_eval.point),
point_and_eval.eval,
Expand Down Expand Up @@ -96,10 +94,7 @@ impl<E: ExtensionField> IOPProverState<E> {
PredType::Source => {
// sanity check for input poly evaluation
if cfg!(debug_assertions) {
let input_layer_poly = witness.witness_in_ref()[wire_id]
.instances
.as_slice()
.original_mle();
let input_layer_poly = &witness.witness_in_ref()[wire_id];
debug_assert_eq!(
input_layer_poly.evaluate(&point_and_eval.point),
point_and_eval.eval,
Expand Down
11 changes: 5 additions & 6 deletions gkr-graph/src/structs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use ff_ext::ExtensionField;
use gkr::structs::{Circuit, CircuitWitness, PointAndEval};
use goldilocks::SmallField;
use simple_frontend::structs::WitnessId;
use std::{marker::PhantomData, sync::Arc};

Expand Down Expand Up @@ -60,13 +59,13 @@ pub struct CircuitGraph<E: ExtensionField> {
}

#[derive(Default)]
pub struct CircuitGraphWitness<F: SmallField> {
pub node_witnesses: Vec<CircuitWitness<F>>,
pub struct CircuitGraphWitness<'a, E: ExtensionField> {
pub node_witnesses: Vec<Arc<CircuitWitness<'a, E>>>,
}

pub struct CircuitGraphBuilder<E: ExtensionField> {
pub struct CircuitGraphBuilder<'a, E: ExtensionField> {
pub(crate) graph: CircuitGraph<E>,
pub(crate) witness: CircuitGraphWitness<E::BaseField>,
pub(crate) witness: CircuitGraphWitness<'a, E>,
}

#[derive(Clone, Debug, Default)]
Expand All @@ -75,5 +74,5 @@ pub struct CircuitGraphAuxInfo {
}

/// Evaluations corresponds to the circuit targets.
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct TargetEvaluations<F>(pub Vec<PointAndEval<F>>);
4 changes: 3 additions & 1 deletion gkr-graph/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ impl<E: ExtensionField> IOPVerifierState<E> {
.collect_vec();
izip!(&circuit.targets, &target_evals.0).for_each(|(target, eval)| match target {
NodeOutputType::OutputLayer(id) => output_evals[*id].push(eval.clone()),
NodeOutputType::WireOut(id, _) => wit_out_evals[*id].push(eval.clone()),
NodeOutputType::WireOut(id, wire_out_id) => {
wit_out_evals[*id][*wire_out_id as usize] = eval.clone()
}
});

for ((node, instance_num_vars), proof) in izip!(
Expand Down
Loading

0 comments on commit 4226a3e

Please sign in to comment.