From 0570250d4faa946c5c9e6afa051d4219d9bb90c9 Mon Sep 17 00:00:00 2001 From: Hugo Shaka Date: Thu, 25 Jul 2024 15:32:56 -0400 Subject: [PATCH] [v16] Support detecting a credential expiry (#44080) * Support detecting a credential expiry (#43886) * Support detecting a credential expiry * Fix Expiry() and add tests * protect identityFile cred with a lock * add godoc * Fix tests * Do not break previous Credentials implementations * fixup! Do not break previous Credentials implementations * address edoardo's comments * fix tests --- api/client/credentials.go | 101 +++++++++++++++++- api/client/credentials_test.go | 56 ++++++++++ api/client/proxy/client.go | 5 + api/identityfile/identityfile.go | 13 +++ api/profile/profile.go | 14 +++ api/utils/keys/privatekey.go | 59 ++++++---- api/utils/keys/privatekey_test.go | 66 ++++++++++++ .../lib/credentials/credentials_test.go | 4 + lib/tbot/config/service_client_credential.go | 15 ++- lib/tbot/identity/identity_facade.go | 14 +++ 10 files changed, 321 insertions(+), 26 deletions(-) diff --git a/api/client/credentials.go b/api/client/credentials.go index d6b824182f94a..f9f3b67cc45d1 100644 --- a/api/client/credentials.go +++ b/api/client/credentials.go @@ -22,6 +22,7 @@ import ( "crypto/x509" "os" "sync" + "time" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" @@ -40,6 +41,10 @@ import ( // also provide other functionality, such as automatic address discovery and // ssh connectivity. // +// Note: starting with v17, all Credentials must have an Expiry method. +// For compatibility guarantees, the future interface is optional and called +// CredentialsWithExpiry. +// // See the examples below for an example of each loader. type Credentials interface { // TLSConfig returns TLS configuration used to authenticate the client. @@ -49,6 +54,21 @@ type Credentials interface { SSHClientConfig() (*ssh.ClientConfig, error) } +// CredentialsWithExpiry are credentials implementing the Expiry() function. +// This interface is here to avoid breaking changes in v15 and v16. Starting with +// v17, Expiry() will be part of the Credentials interface. +type CredentialsWithExpiry interface { + Credentials + + // Expiry returns the Credentials expiry if it's possible to know its expiry. + // When expiry can be determined returns true, else returns false. + // If the Credentials don't expire, returns the zero time. + // If the Credential is dynamically refreshed or reloaded, (e.g filesystem + // reload or tbot renewal), Expiry returns the expiry of the currently active + // Credentials. + Expiry() (time.Time, bool) +} + // CredentialsWithDefaultAddrs additionally provides default addresses sourced // from the credential which are used when the client has not been explicitly // configured with an address. @@ -60,6 +80,17 @@ type CredentialsWithDefaultAddrs interface { DefaultAddrs() ([]string, error) } +// Expiry checks if the Credentials has an Expiry function and invokes it. +// Starting with v17, this is part of the Credentials interface but we must be backward compatible with v16 and below. +// If the Credentials don't implement Expiry, returns false. +func Expiry(c Credentials) (time.Time, bool) { + credsWithExpiry, ok := c.(CredentialsWithExpiry) + if !ok { + return time.Time{}, false + } + return credsWithExpiry.Expiry() +} + // LoadTLS is used to load Credentials directly from a *tls.Config. // // TLS creds can only be used to connect directly to a Teleport Auth server. @@ -87,6 +118,12 @@ func (c *tlsConfigCreds) SSHClientConfig() (*ssh.ClientConfig, error) { return nil, trace.NotImplemented("no ssh config") } +// Expiry returns the credential expiry. As the tlsConfigCreds are built from an existing tlsConfig +// we have no way of knowing which certificate will be returned and if it's expired. +func (c *tlsConfigCreds) Expiry() (time.Time, bool) { + return time.Time{}, false +} + // LoadKeyPair is used to load Credentials from a certicate keypair on disk. // // KeyPair Credentials can only be used to connect directly to a Teleport Auth server. @@ -115,7 +152,7 @@ type keypairCreds struct { // TLSConfig returns TLS configuration. func (c *keypairCreds) TLSConfig() (*tls.Config, error) { - cert, err := tls.LoadX509KeyPair(c.certFile, c.keyFile) + cert, err := keys.LoadX509KeyPair(c.certFile, c.keyFile) if err != nil { return nil, trace.Wrap(err) } @@ -141,6 +178,19 @@ func (c *keypairCreds) SSHClientConfig() (*ssh.ClientConfig, error) { return nil, trace.NotImplemented("no ssh config") } +// Expiry returns the credential expiry. +func (c *keypairCreds) Expiry() (time.Time, bool) { + certPEMBlock, err := os.ReadFile(c.certFile) + if err != nil { + return time.Time{}, false + } + cert, _, err := keys.X509Certificate(certPEMBlock) + if err != nil { + return time.Time{}, false + } + return cert.NotAfter, true +} + // LoadIdentityFile is used to load Credentials from an identity file on disk. // // Identity Credentials can be used to connect to an auth server directly @@ -162,12 +212,16 @@ func LoadIdentityFile(path string) Credentials { // identityCredsFile use an identity file to provide client credentials. type identityCredsFile struct { + // mutex protects identityFile + mutex sync.Mutex identityFile *identityfile.IdentityFile path string } // TLSConfig returns TLS configuration. func (c *identityCredsFile) TLSConfig() (*tls.Config, error) { + c.mutex.Lock() + defer c.mutex.Unlock() if err := c.load(); err != nil { return nil, trace.Wrap(err) } @@ -182,6 +236,8 @@ func (c *identityCredsFile) TLSConfig() (*tls.Config, error) { // SSHClientConfig returns SSH configuration. func (c *identityCredsFile) SSHClientConfig() (*ssh.ClientConfig, error) { + c.mutex.Lock() + defer c.mutex.Unlock() if err := c.load(); err != nil { return nil, trace.Wrap(err) } @@ -194,6 +250,16 @@ func (c *identityCredsFile) SSHClientConfig() (*ssh.ClientConfig, error) { return sshConfig, nil } +// Expiry returns the credential expiry. +func (c *identityCredsFile) Expiry() (time.Time, bool) { + c.mutex.Lock() + defer c.mutex.Unlock() + if err := c.load(); err != nil { + return time.Time{}, false + } + return c.identityFile.Expiry() +} + // load is used to lazy load the identity file from persistent storage. // This allows LoadIdentity to avoid possible errors for UX purposes. func (c *identityCredsFile) load() error { @@ -260,6 +326,15 @@ func (c *identityCredsString) SSHClientConfig() (*ssh.ClientConfig, error) { return sshConfig, nil } +// Expiry returns the credential expiry. +func (c *identityCredsString) Expiry() (time.Time, bool) { + if err := c.load(); err != nil { + return time.Time{}, false + } + + return c.identityFile.Expiry() +} + // load is used to lazy load the identity file from a string. func (c *identityCredsString) load() error { if c.identityFile != nil { @@ -329,6 +404,13 @@ func (c *profileCreds) SSHClientConfig() (*ssh.ClientConfig, error) { return sshConfig, nil } +func (c *profileCreds) Expiry() (time.Time, bool) { + if err := c.load(); err != nil { + return time.Time{}, false + } + return c.profile.Expiry() +} + // DefaultAddrs implements CredentialsWithDefaultAddrs by providing the // WebProxyAddr from the credential func (c *profileCreds) DefaultAddrs() ([]string, error) { @@ -563,3 +645,20 @@ func (d *DynamicIdentityFileCreds) SSHClientConfig() (*ssh.ClientConfig, error) } return cfg, nil } + +// Expiry returns the current credential expiry. +func (d *DynamicIdentityFileCreds) Expiry() (time.Time, bool) { + d.mu.RLock() + defer d.mu.RUnlock() + + if d.tlsCert == nil || len(d.tlsCert.Certificate) == 0 { + return time.Time{}, false + } + + x509Cert, err := x509.ParseCertificate(d.tlsCert.Certificate[0]) + if err != nil { + return time.Time{}, false + } + + return x509Cert.NotAfter, true +} diff --git a/api/client/credentials_test.go b/api/client/credentials_test.go index 273707288bf93..896a873c19811 100644 --- a/api/client/credentials_test.go +++ b/api/client/credentials_test.go @@ -57,6 +57,12 @@ func TestLoadTLS(t *testing.T) { require.NoError(t, err) requireEqualTLSConfig(t, expectedTLSConfig, tlsConfig) + require.Implements(t, (*CredentialsWithExpiry)(nil), creds) + expiringCreds := creds.(CredentialsWithExpiry) + expiry, ok := expiringCreds.Expiry() + require.False(t, ok, "Expiry should not be knows on creds built only from TLS Config") + require.True(t, expiry.IsZero(), "unknown expiry should be zero") + // Load invalid tls.Config. invalidTLSCreds := LoadTLS(nil) _, err = invalidTLSCreds.TLSConfig() @@ -100,12 +106,22 @@ func TestLoadIdentityFile(t *testing.T) { require.NoError(t, err) requireEqualSSHConfig(t, expectedSSHConfig, sshConfig) + require.Implements(t, (*CredentialsWithExpiry)(nil), creds) + expiringCreds := creds.(CredentialsWithExpiry) + expiry, ok := expiringCreds.Expiry() + require.True(t, ok, "Expiry should be known when we build creds from an identity file") + require.Equal(t, tlsCertNotAfter, expiry) + // Load invalid identity. creds = LoadIdentityFile("invalid_path") _, err = creds.TLSConfig() require.Error(t, err) _, err = creds.SSHClientConfig() require.Error(t, err) + require.Implements(t, (*CredentialsWithExpiry)(nil), creds) + expiringCreds = creds.(CredentialsWithExpiry) + _, ok = expiringCreds.Expiry() + require.False(t, ok, "expiry should be unknown on a broken id file") } func TestLoadIdentityFileFromString(t *testing.T) { @@ -146,12 +162,22 @@ func TestLoadIdentityFileFromString(t *testing.T) { require.NoError(t, err) requireEqualSSHConfig(t, expectedSSHConfig, sshConfig) + require.Implements(t, (*CredentialsWithExpiry)(nil), creds) + expiringCreds := creds.(CredentialsWithExpiry) + expiry, ok := expiringCreds.Expiry() + require.True(t, ok, "expiry should be known when we build creds from an identity file") + require.Equal(t, tlsCertNotAfter, expiry) + // Load invalid identity. creds = LoadIdentityFileFromString("invalid_creds") _, err = creds.TLSConfig() require.Error(t, err) _, err = creds.SSHClientConfig() require.Error(t, err) + require.Implements(t, (*CredentialsWithExpiry)(nil), creds) + expiringCreds = creds.(CredentialsWithExpiry) + _, ok = expiringCreds.Expiry() + require.False(t, ok, "expiry should be unknown on a broken id file") } func TestLoadKeyPair(t *testing.T) { @@ -176,11 +202,20 @@ func TestLoadKeyPair(t *testing.T) { tlsConfig, err := creds.TLSConfig() require.NoError(t, err) requireEqualTLSConfig(t, expectedTLSConfig, tlsConfig) + require.Implements(t, (*CredentialsWithExpiry)(nil), creds) + expiringCreds := creds.(CredentialsWithExpiry) + expiry, ok := expiringCreds.Expiry() + require.True(t, ok, "expiry should be known when we build creds from cert files") + require.Equal(t, tlsCertNotAfter, expiry) // Load invalid keypairs. invalidIdentityCreds := LoadKeyPair("invalid_path", "invalid_path", "invalid_path") _, err = invalidIdentityCreds.TLSConfig() require.Error(t, err) + require.Implements(t, (*CredentialsWithExpiry)(nil), creds) + expiringCreds = invalidIdentityCreds.(CredentialsWithExpiry) + _, ok = expiringCreds.Expiry() + require.False(t, ok, "expiry should be unknown on a broken credential") } func TestLoadProfile(t *testing.T) { @@ -207,6 +242,10 @@ func TestLoadProfile(t *testing.T) { require.Error(t, err) _, err = creds.SSHClientConfig() require.Error(t, err) + require.Implements(t, (*CredentialsWithExpiry)(nil), creds) + expiringCreds := creds.(CredentialsWithExpiry) + _, ok := expiringCreds.Expiry() + require.False(t, ok, "expiry should be unknown on a broken profile") }) } @@ -226,6 +265,12 @@ func testProfileContents(t *testing.T, dir, name string) { sshConfig, err := creds.SSHClientConfig() require.NoError(t, err) requireEqualSSHConfig(t, expectedSSHConfig, sshConfig) + + require.Implements(t, (*CredentialsWithExpiry)(nil), creds) + expiringCreds := creds.(CredentialsWithExpiry) + expiry, ok := expiringCreds.Expiry() + require.True(t, ok, "expiry should be known when we build creds from a profile") + require.Equal(t, tlsCertNotAfter, expiry) } func writeProfile(t *testing.T, p *profile.Profile) { @@ -306,6 +351,7 @@ m1gfG9yqEte7pxv3yWM+7X2bzEjCBds4feahuKPNxOAOSfLUZiTpmOVlRzrpRIhu WQdM2NXAMABGAofGrVklPIiraUoHzr0Xxpia4vQwRewYXv8bCPHW+8g8vGBGvoG2 gtLit9DL5DR5ac/CRGJt -----END CERTIFICATE-----`) + tlsCertNotAfter = time.Date(2021, 2, 18, 8, 28, 21, 0, time.UTC) keyPEM = []byte(`-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEAzkUVoJ4rn2XAi2HJeBIIxlsdMPGzLroJub9eHAVspAueDJLS @@ -420,6 +466,10 @@ func TestDynamicIdentityFileCreds(t *testing.T) { require.NoError(t, err) require.Equal(t, wantTLSCert, *gotTLSCert) + expiry, ok := cred.Expiry() + require.True(t, ok, "expiry should be known when we build creds from an identity file") + require.Equal(t, tlsCertNotAfter, expiry) + tlsCACertPEM, _ := pem.Decode(tlsCACert) tlsCACertDER, err := x509.ParseCertificate(tlsCACertPEM.Bytes) require.NoError(t, err) @@ -427,6 +477,7 @@ func TestDynamicIdentityFileCreds(t *testing.T) { wantCertPool.AddCert(tlsCACertDER) require.True(t, wantCertPool.Equal(tlsConfig.RootCAs), "tlsconfig.RootCAs mismatch") + newExpiry := tlsCertNotAfter.Add(24 * time.Hour) // Generate a new TLS certificate that contains the same private key as // the original. template := &x509.Certificate{ @@ -434,6 +485,7 @@ func TestDynamicIdentityFileCreds(t *testing.T) { Subject: pkix.Name{ CommonName: "example", }, + NotAfter: newExpiry, KeyUsage: x509.KeyUsageDigitalSignature, BasicConstraintsValid: true, DNSNames: []string{constants.APIDomain}, @@ -471,6 +523,10 @@ func TestDynamicIdentityFileCreds(t *testing.T) { wantTLSCert, err = tls.X509KeyPair(secondTLSCertPem, keyPEM) require.NoError(t, err) require.Equal(t, wantTLSCert, *gotTLSCert) + + expiry, ok = cred.Expiry() + require.True(t, ok, "expiry should be known when we build creds from an identity file") + require.Equal(t, newExpiry, expiry) } func ExampleDynamicIdentityFileCreds() { diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index 78f5127e4f994..817da428928c0 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -147,6 +147,11 @@ func (mc insecureCredentials) SSHClientConfig() (*ssh.ClientConfig, error) { return nil, trace.NotImplemented("no ssh config") } +// Expiry returns the credential expiry. insecureCredentials never expire. +func (mc insecureCredentials) Expiry() (time.Time, bool) { + return time.Time{}, true +} + // Client is a client to the Teleport Proxy SSH server on behalf of a user. // The Proxy SSH port used to serve only SSH, however portions of the api are // being migrated to gRPC to reduce latency. The Client is capable of communicating diff --git a/api/identityfile/identityfile.go b/api/identityfile/identityfile.go index cbbd09d94598e..068534c92f7af 100644 --- a/api/identityfile/identityfile.go +++ b/api/identityfile/identityfile.go @@ -26,6 +26,7 @@ import ( "io" "os" "strings" + "time" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" @@ -114,6 +115,18 @@ func (i *IdentityFile) SSHClientConfig() (*ssh.ClientConfig, error) { return ssh, nil } +// Expiry returns the credential expiry. +func (i *IdentityFile) Expiry() (time.Time, bool) { + if i.Certs.TLS == nil { + return time.Time{}, false + } + cert, _, err := keys.X509Certificate(i.Certs.TLS) + if err != nil { + return time.Time{}, false + } + return cert.NotAfter, true +} + // Write writes the given identityFile to the specified path. func Write(idFile *IdentityFile, path string) error { buf := new(bytes.Buffer) diff --git a/api/profile/profile.go b/api/profile/profile.go index 7475cdf5d1089..08d9911761e7a 100644 --- a/api/profile/profile.go +++ b/api/profile/profile.go @@ -25,6 +25,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" @@ -153,6 +154,19 @@ func (p *Profile) TLSConfig() (*tls.Config, error) { }, nil } +// Expiry returns the credential expiry. +func (p *Profile) Expiry() (time.Time, bool) { + certPEMBlock, err := os.ReadFile(p.TLSCertPath()) + if err != nil { + return time.Time{}, false + } + cert, _, err := keys.X509Certificate(certPEMBlock) + if err != nil { + return time.Time{}, false + } + return cert.NotAfter, true +} + // RequireKubeLocalProxy returns true if this profile indicates a local proxy // is required for kube access. func (p *Profile) RequireKubeLocalProxy() bool { diff --git a/api/utils/keys/privatekey.go b/api/utils/keys/privatekey.go index 8abcad2b22adf..adb7f358efe1c 100644 --- a/api/utils/keys/privatekey.go +++ b/api/utils/keys/privatekey.go @@ -98,32 +98,12 @@ func (k *PrivateKey) TLSCertificate(certPEMBlock []byte) (tls.Certificate, error PrivateKey: k.Signer, } - var skippedBlockTypes []string - for { - var certDERBlock *pem.Block - certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) - if certDERBlock == nil { - break - } - if certDERBlock.Type == "CERTIFICATE" { - cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) - } else { - skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) - } - } - - if len(cert.Certificate) == 0 { - if len(skippedBlockTypes) == 0 { - return tls.Certificate{}, trace.BadParameter("tls: failed to find any PEM data in certificate input") - } - return tls.Certificate{}, trace.BadParameter("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes) - } - - // Check that the certificate's public key matches this private key. - x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + // Parse the certificate and verify it is valid. + x509Cert, rawCerts, err := X509Certificate(certPEMBlock) if err != nil { return tls.Certificate{}, trace.Wrap(err) } + cert.Certificate = rawCerts if keyPub, ok := k.Public().(cryptoPublicKeyI); !ok { return tls.Certificate{}, trace.BadParameter("private key does not contain a valid public key") @@ -334,3 +314,36 @@ func IsRSAPrivateKey(privKey []byte) bool { return false } } + +// X509Certificate takes a PEM-encoded file containing one or more certificates, extracts all certificates, and parses +// the Leaf certificate (the first one in the chain). If you are loading both a certificate and a private key, you +// should use X509KeyPair instead. +func X509Certificate(certPEMBlock []byte) (*x509.Certificate, [][]byte, error) { + var skippedBlockTypes []string + var rawCerts [][]byte + for { + var certDERBlock *pem.Block + certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) + if certDERBlock == nil { + break + } + if certDERBlock.Type == "CERTIFICATE" { + rawCerts = append(rawCerts, certDERBlock.Bytes) + } else { + skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) + } + } + + if len(rawCerts) == 0 { + if len(skippedBlockTypes) == 0 { + return nil, nil, trace.BadParameter("tls: failed to find any PEM data in certificate input") + } + return nil, nil, trace.BadParameter("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes) + } + + x509Cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return nil, rawCerts, trace.Wrap(err, "failed to parse certificate") + } + return x509Cert, rawCerts, nil +} diff --git a/api/utils/keys/privatekey_test.go b/api/utils/keys/privatekey_test.go index ffba01c134cd5..ce84ff8141580 100644 --- a/api/utils/keys/privatekey_test.go +++ b/api/utils/keys/privatekey_test.go @@ -25,8 +25,10 @@ import ( "crypto/rand" "crypto/rsa" "crypto/tls" + "crypto/x509" "encoding/pem" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -90,6 +92,69 @@ func TestX509KeyPair(t *testing.T) { } } +func TestX509Certificate(t *testing.T) { + // Checking certificate expiry to see if the certificate got parsed and we did not get an empty struct. + hasExpiry := func(t require.TestingT, i interface{}, args ...interface{}) { + cert, ok := i.(*x509.Certificate) + require.True(t, ok) + require.NotNil(t, cert) + require.Equal(t, rsaCertExpiry, cert.NotAfter) + } + + nilCert := func(t require.TestingT, i interface{}, args ...interface{}) { + cert, ok := i.(*x509.Certificate) + require.True(t, ok) + require.Nil(t, cert) + } + + for _, tc := range []struct { + name string + keyPEM []byte + certPEM []byte + expectedLength int + expectedError require.ErrorAssertionFunc + validateResult require.ValueAssertionFunc + }{ + { + name: "rsa cert", + certPEM: rsaCertPEM, + expectedLength: 1, + expectedError: require.NoError, + validateResult: hasExpiry, + }, { + name: "rsa certs", + certPEM: func() []byte { + // encode two certs into certPEM. + rsaCertPEMDuplicated := new(bytes.Buffer) + der, _ := pem.Decode(rsaCertPEM) + pem.Encode(rsaCertPEMDuplicated, der) + pem.Encode(rsaCertPEMDuplicated, der) + return rsaCertPEMDuplicated.Bytes() + }(), + expectedLength: 2, + expectedError: require.NoError, + validateResult: hasExpiry, + }, + { + name: "no cert", + certPEM: []byte{}, + expectedLength: 0, + expectedError: require.Error, + validateResult: nilCert, + }, + } { + t.Run(tc.name, func(t *testing.T) { + cert, rawCerts, err := X509Certificate(tc.certPEM) + require.Len(t, rawCerts, tc.expectedLength) + + tc.expectedError(t, err) + + tc.validateResult(t, cert) + }) + } + +} + var ( // generated with `openssl req -x509 -out rsa.crt -keyout rsa.key -newkey rsa:2048 -nodes -sha256` rsaKeyPEM = []byte(`-----BEGIN PRIVATE KEY----- @@ -141,4 +206,5 @@ mg0exCUFW40aXpfm0z0dNNwoN+FPSefKMYMQ1LV87I6zGnmVTYH9Nix3REiuliIQ 7XXnJc7A6tsc6yXdVG6IpGnKXuTvl/r4iIbH+JDv3MDSvZSCE5kzAPFjgB3zMAZ8 Z0+424ERgom0Zdy75Y8I -----END CERTIFICATE-----`) + rsaCertExpiry = time.Date(2022, time.September, 21, 19, 1, 1, 0, time.UTC) ) diff --git a/integrations/lib/credentials/credentials_test.go b/integrations/lib/credentials/credentials_test.go index 971a316c363b4..a0c917323d726 100644 --- a/integrations/lib/credentials/credentials_test.go +++ b/integrations/lib/credentials/credentials_test.go @@ -54,6 +54,10 @@ func (mc *mockTLSCredentials) SSHClientConfig() (*ssh.ClientConfig, error) { return nil, trace.NotImplemented("no ssh config") } +func (mc *mockTLSCredentials) Expiry() (time.Time, bool) { + return time.Time{}, true +} + func TestCheckExpiredCredentials(t *testing.T) { // Setup the CA and sign the client certs ca := &x509.Certificate{ diff --git a/lib/tbot/config/service_client_credential.go b/lib/tbot/config/service_client_credential.go index 0f8abcb886d1c..d06b04dff4313 100644 --- a/lib/tbot/config/service_client_credential.go +++ b/lib/tbot/config/service_client_credential.go @@ -21,6 +21,7 @@ package config import ( "crypto/tls" "sync" + "time" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" @@ -32,8 +33,8 @@ import ( const UnstableClientCredentialOutputType = "unstable_client_credential" var ( - _ ServiceConfig = &UnstableClientCredentialOutput{} - _ client.Credentials = &UnstableClientCredentialOutput{} + _ ServiceConfig = &UnstableClientCredentialOutput{} + _ client.CredentialsWithExpiry = &UnstableClientCredentialOutput{} ) // UnstableClientCredentialOutput is an experimental tbot output which is @@ -95,6 +96,16 @@ func (o *UnstableClientCredentialOutput) SSHClientConfig() (*ssh.ClientConfig, e return o.facade.SSHClientConfig() } +// Expiry returns the credential expiry. +func (o *UnstableClientCredentialOutput) Expiry() (time.Time, bool) { + o.mu.Lock() + defer o.mu.Unlock() + if o.facade == nil { + return time.Time{}, false + } + return o.facade.Expiry() +} + // Facade returns the underlying facade func (o *UnstableClientCredentialOutput) Facade() (*identity.Facade, error) { o.mu.Lock() diff --git a/lib/tbot/identity/identity_facade.go b/lib/tbot/identity/identity_facade.go index e2a7e00e5f37a..59010840b3a0d 100644 --- a/lib/tbot/identity/identity_facade.go +++ b/lib/tbot/identity/identity_facade.go @@ -22,6 +22,7 @@ import ( "crypto/tls" "crypto/x509" "sync" + "time" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" @@ -30,6 +31,7 @@ import ( "github.com/gravitational/teleport/api/client" apidefaults "github.com/gravitational/teleport/api/defaults" apiutils "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" @@ -204,3 +206,15 @@ func (f *Facade) SSHClientConfig() (*ssh.ClientConfig, error) { } return cfg, nil } + +// Expiry returns the credential expiry. +func (f *Facade) Expiry() (time.Time, bool) { + if len(f.identity.TLSCert.Certificate) == 0 { + return time.Time{}, false + } + cert, _, err := keys.X509Certificate(f.identity.TLSCert.Certificate[0]) + if err != nil { + return time.Time{}, false + } + return cert.NotAfter, true +}