Skip to content

Commit

Permalink
Merge pull request #76 from Remitly/refactor-provider-options
Browse files Browse the repository at this point in the history
Refactor provider options
  • Loading branch information
Shraya Ramani authored Oct 5, 2018
2 parents 9776b52 + 5a35c4f commit d1f0d32
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 66 deletions.
2 changes: 1 addition & 1 deletion cmd/sso-auth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func main() {
return nil
}

authenticator, err := auth.NewAuthenticator(opts, emailValidator, auth.SetCookieStore(opts), auth.AssignStatsdClient(opts))
authenticator, err := auth.NewAuthenticator(opts, emailValidator, auth.AssignProvider(opts), auth.SetCookieStore(opts), auth.AssignStatsdClient(opts))
if err != nil {
logger.Error(err, "error creating new Authenticator")
os.Exit(1)
Expand Down
1 change: 0 additions & 1 deletion internal/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ func NewAuthenticator(opts *Options, optionFuncs ...func(*Authenticator) error)
Host: opts.Host,
CookieSecure: opts.CookieSecure,

provider: opts.provider,
redirectURL: redirectURL,
SetXAuthRequest: opts.SetXAuthRequest,
PassUserHeaders: opts.PassUserHeaders,
Expand Down
46 changes: 45 additions & 1 deletion internal/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,9 @@ func TestOAuthStart(t *testing.T) {
opts := testOpts("abced", "testtest")
opts.RedirectURL = "https://example.com/oauth2/callback"
opts.Validate()
proxy, _ := NewAuthenticator(opts, func(p *Authenticator) error {
u, _ := url.Parse("http://example.com")
provider := providers.NewTestProvider(u)
proxy, _ := NewAuthenticator(opts, setTestProvider(provider), func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
return nil
}, setMockCSRFStore(&sessions.MockCSRFStore{}))
Expand Down Expand Up @@ -1632,3 +1634,45 @@ func TestHostHeader(t *testing.T) {
})
}
}

func TestGoogleProviderApiSettings(t *testing.T) {
opts := testOpts("abced", "testtest")
opts.Provider = "google"
opts.Validate()
proxy, _ := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
return nil
})
p := proxy.provider.Data()
testutil.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline",
p.SignInURL.String())
testutil.Equal(t, "https://www.googleapis.com/oauth2/v3/token",
p.RedeemURL.String())
testutil.Equal(t, "", p.ProfileURL.String())
testutil.Equal(t, "profile email", p.Scope)
}

func TestGoogleGroupInvalidFile(t *testing.T) {
opts := testOpts("abced", "testtest")
opts.Provider = "google"
opts.GoogleAdminEmail = "[email protected]"
opts.GoogleServiceAccountJSON = "file_doesnt_exist.json"
opts.Validate()
_, err := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
return nil
})
testutil.NotEqual(t, nil, err)
testutil.Equal(t, "invalid Google credentials file: file_doesnt_exist.json", err.Error())
}

func TestUnimplementedProvider(t *testing.T) {
opts := testOpts("abced", "testtest")
opts.Validate()
_, err := NewAuthenticator(opts, AssignProvider(opts), func(p *Authenticator) error {
p.Validator = func(string) bool { return true }
return nil
})
testutil.NotEqual(t, nil, err)
testutil.Equal(t, "unimplemented provider: \"\"", err.Error())
}
91 changes: 59 additions & 32 deletions internal/auth/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ type Options struct {

// internal values that are set after config validation
redirectURL *url.URL
provider providers.Provider
decodedCookieSecret []byte
GroupsCacheStopFunc func()
}
Expand Down Expand Up @@ -170,7 +169,7 @@ func (o *Options) Validate() error {

o.redirectURL, msgs = parseURL(o.RedirectURL, "redirect", msgs)

msgs = parseProviderInfo(o, msgs)
msgs = validateEndpoints(o, msgs)

decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
if err != nil {
Expand Down Expand Up @@ -198,15 +197,6 @@ func (o *Options) Validate() error {
o.CookieExpire.String()))
}

if o.GoogleAdminEmail != "" || o.GoogleServiceAccountJSON != "" {
if o.GoogleAdminEmail == "" {
msgs = append(msgs, "missing setting: google-admin-email")
}
if o.GoogleServiceAccountJSON == "" {
msgs = append(msgs, "missing setting: google-service-account-json")
}
}

msgs = validateCookieName(o, msgs)

if o.StatsdHost == "" {
Expand All @@ -224,43 +214,80 @@ func (o *Options) Validate() error {
return nil
}

func parseProviderInfo(o *Options, msgs []string) []string {
func validateEndpoints(o *Options, msgs []string) []string {
_, msgs = parseURL(o.SignInURL, "signin", msgs)
_, msgs = parseURL(o.RedeemURL, "redeem", msgs)
_, msgs = parseURL(o.ProfileURL, "profile", msgs)
_, msgs = parseURL(o.ValidateURL, "validate", msgs)

return msgs
}

func validateCookieName(o *Options, msgs []string) []string {
cookie := &http.Cookie{Name: o.CookieName}
if cookie.String() == "" {
return append(msgs, fmt.Sprintf("invalid cookie name: %q", o.CookieName))
}
return msgs
}

func newProvider(o *Options) (providers.Provider, error) {
p := &providers.ProviderData{
Scope: o.Scope,
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
ApprovalPrompt: o.ApprovalPrompt,
SessionLifetimeTTL: o.SessionLifetimeTTL,
}
p.SignInURL, msgs = parseURL(o.SignInURL, "signin", msgs)
p.RedeemURL, msgs = parseURL(o.RedeemURL, "redeem", msgs)

var err error
if p.SignInURL, err = url.Parse(o.SignInURL); err != nil {
return nil, err
}
if p.RedeemURL, err = url.Parse(o.RedeemURL); err != nil {
return nil, err
}
p.RevokeURL = &url.URL{}
p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs)
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
if p.ProfileURL, err = url.Parse(o.ProfileURL); err != nil {
return nil, err
}
if p.ValidateURL, err = url.Parse(o.ValidateURL); err != nil {
return nil, err
}

if o.GoogleServiceAccountJSON != "" {
_, err := os.Open(o.GoogleServiceAccountJSON)
var singleFlightProvider providers.Provider
switch o.Provider {
case providers.GoogleProviderName: // Google
if o.GoogleServiceAccountJSON != "" {
_, err := os.Open(o.GoogleServiceAccountJSON)
if err != nil {
return nil, fmt.Errorf("invalid Google credentials file: %s", o.GoogleServiceAccountJSON)
}
}
googleProvider, err := providers.NewGoogleProvider(p, o.GoogleAdminEmail, o.GoogleServiceAccountJSON)
if err != nil {
msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON)
return nil, err
}
cache := groups.NewFillCache(googleProvider.PopulateMembers, o.GroupsCacheRefreshTTL)
googleProvider.GroupsCache = cache
o.GroupsCacheStopFunc = cache.Stop
singleFlightProvider = providers.NewSingleFlightProvider(googleProvider)
default:
return nil, fmt.Errorf("unimplemented provider: %q", o.Provider)
}
googleProvider := providers.NewGoogleProvider(p, o.GoogleAdminEmail, o.GoogleServiceAccountJSON)
cache := groups.NewFillCache(googleProvider.PopulateMembers, o.GroupsCacheRefreshTTL)
googleProvider.GroupsCache = cache
o.GroupsCacheStopFunc = cache.Stop
singleFlightProvider := providers.NewSingleFlightProvider(googleProvider)

o.provider = singleFlightProvider

return msgs
return singleFlightProvider, nil
}

func validateCookieName(o *Options, msgs []string) []string {
cookie := &http.Cookie{Name: o.CookieName}
if cookie.String() == "" {
return append(msgs, fmt.Sprintf("invalid cookie name: %q", o.CookieName))
// AssignProvider is a function that takes an Options struct and assigns the
// appropriate provider to the proxy. Should be called prior to
// AssignStatsdClient.
func AssignProvider(opts *Options) func(*Authenticator) error {
return func(proxy *Authenticator) error {
var err error
proxy.provider, err = newProvider(opts)
return err
}
return msgs
}

// AssignStatsdClient is function that takes in an Options struct and assigns a statsd client
Expand Down
27 changes: 0 additions & 27 deletions internal/auth/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,6 @@ func TestNewOptions(t *testing.T) {
testutil.Equal(t, expected, err.Error())
}

func TestGoogleGroupInvalidFile(t *testing.T) {
defer func() {
r := recover()
panicMsg := "could not read google credentials file"
if r != panicMsg {
t.Errorf("expected panic with message %s but got %s", panicMsg, r)
}
}()
o := testOptions()
o.GoogleAdminEmail = "[email protected]"
o.GoogleServiceAccountJSON = "file_doesnt_exist.json"
o.Validate()

}

func TestInitializedOptions(t *testing.T) {
o := testOptions()
testutil.Equal(t, nil, o.Validate())
Expand All @@ -84,18 +69,6 @@ func TestRedirectURL(t *testing.T) {
testutil.Equal(t, expected, o.redirectURL)
}

func TestDefaultProviderApiSettings(t *testing.T) {
o := testOptions()
testutil.Equal(t, nil, o.Validate())
p := o.provider.Data()
testutil.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline",
p.SignInURL.String())
testutil.Equal(t, "https://www.googleapis.com/oauth2/v3/token",
p.RedeemURL.String())
testutil.Equal(t, "", p.ProfileURL.String())
testutil.Equal(t, "profile email", p.Scope)
}

func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) {
o := testOptions()
testutil.Equal(t, nil, o.Validate())
Expand Down
15 changes: 12 additions & 3 deletions internal/auth/providers/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@ type GoogleProvider struct {
}

// NewGoogleProvider returns a new GoogleProvider and sets the provider url endpoints.
func NewGoogleProvider(p *ProviderData, adminEmail, credsFilePath string) *GoogleProvider {
func NewGoogleProvider(p *ProviderData, adminEmail, credsFilePath string) (*GoogleProvider, error) {
if adminEmail != "" || credsFilePath != "" {
if adminEmail == "" {
return nil, errors.New("missing setting: google-admin-email")
}
if credsFilePath == "" {
return nil, errors.New("missing setting: google-service-account-json")
}
}

p.ProviderName = "Google"
if p.SignInURL.String() == "" {
p.SignInURL = &url.URL{Scheme: "https",
Expand Down Expand Up @@ -77,14 +86,14 @@ func NewGoogleProvider(p *ProviderData, adminEmail, credsFilePath string) *Googl
if credsFilePath != "" {
credsReader, err := os.Open(credsFilePath)
if err != nil {
panic("could not read google credentials file")
return nil, errors.New("could not read google credentials file")
}
googleProvider.AdminService = &GoogleAdminService{
adminService: getAdminService(adminEmail, credsReader),
cb: googleProvider.cb,
}
}
return googleProvider
return googleProvider, nil
}

// SetStatsdClient sets the google provider and admin service statsd client
Expand Down
3 changes: 2 additions & 1 deletion internal/auth/providers/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ func newGoogleProvider(providerData *ProviderData) *GoogleProvider {
ValidateURL: &url.URL{},
Scope: ""}
}
return NewGoogleProvider(providerData, "", "")
provider, _ := NewGoogleProvider(providerData, "", "")
return provider
}

func TestGoogleProviderDefaults(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions internal/auth/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ var (
ErrServiceUnavailable = errors.New("SERVICE_UNAVAILABLE")
)

const (
// GoogleProviderName identifies the Google provider
GoogleProviderName = "google"
)

// Provider is an interface exposing functions necessary to authenticate with a given provider.
type Provider interface {
Data() *ProviderData
Expand Down

0 comments on commit d1f0d32

Please sign in to comment.