Skip to content

Commit

Permalink
Merge pull request #6 from BrandonRoehl/single-transaction
Browse files Browse the repository at this point in the history
Single transaction
  • Loading branch information
BrandonRoehl authored Oct 1, 2019
2 parents cd198d0 + 5c7d255 commit f25e26f
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 89 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
language: go

go:
- 1.11.x
- 1.13.x
- master

script:
Expand Down
54 changes: 39 additions & 15 deletions dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mysqldump

import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
Expand All @@ -14,9 +15,11 @@ import (
/*
Data struct to configure dump behavior
Out: Stream to wite to
Connection: Database connection to dump
IgnoreTables: Mark sensitive tables to ignore
Out: Stream to wite to
Connection: Database connection to dump
IgnoreTables: Mark sensitive tables to ignore
MaxAllowedPacket: Sets the largest packet size to use in backups
LockTables: Lock all tables for the duration of the dump
*/
type Data struct {
Out io.Writer
Expand All @@ -25,6 +28,7 @@ type Data struct {
MaxAllowedPacket int
LockTables bool

tx *sql.Tx
headerTmpl *template.Template
tableTmpl *template.Template
footerTmpl *template.Template
Expand Down Expand Up @@ -123,11 +127,19 @@ func (data *Data) Dump() error {
data.MaxAllowedPacket = defaultMaxAllowedPacket
}

if err := meta.updateServerVersion(data.Connection); err != nil {
if err := data.getTemplates(); err != nil {
return err
}

if err := data.getTemplates(); err != nil {
// Start the read only transaction and defer the rollback until the end
// This way the database will have the exact state it did at the begining of
// the backup and nothing can be accidentally committed
if err := data.begin(); err != nil {
return err
}
defer data.rollback()

if err := meta.updateServerVersion(data); err != nil {
return err
}

Expand Down Expand Up @@ -173,6 +185,21 @@ func (data *Data) Dump() error {

// MARK: - Private methods

// begin starts a read only transaction that will be whatever the database was
// when it was called
func (data *Data) begin() (err error) {
data.tx, err = data.Connection.BeginTx(context.Background(), &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
ReadOnly: true,
})
return
}

// rollback cancels the transaction
func (data *Data) rollback() error {
return data.tx.Rollback()
}

// MARK: writter methods

func (data *Data) dumpTable(name string) error {
Expand Down Expand Up @@ -214,7 +241,7 @@ func (data *Data) getTemplates() (err error) {
func (data *Data) getTables() ([]string, error) {
tables := make([]string, 0)

rows, err := data.Connection.Query("SHOW TABLES")
rows, err := data.tx.Query("SHOW TABLES")
if err != nil {
return tables, err
}
Expand All @@ -241,10 +268,10 @@ func (data *Data) isIgnoredTable(name string) bool {
return false
}

func (data *metaData) updateServerVersion(db *sql.DB) (err error) {
func (meta *metaData) updateServerVersion(data *Data) (err error) {
var serverVersion sql.NullString
err = db.QueryRow("SELECT version()").Scan(&serverVersion)
data.ServerVersion = serverVersion.String
err = data.tx.QueryRow("SELECT version()").Scan(&serverVersion)
meta.ServerVersion = serverVersion.String
return
}

Expand All @@ -263,7 +290,7 @@ func (table *table) NameEsc() string {

func (table *table) CreateSQL() (string, error) {
var tableReturn, tableSQL sql.NullString
if err := table.data.Connection.QueryRow("SHOW CREATE TABLE "+table.NameEsc()).Scan(&tableReturn, &tableSQL); err != nil {
if err := table.data.tx.QueryRow("SHOW CREATE TABLE "+table.NameEsc()).Scan(&tableReturn, &tableSQL); err != nil {
return "", err
}

Expand All @@ -274,13 +301,12 @@ func (table *table) CreateSQL() (string, error) {
return tableSQL.String, nil
}

// defer rows.Close()
func (table *table) Init() (err error) {
if len(table.types) != 0 {
return errors.New("can't init twice")
}

table.rows, err = table.data.Connection.Query("SELECT * FROM " + table.NameEsc())
table.rows, err = table.data.tx.Query("SELECT * FROM " + table.NameEsc())
if err != nil {
return err
}
Expand All @@ -299,6 +325,7 @@ func (table *table) Init() (err error) {
}

table.types = make([]reflect.Type, len(tt))
table.values = make([]interface{}, len(tt))
for i, tp := range tt {
st := tp.ScanType()
if tp.DatabaseTypeName() == "BLOB" {
Expand All @@ -312,9 +339,6 @@ func (table *table) Init() (err error) {
} else {
table.types[i] = reflect.TypeOf(sql.NullString{})
}
}
table.values = make([]interface{}, len(tt))
for i := range table.values {
table.values[i] = reflect.New(table.types[i]).Interface()
}
return nil
Expand Down
Loading

0 comments on commit f25e26f

Please sign in to comment.