diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 99b764364..3c37b117a 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -1,6 +1,8 @@ -use crate::{Change, EmuContext, Platform, VMState, WORD_SIZE, WordAddr, WriteOp}; +use crate::{ + Change, EmuContext, Platform, RegIdx, Tracer, VMState, WORD_SIZE, Word, WordAddr, WriteOp, +}; use anyhow::Result; -use itertools::{Itertools, izip}; +use itertools::{Itertools, chain, izip}; use tiny_keccak::keccakf; /// A syscall event, available to the circuit witness generators. @@ -13,10 +15,39 @@ pub struct SyscallWitness { /// The effects of a syscall to apply on the VM. #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct SyscallEffects { - pub witness: SyscallWitness, + /// The witness being built. Get it with `finalize`. + witness: SyscallWitness, + + /// The next PC after the syscall. Defaults to the next instruction. pub next_pc: Option, } +impl SyscallEffects { + /// Iterate over the register values after the syscall. + pub fn iter_reg_values(&self) -> impl Iterator + '_ { + self.witness + .reg_accesses + .iter() + .map(|op| (op.register_index(), op.value.after)) + } + + /// Iterate over the memory values after the syscall. + pub fn iter_mem_values(&self) -> impl Iterator + '_ { + self.witness + .mem_writes + .iter() + .map(|op| (op.addr, op.value.after)) + } + + /// Keep track of the cycles of registers and memory accesses. + pub fn finalize(mut self, tracer: &mut Tracer) -> SyscallWitness { + for op in chain(&mut self.witness.reg_accesses, &mut self.witness.mem_writes) { + op.previous_cycle = tracer.track_access(op.addr, 0); + } + self.witness + } +} + pub const KECCAK_PERMUTE: u32 = 0x00_01_01_09; /// Trace the inputs and effects of a syscall. diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 88ef444cc..fd9a12cf0 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -1,7 +1,5 @@ use std::{collections::HashMap, fmt, mem}; -use itertools::chain; - use crate::{ CENO_PLATFORM, InsnKind, PC_STEP_SIZE, Platform, addr::{ByteAddr, Cycle, RegIdx, Word, WordAddr}, @@ -405,29 +403,18 @@ impl Tracer { }); } - pub fn track_syscall(&mut self, mut effects: SyscallEffects) { - // Keep track of the cycles of registers and memory accesses. - for op in chain( - &mut effects.witness.reg_accesses, - &mut effects.witness.mem_writes, - ) { - op.previous_cycle = self.track_access(op.addr, 0); - assert_ne!( - op.previous_cycle, self.record.cycle, - "Address {:?} was accessed twice in the same cycle", - op.addr - ); - } + pub fn track_syscall(&mut self, effects: SyscallEffects) { + let witness = effects.finalize(self); assert!(self.record.syscall.is_none(), "Only one syscall per step"); - self.record.syscall = Some(effects.witness); + self.record.syscall = Some(witness); } /// - Return the cycle when an address was last accessed. /// - Return 0 if this is the first access. /// - Record the current instruction as the origin of the latest access. /// - Accesses within the same instruction are distinguished by `subcycle ∈ [0, 3]`. - fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { + pub fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { self.latest_accesses .insert(addr, self.record.cycle + subcycle) .unwrap_or(0) diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 0eba5fe9c..536d77d3e 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -108,12 +108,12 @@ impl VMState { } fn apply_syscall(&mut self, effects: SyscallEffects) -> Result<()> { - for write_op in &effects.witness.mem_writes { - self.memory.insert(write_op.addr, write_op.value.after); + for (addr, value) in effects.iter_mem_values() { + self.memory.insert(addr, value); } - for reg_access in &effects.witness.reg_accesses { - self.registers[reg_access.register_index()] = reg_access.value.after; + for (idx, value) in effects.iter_reg_values() { + self.registers[idx] = value; } let next_pc = effects.next_pc.unwrap_or(self.pc + PC_STEP_SIZE as u32);