Skip to content

Commit

Permalink
track db_names for mysql and singlestore; breaking change for singles…
Browse files Browse the repository at this point in the history
…tore (#63)

* track db_names for mysql and singlestore; breaking change for singlestore

* lint
  • Loading branch information
muir authored Dec 25, 2022
1 parent 3d48dae commit d434f0d
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 32 deletions.
98 changes: 77 additions & 21 deletions lsmysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
// * uses /* -- and # for comments
// * supports advisory locks
// * has quoting modes (ANSI_QUOTES)
// * can use ALTER TABLE to modify the primary key of a table
//
// Because mysql DDL commands cause transactions to autocommit, tracking the schema changes in
// a secondary table (like libschema does) is inherently unsafe. The MySQL driver will
Expand All @@ -43,9 +44,9 @@ type MySQL struct {
lockTx *sql.Tx
lockStr string
db *sql.DB
databaseName string // used in skip.go only
databaseName string
lock sync.Mutex
trackingSchemaTable func(*libschema.Database) (string, string, error)
trackingSchemaTable func(*libschema.Database) (string, string, string, error)
skipDatabase bool
}

Expand Down Expand Up @@ -219,7 +220,7 @@ func (p *MySQL) DoOneMigration(ctx context.Context, log *internal.Log, d *libsch
// called internally which means that is safe to override
// in types that embed MySQL.
func (p *MySQL) CreateSchemaTableIfNotExists(ctx context.Context, _ *internal.Log, d *libschema.Database) error {
schema, tableName, err := p.trackingSchemaTable(d)
schema, tableName, _, err := p.trackingSchemaTable(d)
if err != nil {
return err
}
Expand Down Expand Up @@ -248,7 +249,10 @@ func (p *MySQL) CreateSchemaTableIfNotExists(ctx context.Context, _ *internal.Lo

var simpleIdentifierRE = regexp.MustCompile(`\A[A-Za-z][A-Za-z0-9_]*\z`)

func WithTrackingTableQuoter(f func(*libschema.Database) (schemaName string, tableName string, err error)) MySQLOpt {
// WithTrackingTableQuoter is a somewhat internal function -- used by lssinglestore.
// It replaces the private function that takes apart the name of the tracking
// table and provides the components.
func WithTrackingTableQuoter(f func(*libschema.Database) (schemaName, tableName, simpleTableName string, err error)) MySQLOpt {
return func(p *MySQL) {
p.trackingSchemaTable = f
}
Expand All @@ -259,34 +263,34 @@ func WithTrackingTableQuoter(f func(*libschema.Database) (schemaName string, tab
// mode, you could have a table called `table` (eg: `CREATE TABLE "table"`) but
// if you're not in ANSI_QUOTES mode then you cannot. We're going to assume
// that we're not in ANSI_QUOTES mode because we cannot assume that we are.
func trackingSchemaTable(d *libschema.Database) (string, string, error) {
func trackingSchemaTable(d *libschema.Database) (string, string, string, error) {
tableName := d.Options.TrackingTable
s := strings.Split(tableName, ".")
switch len(s) {
case 2:
schema := s[0]
if !simpleIdentifierRE.MatchString(schema) {
return "", "", errors.Errorf("Tracking table schema name must be a simple identifier, not '%s'", schema)
return "", "", "", errors.Errorf("Tracking table schema name must be a simple identifier, not '%s'", schema)
}
table := s[1]
if !simpleIdentifierRE.MatchString(table) {
return "", "", errors.Errorf("Tracking table table name must be a simple identifier, not '%s'", table)
return "", "", "", errors.Errorf("Tracking table table name must be a simple identifier, not '%s'", table)
}
return schema, schema + "." + table, nil
return schema, schema + "." + table, table, nil
case 1:
if !simpleIdentifierRE.MatchString(tableName) {
return "", "", errors.Errorf("Tracking table table name must be a simple identifier, not '%s'", tableName)
return "", "", "", errors.Errorf("Tracking table table name must be a simple identifier, not '%s'", tableName)
}
return "", tableName, nil
return "", tableName, tableName, nil
default:
return "", "", errors.Errorf("Tracking table '%s' is not valid", tableName)
return "", "", "", errors.Errorf("Tracking table '%s' is not valid", tableName)
}
}

// trackingTable returns the schema+table reference for the migration tracking table.
// The name is already quoted properly for use as a save mysql identifier.
func (p *MySQL) trackingTable(d *libschema.Database) string {
_, table, _ := p.trackingSchemaTable(d)
_, table, _, _ := p.trackingSchemaTable(d)
return table
}

Expand All @@ -301,9 +305,9 @@ func (p *MySQL) saveStatus(log *internal.Log, tx *sql.Tx, d *libschema.Database,
"error": migrationError,
})
q := fmt.Sprintf(`
REPLACE INTO %s (library, migration, done, error, updated_at)
VALUES (?, ?, ?, ?, now())`, p.trackingTable(d))
_, err := tx.Exec(q, m.Base().Name.Library, m.Base().Name.Name, done, estr)
REPLACE INTO %s (db_name, library, migration, done, error, updated_at)
VALUES (?, ?, ?, ?, ?, now())`, p.trackingTable(d))
_, err := tx.Exec(q, p.databaseName, m.Base().Name.Library, m.Base().Name.Name, done, estr)
if err != nil {
return errors.Wrapf(err, "Save status for %s", m.Base().Name)
}
Expand All @@ -321,14 +325,15 @@ func (p *MySQL) saveStatus(log *internal.Log, tx *sql.Tx, d *libschema.Database,
// does not release the lock. We'll use a transaction just to make sure that
// we're using the same connection. If LockMigrationsTable succeeds, be sure to
// call UnlockMigrationsTable.
func (p *MySQL) LockMigrationsTable(ctx context.Context, _ *internal.Log, d *libschema.Database) error {
// LockMigrationsTable is overridden for SingleStore
p.lock.Lock()
defer p.lock.Unlock()
_, tableName, err := p.trackingSchemaTable(d)
func (p *MySQL) LockMigrationsTable(ctx context.Context, _ *internal.Log, d *libschema.Database) (finalErr error) {
schema, tableName, simpleTableName, err := p.trackingSchemaTable(d)
if err != nil {
return err
}

// LockMigrationsTable is overridden for SingleStore
p.lock.Lock()
defer p.lock.Unlock()
if p.lockTx != nil {
return errors.Errorf("libschema migrations table, '%s' already locked", tableName)
}
Expand All @@ -343,6 +348,56 @@ func (p *MySQL) LockMigrationsTable(ctx context.Context, _ *internal.Log, d *lib
return errors.Wrapf(err, "Could not get lock for libschema migrations")
}
p.lockTx = tx

// This moment, after getting an exclusive lock on the migrations table, is
// the right moment to do any schema upgrades of the migrations table.

defer func() {
if finalErr != nil {
_, _ = tx.Exec(`SELECT RELEASE_LOCK(?)`, p.lockStr)
_ = tx.Rollback()
p.lockTx = nil
}
}()

currentDatabaseValue := p.databaseName
defer func() {
p.databaseName = currentDatabaseValue
}()
p.databaseName = schema

ok, err := p.DoesColumnExist(simpleTableName, "db_name")
if err != nil {
return errors.Wrapf(err, "could not check if %s has a db_name column", tableName)
}
if !ok {
_, err = d.DB().ExecContext(ctx, fmt.Sprintf(`
ALTER TABLE %s
ADD COLUMN db_name varchar(255)`, tableName))
if err != nil {
return errors.Wrapf(err, "could not add db_name column to %s", tableName)
}
}
ok, err = p.ColumnIsInPrimaryKey(simpleTableName, "db_name")
if err != nil {
return errors.Wrapf(err, "could not check if %s.db_name column is in the primary key", tableName)
}
if ok {
return nil
}
_, err = d.DB().ExecContext(ctx, fmt.Sprintf(`
UPDATE %s
SET db_name = ?
WHERE db_name IS NULL`, tableName), p.databaseName)
if err != nil {
return errors.Wrapf(err, "could not set %s.db_name column", tableName)
}
_, err = d.DB().ExecContext(ctx, fmt.Sprintf(`
ALTER TABLE %s
DROP PRIMARY KEY, ADD PRIMARY KEY (db_name, library, migration)`, tableName))
if err != nil {
return errors.Wrapf(err, "could change primary key for %s", tableName)
}
return nil
}

Expand Down Expand Up @@ -379,7 +434,8 @@ func (p *MySQL) LoadStatus(ctx context.Context, _ *internal.Log, d *libschema.Da
tableName := p.trackingTable(d)
rows, err := d.DB().QueryContext(ctx, fmt.Sprintf(`
SELECT library, migration, done
FROM %s`, tableName))
FROM %s
WHERE db_name = ?`, tableName), p.databaseName)
if err != nil {
return nil, errors.Wrap(err, "Cannot query migration status")
}
Expand Down
19 changes: 19 additions & 0 deletions lsmysql/skip.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ func (p *MySQL) HasPrimaryKey(table string) (bool, error) {
return count != 0, errors.Wrapf(err, "has primary key %s.%s", database, table)
}

// ColumnIsInPrimaryKey returns true if the column part of the prmary key.
// The table is assumed to be in the current database unless m.UseDatabase() has been called.
func (p *MySQL) ColumnIsInPrimaryKey(table string, column string) (bool, error) {
database, err := p.DatabaseName()
if err != nil {
return false, err
}
var count int
err = p.db.QueryRow(`
SELECT COUNT(*)
FROM information_schema.columns
WHERE table_schema = ?
AND table_name = ?
AND column_name = ?
AND column_key = 'PRI'`,
database, table, column).Scan(&count)
return count != 0, errors.Wrapf(err, "column is in primary key %s.%s.%s", database, table, column)
}

// TableHasIndex returns true if there is an index matching the
// name given.
// The table is assumed to be in the current database unless m.UseDatabase() has been called.
Expand Down
21 changes: 11 additions & 10 deletions lssinglestore/singlestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (p *SingleStore) LockMigrationsTable(ctx context.Context, _ *internal.Log,
if p.lockTx != nil {
return errors.Errorf("migrations already locked")
}
_, tableName, err := trackingSchemaTable(d)
_, tableName, _, err := trackingSchemaTable(d)
if err != nil {
return err
}
Expand Down Expand Up @@ -152,34 +152,34 @@ func makeID(raw string) (string, error) {
}
}

func trackingSchemaTable(d *libschema.Database) (string, string, error) {
func trackingSchemaTable(d *libschema.Database) (string, string, string, error) {
tableName := d.Options.TrackingTable
s := strings.Split(tableName, ".")
switch len(s) {
case 2:
schema, err := makeID(s[0])
if err != nil {
return "", "", errors.Wrap(err, "cannot make tracking table schema name")
return "", "", "", errors.Wrap(err, "cannot make tracking table schema name")
}
table, err := makeID(s[1])
if err != nil {
return "", "", errors.Wrap(err, "cannot make tracking table table name")
return "", "", "", errors.Wrap(err, "cannot make tracking table table name")
}
return schema, schema + "." + table, nil
return schema, schema + "." + table, table, nil
case 1:
table, err := makeID(tableName)
if err != nil {
return "", "", errors.Wrap(err, "cannot make tracking table table name")
return "", "", "", errors.Wrap(err, "cannot make tracking table table name")
}
return "", table, nil
return "", table, table, nil
default:
return "", "", errors.Errorf("tracking table '%s' is not valid", tableName)
return "", "", "", errors.Errorf("tracking table '%s' is not valid", tableName)
}
}

// CreateSchemaTableIfNotExists creates the migration tracking table for libschema.
func (p *SingleStore) CreateSchemaTableIfNotExists(ctx context.Context, _ *internal.Log, d *libschema.Database) error {
schema, tableName, err := trackingSchemaTable(d)
schema, tableName, _, err := trackingSchemaTable(d)
if err != nil {
return err
}
Expand All @@ -193,14 +193,15 @@ func (p *SingleStore) CreateSchemaTableIfNotExists(ctx context.Context, _ *inter
}
_, err = d.DB().ExecContext(ctx, fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
db_name varchar(255) NOT NULL,
library varchar(255) NOT NULL,
migration varchar(255) NOT NULL,
done boolean NOT NULL,
error text NOT NULL,
updated_at timestamp DEFAULT now(),
SORT KEY (library, migration),
SHARD KEY (library, migration),
PRIMARY KEY (library, migration)
PRIMARY KEY (db_name, library, migration)
)`, tableName))
if err != nil {
return errors.Wrapf(err, "Could not create libschema migrations table '%s'", tableName)
Expand Down
6 changes: 5 additions & 1 deletion lssinglestore/unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ func TestTrackingSchemaTable(t *testing.T) {
tt string
err bool
schema string
simple string
table string
}{
{
tt: "`foo`.xk-z",
schema: "`foo`",
table: "`foo`.`xk-z`",
simple: "`xk-z`",
},
{
tt: "`foo.xk-z",
Expand All @@ -32,6 +34,7 @@ func TestTrackingSchemaTable(t *testing.T) {
tt: "foo",
schema: "",
table: "foo",
simple: "foo",
},
{
tt: "x.y.z",
Expand All @@ -50,13 +53,14 @@ func TestTrackingSchemaTable(t *testing.T) {
TrackingTable: tc.tt,
},
}
schema, table, err := trackingSchemaTable(d)
schema, table, simple, err := trackingSchemaTable(d)
if tc.err {
assert.Error(t, err)
} else {
if assert.NoError(t, err) {
assert.Equal(t, tc.schema, schema, "schema")
assert.Equal(t, tc.table, table, "table")
assert.Equal(t, tc.simple, simple, "simpleTable")
}
}
})
Expand Down

0 comments on commit d434f0d

Please sign in to comment.