diff --git a/adapter.go b/adapter.go index 5fd34d3..c928a16 100755 --- a/adapter.go +++ b/adapter.go @@ -26,9 +26,11 @@ import ( "github.com/jackc/pgconn" "gorm.io/driver/mysql" "gorm.io/driver/postgres" + //"gorm.io/driver/sqlite" "gorm.io/driver/sqlserver" "gorm.io/gorm" + "gorm.io/gorm/logger" ) const ( @@ -278,6 +280,11 @@ func (a *Adapter) open() error { return a.createTable() } +// AddLogger adds logger to db +func (a *Adapter) AddLogger(l logger.Interface) { + a.db = a.db.Session(&gorm.Session{Logger: l, Context: a.db.Statement.Context}) +} + func (a *Adapter) close() error { a.db = nil return nil diff --git a/adapter_test.go b/adapter_test.go index f847308..f2692ab 100755 --- a/adapter_test.go +++ b/adapter_test.go @@ -16,6 +16,7 @@ package gormadapter import ( "log" + "os" "testing" "github.com/casbin/casbin/v2" @@ -27,6 +28,7 @@ import ( "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/gorm" + "gorm.io/gorm/logger" ) func testGetPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) { @@ -428,7 +430,19 @@ func TestAdapters(t *testing.T) { testUpdatePolicies(t, a) testUpdateFilteredPolicies(t, a) + a = initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule") + a.AddLogger(logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{})) + testUpdatePolicy(t, a) + testUpdatePolicies(t, a) + testUpdateFilteredPolicies(t, a) + + a = initAdapter(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable") + testUpdatePolicy(t, a) + testUpdatePolicies(t, a) + testUpdateFilteredPolicies(t, a) + a = initAdapter(t, "postgres", "user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable") + a.AddLogger(logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{})) testUpdatePolicy(t, a) testUpdatePolicies(t, a) testUpdateFilteredPolicies(t, a)