diff --git a/src/pymatgen/analysis/elasticity/elastic.py b/src/pymatgen/analysis/elasticity/elastic.py index 4a236e24d3f..d67d6d87c22 100644 --- a/src/pymatgen/analysis/elasticity/elastic.py +++ b/src/pymatgen/analysis/elasticity/elastic.py @@ -479,8 +479,7 @@ def from_pseudoinverse(cls, strains, stresses) -> Self: "questionable results from vasp data, use with caution." ) stresses = np.array([Stress(stress).voigt for stress in stresses]) - with warnings.catch_warnings(): - strains = np.array([Strain(strain).voigt for strain in strains]) + strains = np.array([Strain(strain).voigt for strain in strains]) voigt_fit = np.transpose(np.dot(np.linalg.pinv(strains), stresses)) return cls.from_voigt(voigt_fit) diff --git a/src/pymatgen/command_line/mcsqs_caller.py b/src/pymatgen/command_line/mcsqs_caller.py index 78647df8591..aa10083ec63 100644 --- a/src/pymatgen/command_line/mcsqs_caller.py +++ b/src/pymatgen/command_line/mcsqs_caller.py @@ -6,7 +6,6 @@ import os import tempfile -import warnings from pathlib import Path from shutil import which from subprocess import Popen, TimeoutExpired @@ -32,7 +31,7 @@ class Sqs(NamedTuple): @requires( which("mcsqs") and which("str2cif"), - "run_mcsqs requires first installing AT-AT, see https://www.brown.edu/Departments/Engineering/Labs/avdw/atat/", + "run_mcsqs requires ATAT, see https://www.brown.edu/Departments/Engineering/Labs/avdw/atat/", ) def run_mcsqs( structure: Structure, @@ -146,7 +145,7 @@ def run_mcsqs( raise RuntimeError("mcsqs exited before timeout reached") - except TimeoutExpired: + except TimeoutExpired as exc: for process in mcsqs_find_sqs_processes: process.kill() process.communicate() @@ -157,7 +156,7 @@ def run_mcsqs( raise RuntimeError( "mcsqs did not generate output files, " "is search_time sufficient or are number of instances too high?" - ) + ) from exc process = Popen(["mcsqs", "-best"]) process.communicate() @@ -166,7 +165,7 @@ def run_mcsqs( return _parse_sqs_path(".") os.chdir(original_directory) - raise TimeoutError("Cluster expansion took too long.") + raise TimeoutError("Cluster expansion took too long.") from exc def _parse_sqs_path(path) -> Sqs: @@ -191,9 +190,7 @@ def _parse_sqs_path(path) -> Sqs: process = Popen(["str2cif"], stdin=input_file, stdout=output_file, cwd=path) process.communicate() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - best_sqs = Structure.from_file(path / "bestsqs.out") + best_sqs = Structure.from_file(path / "bestsqs.out") # Get best SQS objective function with open(path / "bestcorr.out") as file: diff --git a/src/pymatgen/entries/correction_calculator.py b/src/pymatgen/entries/correction_calculator.py index 0053b7e5d23..a196ac52b76 100644 --- a/src/pymatgen/entries/correction_calculator.py +++ b/src/pymatgen/entries/correction_calculator.py @@ -226,7 +226,7 @@ def compute_corrections(self, exp_entries: list, calc_entries: dict) -> dict: with warnings.catch_warnings(): # numpy raises warning if the entire array is nan values - warnings.simplefilter("ignore", category=RuntimeWarning) + warnings.filterwarnings("ignore", message="Mean of empty slice", category=RuntimeWarning) mean_uncert = np.nanmean(sigma) sigma = np.where(np.isnan(sigma), mean_uncert, sigma) diff --git a/src/pymatgen/transformations/advanced_transformations.py b/src/pymatgen/transformations/advanced_transformations.py index 2537241b653..1a71eb1b95b 100644 --- a/src/pymatgen/transformations/advanced_transformations.py +++ b/src/pymatgen/transformations/advanced_transformations.py @@ -923,7 +923,7 @@ def find_codopant( for sym in symbols: try: with warnings.catch_warnings(): - warnings.simplefilter("ignore") + warnings.filterwarnings("ignore", message=r"No (default )?ionic radius for .+") sp = Species(sym, oxidation_state) radius = sp.ionic_radius if radius is not None: diff --git a/src/pymatgen/util/testing/__init__.py b/src/pymatgen/util/testing/__init__.py index acf83e32c93..5abd470b2ed 100644 --- a/src/pymatgen/util/testing/__init__.py +++ b/src/pymatgen/util/testing/__init__.py @@ -12,13 +12,13 @@ import string from pathlib import Path from typing import TYPE_CHECKING -from unittest import TestCase import pytest from monty.json import MontyDecoder, MontyEncoder, MSONable from monty.serialization import loadfn from pymatgen.core import ROOT, SETTINGS, Structure +from pymatgen.util.testing._temp_testcase import _TempTestCase4Migrate if TYPE_CHECKING: from collections.abc import Sequence @@ -35,7 +35,7 @@ FAKE_POTCAR_DIR = f"{VASP_IN_DIR}/fake_potcars" -class PymatgenTest(TestCase): +class PymatgenTest(_TempTestCase4Migrate): """Extends unittest.TestCase with several assert methods for array and str comparison.""" # dict of lazily-loaded test structures (initialized to None) diff --git a/src/pymatgen/util/testing/_temp_testcase.py b/src/pymatgen/util/testing/_temp_testcase.py new file mode 100644 index 00000000000..d87ade89a46 --- /dev/null +++ b/src/pymatgen/util/testing/_temp_testcase.py @@ -0,0 +1,84 @@ +"""Temporary TestCase for migration to `pytest` framework, +inserted FutureWarning for unittest.TestCase-specific methods. + +TODO: remove entire module after migration +""" + +# ruff: noqa: PT009, PT027 + +from __future__ import annotations + +import warnings +from unittest import TestCase + + +class _TempTestCase4Migrate(TestCase): + @staticmethod + def _issue_warning(method_name): + warnings.warn( + f"unittest {method_name=} will not be supported by pytest after migration by 2026-01-01, see PR4209.", + FutureWarning, + stacklevel=2, + ) + + def setUp(self, *args, **kwargs): + self._issue_warning("setUp") + super().setUp(*args, **kwargs) + + def tearDown(self, *args, **kwargs): + self._issue_warning("tearDown") + super().tearDown(*args, **kwargs) + + @classmethod + def setUpClass(cls, *args, **kwargs): + cls._issue_warning("setUpClass") + super().setUpClass(*args, **kwargs) + + @classmethod + def tearDownClass(cls, *args, **kwargs): + cls._issue_warning("tearDownClass") + super().tearDownClass(*args, **kwargs) + + def assertEqual(self, *args, **kwargs): + self._issue_warning("assertEqual") + return super().assertEqual(*args, **kwargs) + + def assertNotEqual(self, *args, **kwargs): + self._issue_warning("assertNotEqual") + return super().assertNotEqual(*args, **kwargs) + + def assertTrue(self, *args, **kwargs): + self._issue_warning("assertTrue") + return super().assertTrue(*args, **kwargs) + + def assertFalse(self, *args, **kwargs): + self._issue_warning("assertFalse") + return super().assertFalse(*args, **kwargs) + + def assertIsNone(self, *args, **kwargs): + self._issue_warning("assertIsNone") + return super().assertIsNone(*args, **kwargs) + + def assertIsNotNone(self, *args, **kwargs): + self._issue_warning("assertIsNotNone") + return super().assertIsNotNone(*args, **kwargs) + + def assertIn(self, *args, **kwargs): # codespell:ignore + self._issue_warning("assertIn") # codespell:ignore + return super().assertIn(*args, **kwargs) # codespell:ignore + + def assertNotIn(self, *args, **kwargs): + self._issue_warning("assertNotIn") + return super().assertNotIn(*args, **kwargs) + + def assertIsInstance(self, *args, **kwargs): + self._issue_warning("assertIsInstance") + return super().assertIsInstance(*args, **kwargs) + + def assertNotIsInstance(self, *args, **kwargs): + self._issue_warning("assertNotIsInstance") + return super().assertNotIsInstance(*args, **kwargs) + + def assertRaises(self, *args, **kwargs): + self._issue_warning("assertRaises") + return super().assertRaises(*args, **kwargs) diff --git a/tests/analysis/test_phase_diagram.py b/tests/analysis/test_phase_diagram.py index 934ba59fa9d..dbabecaa3b5 100644 --- a/tests/analysis/test_phase_diagram.py +++ b/tests/analysis/test_phase_diagram.py @@ -2,7 +2,6 @@ import collections import unittest -import unittest.mock from itertools import combinations from numbers import Number from unittest import TestCase diff --git a/tests/command_line/test_bader_caller.py b/tests/command_line/test_bader_caller.py index ad39f8c6b61..90ec2df89d1 100644 --- a/tests/command_line/test_bader_caller.py +++ b/tests/command_line/test_bader_caller.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from shutil import which import numpy as np @@ -17,9 +16,6 @@ @pytest.mark.skipif(not which("bader"), reason="bader executable not present") class TestBaderAnalysis(PymatgenTest): - def setUp(self): - warnings.catch_warnings() - def test_init(self): # test with reference file analysis = BaderAnalysis( diff --git a/tests/command_line/test_mcsqs_caller.py b/tests/command_line/test_mcsqs_caller.py index ca9f8a018d6..4ff60fdd0a4 100644 --- a/tests/command_line/test_mcsqs_caller.py +++ b/tests/command_line/test_mcsqs_caller.py @@ -17,7 +17,7 @@ TEST_DIR = f"{TEST_FILES_DIR}/io/atat/mcsqs" -@pytest.mark.skipif(not (which("mcsqs") and which("str2cif")), reason="mcsqs executable not present") +@pytest.mark.skipif(not (which("mcsqs") and which("str2cif")), reason="mcsqs or str2cif executable not present") class TestMcsqsCaller(PymatgenTest): def setUp(self): self.pzt_structs = loadfn(f"{TEST_DIR}/pzt-structs.json") @@ -103,5 +103,6 @@ def test_mcsqs_caller_runtime_error(self): struct.replace_species({"Ti": {"Ti": 0.5, "Zr": 0.5}, "Zr": {"Ti": 0.5, "Zr": 0.5}}) struct.replace_species({"Pb": {"Ti": 0.2, "Pb": 0.8}}) struct.replace_species({"O": {"F": 0.8, "O": 0.2}}) - with pytest.raises(RuntimeError, match="mcsqs exited before timeout reached"): + + with pytest.raises(RuntimeError, match="mcsqs did not generate output files"): run_mcsqs(struct, {2: 6, 3: 4}, 10, 0.000001) diff --git a/tests/core/test_units.py b/tests/core/test_units.py index 88e06d92517..7334f08f6c3 100644 --- a/tests/core/test_units.py +++ b/tests/core/test_units.py @@ -110,14 +110,14 @@ def test_memory(self): assert mega_0 == mega_1 == mega_2 == mega_3 def test_deprecated_memory(self): - # TODO: remove after 2025-01-01 + """TODO: remove entire test method after 2025-01-01""" for unit in ("Kb", "kb", "Mb", "mb", "Gb", "gb", "Tb", "tb"): with pytest.warns(DeprecationWarning, match=f"Unit {unit} is deprecated"): Memory(1, unit) with warnings.catch_warnings(): - warnings.simplefilter("error") for unit in ("KB", "MB", "GB", "TB"): + warnings.filterwarnings("error", f"Unit {unit} is deprecated", DeprecationWarning) Memory(1, unit) def test_unitized(self): diff --git a/tests/io/vasp/test_inputs.py b/tests/io/vasp/test_inputs.py index 1616ff4e5e7..bd4d4eeec7d 100644 --- a/tests/io/vasp/test_inputs.py +++ b/tests/io/vasp/test_inputs.py @@ -208,9 +208,8 @@ def test_from_file_potcar_overwrite_elements(self): copyfile(f"{VASP_IN_DIR}/POSCAR_C2", tmp_poscar_path := f"{self.tmp_path}/POSCAR") copyfile(f"{VASP_IN_DIR}/POTCAR_C2.gz", f"{self.tmp_path}/POTCAR.gz") - with warnings.catch_warnings(record=True) as record: - _poscar = Poscar.from_file(tmp_poscar_path) - assert not any("Elements in POSCAR would be overwritten" in str(warning.message) for warning in record) + warnings.filterwarnings("error", message="Elements in POSCAR would be overwritten") + _poscar = Poscar.from_file(tmp_poscar_path) def test_from_str_default_names(self): """Similar to test_from_file_bad_potcar, ensure "default_names" @@ -237,18 +236,19 @@ def test_from_str_default_names(self): assert poscar.site_symbols == ["Si", "O"] # Assert no warning if using the same elements (or superset) - with warnings.catch_warnings(record=True) as record: + with warnings.catch_warnings(): + warnings.filterwarnings("error", message="Elements in POSCAR would be overwritten") + poscar = Poscar.from_str(poscar_str, default_names=["Si", "F"]) assert poscar.site_symbols == ["Si", "F"] poscar = Poscar.from_str(poscar_str, default_names=["Si", "F", "O"]) assert poscar.site_symbols == ["Si", "F"] - assert not any("Elements in POSCAR would be overwritten" in str(warning.message) for warning in record) # Make sure it could be bypassed (by using None, when not check_for_potcar) - with warnings.catch_warnings(record=True) as record: + with warnings.catch_warnings(): + warnings.filterwarnings("error", message="Elements in POSCAR would be overwritten") _poscar = Poscar.from_str(poscar_str, default_names=None) - assert not any("Elements in POSCAR would be overwritten" in str(warning.message) for warning in record) def test_from_str_default_names_vasp4(self): """Poscar.from_str with default_names given could also be used to diff --git a/tests/util/test_testing_migrate.py b/tests/util/test_testing_migrate.py new file mode 100644 index 00000000000..d99f4fe06f1 --- /dev/null +++ b/tests/util/test_testing_migrate.py @@ -0,0 +1,82 @@ +"""This is not a functional test but a utility to verify behaviors specific to +`unittest.TestCase`. It ensures we're aware the side effects from the migration. + +TODO: remove this test module after migration (2026-01-01), see PR 4209. +""" + +# ruff: noqa: PT009, PT027, FBT003 + +from __future__ import annotations + +import pytest + +from pymatgen.util.testing import PymatgenTest + + +@pytest.mark.filterwarnings("ignore", message="will not be supported by pytest after migration", category=FutureWarning) +class TestPymatgenTestTestCase(PymatgenTest): + """Baseline inspector for migration side effects.""" + + def test_unittest_testcase_specific_funcs(self): + """Make sure TestCase-specific methods are available until migration, + and FutureWarning is emitted. + """ + msg = "will not be supported by pytest after migration" + + # Testing setUp and tearDown methods + with pytest.warns(FutureWarning, match=msg): + self.setUp() + with pytest.warns(FutureWarning, match=msg): + self.tearDown() + + # Testing class-level setUp and tearDown methods + with pytest.warns(FutureWarning, match=msg): + self.setUpClass() + with pytest.warns(FutureWarning, match=msg): + self.tearDownClass() + + # Test the assertion methods + with pytest.warns(FutureWarning, match=msg): + self.assertEqual(1, 1) + + with pytest.warns(FutureWarning, match=msg): + self.assertNotEqual(1, 2) + + with pytest.warns(FutureWarning, match=msg): + self.assertTrue(True) + + with pytest.warns(FutureWarning, match=msg): + self.assertFalse(False) + + with pytest.warns(FutureWarning, match=msg): + self.assertIsNone(None) + + with pytest.warns(FutureWarning, match=msg): + self.assertIsNotNone("hello") + + with pytest.warns(FutureWarning, match=msg), self.assertRaises(ValueError): + raise ValueError("hi") + + +class TestPymatgenTestPytest: + def test_unittest_testcase_specific_funcs(self): + """Test unittest.TestCase-specific methods for migration to pytest.""" + # Testing setUp and tearDown methods + with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"): + self.setUp() + with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"): + self.tearDown() + + # Testing class-level setUp and tearDown methods + with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"): + self.setUpClass() + with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"): + self.tearDownClass() + + # Test the assertion methods + with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"): + self.assertTrue(True) + with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"): + self.assertFalse(False) + with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"): + self.assertEqual(1, 1)