From a0c8e0cc7c83c66c6ac374f4e1fa129bc11abf61 Mon Sep 17 00:00:00 2001 From: Eric Leijonmarck Date: Tue, 16 Jan 2024 11:38:25 +0000 Subject: [PATCH] feat: unlink --- cmd/unlink.go | 33 +++ internal/unlink/unlink.go | 216 ++++++++++++++++ internal/unlink/unlink_test.go | 437 +++++++++++++++++++++++++++++++++ 3 files changed, 686 insertions(+) create mode 100644 cmd/unlink.go create mode 100644 internal/unlink/unlink.go create mode 100644 internal/unlink/unlink_test.go diff --git a/cmd/unlink.go b/cmd/unlink.go new file mode 100644 index 000000000..3147fa461 --- /dev/null +++ b/cmd/unlink.go @@ -0,0 +1,33 @@ +package cmd + +import ( + "os" + "os/signal" + + "github.com/spf13/afero" + "github.com/spf13/cobra" + "github.com/supabase/cli/internal/unlink" +) + +var ( + unlinkCmd = &cobra.Command{ + GroupID: groupLocalDev, + Use: "unlink", + Short: "Unlink to a Supabase project", + PreRunE: func(cmd *cobra.Command, args []string) error { + return unlink.PreRun("", afero.NewOsFs()) + }, + RunE: func(cmd *cobra.Command, args []string) error { + ctx, _ := signal.NotifyContext(cmd.Context(), os.Interrupt) + fsys := afero.NewOsFs() + return unlink.Run(ctx, dbPassword, fsys) + }, + PostRunE: func(cmd *cobra.Command, args []string) error { + return unlink.PostRun("", os.Stdout, afero.NewOsFs()) + }, + } +) + +func init() { + rootCmd.AddCommand(unlinkCmd) +} diff --git a/internal/unlink/unlink.go b/internal/unlink/unlink.go new file mode 100644 index 000000000..9ca1ab1d6 --- /dev/null +++ b/internal/unlink/unlink.go @@ -0,0 +1,216 @@ +package unlink + +import ( + "context" + "fmt" + "io" + "os" + "strconv" + "strings" + "sync" + + "github.com/BurntSushi/toml" + "github.com/go-errors/errors" + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" + "github.com/spf13/afero" + "github.com/spf13/viper" + "github.com/supabase/cli/internal/migration/repair" + "github.com/supabase/cli/internal/utils" + "github.com/supabase/cli/internal/utils/credentials" + "github.com/supabase/cli/internal/utils/tenant" + "github.com/supabase/cli/pkg/api" +) + +var updatedConfig ConfigCopy + +type ConfigCopy struct { + Api interface{} `toml:"api"` + Db interface{} `toml:"db"` + Pooler interface{} `toml:"db.pooler"` +} + +func (c ConfigCopy) IsEmpty() bool { + return c.Api == nil && c.Db == nil && c.Pooler == nil +} + +func PreRun(projectRef string, fsys afero.Fs) error { + return utils.LoadConfigFS(fsys) +} + +func Run(ctx context.Context, password string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { + unlinkServices(ctx, fsys) + + // 3. Save project ref + return utils.WriteFile(utils.ProjectRefPath, []byte(""), fsys) +} + +func PostRun(projectRef string, stdout io.Writer, fsys afero.Fs) error { + fmt.Fprintln(stdout, "Finished "+utils.Aqua("supabase unlink")+".") + if updatedConfig.IsEmpty() { + return nil + } + fmt.Fprintln(os.Stderr, "Local config differs from linked project. Try updating", utils.Bold(utils.ConfigPath)) + enc := toml.NewEncoder(stdout) + enc.Indent = "" + if err := enc.Encode(updatedConfig); err != nil { + return errors.Errorf("failed to marshal toml config: %w", err) + } + return nil +} + +func unlinkServices(ctx context.Context, fsys afero.Fs) { + // Ignore non-fatal errors linking services + var wg sync.WaitGroup + wg.Add(6) + go func() { + defer wg.Done() + if err := unlinkDatabaseVersion(ctx, fsys); err != nil && viper.GetBool("DEBUG") { + fmt.Fprintln(os.Stderr, err) + } + }() + go func() { + defer wg.Done() + if err := linkPostgrest(ctx, ""); err != nil && viper.GetBool("DEBUG") { + fmt.Fprintln(os.Stderr, err) + } + }() + go func() { + defer wg.Done() + if err := unlinkPostgrestVersion(ctx, fsys); err != nil && viper.GetBool("DEBUG") { + fmt.Fprintln(os.Stderr, err) + } + }() + go func() { + defer wg.Done() + if err := unlinkGotrueVersion(ctx, fsys); err != nil && viper.GetBool("DEBUG") { + fmt.Fprintln(os.Stderr, err) + } + }() + go func() { + defer wg.Done() + if err := unlinkStorageVersion(ctx, fsys); err != nil && viper.GetBool("DEBUG") { + fmt.Fprintln(os.Stderr, err) + } + }() + go func() { + defer wg.Done() + if err := unlinkPooler(ctx, fsys); err != nil && viper.GetBool("DEBUG") { + fmt.Fprintln(os.Stderr, err) + } + }() + wg.Wait() +} + +func linkPostgrest(ctx context.Context, projectRef string) error { + resp, err := utils.GetSupabase().GetPostgRESTConfigWithResponse(ctx, projectRef) + if err != nil { + return errors.Errorf("failed to get postgrest config: %w", err) + } + if resp.JSON200 == nil { + return errors.Errorf("%w: %s", tenant.ErrAuthToken, string(resp.Body)) + } + updateApiConfig(*resp.JSON200) + return nil +} + +func unlinkPostgrestVersion(ctx context.Context, fsys afero.Fs) error { + return utils.WriteFile(utils.RestVersionPath, []byte(""), fsys) +} + +func updateApiConfig(config api.PostgrestConfigWithJWTSecretResponse) { + copy := utils.Config.Api + copy.MaxRows = uint(config.MaxRows) + copy.ExtraSearchPath = readCsv(config.DbExtraSearchPath) + copy.Schemas = readCsv(config.DbSchema) + changed := utils.Config.Api.MaxRows != copy.MaxRows || + !utils.SliceEqual(utils.Config.Api.ExtraSearchPath, copy.ExtraSearchPath) || + !utils.SliceEqual(utils.Config.Api.Schemas, copy.Schemas) + if changed { + updatedConfig.Api = copy + } +} + +func readCsv(line string) []string { + var result []string + tokens := strings.Split(line, ",") + for _, t := range tokens { + trimmed := strings.TrimSpace(t) + if len(trimmed) > 0 { + result = append(result, trimmed) + } + } + return result +} + +func unlinkGotrueVersion(ctx context.Context, fsys afero.Fs) error { + return utils.WriteFile(utils.GotrueVersionPath, []byte(""), fsys) +} + +func unlinkStorageVersion(ctx context.Context, fsys afero.Fs) error { + return utils.WriteFile(utils.StorageVersionPath, []byte(""), fsys) +} + +func unlinkDatabase(ctx context.Context, config pgconn.Config, options ...func(*pgx.ConnConfig)) error { + conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + if err != nil { + return err + } + defer conn.Close(context.Background()) + updatePostgresConfig(conn) + // If `schema_migrations` doesn't exist on the remote database, create it. + return repair.CreateMigrationTable(ctx, conn) +} + +func unlinkDatabaseVersion(ctx context.Context, fsys afero.Fs) error { + return utils.WriteFile(utils.PostgresVersionPath, []byte(""), fsys) +} + +func updatePostgresConfig(conn *pgx.Conn) { + serverVersion := conn.PgConn().ParameterStatus("server_version") + // Safe to assume that supported Postgres version is 10.0 <= n < 100.0 + majorDigits := len(serverVersion) + if majorDigits > 2 { + majorDigits = 2 + } + dbMajorVersion, err := strconv.ParseUint(serverVersion[:majorDigits], 10, 7) + // Treat error as unchanged + if err == nil && uint64(utils.Config.Db.MajorVersion) != dbMajorVersion { + copy := utils.Config.Db + copy.MajorVersion = uint(dbMajorVersion) + updatedConfig.Db = copy + } +} + +func unlinkPooler(ctx context.Context, fsys afero.Fs) error { + return utils.WriteFile(utils.PoolerUrlPath, []byte(""), fsys) +} + +func updatePoolerConfig(config api.V1PgbouncerConfigResponse) { + copy := utils.Config.Db.Pooler + if config.PoolMode != nil { + copy.PoolMode = utils.PoolMode(*config.PoolMode) + } + if config.DefaultPoolSize != nil { + copy.DefaultPoolSize = uint(*config.DefaultPoolSize) + } + if config.MaxClientConn != nil { + copy.MaxClientConn = uint(*config.MaxClientConn) + } + changed := utils.Config.Db.Pooler.PoolMode != copy.PoolMode || + utils.Config.Db.Pooler.DefaultPoolSize != copy.DefaultPoolSize || + utils.Config.Db.Pooler.MaxClientConn != copy.MaxClientConn + if changed { + updatedConfig.Pooler = copy + } +} + +func PromptPassword(stdin *os.File) string { + fmt.Fprint(os.Stderr, "Enter your database password: ") + return credentials.PromptMasked(stdin) +} + +func PromptPasswordAllowBlank(stdin *os.File) string { + fmt.Fprint(os.Stderr, "Enter your database password (or leave blank to skip): ") + return credentials.PromptMasked(stdin) +} diff --git a/internal/unlink/unlink_test.go b/internal/unlink/unlink_test.go new file mode 100644 index 000000000..0a74937cd --- /dev/null +++ b/internal/unlink/unlink_test.go @@ -0,0 +1,437 @@ +package unlink + +import ( + "context" + "errors" + "os" + "strings" + "testing" + + "github.com/jackc/pgconn" + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v4" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/supabase/cli/internal/link" + "github.com/supabase/cli/internal/migration/history" + "github.com/supabase/cli/internal/testing/apitest" + "github.com/supabase/cli/internal/testing/pgtest" + "github.com/supabase/cli/internal/utils" + "github.com/supabase/cli/internal/utils/tenant" + "github.com/supabase/cli/pkg/api" + "github.com/zalando/go-keyring" + "gopkg.in/h2non/gock.v1" +) + +var dbConfig = pgconn.Config{ + Host: "127.0.0.1", + Port: 5432, + User: "admin", + Password: "password", + Database: "postgres", +} + +func TestPreRun(t *testing.T) { + // Reset global variable + copy := utils.Config + teardown := func() { + utils.Config = copy + } + + t.Run("passes sanity check", func(t *testing.T) { + defer teardown() + project := apitest.RandomProjectRef() + // Setup in-memory fs + fsys := afero.NewMemMapFs() + require.NoError(t, utils.WriteConfig(fsys, false)) + // Run test + err := PreRun(project, fsys) + // Check error + assert.NoError(t, err) + }) + + t.Run("throws error on invalid project ref", func(t *testing.T) { + defer teardown() + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Run test + err := PreRun("malformed", fsys) + // Check error + assert.ErrorIs(t, err, utils.ErrInvalidRef) + }) + + t.Run("throws error on missing config", func(t *testing.T) { + defer teardown() + project := apitest.RandomProjectRef() + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Run test + err := PreRun(project, fsys) + // Check error + assert.ErrorIs(t, err, os.ErrNotExist) + }) +} + +// Reset global variable +func teardown() { + updatedConfig.Api = nil + updatedConfig.Db = nil + updatedConfig.Pooler = nil +} + +func TestPostRun(t *testing.T) { + t.Run("prints completion message", func(t *testing.T) { + defer teardown() + project := "test-project" + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Run test + buf := &strings.Builder{} + err := PostRun(project, buf, fsys) + // Check error + assert.NoError(t, err) + assert.Equal(t, "Finished supabase link.\n", buf.String()) + }) + + t.Run("prints changed config", func(t *testing.T) { + defer teardown() + project := "test-project" + updatedConfig.Api = "test" + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Run test + buf := &strings.Builder{} + err := PostRun(project, buf, fsys) + // Check error + assert.NoError(t, err) + assert.Contains(t, buf.String(), `api = "test"`) + }) +} + +func TestUnlinkCommand(t *testing.T) { + project := "test-project" + // Setup valid access token + token := apitest.RandomAccessToken(t) + t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) + // Mock credentials store + keyring.MockInit() + + t.Run("unlink valid project", func(t *testing.T) { + defer teardown() + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + pgtest.MockMigrationHistory(conn) + // Flush pending mocks after test execution + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/api-keys"). + Reply(200). + JSON([]api.ApiKeyResponse{{Name: "anon", ApiKey: "anon-key"}}) + // Link configs + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/postgrest"). + Reply(200). + JSON(api.PostgrestConfigResponse{}) + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/config/database/pgbouncer"). + Reply(200). + JSON(api.V1PgbouncerConfigResponse{}) + // Link versions + auth := tenant.HealthResponse{Version: "v2.74.2"} + gock.New("https://" + utils.GetSupabaseHost(project)). + Get("/auth/v1/health"). + Reply(200). + JSON(auth) + rest := tenant.SwaggerResponse{Info: tenant.SwaggerInfo{Version: "11.1.0"}} + gock.New("https://" + utils.GetSupabaseHost(project)). + Get("/rest/v1/"). + Reply(200). + JSON(rest) + gock.New("https://" + utils.GetSupabaseHost(project)). + Get("/storage/v1/version"). + Reply(200). + BodyString("0.40.4") + postgres := api.DatabaseResponse{ + Host: utils.GetSupabaseDbHost(project), + Version: "15.1.0.117", + } + gock.New(utils.DefaultApiHost). + Get("/v1/projects"). + Reply(200). + JSON([]api.ProjectResponse{ + { + Id: project, + Database: &postgres, + OrganizationId: "combined-fuchsia-lion", + Name: "Test Project", + Region: "us-west-1", + CreatedAt: "2022-04-25T02:14:55.906498Z", + }, + }) + // Run link + err := link.Run(context.Background(), project, dbConfig.Password, fsys, conn.Intercept) + // Check error + assert.NoError(t, err) + assert.Empty(t, apitest.ListUnmatchedRequests()) + // Run unlink test + err = Run(context.Background(), dbConfig.Password, fsys, conn.Intercept) + // Check error + assert.NoError(t, err) + assert.Empty(t, apitest.ListUnmatchedRequests()) + // Validate file contents + content, err := afero.ReadFile(fsys, utils.ProjectRefPath) + assert.NoError(t, err) + assert.Equal(t, []byte(project), content) + restVersion, err := afero.ReadFile(fsys, utils.RestVersionPath) + assert.NoError(t, err) + assert.Equal(t, []byte("v"+rest.Info.Version), restVersion) + authVersion, err := afero.ReadFile(fsys, utils.GotrueVersionPath) + assert.NoError(t, err) + assert.Equal(t, []byte(auth.Version), authVersion) + postgresVersion, err := afero.ReadFile(fsys, utils.PostgresVersionPath) + assert.NoError(t, err) + assert.Equal(t, []byte(postgres.Version), postgresVersion) + }) + + t.Run("throws error on network failure", func(t *testing.T) { + t.Skip() + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Flush pending mocks after test execution + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/api-keys"). + ReplyError(errors.New("network error")) + // Run test + err := Run(context.Background(), dbConfig.Password, fsys) + // Check error + assert.ErrorContains(t, err, "network error") + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) + + t.Run("ignores error linking services", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Flush pending mocks after test execution + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/api-keys"). + Reply(200). + JSON([]api.ApiKeyResponse{{Name: "anon", ApiKey: "anon-key"}}) + // Link configs + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/postgrest"). + ReplyError(errors.New("network error")) + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/config/database/pgbouncer"). + ReplyError(errors.New("network error")) + // Link versions + gock.New("https://" + utils.GetSupabaseHost(project)). + Get("/auth/v1/health"). + ReplyError(errors.New("network error")) + gock.New("https://" + utils.GetSupabaseHost(project)). + Get("/rest/v1/"). + ReplyError(errors.New("network error")) + gock.New("https://" + utils.GetSupabaseHost(project)). + Get("/storage/v1/version"). + ReplyError(errors.New("network error")) + gock.New(utils.DefaultApiHost). + Get("/v1/projects"). + ReplyError(errors.New("network error")) + // Run test + err := Run(context.Background(), dbConfig.Password, fsys, func(cc *pgx.ConnConfig) { + cc.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + return nil, errors.New("hostname resolving error") + } + }) + // Check error + assert.ErrorContains(t, err, "hostname resolving error") + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) + + t.Run("throws error on write failure", func(t *testing.T) { + defer teardown() + // Setup in-memory fs + fsys := afero.NewReadOnlyFs(afero.NewMemMapFs()) + // Flush pending mocks after test execution + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/api-keys"). + Reply(200). + JSON([]api.ApiKeyResponse{{Name: "anon", ApiKey: "anon-key"}}) + // Link configs + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/postgrest"). + ReplyError(errors.New("network error")) + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/config/database/pgbouncer"). + ReplyError(errors.New("network error")) + // Link versions + gock.New("https://" + utils.GetSupabaseHost(project)). + Get("/auth/v1/health"). + ReplyError(errors.New("network error")) + gock.New("https://" + utils.GetSupabaseHost(project)). + Get("/rest/v1/"). + ReplyError(errors.New("network error")) + gock.New("https://" + utils.GetSupabaseHost(project)). + Get("/storage/v1/version"). + ReplyError(errors.New("network error")) + gock.New(utils.DefaultApiHost). + Get("/v1/projects"). + ReplyError(errors.New("network error")) + // Run test + err := Run(context.Background(), project, fsys) + // Check error + assert.ErrorContains(t, err, "operation not permitted") + assert.Empty(t, apitest.ListUnmatchedRequests()) + // Validate file contents + exists, err := afero.Exists(fsys, utils.ProjectRefPath) + assert.NoError(t, err) + assert.False(t, exists) + }) +} + +func TestLinkPostgrest(t *testing.T) { + project := "test-project" + // Setup valid access token + token := apitest.RandomAccessToken(t) + t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) + + t.Run("ignores matching config", func(t *testing.T) { + defer teardown() + // Flush pending mocks after test execution + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/postgrest"). + Reply(200). + JSON(api.PostgrestConfigResponse{}) + // Run test + err := linkPostgrest(context.Background(), project) + // Check error + assert.NoError(t, err) + assert.Empty(t, apitest.ListUnmatchedRequests()) + assert.Empty(t, updatedConfig) + }) + + t.Run("updates api on newer config", func(t *testing.T) { + defer teardown() + // Flush pending mocks after test execution + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/postgrest"). + Reply(200). + JSON(api.PostgrestConfigResponse{ + DbSchema: "public, storage, graphql_public", + DbExtraSearchPath: "public, extensions", + MaxRows: 1000, + }) + // Run test + err := linkPostgrest(context.Background(), project) + // Check error + assert.NoError(t, err) + assert.Empty(t, apitest.ListUnmatchedRequests()) + utils.Config.Api.Schemas = []string{"public", "storage", "graphql_public"} + utils.Config.Api.ExtraSearchPath = []string{"public", "extensions"} + utils.Config.Api.MaxRows = 1000 + assert.Equal(t, ConfigCopy{ + Api: utils.Config.Api, + }, updatedConfig) + }) + + t.Run("throws error on network failure", func(t *testing.T) { + defer teardown() + // Flush pending mocks after test execution + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/postgrest"). + ReplyError(errors.New("network error")) + // Run test + err := linkPostgrest(context.Background(), project) + // Validate api + assert.ErrorContains(t, err, "network error") + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) + + t.Run("throws error on server unavailable", func(t *testing.T) { + defer teardown() + // Flush pending mocks after test execution + defer gock.OffAll() + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + project + "/postgrest"). + Reply(500). + JSON(map[string]string{"message": "unavailable"}) + // Run test + err := linkPostgrest(context.Background(), project) + // Validate api + assert.ErrorIs(t, err, tenant.ErrAuthToken) + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) +} + +func TestLinkDatabase(t *testing.T) { + t.Run("throws error on connect failure", func(t *testing.T) { + defer teardown() + // Run test + err := unlinkDatabase(context.Background(), pgconn.Config{}) + // Check error + assert.ErrorContains(t, err, "invalid port (outside range)") + assert.Empty(t, updatedConfig) + }) + + t.Run("ignores missing server version", func(t *testing.T) { + defer teardown() + // Setup mock postgres + conn := pgtest.NewWithStatus(map[string]string{ + "standard_conforming_strings": "on", + }) + defer conn.Close(t) + pgtest.MockMigrationHistory(conn) + // Run test + err := unlinkDatabase(context.Background(), dbConfig, conn.Intercept) + // Check error + assert.NoError(t, err) + assert.Empty(t, updatedConfig) + }) + + t.Run("updates config to newer db version", func(t *testing.T) { + defer teardown() + utils.Config.Db.MajorVersion = 14 + // Setup mock postgres + conn := pgtest.NewWithStatus(map[string]string{ + "standard_conforming_strings": "on", + "server_version": "15.0", + }) + defer conn.Close(t) + pgtest.MockMigrationHistory(conn) + // Run test + err := unlinkDatabase(context.Background(), dbConfig, conn.Intercept) + // Check error + assert.NoError(t, err) + utils.Config.Db.MajorVersion = 15 + assert.Equal(t, ConfigCopy{ + Db: utils.Config.Db, + }, updatedConfig) + }) + + t.Run("throws error on query failure", func(t *testing.T) { + defer teardown() + utils.Config.Db.MajorVersion = 14 + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + conn.Query(history.CREATE_VERSION_SCHEMA). + Reply("CREATE SCHEMA"). + Query(history.CREATE_VERSION_TABLE). + ReplyError(pgerrcode.InsufficientPrivilege, "permission denied for relation supabase_migrations"). + Query(history.ADD_STATEMENTS_COLUMN). + Query(history.ADD_NAME_COLUMN) + // Run test + err := unlinkDatabase(context.Background(), dbConfig, conn.Intercept) + // Check error + assert.ErrorContains(t, err, "ERROR: permission denied for relation supabase_migrations (SQLSTATE 42501)") + }) +}