Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add pair table model to pytorch #3192

Merged
merged 31 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
47cff4b
feat: add pair table model to pytorch
Jan 28, 2024
04b6f57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2024
eb59d87
fix: typo
Jan 28, 2024
b7cbbd5
fix: typo
Jan 28, 2024
a1a76bb
Merge branch 'devel' into devel
anyangml Jan 28, 2024
84767f3
fix: update ruct extrapolation
Jan 28, 2024
8fee8fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2024
ff08515
fix: update allclose precision
Jan 28, 2024
f4b3720
Merge branch 'devel' into devel
anyangml Jan 29, 2024
451916e
Merge branch 'devel' into devel
anyangml Jan 29, 2024
0968eaa
Merge branch 'devel' into devel
anyangml Jan 29, 2024
6b0559e
Merge branch 'devel' into devel
anyangml Jan 29, 2024
8cbb98c
chore: refactor common method to PairTab
Jan 29, 2024
a08092c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
d3090b9
fix: update unit tests
Jan 29, 2024
daf2fc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
399c278
fix: revert padding zero mask change
Jan 29, 2024
59abe43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
8f1cdc8
Merge branch 'devel' into devel
anyangml Jan 30, 2024
88936cc
Merge branch 'devel' into devel
anyangml Jan 30, 2024
1c4ee0d
feat: redo extrapolation with cubic spline for smoothness
Jan 30, 2024
5793828
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2024
27f3559
Merge branch 'devel' into devel
anyangml Jan 30, 2024
92dec18
chore: refactor _make_data in PairTab
Jan 30, 2024
bc04359
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2024
4433035
chore: move file
Jan 30, 2024
2ba0318
Merge branch 'devel' into devel
anyangml Jan 30, 2024
f2c40e6
Merge branch 'devel' into devel
anyangml Jan 31, 2024
4851a0a
chore: refactor extrapolation code
Jan 31, 2024
ddbe7db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2024
29d95db
Merge branch 'devel' into devel
anyangml Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
374 changes: 374 additions & 0 deletions deepmd/pt/model/model/pair_tab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (

Check warning on line 3 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L2-L3

Added lines #L2 - L3 were not covered by tests
Dict,
List,
Optional,
Union,
)

import numpy as np
import torch
from torch import (

Check warning on line 12 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L10-L12

Added lines #L10 - L12 were not covered by tests
nn,
)

from deepmd.model_format import (

Check warning on line 16 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L16

Added line #L16 was not covered by tests
FittingOutputDef,
OutputVariableDef,
)
from deepmd.utils.pair_tab import (

Check warning on line 20 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L20

Added line #L20 was not covered by tests
PairTab,
)

from .atomic_model import (

Check warning on line 24 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L24

Added line #L24 was not covered by tests
AtomicModel,
)


class PairTabModel(nn.Module, AtomicModel):

Check warning on line 29 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L29

Added line #L29 was not covered by tests
"""Pairwise tabulation energy model.

This model can be used to tabulate the pairwise energy between atoms for either
short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not
be used alone, but rather as one submodel of a linear (sum) model, such as
DP+D3.

Do not put the model on the first model of a linear model, since the linear
model fetches the type map from the first model.

At this moment, the model does not smooth the energy at the cutoff radius, so
one needs to make sure the energy has been smoothed to zero.

Parameters
----------
tab_file : str
The path to the tabulation file.
rcut : float
The cutoff radius.
sel : int or list[int]
The maxmum number of atoms in the cut-off radius.
"""

def __init__(

Check warning on line 53 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L53

Added line #L53 was not covered by tests
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
):
super().__init__()
self.tab_file = tab_file
self.rcut = rcut

Check warning on line 58 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L56-L58

Added lines #L56 - L58 were not covered by tests

# check table data against rcut and update tab_file if needed.
self._check_table_upper_boundary()

Check warning on line 61 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L61

Added line #L61 was not covered by tests

self.tab = PairTab(self.tab_file)
self.ntypes = self.tab.ntypes

Check warning on line 64 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L63-L64

Added lines #L63 - L64 were not covered by tests

tab_info, tab_data = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)

Check warning on line 68 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L66-L68

Added lines #L66 - L68 were not covered by tests

# self.model_type = "ener"
# self.model_version = MODEL_VERSION ## this shoud be in the parent class

if isinstance(sel, int):
self.sel = sel
elif isinstance(sel, list):
self.sel = sum(sel)

Check warning on line 76 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L73-L76

Added lines #L73 - L76 were not covered by tests
else:
raise TypeError("sel must be int or list[int]")

Check warning on line 78 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L78

Added line #L78 was not covered by tests

def get_fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(

Check warning on line 81 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L80-L81

Added lines #L80 - L81 were not covered by tests
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
)
]
)

def get_rcut(self) -> float:
return self.rcut

Check warning on line 90 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L89-L90

Added lines #L89 - L90 were not covered by tests

def get_sel(self) -> int:
return self.sel

Check warning on line 93 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L92-L93

Added lines #L92 - L93 were not covered by tests

def distinguish_types(self) -> bool:

Check warning on line 95 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L95

Added line #L95 was not covered by tests
# to match DPA1 and DPA2.
return False

Check warning on line 97 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L97

Added line #L97 was not covered by tests

def forward_atomic(

Check warning on line 99 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L99

Added line #L99 was not covered by tests
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
nframes, nloc, nnei = nlist.shape

Check warning on line 107 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L107

Added line #L107 was not covered by tests

# this will mask all -1 in the nlist
masked_nlist = torch.clamp(nlist, 0)

Check warning on line 110 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L110

Added line #L110 was not covered by tests

atype = extended_atype[:, :nloc] # (nframes, nloc)
pairwise_dr = self._get_pairwise_dist(

Check warning on line 113 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L112-L113

Added lines #L112 - L113 were not covered by tests
extended_coord
) # (nframes, nall, nall, 3)
pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall)

Check warning on line 116 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L116

Added line #L116 was not covered by tests

self.tab_data = self.tab_data.reshape(

Check warning on line 118 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L118

Added line #L118 was not covered by tests
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

# to calculate the atomic_energy, we need 3 tensors, i_type, j_type, rr
# i_type : (nframes, nloc), this is atype.
# j_type : (nframes, nloc, nnei)
j_type = extended_atype[

Check warning on line 125 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L125

Added line #L125 was not covered by tests
torch.arange(extended_atype.size(0))[:, None, None], masked_nlist
]

# slice rr to get (nframes, nloc, nnei)
rr = torch.gather(pairwise_rr[:, :nloc, :], 2, masked_nlist)

Check warning on line 130 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L130

Added line #L130 was not covered by tests

raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr)

Check warning on line 132 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L132

Added line #L132 was not covered by tests

atomic_energy = 0.5 * torch.sum(

Check warning on line 134 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L134

Added line #L134 was not covered by tests
torch.where(
nlist != -1, raw_atomic_energy, torch.zeros_like(raw_atomic_energy)
),
dim=-1,
)

return {"energy": atomic_energy}

Check warning on line 141 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L141

Added line #L141 was not covered by tests

def _check_table_upper_boundary(self):

Check warning on line 143 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L143

Added line #L143 was not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved
"""Update User Provided Table Based on `rcut`.

This function checks the upper boundary provided in the table against rcut.
If the table upper boundary values decay to zero before rcut, padding zeros will
be add to the table to cover rcut; if the table upper boundary values do not decay to zero
before ruct, linear extrapolation will be performed to rcut. In both cases, the table file
will be overwritten.

Examples
--------
table = [[0.005 1. 2. 3. ]
[0.01 0.8 1.6 2.4 ]
[0.015 0. 1. 1.5 ]]

rcut = 0.022

new_table = [[0.005 1. 2. 3. ]
[0.01 0.8 1.6 2.4 ]
[0.015 0. 1. 1.5 ]
[0.02 0. 0.5 0.75 ]
[0.025 0. 0. 0. ]]
anyangml marked this conversation as resolved.
Show resolved Hide resolved

----------------------------------------------

table = [[0.005 1. 2. 3. ]
[0.01 0.8 1.6 2.4 ]
[0.015 0.5 1. 1.5 ]
[0.02 0.25 0.4 0.75 ]
[0.025 0. 0.1 0. ]
[0.03 0. 0. 0. ]]

rcut = 0.031

new_table = [[0.005 1. 2. 3. ]
[0.01 0.8 1.6 2.4 ]
[0.015 0.5 1. 1.5 ]
[0.02 0.25 0.4 0.75 ]
[0.025 0. 0.1 0. ]
[0.03 0. 0. 0. ]
[0.035 0. 0. 0. ]]
anyangml marked this conversation as resolved.
Show resolved Hide resolved
"""
raw_data = np.loadtxt(self.tab_file)
upper = raw_data[-1][0]
upper_val = raw_data[-1][1:]
upper_idx = raw_data.shape[0] - 1
increment = raw_data[1][0] - raw_data[0][0]

Check warning on line 189 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L185-L189

Added lines #L185 - L189 were not covered by tests

# the index of table for the grid point right after rcut
rcut_idx = int(self.rcut / increment)

Check warning on line 192 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L192

Added line #L192 was not covered by tests

if np.all(upper_val == 0):

Check warning on line 194 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L194

Added line #L194 was not covered by tests
# if table values decay to `0` after rcut
if self.rcut < upper and np.any(raw_data[rcut_idx - 1][1:] != 0):
logging.warning(

Check warning on line 197 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L196-L197

Added lines #L196 - L197 were not covered by tests
"The energy provided in the table does not decay to 0 at rcut."
)
# if table values decay to `0` at rcut, do nothing

# if table values decay to `0` before rcut, pad table with `0`s.
elif self.rcut > upper:
pad_zero = np.zeros((rcut_idx - upper_idx, 4))
pad_zero[:, 0] = np.linspace(

Check warning on line 205 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L203-L205

Added lines #L203 - L205 were not covered by tests
upper + increment, increment * (rcut_idx + 1), rcut_idx - upper_idx
)
raw_data = np.concatenate((raw_data, pad_zero), axis=0)

Check warning on line 208 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L208

Added line #L208 was not covered by tests
else:
# if table values do not decay to `0` at rcut
if self.rcut <= upper:
logging.warning(

Check warning on line 212 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L211-L212

Added lines #L211 - L212 were not covered by tests
"The energy provided in the table does not decay to 0 at rcut."
)
# if rcut goes beyond table upper bond, need extrapolation.
else:
logging.warning(

Check warning on line 217 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L217

Added line #L217 was not covered by tests
"The rcut goes beyond table upper boundary, performing linear extrapolation."
)
pad_linear = np.zeros((rcut_idx - upper_idx + 1, 4))
pad_linear[:, 0] = np.linspace(

Check warning on line 221 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L220-L221

Added lines #L220 - L221 were not covered by tests
upper, increment * (rcut_idx + 1), rcut_idx - upper_idx + 1
)
pad_linear[:, 1:] = np.array(

Check warning on line 224 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L224

Added line #L224 was not covered by tests
[
np.linspace(start, 0, rcut_idx - upper_idx + 1)
for start in upper_val
]
).T
raw_data = np.concatenate((raw_data[:-1, :], pad_linear), axis=0)

Check warning on line 230 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L230

Added line #L230 was not covered by tests

# over writing file with padding if applicable.
with open(self.tab_file, "wb") as f:
np.savetxt(f, raw_data)

Check warning on line 234 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L233-L234

Added lines #L233 - L234 were not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved

def _pair_tabulated_inter(

Check warning on line 236 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L236

Added line #L236 was not covered by tests
self,
nlist: torch.Tensor,
i_type: torch.Tensor,
j_type: torch.Tensor,
rr: torch.Tensor,
) -> torch.Tensor:
"""Pairwise tabulated energy.

Parameters
----------
nlist : torch.Tensor
The unmasked neighbour list. (nframes, nloc)
i_type : torch.Tensor
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : torch.Tensor
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
rr : torch.Tensor
The salar distance vector between two atoms. (nframes, nloc, nnei)

Returns
-------
torch.Tensor
The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei)

Raises
------
Exception
If the distance is beyond the table.

Notes
-----
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""
rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1.0 / hh

Check warning on line 273 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L271-L273

Added lines #L271 - L273 were not covered by tests

self.nspline = int(self.tab_info[2] + 0.1)

Check warning on line 275 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L275

Added line #L275 was not covered by tests

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

Check warning on line 277 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L277

Added line #L277 was not covered by tests

# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handel the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
uu = torch.where(nlist != -1, uu, self.nspline + 1)

Check warning on line 281 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L281

Added line #L281 was not covered by tests

if torch.any(uu < 0):
raise Exception("coord go beyond table lower boundary")

Check warning on line 284 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L283-L284

Added lines #L283 - L284 were not covered by tests

idx = uu.to(torch.int)

Check warning on line 286 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L286

Added line #L286 was not covered by tests

uu -= idx

Check warning on line 288 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L288

Added line #L288 was not covered by tests

final_coef = self._extract_spline_coefficient(i_type, j_type, idx)

Check warning on line 290 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L290

Added line #L290 was not covered by tests

a3, a2, a1, a0 = torch.unbind(final_coef, dim=-1) # 4 * (nframes, nloc, nnei)

Check warning on line 292 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L292

Added line #L292 was not covered by tests

etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations.
ener = etmp * uu + a0
return ener

Check warning on line 296 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L294-L296

Added lines #L294 - L296 were not covered by tests

@staticmethod
def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:

Check warning on line 299 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L298-L299

Added lines #L298 - L299 were not covered by tests
"""Get pairwise distance `dr`.

Parameters
----------
coords : torch.Tensor
The coordinate of the atoms shape of (nframes * nall * 3).

Returns
-------
torch.Tensor
The pairwise distance between the atoms (nframes * nall * nall * 3).

Examples
--------
coords = torch.tensor([[
[0,0,0],
[1,3,5],
[2,4,6]
]])

dist = tensor([[
[[ 0, 0, 0],
[-1, -3, -5],
[-2, -4, -6]],

[[ 1, 3, 5],
[ 0, 0, 0],
[-1, -1, -1]],

[[ 2, 4, 6],
[ 1, 1, 1],
[ 0, 0, 0]]
]])
"""
return coords.unsqueeze(2) - coords.unsqueeze(1)

Check warning on line 334 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L334

Added line #L334 was not covered by tests

def _extract_spline_coefficient(

Check warning on line 336 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L336

Added line #L336 was not covered by tests
self, i_type: torch.Tensor, j_type: torch.Tensor, idx: torch.Tensor
) -> torch.Tensor:
"""Extract the spline coefficient from the table.

Parameters
----------
i_type : torch.Tensor
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : torch.Tensor
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
idx : torch.Tensor
The index of the spline coefficient. (nframes, nloc, nnei)

Returns
-------
torch.Tensor
The spline coefficient. (nframes, nloc, nnei, 4)

"""
# (nframes, nloc, nnei)
expanded_i_type = i_type.unsqueeze(-1).expand(-1, -1, j_type.shape[-1])

Check warning on line 357 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L357

Added line #L357 was not covered by tests

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = self.tab_data[expanded_i_type, j_type]

Check warning on line 360 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L360

Added line #L360 was not covered by tests

# (nframes, nloc, nnei, 1, 4)
expanded_idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, -1, 4)

Check warning on line 363 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L363

Added line #L363 was not covered by tests

# handle the case where idx is beyond the number of splines
clipped_indices = torch.clamp(expanded_idx, 0, self.nspline - 1).to(torch.int64)

Check warning on line 366 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L366

Added line #L366 was not covered by tests

# (nframes, nloc, nnei, 4)
final_coef = torch.gather(expanded_tab_data, 3, clipped_indices).squeeze()

Check warning on line 369 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L369

Added line #L369 was not covered by tests

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() >= self.nspline] = 0

Check warning on line 372 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L372

Added line #L372 was not covered by tests

return final_coef

Check warning on line 374 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L374

Added line #L374 was not covered by tests
Loading
Loading