diff --git a/integration/proxy/teleterm_test.go b/integration/proxy/teleterm_test.go index 59697ef84a958..774f541b913b4 100644 --- a/integration/proxy/teleterm_test.go +++ b/integration/proxy/teleterm_test.go @@ -51,6 +51,7 @@ import ( wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" "github.com/gravitational/teleport/lib/client" libclient "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/client/clientcache" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/service/servicecfg" @@ -58,7 +59,6 @@ import ( "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/daemon" "github.com/gravitational/teleport/lib/teleterm/gateway" - "github.com/gravitational/teleport/lib/teleterm/services/clientcache" "github.com/gravitational/teleport/lib/utils" ) @@ -200,8 +200,8 @@ func testGatewayCertRenewal(ctx context.Context, t *testing.T, params gatewayCer CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { return grpc.WithTransportCredentials(insecure.NewCredentials()), nil }, - CreateClientCacheFunc: func(resolveCluster daemon.ResolveClusterFunc) daemon.ClientCache { - return clientcache.NewNoCache(clientcache.ResolveClusterFunc(resolveCluster)) + CreateClientCacheFunc: func(newClient clientcache.NewClientFunc) (daemon.ClientCache, error) { + return clientcache.NewNoCache(newClient), nil }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), diff --git a/lib/client/api.go b/lib/client/api.go index 5f046e5d3a15d..49e5aadc5212d 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -601,9 +601,9 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error, case !IsErrorResolvableWithRelogin(fnErr): return trace.Wrap(fnErr) } - opt := retryWithReloginOptions{} + opt := defaultRetryWithReloginOptions() for _, o := range opts { - o(&opt) + o(opt) } log.Debugf("Activating relogin on error=%q (type=%T)", fnErr, trace.Unwrap(fnErr)) @@ -650,11 +650,17 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error, } // Save profile to record proxy credentials - if err := tc.SaveProfile(true); err != nil { + if err := tc.SaveProfile(opt.makeCurrentProfile); err != nil { log.Warningf("Failed to save profile: %v", err) return trace.Wrap(err) } + if opt.afterLoginHook != nil { + if err := opt.afterLoginHook(); err != nil { + return trace.Wrap(err) + } + } + return fn() } @@ -666,9 +672,19 @@ type RetryWithReloginOption func(*retryWithReloginOptions) type retryWithReloginOptions struct { // beforeLoginHook is a function that will be called before the login attempt. beforeLoginHook func() error + // afterLoginHook is a function that will be called after a successful login. + afterLoginHook func() error + // makeCurrentProfile determines whether to update the current profile after login. + makeCurrentProfile bool } -// WithBeforeLogin is a functional option for configuring a function that will +func defaultRetryWithReloginOptions() *retryWithReloginOptions { + return &retryWithReloginOptions{ + makeCurrentProfile: true, + } +} + +// WithBeforeLoginHook is a functional option for configuring a function that will // be called before the login attempt. func WithBeforeLoginHook(fn func() error) RetryWithReloginOption { return func(o *retryWithReloginOptions) { @@ -676,6 +692,22 @@ func WithBeforeLoginHook(fn func() error) RetryWithReloginOption { } } +// WithAfterLoginHook is a functional option for configuring a function that will +// be called after a successful login. +func WithAfterLoginHook(fn func() error) RetryWithReloginOption { + return func(o *retryWithReloginOptions) { + o.afterLoginHook = fn + } +} + +// WithMakeCurrentProfile is a functional option for configuring whether to update the current profile after a +// successful login. +func WithMakeCurrentProfile(makeCurrentProfile bool) RetryWithReloginOption { + return func(o *retryWithReloginOptions) { + o.makeCurrentProfile = makeCurrentProfile + } +} + // IsErrorResolvableWithRelogin returns true if relogin is attempted on `err`. func IsErrorResolvableWithRelogin(err error) bool { // Private key policy errors indicate that the user must login with an diff --git a/lib/teleterm/services/clientcache/clientcache.go b/lib/client/clientcache/clientcache.go similarity index 50% rename from lib/teleterm/services/clientcache/clientcache.go rename to lib/client/clientcache/clientcache.go index 3ce3abd19ec68..5a9c4df29e7de 100644 --- a/lib/teleterm/services/clientcache/clientcache.go +++ b/lib/client/clientcache/clientcache.go @@ -27,66 +27,90 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/teleterm/api/uri" - "github.com/gravitational/teleport/lib/teleterm/clusters" ) -// Cache stores clients keyed by cluster URI. +// Cache stores clients keyed by profile name and leaf cluster name. // Safe for concurrent access. // Closes all clients and wipes the cache on Clear. type Cache struct { cfg Config - mu sync.Mutex - // clients keep mapping between cluster URI - // (both root and leaf) and cluster clients - clients map[uri.ResourceURI]*client.ClusterClient - // group prevents duplicate requests to create clients - // for a given cluster URI + mu sync.RWMutex + // clients keeps a mapping from key (profile name and leaf cluster name) to cluster client. + clients map[key]*client.ClusterClient + // group prevents duplicate requests to create clients for a given cluster. group singleflight.Group } -type ResolveClusterFunc func(uri uri.ResourceURI) (*clusters.Cluster, *client.TeleportClient, error) +// NewClientFunc is a function that will return a new [*client.TeleportClient] for a given profile and leaf +// cluster. [leafClusterName] may be empty, in which case implementations should return a client for the root cluster. +type NewClientFunc func(ctx context.Context, profileName, leafClusterName string) (*client.TeleportClient, error) + +// RetryWithReloginFunc is a function that should call [fn], and if it fails with an error that may be +// resolved with a cluster relogin, attempts the relogin and calls [fn] again if the relogin is successful. +type RetryWithReloginFunc func(ctx context.Context, tc *client.TeleportClient, fn func() error, opts ...client.RetryWithReloginOption) error // Config describes the client cache configuration. type Config struct { - ResolveClusterFunc ResolveClusterFunc - Log logrus.FieldLogger + NewClientFunc NewClientFunc + RetryWithReloginFunc RetryWithReloginFunc + Log logrus.FieldLogger } -func (c *Config) checkAndSetDefaults() { +func (c *Config) checkAndSetDefaults() error { + if c.NewClientFunc == nil { + return trace.BadParameter("NewClientFunc is required") + } + if c.RetryWithReloginFunc == nil { + return trace.BadParameter("RetryWithReloginFunc is required") + } if c.Log == nil { c.Log = logrus.WithField(teleport.ComponentKey, "clientcache") } + return nil +} + +type key struct { + profile string + leafCluster string +} + +func (k key) String() string { + if k.leafCluster != "" { + return k.profile + "/" + k.leafCluster + } + return k.profile } // New creates an instance of Cache. -func New(c Config) *Cache { - c.checkAndSetDefaults() +func New(c Config) (*Cache, error) { + if err := c.checkAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } return &Cache{ cfg: c, - clients: make(map[uri.ResourceURI]*client.ClusterClient), - } + clients: make(map[key]*client.ClusterClient), + }, nil } -// Get returns a client from the cache if there is one, -// otherwise it dials the remote server. +// Get returns a client from the cache if there is one, otherwise it dials the remote server. // The caller should not close the returned client. -func (c *Cache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.ClusterClient, error) { - groupClt, err, _ := c.group.Do(clusterURI.String(), func() (any, error) { - if fromCache := c.getFromCache(clusterURI); fromCache != nil { - c.cfg.Log.WithField("cluster", clusterURI.String()).Info("Retrieved client from cache.") +func (c *Cache) Get(ctx context.Context, profileName, leafClusterName string) (*client.ClusterClient, error) { + k := key{profile: profileName, leafCluster: leafClusterName} + groupClt, err, _ := c.group.Do(k.String(), func() (any, error) { + if fromCache := c.getFromCache(k); fromCache != nil { + c.cfg.Log.WithField("cluster", k).Debug("Retrieved client from cache.") return fromCache, nil } - _, clusterClient, err := c.cfg.ResolveClusterFunc(clusterURI) + tc, err := c.cfg.NewClientFunc(ctx, profileName, leafClusterName) if err != nil { return nil, trace.Wrap(err) } var newClient *client.ClusterClient - if err := clusters.AddMetadataToRetryableError(ctx, func() error { - clt, err := clusterClient.ConnectToCluster(ctx) + if err := c.cfg.RetryWithReloginFunc(ctx, tc, func() error { + clt, err := tc.ConnectToCluster(ctx) if err != nil { return trace.Wrap(err) } @@ -96,12 +120,10 @@ func (c *Cache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.Cl return nil, trace.Wrap(err) } - // We'll save the client in the cache, so we don't have to - // build a new connection next time. - // All cached clients will be closed when the daemon exits. - c.addToCache(clusterURI, newClient) + // Save the client in the cache, so we don't have to build a new connection next time. + c.addToCache(k, newClient) - c.cfg.Log.WithField("cluster", clusterURI.String()).Info("Added client to cache.") + c.cfg.Log.WithField("cluster", k).Info("Added client to cache.") return newClient, nil }) @@ -117,30 +139,28 @@ func (c *Cache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.Cl return clt, nil } -// ClearForRoot closes and removes clients from the cache -// for the root cluster and its leaf clusters. -func (c *Cache) ClearForRoot(clusterURI uri.ResourceURI) error { +// ClearForRoot closes and removes clients from the cache for the root cluster and its leaf clusters. +func (c *Cache) ClearForRoot(profileName string) error { c.mu.Lock() defer c.mu.Unlock() - rootClusterURI := clusterURI.GetRootClusterURI() var ( errors []error deleted []string ) - for resourceURI, clt := range c.clients { - if resourceURI.GetRootClusterURI() == rootClusterURI { + for k, clt := range c.clients { + if k.profile == profileName { if err := clt.Close(); err != nil { errors = append(errors, err) } - deleted = append(deleted, resourceURI.GetClusterURI().String()) - delete(c.clients, resourceURI) + deleted = append(deleted, k.String()) + delete(c.clients, k) } } c.cfg.Log.WithFields( - logrus.Fields{"cluster": rootClusterURI.String(), "clients": deleted}, + logrus.Fields{"cluster": profileName, "clients": deleted}, ).Info("Invalidated cached clients for root cluster.") return trace.NewAggregate(errors...) @@ -163,18 +183,18 @@ func (c *Cache) Clear() error { return trace.NewAggregate(errors...) } -func (c *Cache) addToCache(clusterURI uri.ResourceURI, clusterClient *client.ClusterClient) { +func (c *Cache) addToCache(k key, clusterClient *client.ClusterClient) { c.mu.Lock() defer c.mu.Unlock() - c.clients[clusterURI] = clusterClient + c.clients[k] = clusterClient } -func (c *Cache) getFromCache(clusterURI uri.ResourceURI) *client.ClusterClient { - c.mu.Lock() - defer c.mu.Unlock() +func (c *Cache) getFromCache(k key) *client.ClusterClient { + c.mu.RLock() + defer c.mu.RUnlock() - clt := c.clients[clusterURI] + clt := c.clients[k] return clt } @@ -183,24 +203,24 @@ func (c *Cache) getFromCache(clusterURI uri.ResourceURI) *client.ClusterClient { // // ClearForRoot and Clear still work as expected. type NoCache struct { - mu sync.Mutex - resolveClusterFunc ResolveClusterFunc - clients []noCacheClient + mu sync.Mutex + newClientFunc NewClientFunc + clients []noCacheClient } type noCacheClient struct { - uri uri.ResourceURI + k key client *client.ClusterClient } -func NewNoCache(resolveClusterFunc ResolveClusterFunc) *NoCache { +func NewNoCache(newClientFunc NewClientFunc) *NoCache { return &NoCache{ - resolveClusterFunc: resolveClusterFunc, + newClientFunc: newClientFunc, } } -func (c *NoCache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.ClusterClient, error) { - _, clusterClient, err := c.resolveClusterFunc(clusterURI) +func (c *NoCache) Get(ctx context.Context, profileName, leafClusterName string) (*client.ClusterClient, error) { + clusterClient, err := c.newClientFunc(ctx, profileName, leafClusterName) if err != nil { return nil, trace.Wrap(err) } @@ -212,7 +232,7 @@ func (c *NoCache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client. c.mu.Lock() c.clients = append(c.clients, noCacheClient{ - uri: clusterURI, + k: key{profile: profileName, leafCluster: leafClusterName}, client: newClient, }) c.mu.Unlock() @@ -220,17 +240,16 @@ func (c *NoCache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client. return newClient, nil } -func (c *NoCache) ClearForRoot(clusterURI uri.ResourceURI) error { +func (c *NoCache) ClearForRoot(profileName string) error { c.mu.Lock() defer c.mu.Unlock() - rootClusterURI := clusterURI.GetRootClusterURI() var ( errors []error ) c.clients = slices.DeleteFunc(c.clients, func(ncc noCacheClient) bool { - belongsToCluster := ncc.uri.GetRootClusterURI() == rootClusterURI + belongsToCluster := ncc.k.profile == profileName if belongsToCluster { if err := ncc.client.Close(); err != nil { diff --git a/lib/client/local_proxy_middleware.go b/lib/client/local_proxy_middleware.go index c35f649f63540..aae23ba25996d 100644 --- a/lib/client/local_proxy_middleware.go +++ b/lib/client/local_proxy_middleware.go @@ -22,7 +22,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "net" "time" "github.com/gravitational/trace" @@ -76,7 +75,7 @@ func NewAppCertChecker(tc *TeleportClient, appRoute proto.RouteToApp, clock cloc // OnNewConnection is a callback triggered when a new downstream connection is // accepted by the local proxy. -func (c *CertChecker) OnNewConnection(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error { +func (c *CertChecker) OnNewConnection(ctx context.Context, lp *alpnproxy.LocalProxy) error { return trace.Wrap(c.ensureValidCerts(ctx, lp)) } diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index 1d218344c2aed..f76e421c073be 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -99,7 +99,7 @@ type LocalProxyMiddleware interface { // OnNewConnection is a callback triggered when a new downstream connection is // accepted by the local proxy. If an error is returned, the connection will be closed // by the local proxy. - OnNewConnection(ctx context.Context, lp *LocalProxy, conn net.Conn) error + OnNewConnection(ctx context.Context, lp *LocalProxy) error // OnStart is a callback triggered when the local proxy starts. OnStart(ctx context.Context, lp *LocalProxy) error } @@ -197,7 +197,7 @@ func (l *LocalProxy) start(ctx context.Context) error { l.cfg.Log.Debug("Accepted downstream connection.") if l.cfg.Middleware != nil { - if err := l.cfg.Middleware.OnNewConnection(ctx, l, conn); err != nil { + if err := l.cfg.Middleware.OnNewConnection(ctx, l); err != nil { l.cfg.Log.WithError(err).Error("Middleware failed to handle client connection.") if err := conn.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) { l.cfg.Log.WithError(err).Debug("Failed to close client connection.") @@ -232,21 +232,59 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC return trace.Wrap(err) } - tlsConn, err := client.DialALPN(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(cert)) + upstreamConn, err := dialALPNMaybePing(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(cert)) if err != nil { return trace.Wrap(err) } - defer tlsConn.Close() + defer upstreamConn.Close() - var upstreamConn net.Conn = tlsConn - if common.IsPingProtocol(common.Protocol(tlsConn.ConnectionState().NegotiatedProtocol)) { - l.cfg.Log.Debug("Using ping connection") - upstreamConn = pingconn.NewTLS(tlsConn) + return trace.Wrap(utils.ProxyConn(ctx, downstreamConn, upstreamConn)) +} + +// HandleTCPConnector injects an inbound TCP connection (via [connector]) that doesn't come in through any +// net.Listener. It is used by VNet to share the common local proxy code. [connector] should be called as late +// as possible so that in case of error VNet clients get a failed TCP dial (with RST) rather than a successful +// dial with an immediately closed connection. +func (l *LocalProxy) HandleTCPConnector(ctx context.Context, connector func() (net.Conn, error)) error { + if l.cfg.Middleware != nil { + if err := l.cfg.Middleware.OnNewConnection(ctx, l); err != nil { + return trace.Wrap(err, "middleware failed to handle client connection") + } + } + + cert, err := l.getCertWithoutConn() + if err != nil { + return trace.Wrap(err) + } + + upstreamConn, err := dialALPNMaybePing(ctx, l.cfg.RemoteProxyAddr, l.getALPNDialerConfig(cert)) + if err != nil { + return trace.Wrap(err) } + defer upstreamConn.Close() + + downstreamConn, err := connector() + if err != nil { + return trace.Wrap(err, "getting downstream conn") + } + defer downstreamConn.Close() return trace.Wrap(utils.ProxyConn(ctx, downstreamConn, upstreamConn)) } +// dialALPNMaybePing is a helper to dial using an ALPNDialer, it wraps the tls conn in a ping conn if +// necessary, and returns a net.Conn if successful. +func dialALPNMaybePing(ctx context.Context, addr string, cfg client.ALPNDialerConfig) (net.Conn, error) { + tlsConn, err := client.DialALPN(ctx, addr, cfg) + if err != nil { + return nil, trace.Wrap(err) + } + if common.IsPingProtocol(common.Protocol(tlsConn.ConnectionState().NegotiatedProtocol)) { + return pingconn.NewTLS(tlsConn), nil + } + return tlsConn, nil +} + func (l *LocalProxy) Close() error { l.cancel() if l.cfg.Listener != nil { @@ -501,6 +539,13 @@ func (l *LocalProxy) getCertForConn(downstreamConn net.Conn) (tls.Certificate, n return tls.Certificate{}, downstreamConn, nil } +func (l *LocalProxy) getCertWithoutConn() (tls.Certificate, error) { + if l.cfg.CheckCertNeeded { + return tls.Certificate{}, trace.BadParameter("getCertWithoutConn called while CheckCertNeeded is true: this is a bug") + } + return l.getCert(), nil +} + func (l *LocalProxy) isPostgresProxy() bool { for _, proto := range common.ProtocolsToString(l.cfg.Protocols) { if strings.HasPrefix(proto, string(common.ProtocolPostgres)) { diff --git a/lib/srv/alpnproxy/local_proxy_test.go b/lib/srv/alpnproxy/local_proxy_test.go index 03972890fe309..b5ba5b1582b92 100644 --- a/lib/srv/alpnproxy/local_proxy_test.go +++ b/lib/srv/alpnproxy/local_proxy_test.go @@ -198,7 +198,7 @@ func (m *mockMiddlewareCounter) onStateChange() { } } -func (m *mockMiddlewareCounter) OnNewConnection(_ context.Context, _ *LocalProxy, _ net.Conn) error { +func (m *mockMiddlewareCounter) OnNewConnection(_ context.Context, _ *LocalProxy) error { m.Lock() defer m.Unlock() m.connCount++ @@ -292,7 +292,7 @@ type mockCertRenewer struct { cert tls.Certificate } -func (m *mockCertRenewer) OnNewConnection(_ context.Context, lp *LocalProxy, _ net.Conn) error { +func (m *mockCertRenewer) OnNewConnection(_ context.Context, lp *LocalProxy) error { lp.SetCert(m.cert) return nil } @@ -428,7 +428,7 @@ func TestCheckDBCerts(t *testing.T) { type mockMiddlewareConnUnauth struct { } -func (m *mockMiddlewareConnUnauth) OnNewConnection(_ context.Context, _ *LocalProxy, _ net.Conn) error { +func (m *mockMiddlewareConnUnauth) OnNewConnection(_ context.Context, _ *LocalProxy) error { return trace.AccessDenied("access denied.") } diff --git a/lib/tbot/service_database_tunnel.go b/lib/tbot/service_database_tunnel.go index f69b2f2cf9ef3..6cd7f18ff362a 100644 --- a/lib/tbot/service_database_tunnel.go +++ b/lib/tbot/service_database_tunnel.go @@ -23,7 +23,6 @@ import ( "crypto/tls" "fmt" "log/slog" - "net" "github.com/gravitational/trace" @@ -42,13 +41,13 @@ import ( var _ alpnproxy.LocalProxyMiddleware = (*alpnProxyMiddleware)(nil) type alpnProxyMiddleware struct { - onNewConnection func(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error + onNewConnection func(ctx context.Context, lp *alpnproxy.LocalProxy) error onStart func(ctx context.Context, lp *alpnproxy.LocalProxy) error } -func (a alpnProxyMiddleware) OnNewConnection(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error { +func (a alpnProxyMiddleware) OnNewConnection(ctx context.Context, lp *alpnproxy.LocalProxy) error { if a.onNewConnection != nil { - return a.onNewConnection(ctx, lp, conn) + return a.onNewConnection(ctx, lp) } return nil } @@ -123,7 +122,7 @@ func (s *DatabaseTunnelService) buildLocalProxyConfig(ctx context.Context) (lpCf s.log.DebugContext(ctx, "Issued initial certificate for local proxy.") middleware := alpnProxyMiddleware{ - onNewConnection: func(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error { + onNewConnection: func(ctx context.Context, lp *alpnproxy.LocalProxy) error { ctx, span := tracer.Start(ctx, "DatabaseTunnelService/OnNewConnection") defer span.End() diff --git a/lib/teleterm/api/uri/uri.go b/lib/teleterm/api/uri/uri.go index da7099f3b4562..7886bf478b59e 100644 --- a/lib/teleterm/api/uri/uri.go +++ b/lib/teleterm/api/uri/uri.go @@ -164,12 +164,7 @@ func (r ResourceURI) GetRootClusterURI() ResourceURI { // If called on a leaf cluster resource URI, it'll return the URI of the leaf cluster. // If called on a root cluster URI or a leaf cluster URI, it's a noop. func (r ResourceURI) GetClusterURI() ResourceURI { - clusterURI := r.GetRootClusterURI() - - if leafClusterName := r.GetLeafClusterName(); leafClusterName != "" { - clusterURI = clusterURI.AppendLeafCluster(leafClusterName) - } - return clusterURI + return r.GetRootClusterURI().AppendLeafCluster(r.GetLeafClusterName()) } // AppendServer appends server segment to the URI @@ -178,8 +173,12 @@ func (r ResourceURI) AppendServer(id string) ResourceURI { return r } -// AppendLeafCluster appends leaf cluster segment to the URI +// AppendLeafCluster appends leaf cluster segment to the URI if name is not empty. func (r ResourceURI) AppendLeafCluster(name string) ResourceURI { + if name == "" { + return r + } + r.path = fmt.Sprintf("%v/leaves/%v", r.path, name) return r } diff --git a/lib/teleterm/api/uri/uri_test.go b/lib/teleterm/api/uri/uri_test.go index 0d5568ca8aac4..e7850e9b7ee36 100644 --- a/lib/teleterm/api/uri/uri_test.go +++ b/lib/teleterm/api/uri/uri_test.go @@ -406,3 +406,29 @@ func TestIsLeaf(t *testing.T) { }) } } + +func TestAppendLeafCluster(t *testing.T) { + tests := []struct { + profileName string + leafName string + out string + }{ + { + profileName: "foo", + leafName: "bar", + out: "/clusters/foo/leaves/bar", + }, + { + profileName: "foo", + leafName: "", + out: "/clusters/foo", + }, + } + + for _, tt := range tests { + t.Run(tt.out, func(t *testing.T) { + actualOut := uri.NewClusterURI(tt.profileName).AppendLeafCluster(tt.leafName).String() + require.Equal(t, tt.out, actualOut) + }) + } +} diff --git a/lib/teleterm/daemon/config.go b/lib/teleterm/daemon/config.go index b047134c274c0..8e0c9f52ef95b 100644 --- a/lib/teleterm/daemon/config.go +++ b/lib/teleterm/daemon/config.go @@ -28,9 +28,9 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/client/clientcache" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/clusters" - "github.com/gravitational/teleport/lib/teleterm/services/clientcache" "github.com/gravitational/teleport/lib/teleterm/services/connectmycomputer" ) @@ -72,7 +72,7 @@ type Config struct { ConnectMyComputerNodeDelete *connectmycomputer.NodeDelete ConnectMyComputerNodeName *connectmycomputer.NodeName - CreateClientCacheFunc func(resolver ResolveClusterFunc) ClientCache + CreateClientCacheFunc func(resolver clientcache.NewClientFunc) (ClientCache, error) } // ResolveClusterFunc returns a cluster by URI. @@ -83,10 +83,10 @@ type ClientCache interface { // Get returns a client from the cache if there is one, // otherwise it dials the remote server. // The caller should not close the returned client. - Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.ClusterClient, error) + Get(ctx context.Context, profileName, leafClusterName string) (*client.ClusterClient, error) // ClearForRoot closes and removes clients from the cache // for the root cluster and its leaf clusters. - ClearForRoot(clusterURI uri.ResourceURI) error + ClearForRoot(profileName string) error // Clear closes and removes all clients. Clear() error } @@ -161,10 +161,14 @@ func (c *Config) CheckAndSetDefaults() error { } if c.CreateClientCacheFunc == nil { - c.CreateClientCacheFunc = func(resolver ResolveClusterFunc) ClientCache { + c.CreateClientCacheFunc = func(newClientFunc clientcache.NewClientFunc) (ClientCache, error) { + retryWithRelogin := func(ctx context.Context, tc *client.TeleportClient, fn func() error, opts ...client.RetryWithReloginOption) error { + return clusters.AddMetadataToRetryableError(ctx, fn) + } return clientcache.New(clientcache.Config{ - Log: c.Log, - ResolveClusterFunc: clientcache.ResolveClusterFunc(resolver), + Log: c.Log, + NewClientFunc: newClientFunc, + RetryWithReloginFunc: clientcache.RetryWithReloginFunc(retryWithRelogin), }) } } diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index 1e176fd0069da..5972859769f3e 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -95,7 +95,11 @@ func New(cfg Config) (*Service, error) { // That's because Daemon.ResolveClusterURI sets a custom MFAPromptConstructor that // shows an MFA prompt in Connect. // At the level of Storage.ResolveClusterFunc we don't have access to it. - service.clientCache = cfg.CreateClientCacheFunc(service.ResolveClusterURI) + service.clientCache, err = cfg.CreateClientCacheFunc(service.NewClusterClient) + if err != nil { + return nil, trace.Wrap(err) + } + return service, nil } @@ -221,6 +225,14 @@ func (s *Service) RemoveCluster(ctx context.Context, uri string) error { return nil } +// NewClusterClient is a wrapper on ResolveClusterURI that can be passed as an argument to +// s.cfg.CreateClientCacheFunc. +func (s *Service) NewClusterClient(ctx context.Context, profileName, leafClusterName string) (*client.TeleportClient, error) { + uri := uri.NewClusterURI(profileName).AppendLeafCluster(leafClusterName) + _, clusterClient, err := s.ResolveClusterURI(uri) + return clusterClient, trace.Wrap(err) +} + // ResolveCluster resolves a cluster by URI by reading data stored on disk in the profile. // // It doesn't make network requests so the returned clusters.Cluster will not include full @@ -1119,14 +1131,17 @@ func (s *Service) findGatewayByTargetURI(targetURI uri.ResourceURI) (gateway.Gat // GetCachedClient returns a client from the cache if it exists, // otherwise it dials the remote server. func (s *Service) GetCachedClient(ctx context.Context, clusterURI uri.ResourceURI) (*client.ClusterClient, error) { - clt, err := s.clientCache.Get(ctx, clusterURI) + profileName := clusterURI.GetProfileName() + leafClusterName := clusterURI.GetLeafClusterName() + clt, err := s.clientCache.Get(ctx, profileName, leafClusterName) return clt, trace.Wrap(err) } // ClearCachedClientsForRoot closes and removes clients from the cache // for the root cluster and its leaf clusters. func (s *Service) ClearCachedClientsForRoot(clusterURI uri.ResourceURI) error { - return trace.Wrap(s.clientCache.ClearForRoot(clusterURI)) + profileName := clusterURI.GetProfileName() + return trace.Wrap(s.clientCache.ClearForRoot(profileName)) } // Service is the daemon service diff --git a/lib/teleterm/daemon/daemon_test.go b/lib/teleterm/daemon/daemon_test.go index 06618adb5ed7b..a887391cc6d87 100644 --- a/lib/teleterm/daemon/daemon_test.go +++ b/lib/teleterm/daemon/daemon_test.go @@ -40,6 +40,7 @@ import ( "github.com/gravitational/teleport/api/types" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/client/clientcache" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/clusters" @@ -272,8 +273,8 @@ func TestGatewayCRUD(t *testing.T) { GatewayCreator: mockGatewayCreator, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), - CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache { - return fakeClientCache{} + CreateClientCacheFunc: func(newClientFunc clientcache.NewClientFunc) (ClientCache, error) { + return fakeClientCache{}, nil }, }) require.NoError(t, err) @@ -453,8 +454,8 @@ func TestRetryWithRelogin(t *testing.T) { }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), - CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache { - return fakeClientCache{} + CreateClientCacheFunc: func(newClientFunc clientcache.NewClientFunc) (ClientCache, error) { + return fakeClientCache{}, nil }, }) require.NoError(t, err) @@ -506,8 +507,8 @@ func TestImportantModalSemaphore(t *testing.T) { }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), - CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache { - return fakeClientCache{} + CreateClientCacheFunc: func(newClientFunc clientcache.NewClientFunc) (ClientCache, error) { + return fakeClientCache{}, nil }, }) require.NoError(t, err) @@ -657,8 +658,8 @@ func TestGetGatewayCLICommand(t *testing.T) { }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), - CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache { - return fakeClientCache{} + CreateClientCacheFunc: func(newClientFunc clientcache.NewClientFunc) (ClientCache, error) { + return fakeClientCache{}, nil }, }) require.NoError(t, err) @@ -744,6 +745,6 @@ type fakeClientCache struct { ClientCache } -func (f fakeClientCache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.ClusterClient, error) { +func (f fakeClientCache) Get(ctx context.Context, profileName, leafClusterName string) (*client.ClusterClient, error) { return &client.ClusterClient{}, nil } diff --git a/lib/teleterm/gateway/app_middleware.go b/lib/teleterm/gateway/app_middleware.go index 8116f65137503..2da5946018147 100644 --- a/lib/teleterm/gateway/app_middleware.go +++ b/lib/teleterm/gateway/app_middleware.go @@ -21,7 +21,6 @@ import ( "crypto/tls" "crypto/x509" "errors" - "net" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -38,7 +37,7 @@ type appMiddleware struct { // it on the local proxy. // Other middlewares typically also handle MFA here. App access doesn't support per-session MFA yet, // so detecting expired certs is all this middleware can do. -func (m *appMiddleware) OnNewConnection(ctx context.Context, lp *alpn.LocalProxy, conn net.Conn) error { +func (m *appMiddleware) OnNewConnection(ctx context.Context, lp *alpn.LocalProxy) error { err := lp.CheckCertExpiry() if err == nil { return nil diff --git a/lib/teleterm/gateway/db_middleware.go b/lib/teleterm/gateway/db_middleware.go index 8b8fcd2b33da6..8469334ba0175 100644 --- a/lib/teleterm/gateway/db_middleware.go +++ b/lib/teleterm/gateway/db_middleware.go @@ -22,7 +22,6 @@ import ( "context" "crypto/x509" "errors" - "net" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -43,7 +42,7 @@ type dbMiddleware struct { // // In the future, DBCertChecker is going to be extended so that it's used by both tsh and Connect // and this middleware will be removed. -func (m *dbMiddleware) OnNewConnection(ctx context.Context, lp *alpn.LocalProxy, conn net.Conn) error { +func (m *dbMiddleware) OnNewConnection(ctx context.Context, lp *alpn.LocalProxy) error { err := lp.CheckDBCert(m.dbRoute) if err == nil { return nil diff --git a/lib/teleterm/gateway/db_middleware_test.go b/lib/teleterm/gateway/db_middleware_test.go index a6b5b8776dc33..653f54ea2b58b 100644 --- a/lib/teleterm/gateway/db_middleware_test.go +++ b/lib/teleterm/gateway/db_middleware_test.go @@ -125,7 +125,7 @@ func TestDBMiddleware_OnNewConnection(t *testing.T) { localProxy.SetCert(tlsCert) - err = middleware.OnNewConnection(ctx, localProxy, nil /* net.Conn, not used by middleware */) + err = middleware.OnNewConnection(ctx, localProxy) tt.expectation(t, err, hasCalledOnExpiredCert) }) } diff --git a/lib/vnet/app_resolver.go b/lib/vnet/app_resolver.go index 0981b8eedc2c9..bb098f001c519 100644 --- a/lib/vnet/app_resolver.go +++ b/lib/vnet/app_resolver.go @@ -18,19 +18,24 @@ package vnet import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "log/slog" "net" "strings" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "golang.org/x/sync/singleflight" "github.com/gravitational/teleport" apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/srv/alpnproxy" + alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" ) // AppProvider is an interface providing the necessary methods to log in to apps and get clients able to list @@ -44,12 +49,35 @@ type AppProvider interface { // [leafClusterName] may be empty when requesting a client for the root cluster. Returned clients are // expected to be cached, as this may be called frequently. GetCachedClient(ctx context.Context, profileName, leafClusterName string) (*client.ClusterClient, error) + + // ReissueAppCert returns a new app certificate for the given app in the named profile and leaf cluster. + // Implementations may trigger a re-login to the cluster, but if they do, they MUST clear all cached + // clients for that cluster so that new working clients will be returned from [GetCachedClient]. + ReissueAppCert(ctx context.Context, profileName, leafClusterName string, app types.Application) (tls.Certificate, error) + + // GetDialOptions returns ALPN dial options for the profile. + GetDialOptions(ctx context.Context, profileName string) (*DialOptions, error) +} + +// DialOptions holds ALPN dial options for dialing apps. +type DialOptions struct { + // WebProxyAddr is the address to dial. + WebProxyAddr string + // ALPNConnUpgradeRequired specifies if ALPN connection upgrade is required. + ALPNConnUpgradeRequired bool + // SNI is a ServerName value set for upstream TLS connection. + SNI string + // RootClusterCACertPool overrides the x509 certificate pool used to verify the server. + RootClusterCACertPool *x509.CertPool + // InsecureSkipTLSVerify turns off verification for x509 upstream ALPN proxy service certificate. + InsecureSkipVerify bool } // TCPAppResolver implements [TCPHandlerResolver] for Teleport TCP apps. type TCPAppResolver struct { appProvider AppProvider slog *slog.Logger + clock clockwork.Clock } // NewTCPAppResolver returns a new *TCPAppResolver which will resolve full-qualified domain names to @@ -60,10 +88,23 @@ type TCPAppResolver struct { // handled. // // [appProvider] is also used to get app certificates used to dial the apps. -func NewTCPAppResolver(appProvider AppProvider) *TCPAppResolver { - return &TCPAppResolver{ +func NewTCPAppResolver(appProvider AppProvider, opts ...tcpAppResolverOption) *TCPAppResolver { + r := &TCPAppResolver{ appProvider: appProvider, slog: slog.With(teleport.ComponentKey, "VNet.AppResolver"), + clock: clockwork.NewRealClock(), + } + for _, opt := range opts { + opt(r) + } + return r +} + +type tcpAppResolverOption func(*TCPAppResolver) + +func withClock(clock clockwork.Clock) tcpAppResolverOption { + return func(r *TCPAppResolver) { + r.clock = clock } } @@ -118,7 +159,8 @@ func (r *TCPAppResolver) resolveTCPHandlerForCluster( Limit: 1, }) if err != nil { - // Don't return an error so we can try to find the app in different clusters or forward the request upstream. + // Don't return an error so we can try to find the app in different clusters or forward the request + // upstream. slog.InfoContext(ctx, "Failed to list application servers.", "error", err) return nil, false, nil } @@ -127,7 +169,7 @@ func (r *TCPAppResolver) resolveTCPHandlerForCluster( return nil, false, nil } app := resp.Resources[0].GetApp() - appHandler, err := newTCPAppHandler(ctx, r.appProvider, profileName, leafClusterName, app) + appHandler, err := r.newTCPAppHandler(ctx, profileName, leafClusterName, app) if err != nil { return nil, false, trace.Wrap(err) } @@ -138,29 +180,78 @@ type tcpAppHandler struct { profileName string leafClusterName string app types.Application + lp *alpnproxy.LocalProxy } -func newTCPAppHandler( +func (r *TCPAppResolver) newTCPAppHandler( ctx context.Context, - appProvider AppProvider, profileName string, leafClusterName string, app types.Application, ) (*tcpAppHandler, error) { + dialOpts, err := r.appProvider.GetDialOptions(ctx, profileName) + if err != nil { + return nil, trace.Wrap(err, "getting dial options for profile %q", profileName) + } + + appCertIssuer := &appCertIssuer{ + appProvider: r.appProvider, + profileName: profileName, + leafClusterName: leafClusterName, + app: app, + } + middleware := client.NewCertChecker(appCertIssuer, r.clock) + + localProxyConfig := alpnproxy.LocalProxyConfig{ + RemoteProxyAddr: dialOpts.WebProxyAddr, + Protocols: []alpncommon.Protocol{alpncommon.ProtocolTCP}, + ParentContext: ctx, + SNI: dialOpts.SNI, + RootCAs: dialOpts.RootClusterCACertPool, + ALPNConnUpgradeRequired: dialOpts.ALPNConnUpgradeRequired, + Middleware: middleware, + InsecureSkipVerify: dialOpts.InsecureSkipVerify, + Clock: r.clock, + } + + lp, err := alpnproxy.NewLocalProxy(localProxyConfig) + if err != nil { + return nil, trace.Wrap(err, "creating local proxy") + } + return &tcpAppHandler{ profileName: profileName, leafClusterName: leafClusterName, app: app, + lp: lp, }, nil } +// HandleTCPConnector handles an incoming TCP connection from VNet by passing it to the local alpn proxy, +// which is set up with middleware to automatically handler certificate renewal and re-logins. func (h *tcpAppHandler) HandleTCPConnector(ctx context.Context, connector func() (net.Conn, error)) error { - conn, err := connector() - if err != nil { - return trace.Wrap(err) - } - // HandleTCPConnector not implemented yet - just echo input back to output. - return trace.Wrap(utils.ProxyConn(ctx, conn, conn)) + return trace.Wrap(h.lp.HandleTCPConnector(ctx, connector), "handling TCP connector") +} + +// appCertIssuer implements [client.CertIssuer]. +type appCertIssuer struct { + appProvider AppProvider + profileName string + leafClusterName string + app types.Application + group singleflight.Group +} + +func (i *appCertIssuer) CheckCert(cert *x509.Certificate) error { + // appCertIssuer does not perform any additional certificate checks. + return nil +} + +func (i *appCertIssuer) IssueCert(ctx context.Context) (tls.Certificate, error) { + cert, err, _ := i.group.Do("", func() (any, error) { + return i.appProvider.ReissueAppCert(ctx, i.profileName, i.leafClusterName, i.app) + }) + return cert.(tls.Certificate), trace.Wrap(err) } func isSubdomain(appFQDN, proxyAddress string) bool { diff --git a/lib/vnet/vnet.go b/lib/vnet/vnet.go index eb3d0e148415a..0a5104f9b00b0 100644 --- a/lib/vnet/vnet.go +++ b/lib/vnet/vnet.go @@ -380,7 +380,6 @@ func (m *Manager) handleTCP(req *tcp.ForwarderRequest) { var wq waiter.Queue waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventErr | waiter.EventHUp) wq.EventRegister(&waitEntry) - defer wq.EventUnregister(&waitEntry) endpoint, err := req.CreateEndpoint(&wq) if err != nil { @@ -393,22 +392,23 @@ func (m *Manager) handleTCP(req *tcp.ForwarderRequest) { endpoint.SocketOptions().SetKeepAlive(true) - conn, connClosed := newConnWithCloseNotifier(gonet.NewTCPConn(&wq, endpoint)) + conn := gonet.NewTCPConn(&wq, endpoint) m.wg.Add(1) go func() { - defer m.wg.Done() + defer func() { + cancel() + conn.Close() + m.wg.Done() + }() select { - case <-connClosed: - // Conn is already being closed, nothing to do. - return case <-notifyCh: - slog.DebugContext(ctx, "Got HUP or ERR, closing TCP conn.") + slog.DebugContext(ctx, "Got HUP or ERR, canceling request context and closing TCP conn.") case <-m.destroyed: - slog.DebugContext(ctx, "VNet is being destroyed, closing TCP conn.") + slog.DebugContext(ctx, "VNet is being destroyed, canceling request context and closing TCP conn.") + case <-ctx.Done(): + slog.DebugContext(ctx, "Request context canceled, closing TCP conn.") } - cancel() - conn.Close() }() return conn, nil @@ -474,6 +474,9 @@ func (m *Manager) handleUDPConcurrent(req *udp.ForwarderRequest) { } var wq waiter.Queue + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventErr | waiter.EventHUp) + wq.EventRegister(&waitEntry) + endpoint, err := req.CreateEndpoint(&wq) if err != nil { slog.ErrorContext(ctx, "Failed to create UDP endpoint.", "error", err) @@ -483,6 +486,23 @@ func (m *Manager) handleUDPConcurrent(req *udp.ForwarderRequest) { conn := gonet.NewUDPConn(m.stack, &wq, endpoint) defer conn.Close() + m.wg.Add(1) + go func() { + defer func() { + cancel() + conn.Close() + m.wg.Done() + }() + select { + case <-notifyCh: + slog.DebugContext(ctx, "Got HUP or ERR, canceling request context and closing UDP conn.") + case <-m.destroyed: + slog.DebugContext(ctx, "VNet is being destroyed, canceling request context and closing UDP conn.") + case <-ctx.Done(): + slog.DebugContext(ctx, "Request context canceled, closing UDP conn.") + } + }() + if err := handler.HandleUDP(ctx, conn); err != nil { slog.DebugContext(ctx, "Error handling UDP conn.", "error", err) } @@ -684,22 +704,3 @@ func u32ToBytes(i uint32) []byte { bytes[3] = byte(i >> 0) return bytes } - -// newConnWithCloseNotifier returns a net.Conn and a channel that will be closed when the conn is closed. -func newConnWithCloseNotifier(conn *gonet.TCPConn) (net.Conn, <-chan struct{}) { - ch := make(chan struct{}) - return &connWithCloseNotifier{ - TCPConn: conn, - closeOnce: sync.OnceFunc(func() { close(ch) }), - }, ch -} - -type connWithCloseNotifier struct { - *gonet.TCPConn - closeOnce func() -} - -func (c *connWithCloseNotifier) Close() error { - c.closeOnce() - return c.TCPConn.Close() -} diff --git a/lib/vnet/vnet_test.go b/lib/vnet/vnet_test.go index b7a87fab2f3da..e1feb60643d66 100644 --- a/lib/vnet/vnet_test.go +++ b/lib/vnet/vnet_test.go @@ -17,11 +17,18 @@ package vnet import ( + "bytes" "context" + "crypto/ed25519" "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "errors" + "fmt" "io" "log/slog" + "math/big" "net" "os" "strings" @@ -30,6 +37,7 @@ import ( "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -59,7 +67,7 @@ type testPack struct { localAddress tcpip.Address } -func newTestPack(t *testing.T, ctx context.Context, appProvider AppProvider) *testPack { +func newTestPack(t *testing.T, ctx context.Context, clock clockwork.FakeClock, appProvider AppProvider) *testPack { // Create two sides of an emulated TUN interface: writes to one can be read on the other, and vice versa. tun1, tun2 := newSplitTUN() @@ -109,14 +117,14 @@ func newTestPack(t *testing.T, ctx context.Context, appProvider AppProvider) *te dnsIPv6 := ipv6WithSuffix(vnetIPv6Prefix, []byte{2}) - tcpHandlerResolver := NewTCPAppResolver(appProvider) + tcpAppResolver := NewTCPAppResolver(appProvider, withClock(clock)) // Create the VNet and connect it to the other side of the TUN. manager, err := NewManager(&Config{ TUNDevice: tun2, IPv6Prefix: vnetIPv6Prefix, DNSIPv6: dnsIPv6, - TCPHandlerResolver: tcpHandlerResolver, + TCPHandlerResolver: tcpAppResolver, upstreamNameserverSource: noUpstreamNameservers{}, }) require.NoError(t, err) @@ -218,14 +226,18 @@ func (n noUpstreamNameservers) UpstreamNameservers(ctx context.Context) ([]strin } type echoAppProvider struct { - profiles []string - clients map[string]map[string]*client.ClusterClient + profiles []string + clients map[string]map[string]*client.ClusterClient + dialOpts DialOptions + reissueAppCert func() tls.Certificate } -// newEchoAppProvider returns a fake app provider with the list of named apps in each profile and leaf cluster. -func newEchoAppProvider(apps map[string]map[string][]string) *echoAppProvider { +// newEchoAppProvider returns an app provider with the list of named apps in each profile and leaf cluster. +func newEchoAppProvider(apps map[string]map[string][]string, dialOpts DialOptions, reissueAppCert func() tls.Certificate) *echoAppProvider { p := &echoAppProvider{ - clients: make(map[string]map[string]*client.ClusterClient, len(apps)), + clients: make(map[string]map[string]*client.ClusterClient, len(apps)), + dialOpts: dialOpts, + reissueAppCert: reissueAppCert, } for profileName, leafClusters := range apps { p.profiles = append(p.profiles, profileName) @@ -262,6 +274,14 @@ func (p *echoAppProvider) GetCachedClient(ctx context.Context, profileName, leaf return c, nil } +func (p *echoAppProvider) ReissueAppCert(ctx context.Context, profileName, leafClusterName string, app types.Application) (tls.Certificate, error) { + return p.reissueAppCert(), nil +} + +func (p *echoAppProvider) GetDialOptions(ctx context.Context, profileName string) (*DialOptions, error) { + return &p.dialOpts, nil +} + // echoAppAuthClient is a fake auth client that answers GetResources requests with a static list of apps and // basic/faked predicate filtering. type echoAppAuthClient struct { @@ -291,7 +311,7 @@ func (c *echoAppAuthClient) GetResources(ctx context.Context, req *proto.ListRes Name: app, }, Spec: types.AppSpecV3{ - PublicAddr: appPublicAddr, + PublicAddr: app, }, }, }, @@ -308,6 +328,94 @@ func TestDialFakeApp(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) + clock := clockwork.NewFakeClockAt(time.Now()) + + ca := newSelfSignedCA(t) + + roots := x509.NewCertPool() + caX509, err := x509.ParseCertificate(ca.Certificate[0]) + require.NoError(t, err) + roots.AddCert(caX509) + + const proxyCN = "testproxy" + proxyCert := newServerCert(t, ca, proxyCN, clock.Now().Add(365*24*time.Hour)) + + proxyTLSConfig := &tls.Config{ + Certificates: []tls.Certificate{proxyCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: roots, + } + + listener, err := tls.Listen("tcp", "localhost:0", proxyTLSConfig) + require.NoError(t, err) + + // Run a fake web proxy that will accept any client connection and echo the input back. + utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{ + Name: "web proxy", + Task: func(ctx context.Context) error { + for { + conn, err := listener.Accept() + if err != nil { + if utils.IsOKNetworkError(err) { + return nil + } + return trace.Wrap(err) + } + go func() { + defer conn.Close() + + // Not using require/assert here and below because this is not in the right subtest or in + // the main test goroutine. The test will fail if the conn is not handled. + tlsConn, ok := conn.(*tls.Conn) + if !ok { + t.Log("client conn is not TLS") + return + } + if err := tlsConn.Handshake(); err != nil { + t.Log("error completing tls handshake") + return + } + clientCerts := tlsConn.ConnectionState().PeerCertificates + if len(clientCerts) == 0 { + t.Log("client has no certs") + return + } + // Manually checking the cert expiry compared to the time of the fake clock, since the TLS + // library will only compare the cert expiry to the real clock. + // It's important that the fake clock is never far behind the real clock, and that the + // cert NotBefore is always at/before the real current time, so the TLS library is + // satisfied. + if clock.Now().After(clientCerts[0].NotAfter) { + t.Logf("client cert is expired: currentTime=%s expiry=%s", clock.Now(), clientCerts[0].NotAfter) + return + } + + _, err := io.Copy(conn, conn) + if err != nil && !utils.IsOKNetworkError(err) { + t.Logf("error in io.Copy for echo proxy server: %v", err) + } + }() + } + }, + Terminate: func() error { + if err := listener.Close(); !utils.IsOKNetworkError(err) { + return trace.Wrap(err) + } + return nil + }, + }) + + dialOpts := DialOptions{ + WebProxyAddr: listener.Addr().String(), + RootClusterCACertPool: roots, + SNI: proxyCN, + } + + const appCertLifetime = time.Hour + reissueClientCert := func() tls.Certificate { + return newClientCert(t, ca, "testclient", clock.Now().Add(appCertLifetime)) + } + appProvider := newEchoAppProvider(map[string]map[string][]string{ "root1.example.com": map[string][]string{ "": {"echo1", "echo2"}, @@ -317,7 +425,7 @@ func TestDialFakeApp(t *testing.T) { "": {"echo1", "echo2"}, "leaf2.example.com": {"echo1"}, }, - }) + }, dialOpts, reissueClientCert) validAppNames := []string{ "echo1.root1.example.com", @@ -333,51 +441,74 @@ func TestDialFakeApp(t *testing.T) { "echo1.leaf2.example.com.", } - p := newTestPack(t, ctx, appProvider) + p := newTestPack(t, ctx, clock, appProvider) t.Run("valid", func(t *testing.T) { t.Parallel() - for _, app := range validAppNames { - app := app - t.Run(app, func(t *testing.T) { - t.Parallel() - conn, err := p.dialHost(ctx, app) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, conn.Close()) }) - - testString := "Hello, World!" - writeBuf := []byte(testString) - _, err = conn.Write(writeBuf) - require.NoError(t, err) - - readBuf := make([]byte, len(writeBuf)) - _, err = io.ReadFull(conn, readBuf) - require.NoError(t, err) - require.Equal(t, string(writeBuf), string(readBuf)) + // Connect to each app 3 times, advancing the clock past the cert lifetime between each + // connection to trigger a cert refresh. + // + // It's important not to run these subtests which advance a shared clock in parallel. It's okay for + // the inner app dial/connection tests to run in parallel because they don't advance the clock. + for i := 0; i < 3; i++ { + t.Run(fmt.Sprint(i), func(t *testing.T) { + for _, app := range validAppNames { + app := app + t.Run(app, func(t *testing.T) { + t.Parallel() + + conn, err := p.dialHost(ctx, app) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, conn.Close()) }) + + testEchoConnection(t, conn) + }) + } }) + clock.Advance(2 * appCertLifetime) } }) - // Tests with invalid hostnames just check that we don't return an answer and nothing panics. t.Run("invalid", func(t *testing.T) { + // It's safe to run these invalid app tests in parallel because they fail the DNS lookup and don't + // even make it to a TCP dial, so the clock used for TLS cert expiry doesn't matter. t.Parallel() for _, app := range invalidAppNames { app := app t.Run("invalid/"+app, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond) defer cancel() _, err := p.lookupHost(ctx, app) - var dnsError *net.DNSError - require.ErrorAs(t, err, &dnsError) - require.True(t, dnsError.IsTimeout, "expected DNS timeout error, got %+v", dnsError) + require.Error(t, err, "asdf") }) } }) } +func testEchoConnection(t *testing.T, conn net.Conn) { + const testString = "1........." + writeBuf := bytes.Repeat([]byte(testString), 200) + readBuf := make([]byte, len(writeBuf)) + + for i := 0; i < 10; i++ { + written, err := conn.Write(writeBuf) + for written < len(writeBuf) && err == nil { + var n int + n, err = conn.Write(writeBuf[written:]) + written += n + } + require.NoError(t, err) + require.Equal(t, len(writeBuf), written) + + n, err := io.ReadFull(conn, readBuf) + require.NoError(t, err) + require.Equal(t, string(writeBuf), string(readBuf[:n])) + } +} + func randomULAAddress() (tcpip.Address, error) { var bytes [16]byte bytes[0] = 0xfd @@ -463,3 +594,63 @@ func (f *fakeTUN) Close() error { f.closeOnce() return nil } + +func newSelfSignedCA(t *testing.T) tls.Certificate { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "test-ca", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLenZero: true, + } + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv) + require.NoError(t, err) + + return tls.Certificate{ + Certificate: [][]byte{certBytes}, + PrivateKey: priv, + } +} + +func newServerCert(t *testing.T, ca tls.Certificate, cn string, expires time.Time) tls.Certificate { + return newLeafCert(t, ca, cn, expires, x509.ExtKeyUsageServerAuth) +} + +func newClientCert(t *testing.T, ca tls.Certificate, cn string, expires time.Time) tls.Certificate { + return newLeafCert(t, ca, cn, expires, x509.ExtKeyUsageClientAuth) +} + +func newLeafCert(t *testing.T, ca tls.Certificate, cn string, expires time.Time, keyUsage x509.ExtKeyUsage) tls.Certificate { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + caCert, err := x509.ParseCertificate(ca.Certificate[0]) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + CommonName: cn, + }, + NotBefore: time.Now(), + NotAfter: expires, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{keyUsage}, + DNSNames: []string{cn}, + } + certBytes, err := x509.CreateCertificate(rand.Reader, &template, caCert, pub, ca.PrivateKey) + require.NoError(t, err) + + return tls.Certificate{ + Certificate: [][]byte{certBytes}, + PrivateKey: priv, + } +} diff --git a/tool/tsh/common/app.go b/tool/tsh/common/app.go index 7396dc7d6491e..385ad05d0c24f 100644 --- a/tool/tsh/common/app.go +++ b/tool/tsh/common/app.go @@ -19,6 +19,7 @@ package common import ( + "context" "crypto/tls" "crypto/x509/pkix" "fmt" @@ -74,7 +75,6 @@ func onAppLogin(cf *CLIConf) error { AccessRequests: profile.ActiveRequests.AccessRequests, } - // TODO (Joerger): DELETE IN v17.0.0 clusterClient, err := tc.ConnectToCluster(cf.Context) if err != nil { return trace.Wrap(err) @@ -83,12 +83,8 @@ func onAppLogin(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - appCertParams.RouteToApp.SessionID, err = authclient.TryCreateAppSessionForClientCertV15(cf.Context, rootClient, tc.Username, appCertParams.RouteToApp) - if err != nil { - return trace.Wrap(err) - } - key, _, err := clusterClient.IssueUserCertsWithMFA(cf.Context, appCertParams, tc.NewMFAPrompt(mfa.WithPromptReasonSessionMFA("Application", app.GetName()))) + key, err := appLogin(cf.Context, tc, clusterClient, rootClient, appCertParams) if err != nil { return trace.Wrap(err) } @@ -104,6 +100,25 @@ func onAppLogin(cf *CLIConf) error { return nil } +func appLogin( + ctx context.Context, + tc *client.TeleportClient, + clusterClient *client.ClusterClient, + rootClient authclient.ClientI, + appCertParams client.ReissueParams, +) (*client.Key, error) { + // TODO (Joerger): DELETE IN v17.0.0 + var err error + appCertParams.RouteToApp.SessionID, err = authclient.TryCreateAppSessionForClientCertV15(ctx, rootClient, tc.Username, appCertParams.RouteToApp) + if err != nil { + return nil, trace.Wrap(err) + } + + key, _, err := clusterClient.IssueUserCertsWithMFA(ctx, appCertParams, + tc.NewMFAPrompt(mfa.WithPromptReasonSessionMFA("Application", appCertParams.RouteToApp.Name))) + return key, trace.Wrap(err) +} + func getRouteToApp(cf *CLIConf, tc *client.TeleportClient, profile *client.ProfileStatus, app types.Application) (proto.RouteToApp, error) { var awsRoleARN string if app.IsAWSConsole() { diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index 8946a15434d55..6509a54e9aebc 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -606,7 +606,7 @@ func loadAppCertificateWithAppLogin(cf *CLIConf, tc *libclient.TeleportClient, a if !needLogin { return tls.Certificate{}, trace.Wrap(err) } - log.WithError(err).Debugf("Loading app certificate failed, attempting to login into app %q", appName) + log.WithError(err).Debugf("Loading app certificate failed, attempting to login to app %q", appName) quiet := cf.Quiet cf.Quiet = true errLogin := onAppLogin(cf) diff --git a/tool/tsh/common/vnet_common.go b/tool/tsh/common/vnet_common.go index a42c5d024518c..dd61b75bab4e6 100644 --- a/tool/tsh/common/vnet_common.go +++ b/tool/tsh/common/vnet_common.go @@ -18,24 +18,50 @@ package common import ( "context" + "crypto/tls" + "crypto/x509" + "fmt" + "log/slog" + "sync" "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/client/clientcache" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/vnet" ) // vnetAppProvider implement [vnet.AppProvider] in order to provide the necessary methods to log in to apps // and get clients able to list apps in all clusters in all current profiles. type vnetAppProvider struct { + cf *CLIConf clientStore *client.Store + clientCache *clientcache.Cache + loginMu sync.Mutex } func newVnetAppProvider(cf *CLIConf) (*vnetAppProvider, error) { clientStore := client.NewFSClientStore(cf.HomePath) - return &vnetAppProvider{ + p := &vnetAppProvider{ + cf: cf, clientStore: clientStore, - }, nil + } + + clientCache, err := clientcache.New(clientcache.Config{ + NewClientFunc: clientcache.NewClientFunc(p.newTeleportClient), + RetryWithReloginFunc: clientcache.RetryWithReloginFunc(p.retryWithRelogin), + }) + if err != nil { + return nil, trace.Wrap(err, "creating client cache") + } + + p.clientCache = clientCache + return p, nil + } // ListProfiles lists the names of all profiles saved for the user. @@ -43,10 +69,144 @@ func (p *vnetAppProvider) ListProfiles() ([]string, error) { return p.clientStore.ListProfiles() } -// GetCachedClient returns a [*client.ClusterClient] for the given profile and leaf cluster. +// GetCachedClient returns a cached [*client.ClusterClient] for the given profile and leaf cluster. // [leafClusterName] may be empty when requesting a client for the root cluster. -// TODO: cache clients across calls. func (p *vnetAppProvider) GetCachedClient(ctx context.Context, profileName, leafClusterName string) (*client.ClusterClient, error) { + return p.clientCache.Get(ctx, profileName, leafClusterName) +} + +// ReissueAppCert returns a new app certificate for the given app in the named profile and leaf cluster. +// It uses retryWithRelogin to issue the new app cert. A relogin may not be necessary if the app cert lifetime +// was shorter than the cluster cert lifetime, or if the user has already re-logged in to the cluster. +// If a cluster relogin is completed, the cluster client cache will be cleared for the root cluster and all +// leaf clusters of that root. +func (p *vnetAppProvider) ReissueAppCert(ctx context.Context, profileName, leafClusterName string, app types.Application) (tls.Certificate, error) { + tc, err := p.newTeleportClient(ctx, profileName, leafClusterName) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + var cert tls.Certificate + err = p.retryWithRelogin(ctx, tc, func() error { + var err error + cert, err = p.reissueAppCert(ctx, tc, profileName, leafClusterName, app) + return trace.Wrap(err, "reissuing app cert") + }) + return cert, trace.Wrap(err) +} + +// GetDialOptions returns ALPN dial options for the profile. +func (p *vnetAppProvider) GetDialOptions(ctx context.Context, profileName string) (*vnet.DialOptions, error) { + profile, err := p.clientStore.GetProfile(profileName) + if err != nil { + return nil, trace.Wrap(err, "loading user profile") + } + dialOpts := &vnet.DialOptions{ + WebProxyAddr: profile.WebProxyAddr, + ALPNConnUpgradeRequired: profile.TLSRoutingConnUpgradeRequired, + } + if dialOpts.ALPNConnUpgradeRequired { + dialOpts.RootClusterCACertPool, err = p.getRootClusterCACertPool(ctx, profileName) + if err != nil { + return nil, trace.Wrap(err) + } + } + return dialOpts, nil +} + +// getRootClusterCACertPool returns a certificate pool for the root cluster of the given profile. +func (p *vnetAppProvider) getRootClusterCACertPool(ctx context.Context, profileName string) (*x509.CertPool, error) { + tc, err := p.newTeleportClient(ctx, profileName, "") + if err != nil { + return nil, trace.Wrap(err, "creating new client") + } + certPool, err := tc.RootClusterCACertPool(ctx) + if err != nil { + return nil, trace.Wrap(err, "loading root cluster CA cert pool") + } + return certPool, nil +} + +func (p *vnetAppProvider) retryWithRelogin(ctx context.Context, tc *client.TeleportClient, fn func() error, opts ...client.RetryWithReloginOption) error { + profileName, err := utils.Host(tc.WebProxyAddr) + if err != nil { + return trace.Wrap(err) + } + + // Make sure the release the login mutex if we end up acquiring it. + didLock := false + defer func() { + if didLock { + p.loginMu.Unlock() + } + }() + + opts = append(opts, + client.WithBeforeLoginHook(func() error { + // Multiple concurrent logins in tsh would be bad UX, especially when MFA is involved, so we only + // allow one login at a time. If another login is already in progress this just returns an error + // and no login will be attempted. Subsequent relogins can be attempted on the next client request + // after the current one finishes. + if p.loginMu.TryLock() { + didLock = true + } else { + return fmt.Errorf("not attempting re-login to cluster %s, another login is current in progress.", tc.SiteName) + } + fmt.Printf("Login for cluster %s expired, attempting to log in again.\n", tc.SiteName) + return nil + }), + client.WithAfterLoginHook(func() error { + return trace.Wrap(p.clientCache.ClearForRoot(profileName), "clearing client cache after relogin") + }), + client.WithMakeCurrentProfile(false), + ) + return client.RetryWithRelogin(ctx, tc, fn, opts...) +} + +func (p *vnetAppProvider) reissueAppCert(ctx context.Context, tc *client.TeleportClient, profileName, leafClusterName string, app types.Application) (tls.Certificate, error) { + slog.InfoContext(ctx, "Reissuing cert for app.", "app_name", app.GetName(), "profile", profileName, "leaf_cluster", leafClusterName) + + routeToApp := proto.RouteToApp{ + Name: app.GetName(), + PublicAddr: app.GetPublicAddr(), + ClusterName: tc.SiteName, + } + + profile, err := tc.ProfileStatus() + if err != nil { + return tls.Certificate{}, trace.Wrap(err, "loading client profile") + } + + appCertParams := client.ReissueParams{ + RouteToCluster: tc.SiteName, + RouteToApp: routeToApp, + AccessRequests: profile.ActiveRequests.AccessRequests, + RequesterName: proto.UserCertsRequest_TSH_APP_LOCAL_PROXY, + } + + clusterClient, err := p.GetCachedClient(ctx, profileName, leafClusterName) + if err != nil { + return tls.Certificate{}, trace.Wrap(err, "getting cached cluster client") + } + rootClient, err := p.GetCachedClient(ctx, profileName, "") + if err != nil { + return tls.Certificate{}, trace.Wrap(err, "getting cached root client") + } + + key, err := appLogin(ctx, tc, clusterClient, rootClient.AuthClient, appCertParams) + if err != nil { + return tls.Certificate{}, trace.Wrap(err, "logging in to app") + } + + cert, err := key.AppTLSCert(app.GetName()) + if err != nil { + return tls.Certificate{}, trace.Wrap(err, "getting TLS cert from key") + } + + return cert, nil +} + +func (p *vnetAppProvider) newTeleportClient(ctx context.Context, profileName, leafClusterName string) (*client.TeleportClient, error) { cfg := &client.Config{ ClientStore: p.clientStore, } @@ -60,7 +220,5 @@ func (p *vnetAppProvider) GetCachedClient(ctx context.Context, profileName, leaf if err != nil { return nil, trace.Wrap(err, "creating new client") } - - clusterClient, err := tc.ConnectToCluster(ctx) - return clusterClient, trace.Wrap(err) + return tc, nil }