diff --git a/tfawserr/awserr.go b/tfawserr/awserr.go index d919f400..54596ca6 100644 --- a/tfawserr/awserr.go +++ b/tfawserr/awserr.go @@ -7,6 +7,7 @@ import ( "strings" smithy "github.com/aws/smithy-go" + smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/hashicorp/aws-sdk-go-base/v2/internal/errs" ) @@ -34,3 +35,17 @@ func ErrMessageContains(err error, code string, message string) bool { } return false } + +// ErrHTTPStatusCodeEquals returns true if the error matches all these conditions: +// - err is of type smithyhttp.ResponseError +// - ResponseError.HTTPStatusCode() equals one of the passed status codes +func ErrHTTPStatusCodeEquals(err error, statusCodes ...int) bool { + if respErr, ok := errs.As[*smithyhttp.ResponseError](err); ok { + for _, statusCode := range statusCodes { + if respErr.HTTPStatusCode() == statusCode { + return true + } + } + } + return false +} diff --git a/tfawserr/awserr_test.go b/tfawserr/awserr_test.go index 7d1d83ae..9a59a94b 100644 --- a/tfawserr/awserr_test.go +++ b/tfawserr/awserr_test.go @@ -5,11 +5,13 @@ package tfawserr import ( "fmt" + "net/http" "testing" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sts/types" smithy "github.com/aws/smithy-go" + smithyhttp "github.com/aws/smithy-go/transport/http" ) func TestErrCodeEquals(t *testing.T) { @@ -224,3 +226,63 @@ func TestErrMessageContains(t *testing.T) { }) } } + +func TestErrHTTPStatusCodeEquals(t *testing.T) { + testCases := map[string]struct { + Err error + Codes []int + Expected bool + }{ + "nil error": { + Err: nil, + Expected: false, + }, + "other error": { + Err: fmt.Errorf("other error"), + Expected: false, + }, + "Top-level smithyhttp.ResponseError matching first code": { + Err: &smithyhttp.ResponseError{Response: &smithyhttp.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}}, + Codes: []int{http.StatusNotFound}, + Expected: true, + }, + "Top-level smithyhttp.ResponseError matching last code": { + Err: &smithyhttp.ResponseError{Response: &smithyhttp.Response{Response: &http.Response{StatusCode: http.StatusOK}}}, + Codes: []int{http.StatusNotFound, http.StatusOK}, + Expected: true, + }, + "Top-level smithyhttp.ResponseError no code": { + Err: &smithyhttp.ResponseError{Response: &smithyhttp.Response{Response: &http.Response{StatusCode: http.StatusOK}}}, + }, + "Top-level smithyhttp.ResponseError non-matching codes": { + Err: &smithyhttp.ResponseError{Response: &smithyhttp.Response{Response: &http.Response{StatusCode: http.StatusOK}}}, + Codes: []int{http.StatusNotFound, http.StatusNoContent}, + }, + "Wrapped smithyhttp.ResponseError matching first code": { + Err: &smithy.OperationError{Err: &smithyhttp.ResponseError{Response: &smithyhttp.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}}}, + Codes: []int{http.StatusNotFound}, + Expected: true, + }, + "Wrapped smithyhttp.ResponseError matching last code": { + Err: &smithy.OperationError{Err: &smithyhttp.ResponseError{Response: &smithyhttp.Response{Response: &http.Response{StatusCode: http.StatusOK}}}}, + Codes: []int{http.StatusNotFound, http.StatusOK}, + Expected: true, + }, + "Wrapped smithyhttp.ResponseError non-matching codes": { + Err: &smithy.OperationError{Err: &smithyhttp.ResponseError{Response: &smithyhttp.Response{Response: &http.Response{StatusCode: http.StatusOK}}}}, + Codes: []int{http.StatusNotFound, http.StatusNoContent}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + + t.Run(name, func(t *testing.T) { + got := ErrHTTPStatusCodeEquals(testCase.Err, testCase.Codes...) + + if got != testCase.Expected { + t.Errorf("got %t, expected %t", got, testCase.Expected) + } + }) + } +}