Skip to content

Commit

Permalink
block default key deletion,delete default key on network deletion
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishek9686 committed Oct 30, 2024
1 parent e1cc0a2 commit 56d5c85
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 19 deletions.
3 changes: 2 additions & 1 deletion controllers/enrollmentkeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) {
func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) {
params := mux.Vars(r)
keyID := params["keyID"]
err := logic.DeleteEnrollmentKey(keyID)
err := logic.DeleteEnrollmentKey(keyID, false)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to remove enrollment key: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
Expand Down Expand Up @@ -159,6 +159,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
enrollmentKeyBody.Groups,
enrollmentKeyBody.Unlimited,
relayId,
false,
)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
Expand Down
1 change: 1 addition & 0 deletions controllers/tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ func deleteTag(w http.ResponseWriter, r *http.Request) {

go func() {
logic.RemoveDeviceTagFromAclPolicies(tag.ID, tag.Network)
logic.RemoveTagFromEnrollmentKeys(tag.ID)
mq.PublishPeerUpdate(false)
}()
logic.ReturnSuccessResponse(w, r, "deleted tag "+tagID)
Expand Down
29 changes: 26 additions & 3 deletions logic/enrollmentkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var (
)

// CreateEnrollmentKey - creates a new enrollment key in db
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) {
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey bool) (*models.EnrollmentKey, error) {
newKeyID, err := getUniqueEnrollmentID()
if err != nil {
return nil, err
Expand Down Expand Up @@ -152,11 +152,14 @@ func deleteEnrollmentkeyFromCache(key string) {
}

// DeleteEnrollmentKey - delete's a given enrollment key by value
func DeleteEnrollmentKey(value string) error {
_, err := GetEnrollmentKey(value)
func DeleteEnrollmentKey(value string, force bool) error {
key, err := GetEnrollmentKey(value)
if err != nil {
return err
}
if key.Default && !force {
return errors.New("cannot delete default network key")
}
err = database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
if err == nil {
if servercfg.CacheEnabled() {
Expand Down Expand Up @@ -311,3 +314,23 @@ func getEnrollmentKeysMap() (map[string]models.EnrollmentKey, error) {
}
return currentKeys, nil
}

func RemoveTagFromEnrollmentKeys(deletedTagID models.TagID) {
keys, _ := GetAllEnrollmentKeys()
for _, key := range keys {
newTags := []models.TagID{}
update := false
for _, tagID := range key.Groups {
if tagID == deletedTagID {
update = true
continue
}
newTags = append(newTags, tagID)
}
if update {
key.Groups = newTags
upsertEnrollmentKey(&key)
}

}
}
30 changes: 15 additions & 15 deletions logic/enrollmentkey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
t.Run("Can_Not_Create_Key", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil)
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
assert.Nil(t, newKey)
assert.NotNil(t, err)
assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey)
})
t.Run("Can_Create_Key_Uses", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil)
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
assert.Nil(t, err)
assert.Equal(t, 1, newKey.UsesRemaining)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_Time", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil)
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil)
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil)
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
assert.True(t, len(newKey.Networks) == 2)
})
t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil)
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
assert.True(t, len(newKey.Tags) == 2)
Expand All @@ -62,18 +62,18 @@ func TestCreateEnrollmentKey(t *testing.T) {
func TestDelete_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil)
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
t.Run("Can_Delete_Key", func(t *testing.T) {
assert.True(t, newKey.IsValid())
err := DeleteEnrollmentKey(newKey.Value)
err := DeleteEnrollmentKey(newKey.Value, false)
assert.Nil(t, err)
oldKey, err := GetEnrollmentKey(newKey.Value)
assert.Equal(t, oldKey, models.EnrollmentKey{})
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
})
t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) {
err := DeleteEnrollmentKey("notakey")
err := DeleteEnrollmentKey("notakey", false)
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
})
Expand All @@ -83,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) {
func TestDecrement_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil)
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
t.Run("Check_initial_uses", func(t *testing.T) {
assert.True(t, newKey.IsValid())
assert.Equal(t, newKey.UsesRemaining, 1)
Expand All @@ -107,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) {
func TestUsability_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil)
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil)
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil)
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false)
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false)
t.Run("Check if valid use key can be used", func(t *testing.T) {
assert.Equal(t, key1.UsesRemaining, 1)
ok := TryToUseEnrollmentKey(key1)
Expand Down Expand Up @@ -145,7 +145,7 @@ func removeAllEnrollments() {
func TestTokenize_EnrollmentKeys(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil)
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
const serverAddr = "api.myserver.com"
Expand Down Expand Up @@ -178,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) {
func TestDeTokenize_EnrollmentKeys(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil)
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
const serverAddr = "api.myserver.com"

Expand Down
12 changes: 12 additions & 0 deletions logic/networks.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,17 @@ func DeleteNetwork(network string) error {
if err != nil {
logger.Log(1, "failed to remove the node acls during network delete for network,", network)
}
// Delete default network enrollment key
keys, _ := GetAllEnrollmentKeys()
for _, key := range keys {
if key.Tags[0] == network {
if key.Default {
DeleteEnrollmentKey(key.Value, true)
break
}

}
}
nodeCount, err := GetNetworkNonServerNodeCount(network)
if nodeCount == 0 || database.IsEmptyRecord(err) {
// delete server nodes first then db records
Expand Down Expand Up @@ -243,6 +254,7 @@ func CreateNetwork(network models.Network) (models.Network, error) {
[]models.TagID{},
true,
uuid.Nil,
true,
)

return network, nil
Expand Down
1 change: 1 addition & 0 deletions migrate/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ func updateEnrollmentKeys() {
[]models.TagID{},
true,
uuid.Nil,
true,
)

}
Expand Down
1 change: 1 addition & 0 deletions models/enrollment_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type EnrollmentKey struct {
Type KeyType `json:"type"`
Relay uuid.UUID `json:"relay"`
Groups []TagID `json:"groups"`
Default bool `json:"default"`
}

// APIEnrollmentKey - used to create enrollment keys via API
Expand Down

0 comments on commit 56d5c85

Please sign in to comment.