Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schedule PymatgenTest for migration from unittest to pytest #4209

Closed
3 changes: 1 addition & 2 deletions src/pymatgen/analysis/elasticity/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,7 @@ def from_pseudoinverse(cls, strains, stresses) -> Self:
"questionable results from vasp data, use with caution."
)
stresses = np.array([Stress(stress).voigt for stress in stresses])
with warnings.catch_warnings():
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
strains = np.array([Strain(strain).voigt for strain in strains])
strains = np.array([Strain(strain).voigt for strain in strains])

voigt_fit = np.transpose(np.dot(np.linalg.pinv(strains), stresses))
return cls.from_voigt(voigt_fit)
Expand Down
13 changes: 5 additions & 8 deletions src/pymatgen/command_line/mcsqs_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import tempfile
import warnings
from pathlib import Path
from shutil import which
from subprocess import Popen, TimeoutExpired
Expand All @@ -32,7 +31,7 @@ class Sqs(NamedTuple):

@requires(
which("mcsqs") and which("str2cif"),
"run_mcsqs requires first installing AT-AT, see https://www.brown.edu/Departments/Engineering/Labs/avdw/atat/",
"run_mcsqs requires ATAT, see https://www.brown.edu/Departments/Engineering/Labs/avdw/atat/",
)
def run_mcsqs(
structure: Structure,
Expand Down Expand Up @@ -146,7 +145,7 @@ def run_mcsqs(

raise RuntimeError("mcsqs exited before timeout reached")

except TimeoutExpired:
except TimeoutExpired as exc:
for process in mcsqs_find_sqs_processes:
process.kill()
process.communicate()
Expand All @@ -157,7 +156,7 @@ def run_mcsqs(
raise RuntimeError(
"mcsqs did not generate output files, "
"is search_time sufficient or are number of instances too high?"
)
) from exc

process = Popen(["mcsqs", "-best"])
process.communicate()
Expand All @@ -166,7 +165,7 @@ def run_mcsqs(
return _parse_sqs_path(".")

os.chdir(original_directory)
raise TimeoutError("Cluster expansion took too long.")
raise TimeoutError("Cluster expansion took too long.") from exc


def _parse_sqs_path(path) -> Sqs:
Expand All @@ -191,9 +190,7 @@ def _parse_sqs_path(path) -> Sqs:
process = Popen(["str2cif"], stdin=input_file, stdout=output_file, cwd=path)
process.communicate()

with warnings.catch_warnings():
warnings.simplefilter("ignore")
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
best_sqs = Structure.from_file(path / "bestsqs.out")
best_sqs = Structure.from_file(path / "bestsqs.out")

# Get best SQS objective function
with open(path / "bestcorr.out") as file:
Expand Down
2 changes: 1 addition & 1 deletion src/pymatgen/entries/correction_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def compute_corrections(self, exp_entries: list, calc_entries: dict) -> dict:

with warnings.catch_warnings():
# numpy raises warning if the entire array is nan values
warnings.simplefilter("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", message="Mean of empty slice", category=RuntimeWarning)
mean_uncert = np.nanmean(sigma)

sigma = np.where(np.isnan(sigma), mean_uncert, sigma)
Expand Down
2 changes: 1 addition & 1 deletion src/pymatgen/transformations/advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def find_codopant(
for sym in symbols:
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
warnings.filterwarnings("ignore", message=r"No (default )?ionic radius for .+")
sp = Species(sym, oxidation_state)
radius = sp.ionic_radius
if radius is not None:
Expand Down
12 changes: 12 additions & 0 deletions src/pymatgen/util/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import pickle # use pickle, not cPickle so that we get the traceback in case of errors
import string
import warnings
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import TestCase
Expand Down Expand Up @@ -41,6 +42,17 @@ class PymatgenTest(TestCase):
# dict of lazily-loaded test structures (initialized to None)
TEST_STRUCTURES: ClassVar[dict[str | Path, Structure | None]] = dict.fromkeys(STRUCTURES_DIR.glob("*"))

@classmethod
def setUpClass(cls):
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
"""Issue a FutureWarning, see PR 4209."""
warnings.warn(
"PymatgenTest is scheduled for migration to pytest after 2026-01-01. "
"Please adapt your tests accordingly.",
FutureWarning,
stacklevel=2,
)
super().setUpClass()

@pytest.fixture(autouse=True) # make all tests run a in a temporary directory accessible via self.tmp_path
def _tmp_dir(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
# https://pytest.org/en/latest/how-to/unittest.html#using-autouse-fixtures-and-accessing-other-fixtures
Expand Down
4 changes: 0 additions & 4 deletions tests/command_line/test_bader_caller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import warnings
from shutil import which

import numpy as np
Expand All @@ -17,9 +16,6 @@

@pytest.mark.skipif(not which("bader"), reason="bader executable not present")
class TestBaderAnalysis(PymatgenTest):
def setUp(self):
warnings.catch_warnings()

def test_init(self):
# test with reference file
analysis = BaderAnalysis(
Expand Down
5 changes: 3 additions & 2 deletions tests/command_line/test_mcsqs_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
TEST_DIR = f"{TEST_FILES_DIR}/io/atat/mcsqs"


@pytest.mark.skipif(not (which("mcsqs") and which("str2cif")), reason="mcsqs executable not present")
@pytest.mark.skipif(not (which("mcsqs") and which("str2cif")), reason="mcsqs or str2cif executable not present")
class TestMcsqsCaller(PymatgenTest):
def setUp(self):
self.pzt_structs = loadfn(f"{TEST_DIR}/pzt-structs.json")
Expand Down Expand Up @@ -103,5 +103,6 @@ def test_mcsqs_caller_runtime_error(self):
struct.replace_species({"Ti": {"Ti": 0.5, "Zr": 0.5}, "Zr": {"Ti": 0.5, "Zr": 0.5}})
struct.replace_species({"Pb": {"Ti": 0.2, "Pb": 0.8}})
struct.replace_species({"O": {"F": 0.8, "O": 0.2}})
with pytest.raises(RuntimeError, match="mcsqs exited before timeout reached"):
Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is failing as RuntimeError with a different message is raised:

>       with pytest.raises(RuntimeError, match="mcsqs exited before timeout reached"):
E       AssertionError: Regex pattern did not match.
E        Regex: 'mcsqs exited before timeout reached'
E        Input: 'mcsqs did not generate output files, is search_time sufficient or are number of instances too high?'

But looks like both exceptions are very similar (all pointing to timeout) so I would just change the error message here:

try:
for process in mcsqs_find_sqs_processes:
process.communicate(timeout=search_time * 60)
if instances and instances > 1:
process = Popen(["mcsqs", "-best"])
process.communicate()
if os.path.isfile("bestsqs.out") and os.path.isfile("bestcorr.out"):
return _parse_sqs_path(".")
raise RuntimeError("mcsqs exited before timeout reached")
except TimeoutExpired:
for process in mcsqs_find_sqs_processes:
process.kill()
process.communicate()
# Find the best sqs structures
if instances and instances > 1:
if not os.path.isfile("bestcorr1.out"):
raise RuntimeError(
"mcsqs did not generate output files, "
"is search_time sufficient or are number of instances too high?"
)


with pytest.raises(RuntimeError, match="mcsqs did not generate output files"):
run_mcsqs(struct, {2: 6, 3: 4}, 10, 0.000001)
4 changes: 2 additions & 2 deletions tests/core/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ def test_memory(self):
assert mega_0 == mega_1 == mega_2 == mega_3

def test_deprecated_memory(self):
# TODO: remove after 2025-01-01
"""TODO: remove entire test method after 2025-01-01"""
for unit in ("Kb", "kb", "Mb", "mb", "Gb", "gb", "Tb", "tb"):
with pytest.warns(DeprecationWarning, match=f"Unit {unit} is deprecated"):
Memory(1, unit)

with warnings.catch_warnings():
warnings.simplefilter("error")
for unit in ("KB", "MB", "GB", "TB"):
warnings.filterwarnings("error", f"Unit {unit} is deprecated", DeprecationWarning)
Memory(1, unit)

def test_unitized(self):
Expand Down
14 changes: 7 additions & 7 deletions tests/io/vasp/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,8 @@ def test_from_file_potcar_overwrite_elements(self):
copyfile(f"{VASP_IN_DIR}/POSCAR_C2", tmp_poscar_path := f"{self.tmp_path}/POSCAR")
copyfile(f"{VASP_IN_DIR}/POTCAR_C2.gz", f"{self.tmp_path}/POTCAR.gz")

with warnings.catch_warnings(record=True) as record:
_poscar = Poscar.from_file(tmp_poscar_path)
assert not any("Elements in POSCAR would be overwritten" in str(warning.message) for warning in record)
warnings.filterwarnings("error", message="Elements in POSCAR would be overwritten")
_poscar = Poscar.from_file(tmp_poscar_path)

def test_from_str_default_names(self):
"""Similar to test_from_file_bad_potcar, ensure "default_names"
Expand All @@ -237,18 +236,19 @@ def test_from_str_default_names(self):
assert poscar.site_symbols == ["Si", "O"]

# Assert no warning if using the same elements (or superset)
with warnings.catch_warnings(record=True) as record:
with warnings.catch_warnings():
warnings.filterwarnings("error", message="Elements in POSCAR would be overwritten")

poscar = Poscar.from_str(poscar_str, default_names=["Si", "F"])
assert poscar.site_symbols == ["Si", "F"]

poscar = Poscar.from_str(poscar_str, default_names=["Si", "F", "O"])
assert poscar.site_symbols == ["Si", "F"]
assert not any("Elements in POSCAR would be overwritten" in str(warning.message) for warning in record)

# Make sure it could be bypassed (by using None, when not check_for_potcar)
with warnings.catch_warnings(record=True) as record:
with warnings.catch_warnings():
warnings.filterwarnings("error", message="Elements in POSCAR would be overwritten")
_poscar = Poscar.from_str(poscar_str, default_names=None)
assert not any("Elements in POSCAR would be overwritten" in str(warning.message) for warning in record)

def test_from_str_default_names_vasp4(self):
"""Poscar.from_str with default_names given could also be used to
Expand Down
72 changes: 72 additions & 0 deletions tests/util/test_testing_migrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""This is not a functional test but a utility to verify behaviors specific to
`unittest.TestCase`. It ensures we're aware the side effects from the migration.

TODO: remove this test module after migration (2026-01-01), see PR 4209.

`unittest.TestCase`-specific features and brief migration guide:
- Setup/teardown methods (`setUp`, `setUpClass`, `tearDown`, `tearDownClass`):
1. Recommended approach in pytest: Use fixtures.
Documentation: https://docs.pytest.org/en/stable/reference/fixtures.html#fixture
OR
2. Use pytest's xUnit-style setup/teardown functions:
`[setup/teardown]_[class/method/function]`.
Documentation: https://docs.pytest.org/en/stable/how-to/xunit_setup.html

- Assertion methods (`assertTrue`, `assertFalse`, `assertEqual`, etc.):
Replace with direct Python `assert` statements.
"""

from __future__ import annotations

import pytest

from pymatgen.util.testing import PymatgenTest


@pytest.mark.filterwarnings("ignore:PymatgenTest is scheduled for migration to pytest")
class TestPymatgenTestTestCase(PymatgenTest):
"""Baseline inspector for migration side effects."""

def test_pmg_test_migration_warning(self):
"""Test PymatgenTest migration warning."""
with pytest.warns(FutureWarning, match="PymatgenTest is scheduled for migration to pytest after 2026-01-01"):
self.setUpClass() # invoke the setup phase

def test_unittest_testcase_specific_funcs(self):
"""Make sure TestCase-specific methods are available until migration."""
# Testing setUp and tearDown methods
self.setUp()
self.tearDown()

# Testing class-level setUp and tearDown methods
self.setUpClass()
self.tearDownClass()

# Test the assertion methods
self.assertTrue(True) # noqa: PT009, FBT003
self.assertFalse(False) # noqa: PT009, FBT003
self.assertEqual(1, 1) # noqa: PT009


class TestPymatgenTestPytest:
def test_unittest_testcase_specific_funcs(self):
"""Test unittest.TestCase-specific methods for migration to pytest."""
# Testing setUp and tearDown methods
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.setUp()
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.tearDown()

# Testing class-level setUp and tearDown methods
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.setUpClass()
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.tearDownClass()

# Test the assertion methods
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.assertTrue(True) # noqa: PT009, FBT003
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.assertFalse(False) # noqa: PT009, FBT003
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.assertEqual(1, 1) # noqa: PT009
Loading