diff --git a/gateway/mw_url_rewrite.go b/gateway/mw_url_rewrite.go index b3e75f37505..80a24537d22 100644 --- a/gateway/mw_url_rewrite.go +++ b/gateway/mw_url_rewrite.go @@ -2,7 +2,7 @@ package gateway import ( "fmt" - "io/ioutil" + "io" "net/http" "net/textproto" "net/url" @@ -692,22 +692,38 @@ func checkContextTrigger(r *http.Request, options map[string]apidef.StringRegexM func checkPayload(r *http.Request, options apidef.StringRegexMap, triggernum int) bool { contextData := ctxGetData(r) - bodyBytes, _ := ioutil.ReadAll(r.Body) + nopCloseRequestBody(r) + // Read the entire request body + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + log.WithError(err).Error("error reading request body") + return false + } + + // Perform regex matching on the request body matched, matches := options.FindAllStringSubmatch(string(bodyBytes), -1) if matched { kn := buildTriggerKey(triggernum, "payload") + if len(matches) == 0 { + // If there are no matches, simply return true return true } + + // Store the first match in the context data contextData[kn] = matches[0][0] + // Iterate over all matches and add them to the context data for i, match := range matches { if len(match) > 0 { addMatchToContextData(contextData, match, triggernum, "payload", i) } } + + // Update the context data with the modified map + ctxSetData(r, contextData) return true } diff --git a/gateway/mw_url_rewrite_test.go b/gateway/mw_url_rewrite_test.go index 1da2fbc3072..d3f1e03a467 100644 --- a/gateway/mw_url_rewrite_test.go +++ b/gateway/mw_url_rewrite_test.go @@ -2,6 +2,7 @@ package gateway import ( "bytes" + "io" "net/http" "net/http/httptest" "testing" @@ -158,11 +159,12 @@ func BenchmarkRewriter(b *testing.B) { func TestRewriterTriggers(t *testing.T) { type TestDef struct { - name string - pattern, to string - in, want string - triggerConf []apidef.RoutingTrigger - req *http.Request + name string + pattern, to string + in, want string + triggerConf []apidef.RoutingTrigger + req *http.Request + payloadTrigger bool } ts := StartTest(nil) @@ -194,6 +196,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -220,6 +223,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -250,6 +254,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -280,6 +285,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -310,6 +316,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -355,6 +362,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -385,6 +393,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -415,6 +424,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -439,6 +449,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -463,6 +474,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -490,6 +502,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -520,6 +533,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -550,6 +564,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -584,6 +599,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -618,6 +634,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -652,6 +669,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -675,6 +693,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + true, } }, func() TestDef { @@ -698,6 +717,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + true, } }, func() TestDef { @@ -721,6 +741,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + true, } }, func() TestDef { @@ -744,6 +765,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + true, } }, func() TestDef { @@ -773,6 +795,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + true, } }, func() TestDef { @@ -802,6 +825,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + true, } }, func() TestDef { @@ -825,6 +849,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -854,6 +879,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + true, } }, func() TestDef { @@ -883,6 +909,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + true, } }, func() TestDef { @@ -906,6 +933,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -929,6 +957,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -958,6 +987,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -987,6 +1017,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -1017,6 +1048,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -1044,6 +1076,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -1073,6 +1106,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, func() TestDef { @@ -1096,6 +1130,7 @@ func TestRewriterTriggers(t *testing.T) { }, }, r, + false, } }, } @@ -1113,6 +1148,14 @@ func TestRewriterTriggers(t *testing.T) { if err != nil { t.Error("compile failed:", err) } + + //added check to ensure that reading the payload to check for the trigger does not break the request + if tc.payloadTrigger { + body, err := io.ReadAll(tc.req.Body) + assert.NotEqual(t, "", string(body)) + assert.NoError(t, err) + } + if got != tc.want { t.Errorf("rewrite failed, want %q, got %q", tc.want, got) } diff --git a/gateway/mw_validate_json.go b/gateway/mw_validate_json.go index c0a9b172437..87f8e4de3c6 100644 --- a/gateway/mw_validate_json.go +++ b/gateway/mw_validate_json.go @@ -3,7 +3,7 @@ package gateway import ( "errors" "fmt" - "io/ioutil" + "io" "net/http" "github.com/TykTechnologies/gojsonschema" @@ -48,8 +48,9 @@ func (k *ValidateJSON) ProcessRequest(w http.ResponseWriter, r *http.Request, _ } } + nopCloseRequestBody(r) // Load input body into gojsonschema - bodyBytes, err := ioutil.ReadAll(r.Body) + bodyBytes, err := io.ReadAll(r.Body) if err != nil { return err, http.StatusBadRequest }