diff --git a/dlt/destinations/impl/sqlalchemy/alter_table.py b/dlt/destinations/impl/sqlalchemy/alter_table.py new file mode 100644 index 0000000000..f85101a740 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/alter_table.py @@ -0,0 +1,38 @@ +from typing import List + +import sqlalchemy as sa +from alembic.runtime.migration import MigrationContext +from alembic.operations import Operations + + +class ListBuffer: + """A partial implementation of string IO to use with alembic. + SQL statements are stored in a list instead of file/stdio + """ + + def __init__(self) -> None: + self._buf = "" + self.sql_lines: List[str] = [] + + def write(self, data: str) -> None: + self._buf += data + + def flush(self) -> None: + if self._buf: + self.sql_lines.append(self._buf) + self._buf = "" + + +class MigrationMaker: + def __init__(self, dialect: sa.engine.Dialect) -> None: + self._buf = ListBuffer() + self.ctx = MigrationContext(dialect, None, {"as_sql": True, "output_buffer": self._buf}) + self.ops = Operations(self.ctx) + + def add_column(self, table_name: str, column: sa.Column, schema: str) -> None: + self.ops.add_column(table_name, column, schema=schema) + + def consume_statements(self) -> List[str]: + lines = self._buf.sql_lines[:] + self._buf.sql_lines.clear() + return lines diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 13b28385d7..8c3df409aa 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -29,6 +29,7 @@ from dlt.destinations.typing import DBTransaction, DBApiCursor from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyCredentials +from dlt.destinations.impl.sqlalchemy.alter_table import MigrationMaker from dlt.common.typing import TFun @@ -97,6 +98,7 @@ class SqlalchemyClient(SqlClientBase[Connection]): dialect: sa.engine.interfaces.Dialect dialect_name: str dbapi = DbApiProps # type: ignore[assignment] + migrations: Optional[MigrationMaker] = None # lazy init as needed def __init__( self, @@ -316,17 +318,16 @@ def fully_qualified_dataset_name(self, escape: bool = True, staging: bool = Fals raise NotImplementedError("Staging not supported") return self.dialect.identifier_preparer.format_schema(self.dataset_name) # type: ignore[attr-defined, no-any-return] - def alter_table_add_column(self, column: sa.Column) -> None: - """Execute an ALTER TABLE ... ADD COLUMN ... statement for the given column. - The column must be fully defined and attached to a table. - """ - # TODO: May need capability to override ALTER TABLE statement for different dialects - alter_tmpl = "ALTER TABLE {table} ADD COLUMN {column};" - statement = alter_tmpl.format( - table=self._make_qualified_table_name(self._make_qualified_table_name(column.table)), # type: ignore[arg-type] - column=self.compile_column_def(column), - ) - self.execute_sql(statement) + def alter_table_add_columns(self, columns: Sequence[sa.Column]) -> None: + if not columns: + return + if self.migrations is None: + self.migrations = MigrationMaker(self.dialect) + for column in columns: + self.migrations.add_column(column.table.name, column, self.dataset_name) + statements = self.migrations.consume_statements() + for statement in statements: + self.execute_sql(statement) def escape_column_name(self, column_name: str, escape: bool = True) -> str: if self.dialect.requires_name_normalize: # type: ignore[attr-defined] diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index bd163c2c53..33f00870c1 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -113,7 +113,7 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.ty elif sc_t == "bool": return sa.Boolean() elif sc_t == "timestamp": - return self._create_date_time_type(sc_t, precision) + return self._create_date_time_type(sc_t, precision, column.get("timezone")) elif sc_t == "bigint": return self._db_integer_type(precision) elif sc_t == "binary": @@ -128,7 +128,7 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.ty elif sc_t == "date": return sa.Date() elif sc_t == "time": - return self._create_date_time_type(sc_t, precision) + return self._create_date_time_type(sc_t, precision, column.get("timezone")) raise TerminalValueError(f"Unsupported data type: {sc_t}") def _from_db_integer_type(self, db_type: sa.Integer) -> TColumnType: @@ -400,13 +400,7 @@ def update_stored_schema( with self.sql_client.begin_transaction(): for table_obj in tables_to_create: self.sql_client.create_table(table_obj) - for col in columns_to_add: - alter = "ALTER TABLE {} ADD COLUMN {}".format( - self.sql_client.make_qualified_table_name(col.table.name), - self.sql_client.compile_column_def(col), - ) - self.sql_client.execute_sql(alter) - + self.sql_client.alter_table_add_columns(columns_to_add) self._update_schema_in_storage(self.schema) return schema_update diff --git a/poetry.lock b/poetry.lock index 6d99143f00..56ceeb03f8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -216,13 +216,13 @@ frozenlist = ">=1.1.0" [[package]] name = "alembic" -version = "1.12.0" +version = "1.13.2" description = "A database migration tool for SQLAlchemy." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "alembic-1.12.0-py3-none-any.whl", hash = "sha256:03226222f1cf943deee6c85d9464261a6c710cd19b4fe867a3ad1f25afda610f"}, - {file = "alembic-1.12.0.tar.gz", hash = "sha256:8e7645c32e4f200675e69f0745415335eb59a3663f5feb487abfa0b30c45888b"}, + {file = "alembic-1.13.2-py3-none-any.whl", hash = "sha256:6b8733129a6224a9a711e17c99b08462dbf7cc9670ba8f2e2ae9af860ceb1953"}, + {file = "alembic-1.13.2.tar.gz", hash = "sha256:1ff0ae32975f4fd96028c39ed9bb3c867fe3af956bd7bb37343b54c9fe7445ef"}, ] [package.dependencies] @@ -233,7 +233,7 @@ SQLAlchemy = ">=1.3.0" typing-extensions = ">=4" [package.extras] -tz = ["python-dateutil"] +tz = ["backports.zoneinfo"] [[package]] name = "alive-progress" @@ -9716,11 +9716,11 @@ qdrant = ["qdrant-client"] redshift = ["psycopg2-binary", "psycopg2cffi"] s3 = ["botocore", "s3fs"] snowflake = ["snowflake-connector-python"] -sqlalchemy = ["sqlalchemy"] +sqlalchemy = ["alembic", "sqlalchemy"] synapse = ["adlfs", "pyarrow", "pyodbc"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "0389931f2b6085e8d8824e2de09f779a4a6725dd86a85ce0a170cb0749d730a3" +content-hash = "bcbb266042a35c3bcb9c82510dac2345ae572c68e0b0d3d2523469db6ddbbb00" diff --git a/pyproject.toml b/pyproject.toml index 72b213c070..ce5d523dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= ' tantivy = { version = ">= 0.22.0", optional = true } deltalake = { version = ">=0.19.0", optional = true } sqlalchemy = {version = ">=1.4", optional = true} +alembic = {version = "^1.13.2", optional = true} [tool.poetry.extras] gcp = ["grpcio", "google-cloud-bigquery", "db-dtypes", "gcsfs"] @@ -109,7 +110,7 @@ clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs dremio = ["pyarrow"] lancedb = ["lancedb", "pyarrow", "tantivy"] deltalake = ["deltalake", "pyarrow"] -sqlalchemy = ["sqlalchemy"] +sqlalchemy = ["sqlalchemy", "alembic"] [tool.poetry.scripts]