diff --git a/Makefile b/Makefile index ef9e8c1..db03715 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -all: format build test +all: format tidy build test build: env GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" ./... @@ -8,3 +8,6 @@ format: test: go test ./... + +tidy: + go mod tidy diff --git a/go.mod b/go.mod index 94ecabe..8dbab40 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,18 @@ module github.com/seantcanavan/lambda_jwt_router -go 1.18 +go 1.21 require ( + cloud.google.com/go v0.111.0 github.com/aws/aws-lambda-go v1.37.0 github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/jgroeneveld/trial v2.0.0+incompatible github.com/joho/godotenv v1.4.0 + github.com/stretchr/testify v1.8.4 + go.mongodb.org/mongo-driver v1.13.1 ) -require github.com/jgroeneveld/schema v1.0.0 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 4e7e90f..bd815ad 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,60 @@ +cloud.google.com/go v0.111.0 h1:YHLKNupSD1KqjDbQ3+LVdQ81h/UJbJyZG203cEfnQgM= +cloud.google.com/go v0.111.0/go.mod h1:0mibmpKP1TyOOFYQY5izo0LnT+ecvOQ0Sg3OdmMiNRU= github.com/aws/aws-lambda-go v1.37.0 h1:WXkQ/xhIcXZZ2P5ZBEw+bbAKeCEcb5NtiYpSwVVzIXg= github.com/aws/aws-lambda-go v1.37.0/go.mod h1:jwFe2KmMsHmffA1X2R09hH6lFzJQxzI8qK17ewzbQMM= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/jgroeneveld/schema v1.0.0 h1:J0E10CrOkiSEsw6dfb1IfrDJD14pf6QLVJ3tRPl/syI= -github.com/jgroeneveld/schema v1.0.0/go.mod h1:M14lv7sNMtGvo3ops1MwslaSYgDYxrSmbzWIQ0Mr5rs= -github.com/jgroeneveld/trial v2.0.0+incompatible h1:d59ctdgor+VqdZCAiUfVN8K13s0ALDioG5DWwZNtRuQ= -github.com/jgroeneveld/trial v2.0.0+incompatible/go.mod h1:I6INLW96EN8WysNBXUFI3M4RIC8ePg9ntAc/Wy+U/+M= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mongodb.org/mongo-driver v1.13.1 h1:YIc7HTYsKndGK4RFzJ3covLz1byri52x0IoMB0Pt/vk= +go.mongodb.org/mongo-driver v1.13.1/go.mod h1:wcDf1JBCXy2mOW0bWHwO/IOYqdca1MPCwDtFu/Z9+eo= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lambda_jwt/jwt_auth_lambda_test.go b/lambda_jwt/jwt_auth_lambda_test.go index 1e87d78..10cd997 100644 --- a/lambda_jwt/jwt_auth_lambda_test.go +++ b/lambda_jwt/jwt_auth_lambda_test.go @@ -5,9 +5,9 @@ import ( "errors" "github.com/aws/aws-lambda-go/events" "github.com/golang-jwt/jwt" - "github.com/jgroeneveld/trial/assert" "github.com/seantcanavan/lambda_jwt_router/lambda_router" "github.com/seantcanavan/lambda_jwt_router/lambda_util" + "github.com/stretchr/testify/require" "net/http" "testing" "time" @@ -24,15 +24,15 @@ func TestAllowOptionsMW(t *testing.T) { // we pass along an error handler but expect http.StatusOK because the AllowOptions handler should execute first jwtMiddlewareHandler := AllowOptionsMW(GenerateEmptySuccessHandler()) res, err := jwtMiddlewareHandler(nil, req) - assert.Nil(t, err) - assert.Equal(t, res.StatusCode, http.StatusOK) + require.Nil(t, err) + require.Equal(t, res.StatusCode, http.StatusOK) }) t.Run("verify OPTIONS req succeeds with invalid JWT for AllowOptions", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() signedJWT, err := Sign(nil) - assert.Nil(t, err) + require.Nil(t, err) signedJWT = signedJWT + "hi" // create an invalid JWT @@ -47,8 +47,8 @@ func TestAllowOptionsMW(t *testing.T) { // we pass along an error handler but expect http.StatusOK because the AllowOptions handler should execute first jwtMiddlewareHandler := AllowOptionsMW(GenerateEmptySuccessHandler()) res, err := jwtMiddlewareHandler(ctx, req) - assert.Nil(t, err) - assert.Equal(t, res.StatusCode, http.StatusOK) + require.Nil(t, err) + require.Equal(t, res.StatusCode, http.StatusOK) }) t.Run("verify OPTIONS req succeeds with no Authorization header for AllowOptions", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) @@ -63,8 +63,8 @@ func TestAllowOptionsMW(t *testing.T) { // we pass along an error handler but expect http.StatusOK because the AllowOptions handler should execute first jwtMiddlewareHandler := AllowOptionsMW(GenerateEmptySuccessHandler()) res, err := jwtMiddlewareHandler(ctx, req) - assert.Nil(t, err) - assert.Equal(t, res.StatusCode, http.StatusOK) + require.Nil(t, err) + require.Equal(t, res.StatusCode, http.StatusOK) }) } @@ -73,15 +73,15 @@ func TestDecodeAndInjectExpandedClaims(t *testing.T) { req := events.APIGatewayProxyRequest{} jwtMiddlewareHandler := DecodeExpanded(GenerateEmptyErrorHandler()) res, err := jwtMiddlewareHandler(nil, req) - assert.Nil(t, err) - assert.Equal(t, res.StatusCode, http.StatusBadRequest) + require.Nil(t, err) + require.Equal(t, res.StatusCode, http.StatusBadRequest) var responseBody lambda_router.HTTPError err = lambda_router.UnmarshalRes(res, &responseBody) - assert.Nil(t, err) + require.Nil(t, err) - assert.Equal(t, responseBody.Status, res.StatusCode) - assert.Equal(t, responseBody.Message, ErrNoAuthorizationHeader.Error()) + require.Equal(t, responseBody.Status, res.StatusCode) + require.Equal(t, responseBody.Message, ErrNoAuthorizationHeader.Error()) }) t.Run("verify context is returned by DecodeExpanded with a signed JWT", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) @@ -90,7 +90,7 @@ func TestDecodeAndInjectExpandedClaims(t *testing.T) { expandedClaims := generateExpandedMapClaims() signedJWT, err := Sign(expandedClaims) - assert.Nil(t, err) + require.Nil(t, err) req := events.APIGatewayProxyRequest{ HTTPMethod: http.MethodGet, @@ -102,29 +102,29 @@ func TestDecodeAndInjectExpandedClaims(t *testing.T) { jwtMiddlewareHandler := DecodeExpanded(generateSuccessHandlerAndMapExpandedContext()) res, err := jwtMiddlewareHandler(ctx, req) - assert.Nil(t, err) - assert.Equal(t, res.StatusCode, http.StatusOK) + require.Nil(t, err) + require.Equal(t, res.StatusCode, http.StatusOK) var returnedClaims ExpandedClaims err = lambda_router.UnmarshalRes(res, &returnedClaims) - assert.Nil(t, err) + require.Nil(t, err) // this verifies that the context gets set in the middleware inject function since the // dummy handler passed to it as the 'next' call injects the values from its passed // context object into the response body. The function doesn't work this way in practice // however it does allow me to fully unit test it to make sure the context setting is working. // It's hacky and I'm not proud of it but I'm not sure how else to do it. - assert.Equal(t, expandedClaims[AudienceKey], returnedClaims.Audience) - assert.Equal(t, expandedClaims[EmailKey], returnedClaims.Email) - assert.Equal(t, expandedClaims[ExpiresAtKey], returnedClaims.ExpiresAt) - assert.Equal(t, expandedClaims[FirstNameKey], returnedClaims.FirstName) - assert.Equal(t, expandedClaims[FullNameKey], returnedClaims.FullName) - assert.Equal(t, expandedClaims[IDKey], returnedClaims.ID) - assert.Equal(t, expandedClaims[IssuedAtKey], returnedClaims.IssuedAt) - assert.Equal(t, expandedClaims[IssuerKey], returnedClaims.Issuer) - assert.Equal(t, expandedClaims[LevelKey], returnedClaims.Level) - assert.Equal(t, expandedClaims[NotBeforeKey], returnedClaims.NotBefore) - assert.Equal(t, expandedClaims[SubjectKey], returnedClaims.Subject) - assert.Equal(t, expandedClaims[UserTypeKey], returnedClaims.UserType) + require.Equal(t, expandedClaims[AudienceKey], returnedClaims.Audience) + require.Equal(t, expandedClaims[EmailKey], returnedClaims.Email) + require.Equal(t, expandedClaims[ExpiresAtKey], returnedClaims.ExpiresAt) + require.Equal(t, expandedClaims[FirstNameKey], returnedClaims.FirstName) + require.Equal(t, expandedClaims[FullNameKey], returnedClaims.FullName) + require.Equal(t, expandedClaims[IDKey], returnedClaims.ID) + require.Equal(t, expandedClaims[IssuedAtKey], returnedClaims.IssuedAt) + require.Equal(t, expandedClaims[IssuerKey], returnedClaims.Issuer) + require.Equal(t, expandedClaims[LevelKey], returnedClaims.Level) + require.Equal(t, expandedClaims[NotBeforeKey], returnedClaims.NotBefore) + require.Equal(t, expandedClaims[SubjectKey], returnedClaims.Subject) + require.Equal(t, expandedClaims[UserTypeKey], returnedClaims.UserType) }) } @@ -133,15 +133,15 @@ func TestDecodeAndInjectStandardClaims(t *testing.T) { req := events.APIGatewayProxyRequest{} jwtMiddlewareHandler := DecodeStandard(GenerateEmptyErrorHandler()) res, err := jwtMiddlewareHandler(nil, req) - assert.Nil(t, err) - assert.Equal(t, res.StatusCode, http.StatusBadRequest) + require.Nil(t, err) + require.Equal(t, res.StatusCode, http.StatusBadRequest) var responseBody lambda_router.HTTPError err = lambda_router.UnmarshalRes(res, &responseBody) - assert.Nil(t, err) + require.Nil(t, err) - assert.Equal(t, responseBody.Status, res.StatusCode) - assert.Equal(t, responseBody.Message, ErrNoAuthorizationHeader.Error()) + require.Equal(t, responseBody.Status, res.StatusCode) + require.Equal(t, responseBody.Message, ErrNoAuthorizationHeader.Error()) }) t.Run("verify context is returned by DecodeStandard with a signed JWT", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) @@ -150,7 +150,7 @@ func TestDecodeAndInjectStandardClaims(t *testing.T) { standardClaims := generateStandardMapClaims() signedJWT, err := Sign(standardClaims) - assert.Nil(t, err) + require.Nil(t, err) req := events.APIGatewayProxyRequest{ HTTPMethod: http.MethodGet, @@ -162,103 +162,103 @@ func TestDecodeAndInjectStandardClaims(t *testing.T) { jwtMiddlewareHandler := DecodeStandard(generateSuccessHandlerAndMapStandardContext()) res, err := jwtMiddlewareHandler(ctx, req) - assert.Nil(t, err) - assert.Equal(t, res.StatusCode, http.StatusOK) + require.Nil(t, err) + require.Equal(t, res.StatusCode, http.StatusOK) var returnedClaims jwt.StandardClaims err = lambda_router.UnmarshalRes(res, &returnedClaims) - assert.Nil(t, err) + require.Nil(t, err) // this verifies that the context gets set in the middleware inject function since the // dummy handler passed to it as the 'next' call injects the values from its passed // context object into the response body. The function doesn't work this way in practice // however it does allow me to fully unit test it to make sure the context setting is working. // It's hacky and I'm not proud of it but I'm not sure how else to do it. - assert.Equal(t, returnedClaims.Audience, standardClaims[AudienceKey]) - assert.Equal(t, returnedClaims.ExpiresAt, standardClaims[ExpiresAtKey]) - assert.Equal(t, returnedClaims.Id, standardClaims[IDKey]) - assert.Equal(t, returnedClaims.IssuedAt, standardClaims[IssuedAtKey]) - assert.Equal(t, returnedClaims.Issuer, standardClaims[IssuerKey]) - assert.Equal(t, returnedClaims.NotBefore, standardClaims[NotBeforeKey]) - assert.Equal(t, returnedClaims.Subject, standardClaims[SubjectKey]) + require.Equal(t, returnedClaims.Audience, standardClaims[AudienceKey]) + require.Equal(t, returnedClaims.ExpiresAt, standardClaims[ExpiresAtKey]) + require.Equal(t, returnedClaims.Id, standardClaims[IDKey]) + require.Equal(t, returnedClaims.IssuedAt, standardClaims[IssuedAtKey]) + require.Equal(t, returnedClaims.Issuer, standardClaims[IssuerKey]) + require.Equal(t, returnedClaims.NotBefore, standardClaims[NotBeforeKey]) + require.Equal(t, returnedClaims.Subject, standardClaims[SubjectKey]) }) } func TestExtractJWT(t *testing.T) { standardClaims := generateStandardMapClaims() signedJWT, err := Sign(standardClaims) - assert.Nil(t, err) + require.Nil(t, err) t.Run("verify ExtractJWT returns err for empty Authorization header", func(t *testing.T) { headers := map[string]string{"Authorization": ""} mapClaims, httpStatus, extractErr := ExtractJWT(headers) - assert.True(t, len(mapClaims) == 0) - assert.Equal(t, http.StatusBadRequest, httpStatus) - assert.NotNil(t, extractErr) - assert.True(t, errors.Is(extractErr, ErrNoAuthorizationHeader)) + require.True(t, len(mapClaims) == 0) + require.Equal(t, http.StatusBadRequest, httpStatus) + require.NotNil(t, extractErr) + require.True(t, errors.Is(extractErr, ErrNoAuthorizationHeader)) }) t.Run("verify ExtractJWT returns err for Authorization header misspelled - all caps", func(t *testing.T) { headers := map[string]string{"AUTHORIZATION": signedJWT} mapClaims, httpStatus, extractErr := ExtractJWT(headers) - assert.True(t, len(mapClaims) == 0) - assert.Equal(t, http.StatusBadRequest, httpStatus) - assert.NotNil(t, extractErr) - assert.True(t, errors.Is(extractErr, ErrNoAuthorizationHeader)) + require.True(t, len(mapClaims) == 0) + require.Equal(t, http.StatusBadRequest, httpStatus) + require.NotNil(t, extractErr) + require.True(t, errors.Is(extractErr, ErrNoAuthorizationHeader)) }) t.Run("verify ExtractJWT returns err for Authorization header misspelled - lowercase", func(t *testing.T) { headers := map[string]string{"authorization": signedJWT} mapClaims, httpStatus, extractErr := ExtractJWT(headers) - assert.True(t, len(mapClaims) == 0) - assert.Equal(t, http.StatusBadRequest, httpStatus) - assert.NotNil(t, extractErr) - assert.True(t, errors.Is(extractErr, ErrNoAuthorizationHeader)) + require.True(t, len(mapClaims) == 0) + require.Equal(t, http.StatusBadRequest, httpStatus) + require.NotNil(t, extractErr) + require.True(t, errors.Is(extractErr, ErrNoAuthorizationHeader)) }) t.Run("verify ExtractJWT returns err for bearer prefix not used", func(t *testing.T) { headers := map[string]string{"Authorization": signedJWT} mapClaims, httpStatus, extractErr := ExtractJWT(headers) - assert.True(t, len(mapClaims) == 0) - assert.Equal(t, http.StatusBadRequest, httpStatus) - assert.NotNil(t, extractErr) - assert.True(t, errors.Is(extractErr, ErrNoBearerPrefix)) + require.True(t, len(mapClaims) == 0) + require.Equal(t, http.StatusBadRequest, httpStatus) + require.NotNil(t, extractErr) + require.True(t, errors.Is(extractErr, ErrNoBearerPrefix)) }) t.Run("verify ExtractJWT returns err for bearer not camel cased", func(t *testing.T) { headers := map[string]string{"Authorization": "bearer " + signedJWT} mapClaims, httpStatus, extractErr := ExtractJWT(headers) - assert.True(t, len(mapClaims) == 0) - assert.Equal(t, http.StatusBadRequest, httpStatus) - assert.NotNil(t, extractErr) - assert.True(t, errors.Is(extractErr, ErrNoBearerPrefix)) + require.True(t, len(mapClaims) == 0) + require.Equal(t, http.StatusBadRequest, httpStatus) + require.NotNil(t, extractErr) + require.True(t, errors.Is(extractErr, ErrNoBearerPrefix)) }) t.Run("verify ExtractJWT returns err for BEARER all caps", func(t *testing.T) { headers := map[string]string{"Authorization": "BEARER " + signedJWT} mapClaims, httpStatus, extractErr := ExtractJWT(headers) - assert.True(t, len(mapClaims) == 0) - assert.Equal(t, http.StatusBadRequest, httpStatus) - assert.NotNil(t, extractErr) - assert.True(t, errors.Is(extractErr, ErrNoBearerPrefix)) + require.True(t, len(mapClaims) == 0) + require.Equal(t, http.StatusBadRequest, httpStatus) + require.NotNil(t, extractErr) + require.True(t, errors.Is(extractErr, ErrNoBearerPrefix)) }) t.Run("verify ExtractJWT returns err for Bearer does not end with space", func(t *testing.T) { headers := map[string]string{"Authorization": "Bearer" + signedJWT} mapClaims, httpStatus, extractErr := ExtractJWT(headers) - assert.True(t, len(mapClaims) == 0) - assert.Equal(t, http.StatusBadRequest, httpStatus) - assert.NotNil(t, extractErr) - assert.True(t, errors.Is(extractErr, ErrNoBearerPrefix)) + require.True(t, len(mapClaims) == 0) + require.Equal(t, http.StatusBadRequest, httpStatus) + require.NotNil(t, extractErr) + require.True(t, errors.Is(extractErr, ErrNoBearerPrefix)) }) t.Run("verify ExtractJWT returns claims correctly with valid input", func(t *testing.T) { headers := map[string]string{"Authorization": "Bearer " + signedJWT} mapClaims, httpStatus, extractErr := ExtractJWT(headers) - assert.True(t, len(mapClaims) == 7) - assert.Equal(t, http.StatusOK, httpStatus) - assert.Nil(t, extractErr) - assert.Nil(t, extractErr) - - assert.Equal(t, mapClaims[AudienceKey], mapClaims[AudienceKey]) - assert.Equal(t, mapClaims[ExpiresAtKey], mapClaims[ExpiresAtKey]) - assert.Equal(t, mapClaims[IDKey], mapClaims[IDKey]) - assert.Equal(t, mapClaims[IssuedAtKey], mapClaims[IssuedAtKey]) - assert.Equal(t, mapClaims[IssuerKey], mapClaims[IssuerKey]) - assert.Equal(t, mapClaims[NotBeforeKey], mapClaims[NotBeforeKey]) - assert.Equal(t, mapClaims[SubjectKey], mapClaims[SubjectKey]) + require.True(t, len(mapClaims) == 7) + require.Equal(t, http.StatusOK, httpStatus) + require.Nil(t, extractErr) + require.Nil(t, extractErr) + + require.Equal(t, mapClaims[AudienceKey], mapClaims[AudienceKey]) + require.Equal(t, mapClaims[ExpiresAtKey], mapClaims[ExpiresAtKey]) + require.Equal(t, mapClaims[IDKey], mapClaims[IDKey]) + require.Equal(t, mapClaims[IssuedAtKey], mapClaims[IssuedAtKey]) + require.Equal(t, mapClaims[IssuerKey], mapClaims[IssuerKey]) + require.Equal(t, mapClaims[NotBeforeKey], mapClaims[NotBeforeKey]) + require.Equal(t, mapClaims[SubjectKey], mapClaims[SubjectKey]) }) } @@ -266,13 +266,13 @@ func TestGenerateEmptyErrorHandler(t *testing.T) { t.Run("verify empty error handler returns error", func(t *testing.T) { errHandler := GenerateEmptyErrorHandler() res, err := errHandler(nil, lambda_util.GenerateRandomAPIGatewayProxyRequest()) - assert.Nil(t, err) // err handler embeds the error in the response, not the golang stack - assert.Equal(t, res.StatusCode, http.StatusInternalServerError) + require.Nil(t, err) // err handler embeds the error in the response, not the golang stack + require.Equal(t, res.StatusCode, http.StatusInternalServerError) var httpError lambda_router.HTTPError err = lambda_router.UnmarshalRes(res, &httpError) - assert.Nil(t, err) - assert.Equal(t, httpError.Status, http.StatusInternalServerError) - assert.Equal(t, httpError.Message, "this error is simulated") + require.Nil(t, err) + require.Equal(t, httpError.Status, http.StatusInternalServerError) + require.Equal(t, httpError.Message, "this error is simulated") }) } @@ -280,9 +280,9 @@ func TestGenerateEmptySuccessHandler(t *testing.T) { t.Run("verify empty success handler returns success", func(t *testing.T) { successHandler := GenerateEmptySuccessHandler() res, err := successHandler(nil, lambda_util.GenerateRandomAPIGatewayProxyRequest()) - assert.Nil(t, err) - assert.Equal(t, res.StatusCode, http.StatusOK) - assert.Equal(t, res.Body, "{}") // empty struct response + require.Nil(t, err) + require.Equal(t, res.StatusCode, http.StatusOK) + require.Equal(t, res.Body, "{}") // empty struct response }) } diff --git a/lambda_jwt/lambda_jwt_test.go b/lambda_jwt/lambda_jwt_test.go index 1b374e3..9632ae5 100644 --- a/lambda_jwt/lambda_jwt_test.go +++ b/lambda_jwt/lambda_jwt_test.go @@ -3,9 +3,9 @@ package lambda_jwt import ( "errors" "github.com/golang-jwt/jwt" - "github.com/jgroeneveld/trial/assert" "github.com/joho/godotenv" "github.com/seantcanavan/lambda_jwt_router/lambda_util" + "github.com/stretchr/testify/require" "log" "strings" "testing" @@ -48,27 +48,27 @@ func TestExtendExpandedClaims(t *testing.T) { t.Run("verify sign and verify expanded and custom fields in claims", func(t *testing.T) { signedJWT, signErr := Sign(extendedClaims) - assert.Nil(t, signErr) + require.Nil(t, signErr) retrievedClaims, verifyErr := VerifyJWT(signedJWT) - assert.Nil(t, verifyErr) + require.Nil(t, verifyErr) // verify the expanded claims values first - assert.Equal(t, retrievedClaims[AudienceKey], expandedClaims.Audience) - assert.Equal(t, retrievedClaims[ExpiresAtKey], float64(expandedClaims.ExpiresAt)) - assert.Equal(t, retrievedClaims[FirstNameKey], expandedClaims.FirstName) - assert.Equal(t, retrievedClaims[IDKey], expandedClaims.ID) - assert.Equal(t, retrievedClaims[IssuedAtKey], float64(expandedClaims.IssuedAt)) - assert.Equal(t, retrievedClaims[IssuerKey], expandedClaims.Issuer) - assert.Equal(t, retrievedClaims[LevelKey], expandedClaims.Level) - assert.Equal(t, retrievedClaims[NotBeforeKey], float64(expandedClaims.NotBefore)) - assert.Equal(t, retrievedClaims[SubjectKey], expandedClaims.Subject) - assert.Equal(t, retrievedClaims[UserTypeKey], expandedClaims.UserType) + require.Equal(t, retrievedClaims[AudienceKey], expandedClaims.Audience) + require.Equal(t, retrievedClaims[ExpiresAtKey], float64(expandedClaims.ExpiresAt)) + require.Equal(t, retrievedClaims[FirstNameKey], expandedClaims.FirstName) + require.Equal(t, retrievedClaims[IDKey], expandedClaims.ID) + require.Equal(t, retrievedClaims[IssuedAtKey], float64(expandedClaims.IssuedAt)) + require.Equal(t, retrievedClaims[IssuerKey], expandedClaims.Issuer) + require.Equal(t, retrievedClaims[LevelKey], expandedClaims.Level) + require.Equal(t, retrievedClaims[NotBeforeKey], float64(expandedClaims.NotBefore)) + require.Equal(t, retrievedClaims[SubjectKey], expandedClaims.Subject) + require.Equal(t, retrievedClaims[UserTypeKey], expandedClaims.UserType) // verify the custom claim values second - assert.Equal(t, retrievedClaims["hi"], "sean") - assert.Equal(t, retrievedClaims["hello"], "there") - assert.Equal(t, retrievedClaims["number"], float64(34)) + require.Equal(t, retrievedClaims["hi"], "sean") + require.Equal(t, retrievedClaims["hello"], "there") + require.Equal(t, retrievedClaims["number"], float64(34)) }) } @@ -91,24 +91,24 @@ func TestExtendStandardClaims(t *testing.T) { t.Run("verify sign and verify standard and custom fields in claims", func(t *testing.T) { signedJWT, signErr := Sign(extendedClaims) - assert.Nil(t, signErr) + require.Nil(t, signErr) retrievedClaims, verifyErr := VerifyJWT(signedJWT) - assert.Nil(t, verifyErr) + require.Nil(t, verifyErr) // verify the expanded claims values first - assert.Equal(t, retrievedClaims[AudienceKey], standardClaims.Audience) - assert.Equal(t, retrievedClaims[ExpiresAtKey], float64(standardClaims.ExpiresAt)) - assert.Equal(t, retrievedClaims[IDKey], standardClaims.Id) - assert.Equal(t, retrievedClaims[IssuedAtKey], float64(standardClaims.IssuedAt)) - assert.Equal(t, retrievedClaims[IssuerKey], standardClaims.Issuer) - assert.Equal(t, retrievedClaims[NotBeforeKey], float64(standardClaims.NotBefore)) - assert.Equal(t, retrievedClaims[SubjectKey], standardClaims.Subject) + require.Equal(t, retrievedClaims[AudienceKey], standardClaims.Audience) + require.Equal(t, retrievedClaims[ExpiresAtKey], float64(standardClaims.ExpiresAt)) + require.Equal(t, retrievedClaims[IDKey], standardClaims.Id) + require.Equal(t, retrievedClaims[IssuedAtKey], float64(standardClaims.IssuedAt)) + require.Equal(t, retrievedClaims[IssuerKey], standardClaims.Issuer) + require.Equal(t, retrievedClaims[NotBeforeKey], float64(standardClaims.NotBefore)) + require.Equal(t, retrievedClaims[SubjectKey], standardClaims.Subject) // verify the custom claim values second - assert.Equal(t, retrievedClaims["hi"], "sean") - assert.Equal(t, retrievedClaims["hello"], "there") - assert.Equal(t, retrievedClaims["number"], float64(34)) + require.Equal(t, retrievedClaims["hi"], "sean") + require.Equal(t, retrievedClaims["hello"], "there") + require.Equal(t, retrievedClaims["number"], float64(34)) }) } @@ -121,28 +121,28 @@ func TestExtractCustomClaims(t *testing.T) { "exp": lambda_util.GenerateRandomString(10), // exp should be an integer }, &badClaims{}) - assert.NotNil(t, extractCustomErr) - assert.True(t, errors.Is(extractCustomErr, ErrBadClaimsObject)) + require.NotNil(t, extractCustomErr) + require.True(t, errors.Is(extractCustomErr, ErrBadClaimsObject)) }) t.Run("verify ExtractCustom works when called with the correct parameters", func(t *testing.T) { customClaims := generateExpandedMapClaims() var expandedClaims ExpandedClaims err := ExtractCustom(customClaims, &expandedClaims) - assert.Nil(t, err) - - assert.Equal(t, customClaims[AudienceKey], expandedClaims.Audience) - assert.Equal(t, customClaims[ExpiresAtKey], expandedClaims.ExpiresAt) - assert.Equal(t, customClaims[EmailKey], expandedClaims.Email) - assert.Equal(t, customClaims[FirstNameKey], expandedClaims.FirstName) - assert.Equal(t, customClaims[FullNameKey], expandedClaims.FullName) - assert.Equal(t, customClaims[IDKey], expandedClaims.ID) - assert.Equal(t, customClaims[IssuedAtKey], expandedClaims.IssuedAt) - assert.Equal(t, customClaims[IssuerKey], expandedClaims.Issuer) - assert.Equal(t, customClaims[LevelKey], expandedClaims.Level) - assert.Equal(t, customClaims[NotBeforeKey], expandedClaims.NotBefore) - assert.Equal(t, customClaims[SubjectKey], expandedClaims.Subject) - assert.Equal(t, customClaims[UserTypeKey], expandedClaims.UserType) + require.Nil(t, err) + + require.Equal(t, customClaims[AudienceKey], expandedClaims.Audience) + require.Equal(t, customClaims[ExpiresAtKey], expandedClaims.ExpiresAt) + require.Equal(t, customClaims[EmailKey], expandedClaims.Email) + require.Equal(t, customClaims[FirstNameKey], expandedClaims.FirstName) + require.Equal(t, customClaims[FullNameKey], expandedClaims.FullName) + require.Equal(t, customClaims[IDKey], expandedClaims.ID) + require.Equal(t, customClaims[IssuedAtKey], expandedClaims.IssuedAt) + require.Equal(t, customClaims[IssuerKey], expandedClaims.Issuer) + require.Equal(t, customClaims[LevelKey], expandedClaims.Level) + require.Equal(t, customClaims[NotBeforeKey], expandedClaims.NotBefore) + require.Equal(t, customClaims[SubjectKey], expandedClaims.Subject) + require.Equal(t, customClaims[UserTypeKey], expandedClaims.UserType) }) } @@ -152,22 +152,22 @@ func TestExtractStandardClaims(t *testing.T) { "exp": lambda_util.GenerateRandomString(10), // exp should be an integer }, &jwt.StandardClaims{}) - assert.NotNil(t, extractCustomErr) - assert.True(t, errors.Is(extractCustomErr, ErrBadClaimsObject)) + require.NotNil(t, extractCustomErr) + require.True(t, errors.Is(extractCustomErr, ErrBadClaimsObject)) }) t.Run("verify ExtractCustom works when called with the correct parameters", func(t *testing.T) { customClaims := generateExpandedMapClaims() var standardClaims jwt.StandardClaims err := ExtractCustom(customClaims, &standardClaims) - assert.Nil(t, err) - - assert.Equal(t, customClaims[AudienceKey], standardClaims.Audience) - assert.Equal(t, customClaims[ExpiresAtKey], standardClaims.ExpiresAt) - assert.Equal(t, customClaims[IssuedAtKey], standardClaims.IssuedAt) - assert.Equal(t, customClaims[IssuerKey], standardClaims.Issuer) - assert.Equal(t, customClaims[NotBeforeKey], standardClaims.NotBefore) - assert.Equal(t, customClaims[SubjectKey], standardClaims.Subject) + require.Nil(t, err) + + require.Equal(t, customClaims[AudienceKey], standardClaims.Audience) + require.Equal(t, customClaims[ExpiresAtKey], standardClaims.ExpiresAt) + require.Equal(t, customClaims[IssuedAtKey], standardClaims.IssuedAt) + require.Equal(t, customClaims[IssuerKey], standardClaims.Issuer) + require.Equal(t, customClaims[NotBeforeKey], standardClaims.NotBefore) + require.Equal(t, customClaims[SubjectKey], standardClaims.Subject) }) } @@ -175,28 +175,28 @@ func TestSign(t *testing.T) { t.Run("verify signed jwt secret with valid standard claim", func(t *testing.T) { customClaims := generateExpandedMapClaims() signedJWT, err := Sign(customClaims) - assert.Nil(t, err) - assert.True(t, len(signedJWT) > 1) - assert.True(t, strings.HasPrefix(signedJWT, "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9")) + require.Nil(t, err) + require.True(t, len(signedJWT) > 1) + require.True(t, strings.HasPrefix(signedJWT, "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9")) }) } func TestVerifyJWT(t *testing.T) { t.Run("verify err when parsing invalid jwt", func(t *testing.T) { _, err := VerifyJWT(lambda_util.GenerateRandomString(10)) - assert.NotNil(t, err) - assert.True(t, errors.Is(err, ErrInvalidJWT)) + require.NotNil(t, err) + require.True(t, errors.Is(err, ErrInvalidJWT)) }) t.Run("verify err when parsing expired token with valid jwt", func(t *testing.T) { customClaims := generateExpandedMapClaims() customClaims["exp"] = time.Now().Add(time.Hour * -10) expiredJWT, signErr := Sign(customClaims) - assert.Nil(t, signErr) + require.Nil(t, signErr) _, err := VerifyJWT(expiredJWT) - assert.NotNil(t, err) - assert.True(t, errors.Is(err, ErrInvalidJWT)) + require.NotNil(t, err) + require.True(t, errors.Is(err, ErrInvalidJWT)) }) } diff --git a/lambda_router/decoder.go b/lambda_router/decoder.go index f341df2..c42ad24 100644 --- a/lambda_router/decoder.go +++ b/lambda_router/decoder.go @@ -1,10 +1,12 @@ package lambda_router import ( + "cloud.google.com/go/civil" "encoding/base64" "encoding/json" "errors" "fmt" + "go.mongodb.org/mongo-driver/bson/primitive" "log" "net/http" "reflect" @@ -16,7 +18,7 @@ import ( "github.com/aws/aws-lambda-go/events" ) -var boolRegex = regexp.MustCompile(`^1|true|on|enabled$`) +var boolRegex = regexp.MustCompile(`^1|true|on|enabled|t$`) // MarshalReq will take an interface input, marshal it to JSON, and add the // JSON as a string to the events.APIGatewayProxyRequest body field before returning. @@ -177,85 +179,202 @@ func unmarshalField( multiParam map[string][]string, param string, ) error { + strVal, ok := params[param] + strVals, okMulti := multiParam[param] + + if !ok && !okMulti { + return nil + } + + //fmt.Println(fmt.Sprintf("param %s", param)) + //fmt.Println(fmt.Sprintf("params[param] %s", strVal)) + //fmt.Println(fmt.Sprintf("multiParam[param] %+v", strVals)) + //fmt.Println(fmt.Sprintf("typeField.Name() %s", typeField.Name())) + //fmt.Println(fmt.Sprintf("typeField.Kind() %s", typeField.Kind())) + // + //if typeField.Kind() == reflect.Array || + // typeField.Kind() == reflect.Chan || + // typeField.Kind() == reflect.Map || + // typeField.Kind() == reflect.Ptr || + // typeField.Kind() == reflect.Slice { + // fmt.Println(fmt.Sprintf("typeField.Elem() %s", typeField.Elem())) + // fmt.Println(fmt.Sprintf("typeField.Elem().Kind() %s", typeField.Elem().Kind())) + //} + + //fmt.Println(fmt.Sprintf("valueField.Type() %s", valueField.Type())) + //fmt.Println(fmt.Sprintf("valueField.Kind() %s", valueField.Kind())) + // + //fmt.Print("\n\n\n") + switch typeField.Kind() { + case reflect.Array: + objectID, err := primitive.ObjectIDFromHex(strVal) + if err != nil { + return fmt.Errorf("invalid ObjectID: %s", err) + } + valueField.Set(reflect.ValueOf(objectID)) + case reflect.String: - valueField.SetString(params[param]) + valueField.SetString(strVal) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str, ok := params[param] - value, err := parseInt64Param(param, str, ok) + value, err := parseInt64Param(param, strVal, ok) if err != nil { return err } valueField.SetInt(value) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str, ok := params[param] - value, err := parseUint64Param(param, str, ok) + value, err := parseUint64Param(param, strVal, ok) if err != nil { return err } valueField.SetUint(value) case reflect.Float32, reflect.Float64: - str, ok := params[param] - value, err := parseFloat64Param(param, str, ok) + value, err := parseFloat64Param(param, strVal, ok) if err != nil { return err } valueField.SetFloat(value) case reflect.Bool: - valueField.SetBool(boolRegex.MatchString(strings.ToLower(params[param]))) + valueField.SetBool(boolRegex.MatchString(strings.ToLower(strVal))) case reflect.Ptr: - if val, ok := params[param]; ok { + if ok { switch typeField.Elem().Kind() { - case reflect.Int, reflect.Int32, reflect.Int64, reflect.String, reflect.Float32, reflect.Float64: - valueField.Set(reflect.ValueOf(&val).Convert(typeField)) + case reflect.String: + valueField.Set(reflect.ValueOf(&strVal).Convert(typeField)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + value, err := parseInt64Param(param, strVal, ok) + if err != nil { + return err + } + // Create a new pointer to the integer type + intPtr := reflect.New(typeField.Elem()) + // Set the value to the newly created pointer + intPtr.Elem().SetInt(value) + // Set the field to the new pointer + valueField.Set(intPtr) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + value, err := parseUint64Param(param, strVal, ok) + if err != nil { + return err + } + // Create a new pointer to the integer type + intPtr := reflect.New(typeField.Elem()) + // Set the value to the newly created pointer + intPtr.Elem().SetUint(value) + // Set the field to the new pointer + valueField.Set(intPtr) + case reflect.Float32, reflect.Float64: + value, err := parseFloat64Param(param, strVal, ok) + if err != nil { + return err + } + // Create a new pointer to the integer type + intPtr := reflect.New(typeField.Elem()) + // Set the value to the newly created pointer + intPtr.Elem().SetFloat(value) + // Set the field to the new pointer + valueField.Set(intPtr) case reflect.Struct: - if typeField.Elem() == reflect.TypeOf(time.Now()) { - parsedTime, err := time.Parse(time.RFC3339, val) + if typeField.Elem() == reflect.TypeOf(civil.Date{}) { + parsedCivil, err := civil.ParseDate(strVal) + if err != nil { + return err + } + valueField.Set(reflect.ValueOf(&parsedCivil)) + } else if typeField.Elem() == reflect.TypeOf(time.Time{}) { + parsedTime, err := time.Parse(time.RFC3339, strVal) if err != nil { return err } valueField.Set(reflect.ValueOf(&parsedTime)) } case reflect.Bool: - b := boolRegex.MatchString(strings.ToLower(val)) + b := boolRegex.MatchString(strings.ToLower(strVal)) valueField.Set(reflect.ValueOf(&b)) + // Handling mongo DB ID types + default: + switch typeField.Elem() { + case reflect.TypeOf(primitive.ObjectID{}): + objectID, err := primitive.ObjectIDFromHex(strVal) + if err != nil { + return fmt.Errorf("invalid ObjectID: %s", err) + } + valueField.Set(reflect.ValueOf(&objectID)) + } } } case reflect.Slice: - // we'll be extracting values from multiParam, generating a slice and - // putting it in valueField - strValues, ok := multiParam[param] - if ok { - slice := reflect.MakeSlice(typeField, len(strValues), len(strValues)) - - for i, str := range strValues { - err := unmarshalField( - typeField.Elem(), - slice.Index(i), - map[string]string{"param": str}, - nil, - "param", - ) - if err != nil { - return err - } + if typeField.Elem().Kind() == reflect.Ptr && typeField.Elem().Elem().Kind() == reflect.String { + // Handling the slice of pointers to custom string type (like Number) + stringValues := strVals + if !okMulti { + stringValues = strings.Split(strVal, ",") + } + slice := reflect.MakeSlice(typeField, len(stringValues), len(stringValues)) + + for i, strVal := range stringValues { + // Create a new instance of the element type (which is a pointer) + newElemPtr := reflect.New(typeField.Elem().Elem()) + // Set the value of the new instance + newElemPtr.Elem().SetString(strVal) + // Set the slice element to the new instance + slice.Index(i).Set(newElemPtr) } valueField.Set(slice) } else { - str, ok := params[param] - if ok { - stringParts := strings.Split(str, ",") - slice := reflect.MakeSlice(typeField, len(stringParts), len(stringParts)) - - for i, p := range stringParts { - inVal := reflect.ValueOf(p) - asVal := inVal.Convert(typeField.Elem()) - slice.Index(i).Set(asVal) + // we'll be extracting values from multiParam, generating a slice and + // putting it in valueField + if okMulti { + slice := reflect.MakeSlice(typeField, len(strVals), len(strVals)) + + for i, str := range strVals { + err := unmarshalField( + typeField.Elem(), + slice.Index(i), + map[string]string{"param": str}, + nil, + "param", + ) + if err != nil { + return err + } } valueField.Set(slice) + } else { + if ok { + stringParts := strings.Split(strVal, ",") + if len(stringParts) < 1 { + return nil + } + slice := reflect.MakeSlice(typeField, len(stringParts), len(stringParts)) + + for i, p := range stringParts { + inVal := reflect.ValueOf(p) + asVal := inVal.Convert(typeField.Elem()) + slice.Index(i).Set(asVal) + } + + valueField.Set(slice) + } + } + } + + case reflect.Struct: + switch valueField.Type() { + case reflect.TypeOf(time.Time{}): + parsedTime, err := time.Parse(time.RFC3339, strVal) + if err != nil { + return err + } + valueField.Set(reflect.ValueOf(parsedTime)) + case reflect.TypeOf(civil.Date{}): + parsedCivil, err := civil.ParseDate(strVal) + if err != nil { + return err } + valueField.Set(reflect.ValueOf(parsedCivil)) } } diff --git a/lambda_router/decoder_test.go b/lambda_router/decoder_test.go index 7026ea1..4bd83a4 100644 --- a/lambda_router/decoder_test.go +++ b/lambda_router/decoder_test.go @@ -1,22 +1,20 @@ package lambda_router import ( - "errors" + "cloud.google.com/go/civil" "github.com/seantcanavan/lambda_jwt_router/lambda_util" - "net/http" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson/primitive" + "strings" "testing" "time" "github.com/aws/aws-lambda-go/events" - "github.com/jgroeneveld/trial/assert" ) type mockConst string -const ( - mockConstTwo mockConst = "two" -) - type Number string const ( @@ -37,21 +35,44 @@ type mockGetReq struct { } type mockListReq struct { - ID string `lambda:"path.id"` - Page int64 `lambda:"query.page"` - PageSize int64 `lambda:"query.page_size"` - Terms []string `lambda:"query.terms"` - Numbers []float64 `lambda:"query.numbers"` - Const mockConst `lambda:"query.const"` - Bool bool `lambda:"query.bool"` - PBoolOne *bool `lambda:"query.pbool1"` - PBoolTwo *bool `lambda:"query.pbool2"` - Time *time.Time `lambda:"query.time"` - Alias stringAliasExample `lambda:"query.alias"` - AliasPtr *stringAliasExample `lambda:"query.alias_ptr"` - CommaSplit []Number `lambda:"query.commaSplit"` - Language string `lambda:"header.Accept-Language"` - Encoding []string `lambda:"header.Accept-Encoding"` + Alias stringAliasExample `lambda:"query.alias"` + AliasPtr *stringAliasExample `lambda:"query.alias_ptr"` + Bool1 bool `lambda:"query.bool1"` + Bool2 bool `lambda:"query.bool2"` + Bool3 bool `lambda:"query.bool3"` + Bool4 bool `lambda:"query.bool4"` + Bool5 bool `lambda:"query.bool5"` + Bool6 bool `lambda:"query.bool6"` + Bool7 bool `lambda:"query.bool7"` + Bool8 bool `lambda:"query.bool8"` + Bool9 bool `lambda:"query.bool9"` + Civil civil.Date `lambda:"query.civil"` + CivilPtr *civil.Date `lambda:"query.civilPtr"` + CivilPtrNil *civil.Date `lambda:"query.civilPtrNil"` + CommaSplit []Number `lambda:"query.commaSplit"` + CommaSplitPtr []*Number `lambda:"query.commaSplitPtr"` + Const mockConst `lambda:"query.const"` + ConstPtr *mockConst `lambda:"query.constPtr"` + ConstPtrNil *mockConst `lambda:"query.constPtrNil"` + Encoding []string `lambda:"header.Accept-Encoding"` + ID string `lambda:"path.id"` + IDs []*string `lambda:"query.ids"` + Language string `lambda:"header.Accept-Language"` + MongoID primitive.ObjectID `lambda:"query.mongoId"` + MongoIDPtr *primitive.ObjectID `lambda:"query.mongoIdPtr"` + MongoIDPtrNil *primitive.ObjectID `lambda:"query.mongoIdPtrNil"` + MongoIDs []primitive.ObjectID `lambda:"query.mongoIds"` + MongoIDsPtr []*primitive.ObjectID `lambda:"query.mongoIdsPtr"` + Number *float32 `lambda:"query.number"` + Numbers []float64 `lambda:"query.numbers"` + PBoolOne *bool `lambda:"query.pbool1"` + PBoolTwo *bool `lambda:"query.pbool2"` + Page int64 `lambda:"query.page"` + PageSize *int64 `lambda:"query.page_size"` + Terms []string `lambda:"query.terms"` + Time time.Time `lambda:"query.time"` + TimePtr *time.Time `lambda:"query.timePtr"` + TimePtrNil *time.Time `lambda:"query.timePtrNil"` } type mockPostReq struct { @@ -76,13 +97,16 @@ func TestMarshalLambdaRequest(t *testing.T) { var miParsed mockItem err := UnmarshalReq(req, true, &miParsed) assert.Nil(t, err) - assert.Equal(t, mi.ID, miParsed.ID) - assert.Equal(t, mi.Name, miParsed.Name) + require.Equal(t, mi.ID, miParsed.ID) + require.Equal(t, mi.Name, miParsed.Name) }) } func Test_UnmarshalReq(t *testing.T) { t.Run("valid path&query input", func(t *testing.T) { + mongoID1 := primitive.NewObjectID() + mongoID2 := primitive.NewObjectID() + var input mockListReq err := UnmarshalReq( events.APIGatewayProxyRequest{ @@ -90,19 +114,39 @@ func Test_UnmarshalReq(t *testing.T) { "id": "fake-scan-id", }, QueryStringParameters: map[string]string{ - "page": "2", - "page_size": "30", - "const": "two", - "bool": "true", - "pbool1": "0", - "time": "2021-11-01T11:11:11.000Z", - "alias": "hello", - "alias_ptr": "world", - "commaSplit": "one,two,three", + "alias": "hello", + "alias_ptr": "world", + "bool1": "1", + "bool2": "true", + "bool3": "on", + "bool4": "enabled", + "bool5": "t", + "bool6": "TRUE", + "bool7": "ON", + "bool8": "ENABLED", + "bool9": "T", + "civil": "2023-12-22", + "civilPtr": "2024-12-22", + "commaSplit": "one,two,three", + "commaSplitPtr": "one,two,three", + "const": "twenty", + "constPtr": "thirty", + "mongoId": mongoID1.Hex(), + "mongoIdPtr": mongoID1.Hex(), + "number": "90.10982", + "page": "2", + "page_size": "30", + "pbool1": "0", + "time": "2021-11-01T11:11:11.000Z", + "timePtr": "2021-11-01T11:11:11.000Z", }, MultiValueQueryStringParameters: map[string][]string{ - "terms": {"one", "two"}, - "numbers": {"1.2", "3.5", "666.666"}, + "commaSplits": {"four,five,six"}, + "ids": {"7", "8", "9"}, + "mongoIds": {mongoID1.Hex(), mongoID2.Hex()}, + "mongoIdsPtr": {mongoID1.Hex(), mongoID2.Hex()}, + "numbers": {"1.2", "3.5", "666.666"}, + "terms": {"artist", "label"}, }, Headers: map[string]string{ "Accept-Language": "en-us", @@ -114,25 +158,69 @@ func Test_UnmarshalReq(t *testing.T) { false, &input, ) - assert.Equal(t, nil, err, "ErrorRes must be nil") - assert.Equal(t, "fake-scan-id", input.ID, "ID must be parsed from path") - assert.Equal(t, int64(2), input.Page, "Page must be parsed from query") - assert.Equal(t, int64(30), input.PageSize, "PageSize must be parsed from query") - assert.Equal(t, "en-us", input.Language, "Language must be parsed from headers") - assert.Equal(t, mockConstTwo, input.Const, "Const must be parsed from query") - assert.True(t, input.Bool, "Bool must be true") - assert.NotNil(t, input.PBoolOne, "PBoolOne must not be nil") - assert.False(t, *input.PBoolOne, "PBoolOne must be *false") - assert.NotNil(t, input.Time, "Time must not be nil") - assert.Equal(t, input.Time.Format(time.RFC3339), "2021-11-01T11:11:11Z") - assert.Equal(t, input.Alias, stringAliasExample("hello")) - assert.NotNil(t, input.AliasPtr) - assert.Equal(t, *input.AliasPtr, aliasExample) - assert.DeepEqual(t, []Number{numberOne, numberTwo, numberThree}, input.CommaSplit, "CommaSplit must have 2 items") - assert.Equal(t, (*bool)(nil), input.PBoolTwo, "PBoolTwo must be nil") - assert.DeepEqual(t, []string{"one", "two"}, input.Terms, "Terms must be parsed from multiple query params") - assert.DeepEqual(t, []float64{1.2, 3.5, 666.666}, input.Numbers, "Numbers must be parsed from multiple query params") - assert.DeepEqual(t, []string{"gzip", "deflate"}, input.Encoding, "Encoding must be parsed from multiple header params") + require.NoError(t, err) + + require.Equal(t, *input.AliasPtr, stringAliasExample("world")) + require.Equal(t, input.Alias, stringAliasExample("hello")) + require.Equal(t, input.Bool1, true) + require.Equal(t, input.Bool2, true) + require.Equal(t, input.Bool3, true) + require.Equal(t, input.Bool4, true) + require.Equal(t, input.Bool5, true) + require.Equal(t, input.Bool6, true) + require.Equal(t, input.Bool7, true) + require.Equal(t, input.Bool8, true) + require.Equal(t, input.Bool9, true) + require.Equal(t, input.Civil.String(), "2023-12-22") + require.Equal(t, input.CivilPtr.String(), "2024-12-22") + require.Equal(t, input.Const, mockConst("twenty")) + require.Equal(t, *input.ConstPtr, mockConst("thirty")) + require.Equal(t, input.ID, "fake-scan-id") + require.Equal(t, input.Language, "en-us") + require.Equal(t, input.MongoID, mongoID1) + require.Equal(t, input.MongoIDPtr.Hex(), mongoID1.Hex()) + require.Equal(t, input.Number, func() *float32 { a := float32(90.10982); return &a }()) + require.Equal(t, *input.PBoolOne, false) + require.Equal(t, input.Page, int64(2)) + require.Equal(t, input.PageSize, func() *int64 { a := int64(30); return &a }()) + require.Equal(t, input.Time, time.Date(2021, 11, 1, 11, 11, 11, 0, time.UTC)) + require.Equal(t, *input.TimePtr, time.Date(2021, 11, 1, 11, 11, 11, 0, time.UTC)) + + numberPtrs := []*Number{ + func() *Number { a := Number("one"); return &a }(), + func() *Number { a := Number("two"); return &a }(), + func() *Number { a := Number("three"); return &a }(), + } + + idPtrs := []*string{ + func() *string { a := "7"; return &a }(), + func() *string { a := "8"; return &a }(), + func() *string { a := "9"; return &a }(), + } + + require.EqualValues(t, input.CommaSplit, []Number{"one", "two", "three"}) + require.EqualValues(t, input.CommaSplitPtr, numberPtrs) + require.EqualValues(t, input.Encoding, []string{"gzip", "deflate"}) + require.EqualValues(t, input.IDs, idPtrs) + require.EqualValues(t, input.MongoIDs, []primitive.ObjectID{mongoID1, mongoID2}) + require.EqualValues(t, input.MongoIDsPtr, []*primitive.ObjectID{&mongoID1, &mongoID2}) + require.EqualValues(t, input.Numbers, []float64{1.2, 3.5, 666.666}) + require.EqualValues(t, input.Terms, []string{"artist", "label"}) + + require.Nil(t, input.CivilPtrNil) + require.Nil(t, input.ConstPtrNil) + require.Nil(t, input.MongoIDPtrNil) + require.Nil(t, input.PBoolTwo) + require.Nil(t, input.TimePtrNil) + }) + t.Run("valid empty input", func(t *testing.T) { + var input mockListReq + err := UnmarshalReq( + events.APIGatewayProxyRequest{}, + false, + &input, + ) + require.NoError(t, err) }) t.Run("invalid path&query input", func(t *testing.T) { @@ -149,11 +237,9 @@ func Test_UnmarshalReq(t *testing.T) { false, &input, ) - assert.NotEqual(t, nil, err, "ErrorRes must not be nil") - var httpErr HTTPError - ok := errors.As(err, &httpErr) - assert.True(t, ok, "ErrorRes must be an response.HTTPError") - assert.Equal(t, http.StatusBadRequest, httpErr.Status, "ErrorRes code must be 400") + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "page")) + require.True(t, strings.Contains(err.Error(), "must be a valid integer")) }) fakeDate := time.Date(2020, 3, 23, 11, 33, 0, 0, time.UTC) @@ -172,10 +258,10 @@ func Test_UnmarshalReq(t *testing.T) { &input, ) - assert.Equal(t, nil, err, "ErrorRes must be nil") - assert.Equal(t, "bla", input.ID, "ID must be parsed from path parameters") - assert.Equal(t, "Fake Post", input.Name, "Name must be parsed from body") - assert.Equal(t, fakeDate, input.Date, "Date must be parsed from body") + require.Equal(t, nil, err, "ErrorRes must be nil") + require.Equal(t, "bla", input.ID, "ID must be parsed from path parameters") + require.Equal(t, "Fake Post", input.Name, "Name must be parsed from body") + require.Equal(t, fakeDate, input.Date, "Date must be parsed from body") }) t.Run("invalid body input, not base64", func(t *testing.T) { @@ -203,9 +289,9 @@ func Test_UnmarshalReq(t *testing.T) { &input, ) - assert.Equal(t, nil, err, "ErrorRes must be nil") - assert.Equal(t, "Fake Post", input.Name, "Name must be parsed from body") - assert.Equal(t, fakeDate, input.Date, "Date must be parsed from body") + require.Equal(t, nil, err, "ErrorRes must be nil") + require.Equal(t, "Fake Post", input.Name, "Name must be parsed from body") + require.Equal(t, fakeDate, input.Date, "Date must be parsed from body") }) t.Run("invalid body input, base64", func(t *testing.T) { diff --git a/lambda_router/http_test.go b/lambda_router/http_test.go index 583be11..23a2231 100644 --- a/lambda_router/http_test.go +++ b/lambda_router/http_test.go @@ -2,12 +2,11 @@ package lambda_router import ( "encoding/json" + "github.com/stretchr/testify/require" "net/http" "net/http/httptest" "testing" "time" - - "github.com/jgroeneveld/trial/assert" ) func TestHTTPHandler(t *testing.T) { @@ -29,9 +28,9 @@ func TestHTTPHandler(t *testing.T) { nil, ) - assert.Equal(t, nil, err, "ErrorRes must not be nil") - assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "Status code must be 401") - assert.True(t, len(testLog) > 0, "Log must have items") + require.Equal(t, nil, err, "ErrorRes must not be nil") + require.Equal(t, http.StatusUnauthorized, res.StatusCode, "Status code must be 401") + require.True(t, len(testLog) > 0, "Log must have items") }) t.Run("POST /api with auth", func(t *testing.T) { @@ -47,15 +46,15 @@ func TestHTTPHandler(t *testing.T) { req.Header.Set("Authorization", "Bearer fake-token") res, err := http.DefaultClient.Do(req) - assert.Equal(t, nil, err, "ErrorRes must not be nil") - assert.Equal(t, http.StatusBadRequest, res.StatusCode, "Status code must be 400") + require.Equal(t, nil, err, "ErrorRes must not be nil") + require.Equal(t, http.StatusBadRequest, res.StatusCode, "Status code must be 400") }) t.Run("GET /api", func(t *testing.T) { res, err := http.Get(ts.URL + "/api") - assert.Equal(t, nil, err, "ErrorRes must not be nil") - assert.Equal(t, http.StatusOK, res.StatusCode, "Status code must be 200") - assert.True(t, len(testLog) > 0, "Log must have items") + require.Equal(t, nil, err, "ErrorRes must not be nil") + require.Equal(t, http.StatusOK, res.StatusCode, "Status code must be 200") + require.True(t, len(testLog) > 0, "Log must have items") }) t.Run("GET /api/something/stuff", func(t *testing.T) { @@ -67,13 +66,13 @@ func TestHTTPHandler(t *testing.T) { req.Header.Set("Accept-Language", "en-us") res, err := http.DefaultClient.Do(req) - assert.Equal(t, nil, err, "Response error must be nil") - assert.Equal(t, http.StatusOK, res.StatusCode, "Status code must be 200") + require.Equal(t, nil, err, "Response error must be nil") + require.Equal(t, http.StatusOK, res.StatusCode, "Status code must be 200") var data []mockItem err = json.NewDecoder(res.Body).Decode(&data) - assert.Equal(t, nil, err, "Decode error must be nil") - assert.DeepEqual( + require.Equal(t, nil, err, "Decode error must be nil") + require.EqualValues( t, []mockItem{ { diff --git a/lambda_router/response_test.go b/lambda_router/response_test.go index b3d8050..941980d 100644 --- a/lambda_router/response_test.go +++ b/lambda_router/response_test.go @@ -3,11 +3,10 @@ package lambda_router import ( "encoding/base64" "errors" + "github.com/stretchr/testify/require" "net/http" "os" "testing" - - "github.com/jgroeneveld/trial/assert" ) type customStruct struct { @@ -32,41 +31,41 @@ func TestCustomRes(t *testing.T) { } res, err := CustomRes(httpStatus, headers, structValue) - assert.Nil(t, err) + require.Nil(t, err) t.Run("verify CustomRes returns the struct in the response body", func(t *testing.T) { var returnedStruct customStruct err = UnmarshalRes(res, &returnedStruct) - assert.Nil(t, err) + require.Nil(t, err) - assert.Equal(t, structValue, returnedStruct) + require.Equal(t, structValue, returnedStruct) }) t.Run("verify CustomRes returns the key value pair in the response headers", func(t *testing.T) { - assert.Equal(t, res.Headers["key"], headers["key"]) + require.Equal(t, res.Headers["key"], headers["key"]) }) t.Run("verify CustomRes returns the correct status code", func(t *testing.T) { - assert.Equal(t, httpStatus, res.StatusCode) + require.Equal(t, httpStatus, res.StatusCode) }) t.Run("verify CustomRes returns CORS headers", func(t *testing.T) { - assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "headers-header-val") - assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "methods-header-val") - assert.Equal(t, res.Headers[CORSOriginHeaderKey], "origin-header-val") + require.Equal(t, res.Headers[CORSHeadersHeaderKey], "headers-header-val") + require.Equal(t, res.Headers[CORSMethodsHeaderKey], "methods-header-val") + require.Equal(t, res.Headers[CORSOriginHeaderKey], "origin-header-val") }) } func TestEmptyRes(t *testing.T) { res, err := EmptyRes() - assert.Equal(t, http.StatusOK, res.StatusCode) - assert.Nil(t, err) - assert.Equal(t, "{}", res.Body) + require.Equal(t, http.StatusOK, res.StatusCode) + require.Nil(t, err) + require.Equal(t, "{}", res.Body) t.Run("verify EmptyRes returns the correct status code", func(t *testing.T) { - assert.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, http.StatusOK, res.StatusCode) }) t.Run("verify EmptyRes returns CORS headers", func(t *testing.T) { - assert.Equal(t, res.Headers[CORSHeadersEnvKey], "") - assert.Equal(t, res.Headers[CORSMethodsEnvKey], "") - assert.Equal(t, res.Headers[CORSOriginEnvKey], "") + require.Equal(t, res.Headers[CORSHeadersEnvKey], "") + require.Equal(t, res.Headers[CORSMethodsEnvKey], "") + require.Equal(t, res.Headers[CORSOriginEnvKey], "") }) } @@ -83,12 +82,12 @@ func TestErrorRes(t *testing.T) { Status: http.StatusBadRequest, Message: "Invalid input", }) - assert.Equal(t, http.StatusBadRequest, res.StatusCode, "status status must be correct") - assert.Equal(t, `{"status":400,"message":"Invalid input"}`, res.Body, "body must be correct") + require.Equal(t, http.StatusBadRequest, res.StatusCode, "status status must be correct") + require.Equal(t, `{"status":400,"message":"Invalid input"}`, res.Body, "body must be correct") t.Run("verify ErrorRes returns CORS headers", func(t *testing.T) { - assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "*") - assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "*") - assert.Equal(t, res.Headers[CORSOriginHeaderKey], "*") + require.Equal(t, res.Headers[CORSHeadersHeaderKey], "*") + require.Equal(t, res.Headers[CORSMethodsHeaderKey], "*") + require.Equal(t, res.Headers[CORSOriginHeaderKey], "*") }) }) t.Run("Handle an HTTPError for ErrorRes when ExposeServerErrors is true", func(t *testing.T) { @@ -97,8 +96,8 @@ func TestErrorRes(t *testing.T) { Status: http.StatusInternalServerError, Message: "database down", }) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode, "status must be correct") - assert.Equal(t, `{"status":500,"message":"database down"}`, res.Body, "body must be correct") + require.Equal(t, http.StatusInternalServerError, res.StatusCode, "status must be correct") + require.Equal(t, `{"status":500,"message":"database down"}`, res.Body, "body must be correct") }) t.Run("Handle an HTTPError for ErrorRes when ExposeServerErrors is false", func(t *testing.T) { ExposeServerErrors = false @@ -106,20 +105,20 @@ func TestErrorRes(t *testing.T) { Status: http.StatusInternalServerError, Message: "database down", }) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode, "status must be correct") - assert.Equal(t, `{"status":500,"message":"Internal Server Error"}`, res.Body, "body must be correct") + require.Equal(t, http.StatusInternalServerError, res.StatusCode, "status must be correct") + require.Equal(t, `{"status":500,"message":"Internal Server Error"}`, res.Body, "body must be correct") }) t.Run("Handle a general error for ErrorRes when ExposeServerErrors is true", func(t *testing.T) { ExposeServerErrors = true res, _ := ErrorRes(errors.New("database down")) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode, "status must be correct") - assert.Equal(t, `{"status":500,"message":"database down"}`, res.Body, "body must be correct") + require.Equal(t, http.StatusInternalServerError, res.StatusCode, "status must be correct") + require.Equal(t, `{"status":500,"message":"database down"}`, res.Body, "body must be correct") }) t.Run("Handle a general error for ErrorRes when ExposeServerErrors is false", func(t *testing.T) { ExposeServerErrors = false res, _ := ErrorRes(errors.New("database down")) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode, "status must be correct") - assert.Equal(t, `{"status":500,"message":"Internal Server Error"}`, res.Body, "body must be correct") + require.Equal(t, http.StatusInternalServerError, res.StatusCode, "status must be correct") + require.Equal(t, `{"status":500,"message":"Internal Server Error"}`, res.Body, "body must be correct") }) } @@ -136,24 +135,24 @@ header1, header2 value1, value2 ` res, err := FileRes("text/csv", map[string]string{"key": "value"}, []byte(csvContent)) - assert.Nil(t, err) + require.Nil(t, err) t.Run("verify FileRes returns the correct status code", func(t *testing.T) { - assert.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, http.StatusOK, res.StatusCode) }) t.Run("verify FileRes marks the response as NOT base64 encoded", func(t *testing.T) { - assert.False(t, res.IsBase64Encoded) + require.False(t, res.IsBase64Encoded) }) t.Run("verify FileRes embeds the bytes correctly in the response object as a string", func(t *testing.T) { - assert.Equal(t, csvContent, res.Body) + require.Equal(t, csvContent, res.Body) }) t.Run("verify FileRes preserves the original header values", func(t *testing.T) { - assert.Equal(t, "value", res.Headers["key"]) + require.Equal(t, "value", res.Headers["key"]) }) t.Run("verify FileRes returns CORS headers", func(t *testing.T) { - assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "*") - assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "*") - assert.Equal(t, res.Headers[CORSOriginHeaderKey], "*") + require.Equal(t, res.Headers[CORSHeadersHeaderKey], "*") + require.Equal(t, res.Headers[CORSMethodsHeaderKey], "*") + require.Equal(t, res.Headers[CORSOriginHeaderKey], "*") }) } @@ -170,28 +169,28 @@ header1, header2 value1, value2 ` res, err := FileB64Res("text/csv", map[string]string{"key": "value"}, []byte(csvContent)) - assert.Nil(t, err) + require.Nil(t, err) t.Run("verify FileB64Res returns the correct status code", func(t *testing.T) { - assert.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, http.StatusOK, res.StatusCode) }) t.Run("verify FileB64Res marks the response as base64 encoded", func(t *testing.T) { - assert.True(t, res.IsBase64Encoded) + require.True(t, res.IsBase64Encoded) }) t.Run("verify FileB64Res embeds the bytes correctly in the response object as a byte64 encoded string", func(t *testing.T) { b64CSVContent := res.Body decodedCSVContent, decodeErr := base64.StdEncoding.DecodeString(b64CSVContent) - assert.Nil(t, decodeErr) - assert.Equal(t, csvContent, string(decodedCSVContent)) + require.Nil(t, decodeErr) + require.Equal(t, csvContent, string(decodedCSVContent)) }) t.Run("verify FileRes preserves the original header values", func(t *testing.T) { - assert.Equal(t, "value", res.Headers["key"]) + require.Equal(t, "value", res.Headers["key"]) }) t.Run("verify FileB64Res returns CORS headers", func(t *testing.T) { - assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "*") - assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "*") - assert.Equal(t, res.Headers[CORSOriginHeaderKey], "*") + require.Equal(t, res.Headers[CORSHeadersHeaderKey], "*") + require.Equal(t, res.Headers[CORSMethodsHeaderKey], "*") + require.Equal(t, res.Headers[CORSOriginHeaderKey], "*") }) } @@ -205,15 +204,15 @@ func TestStatusAndErrorRes(t *testing.T) { newErr := errors.New("hello there") res, err := StatusAndErrorRes(http.StatusTeapot, newErr) - assert.Nil(t, err) + require.Nil(t, err) t.Run("verify StatusAndErrorRes returns the correct status code", func(t *testing.T) { - assert.Equal(t, http.StatusTeapot, res.StatusCode) + require.Equal(t, http.StatusTeapot, res.StatusCode) }) t.Run("verify StatusAndErrorRes returns CORS headers", func(t *testing.T) { - assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "*") - assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "*") - assert.Equal(t, res.Headers[CORSOriginHeaderKey], "*") + require.Equal(t, res.Headers[CORSHeadersHeaderKey], "*") + require.Equal(t, res.Headers[CORSMethodsHeaderKey], "*") + require.Equal(t, res.Headers[CORSOriginHeaderKey], "*") }) } @@ -227,19 +226,19 @@ func TestSuccessRes(t *testing.T) { cs := customStruct{StructKey: "hello there"} res, err := SuccessRes(cs) - assert.Nil(t, err) + require.Nil(t, err) t.Run("verify SuccessRes returns the correct status code", func(t *testing.T) { - assert.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, http.StatusOK, res.StatusCode) }) t.Run("verify SuccessRes returns the struct in the response body", func(t *testing.T) { var returnedStruct customStruct unmarshalErr := UnmarshalRes(res, &returnedStruct) - assert.Nil(t, unmarshalErr) - assert.Equal(t, cs, returnedStruct) + require.Nil(t, unmarshalErr) + require.Equal(t, cs, returnedStruct) }) t.Run("verify SuccessRes returns CORS headers", func(t *testing.T) { - assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "*") - assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "*") - assert.Equal(t, res.Headers[CORSOriginHeaderKey], "*") + require.Equal(t, res.Headers[CORSHeadersHeaderKey], "*") + require.Equal(t, res.Headers[CORSMethodsHeaderKey], "*") + require.Equal(t, res.Headers[CORSOriginHeaderKey], "*") }) } diff --git a/lambda_router/router_test.go b/lambda_router/router_test.go index 4c491ff..7242b4d 100644 --- a/lambda_router/router_test.go +++ b/lambda_router/router_test.go @@ -4,13 +4,13 @@ import ( "context" "errors" "fmt" + "github.com/stretchr/testify/require" "net/http" "strings" "testing" "time" "github.com/aws/aws-lambda-go/events" - "github.com/jgroeneveld/trial/assert" ) var testLog []string @@ -26,34 +26,34 @@ func TestRouter(t *testing.T) { t.Run("Routes created correctly", func(t *testing.T) { t.Run("/", func(t *testing.T) { route, ok := lmd.routes["/"] - assert.True(t, ok, "Route must be created") + require.True(t, ok, "Route must be created") if ok { - assert.Equal(t, `^/api$`, route.re.String(), "Regex must be correct") - assert.NotEqual(t, nil, route.methods[http.MethodGet], "GET method must exist") - assert.NotEqual(t, nil, route.methods[http.MethodPost], "POST method must exist") - assert.NotEqual(t, nil, route.methods[http.MethodOptions], "OPTIONS method must exist") // auto generated for CORS support + require.Equal(t, `^/api$`, route.re.String(), "Regex must be correct") + require.NotEqual(t, nil, route.methods[http.MethodGet], "GET method must exist") + require.NotEqual(t, nil, route.methods[http.MethodPost], "POST method must exist") + require.NotEqual(t, nil, route.methods[http.MethodOptions], "OPTIONS method must exist") // auto generated for CORS support } }) t.Run("/:id", func(t *testing.T) { route, ok := lmd.routes["/:id"] - assert.True(t, ok, "Route must be created") + require.True(t, ok, "Route must be created") if ok { - assert.Equal(t, `^/api/([^/]+)$`, route.re.String(), "Regex must be correct") - assert.NotEqual(t, nil, route.methods[http.MethodGet], "GET method must exist") - assert.NotEqual(t, nil, route.methods[http.MethodOptions], "OPTIONS method must exist") // auto generated for CORS support + require.Equal(t, `^/api/([^/]+)$`, route.re.String(), "Regex must be correct") + require.NotEqual(t, nil, route.methods[http.MethodGet], "GET method must exist") + require.NotEqual(t, nil, route.methods[http.MethodOptions], "OPTIONS method must exist") // auto generated for CORS support } }) t.Run("/:id/stuff/:fake", func(t *testing.T) { route, ok := lmd.routes["/:id/stuff/:fake"] - assert.True(t, ok, "Route must be created") + require.True(t, ok, "Route must be created") if ok { - assert.Equal( + require.Equal( t, `^/api/([^/]+)/stuff/([^/]+)$`, route.re.String(), "Regex must be correct", ) - assert.DeepEqual( + require.EqualValues( t, []string{"id", "fake"}, route.paramNames, @@ -70,7 +70,7 @@ func TestRouter(t *testing.T) { Path: "/api", } _, err := lmd.matchReq(&req) - assert.Equal(t, nil, err, "ErrorRes must be nil") + require.Equal(t, nil, err, "ErrorRes must be nil") }) t.Run("POST /api/", func(t *testing.T) { @@ -80,7 +80,7 @@ func TestRouter(t *testing.T) { Path: "/api/", } _, err := lmd.matchReq(&req) - assert.Equal(t, nil, err, "ErrorRes must be nil") + require.Equal(t, nil, err, "ErrorRes must be nil") }) t.Run("DELETE /api", func(t *testing.T) { @@ -89,11 +89,11 @@ func TestRouter(t *testing.T) { Path: "/api", } _, err := lmd.matchReq(&req) - assert.NotEqual(t, nil, err, "ErrorRes must not be nil") + require.NotEqual(t, nil, err, "ErrorRes must not be nil") var httpErr HTTPError ok := errors.As(err, &httpErr) - assert.True(t, ok, "ErrorRes must be an HTTP error") - assert.Equal(t, http.StatusMethodNotAllowed, httpErr.Status, "ErrorRes code must be 405") + require.True(t, ok, "ErrorRes must be an HTTP error") + require.Equal(t, http.StatusMethodNotAllowed, httpErr.Status, "ErrorRes code must be 405") }) t.Run("GET /api/fake-id", func(t *testing.T) { @@ -102,8 +102,8 @@ func TestRouter(t *testing.T) { Path: "/api/fake-id", } _, err := lmd.matchReq(&req) - assert.Equal(t, nil, err, "ErrorRes must be nil") - assert.Equal(t, "fake-id", req.PathParameters["id"], "ID must be correct") + require.Equal(t, nil, err, "ErrorRes must be nil") + require.Equal(t, "fake-id", req.PathParameters["id"], "ID must be correct") }) t.Run("GET /api/fake-id/bla", func(t *testing.T) { @@ -112,11 +112,11 @@ func TestRouter(t *testing.T) { Path: "/api/fake-id/bla", } _, err := lmd.matchReq(&req) - assert.NotEqual(t, nil, err, "ErrorRes must not be nil") + require.NotEqual(t, nil, err, "ErrorRes must not be nil") var httpErr HTTPError ok := errors.As(err, &httpErr) - assert.True(t, ok, "ErrorRes must be an HTTP error") - assert.Equal(t, http.StatusNotFound, httpErr.Status, "ErrorRes code must be 404") + require.True(t, ok, "ErrorRes must be an HTTP error") + require.Equal(t, http.StatusNotFound, httpErr.Status, "ErrorRes code must be 404") }) t.Run("GET /api/fake-id/stuff/faked-fake", func(t *testing.T) { @@ -125,9 +125,9 @@ func TestRouter(t *testing.T) { Path: "/api/fake-id/stuff/faked-fake", } _, err := lmd.matchReq(&req) - assert.Equal(t, nil, err, "ErrorRes must be nil") - assert.Equal(t, "fake-id", req.PathParameters["id"], "'id' must be correct") - assert.Equal(t, "faked-fake", req.PathParameters["fake"], "'fake' must be correct") + require.Equal(t, nil, err, "ErrorRes must be nil") + require.Equal(t, "fake-id", req.PathParameters["id"], "'id' must be correct") + require.Equal(t, "faked-fake", req.PathParameters["fake"], "'fake' must be correct") }) }) @@ -138,10 +138,10 @@ func TestRouter(t *testing.T) { Path: "/api", } res, err := lmd.Handler(context.Background(), req) - assert.Equal(t, nil, err, "ErrorRes must not be nil") - assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "Status code must be 401") - assert.True(t, len(testLog) > 0, "Log must have items") - assert.Equal( + require.Equal(t, nil, err, "ErrorRes must not be nil") + require.Equal(t, http.StatusUnauthorized, res.StatusCode, "Status code must be 401") + require.True(t, len(testLog) > 0, "Log must have items") + require.Equal( t, "[ERR] [POST /api] [401]", testLog[len(testLog)-1], @@ -158,8 +158,8 @@ func TestRouter(t *testing.T) { }, } res, err := lmd.Handler(context.Background(), req) - assert.Equal(t, nil, err, "ErrorRes must not be nil") - assert.Equal(t, http.StatusBadRequest, res.StatusCode, "Status code must be 400") + require.Equal(t, nil, err, "ErrorRes must not be nil") + require.Equal(t, http.StatusBadRequest, res.StatusCode, "Status code must be 400") }) t.Run("GET /api", func(t *testing.T) { @@ -168,10 +168,10 @@ func TestRouter(t *testing.T) { Path: "/api", } res, err := lmd.Handler(context.Background(), req) - assert.Equal(t, nil, err, "ErrorRes must not be nil") - assert.Equal(t, http.StatusOK, res.StatusCode, "Status code must be 200") - assert.True(t, len(testLog) > 0, "Log must have items") - assert.Equal( + require.Equal(t, nil, err, "ErrorRes must not be nil") + require.Equal(t, http.StatusOK, res.StatusCode, "Status code must be 200") + require.True(t, len(testLog) > 0, "Log must have items") + require.Equal( t, "[INF] [GET /api] [200]", testLog[len(testLog)-1], @@ -207,20 +207,20 @@ func TestRouter(t *testing.T) { HTTPMethod: http.MethodPost, Path: "/foo/bar", }) - assert.Equal(t, "/foo/bar", res.Body, "req must match /foo/bar route") + require.Equal(t, "/foo/bar", res.Body, "req must match /foo/bar route") } res, _ := router.Handler(context.Background(), events.APIGatewayProxyRequest{ HTTPMethod: http.MethodDelete, Path: "/foo/bar", }) - assert.Equal(t, http.StatusMethodNotAllowed, res.StatusCode, "Status code must be 405") + require.Equal(t, http.StatusMethodNotAllowed, res.StatusCode, "Status code must be 405") res, _ = router.Handler(context.Background(), events.APIGatewayProxyRequest{ HTTPMethod: http.MethodGet, Path: "/foo/bar2", }) - assert.Equal(t, "/foo/:id", res.Body, "Body must match") + require.Equal(t, "/foo/:id", res.Body, "Body must match") }) }