diff --git a/internal/db/diff/diff.go b/internal/db/diff/diff.go index 9f90fe3c5..7abdb6f93 100644 --- a/internal/db/diff/diff.go +++ b/internal/db/diff/diff.go @@ -5,7 +5,9 @@ import ( _ "embed" "fmt" "io" + "io/fs" "os" + "path/filepath" "regexp" "strconv" "strings" @@ -24,6 +26,7 @@ import ( "github.com/supabase/cli/internal/gen/keys" "github.com/supabase/cli/internal/migration/apply" "github.com/supabase/cli/internal/migration/list" + "github.com/supabase/cli/internal/migration/repair" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/parser" ) @@ -35,6 +38,42 @@ func Run(ctx context.Context, schema []string, file string, config pgconn.Config if err := utils.LoadConfigFS(fsys); err != nil { return err } + if utils.IsLocalDatabase(config) { + if exists, err := afero.DirExists(fsys, utils.SchemasDir); exists { + if err := utils.AssertSupabaseDbIsRunning(); errors.Is(err, utils.ErrNotRunning) { + fmt.Fprintf(os.Stderr, "Creating local database from %s...\n", utils.Bold(utils.SchemasDir)) + var declared []string + if err := afero.Walk(fsys, utils.SchemasDir, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return errors.Errorf("failed to walk dir: %w", err) + } + if !info.Mode().IsRegular() || filepath.Ext(info.Name()) != ".sql" { + return nil + } + declared = append(declared, path) + return nil + }); err != nil { + return err + } + if len(declared) > 0 { + container, err := CreateShadowDatabase(ctx, utils.Config.Db.Port) + if err != nil { + return err + } + defer utils.DockerRemove(container) + if !start.WaitForHealthyService(ctx, container, start.HealthTimeout) { + return errors.New(start.ErrDatabase) + } + if err := MigrateBaseDatabase(ctx, container, declared, fsys, options...); err != nil { + return err + } + } + } + } else if err != nil { + logger := utils.GetDebugLogger() + fmt.Fprintln(logger, err) + } + } // 1. Load all user defined schemas if len(schema) == 0 { schema, err = loadSchema(ctx, config, options...) @@ -103,9 +142,9 @@ func LoadUserSchemas(ctx context.Context, conn *pgx.Conn) ([]string, error) { return reset.ListSchemas(ctx, conn, exclude...) } -func CreateShadowDatabase(ctx context.Context) (string, error) { +func CreateShadowDatabase(ctx context.Context, port uint16) (string, error) { config := start.NewContainerConfig() - hostPort := strconv.FormatUint(uint64(utils.Config.Db.ShadowPort), 10) + hostPort := strconv.FormatUint(uint64(port), 10) hostConfig := container.HostConfig{ PortBindings: nat.PortMap{"5432/tcp": []nat.PortBinding{{HostPort: hostPort}}}, AutoRemove: true, @@ -144,9 +183,31 @@ func MigrateShadowDatabase(ctx context.Context, container string, fsys afero.Fs, return apply.MigrateUp(ctx, conn, migrations, fsys) } +func MigrateBaseDatabase(ctx context.Context, container string, migrations []string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { + conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{}, options...) + if err != nil { + return err + } + defer conn.Close(context.Background()) + if err := start.SetupDatabase(ctx, conn, container[:12], os.Stderr, fsys); err != nil { + return err + } + for _, path := range migrations { + fmt.Fprintln(os.Stderr, "Applying schema "+utils.Bold(path)+"...") + migration, err := repair.NewMigrationFromFile(path, fsys) + if err != nil { + return err + } + if err := migration.ExecBatch(ctx, conn); err != nil { + return err + } + } + return nil +} + func DiffDatabase(ctx context.Context, schema []string, config pgconn.Config, w io.Writer, fsys afero.Fs, differ func(context.Context, string, string, []string) (string, error), options ...func(*pgx.ConnConfig)) (string, error) { fmt.Fprintln(w, "Creating shadow database...") - shadow, err := CreateShadowDatabase(ctx) + shadow, err := CreateShadowDatabase(ctx, utils.Config.Db.ShadowPort) if err != nil { return "", err } diff --git a/internal/db/diff/pgadmin.go b/internal/db/diff/pgadmin.go index 6c94442f4..55df354e5 100644 --- a/internal/db/diff/pgadmin.go +++ b/internal/db/diff/pgadmin.go @@ -58,7 +58,7 @@ func run(p utils.Program, ctx context.Context, schema []string, config pgconn.Co p.Send(utils.StatusMsg("Creating shadow database...")) // 1. Create shadow db and run migrations - shadow, err := CreateShadowDatabase(ctx) + shadow, err := CreateShadowDatabase(ctx, utils.Config.Db.ShadowPort) if err != nil { return err } diff --git a/internal/gen/types/typescript/typescript_test.go b/internal/gen/types/typescript/typescript_test.go index 9e4ed81e4..31fec482b 100644 --- a/internal/gen/types/typescript/typescript_test.go +++ b/internal/gen/types/typescript/typescript_test.go @@ -25,7 +25,7 @@ func TestGenLocalCommand(t *testing.T) { dbConfig := pgconn.Config{ Host: utils.Config.Hostname, - Port: uint16(utils.Config.Db.Port), + Port: utils.Config.Db.Port, User: "admin", Password: "password", } diff --git a/internal/migration/squash/squash.go b/internal/migration/squash/squash.go index 7858f8063..2719b995d 100644 --- a/internal/migration/squash/squash.go +++ b/internal/migration/squash/squash.go @@ -80,7 +80,7 @@ func squashToVersion(ctx context.Context, version string, fsys afero.Fs, options func squashMigrations(ctx context.Context, migrations []string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { // 1. Start shadow database - shadow, err := diff.CreateShadowDatabase(ctx) + shadow, err := diff.CreateShadowDatabase(ctx, utils.Config.Db.ShadowPort) if err != nil { return err } diff --git a/internal/start/start.go b/internal/start/start.go index e25ed4c08..f8273619a 100644 --- a/internal/start/start.go +++ b/internal/start/start.go @@ -103,7 +103,7 @@ type kongConfig struct { EdgeRuntimeId string LogflareId string ApiHost string - ApiPort uint + ApiPort uint16 } var ( diff --git a/internal/utils/config.go b/internal/utils/config.go index 177751696..4df67de2f 100644 --- a/internal/utils/config.go +++ b/internal/utils/config.go @@ -254,7 +254,7 @@ type ( api struct { Enabled bool `toml:"enabled"` Image string `toml:"-"` - Port uint `toml:"port"` + Port uint16 `toml:"port"` Schemas []string `toml:"schemas"` ExtraSearchPath []string `toml:"extra_search_path"` MaxRows uint `toml:"max_rows"` @@ -262,8 +262,8 @@ type ( db struct { Image string `toml:"-"` - Port uint `toml:"port"` - ShadowPort uint `toml:"shadow_port"` + Port uint16 `toml:"port"` + ShadowPort uint16 `toml:"shadow_port"` MajorVersion uint `toml:"major_version"` Password string `toml:"-"` RootKey string `toml:"-" mapstructure:"root_key"` @@ -287,16 +287,16 @@ type ( studio struct { Enabled bool `toml:"enabled"` - Port uint `toml:"port"` + Port uint16 `toml:"port"` ApiUrl string `toml:"api_url"` OpenaiApiKey string `toml:"openai_api_key"` } inbucket struct { - Enabled bool `toml:"enabled"` - Port uint `toml:"port"` - SmtpPort uint `toml:"smtp_port"` - Pop3Port uint `toml:"pop3_port"` + Enabled bool `toml:"enabled"` + Port uint16 `toml:"port"` + SmtpPort uint16 `toml:"smtp_port"` + Pop3Port uint16 `toml:"pop3_port"` } storage struct { diff --git a/internal/utils/connect.go b/internal/utils/connect.go index cf7c5d9fc..b3c44ad42 100644 --- a/internal/utils/connect.go +++ b/internal/utils/connect.go @@ -98,7 +98,7 @@ func ConnectLocalPostgres(ctx context.Context, config pgconn.Config, options ... config.Host = Config.Hostname } if config.Port == 0 { - config.Port = uint16(Config.Db.Port) + config.Port = Config.Db.Port } if len(config.User) == 0 { config.User = "postgres" @@ -155,5 +155,5 @@ func ConnectByConfig(ctx context.Context, config pgconn.Config, options ...func( } func IsLocalDatabase(config pgconn.Config) bool { - return config.Host == Config.Hostname && config.Port == uint16(Config.Db.Port) + return config.Host == Config.Hostname && config.Port == Config.Db.Port } diff --git a/internal/utils/misc.go b/internal/utils/misc.go index 8d5384761..563b7cdd7 100644 --- a/internal/utils/misc.go +++ b/internal/utils/misc.go @@ -178,6 +178,7 @@ var ( RestVersionPath = filepath.Join(TempDir, "rest-version") StorageVersionPath = filepath.Join(TempDir, "storage-version") CurrBranchPath = filepath.Join(SupabaseDirPath, ".branches", "_current_branch") + SchemasDir = filepath.Join(SupabaseDirPath, "schemas") MigrationsDir = filepath.Join(SupabaseDirPath, "migrations") FunctionsDir = filepath.Join(SupabaseDirPath, "functions") FallbackImportMapPath = filepath.Join(FunctionsDir, "import_map.json")