Skip to content

Commit

Permalink
implement upstream basic authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffy-mathew committed Oct 1, 2024
1 parent 685c35f commit c4dfd99
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 0 deletions.
19 changes: 19 additions & 0 deletions apidef/api_definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,25 @@ type APIDefinition struct {
VersionName string `bson:"-" json:"-"`

DetailedTracing bool `bson:"detailed_tracing" json:"detailed_tracing"`

// UpstreamAuth stores information about authenticating against upstream.
UpstreamAuth UpstreamAuth `bson:"upstream_auth" json:"upstream_auth"`
}

type UpstreamAuth struct {
Enabled bool `bson:"enabled" json:"enabled"`
BasicAuth UpstreamBasicAuth `bson:"basic_auth" json:"basic_auth"`
}

func (u *UpstreamAuth) IsEnabled() bool {
return u.Enabled && u.BasicAuth.Enabled
}

type UpstreamBasicAuth struct {
Enabled bool `bson:"enabled" json:"enabled,omitempty"`
Username string `bson:"username" json:"username"`
Password string `bson:"password" json:"password"`
HeaderName string `bson:"auth_header_name" json:"authHeaderName"`
}

type AnalyticsPluginConfig struct {
Expand Down
78 changes: 78 additions & 0 deletions apidef/oas/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type Upstream struct {

// RateLimit contains the configuration related to API level rate limit.
RateLimit *RateLimit `bson:"rateLimit,omitempty" json:"rateLimit,omitempty"`

// Authentication contains the configuration related to upstream authentication.
Authentication *UpstreamAuth `bson:"authentication,omitempty" json:"authentication,omitempty"`
}

// Fill fills *Upstream from apidef.APIDefinition.
Expand Down Expand Up @@ -79,6 +82,15 @@ func (u *Upstream) Fill(api apidef.APIDefinition) {
if ShouldOmit(u.RateLimit) {
u.RateLimit = nil
}

if u.Authentication == nil {
u.Authentication = &UpstreamAuth{}
}

u.Authentication.Fill(api.UpstreamAuth)
if ShouldOmit(u.Authentication) {
u.Authentication = nil
}
}

// ExtractTo extracts *Upstream into *apidef.APIDefinition.
Expand Down Expand Up @@ -129,6 +141,15 @@ func (u *Upstream) ExtractTo(api *apidef.APIDefinition) {
}

u.RateLimit.ExtractTo(api)

if u.Authentication == nil {
u.Authentication = &UpstreamAuth{}
defer func() {
u.Authentication = nil
}()
}

u.Authentication.ExtractTo(&api.UpstreamAuth)
}

// ServiceDiscovery holds configuration required for service discovery.
Expand Down Expand Up @@ -529,3 +550,60 @@ func (r *RateLimitEndpoint) ExtractTo(meta *apidef.RateLimitMeta) {
meta.Rate = float64(r.Rate)
meta.Per = r.Per.Seconds()
}

type UpstreamAuth struct {
Enabled bool `bson:"enabled" json:"enabled"`
BasicAuth *UpstreamBasicAuth `bson:"basicAuth,omitempty" json:"basicAuth,omitempty"`
}

// Fill fills *UpstreamAuth from apidef.UpstreamAuth.
func (u *UpstreamAuth) Fill(api apidef.UpstreamAuth) {
u.Enabled = api.Enabled

if u.BasicAuth == nil {
u.BasicAuth = &UpstreamBasicAuth{}
}

u.BasicAuth.Fill(api.BasicAuth)
if ShouldOmit(u.BasicAuth) {
u.BasicAuth = nil
}
}

// ExtractTo extracts *UpstreamAuth into *apidef.UpstreamAuth.
func (u *UpstreamAuth) ExtractTo(api *apidef.UpstreamAuth) {
api.Enabled = u.Enabled

if u.BasicAuth == nil {
u.BasicAuth = &UpstreamBasicAuth{}
defer func() {
u.BasicAuth = nil
}()
}

u.BasicAuth.ExtractTo(&api.BasicAuth)
}

type UpstreamBasicAuth struct {
Enabled bool `bson:"enabled" json:"enabled"`
HeaderName string `bson:"headerName" json:"headerName"`
Username string `bson:"username" json:"username"`
Password string `bson:"password" json:"password"`
}

// Fill fills *UpstreamBasicAuth from apidef.UpstreamBasicAuth.
func (u *UpstreamBasicAuth) Fill(api apidef.UpstreamBasicAuth) {
u.Enabled = api.Enabled
u.HeaderName = api.HeaderName
u.Username = api.Username
u.Password = api.Password
}

// ExtractTo extracts *UpstreamBasicAuth into *apidef.UpstreamBasicAuth.
func (u *UpstreamBasicAuth) ExtractTo(api *apidef.UpstreamBasicAuth) {
api.Enabled = u.Enabled
api.Enabled = u.Enabled
api.HeaderName = u.HeaderName
api.Username = u.Username
api.Password = u.Password
}
35 changes: 35 additions & 0 deletions ctx/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ const (
// CacheOptions holds cache options required for cache writer middleware.
CacheOptions
OASDefinition

// UpstreamAuthHeader sets the header name to be used for upstream authentication.
UpstreamAuthHeader
// UpstreamAuthValue sets the value for upstream authentication.
UpstreamAuthValue
)

func setContext(r *http.Request, ctx context.Context) {
Expand Down Expand Up @@ -158,3 +163,33 @@ func GetOASDefinition(r *http.Request) *oas.OAS {

return ret
}

// SetUpstreamAuthHeader sets the header name to be used for upstream authentication.
func SetUpstreamAuthHeader(r *http.Request, name string) {
ctx := r.Context()
ctx = context.WithValue(ctx, UpstreamAuthHeader, name)
setContext(r, ctx)
}

// GetUpstreamAuthHeader returns the header name to be used for upstream authentication.
func GetUpstreamAuthHeader(r *http.Request) string {
if v := r.Context().Value(UpstreamAuthHeader); v != nil {
return v.(string)
}
return ""
}

// SetUpstreamAuthValue sets the auth header value to be used for upstream authentication.
func SetUpstreamAuthValue(r *http.Request, name string) {
ctx := r.Context()
ctx = context.WithValue(ctx, UpstreamAuthValue, name)
setContext(r, ctx)
}

// GetUpstreamAuthValue gets the auth header value to be used for upstream authentication.
func GetUpstreamAuthValue(r *http.Request) string {
if v := r.Context().Value(UpstreamAuthValue); v != nil {
return v.(string)
}
return ""
}
2 changes: 2 additions & 0 deletions gateway/api_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int,
}
}

gw.mwAppendEnabled(&chainArray, &UpstreamBasicAuth{BaseMiddleware: baseMid})

chain = alice.New(chainArray...).Then(&DummyProxyHandler{SH: SuccessHandler{baseMid}, Gw: gw})

if !spec.UseKeylessAccess {
Expand Down
49 changes: 49 additions & 0 deletions gateway/mw_upstream_basic_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package gateway

import (
"encoding/base64"
"fmt"
"net/http"

"github.com/TykTechnologies/tyk/ctx"
"github.com/TykTechnologies/tyk/header"
)

// UpstreamBasicAuth is a middleware that will do basic authentication for upstream connections.
// UpstreamBasicAuth middleware is only supported in Tyk OAS API definitions.
type UpstreamBasicAuth struct {
*BaseMiddleware
}

func (t *UpstreamBasicAuth) Name() string {
return "UpstreamBasicAuth"
}

func (t *UpstreamBasicAuth) EnabledForSpec() bool {
if !t.Spec.UpstreamAuth.Enabled {
return false
}

if !t.Spec.UpstreamAuth.BasicAuth.Enabled {
return false
}

return true
}

// ProcessRequest will inject basic auth info into request context so that it can be used during reverse proxy.
func (t *UpstreamBasicAuth) ProcessRequest(_ http.ResponseWriter, r *http.Request, _ interface{}) (error, int) {
basicAuthConfig := t.Spec.UpstreamAuth.BasicAuth

authHeaderName := header.Authorization
if basicAuthConfig.HeaderName != "" {
authHeaderName = basicAuthConfig.HeaderName
}
ctx.SetUpstreamAuthHeader(r, authHeaderName)

payload := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", basicAuthConfig.Username, basicAuthConfig.Password)))

ctx.SetUpstreamAuthValue(r, payload)

return nil, http.StatusOK
}
143 changes: 143 additions & 0 deletions gateway/mw_upstream_basic_auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package gateway

import (
"encoding/base64"
"encoding/json"
"net/http"
"testing"

"github.com/stretchr/testify/assert"

"github.com/TykTechnologies/tyk/apidef"
"github.com/TykTechnologies/tyk/header"
"github.com/TykTechnologies/tyk/test"
)

func TestUpstreamBasicAuthentication(t *testing.T) {

ts := StartTest(nil)
t.Cleanup(func() {
ts.Close()
})

userName, password, customAuthHeader := "user", "password", "Custom-Auth"
expectedAuth := base64.StdEncoding.EncodeToString([]byte(userName + ":" + password))

ts.Gw.BuildAndLoadAPI(
func(spec *APISpec) {
spec.Proxy.ListenPath = "/upstream-basic-auth-enabled/"
spec.UseKeylessAccess = true
spec.UpstreamAuth = apidef.UpstreamAuth{
Enabled: true,
BasicAuth: apidef.UpstreamBasicAuth{
Enabled: true,
Username: userName,
Password: password,
},
}
spec.Proxy.StripListenPath = true
}, func(spec *APISpec) {
spec.Proxy.ListenPath = "/upstream-basic-auth-custom-header/"
spec.UseKeylessAccess = true
spec.UpstreamAuth = apidef.UpstreamAuth{
Enabled: true,
BasicAuth: apidef.UpstreamBasicAuth{
Enabled: true,
Username: userName,
Password: password,
HeaderName: customAuthHeader,
},
}
spec.Proxy.StripListenPath = true
},
func(spec *APISpec) {
spec.Proxy.ListenPath = "/upstream-basic-auth-disabled/"
spec.UseKeylessAccess = true
spec.UpstreamAuth = apidef.UpstreamAuth{
Enabled: true,
BasicAuth: apidef.UpstreamBasicAuth{
Enabled: false,
Username: userName,
Password: password,
},
}
spec.Proxy.StripListenPath = true
},
func(spec *APISpec) {
spec.Proxy.ListenPath = "/upstream-auth-disabled/"
spec.UseKeylessAccess = true
spec.UpstreamAuth = apidef.UpstreamAuth{
Enabled: false,
}
spec.Proxy.StripListenPath = true
},
)

_, _ = ts.Run(t, test.TestCases{
{
Path: "/upstream-basic-auth-enabled/",
Code: http.StatusOK,
BodyMatchFunc: func(body []byte) bool {
resp := struct {
Headers map[string]string `json:"headers"`
}{}
err := json.Unmarshal(body, &resp)
assert.NoError(t, err)

assert.Contains(t, resp.Headers, header.Authorization)
assert.NotEmpty(t, resp.Headers[header.Authorization])
assert.Equal(t, expectedAuth, resp.Headers[header.Authorization])

return true
},
},
{
Path: "/upstream-basic-auth-custom-header/",
Code: http.StatusOK,
BodyMatchFunc: func(body []byte) bool {
resp := struct {
Headers map[string]string `json:"headers"`
}{}
err := json.Unmarshal(body, &resp)
assert.NoError(t, err)

assert.Contains(t, resp.Headers, customAuthHeader)
assert.NotEmpty(t, resp.Headers[customAuthHeader])
assert.Equal(t, expectedAuth, resp.Headers[customAuthHeader])

return true
},
},
{
Path: "/upstream-basic-auth-disabled/",
Code: http.StatusOK,
BodyMatchFunc: func(body []byte) bool {
resp := struct {
Headers map[string]string `json:"headers"`
}{}
err := json.Unmarshal(body, &resp)
assert.NoError(t, err)

assert.NotContains(t, resp.Headers, header.Authorization)

return true
},
},
{
Path: "/upstream-auth-disabled/",
Code: http.StatusOK,
BodyMatchFunc: func(body []byte) bool {
resp := struct {
Headers map[string]string `json:"headers"`
}{}
err := json.Unmarshal(body, &resp)
assert.NoError(t, err)

assert.NotContains(t, resp.Headers, header.Authorization)

return true
},
},
}...)

}
Loading

0 comments on commit c4dfd99

Please sign in to comment.