diff --git a/dagster_mssql_bcp_tests/bcp_polars/test_bcp.py b/dagster_mssql_bcp_tests/bcp_polars/test_bcp.py index 5664167..ad08f53 100644 --- a/dagster_mssql_bcp_tests/bcp_polars/test_bcp.py +++ b/dagster_mssql_bcp_tests/bcp_polars/test_bcp.py @@ -249,7 +249,7 @@ def test_replace_values(self, polars_io): {"a": ["1,000", "2", "3"], "b": [4000, 5, 6], "c": ["a", "b", "c"]} ) expected = df = pl.DataFrame( - {"a": ["1000", "2", "3"], "b": [4000, 5, 6], "c": ["a", "b", "c"]} + {"a": ["1000", "2", "3"], "b": [4000, 5, 6], "c": ["a", "b", "c"], 'should_process_replacements': [0, 0, 0]} ) schema = polars_mssql_bcp.AssetSchema( [ @@ -264,7 +264,7 @@ def test_replace_values(self, polars_io): df = pl.DataFrame( {"c": ["nan", "NAN", "c", "abc\tdef", "abc\t\ndef", "abc\ndef", "nan", "somenanthing"]} ) - expected = df = pl.DataFrame( + expected = pl.DataFrame( { "c": [ "", @@ -275,6 +275,9 @@ def test_replace_values(self, polars_io): "abc__NEWLINE__def", "", "somenanthing" + ], + 'should_process_replacements': [ + 0, 0, 0, 1, 1, 1, 0, 0 ] } ) @@ -304,6 +307,7 @@ def test_replace_values(self, polars_io): # "2021-01-01 00:00:00-05:00", ], "d": ["2021-01-01 00:00:00-05:00", "2021-01-01 00:00:00-05:00"], + "should_process_replacements": [0, 0] } ) diff --git a/pyproject.toml b/pyproject.toml index 5490962..a38a24d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "dagster-mssql-bcp" -version = "0.0.8" +version = "0.0.9" dependencies = [ "dagster", diff --git a/src/dagster_mssql_bcp/bcp_core/bcp_core.py b/src/dagster_mssql_bcp/bcp_core/bcp_core.py index 63ec336..040e85f 100644 --- a/src/dagster_mssql_bcp/bcp_core/bcp_core.py +++ b/src/dagster_mssql_bcp/bcp_core/bcp_core.py @@ -203,71 +203,97 @@ def load_bcp( if process_replacements is None: process_replacements = self.process_replacements - connection_config_dict = self.connection_config - asset_schema = self._parse_asset_schema(schema, table, asset_schema) - if not isinstance(asset_schema, AssetSchema): - raise ValueError("Invalid Asset Schema provided") - - data = self._rename_columns(data, asset_schema.get_rename_dict()) - if uuid is None: uuid = str(uuid4()) uuid_table = uuid.replace("-", "_") staging_table = f"{table}_staging_{uuid_table}" + get_dagster_logger().debug('renaming columns') + data = self._rename_columns(data, asset_schema.get_rename_dict()) + + get_dagster_logger().debug('adding meta to asset schema') self._add_meta_to_asset_schema( asset_schema, add_row_hash, add_load_datetime, add_load_uuid ) - with connect_mssql(connection_config_dict) as connection: - self._create_target_tables( - schema, table, asset_schema, staging_table, connection - ) - - data = self._pre_start_hook( - data + with connect_mssql(self.connection_config) as connection: + get_dagster_logger().debug('pre-bcp stage') + data, schema_deltas = self._pre_bcp_stage( + connection=connection, + data=data, + schema=schema, + table=table, + asset_schema=asset_schema, + add_row_hash=add_row_hash, + add_load_datetime=add_load_datetime, + add_load_uuid=add_load_uuid, + uuid=uuid, + process_datetime=process_datetime, + process_replacements=process_replacements, + staging_table=staging_table, ) - data, schema_deltas = self._pre_bcp_stage( + self._bcp_stage(data, schema, staging_table) + + with connect_mssql(self.connection_config) as connection: + new_line_count = self._post_bcp_stage( connection, data, schema, table, + staging_table, asset_schema, add_row_hash, - add_load_datetime, - add_load_uuid, - uuid, - process_datetime, process_replacements, ) - data = self._pre_bcp_stage_completed_hook( - data - ) - - self._bcp_stage(data, schema, staging_table) + return { + "uuid": uuid, + "schema_table_name": f"{schema}.{table}", + "row_count": new_line_count, + "schema_deltas": schema_deltas, + } - new_line_count = self._post_bcp_stage( + def _pre_bcp_stage( + self, + connection, + data, + schema, + table, + asset_schema, + add_row_hash, + add_load_datetime, + add_load_uuid, + uuid, + process_datetime, + process_replacements, + staging_table, + ): + + data = self._pre_prcessing_start_hook(data) + self._create_target_tables( + schema, table, asset_schema, staging_table, connection + ) + data, schema_deltas = self._standarize_input_data( + connection, data, schema, table, - staging_table, asset_schema, add_row_hash, + add_load_datetime, + add_load_uuid, + uuid, + process_datetime, process_replacements, - connection_config_dict, ) + data = self._pre_processing_complete_hook(data) - return { - "uuid": uuid, - "row_count": new_line_count, - "schema_deltas": schema_deltas, - } - - def _pre_bcp_stage( + return data, schema_deltas + + def _standarize_input_data( self, connection: Connection, data, @@ -303,25 +329,27 @@ def _pre_bcp_stage( frame_columns, asset_schema.get_columns(), sql_structure ) - # Filter columns that are not in the json schema (evolution) + # Filter columns that are not in the json schema (evolution) data = self._filter_columns(data, asset_schema.get_columns(True)) - sql_structure = sql_structure or frame_columns - data = self._reorder_columns(data, sql_structure) + # sql_structure = sql_structure or frame_columns + data = self._reorder_columns(data, asset_schema.get_columns(True)) + data = self._add_replacement_flag_column(data) if process_replacements: data = self._replace_values(data, asset_schema) if process_datetime: data = self._process_datetime(data, asset_schema) return data, schema_deltas - + def _bcp_stage(self, data, schema, staging_table): + get_dagster_logger().debug('bcp stage') with TemporaryDirectory() as temp_dir: temp_dir = Path(temp_dir) format_file = temp_dir / f"{staging_table}_format_file.fmt" error_file = temp_dir / f"{staging_table}_error_file.err" csv_file = self._save_csv(data, temp_dir, f"{staging_table}.csv") - + self._generate_format_file(schema, staging_table, format_file) self._insert_with_bcp( schema, @@ -333,6 +361,7 @@ def _bcp_stage(self, data, schema, staging_table): def _post_bcp_stage( self, + connection: Connection, data, schema, table, @@ -340,34 +369,33 @@ def _post_bcp_stage( asset_schema, add_row_hash, process_replacements, - connection_config_dict, ): - with connect_mssql(connection_config_dict) as con: - # Validate loads (counts of tables match) - new_line_count = self._validate_bcp_load( - con, schema, staging_table, None - ) - - if process_replacements: - self._replace_temporary_tab_newline( - con, schema, staging_table, asset_schema - ) + get_dagster_logger().debug('post-bcp stage') + # Validate loads (counts of tables match) + new_line_count = self._validate_bcp_load( + connection, schema, staging_table, None + ) - if add_row_hash: - self._calculate_row_hash( - con, schema, staging_table, asset_schema.get_hash_columns() - ) + if process_replacements: + self._replace_temporary_tab_newline( + connection, schema, staging_table, asset_schema + ) - self._insert_and_drop_bcp_table( - con, schema, table, staging_table, asset_schema + if add_row_hash: + self._calculate_row_hash( + connection, schema, staging_table, asset_schema.get_hash_columns() ) + self._insert_and_drop_bcp_table( + connection, schema, table, staging_table, asset_schema + ) + return new_line_count - def _pre_bcp_stage_completed_hook(self, dataframe): + def _pre_processing_complete_hook(self, dataframe): return dataframe - def _pre_start_hook(self, dataframe): + def _pre_prcessing_start_hook(self, dataframe): return dataframe def _parse_asset_schema(self, schema, table, asset_schema): @@ -400,6 +428,14 @@ def _parse_asset_schema(self, schema, table, asset_schema): ) elif isinstance(asset_schema, list): asset_schema = AssetSchema(asset_schema) + elif isinstance(asset_schema, AssetSchema): + asset_schema = asset_schema + else: + asset_schema = None + + if asset_schema is None: + raise ValueError("No data table provided in metadata") + return asset_schema # region pre load @@ -466,7 +502,7 @@ def _create_target_tables( connection, schema, staging_table, - asset_schema.get_sql_columns(True), + asset_schema.get_sql_columns(True) + ["should_process_replacements BIT"], ) @abstractmethod @@ -703,6 +739,7 @@ def _get_sql_columns( result = [row[0] for row in connection.execute(text(sql))] result = None if len(result) == 0 else result return result + @abstractmethod def _add_identity_columns(self, data, asset_schema: AssetSchema): """ @@ -713,7 +750,7 @@ def _add_identity_columns(self, data, asset_schema: AssetSchema): asset_schema (AssetSchema): The schema that defines the identity columns. """ raise NotImplementedError - + @abstractmethod def _get_frame_columns(self, data) -> list[str]: """ @@ -922,6 +959,8 @@ def _replace_temporary_tab_newline( UPDATE {schema}.{table} SET {set_columns} + WHERE + should_process_replacements = 1 """ update_sql_str = update_sql.format( @@ -1006,3 +1045,11 @@ def _calculate_row_hash( connection.execute(text(update_sql)) # endregion + + @abstractmethod + def _add_replacement_flag_column(self, data): + """ + Adds a bit column, `should_replace`, to indicate if that row should have the REPLACE applied. + Replace is applied for tabs and new lines only + """ + raise NotImplementedError diff --git a/src/dagster_mssql_bcp/bcp_core/bcp_io_manager_core.py b/src/dagster_mssql_bcp/bcp_core/bcp_io_manager_core.py index 21a3dd3..1495acc 100644 --- a/src/dagster_mssql_bcp/bcp_core/bcp_io_manager_core.py +++ b/src/dagster_mssql_bcp/bcp_core/bcp_io_manager_core.py @@ -1,6 +1,11 @@ from uuid import uuid4 -from dagster import ConfigurableIOManager, InputContext, OutputContext, get_dagster_logger +from dagster import ( + ConfigurableIOManager, + InputContext, + OutputContext, + get_dagster_logger, +) from abc import abstractmethod, ABC from .asset_schema import AssetSchema @@ -9,6 +14,7 @@ from .bcp_core import BCPCore + class BCPIOManagerCore(ConfigurableIOManager, ABC): host: str port: str @@ -63,9 +69,7 @@ def handle_output(self, context: OutputContext, obj): get_dagster_logger().info("No data to load") return - bcp_manager = self.get_bcp( - **self.config - ) + bcp_manager = self.get_bcp(**self.config) metadata = ( context.definition_metadata @@ -74,29 +78,30 @@ def handle_output(self, context: OutputContext, obj): ) if len(context.asset_key.path) < 2: - schema = 'dbo' + schema = "dbo" table = context.asset_key.path[-1] else: schema, table = context.asset_key.path[-2], context.asset_key.path[-1] - + schema = metadata.get("schema", schema) table = metadata.get("table", table) - asset_schema = metadata.get("asset_schema") - if asset_schema is None: - raise ValueError("No data table provided in metadata") - asset_schema = AssetSchema(asset_schema) + asset_schema = AssetSchema(metadata.get("asset_schema")) add_row_hash = metadata.get("add_row_hash", True) add_load_datetime = metadata.get("add_load_datetime", True) add_load_uuid = metadata.get("add_load_uuid", True) process_datetime = metadata.get("process_datetime", self.process_datetime) - process_replacements = metadata.get("process_replacements", self.process_replacements) + process_replacements = metadata.get( + "process_replacements", self.process_replacements + ) uuid = str(uuid4()) uuid_table = uuid.replace("-", "_").split("_")[0] - io_table = f"{table}__io__{uuid_table}" + staging_Table = f"{table}__io__{uuid_table}" + + obj = bcp_manager._rename_columns(obj, asset_schema.get_rename_dict()) asset_schema = bcp_manager._add_meta_to_asset_schema( asset_schema, @@ -105,51 +110,37 @@ def handle_output(self, context: OutputContext, obj): add_load_uuid=add_load_uuid, ) - # create the table with connect_mssql(bcp_manager.connection_config) as connection: - # if the table doesn't exist do this otherwise select 1=0 - result = connection.exec_driver_sql( - f"""SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{schema}' AND TABLE_NAME = '{table}'""" + data, schema_deltas = bcp_manager._pre_bcp_stage( + connection=connection, + data=obj, + schema=schema, + table=table, + asset_schema=asset_schema, + add_row_hash=add_row_hash, + add_load_datetime=add_load_datetime, + add_load_uuid=add_load_uuid, + uuid=uuid, + process_datetime=process_datetime, + process_replacements=process_replacements, + staging_table=staging_Table, ) - if len(result.fetchall()) == 0: - bcp_manager._create_schema(connection, schema) - bcp_manager._create_table( - connection=connection, - schema=schema, - table=table, - columns=asset_schema.get_sql_columns(), - ) - - - results = bcp_manager.load_bcp( - data=obj, - schema=schema, - table=io_table, - asset_schema=asset_schema, - add_row_hash=add_row_hash, - add_load_datetime=add_load_datetime, - add_load_uuid=add_load_uuid, - uuid=uuid, - process_datetime=process_datetime, - process_replacements=process_replacements, - ) - uuid_value, row_count, deltas = ( - results["uuid"], - results["row_count"], - results["schema_deltas"], - ) - asset_schema_columns_str = ",".join(asset_schema.get_columns()) + bcp_manager._bcp_stage(data, schema, staging_Table) + with connect_mssql(bcp_manager.connection_config) as connection: cleanup_sql = get_cleanup_statement(table, schema, context) connection.exec_driver_sql(cleanup_sql) - connection.exec_driver_sql( - f""" - INSERT INTO {schema}.{table} ({asset_schema_columns_str}) - SELECT {asset_schema_columns_str} - FROM {schema}.{io_table}""" + row_count = bcp_manager._post_bcp_stage( + connection=connection, + data=obj, + schema=schema, + table=table, + staging_table=staging_Table, + asset_schema=asset_schema, + add_row_hash=add_row_hash, + process_replacements=process_replacements, ) - connection.exec_driver_sql(f"DROP TABLE {schema}.{io_table}") context.add_output_metadata( dict( @@ -159,10 +150,10 @@ def handle_output(self, context: OutputContext, obj): context, (context.definition_metadata or {}).get("columns"), ), - uuid_query=f"SELECT * FROM {schema}.{table} WHERE load_uuid = '{uuid_value}'", + uuid_query=f"SELECT * FROM {schema}.{table} WHERE load_uuid = '{uuid}'", row_count=row_count, ) - | deltas + | schema_deltas ) @abstractmethod diff --git a/src/dagster_mssql_bcp/bcp_pandas/pandas_mssql_bcp.py b/src/dagster_mssql_bcp/bcp_pandas/pandas_mssql_bcp.py index 96228ba..4789b29 100644 --- a/src/dagster_mssql_bcp/bcp_pandas/pandas_mssql_bcp.py +++ b/src/dagster_mssql_bcp/bcp_pandas/pandas_mssql_bcp.py @@ -12,7 +12,6 @@ from dagster_mssql_bcp.bcp_core import AssetSchema, BCPCore - class PandasBCP(BCPCore): def _add_meta_columns( self, @@ -122,13 +121,17 @@ def _filter_columns(self, data: pd.DataFrame, columns: list[str]): def _rename_columns(self, data: pd.DataFrame, columns: dict) -> pd.DataFrame: return data.rename(columns=columns) - - def _add_identity_columns(self, data: pd.DataFrame, asset_schema: AssetSchema) -> pd.DataFrame: + def _add_identity_columns( + self, data: pd.DataFrame, asset_schema: AssetSchema + ) -> pd.DataFrame: ident_cols = asset_schema.get_identity_columns() - missing_idents = [ - _ for _ in ident_cols if _ not in data.columns - ] + missing_idents = [_ for _ in ident_cols if _ not in data.columns] for _ in missing_idents: data[_] = None - - return data \ No newline at end of file + + return data + + def _add_replacement_flag_column(self, data: pd.DataFrame): + # we just set this to 1 to force all rows to participate + data["should_process_replacements"] = 1 + return data diff --git a/src/dagster_mssql_bcp/bcp_polars/polars_mssql_bcp.py b/src/dagster_mssql_bcp/bcp_polars/polars_mssql_bcp.py index 787857a..0efdf8d 100644 --- a/src/dagster_mssql_bcp/bcp_polars/polars_mssql_bcp.py +++ b/src/dagster_mssql_bcp/bcp_polars/polars_mssql_bcp.py @@ -1,6 +1,5 @@ from pathlib import Path - import pendulum try: @@ -52,6 +51,27 @@ def _replace_values(self, data: pl.LazyFrame, asset_schema: AssetSchema): if _ in asset_schema.get_numeric_columns() ] + string_cols = data.select(cs.by_dtype(pl.String)).collect_schema().names() + + if len(string_cols) > 0: + # calculates only the rows that have replacements + data = data.with_columns( + [ + pl.col(_) + .str.contains("(\t)|(\n)") + .alias(f"{_}__bcp__has_replacement_values") + for _ in string_cols + ] + ) + + data = data.with_columns( + pl.any_horizontal( + [f"{_}__bcp__has_replacement_values" for _ in string_cols] + ).alias("should_process_replacements") + ) + + data = data.drop([f"{_}__bcp__has_replacement_values" for _ in string_cols]) + data = data.with_columns( [ pl.col(_) @@ -59,7 +79,7 @@ def _replace_values(self, data: pl.LazyFrame, asset_schema: AssetSchema): .str.replace_all("\n", "__NEWLINE__") .str.replace_all("^nan$", "") .str.replace_all("^NAN$", "") - for _ in data.select(cs.by_dtype(pl.String)).collect_schema().names() + for _ in string_cols if _ not in number_columns_that_are_strings ] + [ @@ -69,7 +89,10 @@ def _replace_values(self, data: pl.LazyFrame, asset_schema: AssetSchema): .str.replace_all("^NAN$", "") for _ in number_columns_that_are_strings ] - + [pl.col(_).cast(pl.Int64) for _ in data.select(cs.boolean()).collect_schema().names()] + + [ + pl.col(_).cast(pl.Int64) + for _ in data.select(cs.boolean()).collect_schema().names() + ] ) return data @@ -129,7 +152,9 @@ def _process_datetime( def _reorder_columns(self, data: pl.LazyFrame, column_list: list[str]): """Reorder the data frame to match the order of the columns in the SQL table.""" - column_list = [column for column in column_list if column in data.collect_schema().names()] + column_list = [ + column for column in column_list if column in data.collect_schema().names() + ] return data.select(column_list) def _save_csv(self, data: pl.LazyFrame, path: Path, file_name: str): @@ -155,9 +180,15 @@ def _add_identity_columns( self, data: pl.LazyFrame, asset_schema: AssetSchema ) -> pl.LazyFrame: ident_cols = asset_schema.get_identity_columns() - missing_idents = [_ for _ in ident_cols if _ not in data.collect_schema().names()] + missing_idents = [ + _ for _ in ident_cols if _ not in data.collect_schema().names() + ] data = data.with_columns([pl.lit(None).alias(_) for _ in missing_idents]) return data - def _pre_start_hook(self, data: pl.DataFrame): + def _pre_prcessing_start_hook(self, data: pl.DataFrame): return data.lazy() + + def _add_replacement_flag_column(self, data: pl.DataFrame): + data = data.with_columns(pl.lit(0).alias("should_process_replacements")) + return data