diff --git a/ixmp4/core/optimization/scalar.py b/ixmp4/core/optimization/scalar.py index b4262d80..5c6e43fe 100644 --- a/ixmp4/core/optimization/scalar.py +++ b/ixmp4/core/optimization/scalar.py @@ -8,6 +8,7 @@ from ixmp4.data.abstract import Docs as DocsModel from ixmp4.data.abstract import Run from ixmp4.data.abstract import Scalar as ScalarModel +from ixmp4.data.abstract import Unit as UnitModel class Scalar(BaseModelFacade): @@ -38,21 +39,19 @@ def value(self, value: float): ) @property - def unit(self): + def unit(self) -> UnitModel: """Associated unit.""" return self._model.unit @unit.setter - def unit(self, unit: str | Unit): - if isinstance(unit, Unit): - unit = unit - else: - unit_model = self.backend.units.get(unit) - unit = Unit(_backend=self.backend, _model=unit_model) + def unit(self, value: str | Unit): + if isinstance(value, str): + unit_model = self.backend.units.get(value) + value = Unit(_backend=self.backend, _model=unit_model) self._model = self.backend.optimization.scalars.update( id=self._model.id, value=self._model.value, - unit_id=unit.id, + unit_id=value.id, ) @property diff --git a/ixmp4/data/abstract/optimization/scalar.py b/ixmp4/data/abstract/optimization/scalar.py index e332d168..3e910abb 100644 --- a/ixmp4/data/abstract/optimization/scalar.py +++ b/ixmp4/data/abstract/optimization/scalar.py @@ -6,6 +6,7 @@ from .. import base from ..docs import DocsRepository +from ..unit import Unit class Scalar(base.BaseModel, Protocol): @@ -17,7 +18,7 @@ class Scalar(base.BaseModel, Protocol): """Value of the Scalar.""" unit__id: types.Integer "Foreign unique integer id of a unit." - unit: types.Mapped + unit: types.Mapped[Unit] "Associated unit." run__id: types.Integer "Foreign unique integer id of a run." diff --git a/ixmp4/data/auth/context.py b/ixmp4/data/auth/context.py index 4ff30672..250695b8 100644 --- a/ixmp4/data/auth/context.py +++ b/ixmp4/data/auth/context.py @@ -47,7 +47,7 @@ def apply(self, access_type: str, exc: db.sql.Select) -> db.sql.Select: if utils.is_joined(exc, Model): perms = self.tabulate_permissions() if perms.empty: - return exc.where(False) # type: ignore + return exc.where(db.false()) if access_type == "edit": perms = perms.where(perms["access_type"] == "EDIT").dropna() # `*` is used as wildcard in permission logic, replaced by sql-wildcard `%` diff --git a/ixmp4/db/__init__.py b/ixmp4/db/__init__.py index 90a3d67d..aec24061 100644 --- a/ixmp4/db/__init__.py +++ b/ixmp4/db/__init__.py @@ -42,6 +42,7 @@ UniqueConstraint, delete, exists, + false, func, insert, or_, diff --git a/tests/core/test_optimization_indexset.py b/tests/core/test_optimization_indexset.py index b96c33f4..109183f8 100644 --- a/tests/core/test_optimization_indexset.py +++ b/tests/core/test_optimization_indexset.py @@ -60,10 +60,12 @@ def test_get_indexset(self, platform: ixmp4.Platform): def test_add_data(self, platform: ixmp4.Platform): run = platform.runs.create("Model", "Scenario") - test_data = ["foo", "bar"] + # See https://mypy.readthedocs.io/en/stable/common_issues.html#variance for why + # a type hint is required here + test_data: list[float | int | str] = ["foo", "bar"] indexset_1 = run.optimization.indexsets.create("Indexset 1") - indexset_1.add(test_data) # type: ignore - run.optimization.indexsets.create("Indexset 2").add(test_data) # type: ignore + indexset_1.add(test_data) + run.optimization.indexsets.create("Indexset 2").add(test_data) indexset_2 = run.optimization.indexsets.get("Indexset 2") assert indexset_1.data == indexset_2.data diff --git a/tests/core/test_optimization_scalar.py b/tests/core/test_optimization_scalar.py index 470aa078..7f0ff526 100644 --- a/tests/core/test_optimization_scalar.py +++ b/tests/core/test_optimization_scalar.py @@ -52,9 +52,6 @@ def test_create_scalar(self, platform: ixmp4.Platform): "Scalar 1", value=20, unit=unit.name ) - with pytest.raises(TypeError): - _ = run.optimization.scalars.create("Scalar 2") # type: ignore - scalar_2 = run.optimization.scalars.create("Scalar 2", value=20, unit=unit) assert scalar_1.id != scalar_2.id @@ -86,7 +83,9 @@ def test_update_scalar(self, platform: ixmp4.Platform): _ = run.optimization.scalars.create("Scalar", value=20, unit=unit2.name) scalar.value = 30 - scalar.unit = "Test Unit" + # At the moment, mypy doesn't allow for different types in property getter and + # setter, see https://github.com/python/mypy/issues/3004 + scalar.unit = "Test Unit" # type: ignore # NOTE: doesn't work for some reason (but doesn't either for e.g. model.get()) # assert scalar == run.optimization.scalars.get("Scalar") result = run.optimization.scalars.get("Scalar") @@ -94,7 +93,7 @@ def test_update_scalar(self, platform: ixmp4.Platform): assert scalar.id == result.id assert scalar.name == result.name assert scalar.value == result.value == 30 - assert scalar.unit.id == result.unit.id == 1 # type: ignore + assert scalar.unit.id == result.unit.id == 1 def test_list_scalars(self, platform: ixmp4.Platform): run = platform.runs.create("Model", "Scenario") diff --git a/tests/core/test_optimization_table.py b/tests/core/test_optimization_table.py index 93f0c03e..c148a8f9 100644 --- a/tests/core/test_optimization_table.py +++ b/tests/core/test_optimization_table.py @@ -43,7 +43,7 @@ def test_create_table(self, platform: ixmp4.Platform): # Test normal creation indexset, indexset_2 = tuple( - IndexSet(_backend=platform.backend, _model=model) # type: ignore + IndexSet(_backend=platform.backend, _model=model) for model in create_indexsets_for_run(platform=platform, run_id=run.id) ) table = run.optimization.tables.create( @@ -121,7 +121,7 @@ def test_get_table(self, platform: ixmp4.Platform): def test_table_add_data(self, platform: ixmp4.Platform): run = platform.runs.create("Model", "Scenario") indexset, indexset_2 = tuple( - IndexSet(_backend=platform.backend, _model=model) # type: ignore + IndexSet(_backend=platform.backend, _model=model) for model in create_indexsets_for_run(platform=platform, run_id=run.id) ) indexset.add(data=["foo", "bar", ""]) @@ -266,7 +266,7 @@ def test_list_tables(self, platform: ixmp4.Platform): def test_tabulate_table(self, platform: ixmp4.Platform): run = platform.runs.create("Model", "Scenario") indexset, indexset_2 = tuple( - IndexSet(_backend=platform.backend, _model=model) # type: ignore + IndexSet(_backend=platform.backend, _model=model) for model in create_indexsets_for_run(platform=platform, run_id=run.id) ) table = run.optimization.tables.create( diff --git a/tests/data/test_meta.py b/tests/data/test_meta.py index 89ee08a6..75a1b9da 100644 --- a/tests/data/test_meta.py +++ b/tests/data/test_meta.py @@ -1,3 +1,5 @@ +from typing import Literal + import pandas as pd import pytest @@ -7,7 +9,18 @@ from ..utils import assert_unordered_equality -TEST_ENTRIES = [ +TEST_ENTRIES: list[ + tuple[ + str, + bool | float | int | str, + Literal[ + RunMetaEntry.Type.BOOL, + RunMetaEntry.Type.FLOAT, + RunMetaEntry.Type.INT, + RunMetaEntry.Type.STR, + ], + ] +] = [ ("Boolean", True, RunMetaEntry.Type.BOOL), ("Float", 0.2, RunMetaEntry.Type.FLOAT), ("Integer", 1, RunMetaEntry.Type.INT), @@ -26,7 +39,7 @@ def test_create_get_entry(self, platform: ixmp4.Platform): run.set_as_default() for key, value, type in TEST_ENTRIES: - entry = platform.backend.meta.create(run.id, key, value) # type:ignore + entry = platform.backend.meta.create(run.id, key, value) assert entry.key == key assert entry.value == value assert entry.type == type diff --git a/tests/data/test_optimization_indexset.py b/tests/data/test_optimization_indexset.py index 7ae99608..7de77d3e 100644 --- a/tests/data/test_optimization_indexset.py +++ b/tests/data/test_optimization_indexset.py @@ -126,14 +126,16 @@ def test_tabulate_indexsets(self, platform: ixmp4.Platform): ) def test_add_data(self, platform: ixmp4.Platform): - test_data = ["foo", "bar"] + # See https://mypy.readthedocs.io/en/stable/common_issues.html#variance for why + # a type hint is required here + test_data: list[float | int | str] = ["foo", "bar"] run = platform.backend.runs.create("Model", "Scenario") indexset_1, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id ) platform.backend.optimization.indexsets.add_data( indexset_id=indexset_1.id, - data=test_data, # type: ignore + data=test_data, ) indexset_1 = platform.backend.optimization.indexsets.get( run_id=run.id, name=indexset_1.name @@ -141,7 +143,7 @@ def test_add_data(self, platform: ixmp4.Platform): platform.backend.optimization.indexsets.add_data( indexset_id=indexset_2.id, - data=test_data, # type: ignore + data=test_data, ) assert ( diff --git a/tests/utils.py b/tests/utils.py index a35e7804..ec3df5e3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,6 +19,6 @@ def create_indexsets_for_run( return tuple( platform.backend.optimization.indexsets.create( run_id=run_id, name=f"Indexset {i}" - ) # type: ignore + ) for i in range(offset, offset + amount) )