Skip to content

Commit

Permalink
fix: load db config after fetching pooler url (#1968)
Browse files Browse the repository at this point in the history
  • Loading branch information
sweatybridge authored Feb 20, 2024
1 parent 4ce8038 commit bd43a42
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 25 deletions.
3 changes: 1 addition & 2 deletions cmd/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ var (
if err := utils.LoadConfigFS(fsys); err != nil {
return err
}
config := flags.GetDbConfigOptionalPassword(flags.ProjectRef)
return link.Run(ctx, flags.ProjectRef, config, fsys)
return link.Run(ctx, flags.ProjectRef, fsys)
},
PostRunE: func(cmd *cobra.Command, args []string) error {
return link.PostRun(flags.ProjectRef, os.Stdout, afero.NewOsFs())
Expand Down
10 changes: 6 additions & 4 deletions internal/link/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/supabase/cli/internal/migration/repair"
"github.com/supabase/cli/internal/utils"
"github.com/supabase/cli/internal/utils/credentials"
"github.com/supabase/cli/internal/utils/flags"
"github.com/supabase/cli/internal/utils/tenant"
"github.com/supabase/cli/pkg/api"
)
Expand All @@ -34,20 +35,21 @@ func (c ConfigCopy) IsEmpty() bool {
return c.Api == nil && c.Db == nil && c.Pooler == nil
}

func Run(ctx context.Context, projectRef string, dbConfig pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
// 1. Check service config
if _, err := tenant.GetApiKeys(ctx, projectRef); err != nil {
return err
}
linkServices(ctx, projectRef, fsys)

// 2. Check database connection
if len(dbConfig.Password) > 0 {
if err := linkDatabase(ctx, dbConfig, options...); err != nil {
config := flags.GetDbConfigOptionalPassword(projectRef)
if len(config.Password) > 0 {
if err := linkDatabase(ctx, config, options...); err != nil {
return err
}
// Save database password
if err := credentials.Set(projectRef, dbConfig.Password); err != nil {
if err := credentials.Set(projectRef, config.Password); err != nil {
fmt.Fprintln(os.Stderr, "Failed to save database password:", err)
}
}
Expand Down
50 changes: 31 additions & 19 deletions internal/link/link_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package link
import (
"context"
"errors"
"os"
"strings"
"testing"

Expand All @@ -11,6 +12,7 @@ import (
"github.com/jackc/pgx/v4"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/supabase/cli/internal/migration/history"
"github.com/supabase/cli/internal/testing/apitest"
"github.com/supabase/cli/internal/testing/pgtest"
Expand Down Expand Up @@ -75,6 +77,19 @@ func TestLinkCommand(t *testing.T) {

t.Run("link valid project", func(t *testing.T) {
defer teardown()
// Change stdin to read from a file
stdin, err := os.CreateTemp("", "")
require.NoError(t, err)
defer os.Remove(stdin.Name())

_, err = stdin.Write([]byte{'\n'})
require.NoError(t, err)
_, err = stdin.Seek(0, 0)
require.NoError(t, err)

oldStdin := os.Stdin
defer func() { os.Stdin = oldStdin }()
os.Stdin = stdin
// Setup in-memory fs
fsys := afero.NewMemMapFs()
// Setup mock postgres
Expand Down Expand Up @@ -129,7 +144,7 @@ func TestLinkCommand(t *testing.T) {
},
})
// Run test
err := Run(context.Background(), project, dbConfig, fsys, conn.Intercept)
err = Run(context.Background(), project, fsys, conn.Intercept)
// Check error
assert.NoError(t, err)
assert.Empty(t, apitest.ListUnmatchedRequests())
Expand All @@ -148,23 +163,20 @@ func TestLinkCommand(t *testing.T) {
assert.Equal(t, []byte(postgres.Version), postgresVersion)
})

t.Run("throws error on network failure", func(t *testing.T) {
t.Skip()
// Setup in-memory fs
fsys := afero.NewMemMapFs()
// Flush pending mocks after test execution
defer gock.OffAll()
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/api-keys").
ReplyError(errors.New("network error"))
// Run test
err := Run(context.Background(), project, dbConfig, fsys)
// Check error
assert.ErrorContains(t, err, "network error")
assert.Empty(t, apitest.ListUnmatchedRequests())
})

t.Run("ignores error linking services", func(t *testing.T) {
// Change stdin to read from a file
stdin, err := os.CreateTemp("", "")
require.NoError(t, err)
defer os.Remove(stdin.Name())

_, err = stdin.Write([]byte{'\n'})
require.NoError(t, err)
_, err = stdin.Seek(0, 0)
require.NoError(t, err)

oldStdin := os.Stdin
defer func() { os.Stdin = oldStdin }()
os.Stdin = stdin
// Setup in-memory fs
fsys := afero.NewMemMapFs()
// Flush pending mocks after test execution
Expand Down Expand Up @@ -194,7 +206,7 @@ func TestLinkCommand(t *testing.T) {
Get("/v1/projects").
ReplyError(errors.New("network error"))
// Run test
err := Run(context.Background(), project, dbConfig, fsys, func(cc *pgx.ConnConfig) {
err = Run(context.Background(), project, fsys, func(cc *pgx.ConnConfig) {
cc.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) {
return nil, errors.New("hostname resolving error")
}
Expand Down Expand Up @@ -235,7 +247,7 @@ func TestLinkCommand(t *testing.T) {
Get("/v1/projects").
ReplyError(errors.New("network error"))
// Run test
err := Run(context.Background(), project, pgconn.Config{}, fsys)
err := Run(context.Background(), project, fsys)
// Check error
assert.ErrorContains(t, err, "operation not permitted")
assert.Empty(t, apitest.ListUnmatchedRequests())
Expand Down

0 comments on commit bd43a42

Please sign in to comment.