diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index bb2dbacab..729e2f430 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -5,10 +5,11 @@ use ceno_zkvm::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{riscv::addsub::AddInstruction, Instruction}, scheme::prover::ZKVMProver, + UIntValue, }; use const_env::from_env; -use ceno_emul::StepRecord; +use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM}; use ceno_zkvm::{ circuit_builder::ZKVMConstraintSystem, scheme::verifier::ZKVMVerifier, @@ -16,6 +17,7 @@ use ceno_zkvm::{ }; use ff_ext::ff::Field; use goldilocks::GoldilocksExt2; +use itertools::Itertools; use sumcheck::util::is_power_of_2; use tracing_flame::FlameLayer; use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry}; @@ -24,6 +26,20 @@ use transcript::Transcript; #[from_env] const RAYON_NUM_THREADS: usize = 8; +// For now, we assume registers +// - x0 is not touched, +// - x1 is initialized to 1, +// - x2 is initialized to -1, +// - x3 is initialized to loop bound. +// we use x4 to hold the acc_sum. +const PROGRAM_ADD_LOOP: [u32; 4] = [ + // func7 rs2 rs1 f3 rd opcode + 0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1 + 0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1 + 0b_1_111111_00000_00011_001_1100_1_1100011, // bne x3, x0, -8 + 0b_000000000000_00000_000_00000_1110011, // ecall halt +]; + fn main() { type E = GoldilocksExt2; @@ -89,30 +105,63 @@ fn main() { ); (cs, config) }; - let pk = zkvm_cs.key_gen(zkvm_fixed_traces); + let pk = zkvm_cs.key_gen(zkvm_fixed_traces).expect("keygen failed"); let vk = pk.get_vk(); // proving let prover = ZKVMProver::new(pk); let verifier = ZKVMVerifier::new(vk); - for instance_num_vars in 15..22 { - // TODO: witness generation from step records emitted by tracer + for instance_num_vars in 8..22 { let num_instances = 1 << instance_num_vars; + let mut vm = VMState::new(CENO_PLATFORM); + let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); + + // init vm.x1 = 1, vm.x2 = -1, vm.x3 = num_instances + // vm.x4 += vm.x1 + vm.init_register_unsafe(1usize, 1); + vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement + vm.init_register_unsafe(3usize, num_instances as u32); + for (i, inst) in PROGRAM_ADD_LOOP.iter().enumerate() { + vm.init_memory(pc_start + i, *inst); + } + let records = vm + .iter_until_success() + .collect::, _>>() + .expect("vm exec failed") + .into_iter() + .filter(|record| record.insn().kind == ADD) + .collect::>(); + tracing::info!("tracer generated {} ADD records", records.len()); + + // TODO: generate range check inputs from opcode_circuit::assign_instances() + let rc_inputs = records + .iter() + .flat_map(|record| { + let rs1 = UIntValue::new(record.rs1().unwrap().value); + let rs2 = UIntValue::new(record.rs2().unwrap().value); + + let rd_prev = UIntValue::new(record.rd().unwrap().value.before); + let rd = UIntValue::new(record.rd().unwrap().value.after); + let carries = rs1 + .add_u16_carries(&rs2) + .into_iter() + .map(|c| c as u16) + .collect_vec(); + + [rd_prev.limbs, rd.limbs, carries].concat() + }) + .map(|x| x as usize) + .collect::>(); + let mut zkvm_witness = BTreeMap::default(); - let add_witness = AddInstruction::assign_instances( - &add_config, - add_cs.num_witin as usize, - vec![StepRecord::default(); num_instances], - ) - .unwrap(); + let add_witness = + AddInstruction::assign_instances(&add_config, add_cs.num_witin as usize, records) + .unwrap(); let range_witness = RangeTableCircuit::::assign_instances( &range_config, range_cs.num_witin as usize, - // TODO: use real data - vec![vec![0; num_instances * 2], vec![4; num_instances * 6]] - .concat() - .as_slice(), + &rc_inputs, ) .unwrap(); diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 517f1c04f..be2139f05 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -141,13 +141,8 @@ impl ConstraintSystem { // TODO: commit to fixed_traces // transpose from row-major to column-major - let fixed_traces = fixed_traces.map(|t| { - t.de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec() - }); + let fixed_traces = + fixed_traces.map(|t| t.de_interleaving().into_mles().into_iter().collect_vec()); ProvingKey { fixed_traces, diff --git a/ceno_zkvm/src/error.rs b/ceno_zkvm/src/error.rs index ea59969f8..b7791def9 100644 --- a/ceno_zkvm/src/error.rs +++ b/ceno_zkvm/src/error.rs @@ -7,6 +7,9 @@ pub enum UtilError { pub enum ZKVMError { CircuitError, UtilError(UtilError), + WitnessNotFound(String), + VKNotFound(String), + FixedTraceNotFound(String), VerifyError(String), } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 6224d016c..2b873d50d 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -161,9 +161,11 @@ impl Instruction for AddInstruction { set_val!(instance, config.ts, 2); let addend_0 = UIntValue::new(step.rs1().unwrap().value); let addend_1 = UIntValue::new(step.rs2().unwrap().value); + let outcome = UIntValue::new(step.rd().unwrap().value.after); + let rd_prev = UIntValue::new(step.rd().unwrap().value.before); config .prev_rd_value - .assign_limbs(instance, [0, 0].iter().map(E::BaseField::from).collect()); + .assign_limbs(instance, rd_prev.u16_fields()); config .addend_0 .assign_limbs(instance, addend_0.u16_fields()); @@ -178,6 +180,7 @@ impl Instruction for AddInstruction { .map(|carry| E::BaseField::from(carry as u64)) .collect_vec(), ); + config.outcome.assign_limbs(instance, outcome.u16_fields()); // TODO #167 set_val!(instance, config.rs1_id, 2); set_val!(instance, config.rs2_id, 2); diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index c475878bf..3bce91d11 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -1,5 +1,6 @@ use crate::{ circuit_builder::{ZKVMConstraintSystem, ZKVMProvingKey}, + error::ZKVMError, witness::RowMajorMatrix, }; use ff_ext::ExtensionField; @@ -9,22 +10,18 @@ impl ZKVMConstraintSystem { pub fn key_gen( self, mut vm_fixed_traces: BTreeMap>>, - ) -> ZKVMProvingKey { + ) -> Result, ZKVMError> { let mut vm_pk = ZKVMProvingKey::default(); for (c_name, cs) in self.circuit_css.into_iter() { - let fixed_traces = vm_fixed_traces.remove(&c_name).expect( - format!( - "circuit {}'s trace is not present in vm_fixed_traces", - c_name - ) - .as_str(), - ); + let fixed_traces = vm_fixed_traces + .remove(&c_name) + .ok_or(ZKVMError::FixedTraceNotFound(c_name.clone()))?; let circuit_pk = cs.key_gen(fixed_traces); assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); } - vm_pk + Ok(vm_pk) } } diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index c29f030ba..dbe64e06e 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -15,3 +15,5 @@ mod uint; mod utils; mod virtual_polys; mod witness; + +pub use uint::UIntValue; diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 685e6f1cb..6784029b0 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -57,7 +57,7 @@ impl ZKVMProver { for (circuit_name, pk) in self.pk.circuit_pks.iter() { let witness = witnesses .remove(circuit_name) - .expect(format!("witness for circuit {} is not found", circuit_name).as_str()); + .ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; // TODO: add an enum for circuit type either in constraint_system or vk let cs = pk.get_cs(); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 92c4905f7..4f6803a24 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -54,7 +54,7 @@ impl ZKVMVerifier { .vk .circuit_vks .get(&name) - .expect(format!("vk of opcode circuit {} is not present", name).as_str()); + .ok_or(ZKVMError::VKNotFound(name.clone()))?; let _rand_point = self.verify_opcode_proof( circuit_vk, &opcode_proof, @@ -88,7 +88,7 @@ impl ZKVMVerifier { .vk .circuit_vks .get(&name) - .expect(format!("vk of table circuit {} is not present", name).as_str()); + .ok_or(ZKVMError::VKNotFound(name.clone()))?; let _rand_point = self.verify_table_proof( circuit_vk, &table_proof,