Skip to content

Commit

Permalink
Using the common deduplication function, and unit testing azure recon…
Browse files Browse the repository at this point in the history
…ciliation
  • Loading branch information
mvbrock committed Nov 18, 2024
1 parent bfae554 commit 021423f
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 42 deletions.
42 changes: 22 additions & 20 deletions lib/srv/discovery/fetchers/aws-sync/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package aws_sync

import "github.com/gravitational/teleport/lib/srv/discovery/common"

// MergeResources merges multiple resources into a single Resources object.
// This is used to merge resources from multiple accounts and regions
// into a single object.
Expand Down Expand Up @@ -58,24 +60,24 @@ func MergeResources(results ...*Resources) *Resources {
}

func deduplicateResources(result *Resources) {
result.Users = deduplicateSlice(result.Users, usersKey)
result.UserInlinePolicies = deduplicateSlice(result.UserInlinePolicies, userInlinePolKey)
result.UserAttachedPolicies = deduplicateSlice(result.UserAttachedPolicies, userAttchPolKey)
result.UserGroups = deduplicateSlice(result.UserGroups, userGroupKey)
result.Groups = deduplicateSlice(result.Groups, groupKey)
result.GroupInlinePolicies = deduplicateSlice(result.GroupInlinePolicies, grpInlinePolKey)
result.GroupAttachedPolicies = deduplicateSlice(result.GroupAttachedPolicies, grpAttchPolKey)
result.Instances = deduplicateSlice(result.Instances, instanceKey)
result.Policies = deduplicateSlice(result.Policies, policyKey)
result.S3Buckets = deduplicateSlice(result.S3Buckets, s3bucketKey)
result.Roles = deduplicateSlice(result.Roles, roleKey)
result.RoleInlinePolicies = deduplicateSlice(result.RoleInlinePolicies, roleInlinePolKey)
result.RoleAttachedPolicies = deduplicateSlice(result.RoleAttachedPolicies, roleAttchPolKey)
result.InstanceProfiles = deduplicateSlice(result.InstanceProfiles, instanceProfKey)
result.AssociatedAccessPolicies = deduplicateSlice(result.AssociatedAccessPolicies, assocAccPolKey)
result.EKSClusters = deduplicateSlice(result.EKSClusters, eksClusterKey)
result.AccessEntries = deduplicateSlice(result.AccessEntries, accessEntryKey)
result.RDSDatabases = deduplicateSlice(result.RDSDatabases, rdsDbKey)
result.SAMLProviders = deduplicateSlice(result.SAMLProviders, samlProvKey)
result.OIDCProviders = deduplicateSlice(result.OIDCProviders, oidcProvKey)
result.Users = common.DeduplicateSlice(result.Users, usersKey)
result.UserInlinePolicies = common.DeduplicateSlice(result.UserInlinePolicies, userInlinePolKey)
result.UserAttachedPolicies = common.DeduplicateSlice(result.UserAttachedPolicies, userAttchPolKey)
result.UserGroups = common.DeduplicateSlice(result.UserGroups, userGroupKey)
result.Groups = common.DeduplicateSlice(result.Groups, groupKey)
result.GroupInlinePolicies = common.DeduplicateSlice(result.GroupInlinePolicies, grpInlinePolKey)
result.GroupAttachedPolicies = common.DeduplicateSlice(result.GroupAttachedPolicies, grpAttchPolKey)
result.Instances = common.DeduplicateSlice(result.Instances, instanceKey)
result.Policies = common.DeduplicateSlice(result.Policies, policyKey)
result.S3Buckets = common.DeduplicateSlice(result.S3Buckets, s3bucketKey)
result.Roles = common.DeduplicateSlice(result.Roles, roleKey)
result.RoleInlinePolicies = common.DeduplicateSlice(result.RoleInlinePolicies, roleInlinePolKey)
result.RoleAttachedPolicies = common.DeduplicateSlice(result.RoleAttachedPolicies, roleAttchPolKey)
result.InstanceProfiles = common.DeduplicateSlice(result.InstanceProfiles, instanceProfKey)
result.AssociatedAccessPolicies = common.DeduplicateSlice(result.AssociatedAccessPolicies, assocAccPolKey)
result.EKSClusters = common.DeduplicateSlice(result.EKSClusters, eksClusterKey)
result.AccessEntries = common.DeduplicateSlice(result.AccessEntries, accessEntryKey)
result.RDSDatabases = common.DeduplicateSlice(result.RDSDatabases, rdsDbKey)
result.SAMLProviders = common.DeduplicateSlice(result.SAMLProviders, samlProvKey)
result.OIDCProviders = common.DeduplicateSlice(result.OIDCProviders, oidcProvKey)
}
7 changes: 4 additions & 3 deletions lib/srv/discovery/fetchers/aws-sync/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package aws_sync

import (
"github.com/gravitational/teleport/lib/srv/discovery/common"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -45,9 +46,9 @@ func TestMergeResources(t *testing.T) {

result := MergeResources(&oldResults, &newResults)
expected := Resources{
Users: deduplicateSlice(append(oldUsers, newUsers...), usersKey),
Roles: deduplicateSlice(append(oldRoles, newRoles...), roleKey),
Instances: deduplicateSlice(append(oldEC2, newEC2...), instanceKey),
Users: common.DeduplicateSlice(append(oldUsers, newUsers...), usersKey),
Roles: common.DeduplicateSlice(append(oldRoles, newRoles...), roleKey),
Instances: common.DeduplicateSlice(append(oldEC2, newEC2...), instanceKey),
}
require.Empty(t, cmp.Diff(&expected, result, protocmp.Transform(), cmpopts.EquateEmpty()))
}
Expand Down
16 changes: 2 additions & 14 deletions lib/srv/discovery/fetchers/aws-sync/reconcile.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package aws_sync

import (
"fmt"
"github.com/gravitational/teleport/lib/srv/discovery/common"

"google.golang.org/protobuf/proto"

Expand Down Expand Up @@ -71,27 +72,14 @@ type reconcilePair struct {
upsert, delete *accessgraphv1alpha.AWSResourceList
}

func deduplicateSlice[T any](s []T, key func(T) string) []T {
out := make([]T, 0, len(s))
seen := make(map[string]struct{})
for _, v := range s {
if _, ok := seen[key(v)]; ok {
continue
}
seen[key(v)] = struct{}{}
out = append(out, v)
}
return out
}

func reconcile[T proto.Message](
oldItems []T,
newItems []T,
keyFn func(T) string,
wrapFn func(T) *accessgraphv1alpha.AWSResource,
) *reconcilePair {
// Remove duplicates from the new items
newItems = deduplicateSlice(newItems, keyFn)
newItems = common.DeduplicateSlice(newItems, keyFn)
upsertRes := newResourceList()
deleteRes := newResourceList()

Expand Down
2 changes: 1 addition & 1 deletion lib/srv/discovery/fetchers/azure-sync/azure-sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (a *Fetcher) Poll(ctx context.Context, feats Features) (*Resources, error)
return nil, err
}
res.VirtualMachines = common.DeduplicateSlice(res.VirtualMachines, azureVmKey)
res.Principals = common.DeduplicateSlice(res.Principals, azureUserKey)
res.Principals = common.DeduplicateSlice(res.Principals, azurePrincipalsKey)
res.RoleDefinitions = common.DeduplicateSlice(res.RoleDefinitions, azureRoleDefKey)
res.RoleAssignments = common.DeduplicateSlice(res.RoleAssignments, azureRoleAssignKey)
return res, trace.Wrap(err)
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/discovery/fetchers/azure-sync/reconcile.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func MergeResources(results ...*Resources) *Resources {
result.Principals = append(result.Principals, r.Principals...)
result.VirtualMachines = append(result.VirtualMachines, r.VirtualMachines...)
}
result.Principals = common.DeduplicateSlice(result.Principals, azureUserKey)
result.Principals = common.DeduplicateSlice(result.Principals, azurePrincipalsKey)
result.VirtualMachines = common.DeduplicateSlice(result.VirtualMachines, azureVmKey)
return result
}
Expand All @@ -52,7 +52,7 @@ func ReconcileResults(old *Resources, new *Resources) (upsert, delete *accessgra
upsert, delete = newResourceList(), newResourceList()
reconciledResources := []*reconcilePair{
reconcile(old.VirtualMachines, new.VirtualMachines, azureVmKey, azureVmWrap),
reconcile(old.Principals, new.Principals, azureUserKey, azureUsersWrap),
reconcile(old.Principals, new.Principals, azurePrincipalsKey, azurePrincipalsWrap),
reconcile(old.RoleDefinitions, new.RoleDefinitions, azureRoleDefKey, azureRoleDefWrap),
reconcile(old.RoleAssignments, new.RoleAssignments, azureRoleAssignKey, azureRoleAssignWrap),
}
Expand Down Expand Up @@ -128,11 +128,11 @@ func azureVmWrap(vm *accessgraphv1alpha.AzureVirtualMachine) *accessgraphv1alpha
return &accessgraphv1alpha.AzureResource{Resource: &accessgraphv1alpha.AzureResource_VirtualMachine{VirtualMachine: vm}}
}

func azureUserKey(user *accessgraphv1alpha.AzurePrincipal) string {
func azurePrincipalsKey(user *accessgraphv1alpha.AzurePrincipal) string {
return fmt.Sprintf("%s:%s", user.SubscriptionId, user.Id)
}

func azureUsersWrap(principal *accessgraphv1alpha.AzurePrincipal) *accessgraphv1alpha.AzureResource {
func azurePrincipalsWrap(principal *accessgraphv1alpha.AzurePrincipal) *accessgraphv1alpha.AzureResource {
return &accessgraphv1alpha.AzureResource{Resource: &accessgraphv1alpha.AzureResource_Principal{Principal: principal}}
}

Expand Down
189 changes: 189 additions & 0 deletions lib/srv/discovery/fetchers/azure-sync/reconcile_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package azure_sync

import (
accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
"github.com/stretchr/testify/require"
"testing"
)

func TestReconcileResults(t *testing.T) {
principals := generatePrincipals()
roleDefs := generateRoleDefs()
roleAssigns := generateRoleAssigns()
vms := generateVms()

tests := []struct {
oldResults *Resources
newResults *Resources
expectedUpserts *accessgraphv1alpha.AzureResourceList
expectedDeletes *accessgraphv1alpha.AzureResourceList
}{
// Overlapping old and new results
{
oldResults: &Resources{
Principals: principals[0:2],
RoleDefinitions: roleDefs[0:2],
RoleAssignments: roleAssigns[0:2],
VirtualMachines: vms[0:2],
},
newResults: &Resources{
Principals: principals[1:3],
RoleDefinitions: roleDefs[1:3],
RoleAssignments: roleAssigns[1:3],
VirtualMachines: vms[1:3],
},
expectedUpserts: generateExpected(principals[2:3], roleDefs[2:3], roleAssigns[2:3], vms[2:3]),
expectedDeletes: generateExpected(principals[0:1], roleDefs[0:1], roleAssigns[0:1], vms[0:1]),
},
// Completely new results
{
oldResults: &Resources{
Principals: nil,
RoleDefinitions: nil,
RoleAssignments: nil,
VirtualMachines: nil,
},
newResults: &Resources{
Principals: principals[1:3],
RoleDefinitions: roleDefs[1:3],
RoleAssignments: roleAssigns[1:3],
VirtualMachines: vms[1:3],
},
expectedUpserts: generateExpected(principals[1:3], roleDefs[1:3], roleAssigns[1:3], vms[1:3]),
expectedDeletes: generateExpected(nil, nil, nil, nil),
},
// No new results
{
oldResults: &Resources{
Principals: principals[1:3],
RoleDefinitions: roleDefs[1:3],
RoleAssignments: roleAssigns[1:3],
VirtualMachines: vms[1:3],
},
newResults: &Resources{
Principals: nil,
RoleDefinitions: nil,
RoleAssignments: nil,
VirtualMachines: nil,
},
expectedUpserts: generateExpected(nil, nil, nil, nil),
expectedDeletes: generateExpected(principals[1:3], roleDefs[1:3], roleAssigns[1:3], vms[1:3]),
},
}

for _, test := range tests {
upserts, deletes := ReconcileResults(test.oldResults, test.newResults)
require.ElementsMatch(t, upserts.Resources, test.expectedUpserts.Resources)
require.ElementsMatch(t, deletes.Resources, test.expectedDeletes.Resources)
}

}

func generateExpected(
principals []*accessgraphv1alpha.AzurePrincipal,
roleDefs []*accessgraphv1alpha.AzureRoleDefinition,
roleAssigns []*accessgraphv1alpha.AzureRoleAssignment,
vms []*accessgraphv1alpha.AzureVirtualMachine,
) *accessgraphv1alpha.AzureResourceList {
resList := &accessgraphv1alpha.AzureResourceList{
Resources: make([]*accessgraphv1alpha.AzureResource, 0),
}
for _, principal := range principals {
resList.Resources = append(resList.Resources, azurePrincipalsWrap(principal))
}
for _, roleDef := range roleDefs {
resList.Resources = append(resList.Resources, azureRoleDefWrap(roleDef))
}
for _, roleAssign := range roleAssigns {
resList.Resources = append(resList.Resources, azureRoleAssignWrap(roleAssign))
}
for _, vm := range vms {
resList.Resources = append(resList.Resources, azureVmWrap(vm))
}
return resList
}

func generatePrincipals() []*accessgraphv1alpha.AzurePrincipal {
return []*accessgraphv1alpha.AzurePrincipal{
{
Id: "/principals/foo",
DisplayName: "userFoo",
},
{
Id: "/principals/bar",
DisplayName: "userBar",
},
{
Id: "/principals/charles",
DisplayName: "userCharles",
},
}
}

func generateRoleDefs() []*accessgraphv1alpha.AzureRoleDefinition {
return []*accessgraphv1alpha.AzureRoleDefinition{
{
Id: "/roledefinitions/foo",
Name: "roleFoo",
},
{
Id: "/roledefinitions/bar",
Name: "roleBar",
},
{
Id: "/roledefinitions/charles",
Name: "roleCharles",
},
}
}

func generateRoleAssigns() []*accessgraphv1alpha.AzureRoleAssignment {
return []*accessgraphv1alpha.AzureRoleAssignment{
{
Id: "/roleassignments/foo",
PrincipalId: "userFoo",
},
{
Id: "/roleassignments/bar",
PrincipalId: "userBar",
},
{
Id: "/roleassignments/charles",
PrincipalId: "userCharles",
},
}
}

func generateVms() []*accessgraphv1alpha.AzureVirtualMachine {
return []*accessgraphv1alpha.AzureVirtualMachine{
{
Id: "/vms/foo",
Name: "userFoo",
},
{
Id: "/vms/bar",
Name: "userBar",
},
{
Id: "/vms/charles",
Name: "userCharles",
},
}
}

0 comments on commit 021423f

Please sign in to comment.