diff --git a/src/RBTree.mo b/src/RBTree.mo index cbb6834a..e019dda1 100644 --- a/src/RBTree.mo +++ b/src/RBTree.mo @@ -72,6 +72,12 @@ module { /// iterator is persistent, like the tree itself public func entriesRev() : I.Iter<(X, Y)> { iter(tree, #bwd) }; + + /// An iterator for the key-value entries of the map, in ascending key order + /// over the keys greater than or equal to x. + /// + /// iterator is persistent, like the tree itself + public func entriesTail(x : X) : I.Iter<(X, Y)> { iterTail(tree, x, compareTo) }; }; @@ -108,6 +114,47 @@ module { } }; + /// An iterator for the entries of the map, in ascending order over the keys + /// greater than or equal to x. + public func iterTail(t : Tree, x : X, compareTo : (X, X) -> O.Order) + : I.Iter<(X, Y)> { + object { + var trees : IterRep = ?(#tr(t), null); + public func next() : ?(X, Y) { + switch trees { + case null { null }; + case (?(#tr(#leaf), ts)) { + trees := ts; + next() + }; + case (?(#xy(xy), ts)) { + trees := ts; + switch (xy.1) { + case null { next() }; + case (?y) { ?(xy.0, y) } + } + }; + case (?(#tr(#node(_, l, xy, r)), ts)) { + switch (compareTo(x, xy.0)) { + case (#less) { + trees := ?(#tr(l), ?(#xy(xy), ?(#tr(r), ts))); + next() + }; + case (#equal) { + trees := ?(#xy(xy), ?(#tr(r), ts)); + next() + }; + case (#greater) { + trees := ?(#tr(r), ts); + next() + }; + } + }; + } + }; + } + }; + /// Remove the value associated with a given key. func removeRec(x : X, compareTo : (X, X) -> O.Order, t : Tree) : (?Y, Tree) { diff --git a/test/RBTreeTest.mo b/test/RBTreeTest.mo index f4e39859..82dc7e13 100644 --- a/test/RBTreeTest.mo +++ b/test/RBTreeTest.mo @@ -1,6 +1,6 @@ import Debug "mo:base/Debug"; import Nat "mo:base/Nat"; -import I "mo:base/Iter"; +import Iter "mo:base/Iter"; import List "mo:base/List"; import RBT "mo:base/RBTree"; @@ -56,6 +56,20 @@ for ((num, lab) in t.entriesRev()) { assert RBT.size(t.share()) == 9; +do { var i = 4; +for ((num, lab) in t.entriesTail(4)) { + assert(num == i); + i += 1; +}}; + +do { var i = 1; +for ((num, lab) in t.entriesTail(0)) { + assert(num == i); + i += 1; +}}; + +assert Iter.size(t.entriesTail(10)) == 0; + t.delete(5); assert RBT.size(t.share()) == 8;