Skip to content

Commit

Permalink
Added option to export dos data into csv,txt,dat,json
Browse files Browse the repository at this point in the history
  • Loading branch information
lllangWV committed Jul 18, 2024
1 parent f310a7f commit 954e6c9
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 21 deletions.
148 changes: 134 additions & 14 deletions pyprocar/plotter/dos_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import os
import yaml
import json
from typing import List

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pylab as plt
from matplotlib.collections import LineCollection
Expand Down Expand Up @@ -55,6 +57,7 @@ def __init__(self,
self.handles = []
self.labels = []
self.orientation = orientation
self.values_dict={}


if ax is None:
Expand All @@ -76,14 +79,19 @@ def plot_dos(self, spins: List[int] = None):
energies = self.dos.energies
dos_total = self.dos.total



self._set_plot_limits(spin_channels)
for ispin, spin_channel in enumerate(spin_channels):

# flip the sign of the total dos if there are 2 spin channels
dos_total_spin = dos_total[spin_channel, :] * (-1 if ispin > 0 else 1)
values_dict['dos_total_spin-'+str(spin_channel)]=dos_total_spin

self._plot_total_dos(energies, dos_total_spin, spin_channel)
values_dict['energies']=energies
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_total_spin


self.values_dict=values_dict
return values_dict

def plot_parametric(self,
Expand All @@ -98,7 +106,11 @@ def plot_parametric(self,
spin_projections,
principal_q_numbers)


orbital_string=':'.join([str(orbital) for orbital in orbitals])
atom_string=':'.join([str(atom) for atom in atoms])
spin_string=':'.join([str(spin_projection) for spin_projection in spin_projections])


self._setup_colorbar(dos_projected, dos_total_projected)
self._set_plot_limits(spin_channels)

Expand All @@ -114,10 +126,15 @@ def plot_parametric(self,
if self.config.plot_total:
self._plot_total_dos(energies, dos_spin_total, spin_channel)

values_dict['spin-'+str(spin_channel)+'_projections']=normalized_dos_spin_projected
values_dict['energies']=energies
values_dict['dos_total_spin-'+str(spin_channel)]=dos_spin_total
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_spin_total
values_dict['spinChannel-'+str(spin_channel) +
f'_orbitals-{orbital_string}' +
f'_atoms-{atom_string}' +
f'_spinProjection-{spin_string}'] =normalized_dos_spin_projected


self.values_dict=values_dict
return values_dict

def plot_parametric_line(self,
Expand All @@ -132,6 +149,10 @@ def plot_parametric_line(self,
spin_projections,
principal_q_numbers)

orbital_string=':'.join([str(orbital) for orbital in orbitals])
atom_string=':'.join([str(atom) for atom in atoms])
spin_string=':'.join([str(spin_projection) for spin_projection in spin_projections])

self._setup_colorbar(dos_projected, dos_total_projected)
self._set_plot_limits(spin_channels)

Expand All @@ -145,10 +166,14 @@ def plot_parametric_line(self,

self._plot_spin_data_parametric_line(energies, dos_spin_total, normalized_dos_spin_projected, spin_channel)

values_dict['spin-'+str(spin_channel)+'_projections']=normalized_dos_spin_projected
values_dict['energies']=energies
values_dict['dos_total_spin-'+str(spin_channel)]=dos_spin_total
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_spin_total
values_dict['spinChannel-'+str(spin_channel) +
f'_orbitals-{orbital_string}' +
f'_atoms-{atom_string}' +
f'_spinProjection-{spin_string}']=normalized_dos_spin_projected

self.values_dict=values_dict
return values_dict

def plot_stack_species(
Expand All @@ -168,13 +193,22 @@ def plot_stack_species(
for specie in range(len(self.structure.species)):
idx = (np.array(self.structure.atoms) == self.structure.species[specie])
atoms = list(np.where(idx)[0])

orbital_string=':'.join([str(orbital) for orbital in orbitals])
atom_string=':'.join([str(atom) for atom in atoms])
spin_string=':'.join([str(spin_projection) for spin_projection in spin_projections])


dos_total, dos_total_projected, dos_projected = self._calculate_parametric_dos(
atoms,
orbitals,
spin_projections,
principal_q_numbers)

color=self.config.colors[specie]



for ispin, spin_channel in enumerate(spin_channels):
energies, dos_spin_total, scaled_dos_spin_projected = self._prepare_parametric_spin_data(spin_channel,
ispin,
Expand All @@ -198,16 +232,24 @@ def plot_stack_species(
bottom_value+=top_value

label=self.structure.species[specie] + orbital_label
values_dict[label+'_spin-'+str(spin_channel)+'_projections']=scaled_dos_spin_projected

values_dict['energies']=energies
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_spin_total
values_dict['spinChannel-'+str(spin_channel) +
f'_orbitals-{orbital_string}' +
f'_atoms-{atom_string}' +
f'_spinProjection-{spin_string}']=scaled_dos_spin_projected




self.handles.append(handle)
self.labels.append(label)

if self.config.plot_total:
total_values_dict=self.plot_dos(spin_channels)
values_dict.update(total_values_dict)

self.values_dict=values_dict
return values_dict

def plot_stack_orbitals(
Expand All @@ -226,6 +268,10 @@ def plot_stack_orbitals(
bottom_value=0
for iorb in range(len(orb_l)):

orbital_string=':'.join([str(orbital) for orbital in orb_l[iorb]])
atom_string=':'.join([str(atom) for atom in atoms])
spin_string=':'.join([str(spin_projection) for spin_projection in spin_projections])

dos_total, dos_total_projected, dos_projected = self._calculate_parametric_dos(
atoms=atoms,
orbitals=orb_l[iorb],
Expand Down Expand Up @@ -256,15 +302,21 @@ def plot_stack_orbitals(
bottom_value+=top_value

label=atom_names + orb_names[iorb]# + self.config.spin_labels[ispin]
values_dict[label+'_spin-'+str(spin_channel)+'_projections']=scaled_dos_spin_projected

values_dict['energies']=energies
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_spin_total
values_dict['spinChannel-'+str(spin_channel) +
f'_orbitals-{orbital_string}' +
f'_atoms-{atom_string}' +
f'_spinProjection-{spin_string}']=scaled_dos_spin_projected

self.handles.append(handle)
self.labels.append(label)

if self.config.plot_total:
total_values_dict=self.plot_dos(spin_channels)
values_dict.update(total_values_dict)

self.values_dict=values_dict
return values_dict

def plot_stack(
Expand Down Expand Up @@ -297,6 +349,10 @@ def plot_stack(
orbitals = items[specie]
orbital_label=self._get_stack_labels(orbitals)

orbital_string=':'.join([str(orbital) for orbital in orbitals])
atom_string=':'.join([str(atom) for atom in atoms])
spin_string=':'.join([str(spin_projection) for spin_projection in spin_projections])

dos_total, dos_total_projected, dos_projected = self._calculate_parametric_dos(
atoms=atoms,
orbitals=orbitals,
Expand Down Expand Up @@ -327,17 +383,22 @@ def plot_stack(
bottom_value+=top_value

label=specie + orbital_label

values_dict[label+'_spin-'+str(spin_channel)+'_projections']=scaled_dos_spin_projected
values_dict['energies']=energies
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_spin_total
values_dict['spinChannel-'+str(spin_channel) +
f'_orbitals-{orbital_string}' +
f'_atoms-{atom_string}' +
f'_spinProjection-{spin_string}']=scaled_dos_spin_projected



self.handles.append(handle)
self.labels.append(label)

if self.config.plot_total:
total_values_dict=self.plot_dos(spin_channels)
values_dict.update(total_values_dict)

self.values_dict=values_dict

return values_dict

Expand Down Expand Up @@ -916,4 +977,63 @@ def update_config(self, config_dict):
for key,value in config_dict.items():
self.config[key]['value']=value

def export_data(self,filename):
"""
This method will export the data to a csv file
Parameters
----------
filename : str
The file name to export the data to
Returns
-------
None
None
"""
possible_file_types=['csv','txt','json','dat']
file_type=filename.split('.')[-1]
if file_type not in possible_file_types:
raise ValueError(f"The file type must be {possible_file_types}")
if self.values_dict is None:
raise ValueError("The data has not been plotted yet")

column_names=list(self.values_dict.keys())
sorted_column_names=[None]*len(column_names)
index=0
for column_name in column_names:
if 'energies' in column_name.split('_')[0]:
sorted_column_names[index]=column_name
index+=1

for column_name in column_names:
if 'dosTotalSpin' in column_name.split('_')[0]:
sorted_column_names[index]=column_name
index+=1
for ispin in range(2):
for column_name in column_names:

if 'spinChannel-0' in column_name.split('_')[0] and ispin==0:
sorted_column_names[index]=column_name
index+=1
if 'spinChannel-1' in column_name.split('_')[0] and ispin==1:
sorted_column_names[index]=column_name
index+=1



column_names.sort()
if file_type=='csv':
df=pd.DataFrame(self.values_dict)
df.to_csv(filename, columns=sorted_column_names, index=False)
elif file_type=='txt':
df=pd.DataFrame(self.values_dict)
df.to_csv(filename, columns=sorted_column_names, sep='\t', index=False)
elif file_type=='json':
with open(filename, 'w') as outfile:
for key,value in self.values_dict.items():
self.values_dict[key]=value.tolist()
json.dump(self.values_dict, outfile)
elif file_type=='dat':
df=pd.DataFrame(self.values_dict)
df.to_csv(filename, columns=sorted_column_names, sep=' ', index=False)
34 changes: 27 additions & 7 deletions pyprocar/scripts/scriptDosplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def dosplot(
ax:plt.Axes=None,
show:bool=True,
print_plot_opts:bool=False,
export_data_file:str=None,
export_append_mode:bool=True,
**kwargs
):

Expand Down Expand Up @@ -230,6 +232,14 @@ def dosplot(
e.g. ``plt_show=True``
export_data_file : str, optional
The file name to export the data to. If not provided the
data will not be exported.
export_append_mode : bool, optional
Boolean to append the mode to the file name. If not provided the
data will be overwritten.
print_plot_opts: bool, optional
Boolean to print the plotting options
Expand Down Expand Up @@ -323,43 +333,43 @@ def dosplot(
if orbitals is None:
orbitals = list(np.arange(len(edos_plot.dos.projected[0][0]), dtype=int))

values_dict = edos_plot.plot_parametric_line(
edos_plot.plot_parametric_line(
atoms=atoms,
principal_q_numbers=[-1],
spins=spins,
orbitals=orbitals,
)

elif mode == "stack_species":
values_dict = edos_plot.plot_stack_species(
edos_plot.plot_stack_species(
spins=spins,
orbitals=orbitals,
)
elif mode == "stack_orbitals":
values_dict = edos_plot.plot_stack_orbitals(
edos_plot.plot_stack_orbitals(
spins=spins,
atoms=atoms,
)
elif mode == "stack":
values_dict = edos_plot.plot_stack(
edos_plot.plot_stack(
spins=spins,
items=items,
)
elif mode == "overlay_species":
values_dict = edos_plot.plot_stack_species(
edos_plot.plot_stack_species(
spins=spins,
orbitals=orbitals,

overlay_mode=True
)
elif mode == "overlay_orbitals":
values_dict = edos_plot.plot_stack_orbitals(
edos_plot.plot_stack_orbitals(
spins=spins,
atoms=atoms,
overlay_mode=True
)
elif mode == "overlay":
values_dict = edos_plot.plot_stack(
edos_plot.plot_stack(
spins=spins,
items=items,
overlay_mode=True
Expand Down Expand Up @@ -399,4 +409,14 @@ def dosplot(
edos_plot.save(savefig)
if show:
edos_plot.show()

if export_data_file is not None:
if export_append_mode:
file_basename,file_type=export_data_file.split('.')
filename=f"{file_basename}_{mode}.{file_type}"
else:
filename=export_data_file
edos_plot.export_data(filename)


return edos_plot.fig, edos_plot.ax
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ scikit-image
scipy
seekpath
spglib
pandas
trimesh
ase
sympy
Expand Down

0 comments on commit 954e6c9

Please sign in to comment.