diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index 915429ea..38f6a063 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -5,6 +5,7 @@ package sql2pgroll import ( "fmt" + "github.com/oapi-codegen/nullable" pgq "github.com/pganalyze/pg_query_go/v6" "github.com/xataio/pgroll/pkg/migrations" @@ -38,6 +39,8 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err op, err = convertAlterTableAddConstraint(stmt, alterTableCmd) case pgq.AlterTableType_AT_DropColumn: op, err = convertAlterTableDropColumn(stmt, alterTableCmd) + case pgq.AlterTableType_AT_ColumnDefault: + op, err = convertAlterTableSetColumnDefault(stmt, alterTableCmd) } if err != nil { @@ -158,11 +161,27 @@ func convertAlterTableAddUniqueConstraint(stmt *pgq.AlterTableStmt, constraint * }, nil } -// convertAlterTableDropColumn converts SQL statements like: +// convertAlterTableSetColumnDefault converts SQL statements like: // -// `ALTER TABLE foo DROP COLUMN bar +// `ALTER TABLE foo COLUMN bar SET DEFAULT 'foo' +// `ALTER TABLE foo COLUMN bar SET DEFAULT null +// `ALTER TABLE foo COLUMN bar DROP DEFAULT // // to an OpDropColumn operation. +func convertAlterTableSetColumnDefault(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { + def := nullable.NewNullNullable[string]() + if val := cmd.GetDef().GetAConst().GetSval(); val != nil { + def.Set(val.Sval) + } + return &migrations.OpAlterColumn{ + Table: stmt.GetRelation().GetRelname(), + Column: cmd.GetName(), + Default: def, + Down: PlaceHolderSQL, + Up: PlaceHolderSQL, + }, nil +} + func convertAlterTableDropColumn(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { if !canConvertDropColumn(cmd) { return nil, nil diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go index 144b220b..68a4b52f 100644 --- a/pkg/sql2pgroll/alter_table_test.go +++ b/pkg/sql2pgroll/alter_table_test.go @@ -36,6 +36,18 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE foo ALTER COLUMN a TYPE text", expectedOp: expect.AlterColumnOp3, }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT 'baz'", + expectedOp: expect.AlterColumnOp5, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar DROP DEFAULT", + expectedOp: expect.AlterColumnOp6, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DEFAULT null", + expectedOp: expect.AlterColumnOp6, + }, { sql: "ALTER TABLE foo ADD CONSTRAINT bar UNIQUE (a)", expectedOp: expect.CreateConstraintOp1, diff --git a/pkg/sql2pgroll/expect/alter_column.go b/pkg/sql2pgroll/expect/alter_column.go index ec9b0acc..7b2af1eb 100644 --- a/pkg/sql2pgroll/expect/alter_column.go +++ b/pkg/sql2pgroll/expect/alter_column.go @@ -3,6 +3,8 @@ package expect import ( + "github.com/oapi-codegen/nullable" + "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/sql2pgroll" ) @@ -37,6 +39,22 @@ var AlterColumnOp4 = &migrations.OpAlterColumn{ Name: ptr("b"), } +var AlterColumnOp5 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullableWithValue("baz"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + +var AlterColumnOp6 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "bar", + Default: nullable.NewNullNullable[string](), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + func ptr[T any](v T) *T { return &v }