From 9eaff01c7078f001c26906299318bc9c54a7b397 Mon Sep 17 00:00:00 2001 From: Cameron Low Date: Wed, 30 Oct 2024 15:27:46 +0000 Subject: [PATCH] Improve code position (match + extended assignments) (#654) - `^match` pattern now allowed within a code pos - `#cname` can be used to select the appropriate sub-branch of a match, e.g. `^match#Some.1` - `^lv<@` and `^lv<$` are now permitted --- src/ecMatching.ml | 56 ++++++++++++++++++++++++++++++++---------- src/ecMatching.mli | 19 ++++++++++---- src/ecParser.mly | 14 ++++++++--- src/ecParsetree.ml | 8 +++--- src/ecPrinting.ml | 21 +++++++++++----- src/ecTyping.ml | 15 +++++++++-- src/phl/ecPhlInline.ml | 4 +++ tests/match_codepos.ec | 40 ++++++++++++++++++++++++++++++ 8 files changed, 144 insertions(+), 33 deletions(-) create mode 100644 tests/match_codepos.ec diff --git a/src/ecMatching.ml b/src/ecMatching.ml index 867cf67ad..fd7f25633 100644 --- a/src/ecMatching.ml +++ b/src/ecMatching.ml @@ -19,8 +19,9 @@ module Position = struct | `If | `While | `Assign of lvmatch - | `Sample - | `Call + | `Sample of lvmatch + | `Call of lvmatch + | `Match ] and lvmatch = [ `LvmNone | `LvmVar of EcTypes.prog_var ] @@ -30,9 +31,10 @@ module Position = struct | `ByMatch of int option * cp_match ] - type codepos1 = int * cp_base - type codepos = (codepos1 * int) list * codepos1 - type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1] + type codepos_brsel = [`Cond of bool | `Match of EcSymbols.symbol] + type codepos1 = int * cp_base + type codepos = (codepos1 * codepos_brsel) list * codepos1 + type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1] let shift1 ~(offset : int) ((o, p) : codepos1) : codepos1 = (o + offset, p) @@ -57,12 +59,19 @@ module Zipper = struct type ('a, 'state) folder = 'a -> 'state -> instr -> 'state * instr list + type spath_match_ctxt = { + locals : (EcIdent.t * ty) list; + prebr : ((EcIdent.t * ty) list * stmt) list; + postbr : ((EcIdent.t * ty) list * stmt) list; + } + type ipath = | ZTop | ZWhile of expr * spath | ZIfThen of expr * spath * stmt | ZIfElse of expr * stmt * spath - + | ZMatch of expr * spath * spath_match_ctxt + and spath = (instr list * instr list) * ipath type zipper = { @@ -95,9 +104,12 @@ module Zipper = struct match ir.i_node, cm with | Swhile _, `While -> i-1 | Sif _, `If -> i-1 - | Srnd _, `Sample -> i-1 - | Scall _, `Call -> i-1 + | Smatch _, `Match -> i-1 + + | Scall (None, _, _), `Call `LvmNone -> i-1 + | Scall (Some lv, _, _), `Call lvm + | Srnd (lv, _), `Sample lvm | Sasgn (lv, _), `Assign lvm -> begin match lv, lvm with | _, `LvmNone -> i-1 @@ -178,23 +190,34 @@ module Zipper = struct let zipper_at_nm_cpos1 (env : EcEnv.env) - ((cp1, sub) : codepos1 * int) + ((cp1, sub) : codepos1 * codepos_brsel) (s : stmt) (zpr : ipath) - : (ipath * stmt) * (codepos1 * int) + : (ipath * stmt) * (codepos1 * codepos_brsel) = let (s1, i, s2) = find_by_cpos1 env cp1 s in let zpr = match i.i_node, sub with - | Swhile (e, sw), 0 -> + | Swhile (e, sw), `Cond true -> (ZWhile (e, ((s1, s2), zpr)), sw) - | Sif (e, ifs1, ifs2), 0 -> + | Sif (e, ifs1, ifs2), `Cond true -> (ZIfThen (e, ((s1, s2), zpr), ifs2), ifs1) - | Sif (e, ifs1, ifs2), 1 -> + | Sif (e, ifs1, ifs2), `Cond false -> (ZIfElse (e, ifs1, ((s1, s2), zpr)), ifs2) + | Smatch (e, bs), `Match cn -> + let _, indt, _ = oget (EcEnv.Ty.get_top_decl e.e_ty env) in + let indt = oget (EcDecl.tydecl_as_datatype indt) in + let cnames = List.fst indt.tydt_ctors in + let ix, _ = + try List.findi (fun _ n -> EcSymbols.sym_equal cn n) cnames + with Not_found -> raise InvalidCPos + in + let prebr, (locals, body), postbr = List.pivot_at ix bs in + (ZMatch (e, ((s1, s2), zpr), { locals; prebr; postbr; }), body) + | _ -> raise InvalidCPos in zpr, ((0, `ByPos (1 + List.length s1)), sub) @@ -228,6 +251,8 @@ module Zipper = struct | ZWhile (e, sp) -> zip (Some (i_while (e, s))) sp | ZIfThen (e, sp, se) -> zip (Some (i_if (e, s, se))) sp | ZIfElse (e, se, sp) -> zip (Some (i_if (e, se, s))) sp + | ZMatch (e, sp, mpi) -> + zip (Some (i_match (e, mpi.prebr @ (mpi.locals, s) :: mpi.postbr))) sp let zip zpr = zip None ((zpr.z_head, zpr.z_tail), zpr.z_path) @@ -238,6 +263,7 @@ module Zipper = struct | ZWhile (_, ((_, is), ip)) -> doit (is :: acc) ip | ZIfThen (_, ((_, is), ip), _) -> doit (is :: acc) ip | ZIfElse (_, _, ((_, is), ip)) -> doit (is :: acc) ip + | ZMatch (_, ((_, is), ip), _) -> doit (is :: acc) ip in let after = @@ -1298,6 +1324,10 @@ module RegexpBaseInstr = struct let z' = zipper head tail path in next_zipper z' + | ZMatch (_, ((head, tail), path), _) -> + let z' = zipper head tail path in + next_zipper z' + let next (e : engine) = next_zipper e.e_zipper |> omap (fun z -> { e with e_zipper = z; e_pos = List.length z.z_head }) diff --git a/src/ecMatching.mli b/src/ecMatching.mli index c792f6ef7..9961f1c24 100644 --- a/src/ecMatching.mli +++ b/src/ecMatching.mli @@ -14,9 +14,10 @@ module Position : sig type cp_match = [ | `If | `While + | `Match | `Assign of lvmatch - | `Sample - | `Call + | `Sample of lvmatch + | `Call of lvmatch ] and lvmatch = [ `LvmNone | `LvmVar of EcTypes.prog_var ] @@ -26,9 +27,10 @@ module Position : sig | `ByMatch of int option * cp_match ] - type codepos1 = int * cp_base - type codepos = (codepos1 * int) list * codepos1 - type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1] + type codepos_brsel = [`Cond of bool | `Match of EcSymbols.symbol] + type codepos1 = int * cp_base + type codepos = (codepos1 * codepos_brsel) list * codepos1 + type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1] val shift1 : offset:int -> codepos1 -> codepos1 val shift : offset:int -> codepos -> codepos @@ -40,11 +42,18 @@ end module Zipper : sig open Position + type spath_match_ctxt = { + locals : (EcIdent.t * ty) list; + prebr : ((EcIdent.t * ty) list * stmt) list; + postbr : ((EcIdent.t * ty) list * stmt) list; + } + type ipath = | ZTop | ZWhile of expr * spath | ZIfThen of expr * spath * stmt | ZIfElse of expr * stmt * spath + | ZMatch of expr * spath * spath_match_ctxt and spath = (instr list * instr list) * ipath diff --git a/src/ecParser.mly b/src/ecParser.mly index 95fef6a9d..71dfbc9b2 100644 --- a/src/ecParser.mly +++ b/src/ecParser.mly @@ -2606,9 +2606,10 @@ tac_dir: icodepos_r: | IF { (`If :> pcp_match) } | WHILE { (`While :> pcp_match) } -| LESAMPLE { (`Sample :> pcp_match) } -| LEAT { (`Call :> pcp_match) } +| MATCH { (`Match :> pcp_match) } +| lvm=lvmatch LESAMPLE { (`Sample lvm :> pcp_match) } +| lvm=lvmatch LEAT { (`Call lvm :> pcp_match) } | lvm=lvmatch LARROW { (`Assign lvm :> pcp_match) } lvmatch: @@ -2631,9 +2632,14 @@ codepos1: | cp=codepos1_wo_off AMP PLUS i=word { ( i, cp) } | cp=codepos1_wo_off AMP MINUS i=word { (-i, cp) } +branch_select: +| SHARP s=boident DOT {`Match s} +| DOT { `Cond true } +| QUESTION { `Cond false } + %inline nm1_codepos: -| i=codepos1 k=ID(DOT { 0 } | QUESTION { 1 } ) - { (i, k) } +| i=codepos1 bs=branch_select + { (i, bs) } codepos: | nm=rlist0(nm1_codepos, empty) i=codepos1 diff --git a/src/ecParsetree.ml b/src/ecParsetree.ml index 5c12aabea..53eda26b8 100644 --- a/src/ecParsetree.ml +++ b/src/ecParsetree.ml @@ -490,17 +490,19 @@ type preduction = { type pcp_match = [ | `If | `While + | `Match | `Assign of plvmatch - | `Sample - | `Call + | `Sample of plvmatch + | `Call of plvmatch ] and plvmatch = [ `LvmNone | `LvmVar of pqsymbol ] type pcp_base = [ `ByPos of int | `ByMatch of int option * pcp_match ] +type pbranch_select = [`Cond of bool | `Match of psymbol] type pcodepos1 = int * pcp_base -type pcodepos = (pcodepos1 * int) list * pcodepos1 +type pcodepos = (pcodepos1 * pbranch_select) list * pcodepos1 type pdocodepos1 = pcodepos1 doption option type pcodeoffset1 = [ diff --git a/src/ecPrinting.ml b/src/ecPrinting.ml index 4f51dc7a1..9b5232c5b 100644 --- a/src/ecPrinting.ml +++ b/src/ecPrinting.ml @@ -2124,9 +2124,12 @@ let pp_codepos1 (ppe : PPEnv.t) (fmt : Format.formatter) ((off, cp) : CP.codepos let k = match k with | `If -> "if" + | `Match -> "match" | `While -> "while" - | `Sample -> "<$" - | `Call -> "<@" + | `Sample `LvmNone -> "<$" + | `Sample (`LvmVar pv) -> Format.asprintf "%a<$" (pp_pv ppe) pv + | `Call `LvmNone -> "<@" + | `Call (`LvmVar pv) -> Format.asprintf "%a<@" (pp_pv ppe) pv | `Assign `LvmNone -> "<-" | `Assign (`LvmVar pv) -> Format.asprintf "%a<-" (pp_pv ppe) pv in Format.asprintf "^%s" k in @@ -2146,14 +2149,20 @@ let pp_codeoffset1 (ppe : PPEnv.t) (fmt : Format.formatter) (offset : CP.codeoff match offset with | `ByPosition p -> Format.fprintf fmt "%a" (pp_codepos1 ppe) p | `ByOffset o -> Format.fprintf fmt "%d" o - + (* -------------------------------------------------------------------- *) let pp_codepos (ppe : PPEnv.t) (fmt : Format.formatter) ((nm, cp1) : CP.codepos) = - let pp_nm (fmt : Format.formatter) ((cp, i) : CP.codepos1 * int) = - Format.eprintf "%a%s" (pp_codepos1 ppe) cp (if i = 0 then "." else "?") + let pp_nm (fmt : Format.formatter) ((cp, bs) : CP.codepos1 * CP.codepos_brsel) = + let bs = + match bs with + | `Cond true -> "." + | `Cond false -> "?" + | `Match cp -> Format.sprintf "#%s." cp + in + Format.fprintf fmt "%a%s" (pp_codepos1 ppe) cp bs in - Format.eprintf "%a%a" (pp_list "" pp_nm) nm (pp_codepos1 ppe) cp1 + Format.fprintf fmt "%a%a" (pp_list "" pp_nm) nm (pp_codepos1 ppe) cp1 (* -------------------------------------------------------------------- *) let pp_opdecl_pr (ppe : PPEnv.t) fmt (basename, ts, ty, op) = diff --git a/src/ecTyping.ml b/src/ecTyping.ml index 8c8d0fc09..b22c2a63d 100644 --- a/src/ecTyping.ml +++ b/src/ecTyping.ml @@ -3455,8 +3455,12 @@ let trans_lv_match ?(memory : memory option) (env : EcEnv.env) (p : plvmatch) : (* -------------------------------------------------------------------- *) let trans_cp_match ?(memory : memory option) (env : EcEnv.env) (p : pcp_match) : cp_match = match p with - | (`Sample | `While | `Call | `If) as p -> + | (`While | `If | `Match) as p -> (p :> cp_match) + | `Sample lv -> + `Sample (trans_lv_match ?memory env lv) + | `Call lv -> + `Call (trans_lv_match ?memory env lv) | `Assign lv -> `Assign (trans_lv_match ?memory env lv) (* -------------------------------------------------------------------- *) @@ -3464,13 +3468,20 @@ let trans_cp_base ?(memory : memory option) (env : EcEnv.env) (p : pcp_base) : c match p with | `ByPos _ as p -> (p :> cp_base) | `ByMatch (i, p) -> `ByMatch (i, trans_cp_match ?memory env p) + (* -------------------------------------------------------------------- *) let trans_codepos1 ?(memory : memory option) (env : EcEnv.env) (p : pcodepos1) : codepos1 = snd_map (trans_cp_base ?memory env) p +(* -------------------------------------------------------------------- *) +let trans_codepos_brsel (bs : pbranch_select) : codepos_brsel = + match bs with + | `Cond b -> `Cond b + | `Match { pl_desc = x } -> `Match x + (* -------------------------------------------------------------------- *) let trans_codepos ?(memory : memory option) (env : EcEnv.env) ((nm, p) : pcodepos) : codepos = - let nm = List.map (fst_map (trans_codepos1 ?memory env)) nm in + let nm = List.map (fun (cp1, bs) -> (trans_codepos1 ?memory env cp1, trans_codepos_brsel bs)) nm in let p = trans_codepos1 ?memory env p in (nm, p) diff --git a/src/phl/ecPhlInline.ml b/src/phl/ecPhlInline.ml index 25a74be2f..4e7f6d027 100644 --- a/src/phl/ecPhlInline.ml +++ b/src/phl/ecPhlInline.ml @@ -309,6 +309,10 @@ module HiInternal = struct | Zp.ZWhile (_, sp) -> aux_s (IPwhile aout) sp | Zp.ZIfThen (_, sp, _) -> aux_s (IPif (aout, [])) sp | Zp.ZIfElse (_, _, sp) -> aux_s (IPif ([], aout)) sp + | Zp.ZMatch (_, sp, mpi) -> + let prebr = List.map (fun _ -> []) mpi.prebr in + let postbr = List.map (fun _ -> []) mpi.postbr in + aux_s (IPmatch (prebr @ aout :: postbr)) sp and aux_s aout ((sl, _), ip) = aux_i [(List.length sl, aout)] ip diff --git a/tests/match_codepos.ec b/tests/match_codepos.ec new file mode 100644 index 000000000..1d5a34c4f --- /dev/null +++ b/tests/match_codepos.ec @@ -0,0 +1,40 @@ +(* -------------------------------------------------------------------- *) +require import Distr. + +(* -------------------------------------------------------------------- *) +module M = { + proc f(x : bool option) = { + var y; + y <- false; + match x with + | None => {} + | Some v => { + if (v) { + y <$ dunit ((y || true) && true); + } + } + end; + return y; + } + proc g(x : bool option) = { + var z; + z <- false; + match x with + | None => {} + | Some v => { + if (v) { + z <$ dunit true; + } + } + end; + return z; + } +}. + +(* -------------------------------------------------------------------- *) +equiv l: M.f ~ M.g: ={arg} ==> ={res}. +proof. +proc. +proc rewrite {1} ^match#Some.^if.^y<$ /=. +by sim. +qed.