Skip to content

Commit

Permalink
Extract convertTypeName function
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-farries committed Dec 3, 2024
1 parent 84661eb commit 9fde877
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 40 deletions.
43 changes: 3 additions & 40 deletions pkg/sql2pgroll/create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
package sql2pgroll

import (
"fmt"
"strings"

pgq "github.com/pganalyze/pg_query_go/v6"
"github.com/xataio/pgroll/pkg/migrations"
)
Expand All @@ -26,42 +23,8 @@ func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) {
}

func convertColumnDef(col *pgq.ColumnDef) migrations.Column {
ignoredTypeParts := map[string]bool{
"pg_catalog": true,
}

// Build the type name, including any schema qualifiers
typeParts := make([]string, 0, len(col.GetTypeName().Names))
for _, node := range col.GetTypeName().Names {
typePart := node.GetString_().GetSval()
if _, ok := ignoredTypeParts[typePart]; ok {
continue
}
typeParts = append(typeParts, typePart)
}

// Build the type modifiers, such as precision and scale for numeric types
var typeMods []string
for _, node := range col.GetTypeName().Typmods {
if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok {
typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval()))
}
}
var typeModifier string
if len(typeMods) > 0 {
typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ","))
}

// Build the array bounds for array types
var arrayBounds string
for _, node := range col.GetTypeName().ArrayBounds {
bound := node.GetInteger().GetIval()
if bound == -1 {
arrayBounds = "[]"
} else {
arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound)
}
}
// Convert the column type
typeString := convertTypeName(col.TypeName)

// Determine column nullability, uniqueness, and primary key status
var notNull, unique, pk bool
Expand All @@ -81,7 +44,7 @@ func convertColumnDef(col *pgq.ColumnDef) migrations.Column {

return migrations.Column{
Name: col.Colname,
Type: strings.Join(typeParts, ".") + typeModifier + arrayBounds,
Type: typeString,
Nullable: !notNull,
Unique: unique,
Default: defaultValue,
Expand Down
52 changes: 52 additions & 0 deletions pkg/sql2pgroll/typename.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// SPDX-License-Identifier: Apache-2.0

package sql2pgroll

import (
"fmt"
"strings"

pgq "github.com/pganalyze/pg_query_go/v6"
)

// convertTypeName converts a TypeName node to a string.
func convertTypeName(typeName *pgq.TypeName) string {
ignoredTypeParts := map[string]bool{
"pg_catalog": true,
}

// Build the type name, including any schema qualifiers
typeParts := make([]string, 0, len(typeName.Names))
for _, node := range typeName.Names {
typePart := node.GetString_().GetSval()
if _, ok := ignoredTypeParts[typePart]; ok {
continue
}
typeParts = append(typeParts, typePart)
}

// Build the type modifiers, such as precision and scale for numeric types
var typeMods []string
for _, node := range typeName.Typmods {
if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok {
typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval()))
}
}
var typeModifier string
if len(typeMods) > 0 {
typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ","))
}

// Build the array bounds for array types
var arrayBounds string
for _, node := range typeName.ArrayBounds {
bound := node.GetInteger().GetIval()
if bound == -1 {
arrayBounds = "[]"
} else {
arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound)
}
}

return strings.Join(typeParts, ".") + typeModifier + arrayBounds
}

0 comments on commit 9fde877

Please sign in to comment.