Skip to content

Commit

Permalink
update AWS KMS to aws-sdk-go-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
nklaassen committed Sep 24, 2024
1 parent 3b80232 commit d305dae
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 165 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/eks v1.48.2
github.com/aws/aws-sdk-go-v2/service/glue v1.95.0
github.com/aws/aws-sdk-go-v2/service/iam v1.35.0
github.com/aws/aws-sdk-go-v2/service/kms v1.35.3
github.com/aws/aws-sdk-go-v2/service/rds v1.82.2
github.com/aws/aws-sdk-go-v2/service/redshift v1.46.6
github.com/aws/aws-sdk-go-v2/service/s3 v1.61.0
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
keystoreOpts := &keystore.Options{
HostUUID: cfg.HostUUID,
ClusterName: cfg.ClusterName,
CloudClients: cfg.CloudClients,
AuthPreferenceGetter: cfg.ClusterConfiguration,
FIPS: cfg.FIPS,
}
if cfg.KeyStoreConfig.PKCS11 != (servicecfg.PKCS11Config{}) {
if !modules.GetModules().Features().GetEntitlement(entitlements.HSM).Enabled {
Expand Down
154 changes: 84 additions & 70 deletions lib/auth/keystore/aws_kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,18 @@ import (
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/kms"
kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"golang.org/x/sync/errgroup"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/cloud"
awslib "github.com/gravitational/teleport/lib/cloud/aws"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/service/servicecfg"
)
Expand All @@ -58,15 +56,8 @@ const (
pendingKeyTimeout = 30 * time.Second
)

type CloudClientProvider interface {
// GetAWSSTSClient returns AWS STS client for the specified region.
GetAWSSTSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (stsiface.STSAPI, error)
// GetAWSKMSClient returns AWS KMS client for the specified region.
GetAWSKMSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (kmsiface.KMSAPI, error)
}

type awsKMSKeystore struct {
kms kmsiface.KMSAPI
kms kmsClient
clusterName types.ClusterName
awsAccount string
awsRegion string
Expand All @@ -75,21 +66,33 @@ type awsKMSKeystore struct {
}

func newAWSKMSKeystore(ctx context.Context, cfg *servicecfg.AWSKMSConfig, opts *Options) (*awsKMSKeystore, error) {
stsClient, err := opts.CloudClients.GetAWSSTSClient(ctx, cfg.AWSRegion, cloud.WithAmbientCredentials())
if err != nil {
return nil, trace.Wrap(err)
stsClient, kmsClient := opts.awsSTSClient, opts.awsKMSClient
if stsClient == nil || kmsClient == nil {
useFIPSEndpoint := aws.FIPSEndpointStateUnset
if opts.FIPS {
useFIPSEndpoint = aws.FIPSEndpointStateEnabled
}
awsCfg, err := config.LoadDefaultConfig(ctx,
config.WithRegion(cfg.AWSRegion),
config.WithUseFIPSEndpoint(useFIPSEndpoint),
)
if err != nil {
return nil, trace.Wrap(err, "loading default AWS config")
}
if stsClient == nil {
stsClient = sts.NewFromConfig(awsCfg)
}
if kmsClient == nil {
kmsClient = kms.NewFromConfig(awsCfg)
}
}
id, err := stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{})
id, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return nil, trace.Wrap(err)
return nil, trace.Wrap(err, "checking AWS account of local credentials for AWS KMS")
}
if aws.StringValue(id.Account) != cfg.AWSAccount {
if aws.ToString(id.Account) != cfg.AWSAccount {
return nil, trace.BadParameter("configured AWS KMS account %q does not match AWS account of ambient credentials %q",
cfg.AWSAccount, aws.StringValue(id.Account))
}
kmsClient, err := opts.CloudClients.GetAWSKMSClient(ctx, cfg.AWSRegion, cloud.WithAmbientCredentials())
if err != nil {
return nil, trace.Wrap(err)
cfg.AWSAccount, aws.ToString(id.Account))
}
clock := opts.clockworkOverride
if clock == nil {
Expand Down Expand Up @@ -125,11 +128,11 @@ func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites

a.logger.InfoContext(ctx, "Creating new AWS KMS keypair.", "algorithm", alg)

output, err := a.kms.CreateKey(&kms.CreateKeyInput{
output, err := a.kms.CreateKey(ctx, &kms.CreateKeyInput{
Description: aws.String("Teleport CA key"),
KeySpec: &alg,
KeyUsage: aws.String(kms.KeyUsageTypeSignVerify),
Tags: []*kms.Tag{
KeySpec: alg,
KeyUsage: kmstypes.KeyUsageTypeSignVerify,
Tags: []kmstypes.Tag{
{
TagKey: aws.String(clusterTagKey),
TagValue: aws.String(a.clusterName.GetClusterName()),
Expand All @@ -142,7 +145,7 @@ func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites
if output.KeyMetadata == nil {
return nil, nil, trace.Errorf("KeyMetadata of generated key is nil")
}
keyARN := aws.StringValue(output.KeyMetadata.Arn)
keyARN := aws.ToString(output.KeyMetadata.Arn)
signer, err := a.newSigner(ctx, keyARN)
if err != nil {
return nil, nil, trace.Wrap(err)
Expand All @@ -155,14 +158,14 @@ func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites
return keyID, signer, nil
}

func awsAlgorithm(alg cryptosuites.Algorithm) (string, error) {
func awsAlgorithm(alg cryptosuites.Algorithm) (kmstypes.KeySpec, error) {
switch alg {
case cryptosuites.RSA2048:
return kms.KeySpecRsa2048, nil
return kmstypes.KeySpecRsa2048, nil
case cryptosuites.ECDSAP256:
return kms.KeySpecEccNistP256, nil
return kmstypes.KeySpecEccNistP256, nil
}
return "", trace.BadParameter("unsupported algorithm: %v", alg)
return "", trace.BadParameter("unsupported algorithm for AWS KMS: %v", alg)
}

// getSigner returns a crypto.Signer for the given key identifier, if it is found.
Expand All @@ -177,7 +180,7 @@ func (a *awsKMSKeystore) getSigner(ctx context.Context, rawKey []byte, publicKey
type awsKMSSigner struct {
keyARN string
pub crypto.PublicKey
kms kmsiface.KMSAPI
kms kmsClient
}

func (a *awsKMSKeystore) newSigner(ctx context.Context, keyARN string) (*awsKMSSigner, error) {
Expand Down Expand Up @@ -211,7 +214,7 @@ func (a *awsKMSKeystore) getPublicKeyDER(ctx context.Context, keyARN string) ([]
timeout := a.clock.NewTimer(pendingKeyTimeout)
defer timeout.Stop()
for {
output, err := a.kms.GetPublicKeyWithContext(ctx, &kms.GetPublicKeyInput{
output, err := a.kms.GetPublicKey(ctx, &kms.GetPublicKeyInput{
KeyId: aws.String(keyARN),
})
if err == nil {
Expand All @@ -222,8 +225,8 @@ func (a *awsKMSKeystore) getPublicKeyDER(ctx context.Context, keyARN string) ([]
// error types
// https://docs.aws.amazon.com/kms/latest/developerguide/programming-eventual-consistency.html
var (
notFound *kms.NotFoundException
invalidState *kms.InvalidStateException
notFound *kmstypes.NotFoundException
invalidState *kmstypes.KMSInvalidStateException
)
if !errors.As(err, &notFound) && !errors.As(err, &invalidState) {
return nil, trace.Wrap(err, "unexpected error fetching AWS KMS public key")
Expand Down Expand Up @@ -257,34 +260,34 @@ func (a *awsKMSSigner) Public() crypto.PublicKey {

// Sign signs the message digest.
func (a *awsKMSSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
var signingAlg string
var signingAlg kmstypes.SigningAlgorithmSpec
switch opts.HashFunc() {
case crypto.SHA256:
switch a.pub.(type) {
case *rsa.PublicKey:
signingAlg = kms.SigningAlgorithmSpecRsassaPkcs1V15Sha256
signingAlg = kmstypes.SigningAlgorithmSpecRsassaPkcs1V15Sha256
case *ecdsa.PublicKey:
signingAlg = kms.SigningAlgorithmSpecEcdsaSha256
signingAlg = kmstypes.SigningAlgorithmSpecEcdsaSha256
default:
return nil, trace.BadParameter("unsupported hash func %q for AWS KMS key type %T", opts.HashFunc(), a.pub)
}
case crypto.SHA512:
switch a.pub.(type) {
case *rsa.PublicKey:
signingAlg = kms.SigningAlgorithmSpecRsassaPkcs1V15Sha512
signingAlg = kmstypes.SigningAlgorithmSpecRsassaPkcs1V15Sha512
case *ecdsa.PublicKey:
signingAlg = kms.SigningAlgorithmSpecEcdsaSha512
signingAlg = kmstypes.SigningAlgorithmSpecEcdsaSha512
default:
return nil, trace.BadParameter("unsupported hash func %q for AWS KMS key type %T", opts.HashFunc(), a.pub)
}
default:
return nil, trace.BadParameter("unsupported hash func %q for AWS KMS key", opts.HashFunc())
}
output, err := a.kms.Sign(&kms.SignInput{
output, err := a.kms.Sign(context.TODO(), &kms.SignInput{
KeyId: aws.String(a.keyARN),
Message: digest,
MessageType: aws.String(kms.MessageTypeDigest),
SigningAlgorithm: aws.String(signingAlg),
MessageType: kmstypes.MessageTypeDigest,
SigningAlgorithm: signingAlg,
})
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -298,9 +301,9 @@ func (a *awsKMSKeystore) deleteKey(ctx context.Context, rawKey []byte) error {
if err != nil {
return trace.Wrap(err)
}
_, err = a.kms.ScheduleKeyDeletion(&kms.ScheduleKeyDeletionInput{
_, err = a.kms.ScheduleKeyDeletion(ctx, &kms.ScheduleKeyDeletionInput{
KeyId: aws.String(keyID.arn),
PendingWindowInDays: aws.Int64(7),
PendingWindowInDays: aws.Int32(7),
})
return trace.Wrap(err, "error deleting AWS KMS key")
}
Expand Down Expand Up @@ -370,29 +373,26 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by
}

// Check if this key was created by this Teleport cluster.
output, err := a.kms.ListResourceTagsWithContext(ctx, &kms.ListResourceTagsInput{
output, err := a.kms.ListResourceTags(ctx, &kms.ListResourceTagsInput{
KeyId: aws.String(keyARN),
})
if err != nil {
err = awslib.ConvertRequestFailureError(err)
if trace.IsAccessDenied(err) {
// It's entirely expected that we'll not be allowed to fetch
// tags for some keys, don't worry about deleting those.
return nil
}
return trace.Wrap(err, "failed to fetch tags for AWS KMS key %q", keyARN)
// It's entirely expected that we won't be allowed to fetch
// tags for some keys, don't worry about deleting those.
a.logger.DebugContext(ctx, "failed to fetch tags for AWS KMS key, skipping", "key_arn", keyARN, "error", err)
return nil
}

clusterName := a.clusterName.GetClusterName()
if !slices.ContainsFunc(output.Tags, func(tag *kms.Tag) bool {
return aws.StringValue(tag.TagKey) == clusterTagKey && aws.StringValue(tag.TagValue) == clusterName
if !slices.ContainsFunc(output.Tags, func(tag kmstypes.Tag) bool {
return aws.ToString(tag.TagKey) == clusterTagKey && aws.ToString(tag.TagValue) == clusterName
}) {
// This key was not created by this Teleport cluster, never delete it.
return nil
}

// Check if this key is not enabled or was created in the past 5 minutes.
describeOutput, err := a.kms.DescribeKeyWithContext(ctx, &kms.DescribeKeyInput{
describeOutput, err := a.kms.DescribeKey(ctx, &kms.DescribeKeyInput{
KeyId: aws.String(keyARN),
})
if err != nil {
Expand All @@ -401,12 +401,12 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by
if describeOutput.KeyMetadata == nil {
return trace.Errorf("failed to describe AWS KMS key %q", keyARN)
}
if keyState := aws.StringValue(describeOutput.KeyMetadata.KeyState); keyState != "Enabled" {
if keyState := describeOutput.KeyMetadata.KeyState; keyState != kmstypes.KeyStateEnabled {
a.logger.InfoContext(ctx, "deleteUnusedKeys skipping AWS KMS key which is not in enabled state.",
"key_arn", keyARN, "key_state", keyState)
return nil
}
creationDate := aws.TimeValue(describeOutput.KeyMetadata.CreationDate)
creationDate := aws.ToTime(describeOutput.KeyMetadata.CreationDate)
if a.clock.Now().Sub(creationDate).Abs() < 5*time.Minute {
// Never delete keys created in the last 5 minutes in case they were
// created by a different auth server and just haven't been added to
Expand Down Expand Up @@ -438,9 +438,9 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by

for _, keyARN := range keysToDelete {
a.logger.InfoContext(ctx, "Deleting unused AWS KMS key.", "key_arn", keyARN)
if _, err := a.kms.ScheduleKeyDeletion(&kms.ScheduleKeyDeletionInput{
if _, err := a.kms.ScheduleKeyDeletion(ctx, &kms.ScheduleKeyDeletionInput{
KeyId: aws.String(keyARN),
PendingWindowInDays: aws.Int64(7),
PendingWindowInDays: aws.Int32(7),
}); err != nil {
return trace.Wrap(err, "failed to schedule AWS KMS key %q for deletion", keyARN)
}
Expand All @@ -459,17 +459,17 @@ func (a *awsKMSKeystore) forEachKey(ctx context.Context, fn func(ctx context.Con
if marker != "" {
markerInput = aws.String(marker)
}
output, err := a.kms.ListKeysWithContext(ctx, &kms.ListKeysInput{
output, err := a.kms.ListKeys(ctx, &kms.ListKeysInput{
Marker: markerInput,
Limit: aws.Int64(1000),
Limit: aws.Int32(1000),
})
if err != nil {
return trace.Wrap(err, "failed to list AWS KMS keys")
}
marker = aws.StringValue(output.NextMarker)
more = aws.BoolValue(output.Truncated)
marker = aws.ToString(output.NextMarker)
more = output.Truncated
for _, keyEntry := range output.Keys {
keyArn := aws.StringValue(keyEntry.KeyArn)
keyArn := aws.ToString(keyEntry.KeyArn)
errGroup.Go(func() error {
return trace.Wrap(fn(ctx, keyArn))
})
Expand Down Expand Up @@ -501,3 +501,17 @@ func parseAWSKMSKeyID(raw []byte) (awsKMSKeyID, error) {
region: parsedARN.Region,
}, nil
}

type kmsClient interface {
CreateKey(context.Context, *kms.CreateKeyInput, ...func(*kms.Options)) (*kms.CreateKeyOutput, error)
GetPublicKey(context.Context, *kms.GetPublicKeyInput, ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error)
ListKeys(context.Context, *kms.ListKeysInput, ...func(*kms.Options)) (*kms.ListKeysOutput, error)
ScheduleKeyDeletion(context.Context, *kms.ScheduleKeyDeletionInput, ...func(*kms.Options)) (*kms.ScheduleKeyDeletionOutput, error)
DescribeKey(context.Context, *kms.DescribeKeyInput, ...func(*kms.Options)) (*kms.DescribeKeyOutput, error)
ListResourceTags(context.Context, *kms.ListResourceTagsInput, ...func(*kms.Options)) (*kms.ListResourceTagsOutput, error)
Sign(context.Context, *kms.SignInput, ...func(*kms.Options)) (*kms.SignOutput, error)
}

type stsClient interface {
GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error)
}
Loading

0 comments on commit d305dae

Please sign in to comment.