From 7f64b4ed404b9ebdb5da3e283a0e1c5f2d895788 Mon Sep 17 00:00:00 2001 From: Han Qiao Date: Thu, 26 Oct 2023 14:55:19 +0800 Subject: [PATCH] fix: refactor connect by config to identify localhost (#1505) * chore: refactor connect by config to identify localhost * chore: update unit tests --- cmd/migration.go | 2 +- internal/db/diff/migra.go | 19 +++--------------- internal/db/diff/migra_test.go | 17 ---------------- internal/db/lint/lint.go | 17 +--------------- internal/db/pull/pull.go | 3 +-- internal/db/pull/pull_test.go | 2 +- internal/db/push/push.go | 2 +- internal/db/push/push_test.go | 8 +++++++- internal/db/reset/reset.go | 2 +- internal/inspect/bloat/bloat.go | 2 +- internal/inspect/blocking/blocking.go | 2 +- internal/inspect/cache/cache.go | 2 +- internal/inspect/calls/calls.go | 2 +- internal/inspect/index_sizes/index_sizes.go | 2 +- internal/inspect/index_usage/index_usage.go | 2 +- internal/inspect/locks/locks.go | 2 +- .../long_running_queries.go | 2 +- internal/inspect/outliers/outliers.go | 2 +- .../replication_slots/replication_slots.go | 2 +- .../role_connections/role_connections.go | 2 +- internal/inspect/seq_scans/seq_scans.go | 2 +- .../table_index_sizes/table_index_sizes.go | 2 +- .../table_record_counts.go | 2 +- internal/inspect/table_sizes/table_sizes.go | 2 +- .../total_index_size/total_index_size.go | 2 +- .../total_table_sizes/total_table_sizes.go | 2 +- .../inspect/unused_indexes/unused_indexes.go | 2 +- internal/inspect/vacuum_stats/vacuum_stats.go | 2 +- internal/migration/list/list.go | 2 +- internal/migration/repair/repair.go | 2 +- internal/migration/squash/squash.go | 4 ++-- internal/migration/squash/squash_test.go | 4 ++-- internal/migration/up/up.go | 4 ++-- internal/utils/connect.go | 20 +++++++++++++++++++ 34 files changed, 63 insertions(+), 83 deletions(-) diff --git a/cmd/migration.go b/cmd/migration.go index 35d588ff9..b5eb04dcd 100644 --- a/cmd/migration.go +++ b/cmd/migration.go @@ -79,7 +79,7 @@ var ( Use: "up", Short: "Apply pending migrations to local database", RunE: func(cmd *cobra.Command, args []string) error { - return up.Run(cmd.Context(), includeAll, afero.NewOsFs()) + return up.Run(cmd.Context(), includeAll, flags.DbConfig, afero.NewOsFs()) }, PostRun: func(cmd *cobra.Command, args []string) { fmt.Println("Local database is up to date.") diff --git a/internal/db/diff/migra.go b/internal/db/diff/migra.go index 176bf4f91..c497acfab 100644 --- a/internal/db/diff/migra.go +++ b/internal/db/diff/migra.go @@ -34,14 +34,6 @@ func RunMigra(ctx context.Context, schema []string, file string, config pgconn.C if err := utils.LoadConfigFS(fsys); err != nil { return err } - if config.Host != "127.0.0.1" { - fmt.Fprintln(os.Stderr, "Connecting to remote database...") - } else { - fmt.Fprintln(os.Stderr, "Connecting to local database...") - if err := utils.AssertSupabaseDbIsRunning(); err != nil { - return err - } - } // 1. Load all user defined schemas if len(schema) == 0 { schema, err = loadSchema(ctx, config, options...) @@ -59,15 +51,10 @@ func RunMigra(ctx context.Context, schema []string, file string, config pgconn.C return SaveDiff(out, file, fsys) } -func loadSchema(ctx context.Context, config pgconn.Config, options ...func(*pgx.ConnConfig)) (schema []string, err error) { - var conn *pgx.Conn - if config.Host == "127.0.0.1" && config.Port == uint16(utils.Config.Db.Port) { - conn, err = utils.ConnectLocalPostgres(ctx, config, options...) - } else { - conn, err = utils.ConnectRemotePostgres(ctx, config, options...) - } +func loadSchema(ctx context.Context, config pgconn.Config, options ...func(*pgx.ConnConfig)) ([]string, error) { + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { - return schema, err + return nil, err } defer conn.Close(context.Background()) return LoadUserSchemas(ctx, conn) diff --git a/internal/db/diff/migra_test.go b/internal/db/diff/migra_test.go index 566f26549..387f862eb 100644 --- a/internal/db/diff/migra_test.go +++ b/internal/db/diff/migra_test.go @@ -92,23 +92,6 @@ func TestRunMigra(t *testing.T) { assert.ErrorIs(t, err, os.ErrNotExist) }) - t.Run("throws error on missing database", func(t *testing.T) { - // Setup in-memory fs - fsys := afero.NewMemMapFs() - require.NoError(t, utils.WriteConfig(fsys, false)) - // Setup mock docker - require.NoError(t, apitest.MockDocker(utils.Docker)) - defer gock.OffAll() - gock.New(utils.Docker.DaemonHost()). - Get("/v" + utils.Docker.ClientVersion() + "/containers/supabase_db_"). - ReplyError(errors.New("network error")) - // Run test - err := RunMigra(context.Background(), []string{"public"}, "", pgconn.Config{Host: "127.0.0.1"}, fsys) - // Check error - assert.ErrorIs(t, err, utils.ErrNotRunning) - assert.Empty(t, apitest.ListUnmatchedRequests()) - }) - t.Run("throws error on failure to load user schemas", func(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() diff --git a/internal/db/lint/lint.go b/internal/db/lint/lint.go index 91a4c36ea..8f169570e 100644 --- a/internal/db/lint/lint.go +++ b/internal/db/lint/lint.go @@ -40,7 +40,7 @@ func toEnum(level string) LintLevel { func Run(ctx context.Context, schema []string, level string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { // Sanity checks. - conn, err := connect(ctx, config, fsys, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } @@ -57,21 +57,6 @@ func Run(ctx context.Context, schema []string, level string, config pgconn.Confi return printResultJSON(result, toEnum(level), os.Stdout) } -func connect(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) (*pgx.Conn, error) { - if config.Host != "127.0.0.1" { - fmt.Fprintln(os.Stderr, "Connecting to remote database...") - return utils.ConnectRemotePostgres(ctx, config, options...) - } - fmt.Fprintln(os.Stderr, "Connecting to local database...") - if err := utils.LoadConfigFS(fsys); err != nil { - return nil, err - } - if err := utils.AssertSupabaseDbIsRunning(); err != nil { - return nil, err - } - return utils.ConnectLocalPostgres(ctx, pgconn.Config{}, options...) -} - func filterResult(result []Result, minLevel LintLevel) (filtered []Result) { for _, r := range result { out := Result{Function: r.Function} diff --git a/internal/db/pull/pull.go b/internal/db/pull/pull.go index dad92c3c7..9285c5a2c 100644 --- a/internal/db/pull/pull.go +++ b/internal/db/pull/pull.go @@ -37,8 +37,7 @@ func Run(ctx context.Context, schema []string, config pgconn.Config, name string return err } // 2. Check postgres connection - fmt.Fprintln(os.Stderr, "Connecting to remote database...") - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/db/pull/pull_test.go b/internal/db/pull/pull_test.go index 2aba1615d..cfd068919 100644 --- a/internal/db/pull/pull_test.go +++ b/internal/db/pull/pull_test.go @@ -23,7 +23,7 @@ import ( ) var dbConfig = pgconn.Config{ - Host: "127.0.0.1", + Host: "db.supabase.co", Port: 5432, User: "admin", Password: "password", diff --git a/internal/db/push/push.go b/internal/db/push/push.go index ef2f6c08b..4bd2ef141 100644 --- a/internal/db/push/push.go +++ b/internal/db/push/push.go @@ -19,7 +19,7 @@ func Run(ctx context.Context, dryRun, ignoreVersionMismatch bool, includeRoles, if dryRun { fmt.Fprintln(os.Stderr, "DRY RUN: migrations will *not* be pushed to the database.") } - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/db/push/push_test.go b/internal/db/push/push_test.go index 78123d7cd..b8b843b8e 100644 --- a/internal/db/push/push_test.go +++ b/internal/db/push/push_test.go @@ -73,7 +73,13 @@ func TestMigrationPush(t *testing.T) { conn.Query(list.LIST_MIGRATION_VERSION). ReplyError(pgerrcode.InvalidCatalogName, `database "target" does not exist`) // Run test - err := Run(context.Background(), false, false, false, false, dbConfig, fsys, conn.Intercept) + err := Run(context.Background(), false, false, false, false, pgconn.Config{ + Host: "db.supabase.co", + Port: 5432, + User: "admin", + Password: "password", + Database: "postgres", + }, fsys, conn.Intercept) // Check error assert.ErrorContains(t, err, `ERROR: database "target" does not exist (SQLSTATE 3D000)`) }) diff --git a/internal/db/reset/reset.go b/internal/db/reset/reset.go index ea543ab4d..9aa807845 100644 --- a/internal/db/reset/reset.go +++ b/internal/db/reset/reset.go @@ -49,7 +49,7 @@ func Run(ctx context.Context, version string, config pgconn.Config, fsys afero.F return err } } - if config.Host != "127.0.0.1" { + if !utils.IsLoopback(config.Host) { if shouldReset := utils.PromptYesNo("Confirm resetting the remote database?", true, os.Stdin); !shouldReset { return context.Canceled } diff --git a/internal/inspect/bloat/bloat.go b/internal/inspect/bloat/bloat.go index c0fca2ee4..3310e70ad 100644 --- a/internal/inspect/bloat/bloat.go +++ b/internal/inspect/bloat/bloat.go @@ -83,7 +83,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/blocking/blocking.go b/internal/inspect/blocking/blocking.go index 1a5297fd3..817e97b7f 100644 --- a/internal/inspect/blocking/blocking.go +++ b/internal/inspect/blocking/blocking.go @@ -42,7 +42,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/cache/cache.go b/internal/inspect/cache/cache.go index e9ce4d3ee..0674bfab1 100644 --- a/internal/inspect/cache/cache.go +++ b/internal/inspect/cache/cache.go @@ -31,7 +31,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/calls/calls.go b/internal/inspect/calls/calls.go index 8969c34df..8d79555d1 100644 --- a/internal/inspect/calls/calls.go +++ b/internal/inspect/calls/calls.go @@ -34,7 +34,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/index_sizes/index_sizes.go b/internal/inspect/index_sizes/index_sizes.go index 4cf7fd20b..b916ff689 100644 --- a/internal/inspect/index_sizes/index_sizes.go +++ b/internal/inspect/index_sizes/index_sizes.go @@ -29,7 +29,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/index_usage/index_usage.go b/internal/inspect/index_usage/index_usage.go index 3e77f8ae0..6060a6542 100644 --- a/internal/inspect/index_usage/index_usage.go +++ b/internal/inspect/index_usage/index_usage.go @@ -38,7 +38,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/locks/locks.go b/internal/inspect/locks/locks.go index 0a1e3a98c..ef8cafb2a 100644 --- a/internal/inspect/locks/locks.go +++ b/internal/inspect/locks/locks.go @@ -38,7 +38,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/long_running_queries/long_running_queries.go b/internal/inspect/long_running_queries/long_running_queries.go index 3708aba89..72272fcbb 100644 --- a/internal/inspect/long_running_queries/long_running_queries.go +++ b/internal/inspect/long_running_queries/long_running_queries.go @@ -33,7 +33,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/outliers/outliers.go b/internal/inspect/outliers/outliers.go index 09b489003..f5421e28a 100644 --- a/internal/inspect/outliers/outliers.go +++ b/internal/inspect/outliers/outliers.go @@ -34,7 +34,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/replication_slots/replication_slots.go b/internal/inspect/replication_slots/replication_slots.go index 05436e0b5..d5d5a3da0 100644 --- a/internal/inspect/replication_slots/replication_slots.go +++ b/internal/inspect/replication_slots/replication_slots.go @@ -35,7 +35,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/role_connections/role_connections.go b/internal/inspect/role_connections/role_connections.go index 413637e9f..290b9f24c 100644 --- a/internal/inspect/role_connections/role_connections.go +++ b/internal/inspect/role_connections/role_connections.go @@ -37,7 +37,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/seq_scans/seq_scans.go b/internal/inspect/seq_scans/seq_scans.go index 95df29e0f..6e361a837 100644 --- a/internal/inspect/seq_scans/seq_scans.go +++ b/internal/inspect/seq_scans/seq_scans.go @@ -25,7 +25,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/table_index_sizes/table_index_sizes.go b/internal/inspect/table_index_sizes/table_index_sizes.go index 26cac0767..c22cb1179 100644 --- a/internal/inspect/table_index_sizes/table_index_sizes.go +++ b/internal/inspect/table_index_sizes/table_index_sizes.go @@ -28,7 +28,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/table_record_counts/table_record_counts.go b/internal/inspect/table_record_counts/table_record_counts.go index fc8055f7a..116a52346 100644 --- a/internal/inspect/table_record_counts/table_record_counts.go +++ b/internal/inspect/table_record_counts/table_record_counts.go @@ -27,7 +27,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/table_sizes/table_sizes.go b/internal/inspect/table_sizes/table_sizes.go index 979ac482c..a15f34644 100644 --- a/internal/inspect/table_sizes/table_sizes.go +++ b/internal/inspect/table_sizes/table_sizes.go @@ -28,7 +28,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/total_index_size/total_index_size.go b/internal/inspect/total_index_size/total_index_size.go index 60b1f4b91..8e6bf65e2 100644 --- a/internal/inspect/total_index_size/total_index_size.go +++ b/internal/inspect/total_index_size/total_index_size.go @@ -26,7 +26,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/total_table_sizes/total_table_sizes.go b/internal/inspect/total_table_sizes/total_table_sizes.go index b429aae45..e87092fbb 100644 --- a/internal/inspect/total_table_sizes/total_table_sizes.go +++ b/internal/inspect/total_table_sizes/total_table_sizes.go @@ -29,7 +29,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/unused_indexes/unused_indexes.go b/internal/inspect/unused_indexes/unused_indexes.go index cb304268a..f6c77ef93 100644 --- a/internal/inspect/unused_indexes/unused_indexes.go +++ b/internal/inspect/unused_indexes/unused_indexes.go @@ -32,7 +32,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/inspect/vacuum_stats/vacuum_stats.go b/internal/inspect/vacuum_stats/vacuum_stats.go index 3f4246a27..db6606036 100644 --- a/internal/inspect/vacuum_stats/vacuum_stats.go +++ b/internal/inspect/vacuum_stats/vacuum_stats.go @@ -71,7 +71,7 @@ type Result struct { } func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/migration/list/list.go b/internal/migration/list/list.go index 017e2360a..0db6ffb4f 100644 --- a/internal/migration/list/list.go +++ b/internal/migration/list/list.go @@ -36,7 +36,7 @@ func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...fu } func loadRemoteVersions(ctx context.Context, config pgconn.Config, options ...func(*pgx.ConnConfig)) ([]string, error) { - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return nil, err } diff --git a/internal/migration/repair/repair.go b/internal/migration/repair/repair.go index aded29dad..59a352740 100644 --- a/internal/migration/repair/repair.go +++ b/internal/migration/repair/repair.go @@ -38,7 +38,7 @@ func Run(ctx context.Context, config pgconn.Config, version, status string, fsys if _, err := strconv.Atoi(version); err != nil { return ErrInvalidVersion } - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/migration/squash/squash.go b/internal/migration/squash/squash.go index 36fd1a79a..449b49c10 100644 --- a/internal/migration/squash/squash.go +++ b/internal/migration/squash/squash.go @@ -37,7 +37,7 @@ func Run(ctx context.Context, version string, config pgconn.Config, fsys afero.F return err } // 2. Update migration history - if len(config.Host) == 0 || !utils.PromptYesNo("Update remote migration history table?", true, os.Stdin) { + if utils.IsLoopback(config.Host) || !utils.PromptYesNo("Update remote migration history table?", true, os.Stdin) { return nil } return baselineMigrations(ctx, config, version, fsys, options...) @@ -110,7 +110,7 @@ func baselineMigrations(ctx context.Context, config pgconn.Config, version strin } } fmt.Fprintln(os.Stderr, "Baselining migration history to", version) - conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/migration/squash/squash_test.go b/internal/migration/squash/squash_test.go index 3bf8240df..661eebc88 100644 --- a/internal/migration/squash/squash_test.go +++ b/internal/migration/squash/squash_test.go @@ -22,7 +22,7 @@ import ( ) var dbConfig = pgconn.Config{ - Host: "127.0.0.1", + Host: "db.supabase.co", Port: 5432, User: "admin", Password: "password", @@ -72,7 +72,7 @@ func TestSquashCommand(t *testing.T) { Query(repair.INSERT_MIGRATION_VERSION, "1", "target", "{}"). Reply("INSERT 1") // Run test - err := Run(context.Background(), "", pgconn.Config{}, fsys, conn.Intercept) + err := Run(context.Background(), "", pgconn.Config{Host: "127.0.0.1"}, fsys, conn.Intercept) // Check error assert.NoError(t, err) assert.Empty(t, apitest.ListUnmatchedRequests()) diff --git a/internal/migration/up/up.go b/internal/migration/up/up.go index 334b1c6eb..732e560f9 100644 --- a/internal/migration/up/up.go +++ b/internal/migration/up/up.go @@ -19,11 +19,11 @@ var ( errMissingLocal = errors.New("Remote migration versions not found in " + utils.MigrationsDir + " directory.") ) -func Run(ctx context.Context, includeAll bool, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { +func Run(ctx context.Context, includeAll bool, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { if err := utils.LoadConfigFS(fsys); err != nil { return err } - conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{}, options...) + conn, err := utils.ConnectByConfig(ctx, config, options...) if err != nil { return err } diff --git a/internal/utils/connect.go b/internal/utils/connect.go index 492d85ce6..d9a1f5a33 100644 --- a/internal/utils/connect.go +++ b/internal/utils/connect.go @@ -7,6 +7,7 @@ import ( "net" "net/url" "os" + "strings" "time" "github.com/jackc/pgconn" @@ -104,3 +105,22 @@ func ConnectByUrl(ctx context.Context, url string, options ...func(*pgx.ConnConf // Connect to database return pgx.ConnectConfig(ctx, config) } + +func ConnectByConfig(ctx context.Context, config pgconn.Config, options ...func(*pgx.ConnConfig)) (*pgx.Conn, error) { + if IsLoopback(config.Host) { + fmt.Fprintln(os.Stderr, "Connecting to local database...") + return ConnectLocalPostgres(ctx, config, options...) + } + fmt.Fprintln(os.Stderr, "Connecting to remote database...") + return ConnectRemotePostgres(ctx, config, options...) +} + +func IsLoopback(host string) bool { + if strings.ToLower(host) == "localhost" { + return true + } + if ip := net.ParseIP(host); ip != nil { + return ip.IsLoopback() + } + return false +}