Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

faster apply transform tables #2 #215

Merged
merged 6 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ addopts = '-s --durations=0 --durations-min=5.0 --tb=native'
# See https://github.com/nat-n/poethepoet for details.
benchmark = { cmd = "python utils/run_benchmarks.py benchmarks.yml --run", help = "Run a single benchmark. Usage: poe benchmark <benchmark_name>" }
benchmark_all = { shell = "python utils/run_benchmarks.py benchmarks.yml --verbose | tee out.txt", help = "Run the project", interpreter = "posix" }
lint = { shell = "git add .pre-commit-config.yaml & pre-commit run", help = "Run pre-commit hooks", interpreter = "posix" }
lint = { shell = "git add .pre-commit-config.yaml; pre-commit run", help = "Run pre-commit hooks on staged files", interpreter = "posix" }
lint-all = { shell = "git add .pre-commit-config.yaml; pre-commit run --all-files", help = "Run pre-commit hooks on all files", interpreter = "posix" }
test = { cmd = "pytest --cov-report term --cov-report html --cov=xl2times --cov=utils", help = "Run unit tests with pytest" }


Expand Down Expand Up @@ -93,6 +94,6 @@ lint.ignore = [
"E501", # line too long, handled by black
]

# Ruff rule-specific options:
[tool.ruff.mccabe]
max-complexity = 12 # increase max function 'complexity'
[tool.ruff.lint.mccabe]
# Flag errors (`C901`) whenever the complexity level exceeds 5.
max-complexity = 12
22 changes: 19 additions & 3 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime

import pandas as pd
from pandas import DataFrame

from xl2times import datatypes, transforms, utils
from xl2times.transforms import (
Expand Down Expand Up @@ -69,6 +70,24 @@ def make_str(df):


class TestTransforms:
def test_explode_process_commodity_cols(self):
df = DataFrame(
{
"process": ["a", "b", ["c", "d"]],
"commodity": [["v", "w", "x"], "y", "z"],
}
)
df2 = transforms.explode_process_commodity_cols(
None, {"name": df.copy()}, None # pyright: ignore
)
correct = DataFrame(
{
"process": ["a", "a", "a", "b", "c", "d"],
"commodity": ["v", "w", "x", "y", "z", "z"],
}
)
assert df2["name"].equals(correct)

def test_uc_wildcards(self):
"""
Tests logic that matches wildcards in the process_uc_wildcards transform .
Expand Down Expand Up @@ -119,9 +138,6 @@ def test_uc_wildcards(self):

# consistency checks with old method
assert len(set(df_new.columns).symmetric_difference(set(df_old.columns))) == 0
assert df_new.fillna(-1).equals(
df_old.fillna(-1)
), "Dataframes should be equal (ignoring Nones and NaNs)"

def test_generate_commodity_groups(self):
"""
Expand Down
1 change: 1 addition & 0 deletions xl2times/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def convert_xl_to_times(
transforms.complete_commodity_groups,
transforms.process_wildcards,
transforms.apply_transform_tables,
transforms.explode_process_commodity_cols,
transforms.apply_final_fixup,
transforms.convert_aliases,
transforms.assign_model_attributes,
Expand Down
159 changes: 115 additions & 44 deletions xl2times/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,7 @@ def is_year(col_name):
]:
# ~TFM_INS-TS: Gather columns whose names are years into a single "Year" column:
df = table.dataframe

if "year" in df.columns:
raise ValueError(f"TFM_INS-TS table already has Year column: {table}")

Expand Down Expand Up @@ -2129,11 +2130,21 @@ def process_wildcards(

if set(df.columns).intersection(set(process_map.keys())):
df = _match_wildcards(
df, process_map, dictionary, get_matching_processes, "process"
df,
process_map,
dictionary,
get_matching_processes,
"process",
explode=False,
)
if set(df.columns).intersection(set(commodity_map.keys())):
df = _match_wildcards(
df, commodity_map, dictionary, get_matching_commodities, "commodity"
df,
commodity_map,
dictionary,
get_matching_commodities,
"commodity",
explode=False,
)

tables[tag] = df
Expand All @@ -2154,6 +2165,7 @@ def _match_wildcards(
dictionary: dict[str, pd.DataFrame],
matcher: Callable,
result_col: str,
explode: bool = False,
) -> pd.DataFrame:
"""
Match wildcards in the given table using the given process map and dictionary.
Expand All @@ -2164,6 +2176,7 @@ def _match_wildcards(
dictionary: Dictionary of process sets to match against.
matcher: Matching function to use, e.g. get_matching_processes or get_matching_commodities.
result_col: Name of the column to store the matched results in.
explode: Whether to explode the results_col ('process'/'commodities') column into a long-format table.

Returns:
The table with the wildcard columns removed and the results of the wildcard matches added as a column named `results_col`
Expand Down Expand Up @@ -2196,25 +2209,86 @@ def _match_wildcards(
# Finally we merge the matches back into the original table.
# This join re-duplicates the duplicate filters dropped above for speed.
df = (
df.merge(filter_matches, on=wild_cols, how="left")
df.merge(filter_matches, on=wild_cols, how="left", suffixes=("_old", ""))
.reset_index(drop=True)
.drop(columns=wild_cols)
)

# TODO TFM_UPD has existing (but empty) 'process' and 'commodity' columns here. Is it ok to drop existing columns here?
if f"{result_col}_old" in df.columns:
if not df[f"{result_col}_old"].isna().all():
logger.warning(
f"Non-empty existing '{result_col}' column will be overwritten!"
)
df = df.drop(columns=[f"{result_col}_old"])

# And we explode any matches to multiple names to give a long-format table.
if result_col in df.columns:
df = df.explode(result_col, ignore_index=True)
else:
df[result_col] = None
if explode:
if result_col in df.columns:
df = df.explode(result_col, ignore_index=True)
else:
df[result_col] = None

# replace NaNs in results_col with None (expected downstream)
if df[result_col].dtype != object:
df[result_col] = df[result_col].astype(object)

# replace NaNs in results_col with None (expected downstream)
df.loc[df[result_col].isna(), [result_col]] = None

return df


def query(
table: DataFrame,
process: str | list | None,
commodity: str | list | None,
attribute: str | None,
region: str | None,
year: int | None,
) -> pd.Index:
qs = []

# special handling for commodity and process, which can be lists or arbitrary scalars
missing_commodity = (
commodity is None or pd.isna(commodity)
if not isinstance(commodity, list)
else pd.isna(commodity).all()
)
missing_process = (
process is None or pd.isna(process)
if not isinstance(process, list)
else pd.isna(process).all()
)

if not missing_process:
qs.append(f"process in {process if isinstance(process, list) else [process]}")
if not missing_commodity:
qs.append(
f"commodity in {commodity if isinstance(commodity, list) else [commodity]}"
)
if attribute is not None:
qs.append(f"attribute == '{attribute}'")
if region is not None:
qs.append(f"region == '{region}'")
if year is not None:
qs.append(f"year == {year}")
query_str = " and ".join(qs)
row_idx = table.query(query_str).index
return row_idx


def eval_and_update(table: DataFrame, rows_to_update: pd.Index, new_value: str) -> None:
"""Performs an inplace update of rows `rows_to_update` of `table` with `new_value`,
which can be a update formula like `*2.3`."""
if isinstance(new_value, str) and new_value[0] in {"*", "+", "-", "/"}:
old_values = table.loc[rows_to_update, "value"]
updated = old_values.astype(float).map(lambda x: eval("x" + new_value))
table.loc[rows_to_update, "value"] = updated
else:
table.loc[rows_to_update, "value"] = new_value


def apply_transform_tables(
config: datatypes.Config,
tables: dict[str, DataFrame],
Expand All @@ -2224,39 +2298,6 @@ def apply_transform_tables(
Include data from transformation tables.
"""

def query(
table: DataFrame,
process: str | None,
commodity: str | None,
attribute: str | None,
region: str | None,
year: int | None,
) -> pd.Index:
qs = []
if process is not None:
qs.append(f"process in ['{process}']")
if commodity is not None:
qs.append(f"commodity in ['{commodity}']")
if attribute is not None:
qs.append(f"attribute == '{attribute}'")
if region is not None:
qs.append(f"region == '{region}'")
if year is not None:
qs.append(f"year == {year}")
return table.query(" and ".join(qs)).index

def eval_and_update(
table: DataFrame, rows_to_update: pd.Index, new_value: str
) -> None:
"""Performs an inplace update of rows `rows_to_update` of `table` with `new_value`,
which can be a update formula like `*2.3`."""
if isinstance(new_value, str) and new_value[0] in {"*", "+", "-", "/"}:
old_values = table.loc[rows_to_update, "value"]
updated = old_values.astype(float).map(lambda x: eval("x" + new_value))
table.loc[rows_to_update, "value"] = updated
else:
table.loc[rows_to_update, "value"] = new_value

if datatypes.Tag.tfm_upd in tables:
updates = tables[datatypes.Tag.tfm_upd]
table = tables[datatypes.Tag.fi_t]
Expand Down Expand Up @@ -2329,7 +2370,12 @@ def eval_and_update(
# Overwrite (inplace) the column given by the attribute (translated by attr_prop)
# with the value from row
# E.g. if row['attribute'] == 'PRC_TSL' then we overwrite 'tslvl'
table.loc[rows_to_update, attr_prop[row["attribute"]]] = row["value"]
if row["attribute"] not in attr_prop:
logger.warning(
f"Unknown attribute {row['attribute']}, skipping update."
)
else:
table.loc[rows_to_update, attr_prop[row["attribute"]]] = row["value"]

if datatypes.Tag.tfm_mig in tables:
updates = tables[datatypes.Tag.tfm_mig]
Expand Down Expand Up @@ -2375,15 +2421,40 @@ def eval_and_update(
table = model.commodity_groups
updates = tables[datatypes.Tag.tfm_comgrp].filter(table.columns, axis=1)

commodity_groups = pd.concat(
[table, updates], ignore_index=True
).drop_duplicates()
commodity_groups = pd.concat([table, updates], ignore_index=True)
commodity_groups = commodity_groups.explode("commodity", ignore_index=True)
commodity_groups = commodity_groups.drop_duplicates()
commodity_groups.loc[commodity_groups["gmap"].isna(), ["gmap"]] = True
model.commodity_groups = commodity_groups.dropna()

return tables


def explode_process_commodity_cols(
config: datatypes.Config,
tables: dict[str, DataFrame],
model: datatypes.TimesModel,
) -> dict[str, DataFrame]:
"""
Explodes the process and commodity columns in the tables that contain them as lists after process_wildcards.
We store wildcard matches for these columns as lists and explode them late here for performance reasons - to avoid row-wise processing that
would otherwise need to iterate over very long tables.
"""

for tag in tables:
df = tables[tag]

if "process" in df.columns:
df = df.explode("process", ignore_index=True)

if "commodity" in df.columns:
df = df.explode("commodity", ignore_index=True)

tables[tag] = df

return tables


def process_time_slices(
config: datatypes.Config,
tables: list[datatypes.EmbeddedXlTable],
Expand Down
Loading
Loading