From 00589671b48421f0148b8b31481a8542c1500491 Mon Sep 17 00:00:00 2001 From: David Boslee Date: Mon, 11 Nov 2024 12:46:24 -0500 Subject: [PATCH] keystore: add support for custom kms tags (#48132) * keystore: add support for custom kms tags * remove utils.IsEmpty and use pointer for aws kms config * test tags len Co-authored-by: Nic Klaassen * add yaml struct tags Co-authored-by: Nic Klaassen * Update lib/auth/keystore/aws_kms_test.go Co-authored-by: Nic Klaassen * keep TeleportCluster tag if not specified * add delete used key test cases with tagging * fix unit tests * remove unused tag var * Update lib/auth/keystore/aws_kms.go Co-authored-by: Nic Klaassen --------- Co-authored-by: Nic Klaassen --- lib/auth/auth.go | 2 +- lib/auth/keystore/aws_kms.go | 40 ++++-- lib/auth/keystore/aws_kms_test.go | 213 +++++++++++++++++++---------- lib/auth/keystore/keystore_test.go | 6 +- lib/auth/keystore/manager.go | 4 +- lib/auth/keystore/testhelpers.go | 2 +- lib/config/configuration.go | 1 + lib/config/fileconf.go | 5 + lib/service/servicecfg/auth.go | 9 +- 9 files changed, 187 insertions(+), 95 deletions(-) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 5e5506e6cbb0a..7138e6e150750 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -379,7 +379,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { if !modules.GetModules().Features().GetEntitlement(entitlements.HSM).Enabled { return nil, fmt.Errorf("Google Cloud KMS support requires a license with the HSM feature enabled: %w", ErrRequiresEnterprise) } - } else if cfg.KeyStoreConfig.AWSKMS != (servicecfg.AWSKMSConfig{}) { + } else if cfg.KeyStoreConfig.AWSKMS != nil { if !modules.GetModules().Features().GetEntitlement(entitlements.HSM).Enabled { return nil, fmt.Errorf("AWS KMS support requires a license with the HSM feature enabled: %w", ErrRequiresEnterprise) } diff --git a/lib/auth/keystore/aws_kms.go b/lib/auth/keystore/aws_kms.go index 19d52069b0ea8..d8726b1ae0e48 100644 --- a/lib/auth/keystore/aws_kms.go +++ b/lib/auth/keystore/aws_kms.go @@ -64,10 +64,10 @@ type CloudClientProvider interface { type awsKMSKeystore struct { kms kmsiface.KMSAPI - clusterName types.ClusterName awsAccount string awsRegion string multiRegionEnabled bool + tags map[string]string clock clockwork.Clock logger *slog.Logger } @@ -89,14 +89,23 @@ func newAWSKMSKeystore(ctx context.Context, cfg *servicecfg.AWSKMSConfig, opts * if err != nil { return nil, trace.Wrap(err) } + + tags := cfg.Tags + if tags == nil { + tags = make(map[string]string, 1) + } + if _, ok := tags[clusterTagKey]; !ok { + tags[clusterTagKey] = opts.ClusterName.GetClusterName() + } + clock := opts.clockworkOverride if clock == nil { clock = clockwork.NewRealClock() } return &awsKMSKeystore{ - clusterName: opts.ClusterName, awsAccount: cfg.AWSAccount, awsRegion: cfg.AWSRegion, + tags: tags, multiRegionEnabled: cfg.MultiRegion.Enabled, kms: kmsClient, clock: clock, @@ -117,16 +126,18 @@ func (a *awsKMSKeystore) keyTypeDescription() string { // generateRSA creates a new RSA private key and returns its identifier and a crypto.Signer. The returned // identifier can be passed to getSigner later to get an equivalent crypto.Signer. func (a *awsKMSKeystore) generateRSA(ctx context.Context, _ ...rsaKeyOption) ([]byte, crypto.Signer, error) { + tags := make([]*kms.Tag, 0, len(a.tags)) + for k, v := range a.tags { + tags = append(tags, &kms.Tag{ + TagKey: aws.String(k), + TagValue: aws.String(v), + }) + } output, err := a.kms.CreateKey(&kms.CreateKeyInput{ Description: aws.String("Teleport CA key"), KeySpec: aws.String("RSA_2048"), KeyUsage: aws.String("SIGN_VERIFY"), - Tags: []*kms.Tag{ - { - TagKey: aws.String(clusterTagKey), - TagValue: aws.String(a.clusterName.GetClusterName()), - }, - }, + Tags: tags, MultiRegion: aws.Bool(a.multiRegionEnabled), }) if err != nil { @@ -351,11 +362,14 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by } return trace.Wrap(err, "failed to fetch tags for AWS KMS key %q", keyARN) } - if !slices.ContainsFunc(output.Tags, func(tag *kms.Tag) bool { - return aws.StringValue(tag.TagKey) == clusterTagKey && aws.StringValue(tag.TagValue) == a.clusterName.GetClusterName() - }) { - // This key was not created by this Teleport cluster, never delete it. - return nil + + // All tags must match for this key to be considered for deletion. + for k, v := range a.tags { + if !slices.ContainsFunc(output.Tags, func(tag *kms.Tag) bool { + return aws.StringValue(tag.TagKey) == k && aws.StringValue(tag.TagValue) == v + }) { + return nil + } } // Check if this key is not enabled or was created in the past 5 minutes. diff --git a/lib/auth/keystore/aws_kms_test.go b/lib/auth/keystore/aws_kms_test.go index 95fbc1ac3d9b2..a5b41677db205 100644 --- a/lib/auth/keystore/aws_kms_test.go +++ b/lib/auth/keystore/aws_kms_test.go @@ -56,80 +56,113 @@ func TestAWSKMS_DeleteUnusedKeys(t *testing.T) { ctx := context.Background() clock := clockwork.NewFakeClock() - const pageSize int = 4 - fakeKMS := newFakeAWSKMSService(t, clock, "123456789012", "us-west-2", pageSize) - cfg := servicecfg.KeystoreConfig{ - AWSKMS: servicecfg.AWSKMSConfig{ - AWSAccount: "123456789012", - AWSRegion: "us-west-2", + for _, tc := range []struct { + name string + tags map[string]string + }{ + { + name: "delete keys with default tags", }, - } - clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ClusterName: "test-cluster"}) - require.NoError(t, err) - opts := &Options{ - ClusterName: clusterName, - HostUUID: "uuid", - CloudClients: &cloud.TestCloudClients{ - KMS: fakeKMS, - STS: &fakeAWSSTSClient{ - account: "123456789012", + { + name: "delete keys with custom tags", + tags: map[string]string{ + "test-key-1": "test-value-1", }, }, - clockworkOverride: clock, - } - keyStore, err := NewManager(ctx, &cfg, opts) - require.NoError(t, err) + { + name: "delete keys with override cluster tag", + tags: map[string]string{ + "TeleportCluster": "test-cluster-2", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + const pageSize int = 4 + fakeKMS := newFakeAWSKMSService(t, clock, "123456789012", "us-west-2", pageSize) + cfg := servicecfg.KeystoreConfig{ + AWSKMS: &servicecfg.AWSKMSConfig{ + AWSAccount: "123456789012", + AWSRegion: "us-west-2", + Tags: tc.tags, + }, + } + clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ClusterName: "test-cluster"}) + require.NoError(t, err) + opts := &Options{ + ClusterName: clusterName, + HostUUID: "uuid", + CloudClients: &cloud.TestCloudClients{ + KMS: fakeKMS, + STS: &fakeAWSSTSClient{ + account: "123456789012", + }, + }, + clockworkOverride: clock, + } + keyStore, err := NewManager(ctx, &cfg, opts) + require.NoError(t, err) - totalKeys := pageSize * 3 - for i := 0; i < totalKeys; i++ { - _, err := keyStore.NewSSHKeyPair(ctx) - require.NoError(t, err) - } + var otherTags []*kms.Tag + for k, v := range keyStore.backendForNewKeys.(*awsKMSKeystore).tags { + if k != clusterTagKey { + otherTags = append(otherTags, &kms.Tag{ + TagKey: aws.String(k), + TagValue: aws.String(v), + }) + } + } - // Newly created keys should not be deleted. - err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/) - require.NoError(t, err) - for _, key := range fakeKMS.keys { - assert.Equal(t, "Enabled", key.state) - } + totalKeys := pageSize * 3 + for i := 0; i < totalKeys; i++ { + _, err := keyStore.NewSSHKeyPair(ctx) + require.NoError(t, err, trace.DebugReport(err)) + } - // Keys created more than 5 minutes ago should be deleted. - clock.Advance(6 * time.Minute) - err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/) - require.NoError(t, err) - for _, key := range fakeKMS.keys { - assert.Equal(t, "PendingDeletion", key.state) - } + // Newly created keys should not be deleted. + err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/) + require.NoError(t, err) + for _, key := range fakeKMS.keys { + assert.Equal(t, kms.KeyStateEnabled, key.state) + } - // Insert a key created by a different Teleport cluster, it should not be - // deleted by the keystore. - output, err := fakeKMS.CreateKey(&kms.CreateKeyInput{ - Tags: []*kms.Tag{ - &kms.Tag{ - TagKey: aws.String(clusterTagKey), - TagValue: aws.String("other-cluster"), - }, - }, - }) - require.NoError(t, err) - otherClusterKeyARN := aws.StringValue(output.KeyMetadata.Arn) + // Keys created more than 5 minutes ago should be deleted. + clock.Advance(6 * time.Minute) + err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/) + require.NoError(t, err) + for _, key := range fakeKMS.keys { + assert.Equal(t, kms.KeyStatePendingDeletion, key.state) + } - clock.Advance(6 * time.Minute) - err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/) - require.NoError(t, err) - for _, key := range fakeKMS.keys { - if key.arn == otherClusterKeyARN { - assert.Equal(t, "Enabled", key.state) - } else { - assert.Equal(t, "PendingDeletion", key.state) - } + // Insert a key created by a different Teleport cluster, it should not be + // deleted by the keystore. + output, err := fakeKMS.CreateKey(&kms.CreateKeyInput{ + KeySpec: aws.String(kms.KeySpecEccNistP256), + Tags: append(otherTags, &kms.Tag{ + TagKey: aws.String(clusterTagKey), + TagValue: aws.String("other-cluster"), + }), + }) + require.NoError(t, err) + otherClusterKeyARN := aws.StringValue(output.KeyMetadata.Arn) + + clock.Advance(6 * time.Minute) + err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/) + require.NoError(t, err) + for _, key := range fakeKMS.keys { + if key.arn == otherClusterKeyARN { + assert.Equal(t, kms.KeyStateEnabled, key.state) + } else { + assert.Equal(t, kms.KeyStatePendingDeletion, key.state) + } + } + }) } } func TestAWSKMS_WrongAccount(t *testing.T) { clock := clockwork.NewFakeClock() cfg := &servicecfg.KeystoreConfig{ - AWSKMS: servicecfg.AWSKMSConfig{ + AWSKMS: &servicecfg.AWSKMSConfig{ AWSAccount: "111111111111", AWSRegion: "us-west-2", }, @@ -161,7 +194,7 @@ func TestAWSKMS_RetryWhilePending(t *testing.T) { pageLimit: 1000, } cfg := &servicecfg.KeystoreConfig{ - AWSKMS: servicecfg.AWSKMSConfig{ + AWSKMS: &servicecfg.AWSKMSConfig{ AWSAccount: "111111111111", AWSRegion: "us-west-2", }, @@ -214,13 +247,13 @@ func TestAWSKMS_RetryWhilePending(t *testing.T) { require.Error(t, err) } -// TestMultiRegionKeys asserts that a keystore created with multi-region enabled -// correctly passes this argument to the AWS client. This gives very little real -// coverage since the AWS KMS service here is faked, but at least we know the -// keystore passed the bool to the client correctly. TestBackends and -// TestManager are both able to run with a real AWS KMS client and you can -// confirm the keys are really multi-region there. -func TestMultiRegionKeys(t *testing.T) { +// TestKeyAWSKeyCreationParameters asserts that an AWS keystore created with a +// variety of parameters correctly passes these parameters to the AWS client. +// This gives very little real coverage since the AWS KMS service here is faked, +// but at least we know the keystore passed the parameters to the client correctly. +// TestBackends and TestManager are both able to run with a real AWS KMS client +// and you can confirm the keys are configured correctly there. +func TestAWSKeyCreationParameters(t *testing.T) { ctx := context.Background() clock := clockwork.NewFakeClock() @@ -240,15 +273,36 @@ func TestMultiRegionKeys(t *testing.T) { clockworkOverride: clock, } - for _, multiRegion := range []bool{false, true} { - t.Run(fmt.Sprint(multiRegion), func(t *testing.T) { + for _, tc := range []struct { + name string + multiRegion bool + tags map[string]string + }{ + { + name: "multi-region enabled with default tags", + multiRegion: true, + }, + { + name: "multi-region disabled with default tags", + multiRegion: false, + }, + { + name: "multi region disabled with custom tags", + multiRegion: false, + tags: map[string]string{ + "key": "value", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { cfg := servicecfg.KeystoreConfig{ - AWSKMS: servicecfg.AWSKMSConfig{ + AWSKMS: &servicecfg.AWSKMSConfig{ AWSAccount: "123456789012", AWSRegion: "us-west-2", MultiRegion: struct{ Enabled bool }{ - Enabled: multiRegion, + Enabled: tc.multiRegion, }, + Tags: tc.tags, }, } keyStore, err := NewManager(ctx, &cfg, opts) @@ -260,11 +314,24 @@ func TestMultiRegionKeys(t *testing.T) { keyID, err := parseAWSKMSKeyID(sshKeyPair.PrivateKey) require.NoError(t, err) - if multiRegion { + if tc.multiRegion { assert.Contains(t, keyID.arn, "mrk-") } else { assert.NotContains(t, keyID.arn, "mrk-") } + + tagsOut, err := fakeKMS.ListResourceTags(&kms.ListResourceTagsInput{KeyId: &keyID.arn}) + require.NoError(t, err) + if len(tc.tags) == 0 { + tc.tags = map[string]string{ + "TeleportCluster": clusterName.GetClusterName(), + } + } + require.Equal(t, len(tc.tags), len(tagsOut.Tags)) + for _, tag := range tagsOut.Tags { + v := tc.tags[aws.StringValue(tag.TagKey)] + require.Equal(t, v, aws.StringValue(tag.TagValue)) + } }) } } diff --git a/lib/auth/keystore/keystore_test.go b/lib/auth/keystore/keystore_test.go index 86405cd3fc937..85610451c67b9 100644 --- a/lib/auth/keystore/keystore_test.go +++ b/lib/auth/keystore/keystore_test.go @@ -525,7 +525,7 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack { if config, ok := awsKMSTestConfig(t); ok { config.AWSKMS.MultiRegion.Enabled = multiRegion - backend, err := newAWSKMSKeystore(ctx, &config.AWSKMS, opts) + backend, err := newAWSKMSKeystore(ctx, config.AWSKMS, opts) require.NoError(t, err) name := "aws_kms" if multiRegion { @@ -552,7 +552,7 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack { // Always test with fake AWS client. fakeAWSKMSConfig := servicecfg.KeystoreConfig{ - AWSKMS: servicecfg.AWSKMSConfig{ + AWSKMS: &servicecfg.AWSKMSConfig{ AWSAccount: "123456789012", AWSRegion: "us-west-2", MultiRegion: struct{ Enabled bool }{ @@ -560,7 +560,7 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack { }, }, } - fakeAWSKMSBackend, err := newAWSKMSKeystore(ctx, &fakeAWSKMSConfig.AWSKMS, opts) + fakeAWSKMSBackend, err := newAWSKMSKeystore(ctx, fakeAWSKMSConfig.AWSKMS, opts) require.NoError(t, err) name := "fake_aws_kms" if multiRegion { diff --git a/lib/auth/keystore/manager.go b/lib/auth/keystore/manager.go index c0c712098f935..877cd3bbc4b03 100644 --- a/lib/auth/keystore/manager.go +++ b/lib/auth/keystore/manager.go @@ -215,8 +215,8 @@ func NewManager(ctx context.Context, cfg *servicecfg.KeystoreConfig, opts *Optio } backendForNewKeys = gcpBackend usableSigningBackends = []backend{gcpBackend, softwareBackend} - case cfg.AWSKMS != (servicecfg.AWSKMSConfig{}): - awsBackend, err := newAWSKMSKeystore(ctx, &cfg.AWSKMS, opts) + case cfg.AWSKMS != nil: + awsBackend, err := newAWSKMSKeystore(ctx, cfg.AWSKMS, opts) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/keystore/testhelpers.go b/lib/auth/keystore/testhelpers.go index 6c8e113a9af21..f88588548af7e 100644 --- a/lib/auth/keystore/testhelpers.go +++ b/lib/auth/keystore/testhelpers.go @@ -95,7 +95,7 @@ func awsKMSTestConfig(t *testing.T) (servicecfg.KeystoreConfig, bool) { return servicecfg.KeystoreConfig{}, false } return servicecfg.KeystoreConfig{ - AWSKMS: servicecfg.AWSKMSConfig{ + AWSKMS: &servicecfg.AWSKMSConfig{ AWSAccount: awsKMSAccount, AWSRegion: awsKMSRegion, }, diff --git a/lib/config/configuration.go b/lib/config/configuration.go index bea0b44d7d874..35bc67654a558 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -1171,6 +1171,7 @@ func applyAWSKMSConfig(kmsConfig *AWSKMS, cfg *servicecfg.Config) error { } cfg.Auth.KeyStore.AWSKMS.AWSRegion = kmsConfig.Region cfg.Auth.KeyStore.AWSKMS.MultiRegion = kmsConfig.MultiRegion + cfg.Auth.KeyStore.AWSKMS.Tags = kmsConfig.Tags return nil } diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go index bb10a43d3085b..da2d7ee79463c 100644 --- a/lib/config/fileconf.go +++ b/lib/config/fileconf.go @@ -922,6 +922,11 @@ type AWSKMS struct { // Enabled configures new keys to be multi-region. Enabled bool } `yaml:"multi_region,omitempty"` + // Tags are key/value pairs used as AWS resource tags. The 'TeleportCluster' + // tag is added automatically if not specified in the set of tags. Changing tags + // after Teleport has already created KMS keys may require manually updating + // the tags of existing keys. + Tags map[string]string `yaml:"tags,omitempty"` } // TrustedCluster struct holds configuration values under "trusted_clusters" key diff --git a/lib/service/servicecfg/auth.go b/lib/service/servicecfg/auth.go index 1ecc416e3c453..810515f49a757 100644 --- a/lib/service/servicecfg/auth.go +++ b/lib/service/servicecfg/auth.go @@ -199,7 +199,7 @@ type KeystoreConfig struct { // GCPKMS holds configuration parameters specific to GCP KMS keystores. GCPKMS GCPKMSConfig // AWSKMS holds configuration parameter specific to AWS KMS keystores. - AWSKMS AWSKMSConfig + AWSKMS *AWSKMSConfig } // CheckAndSetDefaults checks that required parameters of the config are @@ -218,7 +218,7 @@ func (cfg *KeystoreConfig) CheckAndSetDefaults() error { } count++ } - if cfg.AWSKMS != (AWSKMSConfig{}) { + if cfg.AWSKMS != nil { if err := cfg.AWSKMS.CheckAndSetDefaults(); err != nil { return trace.Wrap(err, "validating aws_kms config") } @@ -294,6 +294,11 @@ type AWSKMSConfig struct { // Enabled configures new keys to be multi-region. Enabled bool } + // Tags are key/value pairs used as AWS resource tags. The 'TeleportCluster' + // tag is added automatically if not specified in the set of tags. Changing tags + // after Teleport has already created KMS keys may require manually updating + // the tags of existing keys. + Tags map[string]string } // CheckAndSetDefaults checks that required parameters of the config are