diff --git a/lib/auth/access_request_test.go b/lib/auth/access_request_test.go index fde1c7d4df323..d6c4cb287e963 100644 --- a/lib/auth/access_request_test.go +++ b/lib/auth/access_request_test.go @@ -1637,17 +1637,17 @@ func TestUpdateAccessRequestWithAdditionalReviewers(t *testing.T) { name: "with ownership through nested list", req: mustRequest("rev1"), accessLists: []*accesslist.AccessList{ + mustAccessList("nested1", "owner1"), mustAccessListWithMembershipKind( - "root", + "nested", testAccessListOwner{"owner1", accesslist.MembershipKindUser}, - testAccessListOwner{"nested", accesslist.MembershipKindList}, + testAccessListOwner{"nested1", accesslist.MembershipKindList}, ), mustAccessListWithMembershipKind( - "nested", + "root", testAccessListOwner{"owner1", accesslist.MembershipKindUser}, - testAccessListOwner{"nested1", accesslist.MembershipKindList}, + testAccessListOwner{"nested", accesslist.MembershipKindList}, ), - mustAccessList("nested1", "owner1"), }, accessListMembers: []struct { Header header.Metadata diff --git a/lib/auth/userloginstate/generator.go b/lib/auth/userloginstate/generator.go index 391b3b88b8157..b40acb14f5da5 100644 --- a/lib/auth/userloginstate/generator.go +++ b/lib/auth/userloginstate/generator.go @@ -177,22 +177,16 @@ func (g *Generator) addAccessListsToState(ctx context.Context, user types.User, return nil, nil, trace.Wrap(err) } - accessListHierarchy, err := accesslists.NewHierarchy(ctx, accesslists.HierarchyConfig{ - AccessLists: accessLists, - Locks: g.accessLists, - Members: g.accessLists, - Clock: g.clock, - }) - if err != nil { - return nil, nil, trace.Wrap(err) - } - var allInheritedRoles []string allInheritedTraits := make(map[string][]string) for _, accessList := range accessLists { // Grants are inherited if the user is a member of the access list, explicitly or via inheritance. - if membershipKind, err := accessListHierarchy.IsAccessListMember(ctx, user, accessList.GetName()); err == nil && membershipKind != accesslists.MembershipOrOwnershipTypeNone { + membershipKind, err := accesslists.IsAccessListMember(ctx, user, accessList, g.accessLists, g.accessLists, g.clock) + if err != nil { + g.log.WithError(err).Warn("checking access list membership") + } + if membershipKind != accesslists.MembershipOrOwnershipTypeNone { g.grantRolesAndTraits(accessList.Spec.Grants, state) if membershipKind == accesslists.MembershipOrOwnershipTypeInherited { allInheritedRoles = append(allInheritedRoles, accessList.Spec.Grants.Roles...) @@ -202,7 +196,11 @@ func (g *Generator) addAccessListsToState(ctx context.Context, user types.User, } } // OwnerGrants are inherited if the user is an owner of the access list, explicitly or via inheritance. - if ownershipType, err := accessListHierarchy.IsAccessListOwner(ctx, user, accessList.GetName()); err == nil && ownershipType != accesslists.MembershipOrOwnershipTypeNone { + ownershipType, err := accesslists.IsAccessListOwner(ctx, user, accessList, g.accessLists, g.accessLists, g.clock) + if err != nil { + g.log.WithError(err).Warn("checking access list ownership") + } + if ownershipType != accesslists.MembershipOrOwnershipTypeNone { g.grantRolesAndTraits(accessList.Spec.OwnerGrants, state) if ownershipType == accesslists.MembershipOrOwnershipTypeInherited { allInheritedRoles = append(allInheritedRoles, accessList.Spec.OwnerGrants.Roles...) diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 4b3da7f8a9405..a6495d9e1d8b4 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -3227,7 +3227,7 @@ func (accessListExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets } func (accessListExecutor) upsert(ctx context.Context, cache *Cache, resource *accesslist.AccessList) error { - _, err := cache.accessListCache.UpsertAccessList(ctx, resource) + _, err := cache.accessListCache.UnconditionalUpsertAccessList(ctx, resource) return trace.Wrap(err) } @@ -3236,7 +3236,7 @@ func (accessListExecutor) deleteAll(ctx context.Context, cache *Cache) error { } func (accessListExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.accessListCache.DeleteAccessList(ctx, resource.GetName()) + return cache.accessListCache.UnconditionalDeleteAccessList(ctx, resource.GetName()) } func (accessListExecutor) isSingleton() bool { return false } @@ -3277,7 +3277,7 @@ func (accessListMemberExecutor) getAll(ctx context.Context, cache *Cache, loadSe } func (accessListMemberExecutor) upsert(ctx context.Context, cache *Cache, resource *accesslist.AccessListMember) error { - _, err := cache.accessListCache.UpsertAccessListMember(ctx, resource) + _, err := cache.accessListCache.UnconditionalUpsertAccessListMember(ctx, resource) return trace.Wrap(err) } @@ -3286,7 +3286,7 @@ func (accessListMemberExecutor) deleteAll(ctx context.Context, cache *Cache) err } func (accessListMemberExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.accessListCache.DeleteAccessListMember(ctx, + return cache.accessListCache.UnconditionalDeleteAccessListMember(ctx, resource.GetMetadata().Description, // Cache passes access ID via description field. resource.GetName()) } diff --git a/lib/services/local/access_list.go b/lib/services/local/access_list.go index d7f2c8c78a738..c23e7ef6c5649 100644 --- a/lib/services/local/access_list.go +++ b/lib/services/local/access_list.go @@ -20,6 +20,7 @@ package local import ( "context" + "slices" "time" "github.com/google/go-cmp/cmp" @@ -60,6 +61,11 @@ const ( // creation or update of AccessLists in order to enforce the license limit // on the number AccessLists in a cluster. createAccessListLimitLockName = "createAccessListLimitLock" + // accessListResourceLockName is the lock used to prevent simultaneous + // writing to any AccessList resources (AccessLists, AccessListMembers). + // it shares the same string as createAccessListLimitLockName to ensure + // backwards compatibility. + accessListResourceLockName = createAccessListLimitLockName ) // AccessListService manages Access List resources in the Backend. The AccessListService's @@ -75,6 +81,18 @@ type AccessListService struct { reviewService *generic.Service[*accesslist.Review] } +type accessListAndMembersGetter struct { + service *generic.Service[*accesslist.AccessList] + memberService *generic.Service[*accesslist.AccessListMember] +} + +func (s *accessListAndMembersGetter) ListAccessListMembers(ctx context.Context, accessListName string, pageSize int, pageToken string) ([]*accesslist.AccessListMember, string, error) { + return s.memberService.WithPrefix(accessListName).ListResources(ctx, pageSize, pageToken) +} +func (s *accessListAndMembersGetter) GetAccessList(ctx context.Context, name string) (*accesslist.AccessList, error) { + return s.service.GetResource(ctx, name) +} + // compile-time assertion that the AccessListService implements the AccessLists // interface var _ services.AccessLists = (*AccessListService)(nil) @@ -168,73 +186,195 @@ func (a *AccessListService) GetAccessListsToReview(ctx context.Context) ([]*acce // UpsertAccessList creates or updates an access list resource. func (a *AccessListService) UpsertAccessList(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { - op := a.service.UpsertResource - return a.runOpWithLock(ctx, accessList, op) + return a.runOpWithLock(ctx, accessList, opTypeUpsert) } // UpdateAccessList updates an access list resource. func (a *AccessListService) UpdateAccessList(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { - op := a.service.ConditionalUpdateResource - return a.runOpWithLock(ctx, accessList, op) + return a.runOpWithLock(ctx, accessList, opTypeUpdate) } -type opFunc func(context.Context, *accesslist.AccessList) (*accesslist.AccessList, error) +type opType int -func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *accesslist.AccessList, op opFunc) (*accesslist.AccessList, error) { +const ( + opTypeUpsert opType = iota + opTypeUpdate +) + +func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *accesslist.AccessList, op opType) (*accesslist.AccessList, error) { if err := accessList.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } var upserted *accesslist.AccessList + var existingList *accesslist.AccessList - updateAccessList := func() error { - return a.service.RunWhileLocked(ctx, lockName(accessList.GetName()), accessListLockTTL, - func(ctx context.Context, _ backend.Backend) error { - var err error - upserted, err = op(ctx, accessList) + opFn := a.service.UpsertResource + if op == opTypeUpdate { + opFn = a.service.ConditionalUpdateResource + } + + validateAccessList := func() error { + var err error + + expectedMemberOf := []string{} + expectedOwnerOf := []string{} + + if op == opTypeUpdate { + existingList, err = a.service.GetResource(ctx, accessList.GetName()) + if err != nil { return trace.Wrap(err) - }) + } + expectedMemberOf = existingList.Spec.MemberOf + expectedOwnerOf = existingList.Spec.OwnerOf + } + + // If the memberOf or ownerOf fields have been updated, return an error – these fields are managed by the service. + if !cmp.Equal(expectedMemberOf, accessList.Spec.MemberOf) { + return trace.BadParameter("memberOf field may not be set explicitly") + } + if !cmp.Equal(expectedOwnerOf, accessList.Spec.OwnerOf) { + return trace.BadParameter("ownerOf field may not be set explicitly") + } + + listMembers, err := a.memberService.WithPrefix(accessList.GetName()).GetResources(ctx) + if err != nil { + return trace.Wrap(err) + } + + return accesslists.ValidateAccessListWithMembers(ctx, accessList, listMembers, &accessListAndMembersGetter{a.service, a.memberService}) + } + + updateAccessList := func() error { + var err error + upserted, err = opFn(ctx, accessList) + return trace.Wrap(err) + } + + reconcileOwners := func() error { + // Create map to store owners for efficient lookup + originalOwnersMap := make(map[string]struct{}) + if existingList != nil { + for _, owner := range existingList.Spec.Owners { + if owner.MembershipKind == accesslist.MembershipKindList { + originalOwnersMap[owner.Name] = struct{}{} + } + } + } + + currentOwnersMap := make(map[string]struct{}) + for _, owner := range accessList.Spec.Owners { + if owner.MembershipKind == accesslist.MembershipKindList { + currentOwnersMap[owner.Name] = struct{}{} + } + } + + // update references for new owners + for ownerName := range currentOwnersMap { + if _, exists := originalOwnersMap[ownerName]; !exists { + if err := a.updateAccessListOwnerOf(ctx, accessList.GetName(), ownerName, true); err != nil { + return trace.Wrap(err) + } + } + } + + // update references for old owners + for ownerName := range originalOwnersMap { + if _, exists := currentOwnersMap[ownerName]; !exists { + if err := a.updateAccessListOwnerOf(ctx, accessList.GetName(), ownerName, false); err != nil { + return trace.Wrap(err) + } + } + } + + return nil } + var actions []func() error + // If IGS is not enabled for this cluster we need to wrap the whole // operation inside *another* lock so that we can accurately count the // access lists in the cluster in order to prevent un-authorized use of // the AccessList feature - - action := updateAccessList if !modules.GetModules().Features().GetEntitlement(entitlements.Identity).Enabled { - action = func() error { - err := a.service.RunWhileLocked(ctx, []string{createAccessListLimitLockName}, accessListLockTTL, - func(ctx context.Context, _ backend.Backend) error { - if err := a.VerifyAccessListCreateLimit(ctx, accessList.GetName()); err != nil { - return trace.Wrap(err) - } - return trace.Wrap(updateAccessList()) - }) - return trace.Wrap(err) - } + actions = append(actions, func() error { return a.VerifyAccessListCreateLimit(ctx, accessList.GetName()) }) } - if err := action(); err != nil { - return nil, trace.Wrap(err) - } + actions = append(actions, validateAccessList, updateAccessList, reconcileOwners) + + err := a.service.RunWhileLocked(ctx, []string{accessListResourceLockName}, accessListLockTTL, + func(ctx context.Context, _ backend.Backend) error { + return a.service.RunWhileLocked(ctx, lockName(accessList.GetName()), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + for _, action := range actions { + if err := action(); err != nil { + return trace.Wrap(err) + } + } + return nil + }) + }) - return upserted, nil + return upserted, trace.Wrap(err) } // DeleteAccessList removes the specified access list resource. func (a *AccessListService) DeleteAccessList(ctx context.Context, name string) error { - err := a.service.RunWhileLocked(ctx, lockName(name), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + action := func(ctx context.Context, _ backend.Backend) error { + // Get list resource. + accessList, err := a.service.GetResource(ctx, name) + if err != nil { + return trace.Wrap(err) + } + + // Check if the access list is a member or owner of any other access lists. + if len(accessList.Spec.MemberOf) > 0 { + return trace.AccessDenied("Cannot delete '%s', as it is a member of one or more other Access Lists", accessList.Spec.Title) + } + if len(accessList.Spec.OwnerOf) > 0 { + return trace.AccessDenied("Cannot delete '%s', as it is an owner of one or more other Access Lists", accessList.Spec.Title) + } + // Delete all associated members. - err := a.memberService.WithPrefix(name).DeleteAllResources(ctx) + members, err := a.memberService.WithPrefix(name).GetResources(ctx) if err != nil { return trace.Wrap(err) } - return trace.Wrap(a.service.DeleteResource(ctx, name)) - }) + if err := a.memberService.WithPrefix(name).DeleteAllResources(ctx); err != nil { + return trace.Wrap(err) + } - return trace.Wrap(err) + // Update memberOf refs. + for _, member := range members { + if member.Spec.MembershipKind != accesslist.MembershipKindList { + continue + } + if err := a.updateAccessListMemberOf(ctx, name, member.GetName(), false); err != nil { + return trace.Wrap(err) + } + } + + // Delete list itself. + if err := a.service.DeleteResource(ctx, name); err != nil { + return trace.Wrap(err) + } + + // Update ownerOf refs. + for _, owner := range accessList.Spec.Owners { + if owner.MembershipKind != accesslist.MembershipKindList { + continue + } + if err := a.updateAccessListOwnerOf(ctx, name, owner.Name, false); err != nil { + return trace.Wrap(err) + } + } + + return nil + } + + return trace.Wrap(a.service.RunWhileLocked(ctx, []string{accessListResourceLockName}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + return a.service.RunWhileLocked(ctx, lockName(name), accessListLockTTL, action) + })) } // DeleteAllAccessLists removes all access lists. @@ -321,72 +461,156 @@ func (a *AccessListService) GetAccessListMember(ctx context.Context, accessList return member, trace.Wrap(err) } +// updateAccessListRefField is a helper that updates the specified field (memberOf or ownerOf) of an Access List, +// adding or removing the specified accessListName to the field of targetName. +func (a *AccessListService) updateAccessListRefField( + ctx context.Context, + accessListName string, + targetName string, + new bool, + fieldSelector func(spec *accesslist.Spec) *[]string, +) error { + targetAccessList, err := a.service.GetResource(ctx, targetName) + if err != nil { + if trace.IsNotFound(err) { + // If list is not found, it's possible that it was deleted. Regardless, there's nothing to update. + return nil + } + return trace.Wrap(err) + } + + field := fieldSelector(&targetAccessList.Spec) + + // If the field already contains the Access List, and we're adding, + // or doesn't contain it, and we're removing, there's nothing to do. + if slices.Contains(*field, accessListName) == new { + return nil + } + + if new { + *field = append(*field, accessListName) + } else { + *field = slices.DeleteFunc(*field, func(e string) bool { + return e == accessListName + }) + } + + if _, err := a.service.UpdateResource(ctx, targetAccessList); err != nil { + return trace.Wrap(err) + } + + return nil +} + +// updateAccessListMemberOf updates the memberOf field for the specified memberName and accessListName. +// Should only be called after the relevant member has been successfully upserted or deleted. +func (a *AccessListService) updateAccessListMemberOf(ctx context.Context, accessListName, memberName string, new bool) error { + return a.updateAccessListRefField(ctx, accessListName, memberName, new, func(spec *accesslist.Spec) *[]string { + return &spec.MemberOf + }) +} + +// updateAccessListOwnerOf updates the ownerOf field for the specified ownerName and accessListName. +// Should only be called after the relevant owner has been successfully upserted or deleted. +func (a *AccessListService) updateAccessListOwnerOf(ctx context.Context, accessListName, ownerName string, new bool) error { + return a.updateAccessListRefField(ctx, accessListName, ownerName, new, func(spec *accesslist.Spec) *[]string { + return &spec.OwnerOf + }) +} + // GetAccessListOwners returns a list of all owners in an Access List, including those inherited from nested Access Lists. // // Returned Owners are not validated for ownership requirements – use `IsAccessListOwner` for validation. func (a *AccessListService) GetAccessListOwners(ctx context.Context, accessListName string) ([]*accesslist.Owner, error) { - accessLists, err := a.GetAccessLists(ctx) - if err != nil { - return nil, trace.Wrap(err, "getting access lists") - } - hierarchy, err := accesslists.NewHierarchy(ctx, accesslists.HierarchyConfig{ - AccessLists: accessLists, - Members: a, - Locks: nil, - Clock: a.clock, + var owners []*accesslist.Owner + err := a.service.RunWhileLocked(ctx, lockName(accessListName), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + accessList, err := a.service.GetResource(ctx, accessListName) + if err != nil { + return trace.Wrap(err) + } + owners, err = accesslists.GetOwnersFor(ctx, accessList, &accessListAndMembersGetter{a.service, a.memberService}) + return trace.Wrap(err) }) - if err != nil { - return nil, trace.Wrap(err, "creating hierarchy") - } - owners, err := hierarchy.GetOwners(accessListName) - if err != nil { - return nil, trace.Wrap(err, "getting owners") - } - return owners, nil + return owners, trace.Wrap(err) } // UpsertAccessListMember creates or updates an access list member resource. func (a *AccessListService) UpsertAccessListMember(ctx context.Context, member *accesslist.AccessListMember) (*accesslist.AccessListMember, error) { var upserted *accesslist.AccessListMember - err := a.service.RunWhileLocked(ctx, lockName(member.Spec.AccessList), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { - _, err := a.service.GetResource(ctx, member.Spec.AccessList) + action := func(ctx context.Context, _ backend.Backend) error { + memberList, err := a.service.GetResource(ctx, member.Spec.AccessList) if err != nil { return trace.Wrap(err) } + + if err := accesslists.ValidateAccessListMember(ctx, memberList, member, &accessListAndMembersGetter{a.service, a.memberService}); err != nil { + return trace.Wrap(err) + } + upserted, err = a.memberService.WithPrefix(member.Spec.AccessList).UpsertResource(ctx, member) + + if err == nil && member.Spec.MembershipKind == accesslist.MembershipKindList { + if err := a.updateAccessListMemberOf(ctx, member.Spec.AccessList, member.Spec.Name, true); err != nil { + return trace.Wrap(err) + } + } + return trace.Wrap(err) - }) - if err != nil { - return nil, trace.Wrap(err) } - return upserted, nil + + err := a.service.RunWhileLocked(ctx, []string{accessListResourceLockName}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + return a.service.RunWhileLocked(ctx, lockName(member.Spec.AccessList), accessListLockTTL, action) + }) + return upserted, trace.Wrap(err) } // UpdateAccessListMember conditionally updates an access list member resource. func (a *AccessListService) UpdateAccessListMember(ctx context.Context, member *accesslist.AccessListMember) (*accesslist.AccessListMember, error) { var updated *accesslist.AccessListMember - err := a.service.RunWhileLocked(ctx, lockName(member.Spec.AccessList), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { - _, err := a.service.GetResource(ctx, member.Spec.AccessList) - if err != nil { + err := a.service.RunWhileLocked(ctx, []string{accessListResourceLockName}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + return a.service.RunWhileLocked(ctx, lockName(member.Spec.AccessList), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + memberList, err := a.service.GetResource(ctx, member.Spec.AccessList) + if err != nil { + return trace.Wrap(err) + } + + if err := accesslists.ValidateAccessListMember(ctx, memberList, member, &accessListAndMembersGetter{a.service, a.memberService}); err != nil { + return trace.Wrap(err) + } + + updated, err = a.memberService.WithPrefix(member.Spec.AccessList).ConditionalUpdateResource(ctx, member) return trace.Wrap(err) - } - updated, err = a.memberService.WithPrefix(member.Spec.AccessList).ConditionalUpdateResource(ctx, member) - return trace.Wrap(err) + }) }) - if err != nil { - return nil, trace.Wrap(err) - } - return updated, nil + return updated, trace.Wrap(err) } // DeleteAccessListMember hard deletes the specified access list member resource. func (a *AccessListService) DeleteAccessListMember(ctx context.Context, accessList string, memberName string) error { - err := a.service.RunWhileLocked(ctx, lockName(accessList), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { - _, err := a.service.GetResource(ctx, accessList) - if err != nil { - return trace.Wrap(err) - } - return trace.Wrap(a.memberService.WithPrefix(accessList).DeleteResource(ctx, memberName)) + err := a.service.RunWhileLocked(ctx, []string{accessListResourceLockName}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + return a.service.RunWhileLocked(ctx, lockName(accessList), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + _, err := a.service.GetResource(ctx, accessList) + if err != nil { + return trace.Wrap(err) + } + + member, err := a.memberService.WithPrefix(accessList).GetResource(ctx, memberName) + if err != nil { + return trace.Wrap(err) + } + + if err := a.memberService.WithPrefix(accessList).DeleteResource(ctx, memberName); err != nil { + return trace.Wrap(err) + } + + if member.Spec.MembershipKind == accesslist.MembershipKindList { + if err := a.updateAccessListMemberOf(ctx, accessList, memberName, false); err != nil { + return trace.Wrap(err) + } + } + + return nil + }) }) return trace.Wrap(err) } @@ -396,12 +620,33 @@ func (a *AccessListService) DeleteAccessListMember(ctx context.Context, accessLi // allowed on a list with implicit membership, as it provides a mechanism for // cleaning out the user list if a list is converted from explicit to implicit. func (a *AccessListService) DeleteAllAccessListMembersForAccessList(ctx context.Context, accessList string) error { - err := a.service.RunWhileLocked(ctx, lockName(accessList), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { - _, err := a.service.GetResource(ctx, accessList) - if err != nil { - return trace.Wrap(err) - } - return trace.Wrap(a.memberService.WithPrefix(accessList).DeleteAllResources(ctx)) + err := a.service.RunWhileLocked(ctx, []string{accessListResourceLockName}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + return a.service.RunWhileLocked(ctx, lockName(accessList), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + _, err := a.service.GetResource(ctx, accessList) + if err != nil { + return trace.Wrap(err) + } + + allMembers, err := a.memberService.WithPrefix(accessList).GetResources(ctx) + if err != nil { + return trace.Wrap(err) + } + + if err := a.memberService.WithPrefix(accessList).DeleteAllResources(ctx); err != nil { + return trace.Wrap(err) + } + + for _, member := range allMembers { + if member.Spec.MembershipKind != accesslist.MembershipKindList { + continue + } + if err := a.updateAccessListMemberOf(ctx, accessList, member.Spec.Name, false); err != nil { + return trace.Wrap(err) + } + } + + return nil + }) }) return trace.Wrap(err) } @@ -424,98 +669,143 @@ func (a *AccessListService) UpsertAccessListWithMembers(ctx context.Context, acc } } + validateAccessList := func() error { + // If the memberOf or ownerOf fields are already set, return an error – these fields are managed by the service. + if len(accessList.Spec.MemberOf) > 0 { + return trace.BadParameter("memberOf field may not be set explicitly") + } + if len(accessList.Spec.OwnerOf) > 0 { + return trace.BadParameter("ownerOf field may not be set explicitly") + } + + if err := accesslists.ValidateAccessListWithMembers(ctx, accessList, membersIn, &accessListAndMembersGetter{a.service, a.memberService}); err != nil { + return trace.Wrap(err) + } + + return nil + } + reconcileMembers := func() error { - return a.service.RunWhileLocked(ctx, lockName(accessList.GetName()), 2*accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { - // Convert the members slice to a map for easier lookup. - membersMap := utils.FromSlice(membersIn, types.GetName) - - var ( - members []*accesslist.AccessListMember - membersToken string - ) - - for { - // List all members for the access list. - var err error - members, membersToken, err = a.memberService.WithPrefix(accessList.GetName()).ListResources(ctx, 0 /* default size */, membersToken) - if err != nil { - return trace.Wrap(err) - } + // Convert the members slice to a map for easier lookup. + membersMap := utils.FromSlice(membersIn, types.GetName) - for _, existingMember := range members { - // If the member is not in the new members map (request), delete it. - if newMember, ok := membersMap[existingMember.GetName()]; !ok { - err = a.memberService.WithPrefix(accessList.GetName()).DeleteResource(ctx, existingMember.GetName()) - if err != nil { - return trace.Wrap(err) - } - } else { - // Preserve the membership metadata for any existing members - // to suppress member records flipping back and forth due - // due SCIM pushes or Sync Service updates. - if !existingMember.Spec.Expires.IsZero() { - newMember.Spec.Expires = existingMember.Spec.Expires - } - if existingMember.Spec.Reason != "" { - newMember.Spec.Reason = existingMember.Spec.Reason - } - newMember.Spec.AddedBy = existingMember.Spec.AddedBy + var ( + members []*accesslist.AccessListMember + membersToken string + ) - // Compare members and update if necessary. - if !cmp.Equal(newMember, existingMember) { - // Update the member. - upserted, err := a.memberService.WithPrefix(accessList.GetName()).UpsertResource(ctx, newMember) - if err != nil { - return trace.Wrap(err) - } + for { + // List all members for the access list. + var err error + members, membersToken, err = a.memberService.WithPrefix(accessList.GetName()).ListResources(ctx, 0 /* default size */, membersToken) + if err != nil { + return trace.Wrap(err) + } - existingMember.SetRevision(upserted.GetRevision()) + for _, existingMember := range members { + // If the member is not in the new members map (request), delete it. + if newMember, ok := membersMap[existingMember.GetName()]; !ok { + err = a.memberService.WithPrefix(accessList.GetName()).DeleteResource(ctx, existingMember.GetName()) + if err != nil { + return trace.Wrap(err) + } + // Update memberOf field if nested list. + if existingMember.Spec.MembershipKind == accesslist.MembershipKindList { + if err := a.updateAccessListMemberOf(ctx, accessList.GetName(), existingMember.GetName(), false); err != nil { + return trace.Wrap(err) } } + } else { + // Preserve the membership metadata for any existing members + // to suppress member records flipping back and forth due + // due SCIM pushes or Sync Service updates. + if !existingMember.Spec.Expires.IsZero() { + newMember.Spec.Expires = existingMember.Spec.Expires + } + if existingMember.Spec.Reason != "" { + newMember.Spec.Reason = existingMember.Spec.Reason + } + newMember.Spec.AddedBy = existingMember.Spec.AddedBy + + // Compare members and update if necessary. + if !cmp.Equal(newMember, existingMember) { + // Update the member. + upserted, err := a.memberService.WithPrefix(accessList.GetName()).UpsertResource(ctx, newMember) + if err != nil { + return trace.Wrap(err) + } - // Remove the member from the map. - delete(membersMap, existingMember.GetName()) + existingMember.SetRevision(upserted.GetRevision()) + } } - if membersToken == "" { - break + // Remove the member from the map. + delete(membersMap, existingMember.GetName()) + } + + if membersToken == "" { + break + } + } + + // Add any remaining members to the access list. + for _, member := range membersMap { + upserted, err := a.memberService.WithPrefix(accessList.GetName()).UpsertResource(ctx, member) + if err != nil { + return trace.Wrap(err) + } + // Update memberOf field if nested list. + if member.Spec.MembershipKind == accesslist.MembershipKindList { + if err := a.updateAccessListMemberOf(ctx, accessList.GetName(), member.Spec.Name, true); err != nil { + return trace.Wrap(err) } } + member.SetRevision(upserted.GetRevision()) + } - // Add any remaining members to the access list. - for _, member := range membersMap { - upserted, err := a.memberService.WithPrefix(accessList.GetName()).UpsertResource(ctx, member) - if err != nil { + return nil + } + + reconcileOwners := func() error { + // update references for new owners + for _, owner := range accessList.Spec.Owners { + if owner.MembershipKind == accesslist.MembershipKindList { + if err := a.updateAccessListOwnerOf(ctx, accessList.GetName(), owner.Name, true); err != nil { return trace.Wrap(err) } - member.SetRevision(upserted.GetRevision()) } + } + return nil + } - var err error - accessList, err = a.service.UpsertResource(ctx, accessList) - return trace.Wrap(err) - }) + updateAccessList := func() error { + var err error + accessList, err = a.service.UpsertResource(ctx, accessList) + return trace.Wrap(err) } + var actions []func() error + // If IGS is not enabled for this cluster we need to wrap the whole update and // member reconciliation in *another* lock so that we can accurately count the // access lists in the cluster in order to prevent un-authorized use of the // AccessList feature - - action := reconcileMembers if !modules.GetModules().Features().GetEntitlement(entitlements.Identity).Enabled { - action = func() error { - return a.service.RunWhileLocked(ctx, []string{createAccessListLimitLockName}, 2*accessListLockTTL, - func(ctx context.Context, _ backend.Backend) error { - if err := a.VerifyAccessListCreateLimit(ctx, accessList.GetName()); err != nil { - return trace.Wrap(err) - } - return trace.Wrap(reconcileMembers()) - }) - } + actions = append(actions, func() error { return a.VerifyAccessListCreateLimit(ctx, accessList.GetName()) }) } - if err := action(); err != nil { + actions = append(actions, validateAccessList, reconcileMembers, updateAccessList, reconcileOwners) + + if err := a.service.RunWhileLocked(ctx, []string{accessListResourceLockName}, 2*accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + return a.service.RunWhileLocked(ctx, lockName(accessList.GetName()), 2*accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + for _, action := range actions { + if err := action(); err != nil { + return trace.Wrap(err) + } + } + return nil + }) + }); err != nil { return nil, nil, trace.Wrap(err) } @@ -704,7 +994,7 @@ func (a *AccessListService) VerifyAccessListCreateLimit(ctx context.Context, tar return nil // unlimited } - lists, err := a.GetAccessLists(ctx) + lists, err := a.service.GetResources(ctx) if err != nil { return trace.Wrap(err) } diff --git a/lib/services/simple/access_list.go b/lib/services/simple/access_list.go index 12ac35af488de..9effb4475c973 100644 --- a/lib/services/simple/access_list.go +++ b/lib/services/simple/access_list.go @@ -204,3 +204,27 @@ func (a *AccessListService) ListAllAccessListMembers(ctx context.Context, pageSi members, nextToken, err := a.memberService.ListResources(ctx, pageSize, pageToken) return members, nextToken, trace.Wrap(err) } + +// UnconditionalUpsertAccessList creates or updates an Access List resource without any validation. +// It should only ever be used by the cache. +func (a *AccessListService) UnconditionalUpsertAccessList(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + return a.service.UpsertResource(ctx, accessList) +} + +// UnconditionalDeleteAccessList removes the specified Access List resource without any validation. +// It should only ever be used by the cache. +func (a *AccessListService) UnconditionalDeleteAccessList(ctx context.Context, name string) error { + return a.service.DeleteResource(ctx, name) +} + +// UnconditionalUpsertAccessListMember creates or updates an Access List Member resource without any validation. +// It should only ever be used by the cache. +func (a *AccessListService) UnconditionalUpsertAccessListMember(ctx context.Context, member *accesslist.AccessListMember) (*accesslist.AccessListMember, error) { + return a.memberService.WithPrefix(member.Spec.AccessList).UpsertResource(ctx, member) +} + +// UnconditionalDeleteAccessListMember removes the specified Access List Member resource without any validation. +// It should only ever be used by the cache. +func (a *AccessListService) UnconditionalDeleteAccessListMember(ctx context.Context, accessList, memberName string) error { + return a.memberService.WithPrefix(accessList).DeleteResource(ctx, memberName) +}