diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 7fcdea4efb..0f2046af25 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1052,7 +1052,9 @@ def get_data(self, is_train=True, task_key="Default"): except StopIteration: # Refresh the status of the dataloader to start from a new epoch with torch.device("cpu"): - self.training_data[task_key] = self.training_dataloader[task_key] + self.training_data[task_key] = self.training_dataloader[ + task_key + ] batch_data = next(iter(self.training_data[task_key])) else: if self.validation_data[task_key] is None: