From ec518610aecc2ae67a0f5b3fd11d1914bd1e23a3 Mon Sep 17 00:00:00 2001 From: Venkat Venkatasubramanian Date: Wed, 28 Aug 2024 18:15:33 -0700 Subject: [PATCH 1/4] First pass at downmigrate --- cmd/migrate/config.go | 56 +- internal/cli/commands.go | 358 ++++----- internal/cli/main.go | 728 +++++++++--------- migrate.go | 1543 ++++++++++++++++++++------------------ migrate_goto_temp.go | 163 ++++ 5 files changed, 1536 insertions(+), 1312 deletions(-) create mode 100644 migrate_goto_temp.go diff --git a/cmd/migrate/config.go b/cmd/migrate/config.go index e5a946922..ff4c3c46d 100644 --- a/cmd/migrate/config.go +++ b/cmd/migrate/config.go @@ -3,36 +3,40 @@ package main import "github.com/spf13/pflag" const ( - // configuration defaults support local development (i.e. "go run ...") - defaultDatabaseDSN = "" - defaultDatabaseDriver = "postgres" - defaultDatabaseAddress = "0.0.0.0:5432" - defaultDatabaseName = "" - defaultDatabaseUser = "postgres" - defaultDatabasePassword = "postgres" - defaultDatabaseSSL = "disable" - defaultConfigDirectory = "/cli/config" + // configuration defaults support local development (i.e. "go run ...") + defaultDatabaseDSN = "" + defaultDatabaseDriver = "postgres" + defaultDatabaseAddress = "0.0.0.0:5432" + defaultDatabaseName = "" + defaultDatabaseUser = "postgres" + defaultDatabasePassword = "postgres" + defaultDatabaseSSL = "disable" + defaultConfigDirectory = "/cli/config" ) var ( - // define flag overrides - flagHelp = pflag.Bool("help", false, "Print usage") - flagVersion = pflag.String("version", Version, "Print version") - flagLoggingVerbose = pflag.Bool("verbose", true, "Print verbose logging") - flagPrefetch = pflag.Uint("prefetch", 10, "Number of migrations to load in advance before executing") - flaglockTimeout = pflag.Uint("lock-timeout", 15, "Allow N seconds to acquire database lock") + // define flag overrides + flagHelp = pflag.Bool("help", false, "Print usage") + flagVersion = pflag.String("version", Version, "Print version") + flagLoggingVerbose = pflag.Bool("verbose", true, "Print verbose logging") + flagPrefetch = pflag.Uint("prefetch", 10, "Number of migrations to load in advance before executing") + flaglockTimeout = pflag.Uint("lock-timeout", 15, "Allow N seconds to acquire database lock") - flagDatabaseDSN = pflag.String("database.dsn", defaultDatabaseDSN, "database connection string") - flagDatabaseDriver = pflag.String("database.driver", defaultDatabaseDriver, "database driver") - flagDatabaseAddress = pflag.String("database.address", defaultDatabaseAddress, "address of the database") - flagDatabaseName = pflag.String("database.name", defaultDatabaseName, "name of the database") - flagDatabaseUser = pflag.String("database.user", defaultDatabaseUser, "database username") - flagDatabasePassword = pflag.String("database.password", defaultDatabasePassword, "database password") - flagDatabaseSSL = pflag.String("database.ssl", defaultDatabaseSSL, "database ssl mode") + flagDatabaseDSN = pflag.String("database.dsn", defaultDatabaseDSN, "database connection string") + flagDatabaseDriver = pflag.String("database.driver", defaultDatabaseDriver, "database driver") + flagDatabaseAddress = pflag.String("database.address", defaultDatabaseAddress, "address of the database") + flagDatabaseName = pflag.String("database.name", defaultDatabaseName, "name of the database") + flagDatabaseUser = pflag.String("database.user", defaultDatabaseUser, "database username") + flagDatabasePassword = pflag.String("database.password", defaultDatabasePassword, "database password") + flagDatabaseSSL = pflag.String("database.ssl", defaultDatabaseSSL, "database ssl mode") - flagSource = pflag.String("source", "", "Location of the migrations (driver://url)") - flagPath = pflag.String("path", "", "Shorthand for -source=file://path") + flagSource = pflag.String("source", "", "Location of the migrations (driver://url)") + flagPath = pflag.String("path", "", "Shorthand for -source=file://path") - flagConfigDirectory = pflag.String("config.source", defaultConfigDirectory, "directory of the configuration file") - flagConfigFile = pflag.String("config.file", "", "configuration file name without extension") + flagConfigDirectory = pflag.String("config.source", defaultConfigDirectory, "directory of the configuration file") + flagConfigFile = pflag.String("config.file", "", "configuration file name without extension") + + // goto command flags + flagDirty = pflag.Bool("dirty", false, "migration is dirty") + flagPVCPath = pflag.String("intermediate-path", "", "path to the mounted volume which is used to copy the migration files") ) diff --git a/internal/cli/commands.go b/internal/cli/commands.go index 7adec2f84..868938f8e 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -1,248 +1,248 @@ package cli import ( - "errors" - "fmt" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/golang-migrate/migrate/v4" - _ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again - _ "github.com/golang-migrate/migrate/v4/source/file" + "errors" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again + _ "github.com/golang-migrate/migrate/v4/source/file" ) var ( - errInvalidSequenceWidth = errors.New("Digits must be positive") - errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive") - errInvalidTimeFormat = errors.New("Time format may not be empty") + errInvalidSequenceWidth = errors.New("Digits must be positive") + errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive") + errInvalidTimeFormat = errors.New("Time format may not be empty") ) func nextSeqVersion(matches []string, seqDigits int) (string, error) { - if seqDigits <= 0 { - return "", errInvalidSequenceWidth - } + if seqDigits <= 0 { + return "", errInvalidSequenceWidth + } - nextSeq := uint64(1) + nextSeq := uint64(1) - if len(matches) > 0 { - filename := matches[len(matches)-1] - matchSeqStr := filepath.Base(filename) - idx := strings.Index(matchSeqStr, "_") + if len(matches) > 0 { + filename := matches[len(matches)-1] + matchSeqStr := filepath.Base(filename) + idx := strings.Index(matchSeqStr, "_") - if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit - return "", fmt.Errorf("Malformed migration filename: %s", filename) - } + if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit + return "", fmt.Errorf("Malformed migration filename: %s", filename) + } - var err error - matchSeqStr = matchSeqStr[0:idx] - nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64) + var err error + matchSeqStr = matchSeqStr[0:idx] + nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64) - if err != nil { - return "", err - } + if err != nil { + return "", err + } - nextSeq++ - } + nextSeq++ + } - version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits) + version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits) - if len(version) > seqDigits { - return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits) - } + if len(version) > seqDigits { + return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits) + } - return version, nil + return version, nil } func timeVersion(startTime time.Time, format string) (version string, err error) { - switch format { - case "": - err = errInvalidTimeFormat - case "unix": - version = strconv.FormatInt(startTime.Unix(), 10) - case "unixNano": - version = strconv.FormatInt(startTime.UnixNano(), 10) - default: - version = startTime.Format(format) - } - - return + switch format { + case "": + err = errInvalidTimeFormat + case "unix": + version = strconv.FormatInt(startTime.Unix(), 10) + case "unixNano": + version = strconv.FormatInt(startTime.UnixNano(), 10) + default: + version = startTime.Format(format) + } + + return } // createCmd (meant to be called via a CLI command) creates a new migration func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int, print bool) error { - if seq && format != defaultTimeFormat { - return errIncompatibleSeqAndFormat - } + if seq && format != defaultTimeFormat { + return errIncompatibleSeqAndFormat + } - var version string - var err error + var version string + var err error - dir = filepath.Clean(dir) - ext = "." + strings.TrimPrefix(ext, ".") + dir = filepath.Clean(dir) + ext = "." + strings.TrimPrefix(ext, ".") - if seq { - matches, err := filepath.Glob(filepath.Join(dir, "*"+ext)) + if seq { + matches, err := filepath.Glob(filepath.Join(dir, "*"+ext)) - if err != nil { - return err - } + if err != nil { + return err + } - version, err = nextSeqVersion(matches, seqDigits) + version, err = nextSeqVersion(matches, seqDigits) - if err != nil { - return err - } - } else { - version, err = timeVersion(startTime, format) + if err != nil { + return err + } + } else { + version, err = timeVersion(startTime, format) - if err != nil { - return err - } - } + if err != nil { + return err + } + } - versionGlob := filepath.Join(dir, version+"_*"+ext) - matches, err := filepath.Glob(versionGlob) + versionGlob := filepath.Join(dir, version+"_*"+ext) + matches, err := filepath.Glob(versionGlob) - if err != nil { - return err - } + if err != nil { + return err + } - if len(matches) > 0 { - return fmt.Errorf("duplicate migration version: %s", version) - } + if len(matches) > 0 { + return fmt.Errorf("duplicate migration version: %s", version) + } - if err = os.MkdirAll(dir, os.ModePerm); err != nil { - return err - } + if err = os.MkdirAll(dir, os.ModePerm); err != nil { + return err + } - for _, direction := range []string{"up", "down"} { - basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext) - filename := filepath.Join(dir, basename) + for _, direction := range []string{"up", "down"} { + basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext) + filename := filepath.Join(dir, basename) - if err = createFile(filename); err != nil { - return err - } + if err = createFile(filename); err != nil { + return err + } - if print { - absPath, _ := filepath.Abs(filename) - log.Println(absPath) - } - } + if print { + absPath, _ := filepath.Abs(filename) + log.Println(absPath) + } + } - return nil + return nil } func createFile(filename string) error { - // create exclusive (fails if file already exists) - // os.Create() specifies 0666 as the FileMode, so we're doing the same - f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666) + // create exclusive (fails if file already exists) + // os.Create() specifies 0666 as the FileMode, so we're doing the same + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666) - if err != nil { - return err - } + if err != nil { + return err + } - return f.Close() + return f.Close() } func gotoCmd(m *migrate.Migrate, v uint) error { - if err := m.Migrate(v); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - return nil + if err := m.Migrate(v); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + return nil } func upCmd(m *migrate.Migrate, limit int) error { - if limit >= 0 { - if err := m.Steps(limit); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } else { - if err := m.Up(); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } - return nil + if limit >= 0 { + if err := m.Steps(limit); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } else { + if err := m.Up(); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } + return nil } func downCmd(m *migrate.Migrate, limit int) error { - if limit >= 0 { - if err := m.Steps(-limit); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } else { - if err := m.Down(); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } - return nil + if limit >= 0 { + if err := m.Steps(-limit); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } else { + if err := m.Down(); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } + return nil } func dropCmd(m *migrate.Migrate) error { - if err := m.Drop(); err != nil { - return err - } - return nil + if err := m.Drop(); err != nil { + return err + } + return nil } func forceCmd(m *migrate.Migrate, v int) error { - if err := m.Force(v); err != nil { - return err - } - return nil + if err := m.Force(v); err != nil { + return err + } + return nil } func versionCmd(m *migrate.Migrate) error { - v, dirty, err := m.Version() - if err != nil { - return err - } - if dirty { - log.Printf("%v (dirty)\n", v) - } else { - log.Println(v) - } - return nil + v, dirty, err := m.Version() + if err != nil { + return err + } + if dirty { + log.Printf("%v (dirty)\n", v) + } else { + log.Println(v) + } + return nil } // numDownMigrationsFromArgs returns an int for number of migrations to apply // and a bool indicating if we need a confirm before applying func numDownMigrationsFromArgs(applyAll bool, args []string) (int, bool, error) { - if applyAll { - if len(args) > 0 { - return 0, false, errors.New("-all cannot be used with other arguments") - } - return -1, false, nil - } - - switch len(args) { - case 0: - return -1, true, nil - case 1: - downValue := args[0] - n, err := strconv.ParseUint(downValue, 10, 64) - if err != nil { - return 0, false, errors.New("can't read limit argument N") - } - return int(n), false, nil - default: - return 0, false, errors.New("too many arguments") - } + if applyAll { + if len(args) > 0 { + return 0, false, errors.New("-all cannot be used with other arguments") + } + return -1, false, nil + } + + switch len(args) { + case 0: + return -1, true, nil + case 1: + downValue := args[0] + n, err := strconv.ParseUint(downValue, 10, 64) + if err != nil { + return 0, false, errors.New("can't read limit argument N") + } + return int(n), false, nil + default: + return 0, false, errors.New("too many arguments") + } } diff --git a/internal/cli/main.go b/internal/cli/main.go index ece7eff0b..a0084dbd0 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -1,96 +1,103 @@ package cli import ( - "database/sql" - "fmt" - "net/url" - "os" - "os/signal" - "strconv" - "strings" - "syscall" - "time" - - flag "github.com/spf13/pflag" - "github.com/spf13/viper" - - "github.com/golang-migrate/migrate/v4" - "github.com/golang-migrate/migrate/v4/database" - "github.com/golang-migrate/migrate/v4/database/postgres" - "github.com/golang-migrate/migrate/v4/source" + "database/sql" + "fmt" + "net/url" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + "time" + + flag "github.com/spf13/pflag" + "github.com/spf13/viper" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source" ) const ( - defaultTimeFormat = "20060102150405" - defaultTimezone = "UTC" - createUsage = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME + defaultTimeFormat = "20060102150405" + defaultTimezone = "UTC" + createUsage = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME Create a set of timestamped up/down migrations titled NAME, in directory D with extension E. Use -seq option to generate sequential up/down migrations with N digits. Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error. Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC). ` - gotoUsage = `goto V Migrate to version V` - upUsage = `up [N] Apply all or N up migrations` - downUsage = `down [N] [-all] Apply all or N down migrations + gotoUsage = `goto V [-dirty] Migrate to version V` + upUsage = `up [N] Apply all or N up migrations` + downUsage = `down [N] [-all] Apply all or N down migrations Use -all to apply all down migrations` - dropUsage = `drop [-f] Drop everything inside database + dropUsage = `drop [-f] Drop everything inside database Use -f to bypass confirmation` - forceUsage = `force V Set version V but don't run migration (ignores dirty state)` + forceUsage = `force V Set version V but don't run migration (ignores dirty state)` ) func handleSubCmdHelp(help bool, usage string, flagSet *flag.FlagSet) { - if help { - fmt.Fprintln(os.Stderr, usage) - flagSet.PrintDefaults() - os.Exit(0) - } + if help { + fmt.Fprintln(os.Stderr, usage) + flagSet.PrintDefaults() + os.Exit(0) + } } func newFlagSetWithHelp(name string) (*flag.FlagSet, *bool) { - flagSet := flag.NewFlagSet(name, flag.ExitOnError) - helpPtr := flagSet.Bool("help", false, "Print help information") - return flagSet, helpPtr + flagSet := flag.NewFlagSet(name, flag.ExitOnError) + helpPtr := flagSet.Bool("help", false, "Print help information") + return flagSet, helpPtr +} + +func newGoToFlagSetWithHelp(name string) (*flag.FlagSet, *bool) { + flagSet := flag.NewFlagSet(name, flag.ExitOnError) + flagSet.Bool("dirty", false, "Migration in dirty state") + helpPtr := flagSet.Bool("help", false, "Print help information") + return flagSet, helpPtr } // set main log var log = &Log{} func printUsageAndExit() { - flag.Usage() + flag.Usage() - // If a command is not found we exit with a status 2 to match the behavior - // of flag.Parse() with flag.ExitOnError when parsing an invalid flag. - os.Exit(2) + // If a command is not found we exit with a status 2 to match the behavior + // of flag.Parse() with flag.ExitOnError when parsing an invalid flag. + os.Exit(2) } func dbMakeConnectionString(driver, user, password, address, name, ssl string) string { - return fmt.Sprintf("%s://%s:%s@%s/%s?sslmode=%s", - driver, url.QueryEscape(user), url.QueryEscape(password), address, name, ssl, - ) + return fmt.Sprintf("%s://%s:%s@%s/%s?sslmode=%s", + driver, url.QueryEscape(user), url.QueryEscape(password), address, name, ssl, + ) } // Main function of a cli application. It is public for backwards compatibility with `cli` package func Main(version string) { - help := viper.GetBool("help") - version = viper.GetString("version") - verbose := viper.GetBool("verbose") - prefetch := viper.GetInt("prefetch") - lockTimeout := viper.GetInt("lock-timeout") - path := viper.GetString("path") - sourcePtr := viper.GetString("source") - - databasePtr := viper.GetString("database.dsn") - if databasePtr == "" { - databasePtr = dbMakeConnectionString( - viper.GetString("database.driver"), viper.GetString("database.user"), - viper.GetString("database.password"), viper.GetString("database.address"), - viper.GetString("database.name"), viper.GetString("database.ssl"), - ) - } - - flag.Usage = func() { - fmt.Fprintf(os.Stderr, - `Usage: migrate OPTIONS COMMAND [arg...] + help := viper.GetBool("help") + version = viper.GetString("version") + verbose := viper.GetBool("verbose") + prefetch := viper.GetInt("prefetch") + lockTimeout := viper.GetInt("lock-timeout") + path := viper.GetString("path") + sourcePtr := viper.GetString("source") + + databasePtr := viper.GetString("database.dsn") + if databasePtr == "" { + databasePtr = dbMakeConnectionString( + viper.GetString("database.driver"), viper.GetString("database.user"), + viper.GetString("database.password"), viper.GetString("database.address"), + viper.GetString("database.name"), viper.GetString("database.ssl"), + ) + } + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, + `Usage: migrate OPTIONS COMMAND [arg...] migrate [ -version | -help ] Options: @@ -125,301 +132,308 @@ Commands: Source drivers: `+strings.Join(source.List(), ", ")+` Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoUsage, upUsage, downUsage, dropUsage, forceUsage) - } - - // initialize logger - log.verbose = verbose - - // show cli version - if version == "" { - fmt.Fprintln(os.Stderr, version) - os.Exit(0) - } - - // show help - if help { - flag.Usage() - os.Exit(0) - } - - // translate -path into -source if given - if sourcePtr == "" && path != "" { - sourcePtr = fmt.Sprintf("file://%v", path) - } - - // initialize migrate - // don't catch migraterErr here and let each command decide - // how it wants to handle the error - var migrater *migrate.Migrate - var migraterErr error - - if driver := viper.GetString("database.driver"); driver == "hotload" { - db, err := sql.Open(driver, databasePtr) - if err != nil { - log.fatalErr(fmt.Errorf("could not open hotload dsn %s: %s", databasePtr, err)) - } - var dbname, user string - if err := db.QueryRow("SELECT current_database(), user").Scan(&dbname, &user); err != nil { - log.fatalErr(fmt.Errorf("could not get current_database: %s", err.Error())) - } - // dbname is not needed since it gets filled in by the driver but we want to be complete - migrateDriver, err := postgres.WithInstance(db, &postgres.Config{DatabaseName: dbname}) - if err != nil { - log.fatalErr(fmt.Errorf("could not create migrate driver: %s", err)) - } - migrater, migraterErr = migrate.NewWithDatabaseInstance(sourcePtr, dbname, migrateDriver) - } else { - migrater, migraterErr = migrate.New(sourcePtr, databasePtr) - } - defer func() { - if migraterErr == nil { - if _, err := migrater.Close(); err != nil { - log.Println(err) - } - } - }() - if migraterErr == nil { - migrater.Log = log - migrater.PrefetchMigrations = uint(prefetch) - migrater.LockTimeout = time.Duration(int64(lockTimeout)) * time.Second - - // handle Ctrl+c - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT) - go func() { - for range signals { - log.Println("Stopping after this running migration ...") - migrater.GracefulStop <- true - return - } - }() - } - - startTime := time.Now() - - if len(flag.Args()) < 1 { - printUsageAndExit() - } - args := flag.Args()[1:] - - switch flag.Arg(0) { - case "create": - - seq := false - seqDigits := 6 - - createFlagSet, help := newFlagSetWithHelp("create") - extPtr := createFlagSet.String("ext", "", "File extension") - dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)") - formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`) - timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`) - createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)") - createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)") - - if err := createFlagSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*help, createUsage, createFlagSet) - - if createFlagSet.NArg() == 0 { - log.fatal("error: please specify name") - } - name := createFlagSet.Arg(0) - - if *extPtr == "" { - log.fatal("error: -ext flag must be specified") - } - - timezone, err := time.LoadLocation(*timezoneName) - if err != nil { - log.fatal(err) - } - - if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil { - log.fatalErr(err) - } - - case "goto": - - gotoSet, helpPtr := newFlagSetWithHelp("goto") - - if err := gotoSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if gotoSet.NArg() == 0 { - log.fatal("error: please specify version argument V") - } - - v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64) - if err != nil { - log.fatal("error: can't read version argument V") - } - - if err := gotoCmd(migrater, uint(v)); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "up": - upSet, helpPtr := newFlagSetWithHelp("up") - - if err := upSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, upUsage, upSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - limit := -1 - if upSet.NArg() > 0 { - n, err := strconv.ParseUint(upSet.Arg(0), 10, 64) - if err != nil { - log.fatal("error: can't read limit argument N") - } - limit = int(n) - } - - if err := upCmd(migrater, limit); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "down": - downFlagSet, helpPtr := newFlagSetWithHelp("down") - applyAll := downFlagSet.Bool("all", false, "Apply all down migrations") - - if err := downFlagSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, downUsage, downFlagSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - downArgs := downFlagSet.Args() - num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs) - if err != nil { - log.fatalErr(err) - } - if needsConfirm { - log.Println("Are you sure you want to apply all down migrations? [y/N]") - var response string - _, _ = fmt.Scanln(&response) - response = strings.ToLower(strings.TrimSpace(response)) - - if response == "y" { - log.Println("Applying all down migrations") - } else { - log.fatal("Not applying all down migrations") - } - } - - if err := downCmd(migrater, num); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "drop": - dropFlagSet, help := newFlagSetWithHelp("drop") - forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt") - - if err := dropFlagSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*help, dropUsage, dropFlagSet) - - if !*forceDrop { - log.Println("Are you sure you want to drop the entire database schema? [y/N]") - var response string - _, _ = fmt.Scanln(&response) - response = strings.ToLower(strings.TrimSpace(response)) - - if response == "y" { - log.Println("Dropping the entire database schema") - } else { - log.fatal("Aborted dropping the entire database schema") - } - } - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if err := dropCmd(migrater); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "force": - forceSet, helpPtr := newFlagSetWithHelp("force") - - if err := forceSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, forceUsage, forceSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if forceSet.NArg() == 0 { - log.fatal("error: please specify version argument V") - } - - v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64) - if err != nil { - log.fatal("error: can't read version argument V") - } - - if v < -1 { - log.fatal("error: argument V must be >= -1") - } - - if err := forceCmd(migrater, int(v)); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "version": - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if err := versionCmd(migrater); err != nil { - log.fatalErr(err) - } - - default: - printUsageAndExit() - } + } + + // initialize logger + log.verbose = verbose + + // show cli version + if version == "" { + fmt.Fprintln(os.Stderr, version) + os.Exit(0) + } + + // show help + if help { + flag.Usage() + os.Exit(0) + } + + // translate -path into -source if given + if sourcePtr == "" && path != "" { + sourcePtr = fmt.Sprintf("file://%v", path) + } + + // initialize migrate + // don't catch migraterErr here and let each command decide + // how it wants to handle the error + var migrater *migrate.Migrate + var migraterErr error + + if driver := viper.GetString("database.driver"); driver == "hotload" { + db, err := sql.Open(driver, databasePtr) + if err != nil { + log.fatalErr(fmt.Errorf("could not open hotload dsn %s: %s", databasePtr, err)) + } + var dbname, user string + if err := db.QueryRow("SELECT current_database(), user").Scan(&dbname, &user); err != nil { + log.fatalErr(fmt.Errorf("could not get current_database: %s", err.Error())) + } + // dbname is not needed since it gets filled in by the driver but we want to be complete + migrateDriver, err := postgres.WithInstance(db, &postgres.Config{DatabaseName: dbname}) + if err != nil { + log.fatalErr(fmt.Errorf("could not create migrate driver: %s", err)) + } + migrater, migraterErr = migrate.NewWithDatabaseInstance(sourcePtr, dbname, migrateDriver) + } else { + migrater, migraterErr = migrate.New(sourcePtr, databasePtr) + } + defer func() { + if migraterErr == nil { + if _, err := migrater.Close(); err != nil { + log.Println(err) + } + } + }() + if migraterErr == nil { + migrater.Log = log + migrater.PrefetchMigrations = uint(prefetch) + migrater.LockTimeout = time.Duration(int64(lockTimeout)) * time.Second + + // handle Ctrl+c + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT) + go func() { + for range signals { + log.Println("Stopping after this running migration ...") + migrater.GracefulStop <- true + return + } + }() + } + + startTime := time.Now() + + if len(flag.Args()) < 1 { + printUsageAndExit() + } + args := flag.Args()[1:] + + switch flag.Arg(0) { + case "create": + + seq := false + seqDigits := 6 + + createFlagSet, help := newFlagSetWithHelp("create") + extPtr := createFlagSet.String("ext", "", "File extension") + dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)") + formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`) + timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`) + createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)") + createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)") + + if err := createFlagSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*help, createUsage, createFlagSet) + + if createFlagSet.NArg() == 0 { + log.fatal("error: please specify name") + } + name := createFlagSet.Arg(0) + + if *extPtr == "" { + log.fatal("error: -ext flag must be specified") + } + + timezone, err := time.LoadLocation(*timezoneName) + if err != nil { + log.fatal(err) + } + + if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil { + log.fatalErr(err) + } + + case "goto": + + gotoSet, helpPtr := newFlagSetWithHelp("goto") + + if err := gotoSet.Parse(args); err != nil { + log.fatalErr(err) + } + handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if gotoSet.NArg() == 0 { + log.fatal("error: please specify version argument V") + } + + v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64) + if err != nil { + log.fatal("error: can't read version argument V") + } + handleDirty := viper.GetBool("dirty") + destPath := viper.GetString("intermediate-path") + + if handleDirty && destPath == "" { + log.fatal("error: intermediate-path must be specified when dirty is set") + } + migrater.WithDirtyStateHandler(path, destPath, handleDirty) + if err = gotoCmd(migrater, uint(v)); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "up": + upSet, helpPtr := newFlagSetWithHelp("up") + + if err := upSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, upUsage, upSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + limit := -1 + if upSet.NArg() > 0 { + n, err := strconv.ParseUint(upSet.Arg(0), 10, 64) + if err != nil { + log.fatal("error: can't read limit argument N") + } + limit = int(n) + } + + if err := upCmd(migrater, limit); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "down": + downFlagSet, helpPtr := newFlagSetWithHelp("down") + applyAll := downFlagSet.Bool("all", false, "Apply all down migrations") + if err := downFlagSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, downUsage, downFlagSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + downArgs := downFlagSet.Args() + + log.Println(*applyAll, downArgs) + + num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs) + if err != nil { + log.fatalErr(err) + } + if needsConfirm { + log.Println("Are you sure you want to apply all down migrations? [y/N]") + var response string + _, _ = fmt.Scanln(&response) + response = strings.ToLower(strings.TrimSpace(response)) + + if response == "y" { + log.Println("Applying all down migrations") + } else { + log.fatal("Not applying all down migrations") + } + } + + if err := downCmd(migrater, num); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "drop": + dropFlagSet, help := newFlagSetWithHelp("drop") + forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt") + + if err := dropFlagSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*help, dropUsage, dropFlagSet) + + if !*forceDrop { + log.Println("Are you sure you want to drop the entire database schema? [y/N]") + var response string + _, _ = fmt.Scanln(&response) + response = strings.ToLower(strings.TrimSpace(response)) + + if response == "y" { + log.Println("Dropping the entire database schema") + } else { + log.fatal("Aborted dropping the entire database schema") + } + } + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if err := dropCmd(migrater); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "force": + forceSet, helpPtr := newFlagSetWithHelp("force") + + if err := forceSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, forceUsage, forceSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if forceSet.NArg() == 0 { + log.fatal("error: please specify version argument V") + } + + v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64) + if err != nil { + log.fatal("error: can't read version argument V") + } + + if v < -1 { + log.fatal("error: argument V must be >= -1") + } + + if err := forceCmd(migrater, int(v)); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "version": + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if err := versionCmd(migrater); err != nil { + log.fatalErr(err) + } + + default: + printUsageAndExit() + } } diff --git a/migrate.go b/migrate.go index 7763782a0..ccbd449f0 100644 --- a/migrate.go +++ b/migrate.go @@ -5,17 +5,17 @@ package migrate import ( - "errors" - "fmt" - "os" - "sync" - "time" + "errors" + "fmt" + "os" + "sync" + "time" - "github.com/hashicorp/go-multierror" + "github.com/hashicorp/go-multierror" - "github.com/golang-migrate/migrate/v4/database" - iurl "github.com/golang-migrate/migrate/v4/internal/url" - "github.com/golang-migrate/migrate/v4/source" + "github.com/golang-migrate/migrate/v4/database" + iurl "github.com/golang-migrate/migrate/v4/internal/url" + "github.com/golang-migrate/migrate/v4/source" ) // DefaultPrefetchMigrations sets the number of migrations to pre-read @@ -29,89 +29,98 @@ var DefaultPrefetchMigrations = uint(10) var DefaultLockTimeout = 15 * time.Second var ( - ErrNoChange = errors.New("no change") - ErrNilVersion = errors.New("no migration") - ErrInvalidVersion = errors.New("version must be >= -1") - ErrLocked = errors.New("database locked") - ErrLockTimeout = errors.New("timeout: can't acquire database lock") + ErrNoChange = errors.New("no change") + ErrNilVersion = errors.New("no migration") + ErrInvalidVersion = errors.New("version must be >= -1") + ErrLocked = errors.New("database locked") + ErrLockTimeout = errors.New("timeout: can't acquire database lock") ) // ErrShortLimit is an error returned when not enough migrations // can be returned by a source for a given limit. type ErrShortLimit struct { - Short uint + Short uint } // Error implements the error interface. func (e ErrShortLimit) Error() string { - return fmt.Sprintf("limit %v short", e.Short) + return fmt.Sprintf("limit %v short", e.Short) } type ErrDirty struct { - Version int + Version int } func (e ErrDirty) Error() string { - return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version) + return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version) } type Migrate struct { - sourceName string - sourceDrv source.Driver - databaseName string - databaseDrv database.Driver + sourceName string + sourceDrv source.Driver + databaseName string + databaseDrv database.Driver - // Log accepts a Logger interface - Log Logger + // Log accepts a Logger interface + Log Logger - // GracefulStop accepts `true` and will stop executing migrations - // as soon as possible at a safe break point, so that the database - // is not corrupted. - GracefulStop chan bool - isLockedMu *sync.Mutex + // GracefulStop accepts `true` and will stop executing migrations + // as soon as possible at a safe break point, so that the database + // is not corrupted. + GracefulStop chan bool + isLockedMu *sync.Mutex - isGracefulStop bool - isLocked bool + isGracefulStop bool + isLocked bool - // PrefetchMigrations defaults to DefaultPrefetchMigrations, - // but can be set per Migrate instance. - PrefetchMigrations uint + // PrefetchMigrations defaults to DefaultPrefetchMigrations, + // but can be set per Migrate instance. + PrefetchMigrations uint - // LockTimeout defaults to DefaultLockTimeout, - // but can be set per Migrate instance. - LockTimeout time.Duration + // LockTimeout defaults to DefaultLockTimeout, + // but can be set per Migrate instance. + LockTimeout time.Duration + + // DirtyStateHandler is used to handle dirty state of the database + ds *dirtyStateHandler +} + +type dirtyStateHandler struct { + srcPath string + destPath string + isDirty bool } // New returns a new Migrate instance from a source URL and a database URL. // The URL scheme is defined by each driver. func New(sourceURL, databaseURL string) (*Migrate, error) { - m := newCommon() - - sourceName, err := iurl.SchemeFromURL(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from source URL: %w", err) - } - m.sourceName = sourceName - - databaseName, err := iurl.SchemeFromURL(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) - } - m.databaseName = databaseName - - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv - - databaseDrv, err := database.Open(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) - } - m.databaseDrv = databaseDrv - - return m, nil + m := newCommon() + + sourceName, err := iurl.SchemeFromURL(sourceURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from source URL: %w", err) + } + m.sourceName = sourceName + + databaseName, err := iurl.SchemeFromURL(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) + } + m.databaseName = databaseName + + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv + + databaseDrv, err := database.Open(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) + } + m.databaseDrv = databaseDrv + + return m, nil } // NewWithDatabaseInstance returns a new Migrate instance from a source URL @@ -119,25 +128,25 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { // Use any string that can serve as an identifier during logging as databaseName. // You are responsible for closing the underlying database client if necessary. func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() + m := newCommon() - sourceName, err := iurl.SchemeFromURL(sourceURL) - if err != nil { - return nil, err - } - m.sourceName = sourceName + sourceName, err := iurl.SchemeFromURL(sourceURL) + if err != nil { + return nil, err + } + m.sourceName = sourceName - m.databaseName = databaseName + m.databaseName = databaseName - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv - m.databaseDrv = databaseInstance + m.databaseDrv = databaseInstance - return m, nil + return m, nil } // NewWithSourceInstance returns a new Migrate instance from an existing source instance @@ -145,25 +154,25 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst // Use any string that can serve as an identifier during logging as sourceName. // You are responsible for closing the underlying source client if necessary. func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { - m := newCommon() + m := newCommon() - databaseName, err := iurl.SchemeFromURL(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) - } - m.databaseName = databaseName + databaseName, err := iurl.SchemeFromURL(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) + } + m.databaseName = databaseName - m.sourceName = sourceName + m.sourceName = sourceName - databaseDrv, err := database.Open(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) - } - m.databaseDrv = databaseDrv + databaseDrv, err := database.Open(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) + } + m.databaseDrv = databaseDrv - m.sourceDrv = sourceInstance + m.sourceDrv = sourceInstance - return m, nil + return m, nil } // NewWithInstance returns a new Migrate instance from an existing source and @@ -171,149 +180,183 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data // as sourceName and databaseName. You are responsible for closing down // the underlying source and database client if necessary. func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() + m := newCommon() - m.sourceName = sourceName - m.databaseName = databaseName + m.sourceName = sourceName + m.databaseName = databaseName - m.sourceDrv = sourceInstance - m.databaseDrv = databaseInstance + m.sourceDrv = sourceInstance + m.databaseDrv = databaseInstance + + return m, nil +} - return m, nil +func (m *Migrate) WithDirtyStateHandler(srcPath, destPath string, isDirty bool) { + m.ds = &dirtyStateHandler{ + srcPath: srcPath, + destPath: destPath, + isDirty: isDirty, + } } func newCommon() *Migrate { - return &Migrate{ - GracefulStop: make(chan bool, 1), - PrefetchMigrations: DefaultPrefetchMigrations, - LockTimeout: DefaultLockTimeout, - isLockedMu: &sync.Mutex{}, - } + return &Migrate{ + GracefulStop: make(chan bool, 1), + PrefetchMigrations: DefaultPrefetchMigrations, + LockTimeout: DefaultLockTimeout, + isLockedMu: &sync.Mutex{}, + } } // Close closes the source and the database. func (m *Migrate) Close() (source error, database error) { - databaseSrvClose := make(chan error) - sourceSrvClose := make(chan error) + databaseSrvClose := make(chan error) + sourceSrvClose := make(chan error) - m.logVerbosePrintf("Closing source and database\n") + m.logVerbosePrintf("Closing source and database\n") - go func() { - databaseSrvClose <- m.databaseDrv.Close() - }() + go func() { + databaseSrvClose <- m.databaseDrv.Close() + }() - go func() { - sourceSrvClose <- m.sourceDrv.Close() - }() + go func() { + sourceSrvClose <- m.sourceDrv.Close() + }() - return <-sourceSrvClose, <-databaseSrvClose + return <-sourceSrvClose, <-databaseSrvClose } // Migrate looks at the currently active migration version, // then migrates either up or down to the specified version. func (m *Migrate) Migrate(version uint) error { - if err := m.lock(); err != nil { - return err - } - - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } - - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } - - ret := make(chan interface{}, m.PrefetchMigrations) - go m.read(curVersion, int(version), ret) - - return m.unlockErr(m.runMigrations(ret)) + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return err + } + if err = m.CopyFiles(); err != nil { + return err + } + // if the dirty flag is passed to the 'goto' command, handle the dirty state + if m.ds.isDirty && dirty { + m.Log.Printf("Version: %d, handle dirty: %t\n", version, m.ds.isDirty) + if err = m.HandleDirtyState(); err != nil { + return err + } + } + + if err = m.lock(); err != nil { + return err + } + + if dirty { + // default behaviour + m.Log.Printf("Database is set to dirty for version: %v\n", curVersion) + return m.unlockErr(ErrDirty{curVersion}) + } + + ret := make(chan interface{}, m.PrefetchMigrations) + go m.read(curVersion, int(version), ret) + + err = m.runMigrations(ret) + if err != nil { + if m.ds.isDirty { + // Handle failure: store last successful migration version and exit + if err = m.HandleMigrationFailure(curVersion, version); err != nil { + return err + } + } + return m.unlockErr(err) + } + // Success: Clean up and confirm + if err = m.CleanupFiles(version); err != nil { + return m.unlockErr(err) + } + return nil } // Steps looks at the currently active migration version. // It will migrate up if n > 0, and down if n < 0. func (m *Migrate) Steps(n int) error { - if n == 0 { - return ErrNoChange - } + if n == 0 { + return ErrNoChange + } - if err := m.lock(); err != nil { - return err - } + if err := m.lock(); err != nil { + return err + } - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } - ret := make(chan interface{}, m.PrefetchMigrations) + ret := make(chan interface{}, m.PrefetchMigrations) - if n > 0 { - go m.readUp(curVersion, n, ret) - } else { - go m.readDown(curVersion, -n, ret) - } + if n > 0 { + go m.readUp(curVersion, n, ret) + } else { + go m.readDown(curVersion, -n, ret) + } - return m.unlockErr(m.runMigrations(ret)) + return m.unlockErr(m.runMigrations(ret)) } // Up looks at the currently active migration version // and will migrate all the way up (applying all up migrations). func (m *Migrate) Up() error { - if err := m.lock(); err != nil { - return err - } + if err := m.lock(); err != nil { + return err + } - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } - ret := make(chan interface{}, m.PrefetchMigrations) + ret := make(chan interface{}, m.PrefetchMigrations) - go m.readUp(curVersion, -1, ret) - return m.unlockErr(m.runMigrations(ret)) + go m.readUp(curVersion, -1, ret) + return m.unlockErr(m.runMigrations(ret)) } // Down looks at the currently active migration version // and will migrate all the way down (applying all down migrations). func (m *Migrate) Down() error { - if err := m.lock(); err != nil { - return err - } - - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } - - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } - - ret := make(chan interface{}, m.PrefetchMigrations) - go m.readDown(curVersion, -1, ret) - return m.unlockErr(m.runMigrations(ret)) + if err := m.lock(); err != nil { + return err + } + + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } + + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } + + ret := make(chan interface{}, m.PrefetchMigrations) + go m.readDown(curVersion, -1, ret) + return m.unlockErr(m.runMigrations(ret)) } // Drop deletes everything in the database. func (m *Migrate) Drop() error { - if err := m.lock(); err != nil { - return err - } - if err := m.databaseDrv.Drop(); err != nil { - return m.unlockErr(err) - } - return m.unlock() + if err := m.lock(); err != nil { + return err + } + if err := m.databaseDrv.Drop(); err != nil { + return m.unlockErr(err) + } + return m.unlock() } // Run runs any migration provided by you against the database. @@ -321,78 +364,78 @@ func (m *Migrate) Drop() error { // Usually you don't need this function at all. Use Migrate, // Steps, Up or Down instead. func (m *Migrate) Run(migration ...*Migration) error { - if len(migration) == 0 { - return ErrNoChange - } - - if err := m.lock(); err != nil { - return err - } - - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } - - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } - - ret := make(chan interface{}, m.PrefetchMigrations) - - go func() { - defer close(ret) - for _, migr := range migration { - if m.PrefetchMigrations > 0 && migr.Body != nil { - m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) - } else { - m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) - } - - ret <- migr - go func(migr *Migration) { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }(migr) - } - }() - - return m.unlockErr(m.runMigrations(ret)) + if len(migration) == 0 { + return ErrNoChange + } + + if err := m.lock(); err != nil { + return err + } + + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } + + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } + + ret := make(chan interface{}, m.PrefetchMigrations) + + go func() { + defer close(ret) + for _, migr := range migration { + if m.PrefetchMigrations > 0 && migr.Body != nil { + m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) + } else { + m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) + } + + ret <- migr + go func(migr *Migration) { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }(migr) + } + }() + + return m.unlockErr(m.runMigrations(ret)) } // Force sets a migration version. // It does not check any currently active version in database. // It resets the dirty state to false. func (m *Migrate) Force(version int) error { - if version < -1 { - return ErrInvalidVersion - } + if version < -1 { + return ErrInvalidVersion + } - if err := m.lock(); err != nil { - return err - } + if err := m.lock(); err != nil { + return err + } - if err := m.databaseDrv.SetVersion(version, false); err != nil { - return m.unlockErr(err) - } + if err := m.databaseDrv.SetVersion(version, false); err != nil { + return m.unlockErr(err) + } - return m.unlock() + return m.unlock() } // Version returns the currently active migration version. // If no migration has been applied, yet, it will return ErrNilVersion. func (m *Migrate) Version() (version uint, dirty bool, err error) { - v, d, err := m.databaseDrv.Version() - if err != nil { - return 0, false, err - } + v, d, err := m.databaseDrv.Version() + if err != nil { + return 0, false, err + } - if v == database.NilVersion { - return 0, false, ErrNilVersion - } + if v == database.NilVersion { + return 0, false, ErrNilVersion + } - return suint(v), d, nil + return suint(v), d, nil } // read reads either up or down migrations from source `from` to `to`. @@ -400,130 +443,130 @@ func (m *Migrate) Version() (version uint, dirty bool, err error) { // If an error occurs during reading, that error is written to the ret channel, too. // Once read is done reading it will close the ret channel. func (m *Migrate) read(from int, to int, ret chan<- interface{}) { - defer close(ret) - - // check if from version exists - if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { - ret <- err - return - } - } - - // check if to version exists - if to >= 0 { - if err := m.versionExists(suint(to)); err != nil { - ret <- err - return - } - } - - // no change? - if from == to { - ret <- ErrNoChange - return - } - - if from < to { - // it's going up - // apply first migration if from is nil version - if from == -1 { - firstVersion, err := m.sourceDrv.First() - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(firstVersion, int(firstVersion)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - from = int(firstVersion) - } - - // run until we reach target ... - for from < to { - if m.stop() { - return - } - - next, err := m.sourceDrv.Next(suint(from)) - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(next, int(next)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - from = int(next) - } - - } else { - // it's going down - // run until we reach target ... - for from > to && from >= 0 { - if m.stop() { - return - } - - prev, err := m.sourceDrv.Prev(suint(from)) - if errors.Is(err, os.ErrNotExist) && to == -1 { - // apply nil migration - migr, err := m.newMigration(suint(from), -1) - if err != nil { - ret <- err - return - } - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - return - - } else if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(suint(from), int(prev)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - from = int(prev) - } - } + defer close(ret) + + // check if from version exists + if from >= 0 { + if err := m.versionExists(suint(from)); err != nil { + ret <- err + return + } + } + + // check if to version exists + if to >= 0 { + if err := m.versionExists(suint(to)); err != nil { + ret <- err + return + } + } + + // no change? + if from == to { + ret <- ErrNoChange + return + } + + if from < to { + // it's going up + // apply first migration if from is nil version + if from == -1 { + firstVersion, err := m.sourceDrv.First() + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(firstVersion, int(firstVersion)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + from = int(firstVersion) + } + + // run until we reach target ... + for from < to { + if m.stop() { + return + } + + next, err := m.sourceDrv.Next(suint(from)) + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(next, int(next)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + from = int(next) + } + + } else { + // it's going down + // run until we reach target ... + for from > to && from >= 0 { + if m.stop() { + return + } + + prev, err := m.sourceDrv.Prev(suint(from)) + if errors.Is(err, os.ErrNotExist) && to == -1 { + // apply nil migration + migr, err := m.newMigration(suint(from), -1) + if err != nil { + ret <- err + return + } + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + return + + } else if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(suint(from), int(prev)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + from = int(prev) + } + } } // readUp reads up migrations from `from` limitted by `limit`. @@ -532,98 +575,98 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) { // If an error occurs during reading, that error is written to the ret channel, too. // Once readUp is done reading it will close the ret channel. func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { - defer close(ret) - - // check if from version exists - if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { - ret <- err - return - } - } - - if limit == 0 { - ret <- ErrNoChange - return - } - - count := 0 - for count < limit || limit == -1 { - if m.stop() { - return - } - - // apply first migration if from is nil version - if from == -1 { - firstVersion, err := m.sourceDrv.First() - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(firstVersion, int(firstVersion)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - from = int(firstVersion) - count++ - continue - } - - // apply next migration - next, err := m.sourceDrv.Next(suint(from)) - if errors.Is(err, os.ErrNotExist) { - // no limit, but no migrations applied? - if limit == -1 && count == 0 { - ret <- ErrNoChange - return - } - - // no limit, reached end - if limit == -1 { - return - } - - // reached end, and didn't apply any migrations - if limit > 0 && count == 0 { - ret <- os.ErrNotExist - return - } - - // applied less migrations than limit? - if count < limit { - ret <- ErrShortLimit{suint(limit - count)} - return - } - } - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(next, int(next)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - from = int(next) - count++ - } + defer close(ret) + + // check if from version exists + if from >= 0 { + if err := m.versionExists(suint(from)); err != nil { + ret <- err + return + } + } + + if limit == 0 { + ret <- ErrNoChange + return + } + + count := 0 + for count < limit || limit == -1 { + if m.stop() { + return + } + + // apply first migration if from is nil version + if from == -1 { + firstVersion, err := m.sourceDrv.First() + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(firstVersion, int(firstVersion)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + from = int(firstVersion) + count++ + continue + } + + // apply next migration + next, err := m.sourceDrv.Next(suint(from)) + if errors.Is(err, os.ErrNotExist) { + // no limit, but no migrations applied? + if limit == -1 && count == 0 { + ret <- ErrNoChange + return + } + + // no limit, reached end + if limit == -1 { + return + } + + // reached end, and didn't apply any migrations + if limit > 0 && count == 0 { + ret <- os.ErrNotExist + return + } + + // applied less migrations than limit? + if count < limit { + ret <- ErrShortLimit{suint(limit - count)} + return + } + } + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(next, int(next)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + from = int(next) + count++ + } } // readDown reads down migrations from `from` limitted by `limit`. @@ -632,88 +675,88 @@ func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { // If an error occurs during reading, that error is written to the ret channel, too. // Once readDown is done reading it will close the ret channel. func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { - defer close(ret) - - // check if from version exists - if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { - ret <- err - return - } - } - - if limit == 0 { - ret <- ErrNoChange - return - } - - // no change if already at nil version - if from == -1 && limit == -1 { - ret <- ErrNoChange - return - } - - // can't go over limit if already at nil version - if from == -1 && limit > 0 { - ret <- os.ErrNotExist - return - } - - count := 0 - for count < limit || limit == -1 { - if m.stop() { - return - } - - prev, err := m.sourceDrv.Prev(suint(from)) - if errors.Is(err, os.ErrNotExist) { - // no limit or haven't reached limit, apply "first" migration - if limit == -1 || limit-count > 0 { - firstVersion, err := m.sourceDrv.First() - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(firstVersion, -1) - if err != nil { - ret <- err - return - } - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - count++ - } - - if count < limit { - ret <- ErrShortLimit{suint(limit - count)} - } - return - } - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(suint(from), int(prev)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - from = int(prev) - count++ - } + defer close(ret) + + // check if from version exists + if from >= 0 { + if err := m.versionExists(suint(from)); err != nil { + ret <- err + return + } + } + + if limit == 0 { + ret <- ErrNoChange + return + } + + // no change if already at nil version + if from == -1 && limit == -1 { + ret <- ErrNoChange + return + } + + // can't go over limit if already at nil version + if from == -1 && limit > 0 { + ret <- os.ErrNotExist + return + } + + count := 0 + for count < limit || limit == -1 { + if m.stop() { + return + } + + prev, err := m.sourceDrv.Prev(suint(from)) + if errors.Is(err, os.ErrNotExist) { + // no limit or haven't reached limit, apply "first" migration + if limit == -1 || limit-count > 0 { + firstVersion, err := m.sourceDrv.First() + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(firstVersion, -1) + if err != nil { + ret <- err + return + } + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + count++ + } + + if count < limit { + ret <- ErrShortLimit{suint(limit - count)} + } + return + } + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(suint(from), int(prev)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + from = int(prev) + count++ + } } // runMigrations reads *Migration and error from a channel. Any other type @@ -723,259 +766,259 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { // to stop execution because it might have received a stop signal on the // GracefulStop channel. func (m *Migrate) runMigrations(ret <-chan interface{}) error { - for r := range ret { - - if m.stop() { - return nil - } - - switch r := r.(type) { - case error: - return r - - case *Migration: - migr := r - - // set version with dirty state - if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil { - return err - } - - if migr.Body != nil { - m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) - if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { - return err - } - } - - // set clean state - if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil { - return err - } - - endTime := time.Now() - readTime := migr.FinishedReading.Sub(migr.StartedBuffering) - runTime := endTime.Sub(migr.FinishedReading) - - // log either verbose or normal - if m.Log != nil { - if m.Log.Verbose() { - m.logPrintf("Finished %v (read %v, ran %v)\n", migr.LogString(), readTime, runTime) - } else { - m.logPrintf("%v (%v)\n", migr.LogString(), readTime+runTime) - } - } - - default: - return fmt.Errorf("unknown type: %T with value: %+v", r, r) - } - } - return nil + for r := range ret { + + if m.stop() { + return nil + } + + switch r := r.(type) { + case error: + return r + + case *Migration: + migr := r + + // set version with dirty state + if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil { + return err + } + + if migr.Body != nil { + m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) + if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { + return err + } + } + + // set clean state + if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil { + return err + } + + endTime := time.Now() + readTime := migr.FinishedReading.Sub(migr.StartedBuffering) + runTime := endTime.Sub(migr.FinishedReading) + + // log either verbose or normal + if m.Log != nil { + if m.Log.Verbose() { + m.logPrintf("Finished %v (read %v, ran %v)\n", migr.LogString(), readTime, runTime) + } else { + m.logPrintf("%v (%v)\n", migr.LogString(), readTime+runTime) + } + } + + default: + return fmt.Errorf("unknown type: %T with value: %+v", r, r) + } + } + return nil } // versionExists checks the source if either the up or down migration for // the specified migration version exists. func (m *Migrate) versionExists(version uint) (result error) { - // try up migration first - up, _, err := m.sourceDrv.ReadUp(version) - if err == nil { - defer func() { - if errClose := up.Close(); errClose != nil { - result = multierror.Append(result, errClose) - } - }() - } - if errors.Is(err, os.ErrExist) { - return nil - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - - // then try down migration - down, _, err := m.sourceDrv.ReadDown(version) - if err == nil { - defer func() { - if errClose := down.Close(); errClose != nil { - result = multierror.Append(result, errClose) - } - }() - } - if errors.Is(err, os.ErrExist) { - return nil - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - - err = fmt.Errorf("no migration found for version %d: %w", version, err) - m.logErr(err) - return err + // try up migration first + up, _, err := m.sourceDrv.ReadUp(version) + if err == nil { + defer func() { + if errClose := up.Close(); errClose != nil { + result = multierror.Append(result, errClose) + } + }() + } + if errors.Is(err, os.ErrExist) { + return nil + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + + // then try down migration + down, _, err := m.sourceDrv.ReadDown(version) + if err == nil { + defer func() { + if errClose := down.Close(); errClose != nil { + result = multierror.Append(result, errClose) + } + }() + } + if errors.Is(err, os.ErrExist) { + return nil + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + + err = fmt.Errorf("no migration found for version %d: %w", version, err) + m.logErr(err) + return err } // stop returns true if no more migrations should be run against the database // because a stop signal was received on the GracefulStop channel. // Calls are cheap and this function is not blocking. func (m *Migrate) stop() bool { - if m.isGracefulStop { - return true - } - - select { - case <-m.GracefulStop: - m.isGracefulStop = true - return true - - default: - return false - } + if m.isGracefulStop { + return true + } + + select { + case <-m.GracefulStop: + m.isGracefulStop = true + return true + + default: + return false + } } // newMigration is a helper func that returns a *Migration for the // specified version and targetVersion. func (m *Migrate) newMigration(version uint, targetVersion int) (*Migration, error) { - var migr *Migration - - if targetVersion >= int(version) { - r, identifier, err := m.sourceDrv.ReadUp(version) - if errors.Is(err, os.ErrNotExist) { - // create "empty" migration - migr, err = NewMigration(nil, "", version, targetVersion) - if err != nil { - return nil, err - } - - } else if err != nil { - return nil, err - - } else { - // create migration from up source - migr, err = NewMigration(r, identifier, version, targetVersion) - if err != nil { - return nil, err - } - } - - } else { - r, identifier, err := m.sourceDrv.ReadDown(version) - if errors.Is(err, os.ErrNotExist) { - // create "empty" migration - migr, err = NewMigration(nil, "", version, targetVersion) - if err != nil { - return nil, err - } - - } else if err != nil { - return nil, err - - } else { - // create migration from down source - migr, err = NewMigration(r, identifier, version, targetVersion) - if err != nil { - return nil, err - } - } - } - - if m.PrefetchMigrations > 0 && migr.Body != nil { - m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) - } else { - m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) - } - - return migr, nil + var migr *Migration + + if targetVersion >= int(version) { + r, identifier, err := m.sourceDrv.ReadUp(version) + if errors.Is(err, os.ErrNotExist) { + // create "empty" migration + migr, err = NewMigration(nil, "", version, targetVersion) + if err != nil { + return nil, err + } + + } else if err != nil { + return nil, err + + } else { + // create migration from up source + migr, err = NewMigration(r, identifier, version, targetVersion) + if err != nil { + return nil, err + } + } + + } else { + r, identifier, err := m.sourceDrv.ReadDown(version) + if errors.Is(err, os.ErrNotExist) { + // create "empty" migration + migr, err = NewMigration(nil, "", version, targetVersion) + if err != nil { + return nil, err + } + + } else if err != nil { + return nil, err + + } else { + // create migration from down source + migr, err = NewMigration(r, identifier, version, targetVersion) + if err != nil { + return nil, err + } + } + } + + if m.PrefetchMigrations > 0 && migr.Body != nil { + m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) + } else { + m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) + } + + return migr, nil } // lock is a thread safe helper function to lock the database. // It should be called as late as possible when running migrations. func (m *Migrate) lock() error { - m.isLockedMu.Lock() - defer m.isLockedMu.Unlock() - - if m.isLocked { - return ErrLocked - } - - // create done channel, used in the timeout goroutine - done := make(chan bool, 1) - defer func() { - done <- true - }() - - // use errchan to signal error back to this context - errchan := make(chan error, 2) - - // start timeout goroutine - timeout := time.After(m.LockTimeout) - go func() { - for { - select { - case <-done: - return - case <-timeout: - errchan <- ErrLockTimeout - return - } - } - }() - - // now try to acquire the lock - go func() { - if err := m.databaseDrv.Lock(); err != nil { - errchan <- err - } else { - errchan <- nil - } - }() - - // wait until we either receive ErrLockTimeout or error from Lock operation - err := <-errchan - if err == nil { - m.isLocked = true - } - return err + m.isLockedMu.Lock() + defer m.isLockedMu.Unlock() + + if m.isLocked { + return ErrLocked + } + + // create done channel, used in the timeout goroutine + done := make(chan bool, 1) + defer func() { + done <- true + }() + + // use errchan to signal error back to this context + errchan := make(chan error, 2) + + // start timeout goroutine + timeout := time.After(m.LockTimeout) + go func() { + for { + select { + case <-done: + return + case <-timeout: + errchan <- ErrLockTimeout + return + } + } + }() + + // now try to acquire the lock + go func() { + if err := m.databaseDrv.Lock(); err != nil { + errchan <- err + } else { + errchan <- nil + } + }() + + // wait until we either receive ErrLockTimeout or error from Lock operation + err := <-errchan + if err == nil { + m.isLocked = true + } + return err } // unlock is a thread safe helper function to unlock the database. // It should be called as early as possible when no more migrations are // expected to be executed. func (m *Migrate) unlock() error { - m.isLockedMu.Lock() - defer m.isLockedMu.Unlock() + m.isLockedMu.Lock() + defer m.isLockedMu.Unlock() - if err := m.databaseDrv.Unlock(); err != nil { - // BUG: Can potentially create a deadlock. Add a timeout. - return err - } + if err := m.databaseDrv.Unlock(); err != nil { + // BUG: Can potentially create a deadlock. Add a timeout. + return err + } - m.isLocked = false - return nil + m.isLocked = false + return nil } // unlockErr calls unlock and returns a combined error // if a prevErr is not nil. func (m *Migrate) unlockErr(prevErr error) error { - if err := m.unlock(); err != nil { - return multierror.Append(prevErr, err) - } - return prevErr + if err := m.unlock(); err != nil { + return multierror.Append(prevErr, err) + } + return prevErr } // logPrintf writes to m.Log if not nil func (m *Migrate) logPrintf(format string, v ...interface{}) { - if m.Log != nil { - m.Log.Printf(format, v...) - } + if m.Log != nil { + m.Log.Printf(format, v...) + } } // logVerbosePrintf writes to m.Log if not nil. Use for verbose logging output. func (m *Migrate) logVerbosePrintf(format string, v ...interface{}) { - if m.Log != nil && m.Log.Verbose() { - m.Log.Printf(format, v...) - } + if m.Log != nil && m.Log.Verbose() { + m.Log.Printf(format, v...) + } } // logErr writes error to m.Log if not nil func (m *Migrate) logErr(err error) { - if m.Log != nil { - m.Log.Printf("error: %v", err) - } + if m.Log != nil { + m.Log.Printf("error: %v", err) + } } diff --git a/migrate_goto_temp.go b/migrate_goto_temp.go new file mode 100644 index 000000000..7e5f181d1 --- /dev/null +++ b/migrate_goto_temp.go @@ -0,0 +1,163 @@ +package migrate + +import ( + "io" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +// Define a constant for the migration file name +const lastSuccessfulMigrationFile = "lastSuccessfulMigration" + +func (m *Migrate) HandleDirtyState() error { + // Perform actions when the database state is dirty + lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) + lastVersionBytes, err := os.ReadFile(lastSuccessfulMigrationPath) + if err != nil { + return err + } + lastVersionStr := strings.TrimSpace(string(lastVersionBytes)) + lastVersion, err := strconv.ParseUint(lastVersionStr, 10, 64) + if err != nil { + return errors.Wrap(err, "failed to parse last successful migration version") + } + + if err = m.Force(int(lastVersion)); err != nil { + return errors.Wrap(err, "failed to apply last successful migration") + } + + m.Log.Printf("Successfully applied migration: %s", lastVersionStr) + + if err = os.Remove(lastSuccessfulMigrationPath); err != nil { + return err + } + + m.Log.Printf("Successfully deleted file: %s", lastSuccessfulMigrationPath) + return nil +} + +func (m *Migrate) HandleMigrationFailure(curVersion int, v uint) error { + failedVersion, _, err := m.databaseDrv.Version() + if err != nil { + return err + } + + // Determine the last successful migration + lastSuccessfulMigration := strconv.Itoa(curVersion) + ret := make(chan interface{}, m.PrefetchMigrations) + go m.read(curVersion, int(v), ret) + + for r := range ret { + mig, ok := r.(*Migration) + if ok { + if mig.Version == uint(failedVersion) { + break + } + lastSuccessfulMigration = strconv.Itoa(int(mig.Version)) + } + } + + m.Log.Printf("migration failed, last successful migration version: %s", lastSuccessfulMigration) + lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) + if err = os.WriteFile(lastSuccessfulMigrationPath, []byte(lastSuccessfulMigration), 0644); err != nil { + return err + } + + return nil +} + +func (m *Migrate) CleanupFiles(v uint) error { + if m.ds.destPath == "" { + return nil + } + files, err := os.ReadDir(m.ds.destPath) + if err != nil { + return err + } + + targetVersion := uint64(v) + + for _, file := range files { + fileName := file.Name() + + // Check if file is a migration file we want to process + if !strings.HasSuffix(fileName, "down.sql") && !strings.HasSuffix(fileName, "up.sql") { + continue + } + + // Extract version and compare + versionEnd := strings.Index(fileName, "_") + if versionEnd == -1 { + // Skip files that don't match the expected naming pattern + continue + } + + fileVersion, err := strconv.ParseUint(fileName[:versionEnd], 10, 64) + if err != nil { + m.Log.Printf("Skipping file %s due to version parse error: %v", fileName, err) + continue + } + + // Delete file if version is greater than targetVersion + if fileVersion > targetVersion { + if err = os.Remove(filepath.Join(m.ds.destPath, fileName)); err != nil { + m.Log.Printf("Failed to delete file %s: %v", fileName, err) + continue + } + m.Log.Printf("Deleted file: %s", fileName) + } + } + + return nil +} + +// CopyFiles copies all files from srcDir to destDir. +func (m *Migrate) CopyFiles() error { + if m.ds.destPath == "" { + return nil + } + _, err := os.ReadDir(m.ds.destPath) + if err != nil { + // If the directory does not exist + return err + } + + return filepath.Walk(m.ds.srcPath, func(src string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // ignore sub-directories in the migration directory + if info.IsDir() { + return nil + } + var ( + srcFile *os.File + destFile *os.File + ) + dest := filepath.Join(m.ds.destPath, info.Name()) + if srcFile, err = os.Open(src); err != nil { + return err + } + defer func(srcFile *os.File) { + if err = srcFile.Close(); err != nil { + m.Log.Printf("failed to close file %s: %s", srcFile.Name, err) + } + }(srcFile) + + // Create the destination file + if destFile, err = os.Create(dest); err != nil { + return err + } + + // Copy the file + if _, err = io.Copy(destFile, srcFile); err == nil { + return err + } + return os.Chmod(dest, info.Mode()) + }) +} From 98be6781ebd8db4d73497af21298ad1a66c31f4c Mon Sep 17 00:00:00 2001 From: Venkat Venkatasubramanian Date: Fri, 30 Aug 2024 09:34:39 -0700 Subject: [PATCH 2/4] First pass at downmigrate --- internal/cli/main.go | 17 ++++++++++++++++- migrate.go | 38 ++++++++++++++++++++++++++++---------- migrate_goto_temp.go | 13 +++++++++++++ test/main.go | 1 + 4 files changed, 58 insertions(+), 11 deletions(-) create mode 100644 test/main.go diff --git a/internal/cli/main.go b/internal/cli/main.go index a0084dbd0..74459da29 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -270,11 +270,26 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU } handleDirty := viper.GetBool("dirty") destPath := viper.GetString("intermediate-path") + srcPath := "" + // if sourcePtr is set, use it to get the source path + // otherwise, use the path flag + if path != "" { + srcPath = path + } + if sourcePtr != "" { + // parse the source path from the source argument + parse, err := url.Parse(sourcePtr) + if err != nil { + log.fatal("error: can't parse the source path from the source argument") + } + srcPath = parse.Path + } if handleDirty && destPath == "" { log.fatal("error: intermediate-path must be specified when dirty is set") } - migrater.WithDirtyStateHandler(path, destPath, handleDirty) + log.Printf("running goto with handleDirty: %t, destPath: %s, srcPath: %s\n", handleDirty, destPath, srcPath) + migrater.WithDirtyStateHandler(srcPath, destPath, handleDirty) if err = gotoCmd(migrater, uint(v)); err != nil { log.fatalErr(err) } diff --git a/migrate.go b/migrate.go index ccbd449f0..e7631c843 100644 --- a/migrate.go +++ b/migrate.go @@ -123,6 +123,15 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { return m, nil } +func (m *Migrate) updateSourceDrv(sourceURL string) error { + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv + return nil +} + // NewWithDatabaseInstance returns a new Migrate instance from a source URL // and an existing database instance. The source URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as databaseName. @@ -231,16 +240,30 @@ func (m *Migrate) Close() (source error, database error) { func (m *Migrate) Migrate(version uint) error { curVersion, dirty, err := m.databaseDrv.Version() if err != nil { + m.Log.Printf("******************Failed to get current version: %v\n", err) return err } + if err = m.CopyFiles(); err != nil { return err } + + m.Log.Printf("Current version: %d, dirty: %t\n", curVersion, dirty) // if the dirty flag is passed to the 'goto' command, handle the dirty state - if m.ds.isDirty && dirty { - m.Log.Printf("Version: %d, handle dirty: %t\n", version, m.ds.isDirty) - if err = m.HandleDirtyState(); err != nil { - return err + if dirty { + if m.ds.isDirty { + m.Log.Printf("Version: %d, handle dirty: %t\n", version, m.ds.isDirty) + if err = m.HandleDirtyState(); err != nil { + return err + } + if err = m.updateSourceDrv(fmt.Sprintf("file://%s", m.ds.destPath)); err != nil { + return err + } + + } else { + // default behaviour + m.Log.Printf("Database is set to dirty for version: %v\n", curVersion) + return ErrDirty{curVersion} } } @@ -248,12 +271,6 @@ func (m *Migrate) Migrate(version uint) error { return err } - if dirty { - // default behaviour - m.Log.Printf("Database is set to dirty for version: %v\n", curVersion) - return m.unlockErr(ErrDirty{curVersion}) - } - ret := make(chan interface{}, m.PrefetchMigrations) go m.read(curVersion, int(version), ret) @@ -766,6 +783,7 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { // to stop execution because it might have received a stop signal on the // GracefulStop channel. func (m *Migrate) runMigrations(ret <-chan interface{}) error { + m.Log.Printf("Starting %s migrations\n", m.sourceDrv) for r := range ret { if m.stop() { diff --git a/migrate_goto_temp.go b/migrate_goto_temp.go index 7e5f181d1..517d63adb 100644 --- a/migrate_goto_temp.go +++ b/migrate_goto_temp.go @@ -126,6 +126,8 @@ func (m *Migrate) CopyFiles() error { return err } + m.Log.Printf("Copying files from %s to %s", m.ds.srcPath, m.ds.destPath) + return filepath.Walk(m.ds.srcPath, func(src string, info os.FileInfo, err error) error { if err != nil { return err @@ -133,8 +135,19 @@ func (m *Migrate) CopyFiles() error { // ignore sub-directories in the migration directory if info.IsDir() { + // Skip the tests directory and its files + if info.Name() == "tests" { + m.Log.Printf("Ignoring directory %s", info.Name()) + return filepath.SkipDir + } return nil } + // Ignore the current.sql file + if info.Name() == "current.sql" { + m.Log.Printf("Ignoring file %s", info.Name()) + return nil + } + var ( srcFile *os.File destFile *os.File diff --git a/test/main.go b/test/main.go new file mode 100644 index 000000000..56e540407 --- /dev/null +++ b/test/main.go @@ -0,0 +1 @@ +package test From 8ab5a721469a9ba06c27b090580dac6a2f17c64d Mon Sep 17 00:00:00 2001 From: Venkat Venkatasubramanian Date: Mon, 23 Sep 2024 08:03:36 -0700 Subject: [PATCH 3/4] downmigrate changes + UTs + cleanups --- cmd/migrate/config.go | 58 +- internal/cli/commands.go | 358 ++++---- internal/cli/main.go | 764 ++++++++-------- migrate.go | 1774 +++++++++++++++++++++----------------- migrate_goto_temp.go | 176 ---- migrate_test.go | 321 +++++++ test/main.go | 1 - 7 files changed, 1878 insertions(+), 1574 deletions(-) delete mode 100644 migrate_goto_temp.go delete mode 100644 test/main.go diff --git a/cmd/migrate/config.go b/cmd/migrate/config.go index ff4c3c46d..de812156a 100644 --- a/cmd/migrate/config.go +++ b/cmd/migrate/config.go @@ -3,40 +3,40 @@ package main import "github.com/spf13/pflag" const ( - // configuration defaults support local development (i.e. "go run ...") - defaultDatabaseDSN = "" - defaultDatabaseDriver = "postgres" - defaultDatabaseAddress = "0.0.0.0:5432" - defaultDatabaseName = "" - defaultDatabaseUser = "postgres" - defaultDatabasePassword = "postgres" - defaultDatabaseSSL = "disable" - defaultConfigDirectory = "/cli/config" + // configuration defaults support local development (i.e. "go run ...") + defaultDatabaseDSN = "" + defaultDatabaseDriver = "postgres" + defaultDatabaseAddress = "0.0.0.0:5432" + defaultDatabaseName = "" + defaultDatabaseUser = "postgres" + defaultDatabasePassword = "postgres" + defaultDatabaseSSL = "disable" + defaultConfigDirectory = "/cli/config" ) var ( - // define flag overrides - flagHelp = pflag.Bool("help", false, "Print usage") - flagVersion = pflag.String("version", Version, "Print version") - flagLoggingVerbose = pflag.Bool("verbose", true, "Print verbose logging") - flagPrefetch = pflag.Uint("prefetch", 10, "Number of migrations to load in advance before executing") - flaglockTimeout = pflag.Uint("lock-timeout", 15, "Allow N seconds to acquire database lock") + // define flag overrides + flagHelp = pflag.Bool("help", false, "Print usage") + flagVersion = pflag.String("version", Version, "Print version") + flagLoggingVerbose = pflag.Bool("verbose", true, "Print verbose logging") + flagPrefetch = pflag.Uint("prefetch", 10, "Number of migrations to load in advance before executing") + flaglockTimeout = pflag.Uint("lock-timeout", 15, "Allow N seconds to acquire database lock") - flagDatabaseDSN = pflag.String("database.dsn", defaultDatabaseDSN, "database connection string") - flagDatabaseDriver = pflag.String("database.driver", defaultDatabaseDriver, "database driver") - flagDatabaseAddress = pflag.String("database.address", defaultDatabaseAddress, "address of the database") - flagDatabaseName = pflag.String("database.name", defaultDatabaseName, "name of the database") - flagDatabaseUser = pflag.String("database.user", defaultDatabaseUser, "database username") - flagDatabasePassword = pflag.String("database.password", defaultDatabasePassword, "database password") - flagDatabaseSSL = pflag.String("database.ssl", defaultDatabaseSSL, "database ssl mode") + flagDatabaseDSN = pflag.String("database.dsn", defaultDatabaseDSN, "database connection string") + flagDatabaseDriver = pflag.String("database.driver", defaultDatabaseDriver, "database driver") + flagDatabaseAddress = pflag.String("database.address", defaultDatabaseAddress, "address of the database") + flagDatabaseName = pflag.String("database.name", defaultDatabaseName, "name of the database") + flagDatabaseUser = pflag.String("database.user", defaultDatabaseUser, "database username") + flagDatabasePassword = pflag.String("database.password", defaultDatabasePassword, "database password") + flagDatabaseSSL = pflag.String("database.ssl", defaultDatabaseSSL, "database ssl mode") - flagSource = pflag.String("source", "", "Location of the migrations (driver://url)") - flagPath = pflag.String("path", "", "Shorthand for -source=file://path") + flagSource = pflag.String("source", "", "Location of the migrations (driver://url)") + flagPath = pflag.String("path", "", "Shorthand for -source=file://path") - flagConfigDirectory = pflag.String("config.source", defaultConfigDirectory, "directory of the configuration file") - flagConfigFile = pflag.String("config.file", "", "configuration file name without extension") + flagConfigDirectory = pflag.String("config.source", defaultConfigDirectory, "directory of the configuration file") + flagConfigFile = pflag.String("config.file", "", "configuration file name without extension") - // goto command flags - flagDirty = pflag.Bool("dirty", false, "migration is dirty") - flagPVCPath = pflag.String("intermediate-path", "", "path to the mounted volume which is used to copy the migration files") + // goto command flags + flagDirty = pflag.Bool("dirty", false, "migration is dirty") + flagPVCPath = pflag.String("intermediate-path", "", "path to the mounted volume which is used to copy the migration files") ) diff --git a/internal/cli/commands.go b/internal/cli/commands.go index 868938f8e..7adec2f84 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -1,248 +1,248 @@ package cli import ( - "errors" - "fmt" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/golang-migrate/migrate/v4" - _ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again - _ "github.com/golang-migrate/migrate/v4/source/file" + "errors" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again + _ "github.com/golang-migrate/migrate/v4/source/file" ) var ( - errInvalidSequenceWidth = errors.New("Digits must be positive") - errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive") - errInvalidTimeFormat = errors.New("Time format may not be empty") + errInvalidSequenceWidth = errors.New("Digits must be positive") + errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive") + errInvalidTimeFormat = errors.New("Time format may not be empty") ) func nextSeqVersion(matches []string, seqDigits int) (string, error) { - if seqDigits <= 0 { - return "", errInvalidSequenceWidth - } + if seqDigits <= 0 { + return "", errInvalidSequenceWidth + } - nextSeq := uint64(1) + nextSeq := uint64(1) - if len(matches) > 0 { - filename := matches[len(matches)-1] - matchSeqStr := filepath.Base(filename) - idx := strings.Index(matchSeqStr, "_") + if len(matches) > 0 { + filename := matches[len(matches)-1] + matchSeqStr := filepath.Base(filename) + idx := strings.Index(matchSeqStr, "_") - if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit - return "", fmt.Errorf("Malformed migration filename: %s", filename) - } + if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit + return "", fmt.Errorf("Malformed migration filename: %s", filename) + } - var err error - matchSeqStr = matchSeqStr[0:idx] - nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64) + var err error + matchSeqStr = matchSeqStr[0:idx] + nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64) - if err != nil { - return "", err - } + if err != nil { + return "", err + } - nextSeq++ - } + nextSeq++ + } - version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits) + version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits) - if len(version) > seqDigits { - return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits) - } + if len(version) > seqDigits { + return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits) + } - return version, nil + return version, nil } func timeVersion(startTime time.Time, format string) (version string, err error) { - switch format { - case "": - err = errInvalidTimeFormat - case "unix": - version = strconv.FormatInt(startTime.Unix(), 10) - case "unixNano": - version = strconv.FormatInt(startTime.UnixNano(), 10) - default: - version = startTime.Format(format) - } - - return + switch format { + case "": + err = errInvalidTimeFormat + case "unix": + version = strconv.FormatInt(startTime.Unix(), 10) + case "unixNano": + version = strconv.FormatInt(startTime.UnixNano(), 10) + default: + version = startTime.Format(format) + } + + return } // createCmd (meant to be called via a CLI command) creates a new migration func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int, print bool) error { - if seq && format != defaultTimeFormat { - return errIncompatibleSeqAndFormat - } + if seq && format != defaultTimeFormat { + return errIncompatibleSeqAndFormat + } - var version string - var err error + var version string + var err error - dir = filepath.Clean(dir) - ext = "." + strings.TrimPrefix(ext, ".") + dir = filepath.Clean(dir) + ext = "." + strings.TrimPrefix(ext, ".") - if seq { - matches, err := filepath.Glob(filepath.Join(dir, "*"+ext)) + if seq { + matches, err := filepath.Glob(filepath.Join(dir, "*"+ext)) - if err != nil { - return err - } + if err != nil { + return err + } - version, err = nextSeqVersion(matches, seqDigits) + version, err = nextSeqVersion(matches, seqDigits) - if err != nil { - return err - } - } else { - version, err = timeVersion(startTime, format) + if err != nil { + return err + } + } else { + version, err = timeVersion(startTime, format) - if err != nil { - return err - } - } + if err != nil { + return err + } + } - versionGlob := filepath.Join(dir, version+"_*"+ext) - matches, err := filepath.Glob(versionGlob) + versionGlob := filepath.Join(dir, version+"_*"+ext) + matches, err := filepath.Glob(versionGlob) - if err != nil { - return err - } + if err != nil { + return err + } - if len(matches) > 0 { - return fmt.Errorf("duplicate migration version: %s", version) - } + if len(matches) > 0 { + return fmt.Errorf("duplicate migration version: %s", version) + } - if err = os.MkdirAll(dir, os.ModePerm); err != nil { - return err - } + if err = os.MkdirAll(dir, os.ModePerm); err != nil { + return err + } - for _, direction := range []string{"up", "down"} { - basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext) - filename := filepath.Join(dir, basename) + for _, direction := range []string{"up", "down"} { + basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext) + filename := filepath.Join(dir, basename) - if err = createFile(filename); err != nil { - return err - } + if err = createFile(filename); err != nil { + return err + } - if print { - absPath, _ := filepath.Abs(filename) - log.Println(absPath) - } - } + if print { + absPath, _ := filepath.Abs(filename) + log.Println(absPath) + } + } - return nil + return nil } func createFile(filename string) error { - // create exclusive (fails if file already exists) - // os.Create() specifies 0666 as the FileMode, so we're doing the same - f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666) + // create exclusive (fails if file already exists) + // os.Create() specifies 0666 as the FileMode, so we're doing the same + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666) - if err != nil { - return err - } + if err != nil { + return err + } - return f.Close() + return f.Close() } func gotoCmd(m *migrate.Migrate, v uint) error { - if err := m.Migrate(v); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - return nil + if err := m.Migrate(v); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + return nil } func upCmd(m *migrate.Migrate, limit int) error { - if limit >= 0 { - if err := m.Steps(limit); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } else { - if err := m.Up(); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } - return nil + if limit >= 0 { + if err := m.Steps(limit); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } else { + if err := m.Up(); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } + return nil } func downCmd(m *migrate.Migrate, limit int) error { - if limit >= 0 { - if err := m.Steps(-limit); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } else { - if err := m.Down(); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } - return nil + if limit >= 0 { + if err := m.Steps(-limit); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } else { + if err := m.Down(); err != nil { + if err != migrate.ErrNoChange { + return err + } + log.Println(err) + } + } + return nil } func dropCmd(m *migrate.Migrate) error { - if err := m.Drop(); err != nil { - return err - } - return nil + if err := m.Drop(); err != nil { + return err + } + return nil } func forceCmd(m *migrate.Migrate, v int) error { - if err := m.Force(v); err != nil { - return err - } - return nil + if err := m.Force(v); err != nil { + return err + } + return nil } func versionCmd(m *migrate.Migrate) error { - v, dirty, err := m.Version() - if err != nil { - return err - } - if dirty { - log.Printf("%v (dirty)\n", v) - } else { - log.Println(v) - } - return nil + v, dirty, err := m.Version() + if err != nil { + return err + } + if dirty { + log.Printf("%v (dirty)\n", v) + } else { + log.Println(v) + } + return nil } // numDownMigrationsFromArgs returns an int for number of migrations to apply // and a bool indicating if we need a confirm before applying func numDownMigrationsFromArgs(applyAll bool, args []string) (int, bool, error) { - if applyAll { - if len(args) > 0 { - return 0, false, errors.New("-all cannot be used with other arguments") - } - return -1, false, nil - } - - switch len(args) { - case 0: - return -1, true, nil - case 1: - downValue := args[0] - n, err := strconv.ParseUint(downValue, 10, 64) - if err != nil { - return 0, false, errors.New("can't read limit argument N") - } - return int(n), false, nil - default: - return 0, false, errors.New("too many arguments") - } + if applyAll { + if len(args) > 0 { + return 0, false, errors.New("-all cannot be used with other arguments") + } + return -1, false, nil + } + + switch len(args) { + case 0: + return -1, true, nil + case 1: + downValue := args[0] + n, err := strconv.ParseUint(downValue, 10, 64) + if err != nil { + return 0, false, errors.New("can't read limit argument N") + } + return int(n), false, nil + default: + return 0, false, errors.New("too many arguments") + } } diff --git a/internal/cli/main.go b/internal/cli/main.go index 74459da29..5158de116 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -1,103 +1,96 @@ package cli import ( - "database/sql" - "fmt" - "net/url" - "os" - "os/signal" - "strconv" - "strings" - "syscall" - "time" - - flag "github.com/spf13/pflag" - "github.com/spf13/viper" - - "github.com/golang-migrate/migrate/v4" - "github.com/golang-migrate/migrate/v4/database" - "github.com/golang-migrate/migrate/v4/database/postgres" - "github.com/golang-migrate/migrate/v4/source" + "database/sql" + "fmt" + "net/url" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + "time" + + flag "github.com/spf13/pflag" + "github.com/spf13/viper" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source" ) const ( - defaultTimeFormat = "20060102150405" - defaultTimezone = "UTC" - createUsage = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME + defaultTimeFormat = "20060102150405" + defaultTimezone = "UTC" + createUsage = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME Create a set of timestamped up/down migrations titled NAME, in directory D with extension E. Use -seq option to generate sequential up/down migrations with N digits. Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error. Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC). ` - gotoUsage = `goto V [-dirty] Migrate to version V` - upUsage = `up [N] Apply all or N up migrations` - downUsage = `down [N] [-all] Apply all or N down migrations + gotoUsage = `goto V [-dirty] [-intermediate-path] Migrate to version V` + upUsage = `up [N] Apply all or N up migrations` + downUsage = `down [N] [-all] Apply all or N down migrations Use -all to apply all down migrations` - dropUsage = `drop [-f] Drop everything inside database + dropUsage = `drop [-f] Drop everything inside database Use -f to bypass confirmation` - forceUsage = `force V Set version V but don't run migration (ignores dirty state)` + forceUsage = `force V Set version V but don't run migration (ignores dirty state)` ) func handleSubCmdHelp(help bool, usage string, flagSet *flag.FlagSet) { - if help { - fmt.Fprintln(os.Stderr, usage) - flagSet.PrintDefaults() - os.Exit(0) - } + if help { + fmt.Fprintln(os.Stderr, usage) + flagSet.PrintDefaults() + os.Exit(0) + } } func newFlagSetWithHelp(name string) (*flag.FlagSet, *bool) { - flagSet := flag.NewFlagSet(name, flag.ExitOnError) - helpPtr := flagSet.Bool("help", false, "Print help information") - return flagSet, helpPtr -} - -func newGoToFlagSetWithHelp(name string) (*flag.FlagSet, *bool) { - flagSet := flag.NewFlagSet(name, flag.ExitOnError) - flagSet.Bool("dirty", false, "Migration in dirty state") - helpPtr := flagSet.Bool("help", false, "Print help information") - return flagSet, helpPtr + flagSet := flag.NewFlagSet(name, flag.ExitOnError) + helpPtr := flagSet.Bool("help", false, "Print help information") + return flagSet, helpPtr } // set main log var log = &Log{} func printUsageAndExit() { - flag.Usage() + flag.Usage() - // If a command is not found we exit with a status 2 to match the behavior - // of flag.Parse() with flag.ExitOnError when parsing an invalid flag. - os.Exit(2) + // If a command is not found we exit with a status 2 to match the behavior + // of flag.Parse() with flag.ExitOnError when parsing an invalid flag. + os.Exit(2) } func dbMakeConnectionString(driver, user, password, address, name, ssl string) string { - return fmt.Sprintf("%s://%s:%s@%s/%s?sslmode=%s", - driver, url.QueryEscape(user), url.QueryEscape(password), address, name, ssl, - ) + return fmt.Sprintf("%s://%s:%s@%s/%s?sslmode=%s", + driver, url.QueryEscape(user), url.QueryEscape(password), address, name, ssl, + ) } // Main function of a cli application. It is public for backwards compatibility with `cli` package func Main(version string) { - help := viper.GetBool("help") - version = viper.GetString("version") - verbose := viper.GetBool("verbose") - prefetch := viper.GetInt("prefetch") - lockTimeout := viper.GetInt("lock-timeout") - path := viper.GetString("path") - sourcePtr := viper.GetString("source") - - databasePtr := viper.GetString("database.dsn") - if databasePtr == "" { - databasePtr = dbMakeConnectionString( - viper.GetString("database.driver"), viper.GetString("database.user"), - viper.GetString("database.password"), viper.GetString("database.address"), - viper.GetString("database.name"), viper.GetString("database.ssl"), - ) - } - - flag.Usage = func() { - fmt.Fprintf(os.Stderr, - `Usage: migrate OPTIONS COMMAND [arg...] + help := viper.GetBool("help") + version = viper.GetString("version") + verbose := viper.GetBool("verbose") + prefetch := viper.GetInt("prefetch") + lockTimeout := viper.GetInt("lock-timeout") + path := viper.GetString("path") + sourcePtr := viper.GetString("source") + + databasePtr := viper.GetString("database.dsn") + if databasePtr == "" { + databasePtr = dbMakeConnectionString( + viper.GetString("database.driver"), viper.GetString("database.user"), + viper.GetString("database.password"), viper.GetString("database.address"), + viper.GetString("database.name"), viper.GetString("database.ssl"), + ) + } + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, + `Usage: migrate OPTIONS COMMAND [arg...] migrate [ -version | -help ] Options: @@ -132,323 +125,322 @@ Commands: Source drivers: `+strings.Join(source.List(), ", ")+` Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoUsage, upUsage, downUsage, dropUsage, forceUsage) - } - - // initialize logger - log.verbose = verbose - - // show cli version - if version == "" { - fmt.Fprintln(os.Stderr, version) - os.Exit(0) - } - - // show help - if help { - flag.Usage() - os.Exit(0) - } - - // translate -path into -source if given - if sourcePtr == "" && path != "" { - sourcePtr = fmt.Sprintf("file://%v", path) - } - - // initialize migrate - // don't catch migraterErr here and let each command decide - // how it wants to handle the error - var migrater *migrate.Migrate - var migraterErr error - - if driver := viper.GetString("database.driver"); driver == "hotload" { - db, err := sql.Open(driver, databasePtr) - if err != nil { - log.fatalErr(fmt.Errorf("could not open hotload dsn %s: %s", databasePtr, err)) - } - var dbname, user string - if err := db.QueryRow("SELECT current_database(), user").Scan(&dbname, &user); err != nil { - log.fatalErr(fmt.Errorf("could not get current_database: %s", err.Error())) - } - // dbname is not needed since it gets filled in by the driver but we want to be complete - migrateDriver, err := postgres.WithInstance(db, &postgres.Config{DatabaseName: dbname}) - if err != nil { - log.fatalErr(fmt.Errorf("could not create migrate driver: %s", err)) - } - migrater, migraterErr = migrate.NewWithDatabaseInstance(sourcePtr, dbname, migrateDriver) - } else { - migrater, migraterErr = migrate.New(sourcePtr, databasePtr) - } - defer func() { - if migraterErr == nil { - if _, err := migrater.Close(); err != nil { - log.Println(err) - } - } - }() - if migraterErr == nil { - migrater.Log = log - migrater.PrefetchMigrations = uint(prefetch) - migrater.LockTimeout = time.Duration(int64(lockTimeout)) * time.Second - - // handle Ctrl+c - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT) - go func() { - for range signals { - log.Println("Stopping after this running migration ...") - migrater.GracefulStop <- true - return - } - }() - } - - startTime := time.Now() - - if len(flag.Args()) < 1 { - printUsageAndExit() - } - args := flag.Args()[1:] - - switch flag.Arg(0) { - case "create": - - seq := false - seqDigits := 6 - - createFlagSet, help := newFlagSetWithHelp("create") - extPtr := createFlagSet.String("ext", "", "File extension") - dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)") - formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`) - timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`) - createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)") - createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)") - - if err := createFlagSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*help, createUsage, createFlagSet) - - if createFlagSet.NArg() == 0 { - log.fatal("error: please specify name") - } - name := createFlagSet.Arg(0) - - if *extPtr == "" { - log.fatal("error: -ext flag must be specified") - } - - timezone, err := time.LoadLocation(*timezoneName) - if err != nil { - log.fatal(err) - } - - if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil { - log.fatalErr(err) - } - - case "goto": - - gotoSet, helpPtr := newFlagSetWithHelp("goto") - - if err := gotoSet.Parse(args); err != nil { - log.fatalErr(err) - } - handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if gotoSet.NArg() == 0 { - log.fatal("error: please specify version argument V") - } - - v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64) - if err != nil { - log.fatal("error: can't read version argument V") - } - handleDirty := viper.GetBool("dirty") - destPath := viper.GetString("intermediate-path") - srcPath := "" - // if sourcePtr is set, use it to get the source path - // otherwise, use the path flag - if path != "" { - srcPath = path - } - if sourcePtr != "" { - // parse the source path from the source argument - parse, err := url.Parse(sourcePtr) - if err != nil { - log.fatal("error: can't parse the source path from the source argument") - } - srcPath = parse.Path - } - - if handleDirty && destPath == "" { - log.fatal("error: intermediate-path must be specified when dirty is set") - } - log.Printf("running goto with handleDirty: %t, destPath: %s, srcPath: %s\n", handleDirty, destPath, srcPath) - migrater.WithDirtyStateHandler(srcPath, destPath, handleDirty) - if err = gotoCmd(migrater, uint(v)); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "up": - upSet, helpPtr := newFlagSetWithHelp("up") - - if err := upSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, upUsage, upSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - limit := -1 - if upSet.NArg() > 0 { - n, err := strconv.ParseUint(upSet.Arg(0), 10, 64) - if err != nil { - log.fatal("error: can't read limit argument N") - } - limit = int(n) - } - - if err := upCmd(migrater, limit); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "down": - downFlagSet, helpPtr := newFlagSetWithHelp("down") - applyAll := downFlagSet.Bool("all", false, "Apply all down migrations") - if err := downFlagSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, downUsage, downFlagSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - downArgs := downFlagSet.Args() - - log.Println(*applyAll, downArgs) - - num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs) - if err != nil { - log.fatalErr(err) - } - if needsConfirm { - log.Println("Are you sure you want to apply all down migrations? [y/N]") - var response string - _, _ = fmt.Scanln(&response) - response = strings.ToLower(strings.TrimSpace(response)) - - if response == "y" { - log.Println("Applying all down migrations") - } else { - log.fatal("Not applying all down migrations") - } - } - - if err := downCmd(migrater, num); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "drop": - dropFlagSet, help := newFlagSetWithHelp("drop") - forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt") - - if err := dropFlagSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*help, dropUsage, dropFlagSet) - - if !*forceDrop { - log.Println("Are you sure you want to drop the entire database schema? [y/N]") - var response string - _, _ = fmt.Scanln(&response) - response = strings.ToLower(strings.TrimSpace(response)) - - if response == "y" { - log.Println("Dropping the entire database schema") - } else { - log.fatal("Aborted dropping the entire database schema") - } - } - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if err := dropCmd(migrater); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "force": - forceSet, helpPtr := newFlagSetWithHelp("force") - - if err := forceSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, forceUsage, forceSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if forceSet.NArg() == 0 { - log.fatal("error: please specify version argument V") - } - - v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64) - if err != nil { - log.fatal("error: can't read version argument V") - } - - if v < -1 { - log.fatal("error: argument V must be >= -1") - } - - if err := forceCmd(migrater, int(v)); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } - - case "version": - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if err := versionCmd(migrater); err != nil { - log.fatalErr(err) - } - - default: - printUsageAndExit() - } + } + + // initialize logger + log.verbose = verbose + + // show cli version + if version == "" { + fmt.Fprintln(os.Stderr, version) + os.Exit(0) + } + + // show help + if help { + flag.Usage() + os.Exit(0) + } + + // translate -path into -source if given + if sourcePtr == "" && path != "" { + sourcePtr = fmt.Sprintf("file://%v", path) + } + + // initialize migrate + // don't catch migraterErr here and let each command decide + // how it wants to handle the error + var migrater *migrate.Migrate + var migraterErr error + + if driver := viper.GetString("database.driver"); driver == "hotload" { + db, err := sql.Open(driver, databasePtr) + if err != nil { + log.fatalErr(fmt.Errorf("could not open hotload dsn %s: %s", databasePtr, err)) + } + var dbname, user string + if err := db.QueryRow("SELECT current_database(), user").Scan(&dbname, &user); err != nil { + log.fatalErr(fmt.Errorf("could not get current_database: %s", err.Error())) + } + // dbname is not needed since it gets filled in by the driver but we want to be complete + migrateDriver, err := postgres.WithInstance(db, &postgres.Config{DatabaseName: dbname}) + if err != nil { + log.fatalErr(fmt.Errorf("could not create migrate driver: %s", err)) + } + migrater, migraterErr = migrate.NewWithDatabaseInstance(sourcePtr, dbname, migrateDriver) + } else { + migrater, migraterErr = migrate.New(sourcePtr, databasePtr) + } + defer func() { + if migraterErr == nil { + if _, err := migrater.Close(); err != nil { + log.Println(err) + } + } + }() + if migraterErr == nil { + migrater.Log = log + migrater.PrefetchMigrations = uint(prefetch) + migrater.LockTimeout = time.Duration(int64(lockTimeout)) * time.Second + + // handle Ctrl+c + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT) + go func() { + for range signals { + log.Println("Stopping after this running migration ...") + migrater.GracefulStop <- true + return + } + }() + } + + startTime := time.Now() + + if len(flag.Args()) < 1 { + printUsageAndExit() + } + args := flag.Args()[1:] + + switch flag.Arg(0) { + case "create": + + seq := false + seqDigits := 6 + + createFlagSet, help := newFlagSetWithHelp("create") + extPtr := createFlagSet.String("ext", "", "File extension") + dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)") + formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`) + timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`) + createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)") + createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)") + + if err := createFlagSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*help, createUsage, createFlagSet) + + if createFlagSet.NArg() == 0 { + log.fatal("error: please specify name") + } + name := createFlagSet.Arg(0) + + if *extPtr == "" { + log.fatal("error: -ext flag must be specified") + } + + timezone, err := time.LoadLocation(*timezoneName) + if err != nil { + log.fatal(err) + } + + if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil { + log.fatalErr(err) + } + + case "goto": + + gotoSet, helpPtr := newFlagSetWithHelp("goto") + + if err := gotoSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if gotoSet.NArg() == 0 { + log.fatal("error: please specify version argument V") + } + + v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64) + if err != nil { + log.fatal("error: can't read version argument V") + } + handleDirty := viper.GetBool("dirty") + destPath := viper.GetString("intermediate-path") + srcPath := "" + // if sourcePtr is set, use it to get the source path + // otherwise, use the path flag + if path != "" { + srcPath = path + } + if sourcePtr != "" { + // parse the source path from the source argument + parse, err := url.Parse(sourcePtr) + if err != nil { + log.fatal("error: can't parse the source path from the source argument") + } + srcPath = parse.Path + } + + if handleDirty && destPath == "" { + log.fatal("error: intermediate-path must be specified when dirty is set") + } + + migrater.WithDirtyStateHandler(srcPath, destPath, handleDirty) + if err = gotoCmd(migrater, uint(v)); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "up": + upSet, helpPtr := newFlagSetWithHelp("up") + + if err := upSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, upUsage, upSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + limit := -1 + if upSet.NArg() > 0 { + n, err := strconv.ParseUint(upSet.Arg(0), 10, 64) + if err != nil { + log.fatal("error: can't read limit argument N") + } + limit = int(n) + } + + if err := upCmd(migrater, limit); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "down": + downFlagSet, helpPtr := newFlagSetWithHelp("down") + applyAll := downFlagSet.Bool("all", false, "Apply all down migrations") + + if err := downFlagSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, downUsage, downFlagSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + downArgs := downFlagSet.Args() + num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs) + if err != nil { + log.fatalErr(err) + } + if needsConfirm { + log.Println("Are you sure you want to apply all down migrations? [y/N]") + var response string + _, _ = fmt.Scanln(&response) + response = strings.ToLower(strings.TrimSpace(response)) + + if response == "y" { + log.Println("Applying all down migrations") + } else { + log.fatal("Not applying all down migrations") + } + } + + if err := downCmd(migrater, num); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "drop": + dropFlagSet, help := newFlagSetWithHelp("drop") + forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt") + + if err := dropFlagSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*help, dropUsage, dropFlagSet) + + if !*forceDrop { + log.Println("Are you sure you want to drop the entire database schema? [y/N]") + var response string + _, _ = fmt.Scanln(&response) + response = strings.ToLower(strings.TrimSpace(response)) + + if response == "y" { + log.Println("Dropping the entire database schema") + } else { + log.fatal("Aborted dropping the entire database schema") + } + } + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if err := dropCmd(migrater); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "force": + forceSet, helpPtr := newFlagSetWithHelp("force") + + if err := forceSet.Parse(args); err != nil { + log.fatalErr(err) + } + + handleSubCmdHelp(*helpPtr, forceUsage, forceSet) + + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if forceSet.NArg() == 0 { + log.fatal("error: please specify version argument V") + } + + v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64) + if err != nil { + log.fatal("error: can't read version argument V") + } + + if v < -1 { + log.fatal("error: argument V must be >= -1") + } + + if err := forceCmd(migrater, int(v)); err != nil { + log.fatalErr(err) + } + + if log.verbose { + log.Println("Finished after", time.Since(startTime)) + } + + case "version": + if migraterErr != nil { + log.fatalErr(migraterErr) + } + + if err := versionCmd(migrater); err != nil { + log.fatalErr(err) + } + + default: + printUsageAndExit() + } } diff --git a/migrate.go b/migrate.go index e7631c843..0793924c1 100644 --- a/migrate.go +++ b/migrate.go @@ -5,17 +5,21 @@ package migrate import ( - "errors" - "fmt" - "os" - "sync" - "time" - - "github.com/hashicorp/go-multierror" - - "github.com/golang-migrate/migrate/v4/database" - iurl "github.com/golang-migrate/migrate/v4/internal/url" - "github.com/golang-migrate/migrate/v4/source" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + + "github.com/golang-migrate/migrate/v4/database" + iurl "github.com/golang-migrate/migrate/v4/internal/url" + "github.com/golang-migrate/migrate/v4/source" ) // DefaultPrefetchMigrations sets the number of migrations to pre-read @@ -29,107 +33,110 @@ var DefaultPrefetchMigrations = uint(10) var DefaultLockTimeout = 15 * time.Second var ( - ErrNoChange = errors.New("no change") - ErrNilVersion = errors.New("no migration") - ErrInvalidVersion = errors.New("version must be >= -1") - ErrLocked = errors.New("database locked") - ErrLockTimeout = errors.New("timeout: can't acquire database lock") + ErrNoChange = errors.New("no change") + ErrNilVersion = errors.New("no migration") + ErrInvalidVersion = errors.New("version must be >= -1") + ErrLocked = errors.New("database locked") + ErrLockTimeout = errors.New("timeout: can't acquire database lock") ) +// Define a constant for the migration file name +const lastSuccessfulMigrationFile = "lastSuccessfulMigration" + // ErrShortLimit is an error returned when not enough migrations // can be returned by a source for a given limit. type ErrShortLimit struct { - Short uint + Short uint } // Error implements the error interface. func (e ErrShortLimit) Error() string { - return fmt.Sprintf("limit %v short", e.Short) + return fmt.Sprintf("limit %v short", e.Short) } type ErrDirty struct { - Version int + Version int } func (e ErrDirty) Error() string { - return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version) + return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version) } type Migrate struct { - sourceName string - sourceDrv source.Driver - databaseName string - databaseDrv database.Driver + sourceName string + sourceDrv source.Driver + databaseName string + databaseDrv database.Driver - // Log accepts a Logger interface - Log Logger + // Log accepts a Logger interface + Log Logger - // GracefulStop accepts `true` and will stop executing migrations - // as soon as possible at a safe break point, so that the database - // is not corrupted. - GracefulStop chan bool - isLockedMu *sync.Mutex + // GracefulStop accepts `true` and will stop executing migrations + // as soon as possible at a safe break point, so that the database + // is not corrupted. + GracefulStop chan bool + isLockedMu *sync.Mutex - isGracefulStop bool - isLocked bool + isGracefulStop bool + isLocked bool - // PrefetchMigrations defaults to DefaultPrefetchMigrations, - // but can be set per Migrate instance. - PrefetchMigrations uint + // PrefetchMigrations defaults to DefaultPrefetchMigrations, + // but can be set per Migrate instance. + PrefetchMigrations uint - // LockTimeout defaults to DefaultLockTimeout, - // but can be set per Migrate instance. - LockTimeout time.Duration + // LockTimeout defaults to DefaultLockTimeout, + // but can be set per Migrate instance. + LockTimeout time.Duration - // DirtyStateHandler is used to handle dirty state of the database - ds *dirtyStateHandler + // DirtyStateHandler is used to handle dirty state of the database + ds *dirtyStateHandler } type dirtyStateHandler struct { - srcPath string - destPath string - isDirty bool + srcPath string + destPath string + isDirty bool } // New returns a new Migrate instance from a source URL and a database URL. // The URL scheme is defined by each driver. func New(sourceURL, databaseURL string) (*Migrate, error) { - m := newCommon() - - sourceName, err := iurl.SchemeFromURL(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from source URL: %w", err) - } - m.sourceName = sourceName - - databaseName, err := iurl.SchemeFromURL(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) - } - m.databaseName = databaseName - - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv - - databaseDrv, err := database.Open(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) - } - m.databaseDrv = databaseDrv - - return m, nil + m := newCommon() + + sourceName, err := iurl.SchemeFromURL(sourceURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from source URL: %w", err) + } + m.sourceName = sourceName + + databaseName, err := iurl.SchemeFromURL(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) + } + m.databaseName = databaseName + + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv + + databaseDrv, err := database.Open(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) + } + m.databaseDrv = databaseDrv + + return m, nil } func (m *Migrate) updateSourceDrv(sourceURL string) error { - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv - return nil + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv + return nil } // NewWithDatabaseInstance returns a new Migrate instance from a source URL @@ -137,25 +144,25 @@ func (m *Migrate) updateSourceDrv(sourceURL string) error { // Use any string that can serve as an identifier during logging as databaseName. // You are responsible for closing the underlying database client if necessary. func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() + m := newCommon() - sourceName, err := iurl.SchemeFromURL(sourceURL) - if err != nil { - return nil, err - } - m.sourceName = sourceName + sourceName, err := iurl.SchemeFromURL(sourceURL) + if err != nil { + return nil, err + } + m.sourceName = sourceName - m.databaseName = databaseName + m.databaseName = databaseName - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv - m.databaseDrv = databaseInstance + m.databaseDrv = databaseInstance - return m, nil + return m, nil } // NewWithSourceInstance returns a new Migrate instance from an existing source instance @@ -163,25 +170,25 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst // Use any string that can serve as an identifier during logging as sourceName. // You are responsible for closing the underlying source client if necessary. func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { - m := newCommon() + m := newCommon() - databaseName, err := iurl.SchemeFromURL(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) - } - m.databaseName = databaseName + databaseName, err := iurl.SchemeFromURL(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) + } + m.databaseName = databaseName - m.sourceName = sourceName + m.sourceName = sourceName - databaseDrv, err := database.Open(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) - } - m.databaseDrv = databaseDrv + databaseDrv, err := database.Open(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) + } + m.databaseDrv = databaseDrv - m.sourceDrv = sourceInstance + m.sourceDrv = sourceInstance - return m, nil + return m, nil } // NewWithInstance returns a new Migrate instance from an existing source and @@ -189,191 +196,194 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data // as sourceName and databaseName. You are responsible for closing down // the underlying source and database client if necessary. func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() + m := newCommon() - m.sourceName = sourceName - m.databaseName = databaseName + m.sourceName = sourceName + m.databaseName = databaseName - m.sourceDrv = sourceInstance - m.databaseDrv = databaseInstance + m.sourceDrv = sourceInstance + m.databaseDrv = databaseInstance - return m, nil + return m, nil } func (m *Migrate) WithDirtyStateHandler(srcPath, destPath string, isDirty bool) { - m.ds = &dirtyStateHandler{ - srcPath: srcPath, - destPath: destPath, - isDirty: isDirty, - } + m.ds = &dirtyStateHandler{ + srcPath: srcPath, + destPath: destPath, + isDirty: isDirty, + } } func newCommon() *Migrate { - return &Migrate{ - GracefulStop: make(chan bool, 1), - PrefetchMigrations: DefaultPrefetchMigrations, - LockTimeout: DefaultLockTimeout, - isLockedMu: &sync.Mutex{}, - } + return &Migrate{ + GracefulStop: make(chan bool, 1), + PrefetchMigrations: DefaultPrefetchMigrations, + LockTimeout: DefaultLockTimeout, + isLockedMu: &sync.Mutex{}, + } } // Close closes the source and the database. func (m *Migrate) Close() (source error, database error) { - databaseSrvClose := make(chan error) - sourceSrvClose := make(chan error) + databaseSrvClose := make(chan error) + sourceSrvClose := make(chan error) - m.logVerbosePrintf("Closing source and database\n") + m.logVerbosePrintf("Closing source and database\n") - go func() { - databaseSrvClose <- m.databaseDrv.Close() - }() + go func() { + databaseSrvClose <- m.databaseDrv.Close() + }() - go func() { - sourceSrvClose <- m.sourceDrv.Close() - }() + go func() { + sourceSrvClose <- m.sourceDrv.Close() + }() - return <-sourceSrvClose, <-databaseSrvClose + return <-sourceSrvClose, <-databaseSrvClose } // Migrate looks at the currently active migration version, // then migrates either up or down to the specified version. func (m *Migrate) Migrate(version uint) error { - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - m.Log.Printf("******************Failed to get current version: %v\n", err) - return err - } - - if err = m.CopyFiles(); err != nil { - return err - } - - m.Log.Printf("Current version: %d, dirty: %t\n", curVersion, dirty) - // if the dirty flag is passed to the 'goto' command, handle the dirty state - if dirty { - if m.ds.isDirty { - m.Log.Printf("Version: %d, handle dirty: %t\n", version, m.ds.isDirty) - if err = m.HandleDirtyState(); err != nil { - return err - } - if err = m.updateSourceDrv(fmt.Sprintf("file://%s", m.ds.destPath)); err != nil { - return err - } - - } else { - // default behaviour - m.Log.Printf("Database is set to dirty for version: %v\n", curVersion) - return ErrDirty{curVersion} - } - } - - if err = m.lock(); err != nil { - return err - } - - ret := make(chan interface{}, m.PrefetchMigrations) - go m.read(curVersion, int(version), ret) - - err = m.runMigrations(ret) - if err != nil { - if m.ds.isDirty { - // Handle failure: store last successful migration version and exit - if err = m.HandleMigrationFailure(curVersion, version); err != nil { - return err - } - } - return m.unlockErr(err) - } - // Success: Clean up and confirm - if err = m.CleanupFiles(version); err != nil { - return m.unlockErr(err) - } - return nil + if err := m.lock(); err != nil { + return err + } + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } + + // if the dirty flag is passed to the 'goto' command, handle the dirty state + if dirty { + if m.ds != nil && m.ds.isDirty { + if err = m.unlock(); err != nil { + return m.unlockErr(err) + } + if err = m.HandleDirtyState(); err != nil { + return m.unlockErr(err) + } + if err = m.lock(); err != nil { + return err + } + if err = m.updateSourceDrv("file://" + m.ds.destPath); err != nil { + return m.unlockErr(err) + } + + } else { + // default behaviour + return m.unlockErr(ErrDirty{curVersion}) + } + } + + // Copy migrations to the destination directory, + // if state was dirty when Migrate was called, we should handle the dirty state first before copying the migrations + if err = m.CopyFiles(); err != nil { + return m.unlockErr(err) + } + + ret := make(chan interface{}, m.PrefetchMigrations) + go m.read(curVersion, int(version), ret) + + if err = m.runMigrations(ret); err != nil { + if m.ds != nil && m.ds.isDirty { + // Handle failure: store last successful migration version and exit + if err = m.HandleMigrationFailure(curVersion, version); err != nil { + return m.unlockErr(err) + } + } + return m.unlockErr(err) + } + // Success: Clean up and confirm + if err = m.CleanupFiles(version); err != nil { + return m.unlockErr(err) + } + // unlock the database + return m.unlock() } // Steps looks at the currently active migration version. // It will migrate up if n > 0, and down if n < 0. func (m *Migrate) Steps(n int) error { - if n == 0 { - return ErrNoChange - } + if n == 0 { + return ErrNoChange + } - if err := m.lock(); err != nil { - return err - } + if err := m.lock(); err != nil { + return err + } - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } - ret := make(chan interface{}, m.PrefetchMigrations) + ret := make(chan interface{}, m.PrefetchMigrations) - if n > 0 { - go m.readUp(curVersion, n, ret) - } else { - go m.readDown(curVersion, -n, ret) - } + if n > 0 { + go m.readUp(curVersion, n, ret) + } else { + go m.readDown(curVersion, -n, ret) + } - return m.unlockErr(m.runMigrations(ret)) + return m.unlockErr(m.runMigrations(ret)) } // Up looks at the currently active migration version // and will migrate all the way up (applying all up migrations). func (m *Migrate) Up() error { - if err := m.lock(); err != nil { - return err - } + if err := m.lock(); err != nil { + return err + } - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } - ret := make(chan interface{}, m.PrefetchMigrations) + ret := make(chan interface{}, m.PrefetchMigrations) - go m.readUp(curVersion, -1, ret) - return m.unlockErr(m.runMigrations(ret)) + go m.readUp(curVersion, -1, ret) + return m.unlockErr(m.runMigrations(ret)) } // Down looks at the currently active migration version // and will migrate all the way down (applying all down migrations). func (m *Migrate) Down() error { - if err := m.lock(); err != nil { - return err - } - - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } - - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } - - ret := make(chan interface{}, m.PrefetchMigrations) - go m.readDown(curVersion, -1, ret) - return m.unlockErr(m.runMigrations(ret)) + if err := m.lock(); err != nil { + return err + } + + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } + + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } + + ret := make(chan interface{}, m.PrefetchMigrations) + go m.readDown(curVersion, -1, ret) + return m.unlockErr(m.runMigrations(ret)) } // Drop deletes everything in the database. func (m *Migrate) Drop() error { - if err := m.lock(); err != nil { - return err - } - if err := m.databaseDrv.Drop(); err != nil { - return m.unlockErr(err) - } - return m.unlock() + if err := m.lock(); err != nil { + return err + } + if err := m.databaseDrv.Drop(); err != nil { + return m.unlockErr(err) + } + return m.unlock() } // Run runs any migration provided by you against the database. @@ -381,78 +391,78 @@ func (m *Migrate) Drop() error { // Usually you don't need this function at all. Use Migrate, // Steps, Up or Down instead. func (m *Migrate) Run(migration ...*Migration) error { - if len(migration) == 0 { - return ErrNoChange - } - - if err := m.lock(); err != nil { - return err - } - - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } - - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } - - ret := make(chan interface{}, m.PrefetchMigrations) - - go func() { - defer close(ret) - for _, migr := range migration { - if m.PrefetchMigrations > 0 && migr.Body != nil { - m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) - } else { - m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) - } - - ret <- migr - go func(migr *Migration) { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }(migr) - } - }() - - return m.unlockErr(m.runMigrations(ret)) + if len(migration) == 0 { + return ErrNoChange + } + + if err := m.lock(); err != nil { + return err + } + + curVersion, dirty, err := m.databaseDrv.Version() + if err != nil { + return m.unlockErr(err) + } + + if dirty { + return m.unlockErr(ErrDirty{curVersion}) + } + + ret := make(chan interface{}, m.PrefetchMigrations) + + go func() { + defer close(ret) + for _, migr := range migration { + if m.PrefetchMigrations > 0 && migr.Body != nil { + m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) + } else { + m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) + } + + ret <- migr + go func(migr *Migration) { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }(migr) + } + }() + + return m.unlockErr(m.runMigrations(ret)) } // Force sets a migration version. // It does not check any currently active version in database. // It resets the dirty state to false. func (m *Migrate) Force(version int) error { - if version < -1 { - return ErrInvalidVersion - } + if version < -1 { + return ErrInvalidVersion + } - if err := m.lock(); err != nil { - return err - } + if err := m.lock(); err != nil { + return err + } - if err := m.databaseDrv.SetVersion(version, false); err != nil { - return m.unlockErr(err) - } + if err := m.databaseDrv.SetVersion(version, false); err != nil { + return m.unlockErr(err) + } - return m.unlock() + return m.unlock() } // Version returns the currently active migration version. // If no migration has been applied, yet, it will return ErrNilVersion. func (m *Migrate) Version() (version uint, dirty bool, err error) { - v, d, err := m.databaseDrv.Version() - if err != nil { - return 0, false, err - } + v, d, err := m.databaseDrv.Version() + if err != nil { + return 0, false, err + } - if v == database.NilVersion { - return 0, false, ErrNilVersion - } + if v == database.NilVersion { + return 0, false, ErrNilVersion + } - return suint(v), d, nil + return suint(v), d, nil } // read reads either up or down migrations from source `from` to `to`. @@ -460,130 +470,130 @@ func (m *Migrate) Version() (version uint, dirty bool, err error) { // If an error occurs during reading, that error is written to the ret channel, too. // Once read is done reading it will close the ret channel. func (m *Migrate) read(from int, to int, ret chan<- interface{}) { - defer close(ret) - - // check if from version exists - if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { - ret <- err - return - } - } - - // check if to version exists - if to >= 0 { - if err := m.versionExists(suint(to)); err != nil { - ret <- err - return - } - } - - // no change? - if from == to { - ret <- ErrNoChange - return - } - - if from < to { - // it's going up - // apply first migration if from is nil version - if from == -1 { - firstVersion, err := m.sourceDrv.First() - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(firstVersion, int(firstVersion)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - from = int(firstVersion) - } - - // run until we reach target ... - for from < to { - if m.stop() { - return - } - - next, err := m.sourceDrv.Next(suint(from)) - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(next, int(next)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - from = int(next) - } - - } else { - // it's going down - // run until we reach target ... - for from > to && from >= 0 { - if m.stop() { - return - } - - prev, err := m.sourceDrv.Prev(suint(from)) - if errors.Is(err, os.ErrNotExist) && to == -1 { - // apply nil migration - migr, err := m.newMigration(suint(from), -1) - if err != nil { - ret <- err - return - } - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - return - - } else if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(suint(from), int(prev)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - - from = int(prev) - } - } + defer close(ret) + + // check if from version exists + if from >= 0 { + if err := m.versionExists(suint(from)); err != nil { + ret <- err + return + } + } + + // check if to version exists + if to >= 0 { + if err := m.versionExists(suint(to)); err != nil { + ret <- err + return + } + } + + // no change? + if from == to { + ret <- ErrNoChange + return + } + + if from < to { + // it's going up + // apply first migration if from is nil version + if from == -1 { + firstVersion, err := m.sourceDrv.First() + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(firstVersion, int(firstVersion)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + from = int(firstVersion) + } + + // run until we reach target ... + for from < to { + if m.stop() { + return + } + + next, err := m.sourceDrv.Next(suint(from)) + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(next, int(next)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + from = int(next) + } + + } else { + // it's going down + // run until we reach target ... + for from > to && from >= 0 { + if m.stop() { + return + } + + prev, err := m.sourceDrv.Prev(suint(from)) + if errors.Is(err, os.ErrNotExist) && to == -1 { + // apply nil migration + migr, err := m.newMigration(suint(from), -1) + if err != nil { + ret <- err + return + } + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + return + + } else if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(suint(from), int(prev)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + + from = int(prev) + } + } } // readUp reads up migrations from `from` limitted by `limit`. @@ -592,98 +602,98 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) { // If an error occurs during reading, that error is written to the ret channel, too. // Once readUp is done reading it will close the ret channel. func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { - defer close(ret) - - // check if from version exists - if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { - ret <- err - return - } - } - - if limit == 0 { - ret <- ErrNoChange - return - } - - count := 0 - for count < limit || limit == -1 { - if m.stop() { - return - } - - // apply first migration if from is nil version - if from == -1 { - firstVersion, err := m.sourceDrv.First() - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(firstVersion, int(firstVersion)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - from = int(firstVersion) - count++ - continue - } - - // apply next migration - next, err := m.sourceDrv.Next(suint(from)) - if errors.Is(err, os.ErrNotExist) { - // no limit, but no migrations applied? - if limit == -1 && count == 0 { - ret <- ErrNoChange - return - } - - // no limit, reached end - if limit == -1 { - return - } - - // reached end, and didn't apply any migrations - if limit > 0 && count == 0 { - ret <- os.ErrNotExist - return - } - - // applied less migrations than limit? - if count < limit { - ret <- ErrShortLimit{suint(limit - count)} - return - } - } - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(next, int(next)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - from = int(next) - count++ - } + defer close(ret) + + // check if from version exists + if from >= 0 { + if err := m.versionExists(suint(from)); err != nil { + ret <- err + return + } + } + + if limit == 0 { + ret <- ErrNoChange + return + } + + count := 0 + for count < limit || limit == -1 { + if m.stop() { + return + } + + // apply first migration if from is nil version + if from == -1 { + firstVersion, err := m.sourceDrv.First() + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(firstVersion, int(firstVersion)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + from = int(firstVersion) + count++ + continue + } + + // apply next migration + next, err := m.sourceDrv.Next(suint(from)) + if errors.Is(err, os.ErrNotExist) { + // no limit, but no migrations applied? + if limit == -1 && count == 0 { + ret <- ErrNoChange + return + } + + // no limit, reached end + if limit == -1 { + return + } + + // reached end, and didn't apply any migrations + if limit > 0 && count == 0 { + ret <- os.ErrNotExist + return + } + + // applied less migrations than limit? + if count < limit { + ret <- ErrShortLimit{suint(limit - count)} + return + } + } + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(next, int(next)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + from = int(next) + count++ + } } // readDown reads down migrations from `from` limitted by `limit`. @@ -692,88 +702,88 @@ func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { // If an error occurs during reading, that error is written to the ret channel, too. // Once readDown is done reading it will close the ret channel. func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { - defer close(ret) - - // check if from version exists - if from >= 0 { - if err := m.versionExists(suint(from)); err != nil { - ret <- err - return - } - } - - if limit == 0 { - ret <- ErrNoChange - return - } - - // no change if already at nil version - if from == -1 && limit == -1 { - ret <- ErrNoChange - return - } - - // can't go over limit if already at nil version - if from == -1 && limit > 0 { - ret <- os.ErrNotExist - return - } - - count := 0 - for count < limit || limit == -1 { - if m.stop() { - return - } - - prev, err := m.sourceDrv.Prev(suint(from)) - if errors.Is(err, os.ErrNotExist) { - // no limit or haven't reached limit, apply "first" migration - if limit == -1 || limit-count > 0 { - firstVersion, err := m.sourceDrv.First() - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(firstVersion, -1) - if err != nil { - ret <- err - return - } - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - count++ - } - - if count < limit { - ret <- ErrShortLimit{suint(limit - count)} - } - return - } - if err != nil { - ret <- err - return - } - - migr, err := m.newMigration(suint(from), int(prev)) - if err != nil { - ret <- err - return - } - - ret <- migr - go func() { - if err := migr.Buffer(); err != nil { - m.logErr(err) - } - }() - from = int(prev) - count++ - } + defer close(ret) + + // check if from version exists + if from >= 0 { + if err := m.versionExists(suint(from)); err != nil { + ret <- err + return + } + } + + if limit == 0 { + ret <- ErrNoChange + return + } + + // no change if already at nil version + if from == -1 && limit == -1 { + ret <- ErrNoChange + return + } + + // can't go over limit if already at nil version + if from == -1 && limit > 0 { + ret <- os.ErrNotExist + return + } + + count := 0 + for count < limit || limit == -1 { + if m.stop() { + return + } + + prev, err := m.sourceDrv.Prev(suint(from)) + if errors.Is(err, os.ErrNotExist) { + // no limit or haven't reached limit, apply "first" migration + if limit == -1 || limit-count > 0 { + firstVersion, err := m.sourceDrv.First() + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(firstVersion, -1) + if err != nil { + ret <- err + return + } + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + count++ + } + + if count < limit { + ret <- ErrShortLimit{suint(limit - count)} + } + return + } + if err != nil { + ret <- err + return + } + + migr, err := m.newMigration(suint(from), int(prev)) + if err != nil { + ret <- err + return + } + + ret <- migr + go func() { + if err := migr.Buffer(); err != nil { + m.logErr(err) + } + }() + from = int(prev) + count++ + } } // runMigrations reads *Migration and error from a channel. Any other type @@ -783,260 +793,418 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { // to stop execution because it might have received a stop signal on the // GracefulStop channel. func (m *Migrate) runMigrations(ret <-chan interface{}) error { - m.Log.Printf("Starting %s migrations\n", m.sourceDrv) - for r := range ret { - - if m.stop() { - return nil - } - - switch r := r.(type) { - case error: - return r - - case *Migration: - migr := r - - // set version with dirty state - if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil { - return err - } - - if migr.Body != nil { - m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) - if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { - return err - } - } - - // set clean state - if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil { - return err - } - - endTime := time.Now() - readTime := migr.FinishedReading.Sub(migr.StartedBuffering) - runTime := endTime.Sub(migr.FinishedReading) - - // log either verbose or normal - if m.Log != nil { - if m.Log.Verbose() { - m.logPrintf("Finished %v (read %v, ran %v)\n", migr.LogString(), readTime, runTime) - } else { - m.logPrintf("%v (%v)\n", migr.LogString(), readTime+runTime) - } - } - - default: - return fmt.Errorf("unknown type: %T with value: %+v", r, r) - } - } - return nil + for r := range ret { + + if m.stop() { + return nil + } + + switch r := r.(type) { + case error: + return r + + case *Migration: + migr := r + + // set version with dirty state + if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil { + return err + } + + if migr.Body != nil { + m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) + if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { + return err + } + } + + // set clean state + if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil { + return err + } + + endTime := time.Now() + readTime := migr.FinishedReading.Sub(migr.StartedBuffering) + runTime := endTime.Sub(migr.FinishedReading) + + // log either verbose or normal + if m.Log != nil { + if m.Log.Verbose() { + m.logPrintf("Finished %v (read %v, ran %v)\n", migr.LogString(), readTime, runTime) + } else { + m.logPrintf("%v (%v)\n", migr.LogString(), readTime+runTime) + } + } + + default: + return fmt.Errorf("unknown type: %T with value: %+v", r, r) + } + } + return nil } // versionExists checks the source if either the up or down migration for // the specified migration version exists. func (m *Migrate) versionExists(version uint) (result error) { - // try up migration first - up, _, err := m.sourceDrv.ReadUp(version) - if err == nil { - defer func() { - if errClose := up.Close(); errClose != nil { - result = multierror.Append(result, errClose) - } - }() - } - if errors.Is(err, os.ErrExist) { - return nil - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - - // then try down migration - down, _, err := m.sourceDrv.ReadDown(version) - if err == nil { - defer func() { - if errClose := down.Close(); errClose != nil { - result = multierror.Append(result, errClose) - } - }() - } - if errors.Is(err, os.ErrExist) { - return nil - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - - err = fmt.Errorf("no migration found for version %d: %w", version, err) - m.logErr(err) - return err + // try up migration first + up, _, err := m.sourceDrv.ReadUp(version) + if err == nil { + defer func() { + if errClose := up.Close(); errClose != nil { + result = multierror.Append(result, errClose) + } + }() + } + if errors.Is(err, os.ErrExist) { + return nil + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + + // then try down migration + down, _, err := m.sourceDrv.ReadDown(version) + if err == nil { + defer func() { + if errClose := down.Close(); errClose != nil { + result = multierror.Append(result, errClose) + } + }() + } + if errors.Is(err, os.ErrExist) { + return nil + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + + err = fmt.Errorf("no migration found for version %d: %w", version, err) + m.logErr(err) + return err } // stop returns true if no more migrations should be run against the database // because a stop signal was received on the GracefulStop channel. // Calls are cheap and this function is not blocking. func (m *Migrate) stop() bool { - if m.isGracefulStop { - return true - } - - select { - case <-m.GracefulStop: - m.isGracefulStop = true - return true - - default: - return false - } + if m.isGracefulStop { + return true + } + + select { + case <-m.GracefulStop: + m.isGracefulStop = true + return true + + default: + return false + } } // newMigration is a helper func that returns a *Migration for the // specified version and targetVersion. func (m *Migrate) newMigration(version uint, targetVersion int) (*Migration, error) { - var migr *Migration - - if targetVersion >= int(version) { - r, identifier, err := m.sourceDrv.ReadUp(version) - if errors.Is(err, os.ErrNotExist) { - // create "empty" migration - migr, err = NewMigration(nil, "", version, targetVersion) - if err != nil { - return nil, err - } - - } else if err != nil { - return nil, err - - } else { - // create migration from up source - migr, err = NewMigration(r, identifier, version, targetVersion) - if err != nil { - return nil, err - } - } - - } else { - r, identifier, err := m.sourceDrv.ReadDown(version) - if errors.Is(err, os.ErrNotExist) { - // create "empty" migration - migr, err = NewMigration(nil, "", version, targetVersion) - if err != nil { - return nil, err - } - - } else if err != nil { - return nil, err - - } else { - // create migration from down source - migr, err = NewMigration(r, identifier, version, targetVersion) - if err != nil { - return nil, err - } - } - } - - if m.PrefetchMigrations > 0 && migr.Body != nil { - m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) - } else { - m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) - } - - return migr, nil + var migr *Migration + + if targetVersion >= int(version) { + r, identifier, err := m.sourceDrv.ReadUp(version) + if errors.Is(err, os.ErrNotExist) { + // create "empty" migration + migr, err = NewMigration(nil, "", version, targetVersion) + if err != nil { + return nil, err + } + + } else if err != nil { + return nil, err + + } else { + // create migration from up source + migr, err = NewMigration(r, identifier, version, targetVersion) + if err != nil { + return nil, err + } + } + + } else { + r, identifier, err := m.sourceDrv.ReadDown(version) + if errors.Is(err, os.ErrNotExist) { + // create "empty" migration + migr, err = NewMigration(nil, "", version, targetVersion) + if err != nil { + return nil, err + } + + } else if err != nil { + return nil, err + + } else { + // create migration from down source + migr, err = NewMigration(r, identifier, version, targetVersion) + if err != nil { + return nil, err + } + } + } + + if m.PrefetchMigrations > 0 && migr.Body != nil { + m.logVerbosePrintf("Start buffering %v\n", migr.LogString()) + } else { + m.logVerbosePrintf("Scheduled %v\n", migr.LogString()) + } + + return migr, nil } // lock is a thread safe helper function to lock the database. // It should be called as late as possible when running migrations. func (m *Migrate) lock() error { - m.isLockedMu.Lock() - defer m.isLockedMu.Unlock() - - if m.isLocked { - return ErrLocked - } - - // create done channel, used in the timeout goroutine - done := make(chan bool, 1) - defer func() { - done <- true - }() - - // use errchan to signal error back to this context - errchan := make(chan error, 2) - - // start timeout goroutine - timeout := time.After(m.LockTimeout) - go func() { - for { - select { - case <-done: - return - case <-timeout: - errchan <- ErrLockTimeout - return - } - } - }() - - // now try to acquire the lock - go func() { - if err := m.databaseDrv.Lock(); err != nil { - errchan <- err - } else { - errchan <- nil - } - }() - - // wait until we either receive ErrLockTimeout or error from Lock operation - err := <-errchan - if err == nil { - m.isLocked = true - } - return err + m.isLockedMu.Lock() + defer m.isLockedMu.Unlock() + + if m.isLocked { + return ErrLocked + } + + // create done channel, used in the timeout goroutine + done := make(chan bool, 1) + defer func() { + done <- true + }() + + // use errchan to signal error back to this context + errchan := make(chan error, 2) + + // start timeout goroutine + timeout := time.After(m.LockTimeout) + go func() { + for { + select { + case <-done: + return + case <-timeout: + errchan <- ErrLockTimeout + return + } + } + }() + + // now try to acquire the lock + go func() { + if err := m.databaseDrv.Lock(); err != nil { + errchan <- err + } else { + errchan <- nil + } + }() + + // wait until we either receive ErrLockTimeout or error from Lock operation + err := <-errchan + if err == nil { + m.isLocked = true + } + return err } // unlock is a thread safe helper function to unlock the database. // It should be called as early as possible when no more migrations are // expected to be executed. func (m *Migrate) unlock() error { - m.isLockedMu.Lock() - defer m.isLockedMu.Unlock() + m.isLockedMu.Lock() + defer m.isLockedMu.Unlock() - if err := m.databaseDrv.Unlock(); err != nil { - // BUG: Can potentially create a deadlock. Add a timeout. - return err - } + if err := m.databaseDrv.Unlock(); err != nil { + // BUG: Can potentially create a deadlock. Add a timeout. + return err + } - m.isLocked = false - return nil + m.isLocked = false + return nil } // unlockErr calls unlock and returns a combined error // if a prevErr is not nil. func (m *Migrate) unlockErr(prevErr error) error { - if err := m.unlock(); err != nil { - return multierror.Append(prevErr, err) - } - return prevErr + if err := m.unlock(); err != nil { + return multierror.Append(prevErr, err) + } + return prevErr } // logPrintf writes to m.Log if not nil func (m *Migrate) logPrintf(format string, v ...interface{}) { - if m.Log != nil { - m.Log.Printf(format, v...) - } + if m.Log != nil { + m.Log.Printf(format, v...) + } } // logVerbosePrintf writes to m.Log if not nil. Use for verbose logging output. func (m *Migrate) logVerbosePrintf(format string, v ...interface{}) { - if m.Log != nil && m.Log.Verbose() { - m.Log.Printf(format, v...) - } + if m.Log != nil && m.Log.Verbose() { + m.Log.Printf(format, v...) + } } // logErr writes error to m.Log if not nil func (m *Migrate) logErr(err error) { - if m.Log != nil { - m.Log.Printf("error: %v", err) - } + if m.Log != nil { + m.Log.Printf("error: %v", err) + } +} + +func (m *Migrate) HandleDirtyState() error { + // Perform actions when the database state is dirty + lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) + lastVersionBytes, err := os.ReadFile(lastSuccessfulMigrationPath) + if err != nil { + return err + } + lastVersionStr := strings.TrimSpace(string(lastVersionBytes)) + lastVersion, err := strconv.ParseInt(lastVersionStr, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse last successful migration version: %w", err) + } + + if err = m.Force(int(lastVersion)); err != nil { + return fmt.Errorf("failed to apply last successful migration: %w", err) + } + + m.logPrintf("Successfully applied migration: %s", lastVersionStr) + + if err = os.Remove(lastSuccessfulMigrationPath); err != nil { + return err + } + + m.logPrintf("Successfully deleted file: %s", lastSuccessfulMigrationPath) + return nil +} + +func (m *Migrate) HandleMigrationFailure(curVersion int, v uint) error { + failedVersion, _, err := m.databaseDrv.Version() + if err != nil { + return err + } + // Determine the last successful migration + lastSuccessfulMigration := strconv.Itoa(curVersion) + ret := make(chan interface{}, m.PrefetchMigrations) + go m.read(curVersion, int(v), ret) + + for r := range ret { + mig, ok := r.(*Migration) + if ok { + if mig.Version == uint(failedVersion) { + break + } + lastSuccessfulMigration = strconv.Itoa(int(mig.Version)) + } + } + + lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) + return os.WriteFile(lastSuccessfulMigrationPath, []byte(lastSuccessfulMigration), 0644) +} + +func (m *Migrate) CleanupFiles(v uint) error { + if m.ds == nil || m.ds.destPath == "" { + return nil + } + files, err := os.ReadDir(m.ds.destPath) + if err != nil { + return err + } + + targetVersion := uint64(v) + + for _, file := range files { + fileName := file.Name() + + // Check if file is a migration file we want to process + if !strings.HasSuffix(fileName, "down.sql") && !strings.HasSuffix(fileName, "up.sql") { + continue + } + + // Extract version and compare + versionEnd := strings.Index(fileName, "_") + if versionEnd == -1 { + // Skip files that don't match the expected naming pattern + continue + } + + fileVersion, err := strconv.ParseUint(fileName[:versionEnd], 10, 64) + if err != nil { + m.logErr(fmt.Errorf("skipping file %s due to version parse error: %v", fileName, err)) + continue + } + + // Delete file if version is greater than targetVersion + if fileVersion > targetVersion { + if err = os.Remove(filepath.Join(m.ds.destPath, fileName)); err != nil { + m.logErr(fmt.Errorf("failed to delete file %s: %v", fileName, err)) + continue + } + m.logPrintf("Deleted file: %s", fileName) + } + } + + return nil +} + +// CopyFiles copies all files from srcDir to destDir. +func (m *Migrate) CopyFiles() error { + if m.ds == nil || m.ds.destPath == "" { + return nil + } + _, err := os.ReadDir(m.ds.destPath) + if err != nil { + // If the directory does not exist + return err + } + + m.logPrintf("Copying files from %s to %s", m.ds.srcPath, m.ds.destPath) + + return filepath.Walk(m.ds.srcPath, func(src string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // ignore sub-directories in the migration directory + if info.IsDir() { + // Skip the tests directory and its files + if info.Name() == "tests" { + return filepath.SkipDir + } + return nil + } + // Ignore the current.sql file + if info.Name() == "current.sql" { + return nil + } + + var ( + srcFile *os.File + destFile *os.File + ) + dest := filepath.Join(m.ds.destPath, info.Name()) + if srcFile, err = os.Open(src); err != nil { + return err + } + defer func(srcFile *os.File) { + if err = srcFile.Close(); err != nil { + m.logErr(fmt.Errorf("failed to close file %s: %v", destFile.Name(), err)) + } + }(srcFile) + + // Create the destination file + if destFile, err = os.Create(dest); err != nil { + return err + } + defer func(destFile *os.File) { + if err = destFile.Close(); err != nil { + m.logErr(fmt.Errorf("failed to close file %s: %v", destFile.Name(), err)) + } + }(destFile) + + // Copy the file + if _, err = io.Copy(destFile, srcFile); err != nil { + return err + } + return os.Chmod(dest, info.Mode()) + }) } diff --git a/migrate_goto_temp.go b/migrate_goto_temp.go deleted file mode 100644 index 517d63adb..000000000 --- a/migrate_goto_temp.go +++ /dev/null @@ -1,176 +0,0 @@ -package migrate - -import ( - "io" - "os" - "path/filepath" - "strconv" - "strings" - - "github.com/pkg/errors" -) - -// Define a constant for the migration file name -const lastSuccessfulMigrationFile = "lastSuccessfulMigration" - -func (m *Migrate) HandleDirtyState() error { - // Perform actions when the database state is dirty - lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) - lastVersionBytes, err := os.ReadFile(lastSuccessfulMigrationPath) - if err != nil { - return err - } - lastVersionStr := strings.TrimSpace(string(lastVersionBytes)) - lastVersion, err := strconv.ParseUint(lastVersionStr, 10, 64) - if err != nil { - return errors.Wrap(err, "failed to parse last successful migration version") - } - - if err = m.Force(int(lastVersion)); err != nil { - return errors.Wrap(err, "failed to apply last successful migration") - } - - m.Log.Printf("Successfully applied migration: %s", lastVersionStr) - - if err = os.Remove(lastSuccessfulMigrationPath); err != nil { - return err - } - - m.Log.Printf("Successfully deleted file: %s", lastSuccessfulMigrationPath) - return nil -} - -func (m *Migrate) HandleMigrationFailure(curVersion int, v uint) error { - failedVersion, _, err := m.databaseDrv.Version() - if err != nil { - return err - } - - // Determine the last successful migration - lastSuccessfulMigration := strconv.Itoa(curVersion) - ret := make(chan interface{}, m.PrefetchMigrations) - go m.read(curVersion, int(v), ret) - - for r := range ret { - mig, ok := r.(*Migration) - if ok { - if mig.Version == uint(failedVersion) { - break - } - lastSuccessfulMigration = strconv.Itoa(int(mig.Version)) - } - } - - m.Log.Printf("migration failed, last successful migration version: %s", lastSuccessfulMigration) - lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) - if err = os.WriteFile(lastSuccessfulMigrationPath, []byte(lastSuccessfulMigration), 0644); err != nil { - return err - } - - return nil -} - -func (m *Migrate) CleanupFiles(v uint) error { - if m.ds.destPath == "" { - return nil - } - files, err := os.ReadDir(m.ds.destPath) - if err != nil { - return err - } - - targetVersion := uint64(v) - - for _, file := range files { - fileName := file.Name() - - // Check if file is a migration file we want to process - if !strings.HasSuffix(fileName, "down.sql") && !strings.HasSuffix(fileName, "up.sql") { - continue - } - - // Extract version and compare - versionEnd := strings.Index(fileName, "_") - if versionEnd == -1 { - // Skip files that don't match the expected naming pattern - continue - } - - fileVersion, err := strconv.ParseUint(fileName[:versionEnd], 10, 64) - if err != nil { - m.Log.Printf("Skipping file %s due to version parse error: %v", fileName, err) - continue - } - - // Delete file if version is greater than targetVersion - if fileVersion > targetVersion { - if err = os.Remove(filepath.Join(m.ds.destPath, fileName)); err != nil { - m.Log.Printf("Failed to delete file %s: %v", fileName, err) - continue - } - m.Log.Printf("Deleted file: %s", fileName) - } - } - - return nil -} - -// CopyFiles copies all files from srcDir to destDir. -func (m *Migrate) CopyFiles() error { - if m.ds.destPath == "" { - return nil - } - _, err := os.ReadDir(m.ds.destPath) - if err != nil { - // If the directory does not exist - return err - } - - m.Log.Printf("Copying files from %s to %s", m.ds.srcPath, m.ds.destPath) - - return filepath.Walk(m.ds.srcPath, func(src string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - // ignore sub-directories in the migration directory - if info.IsDir() { - // Skip the tests directory and its files - if info.Name() == "tests" { - m.Log.Printf("Ignoring directory %s", info.Name()) - return filepath.SkipDir - } - return nil - } - // Ignore the current.sql file - if info.Name() == "current.sql" { - m.Log.Printf("Ignoring file %s", info.Name()) - return nil - } - - var ( - srcFile *os.File - destFile *os.File - ) - dest := filepath.Join(m.ds.destPath, info.Name()) - if srcFile, err = os.Open(src); err != nil { - return err - } - defer func(srcFile *os.File) { - if err = srcFile.Close(); err != nil { - m.Log.Printf("failed to close file %s: %s", srcFile.Name, err) - } - }(srcFile) - - // Create the destination file - if destFile, err = os.Create(dest); err != nil { - return err - } - - // Copy the file - if _, err = io.Copy(destFile, srcFile); err == nil { - return err - } - return os.Chmod(dest, info.Mode()) - }) -} diff --git a/migrate_test.go b/migrate_test.go index f2728179e..61c6b7d73 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -4,9 +4,12 @@ import ( "bytes" "database/sql" "errors" + "fmt" "io" "log" "os" + "path/filepath" + "strconv" "strings" "testing" @@ -1414,3 +1417,321 @@ func equalDbSeq(t *testing.T, i int, expected migrationSequence, got *dStub.Stub t.Fatalf("\nexpected sequence %v,\ngot %v, in %v", bs, got.MigrationSequence, i) } } + +// Setting up temp directory to be used as a PVC mount +func setupTempDir(t *testing.T) (string, func()) { + tempDir, err := os.MkdirTemp("", "migrate_test") + if err != nil { + t.Fatal(err) + } + return tempDir, func() { + if err = os.RemoveAll(tempDir); err != nil { + t.Fatal(err) + } + } +} + +func setupMigrateInstance(tempDir string) (*Migrate, *dStub.Stub) { + m, _ := New("stub://", "stub://") + m.ds = &dirtyStateHandler{ + destPath: tempDir, + isDirty: true, + } + return m, m.databaseDrv.(*dStub.Stub) +} + +func setupSourceStubMigrations() *source.Migrations { + migrations := source.NewMigrations() + migrations.Append(&source.Migration{Version: 1, Direction: source.Up, Identifier: "CREATE 1"}) + migrations.Append(&source.Migration{Version: 1, Direction: source.Down, Identifier: "DROP 1"}) + migrations.Append(&source.Migration{Version: 2, Direction: source.Up, Identifier: "CREATE 2"}) + migrations.Append(&source.Migration{Version: 2, Direction: source.Down, Identifier: "DROP 2"}) + migrations.Append(&source.Migration{Version: 3, Direction: source.Up, Identifier: "CREATE 3"}) + migrations.Append(&source.Migration{Version: 3, Direction: source.Down, Identifier: "DROP 3"}) + migrations.Append(&source.Migration{Version: 4, Direction: source.Up, Identifier: "CREATE 4"}) + migrations.Append(&source.Migration{Version: 4, Direction: source.Down, Identifier: "DROP 4"}) + migrations.Append(&source.Migration{Version: 5, Direction: source.Up, Identifier: "CREATE 5"}) + migrations.Append(&source.Migration{Version: 5, Direction: source.Down, Identifier: "DROP 5"}) + migrations.Append(&source.Migration{Version: 6, Direction: source.Up, Identifier: "CREATE 6"}) + migrations.Append(&source.Migration{Version: 6, Direction: source.Down, Identifier: "DROP 6"}) + migrations.Append(&source.Migration{Version: 7, Direction: source.Up, Identifier: "CREATE 7"}) + migrations.Append(&source.Migration{Version: 7, Direction: source.Down, Identifier: "DROP 7"}) + return migrations +} + +func TestHandleDirtyState(t *testing.T) { + tempDir, cleanup := setupTempDir(t) + defer cleanup() + + m, dbDrv := setupMigrateInstance(tempDir) + m.sourceDrv.(*sStub.Stub).Migrations = setupSourceStubMigrations() + + tests := []struct { + lastSuccessful int + currentVersion int + err error + setupFailure bool + }{ + {lastSuccessful: 1, currentVersion: 2, err: nil, setupFailure: false}, + {lastSuccessful: 4, currentVersion: 5, err: nil, setupFailure: false}, + {lastSuccessful: 3, currentVersion: 4, err: nil, setupFailure: false}, + {lastSuccessful: -3, currentVersion: 4, err: ErrInvalidVersion, setupFailure: false}, + {lastSuccessful: 4, currentVersion: 3, err: fmt.Errorf("open %s: no such file or directory", filepath.Join(tempDir, lastSuccessfulMigrationFile)), setupFailure: true}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + var lastSuccessfulMigrationPath string + // setupFailure tests scenario where the 'lastSuccessfulMigrationFile' doesn't exist + if !test.setupFailure { + lastSuccessfulMigrationPath = filepath.Join(tempDir, lastSuccessfulMigrationFile) + if err := os.WriteFile(lastSuccessfulMigrationPath, []byte(strconv.Itoa(test.lastSuccessful)), 0644); err != nil { + t.Fatal(err) + } + } + // Setting the DB version as dirty + if err := dbDrv.SetVersion(test.currentVersion, true); err != nil { + t.Fatal(err) + } + + // Quick check to see if set correctly + version, b, err := dbDrv.Version() + if err != nil { + t.Fatal(err) + } + if version != test.currentVersion { + t.Fatalf("expected version %d, got %d", test.currentVersion, version) + } + + if !b { + t.Fatalf("expected false, got true") + } + + // Handle dirty state + if err = m.HandleDirtyState(); err != nil { + if strings.Contains(err.Error(), test.err.Error()) { + t.Logf("expected error %v, got %v", test.err, err) + if !test.setupFailure { + if err = os.Remove(lastSuccessfulMigrationPath); err != nil { + t.Fatal(err) + } + } + return + } else { + t.Fatal(err) + } + } + // Check 1: DB should no longer be dirty + if dbDrv.IsDirty { + t.Fatalf("expected dirty to be false, got true") + } + // Check 2: Current version should be the last successful version + if dbDrv.CurrentVersion != test.lastSuccessful { + t.Fatalf("expected version %d, got %d", test.lastSuccessful, dbDrv.CurrentVersion) + } + // Check 3: The lastSuccessfulMigration file shouldn't exists + if _, err = os.Stat(lastSuccessfulMigrationPath); !os.IsNotExist(err) { + t.Fatalf("expected file to be deleted, but it still exists") + } + }) + } +} + +func TestHandleMigrationFailure(t *testing.T) { + tempDir, cleanup := setupTempDir(t) + defer cleanup() + + m, dbDrv := setupMigrateInstance(tempDir) + m.sourceDrv.(*sStub.Stub).Migrations = setupSourceStubMigrations() + + tests := []struct { + curVersion int + targetVersion uint + dirtyVersion int + }{ + {curVersion: 1, targetVersion: 7, dirtyVersion: 4}, + {curVersion: 4, targetVersion: 6, dirtyVersion: 5}, + {curVersion: 3, targetVersion: 7, dirtyVersion: 6}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + t.Cleanup(func() { + m.sourceDrv.(*sStub.Stub).Migrations = setupSourceStubMigrations() + dbDrv = m.databaseDrv.(*dStub.Stub) + }) + + // Setup: Simulate a migration failure by setting the dirty version in the DB + if err := dbDrv.SetVersion(test.dirtyVersion, true); err != nil { + t.Fatal(err) + } + + // Test + if err := m.HandleMigrationFailure(test.curVersion, test.targetVersion); err != nil { + t.Fatal(err) + } + + // Check 1: Should no longer be dirty + if !dbDrv.IsDirty { + t.Fatalf("expected dirty to be true, got false") + } + + // Check 2: last successful Migration version should be stored in a file + lastSuccessfulMigrationPath := filepath.Join(tempDir, lastSuccessfulMigrationFile) + if _, err := os.Stat(lastSuccessfulMigrationPath); os.IsNotExist(err) { + t.Fatalf("expected file to be created, but it does not exist") + } + + // Check 3: Check if the content of last successful migration has the correct version + content, err := os.ReadFile(lastSuccessfulMigrationPath) + if err != nil { + t.Fatal(err) + } + + if string(content) != strconv.Itoa(test.dirtyVersion-1) { + t.Fatalf("expected %d, got %s", test.dirtyVersion-1, string(content)) + } + }) + } +} + +func TestCleanupFiles(t *testing.T) { + tempDir, cleanup := setupTempDir(t) + defer cleanup() + + m, _ := setupMigrateInstance(tempDir) + + tests := []struct { + migrationFiles []string + targetVersion uint + remainingFiles []string + emptyDestPath bool + }{ + { + migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + targetVersion: 2, + remainingFiles: []string{"1_up.sql", "2_up.sql"}, + }, + { + migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql", "4_up.sql", "5_up.sql"}, + targetVersion: 3, + remainingFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + }, + { + migrationFiles: []string{}, + targetVersion: 1, + remainingFiles: []string{}, + emptyDestPath: true, + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + for _, file := range test.migrationFiles { + if err := os.WriteFile(filepath.Join(tempDir, file), []byte(""), 0644); err != nil { + t.Fatal(err) + } + } + + if test.emptyDestPath { + m.ds.destPath = "" + } + + if err := m.CleanupFiles(test.targetVersion); err != nil { + t.Fatal(err) + } + + // check 1: only files upto the target version should exist + for _, file := range test.remainingFiles { + if _, err := os.Stat(filepath.Join(tempDir, file)); os.IsNotExist(err) { + t.Fatalf("expected file %s to exist, but it does not", file) + } + } + + // check 2: the files removed are as expected + deletedFiles := diff(test.migrationFiles, test.remainingFiles) + for _, deletedFile := range deletedFiles { + if _, err := os.Stat(filepath.Join(tempDir, deletedFile)); !os.IsNotExist(err) { + t.Fatalf("expected file %s to be deleted, but it still exists", deletedFile) + } + } + }) + } +} + +func TestCopyFiles(t *testing.T) { + srcDir, cleanupSrc := setupTempDir(t) + defer cleanupSrc() + + destDir, cleanupDest := setupTempDir(t) + defer cleanupDest() + + m, _ := New("stub://", "stub://") + m.ds = &dirtyStateHandler{ + srcPath: srcDir, + destPath: destDir, + } + + tests := []struct { + migrationFiles []string + copiedFiles []string + emptyDestPath bool + }{ + { + migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + copiedFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + }, + { + migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql", "4_up.sql", "current.sql"}, + copiedFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql", "4_up.sql"}, + }, + { + emptyDestPath: true, + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + for _, file := range test.migrationFiles { + if err := os.WriteFile(filepath.Join(srcDir, file), []byte(""), 0644); err != nil { + t.Fatal(err) + } + } + if test.emptyDestPath { + m.ds.destPath = "" + } + + if err := m.CopyFiles(); err != nil { + t.Fatal(err) + } + + for _, file := range test.copiedFiles { + if _, err := os.Stat(filepath.Join(destDir, file)); os.IsNotExist(err) { + t.Fatalf("expected file %s to be copied, but it does not exist", file) + } + } + }) + } +} + +/* + diff returns an array containing the elements in Array A and not in B +*/ + +func diff(a, b []string) []string { + temp := map[string]int{} + for _, s := range a { + temp[s]++ + } + for _, s := range b { + temp[s]-- + } + + var result []string + for s, v := range temp { + if v != 0 { + result = append(result, s) + } + } + return result +} diff --git a/test/main.go b/test/main.go deleted file mode 100644 index 56e540407..000000000 --- a/test/main.go +++ /dev/null @@ -1 +0,0 @@ -package test From 6bebd334c8ed0c0db2a09511d93b2672c2b29666 Mon Sep 17 00:00:00 2001 From: Venkat Venkatasubramanian Date: Wed, 2 Oct 2024 16:41:28 -0700 Subject: [PATCH 4/4] address comments --- cmd/migrate/config.go | 4 +- internal/cli/main.go | 30 ++--- migrate.go | 258 ++++++++++++++++++++---------------------- migrate_test.go | 137 ++++++++-------------- 4 files changed, 184 insertions(+), 245 deletions(-) diff --git a/cmd/migrate/config.go b/cmd/migrate/config.go index de812156a..905b5af5d 100644 --- a/cmd/migrate/config.go +++ b/cmd/migrate/config.go @@ -37,6 +37,6 @@ var ( flagConfigFile = pflag.String("config.file", "", "configuration file name without extension") // goto command flags - flagDirty = pflag.Bool("dirty", false, "migration is dirty") - flagPVCPath = pflag.String("intermediate-path", "", "path to the mounted volume which is used to copy the migration files") + flagDirty = pflag.Bool("force-dirty-handling", false, "force the handling of dirty database state") + flagMountPath = pflag.String("cache-dir", "", "path to the mounted volume which is used to copy the migration files") ) diff --git a/internal/cli/main.go b/internal/cli/main.go index 5158de116..d15e33064 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -29,7 +29,9 @@ const ( Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error. Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC). ` - gotoUsage = `goto V [-dirty] [-intermediate-path] Migrate to version V` + gotoUsage = `goto V [-force-dirty-handling] [-cache-dir P] Migrate to version V + Use -force-dirty-handling to handle dirty database state + Use -cache-dir to specify the intermediate path P for storing migrations` upUsage = `up [N] Apply all or N up migrations` downUsage = `down [N] [-all] Apply all or N down migrations Use -all to apply all down migrations` @@ -262,28 +264,18 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU if err != nil { log.fatal("error: can't read version argument V") } - handleDirty := viper.GetBool("dirty") - destPath := viper.GetString("intermediate-path") - srcPath := "" - // if sourcePtr is set, use it to get the source path - // otherwise, use the path flag - if path != "" { - srcPath = path - } - if sourcePtr != "" { - // parse the source path from the source argument - parse, err := url.Parse(sourcePtr) - if err != nil { - log.fatal("error: can't parse the source path from the source argument") + handleDirty := viper.GetBool("force-dirty-handling") + if handleDirty { + destPath := viper.GetString("cache-dir") + if destPath == "" { + log.fatal("error: cache-dir must be specified when force-dirty-handling is set") } - srcPath = parse.Path - } - if handleDirty && destPath == "" { - log.fatal("error: intermediate-path must be specified when dirty is set") + if err = migrater.WithDirtyStateHandler(sourcePtr, destPath, handleDirty); err != nil { + log.fatalErr(err) + } } - migrater.WithDirtyStateHandler(srcPath, destPath, handleDirty) if err = gotoCmd(migrater, uint(v)); err != nil { log.fatalErr(err) } diff --git a/migrate.go b/migrate.go index 0793924c1..1d1187254 100644 --- a/migrate.go +++ b/migrate.go @@ -7,7 +7,7 @@ package migrate import ( "errors" "fmt" - "io" + "net/url" "os" "path/filepath" "strconv" @@ -89,13 +89,19 @@ type Migrate struct { LockTimeout time.Duration // DirtyStateHandler is used to handle dirty state of the database - ds *dirtyStateHandler + dirtyStateConf *dirtyStateHandler } type dirtyStateHandler struct { - srcPath string - destPath string - isDirty bool + srcScheme string + srcPath string + destScheme string + destPath string + enable bool +} + +func (m *Migrate) IsDirtyHandlingEnabled() bool { + return m.dirtyStateConf != nil && m.dirtyStateConf.enable && m.dirtyStateConf.destPath != "" } // New returns a new Migrate instance from a source URL and a database URL. @@ -131,6 +137,11 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { } func (m *Migrate) updateSourceDrv(sourceURL string) error { + sourceName, err := iurl.SchemeFromURL(sourceURL) + if err != nil { + return fmt.Errorf("failed to parse scheme from source URL: %w", err) + } + m.sourceName = sourceName sourceDrv, err := source.Open(sourceURL) if err != nil { return fmt.Errorf("failed to open source, %q: %w", sourceURL, err) @@ -207,12 +218,40 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa return m, nil } -func (m *Migrate) WithDirtyStateHandler(srcPath, destPath string, isDirty bool) { - m.ds = &dirtyStateHandler{ - srcPath: srcPath, - destPath: destPath, - isDirty: isDirty, +func (m *Migrate) WithDirtyStateHandler(srcPath, destPath string, isDirty bool) error { + parser := func(path string) (string, string, error) { + var scheme, p string + uri, err := url.Parse(path) + if err != nil { + return "", "", err + } + scheme = uri.Scheme + p = uri.Path + // if no scheme is provided, assume it's a file path + if scheme == "" { + scheme = "file://" + } + return scheme, p, nil + } + + sScheme, sPath, err := parser(srcPath) + if err != nil { + return err + } + + dScheme, dPath, err := parser(destPath) + if err != nil { + return err } + + m.dirtyStateConf = &dirtyStateHandler{ + srcScheme: sScheme, + destScheme: dScheme, + srcPath: sPath, + destPath: dPath, + enable: isDirty, + } + return nil } func newCommon() *Migrate { @@ -255,29 +294,19 @@ func (m *Migrate) Migrate(version uint) error { // if the dirty flag is passed to the 'goto' command, handle the dirty state if dirty { - if m.ds != nil && m.ds.isDirty { - if err = m.unlock(); err != nil { - return m.unlockErr(err) - } - if err = m.HandleDirtyState(); err != nil { + if m.IsDirtyHandlingEnabled() { + if err = m.handleDirtyState(); err != nil { return m.unlockErr(err) } - if err = m.lock(); err != nil { - return err - } - if err = m.updateSourceDrv("file://" + m.ds.destPath); err != nil { - return m.unlockErr(err) - } - } else { - // default behaviour + // default behavior return m.unlockErr(ErrDirty{curVersion}) } } // Copy migrations to the destination directory, // if state was dirty when Migrate was called, we should handle the dirty state first before copying the migrations - if err = m.CopyFiles(); err != nil { + if err = m.copyFiles(); err != nil { return m.unlockErr(err) } @@ -285,16 +314,11 @@ func (m *Migrate) Migrate(version uint) error { go m.read(curVersion, int(version), ret) if err = m.runMigrations(ret); err != nil { - if m.ds != nil && m.ds.isDirty { - // Handle failure: store last successful migration version and exit - if err = m.HandleMigrationFailure(curVersion, version); err != nil { - return m.unlockErr(err) - } - } return m.unlockErr(err) } // Success: Clean up and confirm - if err = m.CleanupFiles(version); err != nil { + // Files are cleaned up after the migration is successful + if err = m.cleanupFiles(version); err != nil { return m.unlockErr(err) } // unlock the database @@ -793,6 +817,7 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { // to stop execution because it might have received a stop signal on the // GracefulStop channel. func (m *Migrate) runMigrations(ret <-chan interface{}) error { + var lastCleanMigrationApplied int for r := range ret { if m.stop() { @@ -814,6 +839,15 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { if migr.Body != nil { m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { + if m.dirtyStateConf != nil && m.dirtyStateConf.enable { + // this condition is required if the first migration fails + if lastCleanMigrationApplied == 0 { + lastCleanMigrationApplied = migr.TargetVersion + } + if e := m.handleMigrationFailure(lastCleanMigrationApplied); e != nil { + return multierror.Append(err, e) + } + } return err } } @@ -822,7 +856,7 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil { return err } - + lastCleanMigrationApplied = migr.TargetVersion endTime := time.Now() readTime := migr.FinishedReading.Sub(migr.StartedBuffering) runTime := endTime.Sub(migr.FinishedReading) @@ -1050,9 +1084,20 @@ func (m *Migrate) logErr(err error) { } } -func (m *Migrate) HandleDirtyState() error { - // Perform actions when the database state is dirty - lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) +func (m *Migrate) handleDirtyState() error { + // Perform the following actions when the database state is dirty + /* + 1. Update the source driver to read the migrations from the mounted volume + 2. Read the last successful migration version from the file + 3. Set the last successful migration version in the schema_migrations table + 4. Delete the last successful migration file + */ + // the source driver should read the migrations from the mounted volume + // as the DB is dirty and last applied migrations to the database are not present in the source path + if err := m.updateSourceDrv(m.dirtyStateConf.destScheme + m.dirtyStateConf.destPath); err != nil { + return err + } + lastSuccessfulMigrationPath := filepath.Join(m.dirtyStateConf.destPath, lastSuccessfulMigrationFile) lastVersionBytes, err := os.ReadFile(lastSuccessfulMigrationPath) if err != nil { return err @@ -1063,11 +1108,12 @@ func (m *Migrate) HandleDirtyState() error { return fmt.Errorf("failed to parse last successful migration version: %w", err) } - if err = m.Force(int(lastVersion)); err != nil { + // Set the last successful migration version in the schema_migrations table + if err = m.databaseDrv.SetVersion(int(lastVersion), false); err != nil { return fmt.Errorf("failed to apply last successful migration: %w", err) } - m.logPrintf("Successfully applied migration: %s", lastVersionStr) + m.logPrintf("Successfully set last successful migration version: %s on the DB", lastVersionStr) if err = os.Remove(lastSuccessfulMigrationPath); err != nil { return err @@ -1077,134 +1123,74 @@ func (m *Migrate) HandleDirtyState() error { return nil } -func (m *Migrate) HandleMigrationFailure(curVersion int, v uint) error { - failedVersion, _, err := m.databaseDrv.Version() - if err != nil { - return err - } - // Determine the last successful migration - lastSuccessfulMigration := strconv.Itoa(curVersion) - ret := make(chan interface{}, m.PrefetchMigrations) - go m.read(curVersion, int(v), ret) - - for r := range ret { - mig, ok := r.(*Migration) - if ok { - if mig.Version == uint(failedVersion) { - break - } - lastSuccessfulMigration = strconv.Itoa(int(mig.Version)) - } +func (m *Migrate) handleMigrationFailure(lastSuccessfulMigration int) error { + if !m.IsDirtyHandlingEnabled() { + return nil } - - lastSuccessfulMigrationPath := filepath.Join(m.ds.destPath, lastSuccessfulMigrationFile) - return os.WriteFile(lastSuccessfulMigrationPath, []byte(lastSuccessfulMigration), 0644) + lastSuccessfulMigrationPath := filepath.Join(m.dirtyStateConf.destPath, lastSuccessfulMigrationFile) + return os.WriteFile(lastSuccessfulMigrationPath, []byte(strconv.Itoa(lastSuccessfulMigration)), 0644) } -func (m *Migrate) CleanupFiles(v uint) error { - if m.ds == nil || m.ds.destPath == "" { +func (m *Migrate) cleanupFiles(targetVersion uint) error { + if !m.IsDirtyHandlingEnabled() { return nil } - files, err := os.ReadDir(m.ds.destPath) + + files, err := os.ReadDir(m.dirtyStateConf.destPath) if err != nil { - return err + // If the directory does not exist + return fmt.Errorf("failed to read directory %s: %w", m.dirtyStateConf.destPath, err) } - targetVersion := uint64(v) - for _, file := range files { fileName := file.Name() - - // Check if file is a migration file we want to process - if !strings.HasSuffix(fileName, "down.sql") && !strings.HasSuffix(fileName, "up.sql") { - continue - } - - // Extract version and compare - versionEnd := strings.Index(fileName, "_") - if versionEnd == -1 { - // Skip files that don't match the expected naming pattern - continue - } - - fileVersion, err := strconv.ParseUint(fileName[:versionEnd], 10, 64) + migration, err := source.Parse(fileName) if err != nil { - m.logErr(fmt.Errorf("skipping file %s due to version parse error: %v", fileName, err)) - continue + return err } - // Delete file if version is greater than targetVersion - if fileVersion > targetVersion { - if err = os.Remove(filepath.Join(m.ds.destPath, fileName)); err != nil { + if migration.Version > targetVersion { + if err = os.Remove(filepath.Join(m.dirtyStateConf.destPath, fileName)); err != nil { m.logErr(fmt.Errorf("failed to delete file %s: %v", fileName, err)) continue } - m.logPrintf("Deleted file: %s", fileName) + m.logPrintf("Migration file: %s removed during cleanup", fileName) } } return nil } -// CopyFiles copies all files from srcDir to destDir. -func (m *Migrate) CopyFiles() error { - if m.ds == nil || m.ds.destPath == "" { +// copyFiles copies all files from source to destination volume. +func (m *Migrate) copyFiles() error { + // this is the case when the dirty handling is disabled + if !m.IsDirtyHandlingEnabled() { return nil } - _, err := os.ReadDir(m.ds.destPath) + + files, err := os.ReadDir(m.dirtyStateConf.srcPath) if err != nil { // If the directory does not exist - return err + return fmt.Errorf("failed to read directory %s: %w", m.dirtyStateConf.srcPath, err) } - - m.logPrintf("Copying files from %s to %s", m.ds.srcPath, m.ds.destPath) - - return filepath.Walk(m.ds.srcPath, func(src string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - // ignore sub-directories in the migration directory - if info.IsDir() { - // Skip the tests directory and its files - if info.Name() == "tests" { - return filepath.SkipDir + m.logPrintf("Copying files from %s to %s", m.dirtyStateConf.srcPath, m.dirtyStateConf.destPath) + for _, file := range files { + fileName := file.Name() + if source.Regex.MatchString(fileName) { + fileContentBytes, err := os.ReadFile(filepath.Join(m.dirtyStateConf.srcPath, fileName)) + if err != nil { + return err } - return nil - } - // Ignore the current.sql file - if info.Name() == "current.sql" { - return nil - } - - var ( - srcFile *os.File - destFile *os.File - ) - dest := filepath.Join(m.ds.destPath, info.Name()) - if srcFile, err = os.Open(src); err != nil { - return err - } - defer func(srcFile *os.File) { - if err = srcFile.Close(); err != nil { - m.logErr(fmt.Errorf("failed to close file %s: %v", destFile.Name(), err)) + info, err := file.Info() + if err != nil { + return err } - }(srcFile) - - // Create the destination file - if destFile, err = os.Create(dest); err != nil { - return err - } - defer func(destFile *os.File) { - if err = destFile.Close(); err != nil { - m.logErr(fmt.Errorf("failed to close file %s: %v", destFile.Name(), err)) + if err = os.WriteFile(filepath.Join(m.dirtyStateConf.destPath, fileName), fileContentBytes, info.Mode().Perm()); err != nil { + return err } - }(destFile) - - // Copy the file - if _, err = io.Copy(destFile, srcFile); err != nil { - return err } - return os.Chmod(dest, info.Mode()) - }) + } + + m.logPrintf("Successfully Copied files from %s to %s", m.dirtyStateConf.srcPath, m.dirtyStateConf.destPath) + return nil } diff --git a/migrate_test.go b/migrate_test.go index 61c6b7d73..33bcb2cd9 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -1418,7 +1418,7 @@ func equalDbSeq(t *testing.T, i int, expected migrationSequence, got *dStub.Stub } } -// Setting up temp directory to be used as a PVC mount +// Setting up temp directory to be used as the volume mount func setupTempDir(t *testing.T) (string, func()) { tempDir, err := os.MkdirTemp("", "migrate_test") if err != nil { @@ -1432,60 +1432,43 @@ func setupTempDir(t *testing.T) (string, func()) { } func setupMigrateInstance(tempDir string) (*Migrate, *dStub.Stub) { - m, _ := New("stub://", "stub://") - m.ds = &dirtyStateHandler{ - destPath: tempDir, - isDirty: true, + scheme := "stub://" + m, _ := New(scheme, scheme) + m.dirtyStateConf = &dirtyStateHandler{ + destScheme: scheme, + destPath: tempDir, + enable: true, } return m, m.databaseDrv.(*dStub.Stub) } -func setupSourceStubMigrations() *source.Migrations { - migrations := source.NewMigrations() - migrations.Append(&source.Migration{Version: 1, Direction: source.Up, Identifier: "CREATE 1"}) - migrations.Append(&source.Migration{Version: 1, Direction: source.Down, Identifier: "DROP 1"}) - migrations.Append(&source.Migration{Version: 2, Direction: source.Up, Identifier: "CREATE 2"}) - migrations.Append(&source.Migration{Version: 2, Direction: source.Down, Identifier: "DROP 2"}) - migrations.Append(&source.Migration{Version: 3, Direction: source.Up, Identifier: "CREATE 3"}) - migrations.Append(&source.Migration{Version: 3, Direction: source.Down, Identifier: "DROP 3"}) - migrations.Append(&source.Migration{Version: 4, Direction: source.Up, Identifier: "CREATE 4"}) - migrations.Append(&source.Migration{Version: 4, Direction: source.Down, Identifier: "DROP 4"}) - migrations.Append(&source.Migration{Version: 5, Direction: source.Up, Identifier: "CREATE 5"}) - migrations.Append(&source.Migration{Version: 5, Direction: source.Down, Identifier: "DROP 5"}) - migrations.Append(&source.Migration{Version: 6, Direction: source.Up, Identifier: "CREATE 6"}) - migrations.Append(&source.Migration{Version: 6, Direction: source.Down, Identifier: "DROP 6"}) - migrations.Append(&source.Migration{Version: 7, Direction: source.Up, Identifier: "CREATE 7"}) - migrations.Append(&source.Migration{Version: 7, Direction: source.Down, Identifier: "DROP 7"}) - return migrations -} - func TestHandleDirtyState(t *testing.T) { tempDir, cleanup := setupTempDir(t) defer cleanup() m, dbDrv := setupMigrateInstance(tempDir) - m.sourceDrv.(*sStub.Stub).Migrations = setupSourceStubMigrations() + m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations tests := []struct { - lastSuccessful int - currentVersion int - err error - setupFailure bool + lastSuccessfulVersion int + currentVersion int + err error + setupFailure bool }{ - {lastSuccessful: 1, currentVersion: 2, err: nil, setupFailure: false}, - {lastSuccessful: 4, currentVersion: 5, err: nil, setupFailure: false}, - {lastSuccessful: 3, currentVersion: 4, err: nil, setupFailure: false}, - {lastSuccessful: -3, currentVersion: 4, err: ErrInvalidVersion, setupFailure: false}, - {lastSuccessful: 4, currentVersion: 3, err: fmt.Errorf("open %s: no such file or directory", filepath.Join(tempDir, lastSuccessfulMigrationFile)), setupFailure: true}, + {lastSuccessfulVersion: 1, currentVersion: 3, err: nil, setupFailure: false}, + {lastSuccessfulVersion: 4, currentVersion: 7, err: nil, setupFailure: false}, + {lastSuccessfulVersion: 3, currentVersion: 4, err: nil, setupFailure: false}, + {lastSuccessfulVersion: -3, currentVersion: 4, err: ErrInvalidVersion, setupFailure: false}, + {lastSuccessfulVersion: 4, currentVersion: 3, err: fmt.Errorf("open %s: no such file or directory", filepath.Join(tempDir, lastSuccessfulMigrationFile)), setupFailure: true}, } for _, test := range tests { t.Run("", func(t *testing.T) { var lastSuccessfulMigrationPath string - // setupFailure tests scenario where the 'lastSuccessfulMigrationFile' doesn't exist + // setupFailure flag helps with testing scenario where the 'lastSuccessfulMigrationFile' doesn't exist if !test.setupFailure { lastSuccessfulMigrationPath = filepath.Join(tempDir, lastSuccessfulMigrationFile) - if err := os.WriteFile(lastSuccessfulMigrationPath, []byte(strconv.Itoa(test.lastSuccessful)), 0644); err != nil { + if err := os.WriteFile(lastSuccessfulMigrationPath, []byte(strconv.Itoa(test.lastSuccessfulVersion)), 0644); err != nil { t.Fatal(err) } } @@ -1504,11 +1487,11 @@ func TestHandleDirtyState(t *testing.T) { } if !b { - t.Fatalf("expected false, got true") + t.Fatalf("expected DB to be dirty, got false") } // Handle dirty state - if err = m.HandleDirtyState(); err != nil { + if err = m.handleDirtyState(); err != nil { if strings.Contains(err.Error(), test.err.Error()) { t.Logf("expected error %v, got %v", test.err, err) if !test.setupFailure { @@ -1526,10 +1509,10 @@ func TestHandleDirtyState(t *testing.T) { t.Fatalf("expected dirty to be false, got true") } // Check 2: Current version should be the last successful version - if dbDrv.CurrentVersion != test.lastSuccessful { - t.Fatalf("expected version %d, got %d", test.lastSuccessful, dbDrv.CurrentVersion) + if dbDrv.CurrentVersion != test.lastSuccessfulVersion { + t.Fatalf("expected version %d, got %d", test.lastSuccessfulVersion, dbDrv.CurrentVersion) } - // Check 3: The lastSuccessfulMigration file shouldn't exists + // Check 3: The lastSuccessfulMigration file shouldn't exist if _, err = os.Stat(lastSuccessfulMigrationPath); !os.IsNotExist(err) { t.Fatalf("expected file to be deleted, but it still exists") } @@ -1541,55 +1524,35 @@ func TestHandleMigrationFailure(t *testing.T) { tempDir, cleanup := setupTempDir(t) defer cleanup() - m, dbDrv := setupMigrateInstance(tempDir) - m.sourceDrv.(*sStub.Stub).Migrations = setupSourceStubMigrations() + m, _ := setupMigrateInstance(tempDir) tests := []struct { - curVersion int - targetVersion uint - dirtyVersion int + lastSuccessFulVersion int }{ - {curVersion: 1, targetVersion: 7, dirtyVersion: 4}, - {curVersion: 4, targetVersion: 6, dirtyVersion: 5}, - {curVersion: 3, targetVersion: 7, dirtyVersion: 6}, + {lastSuccessFulVersion: 3}, + {lastSuccessFulVersion: 4}, + {lastSuccessFulVersion: 5}, } for _, test := range tests { t.Run("", func(t *testing.T) { - t.Cleanup(func() { - m.sourceDrv.(*sStub.Stub).Migrations = setupSourceStubMigrations() - dbDrv = m.databaseDrv.(*dStub.Stub) - }) - - // Setup: Simulate a migration failure by setting the dirty version in the DB - if err := dbDrv.SetVersion(test.dirtyVersion, true); err != nil { - t.Fatal(err) - } - - // Test - if err := m.HandleMigrationFailure(test.curVersion, test.targetVersion); err != nil { + if err := m.handleMigrationFailure(test.lastSuccessFulVersion); err != nil { t.Fatal(err) } - - // Check 1: Should no longer be dirty - if !dbDrv.IsDirty { - t.Fatalf("expected dirty to be true, got false") - } - - // Check 2: last successful Migration version should be stored in a file + // Check 1: last successful Migration version should be stored in a file lastSuccessfulMigrationPath := filepath.Join(tempDir, lastSuccessfulMigrationFile) if _, err := os.Stat(lastSuccessfulMigrationPath); os.IsNotExist(err) { t.Fatalf("expected file to be created, but it does not exist") } - // Check 3: Check if the content of last successful migration has the correct version + // Check 2: Check if the content of last successful migration has the correct version content, err := os.ReadFile(lastSuccessfulMigrationPath) if err != nil { t.Fatal(err) } - if string(content) != strconv.Itoa(test.dirtyVersion-1) { - t.Fatalf("expected %d, got %s", test.dirtyVersion-1, string(content)) + if string(content) != strconv.Itoa(test.lastSuccessFulVersion) { + t.Fatalf("expected %d, got %s", test.lastSuccessFulVersion, string(content)) } }) } @@ -1600,6 +1563,7 @@ func TestCleanupFiles(t *testing.T) { defer cleanup() m, _ := setupMigrateInstance(tempDir) + m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations tests := []struct { migrationFiles []string @@ -1608,14 +1572,14 @@ func TestCleanupFiles(t *testing.T) { emptyDestPath bool }{ { - migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, targetVersion: 2, - remainingFiles: []string{"1_up.sql", "2_up.sql"}, + remainingFiles: []string{"1_name.up.sql", "2_name.up.sql"}, }, { - migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql", "4_up.sql", "5_up.sql"}, + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql", "4_name.up.sql", "5_name.up.sql"}, targetVersion: 3, - remainingFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + remainingFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, }, { migrationFiles: []string{}, @@ -1634,10 +1598,10 @@ func TestCleanupFiles(t *testing.T) { } if test.emptyDestPath { - m.ds.destPath = "" + m.dirtyStateConf.destPath = "" } - if err := m.CleanupFiles(test.targetVersion); err != nil { + if err := m.cleanupFiles(test.targetVersion); err != nil { t.Fatal(err) } @@ -1666,11 +1630,8 @@ func TestCopyFiles(t *testing.T) { destDir, cleanupDest := setupTempDir(t) defer cleanupDest() - m, _ := New("stub://", "stub://") - m.ds = &dirtyStateHandler{ - srcPath: srcDir, - destPath: destDir, - } + m, _ := setupMigrateInstance(destDir) + m.dirtyStateConf.srcPath = srcDir tests := []struct { migrationFiles []string @@ -1678,12 +1639,12 @@ func TestCopyFiles(t *testing.T) { emptyDestPath bool }{ { - migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, - copiedFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql"}, + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, + copiedFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, }, { - migrationFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql", "4_up.sql", "current.sql"}, - copiedFiles: []string{"1_up.sql", "2_up.sql", "3_up.sql", "4_up.sql"}, + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql", "4_name.up.sql", "current.sql"}, + copiedFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql", "4_name.up.sql"}, }, { emptyDestPath: true, @@ -1698,10 +1659,10 @@ func TestCopyFiles(t *testing.T) { } } if test.emptyDestPath { - m.ds.destPath = "" + m.dirtyStateConf.destPath = "" } - if err := m.CopyFiles(); err != nil { + if err := m.copyFiles(); err != nil { t.Fatal(err) }