Skip to content

Commit

Permalink
Improve plotting SHEAP output
Browse files Browse the repository at this point in the history
  • Loading branch information
zhubonan committed Jul 3, 2022
1 parent e4a1110 commit 6b18da2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ max-bool-expr=5
max-branches=12

# Maximum number of locals for function / method body
max-locals=20
max-locals=40

# Maximum number of parents for a class (see R0901).
max-parents=20
Expand Down
25 changes: 18 additions & 7 deletions disp/cli/cmd_tools.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
Collection of useful tools
"""
from json import dumps
from json import dumps, load
from pathlib import Path
import click
from ase.io import read
from disp.tools.modcell import modify_cell
from disp.tools.sheapio import sheap_to_dict, parse_sheap_output
from disp.tools.sheapio import SheapOut, sheap_to_dict, parse_sheap_output

# pylint: disable=import-outside-toplevel

Expand Down Expand Up @@ -93,21 +94,27 @@ def cmd_sheap2json(sheapout, path):
@click.option('--vmax',
help='Color scale maximum relative to the minimum value',
default=0.25)
def plot_sheap(sheapout, vmax):
@click.option('--savename', default='sheap-map.pdf')
@click.option('--plot/--no-plot', default=True)
def plot_sheap(sheapout, vmax, savename, plot):
"""
Plot the output of SHEAP as spheres, respecting the specification of radius output.
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

with open(sheapout) as handle:
parsed = parse_sheap_output(handle)
if 'json' in sheapout:
with open(sheapout) as handle:
parsed = SheapOut(**load(handle))
else:
with open(sheapout) as handle:
parsed = parse_sheap_output(handle)
coords = np.array(parsed.coords)
radius = np.array(parsed.radius)

# Plot the output
_, axes = plt.subplots(1, 1)
fig, axes = plt.subplots(1, 1)
axes.set_aspect('equal')

# Compute the colours
Expand All @@ -126,4 +133,8 @@ def plot_sheap(sheapout, vmax):
axes.set_xlim(min(xmin, ymin) - 0.1, max(xmax, ymax) + 0.1)
axes.set_ylim(min(xmin, ymin) - 0.1, max(xmax, ymax) + 0.1)
plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap))
plt.show()
if plot:
plt.show()
fig.savefig(savename, dpi=200)
# Save the raw data as json
Path(savename).with_suffix('.json').write_text(dumps(parsed._asdict()))

0 comments on commit 6b18da2

Please sign in to comment.