Skip to content

Commit

Permalink
Merge pull request #6 from codex-semantics-library/feature/fold_on_di…
Browse files Browse the repository at this point in the history
…fferences

New functions: fold_on_nonequal_inter and fold_on_nonequal_union
  • Loading branch information
mlemerre authored May 10, 2024
2 parents 115a5c9 + cb16f86 commit bef2736
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 1 deletion.
159 changes: 159 additions & 0 deletions patriciaTree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ module type BASE_MAP = sig
type ('acc,'map) polyfold = { f: 'a. 'a key -> ('a,'map) value -> 'acc -> 'acc } [@@unboxed]
val fold : ('acc,'map) polyfold -> 'map t -> 'acc -> 'acc

type ('acc,'map) polyfold2 = { f: 'a. 'a key -> ('a,'map) value -> ('a,'map) value -> 'acc -> 'acc } [@@unboxed]
val fold_on_nonequal_inter : ('acc,'map) polyfold2 -> 'map t -> 'map t -> 'acc -> 'acc

type ('acc,'map) polyfold2_union = { f: 'a. 'a key -> ('a,'map) value option -> ('a,'map) value option -> 'acc -> 'acc } [@@unboxed]
val fold_on_nonequal_union : ('acc,'map) polyfold2_union -> 'map t -> 'map t -> 'acc -> 'acc

type 'map polypredicate = { f: 'a. 'a key -> ('a,'map) value -> bool; } [@@unboxed]
val filter : 'map polypredicate -> 'map t -> 'map t
val for_all : 'map polypredicate -> 'map t -> bool
Expand Down Expand Up @@ -299,6 +305,11 @@ module type MAP = sig
val split : key -> 'a t -> 'a t * 'a option * 'a t
val iter : (key -> 'a -> unit) -> 'a t -> unit
val fold : (key -> 'a -> 'acc -> 'acc) -> 'a t -> 'acc -> 'acc
val fold_on_nonequal_inter : (key -> 'a -> 'a -> 'acc -> 'acc) ->
'a t -> 'a t -> 'acc -> 'acc
val fold_on_nonequal_union :
(key -> 'a option -> 'a option -> 'acc -> 'acc) ->
'a t -> 'a t -> 'acc -> 'acc
val filter : (key -> 'a -> bool) -> 'a t -> 'a t
val for_all : (key -> 'a -> bool) -> 'a t -> bool
val map : ('a -> 'a) -> 'a t -> 'a t
Expand Down Expand Up @@ -1172,6 +1183,144 @@ module MakeCustomHeterogeneous
let acc = fold f tree0 acc in
fold f tree1 acc


type ('acc,'map) polyfold2 = { f: 'a. 'a key -> ('a,'map) value -> ('a,'map) value -> 'acc -> 'acc } [@@unboxed]
let rec fold_on_nonequal_inter f ta tb acc =
if ta == tb then acc
else match NODE.view ta,NODE.view tb with
| Empty, _ | _, Empty -> acc
| Leaf{key;value},_ ->
(try let valueb = find key tb in
if valueb == value then acc else
f.f key value valueb acc
with Not_found -> acc)
| _,Leaf{key;value} ->
(try let valuea = find key ta in
if valuea == value then acc else
f.f key valuea value acc
with Not_found -> acc)
| Branch{prefix=pa;branching_bit=ma;tree0=ta0;tree1=ta1},
Branch{prefix=pb;branching_bit=mb;tree0=tb0;tree1=tb1} ->
if ma == mb && pa == pb
(* Same prefix: fold on each subtrees *)
then
let acc = fold_on_nonequal_inter f ta0 tb0 acc in
let acc = fold_on_nonequal_inter f ta1 tb1 acc in
acc
else if unsigned_lt mb ma && match_prefix pb pa ma
then if ma land pb == 0
then fold_on_nonequal_inter f ta0 tb acc
else fold_on_nonequal_inter f ta1 tb acc
else if unsigned_lt ma mb && match_prefix pa pb mb
then if mb land pa == 0
then fold_on_nonequal_inter f ta tb0 acc
else fold_on_nonequal_inter f ta tb1 acc
else acc


type ('acc,'map) polyfold2_union =
{ f: 'a. 'a key -> ('a,'map) value option -> ('a,'map) value option ->
'acc -> 'acc } [@@unboxed]
let rec fold_on_nonequal_union:
'm 'acc. ('acc,'m) polyfold2_union -> 'm t -> 'm t -> 'acc -> 'acc =
fun (type m) f (ta:m t) (tb:m t) acc ->
if ta == tb then acc
else
let fleft:(_,_) polyfold =
{f=fun key value acc -> f.f key (Some value) None acc} in
let fright:(_,_)polyfold =
{f=fun key value acc -> f.f key None (Some value) acc} in
match NODE.view ta,NODE.view tb with
| Empty, _ -> fold fright tb acc
| _, Empty -> fold fleft ta acc
| Leaf{key;value},_ ->
let ida = Key.to_int key in
(* Fold on the rest, knowing that ida may or may not be in b. So we fold and use
did_a to remember if we already did the call to a. *)
let g (type b) (keyb:b key) (valueb:(b,m) value) (acc,did_a) =
let default() = (f.f keyb None (Some valueb) acc,did_a) in
if did_a then default()
else
let idb = Key.to_int keyb in
if unsigned_lt idb ida then default()
else if unsigned_lt ida idb then
let acc = f.f key (Some value) None acc in
let acc = f.f keyb None (Some valueb) acc in
(acc,true)
else match Key.polyeq key keyb with
| Eq ->
if value == valueb then (acc,true)
else (f.f key (Some value) (Some valueb) acc,true)
| Diff ->
raise (Invalid_argument "Keys with same to_int value are not equal by polyeq")
in
let (acc,found) = fold{f=fun keyb valueb acc -> g keyb valueb acc} tb (acc,false) in
if found then acc
else f.f key (Some value) None acc
| _,Leaf{key;value} ->
let idb = Key.to_int key in
let g (type a) (keya: a key) (valuea:(a,m) value) (acc,did_b) =
let default() = (f.f keya (Some valuea) None acc,did_b) in
if did_b then default()
else
let ida = Key.to_int keya in
if unsigned_lt ida idb then default()
else if unsigned_lt idb ida then
let acc = f.f key None (Some value) acc in
let acc = f.f keya (Some valuea) None acc in
(acc,true)
else match Key.polyeq keya key with
| Eq ->
if valuea == value then (acc,true)
else (f.f keya (Some valuea) (Some value) acc,true)
| Diff ->
raise (Invalid_argument "Keys with same to_int value are not equal by polyeq")
in
let (acc,found) = fold{f=fun keya valuea acc -> g keya valuea acc} ta (acc,false) in
if found then acc
else f.f key None (Some value) acc
| Branch{prefix=pa;branching_bit=ma;tree0=ta0;tree1=ta1},
Branch{prefix=pb;branching_bit=mb;tree0=tb0;tree1=tb1} ->
if ma == mb && pa == pb
(* Same prefix: merge the subtrees *)
then
let acc = fold_on_nonequal_union f ta0 tb0 acc in
let acc = fold_on_nonequal_union f ta1 tb1 acc in
acc
else if unsigned_lt mb ma && match_prefix pb pa ma
then if ma land pb == 0
then
let acc = fold_on_nonequal_union f ta0 tb acc in
let acc = fold fleft ta1 acc in
acc
else
let acc = fold fleft ta0 acc in
let acc = fold_on_nonequal_union f ta1 tb acc in
acc
else if unsigned_lt ma mb && match_prefix pa pb mb
then if mb land pa == 0
then
let acc = fold_on_nonequal_union f ta tb0 acc in
let acc = fold fright tb1 acc in
acc
else
let acc = fold fright tb0 acc in
let acc = fold_on_nonequal_union f ta tb1 acc in
acc
else
(* Distinct subtrees: process them in increasing order of keys. *)
if unsigned_lt pa pb then
let acc = fold fleft ta acc in
let acc = fold fright tb acc in
acc
else
let acc = fold fright tb acc in
let acc = fold fleft ta acc in
acc
;;



type 'map polypredicate = { f: 'a. 'a key -> ('a,'map) value -> bool; } [@@unboxed]
let filter f m = filter_map {f = fun k v -> if f.f k v then Some v else None } m
let rec for_all f m = match NODE.view m with
Expand Down Expand Up @@ -1512,6 +1661,16 @@ module MakeCustom
let slow_merge (f : key -> 'a option -> 'b option -> 'c option) a b = BaseMap.slow_merge {f=fun k v1 v2 -> snd_opt (f k (opt_snd v1) (opt_snd v2))} a b
let iter (f: key -> 'a -> unit) a = BaseMap.iter {f=fun k (Snd v) -> f k v} a
let fold (f: key -> 'a -> 'acc) m acc = BaseMap.fold {f=fun k (Snd v) acc -> f k v acc} m acc
let fold_on_nonequal_inter (f: key -> 'a -> 'b -> 'acc) ma mb acc =
let f k (Snd va) (Snd vb) acc = f k va vb acc in
BaseMap.fold_on_nonequal_inter {f} ma mb acc
let fold_on_nonequal_union
(f: key -> 'a option -> 'b option -> 'acc) ma mb acc =
let f k va vb acc =
let va = Option.map (fun (Snd v) -> v) va in
let vb = Option.map (fun (Snd v) -> v) vb in
f k va vb acc in
BaseMap.fold_on_nonequal_union {f} ma mb acc

let pretty ?pp_sep (f: Format.formatter -> key -> 'a -> unit) fmt m =
BaseMap.pretty ?pp_sep {f=fun fmt k (Snd v) -> f fmt k v} fmt m
Expand Down
36 changes: 35 additions & 1 deletion patriciaTree.mli
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,24 @@ module type BASE_MAP = sig
where [(key_1, value_1) ... (key_n, value_n)] are the bindings of [m], in
the {{!unsigned_lt}unsigned order} on [Key.to_int]. *)

type ('acc,'map) polyfold2 = { f: 'a. 'a key -> ('a,'map) value -> ('a,'map) value -> 'acc -> 'acc } [@@unboxed]
val fold_on_nonequal_inter : ('acc,'map) polyfold2 -> 'map t -> 'map t -> 'acc -> 'acc
(** [fold_on_nonequal_inter f m1 m2 acc] returns
[f.f key_n value1_n value2n (... (f.f key_1 value1_1 value2_1 acc))] where
[(key_1, value1_1, value2_1) ... (key_n, value1_n, value2_n)] are the
bindings that exist in both maps ([m1 ∩ m2]) whose values are physically different.
Calls to [f.f] are performed in the {{!unsigned_lt}unsigned order} of [Key.to_int]. *)


type ('acc,'map) polyfold2_union = { f: 'a. 'a key -> ('a,'map) value option -> ('a,'map) value option -> 'acc -> 'acc } [@@unboxed]
val fold_on_nonequal_union : ('acc,'map) polyfold2_union -> 'map t -> 'map t -> 'acc -> 'acc
(** [fold_on_nonequal_union f m1 m2 acc] returns
[f.f key_n value1_n value2n (... (f.f key_1 value1_1 value2_1 acc))] where
[(key_1, value1_1, value2_1) ... (key_n, value1_n, value2_n)] are the
bindings that exists in either map ([m1 ∪ m2]) whose values are physically
different.
Calls to [f.f] are performed in the {{!unsigned_lt}unsigned order} of [Key.to_int]. *)

type 'map polypredicate = { f: 'a. 'a key -> ('a,'map) value -> bool; } [@@unboxed]
val filter : 'map polypredicate -> 'map t -> 'map t
(** [filter f m] returns the submap of [m] containing the bindings [k->v]
Expand Down Expand Up @@ -899,6 +917,23 @@ module type MAP = sig
val fold : (key -> 'a -> 'acc -> 'acc) -> 'a t -> 'acc -> 'acc
(** Fold on each (key,value) pair of the map, in increasing {{!unsigned_lt}unsigned order} of keys. *)

val fold_on_nonequal_inter : (key -> 'a -> 'a -> 'acc -> 'acc) ->
'a t -> 'a t -> 'acc -> 'acc
(** [fold_on_nonequal_inter f m1 m2 acc] returns
[f key_n value1_n value2n (... (f key_1 value1_1 value2_1 acc))] where
[(key_1, value1_1, value2_1) ... (key_n, value1_n, value2_n)] are the
bindings that exist in both maps ([m1 ∩ m2]) whose values are physically different.
Calls to [f] are performed in the {{!unsigned_lt}unsigned order} of [Key.to_int]. *)

val fold_on_nonequal_union: (key -> 'a option -> 'a option -> 'acc -> 'acc) ->
'a t -> 'a t -> 'acc -> 'acc
(** [fold_on_nonequal_union f m1 m2 acc] returns
[f key_n value1_n value2n (... (f key_1 value1_1 value2_1 acc))] where
[(key_1, value1_1, value2_1) ... (key_n, value1_n, value2_n)] are the
bindings that exists in either map ([m1 ∪ m2]) whose values are physically
different.
Calls to [f.f] are performed in the {{!unsigned_lt}unsigned order} of [Key.to_int]. *)

val filter : (key -> 'a -> bool) -> 'a t -> 'a t
(** Returns the submap containing only the key->value pairs satisfying the
given predicate. [f] is called in increasing {{!unsigned_lt}unsigned order} of keys. *)
Expand Down Expand Up @@ -1115,7 +1150,6 @@ module type MAP = sig
in increasing {{!unsigned_lt}unsigned order} of [Key.to_int] *)
end


(** {1 Keys} *)
(** Keys are the functor arguments used to build the maps. *)

Expand Down
53 changes: 53 additions & 0 deletions patriciaTreeTest.ml
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,27 @@ let%test_module _ = (module struct
| None, _ | _, None -> None
| Some a, Some b -> (f key a b)) m1 m2

let fold_on_nonequal_inter f m1 m2 acc =
let racc = ref acc in
ignore @@ M.merge (fun key a b ->
match a,b with
| None, _ | _, None -> None
| Some a, Some b ->
if a != b
then racc := f key a b !racc;
None) m1 m2;
!racc

let fold_on_nonequal_union f ma mb acc =
let union = M.merge (fun _key a b ->
match a,b with
| None, None -> assert false
| Some a, Some b when a == b -> None
| None, Some _ | Some _, None | Some _, Some _ -> Some(a,b)) ma mb in
let elts = M.bindings union in
let elts = List.sort (fun (key1,_val1) (key2,_val2) -> unsigned_compare key1 key2) elts in
List.fold_left (fun acc (key,(val1,val2)) -> f key val1 val2 acc) acc elts

let pop_unsigned_minimum m =
match M.min_binding m with
| exception Not_found -> None
Expand Down Expand Up @@ -529,6 +550,38 @@ let%test_module _ = (module struct
modelres == myres)
let () = QCheck.Test.check_exn test_disjoint

let test_fold_on_nonequal_inter = QCheck.Test.make ~count:1000 ~name:"fold_on_nonequal_inter"
gen (fun x ->
let (m1,model1,m2,model2) = model_from_gen x in
let orig_f key v1 v2 acc = sdbm key @@ sdbm v1 @@ sdbm v2 acc in
let chk_calls = check_increases () in
let f key v1 v2 acc =
chk_calls key;
orig_f key v1 v2 acc
in
let myres = MyMap.fold_on_nonequal_inter f m1 m2 117 in
let modelres = IntMap.fold_on_nonequal_inter orig_f model1 model2 117 in
modelres == myres)
let () = QCheck.Test.check_exn test_fold_on_nonequal_inter

let test_fold_on_nonequal_union = QCheck.Test.make ~count:1000 ~name:"fold_on_nonequal_union"
gen (fun x ->
let (m1,model1,m2,model2) = model_from_gen x in
let orig_f key v1 v2 acc =
(* Printf.printf "Calling f key=%d v1=%s v2=%s acc=%d\n%!" *)
(* key (match v1 with None -> "None" | Some v -> string_of_int v) *)
(* (match v2 with None -> "None" | Some v -> string_of_int v) acc; *)
(* chk_calls key; *)
let v1 = match v1 with None -> 421 | Some v -> v in
let v2 = match v2 with None -> 567 | Some v -> v in
sdbm key @@ sdbm v1 @@ sdbm v2 acc in
let chk_calls = check_increases () in
let f key v1 v2 acc = chk_calls key; orig_f key v1 v2 acc in
let myres = MyMap.fold_on_nonequal_union f m1 m2 117 in
let modelres = IntMap.fold_on_nonequal_union orig_f model1 model2 117 in
modelres == myres)
let () = QCheck.Test.check_exn test_fold_on_nonequal_union

let%test "negative_keys" =
let map = MyMap.add 0 0 MyMap.empty in
let _pp_l fmt = Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "; ")
Expand Down

0 comments on commit bef2736

Please sign in to comment.