Skip to content

Commit

Permalink
preserve Identity Center imported group member label (#48785)
Browse files Browse the repository at this point in the history
* preserve Identity Center imported group member label

* test

* preserve label on udpate/upsert member methods

* address review comments

* error check one liner, remove unnecessary return
  • Loading branch information
flyinghermit authored Nov 13, 2024
1 parent f99c34a commit d712a37
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
26 changes: 26 additions & 0 deletions lib/services/local/access_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
accesslistv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accesslist/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/common"
"github.com/gravitational/teleport/api/types/header"
"github.com/gravitational/teleport/entitlements"
"github.com/gravitational/teleport/lib/accesslists"
Expand Down Expand Up @@ -544,6 +545,11 @@ func (a *AccessListService) UpsertAccessListMember(ctx context.Context, member *
if err != nil {
return trace.Wrap(err)
}
existingMember, err := a.memberService.GetResource(ctx, member.GetName())
if err != nil && !trace.IsNotFound(err) {
return trace.Wrap(err)
}
keepAWSIdentityCenterLabels(existingMember, member)

if err := accesslists.ValidateAccessListMember(ctx, memberList, member, &accessListAndMembersGetter{a.service, a.memberService}); err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -575,6 +581,11 @@ func (a *AccessListService) UpdateAccessListMember(ctx context.Context, member *
if err != nil {
return trace.Wrap(err)
}
existingMember, err := a.memberService.GetResource(ctx, member.GetName())
if err != nil && !trace.IsNotFound(err) {
return trace.Wrap(err)
}
keepAWSIdentityCenterLabels(existingMember, member)

if err := accesslists.ValidateAccessListMember(ctx, memberList, member, &accessListAndMembersGetter{a.service, a.memberService}); err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -732,6 +743,7 @@ func (a *AccessListService) UpsertAccessListWithMembers(ctx context.Context, acc
if existingMember.Spec.Reason != "" {
newMember.Spec.Reason = existingMember.Spec.Reason
}
keepAWSIdentityCenterLabels(existingMember, newMember)
newMember.Spec.AddedBy = existingMember.Spec.AddedBy

// Compare members and update if necessary.
Expand Down Expand Up @@ -1029,3 +1041,17 @@ func (a *AccessListService) VerifyAccessListCreateLimit(ctx context.Context, tar
const limitReachedMessage = "cluster has reached its limit for creating access lists, please contact the cluster administrator"
return trace.AccessDenied(limitReachedMessage)
}

// keepAWSIdentityCenterLabels preserves member labels if
// it originated from AWS Identity Center plugin.
// The Web UI does not currently preserve metadata labels so this function should be called
// in every update/upsert member calls.
// Remove this function once https://github.com/gravitational/teleport.e/issues/5415 is addressed.
func keepAWSIdentityCenterLabels(old, new *accesslist.AccessListMember) {
if old == nil || new == nil {
return
}
if old.Origin() == common.OriginAWSIdentityCenter {
new.Metadata.Labels = old.GetAllLabels()
}
}
46 changes: 46 additions & 0 deletions lib/services/local/access_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/common"
"github.com/gravitational/teleport/api/types/header"
"github.com/gravitational/teleport/api/types/trait"
"github.com/gravitational/teleport/entitlements"
Expand Down Expand Up @@ -662,6 +663,51 @@ func TestAccessListMembersCRUD(t *testing.T) {
require.ErrorIs(t, err, trace.NotFound("access_list %q doesn't exist", accessList2.GetName()))
}

func TestUpsertAndUpdateAccessListWithMembers_PreservesIdentityCenterLablesForExistingMembers(t *testing.T) {
ctx := context.Background()
clock := clockwork.NewFakeClock()
mem, err := memory.New(memory.Config{
Context: ctx,
Clock: clock,
})
require.NoError(t, err)
service := newAccessListService(t, mem, clock, true /* igsEnabled */)

accessList1 := newAccessList(t, "accessList1", clock)
_, err = service.UpsertAccessList(ctx, accessList1)
require.NoError(t, err)
accessList1Member1 := newAccessListMember(t, accessList1.GetName(), "aws-ic-user")
accessList1Member1.SetOrigin(common.OriginAWSIdentityCenter)
accessList1Member1.Metadata.Labels["foo"] = "bar"

_, err = service.UpsertAccessListMember(ctx, accessList1Member1)
require.NoError(t, err)

member, err := service.GetAccessListMember(ctx, accessList1.GetName(), accessList1Member1.GetName())
require.NoError(t, err)
require.Empty(
t,
cmp.Diff(
accessList1Member1,
member,
cmpopts.IgnoreFields(header.Metadata{}, "Revision"),
cmpopts.IgnoreFields(accesslist.AccessListMemberSpec{}, "Joined"),
))

dupeMemberButWithoutOriginLabel := newAccessListMember(t, accessList1.GetName(), "aws-ic-user")
_, updatedMembers, err := service.UpsertAccessListWithMembers(ctx, accessList1, []*accesslist.AccessListMember{dupeMemberButWithoutOriginLabel})
require.NoError(t, err)
require.Equal(t, "bar", updatedMembers[0].GetMetadata().Labels["foo"])

updatedMember, err := service.UpdateAccessListMember(ctx, dupeMemberButWithoutOriginLabel)
require.NoError(t, err)
require.Equal(t, "bar", updatedMember.GetMetadata().Labels["foo"])

upsertedMember, err := service.UpdateAccessListMember(ctx, dupeMemberButWithoutOriginLabel)
require.NoError(t, err)
require.Equal(t, "bar", upsertedMember.GetMetadata().Labels["foo"])
}

func TestAccessListReviewCRUD(t *testing.T) {
ctx := context.Background()
clock := clockwork.NewFakeClock()
Expand Down

0 comments on commit d712a37

Please sign in to comment.