From 32da24366e5dbb6a7c271540cf4fe5198ba2837a Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 23 Dec 2024 20:30:40 +0800 Subject: [PATCH 1/2] 4424 --- deepmd/pt/utils/dataset.py | 12 ++++++++++++ deepmd/pt/utils/stat.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 3043839308..df4f1fc6cd 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -40,6 +40,18 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data + def _build_element_to_frames(self): + """Mapping element types to frame indexes""" + element_to_frames = {element: [] for element in range(self._ntypes)} + for frame_idx in range(len(self)): + frame_data = self._data_system.get_item_torch(frame_idx) + + elements = frame_data["atype"] + for element in set(elements): + if len(element_to_frames[element]) < 10: + element_to_frames[element].append(frame_idx) + return element_to_frames + def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" for data_item in data_requirement: diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 1c5e3f1c52..82cb816f7b 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -82,6 +82,41 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) + + all_elements = set() + if datasets and hasattr(datasets[0], 'element_to_frames'): + all_elements.update(datasets[0].element_to_frames.keys()) + print('we want', all_elements) + + collected_elements = set() + for sys_stat in lst: + if 'atype' in sys_stat: + collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + missing_elements = all_elements - collected_elements + + for missing_element in missing_elements: + for i, dataset in enumerate(datasets): + if hasattr(dataset, 'element_to_frames'): + frame_indices = dataset.element_to_frames.get(missing_element, []) + for frame_idx in frame_indices: + if len(lst[i]['atype']) >= nbatches: + break + frame_data = dataset[frame_idx] + for key in frame_data: + if key not in lst[i]: + lst[i][key] = [] + lst[i][key].append(frame_data[key]) + + collected_elements = set() + for sys_stat in lst: + if 'atype' in sys_stat: + collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + + for sys_stat in lst: + for key in sys_stat: + if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): + sys_stat[key] = torch.cat(sys_stat[key], dim=0) + return lst From adf2315d57730d2cdfe2a4244a2d2e167bcb45e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Dec 2024 12:35:16 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/dataset.py | 6 +++--- deepmd/pt/utils/stat.py | 26 ++++++++++++++++---------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index df4f1fc6cd..b5fd8c58a0 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -42,13 +42,13 @@ def __getitem__(self, index): def _build_element_to_frames(self): """Mapping element types to frame indexes""" - element_to_frames = {element: [] for element in range(self._ntypes)} + element_to_frames = {element: [] for element in range(self._ntypes)} for frame_idx in range(len(self)): frame_data = self._data_system.get_item_torch(frame_idx) - elements = frame_data["atype"] + elements = frame_data["atype"] for element in set(elements): - if len(element_to_frames[element]) < 10: + if len(element_to_frames[element]) < 10: element_to_frames[element].append(frame_idx) return element_to_frames diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 82cb816f7b..e3e7410a21 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -84,22 +84,24 @@ def make_stat_input(datasets, dataloaders, nbatches): lst.append(sys_stat) all_elements = set() - if datasets and hasattr(datasets[0], 'element_to_frames'): + if datasets and hasattr(datasets[0], "element_to_frames"): all_elements.update(datasets[0].element_to_frames.keys()) - print('we want', all_elements) + print("we want", all_elements) collected_elements = set() for sys_stat in lst: - if 'atype' in sys_stat: - collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + if "atype" in sys_stat: + collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy())) missing_elements = all_elements - collected_elements for missing_element in missing_elements: for i, dataset in enumerate(datasets): - if hasattr(dataset, 'element_to_frames'): - frame_indices = dataset.element_to_frames.get(missing_element, []) + if hasattr(dataset, "element_to_frames"): + frame_indices = dataset.element_to_frames.get( + missing_element, [] + ) for frame_idx in frame_indices: - if len(lst[i]['atype']) >= nbatches: + if len(lst[i]["atype"]) >= nbatches: break frame_data = dataset[frame_idx] for key in frame_data: @@ -109,12 +111,16 @@ def make_stat_input(datasets, dataloaders, nbatches): collected_elements = set() for sys_stat in lst: - if 'atype' in sys_stat: - collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + if "atype" in sys_stat: + collected_elements.update( + np.unique(sys_stat["atype"].cpu().numpy()) + ) for sys_stat in lst: for key in sys_stat: - if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): + if isinstance(sys_stat[key], list) and isinstance( + sys_stat[key][0], torch.Tensor + ): sys_stat[key] = torch.cat(sys_stat[key], dim=0) return lst