From a849534638d9d6bba56a8a5aae9bacb41d4d38df Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Wed, 18 Dec 2024 09:03:30 +0000 Subject: [PATCH] Handle unconvertible `CREATE TABLE` statements (#546) There are [many options](https://www.postgresql.org/docs/current/sql-createtable.html) for the `CREATE TABLE` statement in Postgres, most of which are not currently representable by the `pgroll` `OpCreateTable` operation. Add tests to ensure that `sql2proll.Convert`falls back to raw SQL operations when these unconvertible options are present in a SQL statement. Part of #504 --- pkg/sql2pgroll/create_table.go | 39 +++++++++++++++++++++++++ pkg/sql2pgroll/create_table_test.go | 45 +++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/pkg/sql2pgroll/create_table.go b/pkg/sql2pgroll/create_table.go index 5add694a..4ea6baea 100644 --- a/pkg/sql2pgroll/create_table.go +++ b/pkg/sql2pgroll/create_table.go @@ -12,6 +12,12 @@ import ( // convertCreateStmt converts a CREATE TABLE statement to a pgroll operation. func convertCreateStmt(stmt *pgq.CreateStmt) (migrations.Operations, error) { + // Check if the statement can be converted + if !canConvertCreateStatement(stmt) { + return nil, nil + } + + // Convert the column definitions columns := make([]migrations.Column, 0, len(stmt.TableElts)) for _, elt := range stmt.TableElts { column, err := convertColumnDef(elt.GetColumnDef()) @@ -29,6 +35,39 @@ func convertCreateStmt(stmt *pgq.CreateStmt) (migrations.Operations, error) { }, nil } +// canConvertCreateTableStatement returns true iff `stmt` can be converted to a +// pgroll operation. +func canConvertCreateStatement(stmt *pgq.CreateStmt) bool { + switch { + // Temporary and unlogged tables are not supported + case stmt.GetRelation().GetRelpersistence() != "p": + return false + // CREATE TABLE IF NOT EXISTS is not supported + case stmt.GetIfNotExists(): + return false + // Table inheritance is not supported + case len(stmt.GetInhRelations()) != 0: + return false + // Paritioned tables are not supported + case stmt.GetPartspec() != nil: + return false + // Specifying an access method is not supported + case stmt.GetAccessMethod() != "": + return false + // Specifying storage options is not supported + case len(stmt.GetOptions()) != 0: + return false + // ON COMMIT options are not supported + case stmt.GetOncommit() != pgq.OnCommitAction_ONCOMMIT_NOOP: + return false + // Setting a tablespace is not supported + case stmt.GetTablespacename() != "": + return false + default: + return true + } +} + func convertColumnDef(col *pgq.ColumnDef) (*migrations.Column, error) { // Convert the column type typeString, err := pgq.DeparseTypeName(col.TypeName) diff --git a/pkg/sql2pgroll/create_table_test.go b/pkg/sql2pgroll/create_table_test.go index acc980b0..62388a0f 100644 --- a/pkg/sql2pgroll/create_table_test.go +++ b/pkg/sql2pgroll/create_table_test.go @@ -72,3 +72,48 @@ func TestConvertCreateTableStatements(t *testing.T) { }) } } + +func TestUnconvertableCreateTableStatements(t *testing.T) { + t.Parallel() + + tests := []string{ + // Temporary and unlogged tables are not supported + "CREATE TEMPORARY TABLE foo(a int)", + "CREATE UNLOGGED TABLE foo(a int)", + + // The IF NOT EXISTS clause is not supported + "CREATE TABLE IF NOT EXISTS foo(a int)", + + // Table inheritance is not supported + "CREATE TABLE foo(a int) INHERITS (bar)", + + // Any kind of partitioning is not supported + "CREATE TABLE foo(a int) PARTITION BY RANGE (a)", + "CREATE TABLE foo(a int) PARTITION BY LIST (a)", + + // Specifying a table access method is not supported + "CREATE TABLE foo(a int) USING bar", + + // Specifying storage options is not supported + "CREATE TABLE foo(a int) WITH (fillfactor=70)", + + // ON COMMMIT options are not supported. These options are syntactically + // valid for all tables, but Postgres will reject them for non-temporary + // tables. We err on the side of caution and reject them for all tables. + "CREATE TABLE foo(a int) ON COMMIT DROP", + + // Specifying a tablespace is not supported + "CREATE TABLE foo(a int) TABLESPACE bar", + } + + for _, sql := range tests { + t.Run(sql, func(t *testing.T) { + ops, err := sql2pgroll.Convert(sql) + require.NoError(t, err) + + require.Len(t, ops, 1) + + assert.Equal(t, expect.RawSQLOp(sql), ops[0]) + }) + } +}