Skip to content

Commit

Permalink
added pipeline save/diff tools
Browse files Browse the repository at this point in the history
  • Loading branch information
SamRWest committed Mar 13, 2024
1 parent ccaf53b commit b8213d6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
2 changes: 2 additions & 0 deletions xl2times/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2456,6 +2456,8 @@ def explode_process_commodity_cols(

tables[tag] = df

utils.save_state(config, tables, model, "exploded_process_commodity_cols.pkl.gz")

return tables


Expand Down
66 changes: 66 additions & 0 deletions xl2times/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
) # see https://loguru.readthedocs.io/en/stable/api/type_hints.html#module-autodoc_stub_file.loguru

import functools
import gzip
import os
import pickle
import re
import sys
from dataclasses import replace
Expand All @@ -18,6 +20,7 @@
from pandas.core.frame import DataFrame

from . import datatypes
from loguru import logger

# prevent excessive number of processes in Windows and high cpu-count machines
# TODO make this a cli param or global setting?
Expand Down Expand Up @@ -285,3 +288,66 @@ def get_logger(log_name: str = default_log_name, log_dir: str = ".") -> loguru.L
}
logger.configure(**log_conf)
return logger


def save_state(
config: datatypes.Config,
tables: dict[str, DataFrame],
model: datatypes.TimesModel,
filename: str,
) -> None:
"""Saves the state from a transform step to a single pickle file.
Useful for troubleshooting regressions by diffing with state from another branch.
"""
pickle.dump({"tables": tables, "model": model}, gzip.open(filename, "wb"))
logger.debug(f"State saved to {filename}")


def compare_df_dict(
df_main: dict[str, DataFrame], df_new: dict[str, DataFrame], sort_cols: bool = True
) -> None:
"""Simple function to compare two dictionaries of DataFrames, for troubleshooting model regressions etc."""
for key in df_main:

main = df_main[key]
new = df_new[key]

if sort_cols:
main = main.sort_index(axis="columns")
new = new.sort_index(axis="columns")

if not main.equals(new):
logger.error(f"Table {key} is different...")

# print first line that is different, and its surrounding lines
for i in range(len(main)):
if not main.iloc[i].equals(new.iloc[i]):
print(f"main: {main.iloc[i - 1:i + 2]}")
print(f"new: {new.iloc[i - 1:i + 2]}")
break
else:
logger.success(f"Table {key} is the same")


def diff_state(filename_before: str, filename_after: str) -> None:
"""Diffs two state files created with save_state()."""
before = pickle.load(gzip.open(filename_before, "rb"))
after = pickle.load(gzip.open(filename_after, "rb"))

# Compare DFs in the tables dict
compare_df_dict(before["tables"], after["tables"])

# Compare DFs on the model object
model_before = before["model"]
model_after = after["model"]
dfs_before = {
a: getattr(model_before, a)
for a in dir(model_before)
if isinstance(getattr(model_before, a), pd.DataFrame)
}
dfs_after = {
a: getattr(model_after, a)
for a in dir(model_after)
if isinstance(getattr(model_after, a), pd.DataFrame)
}
compare_df_dict(dfs_before, dfs_after)

0 comments on commit b8213d6

Please sign in to comment.