Skip to content

Commit

Permalink
Fix setting request reason for automatic ssh access requests. (#43076)
Browse files Browse the repository at this point in the history
* Fix setting request reason for automatic ssh access requests.

* Use slices.ContainsFunc for test check.
  • Loading branch information
AntonAM authored Jun 18, 2024
1 parent d6fe42b commit db7baed
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
35 changes: 18 additions & 17 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -3376,28 +3376,29 @@ func retryWithAccessRequest(
}

func promptUserForAccessRequestDetails(cf *CLIConf, req types.AccessRequest) error {
if cf.RequestMode != accessRequestModeRole {
return nil
}
// If this is a role access request, ensure that it only has one role.
switch len(req.GetRoles()) {
case 0:
return trace.AccessDenied("no roles to request that would grant access")
case 1:
return nil
default:
selectedRole, err := prompt.PickOne(
cf.Context, os.Stdout, prompt.NewContextReader(os.Stdin),
"Choose role to request",
req.GetRoles())
if err != nil {
return trace.Wrap(err)
if cf.RequestMode == accessRequestModeRole {
// If this is a role access request, ensure that it only has one role.
switch len(req.GetRoles()) {
case 0:
return trace.AccessDenied("no roles to request that would grant access")
case 1:
// No need to choose a role, just set request reason.
default:
selectedRole, err := prompt.PickOne(
cf.Context, os.Stdout, prompt.NewContextReader(os.Stdin),
"Choose role to request",
req.GetRoles())
if err != nil {
return trace.Wrap(err)
}
req.SetRoles([]string{selectedRole})
}
req.SetRoles([]string{selectedRole})
}

if err := setAccessRequestReason(cf, req); err != nil {
return trace.Wrap(err)
}

return nil
}

Expand Down
23 changes: 22 additions & 1 deletion tool/tsh/common/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import (
"github.com/ghodss/yaml"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1975,6 +1976,20 @@ func TestSSHAccessRequest(t *testing.T) {
_, err = rootAuth.GetAuthServer().UpsertUser(ctx, alice)
require.NoError(t, err)

err = Run(ctx, []string{
"logout",
}, setHomePath(tmpHomePath))
require.NoError(t, err)

err = Run(ctx, []string{
"login",
"--insecure",
"--proxy", proxyAddr.String(),
"--user", "alice",
}, setHomePath(tmpHomePath), setMockSSOLogin(rootAuth.GetAuthServer(), alice, connector.GetName()))
require.NoError(t, err)

requestReason := uuid.New().String()
// the first ssh request can fail if the proxy node watcher doesn't know
// about the nodes yet, retry a few times until it works
require.Eventually(t, func() bool {
Expand All @@ -1984,7 +1999,7 @@ func TestSSHAccessRequest(t *testing.T) {
"--debug",
"--insecure",
"--request-mode", tc.requestMode,
"--request-reason", "reason here to bypass prompt",
"--request-reason", requestReason,
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"echo", "test",
}, setHomePath(tmpHomePath))
Expand All @@ -1994,6 +2009,12 @@ func TestSSHAccessRequest(t *testing.T) {
return err == nil
}, 10*time.Second, 100*time.Millisecond, "failed to ssh with retries")

requests, err := rootAuth.GetAuthServer().GetAccessRequests(ctx, types.AccessRequestFilter{})
require.NoError(t, err)
require.True(t, slices.ContainsFunc(requests, func(request types.AccessRequest) bool {
return request.GetRequestReason() == requestReason
}), "access request with the specified reason was not found")

// now that we have an approved access request, it should work without
// prompting for a request reason
err = Run(ctx, []string{
Expand Down

0 comments on commit db7baed

Please sign in to comment.