From 7d5de3b8fe58ad69633c97e960d60ba392c3af4d Mon Sep 17 00:00:00 2001 From: Qishuai Liu Date: Thu, 25 Jul 2024 18:10:54 +0900 Subject: [PATCH] Revert "feat: transaction support powerful more" This reverts commit cf1908f1e53280f181d588b741a47895d88a0535. --- common.go | 23 ----- database.go | 24 +++-- expression.go | 8 +- field.go | 4 +- select.go | 10 +- table.go | 3 - transaction.go | 242 ++------------------------------------------ transaction_test.go | 143 +------------------------- 8 files changed, 30 insertions(+), 427 deletions(-) diff --git a/common.go b/common.go index ff26294..6dc593e 100644 --- a/common.go +++ b/common.go @@ -140,29 +140,6 @@ func commaOrderBys(scope scope, orderBys []OrderBy) (string, error) { } func getCallerInfo(db database, retry bool) string { - if !db.enableCallerInfo { - return "" - } - extraInfo := "" - if retry { - extraInfo += " (retry)" - } - for i := 0; true; i++ { - _, file, line, ok := runtime.Caller(i) - if !ok { - break - } - if file == "" || strings.Contains(file, "/sqlingo@v") { - continue - } - segs := strings.Split(file, "/") - name := segs[len(segs)-1] - return fmt.Sprintf("/* %s:%d%s */ ", name, line, extraInfo) - } - return "" -} - -func getTxCallerInfo(db transaction, retry bool) string { if !db.enableCallerInfo { return "" } diff --git a/database.go b/database.go index 4477c1b..65d7bc5 100644 --- a/database.go +++ b/database.go @@ -58,13 +58,11 @@ type Database interface { Update(table Table) updateWithSet // Initiate a DELETE FROM statement DeleteFrom(table Table) deleteWithTable +} - // Begin Start a new transaction and returning a Transaction object. - // the DDL operations using the returned Transaction object will - // regard as one time transaction. - // User must manually call Commit() or Rollback() to end the transaction, - // after that, more DDL operations or TCL will return error. - Begin() (Transaction, error) +type txOrDB interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) } var ( @@ -74,6 +72,7 @@ var ( type database struct { db *sql.DB + tx *sql.Tx logger LoggerFunc dialect dialect retryPolicy func(error) bool @@ -186,6 +185,13 @@ func (d database) GetDB() *sql.DB { return d.db } +func (d database) getTxOrDB() txOrDB { + if d.tx != nil { + return d.tx + } + return d.db +} + func (d database) Query(sqlString string) (Cursor, error) { return d.QueryContext(context.Background(), sqlString) } @@ -196,7 +202,7 @@ func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, e sqlStringWithCallerInfo := getCallerInfo(d, isRetry) + sqlString rows, err := d.queryContextOnce(ctx, sqlStringWithCallerInfo, isRetry) if err != nil { - isRetry = d.retryPolicy != nil && d.retryPolicy(err) + isRetry = d.tx == nil && d.retryPolicy != nil && d.retryPolicy(err) if isRetry { continue } @@ -221,7 +227,7 @@ func (d database) queryContextOnce(ctx context.Context, sqlString string, retry interceptor := d.interceptor var rows *sql.Rows invoker := func(ctx context.Context, sql string) (err error) { - rows, err = d.GetDB().QueryContext(ctx, sql) + rows, err = d.getTxOrDB().QueryContext(ctx, sql) return } @@ -258,7 +264,7 @@ func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Res var result sql.Result invoker := func(ctx context.Context, sql string) (err error) { - result, err = d.GetDB().ExecContext(ctx, sql) + result, err = d.getTxOrDB().ExecContext(ctx, sql) return } var err error diff --git a/expression.go b/expression.go index 2f6e8bd..74d5d57 100644 --- a/expression.go +++ b/expression.go @@ -148,11 +148,9 @@ func (e expression) GetTable() Table { } type scope struct { - // Transaction should be nil if without transaction begin - Transaction *transaction - Database *database - Tables []Table - lastJoin *join + Database *database + Tables []Table + lastJoin *join } func staticExpression(sql string, priority priority, isBool bool) expression { diff --git a/field.go b/field.go index a20551f..cbe5674 100644 --- a/field.go +++ b/field.go @@ -61,9 +61,7 @@ func newField(table Table, fieldName string) actualField { expression: expression{ builder: func(scope scope) (string, error) { dialect := dialectUnknown - if scope.Transaction != nil { - dialect = scope.Transaction.dialect - } else if scope.Database != nil { + if scope.Database != nil { dialect = scope.Database.dialect } if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName { diff --git a/select.go b/select.go index 5414334..097c356 100644 --- a/select.go +++ b/select.go @@ -610,17 +610,11 @@ func (s selectStatus) FetchCursor() (Cursor, error) { return nil, err } - var c Cursor - if s.base.scope.Transaction != nil { - c, err = s.base.scope.Transaction.QueryContext(s.ctx, sqlString) - } else { - c, err = s.base.scope.Database.QueryContext(s.ctx, sqlString) - } - + cursor, err := s.base.scope.Database.QueryContext(s.ctx, sqlString) if err != nil { return nil, err } - return c, nil + return cursor, nil } func (s selectStatus) FetchFirst(dest ...interface{}) (ok bool, err error) { diff --git a/table.go b/table.go index 09afdf9..dc363f3 100644 --- a/table.go +++ b/table.go @@ -24,9 +24,6 @@ func (t table) GetName() string { } func (t table) GetSQL(scope scope) string { - if scope.Transaction != nil { - return t.sqlDialects[scope.Transaction.dialect] - } return t.sqlDialects[scope.Database.dialect] } diff --git a/transaction.go b/transaction.go index a87a7ee..2ccd0a2 100644 --- a/transaction.go +++ b/transaction.go @@ -3,12 +3,12 @@ package sqlingo import ( "context" "database/sql" - "time" ) // Transaction is the interface of a transaction with underlying sql.Tx object. // It provides methods to execute DDL and TCL operations. type Transaction interface { + GetDB() *sql.DB GetTx() *sql.Tx Query(sql string) (Cursor, error) Execute(sql string) (sql.Result, error) @@ -19,13 +19,10 @@ type Transaction interface { InsertInto(table Table) insertWithTable Update(table Table) updateWithSet DeleteFrom(table Table) deleteWithTable - ReplaceInto(table Table) insertWithTable - // ReplaceInto(table Table) insertWithTable - Commit() error - Rollback() error - Savepoint(name string) error - RollbackTo(name string) error - ReleaseSavepoint(name string) error +} + +func (d *database) GetTx() *sql.Tx { + return d.tx } func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx Transaction) error) error { @@ -44,16 +41,8 @@ func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx T }() if f != nil { - db := transaction{ - tx: tx, - logger: d.logger, - dialect: d.dialect, - retryPolicy: d.retryPolicy, - enableCallerInfo: d.enableCallerInfo, - interceptor: d.interceptor, - } + db := *d db.tx = tx - err = f(&db) if err != nil { return err @@ -67,222 +56,3 @@ func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx T isCommitted = true return nil } - -// Begin starts a new transaction and returning a Transaction object. -// the DDL operations using the returned Transaction object will -// regard as one time transaction. -// User must manually call Commit() or Rollback() to end the transaction, -// after that, more DDL operations or TCL will return error. -func (d *database) Begin() (Transaction, error) { - var err error - tx, err := d.db.Begin() - if err != nil { - return nil, err - } - // copy extra to transaction - t := &transaction{ - tx: tx, - logger: d.logger, - dialect: d.dialect, - retryPolicy: d.retryPolicy, - enableCallerInfo: d.enableCallerInfo, - interceptor: d.interceptor, - } - return t, nil -} - -type transaction struct { - tx *sql.Tx - logger LoggerFunc - dialect dialect - retryPolicy func(error) bool - enableCallerInfo bool - interceptor InterceptorFunc -} - -func (t transaction) GetTx() *sql.Tx { - return t.tx -} - -func (t transaction) Query(sql string) (Cursor, error) { - return t.QueryContext(context.Background(), sql) -} - -func (t transaction) QueryContext(ctx context.Context, sqlString string) (Cursor, error) { - isRetry := false - for { - sqlStringWithCallerInfo := getTxCallerInfo(t, isRetry) + sqlString - - rows, err := t.queryContextOnce(ctx, sqlStringWithCallerInfo) - if err != nil { - isRetry = t.tx == nil && t.retryPolicy != nil && t.retryPolicy(err) - if isRetry { - continue - } - return nil, err - } - return cursor{rows: rows}, nil - } -} - -func (t transaction) queryContextOnce(ctx context.Context, sqlStringWithCallerInfo string) (*sql.Rows, error) { - if ctx == nil { - ctx = context.Background() - } - startTime := time.Now() - defer func() { - endTime := time.Now() - if t.logger != nil { - t.logger(sqlStringWithCallerInfo, endTime.Sub(startTime), true, false) - } - }() - - interceptor := t.interceptor - var rows *sql.Rows - invoker := func(ctx context.Context, sql string) (err error) { - rows, err = t.GetTx().QueryContext(ctx, sql) - return - } - - var err error - if interceptor == nil { - err = invoker(ctx, sqlStringWithCallerInfo) - } else { - err = interceptor(ctx, sqlStringWithCallerInfo, invoker) - } - if err != nil { - return nil, err - } - return rows, nil -} - -func (t transaction) Execute(sql string) (sql.Result, error) { - return t.ExecuteContext(context.Background(), sql) -} - -func (t transaction) ExecuteContext(ctx context.Context, sqlString string) (sql.Result, error) { - if ctx == nil { - ctx = context.Background() - } - sqlStringWithCallerInfo := getTxCallerInfo(t, false) + sqlString - startTime := time.Now() - defer func() { - endTime := time.Now() - if t.logger != nil { - t.logger(sqlStringWithCallerInfo, endTime.Sub(startTime), true, false) - } - }() - - var result sql.Result - invoker := func(ctx context.Context, sql string) (err error) { - result, err = t.GetTx().ExecContext(ctx, sql) - return - } - var err error - if t.interceptor == nil { - err = invoker(ctx, sqlStringWithCallerInfo) - } else { - err = t.interceptor(ctx, sqlStringWithCallerInfo, invoker) - } - if err != nil { - return nil, err - } - - return result, err -} - -func (t transaction) Select(fields ...interface{}) selectWithFields { - return selectStatus{ - base: selectBase{ - scope: scope{ - Transaction: &t, - }, - fields: getFields(fields), - }, - } -} - -func (t transaction) SelectDistinct(fields ...interface{}) selectWithFields { - return selectStatus{ - base: selectBase{ - scope: scope{ - Transaction: &t, - }, - fields: getFields(fields), - distinct: true, - }, - } -} - -func (t transaction) SelectFrom(tables ...Table) selectWithTables { - return selectStatus{ - base: selectBase{ - scope: scope{ - Transaction: &t, - Tables: tables, - }, - }, - } -} - -func (t transaction) InsertInto(table Table) insertWithTable { - return insertStatus{ - scope: scope{ - Transaction: &t, - Tables: []Table{table}, - }, - } -} - -func (t transaction) Update(table Table) updateWithSet { - return updateStatus{ - scope: scope{ - Transaction: &t, - Tables: []Table{table}}, - } -} - -func (t transaction) DeleteFrom(table Table) deleteWithTable { - return deleteStatus{ - scope: scope{ - Transaction: &t, - Tables: []Table{table}, - }, - } -} - -func (t transaction) ReplaceInto(table Table) insertWithTable { - return insertStatus{ - method: "REPLACE", - scope: scope{ - Transaction: &t, - Tables: []Table{table}, - }, - } -} - -func (t transaction) Commit() error { - return t.GetTx().Commit() -} - -func (t transaction) Rollback() error { - return t.GetTx().Rollback() -} - -// Savepoint todo defend sql injection -func (t transaction) Savepoint(name string) error { - _, err := t.GetTx().Exec("SAVEPOINT " + name) - return err -} - -// RollbackTo todo defend sql injection -func (t transaction) RollbackTo(name string) error { - _, err := t.GetTx().Exec("ROLLBACK TO " + name) - return err -} - -// ReleaseSavepoint todo defend sql injection -func (t transaction) ReleaseSavepoint(name string) error { - _, err := t.GetTx().Exec("RELEASE SAVEPOINT " + name) - return err -} diff --git a/transaction_test.go b/transaction_test.go index c1711da..8e5b64c 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -2,7 +2,6 @@ package sqlingo import ( "context" - "database/sql" "errors" "testing" ) @@ -29,6 +28,9 @@ func (m *mockTx) Rollback() error { func TestTransaction(t *testing.T) { db := newMockDatabase() err := db.BeginTx(nil, nil, func(tx Transaction) error { + if tx.GetDB() != db.GetDB() { + t.Error() + } if tx.GetTx() == nil { t.Error() } @@ -79,142 +81,3 @@ func TestTransaction(t *testing.T) { t.Error("should get error here") } } - -func TestTransaction_Commit(t *testing.T) { - db := newMockDatabase() - tx, err := db.Begin() - if err != nil { - t.Error(err) - } - - if err = tx.Commit(); err != nil { - t.Error(err) - } - - if !sharedMockConn.mockTx.isCommitted { - t.Error() - } -} - -func TestTransaction_Rollback(t *testing.T) { - db := newMockDatabase() - tx, err := db.Begin() - if err != nil { - t.Error(err) - } - - if err = tx.Rollback(); err != nil { - t.Error(err) - } - if !sharedMockConn.mockTx.isRolledBack { - t.Error() - } -} - -func TestTransaction_Done(t *testing.T) { - db := newMockDatabase() - tx, err := db.Begin() - if err != nil { - t.Error(err) - } - - if err = tx.Commit(); err != nil { - t.Error(err) - } - - if err = tx.Rollback(); !errors.Is(err, sql.ErrTxDone) { - t.Error(err) - } - - if err = tx.Commit(); !errors.Is(err, sql.ErrTxDone) { - t.Error(err) - } - - if _, err = tx.Select(1).FetchAll(); !errors.Is(err, sql.ErrTxDone) { - t.Error(err) - } -} - -func TestTransaction_Execute(t *testing.T) { - var sqlCount = make(map[string]int) - db := newMockDatabase() - - tx, err := db.Begin() - if err != nil { - t.Error(err) - } - db.SetInterceptor(func(ctx context.Context, sql string, invoker InvokerFunc) error { - sqlCount[sql]++ - return invoker(ctx, sql) - }) - - if _, err = tx.Execute("SQL 1 NOT SET INTERCEPTOR"); err != nil { - t.Error(err) - } - if sqlCount["SQL 1 NOT SET INTERCEPTOR"] != 0 { - t.Error() - } - - if err = tx.Rollback(); err != nil { - t.Error(err) - } - - tx, err = db.Begin() - if err != nil { - t.Error(err) - } - if _, err = tx.Execute("SQL 2 SET INTERCEPTOR"); err != nil { - t.Error(err) - } - if sqlCount["SQL 2 SET INTERCEPTOR"] != 1 { - t.Error() - } - - if err = tx.Commit(); err != nil { - t.Error(err) - } -} - -// TestTransaction_CRUD tests the CRUD operations in a transaction, cause sql build is tested on database, -// so we only insure there is no panic here. -func TestTransaction_CRUD(t *testing.T) { - db := newMockDatabase() - db.EnableCallerInfo(true) - tx, err := db.Begin() - if err != nil { - t.Error(err) - } - _, err = tx.Select().From(table1).FetchAll() - if err != nil { - t.Error(err) - } - - if _, err = tx.SelectFrom(table1).FetchAll(); err != nil { - t.Error(err) - } - - if _, err = tx.SelectDistinct(field2).From(table1).FetchAll(); err != nil { - t.Error(err) - } - - if _, err = tx.InsertInto(Test).Values(1, 2).Execute(); err != nil { - t.Error(err) - } - - if _, err = tx.ReplaceInto(Test).Values(1, 2).Execute(); err != nil { - t.Error(err) - } - - if _, err = tx.DeleteFrom(table1).Where().Execute(); err != nil { - t.Error(err) - } - - if _, err = tx.Update(table1).Set(field1, 1).Where().Execute(); err != nil { - t.Error(err) - } - - if err = tx.Rollback(); err != nil { - t.Error(err) - } - -}