Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LEM match values #1190

Merged
merged 3 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions src/lem/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,9 @@ impl Block {
def.alloc_consts(cs, store, g, lang);
}
}
Ctrl::MatchSymbol(_, cases, def) => {
g.alloc_tag(cs, &Sym);
Ctrl::MatchValue(_, lit_type, cases, def) => {
let tag = lit_type.tag();
g.alloc_tag(cs, &tag);
for block in cases.values() {
block.alloc_consts(cs, store, g, lang);
}
Expand Down Expand Up @@ -1307,14 +1308,15 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
// The number of slots the match used is the max number of slots of each branch
*next_slot = next_slot.fold_max(branch_slots);
}
Ctrl::MatchSymbol(match_var, cases, def) => {
Ctrl::MatchValue(match_var, lit_type, cases, def) => {
let tag = lit_type.tag();
let match_var_ptr = bound_allocations.get_ptr(match_var)?.clone();

let mut cases_vec = Vec::with_capacity(cases.len());
for (sym, block) in cases {
let sym_ptr = ctx.store.intern_symbol(sym);
let sym_hash = *ctx.store.hash_ptr(&sym_ptr).value();
cases_vec.push((sym_hash, block));
for (lit, block) in cases {
let lit_ptr = lit.to_ptr(ctx.store);
let lit_hash = *ctx.store.hash_ptr(&lit_ptr).value();
cases_vec.push((lit_hash, block));
}

let branch_slots = synthesize_match(
Expand All @@ -1325,13 +1327,13 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
ctx,
)?;

// Now we enforce `MatchSymbol`'s tag
let sym_tag = ctx.global_allocator.alloc_tag(cs, &Sym);
// Now we enforce `MatchValue`'s tag
let lit_tag = ctx.global_allocator.alloc_tag(cs, &tag);
implies_equal(
ns!(cs, format!("implies equal {match_var}.tag")),
not_dummy,
match_var_ptr.tag(),
sym_tag,
lit_tag,
);

// The number of slots the match used is the max number of slots of each branch
Expand Down Expand Up @@ -1643,11 +1645,11 @@ impl Func {
}
num_constraints
}
Ctrl::MatchSymbol(_, cases, def) => {
// First we enforce that the tag of the pointer being matched on
// is Sym
Ctrl::MatchValue(_, lit_type, cases, def) => {
let tag = lit_type.tag();
// First we enforce the tag of the pointer being matched
num_constraints += 1;
globals.insert(FWrap(Sym.to_field()));
globals.insert(FWrap(tag.to_field()));
// We allocate one boolean per case and constrain it once
// per case. Then we add 1 constraint to enforce only one
// case was selected
Expand Down
109 changes: 57 additions & 52 deletions src/lem/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use super::{
interpreter::{Frame, Hints},
pointers::{Ptr, RawPtr},
store::{fetch_ptrs, Store},
Ctrl, Func, Op, Tag, Var,
Ctrl, Func, Lit, LitType, Op, Tag, Var,
};

static EVAL_STEP: OnceCell<Func> = OnceCell::new();
Expand Down Expand Up @@ -336,7 +336,7 @@ fn car_cdr() -> Func {
/// match cproc.tag {
/// Expr::Cproc => {
/// let (cproc_name, evaluated_args) = decons2(cproc);
/// match symbol cproc_name {
/// match cproc_name.value {
/// // `x` is the name of the coprocessor being called
/// x => {
/// // `n` is the arity of the coprocessor `x`
Expand Down Expand Up @@ -426,13 +426,13 @@ fn run_cproc(cproc_sym: Symbol, arity: usize) -> Func {
}
}

// MatchSymbol
// MatchValue
block = Block {
ops: vec![Op::Decons2(
[cproc_name.clone(), evaluated_args],
cproc.clone(),
)],
ctrl: Ctrl::match_symbol(cproc_name, vec![(cproc_sym, block)], None),
ctrl: Ctrl::match_symbol(cproc_name, vec![(Lit::Symbol(cproc_sym), block)], None),
};

// MatchTag
Expand Down Expand Up @@ -473,7 +473,7 @@ pub fn make_cprocs_funcs_from_lang<F: LurkField, C: Coprocessor<F>>(
/// let nil = Symbol("nil");
/// let nil = cast(nil, Expr::Nil);
/// let t = Symbol("t");
/// match symbol head {
/// match head.value {
/// // one arm for each coprocessor in the `Lang`
/// ... => {
/// return (t)
Expand All @@ -498,7 +498,12 @@ fn is_cproc(cprocs: &[(&Symbol, usize)]) -> Func {
];
let match_symbol_cases = cprocs
.iter()
.map(|(cproc, _)| ((*cproc).clone(), Block::ctrl(ctrl!(return (t)))))
.map(|(cproc, _)| {
(
Lit::Symbol((*cproc).clone()),
Block::ctrl(ctrl!(return (t))),
)
})
.collect();
let def = Some(Block::ctrl(ctrl!(return (nil))));
let ctrl = Ctrl::match_symbol(head.clone(), match_symbol_cases, def);
Expand All @@ -525,7 +530,7 @@ fn is_cproc(cprocs: &[(&Symbol, usize)]) -> Func {
/// let makethunk = Symbol("make-thunk");
/// let errctrl = Symbol("error");
/// let ret = Symbol("return");
/// match symbol cproc_name {
/// match cproc_name.value {
/// x => {
/// // `n` is the arity of the coprocessor `x`
/// let is_nil = eq_tag(evaluated_args, nil);
Expand Down Expand Up @@ -632,9 +637,9 @@ fn match_and_run_cproc(cprocs: &[(&Symbol, usize)]) -> Func {
ctrl: Ctrl::if_(is_nil.clone(), err_block.clone(), block),
}
}
match_symbol_map.insert(cproc.clone(), block);
match_symbol_map.insert(Lit::Symbol(cproc.clone()), block);
}
let ctrl = Ctrl::MatchSymbol(cproc_name.clone(), match_symbol_map, None);
let ctrl = Ctrl::MatchValue(cproc_name.clone(), LitType::Symbol, match_symbol_map, None);
let func_inp = vec![cproc_name, evaluated_args, env, cont];
let ops = if max_arity == 0 {
vec![
Expand Down Expand Up @@ -677,48 +682,48 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
let get_unop = aux_func!(get_unop(head): 1 => {
let nil = Symbol("nil");
let nil = cast(nil, Expr::Nil);
match symbol head {
"car" => {
match head.value {
Symbol("car") => {
let op: Op1::Car;
return (op);
}
"cdr" => {
Symbol("cdr") => {
let op: Op1::Cdr;
return (op);
}
"commit" => {
Symbol("commit") => {
let op: Op1::Commit;
return (op);
}
"num" => {
Symbol("num") => {
let op: Op1::Num;
return (op);
}
"u64" => {
Symbol("u64") => {
let op: Op1::U64;
return (op);
}
"comm" => {
Symbol("comm") => {
let op: Op1::Comm;
return (op);
}
"char" => {
Symbol("char") => {
let op: Op1::Char;
return (op);
}
"open" => {
Symbol("open") => {
let op: Op1::Open;
return (op);
}
"secret" => {
Symbol("secret") => {
let op: Op1::Secret;
return (op);
}
"atom" => {
Symbol("atom") => {
let op: Op1::Atom;
return (op);
}
"emit" => {
Symbol("emit") => {
let op: Op1::Emit;
return (op);
}
Expand All @@ -728,60 +733,60 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
let get_binop = aux_func!(get_binop(head): 1 => {
let nil = Symbol("nil");
let nil = cast(nil, Expr::Nil);
match symbol head {
"cons" => {
match head.value {
Symbol("cons") => {
let op: Op2::Cons;
return (op);
}
"strcons" => {
Symbol("strcons") => {
let op: Op2::StrCons;
return (op);
}
"hide" => {
Symbol("hide") => {
let op: Op2::Hide;
return (op);
}
"+" => {
Symbol("+") => {
let op: Op2::Sum;
return (op);
}
"-" => {
Symbol("-") => {
let op: Op2::Diff;
return (op);
}
"*" => {
Symbol("*") => {
let op: Op2::Product;
return (op);
}
"/" => {
Symbol("/") => {
let op: Op2::Quotient;
return (op);
}
"%" => {
Symbol("%") => {
let op: Op2::Modulo;
return (op);
}
"=" => {
Symbol("=") => {
let op: Op2::NumEqual;
return (op);
}
"eq" => {
Symbol("eq") => {
let op: Op2::Equal;
return (op);
}
"<" => {
Symbol("<") => {
let op: Op2::Less;
return (op);
}
">" => {
Symbol(">") => {
let op: Op2::Greater;
return (op);
}
"<=" => {
Symbol("<=") => {
let op: Op2::LessEqual;
return (op);
}
">=" => {
Symbol(">=") => {
let op: Op2::GreaterEqual;
return (op);
}
Expand Down Expand Up @@ -891,11 +896,11 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
let (res, res_env, state) = lookup(res, res_env, state);
let (res, res_env, state) = lookup(res, res_env, state);
let (res, res_env, state) = lookup(res, res_env, state);
match symbol state {
"error" => {
match state.value {
Symbol("error") => {
return (expr, env, err, errctrl)
}
"found" => {
Symbol("found") => {
match res.tag {
// if `val2` is a recursive closure, then extend its environment
Expr::Rec => {
Expand All @@ -909,7 +914,7 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
};
return (res, res_env, cont, apply)
}
"not_found" => {
Symbol("not_found") => {
// if it's not yet found, we must keep reducing
return (res, res_env, cont, ret)
}
Expand Down Expand Up @@ -971,8 +976,8 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
};
return (expr, env, err, errctrl)
}
match symbol head {
"lambda" => {
match head.value {
Symbol("lambda") => {
let (vars, rest) = car_cdr(rest);
let rest_nil = eq_tag(rest, nil);
if rest_nil {
Expand Down Expand Up @@ -1001,7 +1006,7 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
};
return (expr, env, err, errctrl)
}
"quote" => {
Symbol("quote") => {
let (quoted, end) = car_cdr(rest);

match end.tag {
Expand All @@ -1011,7 +1016,7 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
};
return (expr, env, err, errctrl)
}
"begin" => {
Symbol("begin") => {
let (arg1, more) = car_cdr(rest);
match more.tag {
Expr::Nil => {
Expand All @@ -1022,7 +1027,7 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
let cont: Cont::Binop = cons4(op, env, more, cont);
return (arg1, env, cont, ret)
}
"eval" => {
Symbol("eval") => {
match rest.tag {
Expr::Nil => {
return (expr, env, err, errctrl)
Expand All @@ -1040,7 +1045,7 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
let cont: Cont::Binop = cons4(op, env, more, cont);
return (arg1, env, cont, ret)
}
"if" => {
Symbol("if") => {
let (condition, more) = car_cdr(rest);
match more.tag {
Expr::Nil => {
Expand All @@ -1050,7 +1055,7 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
let cont: Cont::If = cons4(more, env, cont, foo);
return (condition, env, cont, ret)
}
"empty-env" => {
Symbol("empty-env") => {
match rest.tag {
Expr::Nil => {
let empty_env: Expr::Env;
Expand All @@ -1059,7 +1064,7 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func {
};
return (expr, env, err, errctrl)
}
"current-env" => {
Symbol("current-env") => {
match rest.tag {
Expr::Nil => {
return (env, env, cont, apply)
Expand Down Expand Up @@ -1213,8 +1218,8 @@ fn apply_cont(cprocs: &[(&Symbol, usize)], ivc: bool) -> Func {
});
let choose_cproc_call = choose_cproc_call(cprocs, ivc);
aux_func!(apply_cont(result, env, cont, ctrl): 4 => {
match symbol ctrl {
"apply-continuation" => {
match ctrl.value {
Symbol("apply-continuation") => {
let makethunk = Symbol("make-thunk");

let errctrl = Symbol("error");
Expand Down Expand Up @@ -1723,8 +1728,8 @@ fn apply_cont(cprocs: &[(&Symbol, usize)], ivc: bool) -> Func {

fn make_thunk() -> Func {
aux_func!(make_thunk(expr, env, cont, ctrl): 3 => {
match symbol ctrl {
"make-thunk" => {
match ctrl.value {
Symbol("make-thunk") => {
match cont.tag {
Cont::Outermost => {
let empty_env: Expr::Env;
Expand Down
Loading