From bd43a42005ebdb5991a08da795f8b709e17f67b0 Mon Sep 17 00:00:00 2001 From: Han Qiao Date: Tue, 20 Feb 2024 19:10:30 +0800 Subject: [PATCH] fix: load db config after fetching pooler url (#1968) --- cmd/link.go | 3 +-- internal/link/link.go | 10 +++++--- internal/link/link_test.go | 50 +++++++++++++++++++++++--------------- 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/cmd/link.go b/cmd/link.go index 91e5d24a5..facaa5320 100644 --- a/cmd/link.go +++ b/cmd/link.go @@ -34,8 +34,7 @@ var ( if err := utils.LoadConfigFS(fsys); err != nil { return err } - config := flags.GetDbConfigOptionalPassword(flags.ProjectRef) - return link.Run(ctx, flags.ProjectRef, config, fsys) + return link.Run(ctx, flags.ProjectRef, fsys) }, PostRunE: func(cmd *cobra.Command, args []string) error { return link.PostRun(flags.ProjectRef, os.Stdout, afero.NewOsFs()) diff --git a/internal/link/link.go b/internal/link/link.go index 9a4f3a8d8..f21079e25 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -18,6 +18,7 @@ import ( "github.com/supabase/cli/internal/migration/repair" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/credentials" + "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/internal/utils/tenant" "github.com/supabase/cli/pkg/api" ) @@ -34,7 +35,7 @@ func (c ConfigCopy) IsEmpty() bool { return c.Api == nil && c.Db == nil && c.Pooler == nil } -func Run(ctx context.Context, projectRef string, dbConfig pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { +func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { // 1. Check service config if _, err := tenant.GetApiKeys(ctx, projectRef); err != nil { return err @@ -42,12 +43,13 @@ func Run(ctx context.Context, projectRef string, dbConfig pgconn.Config, fsys af linkServices(ctx, projectRef, fsys) // 2. Check database connection - if len(dbConfig.Password) > 0 { - if err := linkDatabase(ctx, dbConfig, options...); err != nil { + config := flags.GetDbConfigOptionalPassword(projectRef) + if len(config.Password) > 0 { + if err := linkDatabase(ctx, config, options...); err != nil { return err } // Save database password - if err := credentials.Set(projectRef, dbConfig.Password); err != nil { + if err := credentials.Set(projectRef, config.Password); err != nil { fmt.Fprintln(os.Stderr, "Failed to save database password:", err) } } diff --git a/internal/link/link_test.go b/internal/link/link_test.go index 9dc42cd2d..81cf643e9 100644 --- a/internal/link/link_test.go +++ b/internal/link/link_test.go @@ -3,6 +3,7 @@ package link import ( "context" "errors" + "os" "strings" "testing" @@ -11,6 +12,7 @@ import ( "github.com/jackc/pgx/v4" "github.com/spf13/afero" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/supabase/cli/internal/migration/history" "github.com/supabase/cli/internal/testing/apitest" "github.com/supabase/cli/internal/testing/pgtest" @@ -75,6 +77,19 @@ func TestLinkCommand(t *testing.T) { t.Run("link valid project", func(t *testing.T) { defer teardown() + // Change stdin to read from a file + stdin, err := os.CreateTemp("", "") + require.NoError(t, err) + defer os.Remove(stdin.Name()) + + _, err = stdin.Write([]byte{'\n'}) + require.NoError(t, err) + _, err = stdin.Seek(0, 0) + require.NoError(t, err) + + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + os.Stdin = stdin // Setup in-memory fs fsys := afero.NewMemMapFs() // Setup mock postgres @@ -129,7 +144,7 @@ func TestLinkCommand(t *testing.T) { }, }) // Run test - err := Run(context.Background(), project, dbConfig, fsys, conn.Intercept) + err = Run(context.Background(), project, fsys, conn.Intercept) // Check error assert.NoError(t, err) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -148,23 +163,20 @@ func TestLinkCommand(t *testing.T) { assert.Equal(t, []byte(postgres.Version), postgresVersion) }) - t.Run("throws error on network failure", func(t *testing.T) { - t.Skip() - // Setup in-memory fs - fsys := afero.NewMemMapFs() - // Flush pending mocks after test execution - defer gock.OffAll() - gock.New(utils.DefaultApiHost). - Get("/v1/projects/" + project + "/api-keys"). - ReplyError(errors.New("network error")) - // Run test - err := Run(context.Background(), project, dbConfig, fsys) - // Check error - assert.ErrorContains(t, err, "network error") - assert.Empty(t, apitest.ListUnmatchedRequests()) - }) - t.Run("ignores error linking services", func(t *testing.T) { + // Change stdin to read from a file + stdin, err := os.CreateTemp("", "") + require.NoError(t, err) + defer os.Remove(stdin.Name()) + + _, err = stdin.Write([]byte{'\n'}) + require.NoError(t, err) + _, err = stdin.Seek(0, 0) + require.NoError(t, err) + + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + os.Stdin = stdin // Setup in-memory fs fsys := afero.NewMemMapFs() // Flush pending mocks after test execution @@ -194,7 +206,7 @@ func TestLinkCommand(t *testing.T) { Get("/v1/projects"). ReplyError(errors.New("network error")) // Run test - err := Run(context.Background(), project, dbConfig, fsys, func(cc *pgx.ConnConfig) { + err = Run(context.Background(), project, fsys, func(cc *pgx.ConnConfig) { cc.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { return nil, errors.New("hostname resolving error") } @@ -235,7 +247,7 @@ func TestLinkCommand(t *testing.T) { Get("/v1/projects"). ReplyError(errors.New("network error")) // Run test - err := Run(context.Background(), project, pgconn.Config{}, fsys) + err := Run(context.Background(), project, fsys) // Check error assert.ErrorContains(t, err, "operation not permitted") assert.Empty(t, apitest.ListUnmatchedRequests())