diff --git a/role_manager.go b/role_manager.go index b320c54..faf2d4f 100644 --- a/role_manager.go +++ b/role_manager.go @@ -15,6 +15,7 @@ package sessionrolemanager import ( + "errors" "sort" "github.com/casbin/casbin/rbac" @@ -48,16 +49,17 @@ func (rm *RoleManager) createRole(name string) *SessionRole { } // Clear clears all stored data and resets the role manager to the initial state. -func (rm *RoleManager) Clear() { +func (rm *RoleManager) Clear() error { rm.allRoles = make(map[string]*SessionRole) + return nil } // AddLink adds the inheritance link between role: name1 and role: name2. // aka role: name1 inherits role: name2. // timeRange is the time range when the role inheritance link is active. -func (rm *RoleManager) AddLink(name1 string, name2 string, timeRange ...string) { +func (rm *RoleManager) AddLink(name1 string, name2 string, timeRange ...string) error { if len(timeRange) != 2 { - return + return errors.New("error: timeRange should be 2 parameters") } startTime := timeRange[0] endTime := timeRange[1] @@ -67,62 +69,64 @@ func (rm *RoleManager) AddLink(name1 string, name2 string, timeRange ...string) session := Session{role2, startTime, endTime} role1.addSession(session) + return nil } // DeleteLink deletes the inheritance link between role: name1 and role: name2. // aka role: name1 does not inherit role: name2 any more. // unused is not used. -func (rm *RoleManager) DeleteLink(name1 string, name2 string, unused ...string) { +func (rm *RoleManager) DeleteLink(name1 string, name2 string, unused ...string) error { if !rm.hasRole(name1) || !rm.hasRole(name2) { - return + return errors.New("error: name1 or name2 does not exist") } role1 := rm.createRole(name1) role2 := rm.createRole(name2) role1.deleteSessions(role2.name) + return nil } // HasLink determines whether role: name1 inherits role: name2. // requestTime is the querying time for the role inheritance link. -func (rm *RoleManager) HasLink(name1 string, name2 string, requestTime ...string) bool { +func (rm *RoleManager) HasLink(name1 string, name2 string, requestTime ...string) (bool, error) { if len(requestTime) != 1 { - return false + return false, errors.New("error: requestTime should be 1 parameter") } if name1 == name2 { - return true + return true, nil } if !rm.hasRole(name1) || !rm.hasRole(name2) { - return false + return false, nil } role1 := rm.createRole(name1) - return role1.hasValidSession(name2, rm.maxHierarchyLevel, requestTime[0]) + return role1.hasValidSession(name2, rm.maxHierarchyLevel, requestTime[0]), nil } // GetRoles gets the roles that a subject inherits. // currentTime is the querying time for the role inheritance link. -func (rm *RoleManager) GetRoles(name string, currentTime ...string) []string { +func (rm *RoleManager) GetRoles(name string, currentTime ...string) ([]string, error) { if len(currentTime) != 1 { - return nil + return nil, errors.New("error: currentTime should be 1 parameter") } requestTime := currentTime[0] if !rm.hasRole(name) { - return nil + return nil, errors.New("error: name does not exist") } sessionRoles := rm.createRole(name).getSessionRoles(requestTime) - return sessionRoles + return sessionRoles, nil } // GetUsers gets the users that inherits a subject. // currentTime is the querying time for the role inheritance link. -func (rm *RoleManager) GetUsers(name string, currentTime ...string) []string { +func (rm *RoleManager) GetUsers(name string, currentTime ...string) ([]string, error) { if len(currentTime) != 1 { - return nil + return nil, errors.New("error: currentTime should be 1 parameter") } requestTime := currentTime[0] @@ -133,14 +137,15 @@ func (rm *RoleManager) GetUsers(name string, currentTime ...string) []string { } } sort.Strings(users) - return users + return users, nil } // PrintRoles prints all the roles to log. -func (rm *RoleManager) PrintRoles() { +func (rm *RoleManager) PrintRoles() error { for _, role := range rm.allRoles { util.LogPrint(role.toString()) } + return nil } // SessionRole is a modified version of the default role. diff --git a/role_manager_test.go b/role_manager_test.go index e1f46b8..3e0b725 100644 --- a/role_manager_test.go +++ b/role_manager_test.go @@ -20,7 +20,7 @@ import ( "time" "github.com/casbin/casbin" - "github.com/casbin/casbin/file-adapter" + "github.com/casbin/casbin/persist/file-adapter" "github.com/casbin/casbin/rbac" "github.com/casbin/casbin/util" ) @@ -34,7 +34,7 @@ func testEnforce(t *testing.T, e *casbin.Enforcer, sub string, obj interface{}, func testSessionRole(t *testing.T, rm rbac.RoleManager, name1 string, name2 string, requestTime string, res bool) { t.Helper() - myRes := rm.HasLink(name1, name2, requestTime) + myRes, _ := rm.HasLink(name1, name2, requestTime) if myRes != res { t.Errorf("%s < %s at time %s: %t, supposed to be %t", name1, name2, requestTime, !res, res) @@ -43,7 +43,7 @@ func testSessionRole(t *testing.T, rm rbac.RoleManager, name1 string, name2 stri func testPrintSessionRoles(t *testing.T, rm rbac.RoleManager, name1 string, requestTime string, res []string) { t.Helper() - myRes := rm.GetRoles(name1, requestTime) + myRes, _ := rm.GetRoles(name1, requestTime) if !util.ArrayEquals(myRes, res) { t.Errorf("%s should have the roles %s at time %s, but has: %s", name1, res, requestTime, myRes) @@ -137,18 +137,18 @@ func TestHasLink(t *testing.T) { alpha := "alpha" bravo := "bravo" - if rm.HasLink(alpha, bravo) { + if res, _ := rm.HasLink(alpha, bravo); res { t.Errorf("Role manager should not have link %s < %s", alpha, bravo) } - if !rm.HasLink(alpha, alpha, getCurrentTime()) { + if res, _ := rm.HasLink(alpha, alpha, getCurrentTime()); !res { t.Errorf("Role manager should have link %s < %s", alpha, alpha) } - if rm.HasLink(alpha, bravo, getCurrentTime()) { + if res, _ := rm.HasLink(alpha, bravo, getCurrentTime()); res { t.Errorf("Role manager should not have link %s < %s", alpha, bravo) } rm.AddLink(alpha, bravo, getCurrentTime(), getInOneHour()) - if !rm.HasLink(alpha, bravo, getCurrentTime()) { + if res, _ := rm.HasLink(alpha, bravo, getCurrentTime()); !res { t.Errorf("Role manager should have link %s < %s", alpha, bravo) } } @@ -163,14 +163,14 @@ func TestDeleteLink(t *testing.T) { rm.AddLink(alpha, charlie, getOneHourAgo(), getInOneHour()) rm.DeleteLink(alpha, bravo) - if rm.HasLink(alpha, bravo, getCurrentTime()) { + if res, _ := rm.HasLink(alpha, bravo, getCurrentTime()); res { t.Errorf("Role manager should not have link %s < %s", alpha, bravo) } rm.DeleteLink(alpha, "delta") rm.DeleteLink(bravo, charlie) - if !rm.HasLink(alpha, charlie, getCurrentTime()) { + if res, _ := rm.HasLink(alpha, charlie, getCurrentTime()); !res { t.Errorf("Role manager should have link %s < %s", alpha, charlie) } } @@ -180,7 +180,7 @@ func TestHierarchieLevel(t *testing.T) { rm.AddLink("alpha", "bravo", getOneHourAgo(), getInOneHour()) rm.AddLink("bravo", "charlie", getOneHourAgo(), getInOneHour()) - if rm.HasLink("alpha", "charlie", getCurrentTime()) { + if res, _ := rm.HasLink("alpha", "charlie", getCurrentTime()); res { t.Error("Role manager should not have link alpha < charlie") } } @@ -191,10 +191,10 @@ func TestOutdatedSessions(t *testing.T) { rm.AddLink("alpha", "bravo", getOneHourAgo(), getCurrentTime()) rm.AddLink("bravo", "charlie", getOneHourAgo(), getInOneHour()) - if rm.HasLink("alpha", "bravo", getInOneHour()) { + if res, _ := rm.HasLink("alpha", "bravo", getInOneHour()); res { t.Error("Role manager should not have link alpha < bravo") } - if !rm.HasLink("alpha", "charlie", getOneHourAgo()) { + if res, _ := rm.HasLink("alpha", "charlie", getOneHourAgo()); !res { t.Error("Role manager should have link alpha < charlie") } } @@ -202,11 +202,11 @@ func TestOutdatedSessions(t *testing.T) { func TestGetRoles(t *testing.T) { rm := NewRoleManager(3) - if rm.GetRoles("alpha") != nil { + if res, _ := rm.GetRoles("alpha"); res != nil { t.Error("Calling GetRoles without a time should return nil.") } - if rm.GetRoles("bravo", getCurrentTime()) != nil { + if res, _ := rm.GetRoles("bravo", getCurrentTime()); res != nil { t.Error("bravo should not exist") } @@ -236,27 +236,27 @@ func TestGetUsers(t *testing.T) { rm.AddLink("charlie", "alpha", getOneHourAgo(), getInOneHour()) rm.AddLink("delta", "alpha", getOneHourAgo(), getInOneHour()) - myRes := rm.GetUsers("alpha") + myRes, _ := rm.GetUsers("alpha") if myRes != nil { t.Errorf("Calling GetUsers without a time should return nil.") } - myRes = rm.GetUsers("alpha", getCurrentTime()) + myRes, _ = rm.GetUsers("alpha", getCurrentTime()) if !util.ArrayEquals(myRes, []string{"bravo", "charlie", "delta"}) { t.Errorf("Alpha should have the using roles [bravo charlie delta] but has %s", myRes) } - myRes = rm.GetUsers("alpha", getOneHourAgo()) + myRes, _ = rm.GetUsers("alpha", getOneHourAgo()) if !util.ArrayEquals(myRes, []string{"bravo", "charlie", "delta"}) { t.Errorf("Alpha should have the using roles [bravo charlie delta] but has %s", myRes) } - myRes = rm.GetUsers("alpha", getAfterOneHour()) + myRes, _ = rm.GetUsers("alpha", getAfterOneHour()) if !util.ArrayEquals(myRes, []string{}) { t.Errorf("Alpha should not have any using roles but has %s", myRes) } - myRes = rm.GetUsers("bravo", getCurrentTime()) + myRes, _ = rm.GetUsers("bravo", getCurrentTime()) if !util.ArrayEquals(myRes, []string{}) { t.Error("bravo should have no using roles") }