diff --git a/ctx/ctx.go b/ctx/ctx.go index 875df130eb3..3d261d05540 100644 --- a/ctx/ctx.go +++ b/ctx/ctx.go @@ -50,6 +50,7 @@ const ( // CacheOptions holds cache options required for cache writer middleware. CacheOptions OASDefinition + SelfLooping ) func setContext(r *http.Request, ctx context.Context) { diff --git a/docs/diagrams/middleware-looping.png b/docs/diagrams/middleware-looping.png new file mode 100644 index 00000000000..93895b052cc Binary files /dev/null and b/docs/diagrams/middleware-looping.png differ diff --git a/gateway/api.go b/gateway/api.go index 780f504fa59..a9b4f2e1c26 100644 --- a/gateway/api.go +++ b/gateway/api.go @@ -43,6 +43,8 @@ import ( "sync" "time" + "github.com/TykTechnologies/tyk/internal/httpctx" + "github.com/getkin/kin-openapi/openapi3" "github.com/TykTechnologies/tyk/config" @@ -3218,6 +3220,11 @@ func ctxSetCheckLoopLimits(r *http.Request, b bool) { // Should we check Rate limits and Quotas? func ctxCheckLimits(r *http.Request) bool { + // If this is a self loop, do not need to check the limits and quotas. + if httpctx.IsSelfLooping(r) { + return false + } + // If looping disabled, allow all if !ctxLoopingEnabled(r) { return true diff --git a/gateway/api_loader.go b/gateway/api_loader.go index 8c0fe06bed2..aaf635001e4 100644 --- a/gateway/api_loader.go +++ b/gateway/api_loader.go @@ -27,6 +27,7 @@ import ( "github.com/TykTechnologies/tyk/storage" "github.com/TykTechnologies/tyk/trace" + "github.com/TykTechnologies/tyk/internal/httpctx" "github.com/TykTechnologies/tyk/internal/httputil" "github.com/TykTechnologies/tyk/internal/otel" ) @@ -589,6 +590,7 @@ func (d *DummyProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var handler http.Handler if r.URL.Hostname() == "self" { + httpctx.SetSelfLooping(r, true) if h, found := d.Gw.apisHandlesByID.Load(d.SH.Spec.APIID); found { if chain, ok := h.(*ChainObject); ok { handler = chain.ThisHandler diff --git a/gateway/looping_test.go b/gateway/looping_test.go index ea573350d80..835083b0eaa 100644 --- a/gateway/looping_test.go +++ b/gateway/looping_test.go @@ -6,6 +6,7 @@ package gateway import ( "encoding/json" + "net/http" "sync" "testing" @@ -324,7 +325,101 @@ func TestLooping(t *testing.T) { {Path: "/external/", Code: 200}, }...) }) +} + +func TestLooping_AnotherAPIWithAuthTokens(t *testing.T) { + ts := StartTest(nil) + defer ts.Close() + + // Looping to another api with auth tokens + specs := ts.Gw.BuildAndLoadAPI(func(spec *APISpec) { + spec.APIDefinition.APIID = "apia" + spec.APIDefinition.Name = "ApiA" + spec.APIDefinition.Proxy.ListenPath = "/apia" + spec.APIDefinition.UseKeylessAccess = false + spec.APIDefinition.AuthConfigs = map[string]apidef.AuthConfig{ + "authToken": { + AuthHeaderName: "Authorization", + }, + } + + UpdateAPIVersion(spec, "v1", func(v *apidef.VersionInfo) { + v.UseExtendedPaths = true + v.ExtendedPaths.URLRewrite = []apidef.URLRewriteMeta{{ + Path: "/", + Method: http.MethodGet, + MatchPattern: ".*", + RewriteTo: "tyk://apib", + }} + }) + }, func(spec *APISpec) { + spec.APIDefinition.APIID = "apib" + spec.APIDefinition.Name = "ApiB" + spec.APIDefinition.Proxy.ListenPath = "/apib" + spec.APIDefinition.UseKeylessAccess = false + spec.APIDefinition.AuthConfigs = map[string]apidef.AuthConfig{ + "authToken": { + AuthHeaderName: "X-Api-Key", + }, + } + }) + specApiA := specs[0] + specApiB := specs[1] + + _, authKeyForApiA := ts.CreateSession(func(s *user.SessionState) { + s.AccessRights = map[string]user.AccessDefinition{ + specApiA.APIDefinition.APIID: { + APIName: specApiA.APIDefinition.Name, + APIID: specApiA.APIDefinition.APIID, + Versions: []string{"default"}, + AllowanceScope: specApiA.APIDefinition.APIID, + }, + } + s.OrgID = specApiA.APIDefinition.OrgID + }) + _, authKeyForApiB := ts.CreateSession(func(s *user.SessionState) { + s.AccessRights = map[string]user.AccessDefinition{ + specApiB.APIDefinition.APIID: { + APIName: specApiB.APIDefinition.Name, + APIID: specApiB.APIDefinition.APIID, + Versions: []string{"default"}, + AllowanceScope: specApiB.APIDefinition.APIID, + }, + } + s.OrgID = specApiB.APIDefinition.OrgID + }) + + headersWithApiBToken := map[string]string{ + "Authorization": authKeyForApiA, + "X-Api-Key": authKeyForApiB, + } + headersWithoutApiBToken := map[string]string{ + "Authorization": authKeyForApiA, + "X-Api-Key": "some-string", + } + headersWithOnlyApiAToken := map[string]string{ + "Authorization": authKeyForApiA, + } + _, _ = ts.Run(t, []test.TestCase{ + { + Headers: headersWithApiBToken, + Path: "/apia", + Code: http.StatusOK, + }, + { + Headers: headersWithoutApiBToken, + Path: "/apia", + Code: http.StatusForbidden, + BodyMatch: "Access to this API has been disallowed", + }, + { + Headers: headersWithOnlyApiAToken, + Path: "/apia", + Code: http.StatusUnauthorized, + BodyMatch: "Authorization field missing", + }, + }...) } func TestConcurrencyReloads(t *testing.T) { diff --git a/gateway/middleware_test.go b/gateway/middleware_test.go index fa7eb5242fd..5941dc814cd 100644 --- a/gateway/middleware_test.go +++ b/gateway/middleware_test.go @@ -354,7 +354,8 @@ func TestQuotaNotAppliedWithURLRewrite(t *testing.T) { spec := ts.Gw.BuildAndLoadAPI(func(spec *APISpec) { spec.Proxy.ListenPath = "/quota-test" spec.UseKeylessAccess = false - UpdateAPIVersion(spec, "Default", func(v *apidef.VersionInfo) { + UpdateAPIVersion(spec, "v1", func(v *apidef.VersionInfo) { + v.UseExtendedPaths = true v.ExtendedPaths.URLRewrite = []apidef.URLRewriteMeta{{ Path: "/abc", Method: http.MethodGet, diff --git a/gateway/mw_auth_key.go b/gateway/mw_auth_key.go index 0c7ecfda9df..8e6bf59f867 100644 --- a/gateway/mw_auth_key.go +++ b/gateway/mw_auth_key.go @@ -6,16 +6,15 @@ import ( "strings" "time" - "github.com/TykTechnologies/tyk/internal/crypto" - "github.com/TykTechnologies/tyk/internal/otel" - "github.com/TykTechnologies/tyk/storage" - - "github.com/TykTechnologies/tyk/user" - "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/internal/crypto" + "github.com/TykTechnologies/tyk/internal/httpctx" + "github.com/TykTechnologies/tyk/internal/otel" "github.com/TykTechnologies/tyk/request" signaturevalidator "github.com/TykTechnologies/tyk/signature_validator" + "github.com/TykTechnologies/tyk/storage" + "github.com/TykTechnologies/tyk/user" ) const ( @@ -95,7 +94,7 @@ func (k *AuthKey) ProcessRequest(_ http.ResponseWriter, r *http.Request, _ inter } // skip auth key check if the request is looped. - if ses := ctxGetSession(r); ses != nil && !ctxCheckLimits(r) { + if ses := ctxGetSession(r); ses != nil && httpctx.IsSelfLooping(r) { return nil, http.StatusOK } diff --git a/internal/httpctx/context.go b/internal/httpctx/context.go new file mode 100644 index 00000000000..a139172e9a5 --- /dev/null +++ b/internal/httpctx/context.go @@ -0,0 +1,28 @@ +package httpctx + +import ( + "context" + "net/http" +) + +type Value[T any] struct { + Key any +} + +func NewValue[T any](key any) *Value[T] { + return &Value[T]{Key: key} +} + +func (v *Value[T]) Get(r *http.Request) (res T) { + if val := r.Context().Value(v.Key); val != nil { + res, _ = val.(T) + } + return +} + +func (v *Value[T]) Set(r *http.Request, val T) *http.Request { + ctx := context.WithValue(r.Context(), v.Key, val) + h := r.WithContext(ctx) + *r = *h + return h +} diff --git a/internal/httpctx/context_test.go b/internal/httpctx/context_test.go new file mode 100644 index 00000000000..758fc60b37c --- /dev/null +++ b/internal/httpctx/context_test.go @@ -0,0 +1,64 @@ +package httpctx_test + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/TykTechnologies/tyk/internal/httpctx" +) + +func TestValue_SetAndGet(t *testing.T) { + // Define a key and instantiate a new Value with type map[string]any + key := "testKey" + value := httpctx.NewValue[map[string]any](key) + + // Prepare a map to store in context + expectedData := map[string]any{ + "userID": 123, + "userRole": "admin", + } + + // Create a new HTTP request using httptest + req := httptest.NewRequest("GET", "/", nil) + + // Set the value in the request's context + req = value.Set(req, expectedData) + + // Retrieve the value from the context + retrievedData := value.Get(req) + assert.Equal(t, expectedData, retrievedData, "Retrieved data does not match expected data") +} + +func TestValue_GetWithMissingKey(t *testing.T) { + // Define a key and instantiate a new Value with type map[string]any + key := "missingKey" + value := httpctx.NewValue[map[string]any](key) + + // Create a new HTTP request using httptest + req := httptest.NewRequest("GET", "/", nil) + + // Try to retrieve the value from the context + retrievedData := value.Get(req) + + // Expect not to find any data + assert.Nil(t, retrievedData, "Expected retrieved data to be nil for a missing key") +} + +func TestValue_SetDifferentTypes(t *testing.T) { + // Test using a different type for Value, e.g., int + intKey := "intKey" + intValue := httpctx.NewValue[int](intKey) + + // Create a new HTTP request using httptest + req := httptest.NewRequest("GET", "/", nil) + + // Set an int value in the context + expectedInt := 42 + req = intValue.Set(req, expectedInt) + + // Retrieve the int value from the context + retrievedInt := intValue.Get(req) + assert.Equal(t, expectedInt, retrievedInt, "Retrieved int value does not match expected value") +} diff --git a/internal/httpctx/looping.go b/internal/httpctx/looping.go new file mode 100644 index 00000000000..708cd5eb3c5 --- /dev/null +++ b/internal/httpctx/looping.go @@ -0,0 +1,19 @@ +package httpctx + +import ( + "net/http" + + "github.com/TykTechnologies/tyk/ctx" +) + +var selfLoopingValue = NewValue[bool](ctx.SelfLooping) + +// SetSelfLooping updates the request context with a boolean value indicating whether the request is in a self-looping state. +func SetSelfLooping(r *http.Request, value bool) { + selfLoopingValue.Set(r, value) +} + +// IsSelfLooping returns true if the request is flagged as self-looping, indicating it originates and targets the same service. +func IsSelfLooping(r *http.Request) bool { + return selfLoopingValue.Get(r) +} diff --git a/internal/httpctx/looping_test.go b/internal/httpctx/looping_test.go new file mode 100644 index 00000000000..878984fee54 --- /dev/null +++ b/internal/httpctx/looping_test.go @@ -0,0 +1,19 @@ +package httpctx_test + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/TykTechnologies/tyk/internal/httpctx" +) + +func TestSetSelfLooping(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + assert.False(t, httpctx.IsSelfLooping(req)) + httpctx.SetSelfLooping(req, true) + assert.True(t, httpctx.IsSelfLooping(req)) + httpctx.SetSelfLooping(req, false) + assert.False(t, httpctx.IsSelfLooping(req)) +}