diff --git a/src/aiidalab_qe/common/bandpdoswidget.py b/src/aiidalab_qe/common/bandpdoswidget.py index e76de7673..51b737318 100644 --- a/src/aiidalab_qe/common/bandpdoswidget.py +++ b/src/aiidalab_qe/common/bandpdoswidget.py @@ -628,7 +628,8 @@ def __init__(self, bands=None, pdos=None, **kwargs): # Set the event handlers self.download_button.on_click(self.download_data) self.update_plot_button.on_click(self._update_plot) - # self.proj_bands_width_slider.observe(self._update_plot, names='value') + self.proj_bands_width_slider.observe(self._update_plot, names="value") + self.project_bands_box.observe(self._update_plot, names="value") self.dos_atoms_group.observe(self._update_plot, names="value") self.dos_plot_group.observe(self._update_plot, names="value") @@ -888,6 +889,10 @@ def get_bands_projections_data( bands_data["projected_bands"] = _prepare_projections_to_plot( bands_data, projections, bands_width ) + if plot_tag != "total": + bands_data["projected_bands"] = update_pdos_labels( + bands_data["projected_bands"] + ) return bands_data @@ -1287,7 +1292,9 @@ def get_labels_radial_nodes(pdos_dict): original_labels = [] orbital_dict = {} - for label_data in pdos_dict["dos"]: + label_data_list = pdos_dict["dos"] if "dos" in pdos_dict else pdos_dict + for label_data in label_data_list: + # for label_data in pdos_dict["dos"]: label_str = label_data["label"] original_labels.append(label_str) @@ -1439,7 +1446,8 @@ def update_pdos_labels(pdos_data): orbital_assignment = assign_orbital_labels(orbital_dict) updated_labels = get_new_pdos_labels(original_labels, orbital_assignment) + label_data_list = pdos_data["dos"] if "dos" in pdos_data else pdos_data for idx, label in enumerate(updated_labels): - pdos_data["dos"][idx]["label"] = label + label_data_list[idx]["label"] = label return pdos_data