diff --git a/pkg/sql2pgroll/create_table.go b/pkg/sql2pgroll/create_table.go index 63604895..dbaeb504 100644 --- a/pkg/sql2pgroll/create_table.go +++ b/pkg/sql2pgroll/create_table.go @@ -3,9 +3,6 @@ package sql2pgroll import ( - "fmt" - "strings" - pgq "github.com/pganalyze/pg_query_go/v6" "github.com/xataio/pgroll/pkg/migrations" ) @@ -26,42 +23,8 @@ func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) { } func convertColumnDef(col *pgq.ColumnDef) migrations.Column { - ignoredTypeParts := map[string]bool{ - "pg_catalog": true, - } - - // Build the type name, including any schema qualifiers - typeParts := make([]string, 0, len(col.GetTypeName().Names)) - for _, node := range col.GetTypeName().Names { - typePart := node.GetString_().GetSval() - if _, ok := ignoredTypeParts[typePart]; ok { - continue - } - typeParts = append(typeParts, typePart) - } - - // Build the type modifiers, such as precision and scale for numeric types - var typeMods []string - for _, node := range col.GetTypeName().Typmods { - if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok { - typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval())) - } - } - var typeModifier string - if len(typeMods) > 0 { - typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ",")) - } - - // Build the array bounds for array types - var arrayBounds string - for _, node := range col.GetTypeName().ArrayBounds { - bound := node.GetInteger().GetIval() - if bound == -1 { - arrayBounds = "[]" - } else { - arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound) - } - } + // Convert the column type + typeString := convertTypeName(col.TypeName) // Determine column nullability, uniqueness, and primary key status var notNull, unique, pk bool @@ -81,7 +44,7 @@ func convertColumnDef(col *pgq.ColumnDef) migrations.Column { return migrations.Column{ Name: col.Colname, - Type: strings.Join(typeParts, ".") + typeModifier + arrayBounds, + Type: typeString, Nullable: !notNull, Unique: unique, Default: defaultValue, diff --git a/pkg/sql2pgroll/typename.go b/pkg/sql2pgroll/typename.go new file mode 100644 index 00000000..f2c5a11d --- /dev/null +++ b/pkg/sql2pgroll/typename.go @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 + +package sql2pgroll + +import ( + "fmt" + "strings" + + pgq "github.com/pganalyze/pg_query_go/v6" +) + +// convertTypeName converts a TypeName node to a string. +func convertTypeName(typeName *pgq.TypeName) string { + ignoredTypeParts := map[string]bool{ + "pg_catalog": true, + } + + // Build the type name, including any schema qualifiers + typeParts := make([]string, 0, len(typeName.Names)) + for _, node := range typeName.Names { + typePart := node.GetString_().GetSval() + if _, ok := ignoredTypeParts[typePart]; ok { + continue + } + typeParts = append(typeParts, typePart) + } + + // Build the type modifiers, such as precision and scale for numeric types + var typeMods []string + for _, node := range typeName.Typmods { + if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok { + typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval())) + } + } + var typeModifier string + if len(typeMods) > 0 { + typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ",")) + } + + // Build the array bounds for array types + var arrayBounds string + for _, node := range typeName.ArrayBounds { + bound := node.GetInteger().GetIval() + if bound == -1 { + arrayBounds = "[]" + } else { + arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound) + } + } + + return strings.Join(typeParts, ".") + typeModifier + arrayBounds +}