Skip to content

Commit

Permalink
Revert "feat: transaction support powerful more"
Browse files Browse the repository at this point in the history
This reverts commit cf1908f.
  • Loading branch information
lqs committed Jul 25, 2024
1 parent 1c7e60d commit 7d5de3b
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 427 deletions.
23 changes: 0 additions & 23 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,29 +140,6 @@ func commaOrderBys(scope scope, orderBys []OrderBy) (string, error) {
}

func getCallerInfo(db database, retry bool) string {
if !db.enableCallerInfo {
return ""
}
extraInfo := ""
if retry {
extraInfo += " (retry)"
}
for i := 0; true; i++ {
_, file, line, ok := runtime.Caller(i)
if !ok {
break
}
if file == "" || strings.Contains(file, "/sqlingo@v") {
continue
}
segs := strings.Split(file, "/")
name := segs[len(segs)-1]
return fmt.Sprintf("/* %s:%d%s */ ", name, line, extraInfo)
}
return ""
}

func getTxCallerInfo(db transaction, retry bool) string {
if !db.enableCallerInfo {
return ""
}
Expand Down
24 changes: 15 additions & 9 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,11 @@ type Database interface {
Update(table Table) updateWithSet
// Initiate a DELETE FROM statement
DeleteFrom(table Table) deleteWithTable
}

// Begin Start a new transaction and returning a Transaction object.
// the DDL operations using the returned Transaction object will
// regard as one time transaction.
// User must manually call Commit() or Rollback() to end the transaction,
// after that, more DDL operations or TCL will return error.
Begin() (Transaction, error)
type txOrDB interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}

var (
Expand All @@ -74,6 +72,7 @@ var (

type database struct {
db *sql.DB
tx *sql.Tx
logger LoggerFunc
dialect dialect
retryPolicy func(error) bool
Expand Down Expand Up @@ -186,6 +185,13 @@ func (d database) GetDB() *sql.DB {
return d.db
}

func (d database) getTxOrDB() txOrDB {
if d.tx != nil {
return d.tx
}
return d.db
}

func (d database) Query(sqlString string) (Cursor, error) {
return d.QueryContext(context.Background(), sqlString)
}
Expand All @@ -196,7 +202,7 @@ func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, e
sqlStringWithCallerInfo := getCallerInfo(d, isRetry) + sqlString
rows, err := d.queryContextOnce(ctx, sqlStringWithCallerInfo, isRetry)
if err != nil {
isRetry = d.retryPolicy != nil && d.retryPolicy(err)
isRetry = d.tx == nil && d.retryPolicy != nil && d.retryPolicy(err)
if isRetry {
continue
}
Expand All @@ -221,7 +227,7 @@ func (d database) queryContextOnce(ctx context.Context, sqlString string, retry
interceptor := d.interceptor
var rows *sql.Rows
invoker := func(ctx context.Context, sql string) (err error) {
rows, err = d.GetDB().QueryContext(ctx, sql)
rows, err = d.getTxOrDB().QueryContext(ctx, sql)
return
}

Expand Down Expand Up @@ -258,7 +264,7 @@ func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Res

var result sql.Result
invoker := func(ctx context.Context, sql string) (err error) {
result, err = d.GetDB().ExecContext(ctx, sql)
result, err = d.getTxOrDB().ExecContext(ctx, sql)
return
}
var err error
Expand Down
8 changes: 3 additions & 5 deletions expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,9 @@ func (e expression) GetTable() Table {
}

type scope struct {
// Transaction should be nil if without transaction begin
Transaction *transaction
Database *database
Tables []Table
lastJoin *join
Database *database
Tables []Table
lastJoin *join
}

func staticExpression(sql string, priority priority, isBool bool) expression {
Expand Down
4 changes: 1 addition & 3 deletions field.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ func newField(table Table, fieldName string) actualField {
expression: expression{
builder: func(scope scope) (string, error) {
dialect := dialectUnknown
if scope.Transaction != nil {
dialect = scope.Transaction.dialect
} else if scope.Database != nil {
if scope.Database != nil {
dialect = scope.Database.dialect
}
if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName {
Expand Down
10 changes: 2 additions & 8 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,17 +610,11 @@ func (s selectStatus) FetchCursor() (Cursor, error) {
return nil, err
}

var c Cursor
if s.base.scope.Transaction != nil {
c, err = s.base.scope.Transaction.QueryContext(s.ctx, sqlString)
} else {
c, err = s.base.scope.Database.QueryContext(s.ctx, sqlString)
}

cursor, err := s.base.scope.Database.QueryContext(s.ctx, sqlString)
if err != nil {
return nil, err
}
return c, nil
return cursor, nil
}

func (s selectStatus) FetchFirst(dest ...interface{}) (ok bool, err error) {
Expand Down
3 changes: 0 additions & 3 deletions table.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ func (t table) GetName() string {
}

func (t table) GetSQL(scope scope) string {
if scope.Transaction != nil {
return t.sqlDialects[scope.Transaction.dialect]
}
return t.sqlDialects[scope.Database.dialect]
}

Expand Down
Loading

0 comments on commit 7d5de3b

Please sign in to comment.