Skip to content

Commit

Permalink
Merge pull request #429 from positiveblue/fix-426
Browse files Browse the repository at this point in the history
rpc: populate allowed/not allowed node ids when listing orders
  • Loading branch information
positiveblue authored Jan 9, 2023
2 parents 15dfbfb + a6aea9c commit cf4ca47
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 21 deletions.
61 changes: 47 additions & 14 deletions order/rpc_parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,19 @@ func ParseRPCOrder(version, leaseDuration uint32,
"at the same time")
}

kit.AllowedNodeIDs = make([][33]byte, len(details.AllowedNodeIds))
for idx, nodeID := range details.AllowedNodeIds {
if _, err := btcec.ParsePubKey(nodeID); err != nil {
return nil, fmt.Errorf("invalid allowed_node_id: %x",
nodeID)
}
copy(kit.AllowedNodeIDs[idx][:], nodeID)
allowedNodeIDs, err := UnmarshalNodeIDSlice(details.AllowedNodeIds)
if err != nil {
return nil, fmt.Errorf("invalid allowed_node_ids: %v", err)
}
kit.AllowedNodeIDs = allowedNodeIDs

kit.NotAllowedNodeIDs = make([][33]byte, len(details.NotAllowedNodeIds))
for idx, nodeID := range details.NotAllowedNodeIds {
if _, err := btcec.ParsePubKey(nodeID); err != nil {
return nil, fmt.Errorf("invalid not_allowed_node_id: "+
"%x", nodeID)
}
copy(kit.NotAllowedNodeIDs[idx][:], nodeID)
notAllowedNodeIDs, err := UnmarshalNodeIDSlice(
details.NotAllowedNodeIds,
)
if err != nil {
return nil, fmt.Errorf("invalid not_allowed_node_ids: %v", err)
}
kit.NotAllowedNodeIDs = notAllowedNodeIDs

kit.IsPublic = details.IsPublic

Expand Down Expand Up @@ -524,6 +520,43 @@ func ParseRPCSign(signMsg *auctioneerrpc.OrderMatchSignBegin) (AccountNonces,
return nonces, prevOutputs, nil
}

// MarshalNodeIDSlice returns a flattened version of an slice of node ids to be
// used in rpc serialization.
func MarshalNodeIDSlice(nodeIDs [][33]byte) [][]byte {
res := make([][]byte, 0, len(nodeIDs))

for i := range nodeIDs {
nodeID := make([]byte, 33)
copy(nodeID, nodeIDs[i][:])

res = append(res, nodeID)
}

return res
}

// UnmarshalNodeIDSlice returns a slice of node ids from a flatten version.
func UnmarshalNodeIDSlice(slice [][]byte) ([][33]byte, error) {
nodeIDs := make([][33]byte, len(slice))
for idx := range slice {
// Check that the node id pub key is in the correct format.
if len(slice[idx]) != 33 {
return nil, fmt.Errorf("invalid node_id length: %x",
slice[idx])
}

// Check that the node id pub key is a valid key.
if _, err := btcec.ParsePubKey(slice[idx]); err != nil {
return nil, fmt.Errorf("invalid node_id: %x",
slice[idx])
}

copy(nodeIDs[idx][:], slice[idx])
}

return nodeIDs, nil
}

// randomPreimage creates a new preimage from a random number generator.
func randomPreimage() ([]byte, error) {
var nonce Nonce
Expand Down
94 changes: 94 additions & 0 deletions order/rpc_parse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package order

import (
"encoding/hex"
"testing"

"github.com/stretchr/testify/require"
)

var nodeIDSerializationTestCases = []struct {
name string
nodeIDs func() [][33]byte
invalidSerializedData func() [][]byte
expectedErr string
}{{
name: "empty slice",
nodeIDs: func() [][33]byte {
return [][33]byte{}
},
}, {
name: "single node id",
nodeIDs: func() [][33]byte {
return [][33]byte{
nodePubkey,
}
},
}, {
name: "multiple node ids",
nodeIDs: func() [][33]byte {
nodeID, _ := hex.DecodeString("036b51e0cc2d9e5988ee4967e0ba67" +
"ef3727bb633fea21a0af58e0c9395446ba09")
var nodePubKey2 [33]byte
copy(nodePubKey2[:], nodeID)

return [][33]byte{
nodePubkey,
nodePubKey2,
}
},
}, {
name: "invalid length",
invalidSerializedData: func() [][]byte {
return [][]byte{
{1, 2},
}
},
expectedErr: "invalid node_id length",
}, {
name: "invalid pub key",
invalidSerializedData: func() [][]byte {
return MarshalNodeIDSlice([][33]byte{
{1, 2},
})
},
expectedErr: "invalid node_id:",
}}

// TestNodeIDSliceSerialization tests that we can properly serialize and
// deserialize a slice of node ids.
func TestNodeIDSliceSerialization(t *testing.T) {
for _, tc := range nodeIDSerializationTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

switch {
// Marshal and Unmarshal valid node ids.
case tc.nodeIDs != nil:
nodeIDs := tc.nodeIDs()
marshaled := MarshalNodeIDSlice(nodeIDs)
require.Equal(t, len(nodeIDs), len(marshaled))

unmarshaled, err := UnmarshalNodeIDSlice(
marshaled,
)

require.NoError(t, err)
require.Equal(t, tc.nodeIDs(), unmarshaled)

// Unmarshal invalid marshaled node ids.
case tc.invalidSerializedData != nil:
marshaled := tc.invalidSerializedData()

_, err := UnmarshalNodeIDSlice(marshaled)
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectedErr)

default:
require.Fail(t, "invalid test case")
}
})
}
}
15 changes: 12 additions & 3 deletions rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,14 @@ func (s *rpcServer) ListOrders(ctx context.Context,
}
}

allowedNodeIDs := order.MarshalNodeIDSlice(
dbOrder.Details().AllowedNodeIDs,
)

notAllowedNodeIDs := order.MarshalNodeIDSlice(
dbOrder.Details().NotAllowedNodeIDs,
)

details := &poolrpc.Order{
TraderKey: dbDetails.AcctKey[:],
RateFixed: dbDetails.FixedRate,
Expand All @@ -1723,7 +1731,9 @@ func (s *rpcServer) ListOrders(ctx context.Context,
AuctionType: auctioneerrpc.AuctionType(
dbOrder.Details().AuctionType,
),
IsPublic: dbOrder.Details().IsPublic,
AllowedNodeIds: allowedNodeIDs,
NotAllowedNodeIds: notAllowedNodeIDs,
IsPublic: dbOrder.Details().IsPublic,
}

switch o := dbOrder.(type) {
Expand Down Expand Up @@ -1780,8 +1790,7 @@ func (s *rpcServer) ListOrders(ctx context.Context,
bids = append(bids, rpcBid)

default:
return nil, fmt.Errorf("unknown order type: %v",
o)
return nil, fmt.Errorf("unknown order type: %v", o)
}
}
return &poolrpc.ListOrdersResponse{
Expand Down
5 changes: 1 addition & 4 deletions tools/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ RUN cd /tmp \
&& mkdir -p /tmp/build/.modcache \
&& cd /tmp/tools \
&& go install -trimpath -tags=tools github.com/golangci/golangci-lint/cmd/golangci-lint \
&& chmod -R 777 /tmp/build/ \
&& git config --global --add safe.directory /build
# The last line is needed to ensure that go build is able to gather
# information from the vsc used in the builds to get the commit hash.
&& chmod -R 777 /tmp/build/

WORKDIR /build

0 comments on commit cf4ca47

Please sign in to comment.