Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New functions: fold_on_nonequal_inter and fold_on_nonequal_union #6

Merged
merged 14 commits into from
May 10, 2024
159 changes: 159 additions & 0 deletions patriciaTree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ module type BASE_MAP = sig
type ('acc,'map) polyfold = { f: 'a. 'a key -> ('a,'map) value -> 'acc -> 'acc } [@@unboxed]
val fold : ('acc,'map) polyfold -> 'map t -> 'acc -> 'acc

type ('acc,'map) polyfold2 = { f: 'a. 'a key -> ('a,'map) value -> ('a,'map) value -> 'acc -> 'acc } [@@unboxed]
val fold_on_nonequal_inter : ('acc,'map) polyfold2 -> 'map t -> 'map t -> 'acc -> 'acc

type ('acc,'map) polyfold2_union = { f: 'a. 'a key -> ('a,'map) value option -> ('a,'map) value option -> 'acc -> 'acc } [@@unboxed]
val fold_on_nonequal_union : ('acc,'map) polyfold2_union -> 'map t -> 'map t -> 'acc -> 'acc

type 'map polypredicate = { f: 'a. 'a key -> ('a,'map) value -> bool; } [@@unboxed]
val filter : 'map polypredicate -> 'map t -> 'map t
val for_all : 'map polypredicate -> 'map t -> bool
Expand Down Expand Up @@ -299,6 +305,11 @@ module type MAP = sig
val split : key -> 'a t -> 'a t * 'a option * 'a t
val iter : (key -> 'a -> unit) -> 'a t -> unit
val fold : (key -> 'a -> 'acc -> 'acc) -> 'a t -> 'acc -> 'acc
val fold_on_nonequal_inter : (key -> 'a -> 'a -> 'acc -> 'acc) ->
'a t -> 'a t -> 'acc -> 'acc
val fold_on_nonequal_union :
(key -> 'a option -> 'a option -> 'acc -> 'acc) ->
'a t -> 'a t -> 'acc -> 'acc
val filter : (key -> 'a -> bool) -> 'a t -> 'a t
val for_all : (key -> 'a -> bool) -> 'a t -> bool
val map : ('a -> 'a) -> 'a t -> 'a t
Expand Down Expand Up @@ -1172,6 +1183,144 @@ module MakeCustomHeterogeneous
let acc = fold f tree0 acc in
fold f tree1 acc


type ('acc,'map) polyfold2 = { f: 'a. 'a key -> ('a,'map) value -> ('a,'map) value -> 'acc -> 'acc } [@@unboxed]
let rec fold_on_nonequal_inter f ta tb acc =
if ta == tb then acc
else match NODE.view ta,NODE.view tb with
| Empty, _ | _, Empty -> acc
| Leaf{key;value},_ ->
(try let valueb = find key tb in
if valueb == value then acc else
f.f key value valueb acc
with Not_found -> acc)
| _,Leaf{key;value} ->
(try let valuea = find key ta in
if valuea == value then acc else
f.f key valuea value acc
with Not_found -> acc)
| Branch{prefix=pa;branching_bit=ma;tree0=ta0;tree1=ta1},
Branch{prefix=pb;branching_bit=mb;tree0=tb0;tree1=tb1} ->
if ma == mb && pa == pb
(* Same prefix: fold on each subtrees *)
then
let acc = fold_on_nonequal_inter f ta0 tb0 acc in
let acc = fold_on_nonequal_inter f ta1 tb1 acc in
acc
else if unsigned_lt mb ma && match_prefix pb pa ma
then if ma land pb == 0
then fold_on_nonequal_inter f ta0 tb acc
else fold_on_nonequal_inter f ta1 tb acc
else if unsigned_lt ma mb && match_prefix pa pb mb
then if mb land pa == 0
then fold_on_nonequal_inter f ta tb0 acc
else fold_on_nonequal_inter f ta tb1 acc
else acc


type ('acc,'map) polyfold2_union =
{ f: 'a. 'a key -> ('a,'map) value option -> ('a,'map) value option ->
'acc -> 'acc } [@@unboxed]
let rec fold_on_nonequal_union:
'm 'acc. ('acc,'m) polyfold2_union -> 'm t -> 'm t -> 'acc -> 'acc =
fun (type m) f (ta:m t) (tb:m t) acc ->
if ta == tb then acc
else
let fleft:(_,_) polyfold =
{f=fun key value acc -> f.f key (Some value) None acc} in
let fright:(_,_)polyfold =
{f=fun key value acc -> f.f key None (Some value) acc} in
match NODE.view ta,NODE.view tb with
| Empty, _ -> fold fright tb acc
| _, Empty -> fold fleft ta acc
| Leaf{key;value},_ ->
let ida = Key.to_int key in
(* Fold on the rest, knowing that ida may or may not be in b. So we fold and use
did_a to remember if we already did the call to a. *)
let g (type b) (keyb:b key) (valueb:(b,m) value) (acc,did_a) =
let default() = (f.f keyb None (Some valueb) acc,did_a) in
if did_a then default()
else
let idb = Key.to_int keyb in
if unsigned_lt idb ida then default()
else if unsigned_lt ida idb then
let acc = f.f key (Some value) None acc in
let acc = f.f keyb None (Some valueb) acc in
(acc,true)
else match Key.polyeq key keyb with
| Eq ->
if value == valueb then (acc,true)
else (f.f key (Some value) (Some valueb) acc,true)
| Diff ->
raise (Invalid_argument "Keys with same to_int value are not equal by polyeq")
in
let (acc,found) = fold{f=fun keyb valueb acc -> g keyb valueb acc} tb (acc,false) in
if found then acc
else f.f key (Some value) None acc
| _,Leaf{key;value} ->
let idb = Key.to_int key in
let g (type a) (keya: a key) (valuea:(a,m) value) (acc,did_b) =
let default() = (f.f keya (Some valuea) None acc,did_b) in
if did_b then default()
else
let ida = Key.to_int keya in
if unsigned_lt ida idb then default()
else if unsigned_lt idb ida then
let acc = f.f key None (Some value) acc in
let acc = f.f keya (Some valuea) None acc in
(acc,true)
else match Key.polyeq keya key with
| Eq ->
if valuea == value then (acc,true)
else (f.f keya (Some valuea) (Some value) acc,true)
| Diff ->
raise (Invalid_argument "Keys with same to_int value are not equal by polyeq")
in
let (acc,found) = fold{f=fun keya valuea acc -> g keya valuea acc} ta (acc,false) in
if found then acc
else f.f key None (Some value) acc
| Branch{prefix=pa;branching_bit=ma;tree0=ta0;tree1=ta1},
Branch{prefix=pb;branching_bit=mb;tree0=tb0;tree1=tb1} ->
if ma == mb && pa == pb
(* Same prefix: merge the subtrees *)
then
let acc = fold_on_nonequal_union f ta0 tb0 acc in
let acc = fold_on_nonequal_union f ta1 tb1 acc in
acc
else if unsigned_lt mb ma && match_prefix pb pa ma
then if ma land pb == 0
then
let acc = fold_on_nonequal_union f ta0 tb acc in
let acc = fold fleft ta1 acc in
acc
else
let acc = fold fleft ta0 acc in
let acc = fold_on_nonequal_union f ta1 tb acc in
acc
else if unsigned_lt ma mb && match_prefix pa pb mb
then if mb land pa == 0
then
let acc = fold_on_nonequal_union f ta tb0 acc in
let acc = fold fright tb1 acc in
acc
else
let acc = fold fright tb0 acc in
let acc = fold_on_nonequal_union f ta tb1 acc in
acc
else
(* Distinct subtrees: process them in increasing order of keys. *)
if unsigned_lt pa pb then
let acc = fold fleft ta acc in
let acc = fold fright tb acc in
acc
else
let acc = fold fright tb acc in
let acc = fold fleft ta acc in
acc
;;



type 'map polypredicate = { f: 'a. 'a key -> ('a,'map) value -> bool; } [@@unboxed]
let filter f m = filter_map {f = fun k v -> if f.f k v then Some v else None } m
let rec for_all f m = match NODE.view m with
Expand Down Expand Up @@ -1512,6 +1661,16 @@ module MakeCustom
let slow_merge (f : key -> 'a option -> 'b option -> 'c option) a b = BaseMap.slow_merge {f=fun k v1 v2 -> snd_opt (f k (opt_snd v1) (opt_snd v2))} a b
let iter (f: key -> 'a -> unit) a = BaseMap.iter {f=fun k (Snd v) -> f k v} a
let fold (f: key -> 'a -> 'acc) m acc = BaseMap.fold {f=fun k (Snd v) acc -> f k v acc} m acc
let fold_on_nonequal_inter (f: key -> 'a -> 'b -> 'acc) ma mb acc =
let f k (Snd va) (Snd vb) acc = f k va vb acc in
BaseMap.fold_on_nonequal_inter {f} ma mb acc
let fold_on_nonequal_union
(f: key -> 'a option -> 'b option -> 'acc) ma mb acc =
let f k va vb acc =
let va = Option.map (fun (Snd v) -> v) va in
let vb = Option.map (fun (Snd v) -> v) vb in
f k va vb acc in
BaseMap.fold_on_nonequal_union {f} ma mb acc

let pretty ?pp_sep (f: Format.formatter -> key -> 'a -> unit) fmt m =
BaseMap.pretty ?pp_sep {f=fun fmt k (Snd v) -> f fmt k v} fmt m
Expand Down
36 changes: 35 additions & 1 deletion patriciaTree.mli
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,24 @@ module type BASE_MAP = sig
where [(key_1, value_1) ... (key_n, value_n)] are the bindings of [m], in
the {{!unsigned_lt}unsigned order} on [Key.to_int]. *)

type ('acc,'map) polyfold2 = { f: 'a. 'a key -> ('a,'map) value -> ('a,'map) value -> 'acc -> 'acc } [@@unboxed]
val fold_on_nonequal_inter : ('acc,'map) polyfold2 -> 'map t -> 'map t -> 'acc -> 'acc
(** [fold_on_nonequal_inter f m1 m2 acc] returns
[f.f key_n value1_n value2n (... (f.f key_1 value1_1 value2_1 acc))] where
[(key_1, value1_1, value2_1) ... (key_n, value1_n, value2_n)] are the
bindings that exist in both maps ([m1 ∩ m2]) whose values are physically different.
Calls to [f.f] are performed in the {{!unsigned_lt}unsigned order} of [Key.to_int]. *)


type ('acc,'map) polyfold2_union = { f: 'a. 'a key -> ('a,'map) value option -> ('a,'map) value option -> 'acc -> 'acc } [@@unboxed]
val fold_on_nonequal_union : ('acc,'map) polyfold2_union -> 'map t -> 'map t -> 'acc -> 'acc
(** [fold_on_nonequal_union f m1 m2 acc] returns
[f.f key_n value1_n value2n (... (f.f key_1 value1_1 value2_1 acc))] where
[(key_1, value1_1, value2_1) ... (key_n, value1_n, value2_n)] are the
bindings that exists in either map ([m1 ∪ m2]) whose values are physically
different.
Calls to [f.f] are performed in the {{!unsigned_lt}unsigned order} of [Key.to_int]. *)

type 'map polypredicate = { f: 'a. 'a key -> ('a,'map) value -> bool; } [@@unboxed]
val filter : 'map polypredicate -> 'map t -> 'map t
(** [filter f m] returns the submap of [m] containing the bindings [k->v]
Expand Down Expand Up @@ -899,6 +917,23 @@ module type MAP = sig
val fold : (key -> 'a -> 'acc -> 'acc) -> 'a t -> 'acc -> 'acc
(** Fold on each (key,value) pair of the map, in increasing {{!unsigned_lt}unsigned order} of keys. *)

val fold_on_nonequal_inter : (key -> 'a -> 'a -> 'acc -> 'acc) ->
'a t -> 'a t -> 'acc -> 'acc
(** [fold_on_nonequal_inter f m1 m2 acc] returns
[f key_n value1_n value2n (... (f key_1 value1_1 value2_1 acc))] where
[(key_1, value1_1, value2_1) ... (key_n, value1_n, value2_n)] are the
bindings that exist in both maps ([m1 ∩ m2]) whose values are physically different.
Calls to [f] are performed in the {{!unsigned_lt}unsigned order} of [Key.to_int]. *)

val fold_on_nonequal_union: (key -> 'a option -> 'a option -> 'acc -> 'acc) ->
'a t -> 'a t -> 'acc -> 'acc
(** [fold_on_nonequal_union f m1 m2 acc] returns
[f key_n value1_n value2n (... (f key_1 value1_1 value2_1 acc))] where
[(key_1, value1_1, value2_1) ... (key_n, value1_n, value2_n)] are the
bindings that exists in either map ([m1 ∪ m2]) whose values are physically
different.
Calls to [f.f] are performed in the {{!unsigned_lt}unsigned order} of [Key.to_int]. *)

val filter : (key -> 'a -> bool) -> 'a t -> 'a t
(** Returns the submap containing only the key->value pairs satisfying the
given predicate. [f] is called in increasing {{!unsigned_lt}unsigned order} of keys. *)
Expand Down Expand Up @@ -1115,7 +1150,6 @@ module type MAP = sig
in increasing {{!unsigned_lt}unsigned order} of [Key.to_int] *)
end


(** {1 Keys} *)
(** Keys are the functor arguments used to build the maps. *)

Expand Down
53 changes: 53 additions & 0 deletions patriciaTreeTest.ml
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,27 @@ let%test_module _ = (module struct
| None, _ | _, None -> None
| Some a, Some b -> (f key a b)) m1 m2

let fold_on_nonequal_inter f m1 m2 acc =
let racc = ref acc in
ignore @@ M.merge (fun key a b ->
match a,b with
| None, _ | _, None -> None
| Some a, Some b ->
if a != b
then racc := f key a b !racc;
None) m1 m2;
!racc

let fold_on_nonequal_union f ma mb acc =
let union = M.merge (fun _key a b ->
match a,b with
| None, None -> assert false
| Some a, Some b when a == b -> None
| None, Some _ | Some _, None | Some _, Some _ -> Some(a,b)) ma mb in
let elts = M.bindings union in
let elts = List.sort (fun (key1,_val1) (key2,_val2) -> unsigned_compare key1 key2) elts in
List.fold_left (fun acc (key,(val1,val2)) -> f key val1 val2 acc) acc elts

let pop_unsigned_minimum m =
match M.min_binding m with
| exception Not_found -> None
Expand Down Expand Up @@ -529,6 +550,38 @@ let%test_module _ = (module struct
modelres == myres)
let () = QCheck.Test.check_exn test_disjoint

let test_fold_on_nonequal_inter = QCheck.Test.make ~count:1000 ~name:"fold_on_nonequal_inter"
gen (fun x ->
let (m1,model1,m2,model2) = model_from_gen x in
let orig_f key v1 v2 acc = sdbm key @@ sdbm v1 @@ sdbm v2 acc in
let chk_calls = check_increases () in
let f key v1 v2 acc =
chk_calls key;
orig_f key v1 v2 acc
in
let myres = MyMap.fold_on_nonequal_inter f m1 m2 117 in
let modelres = IntMap.fold_on_nonequal_inter orig_f model1 model2 117 in
modelres == myres)
let () = QCheck.Test.check_exn test_fold_on_nonequal_inter

let test_fold_on_nonequal_union = QCheck.Test.make ~count:1000 ~name:"fold_on_nonequal_union"
gen (fun x ->
let (m1,model1,m2,model2) = model_from_gen x in
let orig_f key v1 v2 acc =
(* Printf.printf "Calling f key=%d v1=%s v2=%s acc=%d\n%!" *)
(* key (match v1 with None -> "None" | Some v -> string_of_int v) *)
(* (match v2 with None -> "None" | Some v -> string_of_int v) acc; *)
(* chk_calls key; *)
let v1 = match v1 with None -> 421 | Some v -> v in
let v2 = match v2 with None -> 567 | Some v -> v in
sdbm key @@ sdbm v1 @@ sdbm v2 acc in
let chk_calls = check_increases () in
let f key v1 v2 acc = chk_calls key; orig_f key v1 v2 acc in
let myres = MyMap.fold_on_nonequal_union f m1 m2 117 in
let modelres = IntMap.fold_on_nonequal_union orig_f model1 model2 117 in
modelres == myres)
let () = QCheck.Test.check_exn test_fold_on_nonequal_union

let%test "negative_keys" =
let map = MyMap.add 0 0 MyMap.empty in
let _pp_l fmt = Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "; ")
Expand Down