diff --git a/.clippy.toml b/.clippy.toml index 9c71707ab2..0538654263 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -5,3 +5,4 @@ disallowed-methods = [ { path = "pasta_curves::Fp", reason = "use pasta_curves::pallas::Base or pasta_curves::vesta::Scalar instead to communicate your intent" }, { path = "pasta_curves::Fq", reason = "use pasta_curves::pallas::Scalar or pasta_curves::vesta::Base instead to communicate your intent" }, ] +allow-dbg-in-tests = true diff --git a/src/circuit/gadgets/constraints.rs b/src/circuit/gadgets/constraints.rs index 98a01e12d4..9788b804b0 100644 --- a/src/circuit/gadgets/constraints.rs +++ b/src/circuit/gadgets/constraints.rs @@ -310,6 +310,28 @@ pub(crate) fn div>( Ok(res) } +pub(crate) fn invert>( + mut cs: CS, + a: &AllocatedNum, +) -> Result, SynthesisError> { + let inv = AllocatedNum::alloc(cs.namespace(|| "invert"), || { + let inv = (a.get_value().ok_or(SynthesisError::AssignmentMissing)?).invert(); + + let inv_opt: Option<_> = inv.into(); + inv_opt.ok_or(SynthesisError::DivisionByZero) + })?; + + // inv * a = 1 + cs.enforce( + || "inversion", + |lc| lc + inv.get_variable(), + |lc| lc + a.get_variable(), + |lc| lc + CS::one(), + ); + + Ok(inv) +} + /// Select the nth element of `from`, where `path_bits` represents n, least-significant bit first. /// The returned result contains the selected element, and constraints are enforced. /// `from.len()` must be a power of two. diff --git a/src/circuit/gadgets/pointer.rs b/src/circuit/gadgets/pointer.rs index 458c07e3da..b3a6072d5d 100644 --- a/src/circuit/gadgets/pointer.rs +++ b/src/circuit/gadgets/pointer.rs @@ -132,6 +132,14 @@ impl AllocatedPtr { &self.hash } + pub fn get_value(&self) -> Option> { + self.tag.get_value().and_then(|tag| { + self.hash + .get_value() + .map(|hash| ZPtr::from_parts(Tag::from_field(&tag).expect("bad tag"), hash)) + }) + } + pub fn enforce_equal>(&self, cs: &mut CS, other: &Self) { // debug_assert_eq!(self.tag.get_value(), other.tag.get_value()); enforce_equal(cs, || "tags equal", &self.tag, &other.tag); diff --git a/src/coprocessor/memoset/mod.rs b/src/coprocessor/memoset/mod.rs new file mode 100644 index 0000000000..ced6ee354f --- /dev/null +++ b/src/coprocessor/memoset/mod.rs @@ -0,0 +1,850 @@ +//! The `memoset` module implements a `MemoSet`. +//! +//! A `MemoSet` is an abstraction we use to memoize deferred proof of (potentially mutually-recursive) query results. +//! Whenever a computation being proved needs the result of a query, the prover non-deterministically supplies the +//! correct response. The resulting key-value pair is then added to a multiset representing deferred proofs. The +//! dependent proof now must not be accepted until every element in the deferred-proof multiset has been proved. +//! +//! Implementation depends on a cryptographic multiset -- for example, ECMH or LogUp (implemented here). This allows us +//! to prove that every element added to to the multiset is later removed only after having been proved. The +//! cryptographic assumption is that it is infeasible to fraudulently demonstrate multiset equality. +//! +//! Our use of the LogUp (logarithmic derivative) technique in the `LogMemo` implementation of `MemoSet` unfortunately +//! requires that the entire history of insertions and removals be committed to in advance -- so that Fiat-Shamir +//! randomness derived from the transcript can be used when mapping field elements to multiset elements. We use Lurk +//! data to assemble the transcript, so that the final randomness is the hash/value component of a `ZPtr` to the +//! content-addressed data structure representing the transcript as assembled. +//! +//! Transcript elements represent deferred proofs that are either added to (when their results are used) or removed from +//! (when correctness of those results is proved) the 'deferred proof' multiset. Insertions are recorded in the +//! transcript as key-value pairs (Lurk data: `(key . value)`); and removals further include the removal multiplicity +//! (Lurk data: `((key . value) . multiplicity)`). It is critical that the multiplicity be included in the transcript, +//! since if free to choose it after the randomness has been derived, the prover can trivially falsify the contents of +//! the multiset -- decoupling claimed truths from those actually proved. +//! +//! Bookkeeping required to correctly build the transcript after evaluation but before proving is maintained by the +//! `Scope`. This allows us to accumulate queries and the subqueries on which they depend, along with the memoized query +//! results computed 'naturally' during evaluation. We then separate and sort in an order matching that which the NIVC +//! prover will follow when provably maintaining the multiset accumulator and Fiat-Shamir transcript in the circuit. + +use std::collections::HashMap; +use std::marker::PhantomData; + +use bellpepper_core::{boolean::Boolean, num::AllocatedNum, ConstraintSystem, SynthesisError}; +use itertools::Itertools; +use once_cell::sync::OnceCell; + +use super::gadgets::construct_cons; +use crate::circuit::gadgets::{ + constraints::{enforce_equal, enforce_equal_zero, invert, sub}, + pointer::AllocatedPtr, +}; +use crate::field::LurkField; +use crate::lem::circuit::GlobalAllocator; +use crate::lem::tag::Tag; +use crate::lem::{pointers::Ptr, store::Store}; +use crate::tag::{ExprTag, Tag as XTag}; +use crate::z_ptr::ZPtr; + +use multiset::MultiSet; +use query::{CircuitQuery, DemoCircuitQuery, Query}; + +mod multiset; +mod query; + +type ScopeCircuitQuery = DemoCircuitQuery; +type ScopeQuery = as CircuitQuery>::Q; + +#[derive(Clone, Debug)] +pub struct Transcript { + acc: Ptr, + _p: PhantomData, +} + +impl Transcript { + fn new(s: &Store) -> Self { + let nil = s.intern_nil(); + Self { + acc: nil, + _p: Default::default(), + } + } + + fn add(&mut self, s: &Store, item: Ptr) { + self.acc = s.cons(item, self.acc); + } + + fn make_kv(s: &Store, key: Ptr, value: Ptr) -> Ptr { + s.cons(key, value) + } + + fn make_kv_count(s: &Store, kv: Ptr, count: usize) -> Ptr { + let count_num = s.num(F::from_u64(count as u64)); + s.cons(kv, count_num) + } + + /// Since the transcript is just a content-addressed Lurk list, its randomness is the hash value of the associated + /// top-level `Cons`. This function sanity-checks the type and extracts that field element. + fn r(&self, s: &Store) -> F { + let z_ptr = s.hash_ptr(&self.acc); + assert_eq!(Tag::Expr(ExprTag::Cons), *z_ptr.tag()); + *z_ptr.value() + } + + #[allow(dead_code)] + fn dbg(&self, s: &Store) { + //dbg!(self.acc.fmt_to_string_simple(s)); + tracing::debug!("transcript: {}", self.acc.fmt_to_string_simple(s)); + } + + #[allow(dead_code)] + fn fmt_to_string_simple(&self, s: &Store) -> String { + self.acc.fmt_to_string_simple(s) + } +} + +#[derive(Clone, Debug)] +pub struct CircuitTranscript { + acc: AllocatedPtr, +} + +impl CircuitTranscript { + fn new>(cs: &mut CS, g: &mut GlobalAllocator, s: &Store) -> Self { + let nil = s.intern_nil(); + let allocated_nil = g.alloc_ptr(cs, &nil, s); + Self { + acc: allocated_nil.clone(), + } + } + + pub fn pick>( + cs: &mut CS, + condition: &Boolean, + a: &Self, + b: &Self, + ) -> Result { + let picked = AllocatedPtr::pick(cs, condition, &a.acc, &b.acc)?; + Ok(Self { acc: picked }) + } + + fn add>( + &self, + cs: &mut CS, + g: &GlobalAllocator, + s: &Store, + item: &AllocatedPtr, + ) -> Result { + Ok(Self { + acc: construct_cons(cs, g, s, item, &self.acc)?, + }) + } + + fn make_kv>( + cs: &mut CS, + g: &GlobalAllocator, + s: &Store, + key: &AllocatedPtr, + value: &AllocatedPtr, + ) -> Result, SynthesisError> { + construct_cons(cs, g, s, key, value) + } + + fn make_kv_count>( + cs: &mut CS, + g: &GlobalAllocator, + s: &Store, + kv: &AllocatedPtr, + count: u64, + ) -> Result<(AllocatedPtr, AllocatedNum), SynthesisError> { + let allocated_count = + { AllocatedNum::alloc(&mut cs.namespace(|| "count"), || Ok(F::from_u64(count)))? }; + let count_ptr = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "count_ptr"), + ExprTag::Num.to_field(), + allocated_count.clone(), + )?; + + Ok((construct_cons(cs, g, s, kv, &count_ptr)?, allocated_count)) + } + + fn r(&self) -> &AllocatedNum { + self.acc.hash() + } + + #[allow(dead_code)] + fn dbg(&self, s: &Store) { + let z = self.acc.get_value::().unwrap(); + let transcript = s.to_ptr(&z); + // dbg!(transcript.fmt_to_string_simple(s)); + tracing::debug!("transcript: {}", transcript.fmt_to_string_simple(s)); + } +} + +#[derive(Debug)] +/// A `Scope` tracks the queries made while evaluating, including the subqueries that result from evaluating other +/// queries -- then makes use of the bookkeeping performed at evaluation time to synthesize proof of each query +/// performed. +pub struct Scope { + memoset: M, + /// k => v + queries: HashMap, + /// k => ordered subqueries + dependencies: HashMap>, + /// kv pairs + toplevel_insertions: Vec, + /// internally-inserted keys + internal_insertions: Vec, + /// unique keys + all_insertions: Vec, + _p: PhantomData<(F, Q)>, +} + +impl Default for Scope, LogMemo> { + fn default() -> Self { + Self { + memoset: Default::default(), + queries: Default::default(), + dependencies: Default::default(), + toplevel_insertions: Default::default(), + internal_insertions: Default::default(), + all_insertions: Default::default(), + _p: Default::default(), + } + } +} + +pub struct CircuitScope, M: MemoSet> { + memoset: M, + /// k -> v + queries: HashMap, ZPtr>, + /// k -> allocated v + transcript: CircuitTranscript, + acc: Option>, + _p: PhantomData, +} + +impl Scope, LogMemo> { + pub fn query(&mut self, s: &Store, form: Ptr) -> Ptr { + let (response, kv_ptr) = self.query_aux(s, form); + + self.toplevel_insertions.push(kv_ptr); + + response + } + + fn query_recursively( + &mut self, + s: &Store, + parent: &ScopeQuery, + child: ScopeQuery, + ) -> Ptr { + let form = child.to_ptr(s); + self.internal_insertions.push(form); + + let (response, _) = self.query_aux(s, form); + + self.dependencies + .entry(parent.to_ptr(s)) + .and_modify(|children| children.push(child.clone())) + .or_insert_with(|| vec![child]); + + response + } + + fn query_aux(&mut self, s: &Store, form: Ptr) -> (Ptr, Ptr) { + let response = self.queries.get(&form).cloned().unwrap_or_else(|| { + let query = ScopeQuery::from_ptr(s, &form).expect("invalid query"); + + let evaluated = query.eval(s, self); + + self.queries.insert(form, evaluated); + evaluated + }); + + let kv = Transcript::make_kv(s, form, response); + self.memoset.add(kv); + + (response, kv) + } + + fn finalize_transcript(&mut self, s: &Store) -> Transcript { + let (transcript, insertions) = self.build_transcript(s); + self.memoset.finalize_transcript(s, transcript.clone()); + self.all_insertions = insertions; + transcript + } + + fn ensure_transcript_finalized(&mut self, s: &Store) { + if !self.memoset.is_finalized() { + self.finalize_transcript(s); + } + } + + fn build_transcript(&self, s: &Store) -> (Transcript, Vec) { + let mut transcript = Transcript::new(s); + + let internal_insertions_kv = self.internal_insertions.iter().map(|key| { + let value = self.queries.get(key).expect("value missing for key"); + Transcript::make_kv(s, *key, *value) + }); + + let mut insertions = + Vec::with_capacity(self.toplevel_insertions.len() + self.internal_insertions.len()); + insertions.extend(&self.toplevel_insertions); + insertions.extend(internal_insertions_kv); + + // Sort insertions by query type (index) for processing. This is because the transcript will be constructed + // sequentially by the circuits, and we potentially batch queries of the same type in a single coprocessor + // circuit. + insertions.sort_by_key(|kv| { + let (key, _) = s.car_cdr(kv).unwrap(); + + ScopeQuery::::from_ptr(s, &key) + .expect("invalid query") + .index() + }); + + for kv in self.toplevel_insertions.iter() { + transcript.add(s, *kv); + } + + // Then add insertions and removals interleaved, sorted by query type. We interleave insertions and removals + // because when proving later, each query's proof must record that its subquery proofs are being deferred + // (insertions) before then proving itself (making use of any subquery results) and removing the now-proved + // deferral from the MemoSet. + let unique_keys = insertions + .iter() + .dedup() // We need to process every key's dependencies once. + .map(|kv| { + let key = s.car_cdr(kv).unwrap().0; + + if let Some(dependencies) = self.dependencies.get(&key) { + dependencies + .iter() + .for_each(|dependency| { + let k = dependency.to_ptr(s); + let v = self + .queries + .get(&k) + .expect("value missing for dependency key"); + // Add an insertion for each dependency (subquery) of the query identified by `key`. Notice + // that these keys might already have been inserted before, but we need to repeat if so + // because the proof must do so each time a query is used. + let kv = Transcript::make_kv(s, k, *v); + transcript.add(s, kv) + }) + }; + let count = self.memoset.count(kv); + let kv_count = Transcript::make_kv_count(s, *kv, count); + + // Add removal for the query identified by `key`. The queries being removed here were deduplicated + // above, so each is removed only once. However, we freely choose the multiplicity (`count`) of the + // removal to match the total number of insertions actually made (considering dependencies). + transcript.add(s, kv_count); + + key + }) + .collect::>(); + + (transcript, unique_keys) + } + + pub fn synthesize>( + &mut self, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + ) -> Result<(), SynthesisError> { + self.ensure_transcript_finalized(s); + + { + let circuit_scope = + &mut CircuitScope::from_scope(&mut cs.namespace(|| "transcript"), g, s, self); + circuit_scope.init(cs, g, s); + { + self.synthesize_insert_toplevel_queries(circuit_scope, cs, g, s)?; + self.synthesize_prove_all_queries(circuit_scope, cs, g, s)?; + } + circuit_scope.finalize(cs, g); + Ok(()) + } + } + + fn synthesize_insert_toplevel_queries>( + &mut self, + circuit_scope: &mut CircuitScope, LogMemo>, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + ) -> Result<(), SynthesisError> { + for (i, kv) in self.toplevel_insertions.iter().enumerate() { + circuit_scope.synthesize_toplevel_query(cs, g, s, i, kv)?; + } + Ok(()) + } + + fn synthesize_prove_all_queries>( + &mut self, + circuit_scope: &mut CircuitScope, LogMemo>, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + ) -> Result<(), SynthesisError> { + for (i, kv) in self.all_insertions.iter().enumerate() { + circuit_scope.synthesize_prove_query(cs, g, s, i, kv)?; + } + Ok(()) + } +} + +impl> CircuitScope> { + fn from_scope>( + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + scope: &Scope>, + ) -> Self { + let queries = scope + .queries + .iter() + .map(|(k, v)| (s.hash_ptr(k), s.hash_ptr(v))) + .collect(); + Self { + memoset: scope.memoset.clone(), + queries, + transcript: CircuitTranscript::new(cs, g, s), + acc: Default::default(), + _p: Default::default(), + } + } + + fn init>( + &mut self, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + ) { + self.acc = Some( + AllocatedPtr::alloc_constant(&mut cs.namespace(|| "acc"), s.hash_ptr(&s.num_u64(0))) + .unwrap(), + ); + + self.transcript = CircuitTranscript::new(cs, g, s); + } + + fn synthesize_insert_query>( + &self, + cs: &mut CS, + g: &GlobalAllocator, + s: &Store, + acc: &AllocatedPtr, + transcript: &CircuitTranscript, + key: &AllocatedPtr, + value: &AllocatedPtr, + ) -> Result<(AllocatedPtr, CircuitTranscript), SynthesisError> { + let kv = CircuitTranscript::make_kv(&mut cs.namespace(|| "kv"), g, s, key, value)?; + let new_transcript = transcript.add(&mut cs.namespace(|| "new_transcript"), g, s, &kv)?; + + let acc_v = acc.hash(); + + let new_acc_v = + self.memoset + .synthesize_add(&mut cs.namespace(|| "new_acc_v"), acc_v, &kv)?; + + let new_acc = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "new_acc"), + ExprTag::Num.to_field(), + new_acc_v, + )?; + + Ok((new_acc, new_transcript.clone())) + } + + fn synthesize_remove>( + &self, + cs: &mut CS, + g: &GlobalAllocator, + s: &Store, + acc: &AllocatedPtr, + transcript: &CircuitTranscript, + key: &AllocatedPtr, + value: &AllocatedPtr, + ) -> Result<(AllocatedPtr, CircuitTranscript), SynthesisError> { + let kv = CircuitTranscript::make_kv(&mut cs.namespace(|| "kv"), g, s, key, value)?; + let zptr = kv.get_value().unwrap_or(s.hash_ptr(&s.intern_nil())); // dummy case: use nil + let raw_count = self.memoset.count(&s.to_ptr(&zptr)) as u64; // dummy case: count is meaningless + + let (kv_count, count) = CircuitTranscript::make_kv_count( + &mut cs.namespace(|| "kv_count"), + g, + s, + &kv, + raw_count, + )?; + let new_transcript = transcript.add( + &mut cs.namespace(|| "new_removal_transcript"), + g, + s, + &kv_count, + )?; + + let new_acc_v = self.memoset.synthesize_remove_n( + &mut cs.namespace(|| "new_acc_v"), + acc.hash(), + &kv, + &count, + )?; + + let new_acc = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "new_acc"), + ExprTag::Num.to_field(), + new_acc_v, + )?; + Ok((new_acc, new_transcript)) + } + + fn finalize>(&mut self, cs: &mut CS, _g: &mut GlobalAllocator) { + let r = self.memoset.allocated_r(cs); + enforce_equal(cs, || "r_matches_transcript", self.transcript.r(), &r); + enforce_equal_zero(cs, || "acc_is_zero", self.acc.clone().unwrap().hash()); + } + + fn synthesize_query>( + &mut self, + cs: &mut CS, + g: &GlobalAllocator, + store: &Store, + key: &AllocatedPtr, + acc: &AllocatedPtr, + transcript: &CircuitTranscript, + not_dummy: &Boolean, // TODO: use this more deeply? + ) -> Result<(AllocatedPtr, AllocatedPtr, CircuitTranscript), SynthesisError> { + let value = AllocatedPtr::alloc(&mut cs.namespace(|| "value"), || { + Ok(if not_dummy.get_value() == Some(true) { + *key.get_value() + .and_then(|k| self.queries.get(&k)) + .ok_or(SynthesisError::AssignmentMissing)? + } else { + // Dummy value that will not be used. + store.hash_ptr(&store.intern_nil()) + }) + })?; + + let (new_acc, new_insertion_transcript) = + self.synthesize_insert_query(cs, g, store, acc, transcript, key, &value)?; + + Ok((value, new_acc, new_insertion_transcript)) + } +} + +impl CircuitScope, LogMemo> { + fn synthesize_toplevel_query>( + &mut self, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + i: usize, + kv: &Ptr, + ) -> Result<(), SynthesisError> { + let (key, value) = s.car_cdr(kv).unwrap(); + let cs = &mut cs.namespace(|| format!("toplevel-{i}")); + let allocated_key = AllocatedPtr::alloc(&mut cs.namespace(|| "allocated_key"), || { + Ok(s.hash_ptr(&key)) + }) + .unwrap(); + + let acc = self.acc.clone().unwrap(); + let insertion_transcript = self.transcript.clone(); + + let (val, new_acc, new_transcript) = self.synthesize_query( + cs, + g, + s, + &allocated_key, + &acc, + &insertion_transcript, + &Boolean::Constant(true), + )?; + + if let Some(val_ptr) = val.get_value().map(|x| s.to_ptr(&x)) { + assert_eq!(value, val_ptr); + } + + self.acc = Some(new_acc); + self.transcript = new_transcript; + Ok(()) + } + + fn synthesize_prove_query>( + &mut self, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + i: usize, + key: &Ptr, + ) -> Result<(), SynthesisError> { + let cs = &mut cs.namespace(|| format!("internal-{i}")); + + let allocated_key = + AllocatedPtr::alloc( + &mut cs.namespace(|| "allocated_key"), + || Ok(s.hash_ptr(key)), + ) + .unwrap(); + + let circuit_query = + ScopeCircuitQuery::from_ptr(&mut cs.namespace(|| "circuit_query"), s, key).unwrap(); + + let acc = self.acc.clone().unwrap(); + let transcript = self.transcript.clone(); + + let (val, new_acc, new_transcript) = circuit_query + .expect("not a query form") + .synthesize_eval(&mut cs.namespace(|| "eval"), g, s, self, &acc, &transcript) + .unwrap(); + + let (new_acc, new_transcript) = + self.synthesize_remove(cs, g, s, &new_acc, &new_transcript, &allocated_key, &val)?; + + self.acc = Some(new_acc); + self.transcript = new_transcript; + + Ok(()) + } + + #[allow(dead_code)] + fn dbg_transcript(&self, s: &Store) { + self.transcript.dbg(s); + } +} + +pub trait MemoSet: Clone { + fn is_finalized(&self) -> bool; + fn finalize_transcript(&mut self, s: &Store, transcript: Transcript); + fn r(&self) -> Option<&F>; + fn map_to_element(&self, x: F) -> Option; + fn add(&mut self, kv: Ptr); + fn synthesize_remove_n>( + &self, + cs: &mut CS, + acc: &AllocatedNum, + kv: &AllocatedPtr, + count: &AllocatedNum, + ) -> Result, SynthesisError>; + fn count(&self, form: &Ptr) -> usize; + + // Circuit + + fn allocated_r>(&self, cs: &mut CS) -> AllocatedNum; + + // x is H(k,v) = hash part of (cons k v) + fn synthesize_map_to_element>( + &self, + cs: &mut CS, + x: AllocatedNum, + ) -> Result, SynthesisError>; + + fn synthesize_add>( + &self, + cs: &mut CS, + acc: &AllocatedNum, + kv: &AllocatedPtr, + ) -> Result, SynthesisError>; +} + +#[derive(Debug, Clone)] +pub struct LogMemo { + multiset: MultiSet, + r: OnceCell, + transcript: OnceCell>, + + allocated_r: OnceCell>>, +} + +impl Default for LogMemo { + fn default() -> Self { + // Be explicit. + Self { + multiset: MultiSet::new(), + r: Default::default(), + transcript: Default::default(), + allocated_r: Default::default(), + } + } +} + +impl MemoSet for LogMemo { + fn count(&self, form: &Ptr) -> usize { + self.multiset.get(form).unwrap_or(0) + } + + fn is_finalized(&self) -> bool { + self.transcript.get().is_some() + } + fn finalize_transcript(&mut self, s: &Store, transcript: Transcript) { + let r = transcript.r(s); + + self.r.set(r).expect("r has already been set"); + + self.transcript + .set(transcript) + .expect("transcript already finalized"); + } + + fn r(&self) -> Option<&F> { + self.r.get() + } + + fn allocated_r>(&self, cs: &mut CS) -> AllocatedNum { + self.allocated_r + .get_or_init(|| { + self.r() + .map(|r| AllocatedNum::alloc_infallible(&mut cs.namespace(|| "r"), || *r)) + }) + .clone() + .unwrap() + } + + // x is H(k,v) = hash part of (cons k v) + fn map_to_element(&self, x: F) -> Option { + self.r().and_then(|r| { + let d = *r + x; + d.invert().into() + }) + } + + // x is H(k,v) = hash part of (cons k v) + // 1 / r + x + fn synthesize_map_to_element>( + &self, + cs: &mut CS, + x: AllocatedNum, + ) -> Result, SynthesisError> { + let r = self.allocated_r(cs); + let r_plus_x = r.add(&mut cs.namespace(|| "r+x"), &x)?; + + invert(&mut cs.namespace(|| "invert(r+x)"), &r_plus_x) + } + + fn add(&mut self, kv: Ptr) { + self.multiset.add(kv); + } + + fn synthesize_add>( + &self, + cs: &mut CS, + acc: &AllocatedNum, + kv: &AllocatedPtr, + ) -> Result, SynthesisError> { + let kv_num = kv.hash().clone(); + let element = self.synthesize_map_to_element(&mut cs.namespace(|| "element"), kv_num)?; + acc.add(&mut cs.namespace(|| "add to acc"), &element) + } + + fn synthesize_remove_n>( + &self, + cs: &mut CS, + acc: &AllocatedNum, + kv: &AllocatedPtr, + count: &AllocatedNum, + ) -> Result, SynthesisError> { + let kv_num = kv.hash().clone(); + let element = self.synthesize_map_to_element(&mut cs.namespace(|| "element"), kv_num)?; + let scaled = element.mul(&mut cs.namespace(|| "scaled"), count)?; + sub(&mut cs.namespace(|| "add to acc"), acc, &scaled) + } +} + +#[cfg(test)] +mod test { + use super::*; + + use crate::state::State; + use bellpepper_core::{test_cs::TestConstraintSystem, Comparable}; + use expect_test::{expect, Expect}; + use pasta_curves::pallas::Scalar as F; + use std::default::Default; + + #[test] + fn test_query() { + let s = &Store::::default(); + let mut scope: Scope, LogMemo> = Scope::default(); + let state = State::init_lurk_state(); + + let fact_4 = s.read_with_default_state("(factorial 4)").unwrap(); + let fact_3 = s.read_with_default_state("(factorial 3)").unwrap(); + + let expect_eq = |computed: usize, expected: Expect| { + expected.assert_eq(&computed.to_string()); + }; + + { + scope.query(s, fact_4); + + for (k, v) in scope.queries.iter() { + println!("k: {}", k.fmt_to_string(s, &state)); + println!("v: {}", v.fmt_to_string(s, &state)); + } + // Factorial 4 will memoize calls to: + // fact(4), fact(3), fact(2), fact(1), and fact(0) + assert_eq!(5, scope.queries.len()); + assert_eq!(1, scope.toplevel_insertions.len()); + assert_eq!(4, scope.internal_insertions.len()); + + scope.finalize_transcript(s); + + let cs = &mut TestConstraintSystem::new(); + let g = &mut GlobalAllocator::default(); + + scope.synthesize(cs, g, s).unwrap(); + + println!( + "transcript: {}", + scope + .memoset + .transcript + .get() + .unwrap() + .fmt_to_string_simple(s) + ); + + expect_eq(cs.num_constraints(), expect!["10826"]); + expect_eq(cs.aux().len(), expect!["10859"]); + + let unsat = cs.which_is_unsatisfied(); + + if unsat.is_some() { + dbg!(unsat); + } + assert!(cs.is_satisfied()); + } + { + let mut scope: Scope, LogMemo> = Scope::default(); + scope.query(s, fact_4); + scope.query(s, fact_3); + + // // No new queries. + assert_eq!(5, scope.queries.len()); + // // One new top-level insertion. + assert_eq!(2, scope.toplevel_insertions.len()); + // // No new internal insertions. + assert_eq!(4, scope.internal_insertions.len()); + + scope.finalize_transcript(s); + + let cs = &mut TestConstraintSystem::new(); + let g = &mut GlobalAllocator::default(); + + scope.synthesize(cs, g, s).unwrap(); + + expect_eq(cs.num_constraints(), expect!["11408"]); + expect_eq(cs.aux().len(), expect!["11445"]); + + let unsat = cs.which_is_unsatisfied(); + if unsat.is_some() { + dbg!(unsat); + } + assert!(cs.is_satisfied()); + } + } +} diff --git a/src/coprocessor/memoset/multiset.rs b/src/coprocessor/memoset/multiset.rs new file mode 100644 index 0000000000..253cf2c5b0 --- /dev/null +++ b/src/coprocessor/memoset/multiset.rs @@ -0,0 +1,59 @@ +use std::collections::HashMap; +use std::default::Default; +use std::hash::Hash; + +#[derive(PartialEq, Eq, Debug, Default, Clone)] +pub(crate) struct MultiSet { + map: HashMap, + cardinality: usize, +} + +impl MultiSet { + pub(crate) fn new() -> Self { + Self { + map: Default::default(), + cardinality: 0, + } + } + pub(crate) fn add(&mut self, element: T) { + *self.map.entry(element).or_insert(0) += 1; + self.cardinality += 1; + } + + pub(crate) fn get(&self, element: &T) -> Option { + self.map.get(element).copied() + } + + #[allow(dead_code)] + pub(crate) fn cardinality(&self) -> usize { + self.cardinality + } + + #[allow(dead_code)] + pub(crate) fn len(&self) -> usize { + self.map.len() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_multiset() { + let mut m = MultiSet::::new(); + let mut c = 0; + let n = 5; + + for i in 1..n { + for _ in 0..i { + m.add(i); + } + c += i; + assert_eq!(i, m.len()); + assert_eq!(c, m.cardinality()); + assert_eq!(Some(i), m.get(&i)); + assert_eq!(None, m.get(&(i + n))); + } + } +} diff --git a/src/coprocessor/memoset/query.rs b/src/coprocessor/memoset/query.rs new file mode 100644 index 0000000000..fe85db927e --- /dev/null +++ b/src/coprocessor/memoset/query.rs @@ -0,0 +1,306 @@ +use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; + +use super::{CircuitScope, CircuitTranscript, LogMemo, Scope}; +use crate::circuit::gadgets::constraints::alloc_is_zero; +use crate::coprocessor::gadgets::construct_list; +use crate::coprocessor::AllocatedPtr; +use crate::field::LurkField; +use crate::lem::circuit::GlobalAllocator; +use crate::lem::{pointers::Ptr, store::Store}; +use crate::symbol::Symbol; +use crate::tag::{ExprTag, Tag}; + +pub trait Query +where + Self: Sized, +{ + fn eval(&self, s: &Store, scope: &mut Scope>) -> Ptr; + fn recursive_eval( + &self, + scope: &mut Scope>, + s: &Store, + subquery: Self, + ) -> Ptr; + fn from_ptr(s: &Store, ptr: &Ptr) -> Option; + fn to_ptr(&self, s: &Store) -> Ptr; + fn symbol(&self) -> Symbol; + fn symbol_ptr(&self, s: &Store) -> Ptr { + s.intern_symbol(&self.symbol()) + } + + fn index(&self) -> usize; +} + +#[allow(unreachable_pub)] +pub trait CircuitQuery +where + Self: Sized, +{ + type Q: Query; + + fn synthesize_eval>( + &self, + cs: &mut CS, + g: &GlobalAllocator, + store: &Store, + scope: &mut CircuitScope>, + acc: &AllocatedPtr, + transcript: &CircuitTranscript, + ) -> Result<(AllocatedPtr, AllocatedPtr, CircuitTranscript), SynthesisError>; + + fn symbol(&self, s: &Store) -> Symbol { + self.dummy_query_variant(s).symbol() + } + + fn symbol_ptr(&self, s: &Store) -> Ptr { + self.dummy_query_variant(s).symbol_ptr(s) + } + + fn from_ptr>( + cs: &mut CS, + s: &Store, + ptr: &Ptr, + ) -> Result, SynthesisError>; + + fn dummy_query_variant(&self, s: &Store) -> Self::Q; +} + +#[derive(Debug, Clone)] +pub enum DemoQuery { + Factorial(Ptr), + Phantom(F), +} + +pub enum DemoCircuitQuery { + Factorial(AllocatedPtr), +} + +impl Query for DemoQuery { + // DemoQuery and Scope depend on each other. + fn eval(&self, s: &Store, scope: &mut Scope>) -> Ptr { + match self { + Self::Factorial(n) => { + let n_zptr = s.hash_ptr(n); + let n = n_zptr.value(); + + if *n == F::ZERO { + s.num(F::ONE) + } else { + let m_ptr = self.recursive_eval(scope, s, Self::Factorial(s.num(*n - F::ONE))); + let m_zptr = s.hash_ptr(&m_ptr); + let m = m_zptr.value(); + + s.num(*n * m) + } + } + _ => unreachable!(), + } + } + + fn recursive_eval( + &self, + scope: &mut Scope>, + s: &Store, + subquery: Self, + ) -> Ptr { + scope.query_recursively(s, self, subquery) + } + + fn symbol(&self) -> Symbol { + match self { + Self::Factorial(_) => Symbol::sym(&["lurk", "user", "factorial"]), + _ => unreachable!(), + } + } + + fn from_ptr(s: &Store, ptr: &Ptr) -> Option { + let (head, body) = s.car_cdr(ptr).expect("query should be cons"); + let sym = s.fetch_sym(&head).expect("head should be sym"); + + if sym == Symbol::sym(&["lurk", "user", "factorial"]) { + let (num, _) = s.car_cdr(&body).expect("query body should be cons"); + Some(Self::Factorial(num)) + } else { + None + } + } + + fn to_ptr(&self, s: &Store) -> Ptr { + match self { + Self::Factorial(n) => { + let factorial = s.intern_symbol(&self.symbol()); + + s.list(vec![factorial, *n]) + } + _ => unreachable!(), + } + } + + fn index(&self) -> usize { + match self { + Self::Factorial(_) => 0, + _ => unreachable!(), + } + } +} + +impl CircuitQuery for DemoCircuitQuery { + type Q = DemoQuery; + + fn dummy_query_variant(&self, s: &Store) -> Self::Q { + match self { + Self::Factorial(_) => Self::Q::Factorial(s.num(F::ZERO)), + } + } + + fn synthesize_eval>( + &self, + cs: &mut CS, + g: &GlobalAllocator, + store: &Store, + scope: &mut CircuitScope>, + acc: &AllocatedPtr, + transcript: &CircuitTranscript, + ) -> Result<(AllocatedPtr, AllocatedPtr, CircuitTranscript), SynthesisError> { + match self { + // TODO: Factor out the recursive boilerplate so individual queries can just implement their distinct logic + // using a sane framework. + Self::Factorial(n) => { + // FIXME: Check n tag or decide not to. + let base_case_f = g.alloc_const(cs, F::ONE); + let base_case = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "base_case"), + ExprTag::Num.to_field(), + base_case_f.clone(), + )?; + + let n_is_zero = alloc_is_zero(&mut cs.namespace(|| "n_is_zero"), n.hash())?; + + let (recursive_result, recursive_acc, recursive_transcript) = { + let new_n = AllocatedNum::alloc(&mut cs.namespace(|| "new_n"), || { + n.hash() + .get_value() + .map(|n| n - F::ONE) + .ok_or(SynthesisError::AssignmentMissing) + })?; + + // new_n * 1 = n - 1 + cs.enforce( + || "enforce_new_n", + |lc| lc + new_n.get_variable(), + |lc| lc + CS::one(), + |lc| lc + n.hash().get_variable() - CS::one(), + ); + + let subquery = { + let symbol = + g.alloc_ptr(cs, &store.intern_symbol(&self.symbol(store)), store); + + let new_num = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "new_num"), + ExprTag::Num.to_field(), + new_n, + )?; + construct_list( + &mut cs.namespace(|| "subquery"), + g, + store, + &[&symbol, &new_num], + None, + )? + }; + + let (sub_result, new_acc, new_transcript) = scope.synthesize_query( + &mut cs.namespace(|| "recursive query"), + g, + store, + &subquery, + acc, + transcript, + &n_is_zero.not(), + )?; + + let result_f = n.hash().mul( + &mut cs.namespace(|| "incremental multiplication"), + sub_result.hash(), + )?; + + let result = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "result"), + ExprTag::Num.to_field(), + result_f, + )?; + + (result, new_acc, new_transcript) + }; + + let value = AllocatedPtr::pick( + &mut cs.namespace(|| "pick value"), + &n_is_zero, + &base_case, + &recursive_result, + )?; + + let acc = AllocatedPtr::pick( + &mut cs.namespace(|| "pick acc"), + &n_is_zero, + acc, + &recursive_acc, + )?; + + let transcript = CircuitTranscript::pick( + &mut cs.namespace(|| "pick recursive_transcript"), + &n_is_zero, + transcript, + &recursive_transcript, + )?; + + Ok((value, acc, transcript)) + } + } + } + + fn from_ptr>( + cs: &mut CS, + s: &Store, + ptr: &Ptr, + ) -> Result, SynthesisError> { + let query = Self::Q::from_ptr(s, ptr); + Ok(if let Some(q) = query { + match q { + Self::Q::Factorial(n) => Some(Self::Factorial(AllocatedPtr::alloc(cs, || { + Ok(s.hash_ptr(&n)) + })?)), + _ => unreachable!(), + } + } else { + None + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + + use ff::Field; + use pasta_curves::pallas::Scalar as F; + + #[test] + fn test_factorial() { + let s = Store::default(); + let mut scope: Scope, LogMemo> = Scope::default(); + let zero = s.num(F::ZERO); + let one = s.num(F::ONE); + let two = s.num(F::from_u64(2)); + let three = s.num(F::from_u64(3)); + let four = s.num(F::from_u64(4)); + let six = s.num(F::from_u64(6)); + let twenty_four = s.num(F::from_u64(24)); + assert_eq!(one, DemoQuery::Factorial(zero).eval(&s, &mut scope)); + assert_eq!(one, DemoQuery::Factorial(one).eval(&s, &mut scope)); + assert_eq!(two, DemoQuery::Factorial(two).eval(&s, &mut scope)); + assert_eq!(six, DemoQuery::Factorial(three).eval(&s, &mut scope)); + assert_eq!(twenty_four, DemoQuery::Factorial(four).eval(&s, &mut scope)); + } +} diff --git a/src/coprocessor/mod.rs b/src/coprocessor/mod.rs index abcdd1b903..ec942d4971 100644 --- a/src/coprocessor/mod.rs +++ b/src/coprocessor/mod.rs @@ -10,6 +10,7 @@ use crate::{ pub mod circom; pub mod gadgets; +pub mod memoset; pub mod sha256; pub mod trie; diff --git a/src/lem/store.rs b/src/lem/store.rs index dd29c05e1a..29cf756ad6 100644 --- a/src/lem/store.rs +++ b/src/lem/store.rs @@ -1181,6 +1181,10 @@ impl Ptr { } } + pub fn fmt_to_string_simple(&self, store: &Store) -> String { + self.fmt_to_string(store, crate::state::initial_lurk_state()) + } + fn fmt_cont2_to_string( &self, name: &str,