Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

store: Add transactions #48

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions internal/store/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Source interface {
Get(ctx context.Context, id uuid.UUID) (*api.Source, error)
Delete(ctx context.Context, id uuid.UUID) error
Update(ctx context.Context, id uuid.UUID, status, statusInfo, credUrl *string, inventory *api.Inventory) (*api.Source, error)
InitialMigration() error
InitialMigration(context.Context) error
}

type SourceStore struct {
Expand All @@ -34,13 +34,13 @@ func NewSource(db *gorm.DB, log logrus.FieldLogger) Source {
return &SourceStore{db: db, log: log}
}

func (s *SourceStore) InitialMigration() error {
return s.db.AutoMigrate(&model.Source{})
func (s *SourceStore) InitialMigration(ctx context.Context) error {
return s.getDB(ctx).AutoMigrate(&model.Source{})
}

func (s *SourceStore) List(ctx context.Context) (*api.SourceList, error) {
var sources model.SourceList
result := s.db.Model(&sources).Order("id").Find(&sources)
result := s.getDB(ctx).Model(&sources).Order("id").Find(&sources)
if result.Error != nil {
return nil, result.Error
}
Expand All @@ -50,7 +50,7 @@ func (s *SourceStore) List(ctx context.Context) (*api.SourceList, error) {

func (s *SourceStore) Create(ctx context.Context, sourceCreate api.SourceCreate) (*api.Source, error) {
source := model.NewSourceFromApiCreateResource(&sourceCreate)
result := s.db.Create(source)
result := s.getDB(ctx).Create(source)
if result.Error != nil {
return nil, result.Error
}
Expand All @@ -59,13 +59,13 @@ func (s *SourceStore) Create(ctx context.Context, sourceCreate api.SourceCreate)
}

func (s *SourceStore) DeleteAll(ctx context.Context) error {
result := s.db.Unscoped().Exec("DELETE FROM sources")
result := s.getDB(ctx).Unscoped().Exec("DELETE FROM sources")
return result.Error
}

func (s *SourceStore) Get(ctx context.Context, id uuid.UUID) (*api.Source, error) {
source := model.NewSourceFromId(id)
result := s.db.First(&source)
result := s.getDB(ctx).First(&source)
if result.Error != nil {
return nil, result.Error
}
Expand All @@ -75,7 +75,7 @@ func (s *SourceStore) Get(ctx context.Context, id uuid.UUID) (*api.Source, error

func (s *SourceStore) Delete(ctx context.Context, id uuid.UUID) error {
source := model.NewSourceFromId(id)
result := s.db.Unscoped().Delete(&source)
result := s.getDB(ctx).Unscoped().Delete(&source)
if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) {
s.log.Infof("ERROR: %v", result.Error)
return result.Error
Expand Down Expand Up @@ -103,11 +103,19 @@ func (s *SourceStore) Update(ctx context.Context, id uuid.UUID, status, statusIn
selectFields = append(selectFields, "cred_url")
}

result := s.db.Model(source).Clauses(clause.Returning{}).Select(selectFields).Updates(&source)
result := s.getDB(ctx).Model(source).Clauses(clause.Returning{}).Select(selectFields).Updates(&source)
if result.Error != nil {
return nil, result.Error
}

apiSource := source.ToApiResource()
return &apiSource, nil
}

func (s *SourceStore) getDB(ctx context.Context) *gorm.DB {
tx := FromContext(ctx)
if tx != nil {
return tx
}
return s.db
}
25 changes: 21 additions & 4 deletions internal/store/store.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,54 @@
package store

import (
"context"

"github.com/sirupsen/logrus"
"gorm.io/gorm"
)

type Store interface {
NewTransactionContext(ctx context.Context) (context.Context, error)
Source() Source
InitialMigration() error
Close() error
}

type DataStore struct {
source Source
db *gorm.DB
source Source
log logrus.FieldLogger
}

func NewStore(db *gorm.DB, log logrus.FieldLogger) Store {
return &DataStore{
source: NewSource(db, log),
db: db,
log: log,
source: NewSource(db, log),
}
}

func (s *DataStore) NewTransactionContext(ctx context.Context) (context.Context, error) {
return newTransactionContext(ctx, s.db, s.log)
}

func (s *DataStore) Source() Source {
return s.source
}

func (s *DataStore) InitialMigration() error {
if err := s.Source().InitialMigration(); err != nil {
ctx, err := s.NewTransactionContext(context.Background())
if err != nil {
return err
}

if err := s.Source().InitialMigration(ctx); err != nil {
_, _ = Rollback(ctx)
return err
}
return nil

_, err = Commit(ctx)
return err
}

func (s *DataStore) Close() error {
Expand Down
13 changes: 13 additions & 0 deletions internal/store/store_suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package store_test

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestStore(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Store Suite")
}
88 changes: 88 additions & 0 deletions internal/store/store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package store_test

import (
"context"

api "github.com/kubev2v/migration-planner/api/v1alpha1"
"github.com/kubev2v/migration-planner/internal/config"
st "github.com/kubev2v/migration-planner/internal/store"
"github.com/kubev2v/migration-planner/pkg/log"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gorm.io/gorm"
)

var _ = Describe("Store", Ordered, func() {
var (
store st.Store
gormDB *gorm.DB
)

BeforeAll(func() {
log := log.InitLogs()
cfg := config.NewDefault()
db, err := st.InitDB(cfg, log)
Expect(err).To(BeNil())
gormDB = db

store = st.NewStore(db, log.WithField("test", "store"))
Expect(store).ToNot(BeNil())

// migrate
err = store.InitialMigration()
Expect(err).To(BeNil())
})

AfterAll(func() {
gormDB.Exec("DROP TABLE sources;")
store.Close()
})

Context("transaction", func() {
It("insert a source successfully", func() {
ctx, err := store.NewTransactionContext(context.TODO())
Expect(err).To(BeNil())

source, err := store.Source().Create(ctx, api.SourceCreate{Name: "test", SshKey: "some key"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really need to add the validation to SSH key ;)

Expect(source).ToNot(BeNil())
Expect(err).To(BeNil())

// commit
_, cerr := st.Commit(ctx)
Expect(cerr).To(BeNil())

count := 0
err = gormDB.Raw("SELECT COUNT(*) from sources;").Scan(&count).Error
Expect(err).To(BeNil())
Expect(count).To(Equal(1))
})

It("rollback a source successfully", func() {
ctx, err := store.NewTransactionContext(context.TODO())
Expect(err).To(BeNil())

source, err := store.Source().Create(ctx, api.SourceCreate{Name: "test", SshKey: "some key"})
Expect(source).ToNot(BeNil())
Expect(err).To(BeNil())

// count in the same transaction
sources, err := store.Source().List(ctx)
Expect(err).To(BeNil())
Expect(sources).NotTo(BeNil())
Expect(*sources).To(HaveLen(1))

// rollback
_, cerr := st.Rollback(ctx)
Expect(cerr).To(BeNil())

count := 0
err = gormDB.Raw("SELECT COUNT(*) from sources;").Scan(&count).Error
Expect(err).To(BeNil())
Expect(count).To(Equal(0))
})

AfterEach(func() {
gormDB.Exec("DELETE from sources;")
})
})
})
126 changes: 126 additions & 0 deletions internal/store/transaction.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package store

import (
"context"
"errors"

"github.com/sirupsen/logrus"
"gorm.io/gorm"
)

type contextKey int

const (
transactionKey contextKey = iota
)

type Tx struct {
txId int64
tx *gorm.DB
log logrus.FieldLogger
}

func Commit(ctx context.Context) (context.Context, error) {
tx, ok := ctx.Value(transactionKey).(*Tx)
if !ok {
return ctx, nil
}

newCtx := context.WithValue(ctx, transactionKey, nil)
return newCtx, tx.Commit()
}

func Rollback(ctx context.Context) (context.Context, error) {
tx, ok := ctx.Value(transactionKey).(*Tx)
if !ok {
return ctx, nil
}

newCtx := context.WithValue(ctx, transactionKey, nil)
return newCtx, tx.Rollback()
}

func FromContext(ctx context.Context) *gorm.DB {
if tx, found := ctx.Value(transactionKey).(*Tx); found {
if dbTx, err := tx.Db(); err == nil {
return dbTx
}
}
return nil
}

func newTransactionContext(ctx context.Context, db *gorm.DB, log logrus.FieldLogger) (context.Context, error) {
//look into the context to see if we have another tx
_, found := ctx.Value(transactionKey).(*Tx)
if found {
return ctx, nil
}

// create a new session
conn := db.Session(&gorm.Session{
Context: ctx,
})

tx, err := newTransaction(conn, log)
if err != nil {
return ctx, err
}

ctx = context.WithValue(ctx, transactionKey, tx)
return ctx, nil
}

func newTransaction(db *gorm.DB, log logrus.FieldLogger) (*Tx, error) {
// must call begin on 'db', which is Gorm.
tx := db.Begin()
if tx.Error != nil {
return nil, tx.Error
}

// current transaction ID set by postgres. these are *not* distinct across time
// and do get reset after postgres performs "vacuuming" to reclaim used IDs.
var txid struct{ ID int64 }
tx.Raw("select txid_current() as id").Scan(&txid)

return &Tx{
txId: txid.ID,
tx: tx,
log: log,
}, nil
}

func (t *Tx) Db() (*gorm.DB, error) {
if t.tx != nil {
return t.tx, nil
}
return nil, errors.New("transaction hasn't started yet")
}

func (t *Tx) Commit() error {
if t.tx == nil {
return errors.New("transaction hasn't started yet")
}

if err := t.tx.Commit().Error; err != nil {
t.log.Errorf("failed to commit transaction %d: %w", t.txId, err)
return err
}
t.log.Debugf("transaction %d commited", t.txId)
t.tx = nil // in case we call commit twice
return nil
}

func (t *Tx) Rollback() error {
if t.tx == nil {
return errors.New("transaction hasn't started yet")
}

if err := t.tx.Rollback().Error; err != nil {
t.log.Errorf("failed to rollback transaction %d: %w", t.txId, err)
return err
}
t.tx = nil // in case we call commit twice

t.log.Debugf("transaction %d rollback", t.txId)
return nil
}
Loading