From 3148ebaa2e1219e537d9d79a6408ac2ff18d2abf Mon Sep 17 00:00:00 2001 From: Andrew Valleteau Date: Wed, 23 Oct 2024 11:00:30 +0200 Subject: [PATCH 1/5] feat: add db postgres settings to config (#2787) * feat: add db postgres settings to config * no comment * fix: add updater logic * chore: fix golang lint * chore: apply review comments * chore: refactor and apply PR comments * Apply suggestions from code review Co-authored-by: Han Qiao --------- Co-authored-by: Han Qiao --- pkg/cast/cast.go | 14 ++++ pkg/config/config.go | 43 ++---------- pkg/config/db.go | 160 ++++++++++++++++++++++++++++++++++++++++++ pkg/config/db_test.go | 155 ++++++++++++++++++++++++++++++++++++++++ pkg/config/updater.go | 42 ++++++++++- 5 files changed, 376 insertions(+), 38 deletions(-) create mode 100644 pkg/config/db.go create mode 100644 pkg/config/db_test.go diff --git a/pkg/cast/cast.go b/pkg/cast/cast.go index b72cadbbf..3c7067163 100644 --- a/pkg/cast/cast.go +++ b/pkg/cast/cast.go @@ -20,6 +20,20 @@ func IntToUint(value int) uint { return uint(value) } +func UintToIntPtr(value *uint) *int { + if value == nil { + return nil + } + return Ptr(UintToInt(*value)) +} + +func IntToUintPtr(value *int) *uint { + if value == nil { + return nil + } + return Ptr(IntToUint(*value)) +} + func Ptr[T any](v T) *T { return &v } diff --git a/pkg/config/config.go b/pkg/config/config.go index f6329f459..6e548111a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -57,13 +57,6 @@ const ( LogflareBigQuery LogflareBackend = "bigquery" ) -type PoolMode string - -const ( - TransactionMode PoolMode = "transaction" - SessionMode PoolMode = "session" -) - type AddressFamily string const ( @@ -146,36 +139,6 @@ type ( Remotes map[string]baseConfig `toml:"-"` } - db struct { - Image string `toml:"-"` - Port uint16 `toml:"port"` - ShadowPort uint16 `toml:"shadow_port"` - MajorVersion uint `toml:"major_version"` - Password string `toml:"-"` - RootKey string `toml:"-" mapstructure:"root_key"` - Pooler pooler `toml:"pooler"` - Seed seed `toml:"seed"` - } - - seed struct { - Enabled bool `toml:"enabled"` - GlobPatterns []string `toml:"sql_paths"` - SqlPaths []string `toml:"-"` - } - - pooler struct { - Enabled bool `toml:"enabled"` - Image string `toml:"-"` - Port uint16 `toml:"port"` - PoolMode PoolMode `toml:"pool_mode"` - DefaultPoolSize uint `toml:"default_pool_size"` - MaxClientConn uint `toml:"max_client_conn"` - ConnectionString string `toml:"-"` - TenantId string `toml:"-"` - EncryptionKey string `toml:"-"` - SecretKeyBase string `toml:"-"` - } - realtime struct { Enabled bool `toml:"enabled"` Image string `toml:"-"` @@ -775,6 +738,12 @@ func (c *baseConfig) Validate(fsys fs.FS) error { } } // Validate db config + if c.Db.Settings.SessionReplicationRole != nil { + allowedRoles := []SessionReplicationRole{SessionReplicationRoleOrigin, SessionReplicationRoleReplica, SessionReplicationRoleLocal} + if !sliceContains(allowedRoles, *c.Db.Settings.SessionReplicationRole) { + return errors.Errorf("Invalid config for db.session_replication_role: %s. Must be one of: %v", *c.Db.Settings.SessionReplicationRole, allowedRoles) + } + } if c.Db.Port == 0 { return errors.New("Missing required field in config: db.port") } diff --git a/pkg/config/db.go b/pkg/config/db.go new file mode 100644 index 000000000..89e5bfd24 --- /dev/null +++ b/pkg/config/db.go @@ -0,0 +1,160 @@ +package config + +import ( + "github.com/google/go-cmp/cmp" + v1API "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" + "github.com/supabase/cli/pkg/diff" +) + +type PoolMode string + +const ( + TransactionMode PoolMode = "transaction" + SessionMode PoolMode = "session" +) + +type SessionReplicationRole string + +const ( + SessionReplicationRoleOrigin SessionReplicationRole = "origin" + SessionReplicationRoleReplica SessionReplicationRole = "replica" + SessionReplicationRoleLocal SessionReplicationRole = "local" +) + +type ( + settings struct { + EffectiveCacheSize *string `toml:"effective_cache_size"` + LogicalDecodingWorkMem *string `toml:"logical_decoding_work_mem"` + MaintenanceWorkMem *string `toml:"maintenance_work_mem"` + MaxConnections *uint `toml:"max_connections"` + MaxLocksPerTransaction *uint `toml:"max_locks_per_transaction"` + MaxParallelMaintenanceWorkers *uint `toml:"max_parallel_maintenance_workers"` + MaxParallelWorkers *uint `toml:"max_parallel_workers"` + MaxParallelWorkersPerGather *uint `toml:"max_parallel_workers_per_gather"` + MaxReplicationSlots *uint `toml:"max_replication_slots"` + MaxSlotWalKeepSize *string `toml:"max_slot_wal_keep_size"` + MaxStandbyArchiveDelay *string `toml:"max_standby_archive_delay"` + MaxStandbyStreamingDelay *string `toml:"max_standby_streaming_delay"` + MaxWalSize *string `toml:"max_wal_size"` + MaxWalSenders *uint `toml:"max_wal_senders"` + MaxWorkerProcesses *uint `toml:"max_worker_processes"` + SessionReplicationRole *SessionReplicationRole `toml:"session_replication_role"` + SharedBuffers *string `toml:"shared_buffers"` + StatementTimeout *string `toml:"statement_timeout"` + WalKeepSize *string `toml:"wal_keep_size"` + WalSenderTimeout *string `toml:"wal_sender_timeout"` + WorkMem *string `toml:"work_mem"` + } + + db struct { + Image string `toml:"-"` + Port uint16 `toml:"port"` + ShadowPort uint16 `toml:"shadow_port"` + MajorVersion uint `toml:"major_version"` + Password string `toml:"-"` + RootKey string `toml:"-" mapstructure:"root_key"` + Pooler pooler `toml:"pooler"` + Seed seed `toml:"seed"` + Settings settings `toml:"settings"` + } + + seed struct { + Enabled bool `toml:"enabled"` + GlobPatterns []string `toml:"sql_paths"` + SqlPaths []string `toml:"-"` + } + + pooler struct { + Enabled bool `toml:"enabled"` + Image string `toml:"-"` + Port uint16 `toml:"port"` + PoolMode PoolMode `toml:"pool_mode"` + DefaultPoolSize uint `toml:"default_pool_size"` + MaxClientConn uint `toml:"max_client_conn"` + ConnectionString string `toml:"-"` + TenantId string `toml:"-"` + EncryptionKey string `toml:"-"` + SecretKeyBase string `toml:"-"` + } +) + +// Compare two db config, if changes requires restart return true, return false otherwise +func (a settings) requireDbRestart(b settings) bool { + return !cmp.Equal(a.MaxConnections, b.MaxConnections) || + !cmp.Equal(a.MaxWorkerProcesses, b.MaxWorkerProcesses) || + !cmp.Equal(a.MaxParallelWorkers, b.MaxParallelWorkers) || + !cmp.Equal(a.MaxWalSenders, b.MaxWalSenders) || + !cmp.Equal(a.MaxReplicationSlots, b.MaxReplicationSlots) || + !cmp.Equal(a.SharedBuffers, b.SharedBuffers) +} + +func (a *settings) ToUpdatePostgresConfigBody() v1API.UpdatePostgresConfigBody { + body := v1API.UpdatePostgresConfigBody{} + + // Parameters that require restart + body.MaxConnections = cast.UintToIntPtr(a.MaxConnections) + body.MaxWorkerProcesses = cast.UintToIntPtr(a.MaxWorkerProcesses) + body.MaxParallelWorkers = cast.UintToIntPtr(a.MaxParallelWorkers) + body.MaxWalSenders = cast.UintToIntPtr(a.MaxWalSenders) + body.MaxReplicationSlots = cast.UintToIntPtr(a.MaxReplicationSlots) + body.SharedBuffers = a.SharedBuffers + + // Parameters that can be changed without restart + body.EffectiveCacheSize = a.EffectiveCacheSize + body.LogicalDecodingWorkMem = a.LogicalDecodingWorkMem + body.MaintenanceWorkMem = a.MaintenanceWorkMem + body.MaxLocksPerTransaction = cast.UintToIntPtr(a.MaxLocksPerTransaction) + body.MaxParallelMaintenanceWorkers = cast.UintToIntPtr(a.MaxParallelMaintenanceWorkers) + body.MaxParallelWorkersPerGather = cast.UintToIntPtr(a.MaxParallelWorkersPerGather) + body.MaxSlotWalKeepSize = a.MaxSlotWalKeepSize + body.MaxStandbyArchiveDelay = a.MaxStandbyArchiveDelay + body.MaxStandbyStreamingDelay = a.MaxStandbyStreamingDelay + body.MaxWalSize = a.MaxWalSize + body.SessionReplicationRole = (*v1API.UpdatePostgresConfigBodySessionReplicationRole)(a.SessionReplicationRole) + body.StatementTimeout = a.StatementTimeout + body.WalKeepSize = a.WalKeepSize + body.WalSenderTimeout = a.WalSenderTimeout + body.WorkMem = a.WorkMem + return body +} + +func (a *settings) fromRemoteConfig(remoteConfig v1API.PostgresConfigResponse) settings { + result := *a + + result.EffectiveCacheSize = remoteConfig.EffectiveCacheSize + result.LogicalDecodingWorkMem = remoteConfig.LogicalDecodingWorkMem + result.MaintenanceWorkMem = remoteConfig.MaintenanceWorkMem + result.MaxConnections = cast.IntToUintPtr(remoteConfig.MaxConnections) + result.MaxLocksPerTransaction = cast.IntToUintPtr(remoteConfig.MaxLocksPerTransaction) + result.MaxParallelMaintenanceWorkers = cast.IntToUintPtr(remoteConfig.MaxParallelMaintenanceWorkers) + result.MaxParallelWorkers = cast.IntToUintPtr(remoteConfig.MaxParallelWorkers) + result.MaxParallelWorkersPerGather = cast.IntToUintPtr(remoteConfig.MaxParallelWorkersPerGather) + result.MaxReplicationSlots = cast.IntToUintPtr(remoteConfig.MaxReplicationSlots) + result.MaxSlotWalKeepSize = remoteConfig.MaxSlotWalKeepSize + result.MaxStandbyArchiveDelay = remoteConfig.MaxStandbyArchiveDelay + result.MaxStandbyStreamingDelay = remoteConfig.MaxStandbyStreamingDelay + result.MaxWalSenders = cast.IntToUintPtr(remoteConfig.MaxWalSenders) + result.MaxWalSize = remoteConfig.MaxWalSize + result.MaxWorkerProcesses = cast.IntToUintPtr(remoteConfig.MaxWorkerProcesses) + result.SessionReplicationRole = (*SessionReplicationRole)(remoteConfig.SessionReplicationRole) + result.SharedBuffers = remoteConfig.SharedBuffers + result.StatementTimeout = remoteConfig.StatementTimeout + result.WalKeepSize = remoteConfig.WalKeepSize + result.WalSenderTimeout = remoteConfig.WalSenderTimeout + result.WorkMem = remoteConfig.WorkMem + return result +} + +func (a *settings) DiffWithRemote(remoteConfig v1API.PostgresConfigResponse) ([]byte, error) { + // Convert the config values into easily comparable remoteConfig values + currentValue, err := ToTomlBytes(a) + if err != nil { + return nil, err + } + remoteCompare, err := ToTomlBytes(a.fromRemoteConfig(remoteConfig)) + if err != nil { + return nil, err + } + return diff.Diff("remote[db.settings]", remoteCompare, "local[db.settings]", currentValue), nil +} diff --git a/pkg/config/db_test.go b/pkg/config/db_test.go new file mode 100644 index 000000000..e7c573475 --- /dev/null +++ b/pkg/config/db_test.go @@ -0,0 +1,155 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" + v1API "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" +) + +func TestDbSettingsToUpdatePostgresConfigBody(t *testing.T) { + t.Run("converts all fields correctly", func(t *testing.T) { + db := &db{ + Settings: settings{ + EffectiveCacheSize: cast.Ptr("4GB"), + MaxConnections: cast.Ptr(uint(100)), + SharedBuffers: cast.Ptr("1GB"), + StatementTimeout: cast.Ptr("30s"), + SessionReplicationRole: cast.Ptr(SessionReplicationRoleReplica), + }, + } + + body := db.Settings.ToUpdatePostgresConfigBody() + + assert.Equal(t, "4GB", *body.EffectiveCacheSize) + assert.Equal(t, 100, *body.MaxConnections) + assert.Equal(t, "1GB", *body.SharedBuffers) + assert.Equal(t, "30s", *body.StatementTimeout) + assert.Equal(t, v1API.UpdatePostgresConfigBodySessionReplicationRoleReplica, *body.SessionReplicationRole) + }) + + t.Run("handles empty fields", func(t *testing.T) { + db := &db{} + + body := db.Settings.ToUpdatePostgresConfigBody() + + assert.Nil(t, body.EffectiveCacheSize) + assert.Nil(t, body.MaxConnections) + assert.Nil(t, body.SharedBuffers) + assert.Nil(t, body.StatementTimeout) + assert.Nil(t, body.SessionReplicationRole) + }) +} + +func TestDbSettingsDiffWithRemote(t *testing.T) { + t.Run("detects differences", func(t *testing.T) { + db := &db{ + Settings: settings{ + EffectiveCacheSize: cast.Ptr("4GB"), + MaxConnections: cast.Ptr(uint(100)), + SharedBuffers: cast.Ptr("1GB"), + }, + } + + remoteConfig := v1API.PostgresConfigResponse{ + EffectiveCacheSize: cast.Ptr("8GB"), + MaxConnections: cast.Ptr(200), + SharedBuffers: cast.Ptr("2GB"), + } + + diff, err := db.Settings.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + + assert.Contains(t, string(diff), "-effective_cache_size = \"8GB\"") + assert.Contains(t, string(diff), "+effective_cache_size = \"4GB\"") + assert.Contains(t, string(diff), "-max_connections = 200") + assert.Contains(t, string(diff), "+max_connections = 100") + assert.Contains(t, string(diff), "-shared_buffers = \"2GB\"") + assert.Contains(t, string(diff), "+shared_buffers = \"1GB\"") + }) + + t.Run("handles no differences", func(t *testing.T) { + db := &db{ + Settings: settings{ + EffectiveCacheSize: cast.Ptr("4GB"), + MaxConnections: cast.Ptr(uint(100)), + SharedBuffers: cast.Ptr("1GB"), + }, + } + + remoteConfig := v1API.PostgresConfigResponse{ + EffectiveCacheSize: cast.Ptr("4GB"), + MaxConnections: cast.Ptr(100), + SharedBuffers: cast.Ptr("1GB"), + } + + diff, err := db.Settings.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + + assert.Empty(t, diff) + }) + + t.Run("handles multiple schemas and search paths with spaces", func(t *testing.T) { + db := &db{ + Settings: settings{ + EffectiveCacheSize: cast.Ptr("4GB"), + MaxConnections: cast.Ptr(uint(100)), + SharedBuffers: cast.Ptr("1GB"), + }, + } + + remoteConfig := v1API.PostgresConfigResponse{ + EffectiveCacheSize: cast.Ptr("4GB"), + MaxConnections: cast.Ptr(100), + SharedBuffers: cast.Ptr("1GB"), + } + + diff, err := db.Settings.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + + assert.Empty(t, diff) + }) + + t.Run("handles api disabled on remote side", func(t *testing.T) { + db := &db{ + Settings: settings{ + EffectiveCacheSize: cast.Ptr("4GB"), + MaxConnections: cast.Ptr(uint(100)), + SharedBuffers: cast.Ptr("1GB"), + }, + } + + remoteConfig := v1API.PostgresConfigResponse{ + // All fields are nil to simulate disabled API + } + + diff, err := db.Settings.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + + assert.Contains(t, string(diff), "+effective_cache_size = \"4GB\"") + assert.Contains(t, string(diff), "+max_connections = 100") + assert.Contains(t, string(diff), "+shared_buffers = \"1GB\"") + }) + + t.Run("handles api disabled on local side", func(t *testing.T) { + db := &db{ + Settings: settings{ + // All fields are nil to simulate disabled API + }, + } + + remoteConfig := v1API.PostgresConfigResponse{ + EffectiveCacheSize: cast.Ptr("4GB"), + MaxConnections: cast.Ptr(100), + SharedBuffers: cast.Ptr("1GB"), + } + + diff, err := db.Settings.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + + assert.Contains(t, string(diff), "-effective_cache_size = \"4GB\"") + assert.Contains(t, string(diff), "-max_connections = 100") + assert.Contains(t, string(diff), "-shared_buffers = \"1GB\"") + }) +} diff --git a/pkg/config/updater.go b/pkg/config/updater.go index 467b7bb63..ac97cc63d 100644 --- a/pkg/config/updater.go +++ b/pkg/config/updater.go @@ -21,7 +21,9 @@ func (u *ConfigUpdater) UpdateRemoteConfig(ctx context.Context, remote baseConfi if err := u.UpdateApiConfig(ctx, remote.ProjectId, remote.Api); err != nil { return err } - // TODO: implement other service configs, ie. auth + if err := u.UpdateDbConfig(ctx, remote.ProjectId, remote.Db); err != nil { + return err + } return nil } @@ -40,6 +42,7 @@ func (u *ConfigUpdater) UpdateApiConfig(ctx context.Context, projectRef string, return nil } fmt.Fprintln(os.Stderr, "Updating API service with config:", string(apiDiff)) + if resp, err := u.client.V1UpdatePostgrestServiceConfigWithResponse(ctx, projectRef, c.ToUpdatePostgrestConfigBody()); err != nil { return errors.Errorf("failed to update API config: %w", err) } else if resp.JSON200 == nil { @@ -47,3 +50,40 @@ func (u *ConfigUpdater) UpdateApiConfig(ctx context.Context, projectRef string, } return nil } + +func (u *ConfigUpdater) UpdateDbSettingsConfig(ctx context.Context, projectRef string, s settings) error { + dbConfig, err := u.client.V1GetPostgresConfigWithResponse(ctx, projectRef) + if err != nil { + return errors.Errorf("failed to read DB config: %w", err) + } else if dbConfig.JSON200 == nil { + return errors.Errorf("unexpected status %d: %s", dbConfig.StatusCode(), string(dbConfig.Body)) + } + dbDiff, err := s.DiffWithRemote(*dbConfig.JSON200) + if err != nil { + return err + } else if len(dbDiff) == 0 { + fmt.Fprintln(os.Stderr, "Remote DB config is up to date.") + return nil + } + fmt.Fprintln(os.Stderr, "Updating DB service with config:", string(dbDiff)) + remoteConfig := s.fromRemoteConfig(*dbConfig.JSON200) + restartRequired := s.requireDbRestart(remoteConfig) + if restartRequired { + fmt.Fprintln(os.Stderr, "Database will be restarted to apply config updates...") + } + updateBody := s.ToUpdatePostgresConfigBody() + updateBody.RestartDatabase = &restartRequired + if resp, err := u.client.V1UpdatePostgresConfigWithResponse(ctx, projectRef, updateBody); err != nil { + return errors.Errorf("failed to update DB config: %w", err) + } else if resp.JSON200 == nil { + return errors.Errorf("unexpected status %d: %s", resp.StatusCode(), string(resp.Body)) + } + return nil +} + +func (u *ConfigUpdater) UpdateDbConfig(ctx context.Context, projectRef string, c db) error { + if err := u.UpdateDbSettingsConfig(ctx, projectRef, c.Settings); err != nil { + return err + } + return nil +} From aaa0d7f619f48de373ef7dea28d396d7cd128660 Mon Sep 17 00:00:00 2001 From: Andrew Valleteau Date: Wed, 23 Oct 2024 11:22:18 +0200 Subject: [PATCH 2/5] chore: refactor use cast ptr (#2793) * go mod tidy * chore: refactor use cast.Ptr --- cmd/functions.go | 5 +++-- go.mod | 2 +- internal/branches/create/create_test.go | 7 ++++--- internal/functions/deploy/deploy.go | 3 ++- internal/functions/deploy/deploy_test.go | 3 ++- internal/functions/serve/serve_test.go | 3 ++- internal/init/init_test.go | 9 +++++---- internal/storage/cp/cp_test.go | 9 +++++---- internal/storage/ls/ls_test.go | 9 +++++---- internal/storage/mv/mv_test.go | 9 +++++---- internal/storage/rm/rm_test.go | 9 +++++---- internal/utils/api.go | 3 ++- internal/utils/console.go | 5 +++-- internal/utils/misc.go | 4 ---- internal/utils/release_test.go | 3 ++- 15 files changed, 46 insertions(+), 37 deletions(-) diff --git a/cmd/functions.go b/cmd/functions.go index 6555bae9b..36dc47f94 100644 --- a/cmd/functions.go +++ b/cmd/functions.go @@ -13,6 +13,7 @@ import ( "github.com/supabase/cli/internal/functions/serve" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" + "github.com/supabase/cli/pkg/cast" ) var ( @@ -106,9 +107,9 @@ var ( } if len(inspectMode.Value) > 0 { - runtimeOption.InspectMode = utils.Ptr(serve.InspectMode(inspectMode.Value)) + runtimeOption.InspectMode = cast.Ptr(serve.InspectMode(inspectMode.Value)) } else if inspectBrk { - runtimeOption.InspectMode = utils.Ptr(serve.InspectModeBrk) + runtimeOption.InspectMode = cast.Ptr(serve.InspectModeBrk) } if runtimeOption.InspectMode == nil && runtimeOption.InspectMain { return fmt.Errorf("--inspect-main must be used together with one of these flags: [inspect inspect-mode]") diff --git a/go.mod b/go.mod index cc84666d5..6f813394d 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/go-xmlfmt/xmlfmt v1.1.2 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golangci/golangci-lint v1.61.0 + github.com/google/go-cmp v0.6.0 github.com/google/go-github/v62 v62.0.0 github.com/google/go-querystring v1.1.0 github.com/google/uuid v1.6.0 @@ -169,7 +170,6 @@ require ( github.com/golangci/plugin-module-register v0.1.1 // indirect github.com/golangci/revgrep v0.5.3 // indirect github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect - github.com/google/go-cmp v0.6.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/gordonklaus/ineffassign v0.1.0 // indirect github.com/gorilla/css v1.0.0 // indirect diff --git a/internal/branches/create/create_test.go b/internal/branches/create/create_test.go index 6f56b586f..e07e10387 100644 --- a/internal/branches/create/create_test.go +++ b/internal/branches/create/create_test.go @@ -14,6 +14,7 @@ import ( "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" ) func TestCreateCommand(t *testing.T) { @@ -36,7 +37,7 @@ func TestCreateCommand(t *testing.T) { }) // Run test err := Run(context.Background(), api.CreateBranchBody{ - Region: utils.Ptr("sin"), + Region: cast.Ptr("sin"), }, fsys) // Check error assert.NoError(t, err) @@ -53,7 +54,7 @@ func TestCreateCommand(t *testing.T) { ReplyError(net.ErrClosed) // Run test err := Run(context.Background(), api.CreateBranchBody{ - Region: utils.Ptr("sin"), + Region: cast.Ptr("sin"), }, fsys) // Check error assert.ErrorIs(t, err, net.ErrClosed) @@ -70,7 +71,7 @@ func TestCreateCommand(t *testing.T) { Reply(http.StatusServiceUnavailable) // Run test err := Run(context.Background(), api.CreateBranchBody{ - Region: utils.Ptr("sin"), + Region: cast.Ptr("sin"), }, fsys) // Check error assert.ErrorContains(t, err, "Unexpected error creating preview branch:") diff --git a/internal/functions/deploy/deploy.go b/internal/functions/deploy/deploy.go index 25ad3372a..b18565d20 100644 --- a/internal/functions/deploy/deploy.go +++ b/internal/functions/deploy/deploy.go @@ -10,6 +10,7 @@ import ( "github.com/go-errors/errors" "github.com/spf13/afero" "github.com/supabase/cli/internal/utils" + "github.com/supabase/cli/pkg/cast" "github.com/supabase/cli/pkg/config" "github.com/supabase/cli/pkg/function" ) @@ -86,7 +87,7 @@ func GetFunctionConfig(slugs []string, importMapPath string, noVerifyJWT *bool, function.ImportMap = utils.FallbackImportMapPath } if noVerifyJWT != nil { - function.VerifyJWT = utils.Ptr(!*noVerifyJWT) + function.VerifyJWT = cast.Ptr(!*noVerifyJWT) } functionConfig[name] = function } diff --git a/internal/functions/deploy/deploy_test.go b/internal/functions/deploy/deploy_test.go index adf49e63f..c92f210cc 100644 --- a/internal/functions/deploy/deploy_test.go +++ b/internal/functions/deploy/deploy_test.go @@ -15,6 +15,7 @@ import ( "github.com/supabase/cli/internal/testing/apitest" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" "github.com/supabase/cli/pkg/config" ) @@ -323,7 +324,7 @@ func TestImportMapPath(t *testing.T) { fsys := afero.NewMemMapFs() require.NoError(t, afero.WriteFile(fsys, utils.FallbackImportMapPath, []byte("{}"), 0644)) // Run test - fc, err := GetFunctionConfig([]string{slug}, utils.FallbackImportMapPath, utils.Ptr(false), fsys) + fc, err := GetFunctionConfig([]string{slug}, utils.FallbackImportMapPath, cast.Ptr(false), fsys) // Check error assert.NoError(t, err) assert.Equal(t, utils.FallbackImportMapPath, fc[slug].ImportMap) diff --git a/internal/functions/serve/serve_test.go b/internal/functions/serve/serve_test.go index 5b98ece87..570c4b927 100644 --- a/internal/functions/serve/serve_test.go +++ b/internal/functions/serve/serve_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "github.com/supabase/cli/internal/testing/apitest" "github.com/supabase/cli/internal/utils" + "github.com/supabase/cli/pkg/cast" ) func TestServeCommand(t *testing.T) { @@ -100,7 +101,7 @@ func TestServeCommand(t *testing.T) { Reply(http.StatusOK). JSON(types.ContainerJSON{}) // Run test - err := Run(context.Background(), ".env", utils.Ptr(true), "import_map.json", RuntimeOption{}, fsys) + err := Run(context.Background(), ".env", cast.Ptr(true), "import_map.json", RuntimeOption{}, fsys) // Check error assert.ErrorIs(t, err, os.ErrNotExist) }) diff --git a/internal/init/init_test.go b/internal/init/init_test.go index 47e35b89e..99a96dce5 100644 --- a/internal/init/init_test.go +++ b/internal/init/init_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/supabase/cli/internal/testing/fstest" "github.com/supabase/cli/internal/utils" + "github.com/supabase/cli/pkg/cast" ) func TestInitCommand(t *testing.T) { @@ -70,7 +71,7 @@ func TestInitCommand(t *testing.T) { // Setup in-memory fs fsys := &afero.MemMapFs{} // Run test - assert.NoError(t, Run(context.Background(), fsys, utils.Ptr(true), nil, utils.InitParams{})) + assert.NoError(t, Run(context.Background(), fsys, cast.Ptr(true), nil, utils.InitParams{})) // Validate generated vscode settings exists, err := afero.Exists(fsys, settingsPath) assert.NoError(t, err) @@ -84,7 +85,7 @@ func TestInitCommand(t *testing.T) { // Setup in-memory fs fsys := &afero.MemMapFs{} // Run test - assert.NoError(t, Run(context.Background(), fsys, utils.Ptr(false), nil, utils.InitParams{})) + assert.NoError(t, Run(context.Background(), fsys, cast.Ptr(false), nil, utils.InitParams{})) // Validate vscode settings file isn't generated exists, err := afero.Exists(fsys, settingsPath) assert.NoError(t, err) @@ -98,7 +99,7 @@ func TestInitCommand(t *testing.T) { // Setup in-memory fs fsys := &afero.MemMapFs{} // Run test - assert.NoError(t, Run(context.Background(), fsys, nil, utils.Ptr(true), utils.InitParams{})) + assert.NoError(t, Run(context.Background(), fsys, nil, cast.Ptr(true), utils.InitParams{})) // Validate generated intellij deno config exists, err := afero.Exists(fsys, denoPath) assert.NoError(t, err) @@ -109,7 +110,7 @@ func TestInitCommand(t *testing.T) { // Setup in-memory fs fsys := &afero.MemMapFs{} // Run test - assert.NoError(t, Run(context.Background(), fsys, nil, utils.Ptr(false), utils.InitParams{})) + assert.NoError(t, Run(context.Background(), fsys, nil, cast.Ptr(false), utils.InitParams{})) // Validate intellij deno config file isn't generated exists, err := afero.Exists(fsys, denoPath) assert.NoError(t, err) diff --git a/internal/storage/cp/cp_test.go b/internal/storage/cp/cp_test.go index fe988c21c..75a0cf3cd 100644 --- a/internal/storage/cp/cp_test.go +++ b/internal/storage/cp/cp_test.go @@ -14,16 +14,17 @@ import ( "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" "github.com/supabase/cli/pkg/fetcher" "github.com/supabase/cli/pkg/storage" ) var mockFile = storage.ObjectResponse{ Name: "abstract.pdf", - Id: utils.Ptr("9b7f9f48-17a6-4ca8-b14a-39b0205a63e9"), - UpdatedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), - CreatedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), - LastAccessedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), + Id: cast.Ptr("9b7f9f48-17a6-4ca8-b14a-39b0205a63e9"), + UpdatedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), + CreatedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), + LastAccessedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), Metadata: &storage.ObjectMetadata{ ETag: `"887ea9be3c68e6f2fca7fd2d7c77d8fe"`, Size: 82702, diff --git a/internal/storage/ls/ls_test.go b/internal/storage/ls/ls_test.go index f1682759c..e0e2cd207 100644 --- a/internal/storage/ls/ls_test.go +++ b/internal/storage/ls/ls_test.go @@ -14,16 +14,17 @@ import ( "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" "github.com/supabase/cli/pkg/fetcher" "github.com/supabase/cli/pkg/storage" ) var mockFile = storage.ObjectResponse{ Name: "abstract.pdf", - Id: utils.Ptr("9b7f9f48-17a6-4ca8-b14a-39b0205a63e9"), - UpdatedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), - CreatedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), - LastAccessedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), + Id: cast.Ptr("9b7f9f48-17a6-4ca8-b14a-39b0205a63e9"), + UpdatedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), + CreatedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), + LastAccessedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), Metadata: &storage.ObjectMetadata{ ETag: `"887ea9be3c68e6f2fca7fd2d7c77d8fe"`, Size: 82702, diff --git a/internal/storage/mv/mv_test.go b/internal/storage/mv/mv_test.go index 2cb1c57e5..fd8ecfbcc 100644 --- a/internal/storage/mv/mv_test.go +++ b/internal/storage/mv/mv_test.go @@ -12,16 +12,17 @@ import ( "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" "github.com/supabase/cli/pkg/fetcher" "github.com/supabase/cli/pkg/storage" ) var mockFile = storage.ObjectResponse{ Name: "abstract.pdf", - Id: utils.Ptr("9b7f9f48-17a6-4ca8-b14a-39b0205a63e9"), - UpdatedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), - CreatedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), - LastAccessedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), + Id: cast.Ptr("9b7f9f48-17a6-4ca8-b14a-39b0205a63e9"), + UpdatedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), + CreatedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), + LastAccessedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), Metadata: &storage.ObjectMetadata{ ETag: `"887ea9be3c68e6f2fca7fd2d7c77d8fe"`, Size: 82702, diff --git a/internal/storage/rm/rm_test.go b/internal/storage/rm/rm_test.go index 6032c5b9b..46d204cf0 100644 --- a/internal/storage/rm/rm_test.go +++ b/internal/storage/rm/rm_test.go @@ -13,16 +13,17 @@ import ( "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" "github.com/supabase/cli/pkg/fetcher" "github.com/supabase/cli/pkg/storage" ) var mockFile = storage.ObjectResponse{ Name: "abstract.pdf", - Id: utils.Ptr("9b7f9f48-17a6-4ca8-b14a-39b0205a63e9"), - UpdatedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), - CreatedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), - LastAccessedAt: utils.Ptr("2023-10-13T18:08:22.068Z"), + Id: cast.Ptr("9b7f9f48-17a6-4ca8-b14a-39b0205a63e9"), + UpdatedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), + CreatedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), + LastAccessedAt: cast.Ptr("2023-10-13T18:08:22.068Z"), Metadata: &storage.ObjectMetadata{ ETag: `"887ea9be3c68e6f2fca7fd2d7c77d8fe"`, Size: 82702, diff --git a/internal/utils/api.go b/internal/utils/api.go index 3dc63d30d..7dc59a088 100644 --- a/internal/utils/api.go +++ b/internal/utils/api.go @@ -16,6 +16,7 @@ import ( "github.com/spf13/viper" "github.com/supabase/cli/internal/utils/cloudflare" supabase "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" ) const ( @@ -60,7 +61,7 @@ func FallbackLookupIP(ctx context.Context, host string) ([]string, error) { func ResolveCNAME(ctx context.Context, host string) (string, error) { // Ref: https://developers.cloudflare.com/1.1.1.1/encryption/dns-over-https/make-api-requests/dns-json cf := cloudflare.NewCloudflareAPI() - data, err := cf.DNSQuery(ctx, cloudflare.DNSParams{Name: host, Type: Ptr(cloudflare.TypeCNAME)}) + data, err := cf.DNSQuery(ctx, cloudflare.DNSParams{Name: host, Type: cast.Ptr(cloudflare.TypeCNAME)}) if err != nil { return "", err } diff --git a/internal/utils/console.go b/internal/utils/console.go index dfd014afc..85bebdc1e 100644 --- a/internal/utils/console.go +++ b/internal/utils/console.go @@ -10,6 +10,7 @@ import ( "time" "github.com/go-errors/errors" + "github.com/supabase/cli/pkg/cast" "golang.org/x/term" ) @@ -78,10 +79,10 @@ func (c *Console) PromptYesNo(ctx context.Context, label string, def bool) (bool func parseYesNo(s string) *bool { s = strings.ToLower(s) if s == "y" || s == "yes" { - return Ptr(true) + return cast.Ptr(true) } if s == "n" || s == "no" { - return Ptr(false) + return cast.Ptr(false) } return nil } diff --git a/internal/utils/misc.go b/internal/utils/misc.go index adb0efa9c..0993ae806 100644 --- a/internal/utils/misc.go +++ b/internal/utils/misc.go @@ -295,10 +295,6 @@ func ValidateFunctionSlug(slug string) error { return nil } -func Ptr[T any](v T) *T { - return &v -} - func GetHostname() string { host := Docker.DaemonHost() if parsed, err := client.ParseHostURL(host); err == nil && parsed.Scheme == "tcp" { diff --git a/internal/utils/release_test.go b/internal/utils/release_test.go index 00b44f49d..25aa920e8 100644 --- a/internal/utils/release_test.go +++ b/internal/utils/release_test.go @@ -10,6 +10,7 @@ import ( "github.com/h2non/gock" "github.com/stretchr/testify/assert" "github.com/supabase/cli/internal/testing/apitest" + "github.com/supabase/cli/pkg/cast" ) func TestLatestRelease(t *testing.T) { @@ -19,7 +20,7 @@ func TestLatestRelease(t *testing.T) { gock.New("https://api.github.com"). Get("/repos/supabase/cli/releases/latest"). Reply(http.StatusOK). - JSON(github.RepositoryRelease{TagName: Ptr("v2")}) + JSON(github.RepositoryRelease{TagName: cast.Ptr("v2")}) // Run test version, err := GetLatestRelease(context.Background()) // Check error From 56c2cc464eb08249c347e30d1d089cf16416cbbc Mon Sep 17 00:00:00 2001 From: Andrew Valleteau Date: Thu, 24 Oct 2024 11:12:13 +0200 Subject: [PATCH 3/5] fix(config): default value for seed remote (#2797) --- pkg/config/config.go | 2 ++ pkg/config/config_test.go | 3 +++ pkg/config/testdata/config.toml | 3 +++ 3 files changed, 8 insertions(+) diff --git a/pkg/config/config.go b/pkg/config/config.go index 6e548111a..7f0acf72c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -697,6 +697,8 @@ func (c *config) Load(path string, fsys fs.FS) error { c.Remotes = make(map[string]baseConfig, len(c.Overrides)) for name, remote := range c.Overrides { base := c.baseConfig.Clone() + // On remotes branches set seed as disabled by default + base.Db.Seed.Enabled = false // Encode a toml file with only config overrides var buf bytes.Buffer if err := toml.NewEncoder(&buf).Encode(remote); err != nil { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index c10669350..8e8f6a2e1 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -93,9 +93,12 @@ func TestConfigParsing(t *testing.T) { assert.Equal(t, false, production.Auth.EnableSignup) assert.Equal(t, false, production.Auth.External["azure"].Enabled) assert.Equal(t, "nope", production.Auth.External["azure"].ClientId) + // Check seed should be disabled by default for remote configs + assert.Equal(t, false, production.Db.Seed.Enabled) // Check the values for the staging override assert.Equal(t, "staging-project", staging.ProjectId) assert.Equal(t, []string{"image/png"}, staging.Storage.Buckets["images"].AllowedMimeTypes) + assert.Equal(t, true, staging.Db.Seed.Enabled) }) } diff --git a/pkg/config/testdata/config.toml b/pkg/config/testdata/config.toml index 0d591d979..2f1a8e90d 100644 --- a/pkg/config/testdata/config.toml +++ b/pkg/config/testdata/config.toml @@ -237,5 +237,8 @@ client_id = "nope" [remotes.staging] project_id = "staging-project" +[remotes.staging.db.seed] +enabled = true + [remotes.staging.storage.buckets.images] allowed_mime_types = ["image/png"] From 78bc7f795ab11fdcfe901901bb8db434fb39fee2 Mon Sep 17 00:00:00 2001 From: Andrew Valleteau Date: Thu, 24 Oct 2024 15:02:11 +0200 Subject: [PATCH 4/5] feat(config): allow local postgres configuration (#2796) * feat(cli): allow local postgres.conf configuration Closes: #2611 * chore: test ToPostgresConfig * chore: add start runtime test * chore: use less max_connections * fix: db start tests * chore: add prefix to postgres config * fix: test * chore: serialise postgres conf as toml --------- Co-authored-by: Qiao Han --- internal/db/start/start.go | 18 +++++++---- internal/db/start/start_test.go | 53 +++++++++++++++++++++++++++++++++ pkg/config/db.go | 13 ++++++++ pkg/config/db_test.go | 38 +++++++++++++++++++++++ 4 files changed, 116 insertions(+), 6 deletions(-) diff --git a/internal/db/start/start.go b/internal/db/start/start.go index 826d46c86..a300f5594 100644 --- a/internal/db/start/start.go +++ b/internal/db/start/start.go @@ -56,7 +56,6 @@ func NewContainerConfig() container.Config { env := []string{ "POSTGRES_PASSWORD=" + utils.Config.Db.Password, "POSTGRES_HOST=/var/run/postgresql", - "POSTGRES_INITDB_ARGS=--lc-ctype=C.UTF-8", "JWT_SECRET=" + utils.Config.Auth.JwtSecret, fmt.Sprintf("JWT_EXP=%d", utils.Config.Auth.JwtExpiry), } @@ -81,13 +80,18 @@ func NewContainerConfig() container.Config { Timeout: 2 * time.Second, Retries: 3, }, - Entrypoint: []string{"sh", "-c", `cat <<'EOF' > /etc/postgresql.schema.sql && cat <<'EOF' > /etc/postgresql-custom/pgsodium_root.key && docker-entrypoint.sh postgres -D /etc/postgresql + Entrypoint: []string{"sh", "-c", ` +cat <<'EOF' > /etc/postgresql.schema.sql && \ +cat <<'EOF' > /etc/postgresql-custom/pgsodium_root.key && \ +cat <<'EOF' >> /etc/postgresql/postgresql.conf && \ +docker-entrypoint.sh postgres -D /etc/postgresql ` + initialSchema + ` ` + _supabaseSchema + ` EOF ` + utils.Config.Db.RootKey + ` EOF -`}, +` + utils.Config.Db.Settings.ToPostgresConfig() + ` +EOF`}, } if utils.Config.Db.MajorVersion >= 14 { config.Cmd = []string{"postgres", @@ -124,11 +128,13 @@ func StartDatabase(ctx context.Context, fsys afero.Fs, w io.Writer, options ...f } if utils.Config.Db.MajorVersion <= 14 { config.Entrypoint = []string{"sh", "-c", ` - cat <<'EOF' > /docker-entrypoint-initdb.d/supabase_schema.sql +cat <<'EOF' > /docker-entrypoint-initdb.d/supabase_schema.sql && \ +cat <<'EOF' >> /etc/postgresql/postgresql.conf && \ +docker-entrypoint.sh postgres -D /etc/postgresql ` + _supabaseSchema + ` EOF - docker-entrypoint.sh postgres -D /etc/postgresql - `} +` + utils.Config.Db.Settings.ToPostgresConfig() + ` +EOF`} hostConfig.Tmpfs = map[string]string{"/docker-entrypoint-initdb.d": ""} } // Creating volume will not override existing volume, so we must inspect explicitly diff --git a/internal/db/start/start_test.go b/internal/db/start/start_test.go index 96d97da8c..475562f2f 100644 --- a/internal/db/start/start_test.go +++ b/internal/db/start/start_test.go @@ -17,6 +17,7 @@ import ( "github.com/supabase/cli/internal/testing/apitest" "github.com/supabase/cli/internal/testing/fstest" "github.com/supabase/cli/internal/utils" + "github.com/supabase/cli/pkg/cast" "github.com/supabase/cli/pkg/pgtest" ) @@ -308,3 +309,55 @@ func TestSetupDatabase(t *testing.T) { assert.Empty(t, apitest.ListUnmatchedRequests()) }) } +func TestStartDatabaseWithCustomSettings(t *testing.T) { + t.Run("starts database with custom MaxConnections", func(t *testing.T) { + // Setup + utils.Config.Db.MajorVersion = 15 + utils.DbId = "supabase_db_test" + utils.ConfigId = "supabase_config_test" + utils.Config.Db.Port = 5432 + utils.Config.Db.Settings.MaxConnections = cast.Ptr(uint(50)) + + // Setup in-memory fs + fsys := afero.NewMemMapFs() + + // Setup mock docker + require.NoError(t, apitest.MockDocker(utils.Docker)) + defer gock.OffAll() + gock.New(utils.Docker.DaemonHost()). + Get("/v" + utils.Docker.ClientVersion() + "/volumes/" + utils.DbId). + Reply(http.StatusNotFound). + JSON(volume.Volume{}) + apitest.MockDockerStart(utils.Docker, utils.GetRegistryImageUrl(utils.Config.Db.Image), utils.DbId) + gock.New(utils.Docker.DaemonHost()). + Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId + "/json"). + Reply(http.StatusOK). + JSON(types.ContainerJSON{ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + Health: &types.Health{Status: types.Healthy}, + }, + }}) + + apitest.MockDockerStart(utils.Docker, utils.GetRegistryImageUrl(utils.Config.Realtime.Image), "test-realtime") + require.NoError(t, apitest.MockDockerLogs(utils.Docker, "test-realtime", "")) + apitest.MockDockerStart(utils.Docker, utils.GetRegistryImageUrl(utils.Config.Storage.Image), "test-storage") + require.NoError(t, apitest.MockDockerLogs(utils.Docker, "test-storage", "")) + apitest.MockDockerStart(utils.Docker, utils.GetRegistryImageUrl(utils.Config.Auth.Image), "test-auth") + require.NoError(t, apitest.MockDockerLogs(utils.Docker, "test-auth", "")) + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + + // Run test + err := StartDatabase(context.Background(), fsys, io.Discard, conn.Intercept) + + // Check error + assert.NoError(t, err) + assert.Empty(t, apitest.ListUnmatchedRequests()) + + // Check if the custom MaxConnections setting was applied + config := NewContainerConfig() + assert.Contains(t, config.Entrypoint[2], "max_connections = 50") + }) +} diff --git a/pkg/config/db.go b/pkg/config/db.go index 89e5bfd24..e7c5f820b 100644 --- a/pkg/config/db.go +++ b/pkg/config/db.go @@ -1,6 +1,8 @@ package config import ( + "bytes" + "github.com/google/go-cmp/cmp" v1API "github.com/supabase/cli/pkg/api" "github.com/supabase/cli/pkg/cast" @@ -146,6 +148,17 @@ func (a *settings) fromRemoteConfig(remoteConfig v1API.PostgresConfigResponse) s return result } +const pgConfHeader = "\n# supabase [db.settings] configuration\n" + +// create a valid string to append to /etc/postgresql/postgresql.conf +func (a *settings) ToPostgresConfig() string { + // Assuming postgres settings is always a flat struct, we can serialise + // using toml, then replace double quotes with single. + data, _ := ToTomlBytes(*a) + body := bytes.ReplaceAll(data, []byte{'"'}, []byte{'\''}) + return pgConfHeader + string(body) +} + func (a *settings) DiffWithRemote(remoteConfig v1API.PostgresConfigResponse) ([]byte, error) { // Convert the config values into easily comparable remoteConfig values currentValue, err := ToTomlBytes(a) diff --git a/pkg/config/db_test.go b/pkg/config/db_test.go index e7c573475..8d70ec21b 100644 --- a/pkg/config/db_test.go +++ b/pkg/config/db_test.go @@ -153,3 +153,41 @@ func TestDbSettingsDiffWithRemote(t *testing.T) { assert.Contains(t, string(diff), "-shared_buffers = \"1GB\"") }) } + +func TestSettingsToPostgresConfig(t *testing.T) { + t.Run("Only set values should appear", func(t *testing.T) { + settings := settings{ + MaxConnections: cast.Ptr(uint(100)), + MaxLocksPerTransaction: cast.Ptr(uint(64)), + SharedBuffers: cast.Ptr("128MB"), + WorkMem: cast.Ptr("4MB"), + } + got := settings.ToPostgresConfig() + + assert.Contains(t, got, "max_connections = 100") + assert.Contains(t, got, "max_locks_per_transaction = 64") + assert.Contains(t, got, "shared_buffers = '128MB'") + assert.Contains(t, got, "work_mem = '4MB'") + + assert.NotContains(t, got, "effective_cache_size") + assert.NotContains(t, got, "maintenance_work_mem") + assert.NotContains(t, got, "max_parallel_workers") + }) + + t.Run("SessionReplicationRole should be handled correctly", func(t *testing.T) { + settings := settings{ + SessionReplicationRole: cast.Ptr(SessionReplicationRoleOrigin), + } + got := settings.ToPostgresConfig() + + assert.Contains(t, got, "session_replication_role = 'origin'") + }) + + t.Run("Empty settings should result in empty string", func(t *testing.T) { + settings := settings{} + got := settings.ToPostgresConfig() + + assert.Equal(t, got, "\n# supabase [db.settings] configuration\n") + assert.NotContains(t, got, "=") + }) +} From 8611ace19bc1fd62f5e4e96154402a208194940f Mon Sep 17 00:00:00 2001 From: Andrew Valleteau Date: Fri, 25 Oct 2024 07:15:10 +0200 Subject: [PATCH 5/5] feat(config): experimental config webhooks (#2794) --- pkg/config/config.go | 27 +++++-- pkg/config/updater.go | 16 ++++ pkg/config/updater_test.go | 153 +++++++++++++++++++++++++++++++++++++ 3 files changed, 191 insertions(+), 5 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 7f0acf72c..7fc8942b5 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -388,12 +388,17 @@ type ( VectorPort uint16 `toml:"vector_port"` } + webhooks struct { + Enabled bool `toml:"enabled"` + } + experimental struct { - OrioleDBVersion string `toml:"orioledb_version"` - S3Host string `toml:"s3_host"` - S3Region string `toml:"s3_region"` - S3AccessKey string `toml:"s3_access_key"` - S3SecretKey string `toml:"s3_secret_key"` + OrioleDBVersion string `toml:"orioledb_version"` + S3Host string `toml:"s3_host"` + S3Region string `toml:"s3_region"` + S3AccessKey string `toml:"s3_access_key"` + S3SecretKey string `toml:"s3_secret_key"` + Webhooks *webhooks `toml:"webhooks"` } ) @@ -986,6 +991,9 @@ func (c *baseConfig) Validate(fsys fs.FS) error { return errors.Errorf("Invalid config for analytics.backend. Must be one of: %v", allowed) } } + if err := c.Experimental.validateWebhooks(); err != nil { + return err + } return nil } @@ -1351,3 +1359,12 @@ func ToTomlBytes(config any) ([]byte, error) { } return buf.Bytes(), nil } + +func (e *experimental) validateWebhooks() error { + if e.Webhooks != nil { + if !e.Webhooks.Enabled { + return errors.Errorf("Webhooks cannot be deactivated. [experimental.webhooks] enabled can either be true or left undefined") + } + } + return nil +} diff --git a/pkg/config/updater.go b/pkg/config/updater.go index ac97cc63d..c7739eded 100644 --- a/pkg/config/updater.go +++ b/pkg/config/updater.go @@ -24,6 +24,9 @@ func (u *ConfigUpdater) UpdateRemoteConfig(ctx context.Context, remote baseConfi if err := u.UpdateDbConfig(ctx, remote.ProjectId, remote.Db); err != nil { return err } + if err := u.UpdateExperimentalConfig(ctx, remote.ProjectId, remote.Experimental); err != nil { + return err + } return nil } @@ -87,3 +90,16 @@ func (u *ConfigUpdater) UpdateDbConfig(ctx context.Context, projectRef string, c } return nil } + +func (u *ConfigUpdater) UpdateExperimentalConfig(ctx context.Context, projectRef string, exp experimental) error { + if exp.Webhooks != nil && exp.Webhooks.Enabled { + fmt.Fprintln(os.Stderr, "Enabling webhooks for the project...") + + if resp, err := u.client.V1EnableDatabaseWebhookWithResponse(ctx, projectRef); err != nil { + return errors.Errorf("failed to enable webhooks: %w", err) + } else if resp.StatusCode() < 200 || resp.StatusCode() >= 300 { + return errors.Errorf("unexpected enable webhook status %d: %s", resp.StatusCode(), string(resp.Body)) + } + } + return nil +} diff --git a/pkg/config/updater_test.go b/pkg/config/updater_test.go index 241d612e9..2ad07c998 100644 --- a/pkg/config/updater_test.go +++ b/pkg/config/updater_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" v1API "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" ) func TestUpdateApi(t *testing.T) { @@ -63,3 +64,155 @@ func TestUpdateApi(t *testing.T) { assert.True(t, gock.IsDone()) }) } + +func TestUpdateDbConfig(t *testing.T) { + server := "http://localhost" + client, err := v1API.NewClientWithResponses(server) + require.NoError(t, err) + + t.Run("updates remote DB config", func(t *testing.T) { + updater := NewConfigUpdater(*client) + // Setup mock server + defer gock.Off() + gock.New(server). + Get("/v1/projects/test-project/config/database"). + Reply(http.StatusOK). + JSON(v1API.PostgresConfigResponse{}) + gock.New(server). + Put("/v1/projects/test-project/config/database"). + Reply(http.StatusOK). + JSON(v1API.PostgresConfigResponse{ + MaxConnections: cast.Ptr(cast.UintToInt(100)), + }) + // Run test + err := updater.UpdateDbConfig(context.Background(), "test-project", db{ + Settings: settings{ + MaxConnections: cast.Ptr(cast.IntToUint(100)), + }, + }) + // Check result + assert.NoError(t, err) + assert.True(t, gock.IsDone()) + }) + + t.Run("skips update if no diff in DB config", func(t *testing.T) { + updater := NewConfigUpdater(*client) + // Setup mock server + defer gock.Off() + gock.New(server). + Get("/v1/projects/test-project/config/database"). + Reply(http.StatusOK). + JSON(v1API.PostgresConfigResponse{ + MaxConnections: cast.Ptr(cast.UintToInt(100)), + }) + // Run test + err := updater.UpdateDbConfig(context.Background(), "test-project", db{ + Settings: settings{ + MaxConnections: cast.Ptr(cast.IntToUint(100)), + }, + }) + // Check result + assert.NoError(t, err) + assert.True(t, gock.IsDone()) + }) +} + +func TestUpdateExperimentalConfig(t *testing.T) { + server := "http://localhost" + client, err := v1API.NewClientWithResponses(server) + require.NoError(t, err) + + t.Run("enables webhooks", func(t *testing.T) { + updater := NewConfigUpdater(*client) + // Setup mock server + defer gock.Off() + gock.New(server). + Post("/v1/projects/test-project/database/webhooks/enable"). + Reply(http.StatusOK). + JSON(map[string]interface{}{}) + // Run test + err := updater.UpdateExperimentalConfig(context.Background(), "test-project", experimental{ + Webhooks: &webhooks{ + Enabled: true, + }, + }) + // Check result + assert.NoError(t, err) + assert.True(t, gock.IsDone()) + }) + + t.Run("skips update if webhooks not enabled", func(t *testing.T) { + updater := NewConfigUpdater(*client) + // Run test + err := updater.UpdateExperimentalConfig(context.Background(), "test-project", experimental{ + Webhooks: &webhooks{ + Enabled: false, + }, + }) + // Check result + assert.NoError(t, err) + assert.True(t, gock.IsDone()) + }) +} + +func TestUpdateRemoteConfig(t *testing.T) { + server := "http://localhost" + client, err := v1API.NewClientWithResponses(server) + require.NoError(t, err) + + t.Run("updates all configs", func(t *testing.T) { + updater := NewConfigUpdater(*client) + // Setup mock server + defer gock.Off() + // API config + gock.New(server). + Get("/v1/projects/test-project/postgrest"). + Reply(http.StatusOK). + JSON(v1API.PostgrestConfigWithJWTSecretResponse{}) + gock.New(server). + Patch("/v1/projects/test-project/postgrest"). + Reply(http.StatusOK). + JSON(v1API.PostgrestConfigWithJWTSecretResponse{ + DbSchema: "public", + MaxRows: 1000, + }) + // DB config + gock.New(server). + Get("/v1/projects/test-project/config/database"). + Reply(http.StatusOK). + JSON(v1API.PostgresConfigResponse{}) + gock.New(server). + Put("/v1/projects/test-project/config/database"). + Reply(http.StatusOK). + JSON(v1API.PostgresConfigResponse{ + MaxConnections: cast.Ptr(cast.UintToInt(100)), + }) + // Experimental config + gock.New(server). + Post("/v1/projects/test-project/database/webhooks/enable"). + Reply(http.StatusOK). + JSON(map[string]interface{}{}) + // Run test + err := updater.UpdateRemoteConfig(context.Background(), baseConfig{ + ProjectId: "test-project", + Api: api{ + Enabled: true, + Schemas: []string{"public", "private"}, + MaxRows: 1000, + }, + Db: db{ + Settings: settings{ + MaxConnections: cast.Ptr(cast.IntToUint(100)), + }, + }, + Experimental: experimental{ + Webhooks: &webhooks{ + Enabled: true, + }, + }, + }) + // Check result + assert.NoError(t, err) + assert.True(t, gock.IsDone()) + }) +}