From 79a9536bc5533e1159be7c571892d06cd5e9f980 Mon Sep 17 00:00:00 2001 From: iamdjones <7071112+iamdjones@users.noreply.github.com> Date: Wed, 18 Dec 2024 20:10:30 -0600 Subject: [PATCH] add check for host rewrite to DynamicMiddleware (js plugin) --- gateway/mw_js_plugin.go | 33 ++++++++++++-- gateway/mw_js_plugin_test.go | 88 ++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 4 deletions(-) diff --git a/gateway/mw_js_plugin.go b/gateway/mw_js_plugin.go index e7bfc470c5c..6b98a81d7a7 100644 --- a/gateway/mw_js_plugin.go +++ b/gateway/mw_js_plugin.go @@ -16,6 +16,7 @@ import ( "time" "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/ctx" "github.com/robertkrimen/otto" _ "github.com/robertkrimen/otto/underscore" @@ -243,10 +244,7 @@ func (d *DynamicMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Reques // make sure request's body can be re-read again nopCloseRequestBody(r) - r.URL, err = url.Parse(newRequestData.Request.URL) - if err != nil { - return nil, http.StatusOK - } + d.SetUrlAndCheckHostRewrite(newRequestData.Request.URL, r) ignoreCanonical := d.Gw.GetConfig().IgnoreCanonicalMIMEHeaderKey // Delete and set headers @@ -323,6 +321,33 @@ func mapStrsToIfaces(m map[string]string) map[string]interface{} { return m2 } +func (d *DynamicMiddleware) SetUrlAndCheckHostRewrite(newTarget string, r *http.Request) error { + + // During looping target can be API name + // Need make it compatible with URL parser + if strings.HasPrefix(newTarget, LoopScheme) { + newTarget = LoopHostRE.ReplaceAllStringFunc(newTarget, func(match string) string { + host := strings.TrimPrefix(match, LoopScheme+"://") + return LoopingUrl(host) + }) + } + + newAsURL, errParseNew := url.Parse(newTarget) + if errParseNew != nil { + return errParseNew + } + + if shouldRewriteHost(r.URL, newAsURL) { + log.Debug("Detected a host rewrite in pattern!") + d.Spec.URLRewriteEnabled = true + setCtxValue(r, ctx.RetainHost, true) + } + + r.URL = newAsURL + + return nil +} + // --- Utility functions during startup to ensure a sane VM is present for each API Def ---- type JSVM struct { diff --git a/gateway/mw_js_plugin_test.go b/gateway/mw_js_plugin_test.go index f9cfce13840..9fcd2e79bbe 100644 --- a/gateway/mw_js_plugin_test.go +++ b/gateway/mw_js_plugin_test.go @@ -1049,3 +1049,91 @@ func testJSVM_Auth(t *testing.T, hashKeys bool) { }, }...) } + +func TestDynamicMiddleware_SetUrlAndCheckHostRewrite(t *testing.T) { + ts := StartTest(nil) + defer ts.Close() + + type args struct { + oldPath string + newTarget string + } + + tests := []struct { + name string + args args + errExpected bool + retainHostVal interface{} + }{ + { + name: "no host rewrite", + args: args{ + oldPath: "/hello", + newTarget: "/status", + }, + errExpected: false, + retainHostVal: nil, + }, + { + name: "invalid new path", + args: args{ + oldPath: "/hello", + newTarget: "http:// example.com/status", + }, + errExpected: true, + retainHostVal: nil, + }, + { + name: "host rewrite", + args: args{ + oldPath: "/hello", + newTarget: "http://example.com/status", + }, + errExpected: false, + retainHostVal: true, + }, + { + name: "scheme in oldPath - host rewrite", + args: args{ + oldPath: "http://tyk-gateway/hello", + newTarget: "http://example.com/status", + }, + errExpected: false, + retainHostVal: true, + }, + { + name: "scheme in oldPath - no host rewrite", + args: args{ + oldPath: "http://tyk-gateway/hello", + newTarget: "/status", + }, + errExpected: false, + retainHostVal: nil, + }, + { + name: "same host for new and old URL", + args: args{ + oldPath: "http://tyk-gateway/hello", + newTarget: "http://tyk-gateway/status", + }, + errExpected: false, + retainHostVal: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ts := StartTest(nil) + defer ts.Close() + + m := &DynamicMiddleware{ + BaseMiddleware: &BaseMiddleware{ + Spec: &APISpec{APIDefinition: &apidef.APIDefinition{}}, + Gw: ts.Gw, + }} + r := httptest.NewRequest("GET", tt.args.oldPath, nil) + err := m.SetUrlAndCheckHostRewrite(tt.args.newTarget, r) + assert.Equal(t, tt.errExpected, err != nil) + assert.Equal(t, tt.retainHostVal, r.Context().Value(ctx.RetainHost)) + }) + } +}