Skip to content

Commit

Permalink
Implement net/http.Handler for local development
Browse files Browse the repository at this point in the history
This commit modifies the lmdrouter.Router type to implement the
net/http.Handler interface (meaning the ServeHTTP method is now
available on *lmdrouter.Route objects).

This allows usage of applications using lmdrouter in environments other
than AWS Lambda, but is mostly useful for local development purposes.

It should be noted that this means applications will now have to use the
`netgo` build tag (`go build -tags netgo`) to make sure binaries are
still statically compiled.

Resolves: 1
  • Loading branch information
Ido Perlmuter committed Mar 9, 2022
1 parent d0b6229 commit f5483e1
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 14 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ func loggerMiddleware(next lmdrouter.Handler) lmdrouter.Handler {
}
```

## Static Compilation for AWS Lambda

To ensure Lambda applications using lmdrouter (or any Lambda applications
written in Go, for that matter) will properly work in AWS's Go runtime, make
sure to compile your applications statically. You can either disable CGO
completely using `CGO_ENABLED=0`, or use the following build flags:

```sh
go build -tags netgo -ldflags "-s -w"
```

## License

This library is distributed under the terms of the [Apache License 2.0](LICENSE).
82 changes: 82 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package lmdrouter

import (
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/aws/aws-lambda-go/events"
)

// ServerHTTP implements the net/http.Handler interface in order to allow
// lmdrouter applications to be used outside of AWS Lambda environments, most
// likely for local development purposes
func (l *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// convert request into an events.APIGatewayProxyRequest object
singleValueHeaders := convertMap(map[string][]string(r.Header))
singleValueQuery := convertMap(
map[string][]string(r.URL.Query()),
)

body, err := io.ReadAll(r.Body)
if err != nil {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]interface{}{
"error": fmt.Sprintf("Failed reading request body: %s", err),
}) // nolint: errcheck
return
}

event := events.APIGatewayProxyRequest{
Path: r.URL.Path,
HTTPMethod: r.Method,
Headers: singleValueHeaders,
MultiValueHeaders: map[string][]string(r.Header),
QueryStringParameters: singleValueQuery,
MultiValueQueryStringParameters: map[string][]string(r.URL.Query()),
Body: string(body),
}

res, err := l.Handler(r.Context(), event)
if err != nil {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(500)
json.NewEncoder(w).Encode(map[string]interface{}{
"error": fmt.Sprintf("Failed executing handler: %s", err),
}) // nolint: errcheck
return
}

for header, values := range res.MultiValueHeaders {
for i, value := range values {
if i == 0 {
w.Header().Set(header, value)
} else {
w.Header().Add(header, value)
}
}
}

for header, value := range res.Headers {
if w.Header().Get(header) == "" {
w.Header().Set(header, value)
}
}

w.WriteHeader(res.StatusCode)
w.Write([]byte(res.Body)) // nolint: errcheck
}

func convertMap(in map[string][]string) map[string]string {
singleValue := make(map[string]string)

for key, value := range in {
if len(value) == 1 {
singleValue[key] = value[0]
}
}

return singleValue
}
99 changes: 99 additions & 0 deletions http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package lmdrouter

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/jgroeneveld/trial/assert"
)

func TestHTTPHandler(t *testing.T) {
lmd := NewRouter("/api", logger)
lmd.Route("GET", "/", listSomethings)
lmd.Route("POST", "/", postSomething, auth)
lmd.Route("GET", "/:id", getSomething)
lmd.Route("GET", "/:id/stuff", listStuff)
lmd.Route("GET", "/:id/stuff/:fake", listStuff)

ts := httptest.NewServer(http.HandlerFunc(lmd.ServeHTTP))

defer ts.Close()

t.Run("POST /api without auth", func(t *testing.T) {
res, err := http.Post(
ts.URL+"/api",
"application/json; charset=UTF-8",
nil,
)

assert.Equal(t, nil, err, "Error must not be nil")
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "Status code must be 401")
assert.True(t, len(log) > 0, "Log must have items")
})

t.Run("POST /api with auth", func(t *testing.T) {
req, err := http.NewRequest(
"POST",
ts.URL+"/api",
nil,
)
if err != nil {
t.Fatalf("Request creation unexpectedly failed: %s", err)
}

req.Header.Set("Authorization", "Bearer fake-token")

res, err := http.DefaultClient.Do(req)
assert.Equal(t, nil, err, "Error must not be nil")
assert.Equal(t, http.StatusBadRequest, res.StatusCode, "Status code must be 400")
})

t.Run("GET /api", func(t *testing.T) {
res, err := http.Get(ts.URL + "/api")
assert.Equal(t, nil, err, "Error must not be nil")
assert.Equal(t, http.StatusOK, res.StatusCode, "Status code must be 200")
assert.True(t, len(log) > 0, "Log must have items")
})

t.Run("GET /api/something/stuff", func(t *testing.T) {
req, _ := http.NewRequest(
"GET",
ts.URL+"/api/something/stuff?terms=one&terms=two&terms=three",
nil,
)
req.Header.Set("Accept-Language", "en-us")

res, err := http.DefaultClient.Do(req)
assert.Equal(t, nil, err, "Response error must be nil")
assert.Equal(t, http.StatusOK, res.StatusCode, "Status code must be 200")

var data []mockItem
err = json.NewDecoder(res.Body).Decode(&data)
assert.Equal(t, nil, err, "Decode error must be nil")
assert.DeepEqual(
t,
[]mockItem{
{
ID: "something",
Name: "one in en-us",
Date: time.Time{},
},
{
ID: "something",
Name: "two in en-us",
Date: time.Time{},
},
{
ID: "something",
Name: "three in en-us",
Date: time.Time{},
},
},
data,
"Response body must match",
)
})
}
3 changes: 3 additions & 0 deletions lmdrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
// API Gateway response (only JSON responses are currently generated). See the
// MarshalResponse function for more information.
//
// * Implements net/http.Handler for local development and general usage outside
// of an AWS Lambda environment.
//
package lmdrouter

import (
Expand Down
62 changes: 48 additions & 14 deletions lmdrouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,18 @@ func TestRouter(t *testing.T) {
route, ok := lmd.routes["/:id/stuff/:fake"]
assert.True(t, ok, "Route must be created")
if ok {
assert.Equal(t, `^/api/([^/]+)/stuff/([^/]+)$`, route.re.String(), "Regex must be correct")
assert.DeepEqual(t, []string{"id", "fake"}, route.paramNames, "Param names must be correct")
assert.Equal(
t,
`^/api/([^/]+)/stuff/([^/]+)$`,
route.re.String(),
"Regex must be correct",
)
assert.DeepEqual(
t,
[]string{"id", "fake"},
route.paramNames,
"Param names must be correct",
)
}
})
})
Expand Down Expand Up @@ -131,7 +141,12 @@ func TestRouter(t *testing.T) {
assert.Equal(t, nil, err, "Error must not be nil")
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "Status code must be 401")
assert.True(t, len(log) > 0, "Log must have items")
assert.Equal(t, "[ERR] [POST /api] [401]", log[len(log)-1], "Last long line must be correct")
assert.Equal(
t,
"[ERR] [POST /api] [401]",
log[len(log)-1],
"Last long line must be correct",
)
})

t.Run("POST /api with auth", func(t *testing.T) {
Expand All @@ -147,7 +162,7 @@ func TestRouter(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, res.StatusCode, "Status code must be 400")
})

t.Run("GET /", func(t *testing.T) {
t.Run("GET /api", func(t *testing.T) {
req := events.APIGatewayProxyRequest{
HTTPMethod: "GET",
Path: "/api",
Expand All @@ -156,20 +171,33 @@ func TestRouter(t *testing.T) {
assert.Equal(t, nil, err, "Error must not be nil")
assert.Equal(t, http.StatusOK, res.StatusCode, "Status code must be 200")
assert.True(t, len(log) > 0, "Log must have items")
assert.Equal(t, "[INF] [GET /api] [200]", log[len(log)-1], "Last long line must be correct")
assert.Equal(
t,
"[INF] [GET /api] [200]",
log[len(log)-1],
"Last long line must be correct",
)
})
})

t.Run("Overlapping routes", func(t *testing.T) {
router := NewRouter("")
router.Route("GET", "/foo/:id", func(_ context.Context, _ events.APIGatewayProxyRequest) (res events.APIGatewayProxyResponse, err error) {
res.Body = "/foo/:id"
return res, nil
})
router.Route("POST", "/foo/bar", func(_ context.Context, _ events.APIGatewayProxyRequest) (res events.APIGatewayProxyResponse, err error) {
res.Body = "/foo/bar"
return res, nil
})
router.Route(
"GET",
"/foo/:id",
func(_ context.Context, _ events.APIGatewayProxyRequest) (res events.APIGatewayProxyResponse, err error) {
res.Body = "/foo/:id"
return res, nil
},
)
router.Route(
"POST",
"/foo/bar",
func(_ context.Context, _ events.APIGatewayProxyRequest) (res events.APIGatewayProxyResponse, err error) {
res.Body = "/foo/bar"
return res, nil
},
)

// call POST /foo/bar in a loop. We do this because the router iterates
// over a map to match routes, which is non-deterministic, meaning
Expand Down Expand Up @@ -270,7 +298,13 @@ func listStuff(ctx context.Context, req events.APIGatewayProxyRequest) (
return HandleError(err)
}

output := []mockItem{}
output := make([]mockItem, len(input.Terms))
for i, term := range input.Terms {
output[i] = mockItem{
ID: input.ID,
Name: fmt.Sprintf("%s in %s", term, input.Language),
}
}

return MarshalResponse(http.StatusOK, nil, output)
}
Expand Down

0 comments on commit f5483e1

Please sign in to comment.