diff --git a/ixmp/core/platform.py b/ixmp/core/platform.py index 8218729dd..10b770b27 100644 --- a/ixmp/core/platform.py +++ b/ixmp/core/platform.py @@ -9,7 +9,7 @@ from ixmp.backend import BACKENDS, FIELDS, ItemType from ixmp.util import as_str_list -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: from ixmp.backend.base import Backend @@ -235,11 +235,11 @@ def export_timeseries_data( "model or scenario." ) filters = { - "model": as_str_list(model) or [], - "scenario": as_str_list(scenario) or [], - "variable": as_str_list(variable) or [], - "unit": as_str_list(unit) or [], - "region": as_str_list(region) or [], + "model": as_str_list(model), + "scenario": as_str_list(scenario), + "variable": as_str_list(variable), + "unit": as_str_list(unit), + "region": as_str_list(region), "default": default, "export_all_runs": export_all_runs, } diff --git a/ixmp/core/scenario.py b/ixmp/core/scenario.py index 5201e634f..b3b8a9a23 100644 --- a/ixmp/core/scenario.py +++ b/ixmp/core/scenario.py @@ -1,10 +1,20 @@ import logging from functools import partialmethod -from itertools import repeat, zip_longest +from itertools import zip_longest from numbers import Real from os import PathLike from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + MutableSequence, + Optional, + Sequence, + Union, +) from warnings import warn import pandas as pd @@ -213,6 +223,11 @@ def add_set( # noqa: C901 # Get index names for set *name*, may raise KeyError idx_names = self.idx_names(name) + # List of keys + keys: MutableSequence[Union[str, MutableSequence[str]]] = [] + # List of comments for each key + comments: List[Optional[str]] = [comment] if comment else [] + # Check arguments and convert to two lists: keys and comments if len(idx_names) == 0: # Basic set. Keys must be strings. @@ -223,7 +238,7 @@ def add_set( # noqa: C901 ) # Ensure keys is a list of str - keys = as_str_list(key) + keys.extend(as_str_list(key)) else: # Set defined over 1+ other sets @@ -235,37 +250,37 @@ def add_set( # noqa: C901 # DataFrame of key values and perhaps comments try: # Pop a 'comment' column off the DataFrame, convert to list - comment = key.pop("comment").to_list() + comments.extend(key.pop("comment")) except KeyError: pass # Convert key to list of list of key values - keys = [] for row in key.to_dict(orient="records"): keys.append(as_str_list(row, idx_names=idx_names)) elif isinstance(key, dict): # Dict of lists of key values # Pop a 'comment' list from the dict - comment = key.pop("comment", None) + comments.extend(key.pop("comment", [])) # Convert to list of list of key values - keys = list(map(as_str_list, zip(*[key[i] for i in idx_names]))) + keys.extend(map(as_str_list, zip(*[key[i] for i in idx_names]))) elif isinstance(key[0], str): # List of key values; wrap - keys = [as_str_list(key)] + keys.append(as_str_list(key)) elif isinstance(key[0], list): # List of lists of key values; convert to list of list of str - keys = list(map(as_str_list, key)) + keys.extend(map(as_str_list, key)) elif isinstance(key, str) and len(idx_names) == 1: # Bare key given for a 1D set; wrap for convenience - keys = [[key]] + keys.append([key]) else: # Other, invalid value raise ValueError(key) # Process comments to a list of str, or let them all be None - comments = as_str_list(comment) if comment else repeat(None, len(keys)) + if not comments: + comments = [None] * len(keys) # Combine iterators to tuples. If the lengths are mismatched, the sentinel # value 'False' is filled in @@ -569,7 +584,7 @@ def add_par( keys = [keys] # Use the same value for all keys - values = [float(value)] * len(keys) + values: List[Any] = [float(value)] * len(keys) else: # Multiple values values = value diff --git a/ixmp/core/timeseries.py b/ixmp/core/timeseries.py index 90c2fbf78..3eb6125ea 100644 --- a/ixmp/core/timeseries.py +++ b/ixmp/core/timeseries.py @@ -365,16 +365,15 @@ def add_timeseries( df.columns = df.columns.astype(int) # Identify columns to drop - to_drop = set() - if year_lim[0]: - to_drop |= set(filter(lambda y: y < year_lim[0], df.columns)) - if year_lim[1]: - to_drop |= set(filter(lambda y: y > year_lim[1], df.columns)) + def predicate(y: Any) -> bool: + return y < (year_lim[0] or y) or (year_lim[1] or y) < y - df.drop(to_drop, axis=1, inplace=True) + df.drop(list(filter(predicate, df.columns)), axis=1, inplace=True) # Add one time series per row - for (r, v, u, sa), data in df.iterrows(): + for key, data in df.iterrows(): + assert isinstance(key, tuple) + r, v, u, sa = key # Values as float; exclude NA self._backend( "set_data", r, v, data.astype(float).dropna().to_dict(), u, sa, meta @@ -429,19 +428,16 @@ def timeseries( year if isinstance(year, Sequence) else [] if year is None else [year], ), columns=FIELDS["ts_get"], - ) - df["model"] = self.model - df["scenario"] = self.scenario + ).assign(model=self.model, scenario=self.scenario) # drop `subannual` column if not requested (False) or required ('auto') if subannual is not True: has_subannual = not all(df["subannual"] == "Year") if subannual is False and has_subannual: - msg = ( - "timeseries data has subannual values, ", - "use `subannual=True or 'auto'`", + raise ValueError( + "timeseries data has subannual values, use `subannual=True or " + "'auto'`" ) - raise ValueError(msg) if not has_subannual: df.drop("subannual", axis=1, inplace=True) @@ -450,8 +446,11 @@ def timeseries( index = IAMC_IDX if "subannual" in df.columns: index = index + ["subannual"] - df = df.pivot_table(index=index, columns="year")["value"].reset_index() - df.columns.names = [None] + df = ( + df.pivot_table(index=index, columns="year")["value"] + .reset_index() + .rename_axis(columns=None) + ) return df diff --git a/ixmp/tests/core/test_platform.py b/ixmp/tests/core/test_platform.py index a0fe8c1e1..bca620420 100644 --- a/ixmp/tests/core/test_platform.py +++ b/ixmp/tests/core/test_platform.py @@ -3,6 +3,7 @@ import logging import re from sys import getrefcount +from typing import TYPE_CHECKING from weakref import getweakrefcount import pandas as pd @@ -14,6 +15,9 @@ from ixmp.backend import FIELDS from ixmp.testing import DATA, assert_logs, models +if TYPE_CHECKING: + from ixmp import Platform + class TestPlatform: def test_init(self): @@ -68,14 +72,11 @@ def test_scenario_list(mp): assert scenario[0] == "Hitchhiker" -def test_export_timeseries_data(test_mp, tmp_path): +def test_export_timeseries_data(mp: "Platform", tmp_path) -> None: path = tmp_path / "export.csv" - test_mp.export_timeseries_data( - path, model="Douglas Adams", unit="???", region="World" - ) + mp.export_timeseries_data(path, model="Douglas Adams", unit="???", region="World") obs = pd.read_csv(path, index_col=False, header=0) - exp = ( DATA[0] .assign(**models["h2g2"], version=1, subannual="Year", meta=0) diff --git a/ixmp/tests/report/test_operator.py b/ixmp/tests/report/test_operator.py index 0c30a6a9a..9e653c3b8 100644 --- a/ixmp/tests/report/test_operator.py +++ b/ixmp/tests/report/test_operator.py @@ -1,6 +1,7 @@ import logging import re from functools import partial +from typing import cast import genno import pandas as pd @@ -123,7 +124,8 @@ def test_update_scenario(caplog, test_mp) -> None: c.add("target", scen) # Create a pd.DataFrame suitable for Scenario.add_par() - data = dantzig_data["d"].query("j == 'chicago'").assign(j="toronto") + d = cast(pd.DataFrame, dantzig_data["d"]) + data = d.query("j == 'chicago'").assign(j="toronto") data["value"] += 1.0 # Add to the Reporter @@ -141,7 +143,7 @@ def test_update_scenario(caplog, test_mp) -> None: assert len(scen.par("d")) == N_before + len(data) # Modify the data - data = pd.concat([dantzig_data["d"], data]).reset_index(drop=True) + data = pd.concat([d, data]).reset_index(drop=True) data["value"] *= 2.0 # Convert to a Quantity object and re-add diff --git a/ixmp/util/__init__.py b/ixmp/util/__init__.py index 8c1214c78..34968bd16 100644 --- a/ixmp/util/__init__.py +++ b/ixmp/util/__init__.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, - Dict, + Any, Iterable, Iterator, List, @@ -57,7 +57,7 @@ def logger(): return logging.getLogger("ixmp") -def as_str_list(arg, idx_names: Optional[Iterable[str]] = None): +def as_str_list(arg, idx_names: Optional[Iterable[str]] = None) -> List[str]: """Convert various `arg` to list of str. Several types of arguments are handled: @@ -70,11 +70,11 @@ def as_str_list(arg, idx_names: Optional[Iterable[str]] = None): """ if arg is None: - return None + return [] elif idx_names is None: # arg must be iterable - # NB narrower ABC Sequence does not work here; e.g. test_excel_io() - # fails via Scenario.add_set(). + # NB narrower ABC Sequence does not work here; e.g. test_excel_io() fails via + # Scenario.add_set(). if isinstance(arg, Iterable) and not isinstance(arg, str): return list(map(str, arg)) else: @@ -120,7 +120,7 @@ def diff(a, b, filters=None) -> Iterator[Tuple[str, pd.DataFrame]]: ] # State variables for loop name = ["", ""] - data: List[pd.DataFrame] = [None, None] + data: List[pd.DataFrame] = [pd.DataFrame(), pd.DataFrame()] # Elements for first iteration name[0], data[0] = next(items[0]) @@ -146,7 +146,7 @@ def diff(a, b, filters=None) -> Iterator[Tuple[str, pd.DataFrame]]: # Either merge on remaining columns; or, for scalars, on the indices on = sorted(set(left.columns) - {"value", "unit"}) - on_arg: Dict[str, object] = ( + on_arg: Mapping[str, Any] = ( dict(on=on) if on else dict(left_index=True, right_index=True) ) @@ -173,7 +173,7 @@ def diff(a, b, filters=None) -> Iterator[Tuple[str, pd.DataFrame]]: except StopIteration: # No more data for this iterator. # Use "~" because it sorts after all ASCII characters - name[i], data[i] = "~ end", None + name[i], data[i] = "~ end", pd.DataFrame() if name[0] == name[1] == "~ end": break diff --git a/ixmp/util/sphinx_linkcode_github.py b/ixmp/util/sphinx_linkcode_github.py index d3987082e..86cd6209f 100644 --- a/ixmp/util/sphinx_linkcode_github.py +++ b/ixmp/util/sphinx_linkcode_github.py @@ -10,7 +10,7 @@ from sphinx.util import logging -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: import sphinx.application log = logging.getLogger(__name__)