Skip to content

Commit

Permalink
feat: dynamic collection name
Browse files Browse the repository at this point in the history
Signed-off-by: 0xb4lamx <[email protected]>
  • Loading branch information
0xb4lamx committed Mar 23, 2021
1 parent 6e29879 commit 889de6f
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 22 deletions.
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,49 @@ func main() {
e.SavePolicy()
}
```
## Advanced Example

```go
package main

import (
"github.com/casbin/casbin/v2"
"github.com/casbin/mongodb-adapter/v3"
mongooptions "go.mongodb.org/mongo-driver/mongo/options"
)

func main() {
// Initialize a MongoDB adapter with NewAdapterWithClientOption:
// The adapter will use custom mongo client options.
// custom database name.
// default collection name 'casbin_rule'.
mongoClientOption := mongooptions.Client().ApplyURI("mongodb://127.0.0.1:27017")
databaseName := "casbin"
a,err := mongodbadapter.NewAdapterWithClientOption(mongoClientOption, databaseName)
// Or you can use NewAdapterWithCollectionName for custom collection name.
if err != nil {
panic(err)
}

e, err := casbin.NewEnforcer("examples/rbac_model.conf", a)
if err != nil {
panic(err)
}

// Load the policy from DB.
e.LoadPolicy()

// Check the permission.
e.Enforce("alice", "data1", "read")

// Modify the policy.
// e.AddPolicy(...)
// e.RemovePolicy(...)

// Save the policy back to DB.
e.SavePolicy()
}
```

## Filtered Policies

Expand Down
26 changes: 20 additions & 6 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import (
)

const defaultTimeout time.Duration = 30 * time.Second
const defaultDatabaseName string = "casbin"
const defaultCollectionName string = "casbin_rule"

// CasbinRule represents a rule in Casbin.
type CasbinRule struct {
Expand Down Expand Up @@ -60,6 +62,7 @@ func finalizer(a *adapter) {

// NewAdapter is the constructor for Adapter. If database name is not provided
// in the Mongo URL, 'casbin' will be used as database name.
// 'casbin_rule' will be used as a collection name.
func NewAdapter(url string, timeout ...interface{}) (persist.BatchAdapter, error) {
if !strings.HasPrefix(url, "mongodb+srv://") && !strings.HasPrefix(url, "mongodb://") {
url = fmt.Sprint("mongodb://" + url)
Expand All @@ -79,15 +82,26 @@ func NewAdapter(url string, timeout ...interface{}) (persist.BatchAdapter, error
if connString.Database != "" {
databaseName = connString.Database
} else {
databaseName = "casbin_rule"
databaseName = defaultDatabaseName
}

return NewAdapterWithClientOption(clientOption, databaseName, timeout...)
return baseNewAdapter(clientOption, databaseName, defaultCollectionName, timeout...)
}

// NewAdapterWithClientOption is an alternative constructor for Adapter
// that does the same as NewAdapter, but uses mongo.ClientOption instead of a Mongo URL
// that does the same as NewAdapter, but uses mongo.ClientOption instead of a Mongo URL + a databaseName option
func NewAdapterWithClientOption(clientOption *options.ClientOptions, databaseName string, timeout ...interface{}) (persist.BatchAdapter, error) {
return baseNewAdapter(clientOption, databaseName, defaultCollectionName, timeout...)
}

// NewAdapterWithCollectionName is an alternative constructor for Adapter
// that does the same as NewAdapterWithClientOption, but with an extra collectionName option
func NewAdapterWithCollectionName(clientOption *options.ClientOptions, databaseName string, collectionName string, timeout ...interface{}) (persist.BatchAdapter, error) {
return baseNewAdapter(clientOption, databaseName, collectionName, timeout...)
}

// baseNewAdapter is a base constructor for Adapter
func baseNewAdapter(clientOption *options.ClientOptions, databaseName string, collectionName string, timeout ...interface{}) (persist.BatchAdapter, error) {
a := &adapter{
clientOption: clientOption,
}
Expand All @@ -102,7 +116,7 @@ func NewAdapterWithClientOption(clientOption *options.ClientOptions, databaseNam
}

// Open the DB, create it if not existed.
err := a.open(databaseName)
err := a.open(databaseName, collectionName)
if err != nil {
return nil, err
}
Expand All @@ -125,7 +139,7 @@ func NewFilteredAdapter(url string) (persist.FilteredAdapter, error) {
return a.(*adapter), nil
}

func (a *adapter) open(databaseName string) error {
func (a *adapter) open(databaseName string, collectionName string) error {
ctx, cancel := context.WithTimeout(context.TODO(), a.timeout)
defer cancel()

Expand All @@ -135,7 +149,7 @@ func (a *adapter) open(databaseName string) error {
}

db := client.Database(databaseName)
collection := db.Collection("casbin_rule")
collection := db.Collection(collectionName)

a.client = client
a.collection = collection
Expand Down
61 changes: 45 additions & 16 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ package mongodbadapter
import (
"fmt"
"os"
"strings"
"testing"

"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/util"
"go.mongodb.org/mongo-driver/bson"
mongooptions "go.mongodb.org/mongo-driver/mongo/options"
)

var testDbURL = os.Getenv("TEST_MONGODB_URL")
Expand Down Expand Up @@ -104,7 +106,7 @@ func TestAdapter(t *testing.T) {
{"bob", "data2", "write"},
{"data2_admin", "data2", "read"},
{"data2_admin", "data2", "write"},
},
},
)
// AutoSave is enabled by default.
// Now we disable it.
Expand All @@ -122,7 +124,7 @@ func TestAdapter(t *testing.T) {
{"bob", "data2", "write"},
{"data2_admin", "data2", "read"},
{"data2_admin", "data2", "write"},
},
},
)

// Now we enable the AutoSave.
Expand All @@ -142,8 +144,8 @@ func TestAdapter(t *testing.T) {
{"data2_admin", "data2", "read"},
{"data2_admin", "data2", "write"},
{"alice", "data1", "write"},
},
)
},
)

// Remove the added rule.
e.RemovePolicy("alice", "data1", "write")
Expand All @@ -158,7 +160,7 @@ func TestAdapter(t *testing.T) {
{"bob", "data2", "write"},
{"data2_admin", "data2", "read"},
{"data2_admin", "data2", "write"},
},
},
)

// Remove "data2_admin" related policy rules via a filter.
Expand All @@ -170,7 +172,7 @@ func TestAdapter(t *testing.T) {
testGetPolicy(t, e, [][]string{
{"alice", "data1", "read"},
{"bob", "data2", "write"},
},
},
)

e.RemoveFilteredPolicy(1, "data1")
Expand Down Expand Up @@ -204,15 +206,15 @@ func TestAddPolicies(t *testing.T) {
{"bob", "data2", "write"},
{"data2_admin", "data2", "read"},
{"data2_admin", "data2", "write"},
},
},
)
a.AddPolicies("p","p",[][]string{
a.AddPolicies("p", "p", [][]string{
{"bob", "data2", "read"},
{"alice", "data2", "write"},
{"alice", "data2", "read"},
{"bob", "data1", "write"},
{"bob", "data1", "read"},
},
},
)

if err := e.LoadPolicy(); err != nil {
Expand All @@ -222,14 +224,14 @@ func TestAddPolicies(t *testing.T) {
testGetPolicy(t, e, [][]string{
{"alice", "data1", "read"},
{"bob", "data2", "write"},
{"data2_admin", "data2", "read"},
{"data2_admin", "data2", "read"},
{"data2_admin", "data2", "write"},
{"bob", "data2", "read"},
{"alice", "data2", "write"},
{"alice", "data2", "read"},
{"bob", "data1", "write"},
{"bob", "data1", "read"},
},
},
)

// Remove the added rule.
Expand All @@ -248,10 +250,10 @@ func TestAddPolicies(t *testing.T) {
testGetPolicy(t, e, [][]string{
{"alice", "data1", "read"},
{"bob", "data2", "write"},
{"data2_admin", "data2", "read"},
{"data2_admin", "data2", "read"},
{"data2_admin", "data2", "write"},
},
)
},
)
}

func TestDeleteFilteredAdapter(t *testing.T) {
Expand Down Expand Up @@ -322,7 +324,7 @@ func TestFilteredAdapter(t *testing.T) {
if err != nil {
panic(err)
}

// Load filtered policies from the database.
e.AddPolicy("alice", "data1", "write")
e.AddPolicy("bob", "data2", "write")
Expand All @@ -343,7 +345,7 @@ func TestFilteredAdapter(t *testing.T) {
testGetPolicy(t, e, [][]string{
{"alice", "data1", "read"},
{"alice", "data1", "write"},
},
},
)

// Test safe handling of SavePolicy when using filtered policies.
Expand Down Expand Up @@ -400,3 +402,30 @@ func TestNewAdapterWithDatabase(t *testing.T) {
panic(err)
}
}

func TestNewAdapterWithClientOption(t *testing.T) {
uri := getDbURL()
if !strings.HasPrefix(uri, "mongodb+srv://") && !strings.HasPrefix(uri, "mongodb://") {
uri = fmt.Sprint("mongodb://" + uri)
}
mongoClientOption := mongooptions.Client().ApplyURI(uri)
databaseName := "casbin_custom"
_, err := NewAdapterWithClientOption(mongoClientOption, databaseName)
if err != nil {
panic(err)
}
}

func TestNewAdapterWithCollectionName(t *testing.T) {
uri := getDbURL()
if !strings.HasPrefix(uri, "mongodb+srv://") && !strings.HasPrefix(uri, "mongodb://") {
uri = fmt.Sprint("mongodb://" + uri)
}
mongoClientOption := mongooptions.Client().ApplyURI(uri)
databaseName := "casbin_custom"
collectionName := "casbin_rule_custom"
_, err := NewAdapterWithCollectionName(mongoClientOption, databaseName, collectionName)
if err != nil {
panic(err)
}
}

0 comments on commit 889de6f

Please sign in to comment.