Skip to content

Commit

Permalink
pt: avoid torch.tensor(constant) during forward (#3421)
Browse files Browse the repository at this point in the history
`torch.tensor(constant)` copies memory from the CPU to the GPU, so it is
host blocking and should be avoided in the `forward` method.


Before, the CPU waited for the GPU using `cudaStreamSynchronize`,
blocking the CPU from doing the following things, where the CPU memory
needs to be copied to the GPU, a.k.a. host-to-device (H2D).


![1709693858444](https://github.com/deepmodeling/deepmd-kit/assets/9496702/e6fb6281-245f-4620-82bd-dbcd02121e32)

After this PR, all ops in the energy loss are asynchronous, as no H2D
happens.

![1709694622120](https://github.com/deepmodeling/deepmd-kit/assets/9496702/172e1601-1a9c-4236-a1e2-a749edc25c50)

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Mar 8, 2024
1 parent 66edd1f commit d3dd604
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 15 deletions.
14 changes: 7 additions & 7 deletions deepmd/pt/loss/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
coord_mask = label["coord_mask"]
type_mask = label["type_mask"]

loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if self.has_coord:
if self.mask_loss_coord:
Expand All @@ -66,9 +66,9 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
beta=self.beta,
)
else:
coord_loss = torch.tensor(
0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
coord_loss = torch.zeros(
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)[0]
else:
coord_loss = F.smooth_l1_loss(
updated_coord.view(-1, 3),
Expand All @@ -89,9 +89,9 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
reduction="mean",
)
else:
token_loss = torch.tensor(
0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
token_loss = torch.zeros(
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)[0]
else:
token_loss = F.nll_loss(
F.log_softmax(logits.view(-1, self.ntypes - 1), dim=-1),
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef
pref_v = self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * coef
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
Other losses for display.
"""
del learning_rate, mae
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if (
self.has_local_weight
Expand Down
14 changes: 8 additions & 6 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def __init__(

self.atomic_bias = None
self.mixed_types_list = [model.mixed_types() for model in self.models]
self.rcuts = torch.tensor(
self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE
)
self.nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE)
BaseAtomicModel.__init__(self, **kwargs)

def mixed_types(self) -> bool:
Expand Down Expand Up @@ -117,14 +121,12 @@ def get_model_sels(self) -> List[List[int]]:
"""Get the sels for each individual models."""
return [model.get_sel() for model in self.models]

def _sort_rcuts_sels(self, device: torch.device) -> Tuple[List[float], List[int]]:
def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]:
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
rcuts = torch.tensor(self.get_model_rcuts(), dtype=torch.float64, device=device)
nsels = torch.tensor(self.get_model_nsels(), device=device)
zipped = torch.stack(
[
rcuts,
nsels,
self.rcuts,
self.nsels,
],
dim=0,
).T
Expand Down Expand Up @@ -171,7 +173,7 @@ def forward_atomic(
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels(device=extended_coord.device)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
nlists = build_multiple_neighbor_list(
extended_coord,
nlist,
Expand Down

0 comments on commit d3dd604

Please sign in to comment.