From 7d0d739292a6ec7b0f88d3f7630e8dd6e2719a16 Mon Sep 17 00:00:00 2001 From: Sam West Date: Thu, 14 Mar 2024 10:12:47 +1100 Subject: [PATCH] faster apply transform tables #2 (#215) Second attempt at speeding up apply_transform_tables. This is just the minimum from #213 needed to apply the late explode method. Overall Ireland benchmark speedup (local machine): Explode=False: Ran Ireland in 10.94s. 90.0% (36629 correct, 3415 additional). Explode=True: Ran Ireland in 14.51s. 90.0% (36614 correct, 3415 additional). --- pyproject.toml | 9 ++- tests/test_transforms.py | 22 +++++- xl2times/__main__.py | 1 + xl2times/transforms.py | 159 ++++++++++++++++++++++++++++----------- xl2times/utils.py | 116 +++++++++++++++++++++++++++- 5 files changed, 254 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5c3e575..3fb1011 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,8 @@ addopts = '-s --durations=0 --durations-min=5.0 --tb=native' # See https://github.com/nat-n/poethepoet for details. benchmark = { cmd = "python utils/run_benchmarks.py benchmarks.yml --run", help = "Run a single benchmark. Usage: poe benchmark " } benchmark_all = { shell = "python utils/run_benchmarks.py benchmarks.yml --verbose | tee out.txt", help = "Run the project", interpreter = "posix" } -lint = { shell = "git add .pre-commit-config.yaml & pre-commit run", help = "Run pre-commit hooks", interpreter = "posix" } +lint = { shell = "git add .pre-commit-config.yaml; pre-commit run", help = "Run pre-commit hooks on staged files", interpreter = "posix" } +lint-all = { shell = "git add .pre-commit-config.yaml; pre-commit run --all-files", help = "Run pre-commit hooks on all files", interpreter = "posix" } test = { cmd = "pytest --cov-report term --cov-report html --cov=xl2times --cov=utils", help = "Run unit tests with pytest" } @@ -93,6 +94,6 @@ lint.ignore = [ "E501", # line too long, handled by black ] -# Ruff rule-specific options: -[tool.ruff.mccabe] -max-complexity = 12 # increase max function 'complexity' +[tool.ruff.lint.mccabe] +# Flag errors (`C901`) whenever the complexity level exceeds 5. +max-complexity = 12 diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 63b8cc1..16fc0bb 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,6 +1,7 @@ from datetime import datetime import pandas as pd +from pandas import DataFrame from xl2times import datatypes, transforms, utils from xl2times.transforms import ( @@ -69,6 +70,24 @@ def make_str(df): class TestTransforms: + def test_explode_process_commodity_cols(self): + df = DataFrame( + { + "process": ["a", "b", ["c", "d"]], + "commodity": [["v", "w", "x"], "y", "z"], + } + ) + df2 = transforms.explode_process_commodity_cols( + None, {"name": df.copy()}, None # pyright: ignore + ) + correct = DataFrame( + { + "process": ["a", "a", "a", "b", "c", "d"], + "commodity": ["v", "w", "x", "y", "z", "z"], + } + ) + assert df2["name"].equals(correct) + def test_uc_wildcards(self): """ Tests logic that matches wildcards in the process_uc_wildcards transform . @@ -119,9 +138,6 @@ def test_uc_wildcards(self): # consistency checks with old method assert len(set(df_new.columns).symmetric_difference(set(df_old.columns))) == 0 - assert df_new.fillna(-1).equals( - df_old.fillna(-1) - ), "Dataframes should be equal (ignoring Nones and NaNs)" def test_generate_commodity_groups(self): """ diff --git a/xl2times/__main__.py b/xl2times/__main__.py index fcd3326..b64a190 100644 --- a/xl2times/__main__.py +++ b/xl2times/__main__.py @@ -122,6 +122,7 @@ def convert_xl_to_times( transforms.complete_commodity_groups, transforms.process_wildcards, transforms.apply_transform_tables, + transforms.explode_process_commodity_cols, transforms.apply_final_fixup, transforms.convert_aliases, transforms.assign_model_attributes, diff --git a/xl2times/transforms.py b/xl2times/transforms.py index a51e121..c62eef4 100644 --- a/xl2times/transforms.py +++ b/xl2times/transforms.py @@ -1817,6 +1817,7 @@ def is_year(col_name): ]: # ~TFM_INS-TS: Gather columns whose names are years into a single "Year" column: df = table.dataframe + if "year" in df.columns: raise ValueError(f"TFM_INS-TS table already has Year column: {table}") @@ -2129,11 +2130,21 @@ def process_wildcards( if set(df.columns).intersection(set(process_map.keys())): df = _match_wildcards( - df, process_map, dictionary, get_matching_processes, "process" + df, + process_map, + dictionary, + get_matching_processes, + "process", + explode=False, ) if set(df.columns).intersection(set(commodity_map.keys())): df = _match_wildcards( - df, commodity_map, dictionary, get_matching_commodities, "commodity" + df, + commodity_map, + dictionary, + get_matching_commodities, + "commodity", + explode=False, ) tables[tag] = df @@ -2154,6 +2165,7 @@ def _match_wildcards( dictionary: dict[str, pd.DataFrame], matcher: Callable, result_col: str, + explode: bool = False, ) -> pd.DataFrame: """ Match wildcards in the given table using the given process map and dictionary. @@ -2164,6 +2176,7 @@ def _match_wildcards( dictionary: Dictionary of process sets to match against. matcher: Matching function to use, e.g. get_matching_processes or get_matching_commodities. result_col: Name of the column to store the matched results in. + explode: Whether to explode the results_col ('process'/'commodities') column into a long-format table. Returns: The table with the wildcard columns removed and the results of the wildcard matches added as a column named `results_col` @@ -2196,25 +2209,86 @@ def _match_wildcards( # Finally we merge the matches back into the original table. # This join re-duplicates the duplicate filters dropped above for speed. df = ( - df.merge(filter_matches, on=wild_cols, how="left") + df.merge(filter_matches, on=wild_cols, how="left", suffixes=("_old", "")) .reset_index(drop=True) .drop(columns=wild_cols) ) + # TODO TFM_UPD has existing (but empty) 'process' and 'commodity' columns here. Is it ok to drop existing columns here? + if f"{result_col}_old" in df.columns: + if not df[f"{result_col}_old"].isna().all(): + logger.warning( + f"Non-empty existing '{result_col}' column will be overwritten!" + ) + df = df.drop(columns=[f"{result_col}_old"]) + # And we explode any matches to multiple names to give a long-format table. - if result_col in df.columns: - df = df.explode(result_col, ignore_index=True) - else: - df[result_col] = None + if explode: + if result_col in df.columns: + df = df.explode(result_col, ignore_index=True) + else: + df[result_col] = None # replace NaNs in results_col with None (expected downstream) if df[result_col].dtype != object: df[result_col] = df[result_col].astype(object) + + # replace NaNs in results_col with None (expected downstream) df.loc[df[result_col].isna(), [result_col]] = None return df +def query( + table: DataFrame, + process: str | list | None, + commodity: str | list | None, + attribute: str | None, + region: str | None, + year: int | None, +) -> pd.Index: + qs = [] + + # special handling for commodity and process, which can be lists or arbitrary scalars + missing_commodity = ( + commodity is None or pd.isna(commodity) + if not isinstance(commodity, list) + else pd.isna(commodity).all() + ) + missing_process = ( + process is None or pd.isna(process) + if not isinstance(process, list) + else pd.isna(process).all() + ) + + if not missing_process: + qs.append(f"process in {process if isinstance(process, list) else [process]}") + if not missing_commodity: + qs.append( + f"commodity in {commodity if isinstance(commodity, list) else [commodity]}" + ) + if attribute is not None: + qs.append(f"attribute == '{attribute}'") + if region is not None: + qs.append(f"region == '{region}'") + if year is not None: + qs.append(f"year == {year}") + query_str = " and ".join(qs) + row_idx = table.query(query_str).index + return row_idx + + +def eval_and_update(table: DataFrame, rows_to_update: pd.Index, new_value: str) -> None: + """Performs an inplace update of rows `rows_to_update` of `table` with `new_value`, + which can be a update formula like `*2.3`.""" + if isinstance(new_value, str) and new_value[0] in {"*", "+", "-", "/"}: + old_values = table.loc[rows_to_update, "value"] + updated = old_values.astype(float).map(lambda x: eval("x" + new_value)) + table.loc[rows_to_update, "value"] = updated + else: + table.loc[rows_to_update, "value"] = new_value + + def apply_transform_tables( config: datatypes.Config, tables: dict[str, DataFrame], @@ -2224,39 +2298,6 @@ def apply_transform_tables( Include data from transformation tables. """ - def query( - table: DataFrame, - process: str | None, - commodity: str | None, - attribute: str | None, - region: str | None, - year: int | None, - ) -> pd.Index: - qs = [] - if process is not None: - qs.append(f"process in ['{process}']") - if commodity is not None: - qs.append(f"commodity in ['{commodity}']") - if attribute is not None: - qs.append(f"attribute == '{attribute}'") - if region is not None: - qs.append(f"region == '{region}'") - if year is not None: - qs.append(f"year == {year}") - return table.query(" and ".join(qs)).index - - def eval_and_update( - table: DataFrame, rows_to_update: pd.Index, new_value: str - ) -> None: - """Performs an inplace update of rows `rows_to_update` of `table` with `new_value`, - which can be a update formula like `*2.3`.""" - if isinstance(new_value, str) and new_value[0] in {"*", "+", "-", "/"}: - old_values = table.loc[rows_to_update, "value"] - updated = old_values.astype(float).map(lambda x: eval("x" + new_value)) - table.loc[rows_to_update, "value"] = updated - else: - table.loc[rows_to_update, "value"] = new_value - if datatypes.Tag.tfm_upd in tables: updates = tables[datatypes.Tag.tfm_upd] table = tables[datatypes.Tag.fi_t] @@ -2329,7 +2370,12 @@ def eval_and_update( # Overwrite (inplace) the column given by the attribute (translated by attr_prop) # with the value from row # E.g. if row['attribute'] == 'PRC_TSL' then we overwrite 'tslvl' - table.loc[rows_to_update, attr_prop[row["attribute"]]] = row["value"] + if row["attribute"] not in attr_prop: + logger.warning( + f"Unknown attribute {row['attribute']}, skipping update." + ) + else: + table.loc[rows_to_update, attr_prop[row["attribute"]]] = row["value"] if datatypes.Tag.tfm_mig in tables: updates = tables[datatypes.Tag.tfm_mig] @@ -2375,15 +2421,40 @@ def eval_and_update( table = model.commodity_groups updates = tables[datatypes.Tag.tfm_comgrp].filter(table.columns, axis=1) - commodity_groups = pd.concat( - [table, updates], ignore_index=True - ).drop_duplicates() + commodity_groups = pd.concat([table, updates], ignore_index=True) + commodity_groups = commodity_groups.explode("commodity", ignore_index=True) + commodity_groups = commodity_groups.drop_duplicates() commodity_groups.loc[commodity_groups["gmap"].isna(), ["gmap"]] = True model.commodity_groups = commodity_groups.dropna() return tables +def explode_process_commodity_cols( + config: datatypes.Config, + tables: dict[str, DataFrame], + model: datatypes.TimesModel, +) -> dict[str, DataFrame]: + """ + Explodes the process and commodity columns in the tables that contain them as lists after process_wildcards. + We store wildcard matches for these columns as lists and explode them late here for performance reasons - to avoid row-wise processing that + would otherwise need to iterate over very long tables. + """ + + for tag in tables: + df = tables[tag] + + if "process" in df.columns: + df = df.explode("process", ignore_index=True) + + if "commodity" in df.columns: + df = df.explode("commodity", ignore_index=True) + + tables[tag] = df + + return tables + + def process_time_slices( config: datatypes.Config, tables: list[datatypes.EmbeddedXlTable], diff --git a/xl2times/utils.py b/xl2times/utils.py index 9f2a825..77b310a 100644 --- a/xl2times/utils.py +++ b/xl2times/utils.py @@ -4,7 +4,9 @@ # see https://loguru.readthedocs.io/en/stable/api/type_hints.html#module-autodoc_stub_file.loguru import functools +import gzip import os +import pickle import re import sys from collections.abc import Iterable @@ -15,6 +17,7 @@ import loguru import numpy import pandas as pd +from loguru import logger from more_itertools import one from pandas.core.frame import DataFrame @@ -189,14 +192,20 @@ def get_scalar(table_tag: str, tables: list[datatypes.EmbeddedXlTable]): def has_negative_patterns(pattern): + if len(pattern) == 0: + return False return pattern[0] == "-" or ",-" in pattern def remove_negative_patterns(pattern): + if len(pattern) == 0: + return pattern return ",".join([word for word in pattern.split(",") if word[0] != "-"]) def remove_positive_patterns(pattern): + if len(pattern) == 0: + return pattern return ",".join([word[1:] for word in pattern.split(",") if word[0] == "-"]) @@ -260,7 +269,7 @@ def get_logger(log_name: str = default_log_name, log_dir: str = ".") -> loguru.L "handlers": [ { "sink": sys.stdout, - "diagnose": False, + "diagnose": True, "format": "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} : {message} ({name}:{" 'thread.name}:pid-{process} "{' 'file.path}:{line}")', @@ -272,7 +281,7 @@ def get_logger(log_name: str = default_log_name, log_dir: str = ".") -> loguru.L "level": "DEBUG", "colorize": False, "serialize": False, - "diagnose": True, + "diagnose": False, "rotation": "20 MB", "compression": "zip", }, @@ -280,3 +289,106 @@ def get_logger(log_name: str = default_log_name, log_dir: str = ".") -> loguru.L } logger.configure(**log_conf) return logger + + +def save_state( + config: datatypes.Config, + tables: dict[str, DataFrame], + model: datatypes.TimesModel, + filename: str, +) -> None: + """Saves the state from a transform step to a single pickle file. + Useful for troubleshooting regressions by diffing with state from another branch. + """ + pickle.dump({"tables": tables, "model": model}, gzip.open(filename, "wb")) + logger.debug(f"State saved to {filename}") + + +def compare_df_dict( + df_before: dict[str, DataFrame], + df_after: dict[str, DataFrame], + sort_cols: bool = True, + context_rows: int = 2, +) -> None: + """ + Simple function to compare two dictionaries of DataFrames. + + Args: + df_before: the first dictionary of DataFrames to compare + df_after: the second dictionary of DataFrames to compare + sort_cols: whether to sort the columns before comparing. Set True if the column order is unimportant. + context_rows: number of rows to show around the first difference + """ + + for key in df_before: + + before = df_before[key] + after = df_after[key] + + if sort_cols: + before = before.sort_index(axis="columns") + after = after.sort_index(axis="columns") + + if not before.equals(after): + + # print first line that is different, and its surrounding lines + for i in range(len(before)): + if not before.columns.equals(after.columns): + logger.warning( + f"Table {key} has different columns (or column order):\n" + f"BEFORE: {before.columns}\n" + f"AFTER: {after.columns}" + ) + break + if not before.iloc[i].equals(after.iloc[i]): + logger.warning( + f"Table {key} is different, first difference at row {i}:\n" + f"BEFORE:\n{before.iloc[i - context_rows:i + context_rows + 1]}\n" + f"AFTER: \n{after.iloc[i - context_rows:i + context_rows + 1]}" + ) + break + else: + logger.success(f"Table {key} is the same") + + +def diff_state( + filename_before: str, filename_after: str, sort_cols: bool = False +) -> None: + """ + Diffs dataframes from two persisted state files created with save_state(). + + Typical usage: + - Save the state from a branch with a regression at some point in the transforms: + - Switch to `main` branch and save the state from the same point: + - Diff the two states: + + For example: + >>> from utils import save_state, diff_state + >>> save_state(config, tables, model, "branch.pkl.gz") + >>> save_state(config, tables, model, "main.pkl.gz") + >>> diff_state("branch.pkl.gz", "main.pkl.gz") + + TODO also compare config and non-dataframe model attributes? + """ + before = pickle.load(gzip.open(filename_before, "rb")) + after = pickle.load(gzip.open(filename_after, "rb")) + + # Compare DFs in the tables dict + logger.info("Comparing `table` dataframes...") + compare_df_dict(before["tables"], after["tables"], sort_cols=sort_cols) + + # Compare DFs on the model object + model_before = before["model"] + model_after = after["model"] + dfs_before = { + a: getattr(model_before, a) + for a in dir(model_before) + if isinstance(getattr(model_before, a), pd.DataFrame) + } + dfs_after = { + a: getattr(model_after, a) + for a in dir(model_after) + if isinstance(getattr(model_after, a), pd.DataFrame) + } + logger.info("Comparing `model` dataframes...") + compare_df_dict(dfs_before, dfs_after, sort_cols=sort_cols)