Skip to content

Commit

Permalink
chore: update IAM join method to use aws-sdk-go-v2 (#47044)
Browse files Browse the repository at this point in the history
* chore: update IAM join method to use aws-sdk-go-v2

* use constant

Co-authored-by: Zac Bergquist <[email protected]>

* fix lint

---------

Co-authored-by: Zac Bergquist <[email protected]>
  • Loading branch information
nklaassen and zmb3 authored Oct 2, 2024
1 parent 90ca3eb commit 5924c97
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 148 deletions.
43 changes: 22 additions & 21 deletions integration/ec2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import (

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -61,16 +60,16 @@ func newSilentLogger() utils.Logger {
return logger
}

func newNodeConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinMethod types.JoinMethod) *servicecfg.Config {
func newNodeConfig(t *testing.T, tokenName string, joinMethod types.JoinMethod) *servicecfg.Config {
config := servicecfg.MakeDefaultConfig()
config.Version = defaults.TeleportConfigVersionV3
config.SetToken(tokenName)
config.JoinMethod = joinMethod
config.SSH.Enabled = true
config.SSH.Addr.Addr = helpers.NewListener(t, service.ListenerNodeSSH, &config.FileDescriptors)
config.Auth.Enabled = false
config.Proxy.Enabled = false
config.DataDir = t.TempDir()
config.SetAuthServerAddress(authAddr)
config.Log = newSilentLogger()
config.CircuitBreakerConfig = breaker.NoopBreakerConfig()
config.InstanceMetadataClient = cloudimds.NewDisabledIMDSClient()
Expand All @@ -79,7 +78,7 @@ func newNodeConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinM

func newProxyConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinMethod types.JoinMethod) *servicecfg.Config {
config := servicecfg.MakeDefaultConfig()
config.Version = defaults.TeleportConfigVersionV2
config.Version = defaults.TeleportConfigVersionV3
config.SetToken(tokenName)
config.JoinMethod = joinMethod
config.SSH.Enabled = false
Expand Down Expand Up @@ -109,6 +108,7 @@ func newAuthConfig(t *testing.T, clock clockwork.Clock) *servicecfg.Config {
}

config := servicecfg.MakeDefaultConfig()
config.Version = defaults.TeleportConfigVersionV3
config.DataDir = t.TempDir()
config.Auth.ListenAddr.Addr = helpers.NewListener(t, service.ListenerAuth, &config.FileDescriptors)
config.Auth.ClusterName, err = services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{
Expand Down Expand Up @@ -140,13 +140,11 @@ func getIID(ctx context.Context, t *testing.T) imds.InstanceIdentityDocument {
return output.InstanceIdentityDocument
}

func getCallerIdentity(t *testing.T) *sts.GetCallerIdentityOutput {
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
})
func getCallerIdentity(ctx context.Context, t *testing.T) *sts.GetCallerIdentityOutput {
cfg, err := config.LoadDefaultConfig(ctx)
require.NoError(t, err)
stsService := sts.New(sess)
output, err := stsService.GetCallerIdentity(nil /*input*/)
stsClient := sts.NewFromConfig(cfg)
output, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
require.NoError(t, err)
return output
}
Expand Down Expand Up @@ -201,7 +199,8 @@ func TestEC2NodeJoin(t *testing.T) {
require.Empty(t, nodes)

// create and start the node
nodeConfig := newNodeConfig(t, authConfig.Auth.ListenAddr, tokenName, types.JoinMethodEC2)
nodeConfig := newNodeConfig(t, tokenName, types.JoinMethodEC2)
nodeConfig.SetAuthServerAddress(authConfig.Auth.ListenAddr)
nodeSvc, err := service.NewTeleport(nodeConfig)
require.NoError(t, err)
require.NoError(t, nodeSvc.Start())
Expand All @@ -214,7 +213,7 @@ func TestEC2NodeJoin(t *testing.T) {
require.Eventually(t, func() bool {
nodes, _ := authServer.GetNodes(ctx, apidefaults.Namespace)
return len(nodes) > 0
}, time.Minute, time.Second, "waiting for node to join cluster")
}, 10*time.Second, 50*time.Millisecond, "waiting for node to join cluster")
}

// TestIAMNodeJoin is an integration test which asserts that the IAM method for
Expand All @@ -225,6 +224,7 @@ func TestIAMNodeJoin(t *testing.T) {
if os.Getenv("TELEPORT_TEST_EC2") == "" {
t.Skipf("Skipping TestIAMNodeJoin because TELEPORT_TEST_EC2 is not set")
}
ctx := context.Background()

// create and start the auth server
authConfig := newAuthConfig(t, nil /*clock*/)
Expand All @@ -236,7 +236,7 @@ func TestIAMNodeJoin(t *testing.T) {
authServer := authSvc.GetAuthServer()

// fetch the caller identity to find the AWS account and create the token
id := getCallerIdentity(t)
id := getCallerIdentity(ctx, t)

tokenName := "test_token"
token, err := types.NewProvisionTokenFromSpec(
Expand All @@ -253,7 +253,7 @@ func TestIAMNodeJoin(t *testing.T) {
})
require.NoError(t, err)

err = authServer.UpsertToken(context.Background(), token)
err = authServer.UpsertToken(ctx, token)
require.NoError(t, err)

// sanity check there are no proxies to start with
Expand All @@ -274,31 +274,32 @@ func TestIAMNodeJoin(t *testing.T) {
proxies, err := authServer.GetProxies()
assert.NoError(t, err)
assert.NotEmpty(t, proxies)
}, time.Minute, time.Second, "waiting for proxy to join cluster")
}, 10*time.Second, 50*time.Millisecond, "waiting for proxy to join cluster")
// InsecureDevMode needed for node to trust proxy
wasInsecureDevMode := lib.IsInsecureDevMode()
t.Cleanup(func() { lib.SetInsecureDevMode(wasInsecureDevMode) })
lib.SetInsecureDevMode(true)

// sanity check there are no nodes to start with
nodes, err := authServer.GetNodes(context.Background(), apidefaults.Namespace)
nodes, err := authServer.GetNodes(ctx, apidefaults.Namespace)
require.NoError(t, err)
require.Empty(t, nodes)

// create and start a node, with use the IAM method to join in IoT mode by
// create and start a node, will use the IAM method to join in IoT mode by
// connecting to the proxy
nodeConfig := newNodeConfig(t, proxyConfig.Proxy.WebAddr, tokenName, types.JoinMethodIAM)
nodeConfig := newNodeConfig(t, tokenName, types.JoinMethodIAM)
nodeConfig.ProxyServer = proxyConfig.Proxy.WebAddr
nodeSvc, err := service.NewTeleport(nodeConfig)
require.NoError(t, err)
require.NoError(t, nodeSvc.Start())
t.Cleanup(func() { require.NoError(t, nodeSvc.Close()) })

// the node should eventually join the cluster and heartbeat
require.EventuallyWithT(t, func(t *assert.CollectT) {
nodes, err := authServer.GetNodes(context.Background(), apidefaults.Namespace)
nodes, err := authServer.GetNodes(ctx, apidefaults.Namespace)
assert.NoError(t, err)
assert.NotEmpty(t, nodes)
}, time.Minute, time.Second, "waiting for node to join cluster")
}, 10*time.Second, 50*time.Millisecond, "waiting for node to join cluster")
}

type mockIMDSClient struct {
Expand Down
6 changes: 4 additions & 2 deletions integration/teleterm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,8 @@ func testWaitForConnectMyComputerNodeJoin(t *testing.T, pack *dbhelpers.Database
}()

// Start the new node.
nodeConfig := newNodeConfig(t, pack.Root.Cluster.Config.Auth.ListenAddr, "token", types.JoinMethodToken)
nodeConfig := newNodeConfig(t, "token", types.JoinMethodToken)
nodeConfig.SetAuthServerAddress(pack.Root.Cluster.Config.Auth.ListenAddr)
nodeConfig.DataDir = filepath.Join(agentsDir, profileName, "data")
nodeConfig.Log = libutils.NewLoggerForTests()
nodeSvc, err := service.NewTeleport(nodeConfig)
Expand Down Expand Up @@ -1031,7 +1032,8 @@ func testDeleteConnectMyComputerNode(t *testing.T, pack *dbhelpers.DatabasePack)
require.NoError(t, err)

// Start the new node.
nodeConfig := newNodeConfig(t, pack.Root.Cluster.Config.Auth.ListenAddr, "token", types.JoinMethodToken)
nodeConfig := newNodeConfig(t, "token", types.JoinMethodToken)
nodeConfig.SetAuthServerAddress(pack.Root.Cluster.Config.Auth.ListenAddr)
nodeConfig.DataDir = filepath.Join(agentsDir, profileName, "data")
nodeConfig.Log = libutils.NewLoggerForTests()
nodeSvc, err := service.NewTeleport(nodeConfig)
Expand Down
25 changes: 11 additions & 14 deletions lib/auth/join/iam/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package iam
import "sync"

var (
// ValidSTSEndpoints holds a sorted list of all known valid public endpoints for
// the AWS STS service. You can generate this list by running
// $ go run github.com/nklaassen/sts-endpoints@latest --go-list
// Update aws-sdk-go in that package to learn about new endpoints.
// ValidSTSEndpoints returns a sorted list of all known valid public endpoints for
// the AWS STS service.
//
// TODO(nklaassen): find a better way to validate STS endpoints or generate
// this list and get notified when it needs to be updated. The original
// solution was https://github.com/nklaassen/sts-endpoints which is based on
// aws-sdk-go v1 which no longer gets updates for new regions.
ValidSTSEndpoints = sync.OnceValue(func() []string {
return []string{
"sts-fips.us-east-1.amazonaws.com",
Expand Down Expand Up @@ -69,18 +72,10 @@ var (
}
})

GlobalSTSEndpoints = sync.OnceValue(func() []string {
return []string{
"sts.amazonaws.com",
// This is not a real endpoint, but the SDK will select it if
// AWS_USE_FIPS_ENDPOINT is set and a region is not.
"sts-fips.aws-global.amazonaws.com",
}
})

// FIPSSTSEndpoints returns the set of known valid FIPS AWS STS endpoints.
FIPSSTSEndpoints = sync.OnceValue(func() []string {
return []string{
"sts-fips.us-east-1.amazonaws.com",
fipsSTSEndpointUSEast1,
"sts-fips.us-east-2.amazonaws.com",
"sts-fips.us-west-1.amazonaws.com",
"sts-fips.us-west-2.amazonaws.com",
Expand All @@ -89,3 +84,5 @@ var (
}
})
)

const fipsSTSEndpointUSEast1 = "sts-fips.us-east-1.amazonaws.com"
Loading

0 comments on commit 5924c97

Please sign in to comment.