From d434f0d7675734311549904dfa28e0126a29b6e6 Mon Sep 17 00:00:00 2001 From: David Sharnoff Date: Sat, 24 Dec 2022 20:14:58 -0800 Subject: [PATCH] track db_names for mysql and singlestore; breaking change for singlestore (#63) * track db_names for mysql and singlestore; breaking change for singlestore * lint --- lsmysql/mysql.go | 98 ++++++++++++++++++++++++++++-------- lsmysql/skip.go | 19 +++++++ lssinglestore/singlestore.go | 21 ++++---- lssinglestore/unit_test.go | 6 ++- 4 files changed, 112 insertions(+), 32 deletions(-) diff --git a/lsmysql/mysql.go b/lsmysql/mysql.go index 2ea5ce5..9a6528e 100644 --- a/lsmysql/mysql.go +++ b/lsmysql/mysql.go @@ -22,6 +22,7 @@ import ( // * uses /* -- and # for comments // * supports advisory locks // * has quoting modes (ANSI_QUOTES) +// * can use ALTER TABLE to modify the primary key of a table // // Because mysql DDL commands cause transactions to autocommit, tracking the schema changes in // a secondary table (like libschema does) is inherently unsafe. The MySQL driver will @@ -43,9 +44,9 @@ type MySQL struct { lockTx *sql.Tx lockStr string db *sql.DB - databaseName string // used in skip.go only + databaseName string lock sync.Mutex - trackingSchemaTable func(*libschema.Database) (string, string, error) + trackingSchemaTable func(*libschema.Database) (string, string, string, error) skipDatabase bool } @@ -219,7 +220,7 @@ func (p *MySQL) DoOneMigration(ctx context.Context, log *internal.Log, d *libsch // called internally which means that is safe to override // in types that embed MySQL. func (p *MySQL) CreateSchemaTableIfNotExists(ctx context.Context, _ *internal.Log, d *libschema.Database) error { - schema, tableName, err := p.trackingSchemaTable(d) + schema, tableName, _, err := p.trackingSchemaTable(d) if err != nil { return err } @@ -248,7 +249,10 @@ func (p *MySQL) CreateSchemaTableIfNotExists(ctx context.Context, _ *internal.Lo var simpleIdentifierRE = regexp.MustCompile(`\A[A-Za-z][A-Za-z0-9_]*\z`) -func WithTrackingTableQuoter(f func(*libschema.Database) (schemaName string, tableName string, err error)) MySQLOpt { +// WithTrackingTableQuoter is a somewhat internal function -- used by lssinglestore. +// It replaces the private function that takes apart the name of the tracking +// table and provides the components. +func WithTrackingTableQuoter(f func(*libschema.Database) (schemaName, tableName, simpleTableName string, err error)) MySQLOpt { return func(p *MySQL) { p.trackingSchemaTable = f } @@ -259,34 +263,34 @@ func WithTrackingTableQuoter(f func(*libschema.Database) (schemaName string, tab // mode, you could have a table called `table` (eg: `CREATE TABLE "table"`) but // if you're not in ANSI_QUOTES mode then you cannot. We're going to assume // that we're not in ANSI_QUOTES mode because we cannot assume that we are. -func trackingSchemaTable(d *libschema.Database) (string, string, error) { +func trackingSchemaTable(d *libschema.Database) (string, string, string, error) { tableName := d.Options.TrackingTable s := strings.Split(tableName, ".") switch len(s) { case 2: schema := s[0] if !simpleIdentifierRE.MatchString(schema) { - return "", "", errors.Errorf("Tracking table schema name must be a simple identifier, not '%s'", schema) + return "", "", "", errors.Errorf("Tracking table schema name must be a simple identifier, not '%s'", schema) } table := s[1] if !simpleIdentifierRE.MatchString(table) { - return "", "", errors.Errorf("Tracking table table name must be a simple identifier, not '%s'", table) + return "", "", "", errors.Errorf("Tracking table table name must be a simple identifier, not '%s'", table) } - return schema, schema + "." + table, nil + return schema, schema + "." + table, table, nil case 1: if !simpleIdentifierRE.MatchString(tableName) { - return "", "", errors.Errorf("Tracking table table name must be a simple identifier, not '%s'", tableName) + return "", "", "", errors.Errorf("Tracking table table name must be a simple identifier, not '%s'", tableName) } - return "", tableName, nil + return "", tableName, tableName, nil default: - return "", "", errors.Errorf("Tracking table '%s' is not valid", tableName) + return "", "", "", errors.Errorf("Tracking table '%s' is not valid", tableName) } } // trackingTable returns the schema+table reference for the migration tracking table. // The name is already quoted properly for use as a save mysql identifier. func (p *MySQL) trackingTable(d *libschema.Database) string { - _, table, _ := p.trackingSchemaTable(d) + _, table, _, _ := p.trackingSchemaTable(d) return table } @@ -301,9 +305,9 @@ func (p *MySQL) saveStatus(log *internal.Log, tx *sql.Tx, d *libschema.Database, "error": migrationError, }) q := fmt.Sprintf(` - REPLACE INTO %s (library, migration, done, error, updated_at) - VALUES (?, ?, ?, ?, now())`, p.trackingTable(d)) - _, err := tx.Exec(q, m.Base().Name.Library, m.Base().Name.Name, done, estr) + REPLACE INTO %s (db_name, library, migration, done, error, updated_at) + VALUES (?, ?, ?, ?, ?, now())`, p.trackingTable(d)) + _, err := tx.Exec(q, p.databaseName, m.Base().Name.Library, m.Base().Name.Name, done, estr) if err != nil { return errors.Wrapf(err, "Save status for %s", m.Base().Name) } @@ -321,14 +325,15 @@ func (p *MySQL) saveStatus(log *internal.Log, tx *sql.Tx, d *libschema.Database, // does not release the lock. We'll use a transaction just to make sure that // we're using the same connection. If LockMigrationsTable succeeds, be sure to // call UnlockMigrationsTable. -func (p *MySQL) LockMigrationsTable(ctx context.Context, _ *internal.Log, d *libschema.Database) error { - // LockMigrationsTable is overridden for SingleStore - p.lock.Lock() - defer p.lock.Unlock() - _, tableName, err := p.trackingSchemaTable(d) +func (p *MySQL) LockMigrationsTable(ctx context.Context, _ *internal.Log, d *libschema.Database) (finalErr error) { + schema, tableName, simpleTableName, err := p.trackingSchemaTable(d) if err != nil { return err } + + // LockMigrationsTable is overridden for SingleStore + p.lock.Lock() + defer p.lock.Unlock() if p.lockTx != nil { return errors.Errorf("libschema migrations table, '%s' already locked", tableName) } @@ -343,6 +348,56 @@ func (p *MySQL) LockMigrationsTable(ctx context.Context, _ *internal.Log, d *lib return errors.Wrapf(err, "Could not get lock for libschema migrations") } p.lockTx = tx + + // This moment, after getting an exclusive lock on the migrations table, is + // the right moment to do any schema upgrades of the migrations table. + + defer func() { + if finalErr != nil { + _, _ = tx.Exec(`SELECT RELEASE_LOCK(?)`, p.lockStr) + _ = tx.Rollback() + p.lockTx = nil + } + }() + + currentDatabaseValue := p.databaseName + defer func() { + p.databaseName = currentDatabaseValue + }() + p.databaseName = schema + + ok, err := p.DoesColumnExist(simpleTableName, "db_name") + if err != nil { + return errors.Wrapf(err, "could not check if %s has a db_name column", tableName) + } + if !ok { + _, err = d.DB().ExecContext(ctx, fmt.Sprintf(` + ALTER TABLE %s + ADD COLUMN db_name varchar(255)`, tableName)) + if err != nil { + return errors.Wrapf(err, "could not add db_name column to %s", tableName) + } + } + ok, err = p.ColumnIsInPrimaryKey(simpleTableName, "db_name") + if err != nil { + return errors.Wrapf(err, "could not check if %s.db_name column is in the primary key", tableName) + } + if ok { + return nil + } + _, err = d.DB().ExecContext(ctx, fmt.Sprintf(` + UPDATE %s + SET db_name = ? + WHERE db_name IS NULL`, tableName), p.databaseName) + if err != nil { + return errors.Wrapf(err, "could not set %s.db_name column", tableName) + } + _, err = d.DB().ExecContext(ctx, fmt.Sprintf(` + ALTER TABLE %s + DROP PRIMARY KEY, ADD PRIMARY KEY (db_name, library, migration)`, tableName)) + if err != nil { + return errors.Wrapf(err, "could change primary key for %s", tableName) + } return nil } @@ -379,7 +434,8 @@ func (p *MySQL) LoadStatus(ctx context.Context, _ *internal.Log, d *libschema.Da tableName := p.trackingTable(d) rows, err := d.DB().QueryContext(ctx, fmt.Sprintf(` SELECT library, migration, done - FROM %s`, tableName)) + FROM %s + WHERE db_name = ?`, tableName), p.databaseName) if err != nil { return nil, errors.Wrap(err, "Cannot query migration status") } diff --git a/lsmysql/skip.go b/lsmysql/skip.go index 790c3d1..012fabf 100644 --- a/lsmysql/skip.go +++ b/lsmysql/skip.go @@ -43,6 +43,25 @@ func (p *MySQL) HasPrimaryKey(table string) (bool, error) { return count != 0, errors.Wrapf(err, "has primary key %s.%s", database, table) } +// ColumnIsInPrimaryKey returns true if the column part of the prmary key. +// The table is assumed to be in the current database unless m.UseDatabase() has been called. +func (p *MySQL) ColumnIsInPrimaryKey(table string, column string) (bool, error) { + database, err := p.DatabaseName() + if err != nil { + return false, err + } + var count int + err = p.db.QueryRow(` + SELECT COUNT(*) + FROM information_schema.columns + WHERE table_schema = ? + AND table_name = ? + AND column_name = ? + AND column_key = 'PRI'`, + database, table, column).Scan(&count) + return count != 0, errors.Wrapf(err, "column is in primary key %s.%s.%s", database, table, column) +} + // TableHasIndex returns true if there is an index matching the // name given. // The table is assumed to be in the current database unless m.UseDatabase() has been called. diff --git a/lssinglestore/singlestore.go b/lssinglestore/singlestore.go index 7262bf3..090784c 100644 --- a/lssinglestore/singlestore.go +++ b/lssinglestore/singlestore.go @@ -95,7 +95,7 @@ func (p *SingleStore) LockMigrationsTable(ctx context.Context, _ *internal.Log, if p.lockTx != nil { return errors.Errorf("migrations already locked") } - _, tableName, err := trackingSchemaTable(d) + _, tableName, _, err := trackingSchemaTable(d) if err != nil { return err } @@ -152,34 +152,34 @@ func makeID(raw string) (string, error) { } } -func trackingSchemaTable(d *libschema.Database) (string, string, error) { +func trackingSchemaTable(d *libschema.Database) (string, string, string, error) { tableName := d.Options.TrackingTable s := strings.Split(tableName, ".") switch len(s) { case 2: schema, err := makeID(s[0]) if err != nil { - return "", "", errors.Wrap(err, "cannot make tracking table schema name") + return "", "", "", errors.Wrap(err, "cannot make tracking table schema name") } table, err := makeID(s[1]) if err != nil { - return "", "", errors.Wrap(err, "cannot make tracking table table name") + return "", "", "", errors.Wrap(err, "cannot make tracking table table name") } - return schema, schema + "." + table, nil + return schema, schema + "." + table, table, nil case 1: table, err := makeID(tableName) if err != nil { - return "", "", errors.Wrap(err, "cannot make tracking table table name") + return "", "", "", errors.Wrap(err, "cannot make tracking table table name") } - return "", table, nil + return "", table, table, nil default: - return "", "", errors.Errorf("tracking table '%s' is not valid", tableName) + return "", "", "", errors.Errorf("tracking table '%s' is not valid", tableName) } } // CreateSchemaTableIfNotExists creates the migration tracking table for libschema. func (p *SingleStore) CreateSchemaTableIfNotExists(ctx context.Context, _ *internal.Log, d *libschema.Database) error { - schema, tableName, err := trackingSchemaTable(d) + schema, tableName, _, err := trackingSchemaTable(d) if err != nil { return err } @@ -193,6 +193,7 @@ func (p *SingleStore) CreateSchemaTableIfNotExists(ctx context.Context, _ *inter } _, err = d.DB().ExecContext(ctx, fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( + db_name varchar(255) NOT NULL, library varchar(255) NOT NULL, migration varchar(255) NOT NULL, done boolean NOT NULL, @@ -200,7 +201,7 @@ func (p *SingleStore) CreateSchemaTableIfNotExists(ctx context.Context, _ *inter updated_at timestamp DEFAULT now(), SORT KEY (library, migration), SHARD KEY (library, migration), - PRIMARY KEY (library, migration) + PRIMARY KEY (db_name, library, migration) )`, tableName)) if err != nil { return errors.Wrapf(err, "Could not create libschema migrations table '%s'", tableName) diff --git a/lssinglestore/unit_test.go b/lssinglestore/unit_test.go index 8efb425..ac6d710 100644 --- a/lssinglestore/unit_test.go +++ b/lssinglestore/unit_test.go @@ -17,12 +17,14 @@ func TestTrackingSchemaTable(t *testing.T) { tt string err bool schema string + simple string table string }{ { tt: "`foo`.xk-z", schema: "`foo`", table: "`foo`.`xk-z`", + simple: "`xk-z`", }, { tt: "`foo.xk-z", @@ -32,6 +34,7 @@ func TestTrackingSchemaTable(t *testing.T) { tt: "foo", schema: "", table: "foo", + simple: "foo", }, { tt: "x.y.z", @@ -50,13 +53,14 @@ func TestTrackingSchemaTable(t *testing.T) { TrackingTable: tc.tt, }, } - schema, table, err := trackingSchemaTable(d) + schema, table, simple, err := trackingSchemaTable(d) if tc.err { assert.Error(t, err) } else { if assert.NoError(t, err) { assert.Equal(t, tc.schema, schema, "schema") assert.Equal(t, tc.table, table, "table") + assert.Equal(t, tc.simple, simple, "simpleTable") } } })