Skip to content

Commit

Permalink
4424
Browse files Browse the repository at this point in the history
  • Loading branch information
SumGuo-88 committed Dec 23, 2024
1 parent e695a91 commit 32da243
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
12 changes: 12 additions & 0 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def __getitem__(self, index):
b_data["natoms"] = self._natoms_vec
return b_data

def _build_element_to_frames(self):
"""Mapping element types to frame indexes"""
element_to_frames = {element: [] for element in range(self._ntypes)}
for frame_idx in range(len(self)):
frame_data = self._data_system.get_item_torch(frame_idx)

elements = frame_data["atype"]
for element in set(elements):
if len(element_to_frames[element]) < 10:
element_to_frames[element].append(frame_idx)
return element_to_frames

def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
"""Add data requirement for this data system."""
for data_item in data_requirement:
Expand Down
35 changes: 35 additions & 0 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,41 @@ def make_stat_input(datasets, dataloaders, nbatches):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
lst.append(sys_stat)

all_elements = set()
if datasets and hasattr(datasets[0], 'element_to_frames'):
all_elements.update(datasets[0].element_to_frames.keys())
print('we want', all_elements)

collected_elements = set()
for sys_stat in lst:
if 'atype' in sys_stat:
collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy()))
missing_elements = all_elements - collected_elements

for missing_element in missing_elements:
for i, dataset in enumerate(datasets):
if hasattr(dataset, 'element_to_frames'):
frame_indices = dataset.element_to_frames.get(missing_element, [])
for frame_idx in frame_indices:
if len(lst[i]['atype']) >= nbatches:
break
frame_data = dataset[frame_idx]
for key in frame_data:
if key not in lst[i]:
lst[i][key] = []
lst[i][key].append(frame_data[key])

collected_elements = set()
for sys_stat in lst:
if 'atype' in sys_stat:
collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy()))

for sys_stat in lst:
for key in sys_stat:
if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)

return lst


Expand Down

0 comments on commit 32da243

Please sign in to comment.