diff --git a/go.mod b/go.mod index fa63fe6eb8695..3778726f94195 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1 github.com/aws/aws-sdk-go-v2/service/athena v1.49.0 + github.com/aws/aws-sdk-go-v2/service/dax v1.23.7 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.0 github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.8 github.com/aws/aws-sdk-go-v2/service/ec2 v1.195.0 diff --git a/go.sum b/go.sum index 6a73eda617b84..aadfbf4b65238 100644 --- a/go.sum +++ b/go.sum @@ -883,6 +883,8 @@ github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1 h1:8EwNbY+A/ github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1/go.mod h1:2mMP2R86zLPAUz0TpJdsKW8XawHgs9Nk97fYJomO3o8= github.com/aws/aws-sdk-go-v2/service/athena v1.49.0 h1:D+iatX9gV6gCuNd6BnUkfwfZJw/cXlEk+LwwDdSMdtw= github.com/aws/aws-sdk-go-v2/service/athena v1.49.0/go.mod h1:27ljwDsnZvfrZKsLzWD4WFjI4OZutEFIjvVtYfj9gHc= +github.com/aws/aws-sdk-go-v2/service/dax v1.23.7 h1:hZg1sHhWXGZShzHGpwcaOT8HZfx26kkbRDNZgZda4xI= +github.com/aws/aws-sdk-go-v2/service/dax v1.23.7/go.mod h1:fYBjETTq8hZfirBEgXM1xIMy+tvCGYZTeWpjeKKp0bU= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.0 h1:isKhHsjpQR3CypQJ4G1g8QWx7zNpiC/xKw1zjgJYVno= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.0/go.mod h1:xDvUyIkwBwNtVZJdHEwAuhFly3mezwdEWkbJ5oNYwIw= github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.8 h1:ntqHwZb+ZyVz0CFYUG0sQ02KMMJh+iXeV3bXoba+s4A= diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index 1b7cf7ecfecec..d16422a0120fe 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -812,6 +812,8 @@ github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1 h1:8EwNbY+A/ github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1/go.mod h1:2mMP2R86zLPAUz0TpJdsKW8XawHgs9Nk97fYJomO3o8= github.com/aws/aws-sdk-go-v2/service/athena v1.49.0 h1:D+iatX9gV6gCuNd6BnUkfwfZJw/cXlEk+LwwDdSMdtw= github.com/aws/aws-sdk-go-v2/service/athena v1.49.0/go.mod h1:27ljwDsnZvfrZKsLzWD4WFjI4OZutEFIjvVtYfj9gHc= +github.com/aws/aws-sdk-go-v2/service/dax v1.23.7 h1:hZg1sHhWXGZShzHGpwcaOT8HZfx26kkbRDNZgZda4xI= +github.com/aws/aws-sdk-go-v2/service/dax v1.23.7/go.mod h1:fYBjETTq8hZfirBEgXM1xIMy+tvCGYZTeWpjeKKp0bU= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.0 h1:isKhHsjpQR3CypQJ4G1g8QWx7zNpiC/xKw1zjgJYVno= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.0/go.mod h1:xDvUyIkwBwNtVZJdHEwAuhFly3mezwdEWkbJ5oNYwIw= github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.8 h1:ntqHwZb+ZyVz0CFYUG0sQ02KMMJh+iXeV3bXoba+s4A= diff --git a/lib/srv/db/dynamodb/engine.go b/lib/srv/db/dynamodb/engine.go index 940b0b315a724..d877741dc628b 100644 --- a/lib/srv/db/dynamodb/engine.go +++ b/lib/srv/db/dynamodb/engine.go @@ -30,10 +30,10 @@ import ( "strconv" "strings" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/service/dax" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodbstreams" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dax" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams" "github.com/gravitational/trace" "github.com/prometheus/client_golang/prometheus" @@ -43,6 +43,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/gravitational/teleport/lib/utils" @@ -54,6 +55,7 @@ func NewEngine(ec common.EngineConfig) common.Engine { return &Engine{ EngineConfig: ec, RoundTrippers: make(map[string]http.RoundTripper), + UseFIPS: modules.GetModules().IsBoringBinary(), } } @@ -71,6 +73,8 @@ type Engine struct { RoundTrippers map[string]http.RoundTripper // CredentialsGetter is used to obtain STS credentials. CredentialsGetter libaws.CredentialsGetter + // UseFIPS will ensure FIPS endpoint resolution. + UseFIPS bool } var _ common.Engine = (*Engine)(nil) @@ -194,7 +198,7 @@ func (e *Engine) process(ctx context.Context, req *http.Request, signer *libaws. // emit an audit event regardless of failure, but using the resolved endpoint. var responseStatusCode uint32 defer func() { - e.emitAuditEvent(req, re.URL, responseStatusCode, err) + e.emitAuditEvent(req, re.URL.String(), responseStatusCode, err) }() // try to read, close, and replace the incoming request body. @@ -319,8 +323,8 @@ func (e *Engine) checkAccess(ctx context.Context, sessionCtx *common.Session) er } // getRoundTripper makes an HTTP round tripper with TLS config based on the given URL. -func (e *Engine) getRoundTripper(ctx context.Context, URL string) (http.RoundTripper, error) { - if rt, ok := e.RoundTrippers[URL]; ok { +func (e *Engine) getRoundTripper(ctx context.Context, u *url.URL) (http.RoundTripper, error) { + if rt, ok := e.RoundTrippers[u.String()]; ok { return rt, nil } tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx.GetExpiry(), e.sessionCtx.Database, e.sessionCtx.DatabaseUser) @@ -329,55 +333,136 @@ func (e *Engine) getRoundTripper(ctx context.Context, URL string) (http.RoundTri } // We need to set the ServerName here because the AWS endpoint service prefix is not known in advance, // and the TLS config we got does not set it. - host, err := getURLHostname(URL) - if err != nil { - return nil, trace.Wrap(err) - } - tlsConfig.ServerName = host + tlsConfig.ServerName = u.Hostname() out, err := defaults.Transport() if err != nil { return nil, trace.Wrap(err) } out.TLSClientConfig = tlsConfig - e.RoundTrippers[URL] = out + e.RoundTrippers[u.String()] = out return out, nil } -// resolveEndpoint returns a resolved endpoint for either the configured URI or the AWS target service and region. -func (e *Engine) resolveEndpoint(req *http.Request) (*endpoints.ResolvedEndpoint, error) { - endpointID, err := extractEndpointID(req) +type endpoint struct { + URL *url.URL + SigningName string + SigningRegion string +} + +// resolveEndpoint returns a resolved endpoint for either the configured URI or +// the AWS target service and region. +// For a target operation, the appropriate AWS service resolver is used. +// Targets look like one of DynamoDB_$version.$operation, +// DynamoDBStreams_$version.$operation, or AmazonDAX$version.$operation. +// For example: DynamoDBStreams_20120810.ListStreams +func (e *Engine) resolveEndpoint(req *http.Request) (*endpoint, error) { + target, err := getTargetHeader(req) if err != nil { return nil, trace.Wrap(err) } - opts := func(opts *endpoints.Options) { - opts.ResolveUnknownService = true + + awsMeta := e.sessionCtx.Database.GetAWS() + + var re *endpoint + switch target := strings.ToLower(target); { + case strings.HasPrefix(target, "dynamodbstreams"): + re, err = resolveDynamoDBStreamsEndpoint(req.Context(), awsMeta.Region, e.UseFIPS) + case strings.HasPrefix(target, "dynamodb"): + re, err = resolveDynamoDBEndpoint(req.Context(), awsMeta.Region, awsMeta.AccountID, e.UseFIPS) + case strings.HasPrefix(target, "amazondax"): + re, err = resolveDaxEndpoint(req.Context(), awsMeta.Region, e.UseFIPS) + default: + return nil, trace.BadParameter("DynamoDB API target %q is not recognized", target) } - re, err := endpoints.DefaultResolver().EndpointFor(endpointID, e.sessionCtx.Database.GetAWS().Region, opts) if err != nil { return nil, trace.Wrap(err) } uri := e.sessionCtx.Database.GetURI() - if uri != "" && uri != apiaws.DynamoDBURIForRegion(e.sessionCtx.Database.GetAWS().Region) { + if uri != "" && uri != apiaws.DynamoDBURIForRegion(awsMeta.Region) { + // Add a temporary schema to make a valid URL for url.Parse. + if !strings.Contains(uri, "://") { + uri = "schema://" + uri + } + u, err := url.Parse(uri) + if err != nil { + return nil, trace.Wrap(err) + } // override the resolved endpoint URL with the user-configured URI. - re.URL = uri + re.URL = u } - if !strings.Contains(re.URL, "://") { - re.URL = "https://" + re.URL + // Force HTTPS + re.URL.Scheme = "https" + return re, nil +} + +func resolveDynamoDBStreamsEndpoint(ctx context.Context, region string, useFIPS bool) (*endpoint, error) { + params := dynamodbstreams.EndpointParameters{ + Region: aws.String(region), + UseFIPS: aws.Bool(useFIPS), } - return &re, nil + ep, err := dynamodbstreams.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) + if err != nil { + return nil, trace.Wrap(err) + } + return &endpoint{ + URL: &ep.URI, + SigningRegion: region, + // DynamoDB Streams uses the same signing name as DynamoDB. + SigningName: "dynamodb", + }, nil } -// rewriteRequest clones a request, modifies the clone to rewrite its URL, and returns the modified request clone. -func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.ResolvedEndpoint, body []byte) (*http.Request, error) { - resolvedURL, err := url.Parse(re.URL) +func resolveDynamoDBEndpoint(ctx context.Context, region, accountID string, useFIPS bool) (*endpoint, error) { + params := dynamodb.EndpointParameters{ + Region: aws.String(region), + // Preferred means if we have an account ID available, then use an + // account ID based endpoint. + // We should always have the account ID available anyway. + // If we didn't then it would just resolve the regional endpoint like + // dynamodb..amazonaws.com. + // AWS documents that account-based routing provides better request + // performance for some services. + // See: https://docs.aws.amazon.com/sdkref/latest/guide/feature-account-endpoints.html + AccountIdEndpointMode: aws.String(aws.AccountIDEndpointModePreferred), + UseFIPS: aws.Bool(useFIPS), + } + if accountID != "" { + params.AccountId = aws.String(accountID) + } + ep, err := dynamodb.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) if err != nil { return nil, trace.Wrap(err) } + return &endpoint{ + URL: &ep.URI, + SigningRegion: region, + SigningName: "dynamodb", + }, nil +} + +func resolveDaxEndpoint(ctx context.Context, region string, useFIPS bool) (*endpoint, error) { + params := dax.EndpointParameters{ + Region: aws.String(region), + UseFIPS: aws.Bool(useFIPS), + } + ep, err := dax.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) + if err != nil { + return nil, trace.Wrap(err) + } + return &endpoint{ + URL: &ep.URI, + SigningRegion: region, + SigningName: "dax", + }, nil +} + +// rewriteRequest clones a request, modifies the clone to rewrite its URL, and returns the modified request clone. +func rewriteRequest(ctx context.Context, r *http.Request, re *endpoint, body []byte) (*http.Request, error) { reqCopy := r.Clone(ctx) // set url and host header to match the resolved endpoint. - reqCopy.URL = resolvedURL - reqCopy.Host = resolvedURL.Host + reqCopy.URL = re.URL + reqCopy.Host = re.URL.Host if body == nil { // no body is fine, skip copying it. return reqCopy, nil @@ -388,42 +473,13 @@ func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.Resolved return reqCopy, nil } -// extractEndpointID extracts the AWS endpoint ID from the request header X-Amz-Target. -func extractEndpointID(req *http.Request) (string, error) { +// getTargetHeader gets the X-Amz-Target header or returns an error if it is not +// present, as we rely on this header for endpoint resolution. +// See X-Amz-Target: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Programming.LowLevelAPI.html +func getTargetHeader(req *http.Request) (string, error) { target := req.Header.Get(libaws.AmzTargetHeader) if target == "" { return "", trace.BadParameter("missing %q header in http request", libaws.AmzTargetHeader) } - endpointID, err := endpointIDForTarget(target) - return endpointID, trace.Wrap(err) -} - -// endpointIDForTarget converts a target operation into the appropriate the AWS endpoint ID. -// Target looks like one of DynamoDB_$version.$operation, DynamoDBStreams_$version.$operation, AmazonDAX$version.$operation, -// for example: DynamoDBStreams_20120810.ListStreams -// See X-Amz-Target: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Programming.LowLevelAPI.html -func endpointIDForTarget(target string) (string, error) { - t := strings.ToLower(target) - switch { - case strings.HasPrefix(t, "dynamodbstreams"): - return dynamodbstreams.EndpointsID, nil - case strings.HasPrefix(t, "dynamodb"): - return dynamodb.EndpointsID, nil - case strings.HasPrefix(t, "amazondax"): - return dax.EndpointsID, nil - default: - return "", trace.BadParameter("DynamoDB API target %q is not recognized", target) - } -} - -// getURLHostname parses a URL to extract its hostname. -func getURLHostname(uri string) (string, error) { - if !strings.Contains(uri, "://") { - uri = "schema://" + uri - } - parsed, err := url.Parse(uri) - if err != nil { - return "", trace.Wrap(err) - } - return parsed.Hostname(), nil + return target, nil } diff --git a/lib/srv/db/dynamodb/engine_test.go b/lib/srv/db/dynamodb/engine_test.go index faeaf0536ed1d..57b1fe170a409 100644 --- a/lib/srv/db/dynamodb/engine_test.go +++ b/lib/srv/db/dynamodb/engine_test.go @@ -38,7 +38,8 @@ func TestResolveEndpoint(t *testing.T) { desc string target string // from X-Amz-Target in requests region string - wantEndpointID string + useFIPS bool + unsetAccountID bool wantSigningName string wantURL string wantErrMsg string @@ -47,15 +48,21 @@ func TestResolveEndpoint(t *testing.T) { desc: "dynamodb target in us west", target: "DynamoDB_20120810.Scan", region: "us-west-1", - wantEndpointID: "dynamodb", + wantSigningName: "dynamodb", + wantURL: "https://123456789012.ddb.us-west-1.amazonaws.com", + }, + { + desc: "dynamodb target in us west with no account id", + target: "DynamoDB_20120810.Scan", + region: "us-west-1", wantSigningName: "dynamodb", wantURL: "https://dynamodb.us-west-1.amazonaws.com", + unsetAccountID: true, }, { desc: "dynamodb target in china", target: "DynamoDB_20120810.Scan", region: "cn-north-1", - wantEndpointID: "dynamodb", wantSigningName: "dynamodb", wantURL: "https://dynamodb.cn-north-1.amazonaws.com.cn", }, @@ -63,7 +70,6 @@ func TestResolveEndpoint(t *testing.T) { desc: "dynamodb streams target in us west", target: "DynamoDBStreams_20120810.ListStreams", region: "us-west-1", - wantEndpointID: "streams.dynamodb", wantSigningName: "dynamodb", wantURL: "https://streams.dynamodb.us-west-1.amazonaws.com", }, @@ -71,7 +77,6 @@ func TestResolveEndpoint(t *testing.T) { desc: "dynamodb streams target in china", target: "DynamoDBStreams_20120810.ListStreams", region: "cn-north-1", - wantEndpointID: "streams.dynamodb", wantSigningName: "dynamodb", wantURL: "https://streams.dynamodb.cn-north-1.amazonaws.com.cn", }, @@ -79,7 +84,6 @@ func TestResolveEndpoint(t *testing.T) { desc: "dax target in us west", target: "AmazonDAXV3.ListTags", region: "us-west-1", - wantEndpointID: "dax", wantSigningName: "dax", wantURL: "https://dax.us-west-1.amazonaws.com", }, @@ -87,10 +91,33 @@ func TestResolveEndpoint(t *testing.T) { desc: "dax target in china", target: "AmazonDAXV3.ListTags", region: "cn-north-1", - wantEndpointID: "dax", wantSigningName: "dax", wantURL: "https://dax.cn-north-1.amazonaws.com.cn", }, + { + desc: "dynamodb target in us west with FIPS required", + target: "DynamoDB_20120810.Scan", + region: "us-west-1", + wantSigningName: "dynamodb", + wantURL: "https://dynamodb-fips.us-west-1.amazonaws.com", + useFIPS: true, + }, + { + desc: "dynamodb streams target in us west with FIPS required", + target: "DynamoDBStreams_20120810.ListStreams", + region: "us-west-1", + wantSigningName: "dynamodb", + wantURL: "https://streams.dynamodb-fips.us-west-1.amazonaws.com", + useFIPS: true, + }, + { + desc: "dax target in us west with FIPS required", + target: "AmazonDAXV3.ListTags", + region: "us-west-1", + wantSigningName: "dax", + wantURL: "https://dax-fips.us-west-1.amazonaws.com", + useFIPS: true, + }, { desc: "unrecognizable target", target: "DDB.Scan", @@ -105,25 +132,19 @@ func TestResolveEndpoint(t *testing.T) { req := &http.Request{Header: make(http.Header)} req.Header.Set(libaws.AmzTargetHeader, tt.target) - // check that the correct endpoint ID is extracted. - endpointID, err := extractEndpointID(req) - if tt.wantErrMsg != "" { - require.Error(t, err) - require.ErrorContains(t, err, tt.wantErrMsg) - return - } - require.Equal(t, tt.wantEndpointID, endpointID) - // check that the engine resolves the correct URL. db := &types.DatabaseV3{ Spec: types.DatabaseSpecV3{ URI: apiaws.DynamoDBURIForRegion(tt.region), AWS: types.AWS{ Region: tt.region, - AccountID: "12345", + AccountID: "123456789012", }, }, } + if tt.unsetAccountID { + db.Spec.AWS.AccountID = "" + } engine := &Engine{ EngineConfig: common.EngineConfig{ Log: slog.Default(), @@ -131,18 +152,26 @@ func TestResolveEndpoint(t *testing.T) { sessionCtx: &common.Session{ Database: db, }, + UseFIPS: tt.useFIPS, } re, err := engine.resolveEndpoint(req) + if tt.wantErrMsg != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErrMsg) + return + } require.NoError(t, err) - require.Equal(t, tt.wantURL, re.URL) + require.Equal(t, tt.wantURL, re.URL.String()) require.Equal(t, tt.wantSigningName, re.SigningName) + require.Equal(t, tt.region, re.SigningRegion) // now use a custom URI and check that it overrides the resolved URL. db.Spec.URI = "foo.com" re, err = engine.resolveEndpoint(req) require.NoError(t, err) - require.Equal(t, "https://foo.com", re.URL) + require.Equal(t, "https://foo.com", re.URL.String()) require.Equal(t, tt.wantSigningName, re.SigningName) + require.Equal(t, tt.region, re.SigningRegion) }) } } diff --git a/lib/srv/db/dynamodb/test.go b/lib/srv/db/dynamodb/test.go index 462dc743782ab..cf7661dc044b4 100644 --- a/lib/srv/db/dynamodb/test.go +++ b/lib/srv/db/dynamodb/test.go @@ -27,10 +27,9 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -41,7 +40,10 @@ import ( ) // Client alias for easier use. -type Client = dynamodb.DynamoDB +type Client struct { + *dynamodb.Client + HTTPClient *http.Client +} // ClientOptionsParams is a struct for client configuration options. type ClientOptionsParams struct { @@ -54,19 +56,19 @@ type ClientOptions func(*ClientOptionsParams) // MakeTestClient returns DynamoDB client connection according to the provided // parameters. func MakeTestClient(_ context.Context, config common.TestClientConfig, opts ...ClientOptions) (*Client, error) { - provider := session.Must(session.NewSession(&aws.Config{ - Credentials: credentials.NewCredentials(&credentials.StaticProvider{Value: credentials.Value{ - AccessKeyID: "fakeClientKeyID", - SecretAccessKey: "fakeClientSecret", - }}), - Region: aws.String("local"), - })) - dynamoClient := dynamodb.New(provider, &aws.Config{ - Endpoint: aws.String("http://" + config.Address), - MaxRetries: aws.Int(0), // disable automatic retries in tests - HTTPClient: &http.Client{Timeout: 5 * time.Second}, + httpClt := &http.Client{Timeout: 5 * time.Second} + dynamoClient := dynamodb.New(dynamodb.Options{ + Region: "local", + Credentials: credentials.NewStaticCredentialsProvider( + "fakeClientKeyID", + "fakeClientSecret", + "", + ), + BaseEndpoint: aws.String("http://" + config.Address), + RetryMaxAttempts: 0, // disable automatic retries in tests + HTTPClient: httpClt, }) - return dynamoClient, nil + return &Client{Client: dynamoClient, HTTPClient: httpClt}, nil } // TestServerOption allows setting test server options. @@ -107,7 +109,7 @@ func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (*T mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - err := awsutils.VerifyAWSSignature(r, credentials.NewStaticCredentials("AKIDl", "SECRET", "SESSION")) + err := awsutils.VerifyAWSSignatureV2(r, credentials.NewStaticCredentialsProvider("AKIDl", "SECRET", "SESSION")) if err != nil { code := trace.ErrorToCode(err) body, _ := json.Marshal(jsonErr{ diff --git a/lib/srv/db/dynamodb_test.go b/lib/srv/db/dynamodb_test.go index 0ed6355983381..f7a2b259e110b 100644 --- a/lib/srv/db/dynamodb_test.go +++ b/lib/srv/db/dynamodb_test.go @@ -25,9 +25,8 @@ import ( "net/http" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - awsdynamodb "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/credentials" + awsdynamodb "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -36,6 +35,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/dynamodb" awsutils "github.com/gravitational/teleport/lib/utils/aws" + "github.com/gravitational/teleport/lib/utils/aws/migration" ) func registerTestDynamoDBEngine() { @@ -50,7 +50,9 @@ func newTestDynamoDBEngine(ec common.EngineConfig) common.Engine { RoundTrippers: make(map[string]http.RoundTripper), // inject mock AWS credentials. CredentialsGetter: awsutils.NewStaticCredentialsGetter( - credentials.NewStaticCredentials("AKIDl", "SECRET", "SESSION"), + migration.NewCredentialsAdapter( + credentials.NewStaticCredentialsProvider("AKIDl", "SECRET", "SESSION"), + ), ), } } @@ -127,14 +129,14 @@ func TestAccessDynamoDB(t *testing.T) { require.NoError(t, err) // Execute a dynamodb query. - out, err := clt.ListTables(&awsdynamodb.ListTablesInput{}) + out, err := clt.ListTables(ctx, &awsdynamodb.ListTablesInput{}) if test.wantErrMsg != "" { require.Error(t, err) require.ErrorContains(t, err, test.wantErrMsg) return } require.NoError(t, err) - require.ElementsMatch(t, mockTables, aws.StringValueSlice(out.TableNames)) + require.ElementsMatch(t, mockTables, out.TableNames) }) } } @@ -159,7 +161,7 @@ func TestAuditDynamoDB(t *testing.T) { require.NoError(t, err) // Execute a dynamodb query. - _, err = clt.ListTables(&awsdynamodb.ListTablesInput{}) + _, err = clt.ListTables(ctx, &awsdynamodb.ListTablesInput{}) require.Error(t, err) require.ErrorContains(t, err, "access to db denied") requireEvent(t, testCtx, libevents.DatabaseSessionStartFailureCode) @@ -176,21 +178,21 @@ func TestAuditDynamoDB(t *testing.T) { require.NoError(t, err) t.Run("session starts and emits a request event", func(t *testing.T) { - _, err := clt.ListTables(&awsdynamodb.ListTablesInput{}) + _, err := clt.ListTables(ctx, &awsdynamodb.ListTablesInput{}) require.NoError(t, err) requireEvent(t, testCtx, libevents.DatabaseSessionStartCode) requireEvent(t, testCtx, libevents.DynamoDBRequestCode) }) t.Run("session ends when client closes the connection", func(t *testing.T) { - clt.Config.HTTPClient.CloseIdleConnections() + clt.HTTPClient.CloseIdleConnections() requireEvent(t, testCtx, libevents.DatabaseSessionEndCode) }) t.Run("session ends when local proxy closes the connection", func(t *testing.T) { // closing local proxy and canceling the context used to start it should trigger session end event. // without this cancel, the session will not end until the smaller of client_idle_timeout or the testCtx closes. - _, err := clt.ListTables(&awsdynamodb.ListTablesInput{}) + _, err := clt.ListTables(ctx, &awsdynamodb.ListTablesInput{}) require.NoError(t, err) requireEvent(t, testCtx, libevents.DatabaseSessionStartCode) requireEvent(t, testCtx, libevents.DynamoDBRequestCode)