Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v15] Support detecting a credential expiry #44658

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion api/client/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"crypto/x509"
"os"
"sync"
"time"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
Expand All @@ -40,6 +41,10 @@ import (
// also provide other functionality, such as automatic address discovery and
// ssh connectivity.
//
// Note: starting with v17, all Credentials must have an Expiry method.
// For compatibility guarantees, the future interface is optional and called
// CredentialsWithExpiry.
//
// See the examples below for an example of each loader.
type Credentials interface {
// TLSConfig returns TLS configuration used to authenticate the client.
Expand All @@ -49,6 +54,21 @@ type Credentials interface {
SSHClientConfig() (*ssh.ClientConfig, error)
}

// CredentialsWithExpiry are credentials implementing the Expiry() function.
// This interface is here to avoid breaking changes in v15 and v16. Starting with
// v17, Expiry() will be part of the Credentials interface.
type CredentialsWithExpiry interface {
Credentials

// Expiry returns the Credentials expiry if it's possible to know its expiry.
// When expiry can be determined returns true, else returns false.
// If the Credentials don't expire, returns the zero time.
// If the Credential is dynamically refreshed or reloaded, (e.g filesystem
// reload or tbot renewal), Expiry returns the expiry of the currently active
// Credentials.
Expiry() (time.Time, bool)
}

// CredentialsWithDefaultAddrs additionally provides default addresses sourced
// from the credential which are used when the client has not been explicitly
// configured with an address.
Expand All @@ -60,6 +80,17 @@ type CredentialsWithDefaultAddrs interface {
DefaultAddrs() ([]string, error)
}

// Expiry checks if the Credentials has an Expiry function and invokes it.
// Starting with v17, this is part of the Credentials interface but we must be backward compatible with v16 and below.
// If the Credentials don't implement Expiry, returns false.
func Expiry(c Credentials) (time.Time, bool) {
credsWithExpiry, ok := c.(CredentialsWithExpiry)
if !ok {
return time.Time{}, false
}
return credsWithExpiry.Expiry()
}

// LoadTLS is used to load Credentials directly from a *tls.Config.
//
// TLS creds can only be used to connect directly to a Teleport Auth server.
Expand Down Expand Up @@ -87,6 +118,12 @@ func (c *tlsConfigCreds) SSHClientConfig() (*ssh.ClientConfig, error) {
return nil, trace.NotImplemented("no ssh config")
}

// Expiry returns the credential expiry. As the tlsConfigCreds are built from an existing tlsConfig
// we have no way of knowing which certificate will be returned and if it's expired.
func (c *tlsConfigCreds) Expiry() (time.Time, bool) {
return time.Time{}, false
}

// LoadKeyPair is used to load Credentials from a certicate keypair on disk.
//
// KeyPair Credentials can only be used to connect directly to a Teleport Auth server.
Expand Down Expand Up @@ -115,7 +152,7 @@ type keypairCreds struct {

// TLSConfig returns TLS configuration.
func (c *keypairCreds) TLSConfig() (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(c.certFile, c.keyFile)
cert, err := keys.LoadX509KeyPair(c.certFile, c.keyFile)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -141,6 +178,19 @@ func (c *keypairCreds) SSHClientConfig() (*ssh.ClientConfig, error) {
return nil, trace.NotImplemented("no ssh config")
}

// Expiry returns the credential expiry.
func (c *keypairCreds) Expiry() (time.Time, bool) {
certPEMBlock, err := os.ReadFile(c.certFile)
if err != nil {
return time.Time{}, false
}
cert, _, err := keys.X509Certificate(certPEMBlock)
if err != nil {
return time.Time{}, false
}
return cert.NotAfter, true
}

// LoadIdentityFile is used to load Credentials from an identity file on disk.
//
// Identity Credentials can be used to connect to an auth server directly
Expand All @@ -162,12 +212,16 @@ func LoadIdentityFile(path string) Credentials {

// identityCredsFile use an identity file to provide client credentials.
type identityCredsFile struct {
// mutex protects identityFile
mutex sync.Mutex
identityFile *identityfile.IdentityFile
path string
}

// TLSConfig returns TLS configuration.
func (c *identityCredsFile) TLSConfig() (*tls.Config, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if err := c.load(); err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -182,6 +236,8 @@ func (c *identityCredsFile) TLSConfig() (*tls.Config, error) {

// SSHClientConfig returns SSH configuration.
func (c *identityCredsFile) SSHClientConfig() (*ssh.ClientConfig, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if err := c.load(); err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -194,6 +250,16 @@ func (c *identityCredsFile) SSHClientConfig() (*ssh.ClientConfig, error) {
return sshConfig, nil
}

// Expiry returns the credential expiry.
func (c *identityCredsFile) Expiry() (time.Time, bool) {
c.mutex.Lock()
defer c.mutex.Unlock()
if err := c.load(); err != nil {
return time.Time{}, false
}
return c.identityFile.Expiry()
}

// load is used to lazy load the identity file from persistent storage.
// This allows LoadIdentity to avoid possible errors for UX purposes.
func (c *identityCredsFile) load() error {
Expand Down Expand Up @@ -260,6 +326,15 @@ func (c *identityCredsString) SSHClientConfig() (*ssh.ClientConfig, error) {
return sshConfig, nil
}

// Expiry returns the credential expiry.
func (c *identityCredsString) Expiry() (time.Time, bool) {
if err := c.load(); err != nil {
return time.Time{}, false
}

return c.identityFile.Expiry()
}

// load is used to lazy load the identity file from a string.
func (c *identityCredsString) load() error {
if c.identityFile != nil {
Expand Down Expand Up @@ -329,6 +404,13 @@ func (c *profileCreds) SSHClientConfig() (*ssh.ClientConfig, error) {
return sshConfig, nil
}

func (c *profileCreds) Expiry() (time.Time, bool) {
if err := c.load(); err != nil {
return time.Time{}, false
}
return c.profile.Expiry()
}

// DefaultAddrs implements CredentialsWithDefaultAddrs by providing the
// WebProxyAddr from the credential
func (c *profileCreds) DefaultAddrs() ([]string, error) {
Expand Down Expand Up @@ -563,3 +645,20 @@ func (d *DynamicIdentityFileCreds) SSHClientConfig() (*ssh.ClientConfig, error)
}
return cfg, nil
}

// Expiry returns the current credential expiry.
func (d *DynamicIdentityFileCreds) Expiry() (time.Time, bool) {
d.mu.RLock()
defer d.mu.RUnlock()

if d.tlsCert == nil || len(d.tlsCert.Certificate) == 0 {
return time.Time{}, false
}

x509Cert, err := x509.ParseCertificate(d.tlsCert.Certificate[0])
if err != nil {
return time.Time{}, false
}

return x509Cert.NotAfter, true
}
56 changes: 56 additions & 0 deletions api/client/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ func TestLoadTLS(t *testing.T) {
require.NoError(t, err)
requireEqualTLSConfig(t, expectedTLSConfig, tlsConfig)

require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
expiry, ok := expiringCreds.Expiry()
require.False(t, ok, "Expiry should not be knows on creds built only from TLS Config")
require.True(t, expiry.IsZero(), "unknown expiry should be zero")

// Load invalid tls.Config.
invalidTLSCreds := LoadTLS(nil)
_, err = invalidTLSCreds.TLSConfig()
Expand Down Expand Up @@ -100,12 +106,22 @@ func TestLoadIdentityFile(t *testing.T) {
require.NoError(t, err)
requireEqualSSHConfig(t, expectedSSHConfig, sshConfig)

require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
expiry, ok := expiringCreds.Expiry()
require.True(t, ok, "Expiry should be known when we build creds from an identity file")
require.Equal(t, tlsCertNotAfter, expiry)

// Load invalid identity.
creds = LoadIdentityFile("invalid_path")
_, err = creds.TLSConfig()
require.Error(t, err)
_, err = creds.SSHClientConfig()
require.Error(t, err)
require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds = creds.(CredentialsWithExpiry)
_, ok = expiringCreds.Expiry()
require.False(t, ok, "expiry should be unknown on a broken id file")
}

func TestLoadIdentityFileFromString(t *testing.T) {
Expand Down Expand Up @@ -146,12 +162,22 @@ func TestLoadIdentityFileFromString(t *testing.T) {
require.NoError(t, err)
requireEqualSSHConfig(t, expectedSSHConfig, sshConfig)

require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
expiry, ok := expiringCreds.Expiry()
require.True(t, ok, "expiry should be known when we build creds from an identity file")
require.Equal(t, tlsCertNotAfter, expiry)

// Load invalid identity.
creds = LoadIdentityFileFromString("invalid_creds")
_, err = creds.TLSConfig()
require.Error(t, err)
_, err = creds.SSHClientConfig()
require.Error(t, err)
require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds = creds.(CredentialsWithExpiry)
_, ok = expiringCreds.Expiry()
require.False(t, ok, "expiry should be unknown on a broken id file")
}

func TestLoadKeyPair(t *testing.T) {
Expand All @@ -176,11 +202,20 @@ func TestLoadKeyPair(t *testing.T) {
tlsConfig, err := creds.TLSConfig()
require.NoError(t, err)
requireEqualTLSConfig(t, expectedTLSConfig, tlsConfig)
require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
expiry, ok := expiringCreds.Expiry()
require.True(t, ok, "expiry should be known when we build creds from cert files")
require.Equal(t, tlsCertNotAfter, expiry)

// Load invalid keypairs.
invalidIdentityCreds := LoadKeyPair("invalid_path", "invalid_path", "invalid_path")
_, err = invalidIdentityCreds.TLSConfig()
require.Error(t, err)
require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds = invalidIdentityCreds.(CredentialsWithExpiry)
_, ok = expiringCreds.Expiry()
require.False(t, ok, "expiry should be unknown on a broken credential")
}

func TestLoadProfile(t *testing.T) {
Expand All @@ -207,6 +242,10 @@ func TestLoadProfile(t *testing.T) {
require.Error(t, err)
_, err = creds.SSHClientConfig()
require.Error(t, err)
require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
_, ok := expiringCreds.Expiry()
require.False(t, ok, "expiry should be unknown on a broken profile")
})
}

Expand All @@ -226,6 +265,12 @@ func testProfileContents(t *testing.T, dir, name string) {
sshConfig, err := creds.SSHClientConfig()
require.NoError(t, err)
requireEqualSSHConfig(t, expectedSSHConfig, sshConfig)

require.Implements(t, (*CredentialsWithExpiry)(nil), creds)
expiringCreds := creds.(CredentialsWithExpiry)
expiry, ok := expiringCreds.Expiry()
require.True(t, ok, "expiry should be known when we build creds from a profile")
require.Equal(t, tlsCertNotAfter, expiry)
}

func writeProfile(t *testing.T, p *profile.Profile) {
Expand Down Expand Up @@ -306,6 +351,7 @@ m1gfG9yqEte7pxv3yWM+7X2bzEjCBds4feahuKPNxOAOSfLUZiTpmOVlRzrpRIhu
WQdM2NXAMABGAofGrVklPIiraUoHzr0Xxpia4vQwRewYXv8bCPHW+8g8vGBGvoG2
gtLit9DL5DR5ac/CRGJt
-----END CERTIFICATE-----`)
tlsCertNotAfter = time.Date(2021, 2, 18, 8, 28, 21, 0, time.UTC)

keyPEM = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAzkUVoJ4rn2XAi2HJeBIIxlsdMPGzLroJub9eHAVspAueDJLS
Expand Down Expand Up @@ -420,20 +466,26 @@ func TestDynamicIdentityFileCreds(t *testing.T) {
require.NoError(t, err)
require.Equal(t, wantTLSCert, *gotTLSCert)

expiry, ok := cred.Expiry()
require.True(t, ok, "expiry should be known when we build creds from an identity file")
require.Equal(t, tlsCertNotAfter, expiry)

tlsCACertPEM, _ := pem.Decode(tlsCACert)
tlsCACertDER, err := x509.ParseCertificate(tlsCACertPEM.Bytes)
require.NoError(t, err)
wantCertPool := x509.NewCertPool()
wantCertPool.AddCert(tlsCACertDER)
require.True(t, wantCertPool.Equal(tlsConfig.RootCAs), "tlsconfig.RootCAs mismatch")

newExpiry := tlsCertNotAfter.Add(24 * time.Hour)
// Generate a new TLS certificate that contains the same private key as
// the original.
template := &x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{
CommonName: "example",
},
NotAfter: newExpiry,
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
DNSNames: []string{constants.APIDomain},
Expand Down Expand Up @@ -471,6 +523,10 @@ func TestDynamicIdentityFileCreds(t *testing.T) {
wantTLSCert, err = tls.X509KeyPair(secondTLSCertPem, keyPEM)
require.NoError(t, err)
require.Equal(t, wantTLSCert, *gotTLSCert)

expiry, ok = cred.Expiry()
require.True(t, ok, "expiry should be known when we build creds from an identity file")
require.Equal(t, newExpiry, expiry)
}

func ExampleDynamicIdentityFileCreds() {
Expand Down
5 changes: 5 additions & 0 deletions api/client/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ func (mc insecureCredentials) SSHClientConfig() (*ssh.ClientConfig, error) {
return nil, trace.NotImplemented("no ssh config")
}

// Expiry returns the credential expiry. insecureCredentials never expire.
func (mc insecureCredentials) Expiry() (time.Time, bool) {
return time.Time{}, true
}

// Client is a client to the Teleport Proxy SSH server on behalf of a user.
// The Proxy SSH port used to serve only SSH, however portions of the api are
// being migrated to gRPC to reduce latency. The Client is capable of communicating
Expand Down
13 changes: 13 additions & 0 deletions api/identityfile/identityfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"io"
"os"
"strings"
"time"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -114,6 +115,18 @@ func (i *IdentityFile) SSHClientConfig() (*ssh.ClientConfig, error) {
return ssh, nil
}

// Expiry returns the credential expiry.
func (i *IdentityFile) Expiry() (time.Time, bool) {
if i.Certs.TLS == nil {
return time.Time{}, false
}
cert, _, err := keys.X509Certificate(i.Certs.TLS)
if err != nil {
return time.Time{}, false
}
return cert.NotAfter, true
}

// Write writes the given identityFile to the specified path.
func Write(idFile *IdentityFile, path string) error {
buf := new(bytes.Buffer)
Expand Down
Loading
Loading