Skip to content

Commit

Permalink
Merge pull request #4 from seantcanavan/do_not_default_cors
Browse files Browse the repository at this point in the history
default CORS to disabled for better security
  • Loading branch information
seantcanavan authored Jul 17, 2023
2 parents 96f4593 + 81aaced commit 381ef96
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ format:
gofmt -s -w -l .

test:
go test -v ./...
go test ./...
31 changes: 15 additions & 16 deletions lambda_router/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ import (
)

const ContentTypeKey = "Content-Type"
const CORSHeadersKey = "Access-Control-Allow-Headers"
const CORSMethodsKey = "Access-Control-Allow-Methods"
const CORSOriginKey = "Access-Control-Allow-Origin"
const CORSHeadersHeaderKey = "Access-Control-Allow-Headers"
const CORSMethodsHeaderKey = "Access-Control-Allow-Methods"
const CORSOriginHeaderKey = "Access-Control-Allow-Origin"
const CORSHeadersEnvKey = "LAMBDA_JWT_ROUTER_CORS_HEADERS"
const CORSMethodsEnvKey = "LAMBDA_JWT_ROUTER_CORS_METHODS"
const CORSOriginEnvKey = "LAMBDA_JWT_ROUTER_CORS_ORIGIN"

// ServerHTTP implements the net/http.Handler interface in order to allow
// lmdrouter applications to be used outside of AWS Lambda environments, most
Expand All @@ -28,26 +31,22 @@ func (l *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.URL.Query(),
)

corsHeaders := os.Getenv("LAMBDA_JWT_ROUTER_CORS_HEADERS")
corsMethods := os.Getenv("LAMBDA_JWT_ROUTER_CORS_METHODS")
corsOrigins := os.Getenv("LAMBDA_JWT_ROUTER_CORS_ORIGIN")
corsHeaders := os.Getenv(CORSHeadersEnvKey)
corsMethods := os.Getenv(CORSMethodsEnvKey)
corsOrigins := os.Getenv(CORSOriginEnvKey)

if corsHeaders == "" {
corsHeaders = "*"
if corsHeaders != "" {
w.Header().Set(CORSHeadersHeaderKey, corsHeaders)
}

if corsMethods == "" {
corsMethods = "*"
if corsMethods != "" {
w.Header().Set(CORSMethodsHeaderKey, corsMethods)
}

if corsOrigins == "" {
corsOrigins = "*"
if corsOrigins != "" {
w.Header().Set(CORSOriginHeaderKey, corsOrigins)
}

w.Header().Set(CORSHeadersKey, "*")
w.Header().Set(CORSMethodsKey, corsMethods)
w.Header().Set(CORSOriginKey, corsOrigins)

body, err := io.ReadAll(r.Body)
if err != nil {
w.Header().Set(ContentTypeKey, "application/json; charset=UTF-8")
Expand Down
22 changes: 9 additions & 13 deletions lambda_router/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,25 +148,21 @@ func (err HTTPError) Error() string {

// addCors injects CORS Origin and CORS Methods headers into the response object before it's returned.
func addCors(headers map[string]string) map[string]string {
corsHeaders := os.Getenv("LAMBDA_JWT_ROUTER_CORS_HEADERS")
corsMethods := os.Getenv("LAMBDA_JWT_ROUTER_CORS_METHODS")
corsOrigins := os.Getenv("LAMBDA_JWT_ROUTER_CORS_ORIGIN")
corsHeaders := os.Getenv(CORSHeadersEnvKey)
corsMethods := os.Getenv(CORSMethodsEnvKey)
corsOrigins := os.Getenv(CORSOriginEnvKey)

if corsHeaders == "" {
corsHeaders = "*"
if corsHeaders != "" {
headers[CORSHeadersHeaderKey] = corsHeaders
}

if corsMethods == "" {
corsMethods = "*"
if corsMethods != "" {
headers[CORSMethodsHeaderKey] = corsMethods
}

if corsOrigins == "" {
corsOrigins = "*"
if corsOrigins != "" {
headers[CORSOriginHeaderKey] = corsOrigins
}

headers[CORSHeadersKey] = corsHeaders
headers[CORSMethodsKey] = corsMethods
headers[CORSOriginKey] = corsOrigins

return headers
}
85 changes: 64 additions & 21 deletions lambda_router/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/base64"
"errors"
"net/http"
"os"
"testing"

"github.com/jgroeneveld/trial/assert"
Expand All @@ -14,6 +15,13 @@ type customStruct struct {
}

func TestCustomRes(t *testing.T) {
os.Setenv(CORSHeadersEnvKey, "headers-header-val")
defer os.Unsetenv(CORSHeadersEnvKey)
os.Setenv(CORSMethodsEnvKey, "methods-header-val")
defer os.Unsetenv(CORSMethodsEnvKey)
os.Setenv(CORSOriginEnvKey, "origin-header-val")
defer os.Unsetenv(CORSOriginEnvKey)

httpStatus := http.StatusTeapot
headers := map[string]string{
"key": "value",
Expand All @@ -40,9 +48,9 @@ func TestCustomRes(t *testing.T) {
assert.Equal(t, httpStatus, res.StatusCode)
})
t.Run("verify CustomRes returns CORS headers", func(t *testing.T) {
assert.Equal(t, res.Headers[CORSHeadersKey], "*")
assert.Equal(t, res.Headers[CORSMethodsKey], "*")
assert.Equal(t, res.Headers[CORSOriginKey], "*")
assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "headers-header-val")
assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "methods-header-val")
assert.Equal(t, res.Headers[CORSOriginHeaderKey], "origin-header-val")
})
}

Expand All @@ -56,13 +64,20 @@ func TestEmptyRes(t *testing.T) {
assert.Equal(t, http.StatusOK, res.StatusCode)
})
t.Run("verify EmptyRes returns CORS headers", func(t *testing.T) {
assert.Equal(t, res.Headers[CORSHeadersKey], "*")
assert.Equal(t, res.Headers[CORSMethodsKey], "*")
assert.Equal(t, res.Headers[CORSOriginKey], "*")
assert.Equal(t, res.Headers[CORSHeadersEnvKey], "")
assert.Equal(t, res.Headers[CORSMethodsEnvKey], "")
assert.Equal(t, res.Headers[CORSOriginEnvKey], "")
})
}

func TestErrorRes(t *testing.T) {
os.Setenv(CORSHeadersEnvKey, "*")
defer os.Unsetenv(CORSHeadersEnvKey)
os.Setenv(CORSMethodsEnvKey, "*")
defer os.Unsetenv(CORSMethodsEnvKey)
os.Setenv(CORSOriginEnvKey, "*")
defer os.Unsetenv(CORSOriginEnvKey)

t.Run("Handle an HTTPError ErrorRes without ExposeServerErrors set and verify CORS", func(t *testing.T) {
res, _ := ErrorRes(HTTPError{
Status: http.StatusBadRequest,
Expand All @@ -71,9 +86,9 @@ func TestErrorRes(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, res.StatusCode, "status status must be correct")
assert.Equal(t, `{"status":400,"message":"Invalid input"}`, res.Body, "body must be correct")
t.Run("verify ErrorRes returns CORS headers", func(t *testing.T) {
assert.Equal(t, res.Headers[CORSHeadersKey], "*")
assert.Equal(t, res.Headers[CORSMethodsKey], "*")
assert.Equal(t, res.Headers[CORSOriginKey], "*")
assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "*")
assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "*")
assert.Equal(t, res.Headers[CORSOriginHeaderKey], "*")
})
})
t.Run("Handle an HTTPError for ErrorRes when ExposeServerErrors is true", func(t *testing.T) {
Expand Down Expand Up @@ -109,6 +124,13 @@ func TestErrorRes(t *testing.T) {
}

func TestFileRes(t *testing.T) {
os.Setenv(CORSHeadersEnvKey, "*")
defer os.Unsetenv(CORSHeadersEnvKey)
os.Setenv(CORSMethodsEnvKey, "*")
defer os.Unsetenv(CORSMethodsEnvKey)
os.Setenv(CORSOriginEnvKey, "*")
defer os.Unsetenv(CORSOriginEnvKey)

csvContent := `
header1, header2
value1, value2
Expand All @@ -129,13 +151,20 @@ value1, value2
assert.Equal(t, "value", res.Headers["key"])
})
t.Run("verify FileRes returns CORS headers", func(t *testing.T) {
assert.Equal(t, res.Headers[CORSHeadersKey], "*")
assert.Equal(t, res.Headers[CORSMethodsKey], "*")
assert.Equal(t, res.Headers[CORSOriginKey], "*")
assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "*")
assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "*")
assert.Equal(t, res.Headers[CORSOriginHeaderKey], "*")
})
}

func TestFileB64Res(t *testing.T) {
os.Setenv(CORSHeadersEnvKey, "*")
defer os.Unsetenv(CORSHeadersEnvKey)
os.Setenv(CORSMethodsEnvKey, "*")
defer os.Unsetenv(CORSMethodsEnvKey)
os.Setenv(CORSOriginEnvKey, "*")
defer os.Unsetenv(CORSOriginEnvKey)

csvContent := `
header1, header2
value1, value2
Expand All @@ -160,13 +189,20 @@ value1, value2
assert.Equal(t, "value", res.Headers["key"])
})
t.Run("verify FileB64Res returns CORS headers", func(t *testing.T) {
assert.Equal(t, res.Headers[CORSHeadersKey], "*")
assert.Equal(t, res.Headers[CORSMethodsKey], "*")
assert.Equal(t, res.Headers[CORSOriginKey], "*")
assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "*")
assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "*")
assert.Equal(t, res.Headers[CORSOriginHeaderKey], "*")
})
}

func TestStatusAndErrorRes(t *testing.T) {
os.Setenv(CORSHeadersEnvKey, "*")
defer os.Unsetenv(CORSHeadersEnvKey)
os.Setenv(CORSMethodsEnvKey, "*")
defer os.Unsetenv(CORSMethodsEnvKey)
os.Setenv(CORSOriginEnvKey, "*")
defer os.Unsetenv(CORSOriginEnvKey)

newErr := errors.New("hello there")
res, err := StatusAndErrorRes(http.StatusTeapot, newErr)
assert.Nil(t, err)
Expand All @@ -175,13 +211,20 @@ func TestStatusAndErrorRes(t *testing.T) {
assert.Equal(t, http.StatusTeapot, res.StatusCode)
})
t.Run("verify StatusAndErrorRes returns CORS headers", func(t *testing.T) {
assert.Equal(t, res.Headers[CORSHeadersKey], "*")
assert.Equal(t, res.Headers[CORSMethodsKey], "*")
assert.Equal(t, res.Headers[CORSOriginKey], "*")
assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "*")
assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "*")
assert.Equal(t, res.Headers[CORSOriginHeaderKey], "*")
})
}

func TestSuccessRes(t *testing.T) {
os.Setenv(CORSHeadersEnvKey, "*")
defer os.Unsetenv(CORSHeadersEnvKey)
os.Setenv(CORSMethodsEnvKey, "*")
defer os.Unsetenv(CORSMethodsEnvKey)
os.Setenv(CORSOriginEnvKey, "*")
defer os.Unsetenv(CORSOriginEnvKey)

cs := customStruct{StructKey: "hello there"}
res, err := SuccessRes(cs)
assert.Nil(t, err)
Expand All @@ -195,8 +238,8 @@ func TestSuccessRes(t *testing.T) {
assert.Equal(t, cs, returnedStruct)
})
t.Run("verify SuccessRes returns CORS headers", func(t *testing.T) {
assert.Equal(t, res.Headers[CORSHeadersKey], "*")
assert.Equal(t, res.Headers[CORSMethodsKey], "*")
assert.Equal(t, res.Headers[CORSOriginKey], "*")
assert.Equal(t, res.Headers[CORSHeadersHeaderKey], "*")
assert.Equal(t, res.Headers[CORSMethodsHeaderKey], "*")
assert.Equal(t, res.Headers[CORSOriginHeaderKey], "*")
})
}

0 comments on commit 381ef96

Please sign in to comment.