diff --git a/cmd/migrate/config.go b/cmd/migrate/config.go index 905b5af5d..a03097618 100644 --- a/cmd/migrate/config.go +++ b/cmd/migrate/config.go @@ -38,5 +38,5 @@ var ( // goto command flags flagDirty = pflag.Bool("force-dirty-handling", false, "force the handling of dirty database state") - flagMountPath = pflag.String("cache-dir", "", "path to the mounted volume which is used to copy the migration files") + flagMountPath = pflag.String("cache-dir", "", "path to the cache-dir which is used to copy the migration files") ) diff --git a/internal/cli/main.go b/internal/cli/main.go index d15e33064..e1d21934d 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -271,7 +271,7 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU log.fatal("error: cache-dir must be specified when force-dirty-handling is set") } - if err = migrater.WithDirtyStateHandler(sourcePtr, destPath, handleDirty); err != nil { + if err = migrater.WithDirtyStateConfig(sourcePtr, destPath, handleDirty); err != nil { log.fatalErr(err) } } diff --git a/migrate.go b/migrate.go index 1d1187254..32a76145a 100644 --- a/migrate.go +++ b/migrate.go @@ -88,11 +88,11 @@ type Migrate struct { // but can be set per Migrate instance. LockTimeout time.Duration - // DirtyStateHandler is used to handle dirty state of the database - dirtyStateConf *dirtyStateHandler + // dirtyStateConfig is used to store the configuration required to handle dirty state of the database + dirtyStateConf *dirtyStateConfig } -type dirtyStateHandler struct { +type dirtyStateConfig struct { srcScheme string srcPath string destScheme string @@ -218,33 +218,30 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa return m, nil } -func (m *Migrate) WithDirtyStateHandler(srcPath, destPath string, isDirty bool) error { - parser := func(path string) (string, string, error) { - var scheme, p string +func (m *Migrate) WithDirtyStateConfig(srcPath, destPath string, isDirty bool) error { + parsePath := func(path string) (string, string, error) { uri, err := url.Parse(path) if err != nil { return "", "", err } - scheme = uri.Scheme - p = uri.Path - // if no scheme is provided, assume it's a file path - if scheme == "" { - scheme = "file://" + scheme := "file" + if uri.Scheme != "file" && uri.Scheme != "" { + return "", "", fmt.Errorf("unsupported scheme: %s", scheme) } - return scheme, p, nil + return scheme + "://", uri.Path, nil } - sScheme, sPath, err := parser(srcPath) + sScheme, sPath, err := parsePath(srcPath) if err != nil { return err } - dScheme, dPath, err := parser(destPath) + dScheme, dPath, err := parsePath(destPath) if err != nil { return err } - m.dirtyStateConf = &dirtyStateHandler{ + m.dirtyStateConf = &dirtyStateConfig{ srcScheme: sScheme, destScheme: dScheme, srcPath: sPath, @@ -839,7 +836,7 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { if migr.Body != nil { m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { - if m.dirtyStateConf != nil && m.dirtyStateConf.enable { + if m.IsDirtyHandlingEnabled() { // this condition is required if the first migration fails if lastCleanMigrationApplied == 0 { lastCleanMigrationApplied = migr.TargetVersion @@ -1087,12 +1084,12 @@ func (m *Migrate) logErr(err error) { func (m *Migrate) handleDirtyState() error { // Perform the following actions when the database state is dirty /* - 1. Update the source driver to read the migrations from the mounted volume + 1. Update the source driver to read the migrations from the destination path 2. Read the last successful migration version from the file 3. Set the last successful migration version in the schema_migrations table 4. Delete the last successful migration file */ - // the source driver should read the migrations from the mounted volume + // the source driver should read the migrations from the destination path // as the DB is dirty and last applied migrations to the database are not present in the source path if err := m.updateSourceDrv(m.dirtyStateConf.destScheme + m.dirtyStateConf.destPath); err != nil { return err diff --git a/migrate_test.go b/migrate_test.go index 33bcb2cd9..4b83b92a0 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -1418,23 +1418,10 @@ func equalDbSeq(t *testing.T, i int, expected migrationSequence, got *dStub.Stub } } -// Setting up temp directory to be used as the volume mount -func setupTempDir(t *testing.T) (string, func()) { - tempDir, err := os.MkdirTemp("", "migrate_test") - if err != nil { - t.Fatal(err) - } - return tempDir, func() { - if err = os.RemoveAll(tempDir); err != nil { - t.Fatal(err) - } - } -} - func setupMigrateInstance(tempDir string) (*Migrate, *dStub.Stub) { scheme := "stub://" m, _ := New(scheme, scheme) - m.dirtyStateConf = &dirtyStateHandler{ + m.dirtyStateConf = &dirtyStateConfig{ destScheme: scheme, destPath: tempDir, enable: true, @@ -1443,8 +1430,7 @@ func setupMigrateInstance(tempDir string) (*Migrate, *dStub.Stub) { } func TestHandleDirtyState(t *testing.T) { - tempDir, cleanup := setupTempDir(t) - defer cleanup() + tempDir := t.TempDir() m, dbDrv := setupMigrateInstance(tempDir) m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations @@ -1521,8 +1507,7 @@ func TestHandleDirtyState(t *testing.T) { } func TestHandleMigrationFailure(t *testing.T) { - tempDir, cleanup := setupTempDir(t) - defer cleanup() + tempDir := t.TempDir() m, _ := setupMigrateInstance(tempDir) @@ -1559,8 +1544,7 @@ func TestHandleMigrationFailure(t *testing.T) { } func TestCleanupFiles(t *testing.T) { - tempDir, cleanup := setupTempDir(t) - defer cleanup() + tempDir := t.TempDir() m, _ := setupMigrateInstance(tempDir) m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations @@ -1624,11 +1608,8 @@ func TestCleanupFiles(t *testing.T) { } func TestCopyFiles(t *testing.T) { - srcDir, cleanupSrc := setupTempDir(t) - defer cleanupSrc() - - destDir, cleanupDest := setupTempDir(t) - defer cleanupDest() + srcDir := t.TempDir() + destDir := t.TempDir() m, _ := setupMigrateInstance(destDir) m.dirtyStateConf.srcPath = srcDir @@ -1647,7 +1628,7 @@ func TestCopyFiles(t *testing.T) { copiedFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql", "4_name.up.sql"}, }, { - emptyDestPath: true, + emptyDestPath: true, // copyFiles should not do anything }, } @@ -1675,6 +1656,88 @@ func TestCopyFiles(t *testing.T) { } } +func TestWithDirtyStateConfig(t *testing.T) { + tests := []struct { + name string + srcPath string + destPath string + isDirty bool + wantErr bool + wantConf *dirtyStateConfig + }{ + { + name: "Valid file paths", + srcPath: "file:///src/path", + destPath: "file:///dest/path", + isDirty: true, + wantErr: false, + wantConf: &dirtyStateConfig{ + srcScheme: "file://", + destScheme: "file://", + srcPath: "/src/path", + destPath: "/dest/path", + enable: true, + }, + }, + { + name: "Invalid source scheme", + srcPath: "s3:///src/path", + destPath: "file:///dest/path", + isDirty: true, + wantErr: true, + }, + { + name: "Invalid destination scheme", + srcPath: "file:///src/path", + destPath: "s3:///dest/path", + isDirty: true, + wantErr: true, + }, + { + name: "Empty source scheme", + srcPath: "/src/path", + destPath: "file:///dest/path", + isDirty: true, + wantErr: false, + wantConf: &dirtyStateConfig{ + srcScheme: "file://", + destScheme: "file://", + srcPath: "/src/path", + destPath: "/dest/path", + enable: true, + }, + }, + { + name: "Empty destination scheme", + srcPath: "file:///src/path", + destPath: "/dest/path", + isDirty: true, + wantErr: false, + wantConf: &dirtyStateConfig{ + srcScheme: "file://", + destScheme: "file://", + srcPath: "/src/path", + destPath: "/dest/path", + enable: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &Migrate{} + err := m.WithDirtyStateConfig(tt.srcPath, tt.destPath, tt.isDirty) + if (err != nil) != tt.wantErr { + t.Errorf("error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && m.dirtyStateConf == tt.wantConf { + t.Errorf("dirtyStateConf = %v, want %v", m.dirtyStateConf, tt.wantConf) + } + }) + } +} + /* diff returns an array containing the elements in Array A and not in B */