Skip to content

Commit

Permalink
feat: add db postgres settings to config
Browse files Browse the repository at this point in the history
  • Loading branch information
avallete committed Oct 21, 2024
1 parent 2da3861 commit c3a1871
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 32 deletions.
26 changes: 26 additions & 0 deletions pkg/cast/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,32 @@ func IntToUint(value int) uint {
return uint(value)
}

func UintToIntPtr(value *uint) *int {
if value == nil {
return nil
}
if *value <= math.MaxInt {
result := int(*value)
return &result
}
maxInt := math.MaxInt
return &maxInt
}

// IntToUint converts an int to a uint, handling negative values
func IntToUintPtr(value *int) *uint {
var result uint
result = 0
if value == nil {
return nil
}
if *value < 0 {
return &result
}
result = uint(*value)
return &result
}

func Ptr[T any](v T) *T {
return &v
}
38 changes: 6 additions & 32 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ const (
LogflareBigQuery LogflareBackend = "bigquery"
)

type PoolMode string

const (
TransactionMode PoolMode = "transaction"
SessionMode PoolMode = "session"
Expand Down Expand Up @@ -146,36 +144,6 @@ type (
Remotes map[string]baseConfig `toml:"-"`
}

db struct {
Image string `toml:"-"`
Port uint16 `toml:"port"`
ShadowPort uint16 `toml:"shadow_port"`
MajorVersion uint `toml:"major_version"`
Password string `toml:"-"`
RootKey string `toml:"-" mapstructure:"root_key"`
Pooler pooler `toml:"pooler"`
Seed seed `toml:"seed"`
}

seed struct {
Enabled bool `toml:"enabled"`
GlobPatterns []string `toml:"sql_paths"`
SqlPaths []string `toml:"-"`
}

pooler struct {
Enabled bool `toml:"enabled"`
Image string `toml:"-"`
Port uint16 `toml:"port"`
PoolMode PoolMode `toml:"pool_mode"`
DefaultPoolSize uint `toml:"default_pool_size"`
MaxClientConn uint `toml:"max_client_conn"`
ConnectionString string `toml:"-"`
TenantId string `toml:"-"`
EncryptionKey string `toml:"-"`
SecretKeyBase string `toml:"-"`
}

realtime struct {
Enabled bool `toml:"enabled"`
Image string `toml:"-"`
Expand Down Expand Up @@ -775,6 +743,12 @@ func (c *baseConfig) Validate(fsys fs.FS) error {
}
}
// Validate db config
if c.Db.remoteDb.SessionReplicationRole != nil {
allowedRoles := []string{"origin", "replica", "local"}
if !sliceContains(allowedRoles, *c.Db.remoteDb.SessionReplicationRole) {
return errors.Errorf("Invalid config for db.session_replication_role: %s. Must be one of: %v", *c.Db.remoteDb.SessionReplicationRole, allowedRoles)
}
}
if c.Db.Port == 0 {
return errors.New("Missing required field in config: db.port")
}
Expand Down
134 changes: 134 additions & 0 deletions pkg/config/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package config

import (
v1API "github.com/supabase/cli/pkg/api"
"github.com/supabase/cli/pkg/cast"
"github.com/supabase/cli/pkg/diff"
)

type (
PoolMode string

// All of thoses are remote only settings that'll apply to supabase hosted database
remoteDb struct {
EffectiveCacheSize *string `toml:"effective_cache_size"`
LogicalDecodingWorkMem *string `toml:"logical_decoding_work_mem"`
MaintenanceWorkMem *string `toml:"maintenance_work_mem"`
MaxConnections *uint `toml:"max_connections"`
MaxLocksPerTransaction *uint `toml:"max_locks_per_transaction"`
MaxParallelMaintenanceWorkers *uint `toml:"max_parallel_maintenance_workers"`
MaxParallelWorkers *uint `toml:"max_parallel_workers"`
MaxParallelWorkersPerGather *uint `toml:"max_parallel_workers_per_gather"`
MaxReplicationSlots *uint `toml:"max_replication_slots"`
MaxSlotWalKeepSize *string `toml:"max_slot_wal_keep_size"`
MaxStandbyArchiveDelay *string `toml:"max_standby_archive_delay"`
MaxStandbyStreamingDelay *string `toml:"max_standby_streaming_delay"`
MaxWalSize *string `toml:"max_wal_size"`
MaxWalSenders *uint `toml:"max_wal_senders"`
MaxWorkerProcesses *uint `toml:"max_worker_processes"`
SessionReplicationRole *string `toml:"session_replication_role"`
SharedBuffers *string `toml:"shared_buffers"`
StatementTimeout *string `toml:"statement_timeout"`
WalKeepSize *string `toml:"wal_keep_size"`
WalSenderTimeout *string `toml:"wal_sender_timeout"`
WorkMem *string `toml:"work_mem"`
}

db struct {
remoteDb
Image string `toml:"-"`
Port uint16 `toml:"port"`
ShadowPort uint16 `toml:"shadow_port"`
MajorVersion uint `toml:"major_version"`
Password string `toml:"-"`
RootKey string `toml:"-" mapstructure:"root_key"`
Pooler pooler `toml:"pooler"`
Seed seed `toml:"seed"`
}

seed struct {
Enabled bool `toml:"enabled"`
GlobPatterns []string `toml:"sql_paths"`
SqlPaths []string `toml:"-"`
}

pooler struct {
Enabled bool `toml:"enabled"`
Image string `toml:"-"`
Port uint16 `toml:"port"`
PoolMode PoolMode `toml:"pool_mode"`
DefaultPoolSize uint `toml:"default_pool_size"`
MaxClientConn uint `toml:"max_client_conn"`
ConnectionString string `toml:"-"`
TenantId string `toml:"-"`
EncryptionKey string `toml:"-"`
SecretKeyBase string `toml:"-"`
}
)

func (a *db) ToUpdatePostgresConfigBody() v1API.UpdatePostgresConfigBody {
body := v1API.UpdatePostgresConfigBody{}

body.EffectiveCacheSize = a.EffectiveCacheSize
body.LogicalDecodingWorkMem = a.LogicalDecodingWorkMem
body.MaintenanceWorkMem = a.MaintenanceWorkMem
body.MaxConnections = cast.UintToIntPtr(a.MaxConnections)
body.MaxLocksPerTransaction = cast.UintToIntPtr(a.MaxLocksPerTransaction)
body.MaxParallelMaintenanceWorkers = cast.UintToIntPtr(a.MaxParallelMaintenanceWorkers)
body.MaxParallelWorkers = cast.UintToIntPtr(a.MaxParallelWorkers)
body.MaxParallelWorkersPerGather = cast.UintToIntPtr(a.MaxParallelWorkersPerGather)
body.MaxReplicationSlots = cast.UintToIntPtr(a.MaxReplicationSlots)
body.MaxSlotWalKeepSize = a.MaxSlotWalKeepSize
body.MaxStandbyArchiveDelay = a.MaxStandbyArchiveDelay
body.MaxStandbyStreamingDelay = a.MaxStandbyStreamingDelay
body.MaxWalSenders = cast.UintToIntPtr(a.MaxWalSenders)
body.MaxWalSize = a.MaxWalSize
body.MaxWorkerProcesses = cast.UintToIntPtr(a.MaxWorkerProcesses)
body.SessionReplicationRole = (*v1API.UpdatePostgresConfigBodySessionReplicationRole)(a.SessionReplicationRole)
body.SharedBuffers = a.SharedBuffers
body.StatementTimeout = a.StatementTimeout
body.WalKeepSize = a.WalKeepSize
body.WalSenderTimeout = a.WalSenderTimeout
body.WorkMem = a.WorkMem
return body
}

func (a *db) fromRemoteApiConfig(remoteConfig v1API.PostgresConfigResponse) db {
result := *a

result.remoteDb.EffectiveCacheSize = remoteConfig.EffectiveCacheSize
result.remoteDb.LogicalDecodingWorkMem = remoteConfig.LogicalDecodingWorkMem
result.remoteDb.MaintenanceWorkMem = remoteConfig.MaintenanceWorkMem
result.remoteDb.MaxConnections = cast.IntToUintPtr(remoteConfig.MaxConnections)
result.remoteDb.MaxLocksPerTransaction = cast.IntToUintPtr(remoteConfig.MaxLocksPerTransaction)
result.remoteDb.MaxParallelMaintenanceWorkers = cast.IntToUintPtr(remoteConfig.MaxParallelMaintenanceWorkers)
result.remoteDb.MaxParallelWorkers = cast.IntToUintPtr(remoteConfig.MaxParallelWorkers)
result.remoteDb.MaxParallelWorkersPerGather = cast.IntToUintPtr(remoteConfig.MaxParallelWorkersPerGather)
result.remoteDb.MaxReplicationSlots = cast.IntToUintPtr(remoteConfig.MaxReplicationSlots)
result.remoteDb.MaxSlotWalKeepSize = remoteConfig.MaxSlotWalKeepSize
result.remoteDb.MaxStandbyArchiveDelay = remoteConfig.MaxStandbyArchiveDelay
result.remoteDb.MaxStandbyStreamingDelay = remoteConfig.MaxStandbyStreamingDelay
result.remoteDb.MaxWalSenders = cast.IntToUintPtr(remoteConfig.MaxWalSenders)
result.remoteDb.MaxWalSize = remoteConfig.MaxWalSize
result.remoteDb.MaxWorkerProcesses = cast.IntToUintPtr(remoteConfig.MaxWorkerProcesses)
result.remoteDb.SessionReplicationRole = (*string)(remoteConfig.SessionReplicationRole)
result.remoteDb.SharedBuffers = remoteConfig.SharedBuffers
result.remoteDb.StatementTimeout = remoteConfig.StatementTimeout
result.remoteDb.WalKeepSize = remoteConfig.WalKeepSize
result.remoteDb.WalSenderTimeout = remoteConfig.WalSenderTimeout
result.remoteDb.WorkMem = remoteConfig.WorkMem
return result
}

func (a *db) DiffWithRemote(remoteConfig v1API.PostgresConfigResponse) ([]byte, error) {
// Convert the config values into easily comparable remoteConfig values
currentValue, err := ToTomlBytes(a)
if err != nil {
return nil, err
}
remoteCompare, err := ToTomlBytes(a.fromRemoteApiConfig(remoteConfig))
if err != nil {
return nil, err
}
return diff.Diff("remote[db]", remoteCompare, "local[db]", currentValue), nil
}
155 changes: 155 additions & 0 deletions pkg/config/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package config

import (
"testing"

"github.com/stretchr/testify/assert"
v1API "github.com/supabase/cli/pkg/api"
"github.com/supabase/cli/pkg/cast"
)

func TestDbToUpdatePostgresConfigBody(t *testing.T) {
t.Run("converts all fields correctly", func(t *testing.T) {
db := &db{
remoteDb: remoteDb{
EffectiveCacheSize: cast.Ptr("4GB"),
MaxConnections: cast.Ptr(uint(100)),
SharedBuffers: cast.Ptr("1GB"),
StatementTimeout: cast.Ptr("30s"),
SessionReplicationRole: cast.Ptr("replica"),
},
}

body := db.ToUpdatePostgresConfigBody()

assert.Equal(t, "4GB", *body.EffectiveCacheSize)
assert.Equal(t, 100, *body.MaxConnections)
assert.Equal(t, "1GB", *body.SharedBuffers)
assert.Equal(t, "30s", *body.StatementTimeout)
assert.Equal(t, v1API.UpdatePostgresConfigBodySessionReplicationRoleReplica, *body.SessionReplicationRole)
})

t.Run("handles empty fields", func(t *testing.T) {
db := &db{}

body := db.ToUpdatePostgresConfigBody()

assert.Nil(t, body.EffectiveCacheSize)
assert.Nil(t, body.MaxConnections)
assert.Nil(t, body.SharedBuffers)
assert.Nil(t, body.StatementTimeout)
assert.Nil(t, body.SessionReplicationRole)
})
}

func TestDbDiffWithRemote(t *testing.T) {
t.Run("detects differences", func(t *testing.T) {
db := &db{
remoteDb: remoteDb{
EffectiveCacheSize: cast.Ptr("4GB"),
MaxConnections: cast.Ptr(uint(100)),
SharedBuffers: cast.Ptr("1GB"),
},
}

remoteConfig := v1API.PostgresConfigResponse{
EffectiveCacheSize: cast.Ptr("8GB"),
MaxConnections: cast.Ptr(200),
SharedBuffers: cast.Ptr("2GB"),
}

diff, err := db.DiffWithRemote(remoteConfig)
assert.NoError(t, err)

assert.Contains(t, string(diff), "-effective_cache_size = \"8GB\"")
assert.Contains(t, string(diff), "+effective_cache_size = \"4GB\"")
assert.Contains(t, string(diff), "-max_connections = 200")
assert.Contains(t, string(diff), "+max_connections = 100")
assert.Contains(t, string(diff), "-shared_buffers = \"2GB\"")
assert.Contains(t, string(diff), "+shared_buffers = \"1GB\"")
})

t.Run("handles no differences", func(t *testing.T) {
db := &db{
remoteDb: remoteDb{
EffectiveCacheSize: cast.Ptr("4GB"),
MaxConnections: cast.Ptr(uint(100)),
SharedBuffers: cast.Ptr("1GB"),
},
}

remoteConfig := v1API.PostgresConfigResponse{
EffectiveCacheSize: cast.Ptr("4GB"),
MaxConnections: cast.Ptr(100),
SharedBuffers: cast.Ptr("1GB"),
}

diff, err := db.DiffWithRemote(remoteConfig)
assert.NoError(t, err)

assert.Empty(t, diff)
})

t.Run("handles multiple schemas and search paths with spaces", func(t *testing.T) {
db := &db{
remoteDb: remoteDb{
EffectiveCacheSize: cast.Ptr("4GB"),
MaxConnections: cast.Ptr(uint(100)),
SharedBuffers: cast.Ptr("1GB"),
},
}

remoteConfig := v1API.PostgresConfigResponse{
EffectiveCacheSize: cast.Ptr("4GB"),
MaxConnections: cast.Ptr(100),
SharedBuffers: cast.Ptr("1GB"),
}

diff, err := db.DiffWithRemote(remoteConfig)
assert.NoError(t, err)

assert.Empty(t, diff)
})

t.Run("handles api disabled on remote side", func(t *testing.T) {
db := &db{
remoteDb: remoteDb{
EffectiveCacheSize: cast.Ptr("4GB"),
MaxConnections: cast.Ptr(uint(100)),
SharedBuffers: cast.Ptr("1GB"),
},
}

remoteConfig := v1API.PostgresConfigResponse{
// All fields are nil to simulate disabled API
}

diff, err := db.DiffWithRemote(remoteConfig)
assert.NoError(t, err)

assert.Contains(t, string(diff), "+effective_cache_size = \"4GB\"")
assert.Contains(t, string(diff), "+max_connections = 100")
assert.Contains(t, string(diff), "+shared_buffers = \"1GB\"")
})

t.Run("handles api disabled on local side", func(t *testing.T) {
db := &db{
remoteDb: remoteDb{
// All fields are nil to simulate disabled API
},
}

remoteConfig := v1API.PostgresConfigResponse{
EffectiveCacheSize: cast.Ptr("4GB"),
MaxConnections: cast.Ptr(100),
SharedBuffers: cast.Ptr("1GB"),
}

diff, err := db.DiffWithRemote(remoteConfig)
assert.NoError(t, err)

assert.Contains(t, string(diff), "-effective_cache_size = \"4GB\"")
assert.Contains(t, string(diff), "-max_connections = 100")
assert.Contains(t, string(diff), "-shared_buffers = \"1GB\"")
})
}

0 comments on commit c3a1871

Please sign in to comment.