From a8aeed372b0f1e87c656ca0acec9a45ab2fe3387 Mon Sep 17 00:00:00 2001 From: Alex Ganose Date: Wed, 10 Jan 2024 12:05:39 +0000 Subject: [PATCH 1/2] Fix bs issue --- .pre-commit-config.yaml | 10 +++++----- sumo/plotting/phonon_bs_plotter.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 91a22a16..a1c3730a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,28 +1,28 @@ exclude: ^tests/data/ repos: - repo: https://github.com/myint/autoflake - rev: v1.4 + rev: v2.2.1 hooks: - id: autoflake args: [--in-place, --remove-all-unused-imports, --remove-unused-variable, --ignore-init-module-imports] - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 23.12.1 hooks: - id: black - repo: https://github.com/pycqa/flake8 - rev: 3.9.2 + rev: 7.0.0 hooks: - id: flake8 args: [--max-line-length=125, "--extend-ignore=E203,W503,E402,F401"] language_version: python3 - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.13.2 hooks: - id: isort name: isort (python) diff --git a/sumo/plotting/phonon_bs_plotter.py b/sumo/plotting/phonon_bs_plotter.py index 5dcfccc6..1a032658 100644 --- a/sumo/plotting/phonon_bs_plotter.py +++ b/sumo/plotting/phonon_bs_plotter.py @@ -156,7 +156,7 @@ def _plot_lines(data, ax, color=None, alpha=1, zorder=1): # nd is branch index, nb is band index, nk is kpoint index for nd, nb in itertools.product( - range(len(data["distances"])), range(self._nb_bands) + range(len(data["distances"])), range(self.n_bands) ): f = freqs[nd][nb] @@ -176,7 +176,7 @@ def _plot_lines(data, ax, color=None, alpha=1, zorder=1): # raise Exception(bs.qpoints) json_plotter = PhononBSPlotter(bs) json_data = json_plotter.bs_plot_data() - if json_plotter._nb_bands != self._nb_bands: + if json_plotter.n_bands != self.n_bands: raise Exception( f"Number of bands in {bs_json} does not match main plot" ) From bd9c52fdc51c0608dd8aee7e93b94e0fa6666f60 Mon Sep 17 00:00:00 2001 From: Alex Ganose Date: Wed, 10 Jan 2024 12:11:45 +0000 Subject: [PATCH 2/2] Update pre-commit --- sumo/electronic_structure/bandstructure.py | 2 -- sumo/electronic_structure/effective_mass.py | 2 -- sumo/io/castep.py | 1 - sumo/io/questaal.py | 1 - sumo/plotting/bs_plotter.py | 2 -- sumo/plotting/phonon_bs_plotter.py | 1 - sumo/symmetry/seekpath_kpath.py | 2 +- tests/tests_plotting/test_band_plotter.py | 2 -- 8 files changed, 1 insertion(+), 12 deletions(-) diff --git a/sumo/electronic_structure/bandstructure.py b/sumo/electronic_structure/bandstructure.py index b89adfc8..988e930d 100644 --- a/sumo/electronic_structure/bandstructure.py +++ b/sumo/electronic_structure/bandstructure.py @@ -149,7 +149,6 @@ def get_projections(bs, selection, normalise=None): # store the projections for all elements and orbitals in a useable format for spin, element, orbital in it.product(spins, elements, all_orbitals): - # convert data to [nb][nk] el_orb_proj = [ [all_proj[spin][nb][nk][element][orbital] for nk in range(nkpts)] @@ -164,7 +163,6 @@ def get_projections(bs, selection, normalise=None): # now go through the selected orbitals and extract what's needed spec_proj = [] for spec in selection: - if isinstance(spec, str): # spec is just an element type, therefore sum all orbitals element = spec diff --git a/sumo/electronic_structure/effective_mass.py b/sumo/electronic_structure/effective_mass.py index a7a1f26a..191b9a9f 100644 --- a/sumo/electronic_structure/effective_mass.py +++ b/sumo/electronic_structure/effective_mass.py @@ -61,7 +61,6 @@ def get_fitting_data(bs, spin, band_id, kpoint_id, num_sample_points=3): # check to see if there are enough points to sample from first # check in the forward direction if kpoint_id + num_sample_points <= branch_data["end_index"]: - # calculate sampling limits start_id = kpoint_id end_id = kpoint_id + num_sample_points + 1 @@ -90,7 +89,6 @@ def get_fitting_data(bs, spin, band_id, kpoint_id, num_sample_points=3): # check in the backward direction if kpoint_id - num_sample_points >= branch_data["start_index"]: - # calculate sampling limits start_id = kpoint_id - num_sample_points end_id = kpoint_id + 1 diff --git a/sumo/io/castep.py b/sumo/io/castep.py index 6546624c..7e2b3d71 100644 --- a/sumo/io/castep.py +++ b/sumo/io/castep.py @@ -138,7 +138,6 @@ def to_file(self, filename): @classmethod def from_file(cls, filename): - with zopen(filename, "rt") as f: lines = [line.strip() for line in f] diff --git a/sumo/io/questaal.py b/sumo/io/questaal.py index 57476a07..8403af4d 100644 --- a/sumo/io/questaal.py +++ b/sumo/io/questaal.py @@ -308,7 +308,6 @@ def _get_structure_from_lattice(self): def to_file(self, filename): """Write QuestaalInit object to init file""" with open(filename, "w") as f: - f.write("LATTICE\n") for key, value in self.lattice.items(): if key == "PLAT": diff --git a/sumo/plotting/bs_plotter.py b/sumo/plotting/bs_plotter.py index d952e573..3ca1ead7 100644 --- a/sumo/plotting/bs_plotter.py +++ b/sumo/plotting/bs_plotter.py @@ -556,7 +556,6 @@ def get_projected_plot( # nd is branch index for spin, nd in it.product(spins, range(nbranches)): - # mask data to reduce plotting load bands = np.array(data["energy"][str(spin)][nd]) mask = np.where( @@ -597,7 +596,6 @@ def get_projected_plot( weights[weights < 0] = 0 if mode == "rgb": - # colours aren't used now but needed later for legend colours = [color1, color2, color3] diff --git a/sumo/plotting/phonon_bs_plotter.py b/sumo/plotting/phonon_bs_plotter.py index 1a032658..5e7cf1eb 100644 --- a/sumo/plotting/phonon_bs_plotter.py +++ b/sumo/plotting/phonon_bs_plotter.py @@ -246,7 +246,6 @@ def _makeplot( if dos is not None: self._plot_phonon_dos(dos, ax=fig.axes[1], color=color, dashline=dashline) else: - # keep correct aspect ratio; match axis to canvas x0, x1 = ax.get_xlim() y0, y1 = ax.get_ylim() diff --git a/sumo/symmetry/seekpath_kpath.py b/sumo/symmetry/seekpath_kpath.py index f150b8b0..59012248 100644 --- a/sumo/symmetry/seekpath_kpath.py +++ b/sumo/symmetry/seekpath_kpath.py @@ -78,7 +78,7 @@ def kpath_from_seekpath(cls, seekpath, point_coords): # convert from seekpath format e.g. [(l1, l2), (l2, l3), (l4, l5)] # to our preferred representation [[l1, l2, l3], [l4, l5]] path = [[seekpath[0][0]]] - for (k1, k2) in seekpath: + for k1, k2 in seekpath: if path[-1] and path[-1][-1] == k1: path[-1].append(k2) else: diff --git a/tests/tests_plotting/test_band_plotter.py b/tests/tests_plotting/test_band_plotter.py index d1663fcb..8754cc72 100644 --- a/tests/tests_plotting/test_band_plotter.py +++ b/tests/tests_plotting/test_band_plotter.py @@ -18,7 +18,6 @@ def test_sanitise_label(self): ("@X", None), ("@HEX", None), ): - self.assertEqual(SBSPlotter._sanitise_label(label_in), label_out) def test_sanitise_label_group(self): @@ -45,5 +44,4 @@ def test_sanitise_label_group(self): (r"X@$\mid$@Y", r"X"), (r"@X@$\mid$@Y", None), ): - self.assertEqual(SBSPlotter._sanitise_label_group(label_in), label_out)