Skip to content

Commit

Permalink
Update .transport.operator.broadcast{,_wildcard}()
Browse files Browse the repository at this point in the history
  • Loading branch information
khaeru committed Nov 18, 2024
1 parent 99ed0fa commit 48b75de
Showing 1 changed file with 48 additions and 13 deletions.
61 changes: 48 additions & 13 deletions message_ix_models/model/transport/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from message_ix_models.report.util import as_quantity
from message_ix_models.util import (
MappingAdapter,
broadcast,
datetime_now_with_tz,
nodes_ex_world,
show_versions,
Expand Down Expand Up @@ -203,16 +202,49 @@ def broadcast_advance(data: "AnyQuantity", y0: int, config: dict) -> "AnyQuantit
return result


def broadcast_n(qty: "AnyQuantity", n: List[str], *, dim: str = "n") -> "AnyQuantity":
"""Broadcast over nodes `n` along dimension `dim`."""
existing = sorted(qty.coords[dim].data)
missing = set(n) - set(existing)
def broadcast(q1: "AnyQuantity", q2: "AnyQuantity") -> "AnyQuantity":
import numpy as np

if missing:
n_map = [(n_, n_) for n_ in existing] + [("*", n_) for n_ in missing]
return MappingAdapter({dim: n_map})(qty)
else:
return qty
# Squeeze dimensions of q1 that are (a) in q2 and (b) contain only NaN or None
# labels
squeezed = q1
for d in q2.dims:
if set(q1.coords[d].data) <= {np.nan, None}:
squeezed = squeezed.squeeze(dim=d)

# TODO Use the following once supported by genno
# squeezed = q1.squeeze(
# dim=[d for d in q2.dims if set(q1.coords[d].data) <= {np.nan, None}]
# )

return squeezed * q2


def broadcast_wildcard(
qty: "AnyQuantity", coords: List[str], *, dim: str = "n"
) -> "AnyQuantity":
"""Broadcast over coordinates `coords` along dimension `dim`.
Any missing labels in `coords` are populated using values of `qty` that have the
‘wildcard’ label "*" for `dim`.
"""
# Identify existing, non-wildcard labels along `dim`
existing = set(qty.coords[dim].data) - {"*"}
# Identify missing labels along `dim`
missing = sorted(set(coords) - existing)

if not missing:
return qty # Nothing to do; `qty` is already complete

# Construct a MappingAdapter:
# - Each existing label (whether in ``) mapped to themselves.
# - "*" mapping to each missing label.
adapt = MappingAdapter(
{dim: [(x, x) for x in sorted(existing)] + [("*", x) for x in missing]}
)

# Apply the adapter to `qty`
return adapt(qty)


def broadcast_t_c_l(
Expand Down Expand Up @@ -455,6 +487,8 @@ def factor_fv(n: List[str], y: List[int], config: dict) -> "AnyQuantity":
Otherwise, the value is 1.0 for every (`n`, `y`).
"""
from message_ix_models.util import broadcast

# Empty data frame
df = pd.DataFrame(columns=["value"], index=pd.Index(y, name="y"))

Expand Down Expand Up @@ -546,6 +580,8 @@ def factor_pdt(n: List[str], y: List[int], t: List[str], config: dict) -> "AnyQu
Otherwise, the value is 1.0 for every (`n`, `t`, `y`).
"""
from message_ix_models.util import broadcast

# Empty data frame
df = pd.DataFrame(columns=t, index=pd.Index(y, name="y"))

Expand Down Expand Up @@ -1142,8 +1178,7 @@ def _add_transport_data(func, c: "Computer", name: str, *, key) -> None:
def transport_data(*args):
"""No action.
This exists to connect :func:`._add_transport_data` to
:meth:`genno.Computer.add`.
This exists to connect :func:`._add_transport_data` to :meth:`genno.Computer.add`.
"""
pass # pragma: no cover

Expand All @@ -1156,7 +1191,7 @@ def transport_check(scenario: "Scenario", ACT: "AnyQuantity") -> pd.Series:
checks = {}

# Correct number of outputs
ACT_lf = ACT.sel(t=["transport freight load factor", "transport pax load factor"])
ACT_lf = ACT.sel(t=["transport f load factor", "transport pax load factor"])
checks["'transport * load factor' technologies are active"] = len(
ACT_lf
) == 2 * len(info.Y) * (len(info.N) - 1)
Expand Down

0 comments on commit 48b75de

Please sign in to comment.