Skip to content

Commit

Permalink
fix p argument calling
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz authored Jan 28, 2024
1 parent b3f0b8d commit 893d904
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ def __len__(self):
def __getitem__(self, index=None):
"""Get a batch of frames from the selected system."""
if index is None:
index = dp_random.choice(np.arange(self.nsystems), self.probs)
index = dp_random.choice(np.arange(self.nsystems), p=self.probs)
b_data = self._data_systems[index].get_batch(self._batch_size)
b_data["natoms"] = torch.tensor(
self._natoms_vec[index], device=env.PREPROCESS_DEVICE
Expand All @@ -892,7 +892,7 @@ def __getitem__(self, index=None):
def get_training_batch(self, index=None):
"""Get a batch of frames from the selected system."""
if index is None:
index = dp_random.choice(np.arange(self.nsystems), self.probs)
index = dp_random.choice(np.arange(self.nsystems), p=self.probs)

Check warning on line 895 in deepmd/pt/utils/dataset.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataset.py#L895

Added line #L895 was not covered by tests
b_data = self._data_systems[index].get_batch_for_train(self._batch_size)
b_data["natoms"] = torch.tensor(
self._natoms_vec[index], device=env.PREPROCESS_DEVICE
Expand Down

0 comments on commit 893d904

Please sign in to comment.