Skip to content

Commit

Permalink
keystore: add support for custom kms tags (#48132) (#48771)
Browse files Browse the repository at this point in the history
* keystore: add support for custom kms tags

* remove utils.IsEmpty and use pointer for aws kms config

* test tags len



* add yaml struct tags



* Update lib/auth/keystore/aws_kms_test.go



* keep TeleportCluster tag if not specified

* add delete used key test cases with tagging

* fix unit tests

* remove unused tag var

* Update lib/auth/keystore/aws_kms.go



---------

Co-authored-by: Nic Klaassen <[email protected]>
  • Loading branch information
dboslee and nklaassen authored Nov 12, 2024
1 parent 1cf82ed commit 00ec4da
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 100 deletions.
2 changes: 1 addition & 1 deletion lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
if !modules.GetModules().Features().GetEntitlement(entitlements.HSM).Enabled {
return nil, fmt.Errorf("Google Cloud KMS support requires a license with the HSM feature enabled: %w", ErrRequiresEnterprise)
}
} else if cfg.KeyStoreConfig.AWSKMS != (servicecfg.AWSKMSConfig{}) {
} else if cfg.KeyStoreConfig.AWSKMS != nil {
if !modules.GetModules().Features().GetEntitlement(entitlements.HSM).Enabled {
return nil, fmt.Errorf("AWS KMS support requires a license with the HSM feature enabled: %w", ErrRequiresEnterprise)
}
Expand Down
41 changes: 27 additions & 14 deletions lib/auth/keystore/aws_kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ const (

type awsKMSKeystore struct {
kms kmsClient
clusterName types.ClusterName
awsAccount string
awsRegion string
multiRegionEnabled bool
tags map[string]string
clock clockwork.Clock
logger *slog.Logger
}
Expand Down Expand Up @@ -95,14 +95,23 @@ func newAWSKMSKeystore(ctx context.Context, cfg *servicecfg.AWSKMSConfig, opts *
return nil, trace.BadParameter("configured AWS KMS account %q does not match AWS account of ambient credentials %q",
cfg.AWSAccount, aws.ToString(id.Account))
}

tags := cfg.Tags
if tags == nil {
tags = make(map[string]string, 1)
}
if _, ok := tags[clusterTagKey]; !ok {
tags[clusterTagKey] = opts.ClusterName.GetClusterName()
}

clock := opts.clockworkOverride
if clock == nil {
clock = clockwork.NewRealClock()
}
return &awsKMSKeystore{
clusterName: opts.ClusterName,
awsAccount: cfg.AWSAccount,
awsRegion: cfg.AWSRegion,
tags: tags,
multiRegionEnabled: cfg.MultiRegion.Enabled,
kms: kmsClient,
clock: clock,
Expand Down Expand Up @@ -132,16 +141,19 @@ func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites
slog.Any("algorithm", algorithm),
slog.Bool("multi_region", a.multiRegionEnabled))

tags := make([]kmstypes.Tag, 0, len(a.tags))
for k, v := range a.tags {
tags = append(tags, kmstypes.Tag{
TagKey: aws.String(k),
TagValue: aws.String(v),
})
}

output, err := a.kms.CreateKey(ctx, &kms.CreateKeyInput{
Description: aws.String("Teleport CA key"),
KeySpec: alg,
KeyUsage: kmstypes.KeyUsageTypeSignVerify,
Tags: []kmstypes.Tag{
{
TagKey: aws.String(clusterTagKey),
TagValue: aws.String(a.clusterName.GetClusterName()),
},
},
Tags: tags,
MultiRegion: aws.Bool(a.multiRegionEnabled),
})
if err != nil {
Expand Down Expand Up @@ -388,12 +400,13 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by
return nil
}

clusterName := a.clusterName.GetClusterName()
if !slices.ContainsFunc(output.Tags, func(tag kmstypes.Tag) bool {
return aws.ToString(tag.TagKey) == clusterTagKey && aws.ToString(tag.TagValue) == clusterName
}) {
// This key was not created by this Teleport cluster, never delete it.
return nil
// All tags must match for this key to be considered for deletion.
for k, v := range a.tags {
if !slices.ContainsFunc(output.Tags, func(tag kmstypes.Tag) bool {
return aws.ToString(tag.TagKey) == k && aws.ToString(tag.TagValue) == v
}) {
return nil
}
}

// Check if this key is not enabled or was created in the past 5 minutes.
Expand Down
215 changes: 141 additions & 74 deletions lib/auth/keystore/aws_kms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,80 +55,112 @@ func TestAWSKMS_DeleteUnusedKeys(t *testing.T) {
ctx := context.Background()
clock := clockwork.NewFakeClock()

const pageSize int = 4
fakeKMS := newFakeAWSKMSService(t, clock, "123456789012", "us-west-2", pageSize)
cfg := servicecfg.KeystoreConfig{
AWSKMS: servicecfg.AWSKMSConfig{
AWSAccount: "123456789012",
AWSRegion: "us-west-2",
for _, tc := range []struct {
name string
tags map[string]string
}{
{
name: "delete keys with default tags",
},
}
clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ClusterName: "test-cluster"})
require.NoError(t, err)
opts := &Options{
ClusterName: clusterName,
HostUUID: "uuid",
AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1},
awsKMSClient: fakeKMS,
awsSTSClient: &fakeAWSSTSClient{
account: "123456789012",
{
name: "delete keys with custom tags",
tags: map[string]string{
"test-key-1": "test-value-1",
},
},
clockworkOverride: clock,
}
keyStore, err := NewManager(ctx, &cfg, opts)
require.NoError(t, err)
{
name: "delete keys with override cluster tag",
tags: map[string]string{
"TeleportCluster": "test-cluster-2",
},
},
} {
t.Run(tc.name, func(t *testing.T) {
const pageSize int = 4
fakeKMS := newFakeAWSKMSService(t, clock, "123456789012", "us-west-2", pageSize)
cfg := servicecfg.KeystoreConfig{
AWSKMS: &servicecfg.AWSKMSConfig{
AWSAccount: "123456789012",
AWSRegion: "us-west-2",
Tags: tc.tags,
},
}
clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ClusterName: "test-cluster"})
require.NoError(t, err)
opts := &Options{
ClusterName: clusterName,
HostUUID: "uuid",
AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1},
awsKMSClient: fakeKMS,
awsSTSClient: &fakeAWSSTSClient{
account: "123456789012",
},
clockworkOverride: clock,
}
keyStore, err := NewManager(ctx, &cfg, opts)
require.NoError(t, err)

totalKeys := pageSize * 3
for i := 0; i < totalKeys; i++ {
_, err := keyStore.NewSSHKeyPair(ctx, cryptosuites.UserCASSH)
require.NoError(t, err, trace.DebugReport(err))
}
var otherTags []kmstypes.Tag
for k, v := range keyStore.backendForNewKeys.(*awsKMSKeystore).tags {
if k != clusterTagKey {
otherTags = append(otherTags, kmstypes.Tag{
TagKey: aws.String(k),
TagValue: aws.String(v),
})
}
}

// Newly created keys should not be deleted.
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
assert.Equal(t, kmstypes.KeyStateEnabled, key.state)
}
totalKeys := pageSize * 3
for i := 0; i < totalKeys; i++ {
_, err := keyStore.NewSSHKeyPair(ctx, cryptosuites.UserCASSH)
require.NoError(t, err, trace.DebugReport(err))
}

// Keys created more than 5 minutes ago should be deleted.
clock.Advance(6 * time.Minute)
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
assert.Equal(t, kmstypes.KeyStatePendingDeletion, key.state)
}
// Newly created keys should not be deleted.
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
assert.Equal(t, kmstypes.KeyStateEnabled, key.state)
}

// Insert a key created by a different Teleport cluster, it should not be
// deleted by the keystore.
output, err := fakeKMS.CreateKey(ctx, &kms.CreateKeyInput{
KeySpec: kmstypes.KeySpecEccNistP256,
Tags: []kmstypes.Tag{
kmstypes.Tag{
TagKey: aws.String(clusterTagKey),
TagValue: aws.String("other-cluster"),
},
},
})
require.NoError(t, err)
otherClusterKeyARN := aws.ToString(output.KeyMetadata.Arn)
// Keys created more than 5 minutes ago should be deleted.
clock.Advance(6 * time.Minute)
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
assert.Equal(t, kmstypes.KeyStatePendingDeletion, key.state)
}

clock.Advance(6 * time.Minute)
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
if key.arn == otherClusterKeyARN {
assert.Equal(t, kmstypes.KeyStateEnabled, key.state)
} else {
assert.Equal(t, kmstypes.KeyStatePendingDeletion, key.state)
}
// Insert a key created by a different Teleport cluster, it should not be
// deleted by the keystore.
output, err := fakeKMS.CreateKey(ctx, &kms.CreateKeyInput{
KeySpec: kmstypes.KeySpecEccNistP256,
Tags: append(otherTags, kmstypes.Tag{
TagKey: aws.String(clusterTagKey),
TagValue: aws.String("other-cluster"),
}),
})
require.NoError(t, err)
otherClusterKeyARN := aws.ToString(output.KeyMetadata.Arn)

clock.Advance(6 * time.Minute)
err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/)
require.NoError(t, err)
for _, key := range fakeKMS.keys {
if key.arn == otherClusterKeyARN {
assert.Equal(t, kmstypes.KeyStateEnabled, key.state)
} else {
assert.Equal(t, kmstypes.KeyStatePendingDeletion, key.state)
}
}
})
}
}

func TestAWSKMS_WrongAccount(t *testing.T) {
clock := clockwork.NewFakeClock()
cfg := &servicecfg.KeystoreConfig{
AWSKMS: servicecfg.AWSKMSConfig{
AWSKMS: &servicecfg.AWSKMSConfig{
AWSAccount: "111111111111",
AWSRegion: "us-west-2",
},
Expand Down Expand Up @@ -159,7 +191,7 @@ func TestAWSKMS_RetryWhilePending(t *testing.T) {
pageLimit: 1000,
}
cfg := &servicecfg.KeystoreConfig{
AWSKMS: servicecfg.AWSKMSConfig{
AWSKMS: &servicecfg.AWSKMSConfig{
AWSAccount: "111111111111",
AWSRegion: "us-west-2",
},
Expand Down Expand Up @@ -211,13 +243,13 @@ func TestAWSKMS_RetryWhilePending(t *testing.T) {
require.Error(t, err)
}

// TestMultiRegionKeys asserts that a keystore created with multi-region enabled
// correctly passes this argument to the AWS client. This gives very little real
// coverage since the AWS KMS service here is faked, but at least we know the
// keystore passed the bool to the client correctly. TestBackends and
// TestManager are both able to run with a real AWS KMS client and you can
// confirm the keys are really multi-region there.
func TestMultiRegionKeys(t *testing.T) {
// TestKeyAWSKeyCreationParameters asserts that an AWS keystore created with a
// variety of parameters correctly passes these parameters to the AWS client.
// This gives very little real coverage since the AWS KMS service here is faked,
// but at least we know the keystore passed the parameters to the client correctly.
// TestBackends and TestManager are both able to run with a real AWS KMS client
// and you can confirm the keys are configured correctly there.
func TestAWSKeyCreationParameters(t *testing.T) {
ctx := context.Background()
clock := clockwork.NewFakeClock()

Expand All @@ -236,15 +268,37 @@ func TestMultiRegionKeys(t *testing.T) {
clockworkOverride: clock,
}

for _, multiRegion := range []bool{false, true} {
t.Run(fmt.Sprint(multiRegion), func(t *testing.T) {
for _, tc := range []struct {
name string
multiRegion bool
tags map[string]string
}{
{
name: "multi-region enabled with default tags",
multiRegion: true,
},
{
name: "multi-region disabled with default tags",
multiRegion: false,
},
{
name: "multi region disabled with custom tags",
multiRegion: false,
tags: map[string]string{
"key": "value",
},
},
} {

t.Run(tc.name, func(t *testing.T) {
cfg := servicecfg.KeystoreConfig{
AWSKMS: servicecfg.AWSKMSConfig{
AWSKMS: &servicecfg.AWSKMSConfig{
AWSAccount: "123456789012",
AWSRegion: "us-west-2",
MultiRegion: struct{ Enabled bool }{
Enabled: multiRegion,
Enabled: tc.multiRegion,
},
Tags: tc.tags,
},
}
keyStore, err := NewManager(ctx, &cfg, opts)
Expand All @@ -256,11 +310,24 @@ func TestMultiRegionKeys(t *testing.T) {
keyID, err := parseAWSKMSKeyID(sshKeyPair.PrivateKey)
require.NoError(t, err)

if multiRegion {
if tc.multiRegion {
assert.Contains(t, keyID.arn, "mrk-")
} else {
assert.NotContains(t, keyID.arn, "mrk-")
}

tagsOut, err := fakeKMS.ListResourceTags(ctx, &kms.ListResourceTagsInput{KeyId: &keyID.arn})
require.NoError(t, err)
if len(tc.tags) == 0 {
tc.tags = map[string]string{
"TeleportCluster": clusterName.GetClusterName(),
}
}
require.Equal(t, len(tc.tags), len(tagsOut.Tags))
for _, tag := range tagsOut.Tags {
v := tc.tags[aws.ToString(tag.TagKey)]
require.Equal(t, v, aws.ToString(tag.TagValue))
}
})
}
}
Expand Down
6 changes: 3 additions & 3 deletions lib/auth/keystore/keystore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack {
opts.awsKMSClient = nil
opts.awsSTSClient = nil

backend, err := newAWSKMSKeystore(ctx, &config.AWSKMS, &opts)
backend, err := newAWSKMSKeystore(ctx, config.AWSKMS, &opts)
require.NoError(t, err)
name := "aws_kms"
if multiRegion {
Expand Down Expand Up @@ -680,15 +680,15 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack {

// Always test with fake AWS client.
fakeAWSKMSConfig := servicecfg.KeystoreConfig{
AWSKMS: servicecfg.AWSKMSConfig{
AWSKMS: &servicecfg.AWSKMSConfig{
AWSAccount: "123456789012",
AWSRegion: "us-west-2",
MultiRegion: struct{ Enabled bool }{
Enabled: multiRegion,
},
},
}
fakeAWSKMSBackend, err := newAWSKMSKeystore(ctx, &fakeAWSKMSConfig.AWSKMS, &baseOpts)
fakeAWSKMSBackend, err := newAWSKMSKeystore(ctx, fakeAWSKMSConfig.AWSKMS, &baseOpts)
require.NoError(t, err)
name := "fake_aws_kms"
if multiRegion {
Expand Down
Loading

0 comments on commit 00ec4da

Please sign in to comment.