diff --git a/deepmd/pt/model/model/pair_tab.py b/deepmd/pt/model/model/pair_tab.py new file mode 100644 index 0000000000..6f0782289a --- /dev/null +++ b/deepmd/pt/model/model/pair_tab.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Union, +) + +import torch +from torch import ( + nn, +) + +from deepmd.model_format import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.utils.pair_tab import ( + PairTab, +) + +from .atomic_model import ( + AtomicModel, +) + + +class PairTabModel(nn.Module, AtomicModel): + """Pairwise tabulation energy model. + + This model can be used to tabulate the pairwise energy between atoms for either + short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not + be used alone, but rather as one submodel of a linear (sum) model, such as + DP+D3. + + Do not put the model on the first model of a linear model, since the linear + model fetches the type map from the first model. + + At this moment, the model does not smooth the energy at the cutoff radius, so + one needs to make sure the energy has been smoothed to zero. + + Parameters + ---------- + tab_file : str + The path to the tabulation file. + rcut : float + The cutoff radius. + sel : int or list[int] + The maxmum number of atoms in the cut-off radius. + """ + + def __init__( + self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs + ): + super().__init__() + self.tab_file = tab_file + self.rcut = rcut + + self.tab = PairTab(self.tab_file, rcut=rcut) + self.ntypes = self.tab.ntypes + + tab_info, tab_data = self.tab.get() # this returns -> Tuple[np.array, np.array] + self.tab_info = torch.from_numpy(tab_info) + self.tab_data = torch.from_numpy(tab_data) + + # self.model_type = "ener" + # self.model_version = MODEL_VERSION ## this shoud be in the parent class + + if isinstance(sel, int): + self.sel = sel + elif isinstance(sel, list): + self.sel = sum(sel) + else: + raise TypeError("sel must be int or list[int]") + + def get_fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", shape=[1], reduciable=True, differentiable=True + ) + ] + ) + + def get_rcut(self) -> float: + return self.rcut + + def get_sel(self) -> int: + return self.sel + + def distinguish_types(self) -> bool: + # to match DPA1 and DPA2. + return False + + def forward_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + self.nframes, self.nloc, self.nnei = nlist.shape + + # this will mask all -1 in the nlist + masked_nlist = torch.clamp(nlist, 0) + + atype = extended_atype[:, : self.nloc] # (nframes, nloc) + pairwise_dr = self._get_pairwise_dist( + extended_coord + ) # (nframes, nall, nall, 3) + pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall) + + self.tab_data = self.tab_data.reshape( + self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 + ) + + # to calculate the atomic_energy, we need 3 tensors, i_type, j_type, rr + # i_type : (nframes, nloc), this is atype. + # j_type : (nframes, nloc, nnei) + j_type = extended_atype[ + torch.arange(extended_atype.size(0))[:, None, None], masked_nlist + ] + + # slice rr to get (nframes, nloc, nnei) + rr = torch.gather(pairwise_rr[:, : self.nloc, :], 2, masked_nlist) + + raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr) + + atomic_energy = 0.5 * torch.sum( + torch.where( + nlist != -1, raw_atomic_energy, torch.zeros_like(raw_atomic_energy) + ), + dim=-1, + ) + + return {"energy": atomic_energy} + + def _pair_tabulated_inter( + self, + nlist: torch.Tensor, + i_type: torch.Tensor, + j_type: torch.Tensor, + rr: torch.Tensor, + ) -> torch.Tensor: + """Pairwise tabulated energy. + + Parameters + ---------- + nlist : torch.Tensor + The unmasked neighbour list. (nframes, nloc) + i_type : torch.Tensor + The integer representation of atom type for all local atoms for all frames. (nframes, nloc) + j_type : torch.Tensor + The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei) + rr : torch.Tensor + The salar distance vector between two atoms. (nframes, nloc, nnei) + + Returns + ------- + torch.Tensor + The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei) + + Raises + ------ + Exception + If the distance is beyond the table. + + Notes + ----- + This function is used to calculate the pairwise energy between two atoms. + It uses a table containing cubic spline coefficients calculated in PairTab. + """ + rmin = self.tab_info[0] + hh = self.tab_info[1] + hi = 1.0 / hh + + self.nspline = int(self.tab_info[2] + 0.1) + + uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei) + + # if nnei of atom 0 has -1 in the nlist, uu would be 0. + # this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms. + uu = torch.where(nlist != -1, uu, self.nspline + 1) + + if torch.any(uu < 0): + raise Exception("coord go beyond table lower boundary") + + idx = uu.to(torch.int) + + uu -= idx + + table_coef = self._extract_spline_coefficient( + i_type, j_type, idx, self.tab_data, self.nspline + ) + table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4) + ener = self._calcualte_ener(table_coef, uu) + + # here we need to overwrite energy to zero at rcut and beyond. + mask_beyond_rcut = rr >= self.rcut + # also overwrite values beyond extrapolation to zero + extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh + ener[mask_beyond_rcut] = 0 + ener[extrapolation_mask] = 0 + + return ener + + @staticmethod + def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor: + """Get pairwise distance `dr`. + + Parameters + ---------- + coords : torch.Tensor + The coordinate of the atoms shape of (nframes * nall * 3). + + Returns + ------- + torch.Tensor + The pairwise distance between the atoms (nframes * nall * nall * 3). + + Examples + -------- + coords = torch.tensor([[ + [0,0,0], + [1,3,5], + [2,4,6] + ]]) + + dist = tensor([[ + [[ 0, 0, 0], + [-1, -3, -5], + [-2, -4, -6]], + + [[ 1, 3, 5], + [ 0, 0, 0], + [-1, -1, -1]], + + [[ 2, 4, 6], + [ 1, 1, 1], + [ 0, 0, 0]] + ]]) + """ + return coords.unsqueeze(2) - coords.unsqueeze(1) + + @staticmethod + def _extract_spline_coefficient( + i_type: torch.Tensor, + j_type: torch.Tensor, + idx: torch.Tensor, + tab_data: torch.Tensor, + nspline: int, + ) -> torch.Tensor: + """Extract the spline coefficient from the table. + + Parameters + ---------- + i_type : torch.Tensor + The integer representation of atom type for all local atoms for all frames. (nframes, nloc) + j_type : torch.Tensor + The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei) + idx : torch.Tensor + The index of the spline coefficient. (nframes, nloc, nnei) + tab_data : torch.Tensor + The table storing all the spline coefficient. (ntype, ntype, nspline, 4) + nspline : int + The number of splines in the table. + + Returns + ------- + torch.Tensor + The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed. + + """ + # (nframes, nloc, nnei) + expanded_i_type = i_type.unsqueeze(-1).expand(-1, -1, j_type.shape[-1]) + + # (nframes, nloc, nnei, nspline, 4) + expanded_tab_data = tab_data[expanded_i_type, j_type] + + # (nframes, nloc, nnei, 1, 4) + expanded_idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, -1, 4) + + # handle the case where idx is beyond the number of splines + clipped_indices = torch.clamp(expanded_idx, 0, nspline - 1).to(torch.int64) + + # (nframes, nloc, nnei, 4) + final_coef = torch.gather(expanded_tab_data, 3, clipped_indices).squeeze() + + # when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`. + final_coef[expanded_idx.squeeze() > nspline] = 0 + return final_coef + + @staticmethod + def _calcualte_ener(coef: torch.Tensor, uu: torch.Tensor) -> torch.Tensor: + """Calculate energy using spline coeeficients. + + Parameters + ---------- + coef : torch.Tensor + The spline coefficients. (nframes, nloc, nnei, 4) + uu : torch.Tensor + The atom displancemnt used in interpolation and extrapolation (nframes, nloc, nnei) + + Returns + ------- + torch.Tensor + The atomic energy for all local atoms for all frames. (nframes, nloc, nnei) + """ + a3, a2, a1, a0 = torch.unbind(coef, dim=-1) + etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations. + ener = etmp * uu + a0 # this energy has the extrapolated value when rcut > rmax + return ener diff --git a/deepmd/utils/pair_tab.py b/deepmd/utils/pair_tab.py index 4451f53379..56f8e618df 100644 --- a/deepmd/utils/pair_tab.py +++ b/deepmd/utils/pair_tab.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from typing import ( + Optional, Tuple, ) @@ -25,11 +27,11 @@ class PairTab: The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly. """ - def __init__(self, filename: str) -> None: + def __init__(self, filename: str, rcut: Optional[float] = None) -> None: """Constructor.""" - self.reinit(filename) + self.reinit(filename, rcut) - def reinit(self, filename: str) -> None: + def reinit(self, filename: str, rcut: Optional[float] = None) -> None: """Initialize the tabulated interaction. Parameters @@ -44,8 +46,8 @@ def reinit(self, filename: str) -> None: """ self.vdata = np.loadtxt(filename) self.rmin = self.vdata[0][0] + self.rmax = self.vdata[-1][0] self.hh = self.vdata[1][0] - self.vdata[0][0] - self.nspline = self.vdata.shape[0] - 1 ncol = self.vdata.shape[1] - 1 n0 = (-1 + np.sqrt(1 + 8 * ncol)) * 0.5 self.ntypes = int(n0 + 0.1) @@ -53,13 +55,155 @@ def reinit(self, filename: str) -> None: "number of volumes provided in %s does not match guessed number of types %d" % (filename, self.ntypes) ) + + # check table data against rcut and update tab_file if needed, table upper boundary is used as rcut if not provided. + self.rcut = rcut if rcut is not None else self.rmax + self._check_table_upper_boundary() + self.nspline = ( + self.vdata.shape[0] - 1 + ) # this nspline is updated based on the expanded table. self.tab_info = np.array([self.rmin, self.hh, self.nspline, self.ntypes]) self.tab_data = self._make_data() + def _check_table_upper_boundary(self) -> None: + """Update User Provided Table Based on `rcut`. + + This function checks the upper boundary provided in the table against rcut. + If the table upper boundary values decay to zero before rcut, padding zeros will + be added to the table to cover rcut; if the table upper boundary values do not decay to zero + before ruct, extrapolation will be performed till rcut. + + Examples + -------- + table = [[0.005 1. 2. 3. ] + [0.01 0.8 1.6 2.4 ] + [0.015 0. 1. 1.5 ]] + + rcut = 0.022 + + new_table = [[0.005 1. 2. 3. ] + [0.01 0.8 1.6 2.4 ] + [0.015 0. 1. 1.5 ] + [0.02 0. 0. 0. ] + + ---------------------------------------------- + + table = [[0.005 1. 2. 3. ] + [0.01 0.8 1.6 2.4 ] + [0.015 0.5 1. 1.5 ] + [0.02 0.25 0.4 0.75 ] + [0.025 0. 0.1 0. ] + [0.03 0. 0. 0. ]] + + rcut = 0.031 + + new_table = [[0.005 1. 2. 3. ] + [0.01 0.8 1.6 2.4 ] + [0.015 0.5 1. 1.5 ] + [0.02 0.25 0.4 0.75 ] + [0.025 0. 0.1 0. ] + [0.03 0. 0. 0. ] + [0.035 0. 0. 0. ]] + """ + upper_val = self.vdata[-1][1:] + upper_idx = self.vdata.shape[0] - 1 + self.ncol = self.vdata.shape[1] + + # the index in table for the grid point of rcut, always give the point after rcut. + rcut_idx = int(np.ceil(self.rcut / self.hh - self.rmin / self.hh)) + if np.all(upper_val == 0): + # if table values decay to `0` after rcut + if self.rcut < self.rmax and np.any(self.vdata[rcut_idx - 1][1:] != 0): + logging.warning( + "The energy provided in the table does not decay to 0 at rcut." + ) + # if table values decay to `0` at rcut, do nothing + + # if table values decay to `0` before rcut, pad table with `0`s. + elif self.rcut > self.rmax: + pad_zero = np.zeros((rcut_idx - upper_idx, self.ncol)) + pad_zero[:, 0] = np.linspace( + self.rmax + self.hh, + self.rmax + self.hh * (rcut_idx - upper_idx), + rcut_idx - upper_idx, + ) + self.vdata = np.concatenate((self.vdata, pad_zero), axis=0) + else: + # if table values do not decay to `0` at rcut + if self.rcut <= self.rmax: + logging.warning( + "The energy provided in the table does not decay to 0 at rcut." + ) + # if rcut goes beyond table upper bond, need extrapolation, ensure values decay to `0` before rcut. + else: + logging.warning( + "The rcut goes beyond table upper boundary, performing extrapolation." + ) + pad_extrapolation = np.zeros((rcut_idx - upper_idx, self.ncol)) + + pad_extrapolation[:, 0] = np.linspace( + self.rmax + self.hh, + self.rmax + self.hh * (rcut_idx - upper_idx), + rcut_idx - upper_idx, + ) + # need to calculate table values to fill in with cubic spline + pad_extrapolation = self._extrapolate_table(pad_extrapolation) + + self.vdata = np.concatenate((self.vdata, pad_extrapolation), axis=0) + def get(self) -> Tuple[np.array, np.array]: """Get the serialized table.""" return self.tab_info, self.tab_data + def _extrapolate_table(self, pad_extrapolation: np.array) -> np.array: + """Soomth extrapolation between table upper boundary and rcut. + + This method should only be used when the table upper boundary `rmax` is smaller than `rcut`, and + the table upper boundary values are not zeros. To simplify the problem, we use a single + cubic spline between `rmax` and `rcut` for each pair of atom types. One can substitute this extrapolation + to higher order polynomials if needed. + + There are two scenarios: + 1. `ruct` - `rmax` >= hh: + Set values at the grid point right before `rcut` to 0, and perform exterapolation between + the grid point and `rmax`, this allows smooth decay to 0 at `rcut`. + 2. `rcut` - `rmax` < hh: + Set values at `rmax + hh` to 0, and perform extrapolation between `rmax` and `rmax + hh`. + + Parameters + ---------- + pad_extrapolation : np.array + The emepty grid that holds the extrapolation values. + + Returns + ------- + np.array + The cubic spline extrapolation. + """ + # in theory we should check if the table has at least two rows. + slope = self.vdata[-1, 1:] - self.vdata[-2, 1:] # shape of (ncol-1, ) + + # for extrapolation, we want values decay to `0` prior to `ruct` if possible + # here we try to find the grid point prior to `rcut` + grid_point = ( + -2 if pad_extrapolation[-1, 0] / self.hh - self.rmax / self.hh >= 2 else -1 + ) + temp_grid = np.stack((self.vdata[-1, :], pad_extrapolation[grid_point, :])) + vv = temp_grid[:, 1:] + xx = temp_grid[:, 0] + cs = CubicSpline(xx, vv, bc_type=((1, slope), (1, np.zeros_like(slope)))) + xx_grid = pad_extrapolation[:, 0] + res = cs(xx_grid) + + pad_extrapolation[:, 1:] = res + + # Note: when doing cubic spline, if we want to ensure values decay to zero prior to `rcut` + # this may cause values be positive post `rcut`, we need to overwrite those values to zero + pad_extrapolation = ( + pad_extrapolation if grid_point == -1 else pad_extrapolation[:-1, :] + ) + return pad_extrapolation + def _make_data(self): data = np.zeros([self.ntypes * self.ntypes * 4 * self.nspline]) stride = 4 * self.nspline @@ -68,7 +212,7 @@ def _make_data(self): for t0 in range(self.ntypes): for t1 in range(t0, self.ntypes): vv = self.vdata[:, 1 + idx_iter] - cs = CubicSpline(xx, vv) + cs = CubicSpline(xx, vv, bc_type="clamped") dd = cs(xx, 1) dd *= self.hh dtmp = np.zeros(stride) diff --git a/source/tests/common/test_pairtab_preprocess.py b/source/tests/common/test_pairtab_preprocess.py new file mode 100644 index 0000000000..a866c42236 --- /dev/null +++ b/source/tests/common/test_pairtab_preprocess.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + patch, +) + +import numpy as np + +from deepmd.utils.pair_tab import ( + PairTab, +) + + +class TestPairTabPreprocessExtrapolate(unittest.TestCase): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + ] + ) + + self.tab1 = PairTab(filename=file_path, rcut=0.028) + self.tab2 = PairTab(filename=file_path, rcut=0.02) + self.tab3 = PairTab(filename=file_path, rcut=0.022) + self.tab4 = PairTab(filename=file_path, rcut=0.03) + self.tab5 = PairTab(filename=file_path, rcut=0.032) + + def test_preprocess(self): + np.testing.assert_allclose( + self.tab1.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + np.testing.assert_allclose( + self.tab2.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + # for this test case, the table does not decay to zero at rcut = 0.22, + # in the cubic spline code, we use a fixed size grid, if will be a problem if we introduce variable gird size. + # we will do post process to overwrite spline coefficient `a3`,`a2`,`a1`,`a0`, to ensure energy decays to `0`. + np.testing.assert_allclose( + self.tab3.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + np.testing.assert_allclose( + self.tab4.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + np.testing.assert_allclose( + self.tab5.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.12468, 0.1992, 0.3741], + [0.03, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + +class TestPairTabPreprocessZero(unittest.TestCase): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ) + + self.tab1 = PairTab(filename=file_path, rcut=0.023) + self.tab2 = PairTab(filename=file_path, rcut=0.025) + self.tab3 = PairTab(filename=file_path, rcut=0.028) + self.tab4 = PairTab(filename=file_path, rcut=0.033) + + def test_preprocess(self): + np.testing.assert_allclose( + self.tab1.vdata, + np.array( + [ + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ), + ) + np.testing.assert_allclose( + self.tab2.vdata, + np.array( + [ + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ), + ) + + np.testing.assert_allclose( + self.tab3.vdata, + np.array( + [ + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + [0.03, 0.0, 0.0, 0.0], + ] + ), + ) + + np.testing.assert_allclose( + self.tab4.vdata, + np.array( + [ + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + [0.03, 0.0, 0.0, 0.0], + [0.035, 0.0, 0.0, 0.0], + ] + ), + ) + + +class TestPairTabPreprocessUneven(unittest.TestCase): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.1, 0.0], + ] + ) + + self.tab1 = PairTab(filename=file_path, rcut=0.025) + self.tab2 = PairTab(filename=file_path, rcut=0.028) + self.tab3 = PairTab(filename=file_path, rcut=0.03) + self.tab4 = PairTab(filename=file_path, rcut=0.037) + + def test_preprocess(self): + np.testing.assert_allclose( + self.tab1.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.1, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + np.testing.assert_allclose( + self.tab2.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.1, 0.0], + [0.03, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + np.testing.assert_allclose( + self.tab3.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.1, 0.0], + [0.03, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + np.testing.assert_allclose( + self.tab4.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.1, 0.0], + [0.03, 0.0, 0.04963, 0.0], + [0.035, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-03, + atol=1e-03, + ) diff --git a/source/tests/pt/test_pairtab.py b/source/tests/pt/test_pairtab.py new file mode 100644 index 0000000000..b4dbda6702 --- /dev/null +++ b/source/tests/pt/test_pairtab.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + patch, +) + +import numpy as np +import torch + +from deepmd.pt.model.model.pair_tab import ( + PairTabModel, +) + + +class TestPairTab(unittest.TestCase): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + ] + ) + + self.model = PairTabModel(tab_file=file_path, rcut=0.02, sel=2) + + self.extended_coord = torch.tensor( + [ + [ + [0.01, 0.01, 0.01], + [0.01, 0.02, 0.01], + [0.01, 0.01, 0.02], + [0.02, 0.01, 0.01], + ], + [ + [0.01, 0.01, 0.01], + [0.01, 0.02, 0.01], + [0.01, 0.01, 0.02], + [0.05, 0.01, 0.01], + ], + ] + ) + + # nframes=2, nall=4 + self.extended_atype = torch.tensor([[0, 1, 0, 1], [0, 0, 1, 1]]) + + # nframes=2, nloc=2, nnei=2 + self.nlist = torch.tensor([[[1, 2], [0, 2]], [[1, 2], [0, 3]]]) + + def test_without_mask(self): + result = self.model.forward_atomic( + self.extended_coord, self.extended_atype, self.nlist + ) + expected_result = torch.tensor([[1.2000, 1.3614], [1.2000, 0.4000]]) + + torch.testing.assert_allclose(result["energy"], expected_result, 0.0001, 0.0001) + + def test_with_mask(self): + self.nlist = torch.tensor([[[1, -1], [0, 2]], [[1, 2], [0, 3]]]) + + result = self.model.forward_atomic( + self.extended_coord, self.extended_atype, self.nlist + ) + expected_result = torch.tensor([[0.8000, 1.3614], [1.2000, 0.4000]]) + + torch.testing.assert_allclose(result["energy"], expected_result, 0.0001, 0.0001) + + def test_jit(self): + model = torch.jit.script(self.model) + + +class TestPairTabTwoAtoms(unittest.TestCase): + @patch("numpy.loadtxt") + def test_extrapolation_nonzero_rmax(self, mock_loadtxt) -> None: + """Scenarios to test. + + rcut < rmax: + rr < rcut: use table values, or interpolate. + rr == rcut: use table values, or interpolate. + rr > rcut: should be 0 + rcut == rmax: + rr < rcut: use table values, or interpolate. + rr == rcut: use table values, or interpolate. + rr > rcut: should be 0 + rcut > rmax: + rr < rmax: use table values, or interpolate. + rr == rmax: use table values, or interpolate. + rmax < rr < rcut: extrapolate + rr >= rcut: should be 0 + + """ + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0], + [0.01, 0.8], + [0.015, 0.5], + [0.02, 0.25], + ] + ) + + # nframes=1, nall=2 + extended_atype = torch.tensor([[0, 0]]) + + # nframes=1, nloc=2, nnei=1 + nlist = torch.tensor([[[1], [-1]]]) + + results = [] + + for dist, rcut in zip( + [ + 0.01, + 0.015, + 0.020, + 0.015, + 0.02, + 0.021, + 0.015, + 0.02, + 0.021, + 0.025, + 0.026, + 0.025, + 0.025, + 0.0216161, + ], + [ + 0.015, + 0.015, + 0.015, + 0.02, + 0.02, + 0.02, + 0.022, + 0.022, + 0.022, + 0.025, + 0.025, + 0.03, + 0.035, + 0.025, + ], + ): + extended_coord = torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [0.0, dist, 0.0], + ], + ] + ) + + model = PairTabModel(tab_file=file_path, rcut=rcut, sel=2) + results.append( + model.forward_atomic(extended_coord, extended_atype, nlist)["energy"] + ) + + expected_result = torch.stack( + [ + torch.tensor( + [ + [ + [0.4, 0], + [0.0, 0], + [0.0, 0], + [0.25, 0], + [0, 0], + [0, 0], + [0.25, 0], + [0.125, 0], + [0.0922, 0], + [0, 0], + [0, 0], + [0, 0], + [0.0923, 0], + [0.0713, 0], + ] + ] + ) + ] + ).reshape(14, 2) + results = torch.stack(results).reshape(14, 2) + + torch.testing.assert_allclose(results, expected_result, 0.0001, 0.0001) + + if __name__ == "__main__": + unittest.main()