From 268591cf8e03a8d23256efe711e5b3c3a4b71ce6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Mar 2024 03:24:49 -0500 Subject: [PATCH] pt: make get_data non-blocking (#3422) `to(DEVICE)` is cpu-blocking but `to(DEVICE, non-blocking=True)` is not blocking. This improves performance by at least 0.1s/100 steps. Before, `get_data` is blocking: ![1709698811097](https://github.com/deepmodeling/deepmd-kit/assets/9496702/b86b3928-41e7-46d3-8692-ca96b3a6475a) ![1709698811150](https://github.com/deepmodeling/deepmd-kit/assets/9496702/c4365203-3f3d-4de8-aae6-d8587f0e95a0) After, `get_data` is not blocking: ![1709698811122](https://github.com/deepmodeling/deepmd-kit/assets/9496702/d991c8f0-35c8-4b5d-822e-77af961e9b6e) ![1709698811169](https://github.com/deepmodeling/deepmd-kit/assets/9496702/a56160c2-78c7-4a44-aa96-1df0b520a60a) The subsequent blocking is `phys2inter` (via `torch.linalg.inv`). Signed-off-by: Jinzhe Zeng --- deepmd/pt/train/training.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 6938db9b3c..93afc38575 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -973,9 +973,11 @@ def get_data(self, is_train=True, task_key="Default"): continue elif not isinstance(batch_data[key], list): if batch_data[key] is not None: - batch_data[key] = batch_data[key].to(DEVICE) + batch_data[key] = batch_data[key].to(DEVICE, non_blocking=True) else: - batch_data[key] = [item.to(DEVICE) for item in batch_data[key]] + batch_data[key] = [ + item.to(DEVICE, non_blocking=True) for item in batch_data[key] + ] # we may need a better way to classify which are inputs and which are labels # now wrapper only supports the following inputs: input_keys = [