Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup process_uc_wildcards #193

Merged
merged 24 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ef30184
Remove rows with duplicate query cols and cleanup handling of TFM var…
SamRWest Feb 16, 2024
b892397
made output parsing regex more robust
SamRWest Feb 16, 2024
0eb5bdd
fixed parse_result
SamRWest Feb 16, 2024
158ce33
Add loguru, poe and poe shortcuts
SamRWest Feb 16, 2024
783cb88
support merging tables (as VEDA appears to) where come columns are op…
SamRWest Feb 16, 2024
df82670
Fixed indentation
SamRWest Feb 18, 2024
6cfe42e
WIP prototype of more efficient uc_wildcards transform
SamRWest Feb 19, 2024
ed82d3d
Working prototype, ~10-20x speedup
SamRWest Feb 20, 2024
e12dfe6
Switched to ireland for unit test data
SamRWest Feb 20, 2024
dc07fae
formatting
SamRWest Feb 20, 2024
d199eeb
cleanup
SamRWest Feb 21, 2024
5cdda5f
extra --debug logic
SamRWest Feb 21, 2024
f5da625
Merge branch 'main' into feature/wildcard_speedup
SamRWest Feb 21, 2024
a4d43fa
post merge fixes
SamRWest Feb 21, 2024
e388130
fix import
SamRWest Feb 21, 2024
33fb41f
Corrected debug logic
SamRWest Feb 21, 2024
fcd11cd
remove shell=True in dd_to_csv run on non-windows OSes
SamRWest Feb 21, 2024
ff268dc
addressed review comments from @olejandro
SamRWest Feb 22, 2024
a2cbeb2
switched to lru_cache
SamRWest Feb 22, 2024
f015daa
logging tweaks
SamRWest Feb 23, 2024
b00e68e
Merged with main
SamRWest Feb 23, 2024
c7a1ab7
Merge branch 'main' into feature/wildcard_speedup
SamRWest Feb 23, 2024
ab76d8d
added extra check in matchers
SamRWest Feb 23, 2024
8dea49f
Merge remote-tracking branch 'origin/feature/wildcard_speedup' into f…
SamRWest Feb 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ docs/api/
.coverage
/out.txt
*.log
/profile.*
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ xl2times = "xl2times.__main__:main"
# don't print runtime warnings
filterwarnings = ["ignore::DeprecationWarning", "ignore::UserWarning", "ignore::FutureWarning"]
# show output, print test coverage report
addopts = '-s --durations=0 --durations-min=5.0 --tb=native --cov-report term --cov-report html --cov=xl2times --cov=utils'
addopts = '-s --durations=0 --durations-min=5.0 --tb=native'

[tool.poe.tasks]
# Automation of common dev tasks etc.
# Run with: `poe <target>`, e,g. `poe lint` or `poe benchmark Ireland`.
# See https://github.com/nat-n/poethepoet for details.
benchmark = { cmd = "python utils/run_benchmarks.py benchmarks.yml --verbose --run", help = "Run a single benchmark. Usage: poe benchmark <benchmark_name>" }
benchmark = { cmd = "python utils/run_benchmarks.py benchmarks.yml --run", help = "Run a single benchmark. Usage: poe benchmark <benchmark_name>" }
benchmark_all = { shell = "python utils/run_benchmarks.py benchmarks.yml --verbose | tee out.txt", help = "Run the project", interpreter = "posix" }
lint = { shell = "git add .pre-commit-config.yaml & pre-commit run", help = "Run pre-commit hooks", interpreter = "posix"}
test = { cmd = "pytest", help = "Run unit tests with pytest" }
lint = { shell = "git add .pre-commit-config.yaml & pre-commit run", help = "Run pre-commit hooks", interpreter = "posix" }
test = { cmd = "pytest --cov-report term --cov-report html --cov=xl2times --cov=utils", help = "Run unit tests with pytest" }
Binary file not shown.
Binary file added tests/data/process_uc_wildcards_ireland_dict.pkl
Binary file not shown.
105 changes: 103 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@

import pandas as pd

from xl2times import transforms
from xl2times import transforms, utils, datatypes
from xl2times.transforms import (
_process_comm_groups_vectorised,
_count_comm_group_vectorised,
expand_rows,
get_matching_commodities,
get_matching_processes,
_match_uc_wildcards,
process_map,
commodity_map,
)

logger = utils.get_logger()

pd.set_option(
"display.max_rows",
20,
Expand All @@ -22,7 +30,99 @@
)


def _match_uc_wildcards_old(
df: pd.DataFrame, dictionary: dict[str, pd.DataFrame]
) -> pd.DataFrame:
"""Old version of the process_uc_wildcards matching logic, for comparison with the new vectorised version.
TODO remove this function once validated.
"""

def make_str(df):
if df is not None and len(df) != 0:
list_from_df = df.iloc[:, 0].unique()
return ",".join(list_from_df)
else:
return None

df["process"] = df.apply(
lambda row: make_str(get_matching_processes(row, dictionary)), axis=1
)
df["commodity"] = df.apply(
lambda row: make_str(get_matching_commodities(row, dictionary)), axis=1
)

query_columns = transforms.process_map.keys() | transforms.commodity_map.keys()
cols_to_drop = [col for col in df.columns if col in query_columns]

df = expand_rows(
query_columns,
datatypes.EmbeddedXlTable(
tag="",
uc_sets={},
sheetname="",
range="",
filename="",
dataframe=df.drop(columns=cols_to_drop),
),
).dataframe
return df


class TestTransforms:
def test_uc_wildcards(self):
"""
Tests logic that matches wildcards in the process_uc_wildcards transform .

Results on Ireland model:
Old method took 0:00:08.42 seconds
New method took 0:00:00.18 seconds, speedup: 46.5x
"""
import pickle

df_in = pd.read_parquet("tests/data/process_uc_wildcards_ireland_data.parquet")
with open("tests/data/process_uc_wildcards_ireland_dict.pkl", "rb") as f:
dictionary = pickle.load(f)
df = df_in.copy()

t0 = datetime.now()

# optimised functions
df_new = _match_uc_wildcards(
df, process_map, dictionary, get_matching_processes, "process"
)
df_new = _match_uc_wildcards(
df_new, commodity_map, dictionary, get_matching_commodities, "commodity"
)

t1 = datetime.now()

# Unoptimised function
df_old = _match_uc_wildcards_old(df, dictionary)

t2 = datetime.now()

logger.info(f"Old method took {t2 - t1} seconds")
logger.info(
f"New method took {t1 - t0} seconds, speedup: {((t2 - t1) / (t1 - t0)):.1f}x"
)

# unit tests
assert df_new is not None and not df_new.empty
assert (
df_new.shape[0] >= df_in.shape[0]
), "should have more rows after processing uc_wildcards"
assert (
df_new.shape[1] < df_in.shape[1]
), "should have fewer columns after processing uc_wildcards"
assert "process" in df_new.columns, "should have added process column"
assert "commodity" in df_new.columns, "should have added commodity column"

# consistency checks with old method
assert len(set(df_new.columns).symmetric_difference(set(df_old.columns))) == 0
assert df_new.fillna(-1).equals(
df_old.fillna(-1)
), "Dataframes should be equal (ignoring Nones and NaNs)"

def test_generate_commodity_groups(self):
"""
Tests that the _count_comm_group_vectorised function works as expected.
Expand Down Expand Up @@ -64,4 +164,5 @@ def test_default_pcg_vectorised(self):


if __name__ == "__main__":
TestTransforms().test_default_pcg_vectorised()
# TestTransforms().test_default_pcg_vectorised()
TestTransforms().test_uc_wildcards()
55 changes: 29 additions & 26 deletions utils/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
import git
import pandas as pd
import yaml
from loguru import logger
from tabulate import tabulate

from dd_to_csv import main
from xl2times import utils
from xl2times.__main__ import parse_args, run
from xl2times.utils import max_workers

logger = utils.get_logger()
Expand Down Expand Up @@ -146,7 +147,8 @@ def run_benchmark(
# First convert ground truth DD to csv
if not skip_csv:
shutil.rmtree(csv_folder, ignore_errors=True)
if os.name != "nt":
if not debug:
# run as subprocess if not in --debug mode
res = subprocess.run(
[
"python",
Expand All @@ -157,6 +159,7 @@ def run_benchmark(
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True if os.name == "nt" else False,
)
if res.returncode != 0:
# Remove partial outputs
Expand All @@ -166,9 +169,12 @@ def run_benchmark(
sys.exit(5)
else:
# If debug option is set, run as a function call to allow stepping with a debugger.
from dd_to_csv import main

main([dd_folder, csv_folder])
try:
main([dd_folder, csv_folder])
except Exception:
logger.exception(f"dd_to_csv failed on {benchmark['name']}")
shutil.rmtree(csv_folder, ignore_errors=True)
sys.exit(5)

elif not path.exists(csv_folder):
logger.error(f"--skip_csv is true but {csv_folder} does not exist")
Expand All @@ -189,22 +195,12 @@ def run_benchmark(
else:
args.append(xl_folder)
start = time.time()
res = None
if not debug:
res = subprocess.run(
["xl2times"] + args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
else:
# If debug option is set, run as a function call to allow stepping with a debugger.
from xl2times.__main__ import run, parse_args

summary = run(parse_args(args))
# Call the conversion function directly
summary = run(parse_args(args))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, thinking about it again, perhaps there is one reason to use subprocess -- at least in the CI should we check that xl2times works as expected from the command line? But on the other hand, run(parse_args( is pretty much the same as the CLI invocation, and CI is probably faster without subprocess...

I'm undecided, so would love your thoughts. And we can leave it as is in this PR and discuss in an issue, maybe?


# pack the results into a namedtuple pretending to be a return value from a subprocess call (as above).
res = namedtuple("stdout", ["stdout", "stderr", "returncode"])(summary, "", 0)
# pack the results into a namedtuple pretending to be a return value from a subprocess call (as above).
res = namedtuple("stdout", ["stdout", "stderr", "returncode"])(summary, "", 0)

runtime = time.time() - start

Expand Down Expand Up @@ -255,8 +251,13 @@ def run_all_benchmarks(
debug=debug,
)

with ProcessPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(run_a_benchmark, benchmarks))
if debug:
# bypass process pool and call benchmarks directly if --debug is set.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love to have this documented in the CLI help for --debug, thanks!

results = [run_a_benchmark(b) for b in benchmarks]
else:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(run_a_benchmark, benchmarks))

logger.info("\n\n" + tabulate(results, headers, floatfmt=".1f") + "\n")

if skip_regression:
Expand Down Expand Up @@ -302,9 +303,10 @@ def run_all_benchmarks(
)
sys.exit(8)

# Re-run benchmarks on main
# Re-run benchmarks on main - check it out and pull
repo.heads.main.checkout()
logger.info("Running benchmarks on main", end="", flush=True)
origin.pull("main") # if main already exists, make sure it's up to date
logger.info("Running benchmarks on main")
run_a_benchmark = partial(
run_benchmark,
benchmarks_folder=benchmarks_folder,
Expand Down Expand Up @@ -441,19 +443,20 @@ def run_all_benchmarks(
"--debug",
action="store_true",
default=False,
help="Run each benchmark as a function call to allow a debugger to stop at breakpoints in benchmark runs.",
help="Run each benchmark as a direct function call (disables subprocesses) to allow a debugger to stop at breakpoints "
"in benchmark runs.",
)
args = args_parser.parse_args()

spec = yaml.safe_load(open(args.benchmarks_yaml))
benchmarks_folder = spec["benchmarks_folder"]
benchmark_names = [b["name"] for b in spec["benchmarks"]]
if len(set(benchmark_names)) != len(benchmark_names):
logger.error(f"Found duplicate name in benchmarks YAML file")
logger.error("Found duplicate name in benchmarks YAML file")
sys.exit(11)

if args.dd and args.times_dir is None:
logger.error(f"--times_dir is required when using --dd")
logger.error("--times_dir is required when using --dd")
sys.exit(12)

if args.run is not None:
Expand Down
15 changes: 9 additions & 6 deletions xl2times/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime

from pandas.core.frame import DataFrame
import pandas as pd
import pickle
Expand Down Expand Up @@ -27,9 +29,10 @@ def convert_xl_to_times(
stop_after_read: bool = False,
) -> Dict[str, DataFrame]:
pickle_file = "raw_tables.pkl"
t0 = datetime.now()
if use_pkl and os.path.isfile(pickle_file):
raw_tables = pickle.load(open(pickle_file, "rb"))
logger.warning(f"Using pickled data not xlsx")
logger.warning("Using pickled data not xlsx")
else:
raw_tables = []

Expand All @@ -40,12 +43,12 @@ def convert_xl_to_times(
raw_tables.extend(result)
else:
for f in input_files:
result = excel.extract_tables(f)
result = excel.extract_tables(str(Path(f).absolute()))
raw_tables.extend(result)
pickle.dump(raw_tables, open(pickle_file, "wb"))
logger.info(
f"Extracted {len(raw_tables)} tables,"
f" {sum(table.dataframe.shape[0] for table in raw_tables)} rows"
f" {sum(table.dataframe.shape[0] for table in raw_tables)} rows in {datetime.now() - t0}"
)

if stop_after_read:
Expand Down Expand Up @@ -248,7 +251,7 @@ def produce_times_tables(
result = {}
used_tables = set()
for mapping in config.times_xl_maps:
if not mapping.xl_name in input:
if mapping.xl_name not in input:
logger.warning(
f"Cannot produce table {mapping.times_name} because"
f" {mapping.xl_name} does not exist"
Expand Down Expand Up @@ -281,7 +284,7 @@ def produce_times_tables(
# Excel columns can be duplicated into multiple Times columns
for times_col, xl_col in mapping.col_map.items():
df[times_col] = df[xl_col]
cols_to_drop = [x for x in df.columns if not x in mapping.times_cols]
cols_to_drop = [x for x in df.columns if x not in mapping.times_cols]
df.drop(columns=cols_to_drop, inplace=True)
df.drop_duplicates(inplace=True)
df.reset_index(drop=True, inplace=True)
Expand Down Expand Up @@ -392,7 +395,7 @@ def dump_tables(tables: List, filename: str) -> List:
return tables


def run(args) -> str | None:
def run(args: argparse.Namespace) -> str | None:
"""
Runs the xl2times conversion.
Args:
Expand Down
5 changes: 2 additions & 3 deletions xl2times/excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Dict, List
import time
from pandas.core.frame import DataFrame
import pandas as pd
import numpy
import re
from . import datatypes
Expand Down Expand Up @@ -43,8 +42,8 @@ def extract_tables(filename: str) -> List[datatypes.EmbeddedXlTable]:
if len(parts) == 2:
uc_sets[parts[0].strip()] = parts[1].strip()
else:
logger.info(
f"WARNING: Malformed UC_SET in {sheet.title}, {filename}"
logger.warning(
f"Malformed UC_SET in {sheet.title}, {filename}"
)
else:
col_index = df.columns.get_loc(colname)
Expand Down
Loading
Loading