Skip to content

Commit

Permalink
Merge pull request #14 from choonkeat/postgres-schema
Browse files Browse the repository at this point in the history
Support postgres schema
  • Loading branch information
choonkeat authored Dec 3, 2023
2 parents d65f685 + 8a95b2f commit 23977e1
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 68 deletions.
10 changes: 6 additions & 4 deletions cmd/dbmigrate/cql.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ import (

func init() {
dbmigrate.Register("cql", dbmigrate.Adapter{
CreateVersionsTable: `CREATE TABLE IF NOT EXISTS dbmigrate_versions (version text, PRIMARY KEY (version));`,
SelectExistingVersions: `SELECT version FROM dbmigrate_versions`,
InsertNewVersion: `INSERT INTO dbmigrate_versions (version) VALUES (?)`,
DeleteOldVersion: `DELETE FROM dbmigrate_versions WHERE version = ?`,
CreateVersionsTable: func(_ *string) string {
return `CREATE TABLE IF NOT EXISTS dbmigrate_versions (version text, PRIMARY KEY (version));`
},
SelectExistingVersions: func(_ *string) string { return `SELECT version FROM dbmigrate_versions` },
InsertNewVersion: func(_ *string) string { return `INSERT INTO dbmigrate_versions (version) VALUES (?)` },
DeleteOldVersion: func(_ *string) string { return `DELETE FROM dbmigrate_versions WHERE version = ?` },
PingQuery: `SELECT gossip_generation FROM system.local`,
BaseDatabaseURL: func(databaseURL string) (string, string, error) {
u, err := url.Parse(databaseURL)
Expand Down
66 changes: 41 additions & 25 deletions cmd/dbmigrate/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,34 @@ import (
_ "github.com/lib/pq"
)

var (
serverReadyWait time.Duration
doCreateDB bool
doCreateMigration bool
doPendingVersions bool
doMigrateUp bool
doMigrateDown int
dirname string
databaseURL string
driverName string
timeout time.Duration
)

func main() {
if err := _main(); err != nil {
log.Fatalln(err.Error())
}
}

func _main() error {
var (
serverReadyWait time.Duration
doCreateDB bool
dbSchema *string
doCreateMigration bool
doPendingVersions bool
doMigrateUp bool
doMigrateDown int
dirname string
databaseURL string
driverName string
timeout time.Duration
errctx error
)

// options
flag.DurationVar(&serverReadyWait,
"server-ready", 0, "wait until database server is ready, then continue")
flag.BoolVar(&doCreateDB,
"create-db", false, "create postgres database (ignore errors), then continue")
dbSchema = flag.String("schema", "", "create schema if necessary (ignore errors), then continue")
flag.BoolVar(&doCreateMigration,
"create", false, "add new migration files into -dir")
flag.BoolVar(&doPendingVersions,
Expand Down Expand Up @@ -76,12 +79,12 @@ func _main() error {
return nil
}

driverName, databaseURL, _ = dbmigrate.SanitizeDriverNameURL(driverName, databaseURL)
driverName, databaseURL, errctx = dbmigrate.SanitizeDriverNameURL(driverName, databaseURL)

if doServerReadyWait := serverReadyWait > 0; doServerReadyWait || doCreateDB {
if doServerReadyWait := serverReadyWait > 0; doServerReadyWait || doCreateDB || dbSchema != nil {
adapter, err := dbmigrate.AdapterFor(driverName)
if err != nil {
return err
return errors.Wrap(err, errctx.Error())
}

if doServerReadyWait {
Expand All @@ -90,12 +93,12 @@ func _main() error {
}
connString, _, err := adapter.BaseDatabaseURL(databaseURL)
if err != nil {
return err
return errors.Wrap(err, errctx.Error())
}
ctx, cancel := context.WithTimeout(context.Background(), serverReadyWait)
defer cancel()
if err := dbmigrate.ReadyWait(ctx, driverName, []string{databaseURL, connString}, log.Println); err != nil {
return err
return errors.Wrap(err, errctx.Error())
}
}

Expand All @@ -108,44 +111,57 @@ func _main() error {
}
connString, dbName, err := adapter.BaseDatabaseURL(databaseURL)
if err != nil {
return err
return errors.Wrap(err, errctx.Error())
}
db, err := sql.Open(driverName, connString)
if err != nil {
return errors.Wrapf(err, "connect to db")
}
// leave errors for subsequent actions
_, _ = db.Exec(adapter.CreateDatabaseQuery(dbName))
_, errctx = db.Exec(adapter.CreateDatabaseQuery(dbName))
_ = db.Close()
}

if dbSchema != nil && *dbSchema != "" {
if adapter.CreateSchemaQuery == nil {
return errors.Errorf("%q does not support -schema", driverName)
}
db, err := sql.Open(driverName, databaseURL)
if err != nil {
return errors.Wrapf(err, "connect to db")
}
// leave errors for subsequent actions
_, errctx = db.Exec(adapter.CreateSchemaQuery(*dbSchema))
_ = db.Close()
}
}

m, err := dbmigrate.New(os.DirFS(dirname), driverName, databaseURL)
if err != nil {
return err
return errors.Wrap(err, errctx.Error())
}
defer m.CloseDB()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

// 2. SHOW pending versions; exit
if doPendingVersions {
versions, err := m.PendingVersions(ctx)
versions, err := m.PendingVersions(ctx, dbSchema)
if err != nil {
return err
return errors.Wrap(err, errctx.Error())
}
fmt.Println(strings.Join(versions, "\n"))
return nil
}

// 3. MIGRATE UP; exit
if doMigrateUp {
return m.MigrateUp(ctx, &sql.TxOptions{}, filenameLogger("[up]"))
return m.MigrateUp(ctx, &sql.TxOptions{}, dbSchema, filenameLogger("[up]"))
}

// 4. MIGRATE DOWN; exit
if doMigrateDown > 0 {
return m.MigrateDown(ctx, &sql.TxOptions{}, filenameLogger("[down]"), doMigrateDown)
return m.MigrateDown(ctx, &sql.TxOptions{}, dbSchema, filenameLogger("[down]"), doMigrateDown)
}

// None of the above, fail
Expand Down
10 changes: 6 additions & 4 deletions cmd/dbmigrate/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ import (

func init() {
dbmigrate.Register("sqlite3", dbmigrate.Adapter{
CreateVersionsTable: `CREATE TABLE dbmigrate_versions (version char(14) NOT NULL PRIMARY KEY)`,
SelectExistingVersions: `SELECT version FROM dbmigrate_versions ORDER BY version ASC`,
InsertNewVersion: `INSERT INTO dbmigrate_versions (version) VALUES (?)`,
DeleteOldVersion: `DELETE FROM dbmigrate_versions WHERE version = ?`,
CreateVersionsTable: func(_ *string) string {
return `CREATE TABLE dbmigrate_versions (version char(14) NOT NULL PRIMARY KEY)`
},
SelectExistingVersions: func(_ *string) string { return `SELECT version FROM dbmigrate_versions ORDER BY version ASC` },
InsertNewVersion: func(_ *string) string { return `INSERT INTO dbmigrate_versions (version) VALUES (?)` },
DeleteOldVersion: func(_ *string) string { return `DELETE FROM dbmigrate_versions WHERE version = ?` },
PingQuery: "SELECT 1",
BeginTx: func(ctx context.Context, db *sql.DB, opts *sql.TxOptions) (dbmigrate.ExecCommitRollbacker, error) {
return db.BeginTx(ctx, opts)
Expand Down
2 changes: 1 addition & 1 deletion examples/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func simpleDbmigrateUp() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()

return m.MigrateUp(ctx, &sql.TxOptions{}, func(currentFilename string) {
return m.MigrateUp(ctx, &sql.TxOptions{}, nil, func(currentFilename string) {
fmt.Println("[migrate up]", currentFilename) // optional print out of which file was migrated
})
}
12 changes: 7 additions & 5 deletions examples/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (

func sqlite3DbmigrateUp() error {
dbmigrate.Register("sqlite3", dbmigrate.Adapter{
CreateVersionsTable: `CREATE TABLE dbmigrate_versions (version char(14) NOT NULL PRIMARY KEY)`,
SelectExistingVersions: `SELECT version FROM dbmigrate_versions ORDER BY version ASC`,
InsertNewVersion: `INSERT INTO dbmigrate_versions (version) VALUES (?)`,
DeleteOldVersion: `DELETE FROM dbmigrate_versions WHERE version = ?`,
CreateVersionsTable: func(_ *string) string {
return `CREATE TABLE dbmigrate_versions (version char(14) NOT NULL PRIMARY KEY)`
},
SelectExistingVersions: func(_ *string) string { return `SELECT version FROM dbmigrate_versions ORDER BY version ASC` },
InsertNewVersion: func(_ *string) string { return `INSERT INTO dbmigrate_versions (version) VALUES (?)` },
DeleteOldVersion: func(_ *string) string { return `DELETE FROM dbmigrate_versions WHERE version = ?` },
})

// though we're using plain local file system in this example
Expand All @@ -36,7 +38,7 @@ func sqlite3DbmigrateUp() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()

return m.MigrateUp(ctx, &sql.TxOptions{}, func(currentFilename string) {
return m.MigrateUp(ctx, &sql.TxOptions{}, nil, func(currentFilename string) {
fmt.Println("[migrate up]", currentFilename) // optional print out of which file was migrated
})
}
71 changes: 46 additions & 25 deletions lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ func (c *Config) CloseDB() error {
return c.db.Close()
}

func (c *Config) existingVersions(ctx context.Context) (*trie.Trie, error) {
func (c *Config) existingVersions(ctx context.Context, schema *string) (*trie.Trie, error) {
// best effort create before we select; if the table is not there, next query will fail anyway
_, _ = c.db.ExecContext(ctx, c.adapter.CreateVersionsTable)
rows, err := c.db.QueryContext(ctx, c.adapter.SelectExistingVersions)
_, errctx := c.db.ExecContext(ctx, c.adapter.CreateVersionsTable(schema))
rows, err := c.db.QueryContext(ctx, c.adapter.SelectExistingVersions(schema))
if err != nil {
return nil, err
return nil, errors.Wrap(err, errctx.Error())
}
defer rows.Close()

Expand All @@ -149,8 +149,8 @@ func (c *Config) existingVersions(ctx context.Context) (*trie.Trie, error) {
}

// PendingVersions returns a slice of version strings that are not appled in the database yet
func (c *Config) PendingVersions(ctx context.Context) ([]string, error) {
migratedVersions, err := c.existingVersions(ctx)
func (c *Config) PendingVersions(ctx context.Context, schema *string) ([]string, error) {
migratedVersions, err := c.existingVersions(ctx, schema)
if err != nil {
return nil, errors.Wrapf(err, "unable to query existing versions")
}
Expand Down Expand Up @@ -186,8 +186,8 @@ type ExecCommitRollbacker interface {
//
// Transaction is committed on success, rollback on error. Different databases will behave
// differently, e.g. postgres & sqlite3 can rollback DDL changes but mysql cannot
func (c *Config) MigrateUp(ctx context.Context, txOpts *sql.TxOptions, logFilename func(string)) error {
migratedVersions, err := c.existingVersions(ctx)
func (c *Config) MigrateUp(ctx context.Context, txOpts *sql.TxOptions, schema *string, logFilename func(string)) error {
migratedVersions, err := c.existingVersions(ctx, schema)
if err != nil {
return errors.Wrapf(err, "unable to query existing versions")
}
Expand Down Expand Up @@ -224,7 +224,7 @@ func (c *Config) MigrateUp(ctx context.Context, txOpts *sql.TxOptions, logFilena
} else if _, err := tx.ExecContext(ctx, string(filecontent)); err != nil {
return errors.Wrapf(err, currName)
}
if _, err := tx.ExecContext(ctx, c.adapter.InsertNewVersion, currVer); err != nil {
if _, err := tx.ExecContext(ctx, c.adapter.InsertNewVersion(schema), currVer); err != nil {
return errors.Wrapf(err, "fail to register version %q", currVer)
}
logFilename(currName)
Expand All @@ -240,8 +240,8 @@ func (c *Config) MigrateUp(ctx context.Context, txOpts *sql.TxOptions, logFilena
//
// Transaction is committed on success, rollback on error. Different databases will behave
// differently, e.g. postgres & sqlite3 can rollback DDL changes but mysql cannot
func (c *Config) MigrateDown(ctx context.Context, txOpts *sql.TxOptions, logFilename func(string), downStep int) error {
migratedVersions, err := c.existingVersions(ctx)
func (c *Config) MigrateDown(ctx context.Context, txOpts *sql.TxOptions, schema *string, logFilename func(string), downStep int) error {
migratedVersions, err := c.existingVersions(ctx, schema)
if err != nil {
return errors.Wrapf(err, "unable to query existing versions")
}
Expand Down Expand Up @@ -283,7 +283,7 @@ func (c *Config) MigrateDown(ctx context.Context, txOpts *sql.TxOptions, logFile
} else if _, err := tx.ExecContext(ctx, string(filecontent)); err != nil {
return errors.Wrapf(err, currName)
}
if _, err := tx.ExecContext(ctx, c.adapter.DeleteOldVersion, currVer); err != nil {
if _, err := tx.ExecContext(ctx, c.adapter.DeleteOldVersion(schema), currVer); err != nil {
return errors.Wrapf(err, "fail to unregister version %q", currVer)
}
logFilename(currName)
Expand Down Expand Up @@ -315,23 +315,39 @@ func Register(name string, value Adapter) {

// Adapter defines raw sql statements to run for an sql.DB adapter
type Adapter struct {
CreateVersionsTable string
SelectExistingVersions string
InsertNewVersion string
DeleteOldVersion string
CreateVersionsTable func(*string) string
SelectExistingVersions func(*string) string
InsertNewVersion func(*string) string
DeleteOldVersion func(*string) string
PingQuery string // `""` means does NOT support -server-ready
CreateDatabaseQuery func(string) string // nil means does NOT support -create-db
CreateSchemaQuery func(string) string // nil means does NOT support -schema
BaseDatabaseURL func(string) (connString string, dbName string, err error) // nil means does not support -server-ready nor -create-db
BeginTx func(ctx context.Context, db *sql.DB, opts *sql.TxOptions) (ExecCommitRollbacker, error)
}

func fqName(schema *string, name string) string {
if schema == nil || *schema == "" {
return name
}
return *schema + "." + name
}

var adapters = map[string]Adapter{
"postgres": {
CreateVersionsTable: `CREATE TABLE IF NOT EXISTS dbmigrate_versions (version char(14) NOT NULL PRIMARY KEY)`,
SelectExistingVersions: `SELECT version FROM dbmigrate_versions ORDER BY version ASC`,
InsertNewVersion: `INSERT INTO dbmigrate_versions (version) VALUES ($1)`,
DeleteOldVersion: `DELETE FROM dbmigrate_versions WHERE version = $1`,
PingQuery: "SELECT 1",
CreateVersionsTable: func(schema *string) string {
return `CREATE TABLE IF NOT EXISTS ` + fqName(schema, "dbmigrate_versions") + ` (version char(14) NOT NULL PRIMARY KEY)`
},
SelectExistingVersions: func(schema *string) string {
return `SELECT version FROM ` + fqName(schema, "dbmigrate_versions") + ` ORDER BY version ASC`
},
InsertNewVersion: func(schema *string) string {
return `INSERT INTO ` + fqName(schema, "dbmigrate_versions") + ` (version) VALUES ($1)`
},
DeleteOldVersion: func(schema *string) string {
return `DELETE FROM ` + fqName(schema, "dbmigrate_versions") + ` WHERE version = $1`
},
PingQuery: "SELECT 1",
BaseDatabaseURL: func(databaseURL string) (string, string, error) {
paths := strings.Split(databaseURL, "/")
pathlen := len(paths)
Expand All @@ -346,15 +362,20 @@ var adapters = map[string]Adapter{
CreateDatabaseQuery: func(dbName string) string {
return "CREATE DATABASE " + dbName
},
CreateSchemaQuery: func(schemaName string) string {
return "CREATE SCHEMA IF NOT EXISTS " + schemaName
},
BeginTx: func(ctx context.Context, db *sql.DB, opts *sql.TxOptions) (ExecCommitRollbacker, error) {
return db.BeginTx(ctx, opts)
},
},
"mysql": {
CreateVersionsTable: `CREATE TABLE dbmigrate_versions (version char(14) NOT NULL PRIMARY KEY)`,
SelectExistingVersions: `SELECT version FROM dbmigrate_versions ORDER BY version ASC`,
InsertNewVersion: `INSERT INTO dbmigrate_versions (version) VALUES (?)`,
DeleteOldVersion: `DELETE FROM dbmigrate_versions WHERE version = ?`,
CreateVersionsTable: func(_ *string) string {
return `CREATE TABLE dbmigrate_versions (version char(14) NOT NULL PRIMARY KEY)`
},
SelectExistingVersions: func(_ *string) string { return `SELECT version FROM dbmigrate_versions ORDER BY version ASC` },
InsertNewVersion: func(_ *string) string { return `INSERT INTO dbmigrate_versions (version) VALUES (?)` },
DeleteOldVersion: func(_ *string) string { return `DELETE FROM dbmigrate_versions WHERE version = ?` },
PingQuery: "SELECT 1",
BaseDatabaseURL: func(databaseURL string) (string, string, error) {
paths := strings.Split(databaseURL, "/")
Expand Down
5 changes: 1 addition & 4 deletions tests/scenario.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

# abort on any failure
set -e
set -euxo pipefail
source `dirname $0`/lib.sh

# setup
Expand All @@ -12,9 +12,6 @@ trap finish EXIT
mkdir -p ${DB_MIGRATIONS_DIR}
echo "testing ${DATABASE_DRIVER}..."

# echo commands that we run
# set -x

# `-create` should work
assert "should create new migration" ${DBMIGRATE_CMD} -dir ${DB_MIGRATIONS_DIR} -create finally 'do! nothing??' 2>/dev/null
assert "should create .up.sql" test -f ${DB_MIGRATIONS_DIR}/*_finally-do-nothing.up.sql
Expand Down

0 comments on commit 23977e1

Please sign in to comment.