diff --git a/adapter.go b/adapter.go index 8d9585e..6d656a5 100644 --- a/adapter.go +++ b/adapter.go @@ -49,11 +49,10 @@ type CasbinRule struct { // adapter represents the MongoDB adapter for policy storage. type adapter struct { - clientOption *options.ClientOptions - client *mongo.Client - collection *mongo.Collection - timeout time.Duration - filtered bool + client *mongo.Client + collection *mongo.Collection + timeout time.Duration + filtered bool } // finalizer is the destructor for adapter. @@ -103,9 +102,7 @@ func NewAdapterWithCollectionName(clientOption *options.ClientOptions, databaseN // baseNewAdapter is a base constructor for Adapter func baseNewAdapter(clientOption *options.ClientOptions, databaseName string, collectionName string, timeout ...interface{}) (persist.BatchAdapter, error) { - a := &adapter{ - clientOption: clientOption, - } + a := &adapter{} a.filtered = false if len(timeout) == 1 { @@ -117,7 +114,7 @@ func baseNewAdapter(clientOption *options.ClientOptions, databaseName string, co } // Open the DB, create it if not existed. - err := a.open(databaseName, collectionName) + err := a.open(clientOption, databaseName, collectionName) if err != nil { return nil, err } @@ -140,11 +137,45 @@ func NewFilteredAdapter(url string) (persist.FilteredAdapter, error) { return a.(*adapter), nil } -func (a *adapter) open(databaseName string, collectionName string) error { +type AdapterConfig struct { + DatabaseName string + CollectionName string + Timeout time.Duration + IsFiltered bool +} + +func NewAdapterByDB(client *mongo.Client, config *AdapterConfig) (persist.BatchAdapter, error) { + if config == nil { + config = &AdapterConfig{} + } + if config.CollectionName == "" { + config.CollectionName = defaultCollectionName + } + if config.DatabaseName == "" { + config.DatabaseName = defaultDatabaseName + } + if config.Timeout == 0 { + config.Timeout = defaultTimeout + } + + a := &adapter{ + client: client, + collection: client.Database(config.DatabaseName).Collection(config.CollectionName), + timeout: config.Timeout, + filtered: config.IsFiltered, + } + + // Call the destructor when the object is released. + runtime.SetFinalizer(a, finalizer) + + return a, nil +} + +func (a *adapter) open(clientOption *options.ClientOptions, databaseName string, collectionName string) error { ctx, cancel := context.WithTimeout(context.TODO(), a.timeout) defer cancel() - client, err := mongo.Connect(ctx, a.clientOption) + client, err := mongo.Connect(ctx, clientOption) if err != nil { return err } diff --git a/adapter_test.go b/adapter_test.go index fd5636f..e8cd6af 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -15,7 +15,9 @@ package mongodbadapter import ( + "context" "fmt" + "go.mongodb.org/mongo-driver/mongo" "os" "strings" "testing" @@ -481,6 +483,27 @@ func TestNewAdapterWithCollectionName(t *testing.T) { } } +func TestNewAdapterByDB(t *testing.T) { + uri := getDbURL() + if !strings.HasPrefix(uri, "mongodb+srv://") && !strings.HasPrefix(uri, "mongodb://") { + uri = fmt.Sprint("mongodb://" + uri) + } + mongoClientOption := mongooptions.Client().ApplyURI(uri) + client, err := mongo.Connect(context.Background(), mongoClientOption) + if err != nil { + panic(err) + } + + config := AdapterConfig{ + DatabaseName: "casbin_custom", + CollectionName: "casbin_rule_custom", + } + _, err = NewAdapterByDB(client, &config) + if err != nil { + panic(err) + } +} + func TestUpdatePolicy(t *testing.T) { initPolicy(t, getDbURL())