Skip to content

Commit

Permalink
TT-9366 added new field groupsIds in SSOAccessData (#428)
Browse files Browse the repository at this point in the history
* added new field groupsIds in SSOAccessData

* goimports -w

---------

Co-authored-by: sredny buitrago <[email protected]>
  • Loading branch information
sredxny and sredny buitrago authored Jan 15, 2025
1 parent 97d9d63 commit c0e0c06
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 40 deletions.
61 changes: 43 additions & 18 deletions tap/identity-handlers/tyk_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type SSOAccessData struct {
EmailAddress string
DisplayName string
GroupID string
GroupsIDs []string
SSOOnlyForRegisteredUsers bool
}

Expand Down Expand Up @@ -157,7 +158,8 @@ func (t *TykIdentityHandler) CreateIdentity(i interface{}) (string, error) {
gUser, ok := i.(goth.User)
email := ""
displayName := ""
groupID := ""
var groupsIDs []string
var groupID string
if ok {
email = GetEmail(gUser, t.profile.CustomEmailField)

Expand All @@ -174,16 +176,18 @@ func (t *TykIdentityHandler) CreateIdentity(i interface{}) (string, error) {
displayName = email
}

groupID = GetGroupId(gUser, t.profile.CustomUserGroupField, t.profile.DefaultUserGroupID, t.profile.UserGroupMapping, t.profile.UserGroupSeparator)
groupsIDs = GetGroupId(gUser, t.profile.CustomUserGroupField, t.profile.DefaultUserGroupID, t.profile.UserGroupMapping, t.profile.UserGroupSeparator)
if len(groupsIDs) > 0 {
groupID = groupsIDs[0]
}
}

tykHandlerLogger.Debugf("The GroupID %s is used for SSO: ", groupID)

tykHandlerLogger.Debugf("The GroupIDs %s used for SSO: ", groupsIDs)
accessRequest := SSOAccessData{
ForSection: thisModule,
OrgID: t.profile.OrgID,
EmailAddress: email,
DisplayName: displayName,
GroupsIDs: groupsIDs,
GroupID: groupID,
SSOOnlyForRegisteredUsers: t.profile.SSOOnlyForRegisteredUsers,
}
Expand Down Expand Up @@ -537,20 +541,41 @@ func GetUserID(gUser goth.User, CustomUserIDField string) string {
return gUser.UserID
}

func GetGroupId(gUser goth.User, CustomUserGroupField, DefaultUserGroup string, userGroupMapping map[string]string, userGroupSeparator string) string {
groupID := DefaultUserGroup
if CustomUserGroupField != "" {
groups := make([]string, 0)
if gUser.RawData[CustomUserGroupField] != nil {
groups = groupsStringer(gUser.RawData[CustomUserGroupField], userGroupSeparator)
}
// Helper function to return either [DefaultUserGroup] or an empty slice
func defaultOrEmptyGroupIDs(DefaultUserGroup string) []string {
if DefaultUserGroup == "" {
return []string{}
}
return []string{DefaultUserGroup}
}

for _, group := range groups {
if gid, ok := userGroupMapping[group]; ok {
groupID = gid
log.Debug(groupID)
}
func GetGroupId(gUser goth.User, CustomUserGroupField, DefaultUserGroup string, userGroupMapping map[string]string, userGroupSeparator string) []string {
if CustomUserGroupField == "" {
return defaultOrEmptyGroupIDs(DefaultUserGroup)
}

rawGroups, exists := gUser.RawData[CustomUserGroupField]
if !exists || rawGroups == nil {
return defaultOrEmptyGroupIDs(DefaultUserGroup)
}

groups := groupsStringer(gUser.RawData[CustomUserGroupField], userGroupSeparator)
if len(groups) == 0 {
return defaultOrEmptyGroupIDs(DefaultUserGroup)
}

var groupsIDs []string

// if empty then return []string{defaultUserGroup}
for _, group := range groups {
if gid, ok := userGroupMapping[group]; ok {
groupsIDs = append(groupsIDs, gid)
}
}
return groupID

if len(groupsIDs) == 0 {
return defaultOrEmptyGroupIDs(DefaultUserGroup)
}

return groupsIDs
}
80 changes: 58 additions & 22 deletions tap/identity-handlers/tyk_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package identityHandlers
import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/markbates/goth"
)

Expand Down Expand Up @@ -120,7 +122,7 @@ func TestGetGroupId(t *testing.T) {
TestName string
CustomGroupIDField string
user goth.User
ExpectedGroupID string
ExpectedGroupsIDs []string
DefaultGroupID string
UserGroupMapping map[string]string
UserGroupSeparator string
Expand All @@ -129,31 +131,31 @@ func TestGetGroupId(t *testing.T) {
TestName: "Custom group id field empty",
CustomGroupIDField: "",
user: goth.User{},
ExpectedGroupID: "",
ExpectedGroupsIDs: []string{},
DefaultGroupID: "",
UserGroupMapping: UserGroupMapping,
},
{
TestName: "Custom group id field empty & default group set",
CustomGroupIDField: "",
user: goth.User{},
ExpectedGroupID: DefaultGroupId,
ExpectedGroupsIDs: []string{DefaultGroupId},
DefaultGroupID: DefaultGroupId,
UserGroupMapping: UserGroupMapping,
},
{
TestName: "Custom group id field not empty but invalid & default group set",
CustomGroupIDField: "my-custom-group-id-field",
user: goth.User{},
ExpectedGroupID: DefaultGroupId,
DefaultGroupID: DefaultGroupId,
ExpectedGroupsIDs: []string{DefaultGroupId},
UserGroupMapping: UserGroupMapping,
},
{
TestName: "Custom group id field not empty but invalid & default group not set",
CustomGroupIDField: "my-custom-group-id-field",
user: goth.User{},
ExpectedGroupID: "",
ExpectedGroupsIDs: []string{},
DefaultGroupID: "",
UserGroupMapping: UserGroupMapping,
},
Expand All @@ -165,9 +167,9 @@ func TestGetGroupId(t *testing.T) {
"my-custom-group-id-field": "admins",
},
},
ExpectedGroupID: "admins-group",
DefaultGroupID: "",
UserGroupMapping: UserGroupMapping,
ExpectedGroupsIDs: []string{"admins-group"},
DefaultGroupID: "",
UserGroupMapping: UserGroupMapping,
},
{
TestName: "Receive many groups from idp with blank space separated",
Expand All @@ -177,9 +179,9 @@ func TestGetGroupId(t *testing.T) {
"my-custom-group-id-field": "devs admins",
},
},
ExpectedGroupID: "admins-group",
DefaultGroupID: "",
UserGroupMapping: UserGroupMapping,
ExpectedGroupsIDs: []string{"devs-group", "admins-group"},
DefaultGroupID: "",
UserGroupMapping: UserGroupMapping,
},
{
TestName: "Receive many groups from idp with comma separated",
Expand All @@ -189,7 +191,7 @@ func TestGetGroupId(t *testing.T) {
"my-custom-group-id-field": "devs,admins",
},
},
ExpectedGroupID: "admins-group",
ExpectedGroupsIDs: []string{"devs-group", "admins-group"},
DefaultGroupID: "",
UserGroupMapping: UserGroupMapping,
UserGroupSeparator: ",",
Expand All @@ -202,9 +204,9 @@ func TestGetGroupId(t *testing.T) {
"my-custom-group-id-field": "admins",
},
},
ExpectedGroupID: "admins-group",
DefaultGroupID: "devs",
UserGroupMapping: UserGroupMapping,
ExpectedGroupsIDs: []string{"admins-group"},
DefaultGroupID: "devs",
UserGroupMapping: UserGroupMapping,
},
{
TestName: "Custom group id field not empty, and the claim being an array",
Expand All @@ -218,18 +220,52 @@ func TestGetGroupId(t *testing.T) {
"CN=Normal Group,OU=Security Groups,OU=GenericOrg,DC=GenericOrg,DC=COM,DC=GEN",
},
}},
ExpectedGroupID: "tyk-admin",
DefaultGroupID: "devs",
UserGroupMapping: UserGroupMapping,
ExpectedGroupsIDs: []string{"tyk-admin"},
DefaultGroupID: "devs",
UserGroupMapping: UserGroupMapping,
},
}

for _, tc := range cases {
t.Run(tc.TestName, func(t *testing.T) {
id := GetGroupId(tc.user, tc.CustomGroupIDField, tc.DefaultGroupID, tc.UserGroupMapping, tc.UserGroupSeparator)
if id != tc.ExpectedGroupID {
t.Errorf("group id incorrect. Expected:%v got:%v", tc.ExpectedGroupID, id)
}
ids := GetGroupId(tc.user, tc.CustomGroupIDField, tc.DefaultGroupID, tc.UserGroupMapping, tc.UserGroupSeparator)
assert.Equal(t, tc.ExpectedGroupsIDs, ids)
})
}
}

func Test_defaultOrEmptyGroupIDs(t *testing.T) {
tests := []struct {
name string
defaultUserGroup string
expectedGroupIDs []string
}{
{
name: "Empty default user group",
defaultUserGroup: "",
expectedGroupIDs: []string{},
},
{
name: "Non-empty default user group",
defaultUserGroup: "defaultGroup",
expectedGroupIDs: []string{"defaultGroup"},
},
{
name: "Default user group with spaces",
defaultUserGroup: "default group",
expectedGroupIDs: []string{"default group"},
},
{
name: "Default user group with special characters",
defaultUserGroup: "group@123",
expectedGroupIDs: []string{"group@123"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := defaultOrEmptyGroupIDs(tt.defaultUserGroup)
assert.Equal(t, tt.expectedGroupIDs, result, "The group IDs should match")
})
}
}

0 comments on commit c0e0c06

Please sign in to comment.