diff --git a/src/Trie.mo b/src/Trie.mo index e4831228..a8793db5 100644 --- a/src/Trie.mo +++ b/src/Trie.mo @@ -252,7 +252,6 @@ module { /// ``` public func empty() : Trie { #empty }; - /// Get the size in O(1) time. /// /// For a more detailed overview of how to use a `Trie`, @@ -354,36 +353,79 @@ module { /// Purely-functional representation permits _O(1)_ copy, via persistent sharing. public func clone(t : Trie) : Trie = t; - /// Replace the given key's value option with the given one, returning the previous one + /// Combine two nodes that may have a reduced size after an entry deletion. + func combineReducedNodes(left : Trie, right : Trie) : Trie { + switch (left, right) { + case (#empty, #empty) { + #empty + }; + case (#leaf(leftLeaf), #empty) { + #leaf(leftLeaf) + }; + case (#empty, #leaf(rightLeaf)) { + #leaf(rightLeaf) + }; + case (#leaf(leftLeaf), #leaf(rightLeaf)) { + let size = leftLeaf.size + rightLeaf.size; + if (size <= MAX_LEAF_SIZE) { + let union = List.append(leftLeaf.keyvals, rightLeaf.keyvals); + #leaf({ size = size; keyvals = union }) + } else { + branch(left, right) + } + }; + case (left, right) { + branch(left, right) + } + } + }; + + /// Replace the given key's value option with the given value, returning the modified trie. + /// Also returns the replaced value if the key existed and `null` otherwise. + /// Compares keys using the provided function `k_eq`. + /// + /// Note: Replacing a key's value by `null` removes the key and also shrinks the trie. + /// + /// For a more detailed overview of how to use a `Trie`, + /// see the [User's Overview](#overview). + /// + /// Example: + /// ```motoko include=initialize + /// trie := Trie.put(trie, key "test", Text.equal, 1).0; + /// trie := Trie.replace(trie, key "test", Text.equal, 42).0; + /// assert (Trie.get(trie, key "hello", Text.equal) == ?42); + /// ``` public func replace(t : Trie, k : Key, k_eq : (K, K) -> Bool, v : ?V) : (Trie, ?V) { let key_eq = equalKey(k_eq); + var replacedValue: ?V = null; - func rec(t : Trie, bitpos : Nat) : (Trie, ?V) { + func recursiveReplace(t : Trie, bitpos : Nat) : Trie { switch t { case (#empty) { let (kvs, _) = AssocList.replace(null, k, key_eq, v); - (leaf(kvs, bitpos), null) + leaf(kvs, bitpos) }; case (#branch(b)) { let bit = Hash.bit(k.hash, bitpos); // rebuild either the left or right path with the (k, v) pair if (not bit) { - let (l, v_) = rec(b.left, bitpos + 1); - (branch(l, b.right), v_) + let l = recursiveReplace(b.left, bitpos + 1); + combineReducedNodes(l, b.right) } else { - let (r, v_) = rec(b.right, bitpos + 1); - (branch(b.left, r), v_) + let r = recursiveReplace(b.right, bitpos + 1); + combineReducedNodes(b.left, r) } }; case (#leaf(l)) { - let (kvs2, old_val) = AssocList.replace(l.keyvals, k, key_eq, v); - (leaf(kvs2, bitpos), old_val) + let (kvs2, oldValue) = AssocList.replace(l.keyvals, k, key_eq, v); + replacedValue := oldValue; + leaf(kvs2, bitpos) } } }; - let (to, vo) = rec(t, 0); - //assert(isValid(to, false)); - (to, vo) + let newTrie = recursiveReplace(t, 0); + //assert(isValid(newTrie, false)); + (newTrie, replacedValue) }; /// Put the given key's value in the trie; return the new trie, and the previous value associated with the key, if any. @@ -577,7 +619,7 @@ module { switch (x, y) { case (null, ?v) { v }; case (?v, null) { v }; - case (_, _) { Debug.trap "Trie.mergeDisjoint"} + case (_, _) { Debug.trap "Trie.mergeDisjoint" } } } ), @@ -1287,11 +1329,7 @@ module { case (#branch(b)) { let fl = rec(b.left, bitpos + 1); let fr = rec(b.right, bitpos + 1); - if (isEmpty(fl) and isEmpty(fr)) { - #empty - } else { - branch(fl, fr) - } + combineReducedNodes(fl, fr) } } }; @@ -1339,11 +1377,7 @@ module { case (#branch(b)) { let fl = rec(b.left, bitpos + 1); let fr = rec(b.right, bitpos + 1); - if (isEmpty(fl) and isEmpty(fr)) { - #empty - } else { - branch(fl, fr) - } + combineReducedNodes(fl, fr) } } }; @@ -1508,7 +1542,11 @@ module { updated_outer }; - /// Remove the given key's value in the trie; return the new trie + /// Remove the entry for the given key from the trie, by returning the reduced trie. + /// Also returns the removed value if the key existed and `null` otherwise. + /// Compares keys using the provided function `k_eq`. + /// + /// Note: The removal of an existing key shrinks the trie. /// /// For a more detailed overview of how to use a `Trie`, /// see the [User's Overview](#overview). @@ -1517,7 +1555,7 @@ module { /// ```motoko include=initialize /// trie := Trie.put(trie, key "hello", Text.equal, 42).0; /// trie := Trie.put(trie, key "bye", Text.equal, 32).0; - /// // remove the value associated with "hello" + /// // remove the entry associated with "hello" /// trie := Trie.remove(trie, key "hello", Text.equal).0; /// assert (Trie.get(trie, key "hello", Text.equal) == null); /// ``` diff --git a/src/TrieMap.mo b/src/TrieMap.mo index fd623f8e..29b00ce3 100644 --- a/src/TrieMap.mo +++ b/src/TrieMap.mo @@ -104,6 +104,8 @@ module { /// Delete the entry associated with key `key`, if it exists. If the key is /// absent, there is no effect. /// + /// Note: The deletion of an existing key shrinks the trie map. + /// /// Example: /// ```motoko include=initialize /// map.put(0, 10); @@ -121,6 +123,8 @@ module { /// Delete the entry associated with key `key`. Return the deleted value /// as an option if it exists, and `null` otherwise. /// + /// Note: The deletion of an existing key shrinks the trie map. + /// /// Example: /// ```motoko include=initialize /// map.put(0, 10); diff --git a/test/trieMapTest.mo b/test/trieMapTest.mo index eaa245be..f2bb0f02 100644 --- a/test/trieMapTest.mo +++ b/test/trieMapTest.mo @@ -4,6 +4,8 @@ import Iter "mo:base/Iter"; import Hash "mo:base/Hash"; import Text "mo:base/Text"; import Nat "mo:base/Nat"; +import Array "mo:base/Array"; +import Order "mo:base/Order"; import Suite "mo:matchers/Suite"; import T "mo:matchers/Testable"; @@ -356,6 +358,156 @@ let suite = Suite.suite( Suite.run(suite); +/* --------------------------------------- */ + +object Random { + var number = 4711; + public func next() : Nat { + number := (123138118391 * number + 133489131) % 9999; + number + } +}; + +func shuffle(array : [Nat]) : [Nat] { + let extended = Array.map(array, func(value) { (value, Random.next()) }); + let sorted = Array.sort<(Nat, Nat)>( + extended, + func(first, second) { + Nat.compare(first.1, second.1) + } + ); + Array.map<(Nat, Nat), Nat>( + sorted, + func(value) { + value.0 + } + ) +}; + +let testSize = 1_000; + +let testKeys = shuffle(Array.tabulate(testSize, func(index) { index })); + +func buildTestTrie() : TrieMap.TrieMap { + let trie = TrieMap.TrieMap(Nat.equal, Hash.hash); + for (key in testKeys.vals()) { + trie.put(key, debug_show (key)) + }; + trie +}; + +func expectedKeyValuePairs(keys : [Nat]) : [(Nat, Text)] { + Array.tabulate<(Nat, Text)>(keys.size(), func(index) { (keys[index], debug_show (keys[index])) }) +}; + +let expectedEntries = expectedKeyValuePairs(Array.sort(testKeys, Nat.compare)); +let expectedKeys = Array.sort(testKeys, Nat.compare); +let expectedValues = Array.sort(Array.map(expectedKeys, func(key) { debug_show (key) }), Text.compare); + +let entryTestable = T.tuple2Testable(T.natTestable, T.textTestable); + +func compareByKey(first : (Nat, Text), second : (Nat, Text)) : Order.Order { + Nat.compare(first.0, second.0) +}; + +func sortedEntries(trie : TrieMap.TrieMap) : [(Nat, Text)] { + Array.sort(Iter.toArray(trie.entries()), compareByKey) +}; + +class TrieMatcher(expected : [(Nat, Text)]) : M.Matcher> { + public func describeMismatch(actual : TrieMap.TrieMap, description : M.Description) { + Prim.debugPrint(debug_show (sortedEntries(actual)) # " should be " # debug_show (expected)) + }; + + public func matches(actual : TrieMap.TrieMap) : Bool { + sortedEntries(actual) == expected + } +}; + +let randomTestSuite = Suite.suite( + "random trie", + [ + Suite.test( + "size", + buildTestTrie().size(), + M.equals(T.nat(testSize)) + ), + Suite.test( + "iterate entries", + sortedEntries(buildTestTrie()), + M.equals(T.array<(Nat, Text)>(entryTestable, expectedEntries)) + ), + Suite.test( + "iterate keys", + Array.sort(Iter.toArray(buildTestTrie().keys()), Nat.compare), + M.equals(T.array(T.natTestable, expectedKeys)) + ), + Suite.test( + "iterate values", + Array.sort(Iter.toArray(buildTestTrie().vals()), Text.compare), + M.equals(T.array(T.textTestable, expectedValues)) + ), + Suite.test( + "get all", + do { + let trie = buildTestTrie(); + for (key in testKeys.vals()) { + let value = trie.get(key); + assert (value == ?debug_show (key)) + }; + trie + }, + TrieMatcher(expectedEntries) + ), + Suite.test( + "replace all", + do { + let trie = buildTestTrie(); + for (key in testKeys.vals()) { + let value = trie.replace(key, "TEST-" # debug_show (key)); + assert (value == ?debug_show (key)) + }; + trie + }, + TrieMatcher(Array.map(expectedKeys, func(key) { (key, "TEST-" # debug_show (key)) })) + ), + Suite.test( + "remove randomized", + do { + let trie = buildTestTrie(); + var count = 0; + for (key in testKeys.vals()) { + if (Random.next() % 2 == 0) { + let result = trie.remove(key); + assert (result == ?debug_show (key)); + count += 1 + } + }; + trie.size() == +testKeys.size() - count + }, + M.equals(T.bool(true)) + ), + Suite.test( + "clear", + do { + let trie = buildTestTrie(); + for ((key, value) in trie.entries()) { + // stable iteration + assert (debug_show (key) == value); + let result = trie.remove(key); + assert (result == ?debug_show (key)) + }; + trie + }, + TrieMatcher([]) + ) + ] +); + +Suite.run(randomTestSuite); + +/* --------------------------------------- */ + debug { let a = TrieMap.TrieMap(Text.equal, Text.hash);