Skip to content

Commit

Permalink
pass example test
Browse files Browse the repository at this point in the history
  • Loading branch information
kunxian-xia committed Sep 9, 2024
1 parent 88f8761 commit e598e98
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 22 deletions.
16 changes: 9 additions & 7 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ use const_env::from_env;

use ceno_emul::StepRecord;
use ceno_zkvm::{
circuit_builder::{ZKVMConstraintSystem, ZKVMVerifyingKey},
circuit_builder::ZKVMConstraintSystem,
scheme::verifier::ZKVMVerifier,
tables::{RangeTableCircuit, TableCircuit},
};
use ff_ext::ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLE;
use goldilocks::GoldilocksExt2;
use sumcheck::util::is_power_of_2;
use tracing_flame::FlameLayer;
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry};
Expand Down Expand Up @@ -98,7 +96,7 @@ fn main() {
let prover = ZKVMProver::new(pk);
let verifier = ZKVMVerifier::new(vk);

for instance_num_vars in 20..22 {
for instance_num_vars in 15..22 {
// TODO: witness generation from step records emitted by tracer
let num_instances = 1 << instance_num_vars;
let mut zkvm_witness = BTreeMap::default();
Expand All @@ -111,7 +109,10 @@ fn main() {
let range_witness = RangeTableCircuit::<E>::assign_instances(
&range_config,
range_cs.num_witin as usize,
&[],
// TODO: use real data
vec![vec![0; num_instances * 2], vec![4; num_instances * 6]]
.concat()
.as_slice(),
)
.unwrap();

Expand All @@ -128,9 +129,10 @@ fn main() {
.create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges)
.expect("create_proof failed");

let mut transcript = Transcript::new(b"riscv");
assert!(
verifier
.verify_proof(zkvm_proof, &mut transcript, &real_challenges,)
.verify_proof(zkvm_proof, &mut transcript, &real_challenges)
.expect("verify proof return with error"),
);

Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub enum UtilError {
pub enum ZKVMError {
CircuitError,
UtilError(UtilError),
VerifyError(&'static str),
VerifyError(String),
}

impl From<UtilError> for ZKVMError {
Expand Down
23 changes: 22 additions & 1 deletion ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ff_ext::ExtensionField;
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
collections::{BTreeMap, BTreeSet},
sync::Arc,
};

Expand Down Expand Up @@ -65,6 +65,17 @@ impl<E: ExtensionField> ZKVMProver<E> {
let num_instances = witness.num_instances();

if is_opcode_circuit {
tracing::debug!(
"opcode circuit {} has {} witnesses, {} reads, {} writes, {} lookups",
circuit_name,
cs.num_witin,
cs.r_expressions.len(),
cs.w_expressions.len(),
cs.lk_expressions.len(),
);
for lk_s in cs.lk_expressions_namespace_map.iter() {
tracing::debug!("opcode circuit {}: {}", circuit_name, lk_s);
}
let opcode_proof = self.create_opcode_proof(
pk,
witness
Expand All @@ -78,6 +89,11 @@ impl<E: ExtensionField> ZKVMProver<E> {
transcript,
challenges,
)?;
tracing::info!(
"generated proof for opcode {} with num_instances={}",
circuit_name,
num_instances
);
vm_proof
.opcode_proofs
.insert(circuit_name.clone(), opcode_proof);
Expand All @@ -95,6 +111,11 @@ impl<E: ExtensionField> ZKVMProver<E> {
transcript,
challenges,
)?;
tracing::info!(
"generated proof for table {} with num_instances={}",
circuit_name,
num_instances
);
vm_proof
.table_proofs
.insert(circuit_name.clone(), table_proof);
Expand Down
51 changes: 38 additions & 13 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
let mut prod_r = E::ONE;
let mut prod_w = E::ONE;
let mut logup_sum = E::ZERO;
let dummy_table_item = challenges[0];
let point_eval = PointAndEval::default();
let mut dummy_table_item_multiplicity = 0;
for (name, opcode_proof) in vm_proof.opcode_proofs {
let circuit_vk = self
.vk
Expand All @@ -61,6 +63,16 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
&point_eval,
challenges,
)?;
tracing::info!("verified proof for opcode {}", name);

// getting the number of dummy padding item that we used in this opcode circuit
let num_lks = circuit_vk.get_cs().lk_expressions.len();
let num_padded_lks_per_instance = num_lks.next_power_of_two() - num_lks;
let num_padded_instance =
opcode_proof.num_instances.next_power_of_two() - opcode_proof.num_instances;
dummy_table_item_multiplicity += num_padded_lks_per_instance
* opcode_proof.num_instances
+ num_lks.next_power_of_two() * num_padded_instance;

prod_r *= opcode_proof.record_r_out_evals.iter().product::<E>();
prod_w *= opcode_proof.record_w_out_evals.iter().product::<E>();
Expand All @@ -85,18 +97,26 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
&point_eval,
challenges,
)?;
tracing::info!("verified proof for table {}", name);

logup_sum -= table_proof.lk_p1_out_eval * table_proof.lk_q1_out_eval.invert().unwrap();
logup_sum -= table_proof.lk_p2_out_eval * table_proof.lk_q2_out_eval.invert().unwrap();
}
logup_sum -=
E::from(dummy_table_item_multiplicity as u64) * dummy_table_item.invert().unwrap();

// check rw_set equality across all proofs
if prod_r != prod_w {
return Ok(false);
}
// TODO: enable this when we have cpu init/finalize and mem init/finalize
// if prod_r != prod_w {
// return Err(ZKVMError::VerifyError("prod_r != prod_w".into()));
// }

// check logup relation across all proofs
if logup_sum != E::ZERO {
return Ok(false);
return Err(ZKVMError::VerifyError(format!(
"logup_sum({:?}) != 0",
logup_sum
)));
}

Ok(true)
Expand Down Expand Up @@ -159,7 +179,7 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
// index 0 is LogUp witness for Fixed Lookup table
if logup_p_evals[0].eval != E::ONE {
return Err(ZKVMError::VerifyError(
"Lookup table witness p(x) != constant 1",
"Lookup table witness p(x) != constant 1".into(),
));
}

Expand Down Expand Up @@ -279,7 +299,7 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
.sum::<E>();
if computed_evals != expected_evaluation {
return Err(ZKVMError::VerifyError(
"main + sel evaluation verify failed",
"main + sel evaluation verify failed".into(),
));
}
// verify records (degree = 1) statement, thus no sumcheck
Expand All @@ -298,7 +318,9 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
eval_by_expr(&proof.wits_in_evals, challenges, expr) != *expected_evals
})
{
return Err(ZKVMError::VerifyError("record evaluate != expected_evals"));
return Err(ZKVMError::VerifyError(
"record evaluate != expected_evals".into(),
));
}

// verify zero expression (degree = 1) statement, thus no sumcheck
Expand Down Expand Up @@ -326,7 +348,6 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
let cs = circuit_vk.get_cs();
let lk_counts_per_instance = cs.lk_table_expressions.len();
let log2_lk_count = ceil_log2(lk_counts_per_instance);
let (chip_record_alpha, _) = (challenges[0], challenges[1]);

let num_instances = proof.num_instances;
let log2_num_instances = ceil_log2(num_instances);
Expand Down Expand Up @@ -394,8 +415,7 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
* ((0..lk_counts_per_instance)
.map(|i| proof.lk_d_in_evals[i] * eq_lk[i])
.sum::<E>()
+ chip_record_alpha
* (eq_lk[lk_counts_per_instance..].iter().sum::<E>() - E::ONE)),
+ (eq_lk[lk_counts_per_instance..].iter().sum::<E>() - E::ONE)),
*alpha_lk_n
* sel_lk
* ((0..lk_counts_per_instance)
Expand All @@ -405,7 +425,9 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
.iter()
.sum::<E>();
if computed_evals != expected_evaluation {
return Err(ZKVMError::VerifyError("sel evaluation verify failed"));
return Err(ZKVMError::VerifyError(
"sel evaluation verify failed".into(),
));
}
// verify records (degree = 1) statement, thus no sumcheck
if cs
Expand All @@ -427,7 +449,9 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
) != *expected_evals
})
{
return Err(ZKVMError::VerifyError("record evaluate != expected_evals"));
return Err(ZKVMError::VerifyError(
"record evaluate != expected_evals".into(),
));
}

Ok(input_opening_point)
Expand Down Expand Up @@ -524,6 +548,7 @@ impl TowerVerify {
},
transcript,
);
tracing::debug!("verified tower proof at layer {}/{}", round + 1, expected_max_round-1);

// check expected_evaluation
let rt: Point<E> = sumcheck_claim.point.iter().map(|c| c.elements).collect();
Expand Down Expand Up @@ -555,7 +580,7 @@ impl TowerVerify {
})
.sum::<E>();
if expected_evaluation != sumcheck_claim.expected_evaluation {
return Err(ZKVMError::VerifyError("mismatch tower evaluation"));
return Err(ZKVMError::VerifyError("mismatch tower evaluation".into()));
}

// derive single eval
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/tables/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {
for limb in inputs {
u16_mlt[*limb] += 1;
}
tracing::debug!("u16_mult[4] = {}", u16_mlt[4]);

let mut witness = RowMajorMatrix::<E::BaseField>::new(u16_mlt.len(), num_witin);
witness
Expand Down

0 comments on commit e598e98

Please sign in to comment.