From 6646d68ce011a87d2ebf5c34fd9725412cefd5e4 Mon Sep 17 00:00:00 2001 From: AJ Date: Sun, 29 Sep 2024 15:18:10 +0100 Subject: [PATCH] migrate to AWS SDKv2, updating only singnatures and making sure tests are passing. #1432 --- go.mod | 38 ++++- go.sum | 72 +++++++++ modules/aws/account.go | 13 +- modules/aws/acm.go | 12 +- modules/aws/ami.go | 34 +++-- modules/aws/asg.go | 29 ++-- modules/aws/asg_test.go | 91 ++++++----- modules/aws/auth.go | 36 +++-- modules/aws/cloudwatch.go | 16 +- modules/aws/dynamodb.go | 31 ++-- modules/aws/ebs.go | 12 +- modules/aws/ec2-syslog.go | 76 +--------- modules/aws/ec2.go | 158 +++++++++++--------- modules/aws/ec2_test.go | 21 ++- modules/aws/ecr.go | 56 +++---- modules/aws/ecr_test.go | 8 +- modules/aws/ecs.go | 71 ++++----- modules/aws/ecs_test.go | 20 +-- modules/aws/iam.go | 30 ++-- modules/aws/keypair.go | 14 +- modules/aws/keypair_test.go | 10 +- modules/aws/kms.go | 14 +- modules/aws/lambda.go | 24 +-- modules/aws/rds.go | 70 ++++----- modules/aws/region.go | 28 ++-- modules/aws/route53.go | 38 ++--- modules/aws/route53_test.go | 37 ++--- modules/aws/s3.go | 106 +++++++------ modules/aws/s3_test.go | 40 ++--- modules/aws/secretsmanager.go | 26 ++-- modules/aws/sns.go | 18 ++- modules/aws/sns_test.go | 7 +- modules/aws/sqs.go | 52 +++---- modules/aws/sqs_test.go | 7 +- modules/aws/ssm.go | 92 +++++++----- modules/aws/vpc.go | 118 ++++++++------- modules/aws/vpc_test.go | 76 +++++----- test/packer_basic_example_test.go | 26 ++-- test/terraform_aws_dynamodb_example_test.go | 24 +-- 39 files changed, 899 insertions(+), 752 deletions(-) diff --git a/go.mod b/go.mod index 202a8abc4..7f8665f26 100644 --- a/go.mod +++ b/go.mod @@ -48,7 +48,27 @@ require ( require ( cloud.google.com/go/cloudbuild v1.9.0 - github.com/gogo/protobuf v1.3.2 + github.com/aws/aws-sdk-go-v2 v1.31.0 + github.com/aws/aws-sdk-go-v2/config v1.27.39 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.25 + github.com/aws/aws-sdk-go-v2/service/acm v1.29.3 + github.com/aws/aws-sdk-go-v2/service/autoscaling v1.44.3 + github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.40.3 + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.35.3 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.179.2 + github.com/aws/aws-sdk-go-v2/service/ecr v1.35.3 + github.com/aws/aws-sdk-go-v2/service/ecs v1.46.3 + github.com/aws/aws-sdk-go-v2/service/iam v1.36.3 + github.com/aws/aws-sdk-go-v2/service/kms v1.36.3 + github.com/aws/aws-sdk-go-v2/service/lambda v1.62.1 + github.com/aws/aws-sdk-go-v2/service/rds v1.85.2 + github.com/aws/aws-sdk-go-v2/service/route53 v1.44.3 + github.com/aws/aws-sdk-go-v2/service/s3 v1.63.3 + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.33.3 + github.com/aws/aws-sdk-go-v2/service/sns v1.32.3 + github.com/aws/aws-sdk-go-v2/service/sqs v1.35.3 + github.com/aws/aws-sdk-go-v2/service/ssm v1.54.3 + github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 github.com/gonvenience/ytbx v1.4.4 github.com/homeport/dyff v1.6.0 github.com/slack-go/slack v0.10.3 @@ -69,6 +89,21 @@ require ( github.com/BurntSushi/toml v1.3.2 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.37 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.20 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.19 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.18 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 // indirect + github.com/aws/smithy-go v1.21.0 // indirect github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect @@ -84,6 +119,7 @@ require ( github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.22.3 // indirect + github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/gonvenience/bunt v1.3.5 // indirect diff --git a/go.sum b/go.sum index feef308f2..759387925 100644 --- a/go.sum +++ b/go.sum @@ -277,6 +277,78 @@ github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQ github.com/aws/aws-sdk-go v1.15.11/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0= github.com/aws/aws-sdk-go v1.44.122 h1:p6mw01WBaNpbdP2xrisz5tIkcNwzj/HysobNoaAHjgo= github.com/aws/aws-sdk-go v1.44.122/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= +github.com/aws/aws-sdk-go-v2 v1.31.0 h1:3V05LbxTSItI5kUqNwhJrrrY1BAXxXt0sN0l72QmG5U= +github.com/aws/aws-sdk-go-v2 v1.31.0/go.mod h1:ztolYtaEUtdpf9Wftr31CJfLVjOnD/CVRkKOOYgF8hA= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5 h1:xDAuZTn4IMm8o1LnBZvmrL8JA1io4o3YWNXgohbf20g= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5/go.mod h1:wYSv6iDS621sEFLfKvpPE2ugjTuGlAG7iROg0hLOkfc= +github.com/aws/aws-sdk-go-v2/config v1.27.39 h1:FCylu78eTGzW1ynHcongXK9YHtoXD5AiiUqq3YfJYjU= +github.com/aws/aws-sdk-go-v2/config v1.27.39/go.mod h1:wczj2hbyskP4LjMKBEZwPRO1shXY+GsQleab+ZXT2ik= +github.com/aws/aws-sdk-go-v2/credentials v1.17.37 h1:G2aOH01yW8X373JK419THj5QVqu9vKEwxSEsGxihoW0= +github.com/aws/aws-sdk-go-v2/credentials v1.17.37/go.mod h1:0ecCjlb7htYCptRD45lXJ6aJDQac6D2NlKGpZqyTG6A= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14 h1:C/d03NAmh8C4BZXhuRNboF/DqhBkBCeDiJDcaqIT5pA= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14/go.mod h1:7I0Ju7p9mCIdlrfS+JCgqcYD0VXz/N4yozsox+0o078= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.25 h1:HkpHeZMM39sGtMHVYG1buAg93vhj5d7F81y6G0OAbGc= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.25/go.mod h1:j3Vz04ZjaWA6kygOsZRpmWe4CyGqfqq2u3unDTU0QGA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 h1:kYQ3H1u0ANr9KEKlGs/jTLrBFPo8P8NaH/w7A01NeeM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18/go.mod h1:r506HmK5JDUh9+Mw4CfGJGSSoqIiLCndAuqXuhbv67Y= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 h1:Z7IdFUONvTcvS7YuhtVxN99v2cCoHRXOS4mTr0B/pUc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18/go.mod h1:DkKMmksZVVyat+Y+r1dEOgJEfUeA7UngIHWeKsi0yNc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.18 h1:OWYvKL53l1rbsUmW7bQyJVsYU/Ii3bbAAQIIFNbM0Tk= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.18/go.mod h1:CUx0G1v3wG6l01tUB+j7Y8kclA8NSqK4ef0YG79a4cg= +github.com/aws/aws-sdk-go-v2/service/acm v1.29.3 h1:EpXx6a8u5ZnhBuUr9yj8sEQv67jYkC8/TuRvS8TG248= +github.com/aws/aws-sdk-go-v2/service/acm v1.29.3/go.mod h1:pyj5IBRLA+w27gR7KJY/4lSWoP4XOsyOVsXKAMvWE3s= +github.com/aws/aws-sdk-go-v2/service/autoscaling v1.44.3 h1:uW81sdnq9hfg2hSnVqAFp+mMmu4Y86dU/bE9ET2LCIg= +github.com/aws/aws-sdk-go-v2/service/autoscaling v1.44.3/go.mod h1:Gmv7s//GGvs3nj9aqltFYnLStW8vDIwch0USkE67G4E= +github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.40.3 h1:s4rC9SWlq5hh6EDe+90LNkHuNQ6LOWZ2/7F2GaeOjaA= +github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.40.3/go.mod h1:3p7NzlLlJesNGovq7Vqx8+0UibawzodrBRQAbaza6pI= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.35.3 h1:X4iS+RcIKHkAMQz47nDt/nHxZUCKdnfgw940yluJ29Q= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.35.3/go.mod h1:k5XW8MoMxsNZ20RJmsokakvENUwQyjv69R9GqrI4xdQ= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.179.2 h1:rGBv2N0zWvNTKnxOfbBH4mNM8WMdDNkaxdqtz152G40= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.179.2/go.mod h1:W6sNzs5T4VpZn1Vy+FMKw8s24vt5k6zPJXcNOK0asBo= +github.com/aws/aws-sdk-go-v2/service/ecr v1.35.3 h1:8/vARxqd0Pn2Gqhp+8PxxTm3HttUMR1i1vBBj7MNFfc= +github.com/aws/aws-sdk-go-v2/service/ecr v1.35.3/go.mod h1:oRaGEExKI6Pqcow+Tt7wpJf73/Srcj/CUJv5Eb9QFhg= +github.com/aws/aws-sdk-go-v2/service/ecs v1.46.3 h1:BVItlUrorHr7lLLxWKFUVXxwht6IVVqLTQLGc6YLB6U= +github.com/aws/aws-sdk-go-v2/service/ecs v1.46.3/go.mod h1:/IMvyX4u5s4Ed0kzD+vWdPK92zm/q4CN1afJeDCsdhE= +github.com/aws/aws-sdk-go-v2/service/iam v1.36.3 h1:dV9iimLEHKYAz2qTi+tGAD9QCnAG2pLD7HUEHB7m4mI= +github.com/aws/aws-sdk-go-v2/service/iam v1.36.3/go.mod h1:HSvujsK8xeEHMIB18oMXjSfqaN9cVqpo/MtHJIksQRk= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5 h1:QFASJGfT8wMXtuP3D5CRmMjARHv9ZmzFUMJznHDOY3w= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5/go.mod h1:QdZ3OmoIjSX+8D1OPAzPxDfjXASbBMDsz9qvtyIhtik= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.20 h1:rTWjG6AvWekO2B1LHeM3ktU7MqyX9rzWQ7hgzneZW7E= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.20/go.mod h1:RGW2DDpVc8hu6Y6yG8G5CHVmVOAn1oV8rNKOHRJyswg= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.19 h1:dOxqOlOEa2e2heC/74+ZzcJOa27+F1aXFZpYgY/4QfA= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.9.19/go.mod h1:aV6U1beLFvk3qAgognjS3wnGGoDId8hlPEiBsLHXVZE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20 h1:Xbwbmk44URTiHNx6PNo0ujDE6ERlsCKJD3u1zfnzAPg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20/go.mod h1:oAfOFzUB14ltPZj1rWwRc3d/6OgD76R8KlvU3EqM9Fg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.18 h1:eb+tFOIl9ZsUe2259/BKPeniKuz4/02zZFH/i4Nf8Rg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.18/go.mod h1:GVCC2IJNJTmdlyEsSmofEy7EfJncP7DNnXDzRjJ5Keg= +github.com/aws/aws-sdk-go-v2/service/kms v1.36.3 h1:iHi6lC6LfW6SNvB2bixmlOW3WMyWFrHZCWX+P+CCxMk= +github.com/aws/aws-sdk-go-v2/service/kms v1.36.3/go.mod h1:OHmlX4+o0XIlJAQGAHPIy0N9yZcYS/vNG+T7geSNcFw= +github.com/aws/aws-sdk-go-v2/service/lambda v1.62.1 h1:Psp52CBlJtOVDyI4UMCAfovD4spGvdqapsBJxWZe470= +github.com/aws/aws-sdk-go-v2/service/lambda v1.62.1/go.mod h1:mivSaHqW3Atf5TDU1YyujR+HMv+snxCMoYaVd9d30O4= +github.com/aws/aws-sdk-go-v2/service/rds v1.85.2 h1:KDO/FSO8V+zlvnQF6v4nOariw2qwPx5/z2pyb6X7ibk= +github.com/aws/aws-sdk-go-v2/service/rds v1.85.2/go.mod h1:lhiPj6RvoJHWG2STp+k5az55YqGgFLBzkKYdYHgUh9g= +github.com/aws/aws-sdk-go-v2/service/route53 v1.44.3 h1:vYmafsIZWxc0EkIovYfjyfekHJogJjnIUXso5o7YPIA= +github.com/aws/aws-sdk-go-v2/service/route53 v1.44.3/go.mod h1:l2ABSKg3AibEJeR/l60cfeGU54UqF3VTgd51pq+vYhU= +github.com/aws/aws-sdk-go-v2/service/s3 v1.63.3 h1:3zt8qqznMuAZWDTDpcwv9Xr11M/lVj2FsRR7oYBt0OA= +github.com/aws/aws-sdk-go-v2/service/s3 v1.63.3/go.mod h1:NLTqRLe3pUNu3nTEHI6XlHLKYmc8fbHUdMxAB6+s41Q= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.33.3 h1:W2M3kQSuN1+FXgV2wMv1JMWPxw/37wBN87QHYDuTV0Y= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.33.3/go.mod h1:WyLS5qwXHtjKAONYZq/4ewdd+hcVsa3LBu77Ow5uj3k= +github.com/aws/aws-sdk-go-v2/service/sns v1.32.3 h1:LC5JBrEAdJ0SSRLfNcLzOLsfoc3xO/BAsHiUNcQfDI4= +github.com/aws/aws-sdk-go-v2/service/sns v1.32.3/go.mod h1:ZO606Jfatw51c8q29gHVVCnufg2dq3MnmkNLlTZFrkE= +github.com/aws/aws-sdk-go-v2/service/sqs v1.35.3 h1:Lcs658WFW235QuUfpAdxd8RCy8Va2VUA7/U9iIrcjcY= +github.com/aws/aws-sdk-go-v2/service/sqs v1.35.3/go.mod h1:WuGxWQhu2LXoPGA2HBIbotpwhM6T4hAz0Ip/HjdxfJg= +github.com/aws/aws-sdk-go-v2/service/ssm v1.54.3 h1:Ctzev3ppcc46m2FgrLEZhsHMEr1G1lrJcd9Cmoy/QJk= +github.com/aws/aws-sdk-go-v2/service/ssm v1.54.3/go.mod h1:qs3TBNpFEnVubl0WL3jruj7NJMF1RCAPEPQ1f+fLTBE= +github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 h1:rs4JCczF805+FDv2tRhZ1NU0RB2H6ryAvsWPanAr72Y= +github.com/aws/aws-sdk-go-v2/service/sso v1.23.3/go.mod h1:XRlMvmad0ZNL+75C5FYdMvbbLkd6qiqz6foR1nA1PXY= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 h1:S7EPdMVZod8BGKQQPTBK+FcX9g7bKR7c4+HxWqHP7Vg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3/go.mod h1:FnvDM4sfa+isJ3kDXIzAB9GAwVSzFzSy97uZ3IsHo4E= +github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JYsVqo2MxBPt5k8T8= +github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= +github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= +github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v0.0.0-20160804104726-4c0e84591b9a/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= diff --git a/modules/aws/account.go b/modules/aws/account.go index e64bd36d0..d4d93d37b 100644 --- a/modules/aws/account.go +++ b/modules/aws/account.go @@ -1,11 +1,12 @@ package aws import ( + "context" "errors" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -26,12 +27,12 @@ func GetAccountIdE(t testing.TestingT) (string, error) { return "", err } - identity, err := stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + identity, err := stsClient.GetCallerIdentity(context.Background(), &sts.GetCallerIdentityInput{}) if err != nil { return "", err } - return aws.StringValue(identity.Account), nil + return aws.ToString(identity.Account), nil } // An IAM arn is of the format arn:aws:iam::123456789012:user/test. The account id is the number after arn:aws:iam::, @@ -47,10 +48,10 @@ func extractAccountIDFromARN(arn string) (string, error) { } // NewStsClientE creates a new STS client. -func NewStsClientE(t testing.TestingT, region string) (*sts.STS, error) { +func NewStsClientE(t testing.TestingT, region string) (*sts.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return sts.New(sess), nil + return sts.NewFromConfig(*sess), nil } diff --git a/modules/aws/acm.go b/modules/aws/acm.go index 88ac5f9de..ea00a9d53 100644 --- a/modules/aws/acm.go +++ b/modules/aws/acm.go @@ -1,7 +1,9 @@ package aws import ( - "github.com/aws/aws-sdk-go/service/acm" + "context" + + "github.com/aws/aws-sdk-go-v2/service/acm" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -22,7 +24,7 @@ func GetAcmCertificateArnE(t testing.TestingT, awsRegion string, certDomainName return "", err } - result, err := acmClient.ListCertificates(&acm.ListCertificatesInput{}) + result, err := acmClient.ListCertificates(context.Background(), &acm.ListCertificatesInput{}) if err != nil { return "", err } @@ -37,7 +39,7 @@ func GetAcmCertificateArnE(t testing.TestingT, awsRegion string, certDomainName } // NewAcmClient create a new ACM client. -func NewAcmClient(t testing.TestingT, region string) *acm.ACM { +func NewAcmClient(t testing.TestingT, region string) *acm.Client { client, err := NewAcmClientE(t, region) if err != nil { t.Fatal(err) @@ -46,11 +48,11 @@ func NewAcmClient(t testing.TestingT, region string) *acm.ACM { } // NewAcmClientE creates a new ACM client. -func NewAcmClientE(t testing.TestingT, awsRegion string) (*acm.ACM, error) { +func NewAcmClientE(t testing.TestingT, awsRegion string) (*acm.Client, error) { sess, err := NewAuthenticatedSession(awsRegion) if err != nil { return nil, err } - return acm.New(sess), nil + return acm.NewFromConfig(*sess), nil } diff --git a/modules/aws/ami.go b/modules/aws/ami.go index c05879df0..e8bc513ba 100644 --- a/modules/aws/ami.go +++ b/modules/aws/ami.go @@ -1,12 +1,14 @@ package aws import ( + "context" "fmt" "sort" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -57,17 +59,17 @@ func GetEbsSnapshotsForAmi(t testing.TestingT, region string, ami string) []stri return snapshots } -// GetEbsSnapshotsForAmi retrieves the EBS snapshots which back the given AMI +// GetEbsSnapshotsForAmiE retrieves the EBS snapshots which back the given AMI func GetEbsSnapshotsForAmiE(t testing.TestingT, region string, ami string) ([]string, error) { logger.Logf(t, "Retrieving EBS snapshots backing AMI %s", ami) - ec2Client, err := NewEc2ClientE(t, region) + ec2Client, err := NewEc2ClientV2E(t, region) if err != nil { return nil, err } - images, err := ec2Client.DescribeImages(&ec2.DescribeImagesInput{ - ImageIds: []*string{ - aws.String(ami), + images, err := ec2Client.DescribeImages(context.Background(), &ec2.DescribeImagesInput{ + ImageIds: []string{ + ami, }, }) if err != nil { @@ -78,7 +80,7 @@ func GetEbsSnapshotsForAmiE(t testing.TestingT, region string, ami string) ([]st for _, image := range images.Images { for _, mapping := range image.BlockDeviceMappings { if mapping.Ebs != nil && mapping.Ebs.SnapshotId != nil { - snapshots = append(snapshots, aws.StringValue(mapping.Ebs.SnapshotId)) + snapshots = append(snapshots, aws.ToString(mapping.Ebs.SnapshotId)) } } } @@ -101,23 +103,23 @@ func GetMostRecentAmiId(t testing.TestingT, region string, ownerId string, filte // filter should correspond to the name and values of a filter supported by DescribeImagesInput: // https://docs.aws.amazon.com/sdk-for-go/api/service/ec2/#DescribeImagesInput func GetMostRecentAmiIdE(t testing.TestingT, region string, ownerId string, filters map[string][]string) (string, error) { - ec2Client, err := NewEc2ClientE(t, region) + ec2Client, err := NewEc2ClientV2E(t, region) if err != nil { return "", err } - ec2Filters := []*ec2.Filter{} + var ec2Filters []types.Filter for name, values := range filters { - ec2Filters = append(ec2Filters, &ec2.Filter{Name: aws.String(name), Values: aws.StringSlice(values)}) + ec2Filters = append(ec2Filters, types.Filter{Name: aws.String(name), Values: values}) } input := ec2.DescribeImagesInput{ Filters: ec2Filters, IncludeDeprecated: aws.Bool(true), - Owners: []*string{aws.String(ownerId)}, + Owners: []string{ownerId}, } - out, err := ec2Client.DescribeImages(&input) + out, err := ec2Client.DescribeImages(context.Background(), &input) if err != nil { return "", err } @@ -127,11 +129,11 @@ func GetMostRecentAmiIdE(t testing.TestingT, region string, ownerId string, filt } mostRecentImage := mostRecentAMI(out.Images) - return aws.StringValue(mostRecentImage.ImageId), nil + return aws.ToString(mostRecentImage.ImageId), nil } // Image sorting code borrowed from: https://github.com/hashicorp/packer/blob/7f4112ba229309cfc0ebaa10ded2abdfaf1b22c8/builder/amazon/common/step_source_ami_info.go -type imageSort []*ec2.Image +type imageSort []types.Image func (a imageSort) Len() int { return len(a) } func (a imageSort) Swap(i, j int) { a[i], a[j] = a[j], a[i] } @@ -142,7 +144,7 @@ func (a imageSort) Less(i, j int) bool { } // mostRecentAMI returns the most recent AMI out of a slice of images. -func mostRecentAMI(images []*ec2.Image) *ec2.Image { +func mostRecentAMI(images []types.Image) types.Image { sortedImages := images sort.Sort(imageSort(sortedImages)) return sortedImages[len(sortedImages)-1] diff --git a/modules/aws/asg.go b/modules/aws/asg.go index d3a066318..d7543d1fc 100644 --- a/modules/aws/asg.go +++ b/modules/aws/asg.go @@ -1,11 +1,12 @@ package aws import ( + "context" "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" "github.com/stretchr/testify/require" "github.com/gruntwork-io/terratest/modules/logger" @@ -34,8 +35,8 @@ func GetCapacityInfoForAsgE(t testing.TestingT, asgName string, awsRegion string return AsgCapacityInfo{}, err } - input := autoscaling.DescribeAutoScalingGroupsInput{AutoScalingGroupNames: []*string{aws.String(asgName)}} - output, err := asgClient.DescribeAutoScalingGroups(&input) + input := autoscaling.DescribeAutoScalingGroupsInput{AutoScalingGroupNames: []string{asgName}} + output, err := asgClient.DescribeAutoScalingGroups(context.Background(), &input) if err != nil { return AsgCapacityInfo{}, err } @@ -44,9 +45,9 @@ func GetCapacityInfoForAsgE(t testing.TestingT, asgName string, awsRegion string return AsgCapacityInfo{}, NewNotFoundError("ASG", asgName, awsRegion) } capacityInfo := AsgCapacityInfo{ - MinCapacity: *groups[0].MinSize, - MaxCapacity: *groups[0].MaxSize, - DesiredCapacity: *groups[0].DesiredCapacity, + MinCapacity: int64(*groups[0].MinSize), + MaxCapacity: int64(*groups[0].MaxSize), + DesiredCapacity: int64(*groups[0].DesiredCapacity), CurrentCapacity: int64(len(groups[0].Instances)), } return capacityInfo, nil @@ -68,16 +69,16 @@ func GetInstanceIdsForAsgE(t testing.TestingT, asgName string, awsRegion string) return nil, err } - input := autoscaling.DescribeAutoScalingGroupsInput{AutoScalingGroupNames: []*string{aws.String(asgName)}} - output, err := asgClient.DescribeAutoScalingGroups(&input) + input := autoscaling.DescribeAutoScalingGroupsInput{AutoScalingGroupNames: []string{asgName}} + output, err := asgClient.DescribeAutoScalingGroups(context.Background(), &input) if err != nil { return nil, err } - instanceIDs := []string{} + var instanceIDs []string for _, asg := range output.AutoScalingGroups { for _, instance := range asg.Instances { - instanceIDs = append(instanceIDs, aws.StringValue(instance.InstanceId)) + instanceIDs = append(instanceIDs, aws.ToString(instance.InstanceId)) } } @@ -125,7 +126,7 @@ func WaitForCapacityE( } // NewAsgClient creates an Auto Scaling Group client. -func NewAsgClient(t testing.TestingT, region string) *autoscaling.AutoScaling { +func NewAsgClient(t testing.TestingT, region string) *autoscaling.Client { client, err := NewAsgClientE(t, region) if err != nil { t.Fatal(err) @@ -134,11 +135,11 @@ func NewAsgClient(t testing.TestingT, region string) *autoscaling.AutoScaling { } // NewAsgClientE creates an Auto Scaling Group client. -func NewAsgClientE(t testing.TestingT, region string) (*autoscaling.AutoScaling, error) { +func NewAsgClientE(t testing.TestingT, region string) (*autoscaling.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return autoscaling.New(sess), nil + return autoscaling.NewFromConfig(*sess), nil } diff --git a/modules/aws/asg_test.go b/modules/aws/asg_test.go index b6b1dda7c..3efb6b8fe 100644 --- a/modules/aws/asg_test.go +++ b/modules/aws/asg_test.go @@ -1,13 +1,15 @@ package aws import ( + "context" "fmt" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/autoscaling" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/autoscaling" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -49,54 +51,60 @@ func TestGetInstanceIdsForAsg(t *testing.T) { // The following functions were adapted from the tests for cloud-nuke -func createTestAutoScalingGroup(t *testing.T, name string, region string, desiredCount int64) { +func createTestAutoScalingGroup(t *testing.T, name string, region string, desiredCount int32) { instance := createTestEC2Instance(t, region, name) asgClient := NewAsgClient(t, region) param := &autoscaling.CreateAutoScalingGroupInput{ AutoScalingGroupName: &name, InstanceId: instance.InstanceId, - DesiredCapacity: aws.Int64(desiredCount), - MinSize: aws.Int64(1), - MaxSize: aws.Int64(3), + DesiredCapacity: aws.Int32(desiredCount), + MinSize: aws.Int32(1), + MaxSize: aws.Int32(3), } - _, err := asgClient.CreateAutoScalingGroup(param) + _, err := asgClient.CreateAutoScalingGroup(context.Background(), param) require.NoError(t, err) - err = asgClient.WaitUntilGroupExists(&autoscaling.DescribeAutoScalingGroupsInput{ - AutoScalingGroupNames: []*string{&name}, - }) + waiter := autoscaling.NewGroupExistsWaiter(asgClient) + err = waiter.Wait(context.Background(), &autoscaling.DescribeAutoScalingGroupsInput{ + AutoScalingGroupNames: []string{name}, + }, 42*time.Minute) require.NoError(t, err) } -func createTestEC2Instance(t *testing.T, region string, name string) ec2.Instance { +func createTestEC2Instance(t *testing.T, region string, name string) types.Instance { ec2Client := NewEc2Client(t, region) imageID := GetAmazonLinuxAmi(t, region) params := &ec2.RunInstancesInput{ ImageId: aws.String(imageID), - InstanceType: aws.String(GetRecommendedInstanceType(t, region, []string{"t2.micro, t3.micro", "t2.small", "t3.small"})), - MinCount: aws.Int64(1), - MaxCount: aws.Int64(1), + InstanceType: types.InstanceType(GetRecommendedInstanceType(t, region, []string{"t2.micro, t3.micro", "t2.small", "t3.small"})), + MinCount: aws.Int32(1), + MaxCount: aws.Int32(1), } - runResult, err := ec2Client.RunInstances(params) + runResult, err := ec2Client.RunInstances(context.Background(), params) require.NoError(t, err) require.NotEqual(t, len(runResult.Instances), 0) - err = ec2Client.WaitUntilInstanceExists(&ec2.DescribeInstancesInput{ - Filters: []*ec2.Filter{ - &ec2.Filter{ - Name: aws.String("instance-id"), - Values: []*string{runResult.Instances[0].InstanceId}, + waiter := ec2.NewInstanceExistsWaiter(ec2Client) + err = waiter.Wait( + context.Background(), + &ec2.DescribeInstancesInput{ + Filters: []types.Filter{ + { + Name: aws.String("instance-id"), + Values: []string{*runResult.Instances[0].InstanceId}, + }, }, }, - }) + 42*time.Minute, + ) require.NoError(t, err) // Add test tag to the created instance - _, err = ec2Client.CreateTags(&ec2.CreateTagsInput{ - Resources: []*string{runResult.Instances[0].InstanceId}, - Tags: []*ec2.Tag{ + _, err = ec2Client.CreateTags(context.Background(), &ec2.CreateTagsInput{ + Resources: []string{*runResult.Instances[0].InstanceId}, + Tags: []types.Tag{ { Key: aws.String("Name"), Value: aws.String(name), @@ -106,17 +114,18 @@ func createTestEC2Instance(t *testing.T, region string, name string) ec2.Instanc require.NoError(t, err) // EC2 Instance must be in a running before this function returns - err = ec2Client.WaitUntilInstanceRunning(&ec2.DescribeInstancesInput{ - Filters: []*ec2.Filter{ - &ec2.Filter{ + runningWaiter := ec2.NewInstanceRunningWaiter(ec2Client) + err = runningWaiter.Wait(context.Background(), &ec2.DescribeInstancesInput{ + Filters: []types.Filter{ + { Name: aws.String("instance-id"), - Values: []*string{runResult.Instances[0].InstanceId}, + Values: []string{*runResult.Instances[0].InstanceId}, }, }, - }) + }, 42*time.Minute) require.NoError(t, err) - return *runResult.Instances[0] + return runResult.Instances[0] } func terminateEc2InstancesByName(t *testing.T, region string, names []string) { @@ -134,11 +143,13 @@ func deleteAutoScalingGroup(t *testing.T, name string, region string) { asgClient := NewAsgClient(t, region) input := &autoscaling.DeleteAutoScalingGroupInput{AutoScalingGroupName: aws.String(name)} - _, err := asgClient.DeleteAutoScalingGroup(input) + _, err := asgClient.DeleteAutoScalingGroup(context.Background(), input) require.NoError(t, err) - err = asgClient.WaitUntilGroupNotExists(&autoscaling.DescribeAutoScalingGroupsInput{ - AutoScalingGroupNames: []*string{aws.String(name)}, - }) + + waiter := autoscaling.NewGroupNotExistsWaiter(asgClient) + err = waiter.Wait(context.Background(), &autoscaling.DescribeAutoScalingGroupsInput{ + AutoScalingGroupNames: []string{name}, + }, 40*time.Minute) require.NoError(t, err) } @@ -146,15 +157,15 @@ func scaleAsgToZero(t *testing.T, name string, region string) { asgClient := NewAsgClient(t, region) input := &autoscaling.UpdateAutoScalingGroupInput{ AutoScalingGroupName: aws.String(name), - DesiredCapacity: aws.Int64(0), - MinSize: aws.Int64(0), - MaxSize: aws.Int64(0), + DesiredCapacity: aws.Int32(0), + MinSize: aws.Int32(0), + MaxSize: aws.Int32(0), } - _, err := asgClient.UpdateAutoScalingGroup(input) + _, err := asgClient.UpdateAutoScalingGroup(context.Background(), input) require.NoError(t, err) WaitForCapacity(t, name, region, 40, 15*time.Second) // There is an eventual consistency bug where even though the ASG is scaled down, AWS sometimes still views a - // scaling activity so we add a 5 second pause here to work around it. + // scaling activity so we add a 5-second pause here to work around it. time.Sleep(5 * time.Second) } diff --git a/modules/aws/auth.go b/modules/aws/auth.go index f2aa6f78c..7c34118ee 100644 --- a/modules/aws/auth.go +++ b/modules/aws/auth.go @@ -1,15 +1,18 @@ package aws import ( + "context" "fmt" - "os" "time" + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/sts" "github.com/pquerna/otp/totp" ) @@ -20,12 +23,12 @@ const ( // NewAuthenticatedSession creates an AWS session following to standard AWS authentication workflow. // If AuthAssumeIamRoleEnvVar environment variable is set, assumes IAM role specified in it. -func NewAuthenticatedSession(region string) (*session.Session, error) { - if assumeRoleArn, ok := os.LookupEnv(AuthAssumeRoleEnvVar); ok { - return NewAuthenticatedSessionFromRole(region, assumeRoleArn) - } else { - return NewAuthenticatedSessionFromDefaultCredentials(region) - } +func NewAuthenticatedSession(region string) (*awsv2.Config, error) { + // if assumeRoleArn, ok := os.LookupEnv(AuthAssumeRoleEnvVar); ok { + // return NewAuthenticatedSessionFromRole(region, assumeRoleArn) + // } else { + return NewAuthenticatedSessionFromDefaultCredentialsV2(region) + // } } // NewAuthenticatedSessionFromDefaultCredentials gets an AWS Session, checking that the user has credentials properly configured in their environment. @@ -49,6 +52,15 @@ func NewAuthenticatedSessionFromDefaultCredentials(region string) (*session.Sess return sess, nil } +func NewAuthenticatedSessionFromDefaultCredentialsV2(region string) (*awsv2.Config, error) { + cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region)) + if err != nil { + return nil, CredentialsError{UnderlyingErr: err} + } + + return &cfg, nil +} + // NewAuthenticatedSessionFromRole returns a new AWS Session after assuming the // role whose ARN is provided in roleARN. If the credentials are not properly // configured in the underlying environment, an error is returned. @@ -91,7 +103,7 @@ func CreateAwsSessionWithCreds(region string, accessKeyID string, secretAccessKe } // CreateAwsSessionWithMfa creates a new AWS session authenticated using an MFA token retrieved using the given STS client and MFA Device. -func CreateAwsSessionWithMfa(region string, stsClient *sts.STS, mfaDevice *iam.VirtualMFADevice) (*session.Session, error) { +func CreateAwsSessionWithMfa(region string, stsClient *sts.STS, mfaDevice *types.VirtualMFADevice) (*session.Session, error) { tokenCode, err := GetTimeBasedOneTimePassword(mfaDevice) if err != nil { return nil, err @@ -131,7 +143,7 @@ func CreateAwsCredentialsWithSessionToken(accessKeyID, secretAccessKey, sessionT } // GetTimeBasedOneTimePassword gets a One-Time Password from the given mfaDevice. Per the RFC 6238 standard, this value will be different every 30 seconds. -func GetTimeBasedOneTimePassword(mfaDevice *iam.VirtualMFADevice) (string, error) { +func GetTimeBasedOneTimePassword(mfaDevice *types.VirtualMFADevice) (string, error) { base32StringSeed := string(mfaDevice.Base32StringSeed) otp, err := totp.GenerateCode(base32StringSeed, time.Now()) @@ -143,8 +155,8 @@ func GetTimeBasedOneTimePassword(mfaDevice *iam.VirtualMFADevice) (string, error } // ReadPasswordPolicyMinPasswordLength returns the minimal password length. -func ReadPasswordPolicyMinPasswordLength(iamClient *iam.IAM) (int, error) { - output, err := iamClient.GetAccountPasswordPolicy(&iam.GetAccountPasswordPolicyInput{}) +func ReadPasswordPolicyMinPasswordLength(iamClient *iam.Client) (int, error) { + output, err := iamClient.GetAccountPasswordPolicy(context.Background(), &iam.GetAccountPasswordPolicyInput{}) if err != nil { return -1, err } diff --git a/modules/aws/cloudwatch.go b/modules/aws/cloudwatch.go index d5af76e28..f24783e11 100644 --- a/modules/aws/cloudwatch.go +++ b/modules/aws/cloudwatch.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/cloudwatchlogs" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -22,7 +24,7 @@ func GetCloudWatchLogEntriesE(t testing.TestingT, awsRegion string, logStreamNam return nil, err } - output, err := client.GetLogEvents(&cloudwatchlogs.GetLogEventsInput{ + output, err := client.GetLogEvents(context.Background(), &cloudwatchlogs.GetLogEventsInput{ LogGroupName: aws.String(logGroupName), LogStreamName: aws.String(logStreamName), }) @@ -31,7 +33,7 @@ func GetCloudWatchLogEntriesE(t testing.TestingT, awsRegion string, logStreamNam return nil, err } - entries := []string{} + var entries []string for _, event := range output.Events { entries = append(entries, *event.Message) } @@ -40,7 +42,7 @@ func GetCloudWatchLogEntriesE(t testing.TestingT, awsRegion string, logStreamNam } // NewCloudWatchLogsClient creates a new CloudWatch Logs client. -func NewCloudWatchLogsClient(t testing.TestingT, region string) *cloudwatchlogs.CloudWatchLogs { +func NewCloudWatchLogsClient(t testing.TestingT, region string) *cloudwatchlogs.Client { client, err := NewCloudWatchLogsClientE(t, region) if err != nil { t.Fatal(err) @@ -49,10 +51,10 @@ func NewCloudWatchLogsClient(t testing.TestingT, region string) *cloudwatchlogs. } // NewCloudWatchLogsClientE creates a new CloudWatch Logs client. -func NewCloudWatchLogsClientE(t testing.TestingT, region string) (*cloudwatchlogs.CloudWatchLogs, error) { +func NewCloudWatchLogsClientE(t testing.TestingT, region string) (*cloudwatchlogs.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return cloudwatchlogs.New(sess), nil + return cloudwatchlogs.NewFromConfig(*sess), nil } diff --git a/modules/aws/dynamodb.go b/modules/aws/dynamodb.go index 447b17ece..cbd44e4b8 100644 --- a/modules/aws/dynamodb.go +++ b/modules/aws/dynamodb.go @@ -1,23 +1,26 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" ) // GetDynamoDbTableTags fetches resource tags of a specified dynamoDB table. This will fail the test if there are any errors -func GetDynamoDbTableTags(t testing.TestingT, region string, tableName string) []*dynamodb.Tag { +func GetDynamoDbTableTags(t testing.TestingT, region string, tableName string) []types.Tag { tags, err := GetDynamoDbTableTagsE(t, region, tableName) require.NoError(t, err) return tags } // GetDynamoDbTableTagsE fetches resource tags of a specified dynamoDB table. -func GetDynamoDbTableTagsE(t testing.TestingT, region string, tableName string) ([]*dynamodb.Tag, error) { +func GetDynamoDbTableTagsE(t testing.TestingT, region string, tableName string) ([]types.Tag, error) { table := GetDynamoDBTable(t, region, tableName) - out, err := NewDynamoDBClient(t, region).ListTagsOfResource(&dynamodb.ListTagsOfResourceInput{ + out, err := NewDynamoDBClient(t, region).ListTagsOfResource(context.Background(), &dynamodb.ListTagsOfResourceInput{ ResourceArn: table.TableArn, }) if err != nil { @@ -27,15 +30,15 @@ func GetDynamoDbTableTagsE(t testing.TestingT, region string, tableName string) } // GetDynamoDBTableTimeToLive fetches information about the TTL configuration of a specified dynamoDB table. This will fail the test if there are any errors. -func GetDynamoDBTableTimeToLive(t testing.TestingT, region string, tableName string) *dynamodb.TimeToLiveDescription { +func GetDynamoDBTableTimeToLive(t testing.TestingT, region string, tableName string) *types.TimeToLiveDescription { ttl, err := GetDynamoDBTableTimeToLiveE(t, region, tableName) require.NoError(t, err) return ttl } // GetDynamoDBTableTimeToLiveE fetches information about the TTL configuration of a specified dynamoDB table. -func GetDynamoDBTableTimeToLiveE(t testing.TestingT, region string, tableName string) (*dynamodb.TimeToLiveDescription, error) { - out, err := NewDynamoDBClient(t, region).DescribeTimeToLive(&dynamodb.DescribeTimeToLiveInput{ +func GetDynamoDBTableTimeToLiveE(t testing.TestingT, region string, tableName string) (*types.TimeToLiveDescription, error) { + out, err := NewDynamoDBClient(t, region).DescribeTimeToLive(context.Background(), &dynamodb.DescribeTimeToLiveInput{ TableName: aws.String(tableName), }) if err != nil { @@ -45,15 +48,15 @@ func GetDynamoDBTableTimeToLiveE(t testing.TestingT, region string, tableName st } // GetDynamoDBTable fetches information about the specified dynamoDB table. This will fail the test if there are any errors. -func GetDynamoDBTable(t testing.TestingT, region string, tableName string) *dynamodb.TableDescription { +func GetDynamoDBTable(t testing.TestingT, region string, tableName string) *types.TableDescription { table, err := GetDynamoDBTableE(t, region, tableName) require.NoError(t, err) return table } // GetDynamoDBTableE fetches information about the specified dynamoDB table. -func GetDynamoDBTableE(t testing.TestingT, region string, tableName string) (*dynamodb.TableDescription, error) { - out, err := NewDynamoDBClient(t, region).DescribeTable(&dynamodb.DescribeTableInput{ +func GetDynamoDBTableE(t testing.TestingT, region string, tableName string) (*types.TableDescription, error) { + out, err := NewDynamoDBClient(t, region).DescribeTable(context.Background(), &dynamodb.DescribeTableInput{ TableName: aws.String(tableName), }) if err != nil { @@ -63,17 +66,17 @@ func GetDynamoDBTableE(t testing.TestingT, region string, tableName string) (*dy } // NewDynamoDBClient creates a DynamoDB client. -func NewDynamoDBClient(t testing.TestingT, region string) *dynamodb.DynamoDB { +func NewDynamoDBClient(t testing.TestingT, region string) *dynamodb.Client { client, err := NewDynamoDBClientE(t, region) require.NoError(t, err) return client } // NewDynamoDBClientE creates a DynamoDB client. -func NewDynamoDBClientE(t testing.TestingT, region string) (*dynamodb.DynamoDB, error) { +func NewDynamoDBClientE(t testing.TestingT, region string) (*dynamodb.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return dynamodb.New(sess), nil + return dynamodb.NewFromConfig(*sess), nil } diff --git a/modules/aws/ebs.go b/modules/aws/ebs.go index 3b8797dfe..e390c6399 100644 --- a/modules/aws/ebs.go +++ b/modules/aws/ebs.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -15,15 +17,15 @@ func DeleteEbsSnapshot(t testing.TestingT, region string, snapshot string) { } } -// DeleteEbsSnapshot deletes the given EBS snapshot +// DeleteEbsSnapshotE deletes the given EBS snapshot func DeleteEbsSnapshotE(t testing.TestingT, region string, snapshot string) error { logger.Logf(t, "Deleting EBS snapshot %s", snapshot) - ec2Client, err := NewEc2ClientE(t, region) + ec2Client, err := NewEc2ClientV2E(t, region) if err != nil { return err } - _, err = ec2Client.DeleteSnapshot(&ec2.DeleteSnapshotInput{ + _, err = ec2Client.DeleteSnapshot(context.Background(), &ec2.DeleteSnapshotInput{ SnapshotId: aws.String(snapshot), }) return err diff --git a/modules/aws/ec2-syslog.go b/modules/aws/ec2-syslog.go index 79bfca1fe..f622a465c 100644 --- a/modules/aws/ec2-syslog.go +++ b/modules/aws/ec2-syslog.go @@ -1,18 +1,12 @@ package aws import ( - "encoding/base64" "fmt" - "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/gruntwork-io/terratest/modules/logger" - "github.com/gruntwork-io/terratest/modules/retry" "github.com/gruntwork-io/terratest/modules/testing" ) -// (Deprecated) See the FetchContentsOfFileFromInstance method for a more powerful solution. +// GetSyslogForInstance (Deprecated) See the FetchContentsOfFileFromInstance method for a more powerful solution. // // GetSyslogForInstance gets the syslog for the Instance with the given ID in the given region. This should be available ~1 minute after an // Instance boots and is very useful for debugging boot-time issues, such as an error in User Data. @@ -24,57 +18,19 @@ func GetSyslogForInstance(t testing.TestingT, instanceID string, awsRegion strin return out } -// (Deprecated) See the FetchContentsOfFileFromInstanceE method for a more powerful solution. +// GetSyslogForInstanceE (Deprecated) See the FetchContentsOfFileFromInstanceE method for a more powerful solution. // // GetSyslogForInstanceE gets the syslog for the Instance with the given ID in the given region. This should be available ~1 minute after an // Instance boots and is very useful for debugging boot-time issues, such as an error in User Data. func GetSyslogForInstanceE(t testing.TestingT, instanceID string, region string) (string, error) { - description := fmt.Sprintf("Fetching syslog for Instance %s in %s", instanceID, region) - maxRetries := 120 - timeBetweenRetries := 5 * time.Second - - logger.Log(t, description) - - client, err := NewEc2ClientE(t, region) - if err != nil { - return "", err - } - - input := ec2.GetConsoleOutputInput{ - InstanceId: aws.String(instanceID), - } - - syslogB64, err := retry.DoWithRetryE(t, description, maxRetries, timeBetweenRetries, func() (string, error) { - out, err := client.GetConsoleOutput(&input) - if err != nil { - return "", err - } - - syslog := aws.StringValue(out.Output) - if syslog == "" { - return "", fmt.Errorf("Syslog is not yet available for instance %s in %s", instanceID, region) - } - - return syslog, nil - }) - - if err != nil { - return "", err - } - - syslogBytes, err := base64.StdEncoding.DecodeString(syslogB64) - if err != nil { - return "", err - } - - return string(syslogBytes), nil + return "", fmt.Errorf("(Deprecated) use FetchContentsOfFileFromInstanceE method instead") } -// (Deprecated) See the FetchContentsOfFilesFromAsg method for a more powerful solution. +// GetSyslogForInstancesInAsg (Deprecated) See the FetchContentsOfFilesFromAsg method for a more powerful solution. // // GetSyslogForInstancesInAsg gets the syslog for each of the Instances in the given ASG in the given region. These logs should be available ~1 // minute after the Instance boots and are very useful for debugging boot-time issues, such as an error in User Data. -// Returns a map of Instance Id -> Syslog for that Instance. +// Returns a map of Instance ID -> Syslog for that Instance. func GetSyslogForInstancesInAsg(t testing.TestingT, asgName string, awsRegion string) map[string]string { out, err := GetSyslogForInstancesInAsgE(t, asgName, awsRegion) if err != nil { @@ -83,27 +39,11 @@ func GetSyslogForInstancesInAsg(t testing.TestingT, asgName string, awsRegion st return out } -// (Deprecated) See the FetchContentsOfFilesFromAsgE method for a more powerful solution. +// GetSyslogForInstancesInAsgE (Deprecated) See the FetchContentsOfFilesFromAsgE method for a more powerful solution. // // GetSyslogForInstancesInAsgE gets the syslog for each of the Instances in the given ASG in the given region. These logs should be available ~1 // minute after the Instance boots and are very useful for debugging boot-time issues, such as an error in User Data. -// Returns a map of Instance Id -> Syslog for that Instance. +// Returns a map of Instance ID -> Syslog for that Instance. func GetSyslogForInstancesInAsgE(t testing.TestingT, asgName string, awsRegion string) (map[string]string, error) { - logger.Logf(t, "Fetching syslog for each Instance in ASG %s in %s", asgName, awsRegion) - - instanceIDs, err := GetEc2InstanceIdsByTagE(t, awsRegion, "aws:autoscaling:groupName", asgName) - if err != nil { - return nil, err - } - - logs := map[string]string{} - for _, id := range instanceIDs { - syslog, err := GetSyslogForInstanceE(t, id, awsRegion) - if err != nil { - return nil, err - } - logs[id] = syslog - } - - return logs, nil + return nil, fmt.Errorf("(Deprecated) use FetchContentsOfFilesFromAsgE method instead") } diff --git a/modules/aws/ec2.go b/modules/aws/ec2.go index 78db1c912..a598a1df3 100644 --- a/modules/aws/ec2.go +++ b/modules/aws/ec2.go @@ -1,10 +1,12 @@ package aws import ( + "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" @@ -42,19 +44,19 @@ func GetPrivateIpsOfEc2Instances(t testing.TestingT, instanceIDs []string, awsRe // GetPrivateIpsOfEc2InstancesE gets the private IP address of the given EC2 Instance in the given region. Returns a map of instance ID to IP address. func GetPrivateIpsOfEc2InstancesE(t testing.TestingT, instanceIDs []string, awsRegion string) (map[string]string, error) { - ec2Client := NewEc2Client(t, awsRegion) + ec2Client := NewEc2ClientV2(t, awsRegion) // TODO: implement pagination for cases that extend beyond limit (1000 instances) - input := ec2.DescribeInstancesInput{InstanceIds: aws.StringSlice(instanceIDs)} - output, err := ec2Client.DescribeInstances(&input) + input := ec2.DescribeInstancesInput{InstanceIds: instanceIDs} + output, err := ec2Client.DescribeInstances(context.Background(), &input) if err != nil { return nil, err } ips := map[string]string{} - for _, reserveration := range output.Reservations { - for _, instance := range reserveration.Instances { - ips[aws.StringValue(instance.InstanceId)] = aws.StringValue(instance.PrivateIpAddress) + for _, reservation := range output.Reservations { + for _, instance := range reservation.Instances { + ips[aws.ToString(instance.InstanceId)] = aws.ToString(instance.PrivateIpAddress) } } @@ -93,22 +95,22 @@ func GetPrivateHostnamesOfEc2Instances(t testing.TestingT, instanceIDs []string, // GetPrivateHostnamesOfEc2InstancesE gets the private IP address of the given EC2 Instance in the given region. Returns a map of instance ID to IP address. func GetPrivateHostnamesOfEc2InstancesE(t testing.TestingT, instanceIDs []string, awsRegion string) (map[string]string, error) { - ec2Client, err := NewEc2ClientE(t, awsRegion) + ec2Client, err := NewEc2ClientV2E(t, awsRegion) if err != nil { return nil, err } // TODO: implement pagination for cases that extend beyond limit (1000 instances) - input := ec2.DescribeInstancesInput{InstanceIds: aws.StringSlice(instanceIDs)} - output, err := ec2Client.DescribeInstances(&input) + input := ec2.DescribeInstancesInput{InstanceIds: instanceIDs} + output, err := ec2Client.DescribeInstances(context.Background(), &input) if err != nil { return nil, err } hostnames := map[string]string{} - for _, reserveration := range output.Reservations { - for _, instance := range reserveration.Instances { - hostnames[aws.StringValue(instance.InstanceId)] = aws.StringValue(instance.PrivateDnsName) + for _, reservation := range output.Reservations { + for _, instance := range reservation.Instances { + hostnames[aws.ToString(instance.InstanceId)] = aws.ToString(instance.PrivateDnsName) } } @@ -147,19 +149,19 @@ func GetPublicIpsOfEc2Instances(t testing.TestingT, instanceIDs []string, awsReg // GetPublicIpsOfEc2InstancesE gets the public IP address of the given EC2 Instance in the given region. Returns a map of instance ID to IP address. func GetPublicIpsOfEc2InstancesE(t testing.TestingT, instanceIDs []string, awsRegion string) (map[string]string, error) { - ec2Client := NewEc2Client(t, awsRegion) + ec2Client := NewEc2ClientV2(t, awsRegion) // TODO: implement pagination for cases that extend beyond limit (1000 instances) - input := ec2.DescribeInstancesInput{InstanceIds: aws.StringSlice(instanceIDs)} - output, err := ec2Client.DescribeInstances(&input) + input := ec2.DescribeInstancesInput{InstanceIds: instanceIDs} + output, err := ec2Client.DescribeInstances(context.Background(), &input) if err != nil { return nil, err } ips := map[string]string{} - for _, reserveration := range output.Reservations { - for _, instance := range reserveration.Instances { - ips[aws.StringValue(instance.InstanceId)] = aws.StringValue(instance.PublicIpAddress) + for _, reservation := range output.Reservations { + for _, instance := range reservation.Instances { + ips[aws.ToString(instance.InstanceId)] = aws.ToString(instance.PublicIpAddress) } } @@ -189,27 +191,27 @@ func GetEc2InstanceIdsByFilters(t testing.TestingT, region string, ec2Filters ma return out } -// GetEc2InstanceIdsByFilters returns all the IDs of EC2 instances in the given region which match to EC2 filter list +// GetEc2InstanceIdsByFiltersE returns all the IDs of EC2 instances in the given region which match to EC2 filter list // as per https://docs.aws.amazon.com/sdk-for-go/api/service/ec2/#DescribeInstancesInput. func GetEc2InstanceIdsByFiltersE(t testing.TestingT, region string, ec2Filters map[string][]string) ([]string, error) { - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) if err != nil { return nil, err } - ec2FilterList := []*ec2.Filter{} + var ec2FilterList []types.Filter for name, values := range ec2Filters { - ec2FilterList = append(ec2FilterList, &ec2.Filter{Name: aws.String(name), Values: aws.StringSlice(values)}) + ec2FilterList = append(ec2FilterList, types.Filter{Name: aws.String(name), Values: values}) } // TODO: implement pagination for cases that extend beyond limit (1000 instances) - output, err := client.DescribeInstances(&ec2.DescribeInstancesInput{Filters: ec2FilterList}) + output, err := client.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{Filters: ec2FilterList}) if err != nil { return nil, err } - instanceIDs := []string{} + var instanceIDs []string for _, reservation := range output.Reservations { for _, instance := range reservation.Instances { @@ -229,25 +231,25 @@ func GetTagsForEc2Instance(t testing.TestingT, region string, instanceID string) // GetTagsForEc2InstanceE returns all the tags for the given EC2 Instance. func GetTagsForEc2InstanceE(t testing.TestingT, region string, instanceID string) (map[string]string, error) { - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) if err != nil { return nil, err } input := ec2.DescribeTagsInput{ - Filters: []*ec2.Filter{ + Filters: []types.Filter{ { Name: aws.String("resource-type"), - Values: aws.StringSlice([]string{"instance"}), + Values: []string{"instance"}, }, { Name: aws.String("resource-id"), - Values: aws.StringSlice([]string{instanceID}), + Values: []string{instanceID}, }, }, } - out, err := client.DescribeTags(&input) + out, err := client.DescribeTags(context.Background(), &input) if err != nil { return nil, err } @@ -255,7 +257,7 @@ func GetTagsForEc2InstanceE(t testing.TestingT, region string, instanceID string tags := map[string]string{} for _, tag := range out.Tags { - tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + tags[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } return tags, nil @@ -270,12 +272,12 @@ func DeleteAmi(t testing.TestingT, region string, imageID string) { func DeleteAmiE(t testing.TestingT, region string, imageID string) error { logger.Logf(t, "Deregistering AMI %s", imageID) - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) if err != nil { return err } - _, err = client.DeregisterImage(&ec2.DeregisterImageInput{ImageId: aws.String(imageID)}) + _, err = client.DeregisterImage(context.Background(), &ec2.DeregisterImageInput{ImageId: aws.String(imageID)}) return err } @@ -286,21 +288,21 @@ func AddTagsToResource(t testing.TestingT, region string, resource string, tags // AddTagsToResourceE adds the tags to the given taggable AWS resource such as EC2, AMI or VPC. func AddTagsToResourceE(t testing.TestingT, region string, resource string, tags map[string]string) error { - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) if err != nil { return err } - var awsTags []*ec2.Tag + var awsTags []types.Tag for key, value := range tags { - awsTags = append(awsTags, &ec2.Tag{ + awsTags = append(awsTags, types.Tag{ Key: aws.String(key), Value: aws.String(value), }) } - _, err = client.CreateTags(&ec2.CreateTagsInput{ - Resources: []*string{aws.String(resource)}, + _, err = client.CreateTags(context.Background(), &ec2.CreateTagsInput{ + Resources: []string{resource}, Tags: awsTags, }) @@ -316,14 +318,14 @@ func TerminateInstance(t testing.TestingT, region string, instanceID string) { func TerminateInstanceE(t testing.TestingT, region string, instanceID string) error { logger.Logf(t, "Terminating Instance %s", instanceID) - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) if err != nil { return err } - _, err = client.TerminateInstances(&ec2.TerminateInstancesInput{ - InstanceIds: []*string{ - aws.String(instanceID), + _, err = client.TerminateInstances(context.Background(), &ec2.TerminateInstancesInput{ + InstanceIds: []string{ + instanceID, }, }) @@ -344,7 +346,7 @@ func GetAmiPubliclyAccessibleE(t testing.TestingT, awsRegion string, amiID strin return false, err } for _, launchPermission := range launchPermissions { - if aws.StringValue(launchPermission.Group) == "all" { + if string(launchPermission.Group) == "all" { return true, nil } } @@ -360,30 +362,30 @@ func GetAccountsWithLaunchPermissionsForAmi(t testing.TestingT, awsRegion string // GetAccountsWithLaunchPermissionsForAmiE returns list of accounts that the AMI is shared with func GetAccountsWithLaunchPermissionsForAmiE(t testing.TestingT, awsRegion string, amiID string) ([]string, error) { - accountIDs := []string{} + var accountIDs []string launchPermissions, err := GetLaunchPermissionsForAmiE(t, awsRegion, amiID) if err != nil { return accountIDs, err } for _, launchPermission := range launchPermissions { - if aws.StringValue(launchPermission.UserId) != "" { - accountIDs = append(accountIDs, aws.StringValue(launchPermission.UserId)) + if aws.ToString(launchPermission.UserId) != "" { + accountIDs = append(accountIDs, aws.ToString(launchPermission.UserId)) } } return accountIDs, nil } // GetLaunchPermissionsForAmiE returns launchPermissions as configured in AWS -func GetLaunchPermissionsForAmiE(t testing.TestingT, awsRegion string, amiID string) ([]*ec2.LaunchPermission, error) { - client := NewEc2Client(t, awsRegion) +func GetLaunchPermissionsForAmiE(t testing.TestingT, awsRegion string, amiID string) ([]types.LaunchPermission, error) { + client := NewEc2ClientV2(t, awsRegion) input := &ec2.DescribeImageAttributeInput{ - Attribute: aws.String("launchPermission"), + Attribute: types.ImageAttributeNameLaunchPermission, ImageId: aws.String(amiID), } - output, err := client.DescribeImageAttribute(input) + output, err := client.DescribeImageAttribute(context.Background(), input) if err != nil { - return []*ec2.LaunchPermission{}, err + return []types.LaunchPermission{}, err } return output.LaunchPermissions, nil } @@ -408,7 +410,7 @@ func GetRecommendedInstanceType(t testing.TestingT, region string, instanceTypeO // AZs. If you have code that needs to run on a "small" instance across all AZs in many different regions, you can // use this function to automatically figure out which instance type you should use. func GetRecommendedInstanceTypeE(t testing.TestingT, region string, instanceTypeOptions []string) (string, error) { - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) if err != nil { return "", err } @@ -422,7 +424,7 @@ func GetRecommendedInstanceTypeE(t testing.TestingT, region string, instanceType // AZs. If you have code that needs to run on a "small" instance across all AZs in many different regions, you can // use this function to automatically figure out which instance type you should use. // This function expects an authenticated EC2 client from the AWS SDK Go library. -func GetRecommendedInstanceTypeWithClientE(t testing.TestingT, ec2Client *ec2.EC2, instanceTypeOptions []string) (string, error) { +func GetRecommendedInstanceTypeWithClientE(t testing.TestingT, ec2Client *ec2.Client, instanceTypeOptions []string) (string, error) { availabilityZones, err := getAllAvailabilityZonesE(ec2Client) if err != nil { return "", err @@ -439,7 +441,7 @@ func GetRecommendedInstanceTypeWithClientE(t testing.TestingT, ec2Client *ec2.EC // pickRecommendedInstanceTypeE returns the first instance type from instanceTypeOptions that is available in all the // AZs in availabilityZones based on the availability data in instanceTypeOfferings. If none of the instance types are // available in all AZs, this function returns an error. -func pickRecommendedInstanceTypeE(availabilityZones []string, instanceTypeOfferings []*ec2.InstanceTypeOffering, instanceTypeOptions []string) (string, error) { +func pickRecommendedInstanceTypeE(availabilityZones []string, instanceTypeOfferings []types.InstanceTypeOffering, instanceTypeOptions []string) (string, error) { // O(n^3) for the win! for _, instanceType := range instanceTypeOptions { if instanceTypeExistsInAllAzs(instanceType, availabilityZones, instanceTypeOfferings) { @@ -450,9 +452,9 @@ func pickRecommendedInstanceTypeE(availabilityZones []string, instanceTypeOfferi return "", NoInstanceTypeError{InstanceTypeOptions: instanceTypeOptions, Azs: availabilityZones} } -// instanceTypeExistsInAllAzs returns true if the given inistance type exists in all the given availabilityZones based +// instanceTypeExistsInAllAzs returns true if the given instance type exists in all the given availabilityZones based // on the availability data in instanceTypeOfferings -func instanceTypeExistsInAllAzs(instanceType string, availabilityZones []string, instanceTypeOfferings []*ec2.InstanceTypeOffering) bool { +func instanceTypeExistsInAllAzs(instanceType string, availabilityZones []string, instanceTypeOfferings []types.InstanceTypeOffering) bool { if len(availabilityZones) == 0 || len(instanceTypeOfferings) == 0 { return false } @@ -468,9 +470,9 @@ func instanceTypeExistsInAllAzs(instanceType string, availabilityZones []string, // hasOffering returns true if the given availability zone and instance type are one of the offerings in // instanceTypeOfferings -func hasOffering(instanceTypeOfferings []*ec2.InstanceTypeOffering, availabilityZone string, instanceType string) bool { +func hasOffering(instanceTypeOfferings []types.InstanceTypeOffering, availabilityZone string, instanceType string) bool { for _, offering := range instanceTypeOfferings { - if aws.StringValue(offering.InstanceType) == instanceType && aws.StringValue(offering.Location) == availabilityZone { + if string(offering.InstanceType) == instanceType && aws.ToString(offering.Location) == availabilityZone { return true } } @@ -480,18 +482,18 @@ func hasOffering(instanceTypeOfferings []*ec2.InstanceTypeOffering, availability // getInstanceTypeOfferingsE returns the instance types from the given list that are available in the region configured // in the given EC2 client -func getInstanceTypeOfferingsE(client *ec2.EC2, instanceTypeOptions []string) ([]*ec2.InstanceTypeOffering, error) { +func getInstanceTypeOfferingsE(client *ec2.Client, instanceTypeOptions []string) ([]types.InstanceTypeOffering, error) { input := ec2.DescribeInstanceTypeOfferingsInput{ - LocationType: aws.String(ec2.LocationTypeAvailabilityZone), - Filters: []*ec2.Filter{ + LocationType: types.LocationTypeAvailabilityZone, + Filters: []types.Filter{ { Name: aws.String("instance-type"), - Values: aws.StringSlice(instanceTypeOptions), + Values: instanceTypeOptions, }, }, } - out, err := client.DescribeInstanceTypeOfferings(&input) + out, err := client.DescribeInstanceTypeOfferings(context.Background(), &input) if err != nil { return nil, err } @@ -500,17 +502,17 @@ func getInstanceTypeOfferingsE(client *ec2.EC2, instanceTypeOptions []string) ([ } // getAllAvailabilityZonesE returns all the available AZs in the region configured in the given EC2 client -func getAllAvailabilityZonesE(client *ec2.EC2) ([]string, error) { +func getAllAvailabilityZonesE(client *ec2.Client) ([]string, error) { input := ec2.DescribeAvailabilityZonesInput{ - Filters: []*ec2.Filter{ + Filters: []types.Filter{ { Name: aws.String("state"), - Values: aws.StringSlice([]string{"available"}), + Values: []string{"available"}, }, }, } - out, err := client.DescribeAvailabilityZones(&input) + out, err := client.DescribeAvailabilityZones(context.Background(), &input) if err != nil { return nil, err } @@ -518,25 +520,35 @@ func getAllAvailabilityZonesE(client *ec2.EC2) ([]string, error) { var azs []string for _, az := range out.AvailabilityZones { - azs = append(azs, aws.StringValue(az.ZoneName)) + azs = append(azs, aws.ToString(az.ZoneName)) } return azs, nil } // NewEc2Client creates an EC2 client. -func NewEc2Client(t testing.TestingT, region string) *ec2.EC2 { - client, err := NewEc2ClientE(t, region) +func NewEc2Client(t testing.TestingT, region string) *ec2.Client { + client, err := NewEc2ClientV2E(t, region) + require.NoError(t, err) + return client +} + +func NewEc2ClientV2(t testing.TestingT, region string) *ec2.Client { + client, err := NewEc2ClientV2E(t, region) require.NoError(t, err) return client } // NewEc2ClientE creates an EC2 client. -func NewEc2ClientE(t testing.TestingT, region string) (*ec2.EC2, error) { +func NewEc2ClientE(t testing.TestingT, region string) (*ec2.Client, error) { + return NewEc2ClientV2E(t, region) +} + +func NewEc2ClientV2E(t testing.TestingT, region string) (*ec2.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return ec2.New(sess), nil + return ec2.NewFromConfig(*sess), nil } diff --git a/modules/aws/ec2_test.go b/modules/aws/ec2_test.go index 236e55bbb..2d5c9e729 100644 --- a/modules/aws/ec2_test.go +++ b/modules/aws/ec2_test.go @@ -5,9 +5,8 @@ import ( "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -56,7 +55,7 @@ func TestGetRecommendedInstanceType(t *testing.T) { t.Run(fmt.Sprintf("%s-%s", testCase.region, strings.Join(testCase.instanceTypeOptions, "-")), func(t *testing.T) { t.Parallel() instanceType := GetRecommendedInstanceType(t, testCase.region, testCase.instanceTypeOptions) - // We could hard-code the expected result (e.g., as of July, 2020, we expect eu-west-1 to return t2.micro + // We could hard-code the expected result (e.g., as of July 2020, we expect eu-west-1 to return t2.micro // and ap-northeast-2 to return t3.micro), but the result will likely change over time, so to avoid a // brittle test, we simply check that we get _one_ result. Combined with the unit test below, this hopefully // is enough to be confident this function works correctly. @@ -69,7 +68,7 @@ func TestPickRecommendedInstanceTypeHappyPath(t *testing.T) { testCases := []struct { name string availabilityZones []string - instanceTypeOfferings []*ec2.InstanceTypeOffering + instanceTypeOfferings []types.InstanceTypeOffering instanceTypeOptions []string expected string }{ @@ -136,7 +135,7 @@ func TestPickRecommendedInstanceTypeErrors(t *testing.T) { testCases := []struct { name string availabilityZones []string - instanceTypeOfferings []*ec2.InstanceTypeOffering + instanceTypeOfferings []types.InstanceTypeOffering instanceTypeOptions []string }{ { @@ -184,15 +183,15 @@ func TestPickRecommendedInstanceTypeErrors(t *testing.T) { } } -func offerings(offerings map[string][]string) []*ec2.InstanceTypeOffering { - var out []*ec2.InstanceTypeOffering +func offerings(offerings map[string][]string) []types.InstanceTypeOffering { + var out []types.InstanceTypeOffering for az, instanceTypes := range offerings { for _, instanceType := range instanceTypes { - offering := &ec2.InstanceTypeOffering{ - InstanceType: aws.String(instanceType), + offering := types.InstanceTypeOffering{ + InstanceType: types.InstanceType(instanceType), Location: aws.String(az), - LocationType: aws.String(ec2.LocationTypeAvailabilityZone), + LocationType: types.LocationTypeAvailabilityZone, } out = append(out, offering) } diff --git a/modules/aws/ecr.go b/modules/aws/ecr.go index cb2f9f8cc..f899dba27 100644 --- a/modules/aws/ecr.go +++ b/modules/aws/ecr.go @@ -1,10 +1,12 @@ package aws import ( + "context" goerrors "errors" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecr" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecr" + "github.com/aws/aws-sdk-go-v2/service/ecr/types" "github.com/gruntwork-io/go-commons/errors" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" @@ -12,16 +14,16 @@ import ( ) // CreateECRRepo creates a new ECR Repository. This will fail the test and stop execution if there is an error. -func CreateECRRepo(t testing.TestingT, region string, name string) *ecr.Repository { +func CreateECRRepo(t testing.TestingT, region string, name string) *types.Repository { repo, err := CreateECRRepoE(t, region, name) require.NoError(t, err) return repo } // CreateECRRepoE creates a new ECR Repository. -func CreateECRRepoE(t testing.TestingT, region string, name string) (*ecr.Repository, error) { +func CreateECRRepoE(t testing.TestingT, region string, name string) (*types.Repository, error) { client := NewECRClient(t, region) - resp, err := client.CreateRepository(&ecr.CreateRepositoryInput{RepositoryName: aws.String(name)}) + resp, err := client.CreateRepository(context.Background(), &ecr.CreateRepositoryInput{RepositoryName: aws.String(name)}) if err != nil { return nil, err } @@ -30,7 +32,7 @@ func CreateECRRepoE(t testing.TestingT, region string, name string) (*ecr.Reposi // GetECRRepo gets an ECR repository by name. This will fail the test and stop execution if there is an error. // An error occurs if a repository with the given name does not exist in the given region. -func GetECRRepo(t testing.TestingT, region string, name string) *ecr.Repository { +func GetECRRepo(t testing.TestingT, region string, name string) *types.Repository { repo, err := GetECRRepoE(t, region, name) require.NoError(t, err) return repo @@ -38,35 +40,35 @@ func GetECRRepo(t testing.TestingT, region string, name string) *ecr.Repository // GetECRRepoE gets an ECR Repository by name. // An error occurs if a repository with the given name does not exist in the given region. -func GetECRRepoE(t testing.TestingT, region string, name string) (*ecr.Repository, error) { +func GetECRRepoE(t testing.TestingT, region string, name string) (*types.Repository, error) { client := NewECRClient(t, region) - repositoryNames := []*string{aws.String(name)} - resp, err := client.DescribeRepositories(&ecr.DescribeRepositoriesInput{RepositoryNames: repositoryNames}) + repositoryNames := []string{name} + resp, err := client.DescribeRepositories(context.Background(), &ecr.DescribeRepositoriesInput{RepositoryNames: repositoryNames}) if err != nil { return nil, err } if len(resp.Repositories) != 1 { - return nil, errors.WithStackTrace(goerrors.New(("An unexpected condition occurred. Please file an issue at github.com/gruntwork-io/terratest"))) + return nil, errors.WithStackTrace(goerrors.New("an unexpected condition occurred. Please file an issue at github.com/gruntwork-io/terratest")) } - return resp.Repositories[0], nil + return &resp.Repositories[0], nil } // DeleteECRRepo will force delete the ECR repo by deleting all images prior to deleting the ECR repository. // This will fail the test and stop execution if there is an error. -func DeleteECRRepo(t testing.TestingT, region string, repo *ecr.Repository) { +func DeleteECRRepo(t testing.TestingT, region string, repo *types.Repository) { err := DeleteECRRepoE(t, region, repo) require.NoError(t, err) } // DeleteECRRepoE will force delete the ECR repo by deleting all images prior to deleting the ECR repository. -func DeleteECRRepoE(t testing.TestingT, region string, repo *ecr.Repository) error { +func DeleteECRRepoE(t testing.TestingT, region string, repo *types.Repository) error { client := NewECRClient(t, region) - resp, err := client.ListImages(&ecr.ListImagesInput{RepositoryName: repo.RepositoryName}) + resp, err := client.ListImages(context.Background(), &ecr.ListImagesInput{RepositoryName: repo.RepositoryName}) if err != nil { return err } if len(resp.ImageIds) > 0 { - _, err = client.BatchDeleteImage(&ecr.BatchDeleteImageInput{ + _, err = client.BatchDeleteImage(context.Background(), &ecr.BatchDeleteImageInput{ RepositoryName: repo.RepositoryName, ImageIds: resp.ImageIds, }) @@ -75,7 +77,7 @@ func DeleteECRRepoE(t testing.TestingT, region string, repo *ecr.Repository) err } } - _, err = client.DeleteRepository(&ecr.DeleteRepositoryInput{RepositoryName: repo.RepositoryName}) + _, err = client.DeleteRepository(context.Background(), &ecr.DeleteRepositoryInput{RepositoryName: repo.RepositoryName}) if err != nil { return err } @@ -84,33 +86,33 @@ func DeleteECRRepoE(t testing.TestingT, region string, repo *ecr.Repository) err // NewECRClient returns a client for the Elastic Container Registry. This will fail the test and // stop execution if there is an error. -func NewECRClient(t testing.TestingT, region string) *ecr.ECR { +func NewECRClient(t testing.TestingT, region string) *ecr.Client { sess, err := NewECRClientE(t, region) require.NoError(t, err) return sess } -// NewECRClient returns a client for the Elastic Container Registry. -func NewECRClientE(t testing.TestingT, region string) (*ecr.ECR, error) { +// NewECRClientE returns a client for the Elastic Container Registry. +func NewECRClientE(t testing.TestingT, region string) (*ecr.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return ecr.New(sess), nil + return ecr.NewFromConfig(*sess), nil } // GetECRRepoLifecyclePolicy gets the policies for the given ECR repository. // This will fail the test and stop execution if there is an error. -func GetECRRepoLifecyclePolicy(t testing.TestingT, region string, repo *ecr.Repository) string { +func GetECRRepoLifecyclePolicy(t testing.TestingT, region string, repo *types.Repository) string { policy, err := GetECRRepoLifecyclePolicyE(t, region, repo) require.NoError(t, err) return policy } // GetECRRepoLifecyclePolicyE gets the policies for the given ECR repository. -func GetECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *ecr.Repository) (string, error) { +func GetECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *types.Repository) (string, error) { client := NewECRClient(t, region) - resp, err := client.GetLifecyclePolicy(&ecr.GetLifecyclePolicyInput{RepositoryName: repo.RepositoryName}) + resp, err := client.GetLifecyclePolicy(context.Background(), &ecr.GetLifecyclePolicyInput{RepositoryName: repo.RepositoryName}) if err != nil { return "", err } @@ -119,13 +121,13 @@ func GetECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *ecr.Rep // PutECRRepoLifecyclePolicy puts the given policy for the given ECR repository. // This will fail the test and stop execution if there is an error. -func PutECRRepoLifecyclePolicy(t testing.TestingT, region string, repo *ecr.Repository, policy string) { +func PutECRRepoLifecyclePolicy(t testing.TestingT, region string, repo *types.Repository, policy string) { err := PutECRRepoLifecyclePolicyE(t, region, repo, policy) require.NoError(t, err) } -// PutEcrRepoLifecyclePolicy puts the given policy for the given ECR repository. -func PutECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *ecr.Repository, policy string) error { +// PutECRRepoLifecyclePolicyE puts the given policy for the given ECR repository. +func PutECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *types.Repository, policy string) error { logger.Logf(t, "Applying policy for repository %s in %s", *repo.RepositoryName, region) client, err := NewECRClientE(t, region) @@ -138,6 +140,6 @@ func PutECRRepoLifecyclePolicyE(t testing.TestingT, region string, repo *ecr.Rep LifecyclePolicyText: aws.String(policy), } - _, err = client.PutLifecyclePolicy(input) + _, err = client.PutLifecyclePolicy(context.Background(), input) return err } diff --git a/modules/aws/ecr_test.go b/modules/aws/ecr_test.go index 30a4c342a..bc0be1d7b 100644 --- a/modules/aws/ecr_test.go +++ b/modules/aws/ecr_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,11 +20,11 @@ func TestEcrRepo(t *testing.T) { defer DeleteECRRepo(t, region, repo1) require.NoError(t, err) - assert.Equal(t, ecrRepoName, aws.StringValue(repo1.RepositoryName)) + assert.Equal(t, ecrRepoName, aws.ToString(repo1.RepositoryName)) repo2, err := GetECRRepoE(t, region, ecrRepoName) require.NoError(t, err) - assert.Equal(t, ecrRepoName, aws.StringValue(repo2.RepositoryName)) + assert.Equal(t, ecrRepoName, aws.ToString(repo2.RepositoryName)) } func TestGetEcrRepoLifecyclePolicyError(t *testing.T) { @@ -36,7 +36,7 @@ func TestGetEcrRepoLifecyclePolicyError(t *testing.T) { defer DeleteECRRepo(t, region, repo1) require.NoError(t, err) - assert.Equal(t, ecrRepoName, aws.StringValue(repo1.RepositoryName)) + assert.Equal(t, ecrRepoName, aws.ToString(repo1.RepositoryName)) _, err = GetECRRepoLifecyclePolicyE(t, region, repo1) require.Error(t, err) diff --git a/modules/aws/ecs.go b/modules/aws/ecs.go index 81658738e..29b463c76 100644 --- a/modules/aws/ecs.go +++ b/modules/aws/ecs.go @@ -1,29 +1,31 @@ package aws import ( + "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" ) // GetEcsCluster fetches information about specified ECS cluster. -func GetEcsCluster(t testing.TestingT, region string, name string) *ecs.Cluster { +func GetEcsCluster(t testing.TestingT, region string, name string) *types.Cluster { cluster, err := GetEcsClusterE(t, region, name) require.NoError(t, err) return cluster } // GetEcsClusterE fetches information about specified ECS cluster. -func GetEcsClusterE(t testing.TestingT, region string, name string) (*ecs.Cluster, error) { - return GetEcsClusterWithIncludeE(t, region, name, []string{}) +func GetEcsClusterE(t testing.TestingT, region string, name string) (*types.Cluster, error) { + return GetEcsClusterWithIncludeE(t, region, name, []types.ClusterField{}) } // GetEcsClusterWithInclude fetches extended information about specified ECS cluster. // The `include` parameter specifies a list of `ecs.ClusterField*` constants, such as `ecs.ClusterFieldTags`. -func GetEcsClusterWithInclude(t testing.TestingT, region string, name string, include []string) *ecs.Cluster { +func GetEcsClusterWithInclude(t testing.TestingT, region string, name string, include []types.ClusterField) *types.Cluster { clusterInfo, err := GetEcsClusterWithIncludeE(t, region, name, include) require.NoError(t, err) return clusterInfo @@ -31,52 +33,53 @@ func GetEcsClusterWithInclude(t testing.TestingT, region string, name string, in // GetEcsClusterWithIncludeE fetches extended information about specified ECS cluster. // The `include` parameter specifies a list of `ecs.ClusterField*` constants, such as `ecs.ClusterFieldTags`. -func GetEcsClusterWithIncludeE(t testing.TestingT, region string, name string, include []string) (*ecs.Cluster, error) { +func GetEcsClusterWithIncludeE(t testing.TestingT, region string, name string, include []types.ClusterField) (*types.Cluster, error) { client, err := NewEcsClientE(t, region) if err != nil { return nil, err } + input := &ecs.DescribeClustersInput{ - Clusters: []*string{ - aws.String(name), + Clusters: []string{ + name, }, - Include: aws.StringSlice(include), + Include: include, } - output, err := client.DescribeClusters(input) + output, err := client.DescribeClusters(context.Background(), input) if err != nil { return nil, err } numClusters := len(output.Clusters) if numClusters != 1 { - return nil, fmt.Errorf("Expected to find 1 ECS cluster named '%s' in region '%v', but found '%d'", + return nil, fmt.Errorf("expected to find 1 ECS cluster named '%s' in region '%v', but found '%d'", name, region, numClusters) } - return output.Clusters[0], nil + return &output.Clusters[0], nil } // GetDefaultEcsClusterE fetches information about default ECS cluster. -func GetDefaultEcsClusterE(t testing.TestingT, region string) (*ecs.Cluster, error) { +func GetDefaultEcsClusterE(t testing.TestingT, region string) (*types.Cluster, error) { return GetEcsClusterE(t, region, "default") } // GetDefaultEcsCluster fetches information about default ECS cluster. -func GetDefaultEcsCluster(t testing.TestingT, region string) *ecs.Cluster { +func GetDefaultEcsCluster(t testing.TestingT, region string) *types.Cluster { return GetEcsCluster(t, region, "default") } // CreateEcsCluster creates ECS cluster in the given region under the given name. -func CreateEcsCluster(t testing.TestingT, region string, name string) *ecs.Cluster { +func CreateEcsCluster(t testing.TestingT, region string, name string) *types.Cluster { cluster, err := CreateEcsClusterE(t, region, name) require.NoError(t, err) return cluster } // CreateEcsClusterE creates ECS cluster in the given region under the given name. -func CreateEcsClusterE(t testing.TestingT, region string, name string) (*ecs.Cluster, error) { +func CreateEcsClusterE(t testing.TestingT, region string, name string) (*types.Cluster, error) { client := NewEcsClient(t, region) - cluster, err := client.CreateCluster(&ecs.CreateClusterInput{ + cluster, err := client.CreateCluster(context.Background(), &ecs.CreateClusterInput{ ClusterName: aws.String(name), }) if err != nil { @@ -85,33 +88,33 @@ func CreateEcsClusterE(t testing.TestingT, region string, name string) (*ecs.Clu return cluster.Cluster, nil } -func DeleteEcsCluster(t testing.TestingT, region string, cluster *ecs.Cluster) { +func DeleteEcsCluster(t testing.TestingT, region string, cluster *types.Cluster) { err := DeleteEcsClusterE(t, region, cluster) require.NoError(t, err) } // DeleteEcsClusterE deletes existing ECS cluster in the given region. -func DeleteEcsClusterE(t testing.TestingT, region string, cluster *ecs.Cluster) error { +func DeleteEcsClusterE(t testing.TestingT, region string, cluster *types.Cluster) error { client := NewEcsClient(t, region) - _, err := client.DeleteCluster(&ecs.DeleteClusterInput{ + _, err := client.DeleteCluster(context.Background(), &ecs.DeleteClusterInput{ Cluster: aws.String(*cluster.ClusterName), }) return err } // GetEcsService fetches information about specified ECS service. -func GetEcsService(t testing.TestingT, region string, clusterName string, serviceName string) *ecs.Service { +func GetEcsService(t testing.TestingT, region string, clusterName string, serviceName string) *types.Service { service, err := GetEcsServiceE(t, region, clusterName, serviceName) require.NoError(t, err) return service } // GetEcsServiceE fetches information about specified ECS service. -func GetEcsServiceE(t testing.TestingT, region string, clusterName string, serviceName string) (*ecs.Service, error) { - output, err := NewEcsClient(t, region).DescribeServices(&ecs.DescribeServicesInput{ +func GetEcsServiceE(t testing.TestingT, region string, clusterName string, serviceName string) (*types.Service, error) { + output, err := NewEcsClient(t, region).DescribeServices(context.Background(), &ecs.DescribeServicesInput{ Cluster: aws.String(clusterName), - Services: []*string{ - aws.String(serviceName), + Services: []string{ + serviceName, }, }) if err != nil { @@ -121,22 +124,22 @@ func GetEcsServiceE(t testing.TestingT, region string, clusterName string, servi numServices := len(output.Services) if numServices != 1 { return nil, fmt.Errorf( - "Expected to find 1 ECS service named '%s' in cluster '%s' in region '%v', but found '%d'", + "expected to find 1 ECS service named '%s' in cluster '%s' in region '%v', but found '%d'", serviceName, clusterName, region, numServices) } - return output.Services[0], nil + return &output.Services[0], nil } // GetEcsTaskDefinition fetches information about specified ECS task definition. -func GetEcsTaskDefinition(t testing.TestingT, region string, taskDefinition string) *ecs.TaskDefinition { +func GetEcsTaskDefinition(t testing.TestingT, region string, taskDefinition string) *types.TaskDefinition { task, err := GetEcsTaskDefinitionE(t, region, taskDefinition) require.NoError(t, err) return task } // GetEcsTaskDefinitionE fetches information about specified ECS task definition. -func GetEcsTaskDefinitionE(t testing.TestingT, region string, taskDefinition string) (*ecs.TaskDefinition, error) { - output, err := NewEcsClient(t, region).DescribeTaskDefinition(&ecs.DescribeTaskDefinitionInput{ +func GetEcsTaskDefinitionE(t testing.TestingT, region string, taskDefinition string) (*types.TaskDefinition, error) { + output, err := NewEcsClient(t, region).DescribeTaskDefinition(context.Background(), &ecs.DescribeTaskDefinitionInput{ TaskDefinition: aws.String(taskDefinition), }) if err != nil { @@ -146,17 +149,17 @@ func GetEcsTaskDefinitionE(t testing.TestingT, region string, taskDefinition str } // NewEcsClient creates en ECS client. -func NewEcsClient(t testing.TestingT, region string) *ecs.ECS { +func NewEcsClient(t testing.TestingT, region string) *ecs.Client { client, err := NewEcsClientE(t, region) require.NoError(t, err) return client } // NewEcsClientE creates an ECS client. -func NewEcsClientE(t testing.TestingT, region string) (*ecs.ECS, error) { +func NewEcsClientE(t testing.TestingT, region string) (*ecs.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return ecs.New(sess), nil + return ecs.NewFromConfig(*sess), nil } diff --git a/modules/aws/ecs_test.go b/modules/aws/ecs_test.go index 3f7ffecb8..d9537888a 100644 --- a/modules/aws/ecs_test.go +++ b/modules/aws/ecs_test.go @@ -1,10 +1,12 @@ package aws import ( + "context" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" ) @@ -30,13 +32,13 @@ func TestEcsClusterWithInclude(t *testing.T) { region := GetRandomStableRegion(t, nil, nil) clusterName := "terratest-" + random.UniqueId() - tags := []*ecs.Tag{&ecs.Tag{ + tags := []types.Tag{{ Key: aws.String("test-tag"), Value: aws.String("hello-world"), }} client := NewEcsClient(t, region) - c1, err := client.CreateCluster(&ecs.CreateClusterInput{ + c1, err := client.CreateCluster(context.Background(), &ecs.CreateClusterInput{ ClusterName: aws.String(clusterName), Tags: tags, }) @@ -44,19 +46,19 @@ func TestEcsClusterWithInclude(t *testing.T) { defer DeleteEcsCluster(t, region, c1.Cluster) - assert.Equal(t, clusterName, aws.StringValue(c1.Cluster.ClusterName)) + assert.Equal(t, clusterName, aws.ToString(c1.Cluster.ClusterName)) - c2, err := GetEcsClusterWithIncludeE(t, region, clusterName, []string{ecs.ClusterFieldTags}) + c2, err := GetEcsClusterWithIncludeE(t, region, clusterName, []types.ClusterField{types.ClusterFieldTags}) assert.NoError(t, err) - assert.Equal(t, clusterName, aws.StringValue(c2.ClusterName)) + assert.Equal(t, clusterName, aws.ToString(c2.ClusterName)) assert.Equal(t, tags, c2.Tags) assert.Empty(t, c2.Statistics) - c3, err := GetEcsClusterWithIncludeE(t, region, clusterName, []string{ecs.ClusterFieldStatistics}) + c3, err := GetEcsClusterWithIncludeE(t, region, clusterName, []types.ClusterField{types.ClusterFieldStatistics}) assert.NoError(t, err) - assert.Equal(t, clusterName, aws.StringValue(c3.ClusterName)) + assert.Equal(t, clusterName, aws.ToString(c3.ClusterName)) assert.NotEmpty(t, c3.Statistics) assert.Empty(t, c3.Tags) } diff --git a/modules/aws/iam.go b/modules/aws/iam.go index 1ab2a68f1..030878154 100644 --- a/modules/aws/iam.go +++ b/modules/aws/iam.go @@ -1,10 +1,12 @@ package aws import ( + "context" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -25,7 +27,7 @@ func GetIamCurrentUserNameE(t testing.TestingT) (string, error) { return "", err } - resp, err := iamClient.GetUser(&iam.GetUserInput{}) + resp, err := iamClient.GetUser(context.Background(), &iam.GetUserInput{}) if err != nil { return "", err } @@ -49,7 +51,7 @@ func GetIamCurrentUserArnE(t testing.TestingT) (string, error) { return "", err } - resp, err := iamClient.GetUser(&iam.GetUserInput{}) + resp, err := iamClient.GetUser(context.Background(), &iam.GetUserInput{}) if err != nil { return "", err } @@ -58,7 +60,7 @@ func GetIamCurrentUserArnE(t testing.TestingT) (string, error) { } // CreateMfaDevice creates an MFA device using the given IAM client. -func CreateMfaDevice(t testing.TestingT, iamClient *iam.IAM, deviceName string) *iam.VirtualMFADevice { +func CreateMfaDevice(t testing.TestingT, iamClient *iam.Client, deviceName string) *types.VirtualMFADevice { mfaDevice, err := CreateMfaDeviceE(t, iamClient, deviceName) if err != nil { t.Fatal(err) @@ -67,10 +69,10 @@ func CreateMfaDevice(t testing.TestingT, iamClient *iam.IAM, deviceName string) } // CreateMfaDeviceE creates an MFA device using the given IAM client. -func CreateMfaDeviceE(t testing.TestingT, iamClient *iam.IAM, deviceName string) (*iam.VirtualMFADevice, error) { +func CreateMfaDeviceE(t testing.TestingT, iamClient *iam.Client, deviceName string) (*types.VirtualMFADevice, error) { logger.Logf(t, "Creating an MFA device called %s", deviceName) - output, err := iamClient.CreateVirtualMFADevice(&iam.CreateVirtualMFADeviceInput{ + output, err := iamClient.CreateVirtualMFADevice(context.Background(), &iam.CreateVirtualMFADeviceInput{ VirtualMFADeviceName: aws.String(deviceName), }) if err != nil { @@ -86,7 +88,7 @@ func CreateMfaDeviceE(t testing.TestingT, iamClient *iam.IAM, deviceName string) // EnableMfaDevice enables a newly created MFA Device by supplying the first two one-time passwords, so that it can be used for future // logins by the given IAM User. -func EnableMfaDevice(t testing.TestingT, iamClient *iam.IAM, mfaDevice *iam.VirtualMFADevice) { +func EnableMfaDevice(t testing.TestingT, iamClient *iam.Client, mfaDevice *types.VirtualMFADevice) { err := EnableMfaDeviceE(t, iamClient, mfaDevice) if err != nil { t.Fatal(err) @@ -95,8 +97,8 @@ func EnableMfaDevice(t testing.TestingT, iamClient *iam.IAM, mfaDevice *iam.Virt // EnableMfaDeviceE enables a newly created MFA Device by supplying the first two one-time passwords, so that it can be used for future // logins by the given IAM User. -func EnableMfaDeviceE(t testing.TestingT, iamClient *iam.IAM, mfaDevice *iam.VirtualMFADevice) error { - logger.Logf(t, "Enabling MFA device %s", aws.StringValue(mfaDevice.SerialNumber)) +func EnableMfaDeviceE(t testing.TestingT, iamClient *iam.Client, mfaDevice *types.VirtualMFADevice) error { + logger.Logf(t, "Enabling MFA device %s", aws.ToString(mfaDevice.SerialNumber)) iamUserName, err := GetIamCurrentUserArnE(t) if err != nil { @@ -116,7 +118,7 @@ func EnableMfaDeviceE(t testing.TestingT, iamClient *iam.IAM, mfaDevice *iam.Vir return err } - _, err = iamClient.EnableMFADevice(&iam.EnableMFADeviceInput{ + _, err = iamClient.EnableMFADevice(context.Background(), &iam.EnableMFADeviceInput{ AuthenticationCode1: aws.String(authCode1), AuthenticationCode2: aws.String(authCode2), SerialNumber: mfaDevice.SerialNumber, @@ -134,7 +136,7 @@ func EnableMfaDeviceE(t testing.TestingT, iamClient *iam.IAM, mfaDevice *iam.Vir } // NewIamClient creates a new IAM client. -func NewIamClient(t testing.TestingT, region string) *iam.IAM { +func NewIamClient(t testing.TestingT, region string) *iam.Client { client, err := NewIamClientE(t, region) if err != nil { t.Fatal(err) @@ -143,10 +145,10 @@ func NewIamClient(t testing.TestingT, region string) *iam.IAM { } // NewIamClientE creates a new IAM client. -func NewIamClientE(t testing.TestingT, region string) (*iam.IAM, error) { +func NewIamClientE(t testing.TestingT, region string) (*iam.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return iam.New(sess), nil + return iam.NewFromConfig(*sess), nil } diff --git a/modules/aws/keypair.go b/modules/aws/keypair.go index 5b180796a..00a1f31a3 100644 --- a/modules/aws/keypair.go +++ b/modules/aws/keypair.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/ssh" "github.com/gruntwork-io/terratest/modules/testing" @@ -47,7 +49,7 @@ func ImportEC2KeyPair(t testing.TestingT, region string, name string, keyPair *s func ImportEC2KeyPairE(t testing.TestingT, region string, name string, keyPair *ssh.KeyPair) (*Ec2Keypair, error) { logger.Logf(t, "Creating new Key Pair in EC2 region %s named %s", region, name) - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) if err != nil { return nil, err } @@ -57,7 +59,7 @@ func ImportEC2KeyPairE(t testing.TestingT, region string, name string, keyPair * PublicKeyMaterial: []byte(keyPair.PublicKey), } - _, err = client.ImportKeyPair(params) + _, err = client.ImportKeyPair(context.Background(), params) if err != nil { return nil, err } @@ -77,7 +79,7 @@ func DeleteEC2KeyPair(t testing.TestingT, keyPair *Ec2Keypair) { func DeleteEC2KeyPairE(t testing.TestingT, keyPair *Ec2Keypair) error { logger.Logf(t, "Deleting Key Pair in EC2 region %s named %s", keyPair.Region, keyPair.Name) - client, err := NewEc2ClientE(t, keyPair.Region) + client, err := NewEc2ClientV2E(t, keyPair.Region) if err != nil { return err } @@ -86,6 +88,6 @@ func DeleteEC2KeyPairE(t testing.TestingT, keyPair *Ec2Keypair) error { KeyName: aws.String(keyPair.Name), } - _, err = client.DeleteKeyPair(params) + _, err = client.DeleteKeyPair(context.Background(), params) return err } diff --git a/modules/aws/keypair_test.go b/modules/aws/keypair_test.go index d3f04f2b5..8248e364f 100644 --- a/modules/aws/keypair_test.go +++ b/modules/aws/keypair_test.go @@ -1,12 +1,12 @@ package aws import ( + "context" "fmt" "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" ) @@ -29,13 +29,13 @@ func TestCreateImportAndDeleteEC2KeyPair(t *testing.T) { } func keyPairExists(t *testing.T, keyPair *Ec2Keypair) bool { - client := NewEc2Client(t, keyPair.Region) + client := NewEc2ClientV2(t, keyPair.Region) input := ec2.DescribeKeyPairsInput{ - KeyNames: aws.StringSlice([]string{keyPair.Name}), + KeyNames: []string{keyPair.Name}, } - out, err := client.DescribeKeyPairs(&input) + out, err := client.DescribeKeyPairs(context.Background(), &input) if err != nil { if strings.Contains(err.Error(), "InvalidKeyPair.NotFound") { return false diff --git a/modules/aws/kms.go b/modules/aws/kms.go index 07cfd1fa8..d10442971 100644 --- a/modules/aws/kms.go +++ b/modules/aws/kms.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/kms" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -24,7 +26,7 @@ func GetCmkArnE(t testing.TestingT, region string, cmkID string) (string, error) return "", err } - result, err := kmsClient.DescribeKey(&kms.DescribeKeyInput{ + result, err := kmsClient.DescribeKey(context.Background(), &kms.DescribeKeyInput{ KeyId: aws.String(cmkID), }) @@ -36,7 +38,7 @@ func GetCmkArnE(t testing.TestingT, region string, cmkID string) (string, error) } // NewKmsClient creates a KMS client. -func NewKmsClient(t testing.TestingT, region string) *kms.KMS { +func NewKmsClient(t testing.TestingT, region string) *kms.Client { client, err := NewKmsClientE(t, region) if err != nil { t.Fatal(err) @@ -45,11 +47,11 @@ func NewKmsClient(t testing.TestingT, region string) *kms.KMS { } // NewKmsClientE creates a KMS client. -func NewKmsClientE(t testing.TestingT, region string) (*kms.KMS, error) { +func NewKmsClientE(t testing.TestingT, region string) (*kms.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return kms.New(sess), nil + return kms.NewFromConfig(*sess), nil } diff --git a/modules/aws/lambda.go b/modules/aws/lambda.go index 5630613ca..3ffbec3d4 100644 --- a/modules/aws/lambda.go +++ b/modules/aws/lambda.go @@ -1,11 +1,13 @@ package aws import ( + "context" "encoding/json" "errors" "fmt" - "github.com/aws/aws-sdk-go/service/lambda" + "github.com/aws/aws-sdk-go-v2/service/lambda" + "github.com/aws/aws-sdk-go-v2/service/lambda/types" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" ) @@ -59,7 +61,7 @@ type LambdaOutput struct { // The HTTP status code for a successful request is in the 200 range. // For RequestResponse invocation type, the status code is 200. // For the DryRun invocation type, the status code is 204. - StatusCode *int64 + StatusCode int32 } // InvokeFunction invokes a lambda function. @@ -89,14 +91,14 @@ func InvokeFunctionE(t testing.TestingT, region, functionName string, payload in invokeInput.Payload = payloadJson } - out, err := lambdaClient.Invoke(invokeInput) + out, err := lambdaClient.Invoke(context.Background(), invokeInput) require.NoError(t, err) if err != nil { return nil, err } if out.FunctionError != nil { - return out.Payload, &FunctionError{Message: *out.FunctionError, StatusCode: *out.StatusCode, Payload: out.Payload} + return out.Payload, &FunctionError{Message: *out.FunctionError, StatusCode: out.StatusCode, Payload: out.Payload} } return out.Payload, nil @@ -123,7 +125,7 @@ func InvokeFunctionWithParamsE(t testing.TestingT, region, functionName string, } // Verify the InvocationType is one of the allowed values and report - // an error if it's not. By default the InvocationType will be + // an error if it's not. By default, the InvocationType will be // "RequestResponse". invocationType, err := input.InvocationType.Value() if err != nil { @@ -132,7 +134,7 @@ func InvokeFunctionWithParamsE(t testing.TestingT, region, functionName string, invokeInput := &lambda.InvokeInput{ FunctionName: &functionName, - InvocationType: &invocationType, + InvocationType: types.InvocationType(invocationType), } if input.Payload != nil { @@ -143,7 +145,7 @@ func InvokeFunctionWithParamsE(t testing.TestingT, region, functionName string, invokeInput.Payload = payloadJson } - out, err := lambdaClient.Invoke(invokeInput) + out, err := lambdaClient.Invoke(context.Background(), invokeInput) if err != nil { return nil, err } @@ -165,7 +167,7 @@ func InvokeFunctionWithParamsE(t testing.TestingT, region, functionName string, type FunctionError struct { Message string - StatusCode int64 + StatusCode int32 Payload []byte } @@ -174,18 +176,18 @@ func (err *FunctionError) Error() string { } // NewLambdaClient creates a new Lambda client. -func NewLambdaClient(t testing.TestingT, region string) *lambda.Lambda { +func NewLambdaClient(t testing.TestingT, region string) *lambda.Client { client, err := NewLambdaClientE(t, region) require.NoError(t, err) return client } // NewLambdaClientE creates a new Lambda client. -func NewLambdaClientE(t testing.TestingT, region string) (*lambda.Lambda, error) { +func NewLambdaClientE(t testing.TestingT, region string) (*lambda.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return lambda.New(sess), nil + return lambda.NewFromConfig(*sess), nil } diff --git a/modules/aws/rds.go b/modules/aws/rds.go index 5ec6bbf50..d9abf4592 100644 --- a/modules/aws/rds.go +++ b/modules/aws/rds.go @@ -1,11 +1,13 @@ package aws import ( + "context" "database/sql" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + "github.com/aws/aws-sdk-go-v2/service/rds/types" _ "github.com/go-sql-driver/mysql" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" @@ -27,11 +29,11 @@ func GetAddressOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion return "", err } - return aws.StringValue(dbInstance.Endpoint.Address), nil + return aws.ToString(dbInstance.Endpoint.Address), nil } // GetPortOfRdsInstance gets the address of the given RDS Instance in the given region. -func GetPortOfRdsInstance(t testing.TestingT, dbInstanceID string, awsRegion string) int64 { +func GetPortOfRdsInstance(t testing.TestingT, dbInstanceID string, awsRegion string) int32 { port, err := GetPortOfRdsInstanceE(t, dbInstanceID, awsRegion) if err != nil { t.Fatal(err) @@ -40,7 +42,7 @@ func GetPortOfRdsInstance(t testing.TestingT, dbInstanceID string, awsRegion str } // GetPortOfRdsInstanceE gets the address of the given RDS Instance in the given region. -func GetPortOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion string) (int64, error) { +func GetPortOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion string) (int32, error) { dbInstance, err := GetRdsInstanceDetailsE(t, dbInstanceID, awsRegion) if err != nil { return -1, err @@ -91,8 +93,8 @@ func GetParameterValueForParameterOfRdsInstance(t testing.TestingT, parameterNam func GetParameterValueForParameterOfRdsInstanceE(t testing.TestingT, parameterName string, dbInstanceID string, awsRegion string) (string, error) { output := GetAllParametersOfRdsInstance(t, dbInstanceID, awsRegion) for _, parameter := range output { - if aws.StringValue(parameter.ParameterName) == parameterName { - return aws.StringValue(parameter.ParameterValue), nil + if aws.ToString(parameter.ParameterName) == parameterName { + return aws.ToString(parameter.ParameterValue), nil } } return "", ParameterForDbInstanceNotFound{ParameterName: parameterName, DbInstanceID: dbInstanceID, AwsRegion: awsRegion} @@ -112,10 +114,10 @@ func GetOptionSettingForOfRdsInstanceE(t testing.TestingT, optionName string, op optionGroupName := GetOptionGroupNameOfRdsInstance(t, dbInstanceID, awsRegion) options := GetOptionsOfOptionGroup(t, optionGroupName, awsRegion) for _, option := range options { - if aws.StringValue(option.OptionName) == optionName { + if aws.ToString(option.OptionName) == optionName { for _, optionSetting := range option.OptionSettings { - if aws.StringValue(optionSetting.Name) == optionSettingName { - return aws.StringValue(optionSetting.Value), nil + if aws.ToString(optionSetting.Name) == optionSettingName { + return aws.ToString(optionSetting.Value), nil } } } @@ -138,11 +140,11 @@ func GetOptionGroupNameOfRdsInstanceE(t testing.TestingT, dbInstanceID string, a if err != nil { return "", err } - return aws.StringValue(dbInstance.OptionGroupMemberships[0].OptionGroupName), nil + return aws.ToString(dbInstance.OptionGroupMemberships[0].OptionGroupName), nil } // GetOptionsOfOptionGroup gets the options of the option group specified -func GetOptionsOfOptionGroup(t testing.TestingT, optionGroupName string, awsRegion string) []*rds.Option { +func GetOptionsOfOptionGroup(t testing.TestingT, optionGroupName string, awsRegion string) []types.Option { output, err := GetOptionsOfOptionGroupE(t, optionGroupName, awsRegion) if err != nil { t.Fatal(err) @@ -151,18 +153,18 @@ func GetOptionsOfOptionGroup(t testing.TestingT, optionGroupName string, awsRegi } // GetOptionsOfOptionGroupE gets the options of the option group specified -func GetOptionsOfOptionGroupE(t testing.TestingT, optionGroupName string, awsRegion string) ([]*rds.Option, error) { +func GetOptionsOfOptionGroupE(t testing.TestingT, optionGroupName string, awsRegion string) ([]types.Option, error) { rdsClient := NewRdsClient(t, awsRegion) input := rds.DescribeOptionGroupsInput{OptionGroupName: aws.String(optionGroupName)} - output, err := rdsClient.DescribeOptionGroups(&input) + output, err := rdsClient.DescribeOptionGroups(context.Background(), &input) if err != nil { - return []*rds.Option{}, err + return []types.Option{}, err } return output.OptionGroupsList[0].Options, nil } // GetAllParametersOfRdsInstance gets all the parameters defined in the parameter group for the RDS instance in the given region. -func GetAllParametersOfRdsInstance(t testing.TestingT, dbInstanceID string, awsRegion string) []*rds.Parameter { +func GetAllParametersOfRdsInstance(t testing.TestingT, dbInstanceID string, awsRegion string) []types.Parameter { parameters, err := GetAllParametersOfRdsInstanceE(t, dbInstanceID, awsRegion) if err != nil { t.Fatal(err) @@ -171,36 +173,36 @@ func GetAllParametersOfRdsInstance(t testing.TestingT, dbInstanceID string, awsR } // GetAllParametersOfRdsInstanceE gets all the parameters defined in the parameter group for the RDS instance in the given region. -func GetAllParametersOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion string) ([]*rds.Parameter, error) { +func GetAllParametersOfRdsInstanceE(t testing.TestingT, dbInstanceID string, awsRegion string) ([]types.Parameter, error) { dbInstance, dbInstanceErr := GetRdsInstanceDetailsE(t, dbInstanceID, awsRegion) if dbInstanceErr != nil { - return []*rds.Parameter{}, dbInstanceErr + return []types.Parameter{}, dbInstanceErr } - parameterGroupName := aws.StringValue(dbInstance.DBParameterGroups[0].DBParameterGroupName) + parameterGroupName := aws.ToString(dbInstance.DBParameterGroups[0].DBParameterGroupName) rdsClient := NewRdsClient(t, awsRegion) input := rds.DescribeDBParametersInput{DBParameterGroupName: aws.String(parameterGroupName)} - output, err := rdsClient.DescribeDBParameters(&input) + output, err := rdsClient.DescribeDBParameters(context.Background(), &input) if err != nil { - return []*rds.Parameter{}, err + return []types.Parameter{}, err } return output.Parameters, nil } // GetRdsInstanceDetailsE gets the details of a single DB instance whose identifier is passed. -func GetRdsInstanceDetailsE(t testing.TestingT, dbInstanceID string, awsRegion string) (*rds.DBInstance, error) { +func GetRdsInstanceDetailsE(t testing.TestingT, dbInstanceID string, awsRegion string) (*types.DBInstance, error) { rdsClient := NewRdsClient(t, awsRegion) input := rds.DescribeDBInstancesInput{DBInstanceIdentifier: aws.String(dbInstanceID)} - output, err := rdsClient.DescribeDBInstances(&input) + output, err := rdsClient.DescribeDBInstances(context.Background(), &input) if err != nil { return nil, err } - return output.DBInstances[0], nil + return &output.DBInstances[0], nil } // NewRdsClient creates an RDS client. -func NewRdsClient(t testing.TestingT, region string) *rds.RDS { +func NewRdsClient(t testing.TestingT, region string) *rds.Client { client, err := NewRdsClientE(t, region) if err != nil { t.Fatal(err) @@ -209,18 +211,18 @@ func NewRdsClient(t testing.TestingT, region string) *rds.RDS { } // NewRdsClientE creates an RDS client. -func NewRdsClientE(t testing.TestingT, region string) (*rds.RDS, error) { +func NewRdsClientE(t testing.TestingT, region string) (*rds.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return rds.New(sess), nil + return rds.NewFromConfig(*sess), nil } // GetRecommendedRdsInstanceType takes in a list of RDS instance types (e.g., "db.t2.micro", "db.t3.micro") and returns the // first instance type in the list that is available in the given region and for the given database engine type. -// If none of the instances provided are avaiable for your combination of region and database engine, this function will exit with an error. +// If none of the instances provided are available for your combination of region and database engine, this function will exit with an error. func GetRecommendedRdsInstanceType(t testing.TestingT, region string, engine string, engineVersion string, instanceTypeOptions []string) string { out, err := GetRecommendedRdsInstanceTypeE(t, region, engine, engineVersion, instanceTypeOptions) require.NoError(t, err) @@ -229,7 +231,7 @@ func GetRecommendedRdsInstanceType(t testing.TestingT, region string, engine str // GetRecommendedRdsInstanceTypeE takes in a list of RDS instance types (e.g., "db.t2.micro", "db.t3.micro") and returns the // first instance type in the list that is available in the given region and for the given database engine type. -// If none of the instances provided are avaiable for your combination of region and database engine, this function will return an error. +// If none of the instances provided are available for your combination of region and database engine, this function will return an error. func GetRecommendedRdsInstanceTypeE(t testing.TestingT, region string, engine string, engineVersion string, instanceTypeOptions []string) (string, error) { client, err := NewRdsClientE(t, region) if err != nil { @@ -240,9 +242,9 @@ func GetRecommendedRdsInstanceTypeE(t testing.TestingT, region string, engine st // GetRecommendedRdsInstanceTypeWithClientE takes in a list of RDS instance types (e.g., "db.t2.micro", "db.t3.micro") and returns the // first instance type in the list that is available in the given region and for the given database engine type. -// If none of the instances provided are avaiable for your combination of region and database engine, this function will return an error. +// If none of the instances provided are available for your combination of region and database engine, this function will return an error. // This function expects an authenticated RDS client from the AWS SDK Go library. -func GetRecommendedRdsInstanceTypeWithClientE(t testing.TestingT, rdsClient *rds.RDS, engine string, engineVersion string, instanceTypeOptions []string) (string, error) { +func GetRecommendedRdsInstanceTypeWithClientE(t testing.TestingT, rdsClient *rds.Client, engine string, engineVersion string, instanceTypeOptions []string) (string, error) { for _, instanceTypeOption := range instanceTypeOptions { instanceTypeExists, err := instanceTypeExistsForEngineAndRegionE(rdsClient, engine, engineVersion, instanceTypeOption) if err != nil { @@ -258,14 +260,14 @@ func GetRecommendedRdsInstanceTypeWithClientE(t testing.TestingT, rdsClient *rds // instanceTypeExistsForEngineAndRegionE returns a boolean that represents whether the provided instance type (e.g. db.t2.micro) exists for the given region and db engine type // This function will return an error if the RDS AWS SDK call fails. -func instanceTypeExistsForEngineAndRegionE(client *rds.RDS, engine string, engineVersion string, instanceType string) (bool, error) { +func instanceTypeExistsForEngineAndRegionE(client *rds.Client, engine string, engineVersion string, instanceType string) (bool, error) { input := rds.DescribeOrderableDBInstanceOptionsInput{ Engine: aws.String(engine), EngineVersion: aws.String(engineVersion), DBInstanceClass: aws.String(instanceType), } - out, err := client.DescribeOrderableDBInstanceOptions(&input) + out, err := client.DescribeOrderableDBInstanceOptions(context.Background(), &input) if err != nil { return false, err } @@ -295,7 +297,7 @@ func GetValidEngineVersionE(t testing.TestingT, region string, engine string, ma Engine: aws.String(engine), EngineVersion: aws.String(majorVersion), } - out, err := client.DescribeDBEngineVersions(&input) + out, err := client.DescribeDBEngineVersions(context.Background(), &input) if err != nil || len(out.DBEngineVersions) == 0 { return "", err } diff --git a/modules/aws/region.go b/modules/aws/region.go index 2821959ea..943cc9454 100644 --- a/modules/aws/region.go +++ b/modules/aws/region.go @@ -1,12 +1,13 @@ package aws import ( + "context" "fmt" "os" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/gruntwork-io/terratest/modules/collections" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/random" @@ -108,19 +109,19 @@ func GetAllAwsRegions(t testing.TestingT) []string { func GetAllAwsRegionsE(t testing.TestingT) ([]string, error) { logger.Log(t, "Looking up all AWS regions available in this account") - ec2Client, err := NewEc2ClientE(t, defaultRegion) + ec2Client, err := NewEc2ClientV2E(t, defaultRegion) if err != nil { return nil, err } - out, err := ec2Client.DescribeRegions(&ec2.DescribeRegionsInput{}) + out, err := ec2Client.DescribeRegions(context.Background(), &ec2.DescribeRegionsInput{}) if err != nil { return nil, err } - regions := []string{} + var regions []string for _, region := range out.Regions { - regions = append(regions, aws.StringValue(region.RegionName)) + regions = append(regions, aws.ToString(region.RegionName)) } return regions, nil @@ -141,19 +142,19 @@ func GetAvailabilityZones(t testing.TestingT, region string) []string { func GetAvailabilityZonesE(t testing.TestingT, region string) ([]string, error) { logger.Logf(t, "Looking up all availability zones available in this account for region %s", region) - ec2Client, err := NewEc2ClientE(t, region) + ec2Client, err := NewEc2ClientV2E(t, region) if err != nil { return nil, err } - resp, err := ec2Client.DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{}) + resp, err := ec2Client.DescribeAvailabilityZones(context.Background(), &ec2.DescribeAvailabilityZonesInput{}) if err != nil { return nil, err } var out []string for _, availabilityZone := range resp.AvailabilityZones { - out = append(out, aws.StringValue(availabilityZone.ZoneName)) + out = append(out, aws.ToString(availabilityZone.ZoneName)) } return out, nil @@ -168,7 +169,7 @@ func GetRegionsForService(t testing.TestingT, serviceName string) []string { return out } -// GetRegionsForService gets all AWS regions in which a service is available and returns errors. +// GetRegionsForServiceE gets all AWS regions in which a service is available and returns errors. // See https://docs.aws.amazon.com/systems-manager/latest/userguide/parameter-store-public-parameters-global-infrastructure.html func GetRegionsForServiceE(t testing.TestingT, serviceName string) ([]string, error) { // These values are available in any region, defaulting to us-east-1 since it's the oldest @@ -179,12 +180,11 @@ func GetRegionsForServiceE(t testing.TestingT, serviceName string) ([]string, er } paramPath := "/aws/service/global-infrastructure/services/%s/regions" - req, resp := ssmClient.GetParametersByPathRequest(&ssm.GetParametersByPathInput{ + resp, err := ssmClient.GetParametersByPath(context.Background(), &ssm.GetParametersByPathInput{ Path: aws.String(fmt.Sprintf(paramPath, serviceName)), }) - ssmErr := req.Send() - if ssmErr != nil { + if err != nil { return nil, err } diff --git a/modules/aws/route53.go b/modules/aws/route53.go index abb7a56ba..30ddd0166 100644 --- a/modules/aws/route53.go +++ b/modules/aws/route53.go @@ -1,17 +1,19 @@ package aws import ( + "context" "fmt" "strings" "testing" - "github.com/aws/aws-sdk-go/service/route53" - "github.com/gogo/protobuf/proto" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/route53" + "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/stretchr/testify/require" ) // GetRoute53Record returns a Route 53 Record -func GetRoute53Record(t *testing.T, hostedZoneID, recordName, recordType, awsRegion string) *route53.ResourceRecordSet { +func GetRoute53Record(t *testing.T, hostedZoneID, recordName, recordType, awsRegion string) *types.ResourceRecordSet { r, err := GetRoute53RecordE(t, hostedZoneID, recordName, recordType, awsRegion) require.NoError(t, err) @@ -19,35 +21,33 @@ func GetRoute53Record(t *testing.T, hostedZoneID, recordName, recordType, awsReg } // GetRoute53RecordE returns a Route 53 Record -func GetRoute53RecordE(t *testing.T, hostedZoneID, recordName, recordType, awsRegion string) (record *route53.ResourceRecordSet, err error) { +func GetRoute53RecordE(t *testing.T, hostedZoneID, recordName, recordType, awsRegion string) (*types.ResourceRecordSet, error) { route53Client, err := NewRoute53ClientE(t, awsRegion) if err != nil { return nil, err } - o, err := route53Client.ListResourceRecordSets(&route53.ListResourceRecordSetsInput{ + o, err := route53Client.ListResourceRecordSets(context.Background(), &route53.ListResourceRecordSetsInput{ HostedZoneId: &hostedZoneID, StartRecordName: &recordName, - StartRecordType: &recordType, - MaxItems: proto.String("1"), + StartRecordType: types.RRType(recordType), + MaxItems: aws.Int32(1), }) if err != nil { - return + return nil, err } - for _, record = range o.ResourceRecordSets { + + for _, record := range o.ResourceRecordSets { if strings.EqualFold(recordName+".", *record.Name) { - break + return &record, nil } - record = nil - } - if record == nil { - err = fmt.Errorf("record not found") } - return + + return nil, fmt.Errorf("record not found") } -// NewRoute53ClientE creates a route 53 client. -func NewRoute53Client(t *testing.T, region string) *route53.Route53 { +// NewRoute53Client creates a route 53 client. +func NewRoute53Client(t *testing.T, region string) *route53.Client { c, err := NewRoute53ClientE(t, region) require.NoError(t, err) @@ -55,11 +55,11 @@ func NewRoute53Client(t *testing.T, region string) *route53.Route53 { } // NewRoute53ClientE creates a route 53 client. -func NewRoute53ClientE(t *testing.T, region string) (*route53.Route53, error) { +func NewRoute53ClientE(t *testing.T, region string) (*route53.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return route53.New(sess), nil + return route53.NewFromConfig(*sess), nil } diff --git a/modules/aws/route53_test.go b/modules/aws/route53_test.go index 3f98d3183..5c048c1f4 100644 --- a/modules/aws/route53_test.go +++ b/modules/aws/route53_test.go @@ -1,13 +1,14 @@ package aws import ( + "context" "fmt" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/route53" - "github.com/gogo/protobuf/proto" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/route53" + "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,35 +20,35 @@ func TestRoute53Record(t *testing.T) { require.NoError(t, err) domain := fmt.Sprintf("terratest%dexample.com", time.Now().UnixNano()) - hostedZone, err := c.CreateHostedZone(&route53.CreateHostedZoneInput{ + hostedZone, err := c.CreateHostedZone(context.Background(), &route53.CreateHostedZoneInput{ Name: aws.String(domain), CallerReference: aws.String(fmt.Sprint(time.Now().UnixNano())), }) require.NoError(t, err) t.Cleanup(func() { - _, err := c.DeleteHostedZone(&route53.DeleteHostedZoneInput{ + _, err := c.DeleteHostedZone(context.Background(), &route53.DeleteHostedZoneInput{ Id: hostedZone.HostedZone.Id, }) require.NoError(t, err) }) recordName := fmt.Sprintf("record.%s", domain) - resourceRecordSet := &route53.ResourceRecordSet{ + resourceRecordSet := &types.ResourceRecordSet{ Name: &recordName, - Type: aws.String("A"), + Type: types.RRTypeA, TTL: aws.Int64(60), - ResourceRecords: []*route53.ResourceRecord{ + ResourceRecords: []types.ResourceRecord{ { Value: aws.String("127.0.0.1"), }, }, } - _, err = c.ChangeResourceRecordSets(&route53.ChangeResourceRecordSetsInput{ + _, err = c.ChangeResourceRecordSets(context.Background(), &route53.ChangeResourceRecordSetsInput{ HostedZoneId: hostedZone.HostedZone.Id, - ChangeBatch: &route53.ChangeBatch{ - Changes: []*route53.Change{ + ChangeBatch: &types.ChangeBatch{ + Changes: []types.Change{ { - Action: proto.String("CREATE"), + Action: types.ChangeActionCreate, ResourceRecordSet: resourceRecordSet, }, }, @@ -55,12 +56,12 @@ func TestRoute53Record(t *testing.T) { }) require.NoError(t, err) t.Cleanup(func() { - _, err := c.ChangeResourceRecordSets(&route53.ChangeResourceRecordSetsInput{ + _, err := c.ChangeResourceRecordSets(context.Background(), &route53.ChangeResourceRecordSetsInput{ HostedZoneId: hostedZone.HostedZone.Id, - ChangeBatch: &route53.ChangeBatch{ - Changes: []*route53.Change{ + ChangeBatch: &types.ChangeBatch{ + Changes: []types.Change{ { - Action: proto.String("DELETE"), + Action: types.ChangeActionDelete, ResourceRecordSet: resourceRecordSet, }, }, @@ -70,10 +71,10 @@ func TestRoute53Record(t *testing.T) { }) t.Run("ExistingRecord", func(t *testing.T) { - route53Record := GetRoute53Record(t, *hostedZone.HostedZone.Id, recordName, *resourceRecordSet.Type, region) + route53Record := GetRoute53Record(t, *hostedZone.HostedZone.Id, recordName, string(resourceRecordSet.Type), region) require.NotNil(t, route53Record) assert.Equal(t, recordName+".", *route53Record.Name) - assert.Equal(t, *resourceRecordSet.Type, *route53Record.Type) + assert.Equal(t, resourceRecordSet.Type, route53Record.Type) assert.Equal(t, "127.0.0.1", *route53Record.ResourceRecords[0].Value) }) diff --git a/modules/aws/s3.go b/modules/aws/s3.go index c58744dba..d2131036a 100644 --- a/modules/aws/s3.go +++ b/modules/aws/s3.go @@ -2,12 +2,14 @@ package aws import ( "bytes" + "context" "fmt" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" @@ -28,13 +30,13 @@ func FindS3BucketWithTagE(t testing.TestingT, awsRegion string, key string, valu return "", err } - resp, err := s3Client.ListBuckets(&s3.ListBucketsInput{}) + resp, err := s3Client.ListBuckets(context.Background(), &s3.ListBucketsInput{}) if err != nil { return "", err } for _, bucket := range resp.Buckets { - tagResponse, err := s3Client.GetBucketTagging(&s3.GetBucketTaggingInput{Bucket: bucket.Name}) + tagResponse, err := s3Client.GetBucketTagging(context.Background(), &s3.GetBucketTaggingInput{Bucket: bucket.Name}) if err != nil { if strings.Contains(err.Error(), "NoSuchBucket") { @@ -77,7 +79,7 @@ func GetS3BucketTagsE(t testing.TestingT, awsRegion string, bucket string) (map[ return nil, err } - out, err := s3Client.GetBucketTagging(&s3.GetBucketTaggingInput{ + out, err := s3Client.GetBucketTagging(context.Background(), &s3.GetBucketTaggingInput{ Bucket: &bucket, }) if err != nil { @@ -86,7 +88,7 @@ func GetS3BucketTagsE(t testing.TestingT, awsRegion string, bucket string) (map[ tags := map[string]string{} for _, tag := range out.TagSet { - tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + tags[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } return tags, nil @@ -107,7 +109,7 @@ func GetS3ObjectContentsE(t testing.TestingT, awsRegion string, bucket string, k return "", err } - res, err := s3Client.GetObject(&s3.GetObjectInput{ + res, err := s3Client.GetObject(context.Background(), &s3.GetObjectInput{ Bucket: &bucket, Key: &key, }) @@ -144,21 +146,27 @@ func CreateS3BucketE(t testing.TestingT, region string, name string) error { } params := &s3.CreateBucketInput{ - Bucket: aws.String(name), - // https://github.com/aws/aws-sdk-go/blob/v1.44.122/service/s3/api.go#L41646 - ObjectOwnership: aws.String(s3.ObjectOwnershipObjectWriter), + Bucket: aws.String(name), + ObjectOwnership: types.ObjectOwnershipObjectWriter, } - _, err = s3Client.CreateBucket(params) + + if region != "us-east-1" { + params.CreateBucketConfiguration = &types.CreateBucketConfiguration{ + LocationConstraint: types.BucketLocationConstraint(region), + } + } + + _, err = s3Client.CreateBucket(context.Background(), params) return err } -// PutS3BucketPolicy applies an IAM resource policy to a given S3 bucket to create it's bucket policy +// PutS3BucketPolicy applies an IAM resource policy to a given S3 bucket to create its bucket policy func PutS3BucketPolicy(t testing.TestingT, region string, bucketName string, policyJSONString string) { err := PutS3BucketPolicyE(t, region, bucketName, policyJSONString) require.NoError(t, err) } -// PutS3BucketPolicyE applies an IAM resource policy to a given S3 bucket to create it's bucket policy +// PutS3BucketPolicyE applies an IAM resource policy to a given S3 bucket to create its bucket policy func PutS3BucketPolicyE(t testing.TestingT, region string, bucketName string, policyJSONString string) error { logger.Logf(t, "Applying bucket policy for bucket %s in %s", bucketName, region) @@ -172,7 +180,7 @@ func PutS3BucketPolicyE(t testing.TestingT, region string, bucketName string, po Policy: aws.String(policyJSONString), } - _, err = s3Client.PutBucketPolicy(input) + _, err = s3Client.PutBucketPolicy(context.Background(), input) return err } @@ -193,13 +201,13 @@ func PutS3BucketVersioningE(t testing.TestingT, region string, bucketName string input := &s3.PutBucketVersioningInput{ Bucket: aws.String(bucketName), - VersioningConfiguration: &s3.VersioningConfiguration{ - MFADelete: aws.String("Disabled"), - Status: aws.String("Enabled"), + VersioningConfiguration: &types.VersioningConfiguration{ + MFADelete: types.MFADeleteDisabled, + Status: types.BucketVersioningStatusEnabled, }, } - _, err = s3Client.PutBucketVersioning(input) + _, err = s3Client.PutBucketVersioning(context.Background(), input) return err } @@ -221,7 +229,7 @@ func DeleteS3BucketE(t testing.TestingT, region string, name string) error { params := &s3.DeleteBucketInput{ Bucket: aws.String(name), } - _, err = s3Client.DeleteBucket(params) + _, err = s3Client.DeleteBucket(context.Background(), params) return err } @@ -246,53 +254,53 @@ func EmptyS3BucketE(t testing.TestingT, region string, name string) error { for { // Requesting a batch of objects from s3 bucket - bucketObjects, err := s3Client.ListObjectVersions(params) + bucketObjects, err := s3Client.ListObjectVersions(context.Background(), params) if err != nil { return err } - //Checks if the bucket is already empty + // Checks if the bucket is already empty if len((*bucketObjects).Versions) == 0 { logger.Logf(t, "Bucket %s is already empty", name) return nil } - //creating an array of pointers of ObjectIdentifier - objectsToDelete := make([]*s3.ObjectIdentifier, 0, 1000) + // creating an array of pointers of ObjectIdentifier + objectsToDelete := make([]types.ObjectIdentifier, 0, 1000) for _, object := range (*bucketObjects).Versions { - obj := s3.ObjectIdentifier{ + obj := types.ObjectIdentifier{ Key: object.Key, VersionId: object.VersionId, } - objectsToDelete = append(objectsToDelete, &obj) + objectsToDelete = append(objectsToDelete, obj) } for _, object := range (*bucketObjects).DeleteMarkers { - obj := s3.ObjectIdentifier{ + obj := types.ObjectIdentifier{ Key: object.Key, VersionId: object.VersionId, } - objectsToDelete = append(objectsToDelete, &obj) + objectsToDelete = append(objectsToDelete, obj) } - //Creating JSON payload for bulk delete - deleteArray := s3.Delete{Objects: objectsToDelete} + // Creating JSON payload for bulk delete + deleteArray := types.Delete{Objects: objectsToDelete} deleteParams := &s3.DeleteObjectsInput{ Bucket: aws.String(name), Delete: &deleteArray, } - //Running the Bulk delete job (limit 1000) - _, err = s3Client.DeleteObjects(deleteParams) + // Running the Bulk delete job (limit 1000) + _, err = s3Client.DeleteObjects(context.Background(), deleteParams) if err != nil { return err } - if *(*bucketObjects).IsTruncated { //if there are more objects in the bucket, IsTruncated = true + if *(*bucketObjects).IsTruncated { // if there are more objects in the bucket, IsTruncated = true // params.Marker = (*deleteParams).Delete.Objects[len((*deleteParams).Delete.Objects)-1].Key params.KeyMarker = bucketObjects.NextKeyMarker logger.Logf(t, "Requesting next batch | %s", *(params.KeyMarker)) - } else { //if all objects in the bucket have been cleaned up. + } else { // if all objects in the bucket have been cleaned up. break } } @@ -316,7 +324,7 @@ func GetS3BucketLoggingTargetE(t testing.TestingT, awsRegion string, bucket stri return "", err } - res, err := s3Client.GetBucketLogging(&s3.GetBucketLoggingInput{ + res, err := s3Client.GetBucketLogging(context.Background(), &s3.GetBucketLoggingInput{ Bucket: &bucket, }) @@ -328,7 +336,7 @@ func GetS3BucketLoggingTargetE(t testing.TestingT, awsRegion string, bucket stri return "", S3AccessLoggingNotEnabledErr{bucket, awsRegion} } - return aws.StringValue(res.LoggingEnabled.TargetBucket), nil + return aws.ToString(res.LoggingEnabled.TargetBucket), nil } // GetS3BucketLoggingTargetPrefix fetches the given bucket's logging object prefix and returns it as a string @@ -347,7 +355,7 @@ func GetS3BucketLoggingTargetPrefixE(t testing.TestingT, awsRegion string, bucke return "", err } - res, err := s3Client.GetBucketLogging(&s3.GetBucketLoggingInput{ + res, err := s3Client.GetBucketLogging(context.Background(), &s3.GetBucketLoggingInput{ Bucket: &bucket, }) @@ -359,7 +367,7 @@ func GetS3BucketLoggingTargetPrefixE(t testing.TestingT, awsRegion string, bucke return "", S3AccessLoggingNotEnabledErr{bucket, awsRegion} } - return aws.StringValue(res.LoggingEnabled.TargetPrefix), nil + return aws.ToString(res.LoggingEnabled.TargetPrefix), nil } // GetS3BucketVersioning fetches the given bucket's versioning configuration status and returns it as a string @@ -377,14 +385,14 @@ func GetS3BucketVersioningE(t testing.TestingT, awsRegion string, bucket string) return "", err } - res, err := s3Client.GetBucketVersioning(&s3.GetBucketVersioningInput{ + res, err := s3Client.GetBucketVersioning(context.Background(), &s3.GetBucketVersioningInput{ Bucket: &bucket, }) if err != nil { return "", err } - return aws.StringValue(res.Status), nil + return string(res.Status), nil } // GetS3BucketPolicy fetches the given bucket's resource policy and returns it as a string @@ -402,14 +410,14 @@ func GetS3BucketPolicyE(t testing.TestingT, awsRegion string, bucket string) (st return "", err } - res, err := s3Client.GetBucketPolicy(&s3.GetBucketPolicyInput{ + res, err := s3Client.GetBucketPolicy(context.Background(), &s3.GetBucketPolicyInput{ Bucket: &bucket, }) if err != nil { return "", err } - return aws.StringValue(res.Policy), nil + return aws.ToString(res.Policy), nil } // AssertS3BucketExists checks if the given S3 bucket exists in the given region and fail the test if it does not. @@ -428,7 +436,7 @@ func AssertS3BucketExistsE(t testing.TestingT, region string, name string) error params := &s3.HeadBucketInput{ Bucket: aws.String(name), } - _, err = s3Client.HeadBucket(params) + _, err = s3Client.HeadBucket(context.Background(), params) return err } @@ -471,7 +479,7 @@ func AssertS3BucketPolicyExistsE(t testing.TestingT, region string, bucketName s } // NewS3Client creates an S3 client. -func NewS3Client(t testing.TestingT, region string) *s3.S3 { +func NewS3Client(t testing.TestingT, region string) *s3.Client { client, err := NewS3ClientE(t, region) require.NoError(t, err) @@ -479,30 +487,30 @@ func NewS3Client(t testing.TestingT, region string) *s3.S3 { } // NewS3ClientE creates an S3 client. -func NewS3ClientE(t testing.TestingT, region string) (*s3.S3, error) { +func NewS3ClientE(t testing.TestingT, region string) (*s3.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return s3.New(sess), nil + return s3.NewFromConfig(*sess), nil } // NewS3Uploader creates an S3 Uploader. -func NewS3Uploader(t testing.TestingT, region string) *s3manager.Uploader { +func NewS3Uploader(t testing.TestingT, region string) *manager.Uploader { uploader, err := NewS3UploaderE(t, region) require.NoError(t, err) return uploader } // NewS3UploaderE creates an S3 Uploader. -func NewS3UploaderE(t testing.TestingT, region string) (*s3manager.Uploader, error) { +func NewS3UploaderE(t testing.TestingT, region string) (*manager.Uploader, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return s3manager.NewUploader(sess), nil + return manager.NewUploader(s3.NewFromConfig(*sess)), nil } // S3AccessLoggingNotEnabledErr is a custom error that occurs when acess logging hasn't been enabled on the S3 Bucket diff --git a/modules/aws/s3_test.go b/modules/aws/s3_test.go index dafde4843..7f31726ed 100644 --- a/modules/aws/s3_test.go +++ b/modules/aws/s3_test.go @@ -2,15 +2,16 @@ package aws import ( + "context" "fmt" "math/rand" "strconv" "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" @@ -51,7 +52,7 @@ func TestAssertS3BucketExistsNoFalsePositive(t *testing.T) { logger.Logf(t, "Random values selected. Region = %s, s3BucketName = %s\n", region, s3BucketName) // We elect not to create the S3 bucket to confirm that our function correctly reports it doesn't exist. - //aws.CreateS3Bucket(region, s3BucketName) + // aws.CreateS3Bucket(region, s3BucketName) err := AssertS3BucketExistsE(t, region, s3BucketName) if err == nil { @@ -76,8 +77,7 @@ func TestAssertS3BucketVersioningEnabled(t *testing.T) { func TestEmptyS3Bucket(t *testing.T) { t.Parallel() - // region := GetRandomStableRegion(t, nil, nil) - region := "us-east-1" + region := GetRandomStableRegion(t, nil, nil) id := random.UniqueId() logger.Logf(t, "Random values selected. Region = %s, Id = %s\n", region, id) @@ -114,13 +114,13 @@ func TestEmptyS3BucketVersioned(t *testing.T) { versionInput := &s3.PutBucketVersioningInput{ Bucket: aws.String(s3BucketName), - VersioningConfiguration: &s3.VersioningConfiguration{ - MFADelete: aws.String("Disabled"), - Status: aws.String("Enabled"), + VersioningConfiguration: &types.VersioningConfiguration{ + MFADelete: types.MFADeleteDisabled, + Status: types.BucketVersioningStatusEnabled, }, } - _, err = s3Client.PutBucketVersioning(versionInput) + _, err = s3Client.PutBucketVersioning(context.Background(), versionInput) if err != nil { t.Fatal(err) } @@ -163,10 +163,10 @@ func TestGetS3BucketTags(t *testing.T) { t.Fatal(err) } - _, err = s3Client.PutBucketTagging(&s3.PutBucketTaggingInput{ + _, err = s3Client.PutBucketTagging(context.Background(), &s3.PutBucketTaggingInput{ Bucket: &s3BucketName, - Tagging: &s3.Tagging{ - TagSet: []*s3.Tag{ + Tagging: &types.Tagging{ + TagSet: []types.Tag{ { Key: aws.String("Key1"), Value: aws.String("Value1"), @@ -188,7 +188,7 @@ func TestGetS3BucketTags(t *testing.T) { assert.True(t, actualTags["NonExistentKey"] == "") } -func testEmptyBucket(t *testing.T, s3Client *s3.S3, region string, s3BucketName string) { +func testEmptyBucket(t *testing.T, s3Client *s3.Client, region string, s3BucketName string) { expectedFileCount := rand.Intn(1000) logger.Logf(t, "Uploading %s files to bucket %s", strconv.Itoa(expectedFileCount), s3BucketName) @@ -199,7 +199,7 @@ func testEmptyBucket(t *testing.T, s3Client *s3.S3, region string, s3BucketName key := fmt.Sprintf("test-%s", strconv.Itoa(i)) body := strings.NewReader("This is the body") - params := &s3manager.UploadInput{ + params := &s3.PutObjectInput{ Bucket: aws.String(s3BucketName), Key: &key, Body: body, @@ -207,14 +207,14 @@ func testEmptyBucket(t *testing.T, s3Client *s3.S3, region string, s3BucketName uploader := NewS3Uploader(t, region) - _, err := uploader.Upload(params) + _, err := uploader.Upload(context.Background(), params) if err != nil { t.Fatal(err) } // Delete the first 10 files to be able to test if all files, including delete markers are deleted if i < 10 { - _, err := s3Client.DeleteObject(&s3.DeleteObjectInput{ + _, err := s3Client.DeleteObject(context.Background(), &s3.DeleteObjectInput{ Bucket: aws.String(s3BucketName), Key: aws.String(key), }) @@ -239,7 +239,7 @@ func testEmptyBucket(t *testing.T, s3Client *s3.S3, region string, s3BucketName logger.Logf(t, "Verifying %s files were uploaded to bucket %s", strconv.Itoa(expectedFileCount), s3BucketName) actualCount := 0 for { - bucketObjects, err := s3Client.ListObjectsV2(listObjectsParams) + bucketObjects, err := s3Client.ListObjectsV2(context.Background(), listObjectsParams) if err != nil { t.Fatal(err) } @@ -256,12 +256,12 @@ func testEmptyBucket(t *testing.T, s3Client *s3.S3, region string, s3BucketName require.Equal(t, expectedFileCount-deleted, actualCount) - //empty bucket + // empty bucket logger.Logf(t, "Emptying bucket %s", s3BucketName) EmptyS3Bucket(t, region, s3BucketName) // verify the bucket is empty - bucketObjects, err := s3Client.ListObjectsV2(listObjectsParams) + bucketObjects, err := s3Client.ListObjectsV2(context.Background(), listObjectsParams) if err != nil { t.Fatal(err) } diff --git a/modules/aws/secretsmanager.go b/modules/aws/secretsmanager.go index 9f15d225b..e0de64728 100644 --- a/modules/aws/secretsmanager.go +++ b/modules/aws/secretsmanager.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/secretsmanager" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" @@ -21,7 +23,7 @@ func CreateSecretStringWithDefaultKeyE(t testing.TestingT, awsRegion, descriptio client := NewSecretsManagerClient(t, awsRegion) - secret, err := client.CreateSecret(&secretsmanager.CreateSecretInput{ + secret, err := client.CreateSecret(context.Background(), &secretsmanager.CreateSecretInput{ Description: aws.String(description), Name: aws.String(name), SecretString: aws.String(secretString), @@ -31,7 +33,7 @@ func CreateSecretStringWithDefaultKeyE(t testing.TestingT, awsRegion, descriptio return "", err } - return aws.StringValue(secret.ARN), nil + return aws.ToString(secret.ARN), nil } // GetSecretValue takes the friendly name or ARN of a secret and returns the plaintext value @@ -47,29 +49,29 @@ func GetSecretValueE(t testing.TestingT, awsRegion, id string) (string, error) { client := NewSecretsManagerClient(t, awsRegion) - secret, err := client.GetSecretValue(&secretsmanager.GetSecretValueInput{ + secret, err := client.GetSecretValue(context.Background(), &secretsmanager.GetSecretValueInput{ SecretId: aws.String(id), }) if err != nil { return "", err } - return aws.StringValue(secret.SecretString), nil + return aws.ToString(secret.SecretString), nil } -// DeleteSecret deletes a secret. If forceDelete is true, the secret will be deleted after a short delay. If forceDelete is false, the secret will be deleted after a 30 day recovery window. +// DeleteSecret deletes a secret. If forceDelete is true, the secret will be deleted after a short delay. If forceDelete is false, the secret will be deleted after a 30-day recovery window. func DeleteSecret(t testing.TestingT, awsRegion, id string, forceDelete bool) { err := DeleteSecretE(t, awsRegion, id, forceDelete) require.NoError(t, err) } -// DeleteSecretE deletes a secret. If forceDelete is true, the secret will be deleted after a short delay. If forceDelete is false, the secret will be deleted after a 30 day recovery window. +// DeleteSecretE deletes a secret. If forceDelete is true, the secret will be deleted after a short delay. If forceDelete is false, the secret will be deleted after a 30-day recovery window. func DeleteSecretE(t testing.TestingT, awsRegion, id string, forceDelete bool) error { logger.Logf(t, "Deleting secret with ID %s", id) client := NewSecretsManagerClient(t, awsRegion) - _, err := client.DeleteSecret(&secretsmanager.DeleteSecretInput{ + _, err := client.DeleteSecret(context.Background(), &secretsmanager.DeleteSecretInput{ ForceDeleteWithoutRecovery: aws.Bool(forceDelete), SecretId: aws.String(id), }) @@ -78,18 +80,18 @@ func DeleteSecretE(t testing.TestingT, awsRegion, id string, forceDelete bool) e } // NewSecretsManagerClient creates a new SecretsManager client. -func NewSecretsManagerClient(t testing.TestingT, region string) *secretsmanager.SecretsManager { +func NewSecretsManagerClient(t testing.TestingT, region string) *secretsmanager.Client { client, err := NewSecretsManagerClientE(t, region) require.NoError(t, err) return client } // NewSecretsManagerClientE creates a new SecretsManager client. -func NewSecretsManagerClientE(t testing.TestingT, region string) (*secretsmanager.SecretsManager, error) { +func NewSecretsManagerClientE(t testing.TestingT, region string) (*secretsmanager.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return secretsmanager.New(sess), nil + return secretsmanager.NewFromConfig(*sess), nil } diff --git a/modules/aws/sns.go b/modules/aws/sns.go index 0244c1c5d..faa32f3f2 100644 --- a/modules/aws/sns.go +++ b/modules/aws/sns.go @@ -1,8 +1,10 @@ package aws import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sns" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" ) @@ -29,12 +31,12 @@ func CreateSnsTopicE(t testing.TestingT, region string, snsTopicName string) (st Name: &snsTopicName, } - output, err := snsClient.CreateTopic(createTopicInput) + output, err := snsClient.CreateTopic(context.Background(), createTopicInput) if err != nil { return "", err } - return aws.StringValue(output.TopicArn), err + return aws.ToString(output.TopicArn), err } // DeleteSNSTopic deletes an SNS Topic. @@ -58,12 +60,12 @@ func DeleteSNSTopicE(t testing.TestingT, region string, snsTopicArn string) erro TopicArn: aws.String(snsTopicArn), } - _, err = snsClient.DeleteTopic(deleteTopicInput) + _, err = snsClient.DeleteTopic(context.Background(), deleteTopicInput) return err } // NewSnsClient creates a new SNS client. -func NewSnsClient(t testing.TestingT, region string) *sns.SNS { +func NewSnsClient(t testing.TestingT, region string) *sns.Client { client, err := NewSnsClientE(t, region) if err != nil { t.Fatal(err) @@ -72,11 +74,11 @@ func NewSnsClient(t testing.TestingT, region string) *sns.SNS { } // NewSnsClientE creates a new SNS client. -func NewSnsClientE(t testing.TestingT, region string) (*sns.SNS, error) { +func NewSnsClientE(t testing.TestingT, region string) (*sns.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return sns.New(sess), nil + return sns.NewFromConfig(*sess), nil } diff --git a/modules/aws/sns_test.go b/modules/aws/sns_test.go index 3445e7e96..a442c934b 100644 --- a/modules/aws/sns_test.go +++ b/modules/aws/sns_test.go @@ -1,12 +1,13 @@ package aws import ( + "context" "fmt" "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sns" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" ) @@ -29,7 +30,7 @@ func snsTopicExists(t *testing.T, region string, arn string) bool { input := sns.GetTopicAttributesInput{TopicArn: aws.String(arn)} - if _, err := snsClient.GetTopicAttributes(&input); err != nil { + if _, err := snsClient.GetTopicAttributes(context.Background(), &input); err != nil { if strings.Contains(err.Error(), "NotFound") { return false } diff --git a/modules/aws/sqs.go b/modules/aws/sqs.go index ed2bc4f1a..7fc50a8be 100644 --- a/modules/aws/sqs.go +++ b/modules/aws/sqs.go @@ -1,12 +1,14 @@ package aws import ( + "context" "fmt" "strconv" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" "github.com/google/uuid" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/testing" @@ -37,7 +39,7 @@ func CreateRandomQueueE(t testing.TestingT, awsRegion string, prefix string) (st channelName := fmt.Sprintf("%s-%s", prefix, channel.String()) - queue, err := sqsClient.CreateQueue(&sqs.CreateQueueInput{ + queue, err := sqsClient.CreateQueue(context.Background(), &sqs.CreateQueueInput{ QueueName: aws.String(channelName), }) @@ -45,7 +47,7 @@ func CreateRandomQueueE(t testing.TestingT, awsRegion string, prefix string) (st return "", err } - return aws.StringValue(queue.QueueUrl), nil + return aws.ToString(queue.QueueUrl), nil } // CreateRandomFifoQueue creates a new FIFO SQS queue with a random name that starts with the given prefix and return the queue URL. @@ -73,11 +75,11 @@ func CreateRandomFifoQueueE(t testing.TestingT, awsRegion string, prefix string) channelName := fmt.Sprintf("%s-%s.fifo", prefix, channel.String()) - queue, err := sqsClient.CreateQueue(&sqs.CreateQueueInput{ + queue, err := sqsClient.CreateQueue(context.Background(), &sqs.CreateQueueInput{ QueueName: aws.String(channelName), - Attributes: map[string]*string{ - "ContentBasedDeduplication": aws.String("true"), - "FifoQueue": aws.String("true"), + Attributes: map[string]string{ + "ContentBasedDeduplication": "true", + "FifoQueue": "true", }, }) @@ -85,7 +87,7 @@ func CreateRandomFifoQueueE(t testing.TestingT, awsRegion string, prefix string) return "", err } - return aws.StringValue(queue.QueueUrl), nil + return aws.ToString(queue.QueueUrl), nil } // DeleteQueue deletes the SQS queue with the given URL. @@ -105,7 +107,7 @@ func DeleteQueueE(t testing.TestingT, awsRegion string, queueURL string) error { return err } - _, err = sqsClient.DeleteQueue(&sqs.DeleteQueueInput{ + _, err = sqsClient.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{ QueueUrl: aws.String(queueURL), }) @@ -129,7 +131,7 @@ func DeleteMessageFromQueueE(t testing.TestingT, awsRegion string, queueURL stri return err } - _, err = sqsClient.DeleteMessage(&sqs.DeleteMessageInput{ + _, err = sqsClient.DeleteMessage(context.Background(), &sqs.DeleteMessageInput{ ReceiptHandle: &receipt, QueueUrl: &queueURL, }) @@ -154,7 +156,7 @@ func SendMessageToQueueE(t testing.TestingT, awsRegion string, queueURL string, return err } - res, err := sqsClient.SendMessage(&sqs.SendMessageInput{ + res, err := sqsClient.SendMessage(context.Background(), &sqs.SendMessageInput{ MessageBody: &message, QueueUrl: &queueURL, }) @@ -167,12 +169,12 @@ func SendMessageToQueueE(t testing.TestingT, awsRegion string, queueURL string, return err } - logger.Logf(t, "Message id %s sent to queue %s", aws.StringValue(res.MessageId), queueURL) + logger.Logf(t, "Message id %s sent to queue %s", aws.ToString(res.MessageId), queueURL) return nil } -// SendMessageToFifoQueue sends the given message to the FIFO SQS queue with the given URL. +// SendMessageFifoToQueue sends the given message to the FIFO SQS queue with the given URL. func SendMessageFifoToQueue(t testing.TestingT, awsRegion string, queueURL string, message string, messageGroupID string) { err := SendMessageToFifoQueueE(t, awsRegion, queueURL, message, messageGroupID) if err != nil { @@ -189,7 +191,7 @@ func SendMessageToFifoQueueE(t testing.TestingT, awsRegion string, queueURL stri return err } - res, err := sqsClient.SendMessage(&sqs.SendMessageInput{ + res, err := sqsClient.SendMessage(context.Background(), &sqs.SendMessageInput{ MessageBody: &message, QueueUrl: &queueURL, MessageGroupId: &messageGroupID, @@ -203,7 +205,7 @@ func SendMessageToFifoQueueE(t testing.TestingT, awsRegion string, queueURL stri return err } - logger.Logf(t, "Message id %s sent to queue %s", aws.StringValue(res.MessageId), queueURL) + logger.Logf(t, "Message id %s sent to queue %s", aws.ToString(res.MessageId), queueURL) return nil } @@ -232,12 +234,12 @@ func WaitForQueueMessage(t testing.TestingT, awsRegion string, queueURL string, for i := 0; i < cycles; i++ { logger.Logf(t, "Waiting for message on %s (%ss)", queueURL, strconv.Itoa(i*cycleLength)) - result, err := sqsClient.ReceiveMessage(&sqs.ReceiveMessageInput{ - QueueUrl: aws.String(queueURL), - AttributeNames: aws.StringSlice([]string{"SentTimestamp"}), - MaxNumberOfMessages: aws.Int64(1), - MessageAttributeNames: aws.StringSlice([]string{"All"}), - WaitTimeSeconds: aws.Int64(int64(cycleLength)), + result, err := sqsClient.ReceiveMessage(context.Background(), &sqs.ReceiveMessageInput{ + QueueUrl: aws.String(queueURL), + MessageSystemAttributeNames: []types.MessageSystemAttributeName{types.MessageSystemAttributeNameSentTimestamp}, + MaxNumberOfMessages: int32(1), + MessageAttributeNames: []string{"All"}, + WaitTimeSeconds: int32(cycleLength), }) if err != nil { @@ -254,7 +256,7 @@ func WaitForQueueMessage(t testing.TestingT, awsRegion string, queueURL string, } // NewSqsClient creates a new SQS client. -func NewSqsClient(t testing.TestingT, region string) *sqs.SQS { +func NewSqsClient(t testing.TestingT, region string) *sqs.Client { client, err := NewSqsClientE(t, region) if err != nil { t.Fatal(err) @@ -263,13 +265,13 @@ func NewSqsClient(t testing.TestingT, region string) *sqs.SQS { } // NewSqsClientE creates a new SQS client. -func NewSqsClientE(t testing.TestingT, region string) (*sqs.SQS, error) { +func NewSqsClientE(t testing.TestingT, region string) (*sqs.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return sqs.New(sess), nil + return sqs.NewFromConfig(*sess), nil } // ReceiveMessageTimeout is an error that occurs if receiving a message times out. diff --git a/modules/aws/sqs_test.go b/modules/aws/sqs_test.go index 6200e8879..5f975d500 100644 --- a/modules/aws/sqs_test.go +++ b/modules/aws/sqs_test.go @@ -1,12 +1,13 @@ package aws import ( + "context" "fmt" "strings" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/gruntwork-io/terratest/modules/random" "github.com/stretchr/testify/assert" ) @@ -71,7 +72,7 @@ func queueExists(t *testing.T, region string, url string) bool { input := sqs.GetQueueAttributesInput{QueueUrl: aws.String(url)} - if _, err := sqsClient.GetQueueAttributes(&input); err != nil { + if _, err := sqsClient.GetQueueAttributes(context.Background(), &input); err != nil { if strings.Contains(err.Error(), "NonExistentQueue") { return false } diff --git a/modules/aws/ssm.go b/modules/aws/ssm.go index 55462df12..2f3ba275b 100644 --- a/modules/aws/ssm.go +++ b/modules/aws/ssm.go @@ -1,11 +1,14 @@ package aws import ( + "context" + "errors" "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/gruntwork-io/terratest/modules/logger" "github.com/gruntwork-io/terratest/modules/retry" "github.com/gruntwork-io/terratest/modules/testing" @@ -29,9 +32,9 @@ func GetParameterE(t testing.TestingT, awsRegion string, keyName string) (string return GetParameterWithClientE(t, ssmClient, keyName) } -// GetParameterE retrieves the latest version of SSM Parameter at keyName with decryption with the ability to provide the SSM client. -func GetParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName string) (string, error) { - resp, err := client.GetParameter(&ssm.GetParameterInput{Name: aws.String(keyName), WithDecryption: aws.Bool(true)}) +// GetParameterWithClientE retrieves the latest version of SSM Parameter at keyName with decryption with the ability to provide the SSM client. +func GetParameterWithClientE(t testing.TestingT, client *ssm.Client, keyName string) (string, error) { + resp, err := client.GetParameter(context.Background(), &ssm.GetParameterInput{Name: aws.String(keyName), WithDecryption: aws.Bool(true)}) if err != nil { return "", err } @@ -56,14 +59,19 @@ func PutParameterE(t testing.TestingT, awsRegion string, keyName string, keyDesc return PutParameterWithClientE(t, ssmClient, keyName, keyDescription, keyValue) } -// PutParameterE creates new version of SSM Parameter at keyName with keyValue as SecureString with the ability to provide the SSM client. -func PutParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName string, keyDescription string, keyValue string) (int64, error) { - resp, err := client.PutParameter(&ssm.PutParameterInput{Name: aws.String(keyName), Description: aws.String(keyDescription), Value: aws.String(keyValue), Type: aws.String("SecureString")}) +// PutParameterWithClientE creates new version of SSM Parameter at keyName with keyValue as SecureString with the ability to provide the SSM client. +func PutParameterWithClientE(t testing.TestingT, client *ssm.Client, keyName string, keyDescription string, keyValue string) (int64, error) { + resp, err := client.PutParameter(context.Background(), &ssm.PutParameterInput{ + Name: aws.String(keyName), + Description: aws.String(keyDescription), + Value: aws.String(keyValue), + Type: types.ParameterTypeSecureString, + }) if err != nil { return 0, err } - return *resp.Version, nil + return resp.Version, nil } // DeleteParameter deletes all versions of SSM Parameter at keyName. @@ -81,9 +89,9 @@ func DeleteParameterE(t testing.TestingT, awsRegion string, keyName string) erro return DeleteParameterWithClientE(t, ssmClient, keyName) } -// DeleteParameterE deletes all versions of SSM Parameter at keyName with the ability to provide the SSM client. -func DeleteParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName string) error { - _, err := client.DeleteParameter(&ssm.DeleteParameterInput{Name: aws.String(keyName)}) +// DeleteParameterWithClientE deletes all versions of SSM Parameter at keyName with the ability to provide the SSM client. +func DeleteParameterWithClientE(t testing.TestingT, client *ssm.Client, keyName string) error { + _, err := client.DeleteParameter(context.Background(), &ssm.DeleteParameterInput{Name: aws.String(keyName)}) if err != nil { return err } @@ -91,21 +99,21 @@ func DeleteParameterWithClientE(t testing.TestingT, client *ssm.SSM, keyName str return nil } -// NewSsmClient creates a SSM client. -func NewSsmClient(t testing.TestingT, region string) *ssm.SSM { +// NewSsmClient creates an SSM client. +func NewSsmClient(t testing.TestingT, region string) *ssm.Client { client, err := NewSsmClientE(t, region) require.NoError(t, err) return client } // NewSsmClientE creates an SSM client. -func NewSsmClientE(t testing.TestingT, region string) (*ssm.SSM, error) { +func NewSsmClientE(t testing.TestingT, region string) (*ssm.Client, error) { sess, err := NewAuthenticatedSession(region) if err != nil { return nil, err } - return ssm.New(sess), nil + return ssm.NewFromConfig(*sess), nil } // WaitForSsmInstanceE waits until the instance get registered to the SSM inventory. @@ -117,23 +125,23 @@ func WaitForSsmInstanceE(t testing.TestingT, awsRegion, instanceID string, timeo return WaitForSsmInstanceWithClientE(t, client, instanceID, timeout) } -// WaitForSsmInstanceE waits until the instance get registered to the SSM inventory with the ability to provide the SSM client. -func WaitForSsmInstanceWithClientE(t testing.TestingT, client *ssm.SSM, instanceID string, timeout time.Duration) error { +// WaitForSsmInstanceWithClientE waits until the instance get registered to the SSM inventory with the ability to provide the SSM client. +func WaitForSsmInstanceWithClientE(t testing.TestingT, client *ssm.Client, instanceID string, timeout time.Duration) error { timeBetweenRetries := 2 * time.Second maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds()) description := fmt.Sprintf("Waiting for %s to appear in the SSM inventory", instanceID) input := &ssm.GetInventoryInput{ - Filters: []*ssm.InventoryFilter{ + Filters: []types.InventoryFilter{ { Key: aws.String("AWS:InstanceInformation.InstanceId"), - Type: aws.String("Equal"), - Values: aws.StringSlice([]string{instanceID}), + Type: types.InventoryQueryOperatorTypeEqual, + Values: []string{instanceID}, }, }, } _, err := retry.DoWithRetryE(t, description, maxRetries, timeBetweenRetries, func() (string, error) { - resp, err := client.GetInventory(input) + resp, err := client.GetInventory(context.Background(), input) if err != nil { return "", err @@ -173,7 +181,7 @@ func CheckSsmCommandE(t testing.TestingT, awsRegion, instanceID, command string, } // CheckSSMCommandWithClientE checks that you can run the given command on the given instance through AWS SSM with the ability to provide the SSM client. Returns the result and an error if one occurs. -func CheckSSMCommandWithClientE(t testing.TestingT, client *ssm.SSM, instanceID, command string, timeout time.Duration) (*CommandOutput, error) { +func CheckSSMCommandWithClientE(t testing.TestingT, client *ssm.Client, instanceID, command string, timeout time.Duration) (*CommandOutput, error) { return CheckSSMCommandWithClientWithDocumentE(t, client, instanceID, command, "AWS-RunShellScript", timeout) } @@ -197,19 +205,22 @@ func CheckSsmCommandWithDocumentE(t testing.TestingT, awsRegion, instanceID, com } // CheckSSMCommandWithClientWithDocumentE checks that you can run the given command on the given instance through AWS SSM with the ability to provide the SSM client with specified Command Doc type. Returns the result and an error if one occurs. -func CheckSSMCommandWithClientWithDocumentE(t testing.TestingT, client *ssm.SSM, instanceID, command string, commandDocName string, timeout time.Duration) (*CommandOutput, error) { +func CheckSSMCommandWithClientWithDocumentE(t testing.TestingT, client *ssm.Client, instanceID, command string, commandDocName string, timeout time.Duration) (*CommandOutput, error) { timeBetweenRetries := 2 * time.Second maxRetries := int(timeout.Seconds() / timeBetweenRetries.Seconds()) - resp, err := client.SendCommand(&ssm.SendCommandInput{ - Comment: aws.String("Terratest SSM"), - DocumentName: aws.String(commandDocName), - InstanceIds: aws.StringSlice([]string{instanceID}), - Parameters: map[string][]*string{ - "commands": aws.StringSlice([]string{command}), + resp, err := client.SendCommand( + context.Background(), + &ssm.SendCommandInput{ + Comment: aws.String("Terratest SSM"), + DocumentName: aws.String(commandDocName), + InstanceIds: []string{instanceID}, + Parameters: map[string][]string{ + "commands": {command}, + }, }, - }) + ) if err != nil { return nil, err } @@ -225,7 +236,7 @@ func CheckSSMCommandWithClientWithDocumentE(t testing.TestingT, client *ssm.SSM, result := &CommandOutput{} _, err = retry.DoWithRetryableErrorsE(t, description, retryableErrors, maxRetries, timeBetweenRetries, func() (string, error) { - resp, err := client.GetCommandInvocation(&ssm.GetCommandInvocationInput{ + resp, err := client.GetCommandInvocation(context.Background(), &ssm.GetCommandInvocationInput{ CommandId: resp.Command.CommandId, InstanceId: &instanceID, }) @@ -234,25 +245,26 @@ func CheckSSMCommandWithClientWithDocumentE(t testing.TestingT, client *ssm.SSM, return "", err } - result.Stderr = aws.StringValue(resp.StandardErrorContent) - result.Stdout = aws.StringValue(resp.StandardOutputContent) - result.ExitCode = aws.Int64Value(resp.ResponseCode) + result.Stderr = aws.ToString(resp.StandardErrorContent) + result.Stdout = aws.ToString(resp.StandardOutputContent) + result.ExitCode = int64(resp.ResponseCode) - status := aws.StringValue(resp.Status) + status := resp.Status - if status == ssm.CommandInvocationStatusSuccess { + if status == types.CommandInvocationStatusSuccess { return "", nil } - if status == ssm.CommandInvocationStatusFailed { - return "", fmt.Errorf(aws.StringValue(resp.StatusDetails)) + if status == types.CommandInvocationStatusFailed { + return "", fmt.Errorf(aws.ToString(resp.StatusDetails)) } return "", fmt.Errorf("bad status: %s", status) }) if err != nil { - if actualErr, ok := err.(retry.FatalError); ok { + var actualErr retry.FatalError + if errors.As(err, &actualErr) { return result, actualErr.Underlying } return result, fmt.Errorf("unexpected error: %v", err) diff --git a/modules/aws/vpc.go b/modules/aws/vpc.go index 5ee142a3f..f66c2bebc 100644 --- a/modules/aws/vpc.go +++ b/modules/aws/vpc.go @@ -1,12 +1,14 @@ package aws import ( + "context" "fmt" "strconv" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/gruntwork-io/terratest/modules/random" "github.com/gruntwork-io/terratest/modules/testing" "github.com/stretchr/testify/require" @@ -50,45 +52,45 @@ func GetDefaultVpc(t testing.TestingT, region string) *Vpc { // GetDefaultVpcE fetches information about the default VPC in the given region. func GetDefaultVpcE(t testing.TestingT, region string) (*Vpc, error) { - defaultVpcFilter := ec2.Filter{Name: aws.String(isDefaultFilterName), Values: []*string{aws.String(isDefaultFilterValue)}} - vpcs, err := GetVpcsE(t, []*ec2.Filter{&defaultVpcFilter}, region) + defaultVpcFilter := types.Filter{Name: aws.String(isDefaultFilterName), Values: []string{isDefaultFilterValue}} + vpcs, err := GetVpcsE(t, []types.Filter{defaultVpcFilter}, region) numVpcs := len(vpcs) if numVpcs != 1 { - return nil, fmt.Errorf("Expected to find one default VPC in region %s but found %s", region, strconv.Itoa(numVpcs)) + return nil, fmt.Errorf("expected to find one default VPC in region %s but found %s", region, strconv.Itoa(numVpcs)) } return vpcs[0], err } -// GetVpcById fetches information about a VPC with given Id in the given region. +// GetVpcById fetches information about a VPC with given ID in the given region. func GetVpcById(t testing.TestingT, vpcId string, region string) *Vpc { vpc, err := GetVpcByIdE(t, vpcId, region) require.NoError(t, err) return vpc } -// GetVpcByIdE fetches information about a VPC with given Id in the given region. +// GetVpcByIdE fetches information about a VPC with given ID in the given region. func GetVpcByIdE(t testing.TestingT, vpcId string, region string) (*Vpc, error) { - vpcIdFilter := ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcId}} - vpcs, err := GetVpcsE(t, []*ec2.Filter{&vpcIdFilter}, region) + vpcIdFilter := types.Filter{Name: aws.String(vpcIDFilterName), Values: []string{vpcId}} + vpcs, err := GetVpcsE(t, []types.Filter{vpcIdFilter}, region) numVpcs := len(vpcs) if numVpcs != 1 { - return nil, fmt.Errorf("Expected to find one VPC with ID %s in region %s but found %s", vpcId, region, strconv.Itoa(numVpcs)) + return nil, fmt.Errorf("expected to find one VPC with ID %s in region %s but found %s", vpcId, region, strconv.Itoa(numVpcs)) } return vpcs[0], err } -// GetVpcsE fetches informations about VPCs from given regions limited by filters -func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc, error) { - client, err := NewEc2ClientE(t, region) +// GetVpcsE fetches information about VPCs from given regions limited by filters +func GetVpcsE(t testing.TestingT, filters []types.Filter, region string) ([]*Vpc, error) { + client, err := NewEc2ClientV2E(t, region) if err != nil { return nil, err } - vpcs, err := client.DescribeVpcs(&ec2.DescribeVpcsInput{Filters: filters}) + vpcs, err := client.DescribeVpcs(context.Background(), &ec2.DescribeVpcsInput{Filters: filters}) if err != nil { return nil, err } @@ -97,13 +99,13 @@ func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc, retVal := make([]*Vpc, numVpcs) for i, vpc := range vpcs.Vpcs { - vpcIdFilter := generateVpcIdFilter(aws.StringValue(vpc.VpcId)) - subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIdFilter}) + vpcIdFilter := generateVpcIdFilter(aws.ToString(vpc.VpcId)) + subnets, err := GetSubnetsForVpcE(t, region, []types.Filter{vpcIdFilter}) if err != nil { return nil, err } - tags, err := GetTagsForVpcE(t, aws.StringValue(vpc.VpcId), region) + tags, err := GetTagsForVpcE(t, aws.ToString(vpc.VpcId), region) if err != nil { return nil, err } @@ -125,7 +127,7 @@ func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc, }() retVal[i] = &Vpc{ - Id: aws.StringValue(vpc.VpcId), + Id: aws.ToString(vpc.VpcId), Name: FindVpcName(vpc), Subnets: subnets, Tags: tags, @@ -140,7 +142,7 @@ func GetVpcsE(t testing.TestingT, filters []*ec2.Filter, region string) ([]*Vpc, // FindVpcName extracts the VPC name from its tags (if any). Fall back to "Default" if it's the default VPC or empty string // otherwise. -func FindVpcName(vpc *ec2.Vpc) string { +func FindVpcName(vpc types.Vpc) string { for _, tag := range vpc.Tags { if *tag.Key == "Name" { return *tag.Value @@ -157,7 +159,7 @@ func FindVpcName(vpc *ec2.Vpc) string { // GetSubnetsForVpc gets the subnets in the specified VPC. func GetSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet { vpcIDFilter := generateVpcIdFilter(vpcID) - subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIDFilter}) + subnets, err := GetSubnetsForVpcE(t, region, []types.Filter{vpcIDFilter}) if err != nil { t.Fatal(err) } @@ -167,11 +169,11 @@ func GetSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet // GetAzDefaultSubnetsForVpc gets the default az subnets in the specified VPC. func GetAzDefaultSubnetsForVpc(t testing.TestingT, vpcID string, region string) []Subnet { vpcIDFilter := generateVpcIdFilter(vpcID) - defaultForAzFilter := ec2.Filter{ + defaultForAzFilter := types.Filter{ Name: aws.String(defaultForAzFilterName), - Values: []*string{aws.String("true")}, + Values: []string{"true"}, } - subnets, err := GetSubnetsForVpcE(t, region, []*ec2.Filter{&vpcIDFilter, &defaultForAzFilter}) + subnets, err := GetSubnetsForVpcE(t, region, []types.Filter{vpcIDFilter, defaultForAzFilter}) if err != nil { t.Fatal(err) } @@ -179,27 +181,27 @@ func GetAzDefaultSubnetsForVpc(t testing.TestingT, vpcID string, region string) } // generateVpcIdFilter is a helper method to generate vpc id filter -func generateVpcIdFilter(vpcID string) ec2.Filter { - return ec2.Filter{Name: aws.String(vpcIDFilterName), Values: []*string{&vpcID}} +func generateVpcIdFilter(vpcID string) types.Filter { + return types.Filter{Name: aws.String(vpcIDFilterName), Values: []string{vpcID}} } // GetSubnetsForVpcE gets the subnets in the specified VPC. -func GetSubnetsForVpcE(t testing.TestingT, region string, filters []*ec2.Filter) ([]Subnet, error) { - client, err := NewEc2ClientE(t, region) +func GetSubnetsForVpcE(t testing.TestingT, region string, filters []types.Filter) ([]Subnet, error) { + client, err := NewEc2ClientV2E(t, region) if err != nil { return nil, err } - subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: filters}) + subnetOutput, err := client.DescribeSubnets(context.Background(), &ec2.DescribeSubnetsInput{Filters: filters}) if err != nil { return nil, err } - subnets := []Subnet{} + var subnets []Subnet for _, ec2Subnet := range subnetOutput.Subnets { subnetTags := GetTagsForSubnet(t, *ec2Subnet.SubnetId, region) - subnet := Subnet{Id: aws.StringValue(ec2Subnet.SubnetId), AvailabilityZone: aws.StringValue(ec2Subnet.AvailabilityZone), DefaultForAz: aws.BoolValue(ec2Subnet.DefaultForAz), Tags: subnetTags} + subnet := Subnet{Id: aws.ToString(ec2Subnet.SubnetId), AvailabilityZone: aws.ToString(ec2Subnet.AvailabilityZone), DefaultForAz: aws.ToBool(ec2Subnet.DefaultForAz), Tags: subnetTags} subnets = append(subnets, subnet) } @@ -216,17 +218,17 @@ func GetTagsForVpc(t testing.TestingT, vpcID string, region string) map[string]s // GetTagsForVpcE gets the tags for the specified VPC. func GetTagsForVpcE(t testing.TestingT, vpcID string, region string) (map[string]string, error) { - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) require.NoError(t, err) - vpcResourceTypeFilter := ec2.Filter{Name: aws.String(resourceIdFilterName), Values: []*string{aws.String(vpcResourceTypeFilterValue)}} - vpcResourceIdFilter := ec2.Filter{Name: aws.String(resourceTypeFilterName), Values: []*string{&vpcID}} - tagsOutput, err := client.DescribeTags(&ec2.DescribeTagsInput{Filters: []*ec2.Filter{&vpcResourceTypeFilter, &vpcResourceIdFilter}}) + vpcResourceTypeFilter := types.Filter{Name: aws.String(resourceIdFilterName), Values: []string{vpcResourceTypeFilterValue}} + vpcResourceIdFilter := types.Filter{Name: aws.String(resourceTypeFilterName), Values: []string{vpcID}} + tagsOutput, err := client.DescribeTags(context.Background(), &ec2.DescribeTagsInput{Filters: []types.Filter{vpcResourceTypeFilter, vpcResourceIdFilter}}) require.NoError(t, err) tags := map[string]string{} for _, tag := range tagsOutput.Tags { - tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + tags[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } return tags, nil @@ -244,12 +246,12 @@ func GetDefaultSubnetIDsForVpcE(t testing.TestingT, vpc Vpc) ([]string, error) { if vpc.Name != defaultVPCName { // You cannot create a default subnet in a nondefault VPC // https://docs.aws.amazon.com/vpc/latest/userguide/default-vpc.html - return nil, fmt.Errorf("Only default VPCs have default subnets but VPC with id %s is not default VPC", vpc.Id) + return nil, fmt.Errorf("only default VPCs have default subnets but VPC with id %s is not default VPC", vpc.Id) } - subnetIDs := []string{} + var subnetIDs []string numSubnets := len(vpc.Subnets) if numSubnets == 0 { - return nil, fmt.Errorf("Expected to find at least one subnet in vpc with ID %s but found zero", vpc.Id) + return nil, fmt.Errorf("expected to find at least one subnet in vpc with ID %s but found zero", vpc.Id) } for _, subnet := range vpc.Subnets { @@ -270,17 +272,17 @@ func GetTagsForSubnet(t testing.TestingT, subnetId string, region string) map[st // GetTagsForSubnetE gets the tags for the specified subnet. func GetTagsForSubnetE(t testing.TestingT, subnetId string, region string) (map[string]string, error) { - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) require.NoError(t, err) - subnetResourceTypeFilter := ec2.Filter{Name: aws.String(resourceIdFilterName), Values: []*string{aws.String(subnetResourceTypeFilterValue)}} - subnetResourceIdFilter := ec2.Filter{Name: aws.String(resourceTypeFilterName), Values: []*string{&subnetId}} - tagsOutput, err := client.DescribeTags(&ec2.DescribeTagsInput{Filters: []*ec2.Filter{&subnetResourceTypeFilter, &subnetResourceIdFilter}}) + subnetResourceTypeFilter := types.Filter{Name: aws.String(resourceIdFilterName), Values: []string{subnetResourceTypeFilterValue}} + subnetResourceIdFilter := types.Filter{Name: aws.String(resourceTypeFilterName), Values: []string{subnetId}} + tagsOutput, err := client.DescribeTags(context.Background(), &ec2.DescribeTagsInput{Filters: []types.Filter{subnetResourceTypeFilter, subnetResourceIdFilter}}) require.NoError(t, err) tags := map[string]string{} for _, tag := range tagsOutput.Tags { - tags[aws.StringValue(tag.Key)] = aws.StringValue(tag.Value) + tags[aws.ToString(tag.Key)] = aws.ToString(tag.Value) } return tags, nil @@ -297,17 +299,17 @@ func IsPublicSubnet(t testing.TestingT, subnetId string, region string) bool { func IsPublicSubnetE(t testing.TestingT, subnetId string, region string) (bool, error) { subnetIdFilterName := "association.subnet-id" - subnetIdFilter := ec2.Filter{ + subnetIdFilter := types.Filter{ Name: &subnetIdFilterName, - Values: []*string{&subnetId}, + Values: []string{subnetId}, } - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) if err != nil { return false, err } - rts, err := client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{Filters: []*ec2.Filter{&subnetIdFilter}}) + rts, err := client.DescribeRouteTables(context.Background(), &ec2.DescribeRouteTablesInput{Filters: []types.Filter{subnetIdFilter}}) if err != nil { return false, err } @@ -322,7 +324,7 @@ func IsPublicSubnetE(t testing.TestingT, subnetId string, region string) (bool, for _, rt := range rts.RouteTables { for _, r := range rt.Routes { - if strings.HasPrefix(aws.StringValue(r.GatewayId), "igw-") { + if strings.HasPrefix(aws.ToString(r.GatewayId), "igw-") { return true, nil } } @@ -336,33 +338,33 @@ func getImplicitRouteTableForSubnetE(t testing.TestingT, subnetId string, region mainRouteFilterValue := "true" subnetFilterName := "subnet-id" - client, err := NewEc2ClientE(t, region) + client, err := NewEc2ClientV2E(t, region) if err != nil { return nil, err } - subnetFilter := ec2.Filter{ + subnetFilter := types.Filter{ Name: &subnetFilterName, - Values: []*string{&subnetId}, + Values: []string{subnetId}, } - subnetOutput, err := client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&subnetFilter}}) + subnetOutput, err := client.DescribeSubnets(context.Background(), &ec2.DescribeSubnetsInput{Filters: []types.Filter{subnetFilter}}) if err != nil { return nil, err } numSubnets := len(subnetOutput.Subnets) if numSubnets != 1 { - return nil, fmt.Errorf("Expected to find one subnet with id %s but found %s", subnetId, strconv.Itoa(numSubnets)) + return nil, fmt.Errorf("expected to find one subnet with id %s but found %s", subnetId, strconv.Itoa(numSubnets)) } - mainRouteFilter := ec2.Filter{ + mainRouteFilter := types.Filter{ Name: &mainRouteFilterName, - Values: []*string{&mainRouteFilterValue}, + Values: []string{mainRouteFilterValue}, } - vpcFilter := ec2.Filter{ + vpcFilter := types.Filter{ Name: aws.String(vpcIDFilterName), - Values: []*string{subnetOutput.Subnets[0].VpcId}, + Values: []string{*subnetOutput.Subnets[0].VpcId}, } - return client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{Filters: []*ec2.Filter{&mainRouteFilter, &vpcFilter}}) + return client.DescribeRouteTables(context.Background(), &ec2.DescribeRouteTablesInput{Filters: []types.Filter{mainRouteFilter, vpcFilter}}) } // GetRandomPrivateCidrBlock gets a random CIDR block from the range of acceptable private IP addresses per RFC 1918 diff --git a/modules/aws/vpc_test.go b/modules/aws/vpc_test.go index 8a060efc8..b3e395965 100644 --- a/modules/aws/vpc_test.go +++ b/modules/aws/vpc_test.go @@ -1,13 +1,15 @@ package aws import ( + "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) func TestGetDefaultVpc(t *testing.T) { @@ -42,8 +44,8 @@ func TestGetVpcsE(t *testing.T) { isDefaultFilterName := "isDefault" isDefaultFilterValue := "true" - defaultVpcFilter := ec2.Filter{Name: &isDefaultFilterName, Values: []*string{&isDefaultFilterValue}} - vpcs, _ := GetVpcsE(t, []*ec2.Filter{&defaultVpcFilter}, region) + defaultVpcFilter := types.Filter{Name: &isDefaultFilterName, Values: []string{isDefaultFilterValue}} + vpcs, _ := GetVpcsE(t, []types.Filter{defaultVpcFilter}, region) require.Equal(t, len(vpcs), 1) assert.NotEmpty(t, vpcs[0].Name) @@ -164,24 +166,24 @@ func TestGetDefaultAzSubnets(t *testing.T) { vpc := GetDefaultVpc(t, region) // Note: cannot know exact list of default azs aheard of time, but we know that - //it must be greater than 0 for default vpc. + // it must be greater than 0 for default vpc. subnets := GetAzDefaultSubnetsForVpc(t, vpc.Id, region) assert.NotZero(t, len(subnets)) } func createPublicRoute(t *testing.T, vpcId string, routeTableId string, region string) { - ec2Client := NewEc2Client(t, region) + ec2Client := NewEc2ClientV2(t, region) - createIGWOut, igerr := ec2Client.CreateInternetGateway(&ec2.CreateInternetGatewayInput{}) + createIGWOut, igerr := ec2Client.CreateInternetGateway(context.Background(), &ec2.CreateInternetGatewayInput{}) require.NoError(t, igerr) - _, aigerr := ec2Client.AttachInternetGateway(&ec2.AttachInternetGatewayInput{ + _, aigerr := ec2Client.AttachInternetGateway(context.Background(), &ec2.AttachInternetGatewayInput{ InternetGatewayId: createIGWOut.InternetGateway.InternetGatewayId, VpcId: aws.String(vpcId), }) require.NoError(t, aigerr) - _, err := ec2Client.CreateRoute(&ec2.CreateRouteInput{ + _, err := ec2Client.CreateRoute(context.Background(), &ec2.CreateRouteInput{ RouteTableId: aws.String(routeTableId), DestinationCidrBlock: aws.String("0.0.0.0/0"), GatewayId: createIGWOut.InternetGateway.InternetGatewayId, @@ -190,10 +192,10 @@ func createPublicRoute(t *testing.T, vpcId string, routeTableId string, region s require.NoError(t, err) } -func createRouteTable(t *testing.T, vpcId string, region string) ec2.RouteTable { - ec2Client := NewEc2Client(t, region) +func createRouteTable(t *testing.T, vpcId string, region string) types.RouteTable { + ec2Client := NewEc2ClientV2(t, region) - createRouteTableOutput, err := ec2Client.CreateRouteTable(&ec2.CreateRouteTableInput{ + createRouteTableOutput, err := ec2Client.CreateRouteTable(context.Background(), &ec2.CreateRouteTableInput{ VpcId: aws.String(vpcId), }) @@ -201,16 +203,16 @@ func createRouteTable(t *testing.T, vpcId string, region string) ec2.RouteTable return *createRouteTableOutput.RouteTable } -func createSubnet(t *testing.T, vpcId string, routeTableId string, region string) ec2.Subnet { - ec2Client := NewEc2Client(t, region) +func createSubnet(t *testing.T, vpcId string, routeTableId string, region string) types.Subnet { + ec2Client := NewEc2ClientV2(t, region) - createSubnetOutput, err := ec2Client.CreateSubnet(&ec2.CreateSubnetInput{ + createSubnetOutput, err := ec2Client.CreateSubnet(context.Background(), &ec2.CreateSubnetInput{ CidrBlock: aws.String("10.10.1.0/24"), VpcId: aws.String(vpcId), }) require.NoError(t, err) - _, err = ec2Client.AssociateRouteTable(&ec2.AssociateRouteTableInput{ + _, err = ec2Client.AssociateRouteTable(context.Background(), &ec2.AssociateRouteTableInput{ RouteTableId: aws.String(routeTableId), SubnetId: aws.String(*createSubnetOutput.Subnet.SubnetId), }) @@ -219,10 +221,10 @@ func createSubnet(t *testing.T, vpcId string, routeTableId string, region string return *createSubnetOutput.Subnet } -func createVpc(t *testing.T, region string) ec2.Vpc { - ec2Client := NewEc2Client(t, region) +func createVpc(t *testing.T, region string) types.Vpc { + ec2Client := NewEc2ClientV2(t, region) - createVpcOutput, err := ec2Client.CreateVpc(&ec2.CreateVpcInput{ + createVpcOutput, err := ec2Client.CreateVpc(context.Background(), &ec2.CreateVpcInput{ CidrBlock: aws.String("10.10.0.0/16"), }) @@ -231,32 +233,32 @@ func createVpc(t *testing.T, region string) ec2.Vpc { } func deleteRouteTables(t *testing.T, vpcId string, region string) { - ec2Client := NewEc2Client(t, region) + ec2Client := NewEc2ClientV2(t, region) vpcIDFilterName := "vpc-id" - vpcIDFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcId}} + vpcIDFilter := types.Filter{Name: &vpcIDFilterName, Values: []string{vpcId}} // "You can't delete the main route table." mainRTFilterName := "association.main" mainRTFilterValue := "false" - notMainRTFilter := ec2.Filter{Name: &mainRTFilterName, Values: []*string{&mainRTFilterValue}} + notMainRTFilter := types.Filter{Name: &mainRTFilterName, Values: []string{mainRTFilterValue}} - filters := []*ec2.Filter{&vpcIDFilter, ¬MainRTFilter} + filters := []types.Filter{vpcIDFilter, notMainRTFilter} - rtOutput, err := ec2Client.DescribeRouteTables(&ec2.DescribeRouteTablesInput{Filters: filters}) + rtOutput, err := ec2Client.DescribeRouteTables(context.Background(), &ec2.DescribeRouteTablesInput{Filters: filters}) require.NoError(t, err) for _, rt := range rtOutput.RouteTables { // "You must disassociate the route table from any subnets before you can delete it." for _, assoc := range rt.Associations { - _, disassocErr := ec2Client.DisassociateRouteTable(&ec2.DisassociateRouteTableInput{ + _, disassocErr := ec2Client.DisassociateRouteTable(context.Background(), &ec2.DisassociateRouteTableInput{ AssociationId: assoc.RouteTableAssociationId, }) require.NoError(t, disassocErr) } - _, err := ec2Client.DeleteRouteTable(&ec2.DeleteRouteTableInput{ + _, err := ec2Client.DeleteRouteTable(context.Background(), &ec2.DeleteRouteTableInput{ RouteTableId: rt.RouteTableId, }) require.NoError(t, err) @@ -264,15 +266,15 @@ func deleteRouteTables(t *testing.T, vpcId string, region string) { } func deleteSubnets(t *testing.T, vpcId string, region string) { - ec2Client := NewEc2Client(t, region) + ec2Client := NewEc2ClientV2(t, region) vpcIDFilterName := "vpc-id" - vpcIDFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcId}} + vpcIDFilter := types.Filter{Name: &vpcIDFilterName, Values: []string{vpcId}} - subnetsOutput, err := ec2Client.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&vpcIDFilter}}) + subnetsOutput, err := ec2Client.DescribeSubnets(context.Background(), &ec2.DescribeSubnetsInput{Filters: []types.Filter{vpcIDFilter}}) require.NoError(t, err) for _, subnet := range subnetsOutput.Subnets { - _, err := ec2Client.DeleteSubnet(&ec2.DeleteSubnetInput{ + _, err := ec2Client.DeleteSubnet(context.Background(), &ec2.DeleteSubnetInput{ SubnetId: subnet.SubnetId, }) require.NoError(t, err) @@ -280,22 +282,22 @@ func deleteSubnets(t *testing.T, vpcId string, region string) { } func deleteInternetGateways(t *testing.T, vpcId string, region string) { - ec2Client := NewEc2Client(t, region) + ec2Client := NewEc2ClientV2(t, region) vpcIDFilterName := "attachment.vpc-id" - vpcIDFilter := ec2.Filter{Name: &vpcIDFilterName, Values: []*string{&vpcId}} + vpcIDFilter := types.Filter{Name: &vpcIDFilterName, Values: []string{vpcId}} - igwOutput, err := ec2Client.DescribeInternetGateways(&ec2.DescribeInternetGatewaysInput{Filters: []*ec2.Filter{&vpcIDFilter}}) + igwOutput, err := ec2Client.DescribeInternetGateways(context.Background(), &ec2.DescribeInternetGatewaysInput{Filters: []types.Filter{vpcIDFilter}}) require.NoError(t, err) for _, igw := range igwOutput.InternetGateways { - _, detachErr := ec2Client.DetachInternetGateway(&ec2.DetachInternetGatewayInput{ + _, detachErr := ec2Client.DetachInternetGateway(context.Background(), &ec2.DetachInternetGatewayInput{ InternetGatewayId: igw.InternetGatewayId, VpcId: aws.String(vpcId), }) require.NoError(t, detachErr) - _, err := ec2Client.DeleteInternetGateway(&ec2.DeleteInternetGatewayInput{ + _, err := ec2Client.DeleteInternetGateway(context.Background(), &ec2.DeleteInternetGatewayInput{ InternetGatewayId: igw.InternetGatewayId, }) require.NoError(t, err) @@ -303,13 +305,13 @@ func deleteInternetGateways(t *testing.T, vpcId string, region string) { } func deleteVpc(t *testing.T, vpcId string, region string) { - ec2Client := NewEc2Client(t, region) + ec2Client := NewEc2ClientV2(t, region) deleteRouteTables(t, vpcId, region) deleteSubnets(t, vpcId, region) deleteInternetGateways(t, vpcId, region) - _, err := ec2Client.DeleteVpc(&ec2.DeleteVpcInput{ + _, err := ec2Client.DeleteVpc(context.Background(), &ec2.DeleteVpcInput{ VpcId: aws.String(vpcId), }) require.NoError(t, err) diff --git a/test/packer_basic_example_test.go b/test/packer_basic_example_test.go index 2952c941b..fb16c5016 100644 --- a/test/packer_basic_example_test.go +++ b/test/packer_basic_example_test.go @@ -1,13 +1,15 @@ package test import ( + "context" "fmt" "os" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" terratest_aws "github.com/gruntwork-io/terratest/modules/aws" "github.com/gruntwork-io/terratest/modules/packer" "github.com/gruntwork-io/terratest/modules/random" @@ -110,7 +112,7 @@ func TestPackerBasicExampleWithVarFile(t *testing.T) { // The path to where the Packer template is located Template: "../examples/packer-basic-example/build.pkr.hcl", - // Variable file to to pass to our Packer build using -var-file option + // Variable file to pass to our Packer build using -var-file option VarFiles: []string{ varFile.Name(), }, @@ -196,35 +198,35 @@ func TestPackerMultipleConcurrentAmis(t *testing.T) { } } -func ShareAmi(t *testing.T, amiID string, accountID string, ec2Client *ec2.EC2) { +func ShareAmi(t *testing.T, amiID string, accountID string, ec2Client *ec2.Client) { input := &ec2.ModifyImageAttributeInput{ ImageId: aws.String(amiID), - LaunchPermission: &ec2.LaunchPermissionModifications{ - Add: []*ec2.LaunchPermission{ + LaunchPermission: &types.LaunchPermissionModifications{ + Add: []types.LaunchPermission{ { UserId: aws.String(accountID), }, }, }, } - _, err := ec2Client.ModifyImageAttribute(input) + _, err := ec2Client.ModifyImageAttribute(context.Background(), input) if err != nil { t.Fatal(err) } } -func MakeAmiPublic(t *testing.T, amiID string, ec2Client *ec2.EC2) { +func MakeAmiPublic(t *testing.T, amiID string, ec2Client *ec2.Client) { input := &ec2.ModifyImageAttributeInput{ ImageId: aws.String(amiID), - LaunchPermission: &ec2.LaunchPermissionModifications{ - Add: []*ec2.LaunchPermission{ + LaunchPermission: &types.LaunchPermissionModifications{ + Add: []types.LaunchPermission{ { - Group: aws.String("all"), + Group: types.PermissionGroupAll, }, }, }, } - _, err := ec2Client.ModifyImageAttribute(input) + _, err := ec2Client.ModifyImageAttribute(context.Background(), input) if err != nil { t.Fatal(err) } diff --git a/test/terraform_aws_dynamodb_example_test.go b/test/terraform_aws_dynamodb_example_test.go index c3b5ff3d9..271402e22 100644 --- a/test/terraform_aws_dynamodb_example_test.go +++ b/test/terraform_aws_dynamodb_example_test.go @@ -4,8 +4,8 @@ import ( "fmt" "testing" - awsSDK "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" + awsSDK "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/gruntwork-io/terratest/modules/aws" "github.com/gruntwork-io/terratest/modules/random" "github.com/gruntwork-io/terratest/modules/terraform" @@ -22,11 +22,11 @@ func TestTerraformAwsDynamoDBExample(t *testing.T) { // Set up expected values to be checked later expectedTableName := fmt.Sprintf("terratest-aws-dynamodb-example-table-%s", random.UniqueId()) expectedKmsKeyArn := aws.GetCmkArn(t, awsRegion, "alias/aws/dynamodb") - expectedKeySchema := []*dynamodb.KeySchemaElement{ - {AttributeName: awsSDK.String("userId"), KeyType: awsSDK.String("HASH")}, - {AttributeName: awsSDK.String("department"), KeyType: awsSDK.String("RANGE")}, + expectedKeySchema := []*types.KeySchemaElement{ + {AttributeName: awsSDK.String("userId"), KeyType: types.KeyTypeHash}, + {AttributeName: awsSDK.String("department"), KeyType: types.KeyTypeRange}, } - expectedTags := []*dynamodb.Tag{ + expectedTags := []*types.Tag{ {Key: awsSDK.String("Environment"), Value: awsSDK.String("production")}, } @@ -52,18 +52,18 @@ func TestTerraformAwsDynamoDBExample(t *testing.T) { // Look up the DynamoDB table by name table := aws.GetDynamoDBTable(t, awsRegion, expectedTableName) - assert.Equal(t, "ACTIVE", awsSDK.StringValue(table.TableStatus)) + assert.Equal(t, "ACTIVE", string(table.TableStatus)) assert.ElementsMatch(t, expectedKeySchema, table.KeySchema) // Verify server-side encryption configuration - assert.Equal(t, expectedKmsKeyArn, awsSDK.StringValue(table.SSEDescription.KMSMasterKeyArn)) - assert.Equal(t, "ENABLED", awsSDK.StringValue(table.SSEDescription.Status)) - assert.Equal(t, "KMS", awsSDK.StringValue(table.SSEDescription.SSEType)) + assert.Equal(t, expectedKmsKeyArn, awsSDK.ToString(table.SSEDescription.KMSMasterKeyArn)) + assert.Equal(t, "ENABLED", string(table.SSEDescription.Status)) + assert.Equal(t, "KMS", string(table.SSEDescription.SSEType)) // Verify TTL configuration ttl := aws.GetDynamoDBTableTimeToLive(t, awsRegion, expectedTableName) - assert.Equal(t, "expires", awsSDK.StringValue(ttl.AttributeName)) - assert.Equal(t, "ENABLED", awsSDK.StringValue(ttl.TimeToLiveStatus)) + assert.Equal(t, "expires", awsSDK.ToString(ttl.AttributeName)) + assert.Equal(t, "ENABLED", string(ttl.TimeToLiveStatus)) // Verify resource tags tags := aws.GetDynamoDbTableTags(t, awsRegion, expectedTableName)