diff --git a/cmd/migrate/config.go b/cmd/migrate/config.go index e5a946922..905b5af5d 100644 --- a/cmd/migrate/config.go +++ b/cmd/migrate/config.go @@ -35,4 +35,8 @@ var ( flagConfigDirectory = pflag.String("config.source", defaultConfigDirectory, "directory of the configuration file") flagConfigFile = pflag.String("config.file", "", "configuration file name without extension") + + // goto command flags + flagDirty = pflag.Bool("force-dirty-handling", false, "force the handling of dirty database state") + flagMountPath = pflag.String("cache-dir", "", "path to the mounted volume which is used to copy the migration files") ) diff --git a/internal/cli/main.go b/internal/cli/main.go index ece7eff0b..d15e33064 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -29,7 +29,9 @@ const ( Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error. Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC). ` - gotoUsage = `goto V Migrate to version V` + gotoUsage = `goto V [-force-dirty-handling] [-cache-dir P] Migrate to version V + Use -force-dirty-handling to handle dirty database state + Use -cache-dir to specify the intermediate path P for storing migrations` upUsage = `up [N] Apply all or N up migrations` downUsage = `down [N] [-all] Apply all or N down migrations Use -all to apply all down migrations` @@ -262,8 +264,19 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU if err != nil { log.fatal("error: can't read version argument V") } + handleDirty := viper.GetBool("force-dirty-handling") + if handleDirty { + destPath := viper.GetString("cache-dir") + if destPath == "" { + log.fatal("error: cache-dir must be specified when force-dirty-handling is set") + } + + if err = migrater.WithDirtyStateHandler(sourcePtr, destPath, handleDirty); err != nil { + log.fatalErr(err) + } + } - if err := gotoCmd(migrater, uint(v)); err != nil { + if err = gotoCmd(migrater, uint(v)); err != nil { log.fatalErr(err) } diff --git a/migrate.go b/migrate.go index 7763782a0..1d1187254 100644 --- a/migrate.go +++ b/migrate.go @@ -7,7 +7,11 @@ package migrate import ( "errors" "fmt" + "net/url" "os" + "path/filepath" + "strconv" + "strings" "sync" "time" @@ -36,6 +40,9 @@ var ( ErrLockTimeout = errors.New("timeout: can't acquire database lock") ) +// Define a constant for the migration file name +const lastSuccessfulMigrationFile = "lastSuccessfulMigration" + // ErrShortLimit is an error returned when not enough migrations // can be returned by a source for a given limit. type ErrShortLimit struct { @@ -80,6 +87,21 @@ type Migrate struct { // LockTimeout defaults to DefaultLockTimeout, // but can be set per Migrate instance. LockTimeout time.Duration + + // DirtyStateHandler is used to handle dirty state of the database + dirtyStateConf *dirtyStateHandler +} + +type dirtyStateHandler struct { + srcScheme string + srcPath string + destScheme string + destPath string + enable bool +} + +func (m *Migrate) IsDirtyHandlingEnabled() bool { + return m.dirtyStateConf != nil && m.dirtyStateConf.enable && m.dirtyStateConf.destPath != "" } // New returns a new Migrate instance from a source URL and a database URL. @@ -114,6 +136,20 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { return m, nil } +func (m *Migrate) updateSourceDrv(sourceURL string) error { + sourceName, err := iurl.SchemeFromURL(sourceURL) + if err != nil { + return fmt.Errorf("failed to parse scheme from source URL: %w", err) + } + m.sourceName = sourceName + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv + return nil +} + // NewWithDatabaseInstance returns a new Migrate instance from a source URL // and an existing database instance. The source URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as databaseName. @@ -182,6 +218,42 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa return m, nil } +func (m *Migrate) WithDirtyStateHandler(srcPath, destPath string, isDirty bool) error { + parser := func(path string) (string, string, error) { + var scheme, p string + uri, err := url.Parse(path) + if err != nil { + return "", "", err + } + scheme = uri.Scheme + p = uri.Path + // if no scheme is provided, assume it's a file path + if scheme == "" { + scheme = "file://" + } + return scheme, p, nil + } + + sScheme, sPath, err := parser(srcPath) + if err != nil { + return err + } + + dScheme, dPath, err := parser(destPath) + if err != nil { + return err + } + + m.dirtyStateConf = &dirtyStateHandler{ + srcScheme: sScheme, + destScheme: dScheme, + srcPath: sPath, + destPath: dPath, + enable: isDirty, + } + return nil +} + func newCommon() *Migrate { return &Migrate{ GracefulStop: make(chan bool, 1), @@ -215,20 +287,42 @@ func (m *Migrate) Migrate(version uint) error { if err := m.lock(); err != nil { return err } - curVersion, dirty, err := m.databaseDrv.Version() if err != nil { return m.unlockErr(err) } + // if the dirty flag is passed to the 'goto' command, handle the dirty state if dirty { - return m.unlockErr(ErrDirty{curVersion}) + if m.IsDirtyHandlingEnabled() { + if err = m.handleDirtyState(); err != nil { + return m.unlockErr(err) + } + } else { + // default behavior + return m.unlockErr(ErrDirty{curVersion}) + } + } + + // Copy migrations to the destination directory, + // if state was dirty when Migrate was called, we should handle the dirty state first before copying the migrations + if err = m.copyFiles(); err != nil { + return m.unlockErr(err) } ret := make(chan interface{}, m.PrefetchMigrations) go m.read(curVersion, int(version), ret) - return m.unlockErr(m.runMigrations(ret)) + if err = m.runMigrations(ret); err != nil { + return m.unlockErr(err) + } + // Success: Clean up and confirm + // Files are cleaned up after the migration is successful + if err = m.cleanupFiles(version); err != nil { + return m.unlockErr(err) + } + // unlock the database + return m.unlock() } // Steps looks at the currently active migration version. @@ -723,6 +817,7 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { // to stop execution because it might have received a stop signal on the // GracefulStop channel. func (m *Migrate) runMigrations(ret <-chan interface{}) error { + var lastCleanMigrationApplied int for r := range ret { if m.stop() { @@ -744,6 +839,15 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { if migr.Body != nil { m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { + if m.dirtyStateConf != nil && m.dirtyStateConf.enable { + // this condition is required if the first migration fails + if lastCleanMigrationApplied == 0 { + lastCleanMigrationApplied = migr.TargetVersion + } + if e := m.handleMigrationFailure(lastCleanMigrationApplied); e != nil { + return multierror.Append(err, e) + } + } return err } } @@ -752,7 +856,7 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil { return err } - + lastCleanMigrationApplied = migr.TargetVersion endTime := time.Now() readTime := migr.FinishedReading.Sub(migr.StartedBuffering) runTime := endTime.Sub(migr.FinishedReading) @@ -979,3 +1083,114 @@ func (m *Migrate) logErr(err error) { m.Log.Printf("error: %v", err) } } + +func (m *Migrate) handleDirtyState() error { + // Perform the following actions when the database state is dirty + /* + 1. Update the source driver to read the migrations from the mounted volume + 2. Read the last successful migration version from the file + 3. Set the last successful migration version in the schema_migrations table + 4. Delete the last successful migration file + */ + // the source driver should read the migrations from the mounted volume + // as the DB is dirty and last applied migrations to the database are not present in the source path + if err := m.updateSourceDrv(m.dirtyStateConf.destScheme + m.dirtyStateConf.destPath); err != nil { + return err + } + lastSuccessfulMigrationPath := filepath.Join(m.dirtyStateConf.destPath, lastSuccessfulMigrationFile) + lastVersionBytes, err := os.ReadFile(lastSuccessfulMigrationPath) + if err != nil { + return err + } + lastVersionStr := strings.TrimSpace(string(lastVersionBytes)) + lastVersion, err := strconv.ParseInt(lastVersionStr, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse last successful migration version: %w", err) + } + + // Set the last successful migration version in the schema_migrations table + if err = m.databaseDrv.SetVersion(int(lastVersion), false); err != nil { + return fmt.Errorf("failed to apply last successful migration: %w", err) + } + + m.logPrintf("Successfully set last successful migration version: %s on the DB", lastVersionStr) + + if err = os.Remove(lastSuccessfulMigrationPath); err != nil { + return err + } + + m.logPrintf("Successfully deleted file: %s", lastSuccessfulMigrationPath) + return nil +} + +func (m *Migrate) handleMigrationFailure(lastSuccessfulMigration int) error { + if !m.IsDirtyHandlingEnabled() { + return nil + } + lastSuccessfulMigrationPath := filepath.Join(m.dirtyStateConf.destPath, lastSuccessfulMigrationFile) + return os.WriteFile(lastSuccessfulMigrationPath, []byte(strconv.Itoa(lastSuccessfulMigration)), 0644) +} + +func (m *Migrate) cleanupFiles(targetVersion uint) error { + if !m.IsDirtyHandlingEnabled() { + return nil + } + + files, err := os.ReadDir(m.dirtyStateConf.destPath) + if err != nil { + // If the directory does not exist + return fmt.Errorf("failed to read directory %s: %w", m.dirtyStateConf.destPath, err) + } + + for _, file := range files { + fileName := file.Name() + migration, err := source.Parse(fileName) + if err != nil { + return err + } + // Delete file if version is greater than targetVersion + if migration.Version > targetVersion { + if err = os.Remove(filepath.Join(m.dirtyStateConf.destPath, fileName)); err != nil { + m.logErr(fmt.Errorf("failed to delete file %s: %v", fileName, err)) + continue + } + m.logPrintf("Migration file: %s removed during cleanup", fileName) + } + } + + return nil +} + +// copyFiles copies all files from source to destination volume. +func (m *Migrate) copyFiles() error { + // this is the case when the dirty handling is disabled + if !m.IsDirtyHandlingEnabled() { + return nil + } + + files, err := os.ReadDir(m.dirtyStateConf.srcPath) + if err != nil { + // If the directory does not exist + return fmt.Errorf("failed to read directory %s: %w", m.dirtyStateConf.srcPath, err) + } + m.logPrintf("Copying files from %s to %s", m.dirtyStateConf.srcPath, m.dirtyStateConf.destPath) + for _, file := range files { + fileName := file.Name() + if source.Regex.MatchString(fileName) { + fileContentBytes, err := os.ReadFile(filepath.Join(m.dirtyStateConf.srcPath, fileName)) + if err != nil { + return err + } + info, err := file.Info() + if err != nil { + return err + } + if err = os.WriteFile(filepath.Join(m.dirtyStateConf.destPath, fileName), fileContentBytes, info.Mode().Perm()); err != nil { + return err + } + } + } + + m.logPrintf("Successfully Copied files from %s to %s", m.dirtyStateConf.srcPath, m.dirtyStateConf.destPath) + return nil +} diff --git a/migrate_test.go b/migrate_test.go index f2728179e..33bcb2cd9 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -4,9 +4,12 @@ import ( "bytes" "database/sql" "errors" + "fmt" "io" "log" "os" + "path/filepath" + "strconv" "strings" "testing" @@ -1414,3 +1417,282 @@ func equalDbSeq(t *testing.T, i int, expected migrationSequence, got *dStub.Stub t.Fatalf("\nexpected sequence %v,\ngot %v, in %v", bs, got.MigrationSequence, i) } } + +// Setting up temp directory to be used as the volume mount +func setupTempDir(t *testing.T) (string, func()) { + tempDir, err := os.MkdirTemp("", "migrate_test") + if err != nil { + t.Fatal(err) + } + return tempDir, func() { + if err = os.RemoveAll(tempDir); err != nil { + t.Fatal(err) + } + } +} + +func setupMigrateInstance(tempDir string) (*Migrate, *dStub.Stub) { + scheme := "stub://" + m, _ := New(scheme, scheme) + m.dirtyStateConf = &dirtyStateHandler{ + destScheme: scheme, + destPath: tempDir, + enable: true, + } + return m, m.databaseDrv.(*dStub.Stub) +} + +func TestHandleDirtyState(t *testing.T) { + tempDir, cleanup := setupTempDir(t) + defer cleanup() + + m, dbDrv := setupMigrateInstance(tempDir) + m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations + + tests := []struct { + lastSuccessfulVersion int + currentVersion int + err error + setupFailure bool + }{ + {lastSuccessfulVersion: 1, currentVersion: 3, err: nil, setupFailure: false}, + {lastSuccessfulVersion: 4, currentVersion: 7, err: nil, setupFailure: false}, + {lastSuccessfulVersion: 3, currentVersion: 4, err: nil, setupFailure: false}, + {lastSuccessfulVersion: -3, currentVersion: 4, err: ErrInvalidVersion, setupFailure: false}, + {lastSuccessfulVersion: 4, currentVersion: 3, err: fmt.Errorf("open %s: no such file or directory", filepath.Join(tempDir, lastSuccessfulMigrationFile)), setupFailure: true}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + var lastSuccessfulMigrationPath string + // setupFailure flag helps with testing scenario where the 'lastSuccessfulMigrationFile' doesn't exist + if !test.setupFailure { + lastSuccessfulMigrationPath = filepath.Join(tempDir, lastSuccessfulMigrationFile) + if err := os.WriteFile(lastSuccessfulMigrationPath, []byte(strconv.Itoa(test.lastSuccessfulVersion)), 0644); err != nil { + t.Fatal(err) + } + } + // Setting the DB version as dirty + if err := dbDrv.SetVersion(test.currentVersion, true); err != nil { + t.Fatal(err) + } + + // Quick check to see if set correctly + version, b, err := dbDrv.Version() + if err != nil { + t.Fatal(err) + } + if version != test.currentVersion { + t.Fatalf("expected version %d, got %d", test.currentVersion, version) + } + + if !b { + t.Fatalf("expected DB to be dirty, got false") + } + + // Handle dirty state + if err = m.handleDirtyState(); err != nil { + if strings.Contains(err.Error(), test.err.Error()) { + t.Logf("expected error %v, got %v", test.err, err) + if !test.setupFailure { + if err = os.Remove(lastSuccessfulMigrationPath); err != nil { + t.Fatal(err) + } + } + return + } else { + t.Fatal(err) + } + } + // Check 1: DB should no longer be dirty + if dbDrv.IsDirty { + t.Fatalf("expected dirty to be false, got true") + } + // Check 2: Current version should be the last successful version + if dbDrv.CurrentVersion != test.lastSuccessfulVersion { + t.Fatalf("expected version %d, got %d", test.lastSuccessfulVersion, dbDrv.CurrentVersion) + } + // Check 3: The lastSuccessfulMigration file shouldn't exist + if _, err = os.Stat(lastSuccessfulMigrationPath); !os.IsNotExist(err) { + t.Fatalf("expected file to be deleted, but it still exists") + } + }) + } +} + +func TestHandleMigrationFailure(t *testing.T) { + tempDir, cleanup := setupTempDir(t) + defer cleanup() + + m, _ := setupMigrateInstance(tempDir) + + tests := []struct { + lastSuccessFulVersion int + }{ + {lastSuccessFulVersion: 3}, + {lastSuccessFulVersion: 4}, + {lastSuccessFulVersion: 5}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + if err := m.handleMigrationFailure(test.lastSuccessFulVersion); err != nil { + t.Fatal(err) + } + // Check 1: last successful Migration version should be stored in a file + lastSuccessfulMigrationPath := filepath.Join(tempDir, lastSuccessfulMigrationFile) + if _, err := os.Stat(lastSuccessfulMigrationPath); os.IsNotExist(err) { + t.Fatalf("expected file to be created, but it does not exist") + } + + // Check 2: Check if the content of last successful migration has the correct version + content, err := os.ReadFile(lastSuccessfulMigrationPath) + if err != nil { + t.Fatal(err) + } + + if string(content) != strconv.Itoa(test.lastSuccessFulVersion) { + t.Fatalf("expected %d, got %s", test.lastSuccessFulVersion, string(content)) + } + }) + } +} + +func TestCleanupFiles(t *testing.T) { + tempDir, cleanup := setupTempDir(t) + defer cleanup() + + m, _ := setupMigrateInstance(tempDir) + m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations + + tests := []struct { + migrationFiles []string + targetVersion uint + remainingFiles []string + emptyDestPath bool + }{ + { + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, + targetVersion: 2, + remainingFiles: []string{"1_name.up.sql", "2_name.up.sql"}, + }, + { + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql", "4_name.up.sql", "5_name.up.sql"}, + targetVersion: 3, + remainingFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, + }, + { + migrationFiles: []string{}, + targetVersion: 1, + remainingFiles: []string{}, + emptyDestPath: true, + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + for _, file := range test.migrationFiles { + if err := os.WriteFile(filepath.Join(tempDir, file), []byte(""), 0644); err != nil { + t.Fatal(err) + } + } + + if test.emptyDestPath { + m.dirtyStateConf.destPath = "" + } + + if err := m.cleanupFiles(test.targetVersion); err != nil { + t.Fatal(err) + } + + // check 1: only files upto the target version should exist + for _, file := range test.remainingFiles { + if _, err := os.Stat(filepath.Join(tempDir, file)); os.IsNotExist(err) { + t.Fatalf("expected file %s to exist, but it does not", file) + } + } + + // check 2: the files removed are as expected + deletedFiles := diff(test.migrationFiles, test.remainingFiles) + for _, deletedFile := range deletedFiles { + if _, err := os.Stat(filepath.Join(tempDir, deletedFile)); !os.IsNotExist(err) { + t.Fatalf("expected file %s to be deleted, but it still exists", deletedFile) + } + } + }) + } +} + +func TestCopyFiles(t *testing.T) { + srcDir, cleanupSrc := setupTempDir(t) + defer cleanupSrc() + + destDir, cleanupDest := setupTempDir(t) + defer cleanupDest() + + m, _ := setupMigrateInstance(destDir) + m.dirtyStateConf.srcPath = srcDir + + tests := []struct { + migrationFiles []string + copiedFiles []string + emptyDestPath bool + }{ + { + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, + copiedFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, + }, + { + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql", "4_name.up.sql", "current.sql"}, + copiedFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql", "4_name.up.sql"}, + }, + { + emptyDestPath: true, + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + for _, file := range test.migrationFiles { + if err := os.WriteFile(filepath.Join(srcDir, file), []byte(""), 0644); err != nil { + t.Fatal(err) + } + } + if test.emptyDestPath { + m.dirtyStateConf.destPath = "" + } + + if err := m.copyFiles(); err != nil { + t.Fatal(err) + } + + for _, file := range test.copiedFiles { + if _, err := os.Stat(filepath.Join(destDir, file)); os.IsNotExist(err) { + t.Fatalf("expected file %s to be copied, but it does not exist", file) + } + } + }) + } +} + +/* + diff returns an array containing the elements in Array A and not in B +*/ + +func diff(a, b []string) []string { + temp := map[string]int{} + for _, s := range a { + temp[s]++ + } + for _, s := range b { + temp[s]-- + } + + var result []string + for s, v := range temp { + if v != 0 { + result = append(result, s) + } + } + return result +}