diff --git a/cmd/sso-auth/main.go b/cmd/sso-auth/main.go index e62dc223..776980cd 100644 --- a/cmd/sso-auth/main.go +++ b/cmd/sso-auth/main.go @@ -37,7 +37,7 @@ func main() { return nil } - authenticator, err := auth.NewAuthenticator(opts, emailValidator, auth.SetCookieStore(opts), auth.AssignStatsdClient(opts)) + authenticator, err := auth.NewAuthenticator(opts, emailValidator, auth.AssignProvider(opts), auth.SetCookieStore(opts), auth.AssignStatsdClient(opts)) if err != nil { logger.Error(err, "error creating new Authenticator") os.Exit(1) diff --git a/internal/auth/authenticator.go b/internal/auth/authenticator.go index 75ee1a37..c7035c99 100644 --- a/internal/auth/authenticator.go +++ b/internal/auth/authenticator.go @@ -144,7 +144,6 @@ func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error) Host: opts.Host, CookieSecure: opts.CookieSecure, - provider: opts.provider, redirectURL: redirectURL, SetXAuthRequest: opts.SetXAuthRequest, PassUserHeaders: opts.PassUserHeaders, diff --git a/internal/auth/authenticator_test.go b/internal/auth/authenticator_test.go index c3117450..e37115b3 100644 --- a/internal/auth/authenticator_test.go +++ b/internal/auth/authenticator_test.go @@ -1541,7 +1541,9 @@ func TestOAuthStart(t *testing.T) { opts := testOpts("abced", "testtest") opts.RedirectURL = "https://example.com/oauth2/callback" opts.Validate() - proxy, _ := NewAuthenticator(opts, func(p *Authenticator) error { + u, _ := url.Parse("http://example.com") + provider := providers.NewTestProvider(u) + proxy, _ := NewAuthenticator(opts, setTestProvider(provider), func(p *Authenticator) error { p.Validator = func(string) bool { return true } return nil }, setMockCSRFStore(&sessions.MockCSRFStore{})) @@ -1632,3 +1634,45 @@ func TestHostHeader(t *testing.T) { }) } } + +func TestGoogleProviderApiSettings(t *testing.T) { + opts := testOpts("abced", "testtest") + opts.Provider = "google" + opts.Validate() + proxy, _ := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error { + p.Validator = func(string) bool { return true } + return nil + }) + p := proxy.provider.Data() + testutil.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", + p.SignInURL.String()) + testutil.Equal(t, "https://www.googleapis.com/oauth2/v3/token", + p.RedeemURL.String()) + testutil.Equal(t, "", p.ProfileURL.String()) + testutil.Equal(t, "profile email", p.Scope) +} + +func TestGoogleGroupInvalidFile(t *testing.T) { + opts := testOpts("abced", "testtest") + opts.Provider = "google" + opts.GoogleAdminEmail = "admin@example.com" + opts.GoogleServiceAccountJSON = "file_doesnt_exist.json" + opts.Validate() + _, err := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error { + p.Validator = func(string) bool { return true } + return nil + }) + testutil.NotEqual(t, nil, err) + testutil.Equal(t, "invalid Google credentials file: file_doesnt_exist.json", err.Error()) +} + +func TestUnimplementedProvider(t *testing.T) { + opts := testOpts("abced", "testtest") + opts.Validate() + _, err := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error { + p.Validator = func(string) bool { return true } + return nil + }) + testutil.NotEqual(t, nil, err) + testutil.Equal(t, "unimplemented provider: \"\"", err.Error()) +} diff --git a/internal/auth/options.go b/internal/auth/options.go index ee590eb4..6a03cd3b 100644 --- a/internal/auth/options.go +++ b/internal/auth/options.go @@ -106,7 +106,6 @@ type Options struct { // internal values that are set after config validation redirectURL *url.URL - provider providers.Provider decodedCookieSecret []byte GroupsCacheStopFunc func() } @@ -170,7 +169,7 @@ func (o *Options) Validate() error { o.redirectURL, msgs = parseURL(o.RedirectURL, "redirect", msgs) - msgs = parseProviderInfo(o, msgs) + msgs = validateEndpoints(o, msgs) decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret) if err != nil { @@ -198,15 +197,6 @@ func (o *Options) Validate() error { o.CookieExpire.String())) } - if o.GoogleAdminEmail != "" || o.GoogleServiceAccountJSON != "" { - if o.GoogleAdminEmail == "" { - msgs = append(msgs, "missing setting: google-admin-email") - } - if o.GoogleServiceAccountJSON == "" { - msgs = append(msgs, "missing setting: google-service-account-json") - } - } - msgs = validateCookieName(o, msgs) if o.StatsdHost == "" { @@ -224,7 +214,24 @@ func (o *Options) Validate() error { return nil } -func parseProviderInfo(o *Options, msgs []string) []string { +func validateEndpoints(o *Options, msgs []string) []string { + _, msgs = parseURL(o.SignInURL, "signin", msgs) + _, msgs = parseURL(o.RedeemURL, "redeem", msgs) + _, msgs = parseURL(o.ProfileURL, "profile", msgs) + _, msgs = parseURL(o.ValidateURL, "validate", msgs) + + return msgs +} + +func validateCookieName(o *Options, msgs []string) []string { + cookie := &http.Cookie{Name: o.CookieName} + if cookie.String() == "" { + return append(msgs, fmt.Sprintf("invalid cookie name: %q", o.CookieName)) + } + return msgs +} + +func newProvider(o *Options) (providers.Provider, error) { p := &providers.ProviderData{ Scope: o.Scope, ClientID: o.ClientID, @@ -232,35 +239,55 @@ func parseProviderInfo(o *Options, msgs []string) []string { ApprovalPrompt: o.ApprovalPrompt, SessionLifetimeTTL: o.SessionLifetimeTTL, } - p.SignInURL, msgs = parseURL(o.SignInURL, "signin", msgs) - p.RedeemURL, msgs = parseURL(o.RedeemURL, "redeem", msgs) + + var err error + if p.SignInURL, err = url.Parse(o.SignInURL); err != nil { + return nil, err + } + if p.RedeemURL, err = url.Parse(o.RedeemURL); err != nil { + return nil, err + } p.RevokeURL = &url.URL{} - p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs) - p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) + if p.ProfileURL, err = url.Parse(o.ProfileURL); err != nil { + return nil, err + } + if p.ValidateURL, err = url.Parse(o.ValidateURL); err != nil { + return nil, err + } - if o.GoogleServiceAccountJSON != "" { - _, err := os.Open(o.GoogleServiceAccountJSON) + var singleFlightProvider providers.Provider + switch o.Provider { + case providers.GoogleProviderName: // Google + if o.GoogleServiceAccountJSON != "" { + _, err := os.Open(o.GoogleServiceAccountJSON) + if err != nil { + return nil, fmt.Errorf("invalid Google credentials file: %s", o.GoogleServiceAccountJSON) + } + } + googleProvider, err := providers.NewGoogleProvider(p, o.GoogleAdminEmail, o.GoogleServiceAccountJSON) if err != nil { - msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON) + return nil, err } + cache := groups.NewFillCache(googleProvider.PopulateMembers, o.GroupsCacheRefreshTTL) + googleProvider.GroupsCache = cache + o.GroupsCacheStopFunc = cache.Stop + singleFlightProvider = providers.NewSingleFlightProvider(googleProvider) + default: + return nil, fmt.Errorf("unimplemented provider: %q", o.Provider) } - googleProvider := providers.NewGoogleProvider(p, o.GoogleAdminEmail, o.GoogleServiceAccountJSON) - cache := groups.NewFillCache(googleProvider.PopulateMembers, o.GroupsCacheRefreshTTL) - googleProvider.GroupsCache = cache - o.GroupsCacheStopFunc = cache.Stop - singleFlightProvider := providers.NewSingleFlightProvider(googleProvider) - - o.provider = singleFlightProvider - return msgs + return singleFlightProvider, nil } -func validateCookieName(o *Options, msgs []string) []string { - cookie := &http.Cookie{Name: o.CookieName} - if cookie.String() == "" { - return append(msgs, fmt.Sprintf("invalid cookie name: %q", o.CookieName)) +// AssignProvider is a function that takes an Options struct and assigns the +// appropriate provider to the proxy. Should be called prior to +// AssignStatsdClient. +func AssignProvider(opts *Options) func(*Authenticator) error { + return func(proxy *Authenticator) error { + var err error + proxy.provider, err = newProvider(opts) + return err } - return msgs } // AssignStatsdClient is function that takes in an Options struct and assigns a statsd client diff --git a/internal/auth/options_test.go b/internal/auth/options_test.go index 842acee0..85cf1ae0 100644 --- a/internal/auth/options_test.go +++ b/internal/auth/options_test.go @@ -53,21 +53,6 @@ func TestNewOptions(t *testing.T) { testutil.Equal(t, expected, err.Error()) } -func TestGoogleGroupInvalidFile(t *testing.T) { - defer func() { - r := recover() - panicMsg := "could not read google credentials file" - if r != panicMsg { - t.Errorf("expected panic with message %s but got %s", panicMsg, r) - } - }() - o := testOptions() - o.GoogleAdminEmail = "admin@example.com" - o.GoogleServiceAccountJSON = "file_doesnt_exist.json" - o.Validate() - -} - func TestInitializedOptions(t *testing.T) { o := testOptions() testutil.Equal(t, nil, o.Validate()) @@ -84,18 +69,6 @@ func TestRedirectURL(t *testing.T) { testutil.Equal(t, expected, o.redirectURL) } -func TestDefaultProviderApiSettings(t *testing.T) { - o := testOptions() - testutil.Equal(t, nil, o.Validate()) - p := o.provider.Data() - testutil.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", - p.SignInURL.String()) - testutil.Equal(t, "https://www.googleapis.com/oauth2/v3/token", - p.RedeemURL.String()) - testutil.Equal(t, "", p.ProfileURL.String()) - testutil.Equal(t, "profile email", p.Scope) -} - func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) { o := testOptions() testutil.Equal(t, nil, o.Validate()) diff --git a/internal/auth/providers/google.go b/internal/auth/providers/google.go index 2f07c662..7c637f4c 100644 --- a/internal/auth/providers/google.go +++ b/internal/auth/providers/google.go @@ -31,7 +31,16 @@ type GoogleProvider struct { } // NewGoogleProvider returns a new GoogleProvider and sets the provider url endpoints. -func NewGoogleProvider(p *ProviderData, adminEmail, credsFilePath string) *GoogleProvider { +func NewGoogleProvider(p *ProviderData, adminEmail, credsFilePath string) (*GoogleProvider, error) { + if adminEmail != "" || credsFilePath != "" { + if adminEmail == "" { + return nil, errors.New("missing setting: google-admin-email") + } + if credsFilePath == "" { + return nil, errors.New("missing setting: google-service-account-json") + } + } + p.ProviderName = "Google" if p.SignInURL.String() == "" { p.SignInURL = &url.URL{Scheme: "https", @@ -77,14 +86,14 @@ func NewGoogleProvider(p *ProviderData, adminEmail, credsFilePath string) *Googl if credsFilePath != "" { credsReader, err := os.Open(credsFilePath) if err != nil { - panic("could not read google credentials file") + return nil, errors.New("could not read google credentials file") } googleProvider.AdminService = &GoogleAdminService{ adminService: getAdminService(adminEmail, credsReader), cb: googleProvider.cb, } } - return googleProvider + return googleProvider, nil } // SetStatsdClient sets the google provider and admin service statsd client diff --git a/internal/auth/providers/google_test.go b/internal/auth/providers/google_test.go index 3a35966c..2367cc30 100644 --- a/internal/auth/providers/google_test.go +++ b/internal/auth/providers/google_test.go @@ -37,7 +37,8 @@ func newGoogleProvider(providerData *ProviderData) *GoogleProvider { ValidateURL: &url.URL{}, Scope: ""} } - return NewGoogleProvider(providerData, "", "") + provider, _ := NewGoogleProvider(providerData, "", "") + return provider } func TestGoogleProviderDefaults(t *testing.T) { diff --git a/internal/auth/providers/providers.go b/internal/auth/providers/providers.go index 1881214f..6075145f 100644 --- a/internal/auth/providers/providers.go +++ b/internal/auth/providers/providers.go @@ -24,6 +24,11 @@ var ( ErrServiceUnavailable = errors.New("SERVICE_UNAVAILABLE") ) +const ( + // GoogleProviderName identifies the Google provider + GoogleProviderName = "google" +) + // Provider is an interface exposing functions necessary to authenticate with a given provider. type Provider interface { Data() *ProviderData