diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index c5e033dbb4776..f4a5903532549 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -44,8 +44,6 @@ import ( "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keys" - "github.com/gravitational/teleport/api/utils/prompt" - "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" dbprofile "github.com/gravitational/teleport/lib/client/db" "github.com/gravitational/teleport/lib/client/db/dbcmd" @@ -818,7 +816,7 @@ func onDatabaseConnect(cf *CLIConf) error { return nil } -func accessRequestForDB(cf *CLIConf, tc *client.TeleportClient, db types.Database) error { +func makeAccessRequestForDatabase(tc *client.TeleportClient, db types.Database) (types.AccessRequest, error) { requestResourceIDs := []types.ResourceID{{ ClusterName: tc.SiteName, Kind: types.KindDatabase, @@ -826,52 +824,42 @@ func accessRequestForDB(cf *CLIConf, tc *client.TeleportClient, db types.Databas }} req, err := services.NewAccessRequestWithResources(tc.Username, nil /* roles */, requestResourceIDs) + return req, trace.Wrap(err) +} + +func makeDatabaseAccessRequestAndWaitForApproval(cf *CLIConf, tc *client.TeleportClient, db types.Database) error { + req, err := makeAccessRequestForDatabase(tc, db) if err != nil { return trace.Wrap(err) } - cf.RequestID = req.GetName() - fmt.Fprintf(cf.Stdout(), "You do not currently have access to %q, attempting to request access.\n\n", db.GetName()) - - // Prompt for a request reason. - requestReason, err := prompt.Input(cf.Context, cf.Stdout(), prompt.Stdin(), "Enter request reason") - if err != nil { + if err := setAccessRequestReason(cf, req); err != nil { return trace.Wrap(err) } - req.SetRequestReason(requestReason) - - fmt.Fprint(os.Stdout, "Creating request...\n") - // Always create access request against the root cluster. - if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { - req, err = clt.CreateAccessRequestV2(cf.Context, req) - return trace.Wrap(err) - }); err != nil { + if err := sendAccessRequestAndWaitForApproval(cf, tc, req); err != nil { return trace.Wrap(err) } + return nil +} - if cf.Username == "" { - cf.Username = tc.Username - } - // re-fetch the request to display it with roles populated. - onRequestShow(cf) - fmt.Println("") +var dbCommandsWithAccessRequestSupport = []string{ + "db login", + "proxy db", + "db connect", +} - // Wait for the request to be resolved. - fmt.Fprintf(os.Stdout, "Waiting for request approval...\n") - var resolvedReq types.AccessRequest - if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { - resolvedReq, err = awaitRequestResolution(cf.Context, clt, req) - return trace.Wrap(err) - }); err != nil { - return trace.Wrap(err) +func shouldRetryGetDatabaseUsingSearchAsRoles(cf *CLIConf, tc *client.TeleportClient, getDatabaseError error) bool { + // Only retry when the database cannot be found without UseSearchAsRoles. + if !trace.IsNotFound(getDatabaseError) || tc.UseSearchAsRoles { + return false } - - // Handle resolution and update client certs if approved. - if err := onRequestResolution(cf, tc, resolvedReq); err != nil { - return trace.Wrap(err) + // Check if auto access request is disabled. + if cf.disableAccessRequest { + return false } - return nil + // Check if the `tsh` command supports auto access request. + return slices.Contains(dbCommandsWithAccessRequestSupport, cf.command) } // getDatabaseInfo fetches information about the database from tsh profile if DB @@ -890,20 +878,32 @@ func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, routes []tlsca.Rout db, err := getDatabaseByNameOrDiscoveredName(cf, tc, routes) switch { - case trace.IsNotFound(err) && !cf.disableAccessRequest && !tc.UseSearchAsRoles: - // Try again with SearchAsRoles. TODO optimize to not do - // UseSearchAsRoles for "regular" users. - tc.UseSearchAsRoles = true - searchAsRolesDB, searchAsRolesErr := getDatabaseByNameOrDiscoveredName(cf, tc, nil) - if searchAsRolesErr != nil { - return nil, trace.Wrap(err) + // If the database cannot be found, try again with UseSearchAsRoles. If + // the database is then found with UseSearchAsRoles, make an access request + // for it and elevate the user with the request ID upon approval. + // + // Note that the access request must be made before the database connection + // is made to avoid mangling the request with the database client tools. + // Thus the flow for auto database access request is different from SSH. + // + // Performance considerations: + // - For common scenarios where UseSearchAsRoles is not desired, it would + // be rare that cf.DatabaseName would be not found in the first API call + // so there won't be a second call usually. + // - accessChecker.GetAllowedSearchAsRoles can be checked to avoid the + // second API call but creating the access checker requires more calls. + // - The db commands do provide "--disable-access-request" to bypass the + // second call. If needed, we can add it to `tsh login` and profile yaml + // in the future. + case shouldRetryGetDatabaseUsingSearchAsRoles(cf, tc, err): + orgErr := err + if db, err = getDatabaseByNameOrDiscoveredNameUsingSearchAsRoles(cf, tc); err != nil { + return nil, trace.Wrap(orgErr) // Returns the original not found error. } - - if err := accessRequestForDB(cf, tc, searchAsRolesDB); err != nil { + if err := makeDatabaseAccessRequestAndWaitForApproval(cf, tc, db); err != nil { return nil, trace.Wrap(err) } - - db = searchAsRolesDB + routes = nil // reset routes after reissueWithRequests. case err != nil: return nil, trace.Wrap(err) @@ -1117,6 +1117,15 @@ func getDatabaseByNameOrDiscoveredName(cf *CLIConf, tc *client.TeleportClient, a return chooseOneDatabase(cf, databases) } +func getDatabaseByNameOrDiscoveredNameUsingSearchAsRoles(cf *CLIConf, tc *client.TeleportClient) (types.Database, error) { + tc.UseSearchAsRoles = true + defer func() { + tc.UseSearchAsRoles = false + }() + db, err := getDatabaseByNameOrDiscoveredName(cf, tc, nil) + return db, trace.Wrap(err) +} + func filterActiveDatabases(routes []tlsca.RouteToDatabase, databases types.Databases) types.Databases { databasesByName := databases.ToMap() var out types.Databases diff --git a/tool/tsh/common/db_test.go b/tool/tsh/common/db_test.go index cbe13d26c232a..c6e6c1988e1b7 100644 --- a/tool/tsh/common/db_test.go +++ b/tool/tsh/common/db_test.go @@ -62,6 +62,7 @@ func TestTshDB(t *testing.T) { t.Run("Login", testDatabaseLogin) t.Run("List", testListDatabase) t.Run("DatabaseSelection", testDatabaseSelection) + t.Run("access request", testDatabaseAccessRequest) } // testDatabaseLogin tests "tsh db login" command and verifies "tsh db @@ -1758,3 +1759,7 @@ func testDatabaseSelection(t *testing.T) { } }) } + +func testDatabaseAccessRequest(t *testing.T) { + // TODO +} diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 024149d15a5e9..74661b8163040 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -841,6 +841,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { proxyDB.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName) proxyDB.Flag("labels", labelHelp).StringVar(&cf.Labels) proxyDB.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) + proxyDB.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason) + proxyDB.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest) proxyApp := proxy.Command("app", "Start local TLS proxy for app connection when using Teleport in single-port mode.") proxyApp.Arg("app", "The name of the application to start local proxy for").Required().StringVar(&cf.AppName) @@ -884,6 +886,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { dbLogin.Flag("db-user", "Database user to configure as default.").Short('u').StringVar(&cf.DatabaseUser) dbLogin.Flag("db-name", "Database name to configure as default.").Short('n').StringVar(&cf.DatabaseName) dbLogin.Flag("db-roles", "List of comma separate database roles to use for auto-provisioned user.").Short('r').StringVar(&cf.DatabaseRoles) + dbLogin.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason) + dbLogin.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest) dbLogout := db.Command("logout", "Remove database credentials.") dbLogout.Arg("db", "Database to remove credentials for.").StringVar(&cf.DatabaseService) dbLogout.Flag("labels", labelHelp).StringVar(&cf.Labels) @@ -910,6 +914,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { dbConnect.Flag("db-roles", "List of comma separate database roles to use for auto-provisioned user.").Short('r').StringVar(&cf.DatabaseRoles) dbConnect.Flag("labels", labelHelp).StringVar(&cf.Labels) dbConnect.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) + dbConnect.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason) + dbConnect.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest) // join join := app.Command("join", "Join the active SSH or Kubernetes session.") @@ -3267,12 +3273,25 @@ func retryWithAccessRequest( log.Debugf("Not attempting to automatically request access, reason: %v", err) return trace.Wrap(origErr) } - cf.RequestID = req.GetName() // Print and log the original AccessDenied error. fmt.Fprintln(os.Stderr, utils.UserMessageFromError(origErr)) fmt.Fprintf(os.Stdout, "You do not currently have access to %q, attempting to request access.\n\n", resource) + if err := setAccessRequestReason(cf, req); err != nil { + return trace.Wrap(err) + } + if err := sendAccessRequestAndWaitForApproval(cf, tc, req); err != nil { + return trace.Wrap(err) + } + + // Retry now that request has been approved and certs updated. + // Clear the original exit status. + tc.ExitStatus = 0 + return trace.Wrap(fn()) +} + +func setAccessRequestReason(cf *CLIConf, req types.AccessRequest) (err error) { requestReason := cf.RequestReason if requestReason == "" { // Prompt for a request reason. @@ -3282,7 +3301,11 @@ func retryWithAccessRequest( } } req.SetRequestReason(requestReason) + return nil +} +func sendAccessRequestAndWaitForApproval(cf *CLIConf, tc *client.TeleportClient, req types.AccessRequest) (err error) { + cf.RequestID = req.GetName() fmt.Fprint(os.Stdout, "Creating request...\n") // Always create access request against the root cluster. if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { @@ -3313,11 +3336,7 @@ func retryWithAccessRequest( if err := onRequestResolution(cf, tc, resolvedReq); err != nil { return trace.Wrap(err) } - - // Retry now that request has been approved and certs updated. - // Clear the original exit status. - tc.ExitStatus = 0 - return trace.Wrap(fn()) + return nil } func onSSHLatency(cf *CLIConf) error {