Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(NET-786): enhance enrollment key validation #2726

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions controllers/enrollmentkeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"time"

"github.com/go-playground/validator/v10"
"github.com/google/uuid"
"github.com/gorilla/mux"

Expand Down Expand Up @@ -115,6 +116,35 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
if enrollmentKeyBody.Expiration > 0 {
newTime = time.Unix(enrollmentKeyBody.Expiration, 0)
}
v := validator.New()
err = v.Struct(enrollmentKeyBody)
if err != nil {
logger.Log(0, r.Header.Get("user"), "error validating request body: ",
err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("validation error: name length must be between 3 and 32: %w", err), "badrequest"))
return
}

if existingKeys, err := logic.GetAllEnrollmentKeys(); err != nil {
logger.Log(0, r.Header.Get("user"), "error validating request body: ",
err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
} else {
// check if any tags are duplicate
existingTags := make(map[string]struct{})
for _, existingKey := range existingKeys {
for _, t := range existingKey.Tags {
existingTags[t] = struct{}{}
}
}
for _, t := range enrollmentKeyBody.Tags {
if _, ok := existingTags[t]; ok {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("key names must be unique"), "badrequest"))
return
}
}
}

relayId := uuid.Nil
if enrollmentKeyBody.Relay != "" {
Expand Down
6 changes: 3 additions & 3 deletions logic/enrollmentkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ var EnrollmentErrors = struct {
FailedToTokenize error
FailedToDeTokenize error
}{
InvalidCreate: fmt.Errorf("invalid enrollment key created"),
InvalidCreate: fmt.Errorf("failed to create enrollment key. paramters invalid"),
NoKeyFound: fmt.Errorf("no enrollmentkey found"),
InvalidKey: fmt.Errorf("invalid key provided"),
NoUsesRemaining: fmt.Errorf("no uses remaining"),
Expand Down Expand Up @@ -61,8 +61,8 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
if len(tags) > 0 {
k.Tags = tags
}
if ok := k.Validate(); !ok {
return nil, EnrollmentErrors.InvalidCreate
if err := k.Validate(); err != nil {
return nil, err
}
if relay != uuid.Nil {
relayNode, err := GetNodeByID(relay.String())
Expand Down
2 changes: 1 addition & 1 deletion logic/enrollmentkey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false, uuid.Nil)
assert.Nil(t, newKey)
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.InvalidCreate)
assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey)
})
t.Run("Can_Create_Key_Uses", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
Expand Down
31 changes: 25 additions & 6 deletions models/enrollment_key.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package models

import (
"errors"
"fmt"
"time"

"github.com/google/uuid"
Expand All @@ -13,6 +15,14 @@ const (
Unlimited
)

var (
ErrNilEnrollmentKey = errors.New("enrollment key is nil")
ErrNilNetworksEnrollmentKey = errors.New("enrollment key networks is nil")
ErrNilTagsEnrollmentKey = errors.New("enrollment key tags is nil")
ErrInvalidEnrollmentKey = errors.New("enrollment key is not valid")
ErrInvalidEnrollmentKeyValue = errors.New("enrollment key value is not valid")
)

// KeyType - the type of enrollment key
type KeyType int

Expand Down Expand Up @@ -50,7 +60,7 @@ type APIEnrollmentKey struct {
UsesRemaining int `json:"uses_remaining"`
Networks []string `json:"networks"`
Unlimited bool `json:"unlimited"`
Tags []string `json:"tags"`
Tags []string `json:"tags" validate:"required,dive,min=3,max=32"`
Type KeyType `json:"type"`
Relay string `json:"relay"`
}
Expand Down Expand Up @@ -81,9 +91,18 @@ func (k *EnrollmentKey) IsValid() bool {

// EnrollmentKey.Validate - validate's an EnrollmentKey
// should be used during creation
func (k *EnrollmentKey) Validate() bool {
return k.Networks != nil &&
k.Tags != nil &&
len(k.Value) == EnrollmentKeyLength &&
k.IsValid()
func (k *EnrollmentKey) Validate() error {
if k == nil {
return ErrNilEnrollmentKey
}
if k.Tags == nil {
return ErrNilTagsEnrollmentKey
}
if len(k.Value) != EnrollmentKeyLength {
return fmt.Errorf("%w: length not %d characters", ErrInvalidEnrollmentKeyValue, EnrollmentKeyLength)
}
if !k.IsValid() {
return fmt.Errorf("%w: uses remaining: %d, expiration: %s, unlimited: %t", ErrInvalidEnrollmentKey, k.UsesRemaining, k.Expiration, k.Unlimited)
}
return nil
}
Loading