Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 17, 2024
1 parent 5361493 commit 90dcd5b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
20 changes: 13 additions & 7 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import (
List,
)
import numpy as np

import h5py
import torch
import torch.distributed as dist
Expand All @@ -22,12 +22,13 @@
Dataset,
WeightedRandomSampler,
)
from torch.utils.data._utils.collate import (

Check warning on line 25 in deepmd/pt/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataloader.py#L25

Added line #L25 was not covered by tests
collate_tensor_fn,
)
from torch.utils.data.distributed import (
DistributedSampler,
)
from torch.utils.data._utils.collate import (
collate_tensor_fn
)

from deepmd.pt.model.descriptor import (
Descriptor,
)
Expand Down Expand Up @@ -231,6 +232,7 @@ def __next__(self):
raise StopIteration
return item


def collate_batch(batch):
example = batch[0]
result = {}

Check warning on line 238 in deepmd/pt/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataloader.py#L238

Added line #L238 was not covered by tests
Expand All @@ -242,10 +244,14 @@ def collate_batch(batch):
result[key] = None
elif key == "fid":
result[key] = [d[key] for d in batch]
elif key == 'type':
result['atype'] = collate_tensor_fn([torch.as_tensor(d[key]) for d in batch])
elif key == "type":
result["atype"] = collate_tensor_fn(

Check warning on line 248 in deepmd/pt/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataloader.py#L247-L248

Added lines #L247 - L248 were not covered by tests
[torch.as_tensor(d[key]) for d in batch]
)
else:
result[key] = collate_tensor_fn([torch.as_tensor(d[key]) for d in batch])
result[key] = collate_tensor_fn(

Check warning on line 252 in deepmd/pt/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataloader.py#L252

Added line #L252 was not covered by tests
[torch.as_tensor(d[key]) for d in batch]
)
return result


Expand Down
5 changes: 2 additions & 3 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DeepmdData,
)


class DeepmdDataSystem:
def __init__(
self,
Expand Down Expand Up @@ -574,9 +575,7 @@ def __init__(
"""
self._type_map = type_map
self._data_system = DeepmdData(

Check warning on line 577 in deepmd/pt/utils/dataset.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataset.py#L577

Added line #L577 was not covered by tests
sys_path=system,
shuffle_test=shuffle,
type_map=self._type_map
sys_path=system, shuffle_test=shuffle, type_map=self._type_map
)
self._data_system.add("energy", 1, atomic=False, must=False, high_prec=True)
self._data_system.add("force", 3, atomic=True, must=False, high_prec=False)
Expand Down
11 changes: 6 additions & 5 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def check_test_size(self, test_size):
return self.test_dir, tmpe.shape[0]
else:
return None
def get_item(self,index:int) -> dict:

def get_item(self, index: int) -> dict:
"""Get a single frame data . The frame is picked from the data system by index.
Parameters
Expand All @@ -275,7 +276,7 @@ def get_item(self,index:int) -> dict:
frame = self.preprocess(frame)
frame["fid"] = index
return frame

Check warning on line 278 in deepmd/utils/data.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/data.py#L272-L278

Added lines #L272 - L278 were not covered by tests

def get_batch(self, batch_size: int) -> dict:
"""Get a batch of data with `batch_size` frames. The frames are randomly picked from the data system.
Expand Down Expand Up @@ -466,8 +467,8 @@ def _shuffle_data(self, data):
else:
ret[kk] = data[kk]
return ret, idx
def _get_nframes(self,set_name:DPPath):

def _get_nframes(self, set_name: DPPath):
# get nframes
if not isinstance(set_name, DPPath):
set_name = DPPath(set_name)
Expand All @@ -480,7 +481,7 @@ def _get_nframes(self,set_name:DPPath):
coord = coord.reshape([1, -1])
nframes = coord.shape[0]
return nframes

Check warning on line 483 in deepmd/utils/data.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/data.py#L479-L483

Added lines #L479 - L483 were not covered by tests

def preprocess(self, data):
for kk in self.data_dict.keys():
if "find_" in kk:
Expand Down

0 comments on commit 90dcd5b

Please sign in to comment.