diff --git a/src/certified/HashTree.mo b/src/certified/HashTree.mo index 97dd5e4..d26bca0 100644 --- a/src/certified/HashTree.mo +++ b/src/certified/HashTree.mo @@ -6,28 +6,28 @@ import CBOR "../cbor/CBOR"; import SHA256 "../crypto/SHA256"; module HashTree { - type Hash = [Nat8]; - type Key = [Nat8]; - type Value = [Nat8]; - - type HashTree = { - #empty; - #fork : (HashTree, HashTree); - #labeled : (Key, HashTree); - #leaf : Value; - #pruned : Hash; + public type Hash = [Nat8]; + public type Key = [Nat8]; + public type Value = [Nat8]; + + public type HashTree = { + #Empty; + #Fork : (HashTree, HashTree); + #Labeled : (Key, HashTree); + #Leaf : Value; + #Pruned : Hash; }; // Well-formed trees have the property that labeled subtrees appear in // strictly increasing order of labels, and are not mixed with leaves. public func wellFormed(t : HashTree) : Bool { switch (t) { - case (#empty or #leaf(_) or #pruned(_)) true; + case (#Empty or #Leaf(_) or #Pruned(_)) true; case (_) { var lbl = [] : [Nat8]; for (t in flatten(t)) switch (t) { - case (#leaf(_)) return false; - case (#labeled(l, t)) { + case (#Leaf(_)) return false; + case (#Labeled(l, t)) { if (not strictlyIncreasing(lbl, l)) return false; if (not wellFormed(t)) return false; lbl := l; @@ -55,11 +55,11 @@ module HashTree { public func next() : ?HashTree { switch (stack) { case (null) null; - case (?(#empty, r) or ?(#pruned(_), r)) { + case (?(#Empty, r) or ?(#Pruned(_), r)) { stack := r; next(); }; - case (?(#fork(left, right), r)) { + case (?(#Fork(left, right), r)) { stack := ?(left, ?(right, r)); next(); }; @@ -73,11 +73,11 @@ module HashTree { public func reconstruct(t : HashTree) : Hash { switch (t) { - case (#empty) { hash.empty() }; - case (#fork(l, r)) { hash.fork(reconstruct(l), reconstruct(r)) }; - case (#labeled(k, t)) { hash.labeled(k, reconstruct(t)) }; - case (#leaf(v)) { hash.leaf(v) }; - case (#pruned(h)) { h }; + case (#Empty) { hash.empty() }; + case (#Fork(l, r)) { hash.fork(reconstruct(l), reconstruct(r)) }; + case (#Labeled(k, t)) { hash.labeled(k, reconstruct(t)) }; + case (#Leaf(v)) { hash.leaf(v) }; + case (#Pruned(h)) { h }; }; }; @@ -116,11 +116,11 @@ module HashTree { private module cbor = { public func tree(t : HashTree) : CBOR.Value = switch (t) { - case (#empty) { empty() }; - case (#fork(l, r)) { fork(l, r) }; - case (#labeled(k, t)) { labeled(k, t) }; - case (#leaf(v)) { leaf(v) }; - case (#pruned(h)) { pruned(h) }; + case (#Empty) { empty() }; + case (#Fork(l, r)) { fork(l, r) }; + case (#Labeled(k, t)) { labeled(k, t) }; + case (#Leaf(v)) { leaf(v) }; + case (#Pruned(h)) { pruned(h) }; }; public func empty() : CBOR.Value = #Array([ diff --git a/src/certified/RBTree.mo b/src/certified/RBTree.mo new file mode 100644 index 0000000..7e2daae --- /dev/null +++ b/src/certified/RBTree.mo @@ -0,0 +1,202 @@ +import P "mo:base/Prelude"; + +import HashTree "HashTree"; + +module { + public type Node = ( + key : [Nat8], // 0 : key + value : [Nat8], // 1 : value + left : ?Node, // 2 : left + right : ?Node, // 3 : right + color : Color, // 4 : color + hash : HashTree.Hash, // 5 : hash + ); + + private type Color = { + #Red; + #Black; + }; + + private func flip(c : Color) : Color { + switch (c) { + case (#Red) { #Black }; + case (#Black) { #Red }; + }; + }; + + public func insertRoot(root : ?Node, k : [Nat8], v : [Nat8]) : (Node, ?[Nat8]) { + let ((nk, nv, l, r, _, h), ov) = insert(root, k, v); + ((nk, nv, l, r, #Black, h), ov); + }; + + public func insert(t : ?Node, k : [Nat8], v : [Nat8]) : (Node, ?[Nat8]) { + switch (t) { + case (null) { (newNode(k, v), null) }; + case (?n) { + let (nk, kv, l, r, c, h) = n; + let (nn, ov) : (Node, ?[Nat8]) = switch (compare(k, nk)) { + case (#less) { + let (nl, ov) = insert(l, k, v); + ((nk, kv, ?nl, r, c, h), ov); + }; + case (#equal) { + ((nk, v, l, r, c, h), ?kv); + }; + case (#greater) { + let (nr, ov) = insert(r, k, v); + ((nk, kv, l, ?nr, c, h), ov); + }; + }; + (balance(update(nn)), ov); + }; + }; + }; + + public func get(t : ?Node, k : [Nat8]) : ?[Nat8] { + var root = t; + label l loop { + let (key, v, l, r, _, _) = switch (root) { + case (null) { break l }; + case (?v) { v }; + }; + switch (compare(k, key)) { + case (#less) { + root := l; + }; + case (#equal) { + return ?v; + }; + case (#greater) { + root := r; + }; + }; + }; + null; + }; + + private func compare(xs : [Nat8], ys : [Nat8]) : { #less; #equal; #greater } { + if (xs.size() < ys.size()) return #less; + if (xs.size() > ys.size()) return #greater; + var i = 0; + while (i < xs.size()) { + let x = xs[i]; + let y = ys[i]; + if (x < y) return #less; + if (y < x) return #greater; + i += 1; + }; + #equal; + }; + + private func isRed(n : ?Node) : Bool { + switch (n) { + case (?(_, _, _, _, #Red, _)) { true }; + case (_) { false }; + }; + }; + + private func balance(n : Node) : Node { + switch (n) { + case (k, v, ?l, ?r, c, h) { + if (not isRed(?l) and isRed(?r)) return rotateLeft(n); + if (isRed(?l) and isRed(l.2)) return rotateRight(n); + if (isRed(?l) and isRed(?r)) return (k, v, ?flipColor(l), ?flipColor(r), flip(c), h); + }; + case (_) {}; + }; + n; + }; + + private func rotateRight(n : Node) : Node { + assert (isRed(n.2)); + var l = unwrap(n.2); + // n.l = n.l.r; + let h = update((n.0, n.1, l.3, n.3, n.4, n.5)); + // r.r = h; + // r.c = h.c; + // r.r.c = #Red; + update((l.0, l.1, l.2, ?(h.0, h.1, h.2, h.3, #Red, h.5), h.4, l.5)); + }; + + private func rotateLeft(n : Node) : Node { + assert (isRed(n.3)); + var r = unwrap(n.3); + // n.r = n.r.l; + let h = update((n.0, n.1, n.2, r.2, n.4, n.5)); + // r.l = h; + // r.c = h.c; + // r.l.c = #Red; + update((r.0, r.1, ?(h.0, h.1, h.2, h.3, #Red, h.5), r.3, h.4, r.5)); + }; + + private func flipColor((k, v, l, r, c, h) : Node) : Node { + (k, v, l, r, flip(c), h); + }; + + // NOTE: do use with caution! + private func unwrap(x : ?T) : T { + switch x { + case (null) { P.unreachable() }; + case (?x) { x }; + }; + }; + + // Returns a new node based on the given key and value. + public func newNode(key : [Nat8], value : [Nat8]) : Node { + (key, value, null, null, #Red, HashTree.reconstruct(#Labeled(key, #Leaf(value)))); + }; + + // Updates the hashes of the given node. + private func update(n : Node) : Node { + let (k, v, l, r, c, _) = n; + (k, v, l, r, c, subHashTree(n)); + }; + + private func subHashTree(n : Node) : HashTree.Hash { + let h = dataHash(n); + let (_, _, l, r, _, _) = n; + switch (l, r) { + case (null, null) h; + case (?l, null) HashTree.reconstruct(#Fork(#Pruned(l.5), #Pruned(h))); + case (null, ?r) HashTree.reconstruct(#Fork(#Pruned(h), #Pruned(r.5))); + case (?l, ?r) HashTree.reconstruct(#Fork(#Pruned(l.5), #Fork(#Pruned(h), #Pruned(r.5)))); + }; + }; + + // Returns the Hash corresponding to the node. + public func getHash(n : ?Node) : ?HashTree.Hash { + switch (n) { + case (?n) ?n.5; + case (null) null; + }; + }; + + // Returns the HashTree corresponding to the node. + public func getHashTree(n : ?Node) : HashTree.HashTree { + switch (n) { + case (null) #Empty; + case (?(v, k, null, null, _, _)) { + if (v.size() == 0) return #Leaf(k); + return #Labeled(v, #Leaf(k)); + }; + case (?(v, _, l, r, _, _)) { + if (v.size() == 0) return #Fork( + getHashTree(l), + getHashTree(r), + ); + return #Labeled( + v, + #Fork( + getHashTree(l), + getHashTree(r), + ), + ); + }; + }; + }; + + // Hashes the data contained within the node. + private func dataHash((k, v, _, _, _, _) : Node) : HashTree.Hash { + HashTree.reconstruct(#Labeled(k, #Leaf(v))); + }; +}; diff --git a/src/certified/README.md b/src/certified/README.md index 04c4636..5c32360 100644 --- a/src/certified/README.md +++ b/src/certified/README.md @@ -5,6 +5,8 @@ declaring a special "certified variable". [Read more...](https://internetcomputer.org/how-it-works/response-certification/) +## HashTree + ```motoko import HashTree "mo:core/certified/HashTree"; ``` @@ -18,3 +20,17 @@ module { wellFormed : HashTree -> Bool; }; ``` + +## RBTree + +```motoko +module { + type Node = ([Nat8], [Nat8], ?Node, ?Node, Color, Hash); + get : (?Node, [Nat8]) -> ?[Nat8]; + getHash : ?Node -> ?Hash; + getHashTree : ?Node -> HashTree; + insert : (?Node, [Nat8], [Nat8]) -> (Node, ?[Nat8]); + insertRoot : (?Node, [Nat8], [Nat8]) -> (Node, ?[Nat8]); + newNode : ([Nat8], [Nat8]) -> Node; +}; +``` diff --git a/test/certified/HashTree.mo b/test/certified/HashTree.mo index cedc27d..1f835d9 100644 --- a/test/certified/HashTree.mo +++ b/test/certified/HashTree.mo @@ -11,40 +11,40 @@ func xb(t : Text) : [Nat8] = switch (decode(t)) { }; }; -let prunedTree = #fork( - #fork( - #labeled(b("a"), #fork( - #pruned(xb("1b4feff9bef8131788b0c9dc6dbad6e81e524249c879e9f10f71ce3749f5a638")), - #labeled(b("y"), #leaf(b("world"))), +let prunedTree = #Fork( + #Fork( + #Labeled(b("a"), #Fork( + #Pruned(xb("1b4feff9bef8131788b0c9dc6dbad6e81e524249c879e9f10f71ce3749f5a638")), + #Labeled(b("y"), #Leaf(b("world"))), )), - #labeled(b("b"), #pruned(xb("7b32ac0c6ba8ce35ac82c255fc7906f7fc130dab2a090f80fe12f9c2cae83ba6"))), + #Labeled(b("b"), #Pruned(xb("7b32ac0c6ba8ce35ac82c255fc7906f7fc130dab2a090f80fe12f9c2cae83ba6"))), ), - #fork( - #pruned(xb("ec8324b8a1f1ac16bd2e806edba78006479c9877fed4eb464a25485465af601d")), - #labeled(b("d"), #leaf(b("morning"))), + #Fork( + #Pruned(xb("ec8324b8a1f1ac16bd2e806edba78006479c9877fed4eb464a25485465af601d")), + #Labeled(b("d"), #Leaf(b("morning"))), ), ); -let tree = #fork( - #fork( - #labeled(b("a"), #fork( - #fork( - #labeled(b("x"), #leaf(b("hello"))), - #empty, +let tree = #Fork( + #Fork( + #Labeled(b("a"), #Fork( + #Fork( + #Labeled(b("x"), #Leaf(b("hello"))), + #Empty, ), - #labeled(b("y"), #leaf(b("world"))), + #Labeled(b("y"), #Leaf(b("world"))), )), - #labeled(b("b"), #leaf(b("good"))), + #Labeled(b("b"), #Leaf(b("good"))), ), - #fork( - #labeled(b("c"), #empty), - #labeled(b("d"), #leaf(b("morning"))), + #Fork( + #Labeled(b("c"), #Empty), + #Labeled(b("d"), #Leaf(b("morning"))), ), ); assert(wellFormed(prunedTree)); assert(wellFormed(tree)); -assert(not wellFormed(#fork(#leaf(b("a")), #empty))); +assert(not wellFormed(#Fork(#Leaf(b("a")), #Empty))); assert(Hex.encode(reconstruct(prunedTree)) == "eb5c5b2195e62d996b84c9bcc8259d19a83786a2f59e0878cec84c811f669aa0"); assert(Hex.encode(reconstruct(prunedTree)) == Hex.encode(reconstruct(tree))); diff --git a/test/certified/RBTree.mo b/test/certified/RBTree.mo new file mode 100644 index 0000000..a5d1771 --- /dev/null +++ b/test/certified/RBTree.mo @@ -0,0 +1,85 @@ +import Nat8 "mo:base/Nat8"; + +import HashTree "mo:core/certified/HashTree"; +import RBTree "mo:core/certified/RBTree"; + +func isRed(n : ?RBTree.Node) : Bool { + switch (n) { + case (?(_, _, _, _, #Red, _)) { true }; + case (_) { false }; + }; +}; + +func isBalanced(t : ?RBTree.Node) : Bool { + func _isBalanced(n : ?RBTree.Node, nrBlack : Nat) : Bool { + var _nrBlack = nrBlack; + switch (n) { + case (null) { + _nrBlack == 0; + }; + case (?n) { + if (not isRed(?n)) { + _nrBlack -= 1; + } else { + assert (not isRed(n.2)); + assert (not isRed(n.3)); + }; + _isBalanced(n.2, _nrBlack) and _isBalanced(n.3, _nrBlack); + }; + }; + }; + + // Calculate number of black nodes by following left. + var nrBlack = 0; + var current = t; + label l loop { + switch (current) { + case (null) { break l }; + case (?n) { + if (not isRed(?n)) nrBlack += 1; + current := n.2; + }; + }; + }; + _isBalanced(t, nrBlack); +}; + +var tree : ?RBTree.Node = null; + +func insert(n : Nat8) { + let kv = [n]; + let (nt, ov) = RBTree.insertRoot(tree, kv, kv); + assert (ov == null); + assert (isBalanced(?nt)); + tree := ?nt; +}; + +insert(10); +insert(8); +insert(12); +insert(9); +insert(11); + +let ht = RBTree.getHashTree(tree); +assert (HashTree.wellFormed(ht)); +assert ( + ht == #Labeled( + [10], + #Fork( + #Labeled( + [8], + #Fork( + #Empty, + #Labeled([9], #Leaf([9])), + ), + ), + #Labeled( + [12], + #Fork( + #Labeled([11], #Leaf([11])), + #Empty, + ), + ), + ), + ) +);