Skip to content

Commit

Permalink
Convert SET DATA TYPE SQL to pgroll operation (#506)
Browse files Browse the repository at this point in the history
Convert SQL statements of the form:

```sql
ALTER TABLE foo ALTER COLUMN a [SET DATA] TYPE text
```

to the equivalent `pgroll` migration:

```json
[
  {
    "alter_column": {
      "column": "a",
      "down": "TODO: Implement SQL data migration",
      "table": "foo",
      "type": "text",
      "up": "TODO: Implement SQL data migration"
    }
  }
]
```

Part of #504
  • Loading branch information
andrew-farries authored Dec 3, 2024
1 parent 84661eb commit c756988
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 44 deletions.
35 changes: 31 additions & 4 deletions pkg/sql2pgroll/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
package sql2pgroll

import (
"fmt"

pgq "github.com/pganalyze/pg_query_go/v6"
"github.com/xataio/pgroll/pkg/migrations"
)
Expand All @@ -22,25 +24,50 @@ func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, err
continue
}

switch alterTableCmd.Subtype {
var op migrations.Operation
var err error
switch alterTableCmd.GetSubtype() {
case pgq.AlterTableType_AT_SetNotNull:
ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, true))
op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, true)
case pgq.AlterTableType_AT_DropNotNull:
ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd, false))
op, err = convertAlterTableSetNotNull(stmt, alterTableCmd, false)
case pgq.AlterTableType_AT_AlterColumnType:
op, err = convertAlterTableAlterColumnType(stmt, alterTableCmd)
}

if err != nil {
return nil, err
}

ops = append(ops, op)
}

return ops, nil
}

func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) migrations.Operation {
func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd, notNull bool) (migrations.Operation, error) {
return &migrations.OpAlterColumn{
Table: stmt.GetRelation().GetRelname(),
Column: cmd.GetName(),
Nullable: ptr(!notNull),
Up: PlaceHolderSQL,
Down: PlaceHolderSQL,
}, nil
}

func convertAlterTableAlterColumnType(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) (migrations.Operation, error) {
node, ok := cmd.GetDef().Node.(*pgq.Node_ColumnDef)
if !ok {
return nil, fmt.Errorf("expected column definition, got %T", cmd.GetDef().Node)
}

return &migrations.OpAlterColumn{
Table: stmt.GetRelation().GetRelname(),
Column: cmd.GetName(),
Type: ptr(convertTypeName(node.ColumnDef.GetTypeName())),
Up: PlaceHolderSQL,
Down: PlaceHolderSQL,
}, nil
}

func ptr[T any](x T) *T {
Expand Down
8 changes: 8 additions & 0 deletions pkg/sql2pgroll/alter_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ func TestConvertAlterTableStatements(t *testing.T) {
sql: "ALTER TABLE foo ALTER COLUMN a DROP NOT NULL",
expectedOp: expect.AlterTableOp2,
},
{
sql: "ALTER TABLE foo ALTER COLUMN a SET DATA TYPE text",
expectedOp: expect.AlterTableOp3,
},
{
sql: "ALTER TABLE foo ALTER COLUMN a TYPE text",
expectedOp: expect.AlterTableOp3,
},
}

for _, tc := range tests {
Expand Down
43 changes: 3 additions & 40 deletions pkg/sql2pgroll/create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
package sql2pgroll

import (
"fmt"
"strings"

pgq "github.com/pganalyze/pg_query_go/v6"
"github.com/xataio/pgroll/pkg/migrations"
)
Expand All @@ -26,42 +23,8 @@ func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) {
}

func convertColumnDef(col *pgq.ColumnDef) migrations.Column {
ignoredTypeParts := map[string]bool{
"pg_catalog": true,
}

// Build the type name, including any schema qualifiers
typeParts := make([]string, 0, len(col.GetTypeName().Names))
for _, node := range col.GetTypeName().Names {
typePart := node.GetString_().GetSval()
if _, ok := ignoredTypeParts[typePart]; ok {
continue
}
typeParts = append(typeParts, typePart)
}

// Build the type modifiers, such as precision and scale for numeric types
var typeMods []string
for _, node := range col.GetTypeName().Typmods {
if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok {
typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval()))
}
}
var typeModifier string
if len(typeMods) > 0 {
typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ","))
}

// Build the array bounds for array types
var arrayBounds string
for _, node := range col.GetTypeName().ArrayBounds {
bound := node.GetInteger().GetIval()
if bound == -1 {
arrayBounds = "[]"
} else {
arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound)
}
}
// Convert the column type
typeString := convertTypeName(col.TypeName)

// Determine column nullability, uniqueness, and primary key status
var notNull, unique, pk bool
Expand All @@ -81,7 +44,7 @@ func convertColumnDef(col *pgq.ColumnDef) migrations.Column {

return migrations.Column{
Name: col.Colname,
Type: strings.Join(typeParts, ".") + typeModifier + arrayBounds,
Type: typeString,
Nullable: !notNull,
Unique: unique,
Default: defaultValue,
Expand Down
8 changes: 8 additions & 0 deletions pkg/sql2pgroll/expect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ var AlterTableOp2 = &migrations.OpAlterColumn{
Down: sql2pgroll.PlaceHolderSQL,
}

var AlterTableOp3 = &migrations.OpAlterColumn{
Table: "foo",
Column: "a",
Type: ptr("text"),
Up: sql2pgroll.PlaceHolderSQL,
Down: sql2pgroll.PlaceHolderSQL,
}

func ptr[T any](v T) *T {
return &v
}
52 changes: 52 additions & 0 deletions pkg/sql2pgroll/typename.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll

import (
"fmt"
"strings"

pgq "github.com/pganalyze/pg_query_go/v6"
)

// convertTypeName converts a TypeName node to a string.
func convertTypeName(typeName *pgq.TypeName) string {
ignoredTypeParts := map[string]bool{
"pg_catalog": true,
}

// Build the type name, including any schema qualifiers
typeParts := make([]string, 0, len(typeName.Names))
for _, node := range typeName.Names {
typePart := node.GetString_().GetSval()
if _, ok := ignoredTypeParts[typePart]; ok {
continue
}
typeParts = append(typeParts, typePart)
}

// Build the type modifiers, such as precision and scale for numeric types
var typeMods []string
for _, node := range typeName.Typmods {
if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok {
typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval()))
}
}
var typeModifier string
if len(typeMods) > 0 {
typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ","))
}

// Build the array bounds for array types
var arrayBounds string
for _, node := range typeName.ArrayBounds {
bound := node.GetInteger().GetIval()
if bound == -1 {
arrayBounds = "[]"
} else {
arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound)
}
}

return strings.Join(typeParts, ".") + typeModifier + arrayBounds
}

0 comments on commit c756988

Please sign in to comment.