diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 6938db9b3c..93afc38575 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -973,9 +973,11 @@ def get_data(self, is_train=True, task_key="Default"): continue elif not isinstance(batch_data[key], list): if batch_data[key] is not None: - batch_data[key] = batch_data[key].to(DEVICE) + batch_data[key] = batch_data[key].to(DEVICE, non_blocking=True) else: - batch_data[key] = [item.to(DEVICE) for item in batch_data[key]] + batch_data[key] = [ + item.to(DEVICE, non_blocking=True) for item in batch_data[key] + ] # we may need a better way to classify which are inputs and which are labels # now wrapper only supports the following inputs: input_keys = [