Skip to content

Commit

Permalink
Add RBTree for certified data.
Browse files Browse the repository at this point in the history
  • Loading branch information
q-uint committed Apr 14, 2024
1 parent 746193f commit d6f3c31
Show file tree
Hide file tree
Showing 5 changed files with 349 additions and 46 deletions.
50 changes: 25 additions & 25 deletions src/certified/HashTree.mo
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,28 @@ import CBOR "../cbor/CBOR";
import SHA256 "../crypto/SHA256";

module HashTree {
type Hash = [Nat8];
type Key = [Nat8];
type Value = [Nat8];

type HashTree = {
#empty;
#fork : (HashTree, HashTree);
#labeled : (Key, HashTree);
#leaf : Value;
#pruned : Hash;
public type Hash = [Nat8];
public type Key = [Nat8];
public type Value = [Nat8];

public type HashTree = {
#Empty;
#Fork : (HashTree, HashTree);
#Labeled : (Key, HashTree);
#Leaf : Value;
#Pruned : Hash;
};

// Well-formed trees have the property that labeled subtrees appear in
// strictly increasing order of labels, and are not mixed with leaves.
public func wellFormed(t : HashTree) : Bool {
switch (t) {
case (#empty or #leaf(_) or #pruned(_)) true;
case (#Empty or #Leaf(_) or #Pruned(_)) true;
case (_) {
var lbl = [] : [Nat8];
for (t in flatten(t)) switch (t) {
case (#leaf(_)) return false;
case (#labeled(l, t)) {
case (#Leaf(_)) return false;
case (#Labeled(l, t)) {
if (not strictlyIncreasing(lbl, l)) return false;
if (not wellFormed(t)) return false;
lbl := l;
Expand Down Expand Up @@ -55,11 +55,11 @@ module HashTree {
public func next() : ?HashTree {
switch (stack) {
case (null) null;
case (?(#empty, r) or ?(#pruned(_), r)) {
case (?(#Empty, r) or ?(#Pruned(_), r)) {
stack := r;
next();
};
case (?(#fork(left, right), r)) {
case (?(#Fork(left, right), r)) {
stack := ?(left, ?(right, r));
next();
};
Expand All @@ -73,11 +73,11 @@ module HashTree {

public func reconstruct(t : HashTree) : Hash {
switch (t) {
case (#empty) { hash.empty() };
case (#fork(l, r)) { hash.fork(reconstruct(l), reconstruct(r)) };
case (#labeled(k, t)) { hash.labeled(k, reconstruct(t)) };
case (#leaf(v)) { hash.leaf(v) };
case (#pruned(h)) { h };
case (#Empty) { hash.empty() };
case (#Fork(l, r)) { hash.fork(reconstruct(l), reconstruct(r)) };
case (#Labeled(k, t)) { hash.labeled(k, reconstruct(t)) };
case (#Leaf(v)) { hash.leaf(v) };
case (#Pruned(h)) { h };
};
};

Expand Down Expand Up @@ -116,11 +116,11 @@ module HashTree {

private module cbor = {
public func tree(t : HashTree) : CBOR.Value = switch (t) {
case (#empty) { empty() };
case (#fork(l, r)) { fork(l, r) };
case (#labeled(k, t)) { labeled(k, t) };
case (#leaf(v)) { leaf(v) };
case (#pruned(h)) { pruned(h) };
case (#Empty) { empty() };
case (#Fork(l, r)) { fork(l, r) };
case (#Labeled(k, t)) { labeled(k, t) };
case (#Leaf(v)) { leaf(v) };
case (#Pruned(h)) { pruned(h) };
};

public func empty() : CBOR.Value = #Array([
Expand Down
202 changes: 202 additions & 0 deletions src/certified/RBTree.mo
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import P "mo:base/Prelude";

import HashTree "HashTree";

module {
public type Node = (
key : [Nat8], // 0 : key
value : [Nat8], // 1 : value
left : ?Node, // 2 : left
right : ?Node, // 3 : right
color : Color, // 4 : color
hash : HashTree.Hash, // 5 : hash
);

private type Color = {
#Red;
#Black;
};

private func flip(c : Color) : Color {
switch (c) {
case (#Red) { #Black };
case (#Black) { #Red };
};
};

public func insertRoot(root : ?Node, k : [Nat8], v : [Nat8]) : (Node, ?[Nat8]) {
let ((nk, nv, l, r, _, h), ov) = insert(root, k, v);
((nk, nv, l, r, #Black, h), ov);
};

public func insert(t : ?Node, k : [Nat8], v : [Nat8]) : (Node, ?[Nat8]) {
switch (t) {
case (null) { (newNode(k, v), null) };
case (?n) {
let (nk, kv, l, r, c, h) = n;
let (nn, ov) : (Node, ?[Nat8]) = switch (compare(k, nk)) {
case (#less) {
let (nl, ov) = insert(l, k, v);
((nk, kv, ?nl, r, c, h), ov);
};
case (#equal) {
((nk, v, l, r, c, h), ?kv);
};
case (#greater) {
let (nr, ov) = insert(r, k, v);
((nk, kv, l, ?nr, c, h), ov);
};
};
(balance(update(nn)), ov);
};
};
};

public func get(t : ?Node, k : [Nat8]) : ?[Nat8] {
var root = t;
label l loop {
let (key, v, l, r, _, _) = switch (root) {
case (null) { break l };
case (?v) { v };
};
switch (compare(k, key)) {
case (#less) {
root := l;
};
case (#equal) {
return ?v;
};
case (#greater) {
root := r;
};
};
};
null;
};

private func compare(xs : [Nat8], ys : [Nat8]) : { #less; #equal; #greater } {
if (xs.size() < ys.size()) return #less;
if (xs.size() > ys.size()) return #greater;
var i = 0;
while (i < xs.size()) {
let x = xs[i];
let y = ys[i];
if (x < y) return #less;
if (y < x) return #greater;
i += 1;
};
#equal;
};

private func isRed(n : ?Node) : Bool {
switch (n) {
case (?(_, _, _, _, #Red, _)) { true };
case (_) { false };
};
};

private func balance(n : Node) : Node {
switch (n) {
case (k, v, ?l, ?r, c, h) {
if (not isRed(?l) and isRed(?r)) return rotateLeft(n);
if (isRed(?l) and isRed(l.2)) return rotateRight(n);
if (isRed(?l) and isRed(?r)) return (k, v, ?flipColor(l), ?flipColor(r), flip(c), h);
};
case (_) {};
};
n;
};

private func rotateRight(n : Node) : Node {
assert (isRed(n.2));
var l = unwrap(n.2);
// n.l = n.l.r;
let h = update((n.0, n.1, l.3, n.3, n.4, n.5));
// r.r = h;
// r.c = h.c;
// r.r.c = #Red;
update((l.0, l.1, l.2, ?(h.0, h.1, h.2, h.3, #Red, h.5), h.4, l.5));
};

private func rotateLeft(n : Node) : Node {
assert (isRed(n.3));
var r = unwrap(n.3);
// n.r = n.r.l;
let h = update((n.0, n.1, n.2, r.2, n.4, n.5));
// r.l = h;
// r.c = h.c;
// r.l.c = #Red;
update((r.0, r.1, ?(h.0, h.1, h.2, h.3, #Red, h.5), r.3, h.4, r.5));
};

private func flipColor((k, v, l, r, c, h) : Node) : Node {
(k, v, l, r, flip(c), h);
};

// NOTE: do use with caution!
private func unwrap<T>(x : ?T) : T {
switch x {
case (null) { P.unreachable() };
case (?x) { x };
};
};

// Returns a new node based on the given key and value.
public func newNode(key : [Nat8], value : [Nat8]) : Node {
(key, value, null, null, #Red, HashTree.reconstruct(#Labeled(key, #Leaf(value))));
};

// Updates the hashes of the given node.
private func update(n : Node) : Node {
let (k, v, l, r, c, _) = n;
(k, v, l, r, c, subHashTree(n));
};

private func subHashTree(n : Node) : HashTree.Hash {
let h = dataHash(n);
let (_, _, l, r, _, _) = n;
switch (l, r) {
case (null, null) h;
case (?l, null) HashTree.reconstruct(#Fork(#Pruned(l.5), #Pruned(h)));
case (null, ?r) HashTree.reconstruct(#Fork(#Pruned(h), #Pruned(r.5)));
case (?l, ?r) HashTree.reconstruct(#Fork(#Pruned(l.5), #Fork(#Pruned(h), #Pruned(r.5))));
};
};

// Returns the Hash corresponding to the node.
public func getHash(n : ?Node) : ?HashTree.Hash {
switch (n) {
case (?n) ?n.5;
case (null) null;
};
};

// Returns the HashTree corresponding to the node.
public func getHashTree(n : ?Node) : HashTree.HashTree {
switch (n) {
case (null) #Empty;
case (?(v, k, null, null, _, _)) {
if (v.size() == 0) return #Leaf(k);
return #Labeled(v, #Leaf(k));
};
case (?(v, _, l, r, _, _)) {
if (v.size() == 0) return #Fork(
getHashTree(l),
getHashTree(r),
);
return #Labeled(
v,
#Fork(
getHashTree(l),
getHashTree(r),
),
);
};
};
};

// Hashes the data contained within the node.
private func dataHash((k, v, _, _, _, _) : Node) : HashTree.Hash {
HashTree.reconstruct(#Labeled(k, #Leaf(v)));
};
};
16 changes: 16 additions & 0 deletions src/certified/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ declaring a special "certified variable".

[Read more...](https://internetcomputer.org/how-it-works/response-certification/)

## HashTree

```motoko
import HashTree "mo:core/certified/HashTree";
```
Expand All @@ -18,3 +20,17 @@ module {
wellFormed : HashTree -> Bool;
};
```

## RBTree

```motoko
module {
type Node = ([Nat8], [Nat8], ?Node, ?Node, Color, Hash);
get : (?Node, [Nat8]) -> ?[Nat8];
getHash : ?Node -> ?Hash;
getHashTree : ?Node -> HashTree;
insert : (?Node, [Nat8], [Nat8]) -> (Node, ?[Nat8]);
insertRoot : (?Node, [Nat8], [Nat8]) -> (Node, ?[Nat8]);
newNode : ([Nat8], [Nat8]) -> Node;
};
```
42 changes: 21 additions & 21 deletions test/certified/HashTree.mo
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,40 @@ func xb(t : Text) : [Nat8] = switch (decode(t)) {
};
};

let prunedTree = #fork(
#fork(
#labeled(b("a"), #fork(
#pruned(xb("1b4feff9bef8131788b0c9dc6dbad6e81e524249c879e9f10f71ce3749f5a638")),
#labeled(b("y"), #leaf(b("world"))),
let prunedTree = #Fork(
#Fork(
#Labeled(b("a"), #Fork(
#Pruned(xb("1b4feff9bef8131788b0c9dc6dbad6e81e524249c879e9f10f71ce3749f5a638")),
#Labeled(b("y"), #Leaf(b("world"))),
)),
#labeled(b("b"), #pruned(xb("7b32ac0c6ba8ce35ac82c255fc7906f7fc130dab2a090f80fe12f9c2cae83ba6"))),
#Labeled(b("b"), #Pruned(xb("7b32ac0c6ba8ce35ac82c255fc7906f7fc130dab2a090f80fe12f9c2cae83ba6"))),
),
#fork(
#pruned(xb("ec8324b8a1f1ac16bd2e806edba78006479c9877fed4eb464a25485465af601d")),
#labeled(b("d"), #leaf(b("morning"))),
#Fork(
#Pruned(xb("ec8324b8a1f1ac16bd2e806edba78006479c9877fed4eb464a25485465af601d")),
#Labeled(b("d"), #Leaf(b("morning"))),
),
);

let tree = #fork(
#fork(
#labeled(b("a"), #fork(
#fork(
#labeled(b("x"), #leaf(b("hello"))),
#empty,
let tree = #Fork(
#Fork(
#Labeled(b("a"), #Fork(
#Fork(
#Labeled(b("x"), #Leaf(b("hello"))),
#Empty,
),
#labeled(b("y"), #leaf(b("world"))),
#Labeled(b("y"), #Leaf(b("world"))),
)),
#labeled(b("b"), #leaf(b("good"))),
#Labeled(b("b"), #Leaf(b("good"))),
),
#fork(
#labeled(b("c"), #empty),
#labeled(b("d"), #leaf(b("morning"))),
#Fork(
#Labeled(b("c"), #Empty),
#Labeled(b("d"), #Leaf(b("morning"))),
),
);

assert(wellFormed(prunedTree));
assert(wellFormed(tree));
assert(not wellFormed(#fork(#leaf(b("a")), #empty)));
assert(not wellFormed(#Fork(#Leaf(b("a")), #Empty)));

assert(Hex.encode(reconstruct(prunedTree)) == "eb5c5b2195e62d996b84c9bcc8259d19a83786a2f59e0878cec84c811f669aa0");
assert(Hex.encode(reconstruct(prunedTree)) == Hex.encode(reconstruct(tree)));
Expand Down
Loading

0 comments on commit d6f3c31

Please sign in to comment.