Skip to content

Commit

Permalink
pt: non-blocking copying training data to device
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Mar 6, 2024
1 parent b0171ce commit f1e1ba6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,9 +954,11 @@ def get_data(self, is_train=True, task_key="Default"):
continue
elif not isinstance(batch_data[key], list):
if batch_data[key] is not None:
batch_data[key] = batch_data[key].to(DEVICE)
batch_data[key] = batch_data[key].to(DEVICE, non_blocking=True)
else:
batch_data[key] = [item.to(DEVICE) for item in batch_data[key]]
batch_data[key] = [

Check warning on line 959 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L959

Added line #L959 was not covered by tests
item.to(DEVICE, non_blocking=True) for item in batch_data[key]
]
# we may need a better way to classify which are inputs and which are labels
# now wrapper only supports the following inputs:
input_keys = [
Expand Down

0 comments on commit f1e1ba6

Please sign in to comment.