Skip to content

Commit

Permalink
Satisfy mypy with pandas-stubs
Browse files Browse the repository at this point in the history
- as_str_list() always returns list.
- Remove "pragma: no cover" on TYPE_CHECKING blocks.
  • Loading branch information
khaeru committed Mar 27, 2024
1 parent c1b5909 commit dbfcd3c
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 50 deletions.
12 changes: 6 additions & 6 deletions ixmp/core/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ixmp.backend import BACKENDS, FIELDS, ItemType
from ixmp.util import as_str_list

if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
from ixmp.backend.base import Backend


Expand Down Expand Up @@ -235,11 +235,11 @@ def export_timeseries_data(
"model or scenario."
)
filters = {
"model": as_str_list(model) or [],
"scenario": as_str_list(scenario) or [],
"variable": as_str_list(variable) or [],
"unit": as_str_list(unit) or [],
"region": as_str_list(region) or [],
"model": as_str_list(model),
"scenario": as_str_list(scenario),
"variable": as_str_list(variable),
"unit": as_str_list(unit),
"region": as_str_list(region),
"default": default,
"export_all_runs": export_all_runs,
}
Expand Down
39 changes: 27 additions & 12 deletions ixmp/core/scenario.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import logging
from functools import partialmethod
from itertools import repeat, zip_longest
from itertools import zip_longest
from numbers import Real
from os import PathLike
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
MutableSequence,
Optional,
Sequence,
Union,
)
from warnings import warn

import pandas as pd
Expand Down Expand Up @@ -213,6 +223,11 @@ def add_set( # noqa: C901
# Get index names for set *name*, may raise KeyError
idx_names = self.idx_names(name)

# List of keys
keys: MutableSequence[Union[str, MutableSequence[str]]] = []
# List of comments for each key
comments: List[Optional[str]] = [comment] if comment else []

# Check arguments and convert to two lists: keys and comments
if len(idx_names) == 0:
# Basic set. Keys must be strings.
Expand All @@ -223,7 +238,7 @@ def add_set( # noqa: C901
)

# Ensure keys is a list of str
keys = as_str_list(key)
keys.extend(as_str_list(key))
else:
# Set defined over 1+ other sets

Expand All @@ -235,37 +250,37 @@ def add_set( # noqa: C901
# DataFrame of key values and perhaps comments
try:
# Pop a 'comment' column off the DataFrame, convert to list
comment = key.pop("comment").to_list()
comments.extend(key.pop("comment"))
except KeyError:
pass

# Convert key to list of list of key values
keys = []
for row in key.to_dict(orient="records"):
keys.append(as_str_list(row, idx_names=idx_names))
elif isinstance(key, dict):
# Dict of lists of key values

# Pop a 'comment' list from the dict
comment = key.pop("comment", None)
comments.extend(key.pop("comment", []))

# Convert to list of list of key values
keys = list(map(as_str_list, zip(*[key[i] for i in idx_names])))
keys.extend(map(as_str_list, zip(*[key[i] for i in idx_names])))
elif isinstance(key[0], str):
# List of key values; wrap
keys = [as_str_list(key)]
keys.append(as_str_list(key))
elif isinstance(key[0], list):
# List of lists of key values; convert to list of list of str
keys = list(map(as_str_list, key))
keys.extend(map(as_str_list, key))
elif isinstance(key, str) and len(idx_names) == 1:
# Bare key given for a 1D set; wrap for convenience
keys = [[key]]
keys.append([key])
else:
# Other, invalid value
raise ValueError(key)

# Process comments to a list of str, or let them all be None
comments = as_str_list(comment) if comment else repeat(None, len(keys))
if not comments:
comments = [None] * len(keys)

# Combine iterators to tuples. If the lengths are mismatched, the sentinel
# value 'False' is filled in
Expand Down Expand Up @@ -569,7 +584,7 @@ def add_par(
keys = [keys]

# Use the same value for all keys
values = [float(value)] * len(keys)
values: List[Any] = [float(value)] * len(keys)
else:
# Multiple values
values = value
Expand Down
31 changes: 15 additions & 16 deletions ixmp/core/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,16 +365,15 @@ def add_timeseries(
df.columns = df.columns.astype(int)

# Identify columns to drop
to_drop = set()
if year_lim[0]:
to_drop |= set(filter(lambda y: y < year_lim[0], df.columns))
if year_lim[1]:
to_drop |= set(filter(lambda y: y > year_lim[1], df.columns))
def predicate(y: Any) -> bool:
return y < (year_lim[0] or y) or (year_lim[1] or y) < y

df.drop(to_drop, axis=1, inplace=True)
df.drop(list(filter(predicate, df.columns)), axis=1, inplace=True)

# Add one time series per row
for (r, v, u, sa), data in df.iterrows():
for key, data in df.iterrows():
assert isinstance(key, tuple)
r, v, u, sa = key
# Values as float; exclude NA
self._backend(
"set_data", r, v, data.astype(float).dropna().to_dict(), u, sa, meta
Expand Down Expand Up @@ -429,19 +428,16 @@ def timeseries(
year if isinstance(year, Sequence) else [] if year is None else [year],
),
columns=FIELDS["ts_get"],
)
df["model"] = self.model
df["scenario"] = self.scenario
).assign(model=self.model, scenario=self.scenario)

# drop `subannual` column if not requested (False) or required ('auto')
if subannual is not True:
has_subannual = not all(df["subannual"] == "Year")
if subannual is False and has_subannual:
msg = (
"timeseries data has subannual values, ",
"use `subannual=True or 'auto'`",
raise ValueError(
"timeseries data has subannual values, use `subannual=True or "
"'auto'`"
)
raise ValueError(msg)
if not has_subannual:
df.drop("subannual", axis=1, inplace=True)

Expand All @@ -450,8 +446,11 @@ def timeseries(
index = IAMC_IDX
if "subannual" in df.columns:
index = index + ["subannual"]
df = df.pivot_table(index=index, columns="year")["value"].reset_index()
df.columns.names = [None]
df = (
df.pivot_table(index=index, columns="year")["value"]
.reset_index()
.rename_axis(columns=None)
)

return df

Expand Down
11 changes: 6 additions & 5 deletions ixmp/tests/core/test_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import re
from sys import getrefcount
from typing import TYPE_CHECKING
from weakref import getweakrefcount

import pandas as pd
Expand All @@ -14,6 +15,9 @@
from ixmp.backend import FIELDS
from ixmp.testing import DATA, assert_logs, models

if TYPE_CHECKING:
from ixmp import Platform


class TestPlatform:
def test_init(self):
Expand Down Expand Up @@ -68,14 +72,11 @@ def test_scenario_list(mp):
assert scenario[0] == "Hitchhiker"


def test_export_timeseries_data(test_mp, tmp_path):
def test_export_timeseries_data(mp: "Platform", tmp_path) -> None:
path = tmp_path / "export.csv"
test_mp.export_timeseries_data(
path, model="Douglas Adams", unit="???", region="World"
)
mp.export_timeseries_data(path, model="Douglas Adams", unit="???", region="World")

obs = pd.read_csv(path, index_col=False, header=0)

exp = (
DATA[0]
.assign(**models["h2g2"], version=1, subannual="Year", meta=0)
Expand Down
6 changes: 4 additions & 2 deletions ixmp/tests/report/test_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
from functools import partial
from typing import cast

import genno
import pandas as pd
Expand Down Expand Up @@ -123,7 +124,8 @@ def test_update_scenario(caplog, test_mp) -> None:
c.add("target", scen)

# Create a pd.DataFrame suitable for Scenario.add_par()
data = dantzig_data["d"].query("j == 'chicago'").assign(j="toronto")
d = cast(pd.DataFrame, dantzig_data["d"])
data = d.query("j == 'chicago'").assign(j="toronto")
data["value"] += 1.0

# Add to the Reporter
Expand All @@ -141,7 +143,7 @@ def test_update_scenario(caplog, test_mp) -> None:
assert len(scen.par("d")) == N_before + len(data)

# Modify the data
data = pd.concat([dantzig_data["d"], data]).reset_index(drop=True)
data = pd.concat([d, data]).reset_index(drop=True)
data["value"] *= 2.0

# Convert to a Quantity object and re-add
Expand Down
16 changes: 8 additions & 8 deletions ixmp/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Dict,
Any,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -57,7 +57,7 @@ def logger():
return logging.getLogger("ixmp")


def as_str_list(arg, idx_names: Optional[Iterable[str]] = None):
def as_str_list(arg, idx_names: Optional[Iterable[str]] = None) -> List[str]:
"""Convert various `arg` to list of str.
Several types of arguments are handled:
Expand All @@ -70,11 +70,11 @@ def as_str_list(arg, idx_names: Optional[Iterable[str]] = None):
"""
if arg is None:
return None
return []
elif idx_names is None:
# arg must be iterable
# NB narrower ABC Sequence does not work here; e.g. test_excel_io()
# fails via Scenario.add_set().
# NB narrower ABC Sequence does not work here; e.g. test_excel_io() fails via
# Scenario.add_set().
if isinstance(arg, Iterable) and not isinstance(arg, str):
return list(map(str, arg))
else:
Expand Down Expand Up @@ -120,7 +120,7 @@ def diff(a, b, filters=None) -> Iterator[Tuple[str, pd.DataFrame]]:
]
# State variables for loop
name = ["", ""]
data: List[pd.DataFrame] = [None, None]
data: List[pd.DataFrame] = [pd.DataFrame(), pd.DataFrame()]

# Elements for first iteration
name[0], data[0] = next(items[0])
Expand All @@ -146,7 +146,7 @@ def diff(a, b, filters=None) -> Iterator[Tuple[str, pd.DataFrame]]:

# Either merge on remaining columns; or, for scalars, on the indices
on = sorted(set(left.columns) - {"value", "unit"})
on_arg: Dict[str, object] = (
on_arg: Mapping[str, Any] = (
dict(on=on) if on else dict(left_index=True, right_index=True)
)

Expand All @@ -173,7 +173,7 @@ def diff(a, b, filters=None) -> Iterator[Tuple[str, pd.DataFrame]]:
except StopIteration:
# No more data for this iterator.
# Use "~" because it sorts after all ASCII characters
name[i], data[i] = "~ end", None
name[i], data[i] = "~ end", pd.DataFrame()

if name[0] == name[1] == "~ end":
break
Expand Down
2 changes: 1 addition & 1 deletion ixmp/util/sphinx_linkcode_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from sphinx.util import logging

if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
import sphinx.application

log = logging.getLogger(__name__)
Expand Down

0 comments on commit dbfcd3c

Please sign in to comment.