Skip to content

Commit

Permalink
[v16] Support detecting a credential expiry (#44080) (#44658)
Browse files Browse the repository at this point in the history
* Support detecting a credential expiry (#43886)

* Support detecting a credential expiry

* Fix Expiry() and add tests

* protect identityFile cred with a lock

* add godoc

* Fix tests

* Do not break previous Credentials implementations

* fixup! Do not break previous Credentials implementations

* address edoardo's comments

* fix tests
  • Loading branch information
hugoShaka authored Jul 26, 2024
1 parent f4a20be commit 043d690
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 26 deletions.
101 changes: 100 additions & 1 deletion api/client/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"crypto/x509"
"os"
"sync"
"time"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
Expand All @@ -40,6 +41,10 @@ import (
// also provide other functionality, such as automatic address discovery and
// ssh connectivity.
//
// Note: starting with v17, all Credentials must have an Expiry method.
// For compatibility guarantees, the future interface is optional and called
// CredentialsWithExpiry.
//
// See the examples below for an example of each loader.
type Credentials interface {
// TLSConfig returns TLS configuration used to authenticate the client.
Expand All @@ -49,6 +54,21 @@ type Credentials interface {
SSHClientConfig() (*ssh.ClientConfig, error)
}

// CredentialsWithExpiry are credentials implementing the Expiry() function.
// This interface is here to avoid breaking changes in v15 and v16. Starting with
// v17, Expiry() will be part of the Credentials interface.
type CredentialsWithExpiry interface {
Credentials

// Expiry returns the Credentials expiry if it's possible to know its expiry.
// When expiry can be determined returns true, else returns false.
// If the Credentials don't expire, returns the zero time.
// If the Credential is dynamically refreshed or reloaded, (e.g filesystem
// reload or tbot renewal), Expiry returns the expiry of the currently active
// Credentials.
Expiry() (time.Time, bool)
}

// CredentialsWithDefaultAddrs additionally provides default addresses sourced
// from the credential which are used when the client has not been explicitly
// configured with an address.
Expand All @@ -60,6 +80,17 @@ type CredentialsWithDefaultAddrs interface {
DefaultAddrs() ([]string, error)
}

// Expiry checks if the Credentials has an Expiry function and invokes it.
// Starting with v17, this is part of the Credentials interface but we must be backward compatible with v16 and below.
// If the Credentials don't implement Expiry, returns false.
func Expiry(c Credentials) (time.Time, bool) {
credsWithExpiry, ok := c.(CredentialsWithExpiry)
if !ok {
return time.Time{}, false
}
return credsWithExpiry.Expiry()
}

// LoadTLS is used to load Credentials directly from a *tls.Config.
//
// TLS creds can only be used to connect directly to a Teleport Auth server.
Expand Down Expand Up @@ -87,6 +118,12 @@ func (c *tlsConfigCreds) SSHClientConfig() (*ssh.ClientConfig, error) {
return nil, trace.NotImplemented("no ssh config")
}

// Expiry returns the credential expiry. As the tlsConfigCreds are built from an existing tlsConfig
// we have no way of knowing which certificate will be returned and if it's expired.
func (c *tlsConfigCreds) Expiry() (time.Time, bool) {
return time.Time{}, false
}

// LoadKeyPair is used to load Credentials from a certicate keypair on disk.
//
// KeyPair Credentials can only be used to connect directly to a Teleport Auth server.
Expand Down Expand Up @@ -115,7 +152,7 @@ type keypairCreds struct {

// TLSConfig returns TLS configuration.
func (c *keypairCreds) TLSConfig() (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(c.certFile, c.keyFile)
cert, err := keys.LoadX509KeyPair(c.certFile, c.keyFile)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -141,6 +178,19 @@ func (c *keypairCreds) SSHClientConfig() (*ssh.ClientConfig, error) {
return nil, trace.NotImplemented("no ssh config")
}

// Expiry returns the credential expiry.
func (c *keypairCreds) Expiry() (time.Time, bool) {
certPEMBlock, err := os.ReadFile(c.certFile)
if err != nil {
return time.Time{}, false
}
cert, _, err := keys.X509Certificate(certPEMBlock)
if err != nil {
return time.Time{}, false
}
return cert.NotAfter, true
}

// LoadIdentityFile is used to load Credentials from an identity file on disk.
//
// Identity Credentials can be used to connect to an auth server directly
Expand All @@ -162,12 +212,16 @@ func LoadIdentityFile(path string) Credentials {

// identityCredsFile use an identity file to provide client credentials.
type identityCredsFile struct {
// mutex protects identityFile
mutex sync.Mutex
identityFile *identityfile.IdentityFile
path string
}

// TLSConfig returns TLS configuration.
func (c *identityCredsFile) TLSConfig() (*tls.Config, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if err := c.load(); err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -182,6 +236,8 @@ func (c *identityCredsFile) TLSConfig() (*tls.Config, error) {

// SSHClientConfig returns SSH configuration.
func (c *identityCredsFile) SSHClientConfig() (*ssh.ClientConfig, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if err := c.load(); err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -194,6 +250,16 @@ func (c *identityCredsFile) SSHClientConfig() (*ssh.ClientConfig, error) {
return sshConfig, nil
}

// Expiry returns the credential expiry.
func (c *identityCredsFile) Expiry() (time.Time, bool) {
c.mutex.Lock()
defer c.mutex.Unlock()
if err := c.load(); err != nil {
return time.Time{}, false
}
return c.identityFile.Expiry()
}

// load is used to lazy load the identity file from persistent storage.
// This allows LoadIdentity to avoid possible errors for UX purposes.
func (c *identityCredsFile) load() error {
Expand Down Expand Up @@ -260,6 +326,15 @@ func (c *identityCredsString) SSHClientConfig() (*ssh.ClientConfig, error) {
return sshConfig, nil
}

// Expiry returns the credential expiry.
func (c *identityCredsString) Expiry() (time.Time, bool) {
if err := c.load(); err != nil {
return time.Time{}, false
}

return c.identityFile.Expiry()
}

// load is used to lazy load the identity file from a string.
func (c *identityCredsString) load() error {
if c.identityFile != nil {
Expand Down Expand Up @@ -329,6 +404,13 @@ func (c *profileCreds) SSHClientConfig() (*ssh.ClientConfig, error) {
return sshConfig, nil
}

func (c *profileCreds) Expiry() (time.Time, bool) {
if err := c.load(); err != nil {
return time.Time{}, false
}
return c.profile.Expiry()
}

// DefaultAddrs implements CredentialsWithDefaultAddrs by providing the
// WebProxyAddr from the credential
func (c *profileCreds) DefaultAddrs() ([]string, error) {
Expand Down Expand Up @@ -563,3 +645,20 @@ func (d *DynamicIdentityFileCreds) SSHClientConfig() (*ssh.ClientConfig, error)
}
return cfg, nil
}

// Expiry returns the current credential expiry.
func (d *DynamicIdentityFileCreds) Expiry() (time.Time, bool) {
d.mu.RLock()
defer d.mu.RUnlock()

if d.tlsCert == nil || len(d.tlsCert.Certificate) == 0 {
return time.Time{}, false
}

x509Cert, err := x509.ParseCertificate(d.tlsCert.Certificate[0])
if err != nil {
return time.Time{}, false
}

return x509Cert.NotAfter, true
}
56 changes: 56 additions & 0 deletions api/client/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ func TestLoadTLS(t *testing.T) {
require.NoError(t, err)
requireEqualTLSConfig(t, expectedTLSConfig, tlsConfig)

require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
expiry, ok := expiringCreds.Expiry()
require.False(t, ok, "Expiry should not be knows on creds built only from TLS Config")
require.True(t, expiry.IsZero(), "unknown expiry should be zero")

// Load invalid tls.Config.
invalidTLSCreds := LoadTLS(nil)
_, err = invalidTLSCreds.TLSConfig()
Expand Down Expand Up @@ -100,12 +106,22 @@ func TestLoadIdentityFile(t *testing.T) {
require.NoError(t, err)
requireEqualSSHConfig(t, expectedSSHConfig, sshConfig)

require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
expiry, ok := expiringCreds.Expiry()
require.True(t, ok, "Expiry should be known when we build creds from an identity file")
require.Equal(t, tlsCertNotAfter, expiry)

// Load invalid identity.
creds = LoadIdentityFile("invalid_path")
_, err = creds.TLSConfig()
require.Error(t, err)
_, err = creds.SSHClientConfig()
require.Error(t, err)
require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds = creds.(CredentialsWithExpiry)
_, ok = expiringCreds.Expiry()
require.False(t, ok, "expiry should be unknown on a broken id file")
}

func TestLoadIdentityFileFromString(t *testing.T) {
Expand Down Expand Up @@ -146,12 +162,22 @@ func TestLoadIdentityFileFromString(t *testing.T) {
require.NoError(t, err)
requireEqualSSHConfig(t, expectedSSHConfig, sshConfig)

require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
expiry, ok := expiringCreds.Expiry()
require.True(t, ok, "expiry should be known when we build creds from an identity file")
require.Equal(t, tlsCertNotAfter, expiry)

// Load invalid identity.
creds = LoadIdentityFileFromString("invalid_creds")
_, err = creds.TLSConfig()
require.Error(t, err)
_, err = creds.SSHClientConfig()
require.Error(t, err)
require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds = creds.(CredentialsWithExpiry)
_, ok = expiringCreds.Expiry()
require.False(t, ok, "expiry should be unknown on a broken id file")
}

func TestLoadKeyPair(t *testing.T) {
Expand All @@ -176,11 +202,20 @@ func TestLoadKeyPair(t *testing.T) {
tlsConfig, err := creds.TLSConfig()
require.NoError(t, err)
requireEqualTLSConfig(t, expectedTLSConfig, tlsConfig)
require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
expiry, ok := expiringCreds.Expiry()
require.True(t, ok, "expiry should be known when we build creds from cert files")
require.Equal(t, tlsCertNotAfter, expiry)

// Load invalid keypairs.
invalidIdentityCreds := LoadKeyPair("invalid_path", "invalid_path", "invalid_path")
_, err = invalidIdentityCreds.TLSConfig()
require.Error(t, err)
require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds = invalidIdentityCreds.(CredentialsWithExpiry)
_, ok = expiringCreds.Expiry()
require.False(t, ok, "expiry should be unknown on a broken credential")
}

func TestLoadProfile(t *testing.T) {
Expand All @@ -207,6 +242,10 @@ func TestLoadProfile(t *testing.T) {
require.Error(t, err)
_, err = creds.SSHClientConfig()
require.Error(t, err)
require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
_, ok := expiringCreds.Expiry()
require.False(t, ok, "expiry should be unknown on a broken profile")
})
}

Expand All @@ -226,6 +265,12 @@ func testProfileContents(t *testing.T, dir, name string) {
sshConfig, err := creds.SSHClientConfig()
require.NoError(t, err)
requireEqualSSHConfig(t, expectedSSHConfig, sshConfig)

require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
expiry, ok := expiringCreds.Expiry()
require.True(t, ok, "expiry should be known when we build creds from a profile")
require.Equal(t, tlsCertNotAfter, expiry)
}

func writeProfile(t *testing.T, p *profile.Profile) {
Expand Down Expand Up @@ -306,6 +351,7 @@ m1gfG9yqEte7pxv3yWM+7X2bzEjCBds4feahuKPNxOAOSfLUZiTpmOVlRzrpRIhu
WQdM2NXAMABGAofGrVklPIiraUoHzr0Xxpia4vQwRewYXv8bCPHW+8g8vGBGvoG2
gtLit9DL5DR5ac/CRGJt
-----END CERTIFICATE-----`)
tlsCertNotAfter = time.Date(2021, 2, 18, 8, 28, 21, 0, time.UTC)

keyPEM = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAzkUVoJ4rn2XAi2HJeBIIxlsdMPGzLroJub9eHAVspAueDJLS
Expand Down Expand Up @@ -420,20 +466,26 @@ func TestDynamicIdentityFileCreds(t *testing.T) {
require.NoError(t, err)
require.Equal(t, wantTLSCert, *gotTLSCert)

expiry, ok := cred.Expiry()
require.True(t, ok, "expiry should be known when we build creds from an identity file")
require.Equal(t, tlsCertNotAfter, expiry)

tlsCACertPEM, _ := pem.Decode(tlsCACert)
tlsCACertDER, err := x509.ParseCertificate(tlsCACertPEM.Bytes)
require.NoError(t, err)
wantCertPool := x509.NewCertPool()
wantCertPool.AddCert(tlsCACertDER)
require.True(t, wantCertPool.Equal(tlsConfig.RootCAs), "tlsconfig.RootCAs mismatch")

newExpiry := tlsCertNotAfter.Add(24 * time.Hour)
// Generate a new TLS certificate that contains the same private key as
// the original.
template := &x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{
CommonName: "example",
},
NotAfter: newExpiry,
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
DNSNames: []string{constants.APIDomain},
Expand Down Expand Up @@ -471,6 +523,10 @@ func TestDynamicIdentityFileCreds(t *testing.T) {
wantTLSCert, err = tls.X509KeyPair(secondTLSCertPem, keyPEM)
require.NoError(t, err)
require.Equal(t, wantTLSCert, *gotTLSCert)

expiry, ok = cred.Expiry()
require.True(t, ok, "expiry should be known when we build creds from an identity file")
require.Equal(t, newExpiry, expiry)
}

func ExampleDynamicIdentityFileCreds() {
Expand Down
5 changes: 5 additions & 0 deletions api/client/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ func (mc insecureCredentials) SSHClientConfig() (*ssh.ClientConfig, error) {
return nil, trace.NotImplemented("no ssh config")
}

// Expiry returns the credential expiry. insecureCredentials never expire.
func (mc insecureCredentials) Expiry() (time.Time, bool) {
return time.Time{}, true
}

// Client is a client to the Teleport Proxy SSH server on behalf of a user.
// The Proxy SSH port used to serve only SSH, however portions of the api are
// being migrated to gRPC to reduce latency. The Client is capable of communicating
Expand Down
13 changes: 13 additions & 0 deletions api/identityfile/identityfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"io"
"os"
"strings"
"time"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -114,6 +115,18 @@ func (i *IdentityFile) SSHClientConfig() (*ssh.ClientConfig, error) {
return ssh, nil
}

// Expiry returns the credential expiry.
func (i *IdentityFile) Expiry() (time.Time, bool) {
if i.Certs.TLS == nil {
return time.Time{}, false
}
cert, _, err := keys.X509Certificate(i.Certs.TLS)
if err != nil {
return time.Time{}, false
}
return cert.NotAfter, true
}

// Write writes the given identityFile to the specified path.
func Write(idFile *IdentityFile, path string) error {
buf := new(bytes.Buffer)
Expand Down
Loading

0 comments on commit 043d690

Please sign in to comment.