Skip to content

Commit

Permalink
Merge pull request #9 from choonkeat/cassandra
Browse files Browse the repository at this point in the history
add Cassandra support
  • Loading branch information
choonkeat authored Jul 7, 2020
2 parents 67ea7a8 + c26b49d commit a4a64e6
Show file tree
Hide file tree
Showing 28 changed files with 407 additions and 70 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
cid.txt
./dbmigrate
tests/db/migrations
5 changes: 2 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
language: go
go:
- "1.11.x"
- "1.14.x"
services:
- docker
script:
- env GO111MODULE=on make test
- env GO111MODULE=on make test BUILD_TARGET="./cmd/dbmigrate" DATABASE_DRIVERS=sqlite3
- make test
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
DATABASE_DRIVERS=postgres mariadb mysql
BUILD_TARGET=./cmd/dbmigrate/main.go
DATABASE_DRIVERS=cql sqlite3 postgres mariadb mysql
BUILD_TARGET=./cmd/dbmigrate/*.go

test: build
go build -o /dev/null ./examples # verify examples can compile
Expand Down
56 changes: 56 additions & 0 deletions cmd/dbmigrate/cql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package main

// by default, Makefile `make build` compiles without this file
// if sqlite3 is required,
// env CGO_ENABLED=1 make build BUILD_TARGET="./cmd/dbmigrate"

import (
"context"
"database/sql"
"net/url"

_ "github.com/MichaelS11/go-cql-driver"
"github.com/choonkeat/dbmigrate"
"github.com/pkg/errors"
)

func init() {
dbmigrate.Register("cql", dbmigrate.Adapter{
CreateVersionsTable: `CREATE TABLE IF NOT EXISTS dbmigrate_versions (version text, PRIMARY KEY (version));`,
SelectExistingVersions: `SELECT version FROM dbmigrate_versions`,
InsertNewVersion: `INSERT INTO dbmigrate_versions (version) VALUES (?)`,
DeleteOldVersion: `DELETE FROM dbmigrate_versions WHERE version = ?`,
PingQuery: `SELECT gossip_generation FROM system.local`,
BaseDatabaseURL: func(databaseURL string) (string, string, error) {
u, err := url.Parse(databaseURL)
if err != nil {
return "", "", errors.Wrapf(err, "invalid cassandra dsn")
}
q := u.Query()
dbName := q.Get("keyspace")
q.Set("keyspace", "system") // default connection
u.RawQuery = q.Encode()
return u.String(), dbName, nil
},
BeginTx: func(ctx context.Context, db *sql.DB, opts *sql.TxOptions) (dbmigrate.ExecCommitRollbacker, error) {
return &noTx{db: db}, nil
},
})
}

// Implements dbmigrate.ExecCommitRollbacker
type noTx struct {
db *sql.DB
}

func (tx *noTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return tx.db.ExecContext(ctx, query, args...)
}

func (tx *noTx) Commit() error {
return nil
}

func (tx *noTx) Rollback() error {
return nil
}
30 changes: 24 additions & 6 deletions cmd/dbmigrate/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,46 @@ func _main() error {
return nil
}

driverName, databaseURL, _ = dbmigrate.SanitizeDriverNameURL(driverName, databaseURL)

if doServerReadyWait := serverReadyWait > 0; doServerReadyWait || doCreateDB {
driverName, _, _ = dbmigrate.SanitizeDriverNameURL(driverName, databaseURL)
connString, dbName, err := dbmigrate.BaseDatabaseURL(driverName, databaseURL, "/"+driverName)
adapter, err := dbmigrate.AdapterFor(driverName)
if err != nil {
return errors.Wrapf(err, "database url without dbname")
return err
}

if doServerReadyWait && driverName != "sqlite3" {
if doServerReadyWait {
if adapter.BaseDatabaseURL == nil {
return errors.Errorf("%q does not support -server-ready", driverName)
}
connString, _, err := adapter.BaseDatabaseURL(databaseURL)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), serverReadyWait)
defer cancel()
if err := dbmigrate.ReadyWait(ctx, driverName, []string{databaseURL, connString}, log.Println); err != nil {
return err
}
}

if doCreateDB && driverName != "sqlite3" {
if doCreateDB {
if adapter.BaseDatabaseURL == nil {
return errors.Errorf("%q does not support -create-db", driverName)
}
if adapter.CreateDatabaseQuery == nil {
return errors.Errorf("%q does not support -create-db", driverName)
}
connString, dbName, err := adapter.BaseDatabaseURL(databaseURL)
if err != nil {
return err
}
db, err := sql.Open(driverName, connString)
if err != nil {
return errors.Wrapf(err, "connect to db")
}
// leave errors for subsequent actions
_, _ = db.Exec("CREATE DATABASE " + dbName)
_, _ = db.Exec(adapter.CreateDatabaseQuery(dbName))
_ = db.Close()
}
}
Expand Down
7 changes: 7 additions & 0 deletions cmd/dbmigrate/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ package main
// env CGO_ENABLED=1 make build BUILD_TARGET="./cmd/dbmigrate"

import (
"context"
"database/sql"

"github.com/choonkeat/dbmigrate"
_ "github.com/mattn/go-sqlite3"
)
Expand All @@ -15,5 +18,9 @@ func init() {
SelectExistingVersions: `SELECT version FROM dbmigrate_versions ORDER BY version ASC`,
InsertNewVersion: `INSERT INTO dbmigrate_versions (version) VALUES (?)`,
DeleteOldVersion: `DELETE FROM dbmigrate_versions WHERE version = ?`,
PingQuery: "SELECT 1",
BeginTx: func(ctx context.Context, db *sql.DB, opts *sql.TxOptions) (dbmigrate.ExecCommitRollbacker, error) {
return db.BeginTx(ctx, opts)
},
})
}
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
module github.com/choonkeat/dbmigrate

go 1.14

require (
github.com/MichaelS11/go-cql-driver v0.0.0-20190914174813-cf3b3196aa43
github.com/derekparker/trie v0.0.0-20180212171413-e608c2733dc7
github.com/go-sql-driver/mysql v1.4.1
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e // indirect
github.com/lib/pq v1.0.0
github.com/mattn/go-sqlite3 v1.10.0
github.com/pkg/errors v0.8.0
Expand Down
19 changes: 19 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
github.com/MichaelS11/go-cql-driver v0.0.0-20190914174813-cf3b3196aa43 h1:G4RmbeBfV1OXJYhmqcU7onWWwIEiLMr5RvsAe/1yIkA=
github.com/MichaelS11/go-cql-driver v0.0.0-20190914174813-cf3b3196aa43/go.mod h1:nW8K1gl1mu8o29Ns1Sv/EvYe9BBrh1T/GqucnYcO9PI=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/derekparker/trie v0.0.0-20180212171413-e608c2733dc7 h1:Cab9yoTQh1TxObKfis1DzZ6vFLK5kbeenMjRES/UE3o=
github.com/derekparker/trie v0.0.0-20180212171413-e608c2733dc7/go.mod h1:D6ICZm05D9VN1n/8iOtBxLpXtoGp6HDFUJ1RNVieOSE=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e h1:SroDcndcOU9BVAduPf/PXihXoR2ZYTQYLXbupbqxAyQ=
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk=
github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o=
Expand All @@ -20,3 +37,5 @@ golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
112 changes: 80 additions & 32 deletions lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"database/sql"
"io/ioutil"
"net/http"
"net/url"
"os"
"sort"
"strings"
Expand All @@ -15,42 +16,34 @@ import (
"github.com/pkg/errors"
)

// RequireDriverName to indicate explicit driver name
var RequireDriverName = errors.Errorf("Cannot discern db driver. Please set -driver flag or DATABASE_DRIVER environment variable.")

// SanitizeDriverNameURL sanitizes `driverName` and `databaseURL` values
func SanitizeDriverNameURL(driverName string, databaseURL string) (string, string, error) {
func SanitizeDriverNameURL(driverName string, databaseURL string) (dbdriver string, dburl string, err error) {
// ensure db and driverName is legit
databaseURL = strings.TrimSpace(databaseURL)
if databaseURL == "" {
return driverName, databaseURL, errors.Errorf("database url not set")
}
if driverName == "" {
// fall back to use the `scheme` part of the url as driverName
// e.g. `postgres://localhost:5432/dbmigrate_test` will thus be `postgres`
driverName = strings.Split(databaseURL, ":")[0]
driverName = strings.TrimSpace(driverName)
if driverName != "" {
return driverName, databaseURL, nil
}
return driverName, databaseURL, nil
}

// BaseDatabaseURL returns the connection string to connect to the server (without the database name)
func BaseDatabaseURL(driverName string, databaseURL string, defaultDbName string) (string, string, error) {
driverName, databaseURL, err := SanitizeDriverNameURL(driverName, databaseURL)
if err != nil {
return "", "", err
if u, err := url.Parse(databaseURL); strings.Contains(databaseURL, "://") && u != nil && err == nil {
return u.Scheme, databaseURL, nil
}

paths := strings.Split(databaseURL, "/")
pathlen := len(paths)
requestURI := strings.Split(paths[pathlen-1], "?")
basePaths := []string{strings.Join(paths[:pathlen-1], "/") + defaultDbName}

if len(requestURI) > 1 {
basePaths = append(basePaths, requestURI[1:]...)
}
return strings.Join(basePaths, "?"), requestURI[0], nil
return "", databaseURL, RequireDriverName
}

// ReadyWait for server to be ready, and try to create db and connect again
func ReadyWait(ctx context.Context, driverName string, databaseURLs []string, logger func(...interface{})) error {
logger(driverName, "checking connection")
adapter, err := AdapterFor(driverName)
if err != nil {
return err
}

count := len(databaseURLs)
curr := -1
for {
Expand All @@ -59,7 +52,7 @@ func ReadyWait(ctx context.Context, driverName string, databaseURLs []string, lo
if err == nil {
logger(driverName, "server up")
var num int
if err = db.QueryRow("SELECT 1").Scan(&num); err == nil {
if err = db.QueryRow(adapter.PingQuery).Scan(&num); err == nil {
logger(driverName, "connected")
return db.Close()
}
Expand Down Expand Up @@ -93,10 +86,9 @@ func New(dir http.FileSystem, driverName string, databaseURL string) (*Config, e
if err != nil {
return nil, errors.Wrapf(err, "see `--help` for more details.")
}
var ok bool
adapter, ok := adapters[driverName]
if !ok {
return nil, errors.Errorf("unsupported driver name %q", driverName)
adapter, err := AdapterFor(driverName)
if err != nil {
return nil, err
}
db, err := sql.Open(driverName, databaseURL)
if err != nil {
Expand Down Expand Up @@ -175,6 +167,13 @@ func (c *Config) PendingVersions(ctx context.Context) ([]string, error) {
return result, nil
}

// ExecCommitRollbacker interface for sql.Tx
type ExecCommitRollbacker interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Commit() error
Rollback() error
}

// MigrateUp applies pending migrations in ascending order, in a transaction
//
// Transaction is committed on success, rollback on error. Different databases will behave
Expand All @@ -185,7 +184,7 @@ func (c *Config) MigrateUp(ctx context.Context, txOpts *sql.TxOptions, logFilena
return errors.Wrapf(err, "unable to query existing versions")
}

tx, err := c.db.BeginTx(ctx, txOpts)
tx, err := c.adapter.BeginTx(ctx, c.db, txOpts)
if err != nil {
return errors.Wrapf(err, "unable to create transaction")
}
Expand Down Expand Up @@ -236,7 +235,7 @@ func (c *Config) MigrateDown(ctx context.Context, txOpts *sql.TxOptions, logFile
return errors.Wrapf(err, "unable to query existing versions")
}

tx, err := c.db.BeginTx(ctx, txOpts)
tx, err := c.adapter.BeginTx(ctx, c.db, txOpts)
if err != nil {
return errors.Wrapf(err, "unable to create transaction")
}
Expand Down Expand Up @@ -306,19 +305,68 @@ type Adapter struct {
SelectExistingVersions string
InsertNewVersion string
DeleteOldVersion string
PingQuery string // `""` means does NOT support -server-ready
CreateDatabaseQuery func(string) string // nil means does NOT support -create-db
BaseDatabaseURL func(string) (connString string, dbName string, err error) // nil means does not support -server-ready nor -create-db
BeginTx func(ctx context.Context, db *sql.DB, opts *sql.TxOptions) (ExecCommitRollbacker, error)
}

var adapters = map[string]Adapter{
"postgres": Adapter{
"postgres": {
CreateVersionsTable: `CREATE TABLE dbmigrate_versions (version char(14) NOT NULL PRIMARY KEY)`,
SelectExistingVersions: `SELECT version FROM dbmigrate_versions ORDER BY version ASC`,
InsertNewVersion: `INSERT INTO dbmigrate_versions (version) VALUES ($1)`,
DeleteOldVersion: `DELETE FROM dbmigrate_versions WHERE version = $1`,
PingQuery: "SELECT 1",
BaseDatabaseURL: func(databaseURL string) (string, string, error) {
paths := strings.Split(databaseURL, "/")
pathlen := len(paths)
requestURI := strings.Split(paths[pathlen-1], "?")
basePaths := []string{strings.Join(paths[:pathlen-1], "/") + "/postgres"}

if len(requestURI) > 1 {
basePaths = append(basePaths, requestURI[1:]...)
}
return strings.Join(basePaths, "?"), requestURI[0], nil
},
CreateDatabaseQuery: func(dbName string) string {
return "CREATE DATABASE " + dbName
},
BeginTx: func(ctx context.Context, db *sql.DB, opts *sql.TxOptions) (ExecCommitRollbacker, error) {
return db.BeginTx(ctx, opts)
},
},
"mysql": Adapter{
"mysql": {
CreateVersionsTable: `CREATE TABLE dbmigrate_versions (version char(14) NOT NULL PRIMARY KEY)`,
SelectExistingVersions: `SELECT version FROM dbmigrate_versions ORDER BY version ASC`,
InsertNewVersion: `INSERT INTO dbmigrate_versions (version) VALUES (?)`,
DeleteOldVersion: `DELETE FROM dbmigrate_versions WHERE version = ?`,
PingQuery: "SELECT 1",
BaseDatabaseURL: func(databaseURL string) (string, string, error) {
paths := strings.Split(databaseURL, "/")
pathlen := len(paths)
requestURI := strings.Split(paths[pathlen-1], "?")
basePaths := []string{strings.Join(paths[:pathlen-1], "/") + "/mysql"}

if len(requestURI) > 1 {
basePaths = append(basePaths, requestURI[1:]...)
}
return strings.Join(basePaths, "?"), requestURI[0], nil
},
CreateDatabaseQuery: func(dbName string) string {
return "CREATE DATABASE " + dbName
},
BeginTx: func(ctx context.Context, db *sql.DB, opts *sql.TxOptions) (ExecCommitRollbacker, error) {
return db.BeginTx(ctx, opts)
},
},
}

// AdapterFor returns Adapter for given driverName
func AdapterFor(driverName string) (Adapter, error) {
a, ok := adapters[driverName]
if !ok {
return a, errors.Errorf("unsupported driver name %q", driverName)
}
return a, nil
}
Loading

0 comments on commit a4a64e6

Please sign in to comment.