From fea126e79ed64816fad0aea3d4e845616bb2961d Mon Sep 17 00:00:00 2001 From: Paul Natsuo Kishimoto Date: Mon, 12 Aug 2024 10:51:15 +0200 Subject: [PATCH] Add .models.shift_period; tests --- message_ix/models.py | 64 ++++++++++++++++++++++++++++++++- message_ix/tests/test_models.py | 21 ++++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/message_ix/models.py b/message_ix/models.py index 5c9253f71..1717a699f 100644 --- a/message_ix/models.py +++ b/message_ix/models.py @@ -4,7 +4,7 @@ from dataclasses import InitVar, dataclass, field from functools import partial from pathlib import Path -from typing import Mapping, MutableMapping, Optional, Tuple +from typing import TYPE_CHECKING, Mapping, MutableMapping, Optional, Tuple from warnings import warn import ixmp.model.gams @@ -12,6 +12,9 @@ from ixmp.backend import ItemType from ixmp.util import maybe_check_out, maybe_commit +if TYPE_CHECKING: + from .core import Scenario + log = logging.getLogger(__name__) #: Solver options used by :meth:`.Scenario.solve`. @@ -1009,3 +1012,62 @@ def __init__(self, *args, **kwargs): def initialize(cls, scenario, with_data=False): MESSAGE.initialize(scenario) MACRO.initialize(scenario, with_data) + + +def shift_period(scenario: "Scenario", y0: int) -> None: + """Shift the first period of the model horizon.""" + from ixmp.backend.jdbc import JDBCBackend + + # Retrieve existing cat_year information, including the current 'firstmodelyear' + cat_year = scenario.set("cat_year") + y0_pre = cat_year.query("type_year == 'firstmodelyear'")["year"].item() + + if y0 == y0_pre: + log.info(f"First model period is already {y0!r}") + return + elif y0 < y0_pre: + raise NotImplementedError( + f"Shift first model period *earlier*, from {y0_pre!r} -> {y0}" + ) + + # Periods to be shifted from within to before the model horizon + periods = list( + filter(lambda y: y0_pre <= y < y0, map(int, sorted(cat_year["year"].unique()))) + ) + log.info(f"Shift data for period(s): {periods}") + + # Handle historical_* parameters for which the dimensions are a subset of the + # corresponding variable's dimensions + data = {} + for var_name, par_name, filter_dim in ( + ("ACT", "historical_activity", "year_act"), + ("CAP_NEW", "historical_new_capacity", "year_vtg"), + ("EXT", "historical_extraction", "year"), + ("GDP", "historical_gdp", "year"), + ("LAND", "historical_land", "year"), + ): + # - Filter data for `var_name` along the `filter_dim`, keeping only the periods + # to be shifted. + # - Drop the marginal column; rename the level column to "value". + # - Group according to the dimensions of the target `par_name`. + # - Sum within groups. + # - Restore index columns. + data[par_name] = ( + scenario.var(var_name, filters={filter_dim: periods}) + .drop("mrg", axis=1) + .rename(columns={"lvl": "value"}) + .groupby(list(MESSAGE.items[par_name].dims)) + .sum()["value"] + .reset_index() + ) + + # TODO Handle "EMISS:n-e-type_tec-y" → + # "historical_emission:n-type_emission-type_tec-type_year", in which dimension names + # are changed + + # TODO Adjust cat_year + + if isinstance(scenario.platform._backend, JDBCBackend): + raise NotImplementedError("Cannot set variable values with JDBCBackend") + + # TODO Store new data diff --git a/message_ix/tests/test_models.py b/message_ix/tests/test_models.py index f8c6dccd6..331ad2fde 100644 --- a/message_ix/tests/test_models.py +++ b/message_ix/tests/test_models.py @@ -3,7 +3,8 @@ import ixmp import pytest -from message_ix.models import MESSAGE, MESSAGE_MACRO +from message_ix.models import MESSAGE, MESSAGE_MACRO, shift_period +from message_ix.testing import make_dantzig def test_initialize(test_mp): @@ -52,3 +53,21 @@ class _MM(MESSAGE_MACRO): "--MAX_ITERATION=100", ] assert all(e in mm.solve_args for e in expected) + + +@pytest.mark.parametrize( + "y0", + ( + # Not implemented: shifting to an earlier period + pytest.param(1962, marks=pytest.mark.xfail(raises=NotImplementedError)), + # Does nothing + 1963, + # Not implemented with ixmp.JDBCBackend + pytest.param(1964, marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param(1965, marks=pytest.mark.xfail(raises=NotImplementedError)), + ), +) +def test_shift_period(test_mp, y0): + s = make_dantzig(test_mp, solve=True, multi_year=True) + + shift_period(s, y0)