From fbab56a045e4b3908aa9ccbcbe1bf5d6d02992ff Mon Sep 17 00:00:00 2001 From: Jeffy Mathew Date: Tue, 29 Oct 2024 07:55:36 +0100 Subject: [PATCH] Add validation rule for Upstream auth --- apidef/api_definitions.go | 3 ++ apidef/validator.go | 45 +++++++++++++++++ apidef/validator_test.go | 92 ++++++++++++++++++++++++++++++++++ gateway/mw_oauth2_auth.go | 10 ++-- gateway/mw_oauth2_auth_test.go | 4 +- 5 files changed, 146 insertions(+), 8 deletions(-) diff --git a/apidef/api_definitions.go b/apidef/api_definitions.go index f2202ed1342..569aeb4d886 100644 --- a/apidef/api_definitions.go +++ b/apidef/api_definitions.go @@ -119,6 +119,9 @@ const ( OAuthType = "oauth" ExternalOAuthType = "externalOAuth" OIDCType = "oidc" + + OAuthAuthorizationTypeClientCredentials = "clientCredentials" + OAuthAuthorizationTypePassword = "password" ) var ( diff --git a/apidef/validator.go b/apidef/validator.go index 233d9633907..f61b4c6bcdb 100644 --- a/apidef/validator.go +++ b/apidef/validator.go @@ -57,6 +57,7 @@ var DefaultValidationRuleSet = ValidationRuleSet{ &RuleAtLeastEnableOneAuthSource{}, &RuleValidateIPList{}, &RuleValidateEnforceTimeout{}, + &RuleUpstreamAuth{}, } func Validate(definition *APIDefinition, ruleSet ValidationRuleSet) ValidationResult { @@ -199,3 +200,47 @@ func (r *RuleValidateEnforceTimeout) Validate(apiDef *APIDefinition, validationR } } } + +var ( + ErrMultipleUpstreamAuthEnabled = errors.New("multiple upstream authentication modes not allowed") + ErrMultipleUpstreamOAuthAuthorizationType = errors.New("multiple upstream OAuth authorization modes not allowed") + ErrUpstreamOAuthAuthorizationTypeRequired = errors.New("upstream OAuth authorization type is required") + ErrInvalidUpstreamOAuthAuthorizationType = errors.New("invalid OAuth authorization type") +) + +type RuleUpstreamAuth struct{} + +func (r *RuleUpstreamAuth) Validate(apiDef *APIDefinition, validationResult *ValidationResult) { + upstreamAuth := apiDef.UpstreamAuth + + if !upstreamAuth.IsEnabled() { + return + } + + if upstreamAuth.BasicAuth.Enabled && upstreamAuth.OAuth.Enabled { + validationResult.IsValid = false + validationResult.AppendError(ErrMultipleUpstreamAuthEnabled) + } + + upstreamOAuth := upstreamAuth.OAuth + // only OAuth checks moving forward + if !upstreamOAuth.IsEnabled() { + return + } + + if len(upstreamOAuth.AllowedAuthorizeTypes) == 0 { + validationResult.IsValid = false + validationResult.AppendError(ErrUpstreamOAuthAuthorizationTypeRequired) + return + } + + if len(upstreamAuth.OAuth.AllowedAuthorizeTypes) > 1 { + validationResult.IsValid = false + validationResult.AppendError(ErrMultipleUpstreamOAuthAuthorizationType) + } + + if authType := upstreamAuth.OAuth.AllowedAuthorizeTypes[0]; authType != OAuthAuthorizationTypeClientCredentials && authType != OAuthAuthorizationTypePassword { + validationResult.IsValid = false + validationResult.AppendError(ErrInvalidUpstreamOAuthAuthorizationType) + } +} diff --git a/apidef/validator_test.go b/apidef/validator_test.go index 3ef01e3105c..d750593840b 100644 --- a/apidef/validator_test.go +++ b/apidef/validator_test.go @@ -416,3 +416,95 @@ func TestRuleValidateEnforceTimeout_Validate(t *testing.T) { t.Run(tc.name, runValidationTest(tc.apiDef, ruleSet, tc.result)) } } + +func TestRuleUpstreamAuth_Validate(t *testing.T) { + ruleSet := ValidationRuleSet{ + &RuleUpstreamAuth{}, + } + + testCases := []struct { + name string + upstreamAuth UpstreamAuth + result ValidationResult + }{ + { + name: "not enabled", + upstreamAuth: UpstreamAuth{ + Enabled: false, + }, + result: ValidationResult{ + IsValid: true, + Errors: nil, + }, + }, + { + name: "basic auth and OAuth enabled", + upstreamAuth: UpstreamAuth{ + Enabled: true, + BasicAuth: UpstreamBasicAuth{ + Enabled: true, + }, + OAuth: UpstreamOAuth{ + Enabled: true, + AllowedAuthorizeTypes: []string{OAuthAuthorizationTypeClientCredentials}, + }, + }, + result: ValidationResult{ + IsValid: false, + Errors: []error{ + ErrMultipleUpstreamAuthEnabled, + }, + }, + }, + { + name: "no upstream OAuth authorization type specified", + upstreamAuth: UpstreamAuth{ + Enabled: true, + OAuth: UpstreamOAuth{ + Enabled: true, + AllowedAuthorizeTypes: []string{}, + }, + }, + result: ValidationResult{ + IsValid: false, + Errors: []error{ErrUpstreamOAuthAuthorizationTypeRequired}, + }, + }, + { + name: "multiple upstream OAuth authorization type specified", + upstreamAuth: UpstreamAuth{ + Enabled: true, + OAuth: UpstreamOAuth{ + Enabled: true, + AllowedAuthorizeTypes: []string{OAuthAuthorizationTypeClientCredentials, OAuthAuthorizationTypePassword}, + }, + }, + result: ValidationResult{ + IsValid: false, + Errors: []error{ErrMultipleUpstreamOAuthAuthorizationType}, + }, + }, + { + name: "invalid upstream OAuth authorization type specified", + upstreamAuth: UpstreamAuth{ + Enabled: true, + OAuth: UpstreamOAuth{ + Enabled: true, + AllowedAuthorizeTypes: []string{"auth-type1"}, + }, + }, + result: ValidationResult{ + IsValid: false, + Errors: []error{ErrInvalidUpstreamOAuthAuthorizationType}, + }, + }, + } + + for _, tc := range testCases { + apiDef := &APIDefinition{ + UpstreamAuth: tc.upstreamAuth, + } + + t.Run(tc.name, runValidationTest(apiDef, ruleSet, tc.result)) + } +} diff --git a/gateway/mw_oauth2_auth.go b/gateway/mw_oauth2_auth.go index 8f7208e1b6e..bc683edbdf1 100644 --- a/gateway/mw_oauth2_auth.go +++ b/gateway/mw_oauth2_auth.go @@ -21,10 +21,8 @@ import ( ) const ( - UpstreamOAuthErrorEventName = "UpstreamOAuthError" - UpstreamOAuthMiddlewareName = "UpstreamOAuth" - ClientCredentialsAuthorizeType = "clientCredentials" - PasswordAuthorizeType = "password" + UpstreamOAuthErrorEventName = "UpstreamOAuthError" + UpstreamOAuthMiddlewareName = "UpstreamOAuth" ) type OAuthHeaderProvider interface { @@ -172,9 +170,9 @@ func getOAuthHeaderProvider(oauthConfig apidef.UpstreamOAuth) (OAuthHeaderProvid return nil, fmt.Errorf("no OAuth configuration selected") case len(oauthConfig.AllowedAuthorizeTypes) > 1: return nil, fmt.Errorf("both client credentials and password authentication are provided") - case oauthConfig.AllowedAuthorizeTypes[0] == ClientCredentialsAuthorizeType: + case oauthConfig.AllowedAuthorizeTypes[0] == apidef.OAuthAuthorizationTypeClientCredentials: return &ClientCredentialsOAuthProvider{}, nil - case oauthConfig.AllowedAuthorizeTypes[0] == PasswordAuthorizeType: + case oauthConfig.AllowedAuthorizeTypes[0] == apidef.OAuthAuthorizationTypePassword: return &PasswordOAuthProvider{}, nil default: return nil, fmt.Errorf("no valid OAuth configuration provided") diff --git a/gateway/mw_oauth2_auth_test.go b/gateway/mw_oauth2_auth_test.go index bc510988767..dab2d48e2f9 100644 --- a/gateway/mw_oauth2_auth_test.go +++ b/gateway/mw_oauth2_auth_test.go @@ -66,7 +66,7 @@ func TestUpstreamOauth2(t *testing.T) { OAuth: apidef.UpstreamOAuth{ Enabled: true, ClientCredentials: cfg, - AllowedAuthorizeTypes: []string{ClientCredentialsAuthorizeType}, + AllowedAuthorizeTypes: []string{apidef.OAuthAuthorizationTypeClientCredentials}, }, } spec.Proxy.StripListenPath = true @@ -150,7 +150,7 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { OAuth: apidef.UpstreamOAuth{ Enabled: true, PasswordAuthentication: cfg, - AllowedAuthorizeTypes: []string{PasswordAuthorizeType}, + AllowedAuthorizeTypes: []string{apidef.OAuthAuthorizationTypePassword}, }, } spec.Proxy.StripListenPath = true