Skip to content

Commit

Permalink
Merge pull request #14 from cody-scott/update_only_changes
Browse files Browse the repository at this point in the history
added update logic
  • Loading branch information
cody-scott authored Nov 5, 2024
2 parents 9274325 + bdd5836 commit 245ffaa
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 129 deletions.
8 changes: 6 additions & 2 deletions dagster_mssql_bcp_tests/bcp_polars/test_bcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_replace_values(self, polars_io):
{"a": ["1,000", "2", "3"], "b": [4000, 5, 6], "c": ["a", "b", "c"]}
)
expected = df = pl.DataFrame(
{"a": ["1000", "2", "3"], "b": [4000, 5, 6], "c": ["a", "b", "c"]}
{"a": ["1000", "2", "3"], "b": [4000, 5, 6], "c": ["a", "b", "c"], 'should_process_replacements': [0, 0, 0]}
)
schema = polars_mssql_bcp.AssetSchema(
[
Expand All @@ -264,7 +264,7 @@ def test_replace_values(self, polars_io):
df = pl.DataFrame(
{"c": ["nan", "NAN", "c", "abc\tdef", "abc\t\ndef", "abc\ndef", "nan", "somenanthing"]}
)
expected = df = pl.DataFrame(
expected = pl.DataFrame(
{
"c": [
"",
Expand All @@ -275,6 +275,9 @@ def test_replace_values(self, polars_io):
"abc__NEWLINE__def",
"",
"somenanthing"
],
'should_process_replacements': [
0, 0, 0, 1, 1, 1, 0, 0
]
}
)
Expand Down Expand Up @@ -304,6 +307,7 @@ def test_replace_values(self, polars_io):
# "2021-01-01 00:00:00-05:00",
],
"d": ["2021-01-01 00:00:00-05:00", "2021-01-01 00:00:00-05:00"],
"should_process_replacements": [0, 0]
}
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "dagster-mssql-bcp"
version = "0.0.8"
version = "0.0.9"
dependencies = [
"dagster",

Expand Down
167 changes: 107 additions & 60 deletions src/dagster_mssql_bcp/bcp_core/bcp_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,71 +203,97 @@ def load_bcp(
if process_replacements is None:
process_replacements = self.process_replacements

connection_config_dict = self.connection_config

asset_schema = self._parse_asset_schema(schema, table, asset_schema)

if not isinstance(asset_schema, AssetSchema):
raise ValueError("Invalid Asset Schema provided")

data = self._rename_columns(data, asset_schema.get_rename_dict())

if uuid is None:
uuid = str(uuid4())
uuid_table = uuid.replace("-", "_")
staging_table = f"{table}_staging_{uuid_table}"

get_dagster_logger().debug('renaming columns')
data = self._rename_columns(data, asset_schema.get_rename_dict())

get_dagster_logger().debug('adding meta to asset schema')
self._add_meta_to_asset_schema(
asset_schema, add_row_hash, add_load_datetime, add_load_uuid
)

with connect_mssql(connection_config_dict) as connection:
self._create_target_tables(
schema, table, asset_schema, staging_table, connection
)

data = self._pre_start_hook(
data
with connect_mssql(self.connection_config) as connection:
get_dagster_logger().debug('pre-bcp stage')
data, schema_deltas = self._pre_bcp_stage(
connection=connection,
data=data,
schema=schema,
table=table,
asset_schema=asset_schema,
add_row_hash=add_row_hash,
add_load_datetime=add_load_datetime,
add_load_uuid=add_load_uuid,
uuid=uuid,
process_datetime=process_datetime,
process_replacements=process_replacements,
staging_table=staging_table,
)

data, schema_deltas = self._pre_bcp_stage(
self._bcp_stage(data, schema, staging_table)

with connect_mssql(self.connection_config) as connection:
new_line_count = self._post_bcp_stage(
connection,
data,
schema,
table,
staging_table,
asset_schema,
add_row_hash,
add_load_datetime,
add_load_uuid,
uuid,
process_datetime,
process_replacements,
)

data = self._pre_bcp_stage_completed_hook(
data
)

self._bcp_stage(data, schema, staging_table)
return {
"uuid": uuid,
"schema_table_name": f"{schema}.{table}",
"row_count": new_line_count,
"schema_deltas": schema_deltas,
}

new_line_count = self._post_bcp_stage(
def _pre_bcp_stage(
self,
connection,
data,
schema,
table,
asset_schema,
add_row_hash,
add_load_datetime,
add_load_uuid,
uuid,
process_datetime,
process_replacements,
staging_table,
):

data = self._pre_prcessing_start_hook(data)
self._create_target_tables(
schema, table, asset_schema, staging_table, connection
)
data, schema_deltas = self._standarize_input_data(
connection,
data,
schema,
table,
staging_table,
asset_schema,
add_row_hash,
add_load_datetime,
add_load_uuid,
uuid,
process_datetime,
process_replacements,
connection_config_dict,
)
data = self._pre_processing_complete_hook(data)

return {
"uuid": uuid,
"row_count": new_line_count,
"schema_deltas": schema_deltas,
}

def _pre_bcp_stage(
return data, schema_deltas

def _standarize_input_data(
self,
connection: Connection,
data,
Expand Down Expand Up @@ -303,25 +329,27 @@ def _pre_bcp_stage(
frame_columns, asset_schema.get_columns(), sql_structure
)

# Filter columns that are not in the json schema (evolution)
# Filter columns that are not in the json schema (evolution)
data = self._filter_columns(data, asset_schema.get_columns(True))
sql_structure = sql_structure or frame_columns
data = self._reorder_columns(data, sql_structure)
# sql_structure = sql_structure or frame_columns
data = self._reorder_columns(data, asset_schema.get_columns(True))

data = self._add_replacement_flag_column(data)
if process_replacements:
data = self._replace_values(data, asset_schema)
if process_datetime:
data = self._process_datetime(data, asset_schema)

return data, schema_deltas

def _bcp_stage(self, data, schema, staging_table):
get_dagster_logger().debug('bcp stage')
with TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
format_file = temp_dir / f"{staging_table}_format_file.fmt"
error_file = temp_dir / f"{staging_table}_error_file.err"
csv_file = self._save_csv(data, temp_dir, f"{staging_table}.csv")

self._generate_format_file(schema, staging_table, format_file)
self._insert_with_bcp(
schema,
Expand All @@ -333,41 +361,41 @@ def _bcp_stage(self, data, schema, staging_table):

def _post_bcp_stage(
self,
connection: Connection,
data,
schema,
table,
staging_table,
asset_schema,
add_row_hash,
process_replacements,
connection_config_dict,
):
with connect_mssql(connection_config_dict) as con:
# Validate loads (counts of tables match)
new_line_count = self._validate_bcp_load(
con, schema, staging_table, None
)

if process_replacements:
self._replace_temporary_tab_newline(
con, schema, staging_table, asset_schema
)
get_dagster_logger().debug('post-bcp stage')
# Validate loads (counts of tables match)
new_line_count = self._validate_bcp_load(
connection, schema, staging_table, None
)

if add_row_hash:
self._calculate_row_hash(
con, schema, staging_table, asset_schema.get_hash_columns()
)
if process_replacements:
self._replace_temporary_tab_newline(
connection, schema, staging_table, asset_schema
)

self._insert_and_drop_bcp_table(
con, schema, table, staging_table, asset_schema
if add_row_hash:
self._calculate_row_hash(
connection, schema, staging_table, asset_schema.get_hash_columns()
)

self._insert_and_drop_bcp_table(
connection, schema, table, staging_table, asset_schema
)

return new_line_count

def _pre_bcp_stage_completed_hook(self, dataframe):
def _pre_processing_complete_hook(self, dataframe):
return dataframe

def _pre_start_hook(self, dataframe):
def _pre_prcessing_start_hook(self, dataframe):
return dataframe

def _parse_asset_schema(self, schema, table, asset_schema):
Expand Down Expand Up @@ -400,6 +428,14 @@ def _parse_asset_schema(self, schema, table, asset_schema):
)
elif isinstance(asset_schema, list):
asset_schema = AssetSchema(asset_schema)
elif isinstance(asset_schema, AssetSchema):
asset_schema = asset_schema
else:
asset_schema = None

if asset_schema is None:
raise ValueError("No data table provided in metadata")

return asset_schema

# region pre load
Expand Down Expand Up @@ -466,7 +502,7 @@ def _create_target_tables(
connection,
schema,
staging_table,
asset_schema.get_sql_columns(True),
asset_schema.get_sql_columns(True) + ["should_process_replacements BIT"],
)

@abstractmethod
Expand Down Expand Up @@ -703,6 +739,7 @@ def _get_sql_columns(
result = [row[0] for row in connection.execute(text(sql))]
result = None if len(result) == 0 else result
return result

@abstractmethod
def _add_identity_columns(self, data, asset_schema: AssetSchema):
"""
Expand All @@ -713,7 +750,7 @@ def _add_identity_columns(self, data, asset_schema: AssetSchema):
asset_schema (AssetSchema): The schema that defines the identity columns.
"""
raise NotImplementedError

@abstractmethod
def _get_frame_columns(self, data) -> list[str]:
"""
Expand Down Expand Up @@ -922,6 +959,8 @@ def _replace_temporary_tab_newline(
UPDATE {schema}.{table}
SET
{set_columns}
WHERE
should_process_replacements = 1
"""

update_sql_str = update_sql.format(
Expand Down Expand Up @@ -1006,3 +1045,11 @@ def _calculate_row_hash(
connection.execute(text(update_sql))

# endregion

@abstractmethod
def _add_replacement_flag_column(self, data):
"""
Adds a bit column, `should_replace`, to indicate if that row should have the REPLACE applied.
Replace is applied for tabs and new lines only
"""
raise NotImplementedError
Loading

0 comments on commit 245ffaa

Please sign in to comment.