Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-empty arg for protobuf endpoints. #32

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/test-registry.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
on:
push:
paths:
- registry/**
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.22.1'
- run: make test-registry
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.21.0'
go-version: '1.22.1'
- uses: aviate-labs/[email protected]
with:
dfx-version: 0.18.0
Expand Down
28 changes: 16 additions & 12 deletions registry/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type Client struct {
func New() (*Client, error) {
dp, err := NewDataProvider()
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create data provider: %w", err)
}
return &Client{
dp: dp,
Expand All @@ -26,15 +26,19 @@ func New() (*Client, error) {
func (c *Client) GetNNSSubnetID() (*principal.Principal, error) {
v, _, err := c.dp.GetValueUpdate([]byte("nns_subnet_id"), nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get NNS subnet ID: %w", err)
}
var nnsSubnetID v1.SubnetId
if err := proto.Unmarshal(v, &nnsSubnetID); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal NNS subnet ID: %w", err)
}
return &principal.Principal{Raw: nnsSubnetID.PrincipalId.Raw}, nil
}

func (c *Client) GetLatestVersion() (uint64, error) {
return c.dp.GetLatestVersion()
}

func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) {
nnsSubnetID, err := c.GetNNSSubnetID()
if err != nil {
Expand All @@ -56,7 +60,7 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) {
for {
records, _, err := c.dp.GetCertifiedChangesSince(currentVersion, nnsPublicKey)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get certified changes: %w", err)
}
currentVersion = records[len(records)-1].Version
for _, record := range records {
Expand All @@ -66,7 +70,7 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) {
} else {
var nodeRecord v1.NodeRecord
if err := proto.Unmarshal(record.Value, &nodeRecord); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal node record: %w", err)
}
nodeMap[strings.TrimPrefix(record.Key, "node_record_")] = &nodeRecord
}
Expand All @@ -76,7 +80,7 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) {
} else {
var nodeOperatorRecord v1.NodeOperatorRecord
if err := proto.Unmarshal(record.Value, &nodeOperatorRecord); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal node operator record: %w", err)
}
nodeOperatorMap[strings.TrimPrefix(record.Key, "node_operator_record_")] = &nodeOperatorRecord
}
Expand Down Expand Up @@ -124,23 +128,23 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) {
func (c *Client) GetSubnetDetails(subnetID principal.Principal) (*v1.SubnetRecord, error) {
v, _, err := c.dp.GetValueUpdate([]byte(fmt.Sprintf("subnet_record_%s", subnetID)), nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get subnet details: %w", err)
}
var record v1.SubnetRecord
if err := proto.Unmarshal(v, &record); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal subnet details: %w", err)
}
return &record, nil
}

func (c *Client) GetSubnetIDs() ([]principal.Principal, error) {
v, _, err := c.dp.GetValueUpdate([]byte("subnet_list"), nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get subnet IDs: %w", err)
}
var list v1.SubnetListRecord
if err := proto.Unmarshal(v, &list); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal subnet IDs: %w", err)
}
var subnets []principal.Principal
for _, subnet := range list.Subnets {
Expand All @@ -152,11 +156,11 @@ func (c *Client) GetSubnetIDs() ([]principal.Principal, error) {
func (c *Client) GetSubnetPublicKey(subnetID principal.Principal) ([]byte, error) {
v, _, err := c.dp.GetValueUpdate([]byte(fmt.Sprintf("crypto_threshold_signing_public_key_%s", subnetID)), nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get subnet public key: %w", err)
}
var publicKey v1.PublicKey
if err := proto.Unmarshal(v, &publicKey); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal subnet public key: %w", err)
}
if publicKey.Algorithm != v1.AlgorithmId_ALGORITHM_ID_THRES_BLS12_381 {
return nil, fmt.Errorf("unsupported public key algorithm")
Expand Down
14 changes: 11 additions & 3 deletions registry/client_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
package registry
package registry_test

import (
"github.com/aviate-labs/agent-go/registry"
"os"
"testing"
)

func TestClient_GetNodeListSince(t *testing.T) {
checkEnabled(t)
c, err := New()

c, err := registry.New()
if err != nil {
t.Fatal(err)
}

latestVersion, err := c.GetLatestVersion()
if err != nil {
t.Fatal(err)
}
if _, err := c.GetNodeListSince(0); err != nil {

if _, err := c.GetNodeListSince(latestVersion - 100); err != nil {
t.Fatal(err)
}
}
Expand Down
26 changes: 13 additions & 13 deletions registry/dataprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type DataProvider struct {
func NewDataProvider() (*DataProvider, error) {
a, err := agent.New(agent.DefaultConfig)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create agent: %w", err)
}
return &DataProvider{a: a}, nil
}
Expand All @@ -37,36 +37,36 @@ func (d DataProvider) GetCertifiedChangesSince(version uint64, publicKey []byte)
},
&resp,
); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to get certified changes: %w", err)
}
ht, err := NewHashTree(resp.HashTree)
if err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to create hash tree: %w", err)
}
rawCurrentVersion, err := ht.Lookup(hashtree.Label("current_version"))
if err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to lookup current version: %w", err)
}
currentVersion, err := leb128.DecodeUnsigned(bytes.NewReader(rawCurrentVersion))
if err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to decode current version: %w", err)
}

deltaNodes, err := ht.LookupSubTree(hashtree.Label("delta"))
if err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to lookup delta nodes: %w", err)
}
rawDeltas, err := hashtree.AllChildren(deltaNodes)
if err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to get all children: %w", err)
}

var deltas []VersionedRecord
lastVersion := version
for _, delta := range rawDeltas {
req := new(v1.RegistryAtomicMutateRequest)
if err := proto.Unmarshal(delta.Value, req); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to unmarshal atomic mutate request: %w", err)
}

v := binary.BigEndian.Uint64(delta.Path[0])
Expand Down Expand Up @@ -99,7 +99,7 @@ func (d DataProvider) GetCertifiedChangesSince(version uint64, publicKey []byte)
publicKey,
digest[:],
); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to verify certified data: %w", err)
}

return deltas, currentVersion.Uint64(), nil
Expand All @@ -116,7 +116,7 @@ func (d DataProvider) GetChangesSince(version uint64) ([]*v1.RegistryDelta, uint
},
&resp,
); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to get changes since: %w", err)
}
if resp.Error != nil {
return nil, 0, fmt.Errorf("error: %s", resp.Error.String())
Expand All @@ -132,7 +132,7 @@ func (d DataProvider) GetLatestVersion() (uint64, error) {
nil,
&resp,
); err != nil {
return 0, err
return 0, fmt.Errorf("failed to get latest version: %w", err)
}
return resp.Version, nil
}
Expand All @@ -154,7 +154,7 @@ func (d DataProvider) GetValue(key []byte, version *uint64) ([]byte, uint64, err
},
&resp,
); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to get value: %w", err)
}
if resp.Error != nil {
return nil, 0, fmt.Errorf("error: %s", resp.Error.String())
Expand All @@ -178,7 +178,7 @@ func (d DataProvider) GetValueUpdate(key []byte, version *uint64) ([]byte, uint6
},
&resp,
); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to get value: %w", err)
}
if resp.Error != nil {
return nil, 0, fmt.Errorf("error: %s", resp.Error.String())
Expand Down
18 changes: 18 additions & 0 deletions registry/dataprovider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package registry_test

import (
"github.com/aviate-labs/agent-go/registry"
"testing"
)

func TestDataProvider_GetLatestVersion(t *testing.T) {
checkEnabled(t)

dp, err := registry.NewDataProvider()
if err != nil {
t.Fatal(err)
}
if _, err := dp.GetLatestVersion(); err != nil {
t.Error(err)
}
}
4 changes: 3 additions & 1 deletion request.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ func (r *Request) MarshalCBOR() ([]byte, error) {
if len(r.MethodName) != 0 {
m["method_name"] = r.MethodName
}
if len(r.Arguments) != 0 {
if r.Arguments != nil {
// Some endpoints require the argument to be an empty array, not null.
// This is the case with the protobuf endpoints on the registry.
m["arg"] = r.Arguments
}
if len(r.Sender.Raw) != 0 {
Expand Down
Loading