From 61ce35e15f0f9c31d926c121e6fe453a028d0f4e Mon Sep 17 00:00:00 2001 From: Pierre-Yves Strub Date: Thu, 5 Dec 2024 18:33:35 +0100 Subject: [PATCH] better overloading inference --- src/ecHiInductive.ml | 2 +- src/ecPrinting.ml | 6 +-- src/ecScope.ml | 2 +- src/ecTyping.ml | 119 +++++++++++++++++++++++++++++++------------ src/ecUnify.ml | 10 ++-- src/ecUnify.mli | 4 +- tests/overloading.ec | 24 +++++++++ 7 files changed, 122 insertions(+), 45 deletions(-) create mode 100644 tests/overloading.ec diff --git a/src/ecHiInductive.ml b/src/ecHiInductive.ml index bef40e9497..16243fee33 100644 --- a/src/ecHiInductive.ml +++ b/src/ecHiInductive.ml @@ -284,7 +284,7 @@ let trans_matchfix let filter = fun _ op -> EcDecl.is_ctor op in let PPApp ((cname, tvi), cargs) = pb.pop_pattern in let tvi = tvi |> omap (TT.transtvi env ue) in - let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in + let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in match cts with | [] -> diff --git a/src/ecPrinting.ml b/src/ecPrinting.ml index 9b5232c5b7..866f9dfec9 100644 --- a/src/ecPrinting.ml +++ b/src/ecPrinting.ml @@ -941,7 +941,7 @@ let pp_opapp (es : 'a list)) = let (nm, opname) = - PPEnv.op_symb ppe op (Some (pred, tvi, List.map t_ty es)) in + PPEnv.op_symb ppe op (Some (pred, tvi, (List.map t_ty es, None))) in let inm = if nm = [] then fst outer else nm in @@ -1250,7 +1250,7 @@ let pp_chained_orderings (ppe : PPEnv.t) t_ty pp_sub outer fmt (f, fs) = ignore (List.fold_left (fun fe (op, tvi, f) -> let (nm, opname) = - PPEnv.op_symb ppe op (Some (`Form, tvi, [t_ty fe; t_ty f])) + PPEnv.op_symb ppe op (Some (`Form, tvi, ([t_ty fe; t_ty f], None))) in Format.fprintf fmt " %t@ %a" (fun fmt -> @@ -1343,7 +1343,7 @@ let lower_left (ppe : PPEnv.t) (t_ty : form -> EcTypes.ty) (f : form) else l_l f2 onm e_bin_prio_rop4 | Fapp ({f_node = Fop (op, tys)}, [f1; f2]) -> (let (inm, opname) = - PPEnv.op_symb ppe op (Some (`Form, tys, List.map t_ty [f1; f2])) in + PPEnv.op_symb ppe op (Some (`Form, tys, (List.map t_ty [f1; f2], None))) in if inm <> [] && inm <> onm then None else match priority_of_binop opname with diff --git a/src/ecScope.ml b/src/ecScope.ml index b17295a4d5..8f5e4adfa9 100644 --- a/src/ecScope.ml +++ b/src/ecScope.ml @@ -1689,7 +1689,7 @@ module Ty = struct let tvi = List.map (TT.transty tp_tydecl env ue) tvi in let selected = EcUnify.select_op ~filter:(fun _ -> EcDecl.is_oper) - (Some (EcUnify.TVIunamed tvi)) env (unloc op) ue [] + (Some (EcUnify.TVIunamed tvi)) env (unloc op) ue ([], None) in let op = match selected with diff --git a/src/ecTyping.ml b/src/ecTyping.ml index b22c2a63d9..2f848b40d3 100644 --- a/src/ecTyping.ml +++ b/src/ecTyping.ml @@ -299,7 +299,7 @@ let select_local env (qs,s) = else None (* -------------------------------------------------------------------- *) -let select_pv env side name ue tvi psig = +let select_pv env side name ue tvi (psig, retty) = if tvi <> None then [] else @@ -307,7 +307,7 @@ let select_pv env side name ue tvi psig = let pvs = EcEnv.Var.lookup_progvar ?side name env in let select (pv,ty) = let subue = UE.copy ue in - let texpected = EcUnify.tfun_expected subue psig in + let texpected = EcUnify.tfun_expected subue ?retty psig in try EcUnify.unify env subue ty texpected; [(pv, ty, subue)] @@ -345,7 +345,7 @@ let gen_select_op (env : EcEnv.env) (name : EcSymbols.qsymbol) (ue : EcUnify.unienv) - (psig : EcTypes.dom) + (psig : EcTypes.dom * EcTypes.ty option) : OpSelect.gopsel list = @@ -431,7 +431,7 @@ let select_form_op env ~forcepv opsc name ue tvi psig = (* -------------------------------------------------------------------- *) let select_proj env opsc name ue tvi recty = let filter = (fun _ op -> EcDecl.is_proj op) in - let ops = EcUnify.select_op ~filter tvi env name ue [recty] in + let ops = EcUnify.select_op ~filter tvi env name ue ([recty], None) in let ops = List.map (fun (p, ty, ue, _) -> (p, ty, ue)) ops in match ops, opsc with @@ -1059,7 +1059,7 @@ let transpattern1 env ue (p : EcParsetree.plpattern) = let fields = let for1 (name, v) = let filter = fun _ op -> EcDecl.is_proj op in - let fds = EcUnify.select_op ~filter None env (unloc name) ue [] in + let fds = EcUnify.select_op ~filter None env (unloc name) ue ([], None) in match List.ohead fds with | None -> let exn = UnknownRecFieldName (unloc name) in @@ -1199,7 +1199,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) = let for1 rf = let filter = fun _ op -> EcDecl.is_proj op in let tvi = rf.rf_tvi |> omap (transtvi env ue) in - let fds = EcUnify.select_op ~filter tvi env (unloc rf.rf_name) ue [] in + let fds = EcUnify.select_op ~filter tvi env (unloc rf.rf_name) ue ([], None) in match List.ohead fds with | None -> let exn = UnknownRecFieldName (unloc rf.rf_name) in @@ -1288,7 +1288,7 @@ let trans_branch ~loc env ue gindty ((pb, body) : ppattern * _) = let filter = fun _ op -> EcDecl.is_ctor op in let PPApp ((cname, tvi), cargs) = pb in let tvi = tvi |> omap (transtvi env ue) in - let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in + let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in match cts with | [] -> @@ -1440,7 +1440,7 @@ let transexp (env : EcEnv.env) mode ue e = | PEident ({ pl_desc = name }, tvi) -> let tvi = tvi |> omap (transtvi env ue) in - let ops = select_exp_op env mode osc name ue tvi [] in + let ops = select_exp_op env mode osc name ue tvi ([], None) in begin match ops with | [] -> tyerror loc env (UnknownVarOrOp (name, [])) @@ -1460,7 +1460,7 @@ let transexp (env : EcEnv.env) mode ue e = let tvi = tvi |> omap (transtvi env ue) in let es = List.map (transexp env) pes in let esig = snd (List.split es) in - let ops = select_exp_op env mode osc name ue tvi esig in + let ops = select_exp_op env mode osc name ue tvi (esig, None) in begin match ops with | [] -> let uidmap = EcUnify.UniEnv.assubst ue in @@ -2745,7 +2745,7 @@ and translvalue ue (env : EcEnv.env) lvalue = let e, ety = e_tuple e, ttuple ety in let name = ([], EcCoreLib.s_set) in let esig = [xty; ety; codomty] in - let ops = select_exp_op env `InProc None name ue tvi esig in + let ops = select_exp_op env `InProc None name ue tvi (esig, None) in match ops with | [] -> @@ -2814,8 +2814,9 @@ and trans_gbinding env ue decl = and trans_form_or_pattern env ?mv ?ps ue pf tt = let state = PFS.create () in - let rec transf_r opsc env f = - let transf = transf_r opsc in + let rec transf_r_tyinfo opsc env ?tt f = + let transf env ?tt f = + transf_r opsc env ?tt f in match f.pl_desc with | PFhole -> begin @@ -3044,20 +3045,18 @@ and trans_form_or_pattern env ?mv ?ps ue pf tt = | PFdecimal (n, f) -> f_decimal (n, f) - | PFtuple args -> begin - let args = List.map (transf env) args in - match args with - | [] -> f_tt - | [f] -> f - | fs -> f_tuple fs - end + | PFtuple pes -> + let esig = List.map (fun _ -> EcUnify.UniEnv.fresh ue) pes in + tt |> oiter (fun tt -> unify_or_fail env ue f.pl_loc ~expct:tt (ttuple esig)); + let es = List.map2 (fun tt pe -> transf env ~tt pe) esig pes in + f_tuple es | PFident ({ pl_desc = name; pl_loc = loc }, tvi) -> let tvi = tvi |> omap (transtvi env ue) in let ops = select_form_op ~forcepv:(PFS.isforced state) - env opsc name ue tvi [] in + env opsc name ue tvi ([], tt) in begin match ops with | [] -> tyerror loc env (UnknownVarOrOp (name, [])) @@ -3183,13 +3182,43 @@ and trans_form_or_pattern env ?mv ?ps ue pf tt = check_mem f.pl_loc EcFol.mright; EcFol.f_ands (List.map (do1 (EcFol.mleft, EcFol.mright)) fs) - | PFapp ({pl_desc = PFident ({ pl_desc = name; pl_loc = loc }, tvi)}, pes) -> + | PFapp ({pl_desc = PFident ({ pl_desc = name; pl_loc = loc }, tvi)}, pes) -> begin + let try_trans ?tt pe = + let ue' = EcUnify.UniEnv.copy ue in + let ps' = Option.map (fun ps -> ref !ps) ps in + match transf env ?tt pe with + | e -> Some e + | exception TyError (_, _, MultipleOpMatch _) -> + Option.iter (fun ps -> ps := !(Option.get ps')) ps; + EcUnify.UniEnv.restore ~dst:ue ~src:ue'; + None + in + + match + let ue' = EcUnify.UniEnv.copy ue in + let ps' = Option.map (fun ps -> ref !ps) ps in + let es = List.map (fun pe -> try_trans pe) pes in + let tvi = tvi |> omap (transtvi env ue) in + let esig = List.map (fun e -> + match e with Some e -> e.f_ty | None -> EcUnify.UniEnv.fresh ue + ) es in + match + select_form_op ~forcepv:(PFS.isforced state) + env opsc name ue tvi (esig, tt) + with + | [sel] -> Some (sel, (es, esig, tvi)) + | _ -> + Option.iter (fun ps -> ps := !(Option.get ps')) ps; + EcUnify.UniEnv.restore ~dst:ue ~src:ue'; + None + with + | None -> begin let tvi = tvi |> omap (transtvi env ue) in let es = List.map (transf env) pes in let esig = List.map EcFol.f_ty es in let ops = select_form_op ~forcepv:(PFS.isforced state) - env opsc name ue tvi esig in + env opsc name ue tvi (esig, tt) in begin match ops with | [] -> @@ -3207,6 +3236,24 @@ and trans_form_or_pattern env ?mv ?ps ue pf tt = let matches = List.map (fun (_, _, subue, m) -> (m, subue)) ops in tyerror loc env (MultipleOpMatch (name, esig, matches)) end + end + + | Some ((_, _, subue, _) as sel, (es, esig, _tvi)) -> + EcUnify.UniEnv.restore ~dst:ue ~src:subue; + let es = + List.map2 ( + fun (e, ty) pe -> + match e with None -> try_trans ~tt:ty pe | Some e -> Some e + ) (List.combine es esig) pes in + let es = + List.map2 ( + fun (e, ty) pe -> + match e with None -> transf env ~tt:ty pe | Some e -> e + ) (List.combine es esig) pes in + let es = List.map2 (fun e l -> mk_loc l.pl_loc e) es pes in + EcUnify.UniEnv.restore ~src:ue ~dst:subue; + form_of_opselect (env, ue) loc sel es + end | PFapp (e, pes) -> let es = List.map (transf env) pes in @@ -3262,25 +3309,30 @@ and trans_form_or_pattern env ?mv ?ps ue pf tt = let f1 = transf env pf1 in unify_or_fail env ue pf1.pl_loc ~expct:pty f1.f_ty; aty |> oiter (fun aty-> unify_or_fail env ue pf1.pl_loc ~expct:pty aty); - let f2 = transf penv f2 in + let f2 = transf penv ?tt f2 in f_let p f1 f2 | PFforall (xs, pf) -> let env, xs = trans_gbinding env ue xs in let f = transf env pf in - unify_or_fail env ue pf.pl_loc ~expct:tbool f.f_ty; - f_forall xs f + unify_or_fail env ue pf.pl_loc ~expct:tbool f.f_ty; + f_forall xs f | PFexists (xs, f1) -> let env, xs = trans_gbinding env ue xs in let f = transf env f1 in - unify_or_fail env ue f1.pl_loc ~expct:tbool f.f_ty; - f_exists xs f + unify_or_fail env ue f1.pl_loc ~expct:tbool f.f_ty; + f_exists xs f | PFlambda (xs, f1) -> let env, xs = trans_binding env ue xs in - let f = transf env f1 in - f_lambda (List.map (fun (x,ty) -> (x,GTty ty)) xs) f + let subtt = tt |> Option.map (fun tt -> + let codom = EcUnify.UniEnv.fresh ue in + unify_or_fail env ue (loc f) ~expct:(toarrow (List.snd xs) codom) tt; + codom + ) in + let f = transf env ?tt:subtt f1 in + f_lambda (List.map (fun (x, ty) -> (x, GTty ty)) xs) f | PFrecord (b, fields) -> let (ctor, fields, (rtvi, reccty)) = @@ -3390,11 +3442,12 @@ and trans_form_or_pattern env ?mv ?ps ue pf tt = unify_or_fail qenv ue post.pl_loc ~expct:tbool post'.f_ty; f_eagerF pre' s1 fpath1 fpath2 s2 post' - in + and transf_r opsc env ?tt pf = + let f = transf_r_tyinfo opsc env ?tt pf in + let () = oiter (fun tt -> unify_or_fail env ue pf.pl_loc ~expct:tt f.f_ty) tt in + f - let f = transf_r None env pf in - tt |> oiter (fun tt -> unify_or_fail env ue pf.pl_loc ~expct:tt f.f_ty); - f + in transf_r None env ?tt pf (* Type-check a memtype. *) and trans_memtype env ue (pmemtype : pmemtype) : memtype = diff --git a/src/ecUnify.ml b/src/ecUnify.ml index cd557aadef..e5bb56299d 100644 --- a/src/ecUnify.ml +++ b/src/ecUnify.ml @@ -396,15 +396,15 @@ let hastc env ue ty tc = ue := { !ue with ue_uf = uf; } (* -------------------------------------------------------------------- *) -let tfun_expected ue psig = - let tres = UniEnv.fresh ue in - EcTypes.toarrow psig tres +let tfun_expected ue ?retty psig = + let retty = ofdfl (fun () -> UniEnv.fresh ue) retty in + EcTypes.toarrow psig retty (* -------------------------------------------------------------------- *) type sbody = ((EcIdent.t * ty) list * expr) Lazy.t (* -------------------------------------------------------------------- *) -let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig = +let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue (psig, retty) = ignore hidden; (* FIXME *) let module D = EcDecl in @@ -457,7 +457,7 @@ let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig let (tip, tvs) = UniEnv.openty_r subue op.D.op_tparams tvi in let top = ty_subst tip op.D.op_ty in - let texpected = tfun_expected subue psig in + let texpected = tfun_expected subue ?retty psig in (try unify env subue top texpected with UnificationFailure _ -> raise E.Failure); diff --git a/src/ecUnify.mli b/src/ecUnify.mli index 2c7fbdb1a1..90488fabc4 100644 --- a/src/ecUnify.mli +++ b/src/ecUnify.mli @@ -37,7 +37,7 @@ end val unify : EcEnv.env -> unienv -> ty -> ty -> unit val hastc : EcEnv.env -> unienv -> ty -> Sp.t -> unit -val tfun_expected : unienv -> EcTypes.ty list -> EcTypes.ty +val tfun_expected : unienv -> ?retty:ty -> EcTypes.ty list -> EcTypes.ty type sbody = ((EcIdent.t * ty) list * expr) Lazy.t @@ -48,5 +48,5 @@ val select_op : -> EcEnv.env -> qsymbol -> unienv - -> dom + -> dom * ty option -> ((EcPath.path * ty list) * ty * unienv * sbody option) list diff --git a/tests/overloading.ec b/tests/overloading.ec new file mode 100644 index 0000000000..2a026f7773 --- /dev/null +++ b/tests/overloading.ec @@ -0,0 +1,24 @@ +require import AllCore List. + +theory T. + op o : int. + op a : int -> int -> int. +end T. + +theory U. + op o : bool. + op a : bool -> bool -> bool. +end U. + +import T U. + +op foo : int -> unit. + +op bar = foo o. + +op plop1 = foldr a false []. + +op plop2 = foldr (fun x => a x) false []. + +op plop3 = foldr (fun x y => a x y) false []. +