diff --git a/internal/store/source.go b/internal/store/source.go index fea624d..87b7349 100644 --- a/internal/store/source.go +++ b/internal/store/source.go @@ -19,7 +19,7 @@ type Source interface { Get(ctx context.Context, id uuid.UUID) (*api.Source, error) Delete(ctx context.Context, id uuid.UUID) error Update(ctx context.Context, id uuid.UUID, status, statusInfo, credUrl *string, inventory *api.Inventory) (*api.Source, error) - InitialMigration() error + InitialMigration(context.Context) error } type SourceStore struct { @@ -34,13 +34,13 @@ func NewSource(db *gorm.DB, log logrus.FieldLogger) Source { return &SourceStore{db: db, log: log} } -func (s *SourceStore) InitialMigration() error { - return s.db.AutoMigrate(&model.Source{}) +func (s *SourceStore) InitialMigration(ctx context.Context) error { + return s.getDB(ctx).AutoMigrate(&model.Source{}) } func (s *SourceStore) List(ctx context.Context) (*api.SourceList, error) { var sources model.SourceList - result := s.db.Model(&sources).Order("id").Find(&sources) + result := s.getDB(ctx).Model(&sources).Order("id").Find(&sources) if result.Error != nil { return nil, result.Error } @@ -50,7 +50,7 @@ func (s *SourceStore) List(ctx context.Context) (*api.SourceList, error) { func (s *SourceStore) Create(ctx context.Context, sourceCreate api.SourceCreate) (*api.Source, error) { source := model.NewSourceFromApiCreateResource(&sourceCreate) - result := s.db.Create(source) + result := s.getDB(ctx).Create(source) if result.Error != nil { return nil, result.Error } @@ -59,13 +59,13 @@ func (s *SourceStore) Create(ctx context.Context, sourceCreate api.SourceCreate) } func (s *SourceStore) DeleteAll(ctx context.Context) error { - result := s.db.Unscoped().Exec("DELETE FROM sources") + result := s.getDB(ctx).Unscoped().Exec("DELETE FROM sources") return result.Error } func (s *SourceStore) Get(ctx context.Context, id uuid.UUID) (*api.Source, error) { source := model.NewSourceFromId(id) - result := s.db.First(&source) + result := s.getDB(ctx).First(&source) if result.Error != nil { return nil, result.Error } @@ -75,7 +75,7 @@ func (s *SourceStore) Get(ctx context.Context, id uuid.UUID) (*api.Source, error func (s *SourceStore) Delete(ctx context.Context, id uuid.UUID) error { source := model.NewSourceFromId(id) - result := s.db.Unscoped().Delete(&source) + result := s.getDB(ctx).Unscoped().Delete(&source) if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) { s.log.Infof("ERROR: %v", result.Error) return result.Error @@ -103,7 +103,7 @@ func (s *SourceStore) Update(ctx context.Context, id uuid.UUID, status, statusIn selectFields = append(selectFields, "cred_url") } - result := s.db.Model(source).Clauses(clause.Returning{}).Select(selectFields).Updates(&source) + result := s.getDB(ctx).Model(source).Clauses(clause.Returning{}).Select(selectFields).Updates(&source) if result.Error != nil { return nil, result.Error } @@ -111,3 +111,11 @@ func (s *SourceStore) Update(ctx context.Context, id uuid.UUID, status, statusIn apiSource := source.ToApiResource() return &apiSource, nil } + +func (s *SourceStore) getDB(ctx context.Context) *gorm.DB { + tx := FromContext(ctx) + if tx != nil { + return tx + } + return s.db +} diff --git a/internal/store/store.go b/internal/store/store.go index 3310f4c..ab98a93 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1,37 +1,54 @@ package store import ( + "context" + "github.com/sirupsen/logrus" "gorm.io/gorm" ) type Store interface { + NewTransactionContext(ctx context.Context) (context.Context, error) Source() Source InitialMigration() error Close() error } type DataStore struct { - source Source db *gorm.DB + source Source + log logrus.FieldLogger } func NewStore(db *gorm.DB, log logrus.FieldLogger) Store { return &DataStore{ - source: NewSource(db, log), db: db, + log: log, + source: NewSource(db, log), } } +func (s *DataStore) NewTransactionContext(ctx context.Context) (context.Context, error) { + return newTransactionContext(ctx, s.db, s.log) +} + func (s *DataStore) Source() Source { return s.source } func (s *DataStore) InitialMigration() error { - if err := s.Source().InitialMigration(); err != nil { + ctx, err := s.NewTransactionContext(context.Background()) + if err != nil { + return err + } + + if err := s.Source().InitialMigration(ctx); err != nil { + _, _ = Rollback(ctx) return err } - return nil + + _, err = Commit(ctx) + return err } func (s *DataStore) Close() error { diff --git a/internal/store/store_suite_test.go b/internal/store/store_suite_test.go new file mode 100644 index 0000000..e79bb60 --- /dev/null +++ b/internal/store/store_suite_test.go @@ -0,0 +1,13 @@ +package store_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestStore(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Store Suite") +} diff --git a/internal/store/store_test.go b/internal/store/store_test.go new file mode 100644 index 0000000..86c8e40 --- /dev/null +++ b/internal/store/store_test.go @@ -0,0 +1,88 @@ +package store_test + +import ( + "context" + + api "github.com/kubev2v/migration-planner/api/v1alpha1" + "github.com/kubev2v/migration-planner/internal/config" + st "github.com/kubev2v/migration-planner/internal/store" + "github.com/kubev2v/migration-planner/pkg/log" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +var _ = Describe("Store", Ordered, func() { + var ( + store st.Store + gormDB *gorm.DB + ) + + BeforeAll(func() { + log := log.InitLogs() + cfg := config.NewDefault() + db, err := st.InitDB(cfg, log) + Expect(err).To(BeNil()) + gormDB = db + + store = st.NewStore(db, log.WithField("test", "store")) + Expect(store).ToNot(BeNil()) + + // migrate + err = store.InitialMigration() + Expect(err).To(BeNil()) + }) + + AfterAll(func() { + gormDB.Exec("DROP TABLE sources;") + store.Close() + }) + + Context("transaction", func() { + It("insert a source successfully", func() { + ctx, err := store.NewTransactionContext(context.TODO()) + Expect(err).To(BeNil()) + + source, err := store.Source().Create(ctx, api.SourceCreate{Name: "test", SshKey: "some key"}) + Expect(source).ToNot(BeNil()) + Expect(err).To(BeNil()) + + // commit + _, cerr := st.Commit(ctx) + Expect(cerr).To(BeNil()) + + count := 0 + err = gormDB.Raw("SELECT COUNT(*) from sources;").Scan(&count).Error + Expect(err).To(BeNil()) + Expect(count).To(Equal(1)) + }) + + It("rollback a source successfully", func() { + ctx, err := store.NewTransactionContext(context.TODO()) + Expect(err).To(BeNil()) + + source, err := store.Source().Create(ctx, api.SourceCreate{Name: "test", SshKey: "some key"}) + Expect(source).ToNot(BeNil()) + Expect(err).To(BeNil()) + + // count in the same transaction + sources, err := store.Source().List(ctx) + Expect(err).To(BeNil()) + Expect(sources).NotTo(BeNil()) + Expect(*sources).To(HaveLen(1)) + + // rollback + _, cerr := st.Rollback(ctx) + Expect(cerr).To(BeNil()) + + count := 0 + err = gormDB.Raw("SELECT COUNT(*) from sources;").Scan(&count).Error + Expect(err).To(BeNil()) + Expect(count).To(Equal(0)) + }) + + AfterEach(func() { + gormDB.Exec("DELETE from sources;") + }) + }) +}) diff --git a/internal/store/transaction.go b/internal/store/transaction.go new file mode 100644 index 0000000..859fdc3 --- /dev/null +++ b/internal/store/transaction.go @@ -0,0 +1,126 @@ +package store + +import ( + "context" + "errors" + + "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +type contextKey int + +const ( + transactionKey contextKey = iota +) + +type Tx struct { + txId int64 + tx *gorm.DB + log logrus.FieldLogger +} + +func Commit(ctx context.Context) (context.Context, error) { + tx, ok := ctx.Value(transactionKey).(*Tx) + if !ok { + return ctx, nil + } + + newCtx := context.WithValue(ctx, transactionKey, nil) + return newCtx, tx.Commit() +} + +func Rollback(ctx context.Context) (context.Context, error) { + tx, ok := ctx.Value(transactionKey).(*Tx) + if !ok { + return ctx, nil + } + + newCtx := context.WithValue(ctx, transactionKey, nil) + return newCtx, tx.Rollback() +} + +func FromContext(ctx context.Context) *gorm.DB { + if tx, found := ctx.Value(transactionKey).(*Tx); found { + if dbTx, err := tx.Db(); err == nil { + return dbTx + } + } + return nil +} + +func newTransactionContext(ctx context.Context, db *gorm.DB, log logrus.FieldLogger) (context.Context, error) { + //look into the context to see if we have another tx + _, found := ctx.Value(transactionKey).(*Tx) + if found { + return ctx, nil + } + + // create a new session + conn := db.Session(&gorm.Session{ + Context: ctx, + }) + + tx, err := newTransaction(conn, log) + if err != nil { + return ctx, err + } + + ctx = context.WithValue(ctx, transactionKey, tx) + return ctx, nil +} + +func newTransaction(db *gorm.DB, log logrus.FieldLogger) (*Tx, error) { + // must call begin on 'db', which is Gorm. + tx := db.Begin() + if tx.Error != nil { + return nil, tx.Error + } + + // current transaction ID set by postgres. these are *not* distinct across time + // and do get reset after postgres performs "vacuuming" to reclaim used IDs. + var txid struct{ ID int64 } + tx.Raw("select txid_current() as id").Scan(&txid) + + return &Tx{ + txId: txid.ID, + tx: tx, + log: log, + }, nil +} + +func (t *Tx) Db() (*gorm.DB, error) { + if t.tx != nil { + return t.tx, nil + } + return nil, errors.New("transaction hasn't started yet") +} + +func (t *Tx) Commit() error { + if t.tx == nil { + return errors.New("transaction hasn't started yet") + } + + if err := t.tx.Commit().Error; err != nil { + t.log.Errorf("failed to commit transaction %d: %w", t.txId, err) + return err + } + t.log.Debugf("transaction %d commited", t.txId) + t.tx = nil // in case we call commit twice + return nil +} + +func (t *Tx) Rollback() error { + if t.tx == nil { + return errors.New("transaction hasn't started yet") + } + + if err := t.tx.Rollback().Error; err != nil { + t.log.Errorf("failed to rollback transaction %d: %w", t.txId, err) + return err + } + t.tx = nil // in case we call commit twice + + t.log.Debugf("transaction %d rollback", t.txId) + return nil +}