From 8dee20d4ba0a3421f35250b167acaa92e8f71d1f Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Mon, 4 Nov 2024 10:43:03 -0500 Subject: [PATCH] [v15] add a hidden `--tunnel` flag to `tsh db connect` to force local proxy tunnel (#48319) * add a hidden `--tunnel` flag to `tsh db connect` to force local proxy with tunnel * move msg to const --- tool/tsh/common/db.go | 15 ++++++++++++-- tool/tsh/common/db_test.go | 40 +++++++++++++++++++++++++++++--------- tool/tsh/common/tsh.go | 1 + 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index 47cf93a49248b..6bef435c91b93 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -765,7 +765,7 @@ func onDatabaseConnect(cf *CLIConf) error { return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, dbInfo.RouteToDatabase)) } - requires := getDBConnectLocalProxyRequirement(cf.Context, tc, dbInfo.RouteToDatabase) + requires := getDBConnectLocalProxyRequirement(cf.Context, tc, dbInfo.RouteToDatabase, cf.LocalProxyTunnel) if err := maybeDatabaseLogin(cf, tc, profile, dbInfo, requires); err != nil { return trace.Wrap(err) } @@ -1675,8 +1675,13 @@ func getDBLocalProxyRequirement(tc *client.TeleportClient, route tlsca.RouteToDa return &out } -func getDBConnectLocalProxyRequirement(ctx context.Context, tc *client.TeleportClient, route tlsca.RouteToDatabase) *dbLocalProxyRequirement { +func getDBConnectLocalProxyRequirement(ctx context.Context, tc *client.TeleportClient, route tlsca.RouteToDatabase, tunnelFlag bool) *dbLocalProxyRequirement { r := getDBLocalProxyRequirement(tc, route) + // Forces local proxy tunnel when --tunnel is on. + if !r.tunnel && tunnelFlag { + r.addLocalProxyWithTunnel(dbConnectRequireReasonTunnelFlag) + } + // Forces local proxy when cluster has TLS routing enabled. if !r.localProxy && tc.TLSRoutingEnabled { r.addLocalProxy(formatTLSRoutingReason(tc.SiteName)) } @@ -1852,6 +1857,12 @@ const ( dbFormatYAML = "yaml" ) +const ( + // dbConnectRequireReasonTunnelFlag is the reason used in local proxy + // requirement calculation when --tunnel flag is specified. + dbConnectRequireReasonTunnelFlag = "--tunnel flag is specified" +) + var ( // dbCmdUnsupportedTemplate is the error message printed when some // database subcommands are not supported. diff --git a/tool/tsh/common/db_test.go b/tool/tsh/common/db_test.go index ba69a6c28dc9b..107dfdcb674d7 100644 --- a/tool/tsh/common/db_test.go +++ b/tool/tsh/common/db_test.go @@ -570,11 +570,13 @@ func TestLocalProxyRequirement(t *testing.T) { defaultAuthPref, err := authServer.GetAuthPreference(ctx) require.NoError(t, err) tests := map[string]struct { - clusterAuthPref types.AuthPreference - route *tlsca.RouteToDatabase - setupTC func(*client.TeleportClient) - wantLocalProxy bool - wantTunnel bool + clusterAuthPref types.AuthPreference + route *tlsca.RouteToDatabase + setupTC func(*client.TeleportClient) + tunnelFlag bool + wantLocalProxy bool + wantTunnel bool + wantTunnelReason string }{ "tunnel not required": { clusterAuthPref: defaultAuthPref, @@ -592,8 +594,9 @@ func TestLocalProxyRequirement(t *testing.T) { RequireMFAType: types.RequireMFAType_SESSION, }, }, - wantLocalProxy: true, - wantTunnel: true, + wantLocalProxy: true, + wantTunnel: true, + wantTunnelReason: "MFA is required", }, "local proxy not required for separate port": { clusterAuthPref: defaultAuthPref, @@ -614,6 +617,25 @@ func TestLocalProxyRequirement(t *testing.T) { wantLocalProxy: true, wantTunnel: false, }, + "tunnel required by tunnel flag": { + clusterAuthPref: defaultAuthPref, + tunnelFlag: true, + wantLocalProxy: true, + wantTunnel: true, + wantTunnelReason: dbConnectRequireReasonTunnelFlag, + }, + "tunnel required for separate port by tunnel flag": { + clusterAuthPref: defaultAuthPref, + setupTC: func(tc *client.TeleportClient) { + tc.TLSRoutingEnabled = false + tc.TLSRoutingConnUpgradeRequired = false + tc.PostgresProxyAddr = "separate.postgres.hostport:8888" + }, + tunnelFlag: true, + wantLocalProxy: true, + wantTunnel: true, + wantTunnelReason: dbConnectRequireReasonTunnelFlag, + }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { @@ -640,12 +662,12 @@ func TestLocalProxyRequirement(t *testing.T) { Username: "alice", Database: "postgres", } - requires := getDBConnectLocalProxyRequirement(ctx, tc, route) + requires := getDBConnectLocalProxyRequirement(ctx, tc, route, tt.tunnelFlag) require.Equal(t, tt.wantLocalProxy, requires.localProxy) require.Equal(t, tt.wantTunnel, requires.tunnel) if requires.tunnel { require.Len(t, requires.tunnelReasons, 1) - require.Contains(t, requires.tunnelReasons[0], "MFA is required") + require.Contains(t, requires.tunnelReasons[0], tt.wantTunnelReason) } }) } diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index f561bb41effa2..27c8dc98030da 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -987,6 +987,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { dbConnect.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) dbConnect.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason) dbConnect.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest) + dbConnect.Flag("tunnel", "Open authenticated tunnel using database's client certificate so clients don't need to authenticate").Hidden().BoolVar(&cf.LocalProxyTunnel) // join join := app.Command("join", "Join the active SSH or Kubernetes session.")