diff --git a/target_postgres/sql_base.py b/target_postgres/sql_base.py index 12c2bbe2..9d879526 100644 --- a/target_postgres/sql_base.py +++ b/target_postgres/sql_base.py @@ -403,23 +403,30 @@ def upsert_table_helper(self, connection, schema, metadata, log_schema_changes=T self.add_key_properties(connection, table_name, schema.get('key_properties', None)) - ## Only process columns which have single, nullable, types - single_type_columns = [] - for column_path, column_schema in schema['schema']['properties'].items(): - for sub_schema in column_schema['anyOf']: - single_type_columns.append((column_path, deepcopy(sub_schema))) - - ## Process new columns against existing - raw_mappings = existing_schema.get('mappings', {}) - + ## Build up mappings to compare new columns against existing mappings = [] - for to, m in raw_mappings.items(): + for to, m in existing_schema.get('mappings', {}).items(): mapping = json_schema.simple_type(m) mapping['from'] = tuple(m['from']) mapping['to'] = to mappings.append(mapping) + ## Only process columns which have single, nullable, types + column_paths_seen = set() + single_type_columns = [] + + for column_path, column_schema in schema['schema']['properties'].items(): + column_paths_seen.add(column_path) + for sub_schema in column_schema['anyOf']: + single_type_columns.append((column_path, deepcopy(sub_schema))) + + ### Add any columns missing from new schema + for m in mappings: + if not m['from'] in column_paths_seen: + single_type_columns.append((m['from'], json_schema.make_nullable(m))) + + ## Process new columns against existing table_empty = self.is_table_empty(connection, table_name) for column_path, column_schema in single_type_columns: diff --git a/tests/unit/test_postgres.py b/tests/unit/test_postgres.py index 527fefce..1a90c176 100644 --- a/tests/unit/test_postgres.py +++ b/tests/unit/test_postgres.py @@ -1059,6 +1059,87 @@ def generate_record(self): assert cat_count == len([x for x in persisted_records if x[0] is None]) +def test_loading__column_type_change__nullable__missing_from_schema(db_cleanup): + cat_count = 20 + main(CONFIG, input_stream=CatStream(cat_count)) + + with psycopg2.connect(**TEST_DB) as conn: + with conn.cursor() as cur: + assert_columns_equal(cur, + 'cats', + { + ('_sdc_batched_at', 'timestamp with time zone', 'YES'), + ('_sdc_received_at', 'timestamp with time zone', 'YES'), + ('_sdc_sequence', 'bigint', 'YES'), + ('_sdc_table_version', 'bigint', 'YES'), + ('adoption__adopted_on', 'timestamp with time zone', 'YES'), + ('adoption__was_foster', 'boolean', 'YES'), + ('age', 'bigint', 'YES'), + ('id', 'bigint', 'NO'), + ('name', 'text', 'NO'), + ('bio', 'text', 'NO'), + ('paw_size', 'bigint', 'NO'), + ('paw_colour', 'text', 'NO'), + ('flea_check_complete', 'boolean', 'NO'), + ('pattern', 'text', 'YES') + }) + + cur.execute(sql.SQL('SELECT {} FROM {}').format( + sql.Identifier('name'), + sql.Identifier('cats') + )) + persisted_records = cur.fetchall() + + ## Assert that the original data is present + assert cat_count == len(persisted_records) + assert cat_count == len([x for x in persisted_records if x[0] is not None]) + + class NameMissingCatStream(CatStream): + def generate_record(self): + record = CatStream.generate_record(self) + record['id'] = record['id'] + cat_count + del record['name'] + return record + + stream = NameMissingCatStream(cat_count) + stream.schema = deepcopy(stream.schema) + del stream.schema['schema']['properties']['name'] + + main(CONFIG, input_stream=stream) + + with psycopg2.connect(**TEST_DB) as conn: + with conn.cursor() as cur: + assert_columns_equal(cur, + 'cats', + { + ('_sdc_batched_at', 'timestamp with time zone', 'YES'), + ('_sdc_received_at', 'timestamp with time zone', 'YES'), + ('_sdc_sequence', 'bigint', 'YES'), + ('_sdc_table_version', 'bigint', 'YES'), + ('adoption__adopted_on', 'timestamp with time zone', 'YES'), + ('adoption__was_foster', 'boolean', 'YES'), + ('age', 'bigint', 'YES'), + ('id', 'bigint', 'NO'), + ('name', 'text', 'YES'), + ('bio', 'text', 'NO'), + ('paw_size', 'bigint', 'NO'), + ('paw_colour', 'text', 'NO'), + ('flea_check_complete', 'boolean', 'NO'), + ('pattern', 'text', 'YES') + }) + + cur.execute(sql.SQL('SELECT {} FROM {}').format( + sql.Identifier('name'), + sql.Identifier('cats') + )) + persisted_records = cur.fetchall() + + ## Assert that the column is has migrated data + assert 2 * cat_count == len(persisted_records) + assert cat_count == len([x for x in persisted_records if x[0] is not None]) + assert cat_count == len([x for x in persisted_records if x[0] is None]) + + def test_loading__multi_types_columns(db_cleanup): stream_count = 50 main(CONFIG, input_stream=MultiTypeStream(stream_count))