Skip to content

Commit

Permalink
[v15] add a hidden --tunnel flag to tsh db connect to force local…
Browse files Browse the repository at this point in the history
… proxy tunnel (#48319)

* add a hidden `--tunnel` flag to `tsh db connect` to force local proxy with tunnel

* move msg to const
  • Loading branch information
greedy52 authored Nov 4, 2024
1 parent 17b0483 commit 8dee20d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
15 changes: 13 additions & 2 deletions tool/tsh/common/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ func onDatabaseConnect(cf *CLIConf) error {
return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, dbInfo.RouteToDatabase))
}

requires := getDBConnectLocalProxyRequirement(cf.Context, tc, dbInfo.RouteToDatabase)
requires := getDBConnectLocalProxyRequirement(cf.Context, tc, dbInfo.RouteToDatabase, cf.LocalProxyTunnel)
if err := maybeDatabaseLogin(cf, tc, profile, dbInfo, requires); err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1675,8 +1675,13 @@ func getDBLocalProxyRequirement(tc *client.TeleportClient, route tlsca.RouteToDa
return &out
}

func getDBConnectLocalProxyRequirement(ctx context.Context, tc *client.TeleportClient, route tlsca.RouteToDatabase) *dbLocalProxyRequirement {
func getDBConnectLocalProxyRequirement(ctx context.Context, tc *client.TeleportClient, route tlsca.RouteToDatabase, tunnelFlag bool) *dbLocalProxyRequirement {
r := getDBLocalProxyRequirement(tc, route)
// Forces local proxy tunnel when --tunnel is on.
if !r.tunnel && tunnelFlag {
r.addLocalProxyWithTunnel(dbConnectRequireReasonTunnelFlag)
}
// Forces local proxy when cluster has TLS routing enabled.
if !r.localProxy && tc.TLSRoutingEnabled {
r.addLocalProxy(formatTLSRoutingReason(tc.SiteName))
}
Expand Down Expand Up @@ -1852,6 +1857,12 @@ const (
dbFormatYAML = "yaml"
)

const (
// dbConnectRequireReasonTunnelFlag is the reason used in local proxy
// requirement calculation when --tunnel flag is specified.
dbConnectRequireReasonTunnelFlag = "--tunnel flag is specified"
)

var (
// dbCmdUnsupportedTemplate is the error message printed when some
// database subcommands are not supported.
Expand Down
40 changes: 31 additions & 9 deletions tool/tsh/common/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,11 +570,13 @@ func TestLocalProxyRequirement(t *testing.T) {
defaultAuthPref, err := authServer.GetAuthPreference(ctx)
require.NoError(t, err)
tests := map[string]struct {
clusterAuthPref types.AuthPreference
route *tlsca.RouteToDatabase
setupTC func(*client.TeleportClient)
wantLocalProxy bool
wantTunnel bool
clusterAuthPref types.AuthPreference
route *tlsca.RouteToDatabase
setupTC func(*client.TeleportClient)
tunnelFlag bool
wantLocalProxy bool
wantTunnel bool
wantTunnelReason string
}{
"tunnel not required": {
clusterAuthPref: defaultAuthPref,
Expand All @@ -592,8 +594,9 @@ func TestLocalProxyRequirement(t *testing.T) {
RequireMFAType: types.RequireMFAType_SESSION,
},
},
wantLocalProxy: true,
wantTunnel: true,
wantLocalProxy: true,
wantTunnel: true,
wantTunnelReason: "MFA is required",
},
"local proxy not required for separate port": {
clusterAuthPref: defaultAuthPref,
Expand All @@ -614,6 +617,25 @@ func TestLocalProxyRequirement(t *testing.T) {
wantLocalProxy: true,
wantTunnel: false,
},
"tunnel required by tunnel flag": {
clusterAuthPref: defaultAuthPref,
tunnelFlag: true,
wantLocalProxy: true,
wantTunnel: true,
wantTunnelReason: dbConnectRequireReasonTunnelFlag,
},
"tunnel required for separate port by tunnel flag": {
clusterAuthPref: defaultAuthPref,
setupTC: func(tc *client.TeleportClient) {
tc.TLSRoutingEnabled = false
tc.TLSRoutingConnUpgradeRequired = false
tc.PostgresProxyAddr = "separate.postgres.hostport:8888"
},
tunnelFlag: true,
wantLocalProxy: true,
wantTunnel: true,
wantTunnelReason: dbConnectRequireReasonTunnelFlag,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
Expand All @@ -640,12 +662,12 @@ func TestLocalProxyRequirement(t *testing.T) {
Username: "alice",
Database: "postgres",
}
requires := getDBConnectLocalProxyRequirement(ctx, tc, route)
requires := getDBConnectLocalProxyRequirement(ctx, tc, route, tt.tunnelFlag)
require.Equal(t, tt.wantLocalProxy, requires.localProxy)
require.Equal(t, tt.wantTunnel, requires.tunnel)
if requires.tunnel {
require.Len(t, requires.tunnelReasons, 1)
require.Contains(t, requires.tunnelReasons[0], "MFA is required")
require.Contains(t, requires.tunnelReasons[0], tt.wantTunnelReason)
}
})
}
Expand Down
1 change: 1 addition & 0 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
dbConnect.Flag("query", queryHelp).StringVar(&cf.PredicateExpression)
dbConnect.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason)
dbConnect.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest)
dbConnect.Flag("tunnel", "Open authenticated tunnel using database's client certificate so clients don't need to authenticate").Hidden().BoolVar(&cf.LocalProxyTunnel)

// join
join := app.Command("join", "Join the active SSH or Kubernetes session.")
Expand Down

0 comments on commit 8dee20d

Please sign in to comment.