diff --git a/apidef/api_definitions.go b/apidef/api_definitions.go index d1c17a4200b..9b49975d169 100644 --- a/apidef/api_definitions.go +++ b/apidef/api_definitions.go @@ -762,6 +762,35 @@ type APIDefinition struct { VersionName string `bson:"-" json:"-"` DetailedTracing bool `bson:"detailed_tracing" json:"detailed_tracing"` + + // UpstreamAuth stores information about authenticating against upstream. + UpstreamAuth UpstreamAuth `bson:"upstream_auth" json:"upstream_auth"` +} + +// UpstreamAuth holds the configurations related to upstream API authentication. +type UpstreamAuth struct { + // Enabled enables upstream API authentication. + Enabled bool `bson:"enabled" json:"enabled"` + // BasicAuth holds the basic authentication configuration for upstream API authentication. + BasicAuth UpstreamBasicAuth `bson:"basic_auth" json:"basic_auth"` +} + +// IsEnabled checks if UpstreamAuthentication is enabled for the API. +func (u *UpstreamAuth) IsEnabled() bool { + return u.Enabled && u.BasicAuth.Enabled +} + +// UpstreamBasicAuth holds upstream basic authentication configuration. +type UpstreamBasicAuth struct { + // Enabled enables upstream basic authentication. + Enabled bool `bson:"enabled" json:"enabled,omitempty"` + // Username is the username to be used for upstream basic authentication. + Username string `bson:"username" json:"username"` + // Password is the password to be used for upstream basic authentication. + Password string `bson:"password" json:"password"` + // HeaderName is the custom header name to be used for upstream basic authentication. + // Defaults to `Authorization`. + HeaderName string `bson:"header_name" json:"header_name"` } type AnalyticsPluginConfig struct { diff --git a/apidef/oas/schema/x-tyk-api-gateway.json b/apidef/oas/schema/x-tyk-api-gateway.json index 8a88d02000c..84b21cd6302 100644 --- a/apidef/oas/schema/x-tyk-api-gateway.json +++ b/apidef/oas/schema/x-tyk-api-gateway.json @@ -1349,6 +1349,9 @@ }, "rateLimit": { "$ref": "#/definitions/X-Tyk-RateLimit" + }, + "authentication": { + "$ref": "#/definitions/X-Tyk-UpstreamAuthentication" } }, "required": [ @@ -2065,6 +2068,40 @@ "X-Tyk-DomainDef": { "type": "string", "pattern": "^([*a-zA-Z0-9-]+(\\.[*a-zA-Z0-9-]+)*)(:\\d+)?$" + }, + "X-Tyk-UpstreamAuthentication": { + "type": "object", + "properties": { + "enabled": { + "type": "boolean" + }, + "basicAuth": { + "$ref": "#/definitions/X-Tyk-UpstreamBasicAuthentication" + } + }, + "required": [ + "enabled" + ] + }, + "X-Tyk-UpstreamBasicAuthentication": { + "type": "object", + "properties": { + "enabled": { + "type": "boolean" + }, + "headerName": { + "type": "string" + }, + "username": { + "type": "string" + }, + "password": { + "type": "string" + } + }, + "required": [ + "enabled" + ] } } } diff --git a/apidef/oas/upstream.go b/apidef/oas/upstream.go index 0077bdd3961..9d322f11af2 100644 --- a/apidef/oas/upstream.go +++ b/apidef/oas/upstream.go @@ -29,6 +29,9 @@ type Upstream struct { // RateLimit contains the configuration related to API level rate limit. RateLimit *RateLimit `bson:"rateLimit,omitempty" json:"rateLimit,omitempty"` + + // Authentication contains the configuration related to upstream authentication. + Authentication *UpstreamAuth `bson:"authentication,omitempty" json:"authentication,omitempty"` } // Fill fills *Upstream from apidef.APIDefinition. @@ -79,6 +82,15 @@ func (u *Upstream) Fill(api apidef.APIDefinition) { if ShouldOmit(u.RateLimit) { u.RateLimit = nil } + + if u.Authentication == nil { + u.Authentication = &UpstreamAuth{} + } + + u.Authentication.Fill(api.UpstreamAuth) + if ShouldOmit(u.Authentication) { + u.Authentication = nil + } } // ExtractTo extracts *Upstream into *apidef.APIDefinition. @@ -129,6 +141,15 @@ func (u *Upstream) ExtractTo(api *apidef.APIDefinition) { } u.RateLimit.ExtractTo(api) + + if u.Authentication == nil { + u.Authentication = &UpstreamAuth{} + defer func() { + u.Authentication = nil + }() + } + + u.Authentication.ExtractTo(&api.UpstreamAuth) } // ServiceDiscovery holds configuration required for service discovery. @@ -529,3 +550,69 @@ func (r *RateLimitEndpoint) ExtractTo(meta *apidef.RateLimitMeta) { meta.Rate = float64(r.Rate) meta.Per = r.Per.Seconds() } + +// UpstreamAuth holds the configurations related to upstream API authentication. +type UpstreamAuth struct { + // Enabled enables upstream API authentication. + Enabled bool `bson:"enabled" json:"enabled"` + // BasicAuth holds the basic authentication configuration for upstream API authentication. + BasicAuth *UpstreamBasicAuth `bson:"basicAuth,omitempty" json:"basicAuth,omitempty"` +} + +// Fill fills *UpstreamAuth from apidef.UpstreamAuth. +func (u *UpstreamAuth) Fill(api apidef.UpstreamAuth) { + u.Enabled = api.Enabled + + if u.BasicAuth == nil { + u.BasicAuth = &UpstreamBasicAuth{} + } + + u.BasicAuth.Fill(api.BasicAuth) + if ShouldOmit(u.BasicAuth) { + u.BasicAuth = nil + } +} + +// ExtractTo extracts *UpstreamAuth into *apidef.UpstreamAuth. +func (u *UpstreamAuth) ExtractTo(api *apidef.UpstreamAuth) { + api.Enabled = u.Enabled + + if u.BasicAuth == nil { + u.BasicAuth = &UpstreamBasicAuth{} + defer func() { + u.BasicAuth = nil + }() + } + + u.BasicAuth.ExtractTo(&api.BasicAuth) +} + +// UpstreamBasicAuth holds upstream basic authentication configuration. +type UpstreamBasicAuth struct { + // Enabled enables upstream basic authentication. + Enabled bool `bson:"enabled" json:"enabled"` + // HeaderName is the custom header name to be used for upstream basic authentication. + // Defaults to `Authorization`. + HeaderName string `bson:"headerName" json:"headerName"` + // Username is the username to be used for upstream basic authentication. + Username string `bson:"username" json:"username"` + // Password is the password to be used for upstream basic authentication. + Password string `bson:"password" json:"password"` +} + +// Fill fills *UpstreamBasicAuth from apidef.UpstreamBasicAuth. +func (u *UpstreamBasicAuth) Fill(api apidef.UpstreamBasicAuth) { + u.Enabled = api.Enabled + u.HeaderName = api.HeaderName + u.Username = api.Username + u.Password = api.Password +} + +// ExtractTo extracts *UpstreamBasicAuth into *apidef.UpstreamBasicAuth. +func (u *UpstreamBasicAuth) ExtractTo(api *apidef.UpstreamBasicAuth) { + api.Enabled = u.Enabled + api.Enabled = u.Enabled + api.HeaderName = u.HeaderName + api.Username = u.Username + api.Password = u.Password +} diff --git a/apidef/schema.go b/apidef/schema.go index 04c7eb0682f..c080bddec47 100644 --- a/apidef/schema.go +++ b/apidef/schema.go @@ -761,7 +761,32 @@ const Schema = `{ }, "detailed_tracing": { "type": "boolean" - } + }, + "upstream_auth": { + "type": "object", + "properties": { + "enabled": { + "type": "boolean" + }, + "basic_auth": { + "type": "object", + "properties": { + "enabled": { + "type": "boolean" + }, + "username": { + "type": "string" + }, + "password": { + "type": "string" + }, + "header_name": { + "type": "string" + } + } + } + } + } }, "required": [ "name", diff --git a/ctx/ctx.go b/ctx/ctx.go index 13b7fd76aeb..43a31a2d838 100644 --- a/ctx/ctx.go +++ b/ctx/ctx.go @@ -5,6 +5,8 @@ import ( "encoding/json" "net/http" + "github.com/TykTechnologies/tyk/internal/httputil" + "github.com/TykTechnologies/tyk/apidef/oas" "github.com/TykTechnologies/tyk/config" @@ -53,11 +55,6 @@ const ( OASDefinition ) -func setContext(r *http.Request, ctx context.Context) { - r2 := r.WithContext(ctx) - *r = *r2 -} - func ctxSetSession(r *http.Request, s *user.SessionState, scheduleUpdate bool, hashKey bool) { if s == nil { @@ -81,7 +78,7 @@ func ctxSetSession(r *http.Request, s *user.SessionState, scheduleUpdate bool, h s.Touch() } - setContext(r, ctx) + httputil.SetContext(r, ctx) } func GetAuthToken(r *http.Request) string { @@ -119,7 +116,7 @@ func SetSession(r *http.Request, s *user.SessionState, scheduleUpdate bool, hash func SetDefinition(r *http.Request, s *apidef.APIDefinition) { ctx := r.Context() ctx = context.WithValue(ctx, Definition, s) - setContext(r, ctx) + httputil.SetContext(r, ctx) } func GetDefinition(r *http.Request) *apidef.APIDefinition { diff --git a/gateway/api_loader.go b/gateway/api_loader.go index 01ad37e940e..8175e0dea0a 100644 --- a/gateway/api_loader.go +++ b/gateway/api_loader.go @@ -468,6 +468,8 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, } } + gw.mwAppendEnabled(&chainArray, &UpstreamBasicAuth{BaseMiddleware: baseMid}) + chain = alice.New(chainArray...).Then(&DummyProxyHandler{SH: SuccessHandler{baseMid}, Gw: gw}) if !spec.UseKeylessAccess { diff --git a/gateway/mw_upstream_basic_auth.go b/gateway/mw_upstream_basic_auth.go new file mode 100644 index 00000000000..066d560d107 --- /dev/null +++ b/gateway/mw_upstream_basic_auth.go @@ -0,0 +1,64 @@ +package gateway + +import ( + "net/http" + + "github.com/TykTechnologies/tyk/internal/httputil" + + "github.com/TykTechnologies/tyk/header" +) + +// UpstreamBasicAuth is a middleware that will do basic authentication for upstream connections. +// UpstreamBasicAuth middleware is only supported in Tyk OAS API definitions. +type UpstreamBasicAuth struct { + *BaseMiddleware +} + +// Name returns the name of middleware. +func (t *UpstreamBasicAuth) Name() string { + return "UpstreamBasicAuth" +} + +// EnabledForSpec returns true if the middleware is enabled based on API Spec. +func (t *UpstreamBasicAuth) EnabledForSpec() bool { + if !t.Spec.UpstreamAuth.Enabled { + return false + } + + if !t.Spec.UpstreamAuth.BasicAuth.Enabled { + return false + } + + return true +} + +// ProcessRequest will inject basic auth info into request context so that it can be used during reverse proxy. +func (t *UpstreamBasicAuth) ProcessRequest(_ http.ResponseWriter, r *http.Request, _ interface{}) (error, int) { + basicAuthConfig := t.Spec.UpstreamAuth.BasicAuth + + upstreamBasicAuthProvider := UpstreamBasicAuthProvider{ + HeaderName: header.Authorization, + } + + if basicAuthConfig.HeaderName != "" { + upstreamBasicAuthProvider.HeaderName = basicAuthConfig.HeaderName + } + + upstreamBasicAuthProvider.AuthValue = httputil.AuthHeader(basicAuthConfig.Username, basicAuthConfig.Password) + + httputil.SetUpstreamAuth(r, upstreamBasicAuthProvider) + return nil, http.StatusOK +} + +// UpstreamBasicAuthProvider implements upstream auth provider. +type UpstreamBasicAuthProvider struct { + // HeaderName is the header name to be used to fill upstream auth with. + HeaderName string + // AuthValue is the value of auth header. + AuthValue string +} + +// Fill sets the request's HeaderName with AuthValue +func (u UpstreamBasicAuthProvider) Fill(r *http.Request) { + r.Header.Add(u.HeaderName, u.AuthValue) +} diff --git a/gateway/mw_upstream_basic_auth_test.go b/gateway/mw_upstream_basic_auth_test.go new file mode 100644 index 00000000000..a2b1870538a --- /dev/null +++ b/gateway/mw_upstream_basic_auth_test.go @@ -0,0 +1,143 @@ +package gateway + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/header" + "github.com/TykTechnologies/tyk/test" +) + +func TestUpstreamBasicAuthentication(t *testing.T) { + ts := StartTest(nil) + t.Cleanup(func() { + ts.Close() + }) + + userName, password, customAuthHeader := "user", "password", "Custom-Auth" + expectedAuth := fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(userName+":"+password))) + + ts.Gw.BuildAndLoadAPI( + func(spec *APISpec) { + spec.Proxy.ListenPath = "/upstream-basic-auth-enabled/" + spec.UseKeylessAccess = true + spec.UpstreamAuth = apidef.UpstreamAuth{ + Enabled: true, + BasicAuth: apidef.UpstreamBasicAuth{ + Enabled: true, + Username: userName, + Password: password, + }, + } + spec.Proxy.StripListenPath = true + }, func(spec *APISpec) { + spec.Proxy.ListenPath = "/upstream-basic-auth-custom-header/" + spec.UseKeylessAccess = true + spec.UpstreamAuth = apidef.UpstreamAuth{ + Enabled: true, + BasicAuth: apidef.UpstreamBasicAuth{ + Enabled: true, + Username: userName, + Password: password, + HeaderName: customAuthHeader, + }, + } + spec.Proxy.StripListenPath = true + }, + func(spec *APISpec) { + spec.Proxy.ListenPath = "/upstream-basic-auth-disabled/" + spec.UseKeylessAccess = true + spec.UpstreamAuth = apidef.UpstreamAuth{ + Enabled: true, + BasicAuth: apidef.UpstreamBasicAuth{ + Enabled: false, + Username: userName, + Password: password, + }, + } + spec.Proxy.StripListenPath = true + }, + func(spec *APISpec) { + spec.Proxy.ListenPath = "/upstream-auth-disabled/" + spec.UseKeylessAccess = true + spec.UpstreamAuth = apidef.UpstreamAuth{ + Enabled: false, + } + spec.Proxy.StripListenPath = true + }, + ) + + ts.Run(t, test.TestCases{ + { + Path: "/upstream-basic-auth-enabled/", + Code: http.StatusOK, + BodyMatchFunc: func(body []byte) bool { + resp := struct { + Headers map[string]string `json:"headers"` + }{} + err := json.Unmarshal(body, &resp) + assert.NoError(t, err) + + assert.Contains(t, resp.Headers, header.Authorization) + assert.NotEmpty(t, resp.Headers[header.Authorization]) + assert.Equal(t, expectedAuth, resp.Headers[header.Authorization]) + + return true + }, + }, + { + Path: "/upstream-basic-auth-custom-header/", + Code: http.StatusOK, + BodyMatchFunc: func(body []byte) bool { + resp := struct { + Headers map[string]string `json:"headers"` + }{} + err := json.Unmarshal(body, &resp) + assert.NoError(t, err) + + assert.Contains(t, resp.Headers, customAuthHeader) + assert.NotEmpty(t, resp.Headers[customAuthHeader]) + assert.Equal(t, expectedAuth, resp.Headers[customAuthHeader]) + + return true + }, + }, + { + Path: "/upstream-basic-auth-disabled/", + Code: http.StatusOK, + BodyMatchFunc: func(body []byte) bool { + resp := struct { + Headers map[string]string `json:"headers"` + }{} + err := json.Unmarshal(body, &resp) + assert.NoError(t, err) + + assert.NotContains(t, resp.Headers, header.Authorization) + + return true + }, + }, + { + Path: "/upstream-auth-disabled/", + Code: http.StatusOK, + BodyMatchFunc: func(body []byte) bool { + resp := struct { + Headers map[string]string `json:"headers"` + }{} + err := json.Unmarshal(body, &resp) + assert.NoError(t, err) + + assert.NotContains(t, resp.Headers, header.Authorization) + + return true + }, + }, + }...) + +} diff --git a/gateway/reverse_proxy.go b/gateway/reverse_proxy.go index a7755c63f5a..820c4201638 100644 --- a/gateway/reverse_proxy.go +++ b/gateway/reverse_proxy.go @@ -1219,6 +1219,8 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques } + p.addAuthInfo(outreq, req) + // do request round trip var ( res *http.Response @@ -1845,3 +1847,13 @@ func (p *ReverseProxy) IsUpgrade(req *http.Request) (string, bool) { return httputil.IsUpgrade(req) } + +func (p *ReverseProxy) addAuthInfo(outReq, req *http.Request) { + if !p.TykAPISpec.UpstreamAuth.IsEnabled() { + return + } + + if authProvider := httputil.GetUpstreamAuth(req); authProvider != nil { + authProvider.Fill(outReq) + } +} diff --git a/internal/httputil/context.go b/internal/httputil/context.go new file mode 100644 index 00000000000..2ac0b07ab2c --- /dev/null +++ b/internal/httputil/context.go @@ -0,0 +1,43 @@ +package httputil + +import ( + "context" + "net/http" + + "github.com/TykTechnologies/tyk/internal/model" +) + +// ContextKey is the key type to be used for context interactions. +type ContextKey string + +const ( + upstreamAuth = ContextKey("upstream-auth") +) + +// SetContext updates the context of a request. +func SetContext(r *http.Request, ctx context.Context) { + r2 := r.WithContext(ctx) + *r = *r2 +} + +// SetUpstreamAuth sets the header name to be used for upstream authentication. +func SetUpstreamAuth(r *http.Request, auth model.UpstreamAuthProvider) { + ctx := r.Context() + ctx = context.WithValue(ctx, upstreamAuth, auth) + SetContext(r, ctx) +} + +// GetUpstreamAuth returns the header name to be used for upstream authentication. +func GetUpstreamAuth(r *http.Request) model.UpstreamAuthProvider { + auth := r.Context().Value(upstreamAuth) + if auth == nil { + return nil + } + + provider, ok := auth.(model.UpstreamAuthProvider) + if !ok { + return nil + } + + return provider +} diff --git a/internal/httputil/context_test.go b/internal/httputil/context_test.go new file mode 100644 index 00000000000..2fffe60b167 --- /dev/null +++ b/internal/httputil/context_test.go @@ -0,0 +1,94 @@ +package httputil_test + +import ( + "context" + "net/http" + "testing" + + "github.com/TykTechnologies/tyk/internal/httputil" + + "github.com/TykTechnologies/tyk/internal/model" + + "github.com/stretchr/testify/assert" +) + +func createReq(tb testing.TB) *http.Request { + tb.Helper() + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com", nil) + assert.NoError(tb, err) + return req +} + +func TestUpstreamAuth(t *testing.T) { + t.Run("valid auth provider", func(t *testing.T) { + mockAuthProvider := &model.MockUpstreamAuthProvider{} + req := createReq(t) + + httputil.SetUpstreamAuth(req, mockAuthProvider) + + // Retrieve the auth provider from the request's context to verify it was set + retrievedAuth := httputil.GetUpstreamAuth(req) + assert.NotNil(t, retrievedAuth) + assert.Equal(t, mockAuthProvider, retrievedAuth) + }) + + t.Run("no auth provider", func(t *testing.T) { + req := createReq(t) + + retrievedAuth := httputil.GetUpstreamAuth(req) + assert.Nil(t, retrievedAuth) + }) + + t.Run("invalid auth provider", func(t *testing.T) { + req := createReq(t) + + // Set a context with a value that is not of type proxy.UpstreamAuthProvider + ctx := context.WithValue(req.Context(), httputil.ContextKey("upstream-auth"), "invalid-type") + httputil.SetContext(req, ctx) + + retrievedAuth := httputil.GetUpstreamAuth(req) + assert.Nil(t, retrievedAuth) + }) +} + +func TestSetContext(t *testing.T) { + t.Run("add key", func(t *testing.T) { + req := createReq(t) + + // Create a new context with a key-value pair + ctx := context.WithValue(context.Background(), httputil.ContextKey("key"), "value") + + // Call SetContext to update the request's context + httputil.SetContext(req, ctx) + + // Verify that the request's context has been updated + retrievedValue := req.Context().Value(httputil.ContextKey("key")) + assert.Equal(t, "value", retrievedValue) + }) + + t.Run("override key", func(t *testing.T) { + + req := createReq(t) + existingCtx := context.WithValue(context.Background(), httputil.ContextKey("existingKey"), "existingValue") + req = req.WithContext(existingCtx) + + // Create a new context to override the existing context + newCtx := context.WithValue(context.Background(), httputil.ContextKey("newKey"), "newValue") + + // Call SetContext to update the request's context with the new context + httputil.SetContext(req, newCtx) + + assert.Nil(t, req.Context().Value(httputil.ContextKey("existingKey"))) + assert.Equal(t, "newValue", req.Context().Value(httputil.ContextKey("newKey"))) + }) + + t.Run("empty context", func(t *testing.T) { + req := createReq(t) + + emptyCtx := context.Background() + + httputil.SetContext(req, emptyCtx) + + assert.Equal(t, emptyCtx, req.Context()) + }) +} diff --git a/internal/httputil/headers.go b/internal/httputil/headers.go new file mode 100644 index 00000000000..4aed5f6749e --- /dev/null +++ b/internal/httputil/headers.go @@ -0,0 +1,26 @@ +package httputil + +import ( + "encoding/base64" + "fmt" + "strings" +) + +// CORSHeaders is a list of CORS headers. +var CORSHeaders = []string{ + "Access-Control-Allow-Origin", + "Access-Control-Expose-Headers", + "Access-Control-Max-Age", + "Access-Control-Allow-Credentials", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Headers", +} + +// AuthHeader will take username and password and return +// "Basic " + base64 encoded `username:password` for use +// in an Authorization header. +func AuthHeader(username, password string) string { + toEncode := strings.Join([]string{username, password}, ":") + encodedPass := base64.StdEncoding.EncodeToString([]byte(toEncode)) + return fmt.Sprintf("Basic %s", encodedPass) +} diff --git a/internal/model/upstream_auth.go b/internal/model/upstream_auth.go new file mode 100644 index 00000000000..a064b42e212 --- /dev/null +++ b/internal/model/upstream_auth.go @@ -0,0 +1,16 @@ +package model + +import "net/http" + +// UpstreamAuthProvider is an interface that can fill in upstream authentication details to the request. +type UpstreamAuthProvider interface { + Fill(r *http.Request) +} + +// MockUpstreamAuthProvider is a mock implementation of UpstreamAuthProvider. +type MockUpstreamAuthProvider struct{} + +// Fill is a mock implementation to be used in tests. +func (m *MockUpstreamAuthProvider) Fill(_ *http.Request) { + // empty mock implementation. +} diff --git a/test/http.go b/test/http.go index 5c09b3e4389..e46463df46b 100644 --- a/test/http.go +++ b/test/http.go @@ -16,6 +16,7 @@ import ( "time" ) +type TestCases []TestCase type TestCase struct { Host string `json:",omitempty"` Method string `json:",omitempty"`