Skip to content

Commit

Permalink
store: Add transactions
Browse files Browse the repository at this point in the history
Transaction are needed because we have to leave the db in a clean state
if something happen while mutating data. Currently, we have only one
resource _source_ but, with the new multi source model, we'll need to
mutate two resources _agent_ and _source_ in the same operation.
This requires that the op is done in a transaction.

This commit adds transaction support into the store. The code is largly
based on the transaction implementation found in other RH service
(e.g. uhc-service-account) but is has been made simplier to fit our
needs.

The transactions are wrapped into a context and passed to store when
services are asking for _source_ repository (for now, we have only one.
Soon to be more). The store checks the context and if it finds a
transaction it will use it to create the source repository. Otherwise,
it uses its own db connection.

Methods for _commit_ and _rollback_ are added so a service can rollback
or commit a transaction if an error occurs.

Signed-off-by: Cosmin Tupangiu <[email protected]>
  • Loading branch information
tupyy authored and machacekondra committed Nov 6, 2024
1 parent 408e000 commit 289dc36
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 13 deletions.
26 changes: 17 additions & 9 deletions internal/store/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Source interface {
Get(ctx context.Context, id uuid.UUID) (*api.Source, error)
Delete(ctx context.Context, id uuid.UUID) error
Update(ctx context.Context, id uuid.UUID, status, statusInfo, credUrl *string, inventory *api.Inventory) (*api.Source, error)
InitialMigration() error
InitialMigration(context.Context) error
}

type SourceStore struct {
Expand All @@ -34,13 +34,13 @@ func NewSource(db *gorm.DB, log logrus.FieldLogger) Source {
return &SourceStore{db: db, log: log}
}

func (s *SourceStore) InitialMigration() error {
return s.db.AutoMigrate(&model.Source{})
func (s *SourceStore) InitialMigration(ctx context.Context) error {
return s.getDB(ctx).AutoMigrate(&model.Source{})
}

func (s *SourceStore) List(ctx context.Context) (*api.SourceList, error) {
var sources model.SourceList
result := s.db.Model(&sources).Order("id").Find(&sources)
result := s.getDB(ctx).Model(&sources).Order("id").Find(&sources)
if result.Error != nil {
return nil, result.Error
}
Expand All @@ -50,7 +50,7 @@ func (s *SourceStore) List(ctx context.Context) (*api.SourceList, error) {

func (s *SourceStore) Create(ctx context.Context, sourceCreate api.SourceCreate) (*api.Source, error) {
source := model.NewSourceFromApiCreateResource(&sourceCreate)
result := s.db.Create(source)
result := s.getDB(ctx).Create(source)
if result.Error != nil {
return nil, result.Error
}
Expand All @@ -59,13 +59,13 @@ func (s *SourceStore) Create(ctx context.Context, sourceCreate api.SourceCreate)
}

func (s *SourceStore) DeleteAll(ctx context.Context) error {
result := s.db.Unscoped().Exec("DELETE FROM sources")
result := s.getDB(ctx).Unscoped().Exec("DELETE FROM sources")
return result.Error
}

func (s *SourceStore) Get(ctx context.Context, id uuid.UUID) (*api.Source, error) {
source := model.NewSourceFromId(id)
result := s.db.First(&source)
result := s.getDB(ctx).First(&source)
if result.Error != nil {
return nil, result.Error
}
Expand All @@ -75,7 +75,7 @@ func (s *SourceStore) Get(ctx context.Context, id uuid.UUID) (*api.Source, error

func (s *SourceStore) Delete(ctx context.Context, id uuid.UUID) error {
source := model.NewSourceFromId(id)
result := s.db.Unscoped().Delete(&source)
result := s.getDB(ctx).Unscoped().Delete(&source)
if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) {
s.log.Infof("ERROR: %v", result.Error)
return result.Error
Expand Down Expand Up @@ -103,11 +103,19 @@ func (s *SourceStore) Update(ctx context.Context, id uuid.UUID, status, statusIn
selectFields = append(selectFields, "cred_url")
}

result := s.db.Model(source).Clauses(clause.Returning{}).Select(selectFields).Updates(&source)
result := s.getDB(ctx).Model(source).Clauses(clause.Returning{}).Select(selectFields).Updates(&source)
if result.Error != nil {
return nil, result.Error
}

apiSource := source.ToApiResource()
return &apiSource, nil
}

func (s *SourceStore) getDB(ctx context.Context) *gorm.DB {
tx := FromContext(ctx)
if tx != nil {
return tx
}
return s.db
}
25 changes: 21 additions & 4 deletions internal/store/store.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,54 @@
package store

import (
"context"

"github.com/sirupsen/logrus"
"gorm.io/gorm"
)

type Store interface {
NewTransactionContext(ctx context.Context) (context.Context, error)
Source() Source
InitialMigration() error
Close() error
}

type DataStore struct {
source Source
db *gorm.DB
source Source
log logrus.FieldLogger
}

func NewStore(db *gorm.DB, log logrus.FieldLogger) Store {
return &DataStore{
source: NewSource(db, log),
db: db,
log: log,
source: NewSource(db, log),
}
}

func (s *DataStore) NewTransactionContext(ctx context.Context) (context.Context, error) {
return newTransactionContext(ctx, s.db, s.log)
}

func (s *DataStore) Source() Source {
return s.source
}

func (s *DataStore) InitialMigration() error {
if err := s.Source().InitialMigration(); err != nil {
ctx, err := s.NewTransactionContext(context.Background())
if err != nil {
return err
}

if err := s.Source().InitialMigration(ctx); err != nil {
_, _ = Rollback(ctx)
return err
}
return nil

_, err = Commit(ctx)
return err
}

func (s *DataStore) Close() error {
Expand Down
13 changes: 13 additions & 0 deletions internal/store/store_suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package store_test

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestStore(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Store Suite")
}
88 changes: 88 additions & 0 deletions internal/store/store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package store_test

import (
"context"

api "github.com/kubev2v/migration-planner/api/v1alpha1"
"github.com/kubev2v/migration-planner/internal/config"
st "github.com/kubev2v/migration-planner/internal/store"
"github.com/kubev2v/migration-planner/pkg/log"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gorm.io/gorm"
)

var _ = Describe("Store", Ordered, func() {
var (
store st.Store
gormDB *gorm.DB
)

BeforeAll(func() {
log := log.InitLogs()
cfg := config.NewDefault()
db, err := st.InitDB(cfg, log)
Expect(err).To(BeNil())
gormDB = db

store = st.NewStore(db, log.WithField("test", "store"))
Expect(store).ToNot(BeNil())

// migrate
err = store.InitialMigration()
Expect(err).To(BeNil())
})

AfterAll(func() {
gormDB.Exec("DROP TABLE sources;")
store.Close()
})

Context("transaction", func() {
It("insert a source successfully", func() {
ctx, err := store.NewTransactionContext(context.TODO())
Expect(err).To(BeNil())

source, err := store.Source().Create(ctx, api.SourceCreate{Name: "test", SshKey: "some key"})
Expect(source).ToNot(BeNil())
Expect(err).To(BeNil())

// commit
_, cerr := st.Commit(ctx)
Expect(cerr).To(BeNil())

count := 0
err = gormDB.Raw("SELECT COUNT(*) from sources;").Scan(&count).Error
Expect(err).To(BeNil())
Expect(count).To(Equal(1))
})

It("rollback a source successfully", func() {
ctx, err := store.NewTransactionContext(context.TODO())
Expect(err).To(BeNil())

source, err := store.Source().Create(ctx, api.SourceCreate{Name: "test", SshKey: "some key"})
Expect(source).ToNot(BeNil())
Expect(err).To(BeNil())

// count in the same transaction
sources, err := store.Source().List(ctx)
Expect(err).To(BeNil())
Expect(sources).NotTo(BeNil())
Expect(*sources).To(HaveLen(1))

// rollback
_, cerr := st.Rollback(ctx)
Expect(cerr).To(BeNil())

count := 0
err = gormDB.Raw("SELECT COUNT(*) from sources;").Scan(&count).Error
Expect(err).To(BeNil())
Expect(count).To(Equal(0))
})

AfterEach(func() {
gormDB.Exec("DELETE from sources;")
})
})
})
126 changes: 126 additions & 0 deletions internal/store/transaction.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package store

import (
"context"
"errors"

"github.com/sirupsen/logrus"
"gorm.io/gorm"
)

type contextKey int

const (
transactionKey contextKey = iota
)

type Tx struct {
txId int64
tx *gorm.DB
log logrus.FieldLogger
}

func Commit(ctx context.Context) (context.Context, error) {
tx, ok := ctx.Value(transactionKey).(*Tx)
if !ok {
return ctx, nil
}

newCtx := context.WithValue(ctx, transactionKey, nil)
return newCtx, tx.Commit()
}

func Rollback(ctx context.Context) (context.Context, error) {
tx, ok := ctx.Value(transactionKey).(*Tx)
if !ok {
return ctx, nil
}

newCtx := context.WithValue(ctx, transactionKey, nil)
return newCtx, tx.Rollback()
}

func FromContext(ctx context.Context) *gorm.DB {
if tx, found := ctx.Value(transactionKey).(*Tx); found {
if dbTx, err := tx.Db(); err == nil {
return dbTx
}
}
return nil
}

func newTransactionContext(ctx context.Context, db *gorm.DB, log logrus.FieldLogger) (context.Context, error) {
//look into the context to see if we have another tx
_, found := ctx.Value(transactionKey).(*Tx)
if found {
return ctx, nil
}

// create a new session
conn := db.Session(&gorm.Session{
Context: ctx,
})

tx, err := newTransaction(conn, log)
if err != nil {
return ctx, err
}

ctx = context.WithValue(ctx, transactionKey, tx)
return ctx, nil
}

func newTransaction(db *gorm.DB, log logrus.FieldLogger) (*Tx, error) {
// must call begin on 'db', which is Gorm.
tx := db.Begin()
if tx.Error != nil {
return nil, tx.Error
}

// current transaction ID set by postgres. these are *not* distinct across time
// and do get reset after postgres performs "vacuuming" to reclaim used IDs.
var txid struct{ ID int64 }
tx.Raw("select txid_current() as id").Scan(&txid)

return &Tx{
txId: txid.ID,
tx: tx,
log: log,
}, nil
}

func (t *Tx) Db() (*gorm.DB, error) {
if t.tx != nil {
return t.tx, nil
}
return nil, errors.New("transaction hasn't started yet")
}

func (t *Tx) Commit() error {
if t.tx == nil {
return errors.New("transaction hasn't started yet")
}

if err := t.tx.Commit().Error; err != nil {
t.log.Errorf("failed to commit transaction %d: %w", t.txId, err)
return err
}
t.log.Debugf("transaction %d commited", t.txId)
t.tx = nil // in case we call commit twice
return nil
}

func (t *Tx) Rollback() error {
if t.tx == nil {
return errors.New("transaction hasn't started yet")
}

if err := t.tx.Rollback().Error; err != nil {
t.log.Errorf("failed to rollback transaction %d: %w", t.txId, err)
return err
}
t.tx = nil // in case we call commit twice

t.log.Debugf("transaction %d rollback", t.txId)
return nil
}

0 comments on commit 289dc36

Please sign in to comment.