diff --git a/src/pymatgen/util/testing.py b/src/pymatgen/util/testing.py index 1a0f42dea9a..73c71a06a20 100644 --- a/src/pymatgen/util/testing.py +++ b/src/pymatgen/util/testing.py @@ -8,7 +8,7 @@ from __future__ import annotations import json -import pickle # use pickle, not cPickle so that we get the traceback in case of errors +import pickle # use pickle over cPickle to get traceback in case of errors import string from pathlib import Path from typing import TYPE_CHECKING @@ -19,12 +19,15 @@ from monty.json import MontyDecoder, MontyEncoder, MSONable from monty.serialization import loadfn -from pymatgen.core import ROOT, SETTINGS, Structure +from pymatgen.core import ROOT, SETTINGS if TYPE_CHECKING: from collections.abc import Sequence from typing import Any, ClassVar + from pymatgen.core import Structure + from pymatgen.util.typing import PathLike + _MODULE_DIR: Path = Path(__file__).absolute().parent STRUCTURES_DIR: Path = _MODULE_DIR / "structures" @@ -33,10 +36,10 @@ VASP_IN_DIR: str = f"{TEST_FILES_DIR}/io/vasp/inputs" VASP_OUT_DIR: str = f"{TEST_FILES_DIR}/io/vasp/outputs" -# fake POTCARs have original header information, meaning properties like number of electrons, +# Fake POTCARs have original header information, meaning properties like number of electrons, # nuclear charge, core radii, etc. are unchanged (important for testing) while values of the and # pseudopotential kinetic energy corrections are scrambled to avoid VASP copyright infringement -FAKE_POTCAR_DIR = f"{VASP_IN_DIR}/fake_potcars" +FAKE_POTCAR_DIR: str = f"{VASP_IN_DIR}/fake_potcars" class MatSciTest: @@ -50,36 +53,53 @@ class MatSciTest: """ # dict of lazily-loaded test structures (initialized to None) - TEST_STRUCTURES: ClassVar[dict[str | Path, Structure | None]] = dict.fromkeys(STRUCTURES_DIR.glob("*")) + TEST_STRUCTURES: ClassVar[dict[PathLike, Structure | None]] = dict.fromkeys(STRUCTURES_DIR.glob("*")) - @pytest.fixture(autouse=True) # make all tests run a in a temporary directory accessible via self.tmp_path + @pytest.fixture(autouse=True) def _tmp_dir(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: - # https://pytest.org/en/latest/how-to/unittest.html#using-autouse-fixtures-and-accessing-other-fixtures - monkeypatch.chdir(tmp_path) # change to pytest-provided temporary directory - self.tmp_path = tmp_path + """Make all tests run a in a temporary directory accessible via self.tmp_path. - @classmethod - def get_structure(cls, name: str) -> Structure: + References: + https://docs.pytest.org/en/stable/how-to/tmp_path.html """ - Load a structure from `pymatgen.util.structures`. + monkeypatch.chdir(tmp_path) # change to temporary directory + self.tmp_path = tmp_path + + @staticmethod + def assert_msonable(obj: MSONable, test_is_subclass: bool = True) -> str: + """Test if an object is MSONable and verify the contract is fulfilled, + and return the serialized object. + + By default, the method tests whether obj is an instance of MSONable. + This check can be deactivated by setting `test_is_subclass` to False. Args: - name (str): Name of the structure file, for example "LiFePO4". + obj (Any): The object to be checked. + test_is_subclass (bool): Check if object is an instance of MSONable + or its subclasses. Returns: - Structure + str: Serialized object. """ - try: - struct = cls.TEST_STRUCTURES.get(name) or loadfn(f"{STRUCTURES_DIR}/{name}.json") - except FileNotFoundError as exc: - raise FileNotFoundError(f"structure for {name} doesn't exist") from exc + obj_name = obj.__class__.__name__ - cls.TEST_STRUCTURES[name] = struct + # Check if is an instance of MONable (or its subclasses) + if test_is_subclass and not isinstance(obj, MSONable): + raise TypeError(f"{obj_name} object is not MSONable") - return struct.copy() + # Check if the object can be accurately reconstructed from its dict representation + if obj.as_dict() != type(obj).from_dict(obj.as_dict()).as_dict(): + raise ValueError(f"{obj_name} object could not be reconstructed accurately from its dict representation.") + + # Verify that the deserialized object's class is a subclass of the original object's class + json_str = json.dumps(obj.as_dict(), cls=MontyEncoder) + round_trip = json.loads(json_str, cls=MontyDecoder) + if not issubclass(type(round_trip), type(obj)): + raise TypeError(f"The reconstructed {round_trip.__class__.__name__} object is not a subclass of {obj_name}") + return json_str @staticmethod - def assert_str_content_equal(actual, expected): + def assert_str_content_equal(actual: str, expected: str) -> None: """Test if two strings are equal, ignoring whitespaces. Args: @@ -99,7 +119,32 @@ def assert_str_content_equal(actual, expected): f"{expected}\n" ) - def serialize_with_pickle(self, objects: Any, protocols: Sequence[int] | None = None, test_eq: bool = True): + @classmethod + def get_structure(cls, name: str) -> Structure: + """ + Load a structure from `pymatgen.util.structures`. + + Args: + name (str): Name of the structure file, for example "LiFePO4". + + Returns: + Structure + """ + try: + struct = cls.TEST_STRUCTURES.get(name) or loadfn(f"{STRUCTURES_DIR}/{name}.json") + except FileNotFoundError as exc: + raise FileNotFoundError(f"structure for {name} doesn't exist") from exc + + cls.TEST_STRUCTURES[name] = struct + + return struct.copy() + + def serialize_with_pickle( + self, + objects: Any, + protocols: Sequence[int] | None = None, + test_eq: bool = True, + ): """Test whether the object(s) can be serialized and deserialized with `pickle`. This method tries to serialize the objects with `pickle` and the protocols specified in input. Then it deserializes the pickled format @@ -163,38 +208,6 @@ def serialize_with_pickle(self, objects: Any, protocols: Sequence[int] | None = return [o[0] for o in objects_by_protocol] return objects_by_protocol - def assert_msonable(self, obj: MSONable, test_is_subclass: bool = True) -> str: - """Test if an object is MSONable and verify the contract is fulfilled, - and return the serialized object. - - By default, the method tests whether obj is an instance of MSONable. - This check can be deactivated by setting `test_is_subclass` to False. - - Args: - obj (Any): The object to be checked. - test_is_subclass (bool): Check if object is an instance of MSONable - or its subclasses. - - Returns: - str: Serialized object. - """ - obj_name = obj.__class__.__name__ - - # Check if is an instance of MONable (or its subclasses) - if test_is_subclass and not isinstance(obj, MSONable): - raise TypeError(f"{obj_name} object is not MSONable") - - # Check if the object can be accurately reconstructed from its dict representation - if obj.as_dict() != type(obj).from_dict(obj.as_dict()).as_dict(): - raise ValueError(f"{obj_name} object could not be reconstructed accurately from its dict representation.") - - # Verify that the deserialized object's class is a subclass of the original object's class - json_str = json.dumps(obj.as_dict(), cls=MontyEncoder) - round_trip = json.loads(json_str, cls=MontyDecoder) - if not issubclass(type(round_trip), type(obj)): - raise TypeError(f"The reconstructed {round_trip.__class__.__name__} object is not a subclass of {obj_name}") - return json_str - @deprecated(MatSciTest, deadline=(2026, 1, 1)) class PymatgenTest(TestCase, MatSciTest):