diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 3043839308..df4f1fc6cd 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -40,6 +40,18 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data + def _build_element_to_frames(self): + """Mapping element types to frame indexes""" + element_to_frames = {element: [] for element in range(self._ntypes)} + for frame_idx in range(len(self)): + frame_data = self._data_system.get_item_torch(frame_idx) + + elements = frame_data["atype"] + for element in set(elements): + if len(element_to_frames[element]) < 10: + element_to_frames[element].append(frame_idx) + return element_to_frames + def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" for data_item in data_requirement: diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 1c5e3f1c52..82cb816f7b 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -82,6 +82,41 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) + + all_elements = set() + if datasets and hasattr(datasets[0], 'element_to_frames'): + all_elements.update(datasets[0].element_to_frames.keys()) + print('we want', all_elements) + + collected_elements = set() + for sys_stat in lst: + if 'atype' in sys_stat: + collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + missing_elements = all_elements - collected_elements + + for missing_element in missing_elements: + for i, dataset in enumerate(datasets): + if hasattr(dataset, 'element_to_frames'): + frame_indices = dataset.element_to_frames.get(missing_element, []) + for frame_idx in frame_indices: + if len(lst[i]['atype']) >= nbatches: + break + frame_data = dataset[frame_idx] + for key in frame_data: + if key not in lst[i]: + lst[i][key] = [] + lst[i][key].append(frame_data[key]) + + collected_elements = set() + for sys_stat in lst: + if 'atype' in sys_stat: + collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + + for sys_stat in lst: + for key in sys_stat: + if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): + sys_stat[key] = torch.cat(sys_stat[key], dim=0) + return lst