From 7b9fb266670f85019450a9b139b80ccab872c660 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Thu, 15 Nov 2018 16:37:07 -0800 Subject: [PATCH] Fix group tests, add sign-in tests, and generate nonces --- internal/auth/providers/azure.go | 18 ++- internal/auth/providers/azure_test.go | 171 ++++++++++++++++---------- 2 files changed, 119 insertions(+), 70 deletions(-) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index cac4c638..1d24a145 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -2,6 +2,9 @@ package providers import ( "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" "errors" "fmt" "net/http" @@ -99,7 +102,6 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionS // TODO: test this w/ an account that uses an alias and compare email claim // with UPN claim; UPN has usually been what you want, but I think it's not // rendered as a full email address here. - // FIXME: validate nonce against session s = &sessions.SessionState{ AccessToken: token.AccessToken, @@ -245,12 +247,24 @@ func (p *AzureV2Provider) GetSignInURL(redirectURI, state string) string { params.Add("scope", p.Scope) params.Add("state", state) params.Set("prompt", p.ApprovalPrompt) - params.Set("nonce", "FIXME") // FIXME, maybe change to session state struct + params.Set("nonce", p.calculateNonce(state)) // required parameter a.RawQuery = params.Encode() return a.String() } +// calculateNonce generates a deterministic nonce from the state value. +// We don't have a session state pointer but we need to generate a nonce +// that we can verify statelessly later. We can only use what's in the +// params and provider struct to assemble a nonce. State is guaranteed to be +// indistinguishable from random and will always change. +func (p *AzureV2Provider) calculateNonce(state string) string { + key := []byte(p.ClientID + p.ClientSecret) + h := hmac.New(sha256.New, key) + h.Write([]byte(state)) + return base64.URLEncoding.EncodeToString(h.Sum(nil))[:8] +} + // ValidateGroupMembership takes in an email and the allowed groups and returns the groups that the email is part of in that list. // If `allGroups` is an empty list it returns all the groups that the user belongs to. func (p *AzureV2Provider) ValidateGroupMembership(email string, allGroups []string) ([]string, error) { diff --git a/internal/auth/providers/azure_test.go b/internal/auth/providers/azure_test.go index 40dea165..df439a3e 100644 --- a/internal/auth/providers/azure_test.go +++ b/internal/auth/providers/azure_test.go @@ -15,7 +15,6 @@ import ( "testing" "time" - "github.com/buzzfeed/sso/internal/pkg/groups" "github.com/buzzfeed/sso/internal/pkg/sessions" "github.com/buzzfeed/sso/internal/pkg/testutil" @@ -369,96 +368,133 @@ func (c *groupsClientMock) Do(req *http.Request) (*http.Response, error) { return &http.Response{}, nil } -func TestAzureV2ValidateGroupMembers(t *testing.T) { +func TestAzureV2GetSignInURL(t *testing.T) { testCases := []struct { - name string - inputAllowedGroups []string - groups []string - groupsError error - getMembersFunc func(string) (groups.MemberSet, bool) - expectedGroups []string - expectedErrorString string + name string + redirectURI string + state string + expectedParams url.Values }{ { - name: "empty input groups should return an empty string", - inputAllowedGroups: []string{}, - groups: []string{"group1"}, - expectedGroups: []string{"group1"}, - getMembersFunc: func(string) (groups.MemberSet, bool) { return nil, false }, + name: "nonce values passed to azure should be deterministic, pass one", + redirectURI: "https://example.com/oauth/callback", + state: "1234", + expectedParams: url.Values{ + "redirect_uri": []string{"https://example.com/oauth/callback"}, + "response_mode": []string{"form_post"}, + "response_type": []string{"id_token code"}, + "scope": []string{"openid email profile offline_access"}, + "state": []string{"1234"}, + "client_id": []string{TestClientID}, + "nonce": []string{"KEB9Aopa"}, + "prompt": []string{"consent"}, + }, }, { - name: "empty inputs and error on groups resource should return error", - inputAllowedGroups: []string{}, - getMembersFunc: func(string) (groups.MemberSet, bool) { return nil, false }, - groupsError: fmt.Errorf("error"), - expectedErrorString: "error", + name: "nonce values passed to azure should be deterministic, pass two", + redirectURI: "https://example.com/oauth/callback", + state: "1234", + expectedParams: url.Values{ + "redirect_uri": []string{"https://example.com/oauth/callback"}, + "response_mode": []string{"form_post"}, + "response_type": []string{"id_token code"}, + "scope": []string{"openid email profile offline_access"}, + "state": []string{"1234"}, + "client_id": []string{TestClientID}, + "nonce": []string{"KEB9Aopa"}, + "prompt": []string{"consent"}, + }, }, { - name: "member exists in cache, should not call groups resource", - inputAllowedGroups: []string{"group1"}, - groupsError: fmt.Errorf("should not get here"), - getMembersFunc: func(string) (groups.MemberSet, bool) { return groups.MemberSet{"email": {}}, true }, - expectedGroups: []string{"group1"}, + name: "nonce values passed to azure should be deterministic, pass three", + redirectURI: "https://example.com/oauth/callback", + state: "4321", + expectedParams: url.Values{ + "redirect_uri": []string{"https://example.com/oauth/callback"}, + "response_mode": []string{"form_post"}, + "response_type": []string{"id_token code"}, + "scope": []string{"openid email profile offline_access"}, + "state": []string{"4321"}, + "client_id": []string{TestClientID}, + "nonce": []string{"x_PhEN0K"}, + "prompt": []string{"consent"}, + }, }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p := newAzureV2Provider(nil) + p.ClientID = TestClientID + p.ClientSecret = "456" + p.Scope = "openid email profile offline_access" + p.ApprovalPrompt = "consent" + + signInURL := p.GetSignInURL(tc.redirectURI, tc.state) + parsedURL, err := url.Parse(signInURL) + if err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(tc.expectedParams, parsedURL.Query()) { + t.Logf("expected params %+v", tc.expectedParams) + t.Logf("got params %+v", parsedURL.Query()) + t.Errorf("unexpected params returned") + } + }) + } +} + +func TestAzureV2ValidateGroupMembers(t *testing.T) { + testCases := []struct { + name string + allowedGroups []string + mockedGroups []string + mockedError error + expectedGroups []string + expectedErrorString string + }{ { - name: "member does not exist in cache, should still not call groups resource", - inputAllowedGroups: []string{"group1"}, - groupsError: fmt.Errorf("should not get here"), - getMembersFunc: func(string) (groups.MemberSet, bool) { return groups.MemberSet{}, true }, - expectedGroups: []string{}, + name: "allowed groups and groups resource output exactly match should return all groups", + allowedGroups: []string{"group1", "group2", "group3"}, + mockedGroups: []string{"group1", "group2", "group3"}, + expectedGroups: []string{"group1", "group2", "group3"}, }, { - name: "subset of groups are not cached, calls groups resource", - inputAllowedGroups: []string{"group1", "group2"}, - groups: []string{"group1", "group2", "group3"}, - groupsError: nil, - getMembersFunc: func(group string) (groups.MemberSet, bool) { - switch group { - case "group1": - return groups.MemberSet{"email": {}}, true - default: - return groups.MemberSet{}, false - } - }, + name: "allowed groups should restrict to subset of groups", + allowedGroups: []string{"group1", "group2"}, + mockedGroups: []string{"group1", "group2", "group3"}, expectedGroups: []string{"group1", "group2"}, }, { - name: "subset of groups are not cached, calls groups resource with error", - inputAllowedGroups: []string{"group1", "group2"}, - groupsError: fmt.Errorf("error"), - getMembersFunc: func(group string) (groups.MemberSet, bool) { - switch group { - case "group1": - return groups.MemberSet{"email": {}}, true - default: - return groups.MemberSet{}, false - } - }, - expectedErrorString: "error", + name: "allowed groups superset should not restrict to subset of groups", + allowedGroups: []string{"group1", "group2", "group3"}, + mockedGroups: []string{"group1", "group2"}, + expectedGroups: []string{"group1", "group2"}, }, { - name: "subset of groups not there, does not call groups resource", - inputAllowedGroups: []string{"group1", "group2"}, - groups: []string{"group1", "group2", "group3"}, - groupsError: fmt.Errorf("should not get here"), - getMembersFunc: func(group string) (groups.MemberSet, bool) { - switch group { - case "group1": - return groups.MemberSet{"email": {}}, true - default: - return groups.MemberSet{}, true - } - }, + name: "groups allowed zero value should default to return all groups", + allowedGroups: []string{}, + mockedGroups: []string{"group1"}, expectedGroups: []string{"group1"}, }, + { + name: "empty inputs and error on groups resource should return error", + allowedGroups: []string{}, + mockedError: fmt.Errorf("error"), + expectedErrorString: "error", + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { p := newAzureV2Provider(nil) - p.GraphService = &MockAzureGraphService{Groups: tc.groups, GroupsError: tc.groupsError} + p.GraphService = &MockAzureGraphService{ + Groups: tc.mockedGroups, + GroupsError: tc.mockedError, + } - groups, err := p.ValidateGroupMembership("email", tc.inputAllowedGroups) + groups, err := p.ValidateGroupMembership("test@example.com", tc.allowedGroups) if err != nil { if tc.expectedErrorString != err.Error() { @@ -470,7 +506,6 @@ func TestAzureV2ValidateGroupMembers(t *testing.T) { t.Logf("got groups %v", groups) t.Errorf("unexpected groups returned") } - }) } }