-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat: add pair table model to pytorch (#3192)
Migrated from this [PR](dptech-corp/deepmd-pytorch#174). This is to reimplement the PairTab Model in Pytorch. Notes: 1. Different from the tensorflow version, the pytorch version abstracts away all the post energy conversion operations (force, virial). 2. Added extrapolation when `rcut` > `rmax`. The pytorch version overwrite energy beyond extrapolation endpoint to `0`. These features are not available in the tensorflow version. The extrapolation uses a cubic spline form, the 1st order derivation for the starting point is estimated using the last two rows in the user defined table. See example below: ![img_v3_027k_b50c690d-dc2d-4803-bd2c-2e73aa3c73fg](https://github.com/deepmodeling/deepmd-kit/assets/137014849/f3efa4d3-795e-4ff8-acdc-642227f0e19c) ![img_v3_027k_8de38597-ef4e-4e5b-989e-dbd13cc93fag](https://github.com/deepmodeling/deepmd-kit/assets/137014849/493da26d-f01d-4dd0-8520-ea2d84e7b548) ![img_v3_027k_f8268564-3f5d-49e6-91d6-169a61d9347g](https://github.com/deepmodeling/deepmd-kit/assets/137014849/b8ad4d4d-a4a4-40f0-94d1-810006e7175b) ![img_v3_027k_3966ef67-dd5e-4f48-992e-c2763311451g](https://github.com/deepmodeling/deepmd-kit/assets/137014849/27f31e79-13c8-4ce8-9911-b4cc0ac8188c) --------- Co-authored-by: Anyang Peng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
d2edb77
commit afb440a
Showing
4 changed files
with
914 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from typing import ( | ||
Dict, | ||
List, | ||
Optional, | ||
Union, | ||
) | ||
|
||
import torch | ||
from torch import ( | ||
nn, | ||
) | ||
|
||
from deepmd.model_format import ( | ||
FittingOutputDef, | ||
OutputVariableDef, | ||
) | ||
from deepmd.utils.pair_tab import ( | ||
PairTab, | ||
) | ||
|
||
from .atomic_model import ( | ||
AtomicModel, | ||
) | ||
|
||
|
||
class PairTabModel(nn.Module, AtomicModel): | ||
"""Pairwise tabulation energy model. | ||
This model can be used to tabulate the pairwise energy between atoms for either | ||
short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not | ||
be used alone, but rather as one submodel of a linear (sum) model, such as | ||
DP+D3. | ||
Do not put the model on the first model of a linear model, since the linear | ||
model fetches the type map from the first model. | ||
At this moment, the model does not smooth the energy at the cutoff radius, so | ||
one needs to make sure the energy has been smoothed to zero. | ||
Parameters | ||
---------- | ||
tab_file : str | ||
The path to the tabulation file. | ||
rcut : float | ||
The cutoff radius. | ||
sel : int or list[int] | ||
The maxmum number of atoms in the cut-off radius. | ||
""" | ||
|
||
def __init__( | ||
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs | ||
): | ||
super().__init__() | ||
self.tab_file = tab_file | ||
self.rcut = rcut | ||
|
||
self.tab = PairTab(self.tab_file, rcut=rcut) | ||
self.ntypes = self.tab.ntypes | ||
|
||
tab_info, tab_data = self.tab.get() # this returns -> Tuple[np.array, np.array] | ||
self.tab_info = torch.from_numpy(tab_info) | ||
self.tab_data = torch.from_numpy(tab_data) | ||
|
||
# self.model_type = "ener" | ||
# self.model_version = MODEL_VERSION ## this shoud be in the parent class | ||
|
||
if isinstance(sel, int): | ||
self.sel = sel | ||
elif isinstance(sel, list): | ||
self.sel = sum(sel) | ||
else: | ||
raise TypeError("sel must be int or list[int]") | ||
|
||
def get_fitting_output_def(self) -> FittingOutputDef: | ||
return FittingOutputDef( | ||
[ | ||
OutputVariableDef( | ||
name="energy", shape=[1], reduciable=True, differentiable=True | ||
) | ||
] | ||
) | ||
|
||
def get_rcut(self) -> float: | ||
return self.rcut | ||
|
||
def get_sel(self) -> int: | ||
return self.sel | ||
|
||
def distinguish_types(self) -> bool: | ||
# to match DPA1 and DPA2. | ||
return False | ||
|
||
def forward_atomic( | ||
self, | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping: Optional[torch.Tensor] = None, | ||
do_atomic_virial: bool = False, | ||
) -> Dict[str, torch.Tensor]: | ||
self.nframes, self.nloc, self.nnei = nlist.shape | ||
|
||
# this will mask all -1 in the nlist | ||
masked_nlist = torch.clamp(nlist, 0) | ||
|
||
atype = extended_atype[:, : self.nloc] # (nframes, nloc) | ||
pairwise_dr = self._get_pairwise_dist( | ||
extended_coord | ||
) # (nframes, nall, nall, 3) | ||
pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall) | ||
|
||
self.tab_data = self.tab_data.reshape( | ||
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 | ||
) | ||
|
||
# to calculate the atomic_energy, we need 3 tensors, i_type, j_type, rr | ||
# i_type : (nframes, nloc), this is atype. | ||
# j_type : (nframes, nloc, nnei) | ||
j_type = extended_atype[ | ||
torch.arange(extended_atype.size(0))[:, None, None], masked_nlist | ||
] | ||
|
||
# slice rr to get (nframes, nloc, nnei) | ||
rr = torch.gather(pairwise_rr[:, : self.nloc, :], 2, masked_nlist) | ||
|
||
raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr) | ||
|
||
atomic_energy = 0.5 * torch.sum( | ||
torch.where( | ||
nlist != -1, raw_atomic_energy, torch.zeros_like(raw_atomic_energy) | ||
), | ||
dim=-1, | ||
) | ||
|
||
return {"energy": atomic_energy} | ||
|
||
def _pair_tabulated_inter( | ||
self, | ||
nlist: torch.Tensor, | ||
i_type: torch.Tensor, | ||
j_type: torch.Tensor, | ||
rr: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""Pairwise tabulated energy. | ||
Parameters | ||
---------- | ||
nlist : torch.Tensor | ||
The unmasked neighbour list. (nframes, nloc) | ||
i_type : torch.Tensor | ||
The integer representation of atom type for all local atoms for all frames. (nframes, nloc) | ||
j_type : torch.Tensor | ||
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei) | ||
rr : torch.Tensor | ||
The salar distance vector between two atoms. (nframes, nloc, nnei) | ||
Returns | ||
------- | ||
torch.Tensor | ||
The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei) | ||
Raises | ||
------ | ||
Exception | ||
If the distance is beyond the table. | ||
Notes | ||
----- | ||
This function is used to calculate the pairwise energy between two atoms. | ||
It uses a table containing cubic spline coefficients calculated in PairTab. | ||
""" | ||
rmin = self.tab_info[0] | ||
hh = self.tab_info[1] | ||
hi = 1.0 / hh | ||
|
||
self.nspline = int(self.tab_info[2] + 0.1) | ||
|
||
uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei) | ||
|
||
# if nnei of atom 0 has -1 in the nlist, uu would be 0. | ||
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms. | ||
uu = torch.where(nlist != -1, uu, self.nspline + 1) | ||
|
||
if torch.any(uu < 0): | ||
raise Exception("coord go beyond table lower boundary") | ||
|
||
idx = uu.to(torch.int) | ||
|
||
uu -= idx | ||
|
||
table_coef = self._extract_spline_coefficient( | ||
i_type, j_type, idx, self.tab_data, self.nspline | ||
) | ||
table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4) | ||
ener = self._calcualte_ener(table_coef, uu) | ||
|
||
# here we need to overwrite energy to zero at rcut and beyond. | ||
mask_beyond_rcut = rr >= self.rcut | ||
# also overwrite values beyond extrapolation to zero | ||
extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh | ||
ener[mask_beyond_rcut] = 0 | ||
ener[extrapolation_mask] = 0 | ||
|
||
return ener | ||
|
||
@staticmethod | ||
def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor: | ||
"""Get pairwise distance `dr`. | ||
Parameters | ||
---------- | ||
coords : torch.Tensor | ||
The coordinate of the atoms shape of (nframes * nall * 3). | ||
Returns | ||
------- | ||
torch.Tensor | ||
The pairwise distance between the atoms (nframes * nall * nall * 3). | ||
Examples | ||
-------- | ||
coords = torch.tensor([[ | ||
[0,0,0], | ||
[1,3,5], | ||
[2,4,6] | ||
]]) | ||
dist = tensor([[ | ||
[[ 0, 0, 0], | ||
[-1, -3, -5], | ||
[-2, -4, -6]], | ||
[[ 1, 3, 5], | ||
[ 0, 0, 0], | ||
[-1, -1, -1]], | ||
[[ 2, 4, 6], | ||
[ 1, 1, 1], | ||
[ 0, 0, 0]] | ||
]]) | ||
""" | ||
return coords.unsqueeze(2) - coords.unsqueeze(1) | ||
|
||
@staticmethod | ||
def _extract_spline_coefficient( | ||
i_type: torch.Tensor, | ||
j_type: torch.Tensor, | ||
idx: torch.Tensor, | ||
tab_data: torch.Tensor, | ||
nspline: int, | ||
) -> torch.Tensor: | ||
"""Extract the spline coefficient from the table. | ||
Parameters | ||
---------- | ||
i_type : torch.Tensor | ||
The integer representation of atom type for all local atoms for all frames. (nframes, nloc) | ||
j_type : torch.Tensor | ||
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei) | ||
idx : torch.Tensor | ||
The index of the spline coefficient. (nframes, nloc, nnei) | ||
tab_data : torch.Tensor | ||
The table storing all the spline coefficient. (ntype, ntype, nspline, 4) | ||
nspline : int | ||
The number of splines in the table. | ||
Returns | ||
------- | ||
torch.Tensor | ||
The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed. | ||
""" | ||
# (nframes, nloc, nnei) | ||
expanded_i_type = i_type.unsqueeze(-1).expand(-1, -1, j_type.shape[-1]) | ||
|
||
# (nframes, nloc, nnei, nspline, 4) | ||
expanded_tab_data = tab_data[expanded_i_type, j_type] | ||
|
||
# (nframes, nloc, nnei, 1, 4) | ||
expanded_idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, -1, 4) | ||
|
||
# handle the case where idx is beyond the number of splines | ||
clipped_indices = torch.clamp(expanded_idx, 0, nspline - 1).to(torch.int64) | ||
|
||
# (nframes, nloc, nnei, 4) | ||
final_coef = torch.gather(expanded_tab_data, 3, clipped_indices).squeeze() | ||
|
||
# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`. | ||
final_coef[expanded_idx.squeeze() > nspline] = 0 | ||
return final_coef | ||
|
||
@staticmethod | ||
def _calcualte_ener(coef: torch.Tensor, uu: torch.Tensor) -> torch.Tensor: | ||
"""Calculate energy using spline coeeficients. | ||
Parameters | ||
---------- | ||
coef : torch.Tensor | ||
The spline coefficients. (nframes, nloc, nnei, 4) | ||
uu : torch.Tensor | ||
The atom displancemnt used in interpolation and extrapolation (nframes, nloc, nnei) | ||
Returns | ||
------- | ||
torch.Tensor | ||
The atomic energy for all local atoms for all frames. (nframes, nloc, nnei) | ||
""" | ||
a3, a2, a1, a0 = torch.unbind(coef, dim=-1) | ||
etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations. | ||
ener = etmp * uu + a0 # this energy has the extrapolated value when rcut > rmax | ||
return ener |
Oops, something went wrong.