Skip to content

Commit

Permalink
CA suite support
Browse files Browse the repository at this point in the history
  • Loading branch information
nklaassen committed Jun 20, 2024
1 parent 8ce35ee commit 4754767
Show file tree
Hide file tree
Showing 16 changed files with 479 additions and 284 deletions.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/testplan.md
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,8 @@ $
$ # test AWS KMS
$ # login in to AWS locally
$ AWS_ACCOUNT="$(aws sts get-caller-identity | jq -r '.Account')"
$ TELEPORT_TEST_AWS_KMS_ACCOUNT="${AWS_ACCOUNT}" TELEPORT_TEST_AWS_REGION=us-west-2 go test ./lib/auth/keystore -v --count 1
$ TELEPORT_TEST_AWS_KMS_ACCOUNT="${AWS_ACCOUNT}" TELEPORT_TEST_AWS_REGION=us-west-2 TELEPORT_ETCD_TEST=1 go test ./integration/hsm -v --count 1
$ TELEPORT_TEST_AWS_KMS_ACCOUNT="${AWS_ACCOUNT}" TELEPORT_TEST_AWS_KMS_REGION=us-west-2 go test ./lib/auth/keystore -v --count 1
$ TELEPORT_TEST_AWS_KMS_ACCOUNT="${AWS_ACCOUNT}" TELEPORT_TEST_AWS_KMS_REGION=us-west-2 TELEPORT_ETCD_TEST=1 go test ./integration/hsm -v --count 1
$
$ # test AWS CloudHSM
$ # set up the CloudHSM cluster and run this on an EC2 that can reach it
Expand Down
8 changes: 8 additions & 0 deletions api/types/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ type AuthPreference interface {

// GetSignatureAlgorithmSuite gets the signature algorithm suite.
GetSignatureAlgorithmSuite() SignatureAlgorithmSuite
// SetSignatureAlgorithmSuite sets the signature algorithm suite.
SetSignatureAlgorithmSuite(SignatureAlgorithmSuite)

// String represents a human readable version of authentication settings.
String() string
Expand Down Expand Up @@ -543,10 +545,16 @@ func (c *AuthPreferenceV2) setStaticFields() {
c.Metadata.Name = MetaNameClusterAuthPreference
}

// GetSignatureAlgorithmSuite gets the signature algorithm suite.
func (c *AuthPreferenceV2) GetSignatureAlgorithmSuite() SignatureAlgorithmSuite {
return c.Spec.SignatureAlgorithmSuite
}

// SetSignatureAlgorithmSuite sets the signature algorithm suite.
func (c *AuthPreferenceV2) SetSignatureAlgorithmSuite(suite SignatureAlgorithmSuite) {
c.Spec.SignatureAlgorithmSuite = suite
}

// CheckAndSetDefaults verifies the constraints for AuthPreference.
func (c *AuthPreferenceV2) CheckAndSetDefaults() error {
c.setStaticFields()
Expand Down
2 changes: 1 addition & 1 deletion integration/hsm/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func newAuthConfig(t *testing.T, log utils.Logger) *servicecfg.Config {
}
var err error
config.Auth.ClusterName, err = services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{
ClusterName: "testcluster",
ClusterName: "test-cluster",
})
require.NoError(t, err)
config.SetAuthServerAddress(config.Auth.ListenAddr)
Expand Down
1 change: 1 addition & 0 deletions integration/hsm/hsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func newHSMAuthConfig(t *testing.T, storageConfig *backend.Config, log utils.Log
config := newAuthConfig(t, log)
config.Auth.StorageConfig = *storageConfig
config.Auth.KeyStore = keystore.HSMTestConfig(t)
config.Auth.Preference.SetSignatureAlgorithmSuite(types.SignatureAlgorithmSuite_HSM_DEV)
return config
}

Expand Down
106 changes: 63 additions & 43 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ import (
"github.com/gravitational/teleport/lib/cache"
"github.com/gravitational/teleport/lib/circleci"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/gcp"
Expand Down Expand Up @@ -337,9 +338,10 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
}

keystoreOpts := &keystore.Options{
HostUUID: cfg.HostUUID,
ClusterName: cfg.ClusterName,
CloudClients: cfg.CloudClients,
HostUUID: cfg.HostUUID,
ClusterName: cfg.ClusterName,
CloudClients: cfg.CloudClients,
AuthPreferenceGetter: cfg.ClusterConfiguration,
}
if cfg.KeyStoreConfig.PKCS11 != (servicecfg.PKCS11Config{}) {
if !modules.GetModules().Features().HSM {
Expand Down Expand Up @@ -6707,62 +6709,80 @@ func (a *Server) addAdditionalTrustedKeysAtomic(ctx context.Context, ca types.Ce
// Keep this function in sync with lib/service/suite/suite.go:NewTestCAWithConfig().
func newKeySet(ctx context.Context, keyStore *keystore.Manager, caID types.CertAuthID) (types.CAKeySet, error) {
var keySet types.CAKeySet

// Add SSH keys if necessary.
switch caID.Type {
case types.UserCA, types.HostCA:
sshKeyPair, err := keyStore.NewSSHKeyPair(ctx)
if err != nil {
return keySet, trace.Wrap(err)
}
tlsKeyPair, err := keyStore.NewTLSKeyPair(ctx, caID.DomainName)
case types.UserCA, types.HostCA, types.OpenSSHCA:
sshKeyPair, err := keyStore.NewSSHKeyPair(ctx, sshCAKeyPurpose(caID.Type))
if err != nil {
return keySet, trace.Wrap(err)
}
keySet.SSH = append(keySet.SSH, sshKeyPair)
keySet.TLS = append(keySet.TLS, tlsKeyPair)
case types.DatabaseCA, types.DatabaseClientCA:
// Database CA only contains TLS cert.
tlsKeyPair, err := keyStore.NewTLSKeyPair(ctx, caID.DomainName)
}

// Add TLS keys if necessary.
switch caID.Type {
case types.UserCA, types.HostCA, types.DatabaseCA, types.DatabaseClientCA, types.SAMLIDPCA, types.SPIFFECA:
tlsKeyPair, err := keyStore.NewTLSKeyPair(ctx, caID.DomainName, tlsCAKeyPurpose(caID.Type))
if err != nil {
return keySet, trace.Wrap(err)
}
keySet.TLS = append(keySet.TLS, tlsKeyPair)
case types.OpenSSHCA:
// OpenSSH CA only contains a SSH key pair.
sshKeyPair, err := keyStore.NewSSHKeyPair(ctx)
if err != nil {
return keySet, trace.Wrap(err)
}
keySet.SSH = append(keySet.SSH, sshKeyPair)
case types.JWTSigner, types.OIDCIdPCA:
jwtKeyPair, err := keyStore.NewJWTKeyPair(ctx)
}

// Add JWT keys if necessary.
switch caID.Type {
case types.JWTSigner, types.OIDCIdPCA, types.SPIFFECA:
jwtKeyPair, err := keyStore.NewJWTKeyPair(ctx, jwtCAKeyPurpose(caID.Type))
if err != nil {
return keySet, trace.Wrap(err)
}
keySet.JWT = append(keySet.JWT, jwtKeyPair)
}

return keySet, nil
}

func sshCAKeyPurpose(caType types.CertAuthType) cryptosuites.KeyPurpose {
switch caType {
case types.UserCA:
return cryptosuites.UserCASSH
case types.HostCA:
return cryptosuites.HostCASSH
case types.OpenSSHCA:
return cryptosuites.OpenSSHCASSH
}
return cryptosuites.KeyPurposeUnspecified
}

func tlsCAKeyPurpose(caType types.CertAuthType) cryptosuites.KeyPurpose {
switch caType {
case types.UserCA:
return cryptosuites.UserCATLS
case types.HostCA:
return cryptosuites.HostCATLS
case types.DatabaseCA:
return cryptosuites.DatabaseCATLS
case types.DatabaseClientCA:
return cryptosuites.DatabaseClientCATLS
case types.SAMLIDPCA:
// SAML IDP CA only contains TLS certs.
tlsKeyPair, err := keyStore.NewTLSKeyPair(ctx, caID.DomainName)
if err != nil {
return keySet, trace.Wrap(err)
}
keySet.TLS = append(keySet.TLS, tlsKeyPair)
return cryptosuites.SAMLIdPCATLS
case types.SPIFFECA:
tlsKeyPair, err := keyStore.NewTLSKeyPair(ctx, caID.DomainName)
if err != nil {
return keySet, trace.Wrap(err)
}
keySet.TLS = append(keySet.TLS, tlsKeyPair)
// Whilst we don't currently support JWT-SVIDs, we will eventually. So
// generate a JWT keypair.
jwtKeyPair, err := keyStore.NewJWTKeyPair(ctx)
if err != nil {
return keySet, trace.Wrap(err)
}
keySet.JWT = append(keySet.JWT, jwtKeyPair)
default:
return keySet, trace.BadParameter("unknown ca type: %s", caID.Type)
return cryptosuites.SPIFFECATLS
}
return keySet, nil
return cryptosuites.KeyPurposeUnspecified
}

func jwtCAKeyPurpose(caType types.CertAuthType) cryptosuites.KeyPurpose {
switch caType {
case types.JWTSigner:
return cryptosuites.JWTCAJWT
case types.OIDCIdPCA:
return cryptosuites.OIDCIdPCAJWT
case types.SPIFFECA:
return cryptosuites.SPIFFECAJWT
}
return cryptosuites.KeyPurposeUnspecified
}

// ensureLocalAdditionalKeys adds additional trusted keys to the CA if they are not
Expand Down
45 changes: 37 additions & 8 deletions lib/auth/keystore/aws_kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package keystore
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"errors"
"fmt"
Expand All @@ -43,6 +45,7 @@ import (
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/cloud"
awslib "github.com/gravitational/teleport/lib/cloud/aws"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/service/servicecfg"
)

Expand Down Expand Up @@ -108,13 +111,20 @@ func (a *awsKMSKeystore) keyTypeDescription() string {
return fmt.Sprintf("AWS KMS keys in account %s and region %s", a.awsAccount, a.awsRegion)
}

// generateRSA creates a new RSA private key and returns its identifier and a crypto.Signer. The returned
// generateKey creates a new private key and returns its identifier and a crypto.Signer. The returned
// identifier can be passed to getSigner later to get an equivalent crypto.Signer.
func (a *awsKMSKeystore) generateRSA(ctx context.Context, _ ...rsaKeyOption) ([]byte, crypto.Signer, error) {
func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites.Algorithm, opts ...rsaKeyOption) ([]byte, crypto.Signer, error) {
alg, err := awsAlgorithm(algorithm)
if err != nil {
return nil, nil, trace.Wrap(err)
}

a.logger.InfoContext(ctx, "Creating new AWS KMS keypair.", "algorithm", alg)

output, err := a.kms.CreateKey(&kms.CreateKeyInput{
Description: aws.String("Teleport CA key"),
KeySpec: aws.String("RSA_2048"),
KeyUsage: aws.String("SIGN_VERIFY"),
KeySpec: &alg,
KeyUsage: aws.String(kms.KeyUsageTypeSignVerify),
Tags: []*kms.Tag{
{
TagKey: aws.String(clusterTagKey),
Expand All @@ -141,6 +151,16 @@ func (a *awsKMSKeystore) generateRSA(ctx context.Context, _ ...rsaKeyOption) ([]
return keyID, signer, nil
}

func awsAlgorithm(alg cryptosuites.Algorithm) (string, error) {
switch alg {
case cryptosuites.RSA2048:
return kms.KeySpecRsa2048, nil
case cryptosuites.ECDSAP256:
return kms.KeySpecEccNistP256, nil
}
return "", trace.BadParameter("unsupported algorithm: %v", alg)
}

// getSigner returns a crypto.Signer for the given key identifier, if it is found.
func (a *awsKMSKeystore) getSigner(ctx context.Context, rawKey []byte, publicKey crypto.PublicKey) (crypto.Signer, error) {
keyID, err := parseAWSKMSKeyID(rawKey)
Expand Down Expand Up @@ -236,16 +256,23 @@ func (a *awsKMSSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpt
var signingAlg string
switch opts.HashFunc() {
case crypto.SHA256:
signingAlg = "RSASSA_PKCS1_V1_5_SHA_256"
switch a.pub.(type) {
case *rsa.PublicKey:
signingAlg = kms.SigningAlgorithmSpecRsassaPkcs1V15Sha256
case *ecdsa.PublicKey:
signingAlg = kms.SigningAlgorithmSpecEcdsaSha256
default:
return nil, trace.BadParameter("unsupported hash func %q for AWS KMS key type %T", opts.HashFunc(), a.pub)
}
case crypto.SHA512:
signingAlg = "RSASSA_PKCS1_V1_5_SHA_512"
signingAlg = kms.SigningAlgorithmSpecRsassaPkcs1V15Sha512
default:
return nil, trace.BadParameter("unsupported hash func %q for AWS KMS key", opts.HashFunc())
}
output, err := a.kms.Sign(&kms.SignInput{
KeyId: aws.String(a.keyARN),
Message: digest,
MessageType: aws.String("DIGEST"),
MessageType: aws.String(kms.MessageTypeDigest),
SigningAlgorithm: aws.String(signingAlg),
})
if err != nil {
Expand Down Expand Up @@ -344,8 +371,10 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by
}
return trace.Wrap(err, "failed to fetch tags for AWS KMS key %q", keyARN)
}

clusterName := a.clusterName.GetClusterName()
if !slices.ContainsFunc(output.Tags, func(tag *kms.Tag) bool {
return aws.StringValue(tag.TagKey) == clusterTagKey && aws.StringValue(tag.TagValue) == a.clusterName.GetClusterName()
return aws.StringValue(tag.TagKey) == clusterTagKey && aws.StringValue(tag.TagValue) == clusterName
}) {
// This key was not created by this Teleport cluster, never delete it.
return nil
Expand Down
Loading

0 comments on commit 4754767

Please sign in to comment.