diff --git a/pkg/inventory/model/client.go b/pkg/inventory/model/client.go index 4dbed74e..9ab563ca 100644 --- a/pkg/inventory/model/client.go +++ b/pkg/inventory/model/client.go @@ -2,7 +2,6 @@ package model import ( "database/sql" - "errors" liberr "github.com/konveyor/controller/pkg/error" "os" "reflect" @@ -13,13 +12,6 @@ const ( Pragma = "PRAGMA foreign_keys = ON" ) -// -// Tx.Commit() -// Tx.End() -// Called and the transaction is not in progress by -// the associated Client. -var TxInvalidError = errors.New("transaction not valid") - // // Database client. type DB interface { @@ -29,8 +21,6 @@ type DB interface { Close(bool) error // Get the specified model. Get(Model) error - // Get for update of the specified model. - GetForUpdate(Model) (*Tx, error) // List models based on the type of slice. List(interface{}, ListOptions) error // Count based on the specified model. @@ -52,8 +42,7 @@ type DB interface { // // Database client. type Client struct { - // Protect internal state. - sync.RWMutex + labeler Labeler // The sqlite3 database will not support // concurrent write operations. dbMutex sync.Mutex @@ -63,8 +52,6 @@ type Client struct { models []interface{} // Database connection. db *sql.DB - // Current database transaction. - tx *sql.Tx // Journal journal Journal } @@ -128,24 +115,6 @@ func (r *Client) Get(model Model) error { return Table{r.db}.Get(model) } -// -// Get the model for update. -// Locks the DB by beginning a transaction. -// The caller MUST commit/end the returned Tx. -func (r *Client) GetForUpdate(model Model) (*Tx, error) { - tx, err := r.Begin() - if err != nil { - return nil, liberr.Wrap(err) - } - err = Table{r.db}.Get(model) - if err != nil { - tx.End() - tx = nil - } - - return tx, err -} - // // List models. // The `list` must be: *[]Model. @@ -168,42 +137,36 @@ func (r *Client) Count(model Model, predicate Predicate) (int64, error) { // client.Insert(model) // tx.Commit() func (r *Client) Begin() (*Tx, error) { - r.Lock() - defer r.Unlock() r.dbMutex.Lock() - tx, err := r.db.Begin() + real, err := r.db.Begin() if err != nil { return nil, err } - r.tx = tx - return &Tx{client: r, ref: tx}, nil + tx := &Tx{ + dbMutex: &r.dbMutex, + journal: &r.journal, + real: real, + } + + return tx, nil } // // Insert the model. func (r *Client) Insert(model Model) error { - r.Lock() - defer r.Unlock() - table := Table{} - if r.tx == nil { - r.dbMutex.Lock() - defer r.dbMutex.Unlock() - table.DB = r.db - } else { - table.DB = r.tx - } + r.dbMutex.Lock() + defer r.dbMutex.Unlock() + table := Table{r.db} err := table.Insert(model) if err != nil { return liberr.Wrap(err) } - err = r.insertLabels(table, model) + err = r.labeler.Insert(table, model) if err != nil { return liberr.Wrap(err) } r.journal.Created(model) - if r.tx == nil { - r.journal.Commit() - } + r.journal.Commit() return nil } @@ -211,16 +174,9 @@ func (r *Client) Insert(model Model) error { // // Update the model. func (r *Client) Update(model Model) error { - r.Lock() - defer r.Unlock() - table := Table{} - if r.tx == nil { - r.dbMutex.Lock() - defer r.dbMutex.Unlock() - table.DB = r.db - } else { - table.DB = r.tx - } + r.dbMutex.Lock() + defer r.dbMutex.Unlock() + table := Table{r.db} current := r.journal.copy(model) err := table.Get(current) if err != nil { @@ -230,14 +186,12 @@ func (r *Client) Update(model Model) error { if err != nil { return liberr.Wrap(err) } - err = r.replaceLabels(table, model) + err = r.labeler.Replace(table, model) if err != nil { return liberr.Wrap(err) } r.journal.Updated(current, model) - if r.tx == nil { - r.journal.Commit() - } + r.journal.Commit() return nil } @@ -245,28 +199,19 @@ func (r *Client) Update(model Model) error { // // Delete the model. func (r *Client) Delete(model Model) error { - r.Lock() - defer r.Unlock() - table := Table{} - if r.tx == nil { - r.dbMutex.Lock() - defer r.dbMutex.Unlock() - table.DB = r.db - } else { - table.DB = r.tx - } + r.dbMutex.Lock() + defer r.dbMutex.Unlock() + table := Table{r.db} err := table.Delete(model) if err != nil { return liberr.Wrap(err) } - err = r.deleteLabels(table, model) + err = r.labeler.Delete(table, model) if err != nil { return liberr.Wrap(err) } r.journal.Deleted(model) - if r.tx == nil { - r.journal.Commit() - } + r.journal.Commit() return nil } @@ -274,8 +219,6 @@ func (r *Client) Delete(model Model) error { // // Watch model events. func (r *Client) Watch(model Model, handler EventHandler) (*Watch, error) { - r.Lock() - defer r.Unlock() mt := reflect.TypeOf(model) switch mt.Kind() { case reflect.Ptr: @@ -312,132 +255,199 @@ func (r *Client) Journal() *Journal { } // -// Insert labels for the model into the DB. -func (r *Client) insertLabels(table Table, model Model) error { - for l, v := range model.Labels() { - label := &Label{ - Parent: model.Pk(), - Kind: table.Name(model), - Name: l, - Value: v, - } - err := table.Insert(label) - if err != nil { - return liberr.Wrap(err) - } +// Database transaction. +type Tx struct { + labeler Labeler + // Associated client. + dbMutex *sync.Mutex + // Journal + journal *Journal + // Reference to real sql.Tx. + real *sql.Tx + // Ended + ended bool +} + +// +// Get the model. +func (r *Tx) Get(model Model) error { + return Table{r.real}.Get(model) +} + +// +// List models. +// The `list` must be: *[]Model. +func (r *Tx) List(list interface{}, options ListOptions) error { + return Table{r.real}.List(list, options) +} + +// +// Count models. +func (r *Tx) Count(model Model, predicate Predicate) (int64, error) { + return Table{r.real}.Count(model, predicate) +} + +// +// Insert the model. +func (r *Tx) Insert(model Model) error { + table := Table{r.real} + err := table.Insert(model) + if err != nil { + return liberr.Wrap(err) } + err = r.labeler.Insert(table, model) + if err != nil { + return liberr.Wrap(err) + } + r.journal.Created(model) return nil } // -// Delete labels for a model in the DB. -func (r *Client) deleteLabels(table Table, model Model) error { - list := []Label{} - err := table.List( - &list, - ListOptions{ - Predicate: And( - Eq("Kind", table.Name(model)), - Eq("Parent", model.Pk())), - }) +// Update the model. +func (r *Tx) Update(model Model) error { + table := Table{r.real} + current := r.journal.copy(model) + err := table.Get(current) if err != nil { return liberr.Wrap(err) } - for _, label := range list { - err := table.Delete(&label) - if err != nil { - return liberr.Wrap(err) - } + err = table.Update(model) + if err != nil { + return liberr.Wrap(err) } + err = r.labeler.Replace(table, model) + if err != nil { + return liberr.Wrap(err) + } + r.journal.Updated(current, model) return nil } // -// Replace labels. -func (r *Client) replaceLabels(table Table, model Model) error { - err := r.deleteLabels(table, model) +// Delete the model. +func (r *Tx) Delete(model Model) error { + table := Table{r.real} + err := table.Delete(model) if err != nil { return liberr.Wrap(err) } - err = r.insertLabels(table, model) + err = r.labeler.Delete(table, model) if err != nil { return liberr.Wrap(err) } + r.journal.Deleted(model) return nil } // // Commit a transaction. -// This MUST be preceeded by Begin() which returns -// the `tx` transaction. This will end the transaction. -func (r *Client) commit(tx *Tx) error { - r.Lock() - defer r.Unlock() - if r.tx == nil || r.tx != tx.ref { - return liberr.Wrap(TxInvalidError) +// Staged changes are committed in the DB. +// This will end the transaction. +func (r *Tx) Commit() (err error) { + if r.ended { + return } defer func() { r.dbMutex.Unlock() - r.tx = nil + r.ended = true }() - err := r.tx.Commit() + err = r.real.Commit() if err != nil { - return liberr.Wrap(err) + err = liberr.Wrap(err) + return } r.journal.Commit() - return nil + return } // // End a transaction. -// This MUST be preceeded by Begin() which returns -// the `tx` transaction. -func (r *Client) end(tx *Tx) error { - r.Lock() - defer r.Unlock() - if r.tx == nil || r.tx != tx.ref { - return liberr.Wrap(TxInvalidError) +// Staged changes are discarded. +// See: Commit(). +func (r *Tx) End() (err error) { + if r.ended { + return } defer func() { r.dbMutex.Unlock() - r.tx = nil + r.ended = true }() - err := r.tx.Rollback() + err = r.real.Rollback() if err != nil { - return liberr.Wrap(err) + err = liberr.Wrap(err) + return } r.journal.Unstage() - return nil + return } // -// Database transaction. -type Tx struct { - // Associated client. - client *Client - // Reference to sql.Tx. - ref *sql.Tx +// Labeler. +type Labeler struct { } // -// Commit a transaction. -// Staged changes are committed in the DB. -// This will end the transaction. -func (r *Tx) Commit() error { - return r.client.commit(r) +// Insert labels for the model into the DB. +func (r *Labeler) Insert(table Table, model Model) error { + for l, v := range model.Labels() { + label := &Label{ + Parent: model.Pk(), + Kind: table.Name(model), + Name: l, + Value: v, + } + err := table.Insert(label) + if err != nil { + return liberr.Wrap(err) + } + } + + return nil } // -// End a transaction. -// Staged changes are discarded. -// See: Commit(). -func (r *Tx) End() error { - return r.client.end(r) +// Delete labels for a model in the DB. +func (r *Labeler) Delete(table Table, model Model) error { + list := []Label{} + err := table.List( + &list, + ListOptions{ + Predicate: And( + Eq("Kind", table.Name(model)), + Eq("Parent", model.Pk())), + }) + if err != nil { + return liberr.Wrap(err) + } + for _, label := range list { + err := table.Delete(&label) + if err != nil { + return liberr.Wrap(err) + } + } + + return nil +} + +// +// Replace labels. +func (r *Labeler) Replace(table Table, model Model) error { + err := r.Delete(table, model) + if err != nil { + return liberr.Wrap(err) + } + err = r.Insert(table, model) + if err != nil { + return liberr.Wrap(err) + } + + return nil } diff --git a/pkg/inventory/model/model_test.go b/pkg/inventory/model/model_test.go index 0b64387c..37fb89d8 100644 --- a/pkg/inventory/model/model_test.go +++ b/pkg/inventory/model/model_test.go @@ -74,6 +74,41 @@ func (w *TestHandler) Error(err error) { func (w *TestHandler) End() { } +type MutatingHandler struct { + DB + name string + created []int + updated []int +} + +func (w *MutatingHandler) Created(e Event) { + tx, _ := w.DB.Begin() + tx.Get(e.Model) + e.Model.(*TestObject).Age++ + tx.Update(e.Model) + tx.Commit() + w.created = append(w.created, e.Model.(*TestObject).ID) +} + +func (w *MutatingHandler) Updated(e Event) { + tx, _ := w.DB.Begin() + tx.Get(e.Model) + e.Model.(*TestObject).Age++ + tx.Update(e.Model) + tx.Commit() + w.updated = append(w.updated, e.Model.(*TestObject).ID) +} + +func (w *MutatingHandler) Deleted(e Event) { +} + +func (w *MutatingHandler) Error(err error) { + return +} + +func (w *MutatingHandler) End() { +} + func TestCRUD(t *testing.T) { var err error g := gomega.NewGomegaWithT(t) @@ -156,12 +191,13 @@ func TestTransactions(t *testing.T) { tx, err := DB.Begin() defer tx.End() g.Expect(err).To(gomega.BeNil()) - g.Expect(tx.ref).To(gomega.Equal(DB.(*Client).tx)) + g.Expect(tx.dbMutex).To(gomega.Equal(&DB.(*Client).dbMutex)) + g.Expect(tx.journal).To(gomega.Equal(&DB.(*Client).journal)) object := &TestObject{ ID: 0, Name: "Elmer", } - err = DB.Insert(object) + err = tx.Insert(object) g.Expect(err).To(gomega.BeNil()) // Get (not found) object = &TestObject{ID: object.ID} @@ -174,27 +210,6 @@ func TestTransactions(t *testing.T) { g.Expect(err).To(gomega.BeNil()) } -func TestGetForUpdate(t *testing.T) { - g := gomega.NewGomegaWithT(t) - DB := New( - "/tmp/test.db", - &Label{}, - &TestObject{}) - err := DB.Open(true) - g.Expect(err).To(gomega.BeNil()) - // Insert - object := &TestObject{ - ID: 0, - Name: "Elmer", - } - err = DB.Insert(object) - g.Expect(err).To(gomega.BeNil()) - tx, err := DB.GetForUpdate(object) - g.Expect(err).To(gomega.BeNil()) - g.Expect(tx.ref).To(gomega.Equal(DB.(*Client).tx)) - tx.Commit() -} - func TestList(t *testing.T) { var err error g := gomega.NewGomegaWithT(t) @@ -375,6 +390,50 @@ func TestWatch(t *testing.T) { } } +func TestMutatingWatch(t *testing.T) { + g := gomega.NewGomegaWithT(t) + DB := New( + "/tmp/test.db", + &Label{}, + &TestObject{}) + err := DB.Open(true) + g.Expect(err).To(gomega.BeNil()) + DB.Journal().Enable() + // Handler A + handlerA := &MutatingHandler{ + name: "A", + DB: DB, + } + watchA, err := DB.Watch(&TestObject{}, handlerA) + g.Expect(err).To(gomega.BeNil()) + g.Expect(watchA).ToNot(gomega.BeNil()) + // Handler B + handlerB := &MutatingHandler{ + name: "A", + DB: DB, + } + watchB, err := DB.Watch(&TestObject{}, handlerB) + g.Expect(err).To(gomega.BeNil()) + g.Expect(watchB).ToNot(gomega.BeNil()) + N := 10 + // Insert + for i := 0; i < N; i++ { + object := &TestObject{ + ID: i, + Name: "Elmer", + } + err = DB.Insert(object) + g.Expect(err).To(gomega.BeNil()) + } + + for { + time.Sleep(time.Millisecond * 10) + if len(handlerA.updated) > 100 { + break + } + } +} + // // Remove leading __ to enable. func __TestConcurrency(t *testing.T) { @@ -418,7 +477,7 @@ func __TestConcurrency(t *testing.T) { } fmt.Printf("read|%d\n", i) }() - time.Sleep(time.Millisecond * time.Duration(100)) + time.Sleep(time.Millisecond * 100) } done <- 0 } @@ -439,7 +498,7 @@ func __TestConcurrency(t *testing.T) { } fmt.Printf("del|%d\n", i) }() - time.Sleep(time.Millisecond * time.Duration(300)) + time.Sleep(time.Millisecond * 300) } done <- 0 } @@ -460,15 +519,24 @@ func __TestConcurrency(t *testing.T) { } fmt.Printf("update|%d\n", i) }() - time.Sleep(time.Millisecond * time.Duration(20)) + time.Sleep(time.Millisecond * 20) } done <- 0 } transaction := func(done chan int) { + time.Sleep(time.Millisecond * 100) var tx *Tx + defer func() { + if tx != nil { + err := tx.Commit() + if err != nil { + panic(err) + } + } + }() threshold := float64(10) for i := N; i < N*2; i++ { - if i == N || math.Mod(float64(i), threshold) == 0 { + if tx == nil { tx, err = DB.Begin() if err != nil { panic(err) @@ -478,7 +546,7 @@ func __TestConcurrency(t *testing.T) { ID: i, Name: "transaction", } - DB.Insert(m) + err = tx.Insert(m) if err != nil { panic(err) } @@ -488,10 +556,11 @@ func __TestConcurrency(t *testing.T) { if err != nil { panic(err) } + tx = nil fmt.Printf("commit|%d\n", i) } fmt.Printf("transaction|%d\n", i) - time.Sleep(time.Millisecond * time.Duration(100)) + time.Sleep(time.Millisecond * 100) } done <- 0 }