Skip to content

Commit

Permalink
support shortcircuit in ResponseModifier
Browse files Browse the repository at this point in the history
Signed-off-by: Billy Zaelani Malik <[email protected]>
  • Loading branch information
minizilla committed Aug 27, 2024
1 parent fb531bf commit 80b1009
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 25 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ generate:
go build -buildmode=plugin -o ./transport/http/server/plugin/tests/lura-server-example.so ./transport/http/server/plugin/tests
go build -buildmode=plugin -o ./proxy/plugin/tests/lura-request-modifier-example.so ./proxy/plugin/tests/logger
go build -buildmode=plugin -o ./proxy/plugin/tests/lura-error-example.so ./proxy/plugin/tests/error
go build -buildmode=plugin -o ./proxy/plugin/tests/lura-shortcircuit-example.so ./proxy/plugin/tests/shortcircuit

test: generate
go test -cover -race ./...
Expand Down
62 changes: 39 additions & 23 deletions proxy/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,58 +107,74 @@ func newPluginMiddleware(logger logging.Logger, tag, pattern string, cfg map[str

if totRespModifiers == 0 {
return func(ctx context.Context, r *Request) (*Response, error) {
var err error
r, err = executeRequestModifiers(ctx, reqModifiers, r)
req, resp, err := executeRequestModifiers(ctx, reqModifiers, r)
if err != nil {
return nil, err
}

return next[0](ctx, r)
if resp != nil {
return resp, nil
}

return next[0](ctx, req)
}
}

return func(ctx context.Context, r *Request) (*Response, error) {
var err error
r, err = executeRequestModifiers(ctx, reqModifiers, r)
req, resp, err := executeRequestModifiers(ctx, reqModifiers, r)
if err != nil {
return nil, err
}

resp, err := next[0](ctx, r)
if err != nil {
return resp, err
if resp == nil {
var err error
resp, err = next[0](ctx, req)
if err != nil {
return resp, err
}
}

return executeResponseModifiers(ctx, respModifiers, resp, newRequestWrapper(ctx, r))
return executeResponseModifiers(ctx, respModifiers, resp, newRequestWrapper(ctx, req))
}
}
}

func executeRequestModifiers(ctx context.Context, reqModifiers []func(interface{}) (interface{}, error), r *Request) (*Request, error) {
func executeRequestModifiers(ctx context.Context, reqModifiers []func(interface{}) (interface{}, error), req *Request) (*Request, *Response, error) {
var tmp RequestWrapper
tmp = newRequestWrapper(ctx, r)
tmp = newRequestWrapper(ctx, req)
var resp *Response

for _, f := range reqModifiers {
res, err := f(tmp)
if err != nil {
return nil, err
return nil, nil, err
}
t, ok := res.(RequestWrapper)
if !ok {
switch t := res.(type) {
case RequestWrapper:
tmp = t
case ResponseWrapper:
resp = new(Response)
resp.Data = t.Data()
resp.IsComplete = t.IsComplete()
resp.Io = t.Io()
resp.Metadata = Metadata{}
resp.Metadata.Headers = t.Headers()
resp.Metadata.StatusCode = t.StatusCode()
break
default:
continue
}
tmp = t
}

r.Method = tmp.Method()
r.URL = tmp.URL()
r.Query = tmp.Query()
r.Path = tmp.Path()
r.Body = tmp.Body()
r.Params = tmp.Params()
r.Headers = tmp.Headers()
req.Method = tmp.Method()
req.URL = tmp.URL()
req.Query = tmp.Query()
req.Path = tmp.Path()
req.Body = tmp.Body()
req.Params = tmp.Params()
req.Headers = tmp.Headers()

return r, nil
return req, resp, nil
}

func executeResponseModifiers(ctx context.Context, respModifiers []func(interface{}) (interface{}, error), r *Response, req RequestWrapper) (*Response, error) {
Expand Down
4 changes: 2 additions & 2 deletions proxy/plugin/modifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func ExampleLoadWithLoggerAndContext() {
fmt.Println(err.Error())
return
}
if total != 2 {
if total != 3 {
fmt.Printf("unexpected number of loaded plugins!. have %d, want 2\n", total)
return
}
Expand Down Expand Up @@ -92,7 +92,7 @@ func TestLoad(t *testing.T) {
t.Error(err.Error())
t.Fail()
}
if total != 2 {
if total != 3 {
t.Errorf("unexpected number of loaded plugins!. have %d, want 2", total)
}

Expand Down
99 changes: 99 additions & 0 deletions proxy/plugin/tests/shortcircuit/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package main

import (
"context"
"errors"
"io"
"net/http"
"net/url"
"strings"
)

func main() {}

var ModifierRegisterer = registerer("lura-shortcircuit-example")

type registerer string

func (r registerer) RegisterModifiers(f func(
name string,
modifierFactory func(map[string]interface{}) func(interface{}) (interface{}, error),
appliesToRequest bool,
appliesToResponse bool,
),
) {
f(string(r)+"-request", r.requestModifierFactory, true, false)
f(string(r)+"-response", r.reqsponseModifierFactory, false, true)
}

func (r registerer) requestModifierFactory(_ map[string]interface{}) func(interface{}) (interface{}, error) {
return func(input interface{}) (interface{}, error) {
req, ok := input.(RequestWrapper)
if !ok {
return nil, unknownTypeErr
}

header := make(http.Header)
header.Add("X-Plugin-Request", "shortcircuit")
return responseWrapper{
request: req,
io: strings.NewReader("shortcircuit"),
headers: header,
statusCode: http.StatusTeapot,
}, nil
}
}

func (r registerer) reqsponseModifierFactory(_ map[string]interface{}) func(interface{}) (interface{}, error) {
return func(input interface{}) (interface{}, error) {
resp, ok := input.(ResponseWrapper)
if !ok {
return nil, unknownTypeErr
}

header := http.Header(resp.Headers())
header.Add("X-Plugin-Response", "shortcircuit")
return resp, nil
}
}

type responseWrapper struct {
ctx context.Context
request interface{}
data map[string]interface{}
isComplete bool
headers map[string][]string
statusCode int
io io.Reader
}

func (r responseWrapper) Context() context.Context { return r.ctx }
func (r responseWrapper) Request() interface{} { return r.request }
func (r responseWrapper) Data() map[string]interface{} { return r.data }
func (r responseWrapper) IsComplete() bool { return r.isComplete }
func (r responseWrapper) Io() io.Reader { return r.io }
func (r responseWrapper) Headers() map[string][]string { return r.headers }
func (r responseWrapper) StatusCode() int { return r.statusCode }

var unknownTypeErr = errors.New("unknown request type")

type RequestWrapper interface {
Context() context.Context
Params() map[string]string
Headers() map[string][]string
Body() io.ReadCloser
Method() string
URL() *url.URL
Query() url.Values
Path() string
}

type ResponseWrapper interface {
Context() context.Context
Request() interface{}
Data() map[string]interface{}
IsComplete() bool
Io() io.Reader
Headers() map[string][]string
StatusCode() int
}
97 changes: 97 additions & 0 deletions proxy/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,103 @@ func TestNewPluginMiddleware_error_response(t *testing.T) {
}
}

func TestNewPluginMiddleware_shortcircuit_request(t *testing.T) {
plugin.LoadWithLogger("./plugin/tests", ".so", plugin.RegisterModifier, logging.NoOp)

validator := func(ctx context.Context, r *Request) (*Response, error) {
t.Helper()
t.Error("the backend should not be called")
return nil, nil
}

bknd := NewBackendPluginMiddleware(
logging.NoOp,
&config.Backend{},
)(validator)

p := NewPluginMiddleware(
logging.NoOp,
&config.EndpointConfig{
ExtraConfig: map[string]interface{}{
plugin.Namespace: map[string]interface{}{
"name": []interface{}{
"lura-shortcircuit-example-request",
},
},
},
},
)(bknd)

resp, err := p(context.Background(), &Request{Path: "/bar"})
if err != nil {
t.Error(err.Error())
}

if resp == nil {
t.Errorf("unexpected response: %v", resp)
return
}

if sc := resp.Metadata.StatusCode; sc != http.StatusTeapot {
t.Errorf("unexpected status code: %d", sc)
}

header := http.Header(resp.Metadata.Headers)
if h := header.Get("X-Plugin-Request"); h != "shortcircuit" {
t.Errorf("unexpected header: %s", h)
}
}

func TestNewPluginMiddleware_shortcircuit_request_response(t *testing.T) {
plugin.LoadWithLogger("./plugin/tests", ".so", plugin.RegisterModifier, logging.NoOp)

validator := func(ctx context.Context, r *Request) (*Response, error) {
t.Error("the backend should not be called")
return nil, nil
}

bknd := NewBackendPluginMiddleware(
logging.NoOp,
&config.Backend{},
)(validator)

p := NewPluginMiddleware(
logging.NoOp,
&config.EndpointConfig{
ExtraConfig: map[string]interface{}{
plugin.Namespace: map[string]interface{}{
"name": []interface{}{
"lura-shortcircuit-example-request",
"lura-shortcircuit-example-response",
},
},
},
},
)(bknd)

resp, err := p(context.Background(), &Request{Path: "/bar"})
if err != nil {
t.Error(err.Error())
}

if resp == nil {
t.Errorf("unexpected response: %v", resp)
return
}

if sc := resp.Metadata.StatusCode; sc != http.StatusTeapot {
t.Errorf("unexpected status code: %d", sc)
}

header := http.Header(resp.Metadata.Headers)
if h := header.Get("X-Plugin-Request"); h != "shortcircuit" {
t.Errorf("unexpected header: %s", h)
}
if h := header.Get("X-Plugin-Response"); h != "shortcircuit" {
t.Errorf("unexpected header: %s", h)
}
}

func TestNewPluginMiddleware_PoisonedPlugin(t *testing.T) {
plugin.RegisterModifier("poisoned", func(map[string]interface{}) func(interface{}) (interface{}, error) {
return nil
Expand Down

0 comments on commit 80b1009

Please sign in to comment.