Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented periodic ping to keep connection of transaction alive #24

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions expression_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,13 @@ func (db *DB) UpdateFields(fields ...string) *DB {
return db.clone().Set("gorm:save_associations", false).Set("gorm:association_save_reference", false).Update(sets)
}

// UpdateFieldsWithoutHooks updates the specified fields of the current model without calling any
// Update hooks and without touching the UpdatedAt column (if any exists).
// The specified fields have to be the names of the struct variables.
func (db *DB) UpdateFieldsWithoutHooks(fields ...string) *DB {
return db.clone().Set("gorm:update_column", true).UpdateFields(fields...)
}

func (db *DB) SelectFields(fields ...string) *DB {
selects := strings.Join(fields, ", ")

Expand Down
209 changes: 156 additions & 53 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm

import (
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -32,15 +33,17 @@ type DB struct {

// Open initialize a new db connection, need to import driver first, e.g:
//
// import _ "github.com/go-sql-driver/mysql"
// func main() {
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
// }
// import _ "github.com/go-sql-driver/mysql"
// func main() {
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
// }
//
// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
// import _ "github.com/jinzhu/gorm/dialects/mysql"
// // import _ "github.com/jinzhu/gorm/dialects/postgres"
// // import _ "github.com/jinzhu/gorm/dialects/sqlite"
// // import _ "github.com/jinzhu/gorm/dialects/mssql"
//
// import _ "github.com/jinzhu/gorm/dialects/mysql"
// // import _ "github.com/jinzhu/gorm/dialects/postgres"
// // import _ "github.com/jinzhu/gorm/dialects/sqlite"
// // import _ "github.com/jinzhu/gorm/dialects/mssql"
func Open(dialect string, args ...interface{}) (db *DB, err error) {
if len(args) == 0 {
err = errors.New("invalid database source")
Expand Down Expand Up @@ -121,7 +124,9 @@ func (s *DB) Dialect() Dialect {
}

// Callback return `Callbacks` container, you could add/change/delete callbacks with it
// db.Callback().Create().Register("update_created_at", updateCreated)
//
// db.Callback().Create().Register("update_created_at", updateCreated)
//
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
func (s *DB) Callback() *Callback {
s.parent.callbacks = s.parent.callbacks.clone()
Expand Down Expand Up @@ -224,9 +229,10 @@ func (s *DB) Offset(offset interface{}) *DB {
}

// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
// db.Order("name DESC")
// db.Order("name DESC", true) // reorder
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
//
// db.Order("name DESC")
// db.Order("name DESC", true) // reorder
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
func (s *DB) Order(value interface{}, reorder ...bool) *DB {
return s.clone().search.Order(value, reorder...).db
}
Expand All @@ -253,23 +259,26 @@ func (s *DB) Having(query interface{}, values ...interface{}) *DB {
}

// Joins specify Joins conditions
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "[email protected]").Find(&user)
//
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "[email protected]").Find(&user)
func (s *DB) Joins(query interface{}, args ...interface{}) *DB {
return s.clone().search.Joins(query, args...).db
}

// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
//
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
//
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
// Refer https://jinzhu.github.io/gorm/crud.html#scopes
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
for _, f := range funcs {
Expand Down Expand Up @@ -356,8 +365,9 @@ func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
}

// Pluck used to query single column from a model as a map
// var ages []int64
// db.Find(&users).Pluck("age", &ages)
//
// var ages []int64
// db.Find(&users).Pluck("age", &ages)
func (s *DB) Pluck(column string, value interface{}) *DB {
return s.NewScope(s.Value).pluck(column, value).db
}
Expand Down Expand Up @@ -454,7 +464,8 @@ func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
}

// Raw use raw sql as conditions, won't run it unless invoked by other methods
// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
//
// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
func (s *DB) Raw(sql string, values ...interface{}) *DB {
return s.clone().search.Raw(true).Where(sql, values...).db
}
Expand All @@ -469,10 +480,11 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB {
}

// Model specify the model you would like to run db operations
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
// db.Model(&user).Update("name", "hello")
//
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
// db.Model(&user).Update("name", "hello")
func (s *DB) Model(value interface{}) *DB {
c := s.clone()
c.Value = value
Expand Down Expand Up @@ -505,6 +517,30 @@ func (s *DB) Begin() *DB {
return c
}

func (s *DB) BeginFancy() (*DB, *sql.Conn) {
var conn *sql.Conn
var err error

c := s.clone()

if db, ok := c.db.(sqlDb); ok && db != nil {
conn, err = c.DB().Conn(context.Background())
if err != nil {
c.AddError(err)

return c, nil
}

tx, err := conn.BeginTx(context.Background(), nil)
c.db = interface{}(tx).(SQLCommon)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)
}

return c, conn
}

// Commit commit a transaction
func (s *DB) Commit() *DB {
if db, ok := s.db.(sqlTx); ok && db != nil {
Expand All @@ -530,32 +566,96 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) {
if _, ok := s.db.(*sql.Tx); ok {
// Already in a transaction
return f(s)
} else {
// Lets start a new transaction
tx := s.Begin()
if err = tx.Error; err != nil {
return
}

// sConn, err := s.DB().Conn(context.Background())
// if err != nil {
// return fmt.Errorf("Could not get database connection: %w", err)
// }

// defer sConn.Close()

// s.db

// Lets start a new transaction
tx, conn := s.BeginFancy()
if err = tx.Error; err != nil {
return err
}

// Create a channel to stop the ping goroutine.
stopTxPing := make(chan bool)
// Get the database connection for the transaction.
// txConn, err := tx.DB().Conn(context.Background())
// if err != nil {
// return fmt.Errorf("Could not get database connection for transaction: %w", err)
// }

// tx.DB().Ping()

// txBlub := tx.db.(*sql.Tx)
// txBlub.

// Start a goroutine that pings the database connection for a keep-alive.
go func() {
for {
select {
// Stop the goroutine when the stop channel receives a value ..
case <-stopTxPing:
return
// .. otherwise ping the database connection every 10 seconds.
case <-time.After(10 * time.Second):
// if s != nil && s.DB() != nil {
err := conn.PingContext(context.Background())
if err != nil {
tx.AddError(
fmt.Errorf(
"Could not ping database connection for transaction: %w",
err,
),
)

return
}
// }
}
}
panicked := true
defer func() {
if panicked || err != nil {
rollbackErr := tx.Rollback().Error
if rollbackErr != nil {
if err == nil {
err = rollbackErr
} else {
err = fmt.Errorf("Transacton code and rollback failed: %s; %s", err, rollbackErr)
}
}()

panicked := true

defer func() {
if panicked || err != nil {
rollbackErr := tx.Rollback().Error
if rollbackErr != nil {
if err == nil {
err = rollbackErr
} else {
err = fmt.Errorf("Transaction code and rollback failed: %s; %s", err, rollbackErr)
}
}
}()
err = f(tx)
if err == nil {
err = tx.Commit().Error
}
panicked = false
return

if conn != nil {
err = conn.Close()
}
}()

err = f(tx)

// As soon as the inner stack has returned, stop the ping goroutine. As the transaction will be
// only committed after this point, the ping would fail and the goroutine will exit.
stopTxPing <- true
// Last but not least, close the stop ping channel.
close(stopTxPing)

if err == nil {
err = tx.Commit().Error
}

panicked = false

return err
}

// SkipAssocSave disables saving of associations
Expand Down Expand Up @@ -674,15 +774,17 @@ func (s *DB) RemoveIndex(indexName string) *DB {
}

// AddForeignKey Add foreign key to the given scope, e.g:
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
//
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
scope := s.NewScope(s.Value)
scope.addForeignKey(field, dest, onDelete, onUpdate)
return scope.db
}

// RemoveForeignKey Remove foreign key from the given scope, e.g:
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
//
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
func (s *DB) RemoveForeignKey(field string, dest string) *DB {
scope := s.clone().NewScope(s.Value)
scope.removeForeignKey(field, dest)
Expand Down Expand Up @@ -712,7 +814,8 @@ func (s *DB) Association(column string) *Association {
}

// Preload preload associations with given conditions
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
//
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
return s.clone().search.Preload(column, conditions...).db
}
Expand Down