diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2d83068 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +coverage.out diff --git a/Makefile b/Makefile index c869c8a..17266af 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,12 @@ -.PHONY: test test-ledger gen gen-ic fmt +.PHONY: test test-cover test-ledger gen gen-ic fmt test: go test -v -cover ./... - + +test-cover: + go test -v -coverprofile=coverage.out ./... + go tool cover -html=coverage.out + test-ledger: cd ic; dfx start --background --clean cd ic/testdata; dfx deploy --no-wallet diff --git a/agent.go b/agent.go index b8e9bec..5668ea2 100644 --- a/agent.go +++ b/agent.go @@ -2,6 +2,7 @@ package agent import ( "encoding/binary" + "encoding/hex" "fmt" "net/url" "time" @@ -41,10 +42,11 @@ type Agent struct { client Client identity identity.Identity ingressExpiry time.Duration + rootKey []byte } // New returns a new Agent based on the given configuration. -func New(cfg Config) Agent { +func New(cfg Config) (*Agent, error) { if cfg.IngressExpiry == 0 { cfg.IngressExpiry = 10 * time.Second } @@ -59,11 +61,21 @@ func New(cfg Config) Agent { if cfg.ClientConfig != nil { ccfg = *cfg.ClientConfig } - return Agent{ - client: NewClient(ccfg), + client := NewClient(ccfg) + rootKey, _ := hex.DecodeString(certificate.RootKey) + if cfg.FetchRootKey { + status, err := client.Status() + if err != nil { + return nil, err + } + rootKey = status.RootKey + } + return &Agent{ + client: client, identity: id, ingressExpiry: cfg.IngressExpiry, - } + rootKey: rootKey, + }, nil } // Call calls a method on a canister and unmarshals the result into the given values. @@ -230,18 +242,14 @@ func (a Agent) RequestStatus(canisterID principal.Principal, requestID RequestID if err := cbor.Unmarshal(c, &state); err != nil { return nil, nil, err } - status, err := a.client.Status() // TODO: fetch status once. - if err != nil { - return nil, nil, err - } - cert, err := certificate.New(canisterID, status.RootKey[len(status.RootKey)-96:], c) + cert, err := certificate.New(canisterID, a.rootKey[len(a.rootKey)-96:], c) if err != nil { return nil, nil, err } if err := cert.Verify(); err != nil { return nil, nil, err } - node, err := certificate.DeserializeNode(state["tree"].([]interface{})) + node, err := certificate.DeserializeNode(state["tree"].([]any)) if err != nil { return nil, nil, err } @@ -343,4 +351,5 @@ type Config struct { Identity identity.Identity IngressExpiry time.Duration ClientConfig *ClientConfig + FetchRootKey bool } diff --git a/agent_test.go b/agent_test.go index d2bb66d..ddc5b59 100644 --- a/agent_test.go +++ b/agent_test.go @@ -12,7 +12,7 @@ import ( func Example_anonymous_query() { ledgerID, _ := principal.Decode("ryjl3-tyaaa-aaaaa-aaaba-cai") - a := agent.New(agent.Config{}) + a, _ := agent.New(agent.Config{}) args, err := candid.EncodeValueString("record { account = \"9523dc824aa062dcd9c91b98f4594ff9c6af661ac96747daef2090b7fe87037d\" }") if err != nil { fmt.Println(err) @@ -26,7 +26,7 @@ func Example_query() { publicKey, privateKey, _ := ed25519.GenerateKey(rand.Reader) id, _ := identity.NewEd25519Identity(publicKey, privateKey) ledgerID, _ := principal.Decode("ryjl3-tyaaa-aaaaa-aaaba-cai") - a := agent.New(agent.Config{ + a, _ := agent.New(agent.Config{ Identity: id, }) args, err := candid.EncodeValueString("record { account = \"9523dc824aa062dcd9c91b98f4594ff9c6af661ac96747daef2090b7fe87037d\" }") diff --git a/certificate/bls/bls.go b/certificate/bls/bls.go index 9280051..f0ebeab 100644 --- a/certificate/bls/bls.go +++ b/certificate/bls/bls.go @@ -46,6 +46,15 @@ func PublicKeyFromHexString(s string) (*PublicKey, error) { return &pub, pub.DeserializeHexStr(s) } +type SecretKey = bls.SecretKey + +// NewSecretKeyByCSPRNG returns a new SecretKey generated by CSPRNG. +func NewSecretKeyByCSPRNG() *SecretKey { + var sk bls.SecretKey + sk.SetByCSPRNG() + return &sk +} + type Signature = bls.Sign // SignatureFromBytes returns a Signature from a byte slice. diff --git a/certificate/bls/bls_test.go b/certificate/bls/bls_test.go index 0e6864a..5811088 100644 --- a/certificate/bls/bls_test.go +++ b/certificate/bls/bls_test.go @@ -1,10 +1,43 @@ package bls import ( + "encoding/hex" "testing" ) +func TestSecretKey(t *testing.T) { + sk := NewSecretKeyByCSPRNG() + s := sk.Sign("hello") + if !s.Verify(sk.GetPublicKey(), "hello") { + t.Error() + } +} + func TestVerify(t *testing.T) { + // SOURCE: https://github.com/dfinity/agent-js/blob/5214dc1fc4b9b41f023a88b1228f04d2f2536987/packages/bls-verify/src/index.test.ts#L101 + publicKeyHex := "a7623a93cdb56c4d23d99c14216afaab3dfd6d4f9eb3db23d038280b6d5cb2caaee2a19dd92c9df7001dede23bf036bc0f33982dfb41e8fa9b8e96b5dc3e83d55ca4dd146c7eb2e8b6859cb5a5db815db86810b8d12cee1588b5dbf34a4dc9a5" + publicKeyRaw, _ := hex.DecodeString(publicKeyHex) + publicKey, err := PublicKeyFromBytes(publicKeyRaw) + if err != nil { + t.Fatal(err) + } + + signatureHex := "b89e13a212c830586eaa9ad53946cd968718ebecc27eda849d9232673dcd4f440e8b5df39bf14a88048c15e16cbcaabe" + signatureHexRaw, _ := hex.DecodeString(signatureHex) + signature, err := SignatureFromBytes(signatureHexRaw) + if err != nil { + t.Fatal(err) + } + + if signature.Verify(publicKey, "bye") { + t.Error() + } + if !signature.Verify(publicKey, "hello") { + t.Error() + } +} + +func TestVerify_hex(t *testing.T) { // SOURCE: https://github.com/dfinity/agent-js/blob/5214dc1fc4b9b41f023a88b1228f04d2f2536987/packages/bls-verify/src/index.test.ts#L101 publicKeyHex := "a7623a93cdb56c4d23d99c14216afaab3dfd6d4f9eb3db23d038280b6d5cb2caaee2a19dd92c9df7001dede23bf036bc0f33982dfb41e8fa9b8e96b5dc3e83d55ca4dd146c7eb2e8b6859cb5a5db815db86810b8d12cee1588b5dbf34a4dc9a5" publicKey, err := PublicKeyFromHexString(publicKeyHex) diff --git a/certificate/certificate.go b/certificate/certificate.go index 633157d..6812743 100644 --- a/certificate/certificate.go +++ b/certificate/certificate.go @@ -1,7 +1,6 @@ package certificate import ( - "encoding/hex" "fmt" "github.com/aviate-labs/agent-go/certificate/bls" "github.com/aviate-labs/agent-go/principal" @@ -42,7 +41,7 @@ func New(canisterID principal.Principal, rootKey []byte, certificate []byte) (*C // Verify verifies the certificate. func (c Certificate) Verify() error { - signature, err := bls.SignatureFromHexString(hex.EncodeToString(c.cert.Signature)) + signature, err := bls.SignatureFromBytes(c.cert.Signature) if err != nil { return err } @@ -51,7 +50,7 @@ func (c Certificate) Verify() error { return err } rootHash := c.cert.Tree.Digest() - message := append(domainSeparator("ic-state-root"), rootHash[:]...) + message := append(DomainSeparator("ic-state-root"), rootHash[:]...) if !signature.Verify(publicKey, string(message)) { return fmt.Errorf("signature verification failed") } diff --git a/certificate/node.go b/certificate/node.go index b0fa661..d8414f0 100644 --- a/certificate/node.go +++ b/certificate/node.go @@ -8,17 +8,17 @@ import ( "github.com/fxamacker/cbor/v2" ) -func Serialize(node Node) ([]byte, error) { - return cbor.Marshal(serialize(node)) -} - -func domainSeparator(t string) []byte { +func DomainSeparator(t string) []byte { return append( []byte{uint8(len(t))}, []byte(t)..., ) } +func Serialize(node Node) ([]byte, error) { + return cbor.Marshal(serialize(node)) +} + func serialize(node Node) []any { switch n := node.(type) { case Empty: @@ -52,7 +52,7 @@ func serialize(node Node) []any { type Empty struct{} func (e Empty) Reconstruct() [32]byte { - return sha256.Sum256(domainSeparator("ic-hashtree-empty")) + return sha256.Sum256(DomainSeparator("ic-hashtree-empty")) } func (e Empty) String() string { @@ -68,7 +68,7 @@ func (f Fork) Reconstruct() [32]byte { l := f.LeftTree.Reconstruct() r := f.RightTree.Reconstruct() return sha256.Sum256(append( - domainSeparator("ic-hashtree-fork"), + DomainSeparator("ic-hashtree-fork"), append(l[:], r[:]...)..., )) } @@ -91,7 +91,7 @@ type Labeled struct { func (l Labeled) Reconstruct() [32]byte { t := l.Tree.Reconstruct() return sha256.Sum256(append( - domainSeparator("ic-hashtree-labeled"), + DomainSeparator("ic-hashtree-labeled"), append(l.Label, t[:]...)..., )) } @@ -104,7 +104,7 @@ type Leaf []byte func (l Leaf) Reconstruct() [32]byte { return sha256.Sum256(append( - domainSeparator("ic-hashtree-leaf"), + DomainSeparator("ic-hashtree-leaf"), l..., )) } diff --git a/client.go b/client.go index 7d44a9b..d54cb29 100644 --- a/client.go +++ b/client.go @@ -37,7 +37,25 @@ func (c Client) Status() (*Status, error) { } func (c Client) call(canisterID principal.Principal, data []byte) ([]byte, error) { - return c.post("call", canisterID, data, 202) + u := c.url(fmt.Sprintf("/api/v2/canister/%s/call", canisterID.Encode())) + resp, err := c.client.Post(u, "application/cbor", bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + switch resp.StatusCode { + case http.StatusAccepted: + return io.ReadAll(resp.Body) + case http.StatusOK: + body, _ := io.ReadAll(resp.Body) + var err preprocessingError + if err := cbor.Unmarshal(body, &err); err != nil { + return nil, err + } + return nil, fmt.Errorf("(%d) %s: %s", err.RejectCode, err.Message, err.ErrorCode) + default: + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("(%d) %s: %s", resp.StatusCode, resp.Status, body) + } } func (c Client) get(path string) ([]byte, error) { @@ -48,14 +66,14 @@ func (c Client) get(path string) ([]byte, error) { return io.ReadAll(resp.Body) } -func (c Client) post(path string, canisterID principal.Principal, data []byte, statusCorePass int) ([]byte, error) { +func (c Client) post(path string, canisterID principal.Principal, data []byte) ([]byte, error) { u := c.url(fmt.Sprintf("/api/v2/canister/%s/%s", canisterID.Encode(), path)) resp, err := c.client.Post(u, "application/cbor", bytes.NewBuffer(data)) if err != nil { return nil, err } switch resp.StatusCode { - case statusCorePass: + case http.StatusOK: return io.ReadAll(resp.Body) default: body, _ := io.ReadAll(resp.Body) @@ -64,11 +82,11 @@ func (c Client) post(path string, canisterID principal.Principal, data []byte, s } func (c Client) query(canisterID principal.Principal, data []byte) ([]byte, error) { - return c.post("query", canisterID, data, 200) + return c.post("query", canisterID, data) } func (c Client) readState(canisterID principal.Principal, data []byte) ([]byte, error) { - return c.post("read_state", canisterID, data, 200) + return c.post("read_state", canisterID, data) } func (c Client) url(p string) string { @@ -81,3 +99,12 @@ func (c Client) url(p string) string { type ClientConfig struct { Host *url.URL } + +type preprocessingError struct { + // The reject code. + RejectCode uint64 `cbor:"reject_code"` + // A textual diagnostic message. + Message string `cbor:"reject_message"` + // An optional implementation-specific textual error code. + ErrorCode string `cbor:"error_code"` +} diff --git a/gen/templates/agent.gotmpl b/gen/templates/agent.gotmpl index 054cab3..b0aa22a 100644 --- a/gen/templates/agent.gotmpl +++ b/gen/templates/agent.gotmpl @@ -16,16 +16,20 @@ type {{ .Name }} = {{ .Type }} // Agent is a client for the "{{ .CanisterName }}" canister. type Agent struct { - a agent.Agent + a *agent.Agent canisterId principal.Principal } // NewAgent creates a new agent for the "{{ .CanisterName }}" canister. -func NewAgent(canisterId principal.Principal, config agent.Config) Agent { - return Agent{ - a: agent.New(config), - canisterId: canisterId, +func NewAgent(canisterId principal.Principal, config agent.Config) (*Agent, error) { + a, err := agent.New(config) + if err != nil { + return nil, err } + return &Agent{ + a: a, + canisterId: canisterId, + }, nil } {{- range .Methods }} diff --git a/go.mod b/go.mod index eac3539..e1c0423 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/di-wu/parser v0.3.0 github.com/fxamacker/cbor/v2 v2.4.0 github.com/herumi/bls-go-binary v1.28.2 - golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 + golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc ) require github.com/x448/float16 v0.8.4 // indirect diff --git a/go.sum b/go.sum index ba2abf8..ce40b82 100644 --- a/go.sum +++ b/go.sum @@ -10,5 +10,5 @@ github.com/herumi/bls-go-binary v1.28.2 h1:F0AezsC0M1a9aZjk7g0l2hMb1F56Xtpfku97p github.com/herumi/bls-go-binary v1.28.2/go.mod h1:O4Vp1AfR4raRGwFeQpr9X/PQtncEicMoOe6BQt1oX0Y= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo= -golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU= +golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= diff --git a/ic/assetstorage/agent.go b/ic/assetstorage/agent.go index 61bafa5..741ff00 100755 --- a/ic/assetstorage/agent.go +++ b/ic/assetstorage/agent.go @@ -11,16 +11,20 @@ import ( // Agent is a client for the "assetstorage" canister. type Agent struct { - a agent.Agent + a *agent.Agent canisterId principal.Principal } // NewAgent creates a new agent for the "assetstorage" canister. -func NewAgent(canisterId principal.Principal, config agent.Config) Agent { - return Agent{ - a: agent.New(config), - canisterId: canisterId, +func NewAgent(canisterId principal.Principal, config agent.Config) (*Agent, error) { + a, err := agent.New(config) + if err != nil { + return nil, err } + return &Agent{ + a: a, + canisterId: canisterId, + }, nil } // ApiVersion calls the "api_version" method on the "assetstorage" canister. diff --git a/ic/cmc/agent.go b/ic/cmc/agent.go index b139be8..310cbc1 100755 --- a/ic/cmc/agent.go +++ b/ic/cmc/agent.go @@ -15,16 +15,20 @@ type AccountIdentifier = struct { // Agent is a client for the "cmc" canister. type Agent struct { - a agent.Agent + a *agent.Agent canisterId principal.Principal } // NewAgent creates a new agent for the "cmc" canister. -func NewAgent(canisterId principal.Principal, config agent.Config) Agent { - return Agent{ - a: agent.New(config), - canisterId: canisterId, +func NewAgent(canisterId principal.Principal, config agent.Config) (*Agent, error) { + a, err := agent.New(config) + if err != nil { + return nil, err } + return &Agent{ + a: a, + canisterId: canisterId, + }, nil } // GetIcpXdrConversionRate calls the "get_icp_xdr_conversion_rate" method on the "cmc" canister. diff --git a/ic/icpledger/agent.go b/ic/icpledger/agent.go index 9f5a719..f71eeea 100755 --- a/ic/icpledger/agent.go +++ b/ic/icpledger/agent.go @@ -17,16 +17,20 @@ type AccountIdentifier = []byte // Agent is a client for the "icpledger" canister. type Agent struct { - a agent.Agent + a *agent.Agent canisterId principal.Principal } // NewAgent creates a new agent for the "icpledger" canister. -func NewAgent(canisterId principal.Principal, config agent.Config) Agent { - return Agent{ - a: agent.New(config), - canisterId: canisterId, +func NewAgent(canisterId principal.Principal, config agent.Config) (*Agent, error) { + a, err := agent.New(config) + if err != nil { + return nil, err } + return &Agent{ + a: a, + canisterId: canisterId, + }, nil } // AccountBalance calls the "account_balance" method on the "icpledger" canister. diff --git a/ic/icpledger_test.go b/ic/icpledger_test.go index 672889e..141be01 100644 --- a/ic/icpledger_test.go +++ b/ic/icpledger_test.go @@ -22,8 +22,9 @@ var ( func Example_accountBalance() { host, _ := url.Parse("https://icp0.io") - a := icpledger.NewAgent(ic.LEDGER_PRINCIPAL, agent.Config{ + a, _ := icpledger.NewAgent(ic.LEDGER_PRINCIPAL, agent.Config{ ClientConfig: &agent.ClientConfig{Host: host}, + FetchRootKey: true, }) name, _ := a.Name() fmt.Println(name.Name) @@ -41,11 +42,12 @@ func TestAgent(t *testing.T) { t.Run("account_balance ed25519", func(t *testing.T) { id, _ := identity.NewRandomEd25519Identity() - a := icpledger.NewAgent(canisterId, agent.Config{ + a, _ := icpledger.NewAgent(canisterId, agent.Config{ Identity: id, ClientConfig: &agent.ClientConfig{ Host: host, }, + FetchRootKey: true, }) tokens, err := a.AccountBalance(icpledger.AccountBalanceArgs{ Account: defaultAccount[:], @@ -60,11 +62,12 @@ func TestAgent(t *testing.T) { t.Run("account_balance secp256k1", func(t *testing.T) { id, _ := identity.NewRandomSecp256k1Identity() - a := icpledger.NewAgent(canisterId, agent.Config{ + a, _ := icpledger.NewAgent(canisterId, agent.Config{ Identity: id, ClientConfig: &agent.ClientConfig{ Host: host, }, + FetchRootKey: true, }) tokens, err := a.AccountBalance(icpledger.AccountBalanceArgs{ Account: defaultAccount[:], @@ -77,10 +80,11 @@ func TestAgent(t *testing.T) { } }) - a := icpledger.NewAgent(canisterId, agent.Config{ + a, _ := icpledger.NewAgent(canisterId, agent.Config{ ClientConfig: &agent.ClientConfig{ Host: host, }, + FetchRootKey: true, }) t.Run("account_balance", func(t *testing.T) { tokens, err := a.AccountBalance(icpledger.AccountBalanceArgs{ diff --git a/ic/wallet/agent.go b/ic/wallet/agent.go index 8e56371..90b3130 100755 --- a/ic/wallet/agent.go +++ b/ic/wallet/agent.go @@ -11,16 +11,20 @@ import ( // Agent is a client for the "wallet" canister. type Agent struct { - a agent.Agent + a *agent.Agent canisterId principal.Principal } // NewAgent creates a new agent for the "wallet" canister. -func NewAgent(canisterId principal.Principal, config agent.Config) Agent { - return Agent{ - a: agent.New(config), - canisterId: canisterId, +func NewAgent(canisterId principal.Principal, config agent.Config) (*Agent, error) { + a, err := agent.New(config) + if err != nil { + return nil, err } + return &Agent{ + a: a, + canisterId: canisterId, + }, nil } // ApiVersion calls the "api_version" method on the "wallet" canister. diff --git a/mock/replica.go b/mock/replica.go new file mode 100644 index 0000000..b2da92d --- /dev/null +++ b/mock/replica.go @@ -0,0 +1,242 @@ +package mock + +import ( + "bytes" + "encoding/hex" + "github.com/aviate-labs/agent-go" + "github.com/aviate-labs/agent-go/candid/marshal" + "github.com/aviate-labs/agent-go/certificate" + "github.com/aviate-labs/agent-go/certificate/bls" + "github.com/aviate-labs/agent-go/principal" + "github.com/fxamacker/cbor/v2" + "io" + "log" + "net/http" + "strings" +) + +type Canister struct { + Id principal.Principal + Handler HandlerFunc +} + +type HandlerFunc func(request Request) ([]any, error) + +type Replica struct { + rootKey *bls.SecretKey + Canisters map[string]Canister + Requests map[string]agent.Request +} + +func NewReplica() *Replica { + return &Replica{ + rootKey: bls.NewSecretKeyByCSPRNG(), + Canisters: make(map[string]Canister), + Requests: make(map[string]agent.Request), + } +} + +// AddCanister adds a canister to the replica. +func (r *Replica) AddCanister(id principal.Principal, handler HandlerFunc) { + r.Canisters[id.String()] = Canister{ + Id: id, + Handler: handler, + } +} + +func (r *Replica) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if !strings.HasPrefix(request.URL.Path, "/api/v2/") { + writer.WriteHeader(http.StatusNotFound) + return + } + + path := strings.Split(request.URL.Path, "/")[3:] + switch path[0] { + case "canister": + if request.Method != http.MethodPost { + writer.WriteHeader(http.StatusMethodNotAllowed) + return + } + body, _ := io.ReadAll(request.Body) + r.handleCanister(writer, path[1], path[2], body) + case "status": + if request.Method != http.MethodGet { + writer.WriteHeader(http.StatusMethodNotAllowed) + return + } + r.handleStatus(writer) + default: + writer.WriteHeader(http.StatusNotFound) + } +} + +func (r *Replica) handleCanister(writer http.ResponseWriter, canisterId, typ string, body []byte) { + canister, ok := r.Canisters[canisterId] + if !ok { + writer.WriteHeader(http.StatusNotFound) + writer.Write([]byte("canister not found: " + canisterId)) + return + } + var envelope agent.Envelope + if err := cbor.Unmarshal(body, &envelope); err != nil { + writer.WriteHeader(http.StatusBadRequest) + writer.Write([]byte(err.Error())) + return + } + // TODO: validate sender + signatures, ... + req := envelope.Content + + switch typ { + case "call": + if req.Type != agent.RequestTypeCall { + writer.WriteHeader(http.StatusBadRequest) + writer.Write([]byte("expected call request")) + return + } + requestId := agent.NewRequestID(req) + requestIdHex := hex.EncodeToString(requestId[:]) + log.Println("received call request", requestIdHex) + r.Requests[requestIdHex] = req + writer.WriteHeader(http.StatusAccepted) + case "query": + if req.Type != agent.RequestTypeQuery { + writer.WriteHeader(http.StatusBadRequest) + writer.Write([]byte("expected query request")) + return + } + requestId := agent.NewRequestID(req) + requestIdHex := hex.EncodeToString(requestId[:]) + log.Println("received query request", requestIdHex) + + values, err := canister.Handler(fromAgentRequest(req)) + if err != nil { + writer.WriteHeader(http.StatusInternalServerError) + writer.Write([]byte(err.Error())) + return + } + + rawReply, err := marshal.Marshal(values) + if err != nil { + writer.WriteHeader(http.StatusInternalServerError) + writer.Write([]byte(err.Error())) + return + } + + reply := make(map[string][]byte) + reply["arg"] = rawReply + resp := agent.Response{ + Status: "replied", + Reply: reply, + } + + writer.WriteHeader(http.StatusOK) + raw, err := cbor.Marshal(resp) + if err != nil { + writer.WriteHeader(http.StatusInternalServerError) + writer.Write([]byte(err.Error())) + return + } + writer.Write(raw) + case "read_state": + if !bytes.Equal(req.Paths[0][0], []byte("request_status")) { + writer.WriteHeader(http.StatusBadRequest) + writer.Write([]byte("expected request_status")) + return + } + requestId := req.Paths[0][1] + requestIdHex := hex.EncodeToString(requestId) + log.Println("received read_state request", requestIdHex) + req, ok := r.Requests[requestIdHex] + if !ok { + writer.WriteHeader(http.StatusNotFound) + writer.Write([]byte("request not found: " + requestIdHex)) + return + } + values, err := canister.Handler(fromAgentRequest(req)) + if err != nil { + writer.WriteHeader(http.StatusInternalServerError) + writer.Write([]byte(err.Error())) + return + } + + rawReply, err := marshal.Marshal(values) + if err != nil { + writer.WriteHeader(http.StatusInternalServerError) + writer.Write([]byte(err.Error())) + return + } + + t := certificate.NewHashTree(certificate.Fork{ + LeftTree: certificate.Labeled{ + Label: []byte("request_status"), + Tree: certificate.Labeled{ + Label: requestId, + Tree: certificate.Fork{ + LeftTree: certificate.Labeled{ + Label: []byte("status"), + Tree: certificate.Leaf("replied"), + }, + RightTree: certificate.Labeled{ + Label: []byte("reply"), + Tree: certificate.Leaf(rawReply), + }, + }, + }, + }, + RightTree: certificate.Empty{}, + }) + d := t.Digest() + m := make(map[string][]byte) + s := r.rootKey.Sign(string(append(certificate.DomainSeparator("ic-state-root"), d[:]...))) + cert := certificate.Cert{ + Tree: t, + Signature: s.Serialize(), + } + rawCert, _ := cbor.Marshal(cert) + m["certificate"] = rawCert + + rawTree, _ := cbor.Marshal(t) + m["tree"] = rawTree + + writer.WriteHeader(http.StatusOK) + raw, err := cbor.Marshal(m) + if err != nil { + writer.WriteHeader(http.StatusInternalServerError) + writer.Write([]byte(err.Error())) + return + } + writer.Write(raw) + default: + writer.WriteHeader(http.StatusNotFound) + } +} + +func (r *Replica) handleStatus(writer http.ResponseWriter) { + log.Println("getting status") + publicKey := r.rootKey.GetPublicKey().Serialize() + status := agent.Status{ + Version: "golang-mock", + RootKey: publicKey, + } + raw, _ := cbor.Marshal(status) + writer.WriteHeader(http.StatusOK) + writer.Write(raw) +} + +type Request struct { + Type agent.RequestType + Sender principal.Principal + MethodName string + Arguments []any +} + +func fromAgentRequest(request agent.Request) Request { + var arguments []any + _ = marshal.Unmarshal(request.Arguments, arguments) + return Request{ + Type: request.Type, + Sender: request.Sender, + MethodName: request.MethodName, + Arguments: arguments, + } +} diff --git a/mock/replica_test.go b/mock/replica_test.go new file mode 100644 index 0000000..c859230 --- /dev/null +++ b/mock/replica_test.go @@ -0,0 +1,58 @@ +package mock_test + +import ( + "bytes" + "github.com/aviate-labs/agent-go" + "github.com/aviate-labs/agent-go/mock" + "github.com/aviate-labs/agent-go/principal" + "net/http/httptest" + "net/url" + "testing" +) + +func TestAgent(t *testing.T) { + replica := mock.NewReplica() + var canisterId principal.Principal + replica.AddCanister( + canisterId, + func(request mock.Request) ([]any, error) { + if !bytes.Equal(request.Sender.Raw, principal.AnonymousID.Raw) { + t.Error("unexpected sender") + } + if request.MethodName != "test" { + t.Error("unexpected method name") + } + if len(request.Arguments) != 0 { + t.Error("unexpected arguments") + } + return []any{"hello"}, nil + }, + ) + + s := httptest.NewServer(replica) + u, _ := url.Parse(s.URL) + a, _ := agent.New(agent.Config{ + ClientConfig: &agent.ClientConfig{Host: u}, + FetchRootKey: true, + }) + + t.Run("call", func(t *testing.T) { + var result string + if err := a.Call(canisterId, "test", nil, []any{&result}); err != nil { + t.Error(err) + } + if result != "hello" { + t.Error("unexpected result") + } + }) + + t.Run("query", func(t *testing.T) { + var result string + if err := a.Query(canisterId, "test", nil, []any{&result}); err != nil { + t.Error(err) + } + if result != "hello" { + t.Error("unexpected result") + } + }) +} diff --git a/status.go b/status.go index e5982ab..7750612 100644 --- a/status.go +++ b/status.go @@ -24,6 +24,19 @@ type Status struct { RootKey []byte } +func (s *Status) MarshalCBOR() ([]byte, error) { + m := map[string]any{ + "ic_api_version": s.Version, + "root_key": s.RootKey, + } + if s.Impl != nil { + m["impl_source"] = s.Impl.Source + m["impl_version"] = s.Impl.Version + m["impl_revision"] = s.Impl.Revision + } + return cbor.Marshal(m) +} + // UnmarshalCBOR implements the CBOR unmarshaler interface. func (s *Status) UnmarshalCBOR(data []byte) error { var status struct {