Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add e2e prover #188

Merged
merged 22 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ceno_emul/src/rv32im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ pub struct Instruction {
pub func7: u32,
}

impl Default for Instruction {
fn default() -> Self {
insn(InsnKind::INVALID, InsnCategory::Invalid, 0x00, 0x0, 0x00)
}
}

impl DecodedInstruction {
pub fn new(insn: u32) -> Self {
Self {
Expand Down
10 changes: 10 additions & 0 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
rv32im::DecodedInstruction,
CENO_PLATFORM,
};
use crate::rv32im::Instruction;

/// An instruction and its context in an execution trace. That is concrete values of registers and memory.
///
Expand All @@ -22,6 +23,7 @@ pub struct StepRecord {
pub cycle: Cycle,
pub pc: Change<ByteAddr>,
pub insn_code: Word,
pub insn: Instruction,

pub rs1: Option<ReadOp>,
pub rs2: Option<ReadOp>,
Expand Down Expand Up @@ -69,6 +71,10 @@ impl StepRecord {
DecodedInstruction::new(self.insn_code)
}

pub fn insn(&self) -> Instruction {
self.insn
}

pub fn rs1(&self) -> Option<ReadOp> {
self.rs1.clone()
}
Expand Down Expand Up @@ -141,6 +147,10 @@ impl Tracer {
self.record.insn_code = value;
}

pub fn store_insn(&mut self, insn: Instruction) {
self.record.insn = insn;
}

pub fn load_register(&mut self, idx: RegIdx, value: Word) {
let addr = CENO_PLATFORM.register_vma(idx).into();

Expand Down
8 changes: 7 additions & 1 deletion ceno_emul/src/vm_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ impl VMState {
Ok(step)
}
}

pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) {
self.registers[idx] = value;
}
}

impl EmuContext for VMState {
Expand Down Expand Up @@ -109,7 +113,9 @@ impl EmuContext for VMState {
Err(anyhow!("Trap {:?}", cause)) // Crash.
}

fn on_insn_decoded(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) {}
fn on_insn_decoded(&mut self, insn: &Instruction, _decoded: &DecodedInstruction) {
self.tracer.store_insn(*insn);
}

fn on_normal_end(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) {
self.tracer.store_pc(ByteAddr(self.pc));
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ fn bench_add(c: &mut Criterion) {
.collect_vec();
let timer = Instant::now();
let _ = prover
.create_proof(
.create_opcode_proof(
wits_in,
num_instances,
max_threads,
Expand Down
130 changes: 91 additions & 39 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use std::time::Instant;

use ark_std::test_rng;
use ceno_zkvm::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::{riscv::addsub::AddInstruction, Instruction},
scheme::prover::ZKVMProver,
};
use ceno_zkvm::{instructions::riscv::addsub::AddInstruction, scheme::prover::ZKVMProver};
use const_env::from_env;

use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM};
use ceno_zkvm::{
scheme::verifier::ZKVMVerifier,
structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::RangeTableCircuit,
};
use ff_ext::ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLE;
use goldilocks::GoldilocksExt2;
use sumcheck::util::is_power_of_2;
use tracing_flame::FlameLayer;
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry};
Expand All @@ -20,7 +20,23 @@ use transcript::Transcript;
#[from_env]
const RAYON_NUM_THREADS: usize = 8;

// For now, we assume registers
// - x0 is not touched,
// - x1 is initialized to 1,
// - x2 is initialized to -1,
// - x3 is initialized to loop bound.
// we use x4 to hold the acc_sum.
const PROGRAM_ADD_LOOP: [u32; 4] = [
// func7 rs2 rs1 f3 rd opcode
0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1
0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1
0b_1_111111_00000_00011_001_1100_1_1100011, // bne x3, x0, -8
0b_000000000000_00000_000_00000_1110011, // ecall halt
];

fn main() {
type E = GoldilocksExt2;

let max_threads = {
if !is_power_of_2(RAYON_NUM_THREADS) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
Expand All @@ -41,16 +57,6 @@ fn main() {
RAYON_NUM_THREADS
}
};
let mut cs = ConstraintSystem::new(|| "risv_add");
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);
let _ = AddInstruction::construct_circuit(&mut circuit_builder);
let pk = cs.key_gen(None);
let num_witin = pk.get_cs().num_witin;

let prover = ZKVMProver::new(pk);
let mut transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap();
let subscriber = Registry::default()
Expand All @@ -64,34 +70,80 @@ fn main() {
.with(flame_layer.with_threads_collapsed(true));
tracing::subscriber::set_global_default(subscriber).unwrap();

for instance_num_vars in 20..22 {
// generate mock witness
// keygen
let mut zkvm_cs = ZKVMConstraintSystem::default();
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let range_config = zkvm_cs.register_table_circuit::<RangeTableCircuit<E>>();

let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);
zkvm_fixed_traces
.register_table_circuit::<RangeTableCircuit<E>>(&zkvm_cs, range_config.clone());

let pk = zkvm_cs
.clone()
.key_gen(zkvm_fixed_traces)
.expect("keygen failed");
let vk = pk.get_vk();

// proving
let prover = ZKVMProver::new(pk);
let verifier = ZKVMVerifier::new(vk);

for instance_num_vars in 8..22 {
let num_instances = 1 << instance_num_vars;
let wits_in = (0..num_witin as usize)
.map(|_| {
(0..num_instances)
.map(|_| Goldilocks::random(&mut rng))
.collect::<Vec<Goldilocks>>()
.into_mle()
.into()
})
.collect_vec();
let mut vm = VMState::new(CENO_PLATFORM);
let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr();

// init vm.x1 = 1, vm.x2 = -1, vm.x3 = num_instances
// vm.x4 += vm.x1
vm.init_register_unsafe(1usize, 1);
vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement
vm.init_register_unsafe(3usize, num_instances as u32);
for (i, inst) in PROGRAM_ADD_LOOP.iter().enumerate() {
vm.init_memory(pc_start + i, *inst);
}
let records = vm
.iter_until_success()
.collect::<Result<Vec<StepRecord>, _>>()
.expect("vm exec failed")
.into_iter()
.filter(|record| record.insn().kind == ADD)
.collect::<Vec<_>>();
tracing::info!("tracer generated {} ADD records", records.len());

let mut zkvm_witness = ZKVMWitnesses::default();
// assign opcode circuits
zkvm_witness
.assign_opcode_circuit::<AddInstruction<E>>(&zkvm_cs, &add_config, records)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Future TODO)
Just an further refine of api , if we make instruction config implementing a trait, then we can also keep xxx_config in zkvm_cs. With that, the api can be further simplified to

zkvm_cs.register_opcode_circuit::<AddInstruction<E>>(); // no return value, config in maintained in zkvm_cs
...
zkvm_witness.assign_opcode_circuit::<AddInstruction<E>>(&zkvm_cs, records)
...

.unwrap();
zkvm_witness.finalize_lk_multiplicities();
// assign table circuits
zkvm_witness
.assign_table_circuit::<RangeTableCircuit<E>>(&zkvm_cs, &range_config)
.unwrap();

let timer = Instant::now();
let _ = prover
.create_proof(
wits_in,
num_instances,
max_threads,
&mut transcript,
&real_challenges,
)

let mut transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let zkvm_proof = prover
.create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges)
.expect("create_proof failed");

let mut transcript = Transcript::new(b"riscv");
assert!(
verifier
.verify_proof(zkvm_proof, &mut transcript, &real_challenges)
.expect("verify proof return with error"),
);

println!(
"AddInstruction::create_proof, instance_num_vars = {}, time = {}",
instance_num_vars,
timer.elapsed().as_secs_f64()
);
}

type E = GoldilocksExt2;
}
40 changes: 11 additions & 29 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use itertools::Itertools;
use std::marker::PhantomData;

use ff_ext::ExtensionField;
use multilinear_extensions::mle::DenseMultilinearExtension;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
error::ZKVMError,
expression::{Expression, Fixed, WitIn},
structs::WitnessId,
structs::{ProvingKey, VerifyingKey, WitnessId},
witness::RowMajorMatrix,
};

/// namespace used for annotation, preserve meta info during circuit construction
Expand Down Expand Up @@ -135,7 +137,13 @@ impl<E: ExtensionField> ConstraintSystem<E> {
}
}

pub fn key_gen(self, fixed_traces: Option<Vec<DenseMultilinearExtension<E>>>) -> ProvingKey<E> {
pub fn key_gen(self, fixed_traces: Option<RowMajorMatrix<E::BaseField>>) -> ProvingKey<E> {
// TODO: commit to fixed_traces

// transpose from row-major to column-major
let fixed_traces =
fixed_traces.map(|t| t.de_interleaving().into_mles().into_iter().collect_vec());

ProvingKey {
fixed_traces,
vk: VerifyingKey { cs: self },
Expand Down Expand Up @@ -293,29 +301,3 @@ impl<E: ExtensionField> ConstraintSystem<E> {
pub struct CircuitBuilder<'a, E: ExtensionField> {
pub(crate) cs: &'a mut ConstraintSystem<E>,
}

#[derive(Clone, Debug)]
pub struct ProvingKey<E: ExtensionField> {
pub fixed_traces: Option<Vec<DenseMultilinearExtension<E>>>,
pub vk: VerifyingKey<E>,
}

impl<E: ExtensionField> ProvingKey<E> {
// pub fn create_pk(vk: VerifyingKey<E>) -> Self {
// Self { vk }
// }
pub fn get_cs(&self) -> &ConstraintSystem<E> {
self.vk.get_cs()
}
}

#[derive(Clone, Debug)]
pub struct VerifyingKey<E: ExtensionField> {
cs: ConstraintSystem<E>,
}

impl<E: ExtensionField> VerifyingKey<E> {
pub fn get_cs(&self) -> &ConstraintSystem<E> {
&self.cs
}
}
5 changes: 4 additions & 1 deletion ceno_zkvm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ pub enum UtilError {
pub enum ZKVMError {
CircuitError,
UtilError(UtilError),
VerifyError(&'static str),
WitnessNotFound(String),
VKNotFound(String),
FixedTraceNotFound(String),
VerifyError(String),
}

impl From<UtilError> for ZKVMError {
Expand Down
2 changes: 2 additions & 0 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub mod riscv;

pub trait Instruction<E: ExtensionField> {
type InstructionConfig: Send + Sync;

fn name() -> String;
fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError>;
Expand Down
Loading
Loading