Skip to content

Commit

Permalink
feat: add pair table model to pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
Anyang Peng authored and Anyang Peng committed Jan 28, 2024
1 parent 3e4715f commit 47cff4b
Show file tree
Hide file tree
Showing 2 changed files with 425 additions and 0 deletions.
351 changes: 351 additions & 0 deletions deepmd/pt/model/model/pair_tab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
from .atomic_model import AtomicModel
from deepmd.utils.pair_tab import (

Check warning on line 2 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L1-L2

Added lines #L1 - L2 were not covered by tests
PairTab,
)
import logging
import torch
from torch import nn
import numpy as np
from typing import Dict, List, Optional, Union

Check warning on line 9 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L5-L9

Added lines #L5 - L9 were not covered by tests

from deepmd.model_format import FittingOutputDef, OutputVariableDef

Check warning on line 11 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L11

Added line #L11 was not covered by tests

class PairTabModel(nn.Module, AtomicModel):

Check warning on line 13 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L13

Added line #L13 was not covered by tests
"""Pairwise tabulation energy model.
This model can be used to tabulate the pairwise energy between atoms for either
short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not
be used alone, but rather as one submodel of a linear (sum) model, such as
DP+D3.
Do not put the model on the first model of a linear model, since the linear
model fetches the type map from the first model.
At this moment, the model does not smooth the energy at the cutoff radius, so
one needs to make sure the energy has been smoothed to zero.
Parameters
----------
tab_file : str
The path to the tabulation file.
rcut : float
The cutoff radius.
sel : int or list[int]
The maxmum number of atoms in the cut-off radius.
"""

def __init__(

Check warning on line 37 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L37

Added line #L37 was not covered by tests
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
):
super().__init__()
self.tab_file = tab_file
self.rcut = rcut

Check warning on line 42 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L40-L42

Added lines #L40 - L42 were not covered by tests

# check table data against rcut and update tab_file if needed.
self._check_table_upper_boundary()

Check warning on line 45 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L45

Added line #L45 was not covered by tests

self.tab = PairTab(self.tab_file)
self.ntypes = self.tab.ntypes

Check warning on line 48 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L47-L48

Added lines #L47 - L48 were not covered by tests


tab_info, tab_data = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)

Check warning on line 53 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L51-L53

Added lines #L51 - L53 were not covered by tests

# self.model_type = "ener"
# self.model_version = MODEL_VERSION ## this shoud be in the parent class

if isinstance(sel, int):
self.sel = sel
elif isinstance(sel, list):
self.sel = sum(sel)

Check warning on line 61 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L58-L61

Added lines #L58 - L61 were not covered by tests
else:
raise TypeError("sel must be int or list[int]")

Check warning on line 63 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L63

Added line #L63 was not covered by tests

def get_fitting_output_def(self)->FittingOutputDef:
return FittingOutputDef(

Check warning on line 66 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L65-L66

Added lines #L65 - L66 were not covered by tests
[OutputVariableDef(
name="energy",
shape=[1],
reduciable=True,
differentiable=True
)]
)

def get_rcut(self)->float:
return self.rcut

Check warning on line 76 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L75-L76

Added lines #L75 - L76 were not covered by tests

def get_sel(self)->int:
return self.sel

Check warning on line 79 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L78-L79

Added lines #L78 - L79 were not covered by tests

def distinguish_types(self)->bool:

Check warning on line 81 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L81

Added line #L81 was not covered by tests
# to match DPA1 and DPA2.
return False

Check warning on line 83 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L83

Added line #L83 was not covered by tests

def forward_atomic(

Check warning on line 85 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L85

Added line #L85 was not covered by tests
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:


nframes, nloc, nnei = nlist.shape

Check warning on line 95 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L95

Added line #L95 was not covered by tests

#this will mask all -1 in the nlist
masked_nlist = torch.clamp(nlist,0)

Check warning on line 98 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L98

Added line #L98 was not covered by tests

atype = extended_atype[:, :nloc] #(nframes, nloc)
pairwise_dr = self._get_pairwise_dist(extended_coord) # (nframes, nall, nall, 3)
pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall)

Check warning on line 102 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L100-L102

Added lines #L100 - L102 were not covered by tests

self.tab_data = self.tab_data.reshape(self.tab.ntypes,self.tab.ntypes,self.tab.nspline,4)

Check warning on line 104 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L104

Added line #L104 was not covered by tests

#to calculate the atomic_energy, we need 3 tensors, i_type, j_type, rr
#i_type : (nframes, nloc), this is atype.
#j_type : (nframes, nloc, nnei)
j_type = extended_atype[torch.arange(extended_atype.size(0))[:, None, None], masked_nlist]

Check warning on line 109 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L109

Added line #L109 was not covered by tests

#slice rr to get (nframes, nloc, nnei)
rr = torch.gather(pairwise_rr[:, :nloc, :],2, masked_nlist)

Check warning on line 112 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L112

Added line #L112 was not covered by tests

raw_atomic_energy = self._pair_tabulated_inter(atype, j_type, rr)

Check failure

Code scanning / CodeQL

Wrong number of arguments in a call Error

Call to
method PairTabModel._pair_tabulated_inter
with too few arguments; should be no fewer than 4.

Check warning on line 114 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L114

Added line #L114 was not covered by tests

atomic_energy = 0.5 * torch.sum(torch.where(nlist != -1, raw_atomic_energy, torch.zeros_like(raw_atomic_energy)) ,dim=-1)

Check warning on line 116 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L116

Added line #L116 was not covered by tests

return {"energy": atomic_energy}

Check warning on line 118 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L118

Added line #L118 was not covered by tests

def _check_table_upper_boundary(self):

Check warning on line 120 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L120

Added line #L120 was not covered by tests
"""Update User Provided Table Based on `rcut`
This function checks the upper boundary provided in the table against rcut.
If the table upper boundary values decay to zero before rcut, padding zeros will
be add to the table to cover rcut; if the table upper boundary values do not decay to zero
before ruct, linear extrapolation will be performed to rcut. In both cases, the table file
will be overwritten.
Example
-------
table = [[0.005 1. 2. 3. ]
[0.01 0.8 1.6 2.4 ]
[0.015 0. 1. 1.5 ]]
rcut = 0.022
new_table = [[0.005 1. 2. 3. ]
[0.01 0.8 1.6 2.4 ]
[0.015 0. 1. 1.5 ]
[0.02 0. 0.5 0.75 ]
[0.025 0. 0. 0. ]]
----------------------------------------------
table = [[0.005 1. 2. 3. ]
[0.01 0.8 1.6 2.4 ]
[0.015 0.5 1. 1.5 ]
[0.02 0.25 0.4 0.75 ]
[0.025 0. 0.1 0. ]
[0.03 0. 0. 0. ]]
rcut = 0.031
new_table = [[0.005 1. 2. 3. ]
[0.01 0.8 1.6 2.4 ]
[0.015 0.5 1. 1.5 ]
[0.02 0.25 0.4 0.75 ]
[0.025 0. 0.1 0. ]
[0.03 0. 0. 0. ]
[0.035 0. 0. 0. ]]
"""

raw_data = np.loadtxt(self.tab_file)
upper = raw_data[-1][0]
upper_val = raw_data[-1][1:]
upper_idx = raw_data.shape[0] - 1
increment = raw_data[1][0] - raw_data[0][0]

Check warning on line 168 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L164-L168

Added lines #L164 - L168 were not covered by tests

#the index of table for the grid point right after rcut
rcut_idx = int(self.rcut/increment)

Check warning on line 171 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L171

Added line #L171 was not covered by tests

if np.all(upper_val == 0):

Check warning on line 173 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L173

Added line #L173 was not covered by tests
# if table values decay to `0` after rcut
if self.rcut < upper and np.any(raw_data[rcut_idx-1][1:]!=0):
logging.warning(

Check warning on line 176 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L175-L176

Added lines #L175 - L176 were not covered by tests
"The energy provided in the table does not decay to 0 at rcut."
)
# if table values decay to `0` at rcut, do nothing

# if table values decay to `0` before rcut, pad table with `0`s.
elif self.rcut > upper:
pad_zero = np.zeros((rcut_idx - upper_idx,4))
pad_zero[:,0] = np.linspace(upper + increment, increment*(rcut_idx+1), rcut_idx-upper_idx)
raw_data = np.concatenate((raw_data,pad_zero),axis=0)

Check warning on line 185 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L182-L185

Added lines #L182 - L185 were not covered by tests
else:
# if table values do not decay to `0` at rcut
if self.rcut <= upper:
logging.warning(

Check warning on line 189 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L188-L189

Added lines #L188 - L189 were not covered by tests
"The energy provided in the table does not decay to 0 at rcut."
)
# if rcut goes beyond table upper bond, need extrapolation.
else:
logging.warning(

Check warning on line 194 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L194

Added line #L194 was not covered by tests
"The rcut goes beyond table upper boundary, performing linear extrapolation."
)
pad_linear = np.zeros((rcut_idx - upper_idx+1,4))
pad_linear[:,0] = np.linspace(upper, increment*(rcut_idx+1), rcut_idx-upper_idx+1)
pad_linear[:,1:] = np.array([

Check warning on line 199 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L197-L199

Added lines #L197 - L199 were not covered by tests
np.linspace(start, 0, rcut_idx - upper_idx+1) for start in upper_val
]).T
raw_data = np.concatenate((raw_data[:-1,:],pad_zero),axis=0)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'pad_zero' may be used before it is initialized.

Check warning on line 202 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L202

Added line #L202 was not covered by tests

#over writing file with padding if applicable.
with open(self.tab_file, 'wb') as f:
np.savetxt(f, raw_data)

Check warning on line 206 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L205-L206

Added lines #L205 - L206 were not covered by tests

def _pair_tabulated_inter(self, nlist: torch.Tensor,i_type: torch.Tensor, j_type: torch.Tensor, rr: torch.Tensor) -> torch.Tensor:

Check warning on line 208 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L208

Added line #L208 was not covered by tests
"""Pairwise tabulated energy.
Parameters
----------
nlist : torch.Tensor
The unmasked neighbour list. (nframes, nloc)
i_type : torch.Tensor
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : torch.Tensor
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
rr : torch.Tensor
The salar distance vector between two atoms. (nframes, nloc, nnei)
Returns
-------
torch.Tensor
The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei)
Raises
------
Exception
If the distance is beyond the table.
Notes
-----
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""

rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1. / hh

Check warning on line 243 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L241-L243

Added lines #L241 - L243 were not covered by tests

self.nspline = int(self.tab_info[2] + 0.1)

Check warning on line 245 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L245

Added line #L245 was not covered by tests

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

Check warning on line 247 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L247

Added line #L247 was not covered by tests


# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handel the nlist where the mask is set to 0.
# by replacing the values wiht nspline + 1, the energy contribution will be 0
uu = torch.where(nlist != -1, uu, self.nspline+1)

Check warning on line 253 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L253

Added line #L253 was not covered by tests

if torch.any(uu < 0):
raise Exception("coord go beyond table lower boundary")

Check warning on line 256 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L255-L256

Added lines #L255 - L256 were not covered by tests

idx = uu.to(torch.int)

Check warning on line 258 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L258

Added line #L258 was not covered by tests

uu -= idx

Check warning on line 260 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L260

Added line #L260 was not covered by tests


final_coef = self._extract_spline_coefficient(i_type, j_type, idx)

Check warning on line 263 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L263

Added line #L263 was not covered by tests

a3, a2, a1, a0 = torch.unbind(final_coef, dim=-1) # 4 * (nframes, nloc, nnei)

Check warning on line 265 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L265

Added line #L265 was not covered by tests

etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations.
ener = etmp * uu + a0
return ener

Check warning on line 269 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L267-L269

Added lines #L267 - L269 were not covered by tests

@staticmethod
def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:

Check warning on line 272 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L271-L272

Added lines #L271 - L272 were not covered by tests
"""Get pairwise distance `dr`.
Parameters
----------
coords : torch.Tensor
The coordinate of the atoms shape of (nframes * nall * 3).
Returns
-------
torch.Tensor
The pairwise distance between the atoms (nframes * nall * nall * 3).
Examples
--------
coords = torch.tensor([[
[0,0,0],
[1,3,5],
[2,4,6]
]])
dist = tensor([[
[[ 0, 0, 0],
[-1, -3, -5],
[-2, -4, -6]],
[[ 1, 3, 5],
[ 0, 0, 0],
[-1, -1, -1]],
[[ 2, 4, 6],
[ 1, 1, 1],
[ 0, 0, 0]]
]])
"""
return coords.unsqueeze(2) - coords.unsqueeze(1)

Check warning on line 307 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L307

Added line #L307 was not covered by tests

def _extract_spline_coefficient(self, i_type: torch.Tensor, j_type: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:

Check warning on line 309 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L309

Added line #L309 was not covered by tests
"""Extract the spline coefficient from the table.
Parameters
----------
i_type : torch.Tensor
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : torch.Tensor
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
idx : torch.Tensor
The index of the spline coefficient. (nframes, nloc, nnei)
Returns
-------
torch.Tensor
The spline coefficient. (nframes, nloc, nnei, 4)
Example
-------
"""

# (nframes, nloc, nnei)
expanded_i_type = i_type.unsqueeze(-1).expand(-1, -1, j_type.shape[-1])

Check warning on line 334 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L334

Added line #L334 was not covered by tests

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = self.tab_data[expanded_i_type, j_type]

Check warning on line 337 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L337

Added line #L337 was not covered by tests

# (nframes, nloc, nnei, 1, 4)
expanded_idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1,-1, -1, 4)

Check warning on line 340 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L340

Added line #L340 was not covered by tests

#handle the case where idx is beyond the number of splines
clipped_indices = torch.clamp(expanded_idx, 0, self.nspline - 1).to(torch.int64)

Check warning on line 343 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L343

Added line #L343 was not covered by tests

# (nframes, nloc, nnei, 4)
final_coef = torch.gather(expanded_tab_data, 3, clipped_indices).squeeze()

Check warning on line 346 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L346

Added line #L346 was not covered by tests

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() >= self.nspline] = 0

Check warning on line 349 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L349

Added line #L349 was not covered by tests

return final_coef

Check warning on line 351 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L351

Added line #L351 was not covered by tests
Loading

0 comments on commit 47cff4b

Please sign in to comment.