From cb28d543614c0a446184d62fbf7b3a05e38add40 Mon Sep 17 00:00:00 2001 From: Quint Daenen Date: Sat, 23 Mar 2024 17:13:31 +0100 Subject: [PATCH] Rewrite certificate package. --- README.md | 11 +- agent.go | 69 ++++-- agent_test.go | 8 +- certificate/lookup.go | 70 ------ certificate/lookup_test.go | 19 -- certificate/node_test.go | 152 ------------- {certificate => certification}/README.md | 0 {certificate => certification}/bls/bls.go | 0 .../bls/bls_test.go | 0 {certificate => certification}/certificate.go | 34 +-- .../certificate_test.go | 8 +- .../hashtree/hashtree.go | 12 +- certification/hashtree/hashtree_test.go | 124 +++++++++++ certification/hashtree/lookup.go | 207 ++++++++++++++++++ .../hashtree}/node.go | 2 +- certification/hashtree/node_test.go | 152 +++++++++++++ {certificate => certification}/http/README.md | 0 {certificate => certification}/http/agent.go | 0 .../http/agent_test.go | 7 +- .../http/certexp/header.go | 0 .../http/certexp/header_test.go | 0 .../http/certexp/ir.go | 0 {certificate => certification}/http/verify.go | 96 ++++---- {certificate => certification}/rootKey.go | 2 +- identity/anonymous.go | 16 +- identity/ed25519.go | 10 +- identity/ed25519_test.go | 22 +- identity/prime256v1.go | 80 +++---- identity/prime256v1_test.go | 22 +- identity/secp256k1.go | 4 +- identity/secp256k1_test.go | 22 +- mock/replica.go | 39 ++-- request.go | 22 +- request_test.go | 7 +- 34 files changed, 758 insertions(+), 459 deletions(-) delete mode 100644 certificate/lookup.go delete mode 100644 certificate/lookup_test.go delete mode 100644 certificate/node_test.go rename {certificate => certification}/README.md (100%) rename {certificate => certification}/bls/bls.go (100%) rename {certificate => certification}/bls/bls_test.go (100%) rename {certificate => certification}/certificate.go (79%) rename {certificate => certification}/certificate_test.go (94%) rename certificate/tree.go => certification/hashtree/hashtree.go (63%) create mode 100644 certification/hashtree/hashtree_test.go create mode 100644 certification/hashtree/lookup.go rename {certificate => certification/hashtree}/node.go (99%) create mode 100644 certification/hashtree/node_test.go rename {certificate => certification}/http/README.md (100%) rename {certificate => certification}/http/agent.go (100%) rename {certificate => certification}/http/agent_test.go (96%) rename {certificate => certification}/http/certexp/header.go (100%) rename {certificate => certification}/http/certexp/header_test.go (100%) rename {certificate => certification}/http/certexp/ir.go (100%) rename {certificate => certification}/http/verify.go (79%) rename {certificate => certification}/rootKey.go (96%) diff --git a/README.md b/README.md index 33e3377..e39079f 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Supported identities are `Ed25519` and `Secp256k1`. By default, the agent uses t ```go id, _ := identity.NewEd25519Identity(publicKey, privateKey) config := agent.Config{ - Identity: id, +Identity: id, } ``` @@ -65,8 +65,8 @@ If you are running a local replica, you can use the `FetchRootKey` option to fet ```go u, _ := url.Parse("http://localhost:8000") config := agent.Config{ - ClientConfig: &agent.ClientConfig{Host: u}, - FetchRootKey: true, +ClientConfig: &agent.ClientConfig{Host: u}, +FetchRootKey: true, } ``` @@ -103,3 +103,8 @@ installed then those tests will be ignored. ```shell go test -v ./... ``` + +## Reference Implementations + +- [Rust Agent](https://github.com/dfinity/agent-rs/) +- [JavaScript Agent](https://github.com/dfinity/agent-js/) diff --git a/agent.go b/agent.go index d73fd44..8243572 100644 --- a/agent.go +++ b/agent.go @@ -8,9 +8,11 @@ import ( "time" "github.com/aviate-labs/agent-go/candid/idl" - "github.com/aviate-labs/agent-go/certificate" + "github.com/aviate-labs/agent-go/certification" + "github.com/aviate-labs/agent-go/certification/hashtree" "github.com/aviate-labs/agent-go/identity" "github.com/aviate-labs/agent-go/principal" + "github.com/fxamacker/cbor/v2" ) @@ -63,7 +65,7 @@ func New(cfg Config) (*Agent, error) { ccfg = *cfg.ClientConfig } client := NewClient(ccfg) - rootKey, _ := hex.DecodeString(certificate.RootKey) + rootKey, _ := hex.DecodeString(certification.RootKey) if cfg.FetchRootKey { status, err := client.Status() if err != nil { @@ -129,8 +131,8 @@ func (a Agent) GetCanisterControllers(canisterID principal.Principal) ([]princip // GetCanisterInfo returns the raw certificate for the given canister based on the given sub-path. func (a Agent) GetCanisterInfo(canisterID principal.Principal, subPath string) ([]byte, error) { - path := [][]byte{[]byte("canister"), canisterID.Raw, []byte(subPath)} - c, err := a.readStateCertificate(canisterID, [][][]byte{path}) + path := []hashtree.Label{hashtree.Label("canister"), canisterID.Raw, hashtree.Label(subPath)} + c, err := a.readStateCertificate(canisterID, [][]hashtree.Label{path}) if err != nil { return nil, err } @@ -138,16 +140,20 @@ func (a Agent) GetCanisterInfo(canisterID principal.Principal, subPath string) ( if err := cbor.Unmarshal(c, &state); err != nil { return nil, err } - node, err := certificate.DeserializeNode(state["tree"].([]any)) + node, err := hashtree.DeserializeNode(state["tree"].([]any)) if err != nil { return nil, err } - return certificate.Lookup(path, node), nil + result := hashtree.NewHashTree(node).Lookup(path...) + if err := result.Found(); err != nil { + return nil, err + } + return result.Value, nil } func (a Agent) GetCanisterMetadata(canisterID principal.Principal, subPath string) ([]byte, error) { - path := [][]byte{[]byte("canister"), canisterID.Raw, []byte("metadata"), []byte(subPath)} - c, err := a.readStateCertificate(canisterID, [][][]byte{path}) + path := []hashtree.Label{hashtree.Label("canister"), canisterID.Raw, hashtree.Label("metadata"), hashtree.Label(subPath)} + c, err := a.readStateCertificate(canisterID, [][]hashtree.Label{path}) if err != nil { return nil, err } @@ -155,11 +161,15 @@ func (a Agent) GetCanisterMetadata(canisterID principal.Principal, subPath strin if err := cbor.Unmarshal(c, &state); err != nil { return nil, err } - node, err := certificate.DeserializeNode(state["tree"].([]any)) + node, err := hashtree.DeserializeNode(state["tree"].([]any)) if err != nil { return nil, err } - return certificate.Lookup(path, node), nil + result := hashtree.NewHashTree(node).Lookup(path...) + if err := result.Found(); err != nil { + return nil, err + } + return result.Value, nil } // GetCanisterModuleHash returns the module hash for the given canister. @@ -208,9 +218,9 @@ func (a Agent) Query(canisterID principal.Principal, methodName string, args []a } // RequestStatus returns the status of the request with the given ID. -func (a Agent) RequestStatus(canisterID principal.Principal, requestID RequestID) ([]byte, certificate.Node, error) { - path := [][]byte{[]byte("request_status"), requestID[:]} - c, err := a.readStateCertificate(canisterID, [][][]byte{path}) +func (a Agent) RequestStatus(canisterID principal.Principal, requestID RequestID) ([]byte, hashtree.Node, error) { + path := []hashtree.Label{hashtree.Label("request_status"), requestID[:]} + c, err := a.readStateCertificate(canisterID, [][]hashtree.Label{path}) if err != nil { return nil, nil, err } @@ -218,18 +228,22 @@ func (a Agent) RequestStatus(canisterID principal.Principal, requestID RequestID if err := cbor.Unmarshal(c, &state); err != nil { return nil, nil, err } - cert, err := certificate.New(canisterID, a.rootKey[len(a.rootKey)-96:], c) + cert, err := certification.New(canisterID, a.rootKey[len(a.rootKey)-96:], c) if err != nil { return nil, nil, err } if err := cert.Verify(); err != nil { return nil, nil, err } - node, err := certificate.DeserializeNode(state["tree"].([]any)) + node, err := hashtree.DeserializeNode(state["tree"].([]any)) if err != nil { return nil, nil, err } - return certificate.Lookup(append(path, []byte("status")), node), node, nil + result := hashtree.NewHashTree(node).Lookup(append(path, hashtree.Label("status"))...) + if err := result.Found(); err != nil { + return nil, nil, err + } + return result.Value, node, nil } // Sender returns the principal that is sending the requests. @@ -256,15 +270,24 @@ func (a Agent) poll(canisterID principal.Principal, requestID RequestID, delay, return nil, err } if len(data) != 0 { - path := [][]byte{[]byte("request_status"), requestID[:]} + path := []hashtree.Label{hashtree.Label("request_status"), requestID[:]} switch string(data) { case "rejected": - code := certificate.Lookup(append(path, []byte("reject_code")), node) - rejectMessage := certificate.Lookup(append(path, []byte("reject_message")), node) - return nil, fmt.Errorf("(%d) %s", uint64FromBytes(code), string(rejectMessage)) + tree := hashtree.NewHashTree(node) + codeResult := tree.Lookup(append(path, hashtree.Label("reject_code"))...) + messageResult := tree.Lookup(append(path, hashtree.Label("reject_message"))...) + if codeResult.Found() != nil || messageResult.Found() != nil { + return nil, fmt.Errorf("no reject code or message found") + } + return nil, fmt.Errorf("(%d) %s", uint64FromBytes(codeResult.Value), string(messageResult.Value)) case "replied": - path := [][]byte{[]byte("request_status"), requestID[:]} - return certificate.Lookup(append(path, []byte("reply")), node), nil + fmt.Println(node) + repliedResult := hashtree.NewHashTree(node).Lookup(append(path, hashtree.Label("reply"))...) + fmt.Println(repliedResult) + if repliedResult.Found() != nil { + return nil, fmt.Errorf("no reply found") + } + return repliedResult.Value, nil } } case <-timer.C: @@ -291,7 +314,7 @@ func (a Agent) readState(canisterID principal.Principal, data []byte) (map[strin return m, cbor.Unmarshal(resp, &m) } -func (a Agent) readStateCertificate(canisterID principal.Principal, paths [][][]byte) ([]byte, error) { +func (a Agent) readStateCertificate(canisterID principal.Principal, paths [][]hashtree.Label) ([]byte, error) { _, data, err := a.sign(Request{ Type: RequestTypeReadState, Sender: a.Sender(), diff --git a/agent_test.go b/agent_test.go index 67fd07c..37fb556 100644 --- a/agent_test.go +++ b/agent_test.go @@ -49,8 +49,8 @@ func Example_json() { // {"e8s":0} } -func Example_query_prime256v1() { - id, _ := identity.NewRandomPrime256v1Identity() +func Example_query_ed25519() { + id, _ := identity.NewRandomEd25519Identity() ledgerID, _ := principal.Decode("ryjl3-tyaaa-aaaaa-aaaba-cai") a, _ := agent.New(agent.Config{Identity: id}) var balance struct { @@ -64,8 +64,8 @@ func Example_query_prime256v1() { // 0 } -func Example_query_ed25519() { - id, _ := identity.NewRandomEd25519Identity() +func Example_query_prime256v1() { + id, _ := identity.NewRandomPrime256v1Identity() ledgerID, _ := principal.Decode("ryjl3-tyaaa-aaaaa-aaaba-cai") a, _ := agent.New(agent.Config{Identity: id}) var balance struct { diff --git a/certificate/lookup.go b/certificate/lookup.go deleted file mode 100644 index 1ff2b62..0000000 --- a/certificate/lookup.go +++ /dev/null @@ -1,70 +0,0 @@ -package certificate - -import ( - "bytes" -) - -// Lookup looks up the given path in the certificate tree. -func Lookup(path [][]byte, node Node) []byte { - if len(path) == 0 { - switch n := node.(type) { - case Leaf: - return n - default: - return nil - } - } - - n := findLabel(flattenNode(node), path[0]) - if n != nil { - return Lookup(path[1:], *n) - } - return nil -} - -// LookupPath returns a path from the given labels. -func LookupPath(p ...string) [][]byte { - var path [][]byte - for _, p := range p { - path = append(path, []byte(p)) - } - return path -} - -func LookupNode(path [][]byte, node Node) *Node { - if len(path) == 0 { - return &node - } - - n := findLabel(flattenNode(node), path[0]) - if n != nil { - return LookupNode(path[1:], *n) - } - return nil -} - -func findLabel(nodes []Node, label Label) *Node { - for _, node := range nodes { - switch n := node.(type) { - case Labeled: - if bytes.Equal(label, n.Label) { - return &n.Tree - } - } - } - return nil -} - -func flattenNode(node Node) []Node { - switch n := node.(type) { - case Empty: - return nil - case Fork: - return append( - flattenNode(n.LeftTree), - flattenNode(n.RightTree)..., - ) - default: - return []Node{node} - } -} diff --git a/certificate/lookup_test.go b/certificate/lookup_test.go deleted file mode 100644 index fed235f..0000000 --- a/certificate/lookup_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package certificate_test - -import ( - "fmt" - - "github.com/aviate-labs/agent-go/certificate" -) - -func ExampleLookup() { - fmt.Println(string(certificate.Lookup(certificate.LookupPath("a", "x"), tree))) - fmt.Println(string(certificate.Lookup(certificate.LookupPath("a", "y"), tree))) - fmt.Println(string(certificate.Lookup(certificate.LookupPath("b"), tree))) - fmt.Println(string(certificate.Lookup(certificate.LookupPath("d"), tree))) - // Output: - // hello - // world - // good - // morning -} diff --git a/certificate/node_test.go b/certificate/node_test.go deleted file mode 100644 index a349e76..0000000 --- a/certificate/node_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package certificate_test - -import ( - "encoding/hex" - "fmt" - "testing" - - cert "github.com/aviate-labs/agent-go/certificate" -) - -var pruned = cert.Fork{ - LeftTree: cert.Fork{ - LeftTree: cert.Labeled{ - Label: []byte("a"), - Tree: cert.Fork{ - LeftTree: cert.Pruned(h2b("1B4FEFF9BEF8131788B0C9DC6DBAD6E81E524249C879E9F10F71CE3749F5A638")), - RightTree: cert.Labeled{ - Label: []byte("y"), - Tree: cert.Leaf("world"), - }, - }, - }, - RightTree: cert.Labeled{ - Label: []byte("b"), - Tree: cert.Pruned(h2b("7B32AC0C6BA8CE35AC82C255FC7906F7FC130DAB2A090F80FE12F9C2CAE83BA6")), - }, - }, - RightTree: cert.Fork{ - LeftTree: cert.Pruned(h2b("EC8324B8A1F1AC16BD2E806EDBA78006479C9877FED4EB464A25485465AF601D")), - RightTree: cert.Labeled{ - Label: []byte("d"), - Tree: cert.Leaf("morning"), - }, - }, -} - -var tree = cert.Fork{ - LeftTree: cert.Fork{ - LeftTree: cert.Labeled{ - Label: []byte("a"), - Tree: cert.Fork{ - LeftTree: cert.Fork{ - LeftTree: cert.Labeled{ - Label: []byte("x"), - Tree: cert.Leaf("hello"), - }, - RightTree: cert.Empty{}, - }, - RightTree: cert.Labeled{ - Label: []byte("y"), - Tree: cert.Leaf("world"), - }, - }, - }, - RightTree: cert.Labeled{ - Label: []byte("b"), - Tree: cert.Leaf("good"), - }, - }, - RightTree: cert.Fork{ - LeftTree: cert.Labeled{ - Label: []byte("c"), - Tree: cert.Empty{}, - }, - RightTree: cert.Labeled{ - Label: []byte("d"), - Tree: cert.Leaf("morning"), - }, - }, -} - -func ExampleDeserialize() { - data, _ := hex.DecodeString("8301830183024161830183018302417882034568656c6c6f810083024179820345776f726c6483024162820344676f6f648301830241638100830241648203476d6f726e696e67") - fmt.Println(cert.Deserialize(data)) - // Output: - // {{a:{{x:hello|∅}|y:world}|b:good}|{c:∅|d:morning}} -} - -func ExamplePruned() { - fmt.Printf("%X", pruned.Reconstruct()) - // Output: - // EB5C5B2195E62D996B84C9BCC8259D19A83786A2F59E0878CEC84C811F669AA0 -} - -func ExampleSerialize() { - b, _ := cert.Serialize(tree) - fmt.Printf("%x", b) - // Output: - // 8301830183024161830183018302417882034568656c6c6f810083024179820345776f726c6483024162820344676f6f648301830241638100830241648203476d6f726e696e67 -} - -func Example_b() { - fmt.Printf("%X", cert.Leaf("good").Reconstruct()) - // Output: - // 7B32AC0C6BA8CE35AC82C255FC7906F7FC130DAB2A090F80FE12F9C2CAE83BA6 -} - -func Example_c() { - fmt.Printf("%X", cert.Labeled{ - Label: []byte("c"), - Tree: cert.Empty{}, - }.Reconstruct()) - // Output: - // EC8324B8A1F1AC16BD2E806EDBA78006479C9877FED4EB464A25485465AF601D -} - -func Example_root() { - fmt.Printf("%X", tree.Reconstruct()) - // Output: - // EB5C5B2195E62D996B84C9BCC8259D19A83786A2F59E0878CEC84C811F669AA0 -} - -// Source: https://sdk.dfinity.org/docs/interface-spec/index.html#_example -// ─┬─┬╴"a" ─┬─┬╴"x" ─╴"hello" -// -// │ │ │ └╴Empty -// │ │ └╴ "y" ─╴"world" -// │ └╴"b" ──╴"good" -// └─┬╴"c" ──╴Empty -// └╴"d" ──╴"morning" -func Example_x() { - fmt.Printf("%X", cert.Fork{ - LeftTree: cert.Labeled{ - Label: []byte("x"), - Tree: cert.Leaf("hello"), - }, - RightTree: cert.Empty{}, - }.Reconstruct()) - // Output: - // 1B4FEFF9BEF8131788B0C9DC6DBAD6E81E524249C879E9F10F71CE3749F5A638 -} - -func TestUFT8Leaf(t *testing.T) { - // []byte{0x90, 0xe4, 0xcf, 0xfc, 0xda, 0x94, 0x83, 0xec, 0x16} is the unsigned leb128 encoding of 1646079569558762000. - // Which is Mon Feb 28 2022 20:19:29 GMT+0000 in nanoseconds unix time. - l := cert.Leaf([]byte{0x90, 0xe4, 0xcf, 0xfc, 0xda, 0x94, 0x83, 0xec, 0x16}) - if l.String() != "0x90e4cffcda9483ec16" { - t.Error(l) - } - - s := cert.Leaf([]byte("some string")) - if s.String() != "some string" { - t.Error(l) - } -} - -func h2b(s string) [32]byte { - var bs [32]byte - b, _ := hex.DecodeString(s) - copy(bs[:], b) - return bs -} diff --git a/certificate/README.md b/certification/README.md similarity index 100% rename from certificate/README.md rename to certification/README.md diff --git a/certificate/bls/bls.go b/certification/bls/bls.go similarity index 100% rename from certificate/bls/bls.go rename to certification/bls/bls.go diff --git a/certificate/bls/bls_test.go b/certification/bls/bls_test.go similarity index 100% rename from certificate/bls/bls_test.go rename to certification/bls/bls_test.go diff --git a/certificate/certificate.go b/certification/certificate.go similarity index 79% rename from certificate/certificate.go rename to certification/certificate.go index 26bdee4..fd7d40d 100644 --- a/certificate/certificate.go +++ b/certification/certificate.go @@ -1,25 +1,28 @@ -package certificate +package certification import ( "fmt" - "github.com/aviate-labs/agent-go/certificate/bls" + "slices" + + "github.com/aviate-labs/agent-go/certification/bls" + "github.com/aviate-labs/agent-go/certification/hashtree" "github.com/aviate-labs/agent-go/principal" + "github.com/fxamacker/cbor/v2" - "slices" ) // Cert is a certificate gets returned by the IC. type Cert struct { // Tree is the certificate tree. - Tree HashTree `cbor:"tree"` + Tree hashtree.HashTree `cbor:"tree"` // Signature is the signature of the certificate tree. Signature []byte `cbor:"signature"` // Delegation is the delegation of the certificate. Delegation *Delegation `cbor:"delegation"` } -// Certificate is a certificate gets returned by the IC and can be used to verify -// the state root based on the root key and canister ID. +// Certificate is a certificate that gets returned by the IC and can be used to verify the state root based on the root +// key and canister ID. type Certificate struct { Cert Cert RootKey []byte @@ -50,7 +53,7 @@ func (c Certificate) Verify() error { return err } rootHash := c.Cert.Tree.Digest() - message := append(DomainSeparator("ic-state-root"), rootHash[:]...) + message := append(hashtree.DomainSeparator("ic-state-root"), rootHash[:]...) if !signature.Verify(publicKey, string(message)) { return fmt.Errorf("signature verification failed") } @@ -64,15 +67,14 @@ func (c Certificate) getPublicKey() (*bls.PublicKey, error) { } cert := c.Cert.Delegation - canisterRanges := Lookup( - LookupPath("subnet", string(cert.SubnetId.Raw), "canister_ranges"), - cert.Certificate.Cert.Tree.Root, + canisterRangesResult := cert.Certificate.Cert.Tree.Lookup( + hashtree.Label("subnet"), cert.SubnetId.Raw, hashtree.Label("canister_ranges"), ) - if canisterRanges == nil { + if canisterRangesResult.Found() != nil { return nil, fmt.Errorf("no canister ranges found for subnet %s", cert.SubnetId) } var rawRanges [][][]byte - if err := cbor.Unmarshal(canisterRanges, &rawRanges); err != nil { + if err := cbor.Unmarshal(canisterRangesResult.Value, &rawRanges); err != nil { return nil, err } @@ -90,14 +92,14 @@ func (c Certificate) getPublicKey() (*bls.PublicKey, error) { return nil, fmt.Errorf("canister %s is not in range", c.CanisterID) } - publicKey := Lookup( - LookupPath("subnet", string(cert.SubnetId.Raw), "public_key"), - cert.Certificate.Cert.Tree.Root, + publicKeyResult := cert.Certificate.Cert.Tree.Lookup( + hashtree.Label("subnet"), cert.SubnetId.Raw, hashtree.Label("public_key"), ) - if publicKey == nil { + if publicKeyResult.Found() != nil { return nil, fmt.Errorf("no public key found for subnet %s", cert.SubnetId) } + publicKey := publicKeyResult.Value if len(publicKey) != len(derPrefix)+96 { return nil, fmt.Errorf("invalid public key length: %d", len(publicKey)) } diff --git a/certificate/certificate_test.go b/certification/certificate_test.go similarity index 94% rename from certificate/certificate_test.go rename to certification/certificate_test.go index 01c7fc4..2e67205 100644 --- a/certificate/certificate_test.go +++ b/certification/certificate_test.go @@ -1,8 +1,8 @@ -package certificate_test +package certification_test import ( "encoding/hex" - "github.com/aviate-labs/agent-go/certificate" + "github.com/aviate-labs/agent-go/certification" "github.com/aviate-labs/agent-go/principal" "testing" ) @@ -16,11 +16,11 @@ func TestSampleCert(t *testing.T) { "00000000002FFFFF0101", } { t.Run(s, func(t *testing.T) { - c, err := certificate.New( + c, err := certification.New( principal.Principal{ Raw: hexToBytes(s), }, - hexToBytes(certificate.RootKey), + hexToBytes(certification.RootKey), hexToBytes(SampleCert), ) if err != nil { diff --git a/certificate/tree.go b/certification/hashtree/hashtree.go similarity index 63% rename from certificate/tree.go rename to certification/hashtree/hashtree.go index 90c9e5a..28a75f4 100644 --- a/certificate/tree.go +++ b/certification/hashtree/hashtree.go @@ -1,4 +1,4 @@ -package certificate +package hashtree // HashTree is a hash tree. type HashTree struct { @@ -15,6 +15,16 @@ func (t HashTree) Digest() [32]byte { return t.Root.Reconstruct() } +// Lookup looks up a path in the hash tree. +func (t HashTree) Lookup(path ...Label) LookupResult { + return lookupPath(t.Root, path...) +} + +// LookupSubTree looks up a path in the hash tree and returns the sub-tree. +func (t HashTree) LookupSubTree(path ...Label) LookupSubTreeResult { + return lookupSubTree(t.Root, path...) +} + // MarshalCBOR marshals a hash tree. func (t HashTree) MarshalCBOR() ([]byte, error) { return Serialize(t.Root) diff --git a/certification/hashtree/hashtree_test.go b/certification/hashtree/hashtree_test.go new file mode 100644 index 0000000..38a1e52 --- /dev/null +++ b/certification/hashtree/hashtree_test.go @@ -0,0 +1,124 @@ +package hashtree + +import ( + "bytes" + "encoding/hex" + "fmt" + "testing" +) + +func TestHashTree_Lookup(t *testing.T) { + t.Run("Empty Nodes", func(t *testing.T) { + tree := NewHashTree(Fork{ + LeftTree: Labeled{ + Label: Label("label 1"), + Tree: Empty{}, + }, + RightTree: Fork{ + LeftTree: Pruned{}, + RightTree: Fork{ + LeftTree: Labeled{ + Label: Label("label 3"), + Tree: Leaf{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}, + }, + RightTree: Labeled{ + Label: Label("label 5"), + Tree: Empty{}, + }, + }, + }, + }) + + for _, i := range []int{0, 1} { + if result := tree.Lookup(Label(fmt.Sprintf("label %d", i))); result.Type != LookupResultAbsent { + t.Fatalf("unexpected lookup result") + } + } + if result := tree.Lookup(Label("label 2")); result.Type != LookupResultUnknown { + t.Fatalf("unexpected lookup result") + } + if result := tree.Lookup(Label("label 3")); result.Type != LookupResultFound { + t.Fatalf("unexpected lookup result") + } else { + if !bytes.Equal(result.Value, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) { + t.Fatalf("unexpected node value") + } + } + for _, i := range []int{4, 5, 6} { + if result := tree.Lookup(Label(fmt.Sprintf("label %d", i))); result.Type != LookupResultAbsent { + t.Fatalf("unexpected lookup result") + } + } + }) + t.Run("Nil Nodes", func(t *testing.T) { + // let tree: HashTree> = fork( + // label("label 1", empty()), + // fork( + // fork( + // label("label 3", leaf(vec![1, 2, 3, 4, 5, 6])), + // label("label 5", empty()), + // ), + // pruned([1; 32]), + // ), + // ); + tree := NewHashTree(Fork{ + LeftTree: Labeled{ + Label: Label("label 1"), + }, + RightTree: Fork{ + LeftTree: Fork{ + LeftTree: Labeled{ + Label: Label("label 3"), + Tree: Leaf{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}, + }, + RightTree: Labeled{ + Label: Label("label 5"), + }, + }, + RightTree: Pruned{}, + }, + }) + for _, i := range []int{0, 1, 2} { + if result := tree.Lookup(Label(fmt.Sprintf("label %d", i))); result.Type != LookupResultAbsent { + t.Fatalf("unexpected lookup result") + } + } + if result := tree.Lookup(Label("label 3")); result.Type != LookupResultFound { + t.Fatalf("unexpected lookup result") + } else { + if !bytes.Equal(result.Value, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) { + t.Fatalf("unexpected node value") + } + } + for _, i := range []int{4, 5} { + if result := tree.Lookup(Label(fmt.Sprintf("label %d", i))); result.Type != LookupResultAbsent { + t.Fatalf("unexpected lookup result") + } + } + if result := tree.Lookup(Label("label 6")); result.Type != LookupResultUnknown { + t.Fatalf("unexpected lookup result") + } + }) +} + +func TestHashTree_simple(t *testing.T) { + tree := NewHashTree(Fork{ + LeftTree: Labeled{ + Label: Label("label 1"), + Tree: Empty{}, + }, + RightTree: Fork{ + LeftTree: Pruned{ + 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + }, + RightTree: Leaf{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}, + }, + }) + digest := tree.Digest() + if hex.EncodeToString(digest[:]) != "69cf325d0f20505b261821a7e77ff72fb9a8753a7964f0b587553bfb44e72532" { + t.Fatalf("unexpected digest: %x", digest) + } +} diff --git a/certification/hashtree/lookup.go b/certification/hashtree/lookup.go new file mode 100644 index 0000000..153d611 --- /dev/null +++ b/certification/hashtree/lookup.go @@ -0,0 +1,207 @@ +package hashtree + +import ( + "bytes" + "fmt" +) + +// LookupResult is the result of a lookup. +type LookupResult struct { + // Type is the type of the lookup result. + Type LookupResultType + // Value is the value of the leaf. Can be nil if the type is not LookupResultFound. + Value []byte +} + +func lookupPath(n Node, path ...Label) LookupResult { + switch { + case len(path) == 0: + switch n := n.(type) { + case Leaf: + return LookupResult{ + Type: LookupResultFound, + Value: n, + } + case nil, Empty: + return LookupResult{ + Type: LookupResultAbsent, + } + case Pruned: + return LookupResult{ + Type: LookupResultUnknown, + } + default: + // Labeled, Fork + return LookupResult{ + Type: LookupResultError, + } + } + default: + switch l := lookupLabel(n, path[0]); l.Type { + case lookupLabelResultFound: + return lookupPath(l.Node, path[1:]...) + case lookupLabelResultUnknown: + return LookupResult{ + Type: LookupResultUnknown, + } + default: + return LookupResult{ + Type: LookupResultAbsent, + } + } + } +} + +// Found returns an error if the lookup result is not found. +func (r LookupResult) Found() error { + switch r.Type { + case LookupResultAbsent: + return fmt.Errorf("not found") + case LookupResultUnknown: + return fmt.Errorf("unknown") + case LookupResultError: + return fmt.Errorf("error") + default: + return nil + } +} + +// LookupResultType is the type of the lookup result. +// It indicates whether the result is guaranteed to be absent, unknown or found. +type LookupResultType int + +const ( + // LookupResultAbsent means that the result is guaranteed to be absent. + LookupResultAbsent LookupResultType = iota + // LookupResultUnknown means that the result is unknown, some leaves were pruned. + LookupResultUnknown + // LookupResultFound means that the result is found. + LookupResultFound + // LookupResultError means that the result is an error, the path is not valid in this context. + LookupResultError +) + +// LookupSubTreeResult is the result of a lookup sub-tree. +type LookupSubTreeResult struct { + // Type is the type of the lookup sub-tree result. + Type LookupResultType + // Node is the node that was found. Can be nil if the type is not LookupResultFound. + Node Node +} + +func lookupSubTree(n Node, path ...Label) LookupSubTreeResult { + switch { + case len(path) == 0: + return LookupSubTreeResult{ + Type: LookupResultFound, + Node: n, + } + default: + switch l := lookupLabel(n, path[0]); l.Type { + case lookupLabelResultFound: + return lookupSubTree(l.Node, path[1:]...) + case lookupLabelResultUnknown: + return LookupSubTreeResult{ + Type: LookupResultUnknown, + } + default: + return LookupSubTreeResult{ + Type: LookupResultAbsent, + } + } + } +} + +// Found returns an error if the lookup sub-tree result is not found. +func (r LookupSubTreeResult) Found() error { + switch r.Type { + case LookupResultAbsent: + return fmt.Errorf("not found") + case LookupResultUnknown: + return fmt.Errorf("unknown") + case LookupResultError: + return fmt.Errorf("error") + default: + return nil + } +} + +// lookupLabelResult is the result of a lookup label. +type lookupLabelResult struct { + // Type is the type of the lookup label result. + Type lookupLabelResultType + // Node is the node that was found. Can be nil. + Node Node +} + +func lookupLabel(n Node, label Label) lookupLabelResult { + switch n := n.(type) { + case Labeled: + c := bytes.Compare(label, n.Label) + switch { + case c < 0: + return lookupLabelResult{ + Type: lookupLabelResultLess, + } + case c > 0: + return lookupLabelResult{ + Type: lookupLabelResultGreater, + } + default: + return lookupLabelResult{ + Type: lookupLabelResultFound, + Node: n.Tree, + } + } + case Pruned: + return lookupLabelResult{ + Type: lookupLabelResultUnknown, + } + case Fork: + switch ll := lookupLabel(n.LeftTree, label); ll.Type { + case lookupLabelResultGreater: + // Continue looking in the right tree. + switch rl := lookupLabel(n.RightTree, label); rl.Type { + case lookupLabelResultLess: + return lookupLabelResult{ + Type: lookupLabelResultAbsent, + } + default: + return rl + } + case lookupLabelResultUnknown: + // Continue looking in the right tree. + switch rl := lookupLabel(n.RightTree, label); rl.Type { + case lookupLabelResultLess: + return lookupLabelResult{ + Type: lookupLabelResultUnknown, + } + default: + return rl + } + default: + return ll + } + default: + return lookupLabelResult{ + Type: lookupLabelResultAbsent, + } + } +} + +// lookupLabelResultType is the type of the lookup label result. +// It indicates whether the label is guaranteed to be absent, unknown, less, greater or found. +type lookupLabelResultType int + +const ( + // lookupLabelResultAbsent means that the label is absent. + lookupLabelResultAbsent lookupLabelResultType = iota + // lookupLabelResultUnknown means that the label is unknown, some leaves were pruned. + lookupLabelResultUnknown + // lookupLabelResultLess means that the label was not found, but could be on the left side. + lookupLabelResultLess + // lookupLabelResultGreater means that the label was not found, but could be on the right side. + lookupLabelResultGreater + // lookupLabelResultFound means that the label was found. + lookupLabelResultFound +) diff --git a/certificate/node.go b/certification/hashtree/node.go similarity index 99% rename from certificate/node.go rename to certification/hashtree/node.go index d8414f0..84b6abb 100644 --- a/certificate/node.go +++ b/certification/hashtree/node.go @@ -1,4 +1,4 @@ -package certificate +package hashtree import ( "crypto/sha256" diff --git a/certification/hashtree/node_test.go b/certification/hashtree/node_test.go new file mode 100644 index 0000000..8ff5c24 --- /dev/null +++ b/certification/hashtree/node_test.go @@ -0,0 +1,152 @@ +package hashtree_test + +import ( + "encoding/hex" + "fmt" + "testing" + + "github.com/aviate-labs/agent-go/certification/hashtree" +) + +var pruned = hashtree.Fork{ + LeftTree: hashtree.Fork{ + LeftTree: hashtree.Labeled{ + Label: []byte("a"), + Tree: hashtree.Fork{ + LeftTree: hashtree.Pruned(h2b("1b4feff9bef8131788b0c9dc6dbad6e81e524249c879e9f10f71ce3749f5a638")), + RightTree: hashtree.Labeled{ + Label: []byte("y"), + Tree: hashtree.Leaf("world"), + }, + }, + }, + RightTree: hashtree.Labeled{ + Label: []byte("b"), + Tree: hashtree.Pruned(h2b("7b32ac0c6ba8ce35ac82c255fc7906f7fc130dab2a090f80fe12f9c2cae83ba6")), + }, + }, + RightTree: hashtree.Fork{ + LeftTree: hashtree.Pruned(h2b("ec8324b8a1f1ac16bd2e806edba78006479c9877fed4eb464a25485465af601d")), + RightTree: hashtree.Labeled{ + Label: []byte("d"), + Tree: hashtree.Leaf("morning"), + }, + }, +} + +var tree = hashtree.Fork{ + LeftTree: hashtree.Fork{ + LeftTree: hashtree.Labeled{ + Label: []byte("a"), + Tree: hashtree.Fork{ + LeftTree: hashtree.Fork{ + LeftTree: hashtree.Labeled{ + Label: []byte("x"), + Tree: hashtree.Leaf("hello"), + }, + RightTree: hashtree.Empty{}, + }, + RightTree: hashtree.Labeled{ + Label: []byte("y"), + Tree: hashtree.Leaf("world"), + }, + }, + }, + RightTree: hashtree.Labeled{ + Label: []byte("b"), + Tree: hashtree.Leaf("good"), + }, + }, + RightTree: hashtree.Fork{ + LeftTree: hashtree.Labeled{ + Label: []byte("c"), + Tree: hashtree.Empty{}, + }, + RightTree: hashtree.Labeled{ + Label: []byte("d"), + Tree: hashtree.Leaf("morning"), + }, + }, +} + +func ExampleDeserialize() { + data, _ := hex.DecodeString("8301830183024161830183018302417882034568656c6c6f810083024179820345776f726c6483024162820344676f6f648301830241638100830241648203476d6f726e696e67") + fmt.Println(hashtree.Deserialize(data)) + // Output: + // {{a:{{x:hello|∅}|y:world}|b:good}|{c:∅|d:morning}} +} + +func ExamplePruned() { + fmt.Printf("%x", pruned.Reconstruct()) + // Output: + // eb5c5b2195e62d996b84c9bcc8259d19a83786a2f59e0878cec84c811f669aa0 +} + +func ExampleSerialize() { + b, _ := hashtree.Serialize(tree) + fmt.Printf("%x", b) + // Output: + // 8301830183024161830183018302417882034568656c6c6f810083024179820345776f726c6483024162820344676f6f648301830241638100830241648203476d6f726e696e67 +} + +func Example_b() { + fmt.Printf("%x", hashtree.Leaf("good").Reconstruct()) + // Output: + // 7b32ac0c6ba8ce35ac82c255fc7906f7fc130dab2a090f80fe12f9c2cae83ba6 +} + +func Example_c() { + fmt.Printf("%x", hashtree.Labeled{ + Label: []byte("c"), + Tree: hashtree.Empty{}, + }.Reconstruct()) + // Output: + // ec8324b8a1f1ac16bd2e806edba78006479c9877fed4eb464a25485465af601d +} + +func Example_root() { + fmt.Printf("%x", tree.Reconstruct()) + // Output: + // eb5c5b2195e62d996b84c9bcc8259d19a83786a2f59e0878cec84c811f669aa0 +} + +// Source: https://sdk.dfinity.org/docs/interface-spec/index.html#_example +// ─┬─┬╴"a" ─┬─┬╴"x" ─╴"hello" +// +// │ │ │ └╴Empty +// │ │ └╴ "y" ─╴"world" +// │ └╴"b" ──╴"good" +// └─┬╴"c" ──╴Empty +// └╴"d" ──╴"morning" +func Example_x() { + fmt.Printf("%x", hashtree.Fork{ + LeftTree: hashtree.Labeled{ + Label: []byte("x"), + Tree: hashtree.Leaf("hello"), + }, + RightTree: hashtree.Empty{}, + }.Reconstruct()) + // Output: + // 1b4feff9bef8131788b0c9dc6dbad6e81e524249c879e9f10f71ce3749f5a638 +} + +func TestUFT8Leaf(t *testing.T) { + // []byte{0x90, 0xe4, 0xcf, 0xfc, 0xda, 0x94, 0x83, 0xec, 0x16} is the unsigned leb128 encoding of 1646079569558762000. + // Which is Mon Feb 28 2022 20:19:29 GMT+0000 in nanoseconds unix time. + l := hashtree.Leaf([]byte{0x90, 0xe4, 0xcf, 0xfc, 0xda, 0x94, 0x83, 0xec, 0x16}) + if l.String() != "0x90e4cffcda9483ec16" { + t.Error(l) + } + + s := hashtree.Leaf("some string") + if s.String() != "some string" { + t.Error(l) + } +} + +func h2b(s string) [32]byte { + var bs [32]byte + b, _ := hex.DecodeString(s) + copy(bs[:], b) + return bs +} diff --git a/certificate/http/README.md b/certification/http/README.md similarity index 100% rename from certificate/http/README.md rename to certification/http/README.md diff --git a/certificate/http/agent.go b/certification/http/agent.go similarity index 100% rename from certificate/http/agent.go rename to certification/http/agent.go diff --git a/certificate/http/agent_test.go b/certification/http/agent_test.go similarity index 96% rename from certificate/http/agent_test.go rename to certification/http/agent_test.go index 7f63584..2d5ba46 100644 --- a/certificate/http/agent_test.go +++ b/certification/http/agent_test.go @@ -2,11 +2,12 @@ package http_test import ( "fmt" + "testing" + "github.com/aviate-labs/agent-go" - "github.com/aviate-labs/agent-go/certificate/http" - "github.com/aviate-labs/agent-go/certificate/http/certexp" + "github.com/aviate-labs/agent-go/certification/http" + "github.com/aviate-labs/agent-go/certification/http/certexp" "github.com/aviate-labs/agent-go/principal" - "testing" ) func TestAgent_HttpRequest(t *testing.T) { diff --git a/certificate/http/certexp/header.go b/certification/http/certexp/header.go similarity index 100% rename from certificate/http/certexp/header.go rename to certification/http/certexp/header.go diff --git a/certificate/http/certexp/header_test.go b/certification/http/certexp/header_test.go similarity index 100% rename from certificate/http/certexp/header_test.go rename to certification/http/certexp/header_test.go diff --git a/certificate/http/certexp/ir.go b/certification/http/certexp/ir.go similarity index 100% rename from certificate/http/certexp/ir.go rename to certification/http/certexp/ir.go diff --git a/certificate/http/verify.go b/certification/http/verify.go similarity index 79% rename from certificate/http/verify.go rename to certification/http/verify.go index 1c84204..577d3b9 100644 --- a/certificate/http/verify.go +++ b/certification/http/verify.go @@ -5,16 +5,19 @@ import ( "crypto/sha256" "encoding/base64" "fmt" - "github.com/aviate-labs/agent-go/certificate" - "github.com/aviate-labs/agent-go/certificate/http/certexp" - "github.com/aviate-labs/leb128" - "github.com/fxamacker/cbor/v2" + "github.com/aviate-labs/agent-go/certification" "math/big" "net/url" "slices" "strconv" "strings" "time" + + "github.com/aviate-labs/agent-go/certification/hashtree" + "github.com/aviate-labs/agent-go/certification/http/certexp" + "github.com/aviate-labs/leb128" + + "github.com/fxamacker/cbor/v2" ) func CalculateRequestHash(r *Request, reqCert *certexp.CertificateExpressionRequestCertification) ([32]byte, error) { @@ -167,7 +170,7 @@ func (a Agent) VerifyResponse(path string, req *Request, resp *Response) error { } // Validate the certificate. - if err := (certificate.Certificate{ + if err := (certification.Certificate{ Cert: certificateHeader.Certificate, RootKey: a.GetRootKey(), CanisterID: a.canisterId, @@ -176,8 +179,11 @@ func (a Agent) VerifyResponse(path string, req *Request, resp *Response) error { } // The timestamp at the /time path must be recent, e.g. 5 minutes. - rawTime := certificate.Lookup(certificate.LookupPath("time"), certificateHeader.Certificate.Tree.Root) - t, err := leb128.DecodeUnsigned(bytes.NewReader(rawTime)) + rawTimeResult := certificateHeader.Certificate.Tree.Lookup(hashtree.Label("time")) + if rawTimeResult.Found() != nil { + return fmt.Errorf("no time found") + } + t, err := leb128.DecodeUnsigned(bytes.NewReader(rawTimeResult.Value)) if err != nil { return err } @@ -220,13 +226,13 @@ func (a *Agent) verify(req *Request, resp *Response, certificateHeader *Certific return err } - exprPathNode := certificate.LookupNode(certificate.LookupPath(exprPath.GetPath()...), certificateHeader.Tree.Root) - if exprPathNode == nil { + exprPathNodeResult := certificateHeader.Tree.LookupSubTree(exprPath.GetPath()...) + if exprPathNodeResult.Found() != nil { return fmt.Errorf("no expression path found") } - var exprHash certificate.Labeled - switch n := (*exprPathNode).(type) { - case certificate.Labeled: + var exprHash hashtree.Labeled + switch n := (exprPathNodeResult.Node).(type) { + case hashtree.Labeled: exprHash = n certExprHash := sha256.Sum256([]byte(certificateExpression)) if !bytes.Equal(exprHash.Label, certExprHash[:]) { @@ -245,12 +251,12 @@ func (a *Agent) verify(req *Request, resp *Response, certificateHeader *Certific return err } if certExpr.Certification.RequestCertification == nil { - n := certificate.LookupNode(certificate.LookupPath("", string(respHash[:])), exprHash.Tree) - if n == nil { + nResult := hashtree.NewHashTree(exprHash.Tree).LookupSubTree(hashtree.Label(""), respHash[:]) + if nResult.Found() != nil { return fmt.Errorf("response hash not found") } - switch n := (*n).(type) { - case certificate.Leaf: + switch n := (nResult.Node).(type) { + case hashtree.Leaf: if len(n) != 0 { return fmt.Errorf("invalid response hash: not empty") } @@ -263,12 +269,12 @@ func (a *Agent) verify(req *Request, resp *Response, certificateHeader *Certific if err != nil { return err } - n := certificate.LookupNode(certificate.LookupPath(string(reqHash[:]), string(respHash[:])), exprHash.Tree) - if n == nil { + nResult := hashtree.NewHashTree(exprHash.Tree).LookupSubTree(reqHash[:], respHash[:]) + if nResult.Found() != nil { return fmt.Errorf("response hash not found") } - switch n := (*n).(type) { - case certificate.Leaf: + switch n := (nResult.Node).(type) { + case hashtree.Leaf: if len(n) != 0 { return fmt.Errorf("invalid response hash: not empty") } @@ -284,22 +290,22 @@ func (a *Agent) verifyLegacy(path string, hash [32]byte, certificateHeader *Cert return fmt.Errorf("certificate version 2 is supported") } - witness := certificate.Lookup(certificate.LookupPath("canister", string(a.canisterId.Raw), "certified_data"), certificateHeader.Certificate.Tree.Root) - if len(witness) != 32 { + witnessResult := certificateHeader.Certificate.Tree.Lookup(hashtree.Label("canister"), a.canisterId.Raw, hashtree.Label("certified_data")) + if witnessResult.Found() != nil || len(witnessResult.Value) != 32 { return fmt.Errorf("no witness found") } reconstruct := certificateHeader.Tree.Root.Reconstruct() - if !bytes.Equal(witness, reconstruct[:32]) { + if !bytes.Equal(witnessResult.Value, reconstruct[:32]) { return fmt.Errorf("invalid witness") } - treeHash := certificate.Lookup(certificate.LookupPath("http_assets", path), certificateHeader.Tree.Root) - if len(treeHash) == 0 { - treeHash = certificate.Lookup(certificate.LookupPath("http_assets"), certificateHeader.Tree.Root) + treeHashResult := certificateHeader.Tree.Lookup(hashtree.Label("http_assets"), hashtree.Label(path)) + if treeHashResult.Found() != nil || len(treeHashResult.Value) == 0 { + treeHashResult = certificateHeader.Tree.Lookup(hashtree.Label("http_assets")) } - if !bytes.Equal(hash[:], treeHash) { + if treeHashResult.Found() != nil || !bytes.Equal(hash[:], treeHashResult.Value) { return fmt.Errorf("invalid hash") } @@ -307,10 +313,10 @@ func (a *Agent) verifyLegacy(path string, hash [32]byte, certificateHeader *Cert } type CertificateHeader struct { - Certificate certificate.Cert - Tree certificate.HashTree + Certificate certification.Cert + Tree hashtree.HashTree Version int - ExprPath []string + ExprPath []hashtree.Label } func ParseCertificateHeader(header string) (*CertificateHeader, error) { @@ -326,7 +332,7 @@ func ParseCertificateHeader(header string) (*CertificateHeader, error) { if err != nil { return nil, err } - var cert certificate.Cert + var cert certification.Cert if err := cbor.Unmarshal(raw, &cert); err != nil { return nil, err } @@ -336,7 +342,7 @@ func ParseCertificateHeader(header string) (*CertificateHeader, error) { if err != nil { return nil, err } - var tree certificate.HashTree + var tree hashtree.HashTree if err := cbor.Unmarshal(raw, &tree); err != nil { return nil, err } @@ -352,10 +358,14 @@ func ParseCertificateHeader(header string) (*CertificateHeader, error) { if err != nil { return nil, err } - var path []string - if err := cbor.Unmarshal(raw, &path); err != nil { + var strPath []string + if err := cbor.Unmarshal(raw, &strPath); err != nil { return nil, err } + var path []hashtree.Label + for _, s := range strPath { + path = append(path, hashtree.Label(s)) + } certificateHeader.ExprPath = path default: return nil, fmt.Errorf("invalid header") @@ -366,15 +376,15 @@ func ParseCertificateHeader(header string) (*CertificateHeader, error) { type ExpressionPath struct { Wildcard bool - Path []string + Path []hashtree.Label } -func ParseExpressionPath(path []string) (*ExpressionPath, error) { - if len(path) < 2 || path[0] != "http_expr" { +func ParseExpressionPath(path []hashtree.Label) (*ExpressionPath, error) { + if len(path) < 2 || !bytes.Equal(path[0], hashtree.Label("http_expr")) { return nil, fmt.Errorf("invalid expression path") } var wilcard bool - switch path[len(path)-1] { + switch string(path[len(path)-1]) { case "<*>": wilcard = true case "<$>": @@ -387,14 +397,14 @@ func ParseExpressionPath(path []string) (*ExpressionPath, error) { }, nil } -func (e ExpressionPath) GetPath() []string { - path := make([]string, len(e.Path)+2) +func (e ExpressionPath) GetPath() []hashtree.Label { + path := make([]hashtree.Label, len(e.Path)+2) copy(path[1:], e.Path) - path[0] = "http_expr" + path[0] = hashtree.Label("http_expr") if e.Wildcard { - path[len(path)-1] = "<*>" + path[len(path)-1] = hashtree.Label("<*>") } else { - path[len(path)-1] = "<$>" + path[len(path)-1] = hashtree.Label("<$>") } return path } diff --git a/certificate/rootKey.go b/certification/rootKey.go similarity index 96% rename from certificate/rootKey.go rename to certification/rootKey.go index 0c17982..36314cd 100644 --- a/certificate/rootKey.go +++ b/certification/rootKey.go @@ -1,4 +1,4 @@ -package certificate +package certification import "encoding/hex" diff --git a/identity/anonymous.go b/identity/anonymous.go index 1b57bd4..5d8d3e9 100644 --- a/identity/anonymous.go +++ b/identity/anonymous.go @@ -7,14 +7,6 @@ import ( // AnonymousIdentity is an identity that does not sign messages. type AnonymousIdentity struct{} -func (id AnonymousIdentity) Verify(_, _ []byte) bool { - return true -} - -func (id AnonymousIdentity) ToPEM() ([]byte, error) { - return nil, nil -} - // PublicKey returns the public key of the identity. func (id AnonymousIdentity) PublicKey() []byte { return nil @@ -29,3 +21,11 @@ func (id AnonymousIdentity) Sender() principal.Principal { func (id AnonymousIdentity) Sign(_ []byte) []byte { return nil } + +func (id AnonymousIdentity) ToPEM() ([]byte, error) { + return nil, nil +} + +func (id AnonymousIdentity) Verify(_, _ []byte) bool { + return true +} diff --git a/identity/ed25519.go b/identity/ed25519.go index 9813bc4..ba6a86e 100644 --- a/identity/ed25519.go +++ b/identity/ed25519.go @@ -85,11 +85,6 @@ func (id Ed25519Identity) Sign(data []byte) []byte { return ed25519.Sign(id.privateKey, data) } -// Verify verifies the given signature. -func (id Ed25519Identity) Verify(data, signature []byte) bool { - return ed25519.Verify(id.publicKey, data, signature) -} - // ToPEM returns the PEM representation of the identity. func (id Ed25519Identity) ToPEM() ([]byte, error) { data, err := x509.MarshalPKCS8PrivateKey(id.privateKey) @@ -101,3 +96,8 @@ func (id Ed25519Identity) ToPEM() ([]byte, error) { Bytes: data, }), nil } + +// Verify verifies the given signature. +func (id Ed25519Identity) Verify(data, signature []byte) bool { + return ed25519.Verify(id.publicKey, data, signature) +} diff --git a/identity/ed25519_test.go b/identity/ed25519_test.go index cf315f3..cca16a5 100644 --- a/identity/ed25519_test.go +++ b/identity/ed25519_test.go @@ -5,6 +5,17 @@ import ( "testing" ) +func TestEd25519Identity_Sign(t *testing.T) { + id, err := NewRandomEd25519Identity() + if err != nil { + t.Fatal(err) + } + data := []byte("hello") + if !id.Verify(data, id.Sign(data)) { + t.Error() + } +} + func TestNewEd25519Identity(t *testing.T) { id, _ := NewRandomEd25519Identity() data, err := id.ToPEM() @@ -22,14 +33,3 @@ func TestNewEd25519Identity(t *testing.T) { t.Error() } } - -func TestEd25519Identity_Sign(t *testing.T) { - id, err := NewRandomEd25519Identity() - if err != nil { - t.Fatal(err) - } - data := []byte("hello") - if !id.Verify(data, id.Sign(data)) { - t.Error() - } -} diff --git a/identity/prime256v1.go b/identity/prime256v1.go index 9b78072..6b841cc 100644 --- a/identity/prime256v1.go +++ b/identity/prime256v1.go @@ -15,41 +15,23 @@ import ( var prime256v1OID = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7} +func derEncodePrime256v1PublicKey(key *ecdsa.PublicKey) ([]byte, error) { + return asn1.Marshal(ecPublicKey{ + Metadata: []asn1.ObjectIdentifier{ + ecPublicKeyOID, + prime256v1OID, + }, + PublicKey: asn1.BitString{ + Bytes: marshal(elliptic.P256(), key.X, key.Y), + }, + }) +} + type Prime256v1Identity struct { privateKey *ecdsa.PrivateKey publicKey *ecdsa.PublicKey } -func NewRandomPrime256v1Identity() (*Prime256v1Identity, error) { - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return nil, err - } - return NewPrime256v1Identity(privateKey), nil -} - -func (id Prime256v1Identity) Sender() principal.Principal { - return principal.NewSelfAuthenticating(id.PublicKey()) -} - -func (id Prime256v1Identity) Sign(msg []byte) []byte { - hashData := sha256.Sum256(msg) - sigR, sigS, _ := ecdsa.Sign(rand.Reader, id.privateKey, hashData[:]) - var buffer [64]byte - r := sigR.Bytes() - s := sigS.Bytes() - copy(buffer[(32-len(r)):], r) - copy(buffer[(64-len(s)):], s) - return buffer[:] -} - -func (id Prime256v1Identity) Verify(msg, sig []byte) bool { - r := new(big.Int).SetBytes(sig[:32]) - s := new(big.Int).SetBytes(sig[32:]) - hashData := sha256.Sum256(msg) - return ecdsa.Verify(id.publicKey, hashData[:], r, s) -} - func NewPrime256v1Identity(privateKey *ecdsa.PrivateKey) *Prime256v1Identity { return &Prime256v1Identity{ privateKey: privateKey, @@ -69,16 +51,12 @@ func NewPrime256v1IdentityFromPEM(data []byte) (*Prime256v1Identity, error) { return NewPrime256v1Identity(privateKey), nil } -func derEncodePrime256v1PublicKey(key *ecdsa.PublicKey) ([]byte, error) { - return asn1.Marshal(ecPublicKey{ - Metadata: []asn1.ObjectIdentifier{ - ecPublicKeyOID, - prime256v1OID, - }, - PublicKey: asn1.BitString{ - Bytes: marshal(elliptic.P256(), key.X, key.Y), - }, - }) +func NewRandomPrime256v1Identity() (*Prime256v1Identity, error) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + return NewPrime256v1Identity(privateKey), nil } func (id Prime256v1Identity) PublicKey() []byte { @@ -86,6 +64,21 @@ func (id Prime256v1Identity) PublicKey() []byte { return der } +func (id Prime256v1Identity) Sender() principal.Principal { + return principal.NewSelfAuthenticating(id.PublicKey()) +} + +func (id Prime256v1Identity) Sign(msg []byte) []byte { + hashData := sha256.Sum256(msg) + sigR, sigS, _ := ecdsa.Sign(rand.Reader, id.privateKey, hashData[:]) + var buffer [64]byte + r := sigR.Bytes() + s := sigS.Bytes() + copy(buffer[(32-len(r)):], r) + copy(buffer[(64-len(s)):], s) + return buffer[:] +} + func (id Prime256v1Identity) ToPEM() ([]byte, error) { data, err := x509.MarshalECPrivateKey(id.privateKey) if err != nil { @@ -96,3 +89,10 @@ func (id Prime256v1Identity) ToPEM() ([]byte, error) { Bytes: data, }), nil } + +func (id Prime256v1Identity) Verify(msg, sig []byte) bool { + r := new(big.Int).SetBytes(sig[:32]) + s := new(big.Int).SetBytes(sig[32:]) + hashData := sha256.Sum256(msg) + return ecdsa.Verify(id.publicKey, hashData[:], r, s) +} diff --git a/identity/prime256v1_test.go b/identity/prime256v1_test.go index 88f5942..7ea4795 100644 --- a/identity/prime256v1_test.go +++ b/identity/prime256v1_test.go @@ -24,17 +24,6 @@ func TestNewPrime256v1Identity(t *testing.T) { } } -func TestPrime256v1Identity_Sign(t *testing.T) { - id, err := NewRandomPrime256v1Identity() - if err != nil { - t.Fatal(err) - } - data := []byte("hello") - if !id.Verify(data, id.Sign(data)) { - t.Error() - } -} - func TestNewPrime256v1IdentityFromPEM(t *testing.T) { pem := ` -----BEGIN EC PRIVATE KEY----- @@ -52,3 +41,14 @@ Sks4xGbA/ZbazsrMl4v446U5UIVxCGGaKw== t.Fatal("public key mismatch") } } + +func TestPrime256v1Identity_Sign(t *testing.T) { + id, err := NewRandomPrime256v1Identity() + if err != nil { + t.Fatal(err) + } + data := []byte("hello") + if !id.Verify(data, id.Sign(data)) { + t.Error() + } +} diff --git a/identity/secp256k1.go b/identity/secp256k1.go index 3f90384..ddf9dfd 100644 --- a/identity/secp256k1.go +++ b/identity/secp256k1.go @@ -12,6 +12,8 @@ import ( "slices" ) +var ecPublicKeyOID = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} + var secp256k1OID = asn1.ObjectIdentifier{1, 3, 132, 0, 10} func derEncodeSecp256k1PublicKey(key *secp256k1.PublicKey) ([]byte, error) { @@ -130,8 +132,6 @@ func (id Secp256k1Identity) Verify(msg, sig []byte) bool { return signature.Verify(hashData[:], id.publicKey) } -var ecPublicKeyOID = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} - type ecPrivateKey struct { Version int PrivateKey []byte diff --git a/identity/secp256k1_test.go b/identity/secp256k1_test.go index 5ac4a6a..65de51d 100644 --- a/identity/secp256k1_test.go +++ b/identity/secp256k1_test.go @@ -24,17 +24,6 @@ func TestNewSecp256k1Identity(t *testing.T) { } } -func TestSecp256k1Identity_Sign(t *testing.T) { - id, err := NewRandomSecp256k1Identity() - if err != nil { - t.Fatal(err) - } - data := []byte("hello") - if !id.Verify(data, id.Sign(data)) { - t.Error() - } -} - func TestNewSecp256k1IdentityFromPEM(t *testing.T) { pem := ` -----BEGIN EC PARAMETERS----- @@ -55,3 +44,14 @@ N3d26cRxD99TPtm8uo2OuzKhSiq6EQ== t.Fatal("public key mismatch") } } + +func TestSecp256k1Identity_Sign(t *testing.T) { + id, err := NewRandomSecp256k1Identity() + if err != nil { + t.Fatal(err) + } + data := []byte("hello") + if !id.Verify(data, id.Sign(data)) { + t.Error() + } +} diff --git a/mock/replica.go b/mock/replica.go index 93a3586..408ec7e 100644 --- a/mock/replica.go +++ b/mock/replica.go @@ -3,15 +3,18 @@ package mock import ( "bytes" "encoding/hex" + "io" + "net/http" + "strings" + "github.com/aviate-labs/agent-go" "github.com/aviate-labs/agent-go/candid/idl" - "github.com/aviate-labs/agent-go/certificate" - "github.com/aviate-labs/agent-go/certificate/bls" + "github.com/aviate-labs/agent-go/certification" + "github.com/aviate-labs/agent-go/certification/bls" + "github.com/aviate-labs/agent-go/certification/hashtree" "github.com/aviate-labs/agent-go/principal" + "github.com/fxamacker/cbor/v2" - "io" - "net/http" - "strings" ) type Canister struct { @@ -188,29 +191,29 @@ func (r *Replica) handleCanister(writer http.ResponseWriter, canisterId, typ str return } - t := certificate.NewHashTree(certificate.Fork{ - LeftTree: certificate.Labeled{ + t := hashtree.NewHashTree(hashtree.Fork{ + LeftTree: hashtree.Labeled{ Label: []byte("request_status"), - Tree: certificate.Labeled{ + Tree: hashtree.Labeled{ Label: requestId, - Tree: certificate.Fork{ - LeftTree: certificate.Labeled{ - Label: []byte("status"), - Tree: certificate.Leaf("replied"), - }, - RightTree: certificate.Labeled{ + Tree: hashtree.Fork{ + LeftTree: hashtree.Labeled{ Label: []byte("reply"), - Tree: certificate.Leaf(rawReply), + Tree: hashtree.Leaf(rawReply), + }, + RightTree: hashtree.Labeled{ + Label: []byte("status"), + Tree: hashtree.Leaf("replied"), }, }, }, }, - RightTree: certificate.Empty{}, + RightTree: hashtree.Empty{}, }) d := t.Digest() m := make(map[string][]byte) - s := r.rootKey.Sign(string(append(certificate.DomainSeparator("ic-state-root"), d[:]...))) - cert := certificate.Cert{ + s := r.rootKey.Sign(string(append(hashtree.DomainSeparator("ic-state-root"), d[:]...))) + cert := certification.Cert{ Tree: t, Signature: s.Serialize(), } diff --git a/request.go b/request.go index 4b17347..04b75e4 100644 --- a/request.go +++ b/request.go @@ -6,9 +6,11 @@ import ( "math/big" "sort" + "github.com/aviate-labs/agent-go/certification/hashtree" "github.com/aviate-labs/agent-go/identity" "github.com/aviate-labs/agent-go/principal" "github.com/aviate-labs/leb128" + "github.com/fxamacker/cbor/v2" ) @@ -29,7 +31,7 @@ func encodeLEB128(i uint64) []byte { return e } -func hashPaths(paths [][][]byte) [32]byte { +func hashPaths(paths [][]hashtree.Label) [32]byte { var hash []byte for _, path := range paths { var rawPathHash []byte @@ -63,7 +65,7 @@ type Request struct { // Argument to pass to the canister method. Arguments []byte // A list of paths, where a path is itself a sequence of blobs. - Paths [][][]byte + Paths [][]hashtree.Label } // MarshalCBOR implements the CBOR marshaler interface. @@ -164,12 +166,12 @@ const ( ) type requestRaw struct { - Type RequestType `cbor:"request_type"` - Sender []byte `cbor:"sender"` - Nonce []byte `cbor:"nonce"` - IngressExpiry uint64 `cbor:"ingress_expiry"` - CanisterID []byte `cbor:"canister_id"` - MethodName string `cbor:"method_name"` - Arguments []byte `cbor:"arg"` - Paths [][][]byte `cbor:"paths,omitempty"` + Type RequestType `cbor:"request_type"` + Sender []byte `cbor:"sender"` + Nonce []byte `cbor:"nonce"` + IngressExpiry uint64 `cbor:"ingress_expiry"` + CanisterID []byte `cbor:"canister_id"` + MethodName string `cbor:"method_name"` + Arguments []byte `cbor:"arg"` + Paths [][]hashtree.Label `cbor:"paths,omitempty"` } diff --git a/request_test.go b/request_test.go index 3b29b0b..1167f98 100644 --- a/request_test.go +++ b/request_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/aviate-labs/agent-go" + "github.com/aviate-labs/agent-go/certification/hashtree" "github.com/aviate-labs/agent-go/principal" ) @@ -21,7 +22,7 @@ func TestNewRequestID(t *testing.T) { if h := fmt.Sprintf("%x", agent.NewRequestID(agent.Request{ Sender: principal.Principal{Raw: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xD2}}, - Paths: [][][]byte{ + Paths: [][]hashtree.Label{ {}, {[]byte("")}, {[]byte("hello"), []byte("world")}, @@ -31,13 +32,13 @@ func TestNewRequestID(t *testing.T) { } if h := fmt.Sprintf("%x", agent.NewRequestID(agent.Request{ - Paths: [][][]byte{}, + Paths: [][]hashtree.Label{}, })); h != "99daa8c80a61e87ac1fdf9dd49e39963bfe4dafb2a45095ebf4cad72d916d5be" { t.Error(h) } if h := fmt.Sprintf("%x", agent.NewRequestID(agent.Request{ - Paths: [][][]byte{{}}, + Paths: [][]hashtree.Label{{}}, })); h != "ea01a9c3d3830db108e0a87995ea0d4183dc9c6e51324e9818fced5c57aa64f5" { t.Error(h) }