diff --git a/.github/workflows/test-registry.yml b/.github/workflows/test-registry.yml new file mode 100644 index 0000000..9ead23a --- /dev/null +++ b/.github/workflows/test-registry.yml @@ -0,0 +1,13 @@ +on: + push: + paths: + - registry/** +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version: '1.22.1' + - run: make test-registry diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b503198..df5675a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,7 +6,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v4 with: - go-version: '1.21.0' + go-version: '1.22.1' - uses: aviate-labs/setup-dfx@v0.3.2 with: dfx-version: 0.18.0 diff --git a/registry/client.go b/registry/client.go index 281698d..917c7da 100644 --- a/registry/client.go +++ b/registry/client.go @@ -16,7 +16,7 @@ type Client struct { func New() (*Client, error) { dp, err := NewDataProvider() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create data provider: %w", err) } return &Client{ dp: dp, @@ -26,15 +26,19 @@ func New() (*Client, error) { func (c *Client) GetNNSSubnetID() (*principal.Principal, error) { v, _, err := c.dp.GetValueUpdate([]byte("nns_subnet_id"), nil) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get NNS subnet ID: %w", err) } var nnsSubnetID v1.SubnetId if err := proto.Unmarshal(v, &nnsSubnetID); err != nil { - return nil, err + return nil, fmt.Errorf("failed to unmarshal NNS subnet ID: %w", err) } return &principal.Principal{Raw: nnsSubnetID.PrincipalId.Raw}, nil } +func (c *Client) GetLatestVersion() (uint64, error) { + return c.dp.GetLatestVersion() +} + func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) { nnsSubnetID, err := c.GetNNSSubnetID() if err != nil { @@ -56,7 +60,7 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) { for { records, _, err := c.dp.GetCertifiedChangesSince(currentVersion, nnsPublicKey) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get certified changes: %w", err) } currentVersion = records[len(records)-1].Version for _, record := range records { @@ -66,7 +70,7 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) { } else { var nodeRecord v1.NodeRecord if err := proto.Unmarshal(record.Value, &nodeRecord); err != nil { - return nil, err + return nil, fmt.Errorf("failed to unmarshal node record: %w", err) } nodeMap[strings.TrimPrefix(record.Key, "node_record_")] = &nodeRecord } @@ -76,7 +80,7 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) { } else { var nodeOperatorRecord v1.NodeOperatorRecord if err := proto.Unmarshal(record.Value, &nodeOperatorRecord); err != nil { - return nil, err + return nil, fmt.Errorf("failed to unmarshal node operator record: %w", err) } nodeOperatorMap[strings.TrimPrefix(record.Key, "node_operator_record_")] = &nodeOperatorRecord } @@ -124,11 +128,11 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) { func (c *Client) GetSubnetDetails(subnetID principal.Principal) (*v1.SubnetRecord, error) { v, _, err := c.dp.GetValueUpdate([]byte(fmt.Sprintf("subnet_record_%s", subnetID)), nil) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get subnet details: %w", err) } var record v1.SubnetRecord if err := proto.Unmarshal(v, &record); err != nil { - return nil, err + return nil, fmt.Errorf("failed to unmarshal subnet details: %w", err) } return &record, nil } @@ -136,11 +140,11 @@ func (c *Client) GetSubnetDetails(subnetID principal.Principal) (*v1.SubnetRecor func (c *Client) GetSubnetIDs() ([]principal.Principal, error) { v, _, err := c.dp.GetValueUpdate([]byte("subnet_list"), nil) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get subnet IDs: %w", err) } var list v1.SubnetListRecord if err := proto.Unmarshal(v, &list); err != nil { - return nil, err + return nil, fmt.Errorf("failed to unmarshal subnet IDs: %w", err) } var subnets []principal.Principal for _, subnet := range list.Subnets { @@ -152,11 +156,11 @@ func (c *Client) GetSubnetIDs() ([]principal.Principal, error) { func (c *Client) GetSubnetPublicKey(subnetID principal.Principal) ([]byte, error) { v, _, err := c.dp.GetValueUpdate([]byte(fmt.Sprintf("crypto_threshold_signing_public_key_%s", subnetID)), nil) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get subnet public key: %w", err) } var publicKey v1.PublicKey if err := proto.Unmarshal(v, &publicKey); err != nil { - return nil, err + return nil, fmt.Errorf("failed to unmarshal subnet public key: %w", err) } if publicKey.Algorithm != v1.AlgorithmId_ALGORITHM_ID_THRES_BLS12_381 { return nil, fmt.Errorf("unsupported public key algorithm") diff --git a/registry/client_test.go b/registry/client_test.go index a7dc6fc..61c0129 100644 --- a/registry/client_test.go +++ b/registry/client_test.go @@ -1,17 +1,25 @@ -package registry +package registry_test import ( + "github.com/aviate-labs/agent-go/registry" "os" "testing" ) func TestClient_GetNodeListSince(t *testing.T) { checkEnabled(t) - c, err := New() + + c, err := registry.New() + if err != nil { + t.Fatal(err) + } + + latestVersion, err := c.GetLatestVersion() if err != nil { t.Fatal(err) } - if _, err := c.GetNodeListSince(0); err != nil { + + if _, err := c.GetNodeListSince(latestVersion - 100); err != nil { t.Fatal(err) } } diff --git a/registry/dataprovider.go b/registry/dataprovider.go index d7514ec..42aa870 100644 --- a/registry/dataprovider.go +++ b/registry/dataprovider.go @@ -22,7 +22,7 @@ type DataProvider struct { func NewDataProvider() (*DataProvider, error) { a, err := agent.New(agent.DefaultConfig) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create agent: %w", err) } return &DataProvider{a: a}, nil } @@ -37,28 +37,28 @@ func (d DataProvider) GetCertifiedChangesSince(version uint64, publicKey []byte) }, &resp, ); err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to get certified changes: %w", err) } ht, err := NewHashTree(resp.HashTree) if err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to create hash tree: %w", err) } rawCurrentVersion, err := ht.Lookup(hashtree.Label("current_version")) if err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to lookup current version: %w", err) } currentVersion, err := leb128.DecodeUnsigned(bytes.NewReader(rawCurrentVersion)) if err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to decode current version: %w", err) } deltaNodes, err := ht.LookupSubTree(hashtree.Label("delta")) if err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to lookup delta nodes: %w", err) } rawDeltas, err := hashtree.AllChildren(deltaNodes) if err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to get all children: %w", err) } var deltas []VersionedRecord @@ -66,7 +66,7 @@ func (d DataProvider) GetCertifiedChangesSince(version uint64, publicKey []byte) for _, delta := range rawDeltas { req := new(v1.RegistryAtomicMutateRequest) if err := proto.Unmarshal(delta.Value, req); err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to unmarshal atomic mutate request: %w", err) } v := binary.BigEndian.Uint64(delta.Path[0]) @@ -99,7 +99,7 @@ func (d DataProvider) GetCertifiedChangesSince(version uint64, publicKey []byte) publicKey, digest[:], ); err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to verify certified data: %w", err) } return deltas, currentVersion.Uint64(), nil @@ -116,7 +116,7 @@ func (d DataProvider) GetChangesSince(version uint64) ([]*v1.RegistryDelta, uint }, &resp, ); err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to get changes since: %w", err) } if resp.Error != nil { return nil, 0, fmt.Errorf("error: %s", resp.Error.String()) @@ -132,7 +132,7 @@ func (d DataProvider) GetLatestVersion() (uint64, error) { nil, &resp, ); err != nil { - return 0, err + return 0, fmt.Errorf("failed to get latest version: %w", err) } return resp.Version, nil } @@ -154,7 +154,7 @@ func (d DataProvider) GetValue(key []byte, version *uint64) ([]byte, uint64, err }, &resp, ); err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to get value: %w", err) } if resp.Error != nil { return nil, 0, fmt.Errorf("error: %s", resp.Error.String()) @@ -178,7 +178,7 @@ func (d DataProvider) GetValueUpdate(key []byte, version *uint64) ([]byte, uint6 }, &resp, ); err != nil { - return nil, 0, err + return nil, 0, fmt.Errorf("failed to get value: %w", err) } if resp.Error != nil { return nil, 0, fmt.Errorf("error: %s", resp.Error.String()) diff --git a/registry/dataprovider_test.go b/registry/dataprovider_test.go new file mode 100644 index 0000000..98031e5 --- /dev/null +++ b/registry/dataprovider_test.go @@ -0,0 +1,18 @@ +package registry_test + +import ( + "github.com/aviate-labs/agent-go/registry" + "testing" +) + +func TestDataProvider_GetLatestVersion(t *testing.T) { + checkEnabled(t) + + dp, err := registry.NewDataProvider() + if err != nil { + t.Fatal(err) + } + if _, err := dp.GetLatestVersion(); err != nil { + t.Error(err) + } +} diff --git a/request.go b/request.go index fa62027..eb80020 100644 --- a/request.go +++ b/request.go @@ -79,7 +79,9 @@ func (r *Request) MarshalCBOR() ([]byte, error) { if len(r.MethodName) != 0 { m["method_name"] = r.MethodName } - if len(r.Arguments) != 0 { + if r.Arguments != nil { + // Some endpoints require the argument to be an empty array, not null. + // This is the case with the protobuf endpoints on the registry. m["arg"] = r.Arguments } if len(r.Sender.Raw) != 0 {