diff --git a/pkg/sql2pgroll/alter_table.go b/pkg/sql2pgroll/alter_table.go index 4a51c18e..33a7df8f 100644 --- a/pkg/sql2pgroll/alter_table.go +++ b/pkg/sql2pgroll/alter_table.go @@ -3,6 +3,8 @@ package sql2pgroll import ( + "fmt" + pgq "github.com/pganalyze/pg_query_go/v6" "github.com/xataio/pgroll/pkg/migrations" ) @@ -22,25 +24,50 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err continue } - switch alterTableCmd.Subtype { + var op migrations.Operation + var err error + switch alterTableCmd.GetSubtype() { case pgq.AlterTableType_AT_SetNotNull: - ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, true)) + op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, true) case pgq.AlterTableType_AT_DropNotNull: - ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, false)) + op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, false) + case pgq.AlterTableType_AT_AlterColumnType: + op, err = convertAlterTableAlterColumnType(stmt, alterTableCmd) + } + + if err != nil { + return nil, err } + + ops = append(ops, op) } return ops, nil } -func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) migrations.Operation { +func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) (migrations.Operation, error) { return &migrations.OpAlterColumn{ Table: stmt.GetRelation().GetRelname(), Column: cmd.GetName(), Nullable: ptr(!notNull), Up: PlaceHolderSQL, Down: PlaceHolderSQL, + }, nil +} + +func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) { + node, ok := cmd.GetDef().Node.(*pgq.Node_ColumnDef) + if !ok { + return nil, fmt.Errorf("expected column definition, got %T", cmd.GetDef().Node) } + + return &migrations.OpAlterColumn{ + Table: stmt.GetRelation().GetRelname(), + Column: cmd.GetName(), + Type: ptr(convertTypeName(node.ColumnDef.GetTypeName())), + Up: PlaceHolderSQL, + Down: PlaceHolderSQL, + }, nil } func ptr[T any](x T) *T { diff --git a/pkg/sql2pgroll/alter_table_test.go b/pkg/sql2pgroll/alter_table_test.go index 3def8c8f..1cdb6521 100644 --- a/pkg/sql2pgroll/alter_table_test.go +++ b/pkg/sql2pgroll/alter_table_test.go @@ -27,6 +27,14 @@ func TestConvertAlterTableStatements(t *testing.T) { sql: "ALTER TABLE foo ALTER COLUMN a DROP NOT NULL", expectedOp: expect.AlterTableOp2, }, + { + sql: "ALTER TABLE foo ALTER COLUMN a SET DATA TYPE text", + expectedOp: expect.AlterTableOp3, + }, + { + sql: "ALTER TABLE foo ALTER COLUMN a TYPE text", + expectedOp: expect.AlterTableOp3, + }, } for _, tc := range tests { diff --git a/pkg/sql2pgroll/expect/alter_table.go b/pkg/sql2pgroll/expect/alter_table.go index a235c2db..4f7aceea 100644 --- a/pkg/sql2pgroll/expect/alter_table.go +++ b/pkg/sql2pgroll/expect/alter_table.go @@ -23,6 +23,14 @@ var AlterTableOp2 = &migrations.OpAlterColumn{ Down: sql2pgroll.PlaceHolderSQL, } +var AlterTableOp3 = &migrations.OpAlterColumn{ + Table: "foo", + Column: "a", + Type: ptr("text"), + Up: sql2pgroll.PlaceHolderSQL, + Down: sql2pgroll.PlaceHolderSQL, +} + func ptr[T any](v T) *T { return &v }