From 6e05b66401445500f700168a0d22240408e716dd Mon Sep 17 00:00:00 2001 From: Alex Ganose Date: Wed, 11 Oct 2023 11:50:58 +0100 Subject: [PATCH 1/2] Fixes from #216 --- setup.py | 2 +- sumo/cli/bandplot.py | 14 ++++---------- sumo/cli/dosplot.py | 18 +++++------------- sumo/plotting/__init__.py | 14 +++++--------- sumo/symmetry/brad_crack_kpath.py | 2 +- 5 files changed, 16 insertions(+), 34 deletions(-) diff --git a/setup.py b/setup.py index bf2db4b4..2053086b 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ "h5py", "pymatgen>=2020.10.20", "phonopy>=2.1.3", - "matplotlib", + "matplotlib>=3.2.0", "seekpath", "castepxbin<1.0", "colormath", diff --git a/sumo/cli/bandplot.py b/sumo/cli/bandplot.py index 0a1e07f3..ee6483c8 100644 --- a/sumo/cli/bandplot.py +++ b/sumo/cli/bandplot.py @@ -762,24 +762,20 @@ def _get_parser(): "--orbitals", type=_el_orb, metavar="O", - help=( - "orbitals to split into lm-decomposed " - 'contributions (e.g. "Ru.d")' - ), + help="orbitals to split into lm-decomposed contributions (e.g. 'Ru.d')", ) parser.add_argument( "--atoms", type=_atoms, metavar="A", - help=('atoms to include (e.g. "O.1.2.3,Ru.1.2.3")'), + help='atoms to include (e.g. "O.1.2.3,Ru.1.2.3")', ) parser.add_argument( "--spin", type=str, default=None, help=( - "select only one spin channel for a " - "spin-polarised calculation " + "select only one spin channel for a spin-polarised calculation " "(options: up, 1; down, -1)" ), ) @@ -883,9 +879,7 @@ def main(): logging.getLogger("").addHandler(console) if args.config is None: - config_path = os.path.join( - ilr_files("sumo.plotting"), "orbital_colours.conf" - ) + config_path = ilr_files("sumo.plotting") / "orbital_colours.conf" else: config_path = args.config colours = configparser.ConfigParser() diff --git a/sumo/cli/dosplot.py b/sumo/cli/dosplot.py index f9c6d7f1..c062fa3d 100644 --- a/sumo/cli/dosplot.py +++ b/sumo/cli/dosplot.py @@ -441,9 +441,7 @@ def _get_parser(): "--code", default="vasp", metavar="C", - help=( - 'Input file format: "vasp" (vasprun.xml) or ' '"questaal" (opt.ext)' - ), + help='Input file format: "vasp" (vasprun.xml) or "questaal" (opt.ext)', ) parser.add_argument( "-p", "--prefix", metavar="P", help="prefix for the files generated" @@ -463,25 +461,21 @@ def _get_parser(): "--orbitals", type=_el_orb, metavar="O", - help=( - "orbitals to split into lm-decomposed " - 'contributions (e.g. "Ru.d")' - ), + help="orbitals to split into lm-decomposed contributions (e.g. 'Ru.d')", ) parser.add_argument( "-a", "--atoms", type=_atoms, metavar="A", - help=('atoms to include (e.g. "O.1.2.3,Ru.1.2.3")'), + help='atoms to include (e.g. "O.1.2.3,Ru.1.2.3")', ) parser.add_argument( "--spin", type=str, default=None, help=( - "select one spin channel only for a " - "spin-polarised calculation " + "select one spin channel only for a spin-polarised calculation " "(options: up, 1; down, -1)" ), ) @@ -634,9 +628,7 @@ def main(): logging.getLogger("").addHandler(console) if args.config is None: - config_path = os.path.join( - ilr_files("sumo.plotting"), "orbital_colours.conf" - ) + config_path = ilr_files("sumo.plotting") / "orbital_colours.conf" else: config_path = args.config colours = configparser.ConfigParser() diff --git a/sumo/plotting/__init__.py b/sumo/plotting/__init__.py index 3260e35a..da02928b 100644 --- a/sumo/plotting/__init__.py +++ b/sumo/plotting/__init__.py @@ -19,15 +19,11 @@ colour_cache = {} -sumo_base_style = os.path.join(ilr_files("sumo.plotting"), "sumo_base.mplstyle") -sumo_dos_style = os.path.join(ilr_files("sumo.plotting"), "sumo_dos.mplstyle") -sumo_bs_style = os.path.join(ilr_files("sumo.plotting"), "sumo_bs.mplstyle") -sumo_phonon_style = os.path.join( - ilr_files("sumo.plotting"), "sumo_phonon.mplstyle" -) -sumo_optics_style = os.path.join( - ilr_files("sumo.plotting"), "sumo_optics.mplstyle" -) +sumo_base_style = ilr_files("sumo.plotting") / "sumo_base.mplstyle" +sumo_dos_style = ilr_files("sumo.plotting") / "sumo_dos.mplstyle" +sumo_bs_style = ilr_files("sumo.plotting") / "sumo_bs.mplstyle" +sumo_phonon_style = ilr_files("sumo.plotting") / "sumo_phonon.mplstyle" +sumo_optics_style = ilr_files("sumo.plotting") / "sumo_optics.mplstyle" def styled_plot(*style_sheets): diff --git a/sumo/symmetry/brad_crack_kpath.py b/sumo/symmetry/brad_crack_kpath.py index b3349c95..98e32339 100644 --- a/sumo/symmetry/brad_crack_kpath.py +++ b/sumo/symmetry/brad_crack_kpath.py @@ -85,7 +85,7 @@ def _get_bradcrack_data(bravais): 'path': [['\Gamma', 'X', ..., 'P'], ['H', 'N', ...]]} """ - json_file = os.path.join(ilr_files("sumo.symmetry"), "bradcrack.json") + json_file = ilr_files("sumo.symmetry") / "bradcrack.json" with open(json_file) as f: bradcrack_data = load_json(f) return bradcrack_data[bravais] From 4ea323b64f939464d78eaff963aedba15e7d19a0 Mon Sep 17 00:00:00 2001 From: Alex Ganose Date: Wed, 11 Oct 2023 11:52:51 +0100 Subject: [PATCH 2/2] Fixes from #216 --- .github/workflows/tests.yml | 2 +- .pre-commit-config.yaml | 2 +- sumo/cli/bandplot.py | 20 ++---- sumo/cli/dosplot.py | 9 +-- sumo/io/castep.py | 2 +- sumo/plotting/__init__.py | 29 +++------ sumo/plotting/optics_plotter.py | 2 - sumo/symmetry/brad_crack_kpath.py | 5 +- .../tests_electronic_structure/test_optics.py | 16 ++--- tests/tests_io/test_castep.py | 63 +++++-------------- tests/tests_io/test_questaal.py | 18 ++---- tests/tests_phonon/test_phonopy.py | 3 +- tests/tests_plotting/test_dos_plotter.py | 7 +-- tests/tests_symmetry/test_bradcrack_kpath.py | 7 +-- tests/tests_symmetry/test_custom_kpath.py | 3 +- tests/tests_symmetry/test_pymatgen_kpath.py | 11 ++-- tests/tests_symmetry/test_seekpath_kpath.py | 3 +- 17 files changed, 65 insertions(+), 137 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2f3cf8ea..30c1b0d9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -28,4 +28,4 @@ jobs: python -m pip install -e '.[tests]' - name: Test - run: pytest \ No newline at end of file + run: pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4f1c1727..91a22a16 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: rev: 22.6.0 hooks: - id: black - - repo: https://gitlab.com/pycqa/flake8 + - repo: https://github.com/pycqa/flake8 rev: 3.9.2 hooks: - id: flake8 diff --git a/sumo/cli/bandplot.py b/sumo/cli/bandplot.py index ee6483c8..dfd12fc2 100644 --- a/sumo/cli/bandplot.py +++ b/sumo/cli/bandplot.py @@ -17,10 +17,9 @@ from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 from importlib_resources import files as ilr_files + import matplotlib as mpl -from pymatgen.electronic_structure.bandstructure import ( - get_reconstructed_band_structure, -) +from pymatgen.electronic_structure.bandstructure import get_reconstructed_band_structure from pymatgen.electronic_structure.core import Spin from pymatgen.io.vasp.outputs import BSVasprun @@ -394,9 +393,7 @@ def bandplot( else: logging.info(f"Found PDOS file {pdos_file}") else: - logging.info( - f"Cell file {cell_file} does not exist, cannot plot PDOS." - ) + logging.info(f"Cell file {cell_file} does not exist, cannot plot PDOS.") dos, pdos = read_castep_dos( dos_file, @@ -620,8 +617,7 @@ def _get_parser(): "-c", "--code", default="vasp", - help="Electronic structure code (default: vasp)." - '"questaal" also supported.', + help="Electronic structure code (default: vasp)." '"questaal" also supported.', ) parser.add_argument( "-p", "--prefix", metavar="P", help="prefix for the files generated" @@ -825,9 +821,7 @@ def _get_parser(): parser.add_argument( "--height", type=float, default=None, help="height of the graph" ) - parser.add_argument( - "--width", type=float, default=None, help="width of the graph" - ) + parser.add_argument("--width", type=float, default=None, help="width of the graph") parser.add_argument( "--ymin", type=float, default=-6.0, help="minimum energy on the y-axis" ) @@ -886,9 +880,7 @@ def main(): colours.read(os.path.abspath(config_path)) warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib") - warnings.filterwarnings( - "ignore", category=UnicodeWarning, module="matplotlib" - ) + warnings.filterwarnings("ignore", category=UnicodeWarning, module="matplotlib") warnings.filterwarnings("ignore", category=UserWarning, module="pymatgen") bandplot( diff --git a/sumo/cli/dosplot.py b/sumo/cli/dosplot.py index c062fa3d..7c20a40d 100644 --- a/sumo/cli/dosplot.py +++ b/sumo/cli/dosplot.py @@ -18,6 +18,7 @@ import matplotlib as mpl import numpy as np + try: from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 @@ -554,9 +555,7 @@ def _get_parser(): parser.add_argument( "--height", type=float, default=None, help="height of the graph" ) - parser.add_argument( - "--width", type=float, default=None, help="width of the graph" - ) + parser.add_argument("--width", type=float, default=None, help="width of the graph") parser.add_argument( "--xmin", type=float, default=-6.0, help="minimum energy on the x-axis" ) @@ -635,9 +634,7 @@ def main(): colours.read(os.path.abspath(config_path)) warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib") - warnings.filterwarnings( - "ignore", category=UnicodeWarning, module="matplotlib" - ) + warnings.filterwarnings("ignore", category=UnicodeWarning, module="matplotlib") warnings.filterwarnings("ignore", category=UserWarning, module="pymatgen") if args.zero_energy is not None: diff --git a/sumo/io/castep.py b/sumo/io/castep.py index 3bfab75c..6546624c 100644 --- a/sumo/io/castep.py +++ b/sumo/io/castep.py @@ -487,7 +487,7 @@ def labels_from_cell(cell_file, phonon=False): line = f.readline() # Skip past block start line while blockend.match(line.lower()) is None: # Do not parse break lines - if 'break' not in line.lower(): + if "break" not in line.lower(): kpt = tuple(map(float, line.split()[:3])) if len(line.split()) > 3: label = line.split()[-1] diff --git a/sumo/plotting/__init__.py b/sumo/plotting/__init__.py index da02928b..5365b6fa 100644 --- a/sumo/plotting/__init__.py +++ b/sumo/plotting/__init__.py @@ -4,8 +4,8 @@ """ Subpackage providing helper functions for generating publication ready plots. """ -from functools import wraps import os +from functools import wraps import matplotlib.pyplot import numpy as np @@ -43,9 +43,7 @@ def styled_plot(*style_sheets): def decorator(get_plot): @wraps(get_plot) - def wrapper( - *args, fonts=None, style=None, no_base_style=False, **kwargs - ): + def wrapper(*args, fonts=None, style=None, no_base_style=False, **kwargs): if no_base_style: list_style = [] else: @@ -58,9 +56,7 @@ def wrapper( list_style += [style] if fonts is not None: - list_style += [ - {"font.family": "sans-serif", "font.sans-serif": fonts} - ] + list_style += [{"font.family": "sans-serif", "font.sans-serif": fonts}] matplotlib.pyplot.style.use(list_style) return get_plot(*args, **kwargs) @@ -273,9 +269,7 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"): "xyz": XYZColor, } if colorspace not in list(colorspace_mapping.keys()): - raise ValueError( - f"colorspace must be one of {colorspace_mapping.keys()}" - ) + raise ValueError(f"colorspace must be one of {colorspace_mapping.keys()}") colorspace = colorspace_mapping[colorspace] @@ -286,19 +280,13 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"): # now convert to the colorspace basis for interpolation basis1 = np.array( - convert_color( - color1_rgb, colorspace, target_illuminant="d50" - ).get_value_tuple() + convert_color(color1_rgb, colorspace, target_illuminant="d50").get_value_tuple() ) basis2 = np.array( - convert_color( - color2_rgb, colorspace, target_illuminant="d50" - ).get_value_tuple() + convert_color(color2_rgb, colorspace, target_illuminant="d50").get_value_tuple() ) basis3 = np.array( - convert_color( - color3_rgb, colorspace, target_illuminant="d50" - ).get_value_tuple() + convert_color(color3_rgb, colorspace, target_illuminant="d50").get_value_tuple() ) # ensure weights is a numpy array @@ -313,8 +301,7 @@ def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"): # convert colors to RGB rgb_colors = [ - convert_color(colorspace(*c), sRGBColor).get_value_tuple() - for c in colors + convert_color(colorspace(*c), sRGBColor).get_value_tuple() for c in colors ] # ensure all rgb values are less than 1 (sometimes issues in interpolation diff --git a/sumo/plotting/optics_plotter.py b/sumo/plotting/optics_plotter.py index 5e4d3ac0..6e23e309 100644 --- a/sumo/plotting/optics_plotter.py +++ b/sumo/plotting/optics_plotter.py @@ -8,7 +8,6 @@ import numpy as np import scipy.constants as scpc from matplotlib import rcParams -from matplotlib.font_manager import FontProperties, findfont from matplotlib.ticker import AutoMinorLocator, FuncFormatter, MaxNLocator from sumo.plotting import ( @@ -242,7 +241,6 @@ def get_plot( ax.set_ylim(ymin, ymax) if spectrum_key == "absorption": - font = findfont(FontProperties(family=["sans-serif"])) ax.yaxis.set_major_formatter( FuncFormatter(curry_power_tick(times_sign=r"\times")) ) diff --git a/sumo/symmetry/brad_crack_kpath.py b/sumo/symmetry/brad_crack_kpath.py index 98e32339..534996f0 100644 --- a/sumo/symmetry/brad_crack_kpath.py +++ b/sumo/symmetry/brad_crack_kpath.py @@ -6,7 +6,6 @@ """ from json import load as load_json -import os import numpy as np @@ -66,9 +65,7 @@ def __init__(self, structure, symprec=1e-3, spg=None): spg_symbol = self.spg_symbol lattice_type = self.lattice_type - bravais = self._get_bravais_lattice( - spg_symbol, lattice_type, a, b, c, unique - ) + bravais = self._get_bravais_lattice(spg_symbol, lattice_type, a, b, c, unique) self._kpath = self._get_bradcrack_data(bravais) @staticmethod diff --git a/tests/tests_electronic_structure/test_optics.py b/tests/tests_electronic_structure/test_optics.py index d6c16642..a0167c78 100644 --- a/tests/tests_electronic_structure/test_optics.py +++ b/tests/tests_electronic_structure/test_optics.py @@ -1,26 +1,22 @@ import json -import unittest import os +import unittest try: from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 from importlib_resources import files as ilr_files + import numpy as np from numpy.testing import assert_almost_equal from pymatgen.io.vasp import Vasprun -from sumo.electronic_structure.optics import ( - calculate_dielectric_properties, - kkr, -) +from sumo.electronic_structure.optics import calculate_dielectric_properties, kkr class AbsorptionTestCase(unittest.TestCase): def setUp(self): - diel_path = os.path.join( - ilr_files("tests"), "data", "Ge", "ge_diel.json" - ) + diel_path = os.path.join(ilr_files("tests"), "data", "Ge", "ge_diel.json") with open(diel_path) as f: self.ge_diel = json.load(f) @@ -35,9 +31,7 @@ def test_absorption(self): self.ge_diel, {"absorption"}, ) - self.assertIsNone( - assert_almost_equal(properties["absorption"], self.ge_abs) - ) + self.assertIsNone(assert_almost_equal(properties["absorption"], self.ge_abs)) class KramersKronigTestCase(unittest.TestCase): diff --git a/tests/tests_io/test_castep.py b/tests/tests_io/test_castep.py index 462b3367..d641808c 100644 --- a/tests/tests_io/test_castep.py +++ b/tests/tests_io/test_castep.py @@ -1,11 +1,12 @@ import json -import unittest import os +import unittest try: from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 from importlib_resources import files as ilr_files + from monty.io import gzip from monty.json import MontyDecoder from numpy.testing import assert_array_almost_equal @@ -26,21 +27,15 @@ class CastepCellTestCase(unittest.TestCase): def setUp(self): - self.si_cell = os.path.join( - ilr_files("tests"), "data", "Si", "Si2.cell" - ) + self.si_cell = os.path.join(ilr_files("tests"), "data", "Si", "Si2.cell") self.si_cell_alt = os.path.join( ilr_files("tests"), "data", "Si", "Si2-alt.cell" ) - self.zns_band_cell = os.path.join( - ilr_files("tests"), "data", "ZnS", "zns.cell" - ) + self.zns_band_cell = os.path.join(ilr_files("tests"), "data", "ZnS", "zns.cell") self.zns_singlepoint_cell = os.path.join( ilr_files("tests"), "data", "ZnS", "zns-sp.cell" ) - si_structure_file = os.path.join( - ilr_files("tests"), "data", "Si", "Si8.json" - ) + si_structure_file = os.path.join(ilr_files("tests"), "data", "Si", "Si8.json") self.si_structure = Structure.from_file(si_structure_file) def test_castep_cell_null_init(self): @@ -105,9 +100,7 @@ def setUp(self): "cell_file": "NiO.cell", } for key, value in nio_files.items(): - nio_files[key] = os.path.join( - ilr_files("tests"), "data", "NiO", value - ) + nio_files[key] = os.path.join(ilr_files("tests"), "data", "NiO", value) self.read_dos_kwargs = nio_files @@ -155,12 +148,8 @@ def test_pdos(self): class CastepBandStructureTestCaseNoSpin(unittest.TestCase): def setUp(self): - self.si_bands = os.path.join( - ilr_files("tests"), "data", "Si", "Si2.bands" - ) - self.si_cell = os.path.join( - ilr_files("tests"), "data", "Si", "Si2.cell" - ) + self.si_bands = os.path.join(ilr_files("tests"), "data", "Si", "Si2.bands") + self.si_cell = os.path.join(ilr_files("tests"), "data", "Si", "Si2.cell") self.si_cell_alt = os.path.join( ilr_files("tests"), "data", "Si", "Si2-alt.cell" ) @@ -185,16 +174,12 @@ def test_castep_bands_read_header(self): def test_castep_bands_read_eigenvalues(self): with open(self.si_header_ref) as f: ref_header = json.load(f) - kpoints, weights, eigenvals = read_bands_eigenvalues( - self.si_bands, ref_header - ) + kpoints, weights, eigenvals = read_bands_eigenvalues(self.si_bands, ref_header) for i, k in enumerate([0.5, 0.36111111, 0.63888889]): self.assertAlmostEqual(kpoints[4][i], k) - self.assertAlmostEqual( - eigenvals[Spin.up][2, 4], 0.09500443 * _ry_to_ev * 2 - ) + self.assertAlmostEqual(eigenvals[Spin.up][2, 4], 0.09500443 * _ry_to_ev * 2) for weight in weights: self.assertAlmostEqual(weight, 0.02272727) @@ -212,9 +197,7 @@ def test_castep_cell_read_labels_alt_spelling(self): class CastepBandStructureTestCaseNickel(unittest.TestCase): def setUp(self): - self.ni_cell = os.path.join( - ilr_files("tests"), "data", "Ni", "ni-band.cell" - ) + self.ni_cell = os.path.join(ilr_files("tests"), "data", "Ni", "ni-band.cell") self.ref_labels = { r"\Gamma": (0.0, 0.0, 0.0), @@ -232,13 +215,9 @@ def test_castep_cell_read_labels_from_list(self): class CastepBandStructureTestCaseWithSpin(unittest.TestCase): def setUp(self): - self.fe_bands = os.path.join( - ilr_files("tests"), "data", "Fe", "Fe.bands" - ) + self.fe_bands = os.path.join(ilr_files("tests"), "data", "Fe", "Fe.bands") - self.fe_cell = os.path.join( - ilr_files("tests"), "data", "Fe", "Fe.cell" - ) + self.fe_cell = os.path.join(ilr_files("tests"), "data", "Fe", "Fe.cell") self.fe_header_ref = os.path.join( ilr_files("tests"), "data", "Fe", "Fe.bands_header.json" ) @@ -252,9 +231,7 @@ def test_castep_bands_read_header(self): class BandStructureTestCasePathBreak(unittest.TestCase): def setUp(self): - self.pt_cell = os.path.join( - ilr_files("tests"), "data", "Pt", "Pt.cell" - ) + self.pt_cell = os.path.join(ilr_files("tests"), "data", "Pt", "Pt.cell") self.ref_labels = { r"\Gamma": (0.0, 0.0, 0.0), "X": (0.5, 0.0, 0.5), @@ -275,12 +252,8 @@ def test_castep_parse_break_in_k_path(self): # sumo-phonon-bandplot -f zns.phonon --units cm-1 --to-json zns_phonon.json class CastepPhononTestCaseZincblende(unittest.TestCase): def setUp(self): - self.zns_phonon = os.path.join( - ilr_files("tests"), "data", "ZnS", "zns.phonon" - ) - self.zns_cell = os.path.join( - ilr_files("tests"), "data", "ZnS", "zns.cell" - ) + self.zns_phonon = os.path.join(ilr_files("tests"), "data", "ZnS", "zns.phonon") + self.zns_cell = os.path.join(ilr_files("tests"), "data", "ZnS", "zns.cell") self.zns_phonon_ref = os.path.join( ilr_files("tests"), "data", "ZnS", "zns_phonon.json" ) @@ -304,9 +277,7 @@ def test_castep_phonon_read_bands(self): ) assert_array_almost_equal(bs_dict["qpoints"], ref_dict["qpoints"]) self.assertEqual(bs_dict["labels_dict"], ref_dict["labels_dict"]) - self.assertEqual( - bs_dict["structure"]["sites"], ref_dict["structure"]["sites"] - ) + self.assertEqual(bs_dict["structure"]["sites"], ref_dict["structure"]["sites"]) assert_array_almost_equal( bs_dict["structure"]["lattice"]["matrix"], ref_dict["structure"]["lattice"]["matrix"], diff --git a/tests/tests_io/test_questaal.py b/tests/tests_io/test_questaal.py index d51b6fc1..8217571d 100644 --- a/tests/tests_io/test_questaal.py +++ b/tests/tests_io/test_questaal.py @@ -1,10 +1,11 @@ -import unittest import os +import unittest try: from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 from importlib_resources import files as ilr_files + import numpy as np from pymatgen.core.lattice import Lattice @@ -13,9 +14,7 @@ class QuestaalOpticsTestCase(unittest.TestCase): def setUp(self): - self.bse_path = os.path.join( - ilr_files("tests"), "data", "SnO2", "eps_BSE.out" - ) + self.bse_path = os.path.join(ilr_files("tests"), "data", "SnO2", "eps_BSE.out") def test_optics_from_bethesalpeter(self): energy, real, imag = dielectric_from_file(self.bse_path) @@ -93,9 +92,7 @@ def test_init_from_python(self): self.assertFalse(init_plat.cartesian) init_plat_alat = QuestaalInit(lattice, site) - init_plat_alat.lattice["PLAT"] = ( - np.array(init_plat_alat.lattice["PLAT"]) * 0.1 - ) + init_plat_alat.lattice["PLAT"] = np.array(init_plat_alat.lattice["PLAT"]) * 0.1 init_plat_alat.lattice["ALAT"] = 10 self.assertEqual(init_plat.structure, init_plat_alat.structure) @@ -143,9 +140,7 @@ def test_init_from_python(self): bohr_init_noconvert = QuestaalInit( bohr_lattice, bohr_cart_sites, ignore_units=True ) - self.assertAlmostEqual( - bohr_init_noconvert.structure.lattice.abc[2], 9.74172715 - ) + self.assertAlmostEqual(bohr_init_noconvert.structure.lattice.abc[2], 9.74172715) self.assertAlmostEqual( bohr_init_noconvert.structure.distance_matrix[0, -1], 3.66441077 ) @@ -208,8 +203,7 @@ def test_init_from_file(self): self.assertLess( ( abs( - np.array(self.ref_pmg_lat.abc) - - np.array(sym_structure.lattice.abc) + np.array(self.ref_pmg_lat.abc) - np.array(sym_structure.lattice.abc) ) ).max(), 1e-5, diff --git a/tests/tests_phonon/test_phonopy.py b/tests/tests_phonon/test_phonopy.py index 0449af69..cf3060b0 100644 --- a/tests/tests_phonon/test_phonopy.py +++ b/tests/tests_phonon/test_phonopy.py @@ -1,10 +1,11 @@ -import unittest import os +import unittest try: from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 from importlib_resources import files as ilr_files + import numpy as np from phonopy import Phonopy from pymatgen.io.vasp.inputs import Poscar diff --git a/tests/tests_plotting/test_dos_plotter.py b/tests/tests_plotting/test_dos_plotter.py index f7c2897b..79c603a5 100644 --- a/tests/tests_plotting/test_dos_plotter.py +++ b/tests/tests_plotting/test_dos_plotter.py @@ -1,5 +1,5 @@ -import unittest import os +import unittest try: import configparser @@ -9,6 +9,7 @@ from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 from importlib_resources import files as ilr_files + import matplotlib.pyplot import sumo.plotting @@ -18,9 +19,7 @@ class GetColourTestCase(unittest.TestCase): def setUp(self): - config_path = os.path.join( - ilr_files("sumo.plotting"), "orbital_colours.conf" - ) + config_path = os.path.join(ilr_files("sumo.plotting"), "orbital_colours.conf") self.config = configparser.ConfigParser() self.config.read(os.path.abspath(config_path)) diff --git a/tests/tests_symmetry/test_bradcrack_kpath.py b/tests/tests_symmetry/test_bradcrack_kpath.py index c18290f4..081ec082 100644 --- a/tests/tests_symmetry/test_bradcrack_kpath.py +++ b/tests/tests_symmetry/test_bradcrack_kpath.py @@ -1,11 +1,12 @@ +import os import unittest import warnings -import os try: from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 from importlib_resources import files as ilr_files + from pymatgen.core.structure import Structure from sumo.symmetry.brad_crack_kpath import BradCrackKpath @@ -30,9 +31,7 @@ def test_bravais_assignment(self): "triclinic", ) self.assertEqual( - BradCrackKpath._get_bravais_lattice( - "Fd-3m", "cubic", 5.66, 5.66, 5.66, 0 - ), + BradCrackKpath._get_bravais_lattice("Fd-3m", "cubic", 5.66, 5.66, 5.66, 0), "cubic_f", ) diff --git a/tests/tests_symmetry/test_custom_kpath.py b/tests/tests_symmetry/test_custom_kpath.py index 4bc1be64..de1eccbf 100644 --- a/tests/tests_symmetry/test_custom_kpath.py +++ b/tests/tests_symmetry/test_custom_kpath.py @@ -1,11 +1,12 @@ +import os import unittest import warnings -import os try: from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 from importlib_resources import files as ilr_files + from pymatgen.core.structure import Structure from sumo.symmetry.custom_kpath import CustomKpath diff --git a/tests/tests_symmetry/test_pymatgen_kpath.py b/tests/tests_symmetry/test_pymatgen_kpath.py index add9e034..bdf0d6b2 100644 --- a/tests/tests_symmetry/test_pymatgen_kpath.py +++ b/tests/tests_symmetry/test_pymatgen_kpath.py @@ -1,11 +1,12 @@ +import os import unittest import warnings -import os try: from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 from importlib_resources import files as ilr_files + from pymatgen.core.structure import Structure from sumo.symmetry.brad_crack_kpath import BradCrackKpath @@ -14,9 +15,7 @@ class SeekpathKpathTestCase(unittest.TestCase): def setUp(self): - zno_poscar = os.path.join( - ilr_files("tests"), "data", "ZnO", "POSCAR" - ) + zno_poscar = os.path.join(ilr_files("tests"), "data", "ZnO", "POSCAR") hgs_poscar = os.path.join(ilr_files("tests"), "data", "Ge", "POSCAR") with warnings.catch_warnings(): # Not interested in Pymatgen warnings @@ -44,6 +43,4 @@ def test_pymatgen_points(self): # Bradcrack kpoints should be a subset of pymatgen kpoints for label, position in kpath_bradcrack.kpoints.items(): self.assertIn(label, kpath_pymatgen.kpoints) - self.assertEqual( - list(position), list(kpath_pymatgen.kpoints[label]) - ) + self.assertEqual(list(position), list(kpath_pymatgen.kpoints[label])) diff --git a/tests/tests_symmetry/test_seekpath_kpath.py b/tests/tests_symmetry/test_seekpath_kpath.py index abd407e8..f428f711 100644 --- a/tests/tests_symmetry/test_seekpath_kpath.py +++ b/tests/tests_symmetry/test_seekpath_kpath.py @@ -1,11 +1,12 @@ +import os import unittest import warnings -import os try: from importlib.resources import files as ilr_files except ImportError: # Python < 3.9 from importlib_resources import files as ilr_files + from pymatgen.core.structure import Structure from sumo.symmetry.seekpath_kpath import SeekpathKpath