Skip to content

Commit

Permalink
Convert ALTER COLUMN SET DATA TYPE SQL
Browse files Browse the repository at this point in the history
Convert SQL statements of the form:

```sql
ALTER TABLE foo ALTER COLUMN a [SET DATA] TYPE text
```

to the equivalent `pgroll` migration:

```json
[
  {
    "alter_column": {
      "column": "a",
      "down": "TODO: Implement SQL data migration",
      "table": "foo",
      "type": "text",
      "up": "TODO: Implement SQL data migration"
    }
  }
]
```
  • Loading branch information
andrew-farries committed Dec 3, 2024
1 parent 9fde877 commit 5523009
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
35 changes: 31 additions & 4 deletions pkg/sql2pgroll/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
package sql2pgroll

import (
"fmt"

pgq "github.com/pganalyze/pg_query_go/v6"
"github.com/xataio/pgroll/pkg/migrations"
)
Expand All @@ -22,25 +24,50 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err
continue
}

switch alterTableCmd.Subtype {
var op migrations.Operation
var err error
switch alterTableCmd.GetSubtype() {
case pgq.AlterTableType_AT_SetNotNull:
ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, true))
op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, true)
case pgq.AlterTableType_AT_DropNotNull:
ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, false))
op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, false)
case pgq.AlterTableType_AT_AlterColumnType:
op, err = convertAlterTableAlterColumnType(stmt, alterTableCmd)
}

if err != nil {
return nil, err
}

ops = append(ops, op)
}

return ops, nil
}

func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) migrations.Operation {
func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) (migrations.Operation, error) {
return &migrations.OpAlterColumn{
Table: stmt.GetRelation().GetRelname(),
Column: cmd.GetName(),
Nullable: ptr(!notNull),
Up: PlaceHolderSQL,
Down: PlaceHolderSQL,
}, nil
}

func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
node, ok := cmd.GetDef().Node.(*pgq.Node_ColumnDef)
if !ok {
return nil, fmt.Errorf("expected column definition, got %T", cmd.GetDef().Node)
}

return &migrations.OpAlterColumn{
Table: stmt.GetRelation().GetRelname(),
Column: cmd.GetName(),
Type: ptr(convertTypeName(node.ColumnDef.GetTypeName())),
Up: PlaceHolderSQL,
Down: PlaceHolderSQL,
}, nil
}

func ptr[T any](x T) *T {
Expand Down
8 changes: 8 additions & 0 deletions pkg/sql2pgroll/alter_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ func TestConvertAlterTableStatements(t *testing.T) {
sql: "ALTER TABLE foo ALTER COLUMN a DROP NOT NULL",
expectedOp: expect.AlterTableOp2,
},
{
sql: "ALTER TABLE foo ALTER COLUMN a SET DATA TYPE text",
expectedOp: expect.AlterTableOp3,
},
{
sql: "ALTER TABLE foo ALTER COLUMN a TYPE text",
expectedOp: expect.AlterTableOp3,
},
}

for _, tc := range tests {
Expand Down
8 changes: 8 additions & 0 deletions pkg/sql2pgroll/expect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ var AlterTableOp2 = &migrations.OpAlterColumn{
Down: sql2pgroll.PlaceHolderSQL,
}

var AlterTableOp3 = &migrations.OpAlterColumn{
Table: "foo",
Column: "a",
Type: ptr("text"),
Up: sql2pgroll.PlaceHolderSQL,
Down: sql2pgroll.PlaceHolderSQL,
}

func ptr[T any](v T) *T {
return &v
}

0 comments on commit 5523009

Please sign in to comment.