diff --git a/.gitignore b/.gitignore index 63506f7d..4f16c973 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ docs/api/ *.log /profile.* xl2times/.cache/ +*.log.zip diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 06e225b8..c99e5011 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,11 +5,20 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace + - repo: https://github.com/psf/black rev: 22.8.0 hooks: - id: black + - repo: https://github.com/RobertCraigie/pyright-python rev: v1.1.304 hooks: - id: pyright + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.2 + hooks: + - id: ruff + types_or: [ python, pyi, jupyter ] + args: [ --fix, --exit-non-zero-on-fix ] diff --git a/pyproject.toml b/pyproject.toml index 01429a66..5c3e575f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,8 @@ dev = [ "tabulate", "pytest", "pytest-cov", - "poethepoet" + "poethepoet", + "ruff" ] [project.urls] @@ -61,3 +62,37 @@ benchmark = { cmd = "python utils/run_benchmarks.py benchmarks.yml --run", help benchmark_all = { shell = "python utils/run_benchmarks.py benchmarks.yml --verbose | tee out.txt", help = "Run the project", interpreter = "posix" } lint = { shell = "git add .pre-commit-config.yaml & pre-commit run", help = "Run pre-commit hooks", interpreter = "posix" } test = { cmd = "pytest --cov-report term --cov-report html --cov=xl2times --cov=utils", help = "Run unit tests with pytest" } + + +# Config for various pre-commit checks are below +# Ruff linting rules - see https://github.com/charliermarsh/ruff and https://beta.ruff.rs/docs/rules/ +[tool.ruff] +target-version = "py311" +line-length = 88 + +# Option 1: use basic rules only. +lint.select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "UP", # pyupgrade + "N", # pep8 naming + "I", # isort + "TID", # tidy imports + "UP", # pyupgrade + "NPY", # numpy style + "PL", # pylint +# "PD", # pandas style # TODO enable later +# "C90", # code complexity # TODO enable later +] + +# Add specific rule codes/groups here to ignore them, or add a '#noqa' comment to the line of code to skip all checks. +lint.ignore = [ + "PLR", # complexity rules + "PD901", "PD011", # pandas 'df'' + "E501", # line too long, handled by black +] + +# Ruff rule-specific options: +[tool.ruff.mccabe] +max-complexity = 12 # increase max function 'complexity' diff --git a/tests/test_transforms.py b/tests/test_transforms.py index e512cefa..63b8cc18 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -2,16 +2,16 @@ import pandas as pd -from xl2times import transforms, utils, datatypes +from xl2times import datatypes, transforms, utils from xl2times.transforms import ( - _process_comm_groups_vectorised, _count_comm_group_vectorised, + _match_wildcards, + _process_comm_groups_vectorised, + commodity_map, expand_rows, get_matching_commodities, get_matching_processes, - _match_wildcards, process_map, - commodity_map, ) logger = utils.get_logger() diff --git a/tests/test_utils.py b/tests/test_utils.py index 9ef1ec28..71a0f27e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ -from xl2times import utils import pandas as pd +from xl2times import utils + class TestUtils: def test_explode(self): diff --git a/utils/dd_to_csv.py b/utils/dd_to_csv.py index f3037f17..7d262279 100644 --- a/utils/dd_to_csv.py +++ b/utils/dd_to_csv.py @@ -1,10 +1,9 @@ import argparse -import sys -from collections import defaultdict import json import os +import sys +from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple, Union import numpy as np import pandas as pd @@ -13,7 +12,7 @@ def parse_parameter_values_from_file( path: Path, -) -> Tuple[Dict[str, List], Dict[str, set]]: +) -> tuple[dict[str, list], dict[str, set]]: """ Parse *.dd to turn it into CSV format There are parameters and sets, and each has a slightly different format @@ -35,11 +34,11 @@ def parse_parameter_values_from_file( """ - data = list(open(path, "r")) + data = list(open(path)) data = [line.rstrip() for line in data] - param_value_dict: Dict[str, List] = dict() - set_data_dict: Dict[str, set] = dict() + param_value_dict: dict[str, list] = dict() + set_data_dict: dict[str, set] = dict() index = 0 while index < len(data): if data[index].startswith("PARAMETER"): @@ -124,8 +123,8 @@ def parse_parameter_values_from_file( def save_data_with_headers( - param_data_dict: Dict[str, Union[pd.DataFrame, List[str]]], - headers_data: Dict[str, List[str]], + param_data_dict: dict[str, pd.DataFrame | list[str]], + headers_data: dict[str, list[str]], save_dir: str, ) -> None: """ @@ -157,7 +156,7 @@ def save_data_with_headers( return -def generate_headers_by_attr() -> Dict[str, List[str]]: +def generate_headers_by_attr() -> dict[str, list[str]]: with open("xl2times/config/times-info.json") as f: attributes = json.load(f) @@ -173,7 +172,7 @@ def generate_headers_by_attr() -> Dict[str, List[str]]: def convert_dd_to_tabular( - basedir: str, output_dir: str, headers_by_attr: Dict[str, List[str]] + basedir: str, output_dir: str, headers_by_attr: dict[str, list[str]] ) -> None: dd_files = [p for p in Path(basedir).rglob("*.dd")] @@ -201,15 +200,15 @@ def convert_dd_to_tabular( os.makedirs(set_path, exist_ok=True) # Extract headers with key=param_name and value=List[attributes] - lines = list(open("xl2times/config/times_mapping.txt", "r")) + lines = list(open("xl2times/config/times_mapping.txt")) headers_data = headers_by_attr # The following will overwrite data obtained from headers_by_attr # TODO: Remove once migration is done? for line in lines: - line = line.strip() - if line != "": - param_name = line.split("[")[0] - attributes = line.split("[")[1].split("]")[0].split(",") + ln = line.strip() + if ln != "": + param_name = ln.split("[")[0] + attributes = ln.split("[")[1].split("]")[0].split(",") headers_data[param_name] = [*attributes] save_data_with_headers(all_parameters, headers_data, param_path) diff --git a/utils/run_benchmarks.py b/utils/run_benchmarks.py index 060a0fc8..bbf18e18 100644 --- a/utils/run_benchmarks.py +++ b/utils/run_benchmarks.py @@ -9,14 +9,14 @@ from concurrent.futures import ProcessPoolExecutor from functools import partial from os import path, symlink -from typing import Any, Tuple +from typing import Any import git import pandas as pd import yaml +from dd_to_csv import main from tabulate import tabulate -from dd_to_csv import main from xl2times import utils from xl2times.__main__ import parse_args, run from xl2times.utils import max_workers @@ -24,7 +24,7 @@ logger = utils.get_logger() -def parse_result(output: str) -> Tuple[float, int, int]: +def parse_result(output: str) -> tuple[float, int, int]: # find pattern in multiline string m = re.findall( r"(\d+\.\d)% of ground truth rows present in output \((\d+)/(\d+)\), (\d+) additional rows", @@ -65,6 +65,7 @@ def run_gams_gdxdiff( stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, + check=False, ) if res.returncode != 0: logger.info(res.stdout) @@ -96,6 +97,7 @@ def run_gams_gdxdiff( stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, + check=False, ) if res.returncode != 0: logger.info(res.stdout) @@ -119,6 +121,7 @@ def run_gams_gdxdiff( stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, + check=False, ) if verbose: logger.info(res.stdout) @@ -138,7 +141,7 @@ def run_benchmark( out_folder: str = "out", verbose: bool = False, debug: bool = False, -) -> Tuple[str, float, str, float, int, int]: +) -> tuple[str, float, str, float, int, int]: xl_folder = path.join(benchmarks_folder, "xlsx", benchmark["input_folder"]) dd_folder = path.join(benchmarks_folder, "dd", benchmark["dd_folder"]) csv_folder = path.join(benchmarks_folder, "csv", benchmark["name"]) @@ -160,6 +163,7 @@ def run_benchmark( stderr=subprocess.STDOUT, text=True, shell=True if os.name == "nt" else False, + check=False, ) if res.returncode != 0: # Remove partial outputs @@ -191,7 +195,7 @@ def run_benchmark( if "regions" in benchmark: args.extend(["--regions", benchmark["regions"]]) if "inputs" in benchmark: - args.extend((path.join(xl_folder, b) for b in benchmark["inputs"])) + args.extend(path.join(xl_folder, b) for b in benchmark["inputs"]) else: args.append(xl_folder) start = time.time() @@ -203,6 +207,7 @@ def run_benchmark( stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, + check=False, ) else: # If debug option is set, run as a function call to allow stepping with a debugger. @@ -295,7 +300,6 @@ def run_all_benchmarks( for benchmark in benchmarks: with open( path.join(benchmarks_folder, "out-main", benchmark["name"], "stdout"), - "r", ) as f: result = parse_result(f.readlines()[-1]) # Use a fake runtime and GAMS result @@ -330,7 +334,8 @@ def run_all_benchmarks( results_main = list(executor.map(run_a_benchmark, benchmarks)) # Print table with combined results to make comparison easier - trunc = lambda s: s[:10] + "\u2026" if len(s) > 10 else s + trunc = lambda s: s[:10] + "\u2026" if len(s) > 10 else s # noqa + combined_results = [ ( f"{b:<20}", diff --git a/xl2times/__main__.py b/xl2times/__main__.py index a1ed5d8b..fcd3326e 100644 --- a/xl2times/__main__.py +++ b/xl2times/__main__.py @@ -1,21 +1,20 @@ import argparse -from concurrent.futures import ProcessPoolExecutor -from datetime import datetime import hashlib -from pandas.core.frame import DataFrame -import pandas as pd -import pickle -from pathlib import Path import os +import pickle import sys import time -from typing import Dict, List +from concurrent.futures import ProcessPoolExecutor +from datetime import datetime +from pathlib import Path + +import pandas as pd +from pandas.core.frame import DataFrame from xl2times import __file__ as xl2times_file_path from xl2times.utils import max_workers -from . import datatypes, utils -from . import excel -from . import transforms + +from . import datatypes, excel, transforms, utils logger = utils.get_logger() @@ -24,7 +23,7 @@ os.makedirs(cache_dir, exist_ok=True) -def _read_xlsx_cached(filename: str) -> List[datatypes.EmbeddedXlTable]: +def _read_xlsx_cached(filename: str) -> list[datatypes.EmbeddedXlTable]: """Extract EmbeddedXlTables from xlsx file (cached). Since excel.extract_tables is quite slow, we cache its results in `cache_dir`. @@ -50,14 +49,14 @@ def _read_xlsx_cached(filename: str) -> List[datatypes.EmbeddedXlTable]: def convert_xl_to_times( - input_files: List[str], + input_files: list[str], output_dir: str, config: datatypes.Config, model: datatypes.TimesModel, no_cache: bool, verbose: bool = False, stop_after_read: bool = False, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: start_time = datetime.now() with ProcessPoolExecutor(max_workers) as executor: raw_tables = executor.map( @@ -167,7 +166,7 @@ def convert_xl_to_times( return output -def write_csv_tables(tables: Dict[str, DataFrame], output_dir: str): +def write_csv_tables(tables: dict[str, DataFrame], output_dir: str): os.makedirs(output_dir, exist_ok=True) for item in os.listdir(output_dir): if item.endswith(".csv"): @@ -176,7 +175,7 @@ def write_csv_tables(tables: Dict[str, DataFrame], output_dir: str): df.to_csv(os.path.join(output_dir, tablename + "_output.csv"), index=False) -def read_csv_tables(input_dir: str) -> Dict[str, DataFrame]: +def read_csv_tables(input_dir: str) -> dict[str, DataFrame]: result = {} for filename in os.listdir(input_dir): result[filename.split(".")[0]] = pd.read_csv( @@ -186,7 +185,7 @@ def read_csv_tables(input_dir: str) -> Dict[str, DataFrame]: def compare( - data: Dict[str, DataFrame], ground_truth: Dict[str, DataFrame], output_dir: str + data: dict[str, DataFrame], ground_truth: dict[str, DataFrame], output_dir: str ) -> str: logger.info( f"Ground truth contains {len(ground_truth)} tables," @@ -261,8 +260,8 @@ def compare( def produce_times_tables( - config: datatypes.Config, input: Dict[str, DataFrame] -) -> Dict[str, DataFrame]: + config: datatypes.Config, input: dict[str, DataFrame] +) -> dict[str, DataFrame]: logger.info( f"produce_times_tables: {len(input)} tables incoming," f" {sum(len(value) for (_, value) in input.items())} rows" @@ -286,7 +285,7 @@ def produce_times_tables( f" {mapping.xl_name} does not contain column {filter_col}" ) # TODO break this loop and continue outer loop? - filter = set(x.lower() for x in {filter_val}) + filter = set(x.lower() for x in (filter_val,)) i = df[filter_col].str.lower().isin(filter) df = df.loc[i, :] # TODO find the correct tech group @@ -330,7 +329,7 @@ def produce_times_tables( def write_dd_files( - tables: Dict[str, DataFrame], config: datatypes.Config, output_dir: str + tables: dict[str, DataFrame], config: datatypes.Config, output_dir: str ): encoding = "utf-8" os.makedirs(output_dir, exist_ok=True) @@ -397,10 +396,10 @@ def strip_filename_prefix(table, prefix): return table -def dump_tables(tables: List, filename: str) -> List: +def dump_tables(tables: list, filename: str) -> list: os.makedirs(os.path.dirname(filename), exist_ok=True) with open(filename, "w") as text_file: - for t in tables if isinstance(tables, List) else tables.items(): + for t in tables if isinstance(tables, list) else tables.items(): if isinstance(t, datatypes.EmbeddedXlTable): tag = t.tag text_file.write(f"sheetname: {t.sheetname}\n") diff --git a/xl2times/datatypes.py b/xl2times/datatypes.py index f1f69d8e..7d3c4336 100644 --- a/xl2times/datatypes.py +++ b/xl2times/datatypes.py @@ -1,11 +1,12 @@ +import json +import re from collections import defaultdict +from collections.abc import Iterable from dataclasses import dataclass, field +from enum import Enum from importlib import resources from itertools import chain -import json -import re -from typing import Dict, Iterable, List, Set, Tuple -from enum import Enum + from loguru import logger from pandas.core.frame import DataFrame @@ -85,7 +86,7 @@ class EmbeddedXlTable: """ tag: str - uc_sets: Dict[str, str] + uc_sets: dict[str, str] sheetname: str range: str filename: str @@ -132,11 +133,11 @@ class TimesXlMap: """ times_name: str - times_cols: List[str] + times_cols: list[str] xl_name: str # TODO once we move away from times_mapping.txt, make this type Tag - xl_cols: List[str] - col_map: Dict[str, str] - filter_rows: Dict[str, str] + xl_cols: list[str] + col_map: dict[str, str] + filter_rows: dict[str, str] @dataclass @@ -145,8 +146,8 @@ class TimesModel: This class contains all the information about the processed TIMES model. """ - internal_regions: Set[str] = field(default_factory=set) - all_regions: Set[str] = field(default_factory=set) + internal_regions: set[str] = field(default_factory=set) + all_regions: set[str] = field(default_factory=set) processes: DataFrame = field(default_factory=DataFrame) commodities: DataFrame = field(default_factory=DataFrame) commodity_groups: DataFrame = field(default_factory=DataFrame) @@ -160,14 +161,14 @@ class TimesModel: time_periods: DataFrame = field(default_factory=DataFrame) units: DataFrame = field(default_factory=DataFrame) start_year: int = field(default_factory=int) - files: Set[str] = field(default_factory=set) + files: set[str] = field(default_factory=set) @property - def external_regions(self) -> Set[str]: + def external_regions(self) -> set[str]: return self.all_regions.difference(self.internal_regions) @property - def data_years(self) -> Set[int]: + def data_years(self) -> set[int]: """ data_years are years for which there is data specified. """ @@ -179,14 +180,14 @@ def data_years(self) -> Set[int]: return {y for y in data_years if y >= 1000} @property - def past_years(self) -> Set[int]: + def past_years(self) -> set[int]: """ Pastyears is the set of all years before start_year. """ return {x for x in self.data_years if x < self.start_year} @property - def model_years(self) -> Set[int]: + def model_years(self) -> set[int]: """ model_years is the union of past_years and the representative years of the model (middleyears). """ @@ -198,26 +199,26 @@ class Config: the mapping betwen excel tables and output tables, categories of tables, etc. """ - times_xl_maps: List[TimesXlMap] + times_xl_maps: list[TimesXlMap] dd_table_order: Iterable[str] - all_attributes: Set[str] - attr_aliases: Set[str] + all_attributes: set[str] + attr_aliases: set[str] # For each tag, this dictionary maps each column alias to the normalized name - column_aliases: Dict[Tag, Dict[str, str]] + column_aliases: dict[Tag, dict[str, str]] # For each tag, this dictionary specifies comment row symbols by column name - row_comment_chars: Dict[Tag, Dict[str, list]] + row_comment_chars: dict[Tag, dict[str, list]] # List of tags for which empty tables should be discarded discard_if_empty: Iterable[Tag] - veda_attr_defaults: Dict[str, Dict[str, list]] + veda_attr_defaults: dict[str, dict[str, list]] # Known columns for each tag - known_columns: Dict[Tag, Set[str]] + known_columns: dict[Tag, set[str]] # Query columns for each tag - query_columns: Dict[Tag, Set[str]] + query_columns: dict[Tag, set[str]] # Required columns for each tag - required_columns: Dict[Tag, Set[str]] + required_columns: dict[Tag, set[str]] # Names of regions to include in the model; if empty, all regions are included. - filter_regions: Set[str] - times_sets: Dict[str, List[str]] + filter_regions: set[str] + times_sets: dict[str, list[str]] def __init__( self, @@ -256,7 +257,7 @@ def __init__( @staticmethod def _read_times_sets( times_sets_file: str, - ) -> Dict[str, List[str]]: + ) -> dict[str, list[str]]: # Read times_sets_file with resources.open_text("xl2times.config", times_sets_file) as f: times_sets = json.load(f) @@ -266,7 +267,7 @@ def _read_times_sets( @staticmethod def _process_times_info( times_info_file: str, - ) -> Tuple[Iterable[str], Set[str], List[TimesXlMap]]: + ) -> tuple[Iterable[str], set[str], list[TimesXlMap]]: # Read times_info_file and compute dd_table_order: # We output tables in order by categories: set, subset, subsubset, md-set, and parameter with resources.open_text("xl2times.config", times_info_file) as f: @@ -279,7 +280,7 @@ def _process_times_info( if unknown_cats: logger.warning(f"Unknown categories in times-info.json: {unknown_cats}") dd_table_order = chain.from_iterable( - (sorted(cat_to_tables[c]) for c in categories) + sorted(cat_to_tables[c]) for c in categories ) # Compute the set of all attributes, i.e. all entities with category = parameter @@ -320,7 +321,7 @@ def create_mapping(entity): return dd_table_order, attributes, param_mappings @staticmethod - def _read_mappings(filename: str) -> List[TimesXlMap]: + def _read_mappings(filename: str) -> list[TimesXlMap]: """ Function to load mappings from a text file between the excel sheets we use as input and the tables we give as output. The mappings have the following structure: @@ -393,13 +394,13 @@ def _read_mappings(filename: str) -> List[TimesXlMap]: @staticmethod def _read_veda_tags_info( veda_tags_file: str, - ) -> Tuple[ - Dict[Tag, Dict[str, str]], - Dict[Tag, Dict[str, list]], + ) -> tuple[ + dict[Tag, dict[str, str]], + dict[Tag, dict[str, list]], Iterable[Tag], - Dict[Tag, Set[str]], - Dict[Tag, Set[str]], - Dict[Tag, Set[str]], + dict[Tag, set[str]], + dict[Tag, set[str]], + dict[Tag, set[str]], ]: def to_tag(s: str) -> Tag: # The file stores the tag name in lowercase, and without the ~ @@ -483,7 +484,7 @@ def to_tag(s: str) -> Tag: @staticmethod def _read_veda_attr_defaults( veda_attr_defaults_file: str, - ) -> Tuple[Dict[str, Dict[str, list]], Set[str]]: + ) -> tuple[dict[str, dict[str, list]], set[str]]: # Read veda_tags_file with resources.open_text("xl2times.config", veda_attr_defaults_file) as f: defaults = json.load(f) @@ -528,7 +529,7 @@ def _read_veda_attr_defaults( return veda_attr_defaults, attr_aliases @staticmethod - def _read_regions_filter(regions_list: str) -> Set[str]: + def _read_regions_filter(regions_list: str) -> set[str]: if regions_list == "": return set() else: diff --git a/xl2times/excel.py b/xl2times/excel.py index befde29b..de71e320 100644 --- a/xl2times/excel.py +++ b/xl2times/excel.py @@ -1,16 +1,16 @@ +import re +import time + +import numpy from loguru import logger from openpyxl import load_workbook from openpyxl.worksheet.cell_range import CellRange -from typing import Dict, List -import time from pandas.core.frame import DataFrame -import numpy -import re -from . import datatypes -from . import utils + +from . import datatypes, utils -def extract_tables(filename: str) -> List[datatypes.EmbeddedXlTable]: +def extract_tables(filename: str) -> list[datatypes.EmbeddedXlTable]: """ This function calls the extract_table function on each individual table in each worksheet of the given excel file. @@ -68,7 +68,7 @@ def extract_tables(filename: str) -> List[datatypes.EmbeddedXlTable]: def extract_table( tag_row: int, tag_col: int, - uc_sets: Dict[str, str], + uc_sets: dict[str, str], df: DataFrame, sheetname: str, filename: str, diff --git a/xl2times/transforms.py b/xl2times/transforms.py index 18d8c69f..a51e1213 100644 --- a/xl2times/transforms.py +++ b/xl2times/transforms.py @@ -1,22 +1,20 @@ import re import time from collections import defaultdict +from collections.abc import Callable from concurrent.futures import ProcessPoolExecutor from dataclasses import replace from functools import reduce from itertools import groupby from pathlib import Path -from typing import Callable -from typing import Dict, List, Set import pandas as pd from loguru import logger -from more_itertools import locate, one +from more_itertools import locate from pandas.core.frame import DataFrame from tqdm import tqdm -from . import datatypes -from . import utils +from . import datatypes, utils from .utils import max_workers csets_ordered_for_pcg = ["DEM", "MAT", "NRG", "ENV", "FIN"] @@ -50,9 +48,9 @@ def remove_comment_rows( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Remove comment rows from all the tables. Assumes table dataframes are not empty. """ @@ -76,7 +74,7 @@ def remove_comment_rows( def _remove_df_comment_rows( df: pd.DataFrame, - comment_chars: Dict[str, list], + comment_chars: dict[str, list], ) -> None: """ Modify a dataframe in-place by deleting rows with cells starting with symbols @@ -130,9 +128,9 @@ def remove_comment_cols(table: datatypes.EmbeddedXlTable) -> datatypes.EmbeddedX def remove_exreg_cols( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Remove external region columns from all the tables except tradelinks. """ @@ -173,9 +171,9 @@ def remove_table_exreg_cols( def remove_tables_with_formulas( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Return a modified copy of 'tables' where tables with formulas (as identified by an initial '=') have deleted from the list. @@ -198,9 +196,9 @@ def has_formulas(table): def validate_input_tables( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Perform some basic validation (tag names are valid, no duplicate column labels), and remove empty tables (for recognized tags). @@ -236,9 +234,9 @@ def discard(table): def revalidate_input_tables( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Perform further validation of input tables by checking whether required columns are present / non-empty. Remove tables without required columns or if they are empty. @@ -277,9 +275,9 @@ def revalidate_input_tables( def normalize_tags_columns( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Normalize (uppercase) tags and (lowercase) column names. @@ -309,9 +307,9 @@ def normalize(table: datatypes.EmbeddedXlTable) -> datatypes.EmbeddedXlTable: def normalize_column_aliases( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: for table in tables: tag = table.tag.split(":")[0] if tag in config.column_aliases: @@ -329,9 +327,9 @@ def normalize_column_aliases( def include_tables_source( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Add a column specifying source filename to every table """ @@ -346,9 +344,9 @@ def include_table_source(table: datatypes.EmbeddedXlTable): def merge_tables( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: """ Merge all tables in 'tables' with the same table tag, as long as they share the same column field values. Print a warning for those that don't share the same column values. @@ -405,9 +403,9 @@ def merge_tables( def process_flexible_import_tables( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Attempt to process all flexible import tables in 'tables'. The processing includes: - Checking that the table is indeed a flexible import table. If not, return it unmodified. @@ -555,9 +553,9 @@ def process_flexible_import_table( def process_user_constraint_tables( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Process all user constraint tables in 'tables'. The processing includes: - Removing, adding and renaming columns as needed. @@ -675,9 +673,9 @@ def process_user_constraint_table( def generate_uc_properties( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Generate a dataframe containing User Constraint properties """ @@ -742,9 +740,9 @@ def generate_uc_properties( def fill_in_missing_values( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Attempt to fill in missing values for all tables except update tables (as these contain wildcards). How the value is filled in depends on the name of the column the empty values @@ -838,7 +836,7 @@ def fill_in_missing_values_table(table): def expand_rows( - query_columns: Set[str], table: datatypes.EmbeddedXlTable + query_columns: set[str], table: datatypes.EmbeddedXlTable ) -> datatypes.EmbeddedXlTable: """ Expand entries with commas into separate entries in the same column. Do this @@ -876,9 +874,9 @@ def split_by_commas(s): def remove_invalid_values( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Remove all entries of any dataframes that are considered invalid. The rules for allowing an entry can be seen in the 'constraints' dictionary below. @@ -924,9 +922,9 @@ def remove_table_invalid_values( def process_units( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: units_map = { "activity": model.processes["tact"].unique(), "capacity": model.processes["tcap"].unique(), @@ -943,9 +941,9 @@ def process_units( def process_time_periods( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: model.start_year = utils.get_scalar(datatypes.Tag.start_year, tables) active_pdef = utils.get_scalar(datatypes.Tag.active_p_def, tables) df = utils.single_table(tables, datatypes.Tag.time_periods).dataframe.copy() @@ -970,15 +968,15 @@ def process_time_periods( def process_regions( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Read model regions and update model.internal_regions and model.all_regions. Include IMPEXP and MINRNW in model.all_regions (defined by default by Veda). """ - model.all_regions.update((["IMPEXP", "MINRNW"])) + model.all_regions.update(["IMPEXP", "MINRNW"]) # Read region settings region_def = utils.single_table(tables, datatypes.Tag.book_regions_map).dataframe # Harmonise the dataframe @@ -1021,9 +1019,9 @@ def process_regions( def complete_dictionary( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: for k, v in [ ("AllRegions", model.all_regions), ("Regions", model.internal_regions), @@ -1061,9 +1059,9 @@ def complete_dictionary( def capitalise_some_values( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Ensure that all attributes and units are uppercase """ @@ -1088,9 +1086,9 @@ def capitalise_attributes_table(table: datatypes.EmbeddedXlTable): def apply_fixups( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: def apply_fixups_table(table: datatypes.EmbeddedXlTable): tag = datatypes.Tag.fi_t if not table.tag.startswith(tag): @@ -1129,7 +1127,7 @@ def _populate_defaults(dataframe: DataFrame, col_name: str): ] # Populate commodity and other_indexes based on defaults - for col in {"commodity", "other_indexes"}: + for col in ("commodity", "other_indexes"): _populate_defaults(df, col) # Fill other indexes for some attributes @@ -1146,9 +1144,9 @@ def _populate_defaults(dataframe: DataFrame, col_name: str): def generate_commodity_groups( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Generate commodity groups. """ @@ -1289,9 +1287,9 @@ def _set_default_veda_pcg(group): def complete_commodity_groups( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: """ Complete the list of commodity groups. """ @@ -1314,9 +1312,9 @@ def complete_commodity_groups( def generate_trade( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Generate inter-regional exchange topology """ @@ -1405,9 +1403,9 @@ def generate_trade( def fill_in_missing_pcgs( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Fill in missing primary commodity groups in FI_Process tables. Expand primary commodity groups specified in FI_Process tables by a suffix. @@ -1457,9 +1455,9 @@ def expand_pcg_from_suffix(df): def remove_fill_tables( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: # These tables collect data from elsewhere and update the table itself or a region below # The collected data is then presumably consumed via Excel references or vlookups # TODO: For the moment, assume that these tables are up-to-date. We will need a tool to do this. @@ -1474,9 +1472,9 @@ def remove_fill_tables( def process_commodity_emissions( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: result = [] for table in tables: if table.tag != datatypes.Tag.comemi: @@ -1512,9 +1510,9 @@ def process_commodity_emissions( def process_commodities( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Process commodities. """ @@ -1539,9 +1537,9 @@ def process_commodities( def process_processes( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Process processes. """ @@ -1591,9 +1589,9 @@ def process_processes( def process_topology( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Create topology. """ @@ -1653,10 +1651,10 @@ def process_topology( def generate_dummy_processes( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, include_dummy_processes=True, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Define dummy processes and specify default cost data for them to ensure that a TIMES model can always be solved. This covers situations when a commodity cannot be supplied @@ -1709,9 +1707,9 @@ def generate_dummy_processes( def process_tradelinks( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Transform tradelinks to tradelinks_dins """ @@ -1789,9 +1787,9 @@ def process_tradelinks( def process_transform_table_variants( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """Reduces variants of TFM_INS like TFM_INS-TS to TFM_INS.""" def has_no_wildcards(list): @@ -1819,7 +1817,6 @@ def is_year(col_name): ]: # ~TFM_INS-TS: Gather columns whose names are years into a single "Year" column: df = table.dataframe - query_columns = config.query_columns[datatypes.Tag(table.tag)] if "year" in df.columns: raise ValueError(f"TFM_INS-TS table already has Year column: {table}") @@ -1872,9 +1869,9 @@ def is_year(col_name): def process_transform_tables( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: """ Process transform tables. """ @@ -1983,9 +1980,9 @@ def process_transform_tables( def process_transform_availability( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: result = [] dropped = [] for table in tables: @@ -2024,7 +2021,7 @@ def intersect(acc, df): return acc.merge(df) -def get_matching_processes(row: pd.Series, topology: Dict[str, DataFrame]) -> pd.Series: +def get_matching_processes(row: pd.Series, topology: dict[str, DataFrame]) -> pd.Series: matching_processes = None for col, key in process_map.items(): if col in row.index and row[col] is not None: @@ -2039,7 +2036,7 @@ def get_matching_processes(row: pd.Series, topology: Dict[str, DataFrame]) -> pd return matching_processes -def get_matching_commodities(row: pd.Series, topology: Dict[str, DataFrame]): +def get_matching_commodities(row: pd.Series, topology: dict[str, DataFrame]): matching_commodities = None for col, key in commodity_map.items(): if col in row.index and row[col] is not None: @@ -2062,8 +2059,8 @@ def df_indexed_by_col(df, col): def generate_topology_dictionary( - tables: Dict[str, DataFrame], model: datatypes.TimesModel -) -> Dict[str, DataFrame]: + tables: dict[str, DataFrame], model: datatypes.TimesModel +) -> dict[str, DataFrame]: # We need to be able to fetch processes based on any combination of name, description, set, comm-in, or comm-out # So we construct tables whose indices are names, etc. and use pd.filter @@ -2111,9 +2108,9 @@ def generate_topology_dictionary( def process_wildcards( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: tags = [ datatypes.Tag.tfm_comgrp, datatypes.Tag.tfm_ins, @@ -2220,15 +2217,13 @@ def _match_wildcards( def apply_transform_tables( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: """ Include data from transformation tables. """ - topology = generate_topology_dictionary(tables, model) - def query( table: DataFrame, process: str | None, @@ -2391,13 +2386,13 @@ def eval_and_update( def process_time_slices( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: def timeslices_table( table: datatypes.EmbeddedXlTable, regions: list, - result: List[datatypes.EmbeddedXlTable], + result: list[datatypes.EmbeddedXlTable], ): # User-specified timeslices (ordered) user_ts_levels = ["SEASON", "WEEKLY", "DAYNITE"] @@ -2514,9 +2509,9 @@ def timeslices_table( def convert_to_string( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: for key, value in tables.items(): tables[key] = value.map( lambda x: str(int(x)) if isinstance(x, float) and x.is_integer() else str(x) @@ -2526,9 +2521,9 @@ def convert_to_string( def convert_aliases( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: # Ensure TIMES names for all attributes replacement_dict = {} for k, v in config.veda_attr_defaults["aliases"].items(): @@ -2553,9 +2548,9 @@ def convert_aliases( def assign_model_attributes( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: model.attributes = tables[datatypes.Tag.fi_t] if datatypes.Tag.uc_t in tables.keys(): @@ -2566,9 +2561,9 @@ def assign_model_attributes( def resolve_remaining_cgs( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: """ Resolve commodity group names in model.attributes specified as commodity type. Supplement model.commodity_groups with resolved commodity groups. @@ -2618,9 +2613,9 @@ def resolve_remaining_cgs( def fix_topology( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: mapping = {"IN-A": "IN", "OUT-A": "OUT"} model.topology.replace({"io": mapping}, inplace=True) @@ -2630,9 +2625,9 @@ def fix_topology( def complete_processes( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: """ Generate processes based on trade links if not defined elsewhere """ @@ -2697,9 +2692,9 @@ def complete_processes( def apply_final_fixup( config: datatypes.Config, - tables: Dict[str, DataFrame], + tables: dict[str, DataFrame], model: datatypes.TimesModel, -) -> Dict[str, DataFrame]: +) -> dict[str, DataFrame]: veda_process_sets = tables["VedaProcessSets"] reg_com_flows = tables["ProcessTopology"].drop(columns="io") @@ -2789,9 +2784,9 @@ def apply_final_fixup( def expand_rows_parallel( config: datatypes.Config, - tables: List[datatypes.EmbeddedXlTable], + tables: list[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, -) -> List[datatypes.EmbeddedXlTable]: +) -> list[datatypes.EmbeddedXlTable]: query_columns_lists = [ ( config.query_columns[datatypes.Tag(table.tag)] diff --git a/xl2times/utils.py b/xl2times/utils.py index 32ad4abf..9f2a825c 100644 --- a/xl2times/utils.py +++ b/xl2times/utils.py @@ -1,15 +1,16 @@ from __future__ import ( annotations, -) # see https://loguru.readthedocs.io/en/stable/api/type_hints.html#module-autodoc_stub_file.loguru +) +# see https://loguru.readthedocs.io/en/stable/api/type_hints.html#module-autodoc_stub_file.loguru import functools import os import re import sys +from collections.abc import Iterable from dataclasses import replace -from math import log10, floor +from math import floor, log10 from pathlib import Path -from typing import Iterable, List import loguru import numpy @@ -84,7 +85,7 @@ def explode(df, data_columns): return df, names -def single_table(tables: List[datatypes.EmbeddedXlTable], tag: str): +def single_table(tables: list[datatypes.EmbeddedXlTable], tag: str): """ Make sure exactly one table in 'tables' has the given table tag, and return it. If there are none or more than one raise an error. @@ -96,7 +97,7 @@ def single_table(tables: List[datatypes.EmbeddedXlTable], tag: str): return one(table for table in tables if table.tag == tag) -def single_column(tables: List[datatypes.EmbeddedXlTable], tag: str, colname: str): +def single_column(tables: list[datatypes.EmbeddedXlTable], tag: str, colname: str): """ Make sure exactly one table in 'tables' has the given table tag, and return the values for the given column name. If there are none or more than one raise an error. @@ -109,7 +110,7 @@ def single_column(tables: List[datatypes.EmbeddedXlTable], tag: str, colname: st return single_table(tables, tag).dataframe[colname].values -def merge_columns(tables: List[datatypes.EmbeddedXlTable], tag: str, colname: str): +def merge_columns(tables: list[datatypes.EmbeddedXlTable], tag: str, colname: str): """ Return a list with all the values belonging to a column 'colname' from a table with the given tag. @@ -149,8 +150,8 @@ def apply_wildcards( current_list = [] for wildcard in wildcard_list: if wildcard.startswith("-"): - wildcard = wildcard[1:] - regexp = re.compile(wildcard.replace("*", ".*")) + w = wildcard[1:] + regexp = re.compile(w.replace("*", ".*")) current_list = [s for s in current_list if not regexp.match(s)] else: regexp = re.compile(wildcard.replace("*", ".*")) @@ -180,7 +181,7 @@ def missing_value_inherit(df: DataFrame, colname: str): last = value -def get_scalar(table_tag: str, tables: List[datatypes.EmbeddedXlTable]): +def get_scalar(table_tag: str, tables: list[datatypes.EmbeddedXlTable]): table = one(filter(lambda t: t.tag == table_tag, tables)) if table.dataframe.shape[0] != 1 or table.dataframe.shape[1] != 1: raise ValueError("Not scalar table")