Skip to content

Commit

Permalink
refactor: simplify some memoset generics (#1045)
Browse files Browse the repository at this point in the history
* refactor: Refactor CircuitQuery and Scope type handling

Relies on associated type projection and method parameters rather than binding generic parameters.

* refactor: Refactor Scope struct to remove redundant field parameter

Query and Memoset already have an inner parameter
  • Loading branch information
huitseeker authored Jan 12, 2024
1 parent a9a5f10 commit 29ea314
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 27 deletions.
8 changes: 4 additions & 4 deletions src/coprocessor/memoset/demo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl<F: LurkField> Query<F> for DemoQuery<F> {
type CQ = DemoCircuitQuery<F>;

// DemoQuery and Scope depend on each other.
fn eval(&self, s: &Store<F>, scope: &mut Scope<F, Self, LogMemo<F>>) -> Ptr {
fn eval(&self, s: &Store<F>, scope: &mut Scope<Self, LogMemo<F>>) -> Ptr {
match self {
Self::Factorial(n) => {
let n_zptr = s.hash_ptr(n);
Expand All @@ -51,7 +51,7 @@ impl<F: LurkField> Query<F> for DemoQuery<F> {

fn recursive_eval(
&self,
scope: &mut Scope<F, Self, LogMemo<F>>,
scope: &mut Scope<Self, LogMemo<F>>,
s: &Store<F>,
subquery: Self,
) -> Ptr {
Expand Down Expand Up @@ -102,7 +102,7 @@ impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {
cs: &mut CS,
g: &GlobalAllocator<F>,
store: &Store<F>,
scope: &mut CircuitScope<F, Self, LogMemo<F>>,
scope: &mut CircuitScope<F, LogMemo<F>>,
acc: &AllocatedPtr<F>,
transcript: &CircuitTranscript<F>,
) -> Result<(AllocatedPtr<F>, AllocatedPtr<F>, CircuitTranscript<F>), SynthesisError> {
Expand Down Expand Up @@ -238,7 +238,7 @@ mod test {
#[test]
fn test_factorial() {
let s = Store::default();
let mut scope: Scope<F, DemoQuery<F>, LogMemo<F>> = Scope::default();
let mut scope: Scope<DemoQuery<F>, LogMemo<F>> = Scope::default();
let zero = s.num(F::ZERO);
let one = s.num(F::ONE);
let two = s.num(F::from_u64(2));
Expand Down
36 changes: 16 additions & 20 deletions src/coprocessor/memoset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::tag::{ExprTag, Tag as XTag};
use crate::z_ptr::ZPtr;

use multiset::MultiSet;
use query::{CircuitQuery, Query};
pub use query::{CircuitQuery, Query};

mod demo;
mod multiset;
Expand Down Expand Up @@ -180,7 +180,7 @@ impl<F: LurkField> CircuitTranscript<F> {
/// A `Scope` tracks the queries made while evaluating, including the subqueries that result from evaluating other
/// queries -- then makes use of the bookkeeping performed at evaluation time to synthesize proof of each query
/// performed.
pub struct Scope<F, Q, M> {
pub struct Scope<Q, M> {
memoset: M,
/// k => v
queries: HashMap<Ptr, Ptr>,
Expand All @@ -192,10 +192,9 @@ pub struct Scope<F, Q, M> {
internal_insertions: Vec<Ptr>,
/// unique keys
all_insertions: Vec<Ptr>,
_p: PhantomData<F>,
}

impl<F: LurkField, Q> Default for Scope<F, Q, LogMemo<F>> {
impl<F: LurkField, Q> Default for Scope<Q, LogMemo<F>> {
fn default() -> Self {
Self {
memoset: Default::default(),
Expand All @@ -204,22 +203,20 @@ impl<F: LurkField, Q> Default for Scope<F, Q, LogMemo<F>> {
toplevel_insertions: Default::default(),
internal_insertions: Default::default(),
all_insertions: Default::default(),
_p: Default::default(),
}
}
}

pub struct CircuitScope<F: LurkField, CQ: CircuitQuery<F>, M: MemoSet<F>> {
pub struct CircuitScope<F: LurkField, M> {
memoset: M,
/// k -> v
queries: HashMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
/// k -> allocated v
transcript: CircuitTranscript<F>,
acc: Option<AllocatedPtr<F>>,
_p: PhantomData<CQ>,
}

impl<F: LurkField, Q: Query<F>> Scope<F, Q, LogMemo<F>> {
impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>> {
pub fn query(&mut self, s: &Store<F>, form: Ptr) -> Ptr {
let (response, kv_ptr) = self.query_aux(s, form);

Expand Down Expand Up @@ -357,12 +354,12 @@ impl<F: LurkField, Q: Query<F>> Scope<F, Q, LogMemo<F>> {
}
}

impl<F: LurkField, CQ: CircuitQuery<F>> CircuitScope<F, CQ, LogMemo<F>> {
fn from_scope<CS: ConstraintSystem<F>, Q: Query<F, CQ = CQ>>(
impl<F: LurkField> CircuitScope<F, LogMemo<F>> {
fn from_scope<CS: ConstraintSystem<F>, Q: Query<F>>(
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
scope: &Scope<F, Q, LogMemo<F>>,
scope: &Scope<Q, LogMemo<F>>,
) -> Self {
let queries = scope
.queries
Expand All @@ -374,7 +371,6 @@ impl<F: LurkField, CQ: CircuitQuery<F>> CircuitScope<F, CQ, LogMemo<F>> {
queries,
transcript: CircuitTranscript::new(cs, g, s),
acc: Default::default(),
_p: Default::default(),
}
}

Expand Down Expand Up @@ -496,9 +492,9 @@ impl<F: LurkField, CQ: CircuitQuery<F>> CircuitScope<F, CQ, LogMemo<F>> {
Ok((value, new_acc, new_insertion_transcript))
}

fn synthesize_insert_toplevel_queries<CS: ConstraintSystem<F>, Q: Query<F, CQ = CQ>>(
fn synthesize_insert_toplevel_queries<CS: ConstraintSystem<F>, Q: Query<F>>(
&mut self,
scope: &mut Scope<F, Q, LogMemo<F>>,
scope: &mut Scope<Q, LogMemo<F>>,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
Expand Down Expand Up @@ -546,20 +542,20 @@ impl<F: LurkField, CQ: CircuitQuery<F>> CircuitScope<F, CQ, LogMemo<F>> {
Ok(())
}

fn synthesize_prove_all_queries<CS: ConstraintSystem<F>, Q: Query<F, CQ = CQ>>(
fn synthesize_prove_all_queries<CS: ConstraintSystem<F>, Q: Query<F>>(
&mut self,
scope: &mut Scope<F, Q, LogMemo<F>>,
scope: &mut Scope<Q, LogMemo<F>>,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
) -> Result<(), SynthesisError> {
for (i, kv) in scope.all_insertions.iter().enumerate() {
self.synthesize_prove_query(cs, g, s, i, kv)?;
self.synthesize_prove_query::<_, Q::CQ>(cs, g, s, i, kv)?;
}
Ok(())
}

fn synthesize_prove_query<CS: ConstraintSystem<F>>(
fn synthesize_prove_query<CS: ConstraintSystem<F>, CQ: CircuitQuery<F>>(
&mut self,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
Expand Down Expand Up @@ -752,7 +748,7 @@ mod test {
#[test]
fn test_query() {
let s = &Store::<F>::default();
let mut scope: Scope<F, DemoQuery<F>, LogMemo<F>> = Scope::default();
let mut scope: Scope<DemoQuery<F>, LogMemo<F>> = Scope::default();
let state = State::init_lurk_state();

let fact_4 = s.read_with_default_state("(factorial 4)").unwrap();
Expand Down Expand Up @@ -803,7 +799,7 @@ mod test {
assert!(cs.is_satisfied());
}
{
let mut scope: Scope<F, DemoQuery<F>, LogMemo<F>> = Scope::default();
let mut scope: Scope<DemoQuery<F>, LogMemo<F>> = Scope::default();
scope.query(s, fact_4);
scope.query(s, fact_3);

Expand Down
6 changes: 3 additions & 3 deletions src/coprocessor/memoset/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ where
{
type CQ: CircuitQuery<F>;

fn eval(&self, s: &Store<F>, scope: &mut Scope<F, Self, LogMemo<F>>) -> Ptr;
fn eval(&self, s: &Store<F>, scope: &mut Scope<Self, LogMemo<F>>) -> Ptr;
fn recursive_eval(
&self,
scope: &mut Scope<F, Self, LogMemo<F>>,
scope: &mut Scope<Self, LogMemo<F>>,
s: &Store<F>,
subquery: Self,
) -> Ptr;
Expand All @@ -39,7 +39,7 @@ where
cs: &mut CS,
g: &GlobalAllocator<F>,
store: &Store<F>,
scope: &mut CircuitScope<F, Self, LogMemo<F>>,
scope: &mut CircuitScope<F, LogMemo<F>>,
acc: &AllocatedPtr<F>,
transcript: &CircuitTranscript<F>,
) -> Result<(AllocatedPtr<F>, AllocatedPtr<F>, CircuitTranscript<F>), SynthesisError>;
Expand Down

1 comment on commit 29ea314

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Fibonacci GPU benchmark.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
125.78 GB RAM
Workflow run: https://github.com/lurk-lab/lurk-rs/actions/runs/7505936732

Benchmark Results

LEM Fibonacci Prove - rc = 100

fib-ref=a9a5f10aeb79f22d7a07fccb9321209358ecc2b4 fib-ref=29ea314f590390d9a27bbcd8125adfabd318dcac
num-100 1.74 s (✅ 1.00x) 1.74 s (✅ 1.00x faster)
num-200 3.37 s (✅ 1.00x) 3.36 s (✅ 1.00x faster)

LEM Fibonacci Prove - rc = 600

fib-ref=a9a5f10aeb79f22d7a07fccb9321209358ecc2b4 fib-ref=29ea314f590390d9a27bbcd8125adfabd318dcac
num-100 2.03 s (✅ 1.00x) 2.03 s (✅ 1.00x slower)
num-200 3.38 s (✅ 1.00x) 3.39 s (✅ 1.00x slower)

Made with criterion-table

Please sign in to comment.