Skip to content

Commit

Permalink
faster apply transform tables #2 (#215)
Browse files Browse the repository at this point in the history
Second attempt at speeding up apply_transform_tables. This is just the
minimum from #213 needed to apply the late explode method.

Overall Ireland benchmark speedup (local machine):
Explode=False: Ran Ireland in 10.94s. 90.0% (36629 correct, 3415
additional).
Explode=True: Ran Ireland in 14.51s. 90.0% (36614 correct, 3415
additional).
  • Loading branch information
SamRWest authored Mar 13, 2024
1 parent 8acd004 commit 7d0d739
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 53 deletions.
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ addopts = '-s --durations=0 --durations-min=5.0 --tb=native'
# See https://github.com/nat-n/poethepoet for details.
benchmark = { cmd = "python utils/run_benchmarks.py benchmarks.yml --run", help = "Run a single benchmark. Usage: poe benchmark <benchmark_name>" }
benchmark_all = { shell = "python utils/run_benchmarks.py benchmarks.yml --verbose | tee out.txt", help = "Run the project", interpreter = "posix" }
lint = { shell = "git add .pre-commit-config.yaml & pre-commit run", help = "Run pre-commit hooks", interpreter = "posix" }
lint = { shell = "git add .pre-commit-config.yaml; pre-commit run", help = "Run pre-commit hooks on staged files", interpreter = "posix" }
lint-all = { shell = "git add .pre-commit-config.yaml; pre-commit run --all-files", help = "Run pre-commit hooks on all files", interpreter = "posix" }
test = { cmd = "pytest --cov-report term --cov-report html --cov=xl2times --cov=utils", help = "Run unit tests with pytest" }


Expand Down Expand Up @@ -93,6 +94,6 @@ lint.ignore = [
"E501", # line too long, handled by black
]

# Ruff rule-specific options:
[tool.ruff.mccabe]
max-complexity = 12 # increase max function 'complexity'
[tool.ruff.lint.mccabe]
# Flag errors (`C901`) whenever the complexity level exceeds 5.
max-complexity = 12
22 changes: 19 additions & 3 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime

import pandas as pd
from pandas import DataFrame

from xl2times import datatypes, transforms, utils
from xl2times.transforms import (
Expand Down Expand Up @@ -69,6 +70,24 @@ def make_str(df):


class TestTransforms:
def test_explode_process_commodity_cols(self):
df = DataFrame(
{
"process": ["a", "b", ["c", "d"]],
"commodity": [["v", "w", "x"], "y", "z"],
}
)
df2 = transforms.explode_process_commodity_cols(
None, {"name": df.copy()}, None # pyright: ignore
)
correct = DataFrame(
{
"process": ["a", "a", "a", "b", "c", "d"],
"commodity": ["v", "w", "x", "y", "z", "z"],
}
)
assert df2["name"].equals(correct)

def test_uc_wildcards(self):
"""
Tests logic that matches wildcards in the process_uc_wildcards transform .
Expand Down Expand Up @@ -119,9 +138,6 @@ def test_uc_wildcards(self):

# consistency checks with old method
assert len(set(df_new.columns).symmetric_difference(set(df_old.columns))) == 0
assert df_new.fillna(-1).equals(
df_old.fillna(-1)
), "Dataframes should be equal (ignoring Nones and NaNs)"

def test_generate_commodity_groups(self):
"""
Expand Down
1 change: 1 addition & 0 deletions xl2times/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def convert_xl_to_times(
transforms.complete_commodity_groups,
transforms.process_wildcards,
transforms.apply_transform_tables,
transforms.explode_process_commodity_cols,
transforms.apply_final_fixup,
transforms.convert_aliases,
transforms.assign_model_attributes,
Expand Down
159 changes: 115 additions & 44 deletions xl2times/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,7 @@ def is_year(col_name):
]:
# ~TFM_INS-TS: Gather columns whose names are years into a single "Year" column:
df = table.dataframe

if "year" in df.columns:
raise ValueError(f"TFM_INS-TS table already has Year column: {table}")

Expand Down Expand Up @@ -2129,11 +2130,21 @@ def process_wildcards(

if set(df.columns).intersection(set(process_map.keys())):
df = _match_wildcards(
df, process_map, dictionary, get_matching_processes, "process"
df,
process_map,
dictionary,
get_matching_processes,
"process",
explode=False,
)
if set(df.columns).intersection(set(commodity_map.keys())):
df = _match_wildcards(
df, commodity_map, dictionary, get_matching_commodities, "commodity"
df,
commodity_map,
dictionary,
get_matching_commodities,
"commodity",
explode=False,
)

tables[tag] = df
Expand All @@ -2154,6 +2165,7 @@ def _match_wildcards(
dictionary: dict[str, pd.DataFrame],
matcher: Callable,
result_col: str,
explode: bool = False,
) -> pd.DataFrame:
"""
Match wildcards in the given table using the given process map and dictionary.
Expand All @@ -2164,6 +2176,7 @@ def _match_wildcards(
dictionary: Dictionary of process sets to match against.
matcher: Matching function to use, e.g. get_matching_processes or get_matching_commodities.
result_col: Name of the column to store the matched results in.
explode: Whether to explode the results_col ('process'/'commodities') column into a long-format table.
Returns:
The table with the wildcard columns removed and the results of the wildcard matches added as a column named `results_col`
Expand Down Expand Up @@ -2196,25 +2209,86 @@ def _match_wildcards(
# Finally we merge the matches back into the original table.
# This join re-duplicates the duplicate filters dropped above for speed.
df = (
df.merge(filter_matches, on=wild_cols, how="left")
df.merge(filter_matches, on=wild_cols, how="left", suffixes=("_old", ""))
.reset_index(drop=True)
.drop(columns=wild_cols)
)

# TODO TFM_UPD has existing (but empty) 'process' and 'commodity' columns here. Is it ok to drop existing columns here?
if f"{result_col}_old" in df.columns:
if not df[f"{result_col}_old"].isna().all():
logger.warning(
f"Non-empty existing '{result_col}' column will be overwritten!"
)
df = df.drop(columns=[f"{result_col}_old"])

# And we explode any matches to multiple names to give a long-format table.
if result_col in df.columns:
df = df.explode(result_col, ignore_index=True)
else:
df[result_col] = None
if explode:
if result_col in df.columns:
df = df.explode(result_col, ignore_index=True)
else:
df[result_col] = None

# replace NaNs in results_col with None (expected downstream)
if df[result_col].dtype != object:
df[result_col] = df[result_col].astype(object)

# replace NaNs in results_col with None (expected downstream)
df.loc[df[result_col].isna(), [result_col]] = None

return df


def query(
table: DataFrame,
process: str | list | None,
commodity: str | list | None,
attribute: str | None,
region: str | None,
year: int | None,
) -> pd.Index:
qs = []

# special handling for commodity and process, which can be lists or arbitrary scalars
missing_commodity = (
commodity is None or pd.isna(commodity)
if not isinstance(commodity, list)
else pd.isna(commodity).all()
)
missing_process = (
process is None or pd.isna(process)
if not isinstance(process, list)
else pd.isna(process).all()
)

if not missing_process:
qs.append(f"process in {process if isinstance(process, list) else [process]}")
if not missing_commodity:
qs.append(
f"commodity in {commodity if isinstance(commodity, list) else [commodity]}"
)
if attribute is not None:
qs.append(f"attribute == '{attribute}'")
if region is not None:
qs.append(f"region == '{region}'")
if year is not None:
qs.append(f"year == {year}")
query_str = " and ".join(qs)
row_idx = table.query(query_str).index
return row_idx


def eval_and_update(table: DataFrame, rows_to_update: pd.Index, new_value: str) -> None:
"""Performs an inplace update of rows `rows_to_update` of `table` with `new_value`,
which can be a update formula like `*2.3`."""
if isinstance(new_value, str) and new_value[0] in {"*", "+", "-", "/"}:
old_values = table.loc[rows_to_update, "value"]
updated = old_values.astype(float).map(lambda x: eval("x" + new_value))
table.loc[rows_to_update, "value"] = updated
else:
table.loc[rows_to_update, "value"] = new_value


def apply_transform_tables(
config: datatypes.Config,
tables: dict[str, DataFrame],
Expand All @@ -2224,39 +2298,6 @@ def apply_transform_tables(
Include data from transformation tables.
"""

def query(
table: DataFrame,
process: str | None,
commodity: str | None,
attribute: str | None,
region: str | None,
year: int | None,
) -> pd.Index:
qs = []
if process is not None:
qs.append(f"process in ['{process}']")
if commodity is not None:
qs.append(f"commodity in ['{commodity}']")
if attribute is not None:
qs.append(f"attribute == '{attribute}'")
if region is not None:
qs.append(f"region == '{region}'")
if year is not None:
qs.append(f"year == {year}")
return table.query(" and ".join(qs)).index

def eval_and_update(
table: DataFrame, rows_to_update: pd.Index, new_value: str
) -> None:
"""Performs an inplace update of rows `rows_to_update` of `table` with `new_value`,
which can be a update formula like `*2.3`."""
if isinstance(new_value, str) and new_value[0] in {"*", "+", "-", "/"}:
old_values = table.loc[rows_to_update, "value"]
updated = old_values.astype(float).map(lambda x: eval("x" + new_value))
table.loc[rows_to_update, "value"] = updated
else:
table.loc[rows_to_update, "value"] = new_value

if datatypes.Tag.tfm_upd in tables:
updates = tables[datatypes.Tag.tfm_upd]
table = tables[datatypes.Tag.fi_t]
Expand Down Expand Up @@ -2329,7 +2370,12 @@ def eval_and_update(
# Overwrite (inplace) the column given by the attribute (translated by attr_prop)
# with the value from row
# E.g. if row['attribute'] == 'PRC_TSL' then we overwrite 'tslvl'
table.loc[rows_to_update, attr_prop[row["attribute"]]] = row["value"]
if row["attribute"] not in attr_prop:
logger.warning(
f"Unknown attribute {row['attribute']}, skipping update."
)
else:
table.loc[rows_to_update, attr_prop[row["attribute"]]] = row["value"]

if datatypes.Tag.tfm_mig in tables:
updates = tables[datatypes.Tag.tfm_mig]
Expand Down Expand Up @@ -2375,15 +2421,40 @@ def eval_and_update(
table = model.commodity_groups
updates = tables[datatypes.Tag.tfm_comgrp].filter(table.columns, axis=1)

commodity_groups = pd.concat(
[table, updates], ignore_index=True
).drop_duplicates()
commodity_groups = pd.concat([table, updates], ignore_index=True)
commodity_groups = commodity_groups.explode("commodity", ignore_index=True)
commodity_groups = commodity_groups.drop_duplicates()
commodity_groups.loc[commodity_groups["gmap"].isna(), ["gmap"]] = True
model.commodity_groups = commodity_groups.dropna()

return tables


def explode_process_commodity_cols(
config: datatypes.Config,
tables: dict[str, DataFrame],
model: datatypes.TimesModel,
) -> dict[str, DataFrame]:
"""
Explodes the process and commodity columns in the tables that contain them as lists after process_wildcards.
We store wildcard matches for these columns as lists and explode them late here for performance reasons - to avoid row-wise processing that
would otherwise need to iterate over very long tables.
"""

for tag in tables:
df = tables[tag]

if "process" in df.columns:
df = df.explode("process", ignore_index=True)

if "commodity" in df.columns:
df = df.explode("commodity", ignore_index=True)

tables[tag] = df

return tables


def process_time_slices(
config: datatypes.Config,
tables: list[datatypes.EmbeddedXlTable],
Expand Down
Loading

0 comments on commit 7d0d739

Please sign in to comment.