diff --git a/db.go b/db.go index a573bb8f..c51015b0 100644 --- a/db.go +++ b/db.go @@ -183,7 +183,7 @@ func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name str } } - tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m} + tmap := NewTableMap(t, name, schema, m) var primaryKey []*ColumnMap tmap.Columns, primaryKey = m.readStructColumns(t) m.tables = append(m.tables, tmap) diff --git a/gorp_test.go b/gorp_test.go index 98ea4d3e..8ee040c9 100644 --- a/gorp_test.go +++ b/gorp_test.go @@ -1586,6 +1586,21 @@ func TestColumnFilter(t *testing.T) { if inv2.IsPaid { t.Error("IsPaid shouldn't have been updated") } + + // update isPaid field only + _updateColumns(dbmap, func(col *gorp.ColumnMap) bool { + return col.ColumnName == "IsPaid" + }, inv1) + + inv2 = &Invoice{} + inv2 = _get(dbmap, inv2, inv1.Id).(*Invoice) + if inv2.Memo != "c" { + t.Errorf("Expected column to be updated (%#v)", inv2) + } + if !inv2.IsPaid { + t.Error("IsPaid should have been updated") + } + } func TestTypeConversionExample(t *testing.T) { diff --git a/table.go b/table.go index 5c513909..cd117518 100644 --- a/table.go +++ b/table.go @@ -16,6 +16,7 @@ import ( "fmt" "reflect" "strings" + "sync" ) // TableMap represents a mapping between a Go struct and a database table @@ -31,10 +32,22 @@ type TableMap struct { uniqueTogether [][]string version *ColumnMap insertPlan bindPlan - updatePlan bindPlan deletePlan bindPlan getPlan bindPlan dbmap *DbMap + + updatePlan map[string]bindPlan + muForUpdate sync.Mutex +} + +func NewTableMap(t reflect.Type, name string, schema string, dbmap *DbMap) *TableMap { + return &TableMap{ + gotype: t, + TableName: name, + SchemaName: schema, + dbmap: dbmap, + updatePlan: map[string]bindPlan{}, + } } // ResetSql removes cached insert/update/select/delete SQL strings @@ -42,7 +55,7 @@ type TableMap struct { // any column names or the table name itself. func (t *TableMap) ResetSql() { t.insertPlan = bindPlan{} - t.updatePlan = bindPlan{} + t.updatePlan = map[string]bindPlan{} t.deletePlan = bindPlan{} t.getPlan = bindPlan{} } diff --git a/table_bindings.go b/table_bindings.go index 5b049a36..20d75762 100644 --- a/table_bindings.go +++ b/table_bindings.go @@ -15,6 +15,7 @@ import ( "bytes" "fmt" "reflect" + "strings" "sync" ) @@ -167,13 +168,35 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { return plan.createBindInstance(elem, t.dbmap.TypeConverter) } +func (t *TableMap) signatureForColumns(colFilter ColumnFilter) string { + var tokens []string + for y := range t.Columns { + col := t.Columns[y] + if colFilter(col) { + tokens = append(tokens, col.ColumnName) + } + } + return strings.Join(tokens, ",") +} func (t *TableMap) bindUpdate(elem reflect.Value, colFilter ColumnFilter) (bindInstance, error) { + var key string if colFilter == nil { colFilter = acceptAllFilter + key = "default" // if we wants to update all columns, make it simple + } else { + key = t.signatureForColumns(colFilter) + } + + t.muForUpdate.Lock() + _, ok := t.updatePlan[key] + if !ok { + t.updatePlan[key] = bindPlan{} } + updatePlan := t.updatePlan[key] + t.muForUpdate.Unlock() - plan := &t.updatePlan + plan := &updatePlan plan.once.Do(func() { s := bytes.Buffer{} s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))