Skip to content

Commit

Permalink
Brillplot: work with Agg backend, create 3d prettyplot base
Browse files Browse the repository at this point in the history
It's important that an Axis passed to the Pymatgen Brillouin plotter
is already 3D or we get some rather confusing error messages.
  • Loading branch information
ajjackson committed Aug 27, 2019
1 parent 386da78 commit ff1402d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 22 deletions.
26 changes: 20 additions & 6 deletions sumo/cli/brillplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@

from pymatgen.io.vasp.outputs import BSVasprun
from pymatgen.electronic_structure.bandstructure import get_reconstructed_band_structure
from pymatgen.electronic_structure.plotter import BSPlotter
from pymatgen.electronic_structure.plotter import (BSPlotter,
plot_brillouin_zone)
import matplotlib as mpl

mpl.use("Agg")
from sumo.plotting import pretty_plot_3d

__author__ = "Arthur Youd"
__version__ = "1.0"
__maintainer__ = "Alex Ganose"
Expand All @@ -35,6 +37,7 @@


def brillplot(filenames=None, prefix=None, directory=None,
width=6, height=6, fonts=None,
image_format="pdf", dpi=400):
"""Generate plot of first brillouin zone from a band-structure calculation.
Args:
Expand All @@ -57,17 +60,28 @@ def brillplot(filenames=None, prefix=None, directory=None,
bs = vr.get_band_structure(line_mode=True)
bandstructures.append(bs)
bs = get_reconstructed_band_structure(bandstructures)
plotter = BSPlotter(bs)
plt = plotter.plot_brillouin()

labels = {}
for k in bs.kpoints:
if k.label:
labels[k.label] = k.frac_coords

lines = []
for b in bs.branches:
lines.append([bs.kpoints[b['start_index']].frac_coords,
bs.kpoints[b['end_index']].frac_coords])

plt = pretty_plot_3d(width, height, dpi=dpi, fonts=fonts)
fig = plot_brillouin_zone(bs.lattice_rec, lines=lines, labels=labels,
ax=plt.gca())

basename = "brillouin.{}".format(image_format)
filename = "{}_{}".format(prefix, basename) if prefix else basename
if directory:
filename = os.path.join(directory, filename)
plt.savefig(filename, format=image_format, dpi=dpi, bbox_inches="tight")
fig.savefig(filename, format=image_format, dpi=dpi, bbox_inches="tight")
return plt


def find_vasprun_files():
"""Search for vasprun files from the current directory.
Expand Down
48 changes: 32 additions & 16 deletions sumo/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from cycler import cycler
from itertools import cycle
from matplotlib import rc
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d import Axes3D

default_colours = [[240, 163, 255], [0, 117, 220], [153, 63, 0], [76, 0, 92],
[66, 102, 0], [255, 0, 16], [157, 204, 0], [194, 0, 136],
Expand All @@ -30,6 +32,18 @@
_ticksize = 15
_linewidth = 1.3

def _setup_fonts(fonts=None):
"""Set some font and rendering parameters which are common to all plots"""
if type(fonts) is str:
fonts = [fonts]
fonts = default_fonts if fonts is None else fonts + default_fonts

rc('font', **{'family': 'sans-serif', 'sans-serif': fonts})
rc('text', usetex=False)
rc('pdf', fonttype=42)
rc('mathtext', fontset='stixsans')
rc('legend', handlelength=2)


def pretty_plot(width=5, height=5, plt=None, dpi=None, fonts=None):
"""Get a :obj:`matplotlib.pyplot` object with publication ready defaults.
Expand All @@ -49,7 +63,6 @@ def pretty_plot(width=5, height=5, plt=None, dpi=None, fonts=None):
:obj:`matplotlib.pyplot`: A :obj:`matplotlib.pyplot` object with
publication ready defaults set.
"""
from matplotlib import rc

if plt is None:
import matplotlib.pyplot as plt
Expand All @@ -73,18 +86,27 @@ def pretty_plot(width=5, height=5, plt=None, dpi=None, fonts=None):
ax.set_xlabel(ax.get_xlabel(), size=_labelsize)
ax.set_ylabel(ax.get_ylabel(), size=_labelsize)

if type(fonts) is str:
fonts = [fonts]
fonts = default_fonts if fonts is None else fonts + default_fonts
_setup_fonts(fonts=fonts)

rc('font', **{'family': 'sans-serif', 'sans-serif': fonts})
rc('text', usetex=False)
rc('pdf', fonttype=42)
rc('mathtext', fontset='stixsans')
rc('legend', handlelength=2)
return plt


def pretty_plot_3d(width=5, height=5, plt=None, dpi=None, fonts=None):
if plt is None:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(width, height), facecolor="w", dpi=dpi)
ax = plt.gca()
ax.set_prop_cycle(colour_cycler())
else:
fig = plt.gcf()
ax = plt.gca()

ax = fig.add_subplot(111, projection='3d')

_setup_fonts(fonts=fonts)

return plt

def pretty_subplot(nrows, ncols, width=5, height=5, sharex=True,
sharey=True, dpi=None, fonts=None, plt=None,
gridspec_kw=None):
Expand Down Expand Up @@ -114,7 +136,6 @@ def pretty_subplot(nrows, ncols, width=5, height=5, sharex=True,
:obj:`matplotlib.pyplot`: A :obj:`matplotlib.pyplot` subplot object
with publication ready defaults set.
"""
from matplotlib import rc

# TODO: Make this work if plt is already set...
if plt is None:
Expand All @@ -139,12 +160,7 @@ def pretty_subplot(nrows, ncols, width=5, height=5, sharex=True,
ax.set_xlabel(ax.get_xlabel(), size=_labelsize)
ax.set_ylabel(ax.get_ylabel(), size=_labelsize)

fonts = default_fonts if fonts is None else fonts + default_fonts

rc('font', **{'family': 'sans-serif', 'sans-serif': fonts})
rc('text', usetex=False)
rc('pdf', fonttype=42)
rc('mathtext', fontset='stixsans')
_setup_fonts(fonts=fonts)
rc('legend', handlelength=1.5)
return plt

Expand Down

0 comments on commit ff1402d

Please sign in to comment.