Skip to content

Commit

Permalink
pt: avoid torch.tensor(constant) during forward
Browse files Browse the repository at this point in the history
torch.tensor(constant) copies memory from the CPU to GPU, so it is host blocking and should be avoided in the `forward` method.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Mar 6, 2024
1 parent b0171ce commit 62003db
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
10 changes: 5 additions & 5 deletions deepmd/pt/loss/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
coord_mask = label["coord_mask"]
type_mask = label["type_mask"]

loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if self.has_coord:
if self.mask_loss_coord:
Expand All @@ -66,8 +66,8 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
beta=self.beta,
)
else:
coord_loss = torch.tensor(
0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
coord_loss = torch.zeros(
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
else:
coord_loss = F.smooth_l1_loss(
Expand All @@ -89,8 +89,8 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
reduction="mean",
)
else:
token_loss = torch.tensor(
0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
token_loss = torch.zeros(
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
else:
token_loss = F.nll_loss(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef
pref_v = self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * coef
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
Other losses for display.
"""
del learning_rate, mae
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
loss = torch.tensor.zeros(
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
more_loss = {}
if (
self.has_local_weight
Expand Down
12 changes: 7 additions & 5 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def __init__(
self.type_map = type_map
self.atomic_bias = None
self.mixed_types_list = [model.mixed_types() for model in self.models]
self.rcuts = torch.tensor(
self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE
)
self.nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE)
BaseAtomicModel.__init__(self, **kwargs)

def mixed_types(self) -> bool:
Expand Down Expand Up @@ -113,12 +117,10 @@ def get_model_sels(self) -> List[List[int]]:

def _sort_rcuts_sels(self, device: torch.device) -> Tuple[List[float], List[int]]:
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
rcuts = torch.tensor(self.get_model_rcuts(), dtype=torch.float64, device=device)
nsels = torch.tensor(self.get_model_nsels(), device=device)
zipped = torch.stack(
[
torch.tensor(rcuts, device=device),
torch.tensor(nsels, device=device),
self.rcuts,
self.nsels,
],
dim=0,
).T
Expand Down Expand Up @@ -165,7 +167,7 @@ def forward_atomic(
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels(device=extended_coord.device)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
nlists = build_multiple_neighbor_list(
extended_coord,
nlist,
Expand Down

0 comments on commit 62003db

Please sign in to comment.