diff --git a/lib/client/api.go b/lib/client/api.go index 9211fe1ddfc63..9a39617e3f49e 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -730,8 +730,42 @@ func WithMakeCurrentProfile(makeCurrentProfile bool) RetryWithReloginOption { } } +// NonRetryableError wraps an error to indicate that the error should fail +// IsErrorResolvableWithRelogin. This wrapper is used to workaround the false +// positives like trace.IsBadParameter check in IsErrorResolvableWithRelogin. +type NonRetryableError struct { + // Err is the original error. + Err error +} + +// Error returns the error text. +func (e *NonRetryableError) Error() string { + if e == nil || e.Err == nil { + return "" + } + return e.Err.Error() +} + +// Unwrap returns the original error. +func (e *NonRetryableError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +// IsNonRetryableError checks if the provided error is a NonRetryableError. +// Equivalent to `errors.As(err, new(*NonRetryableError))`. +func IsNonRetryableError(err error) bool { + return errors.As(err, new(*NonRetryableError)) +} + // IsErrorResolvableWithRelogin returns true if relogin is attempted on `err`. func IsErrorResolvableWithRelogin(err error) bool { + if IsNonRetryableError(err) { + return false + } + // Private key policy errors indicate that the user must login with an // unexpected private key policy requirement satisfied. This can occur // in the following cases: @@ -767,6 +801,8 @@ func IsErrorResolvableWithRelogin(err error) bool { // TODO(codingllama): Retrying BadParameter is a terrible idea. // We should fix this and remove the RemoteError condition above as well. // Any retriable error should be explicitly marked as such. + // Once trace.IsBadParameter check is removed, the nonRetryableError + // workaround can also be removed. return trace.IsBadParameter(err) || trace.IsTrustError(err) || utils.IsCertExpiredError(err) || diff --git a/lib/client/api_test.go b/lib/client/api_test.go index 670c3398d8c42..92e1a3676e85c 100644 --- a/lib/client/api_test.go +++ b/lib/client/api_test.go @@ -31,6 +31,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" @@ -1243,6 +1244,18 @@ func TestIsErrorResolvableWithRelogin(t *testing.T) { }, expectResolvable: true, }, + { + name: "trace.BadParameter should be resolvable", + err: trace.BadParameter("bad"), + expectResolvable: true, + }, + { + name: "nonRetryableError should not be resolvable", + err: trace.Wrap(&NonRetryableError{ + Err: trace.BadParameter("bad"), + }), + expectResolvable: false, + }, } { t.Run(tt.name, func(t *testing.T) { resolvable := IsErrorResolvableWithRelogin(tt.err) @@ -1343,3 +1356,15 @@ func TestGetTargetNodes(t *testing.T) { }) } } + +func TestNonRetryableError(t *testing.T) { + orgError := trace.AccessDenied("do not enter") + err := &NonRetryableError{ + Err: orgError, + } + require.Error(t, err) + assert.Equal(t, "do not enter", err.Error()) + assert.True(t, IsNonRetryableError(err)) + assert.True(t, trace.IsAccessDenied(err)) + assert.Equal(t, orgError, err.Unwrap()) +} diff --git a/tool/tsh/common/app.go b/tool/tsh/common/app.go index 2edaf86cdb3ce..6e13d18eec647 100644 --- a/tool/tsh/common/app.go +++ b/tool/tsh/common/app.go @@ -518,7 +518,8 @@ func getAppInfo(cf *CLIConf, clt authclient.ClientI, profile *client.ProfileStat isActive: true, }, nil } else if !trace.IsNotFound(err) { - return nil, trace.Wrap(err) + // pickActiveApp errors are non-retryable. + return nil, trace.Wrap(&client.NonRetryableError{Err: err}) } // If we didn't find an active profile for the app, get info from server. @@ -542,33 +543,45 @@ func getAppInfo(cf *CLIConf, clt authclient.ClientI, profile *client.ProfileStat app: app, } + // When getAppInfo gets called inside RetryWithRelogin, it will relogin on + // trace.BadParameter errors. Wrap errors from pickCloudAppLogin as they + // are not retryable. + if err := appInfo.pickCloudAppLogin(cf, logins); err != nil { + return nil, trace.Wrap(&client.NonRetryableError{Err: err}) + } + return appInfo, nil +} + +// pickCloudAppLogin picks the cloud identity for the app based on provided CLI +// flags and/or available logins of the Teleport user. +func (a *appInfo) pickCloudAppLogin(cf *CLIConf, logins []string) error { // If this is a cloud app, set additional applicable fields from CLI flags or roles. switch { - case app.IsAWSConsole(): - awsRoleARN, err := getARNFromFlags(cf, app, logins) + case a.app.IsAWSConsole(): + awsRoleARN, err := getARNFromFlags(cf, a.app, logins) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } - appInfo.AWSRoleARN = awsRoleARN + a.AWSRoleARN = awsRoleARN - case app.IsAzureCloud(): - azureIdentity, err := getAzureIdentityFromFlags(cf, profile) + case a.app.IsAzureCloud(): + azureIdentity, err := getAzureIdentityFromFlags(cf, a.profile) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } log.Debugf("Azure identity is %q", azureIdentity) - appInfo.AzureIdentity = azureIdentity + a.AzureIdentity = azureIdentity - case app.IsGCP(): - gcpServiceAccount, err := getGCPServiceAccountFromFlags(cf, profile) + case a.app.IsGCP(): + gcpServiceAccount, err := getGCPServiceAccountFromFlags(cf, a.profile) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } log.Debugf("GCP service account is %q", gcpServiceAccount) - appInfo.GCPServiceAccount = gcpServiceAccount + a.GCPServiceAccount = gcpServiceAccount } - return appInfo, nil + return nil } // appInfo wraps a RouteToApp and the corresponding app.