Skip to content

Commit

Permalink
use tracer to generate step records for riscv_add example
Browse files Browse the repository at this point in the history
  • Loading branch information
kunxian-xia committed Sep 9, 2024
1 parent 865e0cc commit fb1174d
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 34 deletions.
77 changes: 63 additions & 14 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ use ceno_zkvm::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::{riscv::addsub::AddInstruction, Instruction},
scheme::prover::ZKVMProver,
UIntValue,
};
use const_env::from_env;

use ceno_emul::StepRecord;
use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM};
use ceno_zkvm::{
circuit_builder::ZKVMConstraintSystem,
scheme::verifier::ZKVMVerifier,
tables::{RangeTableCircuit, TableCircuit},
};
use ff_ext::ff::Field;
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use sumcheck::util::is_power_of_2;
use tracing_flame::FlameLayer;
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry};
Expand All @@ -24,6 +26,20 @@ use transcript::Transcript;
#[from_env]
const RAYON_NUM_THREADS: usize = 8;

// For now, we assume registers
// - x0 is not touched,
// - x1 is initialized to 1,
// - x2 is initialized to -1,
// - x3 is initialized to loop bound.
// we use x4 to hold the acc_sum.
const PROGRAM_ADD_LOOP: [u32; 4] = [
// func7 rs2 rs1 f3 rd opcode
0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1
0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1
0b_1_111111_00000_00011_001_1100_1_1100011, // bne x3, x0, -8
0b_000000000000_00000_000_00000_1110011, // ecall halt
];

fn main() {
type E = GoldilocksExt2;

Expand Down Expand Up @@ -89,30 +105,63 @@ fn main() {
);
(cs, config)
};
let pk = zkvm_cs.key_gen(zkvm_fixed_traces);
let pk = zkvm_cs.key_gen(zkvm_fixed_traces).expect("keygen failed");
let vk = pk.get_vk();

// proving
let prover = ZKVMProver::new(pk);
let verifier = ZKVMVerifier::new(vk);

for instance_num_vars in 15..22 {
// TODO: witness generation from step records emitted by tracer
for instance_num_vars in 8..22 {
let num_instances = 1 << instance_num_vars;
let mut vm = VMState::new(CENO_PLATFORM);
let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr();

// init vm.x1 = 1, vm.x2 = -1, vm.x3 = num_instances
// vm.x4 += vm.x1
vm.init_register_unsafe(1usize, 1);
vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement
vm.init_register_unsafe(3usize, num_instances as u32);
for (i, inst) in PROGRAM_ADD_LOOP.iter().enumerate() {
vm.init_memory(pc_start + i, *inst);
}
let records = vm
.iter_until_success()
.collect::<Result<Vec<StepRecord>, _>>()
.expect("vm exec failed")
.into_iter()
.filter(|record| record.insn().kind == ADD)
.collect::<Vec<_>>();
tracing::info!("tracer generated {} ADD records", records.len());

// TODO: generate range check inputs from opcode_circuit::assign_instances()
let rc_inputs = records
.iter()
.flat_map(|record| {
let rs1 = UIntValue::new(record.rs1().unwrap().value);
let rs2 = UIntValue::new(record.rs2().unwrap().value);

let rd_prev = UIntValue::new(record.rd().unwrap().value.before);
let rd = UIntValue::new(record.rd().unwrap().value.after);
let carries = rs1
.add_u16_carries(&rs2)
.into_iter()
.map(|c| c as u16)
.collect_vec();

[rd_prev.limbs, rd.limbs, carries].concat()
})
.map(|x| x as usize)
.collect::<Vec<_>>();

let mut zkvm_witness = BTreeMap::default();
let add_witness = AddInstruction::assign_instances(
&add_config,
add_cs.num_witin as usize,
vec![StepRecord::default(); num_instances],
)
.unwrap();
let add_witness =
AddInstruction::assign_instances(&add_config, add_cs.num_witin as usize, records)
.unwrap();
let range_witness = RangeTableCircuit::<E>::assign_instances(
&range_config,
range_cs.num_witin as usize,
// TODO: use real data
vec![vec![0; num_instances * 2], vec![4; num_instances * 6]]
.concat()
.as_slice(),
&rc_inputs,
)
.unwrap();

Expand Down
9 changes: 2 additions & 7 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,8 @@ impl<E: ExtensionField> ConstraintSystem<E> {
// TODO: commit to fixed_traces

// transpose from row-major to column-major
let fixed_traces = fixed_traces.map(|t| {
t.de_interleaving()
.into_mles()
.into_iter()
.map(|v| v.into())
.collect_vec()
});
let fixed_traces =
fixed_traces.map(|t| t.de_interleaving().into_mles().into_iter().collect_vec());

ProvingKey {
fixed_traces,
Expand Down
3 changes: 3 additions & 0 deletions ceno_zkvm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ pub enum UtilError {
pub enum ZKVMError {
CircuitError,
UtilError(UtilError),
WitnessNotFound(String),
VKNotFound(String),
FixedTraceNotFound(String),
VerifyError(String),
}

Expand Down
5 changes: 4 additions & 1 deletion ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction<E> {
set_val!(instance, config.ts, 2);
let addend_0 = UIntValue::new(step.rs1().unwrap().value);
let addend_1 = UIntValue::new(step.rs2().unwrap().value);
let outcome = UIntValue::new(step.rd().unwrap().value.after);
let rd_prev = UIntValue::new(step.rd().unwrap().value.before);
config
.prev_rd_value
.assign_limbs(instance, [0, 0].iter().map(E::BaseField::from).collect());
.assign_limbs(instance, rd_prev.u16_fields());
config
.addend_0
.assign_limbs(instance, addend_0.u16_fields());
Expand All @@ -178,6 +180,7 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction<E> {
.map(|carry| E::BaseField::from(carry as u64))
.collect_vec(),
);
config.outcome.assign_limbs(instance, outcome.u16_fields());
// TODO #167
set_val!(instance, config.rs1_id, 2);
set_val!(instance, config.rs2_id, 2);
Expand Down
15 changes: 6 additions & 9 deletions ceno_zkvm/src/keygen.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{
circuit_builder::{ZKVMConstraintSystem, ZKVMProvingKey},
error::ZKVMError,
witness::RowMajorMatrix,
};
use ff_ext::ExtensionField;
Expand All @@ -9,22 +10,18 @@ impl<E: ExtensionField> ZKVMConstraintSystem<E> {
pub fn key_gen(
self,
mut vm_fixed_traces: BTreeMap<String, Option<RowMajorMatrix<E::BaseField>>>,
) -> ZKVMProvingKey<E> {
) -> Result<ZKVMProvingKey<E>, ZKVMError> {
let mut vm_pk = ZKVMProvingKey::default();

for (c_name, cs) in self.circuit_css.into_iter() {
let fixed_traces = vm_fixed_traces.remove(&c_name).expect(
format!(
"circuit {}'s trace is not present in vm_fixed_traces",
c_name
)
.as_str(),
);
let fixed_traces = vm_fixed_traces
.remove(&c_name)
.ok_or(ZKVMError::FixedTraceNotFound(c_name.clone()))?;

let circuit_pk = cs.key_gen(fixed_traces);
assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none());
}

vm_pk
Ok(vm_pk)
}
}
2 changes: 2 additions & 0 deletions ceno_zkvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ mod uint;
mod utils;
mod virtual_polys;
mod witness;

pub use uint::UIntValue;
2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl<E: ExtensionField> ZKVMProver<E> {
for (circuit_name, pk) in self.pk.circuit_pks.iter() {
let witness = witnesses
.remove(circuit_name)
.expect(format!("witness for circuit {} is not found", circuit_name).as_str());
.ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?;

// TODO: add an enum for circuit type either in constraint_system or vk
let cs = pk.get_cs();
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
.vk
.circuit_vks
.get(&name)
.expect(format!("vk of opcode circuit {} is not present", name).as_str());
.ok_or(ZKVMError::VKNotFound(name.clone()))?;
let _rand_point = self.verify_opcode_proof(
circuit_vk,
&opcode_proof,
Expand Down Expand Up @@ -88,7 +88,7 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
.vk
.circuit_vks
.get(&name)
.expect(format!("vk of table circuit {} is not present", name).as_str());
.ok_or(ZKVMError::VKNotFound(name.clone()))?;
let _rand_point = self.verify_table_proof(
circuit_vk,
&table_proof,
Expand Down

0 comments on commit fb1174d

Please sign in to comment.