Skip to content

Commit

Permalink
Relocate provider specific validation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
sporkmonger committed Oct 3, 2018
1 parent cb40882 commit 3b439f6
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 20 deletions.
15 changes: 14 additions & 1 deletion internal/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1635,8 +1635,9 @@ func TestHostHeader(t *testing.T) {
}
}

func TestDefaultProviderApiSettings(t *testing.T) {
func TestGoogleProviderApiSettings(t *testing.T) {
opts := testOpts("abced", "testtest")
opts.Provider = "google"
opts.Validate()
proxy, _ := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
Expand All @@ -1653,6 +1654,7 @@ func TestDefaultProviderApiSettings(t *testing.T) {

func TestGoogleGroupInvalidFile(t *testing.T) {
opts := testOpts("abced", "testtest")
opts.Provider = "google"
opts.GoogleAdminEmail = "[email protected]"
opts.GoogleServiceAccountJSON = "file_doesnt_exist.json"
opts.Validate()
Expand All @@ -1663,3 +1665,14 @@ func TestGoogleGroupInvalidFile(t *testing.T) {
testutil.NotEqual(t, nil, err)
testutil.Equal(t, "invalid Google credentials file: file_doesnt_exist.json", err.Error())
}

func TestUnimplementedProvider(t *testing.T) {
opts := testOpts("abced", "testtest")
opts.Validate()
_, err := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
return nil
})
testutil.NotEqual(t, nil, err)
testutil.Equal(t, "unimplemented provider: \"\"", err.Error())
}
25 changes: 10 additions & 15 deletions internal/auth/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package auth
import (
"crypto"
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -170,7 +169,7 @@ func (o *Options) Validate() error {

o.redirectURL, msgs = parseURL(o.RedirectURL, "redirect", msgs)

msgs = parseProviderInfo(o, msgs)
msgs = validateEndpoints(o, msgs)

decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
if err != nil {
Expand Down Expand Up @@ -198,15 +197,6 @@ func (o *Options) Validate() error {
o.CookieExpire.String()))
}

if o.GoogleAdminEmail != "" || o.GoogleServiceAccountJSON != "" {
if o.GoogleAdminEmail == "" {
msgs = append(msgs, "missing setting: google-admin-email")
}
if o.GoogleServiceAccountJSON == "" {
msgs = append(msgs, "missing setting: google-service-account-json")
}
}

msgs = validateCookieName(o, msgs)

if o.StatsdHost == "" {
Expand All @@ -224,7 +214,7 @@ func (o *Options) Validate() error {
return nil
}

func parseProviderInfo(o *Options, msgs []string) []string {
func validateEndpoints(o *Options, msgs []string) []string {
_, msgs = parseURL(o.SignInURL, "signin", msgs)
_, msgs = parseURL(o.RedeemURL, "redeem", msgs)
_, msgs = parseURL(o.ProfileURL, "profile", msgs)
Expand Down Expand Up @@ -267,18 +257,23 @@ func newProvider(o *Options) (providers.Provider, error) {

var singleFlightProvider providers.Provider
switch o.Provider {
default: // Google
case providers.GoogleProviderName: // Google
if o.GoogleServiceAccountJSON != "" {
_, err := os.Open(o.GoogleServiceAccountJSON)
if err != nil {
return nil, errors.New("invalid Google credentials file: " + o.GoogleServiceAccountJSON)
return nil, fmt.Errorf("invalid Google credentials file: %s", o.GoogleServiceAccountJSON)
}
}
googleProvider := providers.NewGoogleProvider(p, o.GoogleAdminEmail, o.GoogleServiceAccountJSON)
googleProvider, err := providers.NewGoogleProvider(p, o.GoogleAdminEmail, o.GoogleServiceAccountJSON)
if err != nil {
return nil, err
}
cache := groups.NewFillCache(googleProvider.PopulateMembers, o.GroupsCacheRefreshTTL)
googleProvider.GroupsCache = cache
o.GroupsCacheStopFunc = cache.Stop
singleFlightProvider = providers.NewSingleFlightProvider(googleProvider)
default:
return nil, fmt.Errorf("unimplemented provider: %q", o.Provider)
}

return singleFlightProvider, nil
Expand Down
15 changes: 12 additions & 3 deletions internal/auth/providers/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@ type GoogleProvider struct {
}

// NewGoogleProvider returns a new GoogleProvider and sets the provider url endpoints.
func NewGoogleProvider(p *ProviderData, adminEmail, credsFilePath string) *GoogleProvider {
func NewGoogleProvider(p *ProviderData, adminEmail, credsFilePath string) (*GoogleProvider, error) {
if adminEmail != "" || credsFilePath != "" {
if adminEmail == "" {
return nil, errors.New("missing setting: google-admin-email")
}
if credsFilePath == "" {
return nil, errors.New("missing setting: google-service-account-json")
}
}

p.ProviderName = "Google"
if p.SignInURL.String() == "" {
p.SignInURL = &url.URL{Scheme: "https",
Expand Down Expand Up @@ -77,14 +86,14 @@ func NewGoogleProvider(p *ProviderData, adminEmail, credsFilePath string) *Googl
if credsFilePath != "" {
credsReader, err := os.Open(credsFilePath)
if err != nil {
panic("could not read google credentials file")
return nil, errors.New("could not read google credentials file")
}
googleProvider.AdminService = &GoogleAdminService{
adminService: getAdminService(adminEmail, credsReader),
cb: googleProvider.cb,
}
}
return googleProvider
return googleProvider, nil
}

// SetStatsdClient sets the google provider and admin service statsd client
Expand Down
3 changes: 2 additions & 1 deletion internal/auth/providers/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ func newGoogleProvider(providerData *ProviderData) *GoogleProvider {
ValidateURL: &url.URL{},
Scope: ""}
}
return NewGoogleProvider(providerData, "", "")
provider, _ := NewGoogleProvider(providerData, "", "")
return provider
}

func TestGoogleProviderDefaults(t *testing.T) {
Expand Down
17 changes: 17 additions & 0 deletions internal/auth/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@ var (
ErrServiceUnavailable = errors.New("SERVICE_UNAVAILABLE")
)

const (
// AzureProviderName identifies the Azure AD provider
AzureProviderName = "azure"
// FacebookProviderName identifies the Facebook provider
FacebookProviderName = "facebook"
// GitHubProviderName identifies the GitHub provider
GitHubProviderName = "github"
// GitLabProviderName identifies the GitLab provider
GitLabProviderName = "gitlab"
// GoogleProviderName identifies the Google provider
GoogleProviderName = "google"
// LinkedInProviderName identifies the LinkedIn provider
LinkedInProviderName = "linkedin"
// OIDCProviderName identifies the generic OpenID Connect provider
OIDCProviderName = "oidc"
)

// Provider is an interface exposing functions necessary to authenticate with a given provider.
type Provider interface {
Data() *ProviderData
Expand Down

0 comments on commit 3b439f6

Please sign in to comment.