diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index 34fc9c56f0f5f..e766bb9ae4dc8 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -531,12 +531,7 @@ func mustStartALPNLocalProxyWithConfig(t *testing.T, config alpnproxy.LocalProxy }) go func() { - var err error - if config.HTTPMiddleware == nil { - err = lp.Start(context.Background()) - } else { - err = lp.StartHTTPAccessProxy(context.Background()) - } + err := lp.Start(context.Background()) assert.NoError(t, err) }() return lp diff --git a/lib/client/api.go b/lib/client/api.go index 173654cabb1af..26d8eaff141b7 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -593,7 +593,7 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error, case utils.IsPredicateError(fnErr): return trace.Wrap(utils.PredicateError{Err: fnErr}) case tc.NonInteractive: - return trace.Wrap(fnErr) + return trace.Wrap(fnErr, "cannot relogin in non-interactive session") case !IsErrorResolvableWithRelogin(fnErr): return trace.Wrap(fnErr) } diff --git a/lib/client/client_store.go b/lib/client/client_store.go index ab55d629bba54..0b0b4f10b120c 100644 --- a/lib/client/client_store.go +++ b/lib/client/client_store.go @@ -82,14 +82,21 @@ func (s *Store) AddKey(key *Key) error { return nil } -// ErrNoCredentials is returned by the client store when a specific key is not found. -// This error can be used to determine whether a client should retrieve new credentials, -// like how it is used with lib/client.RetryWithRelogin. -var ErrNoCredentials = trace.NotFound("no credentials") +var ( + // ErrNoCredentials is returned by the client store when a specific key is not found. + // This error can be used to determine whether a client should retrieve new credentials, + // like how it is used with lib/client.RetryWithRelogin. + ErrNoCredentials = &trace.NotFoundError{Message: "no credentials"} -// IsNoCredentialsError returns whether the given error is an ErrNoCredentials error. + // ErrNoProfile is returned by the client store when a specific profile is not found. + // This error can be used to determine whether a client should retrieve new credentials, + // like how it is used with lib/client.RetryWithRelogin. + ErrNoProfile = &trace.NotFoundError{Message: "no profile"} +) + +// IsNoCredentialsError returns whether the given error implies that the user should retrieve new credentials. func IsNoCredentialsError(err error) bool { - return errors.Is(err, ErrNoCredentials) + return errors.Is(err, ErrNoCredentials) || errors.Is(err, ErrNoProfile) } // GetKey gets the requested key with trusted the requested certificates. The key's @@ -161,6 +168,9 @@ func (s *Store) ReadProfileStatus(profileName string) (*ProfileStatus, error) { profile, err := s.GetProfile(profileName) if err != nil { + if trace.IsNotFound(err) { + return nil, trace.Wrap(ErrNoProfile, err.Error()) + } return nil, trace.Wrap(err) } idx := KeyIndex{ diff --git a/lib/client/local_proxy_middleware.go b/lib/client/local_proxy_middleware.go index aae23ba25996d..e093cdb6c6060 100644 --- a/lib/client/local_proxy_middleware.go +++ b/lib/client/local_proxy_middleware.go @@ -20,16 +20,26 @@ package client import ( "context" + "crypto" + "crypto/rand" + "crypto/rsa" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "net" + "os" + "sync" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/mfa" + "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -38,26 +48,31 @@ import ( // CertChecker is a local proxy middleware that ensures certs are valid // on start up and on each new connection. type CertChecker struct { - // CertReissuer checks and reissues certs. - CertReissuer CertIssuer + // certIssuer checks and issues certs. + certIssuer CertIssuer // clock specifies the time provider. Will be used to override the time anchor // for TLS certificate verification. Defaults to real clock if unspecified clock clockwork.Clock + + cert tls.Certificate + certMu sync.Mutex } var _ alpnproxy.LocalProxyMiddleware = (*CertChecker)(nil) +// NewCertChecker creates a new CertChecker with the given CertIssuer. func NewCertChecker(certIssuer CertIssuer, clock clockwork.Clock) *CertChecker { if clock == nil { clock = clockwork.NewRealClock() } + return &CertChecker{ - CertReissuer: certIssuer, - clock: clock, + certIssuer: certIssuer, + clock: clock, } } -// Create a new CertChecker for the given database. +// NewDBCertChecker creates a new CertChecker for the given database. func NewDBCertChecker(tc *TeleportClient, dbRoute tlsca.RouteToDatabase, clock clockwork.Clock) *CertChecker { return NewCertChecker(&DBCertIssuer{ Client: tc, @@ -65,7 +80,7 @@ func NewDBCertChecker(tc *TeleportClient, dbRoute tlsca.RouteToDatabase, clock c }, clock) } -// Create a new CertChecker for the given app. +// NewAppCertChecker creates a new CertChecker for the given app. func NewAppCertChecker(tc *TeleportClient, appRoute proto.RouteToApp, clock clockwork.Clock) *CertChecker { return NewCertChecker(&AppCertIssuer{ Client: tc, @@ -76,37 +91,72 @@ func NewAppCertChecker(tc *TeleportClient, appRoute proto.RouteToApp, clock cloc // OnNewConnection is a callback triggered when a new downstream connection is // accepted by the local proxy. func (c *CertChecker) OnNewConnection(ctx context.Context, lp *alpnproxy.LocalProxy) error { - return trace.Wrap(c.ensureValidCerts(ctx, lp)) + cert, err := c.GetOrIssueCert(ctx) + if err != nil { + return trace.Wrap(err) + } + + lp.SetCert(cert) + return nil } // OnStart is a callback triggered when the local proxy starts. func (c *CertChecker) OnStart(ctx context.Context, lp *alpnproxy.LocalProxy) error { - return trace.Wrap(c.ensureValidCerts(ctx, lp)) + cert, err := c.GetOrIssueCert(ctx) + if err != nil { + return trace.Wrap(err) + } + + lp.SetCert(cert) + return nil } -// ensureValidCerts ensures that the local proxy is configured with valid certs. -func (c *CertChecker) ensureValidCerts(ctx context.Context, lp *alpnproxy.LocalProxy) error { - if err := lp.CheckCert(c.CertReissuer.CheckCert); err != nil { - log.WithError(err).Debug("local proxy tunnel certificates need to be reissued") - } else { - return nil +// SetCert sets the CertChecker's certificate. +func (c *CertChecker) SetCert(cert tls.Certificate) { + c.certMu.Lock() + defer c.certMu.Unlock() + c.cert = cert +} + +// GetOrIssueCert gets the CertChecker's certificate, or issues a new +// certificate if the it is invalid (e.g. expired) or missing. +func (c *CertChecker) GetOrIssueCert(ctx context.Context) (tls.Certificate, error) { + c.certMu.Lock() + defer c.certMu.Unlock() + + if err := c.checkCert(); err == nil { + return c.cert, nil } - cert, err := c.CertReissuer.IssueCert(ctx) + cert, err := c.certIssuer.IssueCert(ctx) if err != nil { - return trace.Wrap(err) + return tls.Certificate{}, trace.Wrap(err) } // reduce per-handshake processing by setting the parsed leaf. if err := utils.InitCertLeaf(&cert); err != nil { - return trace.Wrap(err) + return tls.Certificate{}, trace.Wrap(err) } certTTL := cert.Leaf.NotAfter.Sub(c.clock.Now()).Round(time.Minute) log.Debugf("Certificate renewed: valid until %s [valid for %v]", cert.Leaf.NotAfter.Format(time.RFC3339), certTTL) - lp.SetCert(cert) - return nil + c.cert = cert + return c.cert, nil +} + +func (c *CertChecker) checkCert() error { + leaf, err := utils.TLSCertLeaf(c.cert) + if err != nil { + return trace.Wrap(err) + } + + // Check for cert expiration. + if err := utils.VerifyCertificateExpiry(leaf, c.clock); err != nil { + return trace.Wrap(err) + } + + return trace.Wrap(c.certIssuer.CheckCert(leaf)) } // CertIssuer checks and issues certs. @@ -139,7 +189,7 @@ func (c *DBCertIssuer) IssueCert(ctx context.Context) (tls.Certificate, error) { var key *Key if err := RetryWithRelogin(ctx, c.Client, func() error { - newKey, err := c.Client.IssueUserCertsWithMFA(ctx, ReissueParams{ + dbCertParams := ReissueParams{ RouteToCluster: c.Client.SiteName, RouteToDatabase: proto.RouteToDatabase{ ServiceName: c.RouteToApp.ServiceName, @@ -149,7 +199,26 @@ func (c *DBCertIssuer) IssueCert(ctx context.Context) (tls.Certificate, error) { }, AccessRequests: accessRequests, RequesterName: proto.UserCertsRequest_TSH_DB_LOCAL_PROXY_TUNNEL, - }, mfa.WithPromptReasonSessionMFA("database", c.RouteToApp.ServiceName)) + } + + clusterClient, err := c.Client.ConnectToCluster(ctx) + if err != nil { + return trace.Wrap(err) + } + + newKey, mfaRequired, err := clusterClient.IssueUserCertsWithMFA(ctx, dbCertParams, c.Client.NewMFAPrompt(mfa.WithPromptReasonSessionMFA("database", c.RouteToApp.ServiceName))) + if err != nil { + return trace.Wrap(err) + } + + // If MFA was not required, we do not require certs be stored solely in memory. + // Save it to disk to avoid additional roundtrips for future requests. + if mfaRequired == proto.MFARequired_MFA_REQUIRED_NO { + if err := c.Client.LocalAgent().AddDatabaseKey(newKey); err != nil { + return trace.Wrap(err) + } + } + key = newKey return trace.Wrap(err) }); err != nil { @@ -193,11 +262,12 @@ func (c *AppCertIssuer) IssueCert(ctx context.Context) (tls.Certificate, error) RequesterName: proto.UserCertsRequest_TSH_APP_LOCAL_PROXY, } - // TODO (Joerger): DELETE IN v17.0.0 clusterClient, err := c.Client.ConnectToCluster(ctx) if err != nil { return trace.Wrap(err) } + + // TODO (Joerger): DELETE IN v17.0.0 rootClient, err := clusterClient.ConnectToRootCluster(ctx) if err != nil { return trace.Wrap(err) @@ -207,7 +277,19 @@ func (c *AppCertIssuer) IssueCert(ctx context.Context) (tls.Certificate, error) return trace.Wrap(err) } - newKey, _, err := clusterClient.IssueUserCertsWithMFA(ctx, appCertParams, c.Client.NewMFAPrompt(mfa.WithPromptReasonSessionMFA("application", c.RouteToApp.Name))) + newKey, mfaRequired, err := clusterClient.IssueUserCertsWithMFA(ctx, appCertParams, c.Client.NewMFAPrompt(mfa.WithPromptReasonSessionMFA("application", c.RouteToApp.Name))) + if err != nil { + return trace.Wrap(err) + } + + // If MFA was not required, we do not require certs be stored solely in memory. + // Save it to disk to avoid additional roundtrips for future requests. + if mfaRequired == proto.MFARequired_MFA_REQUIRED_NO { + if err := c.Client.LocalAgent().AddAppKey(newKey); err != nil { + return trace.Wrap(err) + } + } + key = newKey return trace.Wrap(err) }); err != nil { @@ -220,3 +302,196 @@ func (c *AppCertIssuer) IssueCert(ctx context.Context) (tls.Certificate, error) } return appCert, nil } + +// LocalCertGenerator is a TLS Certificate generator used to inject +// valid TLS certificates based on SNI during local HTTPS handshakes. +type LocalCertGenerator struct { + certChecker *CertChecker + caPath string + + mu sync.Mutex + // ca is the certificate authority for signing certificates. + ca tls.Certificate + // certsByHost is a cache of certs for hosts generated with the local CA. + // The key is the host's servername SNI. + certsByHost map[string]*tls.Certificate +} + +// NewLocalCertGenerator creates a new LocalCertGenerator and listens to the +// configured listen address. +func NewLocalCertGenerator(ctx context.Context, certChecker *CertChecker, caPath string) (*LocalCertGenerator, error) { + r := &LocalCertGenerator{ + certChecker: certChecker, + caPath: caPath, + } + + if err := r.ensureValidCA(ctx); err != nil { + return nil, trace.Wrap(err) + } + + return r, nil +} + +// GetCertificate generates and returns TLS certificate for incoming +// connection. Implements tls.Config.GetCertificate. +func (r *LocalCertGenerator) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if err := r.ensureValidCA(clientHello.Context()); err != nil { + return nil, trace.Wrap(err) + } + + cert, err := r.generateCert(clientHello.ServerName) + if err != nil { + return nil, trace.WrapWithMessage(err, "failed to generate certificate for %q: %v", clientHello.ServerName, err) + } + + return cert, nil +} + +// generateCert generates a new certificate for the specified host. +func (r *LocalCertGenerator) generateCert(host string) (*tls.Certificate, error) { + r.mu.Lock() + defer r.mu.Unlock() + if cert, found := r.certsByHost[host]; found { + return cert, nil + } + + certKey, err := rsa.GenerateKey(rand.Reader, constants.RSAKeySize) + if err != nil { + return nil, trace.Wrap(err) + } + + certAuthority, err := tlsca.FromTLSCertificate(r.ca) + if err != nil { + return nil, trace.Wrap(err) + } + + subject := certAuthority.Cert.Subject + subject.CommonName = host + + certPem, err := certAuthority.GenerateCertificate(tlsca.CertificateRequest{ + PublicKey: &certKey.PublicKey, + Subject: subject, + NotAfter: certAuthority.Cert.NotAfter, + DNSNames: []string{host}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + cert, err := tls.X509KeyPair(certPem, tlsca.MarshalPrivateKeyPEM(certKey)) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := utils.InitCertLeaf(&cert); err != nil { + return nil, trace.Wrap(err) + } + + r.certsByHost[host] = &cert + return &cert, nil +} + +// ensureValidCA checks if the CA is valid. If it is no longer valid, generate a new +// CA and clear the host cert cache. +func (r *LocalCertGenerator) ensureValidCA(ctx context.Context) error { + r.mu.Lock() + defer r.mu.Unlock() + + // Check if the CA is invalid (expired) + if err := r.checkCA(); err == nil { + return nil + } + + // Generate a new CA from a valid remote cert. + remoteTLSCert, err := r.certChecker.GetOrIssueCert(ctx) + if err != nil { + return trace.Wrap(err) + } + + caTLSCert, err := generateSelfSignedCAFromCert(remoteTLSCert, r.caPath) + if err != nil { + return trace.Wrap(err) + } + + if err := utils.InitCertLeaf(&caTLSCert); err != nil { + return trace.Wrap(err) + } + + certTTL := time.Until(caTLSCert.Leaf.NotAfter).Round(time.Minute) + log.Debugf("Local CA renewed: valid until %s [valid for %v]", caTLSCert.Leaf.NotAfter.Format(time.RFC3339), certTTL) + + // Clear cert cache and use CA for hostnames in the CA. + r.certsByHost = make(map[string]*tls.Certificate) + for _, host := range caTLSCert.Leaf.DNSNames { + r.certsByHost[host] = &caTLSCert + } + + // Requests to IPs have no server names. Default to CA. + r.certsByHost[""] = &caTLSCert + + r.ca = caTLSCert + return nil +} + +func (r *LocalCertGenerator) checkCA() error { + caCert, err := utils.TLSCertLeaf(r.ca) + if err != nil { + return trace.Wrap(err) + } + + err = utils.VerifyCertificateExpiry(caCert, nil /*real clock*/) + return trace.Wrap(err) +} + +// generateSelfSignedCA generates a new self-signed CA for localhost +// and saves/overwrites the local CA file in the given path. +func generateSelfSignedCAFromCert(cert tls.Certificate, caPath string) (tls.Certificate, error) { + certExpiry, err := getTLSCertExpireTime(cert) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return tls.Certificate{}, trace.BadParameter("private key type %T does not implement crypto.Signer", cert.PrivateKey) + } + + certPem, err := tlsca.GenerateSelfSignedCAWithConfig(tlsca.GenerateCAConfig{ + Entity: pkix.Name{ + CommonName: "localhost", + Organization: []string{"Teleport"}, + }, + Signer: signer, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP(defaults.Localhost)}, + TTL: time.Until(certExpiry), + }) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + if _, err := utils.EnsureLocalPath(caPath, "", ""); err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + if err = os.WriteFile(caPath, certPem, 0o600); err != nil { + return tls.Certificate{}, trace.ConvertSystemError(err) + } + + keyPem, err := keys.MarshalPrivateKey(signer) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + caCert, err := tls.X509KeyPair(certPem, keyPem) + return caCert, trace.Wrap(err) +} + +// getTLSCertExpireTime returns the certificate NotAfter time. +func getTLSCertExpireTime(cert tls.Certificate) (time.Time, error) { + x509cert, err := utils.TLSCertLeaf(cert) + if err != nil { + return time.Time{}, trace.Wrap(err) + } + return x509cert.NotAfter, nil +} diff --git a/lib/client/local_proxy_middleware_test.go b/lib/client/local_proxy_middleware_test.go new file mode 100644 index 0000000000000..bc32ab4b16e32 --- /dev/null +++ b/lib/client/local_proxy_middleware_test.go @@ -0,0 +1,173 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package client_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "path/filepath" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/auth/native" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/tlsca" +) + +func TestCertChecker(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + certIssuer := newMockCertIssuer(t, clock) + certChecker := client.NewCertChecker(certIssuer, clock) + + // certChecker should issue a new cert on first request. + cert, err := certChecker.GetOrIssueCert(ctx) + require.NoError(t, err) + + // subsequent calls should return the same cert. + sameCert, err := certChecker.GetOrIssueCert(ctx) + require.NoError(t, err) + require.Equal(t, cert, sameCert) + + // If the current cert expires it should be reissued. + clock.Advance(2 * time.Minute) + expiredCert := cert + + cert, err = certChecker.GetOrIssueCert(ctx) + require.NoError(t, err) + require.NotEqual(t, cert, expiredCert) + + // If the current cert fails certIssuer checks, a new one should be issued. + certIssuer.checkErr = trace.BadParameter("bad cert") + badCert := cert + + cert, err = certChecker.GetOrIssueCert(ctx) + require.NoError(t, err) + require.NotEqual(t, cert, badCert) + + // If issuing a new cert fails, an error is returned. + certIssuer.issueErr = trace.BadParameter("failed to issue cert") + _, err = certChecker.GetOrIssueCert(ctx) + require.ErrorIs(t, err, certIssuer.issueErr, "expected error %v but got %v", certIssuer.issueErr, err) +} + +func TestLocalCertGenerator(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + certIssuer := newMockCertIssuer(t, clock) + certChecker := client.NewCertChecker(certIssuer, clock) + caPath := filepath.Join(t.TempDir(), "localca.pem") + + localCertGenerator, err := client.NewLocalCertGenerator(ctx, certChecker, caPath) + require.NoError(t, err) + + // The cert generator should return the local CA cert for SNIs "localhost" or empty (plain ip). + caCert, err := localCertGenerator.GetCertificate(&tls.ClientHelloInfo{ + ServerName: "localhost", + }) + require.NoError(t, err) + require.Equal(t, []string{"localhost"}, caCert.Leaf.DNSNames) + + cert, err := localCertGenerator.GetCertificate(&tls.ClientHelloInfo{ + ServerName: "", + }) + require.NoError(t, err) + require.Equal(t, caCert, cert) + + // The cert generator should issue new certs from the local CA for other SNIs. + exampleCert, err := localCertGenerator.GetCertificate(&tls.ClientHelloInfo{ + ServerName: "example.com", + }) + require.NoError(t, err) + require.Equal(t, []string{"example.com"}, exampleCert.Leaf.DNSNames) +} + +type mockCertIssuer struct { + ca *tlsca.CertAuthority + clock clockwork.Clock + checkErr error + issueErr error +} + +func newMockCertIssuer(t *testing.T, clock clockwork.Clock) *mockCertIssuer { + certIssuer := &mockCertIssuer{ + clock: clock, + } + + certIssuer.initCA(t) + return certIssuer +} + +func (c *mockCertIssuer) initCA(t *testing.T) { + priv, err := native.GeneratePrivateKey() + require.NoError(t, err) + + cert, err := tlsca.GenerateSelfSignedCAWithConfig(tlsca.GenerateCAConfig{ + Signer: priv, + Entity: pkix.Name{ + CommonName: "root", + Organization: []string{"teleport"}, + }, + TTL: defaults.CATTL, + Clock: c.clock, + }) + require.NoError(t, err) + + c.ca, err = tlsca.FromCertAndSigner(cert, priv) + require.NoError(t, err) +} + +func (c *mockCertIssuer) CheckCert(cert *x509.Certificate) error { + return trace.Wrap(c.checkErr) +} + +func (c *mockCertIssuer) IssueCert(ctx context.Context) (tls.Certificate, error) { + if c.issueErr != nil { + return tls.Certificate{}, trace.Wrap(c.issueErr) + } + + priv, err := native.GeneratePrivateKey() + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + certPem, err := c.ca.GenerateCertificate(tlsca.CertificateRequest{ + PublicKey: priv.Public(), + Subject: pkix.Name{ + CommonName: "user", + Organization: []string{"teleport"}, + }, + NotAfter: c.clock.Now().Add(time.Minute), + }) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + tlsCert, err := tls.X509KeyPair(certPem, priv.PrivateKeyPEM()) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + return tlsCert, nil +} diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index f76e421c073be..0aa7e75a193f2 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -165,20 +165,21 @@ func NewLocalProxy(cfg LocalProxyConfig, opts ...LocalProxyConfigOpt) (*LocalPro // Start starts the LocalProxy. func (l *LocalProxy) Start(ctx context.Context) error { + if l.cfg.Middleware != nil { + if err := l.cfg.Middleware.OnStart(ctx, l); err != nil { + return trace.Wrap(err) + } + } + if l.cfg.HTTPMiddleware != nil { - return trace.Wrap(l.StartHTTPAccessProxy(ctx)) + return trace.Wrap(l.startHTTPAccessProxy(ctx)) } + return trace.Wrap(l.start(ctx)) } // start starts the LocalProxy for raw TCP or raw TLS (non-HTTP) connections. func (l *LocalProxy) start(ctx context.Context) error { - if l.cfg.Middleware != nil { - err := l.cfg.Middleware.OnStart(ctx, l) - if err != nil { - return trace.Wrap(err) - } - } for { select { case <-ctx.Done(): @@ -342,23 +343,26 @@ func (l *LocalProxy) makeHTTPReverseProxy(certs ...tls.Certificate) *httputil.Re } } -// StartHTTPAccessProxy starts the local HTTP access proxy. -func (l *LocalProxy) StartHTTPAccessProxy(ctx context.Context) error { - if l.cfg.HTTPMiddleware == nil { - return trace.BadParameter("Missing HTTPMiddleware in configuration") - } - +// startHTTPAccessProxy starts the local HTTP access proxy. +func (l *LocalProxy) startHTTPAccessProxy(ctx context.Context) error { if err := l.cfg.HTTPMiddleware.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) } l.cfg.Log.Info("Starting HTTP access proxy") defer l.cfg.Log.Info("HTTP access proxy stopped") - defaultProxy := l.makeHTTPReverseProxy(l.getCert()) server := &http.Server{ ReadHeaderTimeout: defaults.ReadHeadersTimeout, Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if l.cfg.Middleware != nil { + if err := l.cfg.Middleware.OnNewConnection(ctx, l); err != nil { + l.cfg.Log.WithError(err).Error("Middleware failed to handle client request.") + trace.WriteError(rw, trace.Wrap(err)) + return + } + } + if l.cfg.HTTPMiddleware.HandleRequest(rw, req) { return } @@ -371,7 +375,7 @@ func (l *LocalProxy) StartHTTPAccessProxy(ctx context.Context) error { req.Header.Del("X-Forwarded-Host") } - proxy, err := l.getHTTPReverseProxyForReq(req, defaultProxy) + proxy, err := l.getHTTPReverseProxyForReq(req) if err != nil { l.cfg.Log.Warnf("Failed to get reverse proxy: %v.", err) trace.WriteError(rw, trace.Wrap(err)) @@ -396,49 +400,25 @@ func (l *LocalProxy) StartHTTPAccessProxy(ctx context.Context) error { return nil } -func (l *LocalProxy) getHTTPReverseProxyForReq(req *http.Request, defaultProxy *httputil.ReverseProxy) (*httputil.ReverseProxy, error) { +func (l *LocalProxy) getHTTPReverseProxyForReq(req *http.Request) (*httputil.ReverseProxy, error) { certs, err := l.cfg.HTTPMiddleware.OverwriteClientCerts(req) - if err != nil { - if trace.IsNotImplemented(err) { - return defaultProxy, nil - } + if trace.IsNotImplemented(err) { + return l.makeHTTPReverseProxy(l.getCert()), nil + } else if err != nil { return nil, trace.Wrap(err) } + + l.cfg.Log.Debug("overwrote certs") return l.makeHTTPReverseProxy(certs...), nil } -// getCerts returns the local proxy's configured TLS certificates. -// For thread-safety, it is important that the returned slice and its contents -// are not be mutated by callers, therefore this method is not exported. +// getCert returns the local proxy's configured TLS certificate. func (l *LocalProxy) getCert() tls.Certificate { l.certMu.RLock() defer l.certMu.RUnlock() return l.cfg.Cert } -// CheckCertExpiry checks the proxy certificates for expiration and runs given checking function. -func (l *LocalProxy) CheckCert(checkCert func(cert *x509.Certificate) error) error { - l.cfg.Log.Debug("checking local proxy certs") - l.certMu.RLock() - defer l.certMu.RUnlock() - - if len(l.cfg.Cert.Certificate) == 0 { - return trace.NotFound("local proxy has no TLS certificates configured") - } - - cert, err := utils.TLSCertLeaf(l.cfg.Cert) - if err != nil { - return trace.Wrap(err) - } - - // Check for cert expiration. - if err := utils.VerifyCertificateExpiry(cert, l.cfg.Clock); err != nil { - return trace.Wrap(err) - } - - return trace.Wrap(checkCert(cert)) -} - // CheckDBCert checks the proxy certificates for expiration and that the cert subject matches a database route. func (l *LocalProxy) CheckDBCert(dbRoute tlsca.RouteToDatabase) error { l.cfg.Log.Debug("checking local proxy database certs") diff --git a/lib/srv/alpnproxy/local_proxy_config_opt.go b/lib/srv/alpnproxy/local_proxy_config_opt.go index 8ff16404e5b72..7b520ce37611f 100644 --- a/lib/srv/alpnproxy/local_proxy_config_opt.go +++ b/lib/srv/alpnproxy/local_proxy_config_opt.go @@ -76,7 +76,9 @@ func WithClusterCAs(ctx context.Context, getClusterCertPool GetClusterCACertPool } // WithClientCert is a LocalProxyConfigOpt that sets the client certs used to -// connect to the remote Teleport Proxy. +// connect to the remote Teleport Proxy. Note that when paired with middleware +// that overwrites the cert, like the CertChecker middleware, this cert will +// not have a chance to be used. func WithClientCert(cert tls.Certificate) LocalProxyConfigOpt { return func(config *LocalProxyConfig) error { config.Cert = cert diff --git a/lib/srv/alpnproxy/local_proxy_test.go b/lib/srv/alpnproxy/local_proxy_test.go index b5ba5b1582b92..8054622c5b303 100644 --- a/lib/srv/alpnproxy/local_proxy_test.go +++ b/lib/srv/alpnproxy/local_proxy_test.go @@ -595,7 +595,7 @@ func createAWSAccessProxySuite(t *testing.T, cred *credentials.Credentials) *Loc hs.Close() }) go func() { - err := lp.StartHTTPAccessProxy(context.Background()) + err := lp.Start(context.Background()) assert.NoError(t, err) }() return lp diff --git a/tool/tsh/common/app.go b/tool/tsh/common/app.go index 385ad05d0c24f..cb40d148d6111 100644 --- a/tool/tsh/common/app.go +++ b/tool/tsh/common/app.go @@ -20,16 +20,13 @@ package common import ( "context" - "crypto/tls" - "crypto/x509/pkix" "fmt" "io" - "net" "os" "os/exec" "strings" + "sync" "text/template" - "time" "github.com/ghodss/yaml" "github.com/gravitational/trace" @@ -38,11 +35,9 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/asciitable" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -54,25 +49,20 @@ func onAppLogin(cf *CLIConf) error { return trace.Wrap(err) } - profile, err := tc.ProfileStatus() - if err != nil { - return trace.Wrap(err) - } - - app, err := getRegisteredApp(cf, tc) + appInfo, err := getAppInfo(cf, tc, nil /*matchRouteToApp*/) if err != nil { return trace.Wrap(err) } - routeToApp, err := getRouteToApp(cf, tc, profile, app) + app, err := appInfo.GetApp(cf.Context, tc) if err != nil { return trace.Wrap(err) } appCertParams := client.ReissueParams{ RouteToCluster: tc.SiteName, - RouteToApp: routeToApp, - AccessRequests: profile.ActiveRequests.AccessRequests, + RouteToApp: appInfo.RouteToApp, + AccessRequests: appInfo.profile.ActiveRequests.AccessRequests, } clusterClient, err := tc.ConnectToCluster(cf.Context) @@ -93,7 +83,7 @@ func onAppLogin(cf *CLIConf) error { return trace.Wrap(err) } - if err := printAppCommand(cf, tc, app, routeToApp); err != nil { + if err := printAppCommand(cf, tc, app, appInfo.RouteToApp); err != nil { return trace.Wrap(err) } @@ -119,46 +109,6 @@ func appLogin( return key, trace.Wrap(err) } -func getRouteToApp(cf *CLIConf, tc *client.TeleportClient, profile *client.ProfileStatus, app types.Application) (proto.RouteToApp, error) { - var awsRoleARN string - if app.IsAWSConsole() { - var err error - awsRoleARN, err = getARNFromFlags(cf, profile, app) - if err != nil { - return proto.RouteToApp{}, trace.Wrap(err) - } - } - - var azureIdentity string - if app.IsAzureCloud() { - var err error - azureIdentity, err = getAzureIdentityFromFlags(cf, profile) - if err != nil { - return proto.RouteToApp{}, trace.Wrap(err) - } - log.Debugf("Azure identity is %q", azureIdentity) - } - - var gcpServiceAccount string - if app.IsGCP() { - var err error - gcpServiceAccount, err = getGCPServiceAccountFromFlags(cf, profile) - if err != nil { - return proto.RouteToApp{}, trace.Wrap(err) - } - log.Debugf("GCP service account is %q", gcpServiceAccount) - } - - return proto.RouteToApp{ - Name: app.GetName(), - PublicAddr: app.GetPublicAddr(), - ClusterName: tc.SiteName, - AWSRoleARN: awsRoleARN, - AzureIdentity: azureIdentity, - GCPServiceAccount: gcpServiceAccount, - }, nil -} - func localProxyRequiredForApp(tc *client.TeleportClient) bool { return tc.TLSRoutingConnUpgradeRequired } @@ -313,26 +263,6 @@ Your service account: {{.serviceAccount}} Example command: tsh gcloud compute instances list `)) -// getRegisteredApp returns the registered application with the specified name. -func getRegisteredApp(cf *CLIConf, tc *client.TeleportClient) (app types.Application, err error) { - var apps []types.Application - err = client.RetryWithRelogin(cf.Context, tc, func() error { - apps, err = tc.ListApps(cf.Context, &proto.ListResourcesRequest{ - Namespace: tc.Namespace, - ResourceType: types.KindAppServer, - PredicateExpression: fmt.Sprintf(`name == "%s"`, cf.AppName), - }) - return trace.Wrap(err) - }) - if err != nil { - return nil, trace.Wrap(err) - } - if len(apps) == 0 { - return nil, trace.NotFound("app %q not found, use `tsh apps ls` to see registered apps", cf.AppName) - } - return apps[0], nil -} - // onAppLogout implements "tsh apps logout" command. func onAppLogout(cf *CLIConf) error { tc, err := makeClient(cf) @@ -424,6 +354,23 @@ func onAppConfig(cf *CLIConf) error { return nil } +const ( + // appFormatURI prints app URI. + appFormatURI = "uri" + // appFormatCA prints app CA cert path. + appFormatCA = "ca" + // appFormatCert prints app cert path. + appFormatCert = "cert" + // appFormatKey prints app key path. + appFormatKey = "key" + // appFormatCURL prints app curl command. + appFormatCURL = "curl" + // appFormatJSON prints app URI, CA cert path, cert path, key path, and curl command in JSON format. + appFormatJSON = "json" + // appFormatYAML prints app URI, CA cert path, cert path, key path, and curl command in YAML format. + appFormatYAML = "yaml" +) + func formatAppConfig(tc *client.TeleportClient, profile *client.ProfileStatus, routeToApp proto.RouteToApp, format string) (string, error) { var uri string if port := tc.WebProxyPort(); port == teleport.StandardHTTPSPort { @@ -534,126 +481,207 @@ func serializeAppConfig(configInfo *appConfigInfo, format string) (string, error return string(out), trace.Wrap(err) } -// pickActiveApp returns the app the current profile is logged into. -// -// If logged into multiple apps, returns an error unless one was specified -// explicitly on CLI. -func pickActiveApp(cf *CLIConf, activeRoutes []tlsca.RouteToApp) (*tlsca.RouteToApp, error) { - if cf.AppName == "" { - switch len(activeRoutes) { - case 0: - return nil, trace.NotFound("please login using 'tsh apps login' first") - case 1: - return &activeRoutes[0], nil - default: - var appNames []string - for _, r := range activeRoutes { - appNames = append(appNames, r.Name) +// getAppInfo fetches app information using the user's tsh profile, +// command line args, and the ListApps endpoint if necessary. If +// provided, the matcher will be used to filter active apps in the +// tsh profile. getAppInfo will also perform re-login if necessary. +func getAppInfo(cf *CLIConf, tc *client.TeleportClient, matchRouteToApp func(tlsca.RouteToApp) bool) (*appInfo, error) { + var profile *client.ProfileStatus + if err := client.RetryWithRelogin(cf.Context, tc, func() error { + var err error + profile, err = tc.ProfileStatus() + return trace.Wrap(err) + }); err != nil { + return nil, trace.Wrap(err) + } + + activeRoutes := profile.Apps + if matchRouteToApp != nil { + var filteredRoutes []tlsca.RouteToApp + for _, route := range profile.Apps { + if matchRouteToApp(route) { + filteredRoutes = append(filteredRoutes, route) } - return nil, trace.BadParameter("multiple apps are available (%v), please specify one via CLI argument", - strings.Join(appNames, ", ")) } + activeRoutes = filteredRoutes } - for _, r := range activeRoutes { - if r.Name == cf.AppName { - return &r, nil - } + if route, err := pickActiveApp(cf, activeRoutes); err == nil { + info := &appInfo{RouteToApp: route, isActive: true} + return info, info.checkAndSetDefaults(cf, tc, profile) + } else if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) } - return nil, trace.NotFound("not logged into app %q", cf.AppName) -} -// loadAppSelfSignedCA loads self-signed CA for provided app, or tries to -// generate a new CA if first load fails. -func loadAppSelfSignedCA(profile *client.ProfileStatus, tc *client.TeleportClient, appName string) (tls.Certificate, error) { - appCerts, _, err := loadAppCertificate(tc, appName) + // If we didn't find an active profile for the app, get info from server. + app, err := getApp(cf.Context, tc, cf.AppName) if err != nil { - return tls.Certificate{}, trace.Wrap(err) + return nil, trace.Wrap(err) } - appCertsExpireAt, err := getTLSCertExpireTime(appCerts) - if err != nil { - return tls.Certificate{}, trace.Wrap(err) + + info := &appInfo{ + RouteToApp: proto.RouteToApp{ + Name: app.GetName(), + PublicAddr: app.GetPublicAddr(), + ClusterName: tc.SiteName, + }, + app: app, } - cert, err := loadSelfSignedCA(profile.AppLocalCAPath(tc.SiteName, appName), profile.KeyPath(), appCertsExpireAt, "localhost") - return cert, trace.Wrap(err) + return info, info.checkAndSetDefaults(cf, tc, profile) +} + +// appInfo wraps a RouteToApp and the corresponding app. +// Its purpose is to prevent repeated fetches of the same app, +// by lazily fetching and caching the app for use as needed. +type appInfo struct { + proto.RouteToApp + // app corresponds to the app route and may be nil, so use GetApp + // instead of accessing it directly. + app types.Application + // isActive indicates an active app matched this app info. + isActive bool + mu sync.Mutex + + // profile is a cached profile status for the current login session. + profile *client.ProfileStatus } -func loadSelfSignedCA(caPath, keyPath string, validUntil time.Time, dnsNames ...string) (tls.Certificate, error) { - caTLSCert, err := keys.LoadX509KeyPair(caPath, keyPath) - if err == nil { - if expire, err := getTLSCertExpireTime(caTLSCert); err == nil && time.Now().Before(expire) { - return caTLSCert, nil +// checkAndSetDefaults checks the app route, applies cli flags, and sets defaults. +func (a *appInfo) checkAndSetDefaults(cf *CLIConf, tc *client.TeleportClient, profile *client.ProfileStatus) error { + a.profile = profile + + switch { + case a.IsAWSConsole(): + app, err := a.GetApp(cf.Context, tc) + if err != nil { + return trace.Wrap(err) } + + awsRoleARN, err := getARNFromFlags(cf, profile, app) + if err != nil { + return trace.Wrap(err) + } + a.AWSRoleARN = awsRoleARN + + case a.IsAzureCloud(): + azureIdentity, err := getAzureIdentityFromFlags(cf, profile) + if err != nil { + return trace.Wrap(err) + } + log.Debugf("Azure identity is %q", azureIdentity) + a.AzureIdentity = azureIdentity + + case a.IsGCP(): + gcpServiceAccount, err := getGCPServiceAccountFromFlags(cf, profile) + if err != nil { + return trace.Wrap(err) + } + log.Debugf("GCP service account is %q", gcpServiceAccount) + a.GCPServiceAccount = gcpServiceAccount } - if err != nil && !trace.IsNotFound(err) { - log.WithError(err).Debugf("Failed to load certificate from %v.", caPath) - } - // Generate and load again. - if err = generateSelfSignedCA(caPath, keyPath, validUntil, dnsNames...); err != nil { - return tls.Certificate{}, err + return nil +} + +func (a *appInfo) IsAWSConsole() bool { + if a.app != nil { + return a.app.IsAWSConsole() } + return a.RouteToApp.AWSRoleARN != "" +} - caTLSCert, err = keys.LoadX509KeyPair(caPath, keyPath) - if err != nil { - return tls.Certificate{}, trace.Wrap(err) +func (a *appInfo) IsAzureCloud() bool { + if a.app != nil { + return a.app.IsAzureCloud() } - return caTLSCert, nil + return a.RouteToApp.AzureIdentity != "" } -// generateSelfSignedCA generates a new self-signed CA for provided dnsNames -// and saves/overwrites the local CA file in the profile directory. -func generateSelfSignedCA(caPath, keyPath string, validUntil time.Time, dnsNames ...string) error { - log.Debugf("Generating local self signed CA at %v", caPath) - keyPem, err := utils.ReadPath(keyPath) - if err != nil { - return trace.Wrap(err) +func (a *appInfo) IsGCP() bool { + if a.app != nil { + return a.app.IsGCP() } + return a.RouteToApp.GCPServiceAccount != "" +} - key, err := keys.ParsePrivateKey(keyPem) +func (a *appInfo) appLocalCAPath(cluster string) string { + return a.profile.AppLocalCAPath(cluster, a.RouteToApp.Name) +} + +// GetApp returns the cached app or fetches it using the app route and +// caches the result. +func (a *appInfo) GetApp(ctx context.Context, tc *client.TeleportClient) (types.Application, error) { + a.mu.Lock() + defer a.mu.Unlock() + if a.app != nil { + return a.app.Copy(), nil + } + // holding mutex across the api call to avoid multiple redundant api calls. + app, err := getApp(ctx, tc, a.Name) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } + a.app = app + return a.app.Copy(), nil +} - certPem, err := tlsca.GenerateSelfSignedCAWithConfig(tlsca.GenerateCAConfig{ - Entity: pkix.Name{ - CommonName: "localhost", - Organization: []string{"Teleport"}, - }, - Signer: key, - DNSNames: dnsNames, - IPAddresses: []net.IP{net.ParseIP(defaults.Localhost)}, - TTL: time.Until(validUntil), +// getApp returns the registered application with the specified name. +func getApp(ctx context.Context, tc *client.TeleportClient, name string) (app types.Application, err error) { + var apps []types.Application + err = client.RetryWithRelogin(ctx, tc, func() error { + apps, err = tc.ListApps(ctx, &proto.ListResourcesRequest{ + Namespace: tc.Namespace, + ResourceType: types.KindAppServer, + PredicateExpression: fmt.Sprintf(`name == "%s"`, name), + }) + return trace.Wrap(err) }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) + } + if len(apps) == 0 { + return nil, trace.NotFound("app %q not found, use `tsh apps ls` to see registered apps", name) } + return apps[0], nil +} - if _, err := utils.EnsureLocalPath(caPath, "", ""); err != nil { - return trace.Wrap(err) +// pickActiveApp returns the app the current profile is logged into. +// +// If logged into multiple apps, returns an error unless one was specified +// explicitly on CLI. +func pickActiveApp(cf *CLIConf, activeRoutes []tlsca.RouteToApp) (proto.RouteToApp, error) { + if cf.AppName == "" { + switch len(activeRoutes) { + case 0: + return proto.RouteToApp{}, trace.NotFound("please login using 'tsh apps login' first") + case 1: + return tlscaRouteToAppToProto(activeRoutes[0]), nil + default: + var appNames []string + for _, r := range activeRoutes { + appNames = append(appNames, r.Name) + } + return proto.RouteToApp{}, trace.BadParameter("multiple apps are available (%v), please specify one via CLI argument", + strings.Join(appNames, ", ")) + } } - // WriteFile truncates existing file before writing. - if err = os.WriteFile(caPath, certPem, 0o600); err != nil { - return trace.ConvertSystemError(err) + for _, r := range activeRoutes { + if r.Name == cf.AppName { + return tlscaRouteToAppToProto(r), nil + } } - return nil + return proto.RouteToApp{}, trace.NotFound("not logged into app %q", cf.AppName) } -const ( - // appFormatURI prints app URI. - appFormatURI = "uri" - // appFormatCA prints app CA cert path. - appFormatCA = "ca" - // appFormatCert prints app cert path. - appFormatCert = "cert" - // appFormatKey prints app key path. - appFormatKey = "key" - // appFormatCURL prints app curl command. - appFormatCURL = "curl" - // appFormatJSON prints app URI, CA cert path, cert path, key path, and curl command in JSON format. - appFormatJSON = "json" - // appFormatYAML prints app URI, CA cert path, cert path, key path, and curl command in YAML format. - appFormatYAML = "yaml" -) +func tlscaRouteToAppToProto(route tlsca.RouteToApp) proto.RouteToApp { + return proto.RouteToApp{ + Name: route.Name, + PublicAddr: route.PublicAddr, + ClusterName: route.ClusterName, + AWSRoleARN: route.AWSRoleARN, + AzureIdentity: route.AzureIdentity, + GCPServiceAccount: route.GCPServiceAccount, + } +} diff --git a/tool/tsh/common/app_aws.go b/tool/tsh/common/app_aws.go index 3bedd2a613f3e..7a05bdb271916 100644 --- a/tool/tsh/common/app_aws.go +++ b/tool/tsh/common/app_aws.go @@ -19,8 +19,8 @@ package common import ( + "context" "fmt" - "net" "os" "os/exec" "strings" @@ -54,7 +54,7 @@ func onAWS(cf *CLIConf) error { cf.AWSEndpointURLMode = true } - err = awsApp.StartLocalProxies() + err = awsApp.StartLocalProxies(cf.Context) if err != nil { return trace.Wrap(err) } @@ -130,28 +130,25 @@ func isAWSFlag(args []string, i int) bool { // awsApp is an AWS app that can start local proxies to serve AWS APIs. type awsApp struct { - cf *CLIConf - profile *client.ProfileStatus - appName string - - localALPNProxy *alpnproxy.LocalProxy - localForwardProxy *alpnproxy.ForwardProxy - credentials *credentials.Credentials - credentialsOnce sync.Once + *localProxyApp + + cf *CLIConf + + credentials *credentials.Credentials + credentialsOnce sync.Once } // newAWSApp creates a new AWS app. -func newAWSApp(cf *CLIConf, profile *client.ProfileStatus, route tlsca.RouteToApp) (*awsApp, error) { +func newAWSApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*awsApp, error) { return &awsApp{ - cf: cf, - profile: profile, - appName: route.Name, + localProxyApp: newLocalProxyApp(tc, appInfo, cf.LocalProxyPort, cf.InsecureSkipVerify), + cf: cf, }, nil } // GetAppName returns the app name. func (a *awsApp) GetAppName() string { - return a.appName + return a.appInfo.RouteToApp.Name } // StartLocalProxies sets up local proxies for serving AWS clients. @@ -168,36 +165,25 @@ func (a *awsApp) GetAppName() string { // // The first method is always preferred as the original hostname is preserved // through forward proxy. -func (a *awsApp) StartLocalProxies() error { - // AWS endpoint URL mode - if a.cf.AWSEndpointURLMode { - if err := a.startLocalALPNProxy(a.cf.LocalProxyPort); err != nil { - return trace.Wrap(err) - } - - return nil +func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalProxyConfigOpt) error { + cred, err := a.GetAWSCredentials() + if err != nil { + return trace.Wrap(err) } - // HTTPS proxy mode - if err := a.startLocalALPNProxy(""); err != nil { - return trace.Wrap(err) + awsMiddleware := &alpnproxy.AWSAccessMiddleware{ + AWSCredentials: cred, } - if err := a.startLocalForwardProxy(a.cf.LocalProxyPort); err != nil { + + // AWS endpoint URL mode + if a.cf.AWSEndpointURLMode { + err := a.StartLocalProxyWithTLS(ctx, alpnproxy.WithHTTPMiddleware(awsMiddleware)) return trace.Wrap(err) } - return nil -} -// close makes all necessary close calls. -func (a *awsApp) Close() error { - var errs []error - if a.localALPNProxy != nil { - errs = append(errs, a.localALPNProxy.Close()) - } - if a.localForwardProxy != nil { - errs = append(errs, a.localForwardProxy.Close()) - } - return trace.NewAggregate(errs...) + // HTTPS proxy mode + err = a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAWSRequests, alpnproxy.WithHTTPMiddleware(awsMiddleware)) + return trace.Wrap(err) } // GetAWSCredentials generates fake AWS credentials that are used for @@ -248,7 +234,7 @@ func (a *awsApp) GetEnvVars() (map[string]string, error) { // https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html "AWS_ACCESS_KEY_ID": credValues.AccessKeyID, "AWS_SECRET_ACCESS_KEY": credValues.SecretAccessKey, - "AWS_CA_BUNDLE": a.profile.AppLocalCAPath(a.cf.SiteName, a.appName), + "AWS_CA_BUNDLE": a.appInfo.appLocalCAPath(a.cf.SiteName), } // Set proxy settings. @@ -298,114 +284,6 @@ func (a *awsApp) RunCommand(cmd *exec.Cmd) error { return nil } -// startLocalALPNProxy starts the local ALPN proxy. -func (a *awsApp) startLocalALPNProxy(port string) error { - tc, err := makeClient(a.cf) - if err != nil { - return trace.Wrap(err) - } - - appCert, err := loadAppCertificateWithAppLogin(a.cf, tc, a.appName) - if err != nil { - return trace.Wrap(err) - } - - localCA, err := loadAppSelfSignedCA(a.profile, tc, a.appName) - if err != nil { - return trace.Wrap(err) - } - - cred, err := a.GetAWSCredentials() - if err != nil { - return trace.Wrap(err) - } - - listenAddr := "localhost:0" - if port != "" { - listenAddr = fmt.Sprintf("localhost:%s", port) - } - - // Create a listener that is able to sign certificates when receiving AWS - // requests tunneled from the local forward proxy. - listener, err := alpnproxy.NewCertGenListener(alpnproxy.CertGenListenerConfig{ - ListenAddr: listenAddr, - CA: localCA, - }) - if err != nil { - return trace.Wrap(err) - } - - a.localALPNProxy, err = alpnproxy.NewLocalProxy( - makeBasicLocalProxyConfig(a.cf, tc, listener), - alpnproxy.WithClientCert(appCert), - alpnproxy.WithClusterCAsIfConnUpgrade(a.cf.Context, tc.RootClusterCACertPool), - alpnproxy.WithHTTPMiddleware(&alpnproxy.AWSAccessMiddleware{ - AWSCredentials: cred, - }), - ) - if err != nil { - if cerr := listener.Close(); cerr != nil { - return trace.NewAggregate(err, cerr) - } - return trace.Wrap(err) - } - - go func() { - if err := a.localALPNProxy.StartHTTPAccessProxy(a.cf.Context); err != nil { - log.WithError(err).Errorf("Failed to start local ALPN proxy.") - } - }() - return nil -} - -// startLocalForwardProxy starts the local forward proxy. -func (a *awsApp) startLocalForwardProxy(port string) error { - listenAddr := "localhost:0" - if port != "" { - listenAddr = fmt.Sprintf("localhost:%s", port) - } - - // Note that the created forward proxy serves HTTP instead of HTTPS, to - // eliminate the need to install temporary CA for various AWS clients. - listener, err := net.Listen("tcp", listenAddr) - if err != nil { - return trace.Wrap(err) - } - - a.localForwardProxy, err = alpnproxy.NewForwardProxy(alpnproxy.ForwardProxyConfig{ - Listener: listener, - CloseContext: a.cf.Context, - Handlers: []alpnproxy.ConnectRequestHandler{ - // Forward AWS requests to ALPN proxy. - alpnproxy.NewForwardToHostHandler(alpnproxy.ForwardToHostHandlerConfig{ - MatchFunc: alpnproxy.MatchAWSRequests, - Host: a.localALPNProxy.GetAddr(), - }), - - // Forward non-AWS requests to user's system proxy, if configured. - alpnproxy.NewForwardToSystemProxyHandler(alpnproxy.ForwardToSystemProxyHandlerConfig{ - InsecureSystemProxy: a.cf.InsecureSkipVerify, - }), - - // Forward non-AWS requests to their original hosts. - alpnproxy.NewForwardToOriginalHostHandler(), - }, - }) - if err != nil { - if cerr := listener.Close(); cerr != nil { - return trace.NewAggregate(err, cerr) - } - return trace.Wrap(err) - } - - go func() { - if err := a.localForwardProxy.Start(); err != nil { - log.WithError(err).Errorf("Failed to start local forward proxy.") - } - }() - return nil -} - func printAWSRoles(roles awsutils.Roles) { if len(roles) == 0 { return @@ -467,6 +345,15 @@ func matchAWSApp(app tlsca.RouteToApp) bool { } func pickAWSApp(cf *CLIConf) (*awsApp, error) { - app, err := pickCloudApp(cf, types.CloudAWS, matchAWSApp, newAWSApp) - return app, trace.Wrap(err) + tc, err := makeClient(cf) + if err != nil { + return nil, trace.Wrap(err) + } + + appInfo, err := getAppInfo(cf, tc, matchAWSApp) + if err != nil { + return nil, trace.Wrap(err) + } + + return newAWSApp(tc, cf, appInfo) } diff --git a/tool/tsh/common/app_azure.go b/tool/tsh/common/app_azure.go index a3c615ea52b34..c547650e667d7 100644 --- a/tool/tsh/common/app_azure.go +++ b/tool/tsh/common/app_azure.go @@ -19,9 +19,9 @@ package common import ( + "context" "crypto" "fmt" - "net" "os" "os/exec" "path" @@ -50,7 +50,7 @@ func onAzure(cf *CLIConf) error { return trace.Wrap(err) } - err = app.StartLocalProxies() + err = app.StartLocalProxies(cf.Context) if err != nil { return trace.Wrap(err) } @@ -69,26 +69,30 @@ func onAzure(cf *CLIConf) error { // azureApp is an Azure app that can start local proxies to serve Azure APIs. type azureApp struct { + *localProxyApp + cf *CLIConf - profile *client.ProfileStatus - app tlsca.RouteToApp + signer crypto.Signer msiSecret string - - localALPNProxy *alpnproxy.LocalProxy - localForwardProxy *alpnproxy.ForwardProxy } // newAzureApp creates a new Azure app. -func newAzureApp(cf *CLIConf, profile *client.ProfileStatus, app tlsca.RouteToApp) (*azureApp, error) { +func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azureApp, error) { + key, err := tc.LocalAgent().GetCoreKey() + if err != nil { + return nil, trace.Wrap(err) + } + msiSecret, err := getMSISecret() if err != nil { return nil, err } + return &azureApp{ - cf: cf, - profile: profile, - app: app, - msiSecret: msiSecret, + localProxyApp: newLocalProxyApp(tc, appInfo, cf.LocalProxyPort, cf.InsecureSkipVerify), + cf: cf, + signer: key.PrivateKey, + msiSecret: msiSecret, }, nil } @@ -127,27 +131,20 @@ func getMSISecret() (string, error) { // However, with MSI_ENDPOINT variable set, clients will reach out to this address for tokens. // We intercept calls to https://azure-msi.teleport.dev using alpnproxy.AzureMSIMiddleware. // These calls are served entirely locally, which helps the overall performance experienced by the user. -func (a *azureApp) StartLocalProxies() error { - // HTTPS proxy mode - if err := a.startLocalALPNProxy(""); err != nil { - return trace.Wrap(err) +func (a *azureApp) StartLocalProxies(ctx context.Context) error { + azureMiddleware := &alpnproxy.AzureMSIMiddleware{ + Key: a.signer, + Secret: a.msiSecret, + // we could, in principle, get the actual TenantID either from live data or from static configuration, + // but at this moment there is no clear advantage over simply issuing a new random identifier. + TenantID: uuid.New().String(), + ClientID: uuid.New().String(), + Identity: a.appInfo.RouteToApp.AzureIdentity, } - if err := a.startLocalForwardProxy(a.cf.LocalProxyPort); err != nil { - return trace.Wrap(err) - } - return nil -} -// Close makes all necessary close calls. -func (a *azureApp) Close() error { - var errs []error - if a.localALPNProxy != nil { - errs = append(errs, a.localALPNProxy.Close()) - } - if a.localForwardProxy != nil { - errs = append(errs, a.localForwardProxy.Close()) - } - return trace.NewAggregate(errs...) + // HTTPS proxy mode + err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAzureRequests, alpnproxy.WithHTTPMiddleware(azureMiddleware)) + return trace.Wrap(err) } // GetEnvVars returns required environment variables to configure the @@ -158,7 +155,7 @@ func (a *azureApp) GetEnvVars() (map[string]string, error) { // 1. `tsh az login` in one console // 2. `az ...` in another console // without custom config dir the second invocation will hang, attempting to connect to (inaccessible without configuration) MSI. - "AZURE_CONFIG_DIR": path.Join(profile.FullProfilePath(a.cf.HomePath), "azure", a.app.ClusterName, a.app.Name), + "AZURE_CONFIG_DIR": path.Join(profile.FullProfilePath(a.cf.HomePath), "azure", a.appInfo.RouteToApp.ClusterName, a.appInfo.RouteToApp.Name), // setting MSI_ENDPOINT instructs Azure CLI to make managed identity calls on this address. // the requests will be handled by tsh proxy. "MSI_ENDPOINT": "https://" + types.TeleportAzureMSIEndpoint + "/" + a.msiSecret, @@ -167,7 +164,7 @@ func (a *azureApp) GetEnvVars() (map[string]string, error) { // This isn't portable and applications other than az CLI may have to set different env variables, // add the application cert to system root store (not recommended, ultimate fallback) // or use equivalent of --insecure flag. - "REQUESTS_CA_BUNDLE": a.profile.AppLocalCAPath(a.cf.SiteName, a.app.Name), + "REQUESTS_CA_BUNDLE": a.appInfo.appLocalCAPath(a.cf.SiteName), } // Set proxy settings. @@ -200,120 +197,6 @@ func (a *azureApp) RunCommand(cmd *exec.Cmd) error { return nil } -// startLocalALPNProxy starts the local ALPN proxy. -func (a *azureApp) startLocalALPNProxy(port string) error { - tc, err := makeClient(a.cf) - if err != nil { - return trace.Wrap(err) - } - - appCert, err := loadAppCertificateWithAppLogin(a.cf, tc, a.app.Name) - if err != nil { - return trace.Wrap(err) - } - - localCA, err := loadAppSelfSignedCA(a.profile, tc, a.app.Name) - if err != nil { - return trace.Wrap(err) - } - - listenAddr := "localhost:0" - if port != "" { - listenAddr = fmt.Sprintf("localhost:%s", port) - } - - // Create a listener that is able to sign certificates when receiving Azure - // requests tunneled from the local forward proxy. - listener, err := alpnproxy.NewCertGenListener(alpnproxy.CertGenListenerConfig{ - ListenAddr: listenAddr, - CA: localCA, - }) - if err != nil { - return trace.Wrap(err) - } - - signer, ok := appCert.PrivateKey.(crypto.Signer) - if !ok { - return trace.BadParameter("private key type %T does not implement crypto.Signer (this is a bug)", appCert.PrivateKey) - } - - a.localALPNProxy, err = alpnproxy.NewLocalProxy( - makeBasicLocalProxyConfig(a.cf, tc, listener), - alpnproxy.WithClientCert(appCert), - alpnproxy.WithClusterCAsIfConnUpgrade(a.cf.Context, tc.RootClusterCACertPool), - alpnproxy.WithHTTPMiddleware(&alpnproxy.AzureMSIMiddleware{ - Key: signer, - Secret: a.msiSecret, - // we could, in principle, get the actual TenantID either from live data or from static configuration, - // but at this moment there is no clear advantage over simply issuing a new random identifier. - TenantID: uuid.New().String(), - ClientID: uuid.New().String(), - Identity: a.app.AzureIdentity, - }), - ) - if err != nil { - if cerr := listener.Close(); cerr != nil { - return trace.NewAggregate(err, cerr) - } - return trace.Wrap(err) - } - - go func() { - if err := a.localALPNProxy.StartHTTPAccessProxy(a.cf.Context); err != nil { - log.WithError(err).Errorf("Failed to start local ALPN proxy.") - } - }() - return nil -} - -// startLocalForwardProxy starts the local forward proxy. -func (a *azureApp) startLocalForwardProxy(port string) error { - listenAddr := "localhost:0" - if port != "" { - listenAddr = fmt.Sprintf("localhost:%s", port) - } - - // Note that the created forward proxy serves HTTP instead of HTTPS, to - // eliminate the need to install temporary CA for various Azure clients. - listener, err := net.Listen("tcp", listenAddr) - if err != nil { - return trace.Wrap(err) - } - - a.localForwardProxy, err = alpnproxy.NewForwardProxy(alpnproxy.ForwardProxyConfig{ - Listener: listener, - CloseContext: a.cf.Context, - Handlers: []alpnproxy.ConnectRequestHandler{ - // Forward Azure requests to ALPN proxy. - alpnproxy.NewForwardToHostHandler(alpnproxy.ForwardToHostHandlerConfig{ - MatchFunc: alpnproxy.MatchAzureRequests, - Host: a.localALPNProxy.GetAddr(), - }), - - // Forward non-Azure requests to user's system proxy, if configured. - alpnproxy.NewForwardToSystemProxyHandler(alpnproxy.ForwardToSystemProxyHandlerConfig{ - InsecureSystemProxy: a.cf.InsecureSkipVerify, - }), - - // Forward non-Azure requests to their original hosts. - alpnproxy.NewForwardToOriginalHostHandler(), - }, - }) - if err != nil { - if cerr := listener.Close(); cerr != nil { - return trace.NewAggregate(err, cerr) - } - return trace.Wrap(err) - } - - go func() { - if err := a.localForwardProxy.Start(); err != nil { - log.WithError(err).Errorf("Failed to start local forward proxy.") - } - }() - return nil -} - func printAzureIdentities(identities []string) { fmt.Println(formatAzureIdentities(identities)) } @@ -386,6 +269,15 @@ func matchAzureApp(app tlsca.RouteToApp) bool { } func pickAzureApp(cf *CLIConf) (*azureApp, error) { - app, err := pickCloudApp(cf, types.CloudAzure, matchAzureApp, newAzureApp) - return app, trace.Wrap(err) + tc, err := makeClient(cf) + if err != nil { + return nil, trace.Wrap(err) + } + + appInfo, err := getAppInfo(cf, tc, matchAzureApp) + if err != nil { + return nil, trace.Wrap(err) + } + + return newAzureApp(tc, cf, appInfo) } diff --git a/tool/tsh/common/app_cloud.go b/tool/tsh/common/app_cloud.go deleted file mode 100644 index 1594b6ce563c8..0000000000000 --- a/tool/tsh/common/app_cloud.go +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package common - -import ( - "strings" - - "github.com/gravitational/trace" - - "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/tlsca" -) - -func defaultValue[t any]() t { - var def t - return def -} - -// pickCloudApp will attempt to find an active cloud app, automatically logging the user to the selected application if possible. -func pickCloudApp[cloudApp any](cf *CLIConf, cloudFriendlyName string, matchRouteToApp func(tlsca.RouteToApp) bool, newCloudApp func(cf *CLIConf, profile *client.ProfileStatus, appRoute tlsca.RouteToApp) (cloudApp, error)) (cloudApp, error) { - app, needLogin, err := pickActiveCloudApp[cloudApp](cf, cloudFriendlyName, matchRouteToApp, newCloudApp) - if err != nil { - if !needLogin { - return defaultValue[cloudApp](), trace.Wrap(err) - } - log.WithError(err).Debugf("Failed to pick an active %v app, attempting to login into app %q", cloudFriendlyName, cf.AppName) - quiet := cf.Quiet - cf.Quiet = true - errLogin := onAppLogin(cf) - cf.Quiet = quiet - if errLogin != nil { - log.WithError(errLogin).Debugf("App login attempt failed") - // combine errors - return defaultValue[cloudApp](), trace.NewAggregate(err, errLogin) - } - // another attempt - app, _, err = pickActiveCloudApp[cloudApp](cf, cloudFriendlyName, matchRouteToApp, newCloudApp) - return app, trace.Wrap(err) - } - return app, nil -} - -func pickActiveCloudApp[cloudApp any](cf *CLIConf, cloudFriendlyName string, matchRouteToApp func(tlsca.RouteToApp) bool, newCloudApp func(cf *CLIConf, profile *client.ProfileStatus, appRoute tlsca.RouteToApp) (cloudApp, error)) (cApp cloudApp, needLogin bool, err error) { - profile, err := cf.ProfileStatus() - if err != nil { - return defaultValue[cloudApp](), false, trace.Wrap(err) - } - if len(profile.Apps) == 0 { - if cf.AppName == "" { - return defaultValue[cloudApp](), false, trace.NotFound("please login to %v app using 'tsh apps login' first", cloudFriendlyName) - } - return defaultValue[cloudApp](), true, trace.NotFound("please login to %v app using 'tsh apps login %v' first", cloudFriendlyName, cf.AppName) - } - name := cf.AppName - if name != "" { - app, err := findApp(profile.Apps, name) - if err != nil { - if trace.IsNotFound(err) { - return defaultValue[cloudApp](), true, trace.NotFound("please login to %v app using 'tsh apps login %v' first", cloudFriendlyName, name) - } - return defaultValue[cloudApp](), false, trace.Wrap(err) - } - if !matchRouteToApp(*app) { - return defaultValue[cloudApp](), false, trace.BadParameter( - "selected app %q is not an %v application", name, cloudFriendlyName, - ) - } - - cApp, err := newCloudApp(cf, profile, *app) - return cApp, false, trace.Wrap(err) - } - - filteredApps := filterApps(matchRouteToApp, profile.Apps) - if len(filteredApps) == 0 { - // no app name to use for attempted login. - return defaultValue[cloudApp](), false, trace.NotFound("please login to %v App using 'tsh apps login' first", cloudFriendlyName) - } - if len(filteredApps) > 1 { - names := strings.Join(getAppNames(filteredApps), ", ") - return defaultValue[cloudApp](), false, trace.BadParameter( - "multiple %v apps are available (%v), please specify one using --app CLI argument", cloudFriendlyName, names, - ) - } - cApp, err = newCloudApp(cf, profile, filteredApps[0]) - return cApp, false, trace.Wrap(err) -} - -func filterApps(matchRouteToApp func(tlsca.RouteToApp) bool, apps []tlsca.RouteToApp) []tlsca.RouteToApp { - var out []tlsca.RouteToApp - for _, app := range apps { - if matchRouteToApp(app) { - out = append(out, app) - } - } - return out -} - -func getAppNames(apps []tlsca.RouteToApp) []string { - var out []string - for _, app := range apps { - out = append(out, app.Name) - } - return out -} - -func findApp(apps []tlsca.RouteToApp, name string) (*tlsca.RouteToApp, error) { - for _, app := range apps { - if app.Name == name { - return &app, nil - } - } - return nil, trace.NotFound("failed to find app with %q name", name) -} diff --git a/tool/tsh/common/app_cloud_test.go b/tool/tsh/common/app_cloud_test.go deleted file mode 100644 index b2a10a7e5cdd5..0000000000000 --- a/tool/tsh/common/app_cloud_test.go +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package common - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/gravitational/teleport/lib/tlsca" -) - -func Test_filterApps(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - matchRouteToApp func(tlsca.RouteToApp) bool - apps []tlsca.RouteToApp - want []tlsca.RouteToApp - }{ - { - name: "aws", - matchRouteToApp: matchAWSApp, - apps: []tlsca.RouteToApp{ - {Name: "none"}, - {Name: "aws1", AWSRoleARN: "dummy"}, - {Name: "aws2", AWSRoleARN: "dummy"}, - {Name: "aws3", AWSRoleARN: "dummy"}, - {Name: "azure", AzureIdentity: "dummy"}, - {Name: "gcp", GCPServiceAccount: "dummy"}, - }, - want: []tlsca.RouteToApp{ - {Name: "aws1", AWSRoleARN: "dummy"}, - {Name: "aws2", AWSRoleARN: "dummy"}, - {Name: "aws3", AWSRoleARN: "dummy"}, - }, - }, - { - name: "azure", - matchRouteToApp: matchAzureApp, - apps: []tlsca.RouteToApp{ - {Name: "none"}, - {Name: "aws", AWSRoleARN: "dummy"}, - {Name: "azure1", AzureIdentity: "dummy"}, - {Name: "azure2", AzureIdentity: "dummy"}, - {Name: "azure3", AzureIdentity: "dummy"}, - {Name: "gcp", GCPServiceAccount: "dummy"}, - }, - want: []tlsca.RouteToApp{ - {Name: "azure1", AzureIdentity: "dummy"}, - {Name: "azure2", AzureIdentity: "dummy"}, - {Name: "azure3", AzureIdentity: "dummy"}, - }, - }, - { - name: "gcp", - matchRouteToApp: matchGCPApp, - apps: []tlsca.RouteToApp{ - {Name: "none"}, - {Name: "aws", AWSRoleARN: "dummy"}, - {Name: "azure", AzureIdentity: "dummy"}, - {Name: "gcp1", GCPServiceAccount: "dummy"}, - {Name: "gcp2", GCPServiceAccount: "dummy"}, - {Name: "gcp3", GCPServiceAccount: "dummy"}, - }, - want: []tlsca.RouteToApp{ - {Name: "gcp1", GCPServiceAccount: "dummy"}, - {Name: "gcp2", GCPServiceAccount: "dummy"}, - {Name: "gcp3", GCPServiceAccount: "dummy"}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - require.Equal(t, tt.want, filterApps(tt.matchRouteToApp, tt.apps)) - }) - } -} diff --git a/tool/tsh/common/app_gcp.go b/tool/tsh/common/app_gcp.go index ffa194b4619ff..641a13d4a336d 100644 --- a/tool/tsh/common/app_gcp.go +++ b/tool/tsh/common/app_gcp.go @@ -19,9 +19,9 @@ package common import ( + "context" "fmt" "hash/fnv" - "net" "os" "os/exec" "path" @@ -32,7 +32,6 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/profile" - "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/asciitable" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" @@ -53,7 +52,7 @@ func onGcloud(cf *CLIConf) error { return trace.Wrap(err) } - err = app.StartLocalProxies() + err = app.StartLocalProxies(cf.Context) if err != nil { return trace.Wrap(err) } @@ -76,7 +75,7 @@ func onGsutil(cf *CLIConf) error { return trace.Wrap(err) } - err = app.StartLocalProxies() + err = app.StartLocalProxies(cf.Context) if err != nil { return trace.Wrap(err) } @@ -95,20 +94,18 @@ func onGsutil(cf *CLIConf) error { // gcpApp is an GCP app that can start local proxies to serve GCP APIs. type gcpApp struct { - cf *CLIConf - profile *client.ProfileStatus - app tlsca.RouteToApp - secret string + *localProxyApp + + cf *CLIConf + + secret string // prefix is a prefix added to the name of configuration files, allowing two instances of gcpApp // to run concurrently without overwriting each other files. prefix string - - localALPNProxy *alpnproxy.LocalProxy - localForwardProxy *alpnproxy.ForwardProxy } // newGCPApp creates a new GCP app. -func newGCPApp(cf *CLIConf, profile *client.ProfileStatus, app tlsca.RouteToApp) (*gcpApp, error) { +func newGCPApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*gcpApp, error) { secret, err := getGCPSecret() if err != nil { return nil, err @@ -119,11 +116,10 @@ func newGCPApp(cf *CLIConf, profile *client.ProfileStatus, app tlsca.RouteToApp) prefix := fmt.Sprintf("%x", h.Sum32()) return &gcpApp{ - cf: cf, - profile: profile, - app: app, - secret: secret, - prefix: prefix, + localProxyApp: newLocalProxyApp(tc, appInfo, cf.LocalProxyPort, cf.InsecureSkipVerify), + cf: cf, + secret: secret, + prefix: prefix, }, nil } @@ -144,39 +140,29 @@ func getGCPSecret() (string, error) { // // The request flow to remote server (i.e. GCP APIs) looks like this: // clients -> local forward proxy -> local ALPN proxy -> remote server -func (a *gcpApp) StartLocalProxies() error { +func (a *gcpApp) StartLocalProxies(ctx context.Context) error { // configuration files if err := a.writeBotoConfig(); err != nil { return trace.Wrap(err) } - // HTTPS proxy mode - if err := a.startLocalALPNProxy(""); err != nil { - return trace.Wrap(err) + gcpMiddleware := &alpnproxy.AuthorizationCheckerMiddleware{ + Secret: a.secret, } - if err := a.startLocalForwardProxy(a.cf.LocalProxyPort); err != nil { - return trace.Wrap(err) - } - return nil + + // HTTPS proxy mode + err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchGCPRequests, alpnproxy.WithHTTPMiddleware(gcpMiddleware)) + return trace.Wrap(err) } // Close makes all necessary close calls. func (a *gcpApp) Close() error { - var errs []error - // close proxies - if a.localALPNProxy != nil { - errs = append(errs, a.localALPNProxy.Close()) - } - if a.localForwardProxy != nil { - errs = append(errs, a.localForwardProxy.Close()) - } - // remove boto config - errs = append(errs, a.removeBotoConfig()...) + errs := append([]error{a.localProxyApp.Close()}, a.removeBotoConfig()...) return trace.NewAggregate(errs...) } func (a *gcpApp) getGcloudConfigPath() string { - return path.Join(profile.FullProfilePath(a.cf.HomePath), "gcp", a.app.ClusterName, a.app.Name, "gcloud") + return path.Join(profile.FullProfilePath(a.cf.HomePath), "gcp", a.appInfo.RouteToApp.ClusterName, a.appInfo.RouteToApp.Name, "gcloud") } // removeBotoConfig removes config files written by WriteBotoConfig. @@ -189,7 +175,7 @@ func (a *gcpApp) removeBotoConfig() []error { } func (a *gcpApp) getBotoConfigDir() string { - return path.Join(profile.FullProfilePath(a.cf.HomePath), "gcp", a.app.ClusterName, a.app.Name) + return path.Join(profile.FullProfilePath(a.cf.HomePath), "gcp", a.appInfo.RouteToApp.ClusterName, a.appInfo.RouteToApp.Name) } func (a *gcpApp) getBotoConfigPath() string { @@ -238,7 +224,7 @@ func (a *gcpApp) writeBotoConfig() error { // GetEnvVars returns required environment variables to configure the // clients. func (a *gcpApp) GetEnvVars() (map[string]string, error) { - projectID, err := gcp.ProjectIDFromServiceAccountName(a.app.GCPServiceAccount) + projectID, err := gcp.ProjectIDFromServiceAccountName(a.appInfo.RouteToApp.GCPServiceAccount) if err != nil { return nil, trace.Wrap(err) } @@ -250,7 +236,7 @@ func (a *gcpApp) GetEnvVars() (map[string]string, error) { // Set core.custom_ca_certs_file via env variable, customizing the path to CA certs file. // https://cloud.google.com/sdk/gcloud/reference/config/set#:~:text=custom_ca_certs_file - "CLOUDSDK_CORE_CUSTOM_CA_CERTS_FILE": a.profile.AppLocalCAPath(a.cf.SiteName, a.app.Name), + "CLOUDSDK_CORE_CUSTOM_CA_CERTS_FILE": a.appInfo.appLocalCAPath(a.cf.SiteName), // We need to set project ID. This is sourced from the account name. // https://cloud.google.com/sdk/gcloud/reference/config#GROUP:~:text=authentication%20to%20gsutil.-,project,-Project%20ID%20of @@ -295,110 +281,6 @@ func (a *gcpApp) RunCommand(cmd *exec.Cmd) error { return nil } -// startLocalALPNProxy starts the local ALPN proxy. -func (a *gcpApp) startLocalALPNProxy(port string) error { - tc, err := makeClient(a.cf) - if err != nil { - return trace.Wrap(err) - } - - appCert, err := loadAppCertificateWithAppLogin(a.cf, tc, a.app.Name) - if err != nil { - return trace.Wrap(err) - } - - localCA, err := loadAppSelfSignedCA(a.profile, tc, a.app.Name) - if err != nil { - return trace.Wrap(err) - } - - listenAddr := "localhost:0" - if port != "" { - listenAddr = fmt.Sprintf("localhost:%s", port) - } - - // Create a listener that is able to sign certificates when receiving GCP - // requests tunneled from the local forward proxy. - listener, err := alpnproxy.NewCertGenListener(alpnproxy.CertGenListenerConfig{ - ListenAddr: listenAddr, - CA: localCA, - }) - if err != nil { - return trace.Wrap(err) - } - - a.localALPNProxy, err = alpnproxy.NewLocalProxy( - makeBasicLocalProxyConfig(a.cf, tc, listener), - alpnproxy.WithClientCert(appCert), - alpnproxy.WithClusterCAsIfConnUpgrade(a.cf.Context, tc.RootClusterCACertPool), - alpnproxy.WithHTTPMiddleware(&alpnproxy.AuthorizationCheckerMiddleware{ - Secret: a.secret, - }), - ) - - if err != nil { - if cerr := listener.Close(); cerr != nil { - return trace.NewAggregate(err, cerr) - } - return trace.Wrap(err) - } - - go func() { - if err := a.localALPNProxy.StartHTTPAccessProxy(a.cf.Context); err != nil { - log.WithError(err).Errorf("Failed to start local ALPN proxy.") - } - }() - return nil -} - -// startLocalForwardProxy starts the local forward proxy. -func (a *gcpApp) startLocalForwardProxy(port string) error { - listenAddr := "localhost:0" - if port != "" { - listenAddr = fmt.Sprintf("localhost:%s", port) - } - - // Note that the created forward proxy serves HTTP instead of HTTPS, to - // eliminate the need to install temporary CA for various GCP clients. - listener, err := net.Listen("tcp", listenAddr) - if err != nil { - return trace.Wrap(err) - } - - a.localForwardProxy, err = alpnproxy.NewForwardProxy(alpnproxy.ForwardProxyConfig{ - Listener: listener, - CloseContext: a.cf.Context, - Handlers: []alpnproxy.ConnectRequestHandler{ - // Forward GCP requests to ALPN proxy. - alpnproxy.NewForwardToHostHandler(alpnproxy.ForwardToHostHandlerConfig{ - MatchFunc: alpnproxy.MatchGCPRequests, - Host: a.localALPNProxy.GetAddr(), - }), - - // Forward non-GCP requests to user's system proxy, if configured. - alpnproxy.NewForwardToSystemProxyHandler(alpnproxy.ForwardToSystemProxyHandlerConfig{ - InsecureSystemProxy: a.cf.InsecureSkipVerify, - }), - - // Forward non-GCP requests to their original hosts. - alpnproxy.NewForwardToOriginalHostHandler(), - }, - }) - if err != nil { - if cerr := listener.Close(); cerr != nil { - return trace.NewAggregate(err, cerr) - } - return trace.Wrap(err) - } - - go func() { - if err := a.localForwardProxy.Start(); err != nil { - log.WithError(err).Errorf("Failed to start local forward proxy.") - } - }() - return nil -} - func printGCPServiceAccounts(accounts []string) { fmt.Println(formatGCPServiceAccounts(accounts)) } @@ -482,6 +364,15 @@ func matchGCPApp(app tlsca.RouteToApp) bool { } func pickGCPApp(cf *CLIConf) (*gcpApp, error) { - app, err := pickCloudApp(cf, types.CloudGCP, matchGCPApp, newGCPApp) - return app, trace.Wrap(err) + tc, err := makeClient(cf) + if err != nil { + return nil, trace.Wrap(err) + } + + appInfo, err := getAppInfo(cf, tc, matchGCPApp) + if err != nil { + return nil, trace.Wrap(err) + } + + return newGCPApp(tc, cf, appInfo) } diff --git a/tool/tsh/common/app_gcp_test.go b/tool/tsh/common/app_gcp_test.go index 0e277fd8f836b..5a84950327486 100644 --- a/tool/tsh/common/app_gcp_test.go +++ b/tool/tsh/common/app_gcp_test.go @@ -26,8 +26,8 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/tlsca" ) func Test_getGCPServiceAccountFromFlags(t *testing.T) { @@ -186,7 +186,7 @@ test-0@other-999999.iam.gserviceaccount.com func Test_gcpApp_Config(t *testing.T) { cf := &CLIConf{HomePath: t.TempDir()} profile := &client.ProfileStatus{} - route := tlsca.RouteToApp{ + route := proto.RouteToApp{ ClusterName: "test.teleport.io", Name: "myapp", GCPServiceAccount: "test@myproject-123456.iam.gserviceaccount.com", @@ -194,7 +194,10 @@ func Test_gcpApp_Config(t *testing.T) { t.Setenv("TELEPORT_GCLOUD_SECRET", "my_secret") - app, err := newGCPApp(cf, profile, route) + app, err := newGCPApp(nil, cf, &appInfo{ + RouteToApp: route, + profile: profile, + }) require.NoError(t, err) require.NotNil(t, app) diff --git a/tool/tsh/common/app_local_proxy.go b/tool/tsh/common/app_local_proxy.go new file mode 100644 index 0000000000000..ab7fba8de5a86 --- /dev/null +++ b/tool/tsh/common/app_local_proxy.go @@ -0,0 +1,208 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package common + +import ( + "cmp" + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/srv/alpnproxy" +) + +// localProxyApp is a generic app that can start local proxies. +type localProxyApp struct { + tc *client.TeleportClient + appInfo *appInfo + insecure bool + port string + + localALPNProxy *alpnproxy.LocalProxy + localForwardProxy *alpnproxy.ForwardProxy +} + +type requestMatcher func(req *http.Request) bool + +// newLocalProxyApp creates a new generic app. +func newLocalProxyApp(tc *client.TeleportClient, appInfo *appInfo, port string, insecure bool) *localProxyApp { + return &localProxyApp{ + tc: tc, + appInfo: appInfo, + port: port, + insecure: insecure, + } +} + +// StartLocalProxy sets up local proxies for serving app clients. +func (a *localProxyApp) StartLocalProxy(ctx context.Context, opts ...alpnproxy.LocalProxyConfigOpt) error { + if err := a.startLocalALPNProxy(ctx, a.port, false /*withTLS*/, opts...); err != nil { + return trace.Wrap(err) + } + + if a.port == "" { + fmt.Println("To avoid port randomization, you can choose the listening port using the --port flag.") + } + return nil +} + +// StartLocalProxy sets up local proxies for serving app clients. +func (a *localProxyApp) StartLocalProxyWithTLS(ctx context.Context, opts ...alpnproxy.LocalProxyConfigOpt) error { + if err := a.startLocalALPNProxy(ctx, a.port, true /*withTLS*/, opts...); err != nil { + return trace.Wrap(err) + } + + if a.port == "" { + fmt.Println("To avoid port randomization, you can choose the listening port using the --port flag.") + } + return nil +} + +// StartLocalProxy sets up local proxies for serving app clients. +func (a *localProxyApp) StartLocalProxyWithForwarder(ctx context.Context, forwardMatcher requestMatcher, opts ...alpnproxy.LocalProxyConfigOpt) error { + if err := a.startLocalALPNProxy(ctx, "", true /*withTLS*/, opts...); err != nil { + return trace.Wrap(err) + } + + if err := a.startLocalForwardProxy(ctx, a.port, forwardMatcher); err != nil { + return trace.Wrap(err) + } + + if a.port == "" { + fmt.Println("To avoid port randomization, you can choose the listening port using the --port flag.") + } + return nil +} + +// Close makes all necessary close calls. +func (a *localProxyApp) Close() error { + var errs []error + if a.localALPNProxy != nil { + errs = append(errs, a.localALPNProxy.Close()) + } + if a.localForwardProxy != nil { + errs = append(errs, a.localForwardProxy.Close()) + } + return trace.NewAggregate(errs...) +} + +// startLocalALPNProxy starts the local ALPN proxy. +func (a *localProxyApp) startLocalALPNProxy(ctx context.Context, port string, withTLS bool, opts ...alpnproxy.LocalProxyConfigOpt) error { + // Create an app cert checker to check and reissue app certs for the local app proxy. + appCertChecker := client.NewAppCertChecker(a.tc, a.appInfo.RouteToApp, nil) + + // If a stored cert is found for the app, try using it. + // Otherwise, let the checker reissue one as needed. + cert, err := loadAppCertificate(a.tc, a.appInfo.RouteToApp.Name) + if err == nil { + appCertChecker.SetCert(cert) + } + + listenAddr := fmt.Sprintf("localhost:%s", cmp.Or(port, "0")) + + var listener net.Listener + if withTLS { + appLocalCAPath := a.appInfo.appLocalCAPath(a.tc.SiteName) + localCertGenerator, err := client.NewLocalCertGenerator(ctx, appCertChecker, appLocalCAPath) + if err != nil { + return trace.Wrap(err) + } + + if listener, err = tls.Listen("tcp", listenAddr, &tls.Config{ + GetCertificate: localCertGenerator.GetCertificate, + }); err != nil { + return trace.Wrap(err) + } + } else { + if listener, err = net.Listen("tcp", listenAddr); err != nil { + return trace.Wrap(err) + } + } + + a.localALPNProxy, err = alpnproxy.NewLocalProxy( + makeBasicLocalProxyConfig(ctx, a.tc, listener, a.insecure), + append(opts, + alpnproxy.WithClusterCAsIfConnUpgrade(ctx, a.tc.RootClusterCACertPool), + alpnproxy.WithMiddleware(appCertChecker), + )..., + ) + if err != nil { + if cerr := listener.Close(); cerr != nil { + return trace.NewAggregate(err, cerr) + } + return trace.Wrap(err) + } + + fmt.Printf("Proxying connections to %s on %v\n", a.appInfo.RouteToApp.Name, a.localALPNProxy.GetAddr()) + + go func() { + if err = a.localALPNProxy.Start(ctx); err != nil { + log.WithError(err).Errorf("Failed to start local ALPN proxy.") + } + }() + return nil +} + +// startLocalForwardProxy starts a local forward proxy that forwards matching requests +// to the local ALPN proxy and unmatched requests to their original hosts. +func (a *localProxyApp) startLocalForwardProxy(ctx context.Context, port string, forwardMatcher requestMatcher) error { + listenAddr := fmt.Sprintf("localhost:%s", cmp.Or(port, "0")) + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + return trace.Wrap(err) + } + + a.localForwardProxy, err = alpnproxy.NewForwardProxy(alpnproxy.ForwardProxyConfig{ + Listener: listener, + CloseContext: ctx, + Handlers: []alpnproxy.ConnectRequestHandler{ + // Forward matched requests to ALPN proxy. + alpnproxy.NewForwardToHostHandler(alpnproxy.ForwardToHostHandlerConfig{ + MatchFunc: forwardMatcher, + Host: a.localALPNProxy.GetAddr(), + }), + + // Forward unmatched requests to user's system proxy, if configured. + alpnproxy.NewForwardToSystemProxyHandler(alpnproxy.ForwardToSystemProxyHandlerConfig{ + InsecureSystemProxy: a.insecure, + }), + + // Forward unmatched requests to their original hosts. + alpnproxy.NewForwardToOriginalHostHandler(), + }, + }) + if err != nil { + if cerr := listener.Close(); cerr != nil { + return trace.NewAggregate(err, cerr) + } + return trace.Wrap(err) + } + + go func() { + if err := a.localForwardProxy.Start(); err != nil { + log.WithError(err).Errorf("Failed to start local forward proxy.") + } + }() + return nil +} diff --git a/tool/tsh/common/app_test.go b/tool/tsh/common/app_test.go index 1fd57a2234b5d..76c6e54eedc1e 100644 --- a/tool/tsh/common/app_test.go +++ b/tool/tsh/common/app_test.go @@ -60,7 +60,7 @@ func startDummyHTTPServer(t *testing.T, name string) string { return srv.URL } -func testDummyAppConn(t require.TestingT, name string, addr string, tlsCerts ...tls.Certificate) { +func testDummyAppConn(t require.TestingT, addr string, tlsCerts ...tls.Certificate) (resp *http.Response) { clt := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ @@ -72,12 +72,7 @@ func testDummyAppConn(t require.TestingT, name string, addr string, tlsCerts ... resp, err := clt.Get(addr) assert.NoError(t, err) - if err != nil { - return - } - assert.Equal(t, 200, resp.StatusCode) - assert.Equal(t, name, resp.Header.Get("Server")) - _ = resp.Body.Close() + return resp } // TestAppCommands tests the following basic app command functionality for registered root and leaf apps. @@ -261,7 +256,10 @@ func TestAppCommands(t *testing.T) { clientCert, err := tls.LoadX509KeyPair(info.Cert, info.Key) require.NoError(t, err) - testDummyAppConn(t, app.name, fmt.Sprintf("https://%v", rootProxyAddr.Addr), clientCert) + resp := testDummyAppConn(t, fmt.Sprintf("https://%v", rootProxyAddr.Addr), clientCert) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, app.name, resp.Header.Get("Server")) // app logout. err = Run(ctx, []string{ @@ -291,20 +289,25 @@ func TestAppCommands(t *testing.T) { }() assert.EventuallyWithT(t, func(t *assert.CollectT) { - testDummyAppConn(t, app.name, fmt.Sprintf("http://127.0.0.1:%v", localProxyPort)) + resp := testDummyAppConn(t, fmt.Sprintf("http://127.0.0.1:%v", localProxyPort)) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, app.name, resp.Header.Get("Server")) + resp.Body.Close() }, 10*time.Second, time.Second) proxyCancel() assert.NoError(t, <-errC) - // proxy certs should not be saved to disk. - err = Run(context.Background(), []string{ - "app", - "config", - app.name, - "--cluster", app.cluster, - }, setHomePath(loginPath)) - assert.True(t, trace.IsNotFound(err), "expected not found error but got: %v", err) + // proxy certs should not be saved to disk if mfa was used.. + if requireMFAType == types.RequireMFAType_SESSION { + err = Run(context.Background(), []string{ + "app", + "config", + app.name, + "--cluster", app.cluster, + }, setHomePath(loginPath)) + assert.True(t, trace.IsNotFound(err), "expected not found error but got: %v", err) + } }) }) } diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index 0acf41d1bc1f8..1301f6fb34f44 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -628,7 +628,7 @@ func maybeStartLocalProxy(ctx context.Context, cf *CLIConf, return nil, trace.Wrap(err) } - lp, err := alpnproxy.NewLocalProxy(makeBasicLocalProxyConfig(cf, tc, listener), opts...) + lp, err := alpnproxy.NewLocalProxy(makeBasicLocalProxyConfig(cf.Context, tc, listener, cf.InsecureSkipVerify), opts...) if err != nil { return nil, trace.Wrap(err) } @@ -706,7 +706,7 @@ func prepareLocalProxyOptions(arg *localProxyConfig) ([]alpnproxy.LocalProxyConf // proxy starts instead. cert, err := loadDBCertificate(arg.tc, arg.dbInfo.ServiceName) if err == nil { - opts = append(opts, alpnproxy.WithClientCert(cert)) + cc.SetCert(cert) } return opts, nil } @@ -1020,7 +1020,7 @@ func (d *databaseInfo) GetDatabase(ctx context.Context, tc *client.TeleportClien d.mu.Lock() defer d.mu.Unlock() if d.database != nil { - return d.database, nil + return d.database.Copy(), nil } // holding mutex across the api call to avoid multiple redundant api calls. database, err := getDatabase(ctx, tc, d.ServiceName) @@ -1028,7 +1028,7 @@ func (d *databaseInfo) GetDatabase(ctx context.Context, tc *client.TeleportClien return nil, trace.Wrap(err) } d.database = database - return d.database, nil + return d.database.Copy(), nil } // chooseOneDatabase is a helper func that returns either the only database in a diff --git a/tool/tsh/common/hardware_key_test.go b/tool/tsh/common/hardware_key_test.go index 1fb819d55b9d8..7e4bcd4cbc233 100644 --- a/tool/tsh/common/hardware_key_test.go +++ b/tool/tsh/common/hardware_key_test.go @@ -26,9 +26,9 @@ import ( "crypto/tls" "encoding/json" "fmt" + "net/http" "os/user" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -350,7 +350,7 @@ func TestHardwareKeyApp(t *testing.T) { require.NoError(t, err) registerDeviceForUser(t, authServer, device, accessUser.GetName(), origin) - // Login before adding hardware key requirement + // Login before adding hardware key requirement and verify we can connect to the app. tmpHomePath := t.TempDir() err = Run(ctx, []string{ "login", @@ -359,7 +359,36 @@ func TestHardwareKeyApp(t *testing.T) { }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, accessUser, connector.GetName())) require.NoError(t, err) - // Require hardware key touch for the user. + err = Run(ctx, []string{ + "app", + "login", + "myapp", + "--insecure", + "--proxy", proxyAddr.String(), + }, setHomePath(tmpHomePath)) + require.NoError(t, err) + + confOut := new(bytes.Buffer) + err = Run(ctx, []string{ + "app", + "config", + "myapp", + "--format", "json", + }, setHomePath(tmpHomePath), setOverrideStdout(confOut)) + require.NoError(t, err) + + var info appConfigInfo + require.NoError(t, json.Unmarshal(confOut.Bytes(), &info)) + + clientCert, err := tls.LoadX509KeyPair(info.Cert, info.Key) + require.NoError(t, err) + + resp := testDummyAppConn(t, fmt.Sprintf("https://%v", proxyAddr.Addr), clientCert) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "myapp", resp.Header.Get("Server")) + resp.Body.Close() + + // Require hardware key touch for the user. The user's current app certs should fail. accessRole, err := authServer.GetRole(ctx, "access") require.NoError(t, err) accessRole.SetOptions(types.RoleOptions{ @@ -369,7 +398,10 @@ func TestHardwareKeyApp(t *testing.T) { require.NoError(t, err) testModules.MockAttestationData = nil - tmpHomePath = t.TempDir() + + resp = testDummyAppConn(t, fmt.Sprintf("https://%v", proxyAddr.Addr), clientCert) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + resp.Body.Close() // App login fails without an attested hardware key login. err = Run(ctx, []string{ @@ -381,25 +413,12 @@ func TestHardwareKeyApp(t *testing.T) { }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, accessUser, connector.GetName())) require.Error(t, err) - // Proxy app fails without an attested hardware key login. - proxyCtx, proxyCancel := context.WithTimeout(ctx, 10*time.Second) - defer proxyCancel() - - err = Run(proxyCtx, []string{ - "proxy", - "app", - "myapp", - "--insecure", - "--proxy", proxyAddr.String(), - }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, accessUser, connector.GetName())) - require.Error(t, err) - // Set MockAttestationData to attest the expected key policy and try again. testModules.MockAttestationData = &keys.AttestationData{ PrivateKeyPolicy: keys.PrivateKeyPolicyHardwareKeyTouch, } - // App commands will still fail without MFA, since the app sessions will + // App Login will still fail without MFA, since the app sessions will // only be attested as "web_session". webauthnLoginOpt := setupWebAuthnChallengeSolver(device, false /* success */) @@ -412,22 +431,10 @@ func TestHardwareKeyApp(t *testing.T) { }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, accessUser, connector.GetName()), webauthnLoginOpt) require.Error(t, err) - proxyCtx, proxyCancel = context.WithTimeout(ctx, 10*time.Second) - defer proxyCancel() - - err = Run(proxyCtx, []string{ - "proxy", - "app", - "myapp", - "--insecure", - "--proxy", proxyAddr.String(), - }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, accessUser, connector.GetName()), webauthnLoginOpt) - require.Error(t, err) - // App commands will succeed with MFA. webauthnLoginOpt = setupWebAuthnChallengeSolver(device, true /* success */) - // Test App login success. + // Test App login success and connect. err = Run(ctx, []string{ "app", "login", @@ -437,8 +444,7 @@ func TestHardwareKeyApp(t *testing.T) { }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, accessUser, connector.GetName()), webauthnLoginOpt) require.NoError(t, err) - // Retrieve the app login config (private key, ca, and cert). - confOut := new(bytes.Buffer) + confOut = new(bytes.Buffer) err = Run(ctx, []string{ "app", "config", @@ -447,36 +453,13 @@ func TestHardwareKeyApp(t *testing.T) { }, setHomePath(tmpHomePath), setOverrideStdout(confOut)) require.NoError(t, err) - // Verify that we can connect to the app using the generated app cert. - var info appConfigInfo require.NoError(t, json.Unmarshal(confOut.Bytes(), &info)) - clientCert, err := tls.LoadX509KeyPair(info.Cert, info.Key) + clientCert, err = tls.LoadX509KeyPair(info.Cert, info.Key) require.NoError(t, err) - testDummyAppConn(t, "myapp", fmt.Sprintf("https://%v", proxyAddr.Addr), clientCert) - - // Test Proxy app success. - localProxyPort := ports.Pop() - proxyCtx, proxyCancel = context.WithTimeout(ctx, 10*time.Second) - defer proxyCancel() - - errC := make(chan error) - go func() { - errC <- Run(proxyCtx, []string{ - "proxy", - "app", - "myapp", - "--port", localProxyPort, - "--insecure", - "--proxy", proxyAddr.String(), - }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, accessUser, connector.GetName()), webauthnLoginOpt) - }() - - assert.EventuallyWithT(t, func(t *assert.CollectT) { - testDummyAppConn(t, "myapp", fmt.Sprintf("http://127.0.0.1:%v", localProxyPort)) - }, 10*time.Second, time.Second) - - proxyCancel() - assert.NoError(t, <-errC) + resp = testDummyAppConn(t, fmt.Sprintf("https://%v", proxyAddr.Addr), clientCert) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "myapp", resp.Header.Get("Server")) + resp.Body.Close() } diff --git a/tool/tsh/common/kube_proxy.go b/tool/tsh/common/kube_proxy.go index d2b23b1f73612..4dc8845b4b23d 100644 --- a/tool/tsh/common/kube_proxy.go +++ b/tool/tsh/common/kube_proxy.go @@ -345,7 +345,7 @@ func makeKubeLocalProxy(cf *CLIConf, tc *client.TeleportClient, clusters kubecon }) localProxy, err := alpnproxy.NewLocalProxy( - makeBasicLocalProxyConfig(cf, tc, lpListener), + makeBasicLocalProxyConfig(cf.Context, tc, lpListener, cf.InsecureSkipVerify), alpnproxy.WithHTTPMiddleware(kubeMiddleware), alpnproxy.WithSNI(client.GetKubeTLSServerName(tc.WebProxyHost())), alpnproxy.WithClusterCAs(cf.Context, tc.RootClusterCACertPool), @@ -388,7 +388,7 @@ func (k *kubeLocalProxy) Start(ctx context.Context) error { errChan <- k.forwardProxy.Start() }() go func() { - errChan <- k.localProxy.StartHTTPAccessProxy(ctx) + errChan <- k.localProxy.Start(ctx) }() select { diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index abe25adb60976..53866dede32ec 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -30,7 +30,6 @@ import ( "strconv" "strings" "text/template" - "time" "unicode" "github.com/gravitational/trace" @@ -211,7 +210,7 @@ func onProxyCommandDB(cf *CLIConf) error { return trace.Wrap(err) } - lp, err := alpnproxy.NewLocalProxy(makeBasicLocalProxyConfig(cf, tc, listener), proxyOpts...) + lp, err := alpnproxy.NewLocalProxy(makeBasicLocalProxyConfig(cf.Context, tc, listener, cf.InsecureSkipVerify), proxyOpts...) if err != nil { return trace.Wrap(err) } @@ -382,80 +381,29 @@ func onProxyCommandApp(cf *CLIConf) error { return trace.Wrap(err) } - app, err := getRegisteredApp(cf, tc) + appInfo, err := getAppInfo(cf, tc, nil /*matchRouteToApp*/) if err != nil { return trace.Wrap(err) } - profile, err := tc.ProfileStatus() + app, err := appInfo.GetApp(cf.Context, tc) if err != nil { return trace.Wrap(err) } - routeToApp, err := getRouteToApp(cf, tc, profile, app) - if err != nil { + proxyApp := newLocalProxyApp(tc, appInfo, cf.LocalProxyPort, cf.InsecureSkipVerify) + if err := proxyApp.StartLocalProxy(cf.Context, alpnproxy.WithALPNProtocol(alpnProtocolForApp(app))); err != nil { return trace.Wrap(err) } - opts := []alpnproxy.LocalProxyConfigOpt{ - alpnproxy.WithALPNProtocol(alpnProtocolForApp(app)), - alpnproxy.WithClusterCAsIfConnUpgrade(cf.Context, tc.RootClusterCACertPool), - alpnproxy.WithMiddleware(libclient.NewAppCertChecker(tc, routeToApp, nil)), - } - - // Virtual profiles (e.g. indirect use via `tbot proxy app`) will attempt - // relogin which is not possible. For these, we'll need to load the app - // certificate manually and prepend the config option. - // TODO(timothyb89): Remove this workaround in favor of - // https://github.com/gravitational/teleport/pull/40985 once it is merged. - if profile.IsVirtual { - cert, needLogin, err := loadAppCertificate(tc, app.GetName()) - if err != nil { - return trace.Wrap(err) - } - - if needLogin { - return trace.BadParameter("app identity requires relogin but this is impossible with a virtual profile") - } - - opts = append([]alpnproxy.LocalProxyConfigOpt{ - alpnproxy.WithClientCert(cert), - }, opts...) - } - - addr := "localhost:0" - if cf.LocalProxyPort != "" { - addr = fmt.Sprintf("127.0.0.1:%s", cf.LocalProxyPort) - } - - listener, err := net.Listen("tcp", addr) - if err != nil { - return trace.Wrap(err) - } - - lp, err := alpnproxy.NewLocalProxy(makeBasicLocalProxyConfig(cf, tc, listener), opts...) - if err != nil { - if cerr := listener.Close(); cerr != nil { - return trace.NewAggregate(err, cerr) + defer func() { + if err := proxyApp.Close(); err != nil { + log.WithError(err).Error("Failed to close app proxy.") } - return trace.Wrap(err) - } - - fmt.Printf("Proxying connections to %s on %v\n", cf.AppName, lp.GetAddr()) - if cf.LocalProxyPort == "" { - fmt.Println("To avoid port randomization, you can choose the listening port using the --port flag.") - } - - go func() { - <-cf.Context.Done() - lp.Close() }() - defer lp.Close() - if err = lp.Start(cf.Context); err != nil { - return trace.Wrap(err) - } - + // Proxy connections until the client terminates the command. + <-cf.Context.Done() return nil } @@ -470,7 +418,7 @@ func onProxyCommandAWS(cf *CLIConf) error { return trace.Wrap(err) } - err = awsApp.StartLocalProxies() + err = awsApp.StartLocalProxies(cf.Context) if err != nil { return trace.Wrap(err) } @@ -484,6 +432,7 @@ func onProxyCommandAWS(cf *CLIConf) error { if err := printProxyAWSTemplate(cf, awsApp); err != nil { return trace.Wrap(err) } + <-cf.Context.Done() return nil } @@ -562,7 +511,7 @@ func onProxyCommandAzure(cf *CLIConf) error { return trace.Wrap(err) } - err = azApp.StartLocalProxies() + err = azApp.StartLocalProxies(cf.Context) if err != nil { return trace.Wrap(err) } @@ -593,7 +542,7 @@ func onProxyCommandGCloud(cf *CLIConf) error { return trace.Wrap(err) } - err = gcpApp.StartLocalProxies() + err = gcpApp.StartLocalProxies(cf.Context) if err != nil { return trace.Wrap(err) } @@ -617,59 +566,20 @@ func onProxyCommandGCloud(cf *CLIConf) error { return nil } -// loadAppCertificateWithAppLogin is a wrapper around loadAppCertificate that will attempt to login the user to -// the app of choice at most once, if the return value from loadAppCertificate call indicates that app login -// should fix the problem. -func loadAppCertificateWithAppLogin(cf *CLIConf, tc *libclient.TeleportClient, appName string) (tls.Certificate, error) { - cert, needLogin, err := loadAppCertificate(tc, appName) - if err != nil { - if !needLogin { - return tls.Certificate{}, trace.Wrap(err) - } - log.WithError(err).Debugf("Loading app certificate failed, attempting to login to app %q", appName) - quiet := cf.Quiet - cf.Quiet = true - errLogin := onAppLogin(cf) - cf.Quiet = quiet - if errLogin != nil { - log.WithError(errLogin).Debugf("Login attempt failed") - // combine errors - return tls.Certificate{}, trace.NewAggregate(err, errLogin) - } - // another attempt - cert, _, err = loadAppCertificate(tc, appName) - return cert, trace.Wrap(err) - } - return cert, nil -} - -// loadAppCertificate loads the app certificate for the provided app. -// Returns tuple (certificate, needLogin, err). -// The boolean `needLogin` will be true if the error returned should go away with successful `tsh app login `. -func loadAppCertificate(tc *libclient.TeleportClient, appName string) (certificate tls.Certificate, needLogin bool, err error) { +func loadAppCertificate(tc *libclient.TeleportClient, appName string) (tls.Certificate, error) { key, err := tc.LocalAgent().GetKey(tc.SiteName, libclient.WithAppCerts{}) if err != nil { - return tls.Certificate{}, false, trace.Wrap(err) + return tls.Certificate{}, trace.Wrap(err) } appCert, err := key.AppTLSCert(appName) if trace.IsNotFound(err) { - return tls.Certificate{}, true, trace.NotFound("please login into the application first: 'tsh apps login %v'", appName) + return tls.Certificate{}, trace.NotFound("please login into the application first: 'tsh apps login %v'", appName) } else if err != nil { - return tls.Certificate{}, false, trace.Wrap(err) - } - - expiresAt, err := getTLSCertExpireTime(appCert) - if err != nil { - return tls.Certificate{}, true, trace.WrapWithMessage(err, "invalid certificate - please login to the application again: 'tsh apps login %v'", appName) - } - if time.Until(expiresAt) < 5*time.Second { - return tls.Certificate{}, true, trace.BadParameter( - "application %s certificate has expired, please re-login to the app using 'tsh apps login %v'", appName, - appName) + return tls.Certificate{}, trace.Wrap(err) } - return appCert, false, nil + return appCert, nil } func loadDBCertificate(tc *libclient.TeleportClient, dbName string) (tls.Certificate, error) { @@ -688,15 +598,6 @@ func loadDBCertificate(tc *libclient.TeleportClient, dbName string) (tls.Certifi return dbCert, nil } -// getTLSCertExpireTime returns the certificate NotAfter time. -func getTLSCertExpireTime(cert tls.Certificate) (time.Time, error) { - x509cert, err := utils.TLSCertLeaf(cert) - if err != nil { - return time.Time{}, trace.Wrap(err) - } - return x509cert.NotAfter, nil -} - func getEnvOrDefault(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value @@ -704,11 +605,11 @@ func getEnvOrDefault(key, defaultValue string) string { return defaultValue } -func makeBasicLocalProxyConfig(cf *CLIConf, tc *libclient.TeleportClient, listener net.Listener) alpnproxy.LocalProxyConfig { +func makeBasicLocalProxyConfig(ctx context.Context, tc *libclient.TeleportClient, listener net.Listener, insecure bool) alpnproxy.LocalProxyConfig { return alpnproxy.LocalProxyConfig{ RemoteProxyAddr: tc.WebProxyAddr, - InsecureSkipVerify: cf.InsecureSkipVerify, - ParentContext: cf.Context, + InsecureSkipVerify: insecure, + ParentContext: ctx, Listener: listener, ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired, }