Skip to content

Commit

Permalink
feat: add numpy version
Browse files Browse the repository at this point in the history
  • Loading branch information
Anyang Peng authored and Anyang Peng committed Feb 1, 2024
1 parent 74002ed commit 0ce23f4
Showing 1 changed file with 286 additions and 0 deletions.
286 changes: 286 additions & 0 deletions deepmd/dpmodel/model/pair_tab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
from typing import (
Dict,
List,
Optional,
Union,
)

import numpy as np

from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
)

from .base_atomic_model import (
BaseAtomicModel,
)

from deepmd.utils.pair_tab import (
PairTab,
)

class PairTabModel(BaseAtomicModel):
"""Pairwise tabulation energy model.
This model can be used to tabulate the pairwise energy between atoms for either
short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not
be used alone, but rather as one submodel of a linear (sum) model, such as
DP+D3.
Do not put the model on the first model of a linear model, since the linear
model fetches the type map from the first model.
At this moment, the model does not smooth the energy at the cutoff radius, so
one needs to make sure the energy has been smoothed to zero.
Parameters
----------
tab_file : str
The path to the tabulation file.
rcut : float
The cutoff radius.
sel : int or list[int]
The maxmum number of atoms in the cut-off radius.
"""

def __init__(
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
):
super().__init__()
self.tab_file = tab_file
self.rcut = rcut

self.tab = PairTab(self.tab_file, rcut=rcut)
self.ntypes = self.tab.ntypes

self.tab_info, self.tab_data = self.tab.get()

if isinstance(sel, int):
self.sel = sel
elif isinstance(sel, list):
self.sel = sum(sel)
else:
raise TypeError("sel must be int or list[int]")

def get_fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
)
]
)

def get_rcut(self) -> float:
return self.rcut

def get_sel(self) -> int:
return self.sel

def distinguish_types(self) -> bool:
# to match DPA1 and DPA2.
return False

def serialize(self) -> dict:
return {"tab_file": self.tab_file, "rcut": self.rcut, "sel": self.sel}

@classmethod
def deserialize(cls, data) -> "PairTabModel":
tab_file = data["tab_file"]
rcut = data["rcut"]
sel = data["sel"]
return cls(tab_file, rcut, sel)

def forward_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[np.array] = None,
do_atomic_virial: bool = False,
) -> Dict[str, np.array]:
self.nframes, self.nloc, self.nnei = nlist.shape

# this will mask all -1 in the nlist
masked_nlist = np.clip(nlist, 0)

atype = extended_atype[:, : self.nloc] # (nframes, nloc)
pairwise_dr = self._get_pairwise_dist(
extended_coord
) # (nframes, nall, nall, 3)
pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall)

self.tab_data = self.tab_data.reshape(
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

# (nframes, nloc, nnei)
j_type = extended_atype[
np.arange(extended_atype.size(0))[:, None, None], masked_nlist
]

# slice rr to get (nframes, nloc, nnei)
rr = np.take_along_axis(pairwise_rr[:, : self.nloc, :], masked_nlist, 2)

raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr)

atomic_energy = 0.5 * np.sum(
np.where(
nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)
),
dim=-1,
)

return {"energy": atomic_energy}

def _pair_tabulated_inter(
self,
nlist: np.array,
i_type: np.array,
j_type: np.array,
rr: np.array,
) -> np.array:
"""Pairwise tabulated energy.
Parameters
----------
nlist : np.array
The unmasked neighbour list. (nframes, nloc)
i_type : np.array
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : np.array
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
rr : np.array
The salar distance vector between two atoms. (nframes, nloc, nnei)
Returns
-------
np.array
The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei)
Raises
------
Exception
If the distance is beyond the table.
Notes
-----
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""
rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1.0 / hh

self.nspline = int(self.tab_info[2] + 0.1)

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
uu = np.where(nlist != -1, uu, self.nspline + 1)

if np.any(uu < 0):
raise Exception("coord go beyond table lower boundary")

idx = uu.astype(np.int)

uu -= idx

table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, self.nspline
)
table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4)
ener = self._calcualte_ener(table_coef, uu)

# here we need to overwrite energy to zero at rcut and beyond.
mask_beyond_rcut = rr >= self.rcut
# also overwrite values beyond extrapolation to zero
extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh
ener[mask_beyond_rcut] = 0
ener[extrapolation_mask] = 0

return ener

@staticmethod
def _get_pairwise_dist(coords: np.array) -> np.array:
"""Get pairwise distance `dr`.
Parameters
----------
coords : np.array
The coordinate of the atoms shape of (nframes * nall * 3).
Returns
-------
np.array
The pairwise distance between the atoms (nframes * nall * nall * 3).
"""
return np.expand_dims(coords, 2) - np.expand_dims(coords, 1)

@staticmethod
def _extract_spline_coefficient(
i_type: np.array,
j_type: np.array,
idx: np.array,
tab_data: np.array,
nspline: int,
) -> np.array:
"""Extract the spline coefficient from the table.
Parameters
----------
i_type : np.array
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : np.array
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
idx : np.array
The index of the spline coefficient. (nframes, nloc, nnei)
tab_data : np.array
The table storing all the spline coefficient. (ntype, ntype, nspline, 4)
nspline : int
The number of splines in the table.
Returns
-------
np.array
The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed.
"""
# (nframes, nloc, nnei)
expanded_i_type = np.broadcast_to(i_type[:, :, np.newaxis], (i_type.shape[0], i_type.shape[1], j_type.shape[-1]))

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = np.broadcast_to(idx[..., np.newaxis, np.newaxis], idx.shape + (1, 4))
clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int)

# (nframes, nloc, nnei, 4)
final_coef = np.squeeze(np.take_along_axis(expanded_tab_data, clipped_indices, 3))

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() > nspline] = 0
return final_coef

@staticmethod
def _calcualte_ener(coef: np.array, uu: np.array) -> np.array:
"""Calculate energy using spline coeeficients.
Parameters
----------
coef : np.array
The spline coefficients. (nframes, nloc, nnei, 4)
uu : np.array
The atom displancemnt used in interpolation and extrapolation (nframes, nloc, nnei)
Returns
-------
np.array
The atomic energy for all local atoms for all frames. (nframes, nloc, nnei)
"""
a3, a2, a1, a0 = torch.unbind(coef, dim=-1)
etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations.
ener = etmp * uu + a0 # this energy has the extrapolated value when rcut > rmax
return ener

0 comments on commit 0ce23f4

Please sign in to comment.