Skip to content

Commit

Permalink
add a recursive test
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex McGrath committed May 10, 2024
1 parent bc26372 commit 44c2545
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
38 changes: 38 additions & 0 deletions lib/auth/userloginstate/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,38 @@ func TestAccessLists(t *testing.T) {
expectedRoleCount: 1,
expectedTraitCount: 6,
},
{
name: "access lists member of sub list",
cloud: true,
user: userNoRolesOrTraits,
// user is member of acl 3, acl 1 includes acl 2, which includes acl 2 and 3
// so user will be granted role1 and 2, and trait1
accessLists: []*accesslist.AccessList{
newAccessListsWithMemberRefs(t, clock, "1", grants([]string{"role1"},
trait.Traits{
"trait1": {"value"},
}),
emptyGrants,
[]accesslist.AccessListRef{{Name: "2", Title: "title"}}),
newAccessListsWithMemberRefs(t, clock, "2", grants([]string{"role1"}, trait.Traits{}),
emptyGrants,
[]accesslist.AccessListRef{{Name: "3", Title: "title"}, {Name: "2", Title: "title"}}),
newAccessList(t, clock, "3", grants([]string{"role2"}, trait.Traits{}), emptyGrants),
},
members: newAccessListMembers(t, clock, "3", "user"),
roles: []string{"role1", "role2"},
wantErr: require.NoError,
expected: newUserLoginState(t, "user",
map[string]string{
userloginstate.OriginalRolesAndTraitsSet: "true",
},
nil,
nil,
[]string{"role1", "role2"},
trait.Traits{"trait1": {"value"}}),
expectedRoleCount: 2,
expectedTraitCount: 1,
},
}

for _, test := range tests {
Expand Down Expand Up @@ -499,6 +531,12 @@ func grants(roles []string, traits trait.Traits) accesslist.Grants {
}
}

func newAccessListsWithMemberRefs(t *testing.T, clock clockwork.Clock, name string, grants accesslist.Grants, ownerGrants accesslist.Grants, memberlists []accesslist.AccessListRef) *accesslist.AccessList {
acl := newAccessList(t, clock, name, grants, ownerGrants)
acl.Spec.MemberAccessLists = memberlists
return acl
}

func newAccessList(t *testing.T, clock clockwork.Clock, name string, grants accesslist.Grants, ownerGrants accesslist.Grants) *accesslist.AccessList {
t.Helper()

Expand Down
19 changes: 14 additions & 5 deletions lib/services/access_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,13 @@ func (a AccessListMembershipChecker) recursiveIsAccessListMemberCheck(ctx contex
}
if !UserMeetsRequirements(identity, subAccessList.Spec.MembershipRequires) {
return trace.AccessDenied("user %s is a member, but does not have the roles or traits required to be a member of this list", username)
} else {
return nil
}

return nil
}

}
return nil
return trace.NotFound("user %s is not a member of the access list", username)
}

// IsAccessListMember will return true if the user is a member for the current list.
Expand All @@ -349,12 +349,21 @@ func (a AccessListMembershipChecker) IsAccessListMember(ctx context.Context, ide
if err != nil && !trace.IsNotFound(err) {
return trace.Wrap(err)
}

if trace.IsNotFound(err) {
if err := a.recursiveIsAccessListMemberCheck(ctx, identity, accessList); err != nil {
// try find if the user could be a member of any lists by recursing
err := a.recursiveIsAccessListMemberCheck(ctx, identity, accessList)
if trace.IsNotFound(err) {
// The member has not been found, so we know they're not a member of this list.
return trace.NotFound("user %s is not a member of the access list", username)
}
if err != nil {
// Some other error has occurred
return trace.Wrap(err)
}
return nil
} else if err != nil {
// Some other error has occurred
return trace.Wrap(err)
}

expires := member.Spec.Expires
Expand Down
7 changes: 6 additions & 1 deletion lib/services/access_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ type testMembersAndLockGetter struct {
locks map[string]types.Lock
}

// GetAccessList implements AccessListGetter.
func (t *testMembersAndLockGetter) GetAccessList(context.Context, string) (*accesslist.AccessList, error) {
return nil, trace.NotImplemented("not implemented")
}

// ListAccessListMembers returns a paginated list of all access list members.
func (t *testMembersAndLockGetter) ListAccessListMembers(ctx context.Context, accessList string, _ int, _ string) (members []*accesslist.AccessListMember, nextToken string, err error) {
for _, member := range t.members[accessList] {
Expand Down Expand Up @@ -489,7 +494,7 @@ func TestIsAccessListMemberChecker(t *testing.T) {
}
getter := &testMembersAndLockGetter{members: memberMap, locks: test.locks}

checker := NewAccessListMembershipChecker(clockwork.NewFakeClockAt(test.currentTime), getter, getter)
checker := NewAccessListMembershipChecker(clockwork.NewFakeClockAt(test.currentTime), getter, getter, getter)
test.errAssertionFunc(t, checker.IsAccessListMember(ctx, test.identity, accessList))
})
}
Expand Down

0 comments on commit 44c2545

Please sign in to comment.