Skip to content

Commit

Permalink
Improve code position (match + extended assignments) (#654)
Browse files Browse the repository at this point in the history
- `^match` pattern now allowed within a code pos
 - `#cname` can be used to select the appropriate sub-branch of a match,
   e.g. `^match#Some.1`
 - `^lv<@` and `^lv<$` are now permitted
  • Loading branch information
Cameron-Low authored Oct 30, 2024
1 parent ed8f813 commit 9eaff01
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 33 deletions.
56 changes: 43 additions & 13 deletions src/ecMatching.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ module Position = struct
| `If
| `While
| `Assign of lvmatch
| `Sample
| `Call
| `Sample of lvmatch
| `Call of lvmatch
| `Match
]

and lvmatch = [ `LvmNone | `LvmVar of EcTypes.prog_var ]
Expand All @@ -30,9 +31,10 @@ module Position = struct
| `ByMatch of int option * cp_match
]

type codepos1 = int * cp_base
type codepos = (codepos1 * int) list * codepos1
type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1]
type codepos_brsel = [`Cond of bool | `Match of EcSymbols.symbol]
type codepos1 = int * cp_base
type codepos = (codepos1 * codepos_brsel) list * codepos1
type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1]

let shift1 ~(offset : int) ((o, p) : codepos1) : codepos1 =
(o + offset, p)
Expand All @@ -57,12 +59,19 @@ module Zipper = struct
type ('a, 'state) folder =
'a -> 'state -> instr -> 'state * instr list

type spath_match_ctxt = {
locals : (EcIdent.t * ty) list;
prebr : ((EcIdent.t * ty) list * stmt) list;
postbr : ((EcIdent.t * ty) list * stmt) list;
}

type ipath =
| ZTop
| ZWhile of expr * spath
| ZIfThen of expr * spath * stmt
| ZIfElse of expr * stmt * spath

| ZMatch of expr * spath * spath_match_ctxt

and spath = (instr list * instr list) * ipath

type zipper = {
Expand Down Expand Up @@ -95,9 +104,12 @@ module Zipper = struct
match ir.i_node, cm with
| Swhile _, `While -> i-1
| Sif _, `If -> i-1
| Srnd _, `Sample -> i-1
| Scall _, `Call -> i-1
| Smatch _, `Match -> i-1

| Scall (None, _, _), `Call `LvmNone -> i-1

| Scall (Some lv, _, _), `Call lvm
| Srnd (lv, _), `Sample lvm
| Sasgn (lv, _), `Assign lvm -> begin
match lv, lvm with
| _, `LvmNone -> i-1
Expand Down Expand Up @@ -178,23 +190,34 @@ module Zipper = struct

let zipper_at_nm_cpos1
(env : EcEnv.env)
((cp1, sub) : codepos1 * int)
((cp1, sub) : codepos1 * codepos_brsel)
(s : stmt)
(zpr : ipath)
: (ipath * stmt) * (codepos1 * int)
: (ipath * stmt) * (codepos1 * codepos_brsel)
=
let (s1, i, s2) = find_by_cpos1 env cp1 s in
let zpr =
match i.i_node, sub with
| Swhile (e, sw), 0 ->
| Swhile (e, sw), `Cond true ->
(ZWhile (e, ((s1, s2), zpr)), sw)

| Sif (e, ifs1, ifs2), 0 ->
| Sif (e, ifs1, ifs2), `Cond true ->
(ZIfThen (e, ((s1, s2), zpr), ifs2), ifs1)

| Sif (e, ifs1, ifs2), 1 ->
| Sif (e, ifs1, ifs2), `Cond false ->
(ZIfElse (e, ifs1, ((s1, s2), zpr)), ifs2)

| Smatch (e, bs), `Match cn ->
let _, indt, _ = oget (EcEnv.Ty.get_top_decl e.e_ty env) in
let indt = oget (EcDecl.tydecl_as_datatype indt) in
let cnames = List.fst indt.tydt_ctors in
let ix, _ =
try List.findi (fun _ n -> EcSymbols.sym_equal cn n) cnames
with Not_found -> raise InvalidCPos
in
let prebr, (locals, body), postbr = List.pivot_at ix bs in
(ZMatch (e, ((s1, s2), zpr), { locals; prebr; postbr; }), body)

| _ -> raise InvalidCPos
in zpr, ((0, `ByPos (1 + List.length s1)), sub)

Expand Down Expand Up @@ -228,6 +251,8 @@ module Zipper = struct
| ZWhile (e, sp) -> zip (Some (i_while (e, s))) sp
| ZIfThen (e, sp, se) -> zip (Some (i_if (e, s, se))) sp
| ZIfElse (e, se, sp) -> zip (Some (i_if (e, se, s))) sp
| ZMatch (e, sp, mpi) ->
zip (Some (i_match (e, mpi.prebr @ (mpi.locals, s) :: mpi.postbr))) sp

let zip zpr = zip None ((zpr.z_head, zpr.z_tail), zpr.z_path)

Expand All @@ -238,6 +263,7 @@ module Zipper = struct
| ZWhile (_, ((_, is), ip)) -> doit (is :: acc) ip
| ZIfThen (_, ((_, is), ip), _) -> doit (is :: acc) ip
| ZIfElse (_, _, ((_, is), ip)) -> doit (is :: acc) ip
| ZMatch (_, ((_, is), ip), _) -> doit (is :: acc) ip
in

let after =
Expand Down Expand Up @@ -1298,6 +1324,10 @@ module RegexpBaseInstr = struct
let z' = zipper head tail path in
next_zipper z'

| ZMatch (_, ((head, tail), path), _) ->
let z' = zipper head tail path in
next_zipper z'

let next (e : engine) =
next_zipper e.e_zipper |> omap (fun z ->
{ e with e_zipper = z; e_pos = List.length z.z_head })
Expand Down
19 changes: 14 additions & 5 deletions src/ecMatching.mli
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ module Position : sig
type cp_match = [
| `If
| `While
| `Match
| `Assign of lvmatch
| `Sample
| `Call
| `Sample of lvmatch
| `Call of lvmatch
]

and lvmatch = [ `LvmNone | `LvmVar of EcTypes.prog_var ]
Expand All @@ -26,9 +27,10 @@ module Position : sig
| `ByMatch of int option * cp_match
]

type codepos1 = int * cp_base
type codepos = (codepos1 * int) list * codepos1
type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1]
type codepos_brsel = [`Cond of bool | `Match of EcSymbols.symbol]
type codepos1 = int * cp_base
type codepos = (codepos1 * codepos_brsel) list * codepos1
type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1]

val shift1 : offset:int -> codepos1 -> codepos1
val shift : offset:int -> codepos -> codepos
Expand All @@ -40,11 +42,18 @@ end
module Zipper : sig
open Position

type spath_match_ctxt = {
locals : (EcIdent.t * ty) list;
prebr : ((EcIdent.t * ty) list * stmt) list;
postbr : ((EcIdent.t * ty) list * stmt) list;
}

type ipath =
| ZTop
| ZWhile of expr * spath
| ZIfThen of expr * spath * stmt
| ZIfElse of expr * stmt * spath
| ZMatch of expr * spath * spath_match_ctxt

and spath = (instr list * instr list) * ipath

Expand Down
14 changes: 10 additions & 4 deletions src/ecParser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -2606,9 +2606,10 @@ tac_dir:
icodepos_r:
| IF { (`If :> pcp_match) }
| WHILE { (`While :> pcp_match) }
| LESAMPLE { (`Sample :> pcp_match) }
| LEAT { (`Call :> pcp_match) }
| MATCH { (`Match :> pcp_match) }

| lvm=lvmatch LESAMPLE { (`Sample lvm :> pcp_match) }
| lvm=lvmatch LEAT { (`Call lvm :> pcp_match) }
| lvm=lvmatch LARROW { (`Assign lvm :> pcp_match) }

lvmatch:
Expand All @@ -2631,9 +2632,14 @@ codepos1:
| cp=codepos1_wo_off AMP PLUS i=word { ( i, cp) }
| cp=codepos1_wo_off AMP MINUS i=word { (-i, cp) }

branch_select:
| SHARP s=boident DOT {`Match s}
| DOT { `Cond true }
| QUESTION { `Cond false }

%inline nm1_codepos:
| i=codepos1 k=ID(DOT { 0 } | QUESTION { 1 } )
{ (i, k) }
| i=codepos1 bs=branch_select
{ (i, bs) }

codepos:
| nm=rlist0(nm1_codepos, empty) i=codepos1
Expand Down
8 changes: 5 additions & 3 deletions src/ecParsetree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -490,17 +490,19 @@ type preduction = {
type pcp_match = [
| `If
| `While
| `Match
| `Assign of plvmatch
| `Sample
| `Call
| `Sample of plvmatch
| `Call of plvmatch
]

and plvmatch = [ `LvmNone | `LvmVar of pqsymbol ]

type pcp_base = [ `ByPos of int | `ByMatch of int option * pcp_match ]

type pbranch_select = [`Cond of bool | `Match of psymbol]
type pcodepos1 = int * pcp_base
type pcodepos = (pcodepos1 * int) list * pcodepos1
type pcodepos = (pcodepos1 * pbranch_select) list * pcodepos1
type pdocodepos1 = pcodepos1 doption option

type pcodeoffset1 = [
Expand Down
21 changes: 15 additions & 6 deletions src/ecPrinting.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2124,9 +2124,12 @@ let pp_codepos1 (ppe : PPEnv.t) (fmt : Format.formatter) ((off, cp) : CP.codepos
let k =
match k with
| `If -> "if"
| `Match -> "match"
| `While -> "while"
| `Sample -> "<$"
| `Call -> "<@"
| `Sample `LvmNone -> "<$"
| `Sample (`LvmVar pv) -> Format.asprintf "%a<$" (pp_pv ppe) pv
| `Call `LvmNone -> "<@"
| `Call (`LvmVar pv) -> Format.asprintf "%a<@" (pp_pv ppe) pv
| `Assign `LvmNone -> "<-"
| `Assign (`LvmVar pv) -> Format.asprintf "%a<-" (pp_pv ppe) pv
in Format.asprintf "^%s" k in
Expand All @@ -2146,14 +2149,20 @@ let pp_codeoffset1 (ppe : PPEnv.t) (fmt : Format.formatter) (offset : CP.codeoff
match offset with
| `ByPosition p -> Format.fprintf fmt "%a" (pp_codepos1 ppe) p
| `ByOffset o -> Format.fprintf fmt "%d" o

(* -------------------------------------------------------------------- *)
let pp_codepos (ppe : PPEnv.t) (fmt : Format.formatter) ((nm, cp1) : CP.codepos) =
let pp_nm (fmt : Format.formatter) ((cp, i) : CP.codepos1 * int) =
Format.eprintf "%a%s" (pp_codepos1 ppe) cp (if i = 0 then "." else "?")
let pp_nm (fmt : Format.formatter) ((cp, bs) : CP.codepos1 * CP.codepos_brsel) =
let bs =
match bs with
| `Cond true -> "."
| `Cond false -> "?"
| `Match cp -> Format.sprintf "#%s." cp
in
Format.fprintf fmt "%a%s" (pp_codepos1 ppe) cp bs
in

Format.eprintf "%a%a" (pp_list "" pp_nm) nm (pp_codepos1 ppe) cp1
Format.fprintf fmt "%a%a" (pp_list "" pp_nm) nm (pp_codepos1 ppe) cp1

(* -------------------------------------------------------------------- *)
let pp_opdecl_pr (ppe : PPEnv.t) fmt (basename, ts, ty, op) =
Expand Down
15 changes: 13 additions & 2 deletions src/ecTyping.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3455,22 +3455,33 @@ let trans_lv_match ?(memory : memory option) (env : EcEnv.env) (p : plvmatch) :
(* -------------------------------------------------------------------- *)
let trans_cp_match ?(memory : memory option) (env : EcEnv.env) (p : pcp_match) : cp_match =
match p with
| (`Sample | `While | `Call | `If) as p ->
| (`While | `If | `Match) as p ->
(p :> cp_match)
| `Sample lv ->
`Sample (trans_lv_match ?memory env lv)
| `Call lv ->
`Call (trans_lv_match ?memory env lv)
| `Assign lv ->
`Assign (trans_lv_match ?memory env lv)
(* -------------------------------------------------------------------- *)
let trans_cp_base ?(memory : memory option) (env : EcEnv.env) (p : pcp_base) : cp_base =
match p with
| `ByPos _ as p -> (p :> cp_base)
| `ByMatch (i, p) -> `ByMatch (i, trans_cp_match ?memory env p)

(* -------------------------------------------------------------------- *)
let trans_codepos1 ?(memory : memory option) (env : EcEnv.env) (p : pcodepos1) : codepos1 =
snd_map (trans_cp_base ?memory env) p

(* -------------------------------------------------------------------- *)
let trans_codepos_brsel (bs : pbranch_select) : codepos_brsel =
match bs with
| `Cond b -> `Cond b
| `Match { pl_desc = x } -> `Match x

(* -------------------------------------------------------------------- *)
let trans_codepos ?(memory : memory option) (env : EcEnv.env) ((nm, p) : pcodepos) : codepos =
let nm = List.map (fst_map (trans_codepos1 ?memory env)) nm in
let nm = List.map (fun (cp1, bs) -> (trans_codepos1 ?memory env cp1, trans_codepos_brsel bs)) nm in
let p = trans_codepos1 ?memory env p in
(nm, p)

Expand Down
4 changes: 4 additions & 0 deletions src/phl/ecPhlInline.ml
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ module HiInternal = struct
| Zp.ZWhile (_, sp) -> aux_s (IPwhile aout) sp
| Zp.ZIfThen (_, sp, _) -> aux_s (IPif (aout, [])) sp
| Zp.ZIfElse (_, _, sp) -> aux_s (IPif ([], aout)) sp
| Zp.ZMatch (_, sp, mpi) ->
let prebr = List.map (fun _ -> []) mpi.prebr in
let postbr = List.map (fun _ -> []) mpi.postbr in
aux_s (IPmatch (prebr @ aout :: postbr)) sp

and aux_s aout ((sl, _), ip) =
aux_i [(List.length sl, aout)] ip
Expand Down
40 changes: 40 additions & 0 deletions tests/match_codepos.ec
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
(* -------------------------------------------------------------------- *)
require import Distr.

(* -------------------------------------------------------------------- *)
module M = {
proc f(x : bool option) = {
var y;
y <- false;
match x with
| None => {}
| Some v => {
if (v) {
y <$ dunit ((y || true) && true);
}
}
end;
return y;
}
proc g(x : bool option) = {
var z;
z <- false;
match x with
| None => {}
| Some v => {
if (v) {
z <$ dunit true;
}
}
end;
return z;
}
}.

(* -------------------------------------------------------------------- *)
equiv l: M.f ~ M.g: ={arg} ==> ={res}.
proof.
proc.
proc rewrite {1} ^match#Some.^if.^y<$ /=.
by sim.
qed.

0 comments on commit 9eaff01

Please sign in to comment.