Skip to content

Commit

Permalink
common: introduce common package
Browse files Browse the repository at this point in the history
Introduced a new common package which includes commonly used variables,
functions and constants across our code. This package also includes the
respective unit tests. Also factored out the definition of a Cacheable
authenticator.

GitHub-PR: #105

Signed-off-by: Athanasios Markou <[email protected]>
  • Loading branch information
Athanasios Markou committed Jan 5, 2023
1 parent c751080 commit 0f5728c
Show file tree
Hide file tree
Showing 26 changed files with 284 additions and 259 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ COPY go.sum .
RUN go mod download
# Copy in the code and compile
COPY *.go ./
COPY common common
RUN CGO_ENABLED=0 GOOS=linux go build -a -ldflags '-extldflags "-static"' -o /go/bin/oidc-authservice


Expand Down
9 changes: 9 additions & 0 deletions authenticator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package main

import (
"net/http"
)

type Cacheable interface {
getCacheKey(r *http.Request) string
}
9 changes: 5 additions & 4 deletions authenticator_idtoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"net/http"

"github.com/arrikto/oidc-authservice/common"
oidc "github.com/coreos/go-oidc"
"k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/authentication/user"
Expand All @@ -18,16 +19,16 @@ type idTokenAuthenticator struct {
}

func (s *idTokenAuthenticator) AuthenticateRequest(r *http.Request) (*authenticator.Response, bool, error) {
logger := loggerForRequest(r, "idtoken authenticator")
logger := common.LoggerForRequest(r, "idtoken authenticator")

// get id-token from header
bearer := getBearerToken(r.Header.Get(s.header))
bearer := common.GetBearerToken(r.Header.Get(s.header))
if len(bearer) == 0 {
logger.Info("No bearer token found")
return nil, false, nil
}

ctx := setTLSContext(r.Context(), s.caBundle)
ctx := common.SetTLSContext(r.Context(), s.caBundle)

// Verifying received ID token
verifier := s.provider.Verifier(&oidc.Config{ClientID: s.clientID})
Expand All @@ -52,7 +53,7 @@ func (s *idTokenAuthenticator) AuthenticateRequest(r *http.Request) (*authentica
groups := []string{}
groupsClaim := claims[s.groupsClaim]
if groupsClaim != nil {
groups = interfaceSliceToStringSlice(groupsClaim.([]interface{}))
groups = common.InterfaceSliceToStringSlice(groupsClaim.([]interface{}))
}

// Authentication using header successfully completed
Expand Down
23 changes: 12 additions & 11 deletions authenticator_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"

"github.com/arrikto/oidc-authservice/common"
oidc "github.com/coreos/go-oidc"
"k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/authentication/user"
Expand All @@ -26,21 +27,21 @@ type jwtTokenAuthenticator struct {
}

type jwtLocalChecks struct {
Issuer string `json:"iss"`
Audiences audience `json:"aud"`
Issuer string `json:"iss"`
Audiences common.Audience `json:"aud"`
}

func (s *jwtTokenAuthenticator) AuthenticateRequest(r *http.Request) (*authenticator.Response, bool, error) {
logger := loggerForRequest(r, "JWT access token authenticator")
logger := common.LoggerForRequest(r, "JWT access token authenticator")

// Get JWT access token from header
bearer := getBearerToken(r.Header.Get(s.header))
bearer := common.GetBearerToken(r.Header.Get(s.header))
if len(bearer) == 0 {
logger.Info("No bearer token found")
return nil, false, nil
}

ctx := setTLSContext(r.Context(), s.caBundle)
ctx := common.SetTLSContext(r.Context(), s.caBundle)

// Verifying received JWT token
for _, aud := range s.audiences {
Expand All @@ -65,19 +66,19 @@ func (s *jwtTokenAuthenticator) AuthenticateRequest(r *http.Request) (*authentic

// Return the error of the go-oidc ID token verifier.
logger.Errorf("JWT-token verification failed: %v", err)
return nil, false, &authenticatorSpecificError{Err: err}
return nil, false, &common.AuthenticatorSpecificError{Err: err}
}

// Retrieve the USERID_CLAIM and the GROUPS_CLAIM
var claims map[string]interface{}
if claimErr := token.Claims(&claims); claimErr != nil {
logger.Errorf("Retrieving user claims failed: %v", claimErr)
return nil, false, &authenticatorSpecificError{Err: claimErr}
return nil, false, &common.AuthenticatorSpecificError{Err: claimErr}
}

userID, groups, claimErr := s.retrieveUserIDGroupsClaims(claims)
if claimErr != nil {
return nil, false, &authenticatorSpecificError{Err: claimErr}
return nil, false, &common.AuthenticatorSpecificError{Err: claimErr}
}

// Authentication using header successfully completed
Expand All @@ -102,7 +103,7 @@ func (s *jwtTokenAuthenticator) AuthenticateRequest(r *http.Request) (*authentic
func (s *jwtTokenAuthenticator) performLocalChecks(bearer string) (error){

// Verify that the retrieved Bearer token is a parsable JWT token
payload, localErr := parseJWT(bearer)
payload, localErr := common.ParseJWT(bearer)
if localErr != nil { // Check next authenticator
localErr = fmt.Errorf("Could not parse the inspected Bearer token.")
return localErr
Expand All @@ -123,7 +124,7 @@ func (s *jwtTokenAuthenticator) performLocalChecks(bearer string) (error){
}

// Check audiences
if !contains(s.audiences, tokenLocalChecks.Audiences){ // Check next authenticator
if !common.Contains(s.audiences, tokenLocalChecks.Audiences){ // Check next authenticator
localErr = fmt.Errorf("The retrieved \"aud\" did not match with any of the" +
" expected audiences.")
return localErr
Expand All @@ -149,7 +150,7 @@ func (s *jwtTokenAuthenticator) retrieveUserIDGroupsClaims(claims map[string]int
return "", []string{}, claimErr
}

groups = interfaceSliceToStringSlice(groupsClaim.([]interface{}))
groups = common.InterfaceSliceToStringSlice(groupsClaim.([]interface{}))

return claims[s.userIDClaim].(string), groups, nil
}
5 changes: 3 additions & 2 deletions authenticator_kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"k8s.io/apiserver/plugin/pkg/authenticator/token/webhook"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"github.com/arrikto/oidc-authservice/common"
)

const (
Expand Down Expand Up @@ -39,7 +40,7 @@ func (k8sauth *kubernetesAuthenticator) AuthenticateRequest(r *http.Request) (*a

// If the request contains an expired token, we stop trying and return 403
if err != nil && strings.Contains(err.Error(), bearerTokenExpiredMsg) {
return nil, false, &loginExpiredError{Err: err}
return nil, false, &common.LoginExpiredError{Err: err}
}

if found {
Expand All @@ -63,6 +64,6 @@ func (k8sauth *kubernetesAuthenticator) AuthenticateRequest(r *http.Request) (*a
// The Kubernetes Authenticator implements the Cacheable
// interface with the getCacheKey().
func (k8sauth *kubernetesAuthenticator) getCacheKey(r *http.Request) (string) {
return getBearerToken(r.Header.Get("Authorization"))
return common.GetBearerToken(r.Header.Get("Authorization"))

}
17 changes: 9 additions & 8 deletions authenticator_opaque.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"net/http"

"github.com/arrikto/oidc-authservice/common"
oidc "github.com/coreos/go-oidc"
"github.com/pkg/errors"
"golang.org/x/oauth2"
Expand All @@ -20,10 +21,10 @@ type opaqueTokenAuthenticator struct {
}

func (s *opaqueTokenAuthenticator) AuthenticateRequest(r *http.Request) (*authenticator.Response, bool, error) {
logger := loggerForRequest(r, "opaque access token authenticator")
logger := common.LoggerForRequest(r, "opaque access token authenticator")

// get id-token from header
bearer := getBearerToken(r.Header.Get(s.header))
bearer := common.GetBearerToken(r.Header.Get(s.header))
if len(bearer) == 0 {
logger.Info("No bearer token found")
return nil, false, nil
Expand All @@ -34,11 +35,11 @@ func (s *opaqueTokenAuthenticator) AuthenticateRequest(r *http.Request) (*authen
TokenType: "Bearer",
}

ctx := setTLSContext(r.Context(), s.caBundle)
ctx := common.SetTLSContext(r.Context(), s.caBundle)

userInfo, err := GetUserInfo(ctx, s.provider, s.oauth2Config.TokenSource(ctx, opaque))
if err != nil {
var reqErr *requestError
var reqErr *common.RequestError
if !errors.As(err, &reqErr) {
return nil, false, errors.Wrap(err, "UserInfo request failed unexpectedly")
}
Expand All @@ -50,12 +51,12 @@ func (s *opaqueTokenAuthenticator) AuthenticateRequest(r *http.Request) (*authen
var claims map[string]interface{}
if claimErr := userInfo.Claims(&claims); claimErr != nil {
logger.Errorf("Retrieving user claims failed: %v", claimErr)
return nil, false, &authenticatorSpecificError{Err: claimErr}
return nil, false, &common.AuthenticatorSpecificError{Err: claimErr}
}

userID, groups, claimErr := s.retrieveUserIDGroupsClaims(claims)
if claimErr != nil {
return nil, false, &authenticatorSpecificError{Err: claimErr}
return nil, false, &common.AuthenticatorSpecificError{Err: claimErr}
}

// Authentication using header successfully completed
Expand Down Expand Up @@ -86,14 +87,14 @@ func (s *opaqueTokenAuthenticator) retrieveUserIDGroupsClaims(claims map[string]
return "", []string{}, claimErr
}

groups = interfaceSliceToStringSlice(groupsClaim.([]interface{}))
groups = common.InterfaceSliceToStringSlice(groupsClaim.([]interface{}))

return claims[s.userIDClaim].(string), groups, nil
}

// The Opaque Access Token Authenticator implements the Cacheable
// interface with the getCacheKey().
func (s *opaqueTokenAuthenticator) getCacheKey(r *http.Request) (string) {
return getBearerToken(r.Header.Get("Authorization"))
return common.GetBearerToken(r.Header.Get("Authorization"))

}
51 changes: 0 additions & 51 deletions authenticator_opaque_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,6 @@ import (
"testing"
)

func TestValidAccessTokenAuthn(t *testing.T) {

tests := []struct {
testName string
AccessTokenAuthnEnabled bool
AccessTokenAuthn string
success bool
}{
{
testName: "Access Token Authenticator is set to JWT",
AccessTokenAuthnEnabled: true,
AccessTokenAuthn: "jwt",
success: true,
},
{
testName: "Access Token Authenticator is set to opaque",
AccessTokenAuthnEnabled: true,
AccessTokenAuthn: "opaque",
success: true,
},
{
testName: "Access Token Authenticator is disabled",
AccessTokenAuthnEnabled: false,
AccessTokenAuthn: "whatever",
success: true,
},
{
testName: "Access Token Authenticator envvar is invalid (JWT)",
AccessTokenAuthnEnabled: true,
AccessTokenAuthn: "JWT",
success: false,
},
{
testName: "Access Token Authenticator envvar is invalid (Opaque)",
AccessTokenAuthnEnabled: true,
AccessTokenAuthn: "Opaque",
success: false,
},
}

for _, c := range tests {
t.Run(c.testName, func(t *testing.T) {
result := validAccessTokenAuthn(c.AccessTokenAuthnEnabled, c.AccessTokenAuthn)

if result != c.success {
t.Errorf("validAccessTokenAuthn result for %v is not the expected one.", c)
}
})
}
}

func TestRetrieveUserIDGroupsUserInfo(t *testing.T) {

s := &opaqueTokenAuthenticator {
Expand Down
7 changes: 4 additions & 3 deletions authenticator_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net/http"
"net/http/httptest"

"github.com/arrikto/oidc-authservice/common"
oidc "github.com/coreos/go-oidc"
"github.com/gorilla/sessions"
"github.com/pkg/errors"
Expand Down Expand Up @@ -35,7 +36,7 @@ type sessionAuthenticator struct {
}

func (sa *sessionAuthenticator) AuthenticateRequest(r *http.Request) (*authenticator.Response, bool, error) {
logger := loggerForRequest(r, "session authenticator")
logger := common.LoggerForRequest(r, "session authenticator")

// Get session from header or cookie
session, authMethod, err := sessionFromRequest(r, sa.store, sa.cookie, sa.header)
Expand All @@ -51,12 +52,12 @@ func (sa *sessionAuthenticator) AuthenticateRequest(r *http.Request) (*authentic

// User is logged in
if sa.strictSessionValidation {
ctx := setTLSContext(r.Context(), sa.caBundle)
ctx := common.SetTLSContext(r.Context(), sa.caBundle)
token := session.Values[userSessionOAuth2Tokens].(oauth2.Token)
// TokenSource takes care of automatically renewing the access token.
_, err := GetUserInfo(ctx, sa.provider, sa.oauth2Config.TokenSource(ctx, &token))
if err != nil {
var reqErr *requestError
var reqErr *common.RequestError
if !errors.As(err, &reqErr) {
return nil, false, errors.Wrap(err, "UserInfo request failed unexpectedly")
}
Expand Down
8 changes: 5 additions & 3 deletions authorizer_external.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"strconv"
"strings"
"time"

"github.com/arrikto/oidc-authservice/common"
)

// ExternalAuthorizer is responsible for handling authorization in an external
Expand Down Expand Up @@ -48,7 +50,7 @@ type AuthorizationRequestInfo struct {

func (e ExternalAuthorizer) Authorize(r *http.Request, userinfo user.Info) (allowed bool, reason string, err error) {
// Collect data and create the AuthorizationRequestBody.
logger := loggerForRequest(r, "external authorizer")
logger := common.LoggerForRequest(r, "external authorizer")
logger = logger.WithField("user", userinfo)
authorizationUserInfo := e.getUserInfo(r, userinfo)

Expand Down Expand Up @@ -82,10 +84,10 @@ func (e ExternalAuthorizer) Authorize(r *http.Request, userinfo user.Info) (allo
// getUserInfo creates a AuthorizationUserInfo object for the current context.
func (e ExternalAuthorizer) getUserInfo(r *http.Request, userinfo user.Info) AuthorizationUserInfo {
// Parse the JWT token and add get the claims if it exists.
bearer := getBearerToken(r.Header.Get("Authorization"))
bearer := common.GetBearerToken(r.Header.Get("Authorization"))
var parsedJwt map[string]interface{} = nil
if bearer != "" {
jwt, err := parseJWT(bearer)
jwt, err := common.ParseJWT(bearer)
if err == nil {
// Unmarshal the JSON to the interface.
err = json.Unmarshal(jwt, &parsedJwt)
Expand Down
Loading

0 comments on commit 0f5728c

Please sign in to comment.