From 2ee511ecf3055d13c7194183916efd3388fb1563 Mon Sep 17 00:00:00 2001 From: Genki Sugawara Date: Tue, 26 Nov 2024 20:01:48 +0900 Subject: [PATCH 1/2] Support driver.Connector --- db.go | 24 ++++++++++++++++++++++++ db_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/db.go b/db.go index 142ceee..c0f85ef 100644 --- a/db.go +++ b/db.go @@ -110,6 +110,30 @@ type txDriver struct { dsn string } +type txConnector struct { + driver *txDriver + dsn string +} + +func (c *txConnector) Driver() driver.Driver { + return c.driver +} + +func (c *txConnector) Connect(ctx context.Context) (driver.Conn, error) { + return c.driver.Open(c.dsn) +} + +func NewConnector(srcDrv, srcDsn, dsn string) driver.Connector { + return &txConnector{ + driver: &txDriver{ + dsn: srcDsn, + drv: srcDrv, + conns: make(map[string]*conn), + }, + dsn: dsn, + } +} + func (d *txDriver) Open(dsn string) (driver.Conn, error) { d.Lock() defer d.Unlock() diff --git a/db_test.go b/db_test.go index fe84eae..e3ea04f 100644 --- a/db_test.go +++ b/db_test.go @@ -51,6 +51,36 @@ func TestShouldRunWithinTransaction(t *testing.T) { } } +func TestShouldRunWithinTransactionForOpenDB(t *testing.T) { + t.Parallel() + var count int + db1 := sql.OpenDB(pgtxdb.NewConnector("pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable", "one")) + defer db1.Close() + + _, err := db1.Exec(`INSERT INTO app_user(username, email) VALUES('txdb', 'txdb@test.com')`) + if err != nil { + t.Fatalf("failed to insert an app_user: %s", err) + } + err = db1.QueryRow("SELECT COUNT(id) FROM app_user").Scan(&count) + if err != nil { + t.Fatalf("failed to count users: %s", err) + } + if count != 1 { + t.Fatalf("expected 1 user to be in database, but got %d", count) + } + + db2 := sql.OpenDB(pgtxdb.NewConnector("pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable", "two")) + defer db2.Close() + + err = db2.QueryRow("SELECT COUNT(id) FROM app_user").Scan(&count) + if err != nil { + t.Fatalf("failed to count app_user: %s", err) + } + if count != 0 { + t.Fatalf("expected 0 user to be in database, but got %d", count) + } +} + func TestShouldNotHoldConnectionForRows(t *testing.T) { t.Parallel() db, err := sql.Open("pgtxdb", "three") From 9004e4678201369c9f346b1546c976a18104edea Mon Sep 17 00:00:00 2001 From: Genki Sugawara Date: Tue, 26 Nov 2024 20:24:35 +0900 Subject: [PATCH 2/2] Fix arg order --- db.go | 2 +- db_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/db.go b/db.go index c0f85ef..d5efea2 100644 --- a/db.go +++ b/db.go @@ -123,7 +123,7 @@ func (c *txConnector) Connect(ctx context.Context) (driver.Conn, error) { return c.driver.Open(c.dsn) } -func NewConnector(srcDrv, srcDsn, dsn string) driver.Connector { +func NewConnector(dsn, srcDrv, srcDsn string) driver.Connector { return &txConnector{ driver: &txDriver{ dsn: srcDsn, diff --git a/db_test.go b/db_test.go index e3ea04f..5f58b59 100644 --- a/db_test.go +++ b/db_test.go @@ -54,7 +54,7 @@ func TestShouldRunWithinTransaction(t *testing.T) { func TestShouldRunWithinTransactionForOpenDB(t *testing.T) { t.Parallel() var count int - db1 := sql.OpenDB(pgtxdb.NewConnector("pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable", "one")) + db1 := sql.OpenDB(pgtxdb.NewConnector("one", "pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable")) defer db1.Close() _, err := db1.Exec(`INSERT INTO app_user(username, email) VALUES('txdb', 'txdb@test.com')`) @@ -69,7 +69,7 @@ func TestShouldRunWithinTransactionForOpenDB(t *testing.T) { t.Fatalf("expected 1 user to be in database, but got %d", count) } - db2 := sql.OpenDB(pgtxdb.NewConnector("pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable", "two")) + db2 := sql.OpenDB(pgtxdb.NewConnector("two", "pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable")) defer db2.Close() err = db2.QueryRow("SELECT COUNT(id) FROM app_user").Scan(&count)