Skip to content

Commit

Permalink
Merge pull request #33 from adamgall/better-signature-node-id-error-h…
Browse files Browse the repository at this point in the history
…andling

Return descriptive error if recovered address doesn't match given address
  • Loading branch information
jshufro authored Oct 24, 2024
2 parents 850e2fd + 852c2ea commit 56b714d
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 24 deletions.
5 changes: 3 additions & 2 deletions api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/Rocket-Rescue-Node/rescue-api/services"
"github.com/ethereum/go-ethereum/common"
"github.com/gorilla/mux"
"github.com/rs/cors"
"go.uber.org/zap"
Expand Down Expand Up @@ -52,7 +53,7 @@ func (ar *apiRouter) CreateCredential(w http.ResponseWriter, r *http.Request) er
}

// Create the credential
cred, err := ar.svc.CreateCredentialWithRetry([]byte(req.Msg), *sig, req.operatorType)
cred, err := ar.svc.CreateCredentialWithRetry([]byte(req.Msg), *sig, common.HexToAddress(req.Address), req.operatorType)
if err != nil {
return writeJSONError(w, err)
}
Expand Down Expand Up @@ -89,7 +90,7 @@ func (ar *apiRouter) GetOperatorInfo(w http.ResponseWriter, r *http.Request) err
req := (*OperatorInfoRequest)(credReq)

// Get operator info
operatorInfo, err := ar.svc.GetOperatorInfo([]byte(req.Msg), *sig, req.operatorType)
operatorInfo, err := ar.svc.GetOperatorInfo([]byte(req.Msg), *sig, common.HexToAddress(req.Address), req.operatorType)
if err != nil {
return writeJSONError(w, err)
}
Expand Down
9 changes: 5 additions & 4 deletions services/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/Rocket-Rescue-Node/credentials"
"github.com/Rocket-Rescue-Node/credentials/pb"
"github.com/Rocket-Rescue-Node/rescue-api/models"
"github.com/ethereum/go-ethereum/common"

"github.com/mattn/go-sqlite3"

Expand Down Expand Up @@ -102,15 +103,15 @@ func GetQuotaJSON(ot credentials.OperatorType) (json.RawMessage, error) {

// Creates a new credential for a node. If a valid credential already exists, it will be returned instead.
// This method will retry if creating a credential fails.
func (s *Service) CreateCredentialWithRetry(msg []byte, sig []byte, ot credentials.OperatorType) (*models.AuthenticatedCredential, error) {
func (s *Service) CreateCredentialWithRetry(msg []byte, sig []byte, expectedNodeId common.Address, ot credentials.OperatorType) (*models.AuthenticatedCredential, error) {
var cred *models.AuthenticatedCredential
var err error

var try int
s.m.Counter("create_credential_with_retry").Inc()
for try = range dbTryDelayMs {
// Try to create the credential.
if cred, err = s.CreateCredential(msg, sig, ot); err == nil {
if cred, err = s.CreateCredential(msg, sig, expectedNodeId, ot); err == nil {
break
}

Expand Down Expand Up @@ -148,11 +149,11 @@ func (s *Service) CreateCredentialWithRetry(msg []byte, sig []byte, ot credentia

// Creates a new credential for a node. If a valid credential exists, it will be returned instead.
// No retry logic is implemented, so it is up to the caller to retry if it does not succeed.
func (s *Service) CreateCredential(msg []byte, sig []byte, ot credentials.OperatorType) (*models.AuthenticatedCredential, error) {
func (s *Service) CreateCredential(msg []byte, sig []byte, expectedNodeId common.Address, ot credentials.OperatorType) (*models.AuthenticatedCredential, error) {
var err error

// Validate request
nodeID, err := s.validateSignedRequest(&msg, &sig, ot)
nodeID, err := s.validateSignedRequest(&msg, &sig, expectedNodeId, ot)
if err != nil {
return nil, err
}
Expand Down
31 changes: 17 additions & 14 deletions services/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/Rocket-Rescue-Node/credentials"
"github.com/Rocket-Rescue-Node/credentials/pb"
"github.com/Rocket-Rescue-Node/rescue-api/util"
"github.com/ethereum/go-ethereum/common"
"github.com/jonboulle/clockwork"
)

Expand All @@ -25,7 +26,7 @@ func createValidCredential(svc *Service, node *util.Wallet) (*credentials.Authen
return nil, fmt.Errorf("Could not sign message: %v", err)
}
// Create credential.
cred, err := svc.CreateCredentialWithRetry(msg, sig, pb.OperatorType_OT_ROCKETPOOL)
cred, err := svc.CreateCredentialWithRetry(msg, sig, *node.Address, pb.OperatorType_OT_ROCKETPOOL)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -230,25 +231,27 @@ func TestCreateCredentialRequests(t *testing.T) {
name string
msg []byte
sig []byte
adr common.Address
ot credentials.OperatorType
err error
}{
{"valid", msg, sig, pb.OperatorType_OT_ROCKETPOOL, nil},
{"valid_solo", msg, soloSig, pb.OperatorType_OT_SOLO, nil},
{"solo_masquerading_rp", msg, soloSig, pb.OperatorType_OT_ROCKETPOOL, &AuthorizationError{}},
{"rp_masquerading_solo", msg, sig, pb.OperatorType_OT_SOLO, &AuthorizationError{}},
{"malformed_signature", msg, []byte("invalid"), pb.OperatorType_OT_ROCKETPOOL, &AuthenticationError{}},
{"invalid_signature", msg, invalidSig, pb.OperatorType_OT_ROCKETPOOL, &AuthenticationError{}},
{"malformed_message", badMsg, badMsgSig, pb.OperatorType_OT_ROCKETPOOL, &ValidationError{}},
{"expired_timestamp", oldMsg, oldMsgSig, pb.OperatorType_OT_ROCKETPOOL, &ValidationError{}},
{"empty_message", []byte{}, sig, pb.OperatorType_OT_ROCKETPOOL, &ValidationError{}},
{"empty_signature", msg, []byte{}, pb.OperatorType_OT_ROCKETPOOL, &AuthenticationError{}},
{"unknown_node", otherMsg, otherSig, pb.OperatorType_OT_ROCKETPOOL, &AuthorizationError{}},
{"valid", msg, sig, *node.Address, pb.OperatorType_OT_ROCKETPOOL, nil},
{"valid_solo", msg, soloSig, *withdrawalAddress.Address, pb.OperatorType_OT_SOLO, nil},
{"solo_masquerading_rp", msg, soloSig, *withdrawalAddress.Address, pb.OperatorType_OT_ROCKETPOOL, &AuthorizationError{}},
{"rp_masquerading_solo", msg, sig, *node.Address, pb.OperatorType_OT_SOLO, &AuthorizationError{}},
{"malformed_signature", msg, []byte("invalid"), *node.Address, pb.OperatorType_OT_ROCKETPOOL, &AuthenticationError{}},
{"invalid_signature", msg, invalidSig, *node.Address, pb.OperatorType_OT_ROCKETPOOL, &AuthenticationError{}},
{"malformed_message", badMsg, badMsgSig, *node.Address, pb.OperatorType_OT_ROCKETPOOL, &ValidationError{}},
{"expired_timestamp", oldMsg, oldMsgSig, *node.Address, pb.OperatorType_OT_ROCKETPOOL, &ValidationError{}},
{"empty_message", []byte{}, sig, *node.Address, pb.OperatorType_OT_ROCKETPOOL, &ValidationError{}},
{"empty_signature", msg, []byte{}, *node.Address, pb.OperatorType_OT_ROCKETPOOL, &AuthenticationError{}},
{"unknown_node", otherMsg, otherSig, *otherNode.Address, pb.OperatorType_OT_ROCKETPOOL, &AuthorizationError{}},
{"mismatched_address", msg, sig, *otherNode.Address, pb.OperatorType_OT_ROCKETPOOL, &AuthenticationError{}},
}

for _, d := range data {
t.Run(d.name, func(t *testing.T) {
_, err := svc.CreateCredentialWithRetry(d.msg, d.sig, d.ot)
_, err := svc.CreateCredentialWithRetry(d.msg, d.sig, d.adr, d.ot)
if !errors.Is(err, d.err) {
t.Fatalf("Expected error %v, got %v", d.err, err)
}
Expand Down Expand Up @@ -299,7 +302,7 @@ func TestCreateCredentialConcurrent(t *testing.T) {
errChan <- err
return
}
_, err = svc.CreateCredentialWithRetry(msg, sig, pb.OperatorType_OT_ROCKETPOOL)
_, err = svc.CreateCredentialWithRetry(msg, sig, *nodes[i].Address, pb.OperatorType_OT_ROCKETPOOL)
if err != nil {
t.Errorf("Could not create credential %d: %v", i, err)
errChan <- err
Expand Down
5 changes: 3 additions & 2 deletions services/operator_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@ import (

"github.com/Rocket-Rescue-Node/credentials"
"github.com/Rocket-Rescue-Node/rescue-api/models"
"github.com/ethereum/go-ethereum/common"
"go.uber.org/zap"
)

type OperatorInfo struct {
CredentialEvents []int64 `json:"credentialEvents"`
}

func (s *Service) GetOperatorInfo(msg []byte, sig []byte, ot credentials.OperatorType) (*OperatorInfo, error) {
func (s *Service) GetOperatorInfo(msg []byte, sig []byte, expectedNodeId common.Address, ot credentials.OperatorType) (*OperatorInfo, error) {
var err error

// Validate request
nodeID, err := s.validateSignedRequest(&msg, &sig, ot)
nodeID, err := s.validateSignedRequest(&msg, &sig, expectedNodeId, ot)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion services/operator_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func getOperatorInfo(svc *Service, node *util.Wallet) (*OperatorInfo, error) {
}

// Get operator info
info, err := svc.GetOperatorInfo(msg, sig, pb.OperatorType_OT_ROCKETPOOL)
info, err := svc.GetOperatorInfo(msg, sig, *node.Address, pb.OperatorType_OT_ROCKETPOOL)
if err != nil {
return nil, err
}
Expand Down
8 changes: 7 additions & 1 deletion services/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package services
import (
"context"
"database/sql"
"fmt"
"regexp"
"time"

Expand Down Expand Up @@ -318,7 +319,7 @@ func (s *Service) checkNodeAuthorization(nodeID *models.NodeID, ot creds.Operato
return nil
}

func (s *Service) validateSignedRequest(msg *[]byte, sig *[]byte, ot pb.OperatorType) (*common.Address, error) {
func (s *Service) validateSignedRequest(msg *[]byte, sig *[]byte, expectedNodeId common.Address, ot pb.OperatorType) (*common.Address, error) {
// Check request age
if err := s.checkRequestAge(msg); err != nil {
return nil, err
Expand All @@ -330,6 +331,11 @@ func (s *Service) validateSignedRequest(msg *[]byte, sig *[]byte, ot pb.Operator
return nil, err
}

// Check if the nodeID matches the expected nodeID
if *nodeID != expectedNodeId {
return nil, &AuthenticationError{fmt.Sprintf("provided node id (%s) did not match address (%s) which signed the message", expectedNodeId.Hex(), nodeID.Hex())}
}

// Check node authz
if err := s.checkNodeAuthorization(nodeID, ot); err != nil {
return nil, err
Expand Down

0 comments on commit 56b714d

Please sign in to comment.