From dbe58a4f0e38c8772c9c9c25bbaef2a4d0f6d95a Mon Sep 17 00:00:00 2001 From: the_aceix Date: Mon, 11 Dec 2023 13:57:00 +0000 Subject: [PATCH] fix(NET-786): enhance enrollment key validation --- controllers/enrollmentkeys.go | 30 ++++++++++++++++++++++++++++++ logic/enrollmentkey.go | 6 +++--- logic/enrollmentkey_test.go | 2 +- models/enrollment_key.go | 31 +++++++++++++++++++++++++------ 4 files changed, 59 insertions(+), 10 deletions(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index c7de4f92a..5c586c476 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/go-playground/validator/v10" "github.com/google/uuid" "github.com/gorilla/mux" @@ -115,6 +116,35 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { if enrollmentKeyBody.Expiration > 0 { newTime = time.Unix(enrollmentKeyBody.Expiration, 0) } + v := validator.New() + err = v.Struct(enrollmentKeyBody) + if err != nil { + logger.Log(0, r.Header.Get("user"), "error validating request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("validation error: name length must be between 3 and 32: %w", err), "badrequest")) + return + } + + if existingKeys, err := logic.GetAllEnrollmentKeys(); err != nil { + logger.Log(0, r.Header.Get("user"), "error validating request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } else { + // check if any tags are duplicate + existingTags := make(map[string]struct{}) + for _, existingKey := range existingKeys { + for _, t := range existingKey.Tags { + existingTags[t] = struct{}{} + } + } + for _, t := range enrollmentKeyBody.Tags { + if _, ok := existingTags[t]; ok { + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("key names must be unique"), "badrequest")) + return + } + } + } relayId := uuid.Nil if enrollmentKeyBody.Relay != "" { diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index 5605bdac9..ae5d01d5f 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -22,7 +22,7 @@ var EnrollmentErrors = struct { FailedToTokenize error FailedToDeTokenize error }{ - InvalidCreate: fmt.Errorf("invalid enrollment key created"), + InvalidCreate: fmt.Errorf("failed to create enrollment key. paramters invalid"), NoKeyFound: fmt.Errorf("no enrollmentkey found"), InvalidKey: fmt.Errorf("invalid key provided"), NoUsesRemaining: fmt.Errorf("no uses remaining"), @@ -61,8 +61,8 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string if len(tags) > 0 { k.Tags = tags } - if ok := k.Validate(); !ok { - return nil, EnrollmentErrors.InvalidCreate + if err := k.Validate(); err != nil { + return nil, err } if relay != uuid.Nil { relayNode, err := GetNodeByID(relay.String()) diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index 3812b08a2..f91469ad0 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -17,7 +17,7 @@ func TestCreateEnrollmentKey(t *testing.T) { newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false, uuid.Nil) assert.Nil(t, newKey) assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentErrors.InvalidCreate) + assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey) }) t.Run("Can_Create_Key_Uses", func(t *testing.T) { newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil) diff --git a/models/enrollment_key.go b/models/enrollment_key.go index 982c5463b..e775344df 100644 --- a/models/enrollment_key.go +++ b/models/enrollment_key.go @@ -1,6 +1,8 @@ package models import ( + "errors" + "fmt" "time" "github.com/google/uuid" @@ -13,6 +15,14 @@ const ( Unlimited ) +var ( + ErrNilEnrollmentKey = errors.New("enrollment key is nil") + ErrNilNetworksEnrollmentKey = errors.New("enrollment key networks is nil") + ErrNilTagsEnrollmentKey = errors.New("enrollment key tags is nil") + ErrInvalidEnrollmentKey = errors.New("enrollment key is not valid") + ErrInvalidEnrollmentKeyValue = errors.New("enrollment key value is not valid") +) + // KeyType - the type of enrollment key type KeyType int @@ -50,7 +60,7 @@ type APIEnrollmentKey struct { UsesRemaining int `json:"uses_remaining"` Networks []string `json:"networks"` Unlimited bool `json:"unlimited"` - Tags []string `json:"tags"` + Tags []string `json:"tags" validate:"required,dive,min=3,max=32"` Type KeyType `json:"type"` Relay string `json:"relay"` } @@ -81,9 +91,18 @@ func (k *EnrollmentKey) IsValid() bool { // EnrollmentKey.Validate - validate's an EnrollmentKey // should be used during creation -func (k *EnrollmentKey) Validate() bool { - return k.Networks != nil && - k.Tags != nil && - len(k.Value) == EnrollmentKeyLength && - k.IsValid() +func (k *EnrollmentKey) Validate() error { + if k == nil { + return ErrNilEnrollmentKey + } + if k.Tags == nil { + return ErrNilTagsEnrollmentKey + } + if len(k.Value) != EnrollmentKeyLength { + return fmt.Errorf("%w: length not %d characters", ErrInvalidEnrollmentKeyValue, EnrollmentKeyLength) + } + if !k.IsValid() { + return fmt.Errorf("%w: uses remaining: %d, expiration: %s, unlimited: %t", ErrInvalidEnrollmentKey, k.UsesRemaining, k.Expiration, k.Unlimited) + } + return nil }