diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c7a3b09..323aefc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,9 +3,9 @@ name: CI on: # Triggers the workflow on push or pull request events but only for the main branch push: - branches: [main] + branches: [ main ] pull_request: - branches: [main] + branches: [ main ] # Allows you to run this workflow manually from the Actions tab workflow_dispatch: @@ -34,6 +34,12 @@ jobs: pre-commit install pre-commit run --all-files + - name: Run unit tests + working-directory: xl2times + run: | + source .venv/bin/activate + pytest + # ---------- Prepare ETSAP Demo models - uses: actions/checkout@v3 @@ -69,6 +75,9 @@ jobs: # ---------- Install GAMS - name: Install GAMS + env: + GAMS_LICENSE: ${{ secrets.GAMS_LICENSE }} + if: ${{ env.GAMS_LICENSE != '' }} run: | curl https://d37drm4t2jghv5.cloudfront.net/distributions/44.1.0/linux/linux_x64_64_sfx.exe -o linux_x64_64_sfx.exe chmod +x linux_x64_64_sfx.exe @@ -81,17 +90,18 @@ jobs: mkdir -p $HOME/.local/share/GAMS echo "$GAMS_LICENSE" > $HOME/.local/share/GAMS/gamslice.txt ls -l $HOME/.local/share/GAMS/ - env: - GAMS_LICENSE: ${{ secrets.GAMS_LICENSE }} + # ---------- Run tool, check for regressions - name: Run tool on all benchmarks + env: + GAMS_LICENSE: ${{ secrets.GAMS_LICENSE }} + if: ${{ env.GAMS_LICENSE != '' }} working-directory: xl2times # Use tee to also save the output to out.txt so that the summary table can be # printed again in the next step. # Save the return code to retcode.txt so that the next step can fail the action - # if run_benchmarks.py failed. run: | source .venv/bin/activate export PATH=$PATH:$GITHUB_WORKSPACE/GAMS/gams44.1_linux_x64_64_sfx @@ -101,6 +111,22 @@ jobs: | tee out.txt; \ echo ${PIPESTATUS[0]} > retcode.txt) + - name: Run CSV-only regression tests (no GAMS license) + env: + GAMS_LICENSE: ${{ secrets.GAMS_LICENSE }} + if: ${{ env.GAMS_LICENSE == '' }} + working-directory: xl2times + # Run without --dd flag if GAMS license secret doesn't exist. + # Useful for testing for (CSV) regressions in forks before creating PRs. + run: | + source .venv/bin/activate + export PATH=$PATH:$GITHUB_WORKSPACE/GAMS/gams44.1_linux_x64_64_sfx + (python utils/run_benchmarks.py benchmarks.yml \ + --times_dir $GITHUB_WORKSPACE/TIMES_model \ + --verbose \ + | tee out.txt; \ + echo ${PIPESTATUS[0]} > retcode.txt) + - name: Print summary working-directory: xl2times run: | diff --git a/.gitignore b/.gitignore index a707f93..2bb5d07 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,11 @@ ground_truth/* *.pyproj.* speedscope.json *.pkl -.venv/ +.venv*/ benchmarks/ +.idea/ +.python-version docs/_build/ docs/api/ +.coverage +/out.txt diff --git a/README.md b/README.md index ee2c1e7..00416c7 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,45 @@ git commit --no-verify See our GitHub Actions CI `.github/workflows/ci.yml` and the utility script `utils/run_benchmarks.py` to see how to run the tool on the DemoS models. +In short, use the commands below to clone the benchmarks data into your local `benchmarks` dir. +Note that this assumes you have access to all these repositories (some are private and +you'll have to request access) - if not, comment out the inaccessible benchmarks from `benchmakrs.yml` before running. + +```bash +mkdir benchmarks +# Get VEDA example models and reference DD files +# XLSX files are in private repo for licensing reasons, please request access or replace with your own licensed VEDA example files. +git clone git@github.com:olejandro/demos-xlsx.git benchmarks/xlsx/ +git clone git@github.com:olejandro/demos-dd.git benchmarks/dd/ + +# Get Ireland model and reference DD files +git clone git@github.com:esma-cgep/tim.git benchmarks/xlsx/Ireland +git clone git@github.com:esma-cgep/tim-gams.git benchmarks/dd/Ireland +``` +Then to run the benchmarks: +```bash +# Run a only a single benchmark by name (see benchmarks.yml for name list) +python utils/run_benchmarks.py benchmarks.yml --verbose --run DemoS_001-all | tee out.txt + +# Run all benchmarks (without GAMS run, just comparing CSV data) +python utils/run_benchmarks.py benchmarks.yml --verbose | tee out.txt + + +# Run benchmarks with regression tests vs main branch +git branch feature/your_new_changes --checkout +# ... make your code changes here ... +git commit -a -m "your commit message" # code must be committed for comparison to `main` branch to run. +python utils/run_benchmarks.py benchmarks.yml --verbose | tee out.txt +``` +At this point, if you haven't broken anything you should see something like: +``` +Change in runtime: +2.97s +Change in correct rows: +0 +Change in additional rows: +0 +No regressions. You're awesome! +``` +If you have a large increase in runtime, a decrease in correct rows or fewer rows being produced, then you've broken something and will need to figure out how to fix it. + ### Debugging Regressions If your change is causing regressions on one of the benchmarks, a useful way to debug and find the difference is to run the tool in verbose mode and compare the intermediate tables. For example, if your branch has regressions on Demo 1: @@ -97,6 +136,7 @@ python -m build python -m twine upload dist/* ``` + ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a diff --git a/pyproject.toml b/pyproject.toml index 3b678e1..ab7f0f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,20 +14,28 @@ requires-python = ">=3.10" license = { file = "LICENSE" } keywords = [] classifiers = [ - "Development Status :: 4 - Beta", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python", - "Programming Language :: Python :: 3", + "Development Status :: 4 - Beta", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", ] dependencies = [ - "GitPython >= 3.1.31, < 3.2", - "more-itertools", - "openpyxl >= 3.0, < 3.1", - "pandas >= 2.1", + "GitPython >= 3.1.31, < 3.2", + "more-itertools", + "openpyxl >= 3.0, < 3.1", + "pandas >= 2.1", + "pyarrow", + "tqdm", ] [project.optional-dependencies] -dev = ["black", "pre-commit", "tabulate"] +dev = [ + "black", + "pre-commit", + "tabulate", + "pytest", + "pytest-cov" +] [project.urls] Documentation = "https://github.com/etsap-TIMES/xl2times#readme" @@ -36,3 +44,9 @@ Source = "https://github.com/etsap-TIMES/xl2times" [project.scripts] xl2times = "xl2times.__main__:main" + +[tool.pytest.ini_options] +# don't print runtime warnings +filterwarnings = ["ignore::DeprecationWarning", "ignore::UserWarning", "ignore::FutureWarning"] +# show output, print test coverage report +addopts = '-s --durations=0 --durations-min=5.0 --tb=native --cov-report term --cov-report html --cov=xl2times --cov=utils' diff --git a/tests/data/austimes_pcg_test_data.parquet b/tests/data/austimes_pcg_test_data.parquet new file mode 100644 index 0000000..c3346d9 Binary files /dev/null and b/tests/data/austimes_pcg_test_data.parquet differ diff --git a/tests/data/comm_groups_austimes_test_data.parquet b/tests/data/comm_groups_austimes_test_data.parquet new file mode 100644 index 0000000..84cfc51 Binary files /dev/null and b/tests/data/comm_groups_austimes_test_data.parquet differ diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 0000000..af77b6b --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,67 @@ +from datetime import datetime + +import pandas as pd + +from xl2times import transforms +from xl2times.transforms import ( + _process_comm_groups_vectorised, + _count_comm_group_vectorised, +) + +pd.set_option( + "display.max_rows", + 20, + "display.max_columns", + 20, + "display.width", + 300, + "display.max_colwidth", + 75, + "display.precision", + 3, +) + + +class TestTransforms: + def test_generate_commodity_groups(self): + """ + Tests that the _count_comm_group_vectorised function works as expected. + Full austimes run: + Vectorised version took 0.021999 seconds + looped version took 966.653371 seconds + 43958x speedup + """ + # data extracted immediately before the original for loops + comm_groups = pd.read_parquet( + "tests/data/comm_groups_austimes_test_data.parquet" + ).drop(columns=["commoditygroup"]) + + # filter data so test runs faster + comm_groups = comm_groups.query("region in ['ACT', 'NSW']") + + comm_groups2 = comm_groups.copy() + _count_comm_group_vectorised(comm_groups2) + assert comm_groups2.drop(columns=["commoditygroup"]).equals(comm_groups) + assert comm_groups2.shape == (comm_groups.shape[0], comm_groups.shape[1] + 1) + + def test_default_pcg_vectorised(self): + """Tests the default primary commodity group identification logic runs correctly. + Full austimes run: + Looped version took 1107.66 seconds + Vectorised version took 62.85 seconds + """ + + # data extracted immediately before the original for loops + comm_groups = pd.read_parquet("tests/data/austimes_pcg_test_data.parquet") + + comm_groups = comm_groups[(comm_groups["region"].isin(["ACT", "NT"]))] + comm_groups2 = _process_comm_groups_vectorised( + comm_groups.copy(), transforms.csets_ordered_for_pcg + ) + assert comm_groups2 is not None and not comm_groups2.empty + assert comm_groups2.shape == (comm_groups.shape[0], comm_groups.shape[1] + 1) + assert comm_groups2.drop(columns=["DefaultVedaPCG"]).equals(comm_groups) + + +if __name__ == "__main__": + TestTransforms().test_default_pcg_vectorised() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/dd_to_csv.py b/utils/dd_to_csv.py index df444e7..a9e1132 100644 --- a/utils/dd_to_csv.py +++ b/utils/dd_to_csv.py @@ -1,4 +1,5 @@ import argparse +import sys from collections import defaultdict import json import os @@ -216,7 +217,7 @@ def convert_dd_to_tabular( return -if __name__ == "__main__": +def main(arg_list: None | list[str] = None): args_parser = argparse.ArgumentParser() args_parser.add_argument( "input_dir", type=str, help="Input directory containing .dd files." @@ -224,5 +225,9 @@ def convert_dd_to_tabular( args_parser.add_argument( "output_dir", type=str, help="Output directory to save the .csv files in." ) - args = args_parser.parse_args() + args = args_parser.parse_args(arg_list) convert_dd_to_tabular(args.input_dir, args.output_dir, generate_headers_by_attr()) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/utils/run_benchmarks.py b/utils/run_benchmarks.py index f2a9a91..74e0e53 100644 --- a/utils/run_benchmarks.py +++ b/utils/run_benchmarks.py @@ -1,4 +1,6 @@ import argparse +import os +from collections import namedtuple from concurrent.futures import ProcessPoolExecutor from functools import partial import git @@ -13,6 +15,8 @@ from typing import Any, Tuple import yaml +from xl2times.utils import max_workers + def parse_result(lastline): m = match( @@ -22,7 +26,7 @@ def parse_result(lastline): ) if not m: print(f"ERROR: could not parse output of run:\n{lastline}") - sys.exit(1) + sys.exit(2) # return (accuracy, num_correct_rows, num_additional_rows) return (float(m.groups()[0]), int(m.groups()[1]), int(m.groups()[3])) @@ -58,7 +62,7 @@ def run_gams_gdxdiff( print(res.stdout) print(res.stderr if res.stderr is not None else "") print(f"ERROR: GAMS failed on {benchmark['name']}") - sys.exit(1) + sys.exit(3) if "error" in res.stdout.lower(): print(res.stdout) print(f"ERROR: GAMS errored on {benchmark['name']}") @@ -89,7 +93,7 @@ def run_gams_gdxdiff( print(res.stdout) print(res.stderr if res.stderr is not None else "") print(f"ERROR: GAMS failed on {benchmark['name']} ground truth") - sys.exit(1) + sys.exit(4) if "error" in res.stdout.lower(): print(res.stdout) print(f"ERROR: GAMS errored on {benchmark['name']}") @@ -125,6 +129,7 @@ def run_benchmark( skip_csv: bool = False, out_folder: str = "out", verbose: bool = False, + debug: bool = False, ) -> Tuple[str, float, str, float, int, int]: xl_folder = path.join(benchmarks_folder, "xlsx", benchmark["input_folder"]) dd_folder = path.join(benchmarks_folder, "dd", benchmark["dd_folder"]) @@ -134,26 +139,33 @@ def run_benchmark( # First convert ground truth DD to csv if not skip_csv: shutil.rmtree(csv_folder, ignore_errors=True) - res = subprocess.run( - [ - "python", - "utils/dd_to_csv.py", - dd_folder, - csv_folder, - ], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - ) - if res.returncode != 0: - # Remove partial outputs - shutil.rmtree(csv_folder, ignore_errors=True) - print(res.stdout) - print(f"ERROR: dd_to_csv failed on {benchmark['name']}") - sys.exit(1) + if os.name != "nt": + res = subprocess.run( + [ + "python", + "utils/dd_to_csv.py", + dd_folder, + csv_folder, + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + if res.returncode != 0: + # Remove partial outputs + shutil.rmtree(csv_folder, ignore_errors=True) + print(res.stdout) + print(f"ERROR: dd_to_csv failed on {benchmark['name']}") + sys.exit(5) + else: + # If debug option is set, run as a function call to allow stepping with a debugger. + from dd_to_csv import main + + main([dd_folder, csv_folder]) + elif not path.exists(csv_folder): print(f"ERROR: --skip_csv is true but {csv_folder} does not exist") - sys.exit(1) + sys.exit(6) # Then run the tool args = [ @@ -170,12 +182,23 @@ def run_benchmark( else: args.append(xl_folder) start = time.time() - res = subprocess.run( - ["xl2times"] + args, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - ) + res = None + if not debug: + res = subprocess.run( + ["xl2times"] + args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + else: + # If debug option is set, run as a function call to allow stepping with a debugger. + from xl2times.__main__ import run, parse_args + + summary = run(parse_args(args)) + + # pack the results into a namedtuple pretending to be a return value from a subprocess call (as above). + res = namedtuple("stdout", ["stdout", "stderr", "returncode"])(summary, "", 0) + runtime = time.time() - start if verbose: @@ -188,7 +211,7 @@ def run_benchmark( if res.returncode != 0: print(res.stdout) print(f"ERROR: tool failed on {benchmark['name']}") - sys.exit(1) + sys.exit(7) with open(path.join(out_folder, "stdout"), "w") as f: f.write(res.stdout) @@ -211,6 +234,7 @@ def run_all_benchmarks( skip_main=False, skip_regression=False, verbose=False, + debug: bool = False, ): print("Running benchmarks", end="", flush=True) headers = ["Benchmark", "Time (s)", "GDX Diff", "Accuracy", "Correct", "Additional"] @@ -221,9 +245,10 @@ def run_all_benchmarks( skip_csv=skip_csv, run_gams=run_gams, verbose=verbose, + debug=debug, ) - with ProcessPoolExecutor() as executor: + with ProcessPoolExecutor(max_workers=max_workers) as executor: results = list(executor.map(run_a_benchmark, benchmarks)) print("\n\n" + tabulate(results, headers, floatfmt=".1f") + "\n") @@ -234,7 +259,9 @@ def run_all_benchmarks( # The rest of this script checks regressions against main # so skip it if we're already on main repo = git.Repo(".") # pyright: ignore - origin = repo.remotes.origin + origin = ( + repo.remotes.origin if "origin" in repo.remotes else repo.remotes[0] + ) # don't assume remote is called 'origin' origin.fetch("main") if "main" not in repo.heads: repo.create_head("main", origin.refs.main).set_tracking_branch(origin.refs.main) @@ -264,7 +291,7 @@ def run_all_benchmarks( else: if repo.is_dirty(): print("Your working directory is not clean. Skipping regression tests.") - sys.exit(1) + sys.exit(8) # Re-run benchmarks on main repo.heads.main.checkout() @@ -277,9 +304,10 @@ def run_all_benchmarks( run_gams=run_gams, out_folder="out-main", verbose=verbose, + debug=debug, ) - with ProcessPoolExecutor() as executor: + with ProcessPoolExecutor(max_workers) as executor: results_main = list(executor.map(run_a_benchmark, benchmarks)) # Print table with combined results to make comparison easier @@ -310,17 +338,33 @@ def run_all_benchmarks( ) if df.isna().values.any(): print(f"ERROR: number of benchmarks changed:\n{df}") - sys.exit(1) + sys.exit(9) accu_regressions = df[df["Correct"] < df["M Correct"]]["Benchmark"] addi_regressions = df[df["Additional"] > df["M Additional"]]["Benchmark"] time_regressions = df[df["Time (s)"] > 2 * df["M Time (s)"]]["Benchmark"] - runtime_change = df["Time (s)"].sum() - df["M Time (s)"].sum() - print(f"Change in runtime: {runtime_change:+.2f}") - correct_change = df["Correct"].sum() - df["M Correct"].sum() - print(f"Change in correct rows: {correct_change:+d}") - additional_change = df["Additional"].sum() - df["M Additional"].sum() - print(f"Change in additional rows: {additional_change:+d}") + our_time = df["Time (s)"].sum() + main_time = df["M Time (s)"].sum() + runtime_change = our_time - main_time + + print(f"Total runtime: {our_time:.2f}s (main: {main_time:.2f}s)") + print( + f"Change in runtime (negative == faster): {runtime_change:+.2f}s ({100*runtime_change/main_time:+.1f}%)" + ) + + our_correct = df["Correct"].sum() + main_correct = df["M Correct"].sum() + correct_change = our_correct - main_correct + print( + f"Change in correct rows (higher == better): {correct_change:+d} ({100*correct_change/main_correct:+.1f}%)" + ) + + our_additional_rows = df["Additional"].sum() + main_additional_rows = df["M Additional"].sum() + additional_change = our_additional_rows - main_additional_rows + print( + f"Change in additional rows: {additional_change:+d} ({100*additional_change/main_additional_rows:+.1f}%)" + ) if len(accu_regressions) + len(addi_regressions) + len(time_regressions) > 0: print() @@ -330,7 +374,7 @@ def run_all_benchmarks( print(f"ERROR: additional rows regressed on: {', '.join(addi_regressions)}") if not time_regressions.empty: print(f"ERROR: runtime regressed on: {', '.join(time_regressions)}") - sys.exit(1) + sys.exit(10) # TODO also check if any new tables are missing? print("No regressions. You're awesome!") @@ -385,6 +429,12 @@ def run_all_benchmarks( default=False, help="Print output of run on each benchmark", ) + args_parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Run each benchmark as a function call to allow a debugger to stop at breakpoints in benchmark runs.", + ) args = args_parser.parse_args() spec = yaml.safe_load(open(args.benchmarks_yaml)) @@ -392,17 +442,17 @@ def run_all_benchmarks( benchmark_names = [b["name"] for b in spec["benchmarks"]] if len(set(benchmark_names)) != len(benchmark_names): print("ERROR: Found duplicate name in benchmarks YAML file") - sys.exit(1) + sys.exit(11) if args.dd and args.times_dir is None: print("ERROR: --times_dir is required when using --dd") - sys.exit(1) + sys.exit(12) if args.run is not None: benchmark = next((b for b in spec["benchmarks"] if b["name"] == args.run), None) if benchmark is None: print(f"ERROR: could not find {args.run} in {args.benchmarks_yaml}") - sys.exit(1) + sys.exit(13) _, runtime, gms, acc, cor, add = run_benchmark( benchmark, @@ -411,6 +461,7 @@ def run_all_benchmarks( run_gams=args.dd, skip_csv=args.skip_csv, verbose=args.verbose, + debug=args.debug, ) print( f"Ran {args.run} in {runtime:.2f}s. {acc}% ({cor} correct, {add} additional).\n" @@ -426,4 +477,5 @@ def run_all_benchmarks( skip_main=args.skip_main, skip_regression=args.skip_regression, verbose=args.verbose, + debug=args.debug, ) diff --git a/xl2times/__main__.py b/xl2times/__main__.py index 0ca48d2..f6d4d76 100644 --- a/xl2times/__main__.py +++ b/xl2times/__main__.py @@ -8,6 +8,8 @@ import sys import time from typing import Dict, List + +from xl2times.utils import max_workers from . import datatypes from . import excel from . import transforms @@ -31,7 +33,7 @@ def convert_xl_to_times( use_pool = True if use_pool: - with ProcessPoolExecutor() as executor: + with ProcessPoolExecutor(max_workers) as executor: for result in executor.map(excel.extract_tables, input_files): raw_tables.extend(result) else: @@ -77,6 +79,7 @@ def convert_xl_to_times( transforms.process_flexible_import_tables, # slow transforms.process_user_constraint_tables, transforms.process_commodity_emissions, + transforms.generate_uc_properties, transforms.process_commodities, transforms.process_transform_availability, transforms.fill_in_missing_values, @@ -116,7 +119,7 @@ def convert_xl_to_times( end_time = time.time() sep = "\n\n" + "=" * 80 + "\n" if verbose else "" print( - f"{sep}transform {transform.__code__.co_name} took {end_time-start_time:.2f} seconds" + f"{sep}transform {transform.__code__.co_name} took {end_time - start_time:.2f} seconds" ) if verbose: if isinstance(output, list): @@ -159,7 +162,7 @@ def read_csv_tables(input_dir: str) -> Dict[str, DataFrame]: def compare( data: Dict[str, DataFrame], ground_truth: Dict[str, DataFrame], output_dir: str -): +) -> str: print( f"Ground truth contains {len(ground_truth)} tables," f" {sum(df.shape[0] for _, df in ground_truth.items())} rows" @@ -222,13 +225,15 @@ def compare( os.path.join(output_dir, table_name + "_missing.csv"), index=False, ) - - print( + result = ( f"{total_correct_rows / total_gt_rows :.1%} of ground truth rows present" f" in output ({total_correct_rows}/{total_gt_rows})" f", {total_additional_rows} additional rows" ) + print(result) + return result + def produce_times_tables( config: datatypes.Config, input: Dict[str, DataFrame] @@ -242,7 +247,7 @@ def produce_times_tables( for mapping in config.times_xl_maps: if not mapping.xl_name in input: print( - f"WARNING: Cannot produce table {mapping.times_name} because input table" + f"WARNING: Cannot produce table {mapping.times_name} because" f" {mapping.xl_name} does not exist" ) else: @@ -252,8 +257,8 @@ def produce_times_tables( for filter_col, filter_val in mapping.filter_rows.items(): if filter_col not in df.columns: print( - f"WARNING: Cannot produce table {mapping.times_name} because input" - f" table {mapping.xl_name} does not contain column {filter_col}" + f"WARNING: Cannot produce table {mapping.times_name} because" + f" {mapping.xl_name} does not contain column {filter_col}" ) # TODO break this loop and continue outer loop? filter = set(x.lower() for x in {filter_val}) @@ -265,8 +270,8 @@ def produce_times_tables( if not all(c in df.columns for c in mapping.xl_cols): missing = set(mapping.xl_cols) - set(df.columns) print( - f"WARNING: Cannot produce table {mapping.times_name} because input" - f" table {mapping.xl_name} does not contain the required columns" + f"WARNING: Cannot produce table {mapping.times_name} because" + f" {mapping.xl_name} does not contain the required columns" f" - {', '.join(missing)}" ) else: @@ -388,42 +393,14 @@ def dump_tables(tables: List, filename: str) -> List: return tables -def main(): - args_parser = argparse.ArgumentParser() - args_parser.add_argument( - "input", - nargs="*", - help="Either an input directory, or a list of input xlsx/xlsm files to process", - ) - args_parser.add_argument( - "--regions", - type=str, - default="", - help="Comma-separated list of regions to include in the model", - ) - args_parser.add_argument( - "--output_dir", type=str, default="output", help="Output directory" - ) - args_parser.add_argument( - "--ground_truth_dir", - type=str, - help="Ground truth directory to compare with output", - ) - args_parser.add_argument("--dd", action="store_true", help="Output DD files") - args_parser.add_argument( - "--only_read", - action="store_true", - help="Read xlsx/xlsm files and stop after outputting raw_tables.txt", - ) - args_parser.add_argument("--use_pkl", action="store_true") - args_parser.add_argument( - "-v", - "--verbose", - action="store_true", - help="Verbose mode: print tables after every transform", - ) - args = args_parser.parse_args() - +def run(args) -> str | None: + """ + Runs the xl2times conversion. + Args: + args: pre-parsed command line arguments + Returns: + comparison with ground-truth string if `ground_truth_dir` is provided, else None. + """ config = datatypes.Config( "times_mapping.txt", "times-info.json", @@ -436,7 +413,7 @@ def main(): if not isinstance(args.input, list) or len(args.input) < 1: print(f"ERROR: expected at least 1 input. Got {args.input}") - sys.exit(1) + sys.exit(-1) elif len(args.input) == 1: assert os.path.isdir(args.input[0]) input_files = [ @@ -471,8 +448,67 @@ def main(): if args.ground_truth_dir: ground_truth = read_csv_tables(args.ground_truth_dir) - compare(tables, ground_truth, args.output_dir) + comparison = compare(tables, ground_truth, args.output_dir) + return comparison + else: + return None + + +def parse_args(arg_list: None | list[str]) -> argparse.Namespace: + """Parses command line arguments. + + Args: + arg_list: List of command line arguments. Uses sys.argv (default argparse behaviour) if `None`. + + Returns: + argparse.Namespace: Parsed arguments. + """ + args_parser = argparse.ArgumentParser() + args_parser.add_argument( + "input", + nargs="*", + help="Either an input directory, or a list of input xlsx/xlsm files to process", + ) + args_parser.add_argument( + "--regions", + type=str, + default="", + help="Comma-separated list of regions to include in the model", + ) + args_parser.add_argument( + "--output_dir", type=str, default="output", help="Output directory" + ) + args_parser.add_argument( + "--ground_truth_dir", + type=str, + help="Ground truth directory to compare with output", + ) + args_parser.add_argument("--dd", action="store_true", help="Output DD files") + args_parser.add_argument( + "--only_read", + action="store_true", + help="Read xlsx/xlsm files and stop after outputting raw_tables.txt", + ) + args_parser.add_argument("--use_pkl", action="store_true") + args_parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Verbose mode: print tables after every transform", + ) + args = args_parser.parse_args(arg_list) + return args + + +def main(arg_list: None | list[str] = None) -> None: + """Main entry point for the xl2times package + Returns: + None. + """ + args = parse_args(arg_list) + run(args) if __name__ == "__main__": - main() + main(sys.argv[1:]) + sys.exit(0) diff --git a/xl2times/config/times_mapping.txt b/xl2times/config/times_mapping.txt index 279999e..8aba747 100644 --- a/xl2times/config/times_mapping.txt +++ b/xl2times/config/times_mapping.txt @@ -33,9 +33,9 @@ TOP_IRE[ALL_REG,COM,ALL_R,C,PRC] = Trade(Origin,IN,Destination,OUT,Process) TS_GROUP[REG,TSLVL,TS] = TimeSlices(Region,TSLVL,TS) TS_MAP[REG,PARENT,TS_MAP] = TimeSliceMap(Region,Parent,TimesliceMap) UC_ATTR[REG,UC_N,SIDE,UC_GRPTYPE,UC_NAME] = ~TODO(Region,UC_N,Side,UC_GRPTYPE,TODO-UC_NAME) -UC_N[UC_N,TEXT] = UserConstraints(UC_N,UC_Desc) -UC_R_EACH[REG,UC_N] = ~TODO(Region,UC_N) -UC_R_SUM[REG,UC_N] = ~TODO(Region,UC_N) +UC_N[UC_N,TEXT] = UserConstraints(Name,Description) +UC_R_EACH[REG,UC_N] = UserConstraints(Region,Name,Reg_Action:R_E) +UC_R_SUM[REG,UC_N] = UserConstraints(Region,Name,Reg_Action:R_S) UNITS[UNITS] = Units(Unit) UNITS_ACT[UNITS] = Units(Unit,Type:activity) UNITS_CAP[UNITS] = Units(Unit,Type:capacity) diff --git a/xl2times/datatypes.py b/xl2times/datatypes.py index 01402d4..13f2d31 100644 --- a/xl2times/datatypes.py +++ b/xl2times/datatypes.py @@ -154,6 +154,7 @@ class TimesModel: trade: DataFrame = field(default_factory=DataFrame) attributes: DataFrame = field(default_factory=DataFrame) user_constraints: DataFrame = field(default_factory=DataFrame) + uc_attributes: DataFrame = field(default_factory=DataFrame) ts_tslvl: DataFrame = field(default_factory=DataFrame) ts_map: DataFrame = field(default_factory=DataFrame) time_periods: DataFrame = field(default_factory=DataFrame) @@ -250,8 +251,12 @@ def create_mapping(entity): times_cols = entity["indexes"] + ["VALUE"] xl_cols = entity["mapping"] + ["value"] # TODO map in json col_map = dict(zip(times_cols, xl_cols)) - # If tag starts with UC, then the data is in UC_T, else FI_T - xl_name = Tag.uc_t if entity["name"].lower().startswith("uc") else Tag.fi_t + # If tag starts with UC, then the data is in UCAttributes, else Attributes + xl_name = ( + "UCAttributes" + if entity["name"].lower().startswith("uc") + else "Attributes" + ) return TimesXlMap( times_name=entity["name"], times_cols=times_cols, diff --git a/xl2times/transforms.py b/xl2times/transforms.py index b7d130e..2e52825 100644 --- a/xl2times/transforms.py +++ b/xl2times/transforms.py @@ -10,6 +10,10 @@ from concurrent.futures import ProcessPoolExecutor import time from functools import reduce + +from tqdm import tqdm + +from .utils import max_workers from . import datatypes from . import utils @@ -594,6 +598,76 @@ def process_user_constraint_table( return [process_user_constraint_table(t) for t in tables] +def generate_uc_properties( + config: datatypes.Config, + tables: List[datatypes.EmbeddedXlTable], + model: datatypes.TimesModel, +) -> List[datatypes.EmbeddedXlTable]: + """ + Generate a dataframe containing User Constraint properties + """ + + uc_tables = [table for table in tables if table.tag == datatypes.Tag.uc_t] + columns = ["uc_n", "uc_desc", "region", "reg_action", "ts_action"] + user_constraints = pd.DataFrame(columns=columns) + # Create df_list to hold DataFrames that will be concatenated later on + df_list = list() + for uc_table in uc_tables: + # Single-column DataFrame with unique UC names + df = uc_table.dataframe.loc[:, ["uc_n"]].drop_duplicates(keep="first") + # Supplement UC names with descriptions, if they exist + df = df.merge( + uc_table.dataframe.loc[:, ["uc_n", "uc_desc"]] + .drop_duplicates(keep="first") + .dropna(), + how="left", + ) + # Add info on how regions and timeslices should be treated by the UCs + for key in uc_table.uc_sets.keys(): + if key.startswith("R_"): + df["reg_action"] = key + df["region"] = uc_table.uc_sets[key].upper().strip() + elif key.startswith("T_"): + df["ts_action"] = key + + df_list.append(df) + # Do further processing if df_list is not empty + if df_list: + # Create a single DataFrame with all UCs + user_constraints = pd.concat(df_list, ignore_index=True) + + # Use name to populate description if it is missing + index = user_constraints["uc_desc"].isna() + if any(index): + user_constraints["uc_desc"][index] = user_constraints["uc_n"][index] + + # TODO: Can this (until user_constraints.explode) become a utility function? + # Handle allregions by substituting it with a list of internal regions + index = user_constraints["region"].str.lower() == "allregions" + if any(index): + user_constraints["region"][index] = [model.internal_regions] + + # Handle comma-separated regions + index = user_constraints["region"].str.contains(",").fillna(value=False) + if any(index): + user_constraints["region"][index] = user_constraints.apply( + lambda row: [ + region + for region in str(row["region"]).split(",") + if region in model.internal_regions + ], + axis=1, + ) + # Explode regions + user_constraints = user_constraints.explode("region", ignore_index=True) + + model.user_constraints = user_constraints.rename( + columns={"uc_n": "name", "uc_desc": "description"} + ) + + return tables + + def fill_in_missing_values( config: datatypes.Config, tables: List[datatypes.EmbeddedXlTable], @@ -671,15 +745,15 @@ def fill_in_missing_values_table(table): if matches is not None: book = matches.group(1) if book in vt_regions: - df.fillna({colname: ",".join(vt_regions[book])}, inplace=True) + df = df.fillna({colname: ",".join(vt_regions[book])}) else: print(f"WARNING: book name {book} not in BookRegions_Map") else: - df.fillna({colname: ",".join(model.internal_regions)}, inplace=True) + df = df.fillna({colname: ",".join(model.internal_regions)}) elif colname == "year": - df.fillna({colname: start_year}, inplace=True) + df = df.fillna({colname: start_year}) elif colname == "currency": - df.fillna({colname: currency}, inplace=True) + df = df.fillna({colname: currency}) return replace(table, dataframe=df) @@ -869,9 +943,11 @@ def complete_dictionary( "TimeSlices": model.ts_tslvl, "TimeSliceMap": model.ts_map, "UserConstraints": model.user_constraints, + "UCAttributes": model.uc_attributes, "Units": model.units, }.items(): - tables[k] = v + if not v.empty: + tables[k] = v return tables @@ -1005,19 +1081,9 @@ def generate_commodity_groups( # Commodity groups by process, region and commodity comm_groups = pd.merge(prc_top, comm_set, on=["region", "commodity"]) - comm_groups["commoditygroup"] = 0 - # Store the number of IN/OUT commodities of the same type per Region and Process in CommodityGroup - for region in comm_groups["region"].unique(): - i_reg = comm_groups["region"] == region - for process in comm_groups[i_reg]["process"].unique(): - i_reg_prc = i_reg & (comm_groups["process"] == process) - for cset in comm_groups[i_reg_prc]["csets"].unique(): - i_reg_prc_cset = i_reg_prc & (comm_groups["csets"] == cset) - for io in ["IN", "OUT"]: - i_reg_prc_cset_io = i_reg_prc_cset & (comm_groups["io"] == io) - comm_groups.loc[i_reg_prc_cset_io, "commoditygroup"] = sum( - i_reg_prc_cset_io - ) + + # Add columns for the number of IN/OUT commodities of each type + _count_comm_group_vectorised(comm_groups) def name_comm_group(df): """ @@ -1034,24 +1100,8 @@ def name_comm_group(df): # Replace commodity group member count with the name comm_groups["commoditygroup"] = comm_groups.apply(name_comm_group, axis=1) - # Determine default PCG according to Veda - comm_groups["DefaultVedaPCG"] = None - for region in comm_groups["region"].unique(): - i_reg = comm_groups["region"] == region - for process in comm_groups[i_reg]["process"]: - i_reg_prc = i_reg & (comm_groups["process"] == process) - default_set = False - for io in ["OUT", "IN"]: - if default_set: - break - i_reg_prc_io = i_reg_prc & (comm_groups["io"] == io) - for cset in csets_ordered_for_pcg: - i_reg_prc_io_cset = i_reg_prc_io & (comm_groups["csets"] == cset) - df = comm_groups[i_reg_prc_io_cset] - if not df.empty: - comm_groups.loc[i_reg_prc_io_cset, "DefaultVedaPCG"] = True - default_set = True - break + # Determine default PCG according to Veda's logic + comm_groups = _process_comm_groups_vectorised(comm_groups, csets_ordered_for_pcg) # Add standard Veda PCGS named contrary to name_comm_group if reg_prc_veda_pcg.shape[0]: @@ -1085,6 +1135,62 @@ def name_comm_group(df): return tables +def _count_comm_group_vectorised(comm_groups: pd.DataFrame) -> None: + """ + Store the number of IN/OUT commodities of the same type per Region and Process in CommodityGroup. + `comm_groups` is modified in-place + Args: + comm_groups: 'Process' DataFrame with additional columns "commoditygroup" + """ + comm_groups["commoditygroup"] = 0 + + comm_groups["commoditygroup"] = ( + comm_groups.groupby(["region", "process", "csets", "io"]).transform("count") + )["commoditygroup"] + # set comoditygroup to 0 for io rows that aren't IN or OUT + comm_groups.loc[~comm_groups["io"].isin(["IN", "OUT"]), "commoditygroup"] = 0 + + +def _process_comm_groups_vectorised( + comm_groups: pd.DataFrame, csets_ordered_for_pcg: list[str] +) -> pd.DataFrame: + """Sets the first commodity group in the list of csets_ordered_for_pcg as the default pcg for each region/process/io combination, + but setting the io="OUT" subset as default before "IN". + + See: + Section 3.7.2.2, pg 80. of `TIMES Documentation PART IV` for details. + Args: + comm_groups: 'Process' DataFrame with columns ["region", "process", "io", "csets", "commoditygroup"] + csets_ordered_for_pcg: List of csets in the order they should be considered for default pcg + Returns: + Processed DataFrame with a new column "DefaultVedaPCG" set to True for the default pcg in each region/process/io combination. + """ + + def _set_default_veda_pcg(group): + """For a given [region, process] group, default group is set as the first cset in the `csets_ordered_for_pcg` list, which is an output, if + one exists, otherwise the first input.""" + if not group["csets"].isin(csets_ordered_for_pcg).all(): + return group + + for io in ["OUT", "IN"]: + for cset in csets_ordered_for_pcg: + group.loc[ + (group["io"] == io) & (group["csets"] == cset), "DefaultVedaPCG" + ] = True + if group["DefaultVedaPCG"].any(): + break + return group + + comm_groups["DefaultVedaPCG"] = None + comm_groups_subset = comm_groups.groupby( + ["region", "process"], sort=False, as_index=False + ).apply(_set_default_veda_pcg) + comm_groups_subset = comm_groups_subset.reset_index( + level=0, drop=True + ).sort_index() # back to the original index and row order + return comm_groups_subset + + def complete_commodity_groups( config: datatypes.Config, tables: Dict[str, DataFrame], @@ -2288,7 +2394,7 @@ def convert_aliases( # TODO: do this earlier model.attributes = tables[datatypes.Tag.fi_t] if datatypes.Tag.uc_t in tables.keys(): - model.user_constraints = tables[datatypes.Tag.uc_t] + model.uc_attributes = tables[datatypes.Tag.uc_t] return tables @@ -2432,19 +2538,6 @@ def apply_more_fixups( ) tables[datatypes.Tag.fi_t] = df - df = tables.get(datatypes.Tag.uc_t) - if df is not None: - # TODO: Handle defaults in a general way. - # Use uc_n value if uc_desc is missing - for uc_n in df["uc_n"].unique(): - index = df["uc_n"] == uc_n - if all(df["uc_desc"][index].isna()): - # Populate the first row only - if any(index): - df.at[list(index).index(True), "uc_desc"] = uc_n - - tables[datatypes.Tag.uc_t] = df - return tables @@ -2453,5 +2546,5 @@ def expand_rows_parallel( tables: List[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, ) -> List[datatypes.EmbeddedXlTable]: - with ProcessPoolExecutor() as executor: + with ProcessPoolExecutor(max_workers) as executor: return list(executor.map(expand_rows, tables)) diff --git a/xl2times/utils.py b/xl2times/utils.py index def64cc..bec4cb9 100644 --- a/xl2times/utils.py +++ b/xl2times/utils.py @@ -1,3 +1,4 @@ +import os import re from dataclasses import replace from math import log10, floor @@ -10,6 +11,10 @@ from . import datatypes +# prevent excessive number of processes in Windows and high cpu-count machines +# TODO make this a cli param or global setting? +max_workers: int = 4 if os.name == "nt" else min(16, os.cpu_count() or 16) + def apply_composite_tag(table: datatypes.EmbeddedXlTable) -> datatypes.EmbeddedXlTable: """ @@ -31,7 +36,7 @@ def apply_composite_tag(table: datatypes.EmbeddedXlTable) -> datatypes.EmbeddedX (newtag, varname) = table.tag.split(":") varname = varname.strip() df = table.dataframe.copy() - df["attribute"].fillna(varname, inplace=True) + df["attribute"] = df["attribute"].fillna(varname) return replace(table, tag=newtag, dataframe=df) else: return table