Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v15] add a hidden --tunnel flag to tsh db connect to force local proxy tunnel #48319

Merged
merged 2 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions tool/tsh/common/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ func onDatabaseConnect(cf *CLIConf) error {
return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, dbInfo.RouteToDatabase))
}

requires := getDBConnectLocalProxyRequirement(cf.Context, tc, dbInfo.RouteToDatabase)
requires := getDBConnectLocalProxyRequirement(cf.Context, tc, dbInfo.RouteToDatabase, cf.LocalProxyTunnel)
if err := maybeDatabaseLogin(cf, tc, profile, dbInfo, requires); err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1675,8 +1675,13 @@ func getDBLocalProxyRequirement(tc *client.TeleportClient, route tlsca.RouteToDa
return &out
}

func getDBConnectLocalProxyRequirement(ctx context.Context, tc *client.TeleportClient, route tlsca.RouteToDatabase) *dbLocalProxyRequirement {
func getDBConnectLocalProxyRequirement(ctx context.Context, tc *client.TeleportClient, route tlsca.RouteToDatabase, tunnelFlag bool) *dbLocalProxyRequirement {
r := getDBLocalProxyRequirement(tc, route)
// Forces local proxy tunnel when --tunnel is on.
if !r.tunnel && tunnelFlag {
r.addLocalProxyWithTunnel(dbConnectRequireReasonTunnelFlag)
}
// Forces local proxy when cluster has TLS routing enabled.
if !r.localProxy && tc.TLSRoutingEnabled {
r.addLocalProxy(formatTLSRoutingReason(tc.SiteName))
}
Expand Down Expand Up @@ -1852,6 +1857,12 @@ const (
dbFormatYAML = "yaml"
)

const (
// dbConnectRequireReasonTunnelFlag is the reason used in local proxy
// requirement calculation when --tunnel flag is specified.
dbConnectRequireReasonTunnelFlag = "--tunnel flag is specified"
)

var (
// dbCmdUnsupportedTemplate is the error message printed when some
// database subcommands are not supported.
Expand Down
40 changes: 31 additions & 9 deletions tool/tsh/common/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,11 +570,13 @@ func TestLocalProxyRequirement(t *testing.T) {
defaultAuthPref, err := authServer.GetAuthPreference(ctx)
require.NoError(t, err)
tests := map[string]struct {
clusterAuthPref types.AuthPreference
route *tlsca.RouteToDatabase
setupTC func(*client.TeleportClient)
wantLocalProxy bool
wantTunnel bool
clusterAuthPref types.AuthPreference
route *tlsca.RouteToDatabase
setupTC func(*client.TeleportClient)
tunnelFlag bool
wantLocalProxy bool
wantTunnel bool
wantTunnelReason string
}{
"tunnel not required": {
clusterAuthPref: defaultAuthPref,
Expand All @@ -592,8 +594,9 @@ func TestLocalProxyRequirement(t *testing.T) {
RequireMFAType: types.RequireMFAType_SESSION,
},
},
wantLocalProxy: true,
wantTunnel: true,
wantLocalProxy: true,
wantTunnel: true,
wantTunnelReason: "MFA is required",
},
"local proxy not required for separate port": {
clusterAuthPref: defaultAuthPref,
Expand All @@ -614,6 +617,25 @@ func TestLocalProxyRequirement(t *testing.T) {
wantLocalProxy: true,
wantTunnel: false,
},
"tunnel required by tunnel flag": {
clusterAuthPref: defaultAuthPref,
tunnelFlag: true,
wantLocalProxy: true,
wantTunnel: true,
wantTunnelReason: dbConnectRequireReasonTunnelFlag,
},
"tunnel required for separate port by tunnel flag": {
clusterAuthPref: defaultAuthPref,
setupTC: func(tc *client.TeleportClient) {
tc.TLSRoutingEnabled = false
tc.TLSRoutingConnUpgradeRequired = false
tc.PostgresProxyAddr = "separate.postgres.hostport:8888"
},
tunnelFlag: true,
wantLocalProxy: true,
wantTunnel: true,
wantTunnelReason: dbConnectRequireReasonTunnelFlag,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
Expand All @@ -640,12 +662,12 @@ func TestLocalProxyRequirement(t *testing.T) {
Username: "alice",
Database: "postgres",
}
requires := getDBConnectLocalProxyRequirement(ctx, tc, route)
requires := getDBConnectLocalProxyRequirement(ctx, tc, route, tt.tunnelFlag)
require.Equal(t, tt.wantLocalProxy, requires.localProxy)
require.Equal(t, tt.wantTunnel, requires.tunnel)
if requires.tunnel {
require.Len(t, requires.tunnelReasons, 1)
require.Contains(t, requires.tunnelReasons[0], "MFA is required")
require.Contains(t, requires.tunnelReasons[0], tt.wantTunnelReason)
}
})
}
Expand Down
1 change: 1 addition & 0 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
dbConnect.Flag("query", queryHelp).StringVar(&cf.PredicateExpression)
dbConnect.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason)
dbConnect.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest)
dbConnect.Flag("tunnel", "Open authenticated tunnel using database's client certificate so clients don't need to authenticate").Hidden().BoolVar(&cf.LocalProxyTunnel)

// join
join := app.Command("join", "Join the active SSH or Kubernetes session.")
Expand Down
Loading