Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor filter to reduce closure size #84

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 131 additions & 117 deletions pkg/auth/iam/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,85 +147,89 @@ func (filter *Filter) AuthAllowEmptySubdomain(opts ...FilterOption) restful.Filt

func (filter *Filter) authFunc(allowEmptySubdomain bool, opts ...FilterOption) restful.FilterFunction {
return func(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) {
token, tokenFrom, err := parseAccessToken(req)
if err != nil {
logrus.Warn("unauthorized access: ", err)
logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{
ErrorCode: UnauthorizedAccess,
ErrorMessage: ErrorCodeMapping[UnauthorizedAccess],
}, restful.MIME_JSON))
filter.authFuncImpl(req, resp, chain, allowEmptySubdomain, opts...)
}
}

return
}
func (filter *Filter) authFuncImpl(req *restful.Request, resp *restful.Response, chain *restful.FilterChain, allowEmptySubdomain bool, opts ...FilterOption) {
token, tokenFrom, err := parseAccessToken(req)
if err != nil {
logrus.Warn("unauthorized access: ", err)
logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{
ErrorCode: UnauthorizedAccess,
ErrorMessage: ErrorCodeMapping[UnauthorizedAccess],
}, restful.MIME_JSON))

claims, err := filter.iamClient.ValidateAndParseClaims(token)
if err != nil {
logrus.Warn("unauthorized access: ", err)
if err.Error() == ErrorCodeMapping[TokenIsExpired] {
logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{
ErrorCode: TokenIsExpired,
ErrorMessage: ErrorCodeMapping[TokenIsExpired],
}, restful.MIME_JSON))
return
}
return
}

claims, err := filter.iamClient.ValidateAndParseClaims(token)
if err != nil {
logrus.Warn("unauthorized access: ", err)
if err.Error() == ErrorCodeMapping[TokenIsExpired] {
logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{
ErrorCode: UnauthorizedAccess,
ErrorMessage: ErrorCodeMapping[UnauthorizedAccess],
ErrorCode: TokenIsExpired,
ErrorMessage: ErrorCodeMapping[TokenIsExpired],
}, restful.MIME_JSON))
return
}
logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{
ErrorCode: UnauthorizedAccess,
ErrorMessage: ErrorCodeMapping[UnauthorizedAccess],
}, restful.MIME_JSON))
return
}

req.SetAttribute(ClaimsAttribute, claims)
req.SetAttribute(ClaimsAttribute, claims)

if tokenFrom == tokenFromCookie {
valid := filter.validateRefererHeader(req, claims, allowEmptySubdomain)
if !valid {
logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{
ErrorCode: InvalidRefererHeader,
ErrorMessage: ErrorCodeMapping[InvalidRefererHeader],
}, restful.MIME_JSON))
if tokenFrom == tokenFromCookie {
valid := filter.validateRefererHeader(req, claims, allowEmptySubdomain)
if !valid {
logIfErr(resp.WriteHeaderAndJson(http.StatusUnauthorized, ErrorResponse{
ErrorCode: InvalidRefererHeader,
ErrorMessage: ErrorCodeMapping[InvalidRefererHeader],
}, restful.MIME_JSON))

return
}
return
}
}

if filter.options.SubdomainValidationEnabled && !allowEmptySubdomain {
if valid := validateSubdomainAgainstNamespace(getHost(req.Request), claims.Namespace, filter.options.SubdomainValidationExcludedNamespaces); !valid {
logIfErr(resp.WriteHeaderAndJson(http.StatusNotFound, ErrorResponse{
ErrorCode: SubdomainMismatch,
ErrorMessage: "data not found: " + ErrorCodeMapping[SubdomainMismatch],
}, restful.MIME_JSON))
if filter.options.SubdomainValidationEnabled && !allowEmptySubdomain {
if valid := validateSubdomainAgainstNamespace(getHost(req.Request), claims.Namespace, filter.options.SubdomainValidationExcludedNamespaces); !valid {
logIfErr(resp.WriteHeaderAndJson(http.StatusNotFound, ErrorResponse{
ErrorCode: SubdomainMismatch,
ErrorMessage: "data not found: " + ErrorCodeMapping[SubdomainMismatch],
}, restful.MIME_JSON))

return
}
return
}
}

for _, opt := range opts {
if err = opt(req, filter.iamClient, claims); err != nil {
if svcErr, ok := err.(restful.ServiceError); ok {
logrus.Warn(svcErr.Message)

var respErr ErrorResponse
for _, opt := range opts {
if err = opt(req, filter.iamClient, claims); err != nil {
if svcErr, ok := err.(restful.ServiceError); ok {
logrus.Warn(svcErr.Message)

err = json.Unmarshal([]byte(svcErr.Message), &respErr)
if err == nil {
logIfErr(resp.WriteHeaderAndJson(svcErr.Code, respErr, restful.MIME_JSON))
} else {
logIfErr(resp.WriteErrorString(svcErr.Code, svcErr.Message))
}
var respErr ErrorResponse

return
err = json.Unmarshal([]byte(svcErr.Message), &respErr)
if err == nil {
logIfErr(resp.WriteHeaderAndJson(svcErr.Code, respErr, restful.MIME_JSON))
} else {
logIfErr(resp.WriteErrorString(svcErr.Code, svcErr.Message))
}

logrus.Warn(err)
logIfErr(resp.WriteErrorString(http.StatusUnauthorized, err.Error()))

return
}
}

chain.ProcessFilter(req, resp)
logrus.Warn(err)
logIfErr(resp.WriteErrorString(http.StatusUnauthorized, err.Error()))

return
}
}

chain.ProcessFilter(req, resp)
}

// PublicAuth returns a filter that allow unauthenticate request and request with valid access token in auth header or cookie
Expand All @@ -241,41 +245,45 @@ func (filter *Filter) authFunc(allowEmptySubdomain bool, opts ...FilterOption) r
// )
func (filter *Filter) PublicAuth(opts ...FilterOption) restful.FilterFunction {
return func(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) {
token, tokenFrom, err := parseAccessToken(req)
if err != nil {
filter.publicAuth(req, resp, chain, opts...)
}
}

func (filter *Filter) publicAuth(req *restful.Request, resp *restful.Response, chain *restful.FilterChain, opts ...FilterOption) {
token, tokenFrom, err := parseAccessToken(req)
if err != nil {
chain.ProcessFilter(req, resp)
return
}

claims, err := filter.iamClient.ValidateAndParseClaims(token)
if err != nil {
logrus.Warn("unauthorized access for public endpoint: ", err)
chain.ProcessFilter(req, resp)
return
}

req.SetAttribute(ClaimsAttribute, claims)

if tokenFrom == tokenFromCookie {
valid := filter.validateRefererHeader(req, claims, false)
if !valid {
req.SetAttribute(ClaimsAttribute, nil)
chain.ProcessFilter(req, resp)
return
}
}

claims, err := filter.iamClient.ValidateAndParseClaims(token)
if err != nil {
logrus.Warn("unauthorized access for public endpoint: ", err)
for _, opt := range opts {
if err = opt(req, filter.iamClient, claims); err != nil {
logrus.Warn(err)
req.SetAttribute(ClaimsAttribute, nil)
chain.ProcessFilter(req, resp)
return
}

req.SetAttribute(ClaimsAttribute, claims)

if tokenFrom == tokenFromCookie {
valid := filter.validateRefererHeader(req, claims, false)
if !valid {
req.SetAttribute(ClaimsAttribute, nil)
chain.ProcessFilter(req, resp)
return
}
}

for _, opt := range opts {
if err = opt(req, filter.iamClient, claims); err != nil {
logrus.Warn(err)
req.SetAttribute(ClaimsAttribute, nil)
chain.ProcessFilter(req, resp)
return
}
}

chain.ProcessFilter(req, resp)
}

chain.ProcessFilter(req, resp)
}

// RetrieveJWTClaims is a convenience function to retrieve JWT claims
Expand All @@ -301,32 +309,35 @@ func WithValidUser() FilterOption {
// WithPermission filters request with valid permission only
func WithPermission(permission *iam.Permission) FilterOption {
return func(req *restful.Request, iamClient iam.Client, claims *iam.JWTClaims) error {
requiredPermissionResources := make(map[string]string)
requiredPermissionResources["{namespace}"] = req.PathParameter("namespace")
requiredPermissionResources["{userId}"] = req.PathParameter("userId")
return withPermission(req, iamClient, claims, permission)
}
}

valid, err := iamClient.ValidatePermission(claims, *permission, requiredPermissionResources)
if err != nil {
return respondError(http.StatusInternalServerError, InternalServerError,
"unable to validate permission: "+err.Error())
}
func withPermission(req *restful.Request, iamClient iam.Client, claims *iam.JWTClaims, permission *iam.Permission) error {
requiredPermissionResources := make(map[string]string)
requiredPermissionResources["{namespace}"] = req.PathParameter("namespace")
requiredPermissionResources["{userId}"] = req.PathParameter("userId")

insufficientPermissionMessage := ErrorCodeMapping[InsufficientPermissions]
if DevStackTraceable {
action := ActionConverter(permission.Action)
insufficientPermissionMessage = fmt.Sprintf("%s. Required permission: %s [%s]", insufficientPermissionMessage,
permission.Resource, action)
}
if !valid {
return respondErrorWithRequiredPermission(http.StatusForbidden, InsufficientPermissions,
"access forbidden: "+insufficientPermissionMessage, Permission{
Resource: permission.Resource,
Action: permission.Action,
})
}
valid, err := iamClient.ValidatePermission(claims, *permission, requiredPermissionResources)
if err != nil {
return respondError(http.StatusInternalServerError, InternalServerError,
"unable to validate permission: "+err.Error())
}

return nil
insufficientPermissionMessage := ErrorCodeMapping[InsufficientPermissions]
if DevStackTraceable {
action := ActionConverter(permission.Action)
insufficientPermissionMessage = fmt.Sprintf("%s. Required permission: %s [%s]", insufficientPermissionMessage,
permission.Resource, action)
}
if !valid {
return respondErrorWithRequiredPermission(http.StatusForbidden, InsufficientPermissions,
"access forbidden: "+insufficientPermissionMessage, Permission{
Resource: permission.Resource,
Action: permission.Action,
})
}
return nil
}

// WithRole filters request with valid role only
Expand Down Expand Up @@ -386,19 +397,22 @@ func WithValidAudience() FilterOption {
// WithValidScope filters request from a user with verified scope
func WithValidScope(scope string) FilterOption {
return func(req *restful.Request, iamClient iam.Client, claims *iam.JWTClaims) error {
err := iamClient.ValidateScope(claims, scope)
insufficientScopeMessage := ErrorCodeMapping[InsufficientScope]
if DevStackTraceable {
insufficientScopeMessage = fmt.Sprintf("%s. Required scope: %s", insufficientScopeMessage,
scope)
}
if err != nil {
return respondError(http.StatusForbidden, InsufficientScope,
"access forbidden: "+insufficientScopeMessage)
}
return withValidScope(scope, iamClient, claims)
}
}

return nil
func withValidScope(scope string, iamClient iam.Client, claims *iam.JWTClaims) error {
err := iamClient.ValidateScope(claims, scope)
insufficientScopeMessage := ErrorCodeMapping[InsufficientScope]
if DevStackTraceable {
insufficientScopeMessage = fmt.Sprintf("%s. Required scope: %s", insufficientScopeMessage,
scope)
}
if err != nil {
return respondError(http.StatusForbidden, InsufficientScope,
"access forbidden: "+insufficientScopeMessage)
}
return nil
}

func validateSubdomainAgainstNamespace(host string, namespace string, excludedNamespaces []string) bool {
Expand Down