From 3ed6021c05117551cc79fe58a99020b699045526 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Tue, 4 Jun 2024 18:21:36 +1000 Subject: [PATCH] refactor: add support for nested transactions (#1622) --- backend/controller/dal/dal.go | 8 ++------ backend/controller/sql/conn.go | 28 ++++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/backend/controller/dal/dal.go b/backend/controller/dal/dal.go index e6170ee0f2..cf2126952b 100644 --- a/backend/controller/dal/dal.go +++ b/backend/controller/dal/dal.go @@ -282,16 +282,12 @@ func (t *Tx) Rollback(ctx context.Context) error { } func (d *DAL) Begin(ctx context.Context) (*Tx, error) { - db, ok := d.db.(*sql.DB) - if !ok { - return nil, fmt.Errorf("can't nest transactions") - } - stx, err := db.Begin(ctx) + tx, err := d.db.Begin(ctx) if err != nil { return nil, translatePGError(err) } return &Tx{&DAL{ - db: stx, + db: tx, DeploymentChanges: d.DeploymentChanges, }}, nil } diff --git a/backend/controller/sql/conn.go b/backend/controller/sql/conn.go index d5ad3f01ba..9f9af6c16f 100644 --- a/backend/controller/sql/conn.go +++ b/backend/controller/sql/conn.go @@ -2,6 +2,7 @@ package sql import ( "context" + "fmt" "github.com/jackc/pgx/v5" ) @@ -37,22 +38,41 @@ func (d *DB) Begin(ctx context.Context) (*Tx, error) { } type Tx struct { - tx pgx.Tx + tx pgx.Tx + savepoints []string *Queries } func (t *Tx) Conn() ConnI { return t.tx } func (t *Tx) Begin(ctx context.Context) (*Tx, error) { - panic("recursive transactions are not supported") + savepoint := fmt.Sprintf("savepoint_%d", len(t.savepoints)) + t.savepoints = append(t.savepoints, savepoint) + _, err := t.tx.Exec(ctx, `SAVEPOINT `+savepoint) + if err != nil { + return nil, err + } + return &Tx{tx: t.tx, savepoints: t.savepoints, Queries: t.Queries}, nil } func (t *Tx) Commit(ctx context.Context) error { - return t.tx.Commit(ctx) + if len(t.savepoints) == 0 { + return t.tx.Commit(ctx) + } + savepoint := t.savepoints[len(t.savepoints)-1] + t.savepoints = t.savepoints[:len(t.savepoints)-1] + _, err := t.tx.Exec(ctx, `RELEASE SAVEPOINT `+savepoint) + return err } func (t *Tx) Rollback(ctx context.Context) error { - return t.tx.Rollback(ctx) + if len(t.savepoints) == 0 { + return t.tx.Rollback(ctx) + } + savepoint := t.savepoints[len(t.savepoints)-1] + t.savepoints = t.savepoints[:len(t.savepoints)-1] + _, err := t.tx.Exec(ctx, `ROLLBACK TO SAVEPOINT `+savepoint) + return err } // CommitOrRollback can be used in a defer statement to commit or rollback a