Skip to content

Commit

Permalink
Only run deduplication when the request is being created
Browse files Browse the repository at this point in the history
  • Loading branch information
zmb3 committed Oct 2, 2024
1 parent da670f8 commit 86a4573
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 18 deletions.
26 changes: 13 additions & 13 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -1110,19 +1110,6 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
return trace.BadParameter("only promoted requests can set the promoted access list title")
}

// deduplicate requested resource IDs
var deduplicated []types.ResourceID
seen := make(map[string]struct{})
for _, resource := range req.GetRequestedResourceIDs() {
id := types.ResourceIDToString(resource)
if _, isDuplicate := seen[id]; isDuplicate {
continue
}
seen[id] = struct{}{}
deduplicated = append(deduplicated, resource)
}
req.SetRequestedResourceIDs(deduplicated)

// check for "wildcard request" (`roles=*`). wildcard requests
// need to be expanded into a list consisting of all existing roles
// that the user does not hold and is allowed to request.
Expand Down Expand Up @@ -1168,6 +1155,19 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
}

if m.opts.expandVars {
// deduplicate requested resource IDs
var deduplicated []types.ResourceID
seen := make(map[string]struct{})
for _, resource := range req.GetRequestedResourceIDs() {
id := types.ResourceIDToString(resource)
if _, isDuplicate := seen[id]; isDuplicate {
continue
}
seen[id] = struct{}{}
deduplicated = append(deduplicated, resource)
}
req.SetRequestedResourceIDs(deduplicated)

// determine the roles which should be requested for a resource access
// request, and write them to the request
if err := m.setRolesForResourceRequest(ctx, req); err != nil {
Expand Down
38 changes: 33 additions & 5 deletions lib/services/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2170,9 +2170,37 @@ func TestValidateDuplicateRequestedResources(t *testing.T) {
desktops: make(map[string]types.WindowsDesktop),
clusterName: "someCluster",
}
testRole, err := types.NewRole("testRole", types.RoleSpecV6{})

for i := 1; i < 3; i++ {
node, err := types.NewServerWithLabels(
fmt.Sprintf("resource%d", i),
types.KindNode,
types.ServerSpecV2{},
map[string]string{"foo": "bar"},
)
require.NoError(t, err)
g.nodes[node.GetName()] = node
}

searchAsRole, err := types.NewRole("searchAs", types.RoleSpecV6{
Allow: types.RoleConditions{
Logins: []string{"root"},
NodeLabels: types.Labels{"*": []string{"*"}},
},
})
require.NoError(t, err)
g.roles[searchAsRole.GetName()] = searchAsRole

testRole, err := types.NewRole("testRole", types.RoleSpecV6{
Allow: types.RoleConditions{
Request: &types.AccessRequestConditions{
SearchAsRoles: []string{searchAsRole.GetName()},
},
},
})
require.NoError(t, err)
g.roles[testRole.GetName()] = testRole

user := g.user(t, testRole.GetName())

clock := clockwork.NewFakeClock()
Expand All @@ -2183,15 +2211,15 @@ func TestValidateDuplicateRequestedResources(t *testing.T) {
req, err := types.NewAccessRequestWithResources("name", user, nil, /* roles */
[]types.ResourceID{
{ClusterName: "someCluster", Kind: "node", Name: "resource1"},
{ClusterName: "someCluster", Kind: "node", Name: "resource1"}, // a true duplicate
{ClusterName: "someCluster", Kind: "app", Name: "resource1"}, // not a duplicate
{ClusterName: "someCluster", Kind: "node", Name: "resource1"}, // a duplicate
{ClusterName: "someCluster", Kind: "node", Name: "resource2"}, // not a duplicate
})
require.NoError(t, err)

require.NoError(t, ValidateAccessRequestForUser(context.Background(), clock, g, req, identity))
require.NoError(t, ValidateAccessRequestForUser(context.Background(), clock, g, req, identity, ExpandVars(true)))
require.Len(t, req.GetRequestedResourceIDs(), 2)
require.Equal(t, "/someCluster/node/resource1", types.ResourceIDToString(req.GetRequestedResourceIDs()[0]))
require.Equal(t, "/someCluster/app/resource1", types.ResourceIDToString(req.GetRequestedResourceIDs()[1]))
require.Equal(t, "/someCluster/node/resource2", types.ResourceIDToString(req.GetRequestedResourceIDs()[1]))
}

func TestValidateAccessRequestClusterNames(t *testing.T) {
Expand Down

0 comments on commit 86a4573

Please sign in to comment.