Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

faster apply transform tables #2 #215

Merged
merged 6 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ docs/api/
*.log
/profile.*
xl2times/.cache/
*.log.zip
22 changes: 19 additions & 3 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime

import pandas as pd
from pandas import DataFrame

from xl2times import transforms, utils, datatypes
from xl2times.transforms import (
Expand Down Expand Up @@ -69,6 +70,24 @@ def make_str(df):


class TestTransforms:
def test_explode_process_commodity_cols(self):
df = DataFrame(
{
"process": ["a", "b", ["c", "d"]],
"commodity": [["v", "w", "x"], "y", "z"],
}
)
df2 = transforms.explode_process_commodity_cols(
None, {"name": df.copy()}, None # pyright: ignore
)
correct = DataFrame(
{
"process": ["a", "a", "a", "b", "c", "d"],
"commodity": ["v", "w", "x", "y", "z", "z"],
}
)
assert df2["name"].equals(correct)

def test_uc_wildcards(self):
"""
Tests logic that matches wildcards in the process_uc_wildcards transform .
Expand Down Expand Up @@ -119,9 +138,6 @@ def test_uc_wildcards(self):

# consistency checks with old method
assert len(set(df_new.columns).symmetric_difference(set(df_old.columns))) == 0
assert df_new.fillna(-1).equals(
df_old.fillna(-1)
), "Dataframes should be equal (ignoring Nones and NaNs)"

def test_generate_commodity_groups(self):
"""
Expand Down
1 change: 1 addition & 0 deletions xl2times/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def convert_xl_to_times(
transforms.complete_commodity_groups,
transforms.process_wildcards,
transforms.apply_transform_tables,
transforms.explode_process_commodity_cols,
transforms.apply_final_fixup,
transforms.convert_aliases,
transforms.assign_model_attributes,
Expand Down
162 changes: 115 additions & 47 deletions xl2times/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,7 +1819,7 @@ def is_year(col_name):
]:
# ~TFM_INS-TS: Gather columns whose names are years into a single "Year" column:
df = table.dataframe
query_columns = config.query_columns[datatypes.Tag(table.tag)]

if "year" in df.columns:
raise ValueError(f"TFM_INS-TS table already has Year column: {table}")

Expand Down Expand Up @@ -2132,11 +2132,21 @@ def process_wildcards(

if set(df.columns).intersection(set(process_map.keys())):
df = _match_wildcards(
df, process_map, dictionary, get_matching_processes, "process"
df,
process_map,
dictionary,
get_matching_processes,
"process",
explode=False,
)
if set(df.columns).intersection(set(commodity_map.keys())):
df = _match_wildcards(
df, commodity_map, dictionary, get_matching_commodities, "commodity"
df,
commodity_map,
dictionary,
get_matching_commodities,
"commodity",
explode=False,
)

tables[tag] = df
Expand All @@ -2157,6 +2167,7 @@ def _match_wildcards(
dictionary: dict[str, pd.DataFrame],
matcher: Callable,
result_col: str,
explode: bool = False,
) -> pd.DataFrame:
"""
Match wildcards in the given table using the given process map and dictionary.
Expand All @@ -2167,6 +2178,7 @@ def _match_wildcards(
dictionary: Dictionary of process sets to match against.
matcher: Matching function to use, e.g. get_matching_processes or get_matching_commodities.
result_col: Name of the column to store the matched results in.
explode: Whether to explode the results_col ('process'/'commodities') column into a long-format table.

Returns:
The table with the wildcard columns removed and the results of the wildcard matches added as a column named `results_col`
Expand Down Expand Up @@ -2199,25 +2211,86 @@ def _match_wildcards(
# Finally we merge the matches back into the original table.
# This join re-duplicates the duplicate filters dropped above for speed.
df = (
df.merge(filter_matches, on=wild_cols, how="left")
df.merge(filter_matches, on=wild_cols, how="left", suffixes=("_old", ""))
.reset_index(drop=True)
.drop(columns=wild_cols)
)

# TODO TFM_UPD has existing (but empty) 'process' and 'commodity' columns here. Is it ok to drop existing columns here?
if f"{result_col}_old" in df.columns:
if not df[f"{result_col}_old"].isna().all():
logger.warning(
f"Non-empty existing '{result_col}' column will be overwritten!"
)
df = df.drop(columns=[f"{result_col}_old"])

# And we explode any matches to multiple names to give a long-format table.
if result_col in df.columns:
df = df.explode(result_col, ignore_index=True)
else:
df[result_col] = None
if explode:
if result_col in df.columns:
df = df.explode(result_col, ignore_index=True)
else:
df[result_col] = None

# replace NaNs in results_col with None (expected downstream)
if df[result_col].dtype != object:
df[result_col] = df[result_col].astype(object)

# replace NaNs in results_col with None (expected downstream)
df.loc[df[result_col].isna(), [result_col]] = None

return df


def query(
table: DataFrame,
process: str | list | None,
commodity: str | list | None,
attribute: str | None,
region: str | None,
year: int | None,
) -> pd.Index:
qs = []

# special handling for commodity and process, which can be lists or arbitrary scalars
missing_commodity = (
commodity is None or pd.isna(commodity)
if not isinstance(commodity, list)
else pd.isna(commodity).all()
)
missing_process = (
process is None or pd.isna(process)
if not isinstance(process, list)
else pd.isna(process).all()
)

if not missing_process:
qs.append(f"process in {process if isinstance(process, list) else [process]}")
if not missing_commodity:
qs.append(
f"commodity in {commodity if isinstance(commodity, list) else [commodity]}"
)
if attribute is not None:
qs.append(f"attribute == '{attribute}'")
if region is not None:
qs.append(f"region == '{region}'")
if year is not None:
qs.append(f"year == {year}")
query_str = " and ".join(qs)
row_idx = table.query(query_str).index
return row_idx


def eval_and_update(table: DataFrame, rows_to_update: pd.Index, new_value: str) -> None:
"""Performs an inplace update of rows `rows_to_update` of `table` with `new_value`,
which can be a update formula like `*2.3`."""
if isinstance(new_value, str) and new_value[0] in {"*", "+", "-", "/"}:
old_values = table.loc[rows_to_update, "value"]
updated = old_values.astype(float).map(lambda x: eval("x" + new_value))
table.loc[rows_to_update, "value"] = updated
else:
table.loc[rows_to_update, "value"] = new_value


def apply_transform_tables(
config: datatypes.Config,
tables: Dict[str, DataFrame],
Expand All @@ -2227,41 +2300,6 @@ def apply_transform_tables(
Include data from transformation tables.
"""

topology = generate_topology_dictionary(tables, model)

def query(
table: DataFrame,
process: str | None,
commodity: str | None,
attribute: str | None,
region: str | None,
year: int | None,
) -> pd.Index:
qs = []
if process is not None:
qs.append(f"process in ['{process}']")
if commodity is not None:
qs.append(f"commodity in ['{commodity}']")
if attribute is not None:
qs.append(f"attribute == '{attribute}'")
if region is not None:
qs.append(f"region == '{region}'")
if year is not None:
qs.append(f"year == {year}")
return table.query(" and ".join(qs)).index

def eval_and_update(
table: DataFrame, rows_to_update: pd.Index, new_value: str
) -> None:
"""Performs an inplace update of rows `rows_to_update` of `table` with `new_value`,
which can be a update formula like `*2.3`."""
if isinstance(new_value, str) and new_value[0] in {"*", "+", "-", "/"}:
old_values = table.loc[rows_to_update, "value"]
updated = old_values.astype(float).map(lambda x: eval("x" + new_value))
table.loc[rows_to_update, "value"] = updated
else:
table.loc[rows_to_update, "value"] = new_value

if datatypes.Tag.tfm_upd in tables:
updates = tables[datatypes.Tag.tfm_upd]
table = tables[datatypes.Tag.fi_t]
Expand Down Expand Up @@ -2334,7 +2372,12 @@ def eval_and_update(
# Overwrite (inplace) the column given by the attribute (translated by attr_prop)
# with the value from row
# E.g. if row['attribute'] == 'PRC_TSL' then we overwrite 'tslvl'
table.loc[rows_to_update, attr_prop[row["attribute"]]] = row["value"]
if row["attribute"] not in attr_prop:
logger.warning(
f"Unknown attribute {row['attribute']}, skipping update."
)
else:
table.loc[rows_to_update, attr_prop[row["attribute"]]] = row["value"]

if datatypes.Tag.tfm_mig in tables:
updates = tables[datatypes.Tag.tfm_mig]
Expand Down Expand Up @@ -2380,15 +2423,40 @@ def eval_and_update(
table = model.commodity_groups
updates = tables[datatypes.Tag.tfm_comgrp].filter(table.columns, axis=1)

commodity_groups = pd.concat(
[table, updates], ignore_index=True
).drop_duplicates()
commodity_groups = pd.concat([table, updates], ignore_index=True)
commodity_groups = commodity_groups.explode("commodity", ignore_index=True)
commodity_groups = commodity_groups.drop_duplicates()
commodity_groups.loc[commodity_groups["gmap"].isna(), ["gmap"]] = True
model.commodity_groups = commodity_groups.dropna()

return tables


def explode_process_commodity_cols(
config: datatypes.Config,
tables: Dict[str, DataFrame],
model: datatypes.TimesModel,
) -> Dict[str, DataFrame]:
"""
Explodes the process and commodity columns in the tables that contain them as lists after process_wildcards.
We store wildcard matches for these columns as lists and explode them late here for performance reasons - to avoid row-wise processing that
would otherwise need to iterate over very long tables.
"""

for tag in tables:
df = tables[tag]

if "process" in df.columns:
df = df.explode("process", ignore_index=True)

if "commodity" in df.columns:
df = df.explode("commodity", ignore_index=True)

tables[tag] = df

return tables


def process_time_slices(
config: datatypes.Config,
tables: List[datatypes.EmbeddedXlTable],
Expand Down
Loading
Loading