diff --git a/adapter.go b/adapter.go index 1d6a1f4..5fd34d3 100755 --- a/adapter.go +++ b/adapter.go @@ -219,7 +219,7 @@ func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) { //} else if driverName == "sqlite3" { // db, err = gorm.Open(sqlite.Open(dataSourceName), &gorm.Config{}) } else { - return nil, errors.New("Database dialect '"+driverName+"' is not supported. Supported databases are postgres, mysql and sqlserver") + return nil, errors.New("Database dialect '" + driverName + "' is not supported. Supported databases are postgres, mysql and sqlserver") } if err != nil { return nil, err @@ -640,3 +640,118 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [] } return tx.Commit().Error } + +func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) { + // UpdateFilteredPolicies deletes old rules and adds new rules. + line := a.getTableInstance() + + line.Ptype = ptype + if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) { + line.V0 = fieldValues[0-fieldIndex] + } + if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) { + line.V1 = fieldValues[1-fieldIndex] + } + if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) { + line.V2 = fieldValues[2-fieldIndex] + } + if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) { + line.V3 = fieldValues[3-fieldIndex] + } + if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) { + line.V4 = fieldValues[4-fieldIndex] + } + if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) { + line.V5 = fieldValues[5-fieldIndex] + } + + newP := make([]CasbinRule, 0, len(newPolicies)) + oldP := make([]CasbinRule, 0) + for _, newRule := range newPolicies { + newP = append(newP, a.savePolicyLine(ptype, newRule)) + } + + tx := a.db.Begin() + + for i := range newP { + str, args := line.queryString() + if err := tx.Where(str, args...).Find(&oldP).Error; err != nil { + tx.Rollback() + return nil, err + } + if err := tx.Where(str, args...).Delete([]CasbinRule{}).Error; err != nil { + tx.Rollback() + return nil, err + } + if err := tx.Create(&newP[i]).Error; err != nil { + tx.Rollback() + return nil, err + } + } + + // return deleted rulues + oldPolicies := make([][]string, 0) + for _, v := range oldP { + oldPolicy := v.toStringPolicy() + oldPolicies = append(oldPolicies, oldPolicy) + } + return oldPolicies, tx.Commit().Error +} + +func (c *CasbinRule) queryString() (interface{}, []interface{}) { + queryArgs := []interface{}{c.Ptype} + + queryStr := "ptype = ?" + if c.V0 != "" { + queryStr += " and v0 = ?" + queryArgs = append(queryArgs, c.V0) + } + if c.V1 != "" { + queryStr += " and v1 = ?" + queryArgs = append(queryArgs, c.V1) + } + if c.V2 != "" { + queryStr += " and v2 = ?" + queryArgs = append(queryArgs, c.V2) + } + if c.V3 != "" { + queryStr += " and v3 = ?" + queryArgs = append(queryArgs, c.V3) + } + if c.V4 != "" { + queryStr += " and v4 = ?" + queryArgs = append(queryArgs, c.V4) + } + if c.V5 != "" { + queryStr += " and v5 = ?" + queryArgs = append(queryArgs, c.V5) + } + + return queryStr, queryArgs +} + +func (c *CasbinRule) toStringPolicy() []string { + policy := make([]string, 0) + if c.Ptype != "" { + policy = append(policy, c.Ptype) + } + if c.V0 != "" { + policy = append(policy, c.V0) + } + if c.V1 != "" { + policy = append(policy, c.V1) + } + if c.V2 != "" { + policy = append(policy, c.V2) + } + if c.V3 != "" { + policy = append(policy, c.V3) + } + if c.V4 != "" { + policy = append(policy, c.V4) + } + if c.V5 != "" { + policy = append(policy, c.V5) + } + return policy +} diff --git a/adapter_test.go b/adapter_test.go index a863e56..f847308 100755 --- a/adapter_test.go +++ b/adapter_test.go @@ -38,6 +38,50 @@ func testGetPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) { } } +func testGetPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) { + myRes := e.GetPolicy() + log.Print("Policy: ", myRes) + + if !arrayEqualsWithoutOrder(myRes, res) { + t.Error("Policy: ", myRes, ", supposed to be ", res) + } +} + +func arrayEqualsWithoutOrder(a [][]string, b [][]string) bool { + if len(a) != len(b) { + return false + } + + mapA := make(map[int]string) + mapB := make(map[int]string) + order := make(map[int]struct{}) + l := len(a) + + for i := 0; i < l; i++ { + mapA[i] = util.ArrayToString(a[i]) + mapB[i] = util.ArrayToString(b[i]) + } + + for i := 0; i < l; i++ { + for j := 0; j < l; j++ { + if _, ok := order[j]; ok { + if j == l-1 { + return false + } else { + continue + } + } + if mapA[i] == mapB[j] { + order[j] = struct{}{} + break + } else if j == l-1 { + return false + } + } + } + return true +} + func initPolicy(t *testing.T, a *Adapter) { // Because the DB is empty at first, // so we need to load the policy from the file adapter (.CSV) first. @@ -256,6 +300,17 @@ func testUpdatePolicies(t *testing.T, a *Adapter) { testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) } +func testUpdateFilteredPolicies(t *testing.T, a *Adapter) { + // NewEnforcer() will load the policy automatically. + e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a) + + e.EnableAutoSave(true) + e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read") + e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write") + e.LoadPolicy() + testGetPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}}) +} + func TestAdapterWithCustomTable(t *testing.T) { db, err := gorm.Open(postgres.Open("user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable"), &gorm.Config{}) if err != nil { @@ -371,10 +426,12 @@ func TestAdapters(t *testing.T) { a = initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule") testUpdatePolicy(t, a) testUpdatePolicies(t, a) + testUpdateFilteredPolicies(t, a) a = initAdapter(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable") testUpdatePolicy(t, a) testUpdatePolicies(t, a) + testUpdateFilteredPolicies(t, a) //a = initAdapter(t, "sqlite3", "casbin.db") //testUpdatePolicy(t, a) diff --git a/go.mod b/go.mod index 86aefde..b2ca239 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/casbin/gorm-adapter/v3 go 1.14 require ( - github.com/casbin/casbin/v2 v2.25.5 + github.com/casbin/casbin/v2 v2.28.3 github.com/go-sql-driver/mysql v1.5.0 github.com/jackc/pgconn v1.8.0 github.com/lib/pq v1.8.0 diff --git a/go.sum b/go.sum index 1e2c755..00c128a 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= -github.com/casbin/casbin/v2 v2.25.5 h1:TPKaoGu1gqAVJtQ2MaTfdHn2zgnCaulLylbNXbY6TYo= -github.com/casbin/casbin/v2 v2.25.5/go.mod h1:wUgota0cQbTXE6Vd+KWpg41726jFRi7upxio0sR+Xd0= +github.com/casbin/casbin/v2 v2.28.3 h1:iHxxEsNHwSciRoYh+54etVUA8AXKS9OKzNy6/39UWvY= +github.com/casbin/casbin/v2 v2.28.3/go.mod h1:vByNa/Fchek0KZUgG5wEsl7iFsiviAYKRtgrQfcJqHg= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=