Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brillouin zone plotter #76

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def load_test_suite():
entry_points={'console_scripts': [
'sumo-bandplot = sumo.cli.bandplot:main',
'sumo-bandstats = sumo.cli.bandstats:main',
'sumo-brillplot = sumo.cli.brillplot:main',
'sumo-dosplot = sumo.cli.dosplot:main',
'sumo-kgen = sumo.cli.kgen:main',
'sumo-phonon-bandplot = sumo.cli.phonon_bandplot:main',
Expand Down
171 changes: 171 additions & 0 deletions sumo/cli/brillplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright (c) Scanlon Materials Theory Group
# Distributed under the terms of the MIT License.

"""
Plot high symmetry points on the Brillouin Zone from calculated band structure

TODO:
- Connect the high symmetry points to make a path as it appears on the band
structure
- Incorporate an option to open a gui to inspect bz by eye
- Apply user styles / appearance options
- Modify the matplotlib backend such that both figures can be saved and a
gui can be used

"""

import os
import sys
import glob
import logging
import argparse
import warnings

from pymatgen.io.vasp.outputs import BSVasprun
from pymatgen.electronic_structure.bandstructure import get_reconstructed_band_structure
from pymatgen.electronic_structure.plotter import plot_brillouin_zone
import matplotlib as mpl
mpl.use("Agg")
from sumo.plotting import pretty_plot_3d
from sumo.plotting import (colour_cache,
styled_plot, sumo_base_style)


__author__ = "Arthur Youd"
__version__ = "1.0"
__maintainer__ = "Alex Ganose"
__email__ = "[email protected]"
__date__ = "August 21, 2019"

@styled_plot(sumo_base_style)
def brillplot(filenames=None, prefix=None, directory=None,
width=6, height=6, fonts=None,
image_format="pdf", dpi=400):
"""Generate plot of first brillouin zone from a band-structure calculation.
Args:
filenames (:obj:`str` or :obj:`list`, optional): Path to input files.
Vasp:
Use vasprun.xml or vasprun.xml.gz file.
image_format (:obj:`str`, optional): The image file format. Can be any
format supported by matplotlib, including: png, jpg, pdf, and svg.
Defaults to pdf.
dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for the
image.
"""
if not filenames:
filenames = find_vasprun_files()
elif isinstance(filenames, str):
filenames = [filenames]
bandstructures = []
for vr_file in filenames:
vr = BSVasprun(vr_file)
bs = vr.get_band_structure(line_mode=True)
bandstructures.append(bs)
bs = get_reconstructed_band_structure(bandstructures)

labels = {}
for k in bs.kpoints:
if k.label:
labels[k.label] = k.frac_coords

lines = []
for b in bs.branches:
lines.append([bs.kpoints[b['start_index']].frac_coords,
bs.kpoints[b['end_index']].frac_coords])

plt = pretty_plot_3d(width, height, dpi=dpi, fonts=fonts)
fig = plot_brillouin_zone(bs.lattice_rec, lines=lines, labels=labels,
ax=plt.gca())

basename = "brillouin.{}".format(image_format)
filename = "{}_{}".format(prefix, basename) if prefix else basename
if directory:
filename = os.path.join(directory, filename)
fig.savefig(filename, format=image_format, dpi=dpi, bbox_inches="tight")
return plt

def find_vasprun_files():
"""Search for vasprun files from the current directory.

The precedence order for file locations is:
1. First search for folders named: 'split-0*'
2. Else, look in the current directory.
The split folder names should always be zero based, therefore easily
sortable.

Returns: list of str

"""
folders = glob.glob("split-*")
folders = sorted(folders) if folders else ["."]

filenames = []
for fol in folders:
vr_file = os.path.join(fol, "vasprun.xml")
vr_file_gz = os.path.join(fol, "vasprun.xml.gz")
if os.path.exists(vr_file):
filenames.append(vr_file)
elif os.path.exists(vr_file_gz):
filenames.append(vr_file_gz)
else:
logging.error("ERROR: No vasprun.xml found in {}!".format(fol))
sys.exit()
return filenames


def _get_parser():
parser = argparse.ArgumentParser(description="""
brillplot is a script to produce publication-ready
brillouin zone diagrams""",
epilog="""
Author: {}
Version: {}
Last updated: {}""".format(__author__, __version__, __date__))
parser.add_argument(
"-f",
"--filenames",
default=None,
nargs="+",
metavar="F",
help="one or more vasprun.xml files to plot",
)
parser.add_argument(
"-d", "--directory", metavar="D", help="output directory for files"
)
parser.add_argument(
"--format",
type=str,
default="pdf",
dest="image_format",
metavar="FORMAT",
help="image file format (options: pdf, svg, jpg, png)",
)
parser.add_argument(
"--dpi", type=int, default=400, help="pixel density for image file"
)
return parser


def main():
args = _get_parser().parse_args()
logging.basicConfig(
filename="sumo-brillplot.log",
level=logging.INFO,
filemode="w",
format="%(message)s",
)
console = logging.StreamHandler()
logging.info(" ".join(sys.argv[:]))
logging.getLogger("").addHandler(console)
warnings.filterwarnings("ignore",
category=UserWarning, module="matplotlib")
warnings.filterwarnings("ignore",
category=UnicodeWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=UserWarning, module="pymatgen")

brillplot(
filenames=args.filenames,
directory=args.directory,
image_format=args.image_format,
dpi=args.dpi,
)
24 changes: 22 additions & 2 deletions sumo/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from matplotlib.collections import LineCollection
from matplotlib import rc, rcParams
from pkg_resources import resource_filename
from mpl_toolkits.mplot3d import Axes3D

colour_cache = {}

Expand All @@ -38,7 +39,6 @@ def styled_plot(*style_sheets):
"""

def decorator(get_plot):

def wrapper(*args, fonts=None, style=None, no_base_style=False,
**kwargs):

Expand Down Expand Up @@ -96,6 +96,27 @@ def pretty_plot(width=None, height=None, plt=None, dpi=None):
return plt


def pretty_plot_3d(width=5, height=5, plt=None, dpi=None, fonts=None):
if plt is None:
plt = matplotlib.pyplot
if width is None:
width = matplotlib.rcParams['figure.figsize'][0]
if height is None:
height = matplotlib.rcParams['figure.figsize'][1]

if dpi is not None:
matplotlib.rcParams['figure.dpi'] = dpi

fig = plt.figure(figsize=(width, height), dpi=dpi)

else:
fig = plt.gcf()

ax = fig.add_subplot(111, projection='3d')

return plt


def pretty_subplot(nrows, ncols, width=None, height=None, sharex=True,
sharey=True, dpi=None, plt=None, gridspec_kw=None):
"""Get a :obj:`matplotlib.pyplot` subplot object with pretty defaults.
Expand Down Expand Up @@ -133,7 +154,6 @@ def pretty_subplot(nrows, ncols, width=None, height=None, sharex=True,
plt.subplots(nrows, ncols, sharex=sharex, sharey=sharey, dpi=dpi,
figsize=(width, height), facecolor='w',
gridspec_kw=gridspec_kw)

return plt


Expand Down