Skip to content

Commit

Permalink
feat: add api NewAdapterByDB (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
JalinWang authored Sep 15, 2022
1 parent 293a343 commit e114641
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 11 deletions.
53 changes: 42 additions & 11 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ type CasbinRule struct {

// adapter represents the MongoDB adapter for policy storage.
type adapter struct {
clientOption *options.ClientOptions
client *mongo.Client
collection *mongo.Collection
timeout time.Duration
filtered bool
client *mongo.Client
collection *mongo.Collection
timeout time.Duration
filtered bool
}

// finalizer is the destructor for adapter.
Expand Down Expand Up @@ -103,9 +102,7 @@ func NewAdapterWithCollectionName(clientOption *options.ClientOptions, databaseN

// baseNewAdapter is a base constructor for Adapter
func baseNewAdapter(clientOption *options.ClientOptions, databaseName string, collectionName string, timeout ...interface{}) (persist.BatchAdapter, error) {
a := &adapter{
clientOption: clientOption,
}
a := &adapter{}
a.filtered = false

if len(timeout) == 1 {
Expand All @@ -117,7 +114,7 @@ func baseNewAdapter(clientOption *options.ClientOptions, databaseName string, co
}

// Open the DB, create it if not existed.
err := a.open(databaseName, collectionName)
err := a.open(clientOption, databaseName, collectionName)
if err != nil {
return nil, err
}
Expand All @@ -140,11 +137,45 @@ func NewFilteredAdapter(url string) (persist.FilteredAdapter, error) {
return a.(*adapter), nil
}

func (a *adapter) open(databaseName string, collectionName string) error {
type AdapterConfig struct {
DatabaseName string
CollectionName string
Timeout time.Duration
IsFiltered bool
}

func NewAdapterByDB(client *mongo.Client, config *AdapterConfig) (persist.BatchAdapter, error) {
if config == nil {
config = &AdapterConfig{}
}
if config.CollectionName == "" {
config.CollectionName = defaultCollectionName
}
if config.DatabaseName == "" {
config.DatabaseName = defaultDatabaseName
}
if config.Timeout == 0 {
config.Timeout = defaultTimeout
}

a := &adapter{
client: client,
collection: client.Database(config.DatabaseName).Collection(config.CollectionName),
timeout: config.Timeout,
filtered: config.IsFiltered,
}

// Call the destructor when the object is released.
runtime.SetFinalizer(a, finalizer)

return a, nil
}

func (a *adapter) open(clientOption *options.ClientOptions, databaseName string, collectionName string) error {
ctx, cancel := context.WithTimeout(context.TODO(), a.timeout)
defer cancel()

client, err := mongo.Connect(ctx, a.clientOption)
client, err := mongo.Connect(ctx, clientOption)
if err != nil {
return err
}
Expand Down
23 changes: 23 additions & 0 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package mongodbadapter

import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/mongo"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -481,6 +483,27 @@ func TestNewAdapterWithCollectionName(t *testing.T) {
}
}

func TestNewAdapterByDB(t *testing.T) {
uri := getDbURL()
if !strings.HasPrefix(uri, "mongodb+srv://") && !strings.HasPrefix(uri, "mongodb://") {
uri = fmt.Sprint("mongodb://" + uri)
}
mongoClientOption := mongooptions.Client().ApplyURI(uri)
client, err := mongo.Connect(context.Background(), mongoClientOption)
if err != nil {
panic(err)
}

config := AdapterConfig{
DatabaseName: "casbin_custom",
CollectionName: "casbin_rule_custom",
}
_, err = NewAdapterByDB(client, &config)
if err != nil {
panic(err)
}
}

func TestUpdatePolicy(t *testing.T) {
initPolicy(t, getDbURL())

Expand Down

0 comments on commit e114641

Please sign in to comment.