diff --git a/cmd/db.go b/cmd/db.go index 9ce00a41b..89bcedd7e 100644 --- a/cmd/db.go +++ b/cmd/db.go @@ -132,10 +132,14 @@ var ( } dbPullCmd = &cobra.Command{ - Use: "pull", + Use: "pull [migration name]", Short: "Pull schema from the remote database", RunE: func(cmd *cobra.Command, args []string) error { - return pull.Run(cmd.Context(), schema, flags.DbConfig, afero.NewOsFs()) + name := "remote_schema" + if len(args) > 0 { + name = args[0] + } + return pull.Run(cmd.Context(), schema, flags.DbConfig, name, afero.NewOsFs()) }, PostRun: func(cmd *cobra.Command, args []string) { fmt.Println("Finished " + utils.Aqua("supabase db pull") + ".") diff --git a/internal/db/diff/diff.go b/internal/db/diff/diff.go index c60d571a9..31b2f5207 100644 --- a/internal/db/diff/diff.go +++ b/internal/db/diff/diff.go @@ -20,7 +20,7 @@ func SaveDiff(out, file string, fsys afero.Fs) error { if len(out) < 2 { fmt.Fprintln(os.Stderr, "No schema changes found") } else if len(file) > 0 { - path := new.GetMigrationPath(file) + path := new.GetMigrationPath(utils.GetCurrentTimestamp(), file) if err := afero.WriteFile(fsys, path, []byte(out), 0644); err != nil { return err } diff --git a/internal/db/pull/pull.go b/internal/db/pull/pull.go index 74a587329..dad92c3c7 100644 --- a/internal/db/pull/pull.go +++ b/internal/db/pull/pull.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "os" - "path/filepath" "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" @@ -15,6 +14,7 @@ import ( "github.com/supabase/cli/internal/db/diff" "github.com/supabase/cli/internal/db/dump" "github.com/supabase/cli/internal/migration/list" + "github.com/supabase/cli/internal/migration/new" "github.com/supabase/cli/internal/migration/repair" "github.com/supabase/cli/internal/utils" ) @@ -28,7 +28,7 @@ var ( errInSync = errors.New("no schema changes found") ) -func Run(ctx context.Context, schema []string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { +func Run(ctx context.Context, schema []string, config pgconn.Config, name string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { // 1. Sanity checks. if err := utils.AssertDockerIsRunning(ctx); err != nil { return err @@ -45,7 +45,7 @@ func Run(ctx context.Context, schema []string, config pgconn.Config, fsys afero. defer conn.Close(context.Background()) // 3. Pull schema timestamp := utils.GetCurrentTimestamp() - path := filepath.Join(utils.MigrationsDir, timestamp+"_remote_schema.sql") + path := new.GetMigrationPath(timestamp, name) if err := utils.RunProgram(ctx, func(p utils.Program, ctx context.Context) error { return run(p, ctx, schema, path, conn, fsys) }); err != nil { diff --git a/internal/db/pull/pull_test.go b/internal/db/pull/pull_test.go index 6a128bf19..2aba1615d 100644 --- a/internal/db/pull/pull_test.go +++ b/internal/db/pull/pull_test.go @@ -44,7 +44,7 @@ func TestPullCommand(t *testing.T) { Get("/_ping"). ReplyError(errors.New("network error")) // Run test - err := Run(context.Background(), nil, pgconn.Config{}, fsys) + err := Run(context.Background(), nil, pgconn.Config{}, "", fsys) // Check error assert.ErrorContains(t, err, "network error") assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -67,7 +67,7 @@ func TestPullCommand(t *testing.T) { SetHeader("API-Version", utils.Docker.ClientVersion()). SetHeader("OSType", "linux") // Run test - err := Run(context.Background(), nil, pgconn.Config{}, fsys) + err := Run(context.Background(), nil, pgconn.Config{}, "", fsys) // Check error assert.ErrorIs(t, err, os.ErrNotExist) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -91,7 +91,7 @@ func TestPullCommand(t *testing.T) { SetHeader("API-Version", utils.Docker.ClientVersion()). SetHeader("OSType", "linux") // Run test - err := Run(context.Background(), nil, pgconn.Config{}, fsys) + err := Run(context.Background(), nil, pgconn.Config{}, "", fsys) // Check error assert.ErrorContains(t, err, "invalid port (outside range)") assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -120,7 +120,7 @@ func TestPullCommand(t *testing.T) { conn.Query(list.LIST_MIGRATION_VERSION). ReplyError(pgerrcode.InvalidCatalogName, `database "postgres" does not exist`) // Run test - err := Run(context.Background(), nil, dbConfig, fsys, conn.Intercept) + err := Run(context.Background(), nil, dbConfig, "", fsys, conn.Intercept) // Check error assert.ErrorContains(t, err, `ERROR: database "postgres" does not exist (SQLSTATE 3D000)`) assert.Empty(t, apitest.ListUnmatchedRequests()) diff --git a/internal/migration/new/new.go b/internal/migration/new/new.go index 0e14b9c50..9c93cfe82 100644 --- a/internal/migration/new/new.go +++ b/internal/migration/new/new.go @@ -11,7 +11,7 @@ import ( ) func Run(migrationName string, stdin afero.File, fsys afero.Fs) error { - path := GetMigrationPath(migrationName) + path := GetMigrationPath(utils.GetCurrentTimestamp(), migrationName) if err := utils.MkdirIfNotExistFS(fsys, filepath.Dir(path)); err != nil { return err } @@ -34,7 +34,7 @@ func Run(migrationName string, stdin afero.File, fsys afero.Fs) error { return nil } -func GetMigrationPath(migrationName string) string { - name := utils.GetCurrentTimestamp() + "_" + migrationName + ".sql" - return filepath.Join(utils.MigrationsDir, name) +func GetMigrationPath(timestamp, name string) string { + fullName := fmt.Sprintf("%s_%s.sql", timestamp, name) + return filepath.Join(utils.MigrationsDir, fullName) }