Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schedule PymatgenTest for migration from unittest to pytest #4209

Closed
3 changes: 1 addition & 2 deletions src/pymatgen/analysis/elasticity/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,7 @@ def from_pseudoinverse(cls, strains, stresses) -> Self:
"questionable results from vasp data, use with caution."
)
stresses = np.array([Stress(stress).voigt for stress in stresses])
with warnings.catch_warnings():
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
strains = np.array([Strain(strain).voigt for strain in strains])
strains = np.array([Strain(strain).voigt for strain in strains])

voigt_fit = np.transpose(np.dot(np.linalg.pinv(strains), stresses))
return cls.from_voigt(voigt_fit)
Expand Down
13 changes: 5 additions & 8 deletions src/pymatgen/command_line/mcsqs_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import tempfile
import warnings
from pathlib import Path
from shutil import which
from subprocess import Popen, TimeoutExpired
Expand All @@ -32,7 +31,7 @@ class Sqs(NamedTuple):

@requires(
which("mcsqs") and which("str2cif"),
"run_mcsqs requires first installing AT-AT, see https://www.brown.edu/Departments/Engineering/Labs/avdw/atat/",
"run_mcsqs requires ATAT, see https://www.brown.edu/Departments/Engineering/Labs/avdw/atat/",
)
def run_mcsqs(
structure: Structure,
Expand Down Expand Up @@ -146,7 +145,7 @@ def run_mcsqs(

raise RuntimeError("mcsqs exited before timeout reached")

except TimeoutExpired:
except TimeoutExpired as exc:
for process in mcsqs_find_sqs_processes:
process.kill()
process.communicate()
Expand All @@ -157,7 +156,7 @@ def run_mcsqs(
raise RuntimeError(
"mcsqs did not generate output files, "
"is search_time sufficient or are number of instances too high?"
)
) from exc

process = Popen(["mcsqs", "-best"])
process.communicate()
Expand All @@ -166,7 +165,7 @@ def run_mcsqs(
return _parse_sqs_path(".")

os.chdir(original_directory)
raise TimeoutError("Cluster expansion took too long.")
raise TimeoutError("Cluster expansion took too long.") from exc


def _parse_sqs_path(path) -> Sqs:
Expand All @@ -191,9 +190,7 @@ def _parse_sqs_path(path) -> Sqs:
process = Popen(["str2cif"], stdin=input_file, stdout=output_file, cwd=path)
process.communicate()

with warnings.catch_warnings():
warnings.simplefilter("ignore")
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
best_sqs = Structure.from_file(path / "bestsqs.out")
best_sqs = Structure.from_file(path / "bestsqs.out")

# Get best SQS objective function
with open(path / "bestcorr.out") as file:
Expand Down
2 changes: 1 addition & 1 deletion src/pymatgen/entries/correction_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def compute_corrections(self, exp_entries: list, calc_entries: dict) -> dict:

with warnings.catch_warnings():
# numpy raises warning if the entire array is nan values
warnings.simplefilter("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", message="Mean of empty slice", category=RuntimeWarning)
mean_uncert = np.nanmean(sigma)

sigma = np.where(np.isnan(sigma), mean_uncert, sigma)
Expand Down
2 changes: 1 addition & 1 deletion src/pymatgen/transformations/advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def find_codopant(
for sym in symbols:
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
warnings.filterwarnings("ignore", message=r"No (default )?ionic radius for .+")
sp = Species(sym, oxidation_state)
radius = sp.ionic_radius
if radius is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/pymatgen/util/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import string
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import TestCase

import pytest
from monty.json import MontyDecoder, MontyEncoder, MSONable
from monty.serialization import loadfn

from pymatgen.core import ROOT, SETTINGS, Structure
from pymatgen.util.testing._temp_testcase import _TempTestCase4Migrate

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -35,7 +35,7 @@
FAKE_POTCAR_DIR = f"{VASP_IN_DIR}/fake_potcars"


class PymatgenTest(TestCase):
class PymatgenTest(_TempTestCase4Migrate):
"""Extends unittest.TestCase with several assert methods for array and str comparison."""

# dict of lazily-loaded test structures (initialized to None)
Expand Down
84 changes: 84 additions & 0 deletions src/pymatgen/util/testing/_temp_testcase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Temporary TestCase for migration to `pytest` framework,
inserted FutureWarning for unittest.TestCase-specific methods.

TODO: remove entire module after migration
"""

# ruff: noqa: PT009, PT027

from __future__ import annotations

import warnings
from unittest import TestCase


class _TempTestCase4Migrate(TestCase):
@staticmethod
def _issue_warning(method_name):
warnings.warn(
f"unittest {method_name=} will not be supported by pytest after migration by 2026-01-01, see PR4209.",
FutureWarning,
stacklevel=2,
)

def setUp(self, *args, **kwargs):
self._issue_warning("setUp")
super().setUp(*args, **kwargs)

def tearDown(self, *args, **kwargs):
self._issue_warning("tearDown")
super().tearDown(*args, **kwargs)

@classmethod
def setUpClass(cls, *args, **kwargs):
cls._issue_warning("setUpClass")
super().setUpClass(*args, **kwargs)

@classmethod
def tearDownClass(cls, *args, **kwargs):
cls._issue_warning("tearDownClass")
super().tearDownClass(*args, **kwargs)

def assertEqual(self, *args, **kwargs):
self._issue_warning("assertEqual")
return super().assertEqual(*args, **kwargs)

def assertNotEqual(self, *args, **kwargs):
self._issue_warning("assertNotEqual")
return super().assertNotEqual(*args, **kwargs)

def assertTrue(self, *args, **kwargs):
self._issue_warning("assertTrue")
return super().assertTrue(*args, **kwargs)

def assertFalse(self, *args, **kwargs):
self._issue_warning("assertFalse")
return super().assertFalse(*args, **kwargs)

def assertIsNone(self, *args, **kwargs):
self._issue_warning("assertIsNone")
return super().assertIsNone(*args, **kwargs)

def assertIsNotNone(self, *args, **kwargs):
self._issue_warning("assertIsNotNone")
return super().assertIsNotNone(*args, **kwargs)

def assertIn(self, *args, **kwargs): # codespell:ignore
self._issue_warning("assertIn") # codespell:ignore
return super().assertIn(*args, **kwargs) # codespell:ignore

def assertNotIn(self, *args, **kwargs):
self._issue_warning("assertNotIn")
return super().assertNotIn(*args, **kwargs)

def assertIsInstance(self, *args, **kwargs):
self._issue_warning("assertIsInstance")
return super().assertIsInstance(*args, **kwargs)

def assertNotIsInstance(self, *args, **kwargs):
self._issue_warning("assertNotIsInstance")
return super().assertNotIsInstance(*args, **kwargs)

def assertRaises(self, *args, **kwargs):
self._issue_warning("assertRaises")
return super().assertRaises(*args, **kwargs)
1 change: 0 additions & 1 deletion tests/analysis/test_phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import collections
import unittest
import unittest.mock
from itertools import combinations
from numbers import Number
from unittest import TestCase
Expand Down
4 changes: 0 additions & 4 deletions tests/command_line/test_bader_caller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import warnings
from shutil import which

import numpy as np
Expand All @@ -17,9 +16,6 @@

@pytest.mark.skipif(not which("bader"), reason="bader executable not present")
class TestBaderAnalysis(PymatgenTest):
def setUp(self):
warnings.catch_warnings()

def test_init(self):
# test with reference file
analysis = BaderAnalysis(
Expand Down
5 changes: 3 additions & 2 deletions tests/command_line/test_mcsqs_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
TEST_DIR = f"{TEST_FILES_DIR}/io/atat/mcsqs"


@pytest.mark.skipif(not (which("mcsqs") and which("str2cif")), reason="mcsqs executable not present")
@pytest.mark.skipif(not (which("mcsqs") and which("str2cif")), reason="mcsqs or str2cif executable not present")
class TestMcsqsCaller(PymatgenTest):
def setUp(self):
self.pzt_structs = loadfn(f"{TEST_DIR}/pzt-structs.json")
Expand Down Expand Up @@ -103,5 +103,6 @@ def test_mcsqs_caller_runtime_error(self):
struct.replace_species({"Ti": {"Ti": 0.5, "Zr": 0.5}, "Zr": {"Ti": 0.5, "Zr": 0.5}})
struct.replace_species({"Pb": {"Ti": 0.2, "Pb": 0.8}})
struct.replace_species({"O": {"F": 0.8, "O": 0.2}})
with pytest.raises(RuntimeError, match="mcsqs exited before timeout reached"):
Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is failing as RuntimeError with a different message is raised:

>       with pytest.raises(RuntimeError, match="mcsqs exited before timeout reached"):
E       AssertionError: Regex pattern did not match.
E        Regex: 'mcsqs exited before timeout reached'
E        Input: 'mcsqs did not generate output files, is search_time sufficient or are number of instances too high?'

But looks like both exceptions are very similar (all pointing to timeout) so I would just change the error message here:

try:
for process in mcsqs_find_sqs_processes:
process.communicate(timeout=search_time * 60)
if instances and instances > 1:
process = Popen(["mcsqs", "-best"])
process.communicate()
if os.path.isfile("bestsqs.out") and os.path.isfile("bestcorr.out"):
return _parse_sqs_path(".")
raise RuntimeError("mcsqs exited before timeout reached")
except TimeoutExpired:
for process in mcsqs_find_sqs_processes:
process.kill()
process.communicate()
# Find the best sqs structures
if instances and instances > 1:
if not os.path.isfile("bestcorr1.out"):
raise RuntimeError(
"mcsqs did not generate output files, "
"is search_time sufficient or are number of instances too high?"
)


with pytest.raises(RuntimeError, match="mcsqs did not generate output files"):
run_mcsqs(struct, {2: 6, 3: 4}, 10, 0.000001)
4 changes: 2 additions & 2 deletions tests/core/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ def test_memory(self):
assert mega_0 == mega_1 == mega_2 == mega_3

def test_deprecated_memory(self):
# TODO: remove after 2025-01-01
"""TODO: remove entire test method after 2025-01-01"""
for unit in ("Kb", "kb", "Mb", "mb", "Gb", "gb", "Tb", "tb"):
with pytest.warns(DeprecationWarning, match=f"Unit {unit} is deprecated"):
Memory(1, unit)

with warnings.catch_warnings():
warnings.simplefilter("error")
for unit in ("KB", "MB", "GB", "TB"):
warnings.filterwarnings("error", f"Unit {unit} is deprecated", DeprecationWarning)
Memory(1, unit)

def test_unitized(self):
Expand Down
14 changes: 7 additions & 7 deletions tests/io/vasp/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,8 @@ def test_from_file_potcar_overwrite_elements(self):
copyfile(f"{VASP_IN_DIR}/POSCAR_C2", tmp_poscar_path := f"{self.tmp_path}/POSCAR")
copyfile(f"{VASP_IN_DIR}/POTCAR_C2.gz", f"{self.tmp_path}/POTCAR.gz")

with warnings.catch_warnings(record=True) as record:
_poscar = Poscar.from_file(tmp_poscar_path)
assert not any("Elements in POSCAR would be overwritten" in str(warning.message) for warning in record)
warnings.filterwarnings("error", message="Elements in POSCAR would be overwritten")
_poscar = Poscar.from_file(tmp_poscar_path)

def test_from_str_default_names(self):
"""Similar to test_from_file_bad_potcar, ensure "default_names"
Expand All @@ -237,18 +236,19 @@ def test_from_str_default_names(self):
assert poscar.site_symbols == ["Si", "O"]

# Assert no warning if using the same elements (or superset)
with warnings.catch_warnings(record=True) as record:
with warnings.catch_warnings():
warnings.filterwarnings("error", message="Elements in POSCAR would be overwritten")

poscar = Poscar.from_str(poscar_str, default_names=["Si", "F"])
assert poscar.site_symbols == ["Si", "F"]

poscar = Poscar.from_str(poscar_str, default_names=["Si", "F", "O"])
assert poscar.site_symbols == ["Si", "F"]
assert not any("Elements in POSCAR would be overwritten" in str(warning.message) for warning in record)

# Make sure it could be bypassed (by using None, when not check_for_potcar)
with warnings.catch_warnings(record=True) as record:
with warnings.catch_warnings():
warnings.filterwarnings("error", message="Elements in POSCAR would be overwritten")
_poscar = Poscar.from_str(poscar_str, default_names=None)
assert not any("Elements in POSCAR would be overwritten" in str(warning.message) for warning in record)

def test_from_str_default_names_vasp4(self):
"""Poscar.from_str with default_names given could also be used to
Expand Down
82 changes: 82 additions & 0 deletions tests/util/test_testing_migrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""This is not a functional test but a utility to verify behaviors specific to
`unittest.TestCase`. It ensures we're aware the side effects from the migration.

TODO: remove this test module after migration (2026-01-01), see PR 4209.
"""

# ruff: noqa: PT009, PT027, FBT003

from __future__ import annotations

import pytest

from pymatgen.util.testing import PymatgenTest


@pytest.mark.filterwarnings("ignore", message="will not be supported by pytest after migration", category=FutureWarning)
class TestPymatgenTestTestCase(PymatgenTest):
"""Baseline inspector for migration side effects."""

def test_unittest_testcase_specific_funcs(self):
"""Make sure TestCase-specific methods are available until migration,
and FutureWarning is emitted.
"""
msg = "will not be supported by pytest after migration"

# Testing setUp and tearDown methods
with pytest.warns(FutureWarning, match=msg):
self.setUp()
with pytest.warns(FutureWarning, match=msg):
self.tearDown()

# Testing class-level setUp and tearDown methods
with pytest.warns(FutureWarning, match=msg):
self.setUpClass()
with pytest.warns(FutureWarning, match=msg):
self.tearDownClass()

# Test the assertion methods
with pytest.warns(FutureWarning, match=msg):
self.assertEqual(1, 1)

with pytest.warns(FutureWarning, match=msg):
self.assertNotEqual(1, 2)

with pytest.warns(FutureWarning, match=msg):
self.assertTrue(True)

with pytest.warns(FutureWarning, match=msg):
self.assertFalse(False)

with pytest.warns(FutureWarning, match=msg):
self.assertIsNone(None)

with pytest.warns(FutureWarning, match=msg):
self.assertIsNotNone("hello")

with pytest.warns(FutureWarning, match=msg), self.assertRaises(ValueError):
raise ValueError("hi")


class TestPymatgenTestPytest:
def test_unittest_testcase_specific_funcs(self):
"""Test unittest.TestCase-specific methods for migration to pytest."""
# Testing setUp and tearDown methods
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.setUp()
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.tearDown()

# Testing class-level setUp and tearDown methods
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.setUpClass()
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.tearDownClass()

# Test the assertion methods
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.assertTrue(True)
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.assertFalse(False)
with pytest.raises(AttributeError, match="'TestPymatgenTestPytest' object has no attribute"):
self.assertEqual(1, 1)
Loading