Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix) Make bias statistics complete for all elements #4496

Draft
wants to merge 2 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def __getitem__(self, index):
b_data["natoms"] = self._natoms_vec
return b_data

def _build_element_to_frames(self):

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.
"""Mapping element types to frame indexes"""
element_to_frames = {element: [] for element in range(self._ntypes)}
for frame_idx in range(len(self)):
frame_data = self._data_system.get_item_torch(frame_idx)

elements = frame_data["atype"]
for element in set(elements):
if len(element_to_frames[element]) < 10:
element_to_frames[element].append(frame_idx)
return element_to_frames

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix indentation error and consider making the max frame limit configurable.
Static analysis flags a SyntaxError likely due to insufficient indentation after defining the method. Ensure this block is indented so that Python recognizes it as part of the method. Also, the hard-coded limit of 10 frames per element may need to be a configurable parameter if you expect variation in your datasets.

Example indentation fix:

-        def _build_element_to_frames(self):
-        """Mapping element types to frame indexes"""
-        element_to_frames = {element: [] for element in range(self._ntypes)} 
+    def _build_element_to_frames(self):
+        """Mapping element types to frame indexes"""
+        element_to_frames = {element: [] for element in range(self._ntypes)}

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.8.2)

44-44: SyntaxError: Expected an indented block after function definition

def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
"""Add data requirement for this data system."""
for data_item in data_requirement:
Expand Down
35 changes: 35 additions & 0 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,41 @@ def make_stat_input(datasets, dataloaders, nbatches):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
lst.append(sys_stat)

all_elements = set()
if datasets and hasattr(datasets[0], 'element_to_frames'):
all_elements.update(datasets[0].element_to_frames.keys())
print('we want', all_elements)

collected_elements = set()
for sys_stat in lst:
if 'atype' in sys_stat:
collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy()))
missing_elements = all_elements - collected_elements

for missing_element in missing_elements:
for i, dataset in enumerate(datasets):
if hasattr(dataset, 'element_to_frames'):
frame_indices = dataset.element_to_frames.get(missing_element, [])
for frame_idx in frame_indices:
if len(lst[i]['atype']) >= nbatches:
break
frame_data = dataset[frame_idx]
for key in frame_data:
if key not in lst[i]:
lst[i][key] = []
lst[i][key].append(frame_data[key])

collected_elements = set()
for sys_stat in lst:
if 'atype' in sys_stat:
collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy()))

Check notice

Code scanning / CodeQL

Nested loops with same variable Note

Nested for statement uses loop variable 'sys_stat' of enclosing
for statement
.
for sys_stat in lst:
for key in sys_stat:
if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)

return lst


Expand Down
Loading