Skip to content

Commit

Permalink
fix(UpdatableAdapter(add method UpdateFilteredPolicies)): upgrade cab… (
Browse files Browse the repository at this point in the history
casbin#103)

* fix(UpdatableAdapter(add method UpdateFilteredPolicies)): upgrade cabin/v2 to v2.28.3, add the new method UpdateFilteredPolicies, otherwise e.UpdatePolicy(...) will cause panic.

Signed-off-by: gaozhihui <[email protected]>

* fix(adapter_test):

in the test case, after e.UpdateFilteredPlicies, the order of policies maybe change, so add a new func to fix it.
  • Loading branch information
adsian authored Apr 26, 2021
1 parent 53e10dc commit 40a8c16
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 4 deletions.
117 changes: 116 additions & 1 deletion adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) {
//} else if driverName == "sqlite3" {
// db, err = gorm.Open(sqlite.Open(dataSourceName), &gorm.Config{})
} else {
return nil, errors.New("Database dialect '"+driverName+"' is not supported. Supported databases are postgres, mysql and sqlserver")
return nil, errors.New("Database dialect '" + driverName + "' is not supported. Supported databases are postgres, mysql and sqlserver")
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -640,3 +640,118 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules []
}
return tx.Commit().Error
}

func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) {
// UpdateFilteredPolicies deletes old rules and adds new rules.
line := a.getTableInstance()

line.Ptype = ptype
if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
line.V0 = fieldValues[0-fieldIndex]
}
if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
line.V1 = fieldValues[1-fieldIndex]
}
if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
line.V2 = fieldValues[2-fieldIndex]
}
if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
line.V3 = fieldValues[3-fieldIndex]
}
if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
line.V4 = fieldValues[4-fieldIndex]
}
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
line.V5 = fieldValues[5-fieldIndex]
}

newP := make([]CasbinRule, 0, len(newPolicies))
oldP := make([]CasbinRule, 0)
for _, newRule := range newPolicies {
newP = append(newP, a.savePolicyLine(ptype, newRule))
}

tx := a.db.Begin()

for i := range newP {
str, args := line.queryString()
if err := tx.Where(str, args...).Find(&oldP).Error; err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Where(str, args...).Delete([]CasbinRule{}).Error; err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Create(&newP[i]).Error; err != nil {
tx.Rollback()
return nil, err
}
}

// return deleted rulues
oldPolicies := make([][]string, 0)
for _, v := range oldP {
oldPolicy := v.toStringPolicy()
oldPolicies = append(oldPolicies, oldPolicy)
}
return oldPolicies, tx.Commit().Error
}

func (c *CasbinRule) queryString() (interface{}, []interface{}) {
queryArgs := []interface{}{c.Ptype}

queryStr := "ptype = ?"
if c.V0 != "" {
queryStr += " and v0 = ?"
queryArgs = append(queryArgs, c.V0)
}
if c.V1 != "" {
queryStr += " and v1 = ?"
queryArgs = append(queryArgs, c.V1)
}
if c.V2 != "" {
queryStr += " and v2 = ?"
queryArgs = append(queryArgs, c.V2)
}
if c.V3 != "" {
queryStr += " and v3 = ?"
queryArgs = append(queryArgs, c.V3)
}
if c.V4 != "" {
queryStr += " and v4 = ?"
queryArgs = append(queryArgs, c.V4)
}
if c.V5 != "" {
queryStr += " and v5 = ?"
queryArgs = append(queryArgs, c.V5)
}

return queryStr, queryArgs
}

func (c *CasbinRule) toStringPolicy() []string {
policy := make([]string, 0)
if c.Ptype != "" {
policy = append(policy, c.Ptype)
}
if c.V0 != "" {
policy = append(policy, c.V0)
}
if c.V1 != "" {
policy = append(policy, c.V1)
}
if c.V2 != "" {
policy = append(policy, c.V2)
}
if c.V3 != "" {
policy = append(policy, c.V3)
}
if c.V4 != "" {
policy = append(policy, c.V4)
}
if c.V5 != "" {
policy = append(policy, c.V5)
}
return policy
}
57 changes: 57 additions & 0 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,50 @@ func testGetPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
}
}

func testGetPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) {
myRes := e.GetPolicy()
log.Print("Policy: ", myRes)

if !arrayEqualsWithoutOrder(myRes, res) {
t.Error("Policy: ", myRes, ", supposed to be ", res)
}
}

func arrayEqualsWithoutOrder(a [][]string, b [][]string) bool {
if len(a) != len(b) {
return false
}

mapA := make(map[int]string)
mapB := make(map[int]string)
order := make(map[int]struct{})
l := len(a)

for i := 0; i < l; i++ {
mapA[i] = util.ArrayToString(a[i])
mapB[i] = util.ArrayToString(b[i])
}

for i := 0; i < l; i++ {
for j := 0; j < l; j++ {
if _, ok := order[j]; ok {
if j == l-1 {
return false
} else {
continue
}
}
if mapA[i] == mapB[j] {
order[j] = struct{}{}
break
} else if j == l-1 {
return false
}
}
}
return true
}

func initPolicy(t *testing.T, a *Adapter) {
// Because the DB is empty at first,
// so we need to load the policy from the file adapter (.CSV) first.
Expand Down Expand Up @@ -256,6 +300,17 @@ func testUpdatePolicies(t *testing.T, a *Adapter) {
testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}

func testUpdateFilteredPolicies(t *testing.T, a *Adapter) {
// NewEnforcer() will load the policy automatically.
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)

e.EnableAutoSave(true)
e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read")
e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write")
e.LoadPolicy()
testGetPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}})
}

func TestAdapterWithCustomTable(t *testing.T) {
db, err := gorm.Open(postgres.Open("user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable"), &gorm.Config{})
if err != nil {
Expand Down Expand Up @@ -371,10 +426,12 @@ func TestAdapters(t *testing.T) {
a = initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule")
testUpdatePolicy(t, a)
testUpdatePolicies(t, a)
testUpdateFilteredPolicies(t, a)

a = initAdapter(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable")
testUpdatePolicy(t, a)
testUpdatePolicies(t, a)
testUpdateFilteredPolicies(t, a)

//a = initAdapter(t, "sqlite3", "casbin.db")
//testUpdatePolicy(t, a)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/casbin/gorm-adapter/v3
go 1.14

require (
github.com/casbin/casbin/v2 v2.25.5
github.com/casbin/casbin/v2 v2.28.3
github.com/go-sql-driver/mysql v1.5.0
github.com/jackc/pgconn v1.8.0
github.com/lib/pq v1.8.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw=
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0=
github.com/casbin/casbin/v2 v2.25.5 h1:TPKaoGu1gqAVJtQ2MaTfdHn2zgnCaulLylbNXbY6TYo=
github.com/casbin/casbin/v2 v2.25.5/go.mod h1:wUgota0cQbTXE6Vd+KWpg41726jFRi7upxio0sR+Xd0=
github.com/casbin/casbin/v2 v2.28.3 h1:iHxxEsNHwSciRoYh+54etVUA8AXKS9OKzNy6/39UWvY=
github.com/casbin/casbin/v2 v2.28.3/go.mod h1:vByNa/Fchek0KZUgG5wEsl7iFsiviAYKRtgrQfcJqHg=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
Expand Down

0 comments on commit 40a8c16

Please sign in to comment.