Skip to content

Commit

Permalink
refactor impl
Browse files Browse the repository at this point in the history
  • Loading branch information
greedy52 committed Mar 8, 2024
1 parent 5d54849 commit b935a92
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 53 deletions.
103 changes: 56 additions & 47 deletions tool/tsh/common/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ import (
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/api/utils/prompt"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/client"
dbprofile "github.com/gravitational/teleport/lib/client/db"
"github.com/gravitational/teleport/lib/client/db/dbcmd"
Expand Down Expand Up @@ -818,60 +816,50 @@ func onDatabaseConnect(cf *CLIConf) error {
return nil
}

func accessRequestForDB(cf *CLIConf, tc *client.TeleportClient, db types.Database) error {
func makeAccessRequestForDatabase(tc *client.TeleportClient, db types.Database) (types.AccessRequest, error) {
requestResourceIDs := []types.ResourceID{{
ClusterName: tc.SiteName,
Kind: types.KindDatabase,
Name: db.GetName(),
}}

req, err := services.NewAccessRequestWithResources(tc.Username, nil /* roles */, requestResourceIDs)
return req, trace.Wrap(err)
}

func makeDatabaseAccessRequestAndWaitForApproval(cf *CLIConf, tc *client.TeleportClient, db types.Database) error {
req, err := makeAccessRequestForDatabase(tc, db)
if err != nil {
return trace.Wrap(err)
}

cf.RequestID = req.GetName()

fmt.Fprintf(cf.Stdout(), "You do not currently have access to %q, attempting to request access.\n\n", db.GetName())

// Prompt for a request reason.
requestReason, err := prompt.Input(cf.Context, cf.Stdout(), prompt.Stdin(), "Enter request reason")
if err != nil {
if err := setAccessRequestReason(cf, req); err != nil {
return trace.Wrap(err)
}
req.SetRequestReason(requestReason)

fmt.Fprint(os.Stdout, "Creating request...\n")
// Always create access request against the root cluster.
if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error {
req, err = clt.CreateAccessRequestV2(cf.Context, req)
return trace.Wrap(err)
}); err != nil {
if err := sendAccessRequestAndWaitForApproval(cf, tc, req); err != nil {
return trace.Wrap(err)
}
return nil
}

if cf.Username == "" {
cf.Username = tc.Username
}
// re-fetch the request to display it with roles populated.
onRequestShow(cf)
fmt.Println("")
var dbCommandsWithAccessRequestSupport = []string{
"db login",
"proxy db",
"db connect",
}

// Wait for the request to be resolved.
fmt.Fprintf(os.Stdout, "Waiting for request approval...\n")
var resolvedReq types.AccessRequest
if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error {
resolvedReq, err = awaitRequestResolution(cf.Context, clt, req)
return trace.Wrap(err)
}); err != nil {
return trace.Wrap(err)
func shouldRetryGetDatabaseUsingSearchAsRoles(cf *CLIConf, tc *client.TeleportClient, getDatabaseError error) bool {
// Only retry when the database cannot be found without UseSearchAsRoles.
if !trace.IsNotFound(getDatabaseError) || tc.UseSearchAsRoles {
return false
}

// Handle resolution and update client certs if approved.
if err := onRequestResolution(cf, tc, resolvedReq); err != nil {
return trace.Wrap(err)
// Check if auto access request is disabled.
if cf.disableAccessRequest {
return false
}
return nil
// Check if the `tsh` command supports auto access request.
return slices.Contains(dbCommandsWithAccessRequestSupport, cf.command)
}

// getDatabaseInfo fetches information about the database from tsh profile if DB
Expand All @@ -890,20 +878,32 @@ func getDatabaseInfo(cf *CLIConf, tc *client.TeleportClient, routes []tlsca.Rout

db, err := getDatabaseByNameOrDiscoveredName(cf, tc, routes)
switch {
case trace.IsNotFound(err) && !cf.disableAccessRequest && !tc.UseSearchAsRoles:
// Try again with SearchAsRoles. TODO optimize to not do
// UseSearchAsRoles for "regular" users.
tc.UseSearchAsRoles = true
searchAsRolesDB, searchAsRolesErr := getDatabaseByNameOrDiscoveredName(cf, tc, nil)
if searchAsRolesErr != nil {
return nil, trace.Wrap(err)
// If the database cannot be found, try again with UseSearchAsRoles. If
// the database is then found with UseSearchAsRoles, make an access request
// for it and elevate the user with the request ID upon approval.
//
// Note that the access request must be made before the database connection
// is made to avoid mangling the request with the database client tools.
// Thus the flow for auto database access request is different from SSH.
//
// Performance considerations:
// - For common scenarios where UseSearchAsRoles is not desired, it would
// be rare that cf.DatabaseName would be not found in the first API call
// so there won't be a second call usually.
// - accessChecker.GetAllowedSearchAsRoles can be checked to avoid the
// second API call but creating the access checker requires more calls.
// - The db commands do provide "--disable-access-request" to bypass the
// second call. If needed, we can add it to `tsh login` and profile yaml
// in the future.
case shouldRetryGetDatabaseUsingSearchAsRoles(cf, tc, err):
orgErr := err
if db, err = getDatabaseByNameOrDiscoveredNameUsingSearchAsRoles(cf, tc); err != nil {
return nil, trace.Wrap(orgErr) // Returns the original not found error.
}

if err := accessRequestForDB(cf, tc, searchAsRolesDB); err != nil {
if err := makeDatabaseAccessRequestAndWaitForApproval(cf, tc, db); err != nil {
return nil, trace.Wrap(err)
}

db = searchAsRolesDB
routes = nil // reset routes after reissueWithRequests.

case err != nil:
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -1117,6 +1117,15 @@ func getDatabaseByNameOrDiscoveredName(cf *CLIConf, tc *client.TeleportClient, a
return chooseOneDatabase(cf, databases)
}

func getDatabaseByNameOrDiscoveredNameUsingSearchAsRoles(cf *CLIConf, tc *client.TeleportClient) (types.Database, error) {
tc.UseSearchAsRoles = true
defer func() {
tc.UseSearchAsRoles = false
}()
db, err := getDatabaseByNameOrDiscoveredName(cf, tc, nil)
return db, trace.Wrap(err)
}

func filterActiveDatabases(routes []tlsca.RouteToDatabase, databases types.Databases) types.Databases {
databasesByName := databases.ToMap()
var out types.Databases
Expand Down
5 changes: 5 additions & 0 deletions tool/tsh/common/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func TestTshDB(t *testing.T) {
t.Run("Login", testDatabaseLogin)
t.Run("List", testListDatabase)
t.Run("DatabaseSelection", testDatabaseSelection)
t.Run("access request", testDatabaseAccessRequest)
}

// testDatabaseLogin tests "tsh db login" command and verifies "tsh db
Expand Down Expand Up @@ -1758,3 +1759,7 @@ func testDatabaseSelection(t *testing.T) {
}
})
}

func testDatabaseAccessRequest(t *testing.T) {
// TODO
}
31 changes: 25 additions & 6 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
proxyDB.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName)
proxyDB.Flag("labels", labelHelp).StringVar(&cf.Labels)
proxyDB.Flag("query", queryHelp).StringVar(&cf.PredicateExpression)
proxyDB.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason)
proxyDB.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest)

proxyApp := proxy.Command("app", "Start local TLS proxy for app connection when using Teleport in single-port mode.")
proxyApp.Arg("app", "The name of the application to start local proxy for").Required().StringVar(&cf.AppName)
Expand Down Expand Up @@ -884,6 +886,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
dbLogin.Flag("db-user", "Database user to configure as default.").Short('u').StringVar(&cf.DatabaseUser)
dbLogin.Flag("db-name", "Database name to configure as default.").Short('n').StringVar(&cf.DatabaseName)
dbLogin.Flag("db-roles", "List of comma separate database roles to use for auto-provisioned user.").Short('r').StringVar(&cf.DatabaseRoles)
dbLogin.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason)
dbLogin.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest)
dbLogout := db.Command("logout", "Remove database credentials.")
dbLogout.Arg("db", "Database to remove credentials for.").StringVar(&cf.DatabaseService)
dbLogout.Flag("labels", labelHelp).StringVar(&cf.Labels)
Expand All @@ -910,6 +914,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
dbConnect.Flag("db-roles", "List of comma separate database roles to use for auto-provisioned user.").Short('r').StringVar(&cf.DatabaseRoles)
dbConnect.Flag("labels", labelHelp).StringVar(&cf.Labels)
dbConnect.Flag("query", queryHelp).StringVar(&cf.PredicateExpression)
dbConnect.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason)
dbConnect.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest)

// join
join := app.Command("join", "Join the active SSH or Kubernetes session.")
Expand Down Expand Up @@ -3267,12 +3273,25 @@ func retryWithAccessRequest(
log.Debugf("Not attempting to automatically request access, reason: %v", err)
return trace.Wrap(origErr)
}
cf.RequestID = req.GetName()

// Print and log the original AccessDenied error.
fmt.Fprintln(os.Stderr, utils.UserMessageFromError(origErr))
fmt.Fprintf(os.Stdout, "You do not currently have access to %q, attempting to request access.\n\n", resource)

if err := setAccessRequestReason(cf, req); err != nil {
return trace.Wrap(err)
}
if err := sendAccessRequestAndWaitForApproval(cf, tc, req); err != nil {
return trace.Wrap(err)
}

// Retry now that request has been approved and certs updated.
// Clear the original exit status.
tc.ExitStatus = 0
return trace.Wrap(fn())
}

func setAccessRequestReason(cf *CLIConf, req types.AccessRequest) (err error) {
requestReason := cf.RequestReason
if requestReason == "" {
// Prompt for a request reason.
Expand All @@ -3282,7 +3301,11 @@ func retryWithAccessRequest(
}
}
req.SetRequestReason(requestReason)
return nil
}

func sendAccessRequestAndWaitForApproval(cf *CLIConf, tc *client.TeleportClient, req types.AccessRequest) (err error) {
cf.RequestID = req.GetName()
fmt.Fprint(os.Stdout, "Creating request...\n")
// Always create access request against the root cluster.
if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error {
Expand Down Expand Up @@ -3313,11 +3336,7 @@ func retryWithAccessRequest(
if err := onRequestResolution(cf, tc, resolvedReq); err != nil {
return trace.Wrap(err)
}

// Retry now that request has been approved and certs updated.
// Clear the original exit status.
tc.ExitStatus = 0
return trace.Wrap(fn())
return nil
}

func onSSHLatency(cf *CLIConf) error {
Expand Down

0 comments on commit b935a92

Please sign in to comment.