Skip to content

Commit

Permalink
Merge pull request #25 from nrfta/feat/transaction-savepoints
Browse files Browse the repository at this point in the history
transaction savepoints helper functions
  • Loading branch information
strobus authored Feb 17, 2021
2 parents 0d3c973 + 5e04717 commit 43fcbb2
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 18 deletions.
39 changes: 21 additions & 18 deletions tests/integration/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,9 @@ import (
pgh "github.com/neighborly/go-pghelpers"
)

func getEnv(key string, fallback string) string {
value := os.Getenv(key)
if len(value) == 0 {
return fallback
}
return value
}

var _ = Describe("Connection Test", func() {
var port, _ = strconv.Atoi(getEnv("POSTGRES_PORT", "5432"))

var (
testConfig = pgh.PostgresConfig{
Host: getEnv("POSTGRES_HOST", "localhost"),
Port: port,
Username: getEnv("POSTGRES_USERNAME", "postgres"),
Password: getEnv("POSTGRES_PASSWORD", ""),
Database: getEnv("POSTGRES_DATABASE", "postgres"),
SSLEnabled: false,
}
testConfig = getTestConfig()
)

It("should connect to a database", func() {
Expand All @@ -39,3 +22,23 @@ var _ = Describe("Connection Test", func() {
Expect(db.Ping()).To(Succeed())
})
})

func getTestConfig() pgh.PostgresConfig {
var port, _ = strconv.Atoi(getEnv("POSTGRES_PORT", "5432"))
return pgh.PostgresConfig{
Host: getEnv("POSTGRES_HOST", "localhost"),
Port: port,
Username: getEnv("POSTGRES_USERNAME", "postgres"),
Password: getEnv("POSTGRES_PASSWORD", ""),
Database: getEnv("POSTGRES_DATABASE", "postgres"),
SSLEnabled: false,
}
}

func getEnv(key string, fallback string) string {
value := os.Getenv(key)
if len(value) == 0 {
return fallback
}
return value
}
150 changes: 150 additions & 0 deletions tests/integration/tx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package pghelpers_test

import (
"database/sql"

pgh "github.com/neighborly/go-pghelpers"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

var _ = Describe("Tx", func() {
var (
txDB *sql.DB
)
BeforeEach(func() {
db, err := pgh.ConnectPostgres(getTestConfig())
Expect(err).To(BeNil())
txDB = db
_, err = txDB.Exec("DROP TABLE IF EXISTS tx_test")
Expect(err).To(BeNil())
_, err = txDB.Exec("CREATE TABLE tx_test (value int)")
Expect(err).To(BeNil())
})
Context("ExecInTx", func() {
It("should commit tx", func() {
var txErr error
err := pgh.ExecInTx(txDB, func(tx *sql.Tx) bool {
_, txErr = tx.Exec("INSERT INTO tx_test (value) VALUES (1)")
return txErr == nil
})
Expect(err).To(BeNil())
Expect(txErr).To(BeNil())

row := txDB.QueryRow("SELECT value from tx_test")
var result int
err = row.Scan(&result)
Expect(err).To(BeNil())
Expect(result).To(Equal(1))
})

It("should rollback tx", func() {
var txErr error
err := pgh.ExecInTx(txDB, func(tx *sql.Tx) bool {
_, txErr = tx.Exec("INSERT INTO tx_test (value) VALUES (1)")
return false
})
Expect(err).To(BeNil())
Expect(txErr).To(BeNil())

row := txDB.QueryRow("SELECT count(*) from tx_test")
var result int
err = row.Scan(&result)
Expect(err).To(BeNil())
Expect(result).To(Equal(0))
})
})

Context("Savepoints", func() {
It("should commit release savepoint", func() {
var txErr error
err := pgh.ExecInTx(txDB, func(tx *sql.Tx) bool {
_, txErr = tx.Exec("INSERT INTO tx_test (value) VALUES (1)")
if txErr != nil {
return false
}
txErr = pgh.SetSavepoint("test", tx)
if txErr != nil {
return false
}
_, txErr = tx.Exec("INSERT INTO tx_test (value) VALUES (2)")
if txErr != nil {
return false
}
txErr = pgh.ReleaseSavepoint("test", tx)
return txErr == nil
})
Expect(err).To(BeNil())
Expect(txErr).To(BeNil())

expectSavepointRows(2, txDB)
})
It("should rollback savepoint", func() {
var txErr error
err := pgh.ExecInTx(txDB, func(tx *sql.Tx) bool {
_, txErr = tx.Exec("INSERT INTO tx_test (value) VALUES (1)")
if txErr != nil {
return false
}
txErr = pgh.SetSavepoint("test", tx)
if txErr != nil {
return false
}
_, txErr = tx.Exec("INSERT INTO tx_test (value) VALUES (2)")
if txErr != nil {
return false
}
txErr = pgh.RollbackToSavepoint("test", tx)
return txErr == nil
})
Expect(err).To(BeNil())
Expect(txErr).To(BeNil())

expectSavepointRows(1, txDB)
})
It("should rollback second savepoint", func() {
var txErr error
err := pgh.ExecInTx(txDB, func(tx *sql.Tx) bool {
_, txErr = tx.Exec("INSERT INTO tx_test (value) VALUES (1)")
if txErr != nil {
return false
}
txErr = pgh.SetSavepoint("test1", tx)
if txErr != nil {
return false
}
_, txErr = tx.Exec("INSERT INTO tx_test (value) VALUES (2)")
if txErr != nil {
return false
}
txErr = pgh.SetSavepoint("test2", tx)
if txErr != nil {
return false
}
_, txErr = tx.Exec("INSERT INTO tx_test (value) VALUES (3)")
if txErr != nil {
return false
}
txErr = pgh.RollbackToSavepoint("test2", tx)
return txErr == nil
})
Expect(err).To(BeNil())
Expect(txErr).To(BeNil())

expectSavepointRows(2, txDB)
})
})
})

func expectSavepointRows(numRows int, txDB *sql.DB) {
rows, err := txDB.Query("SELECT value from tx_test")
var result int
expectedValue := 1
for rows.Next() {
err = rows.Scan(&result)
Expect(err).To(BeNil())
Expect(result).To(Equal(expectedValue))
expectedValue++
}
Expect(expectedValue).To(Equal(numRows + 1))
}
20 changes: 20 additions & 0 deletions tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pghelpers
import (
"context"
"database/sql"
"fmt"
)

// ExecInTxFunc defines a function type for the ExecInTx function argument.
Expand All @@ -22,3 +23,22 @@ func ExecInTx(db *sql.DB, fn ExecInTxFunc) error {
}
return tx.Rollback()
}

// SetSavepoint sets a named savepoint in the current transaction.
func SetSavepoint(name string, tx *sql.Tx) error {
_, err := tx.Exec(fmt.Sprintf("SAVEPOINT %s", name))
return err
}

// ReleaseSavepoint releases a named savepoint previously set in the transaction. This allows the commands
// executed after the savepoint to be committed.
func ReleaseSavepoint(name string, tx *sql.Tx) error {
_, err := tx.Exec(fmt.Sprintf("RELEASE SAVEPOINT %s", name))
return err
}

// RollbackToSavepoint rolls back the transaction to the named savepoint.
func RollbackToSavepoint(name string, tx *sql.Tx) error {
_, err := tx.Exec(fmt.Sprintf("ROLLBACK TO SAVEPOINT %s", name))
return err
}

0 comments on commit 43fcbb2

Please sign in to comment.