Skip to content

Commit

Permalink
feat: redo extrapolation with cubic spline for smoothness
Browse files Browse the repository at this point in the history
  • Loading branch information
Anyang Peng authored and Anyang Peng committed Jan 30, 2024
1 parent 88936cc commit 1c4ee0d
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 26 deletions.
139 changes: 119 additions & 20 deletions deepmd/pt/model/model/pair_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
Union,
)

from scipy.interpolate import (
CubicSpline,
)

import numpy as np
import torch
from torch import (
nn,
Expand Down Expand Up @@ -99,12 +104,12 @@ def forward_atomic(
mapping: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
nframes, nloc, nnei = nlist.shape
self.nframes, self.nloc, self.nnei = nlist.shape

# this will mask all -1 in the nlist
masked_nlist = torch.clamp(nlist, 0)

atype = extended_atype[:, :nloc] # (nframes, nloc)
atype = extended_atype[:, :self.nloc] # (nframes, nloc)
pairwise_dr = self._get_pairwise_dist(
extended_coord
) # (nframes, nall, nall, 3)
Expand All @@ -122,7 +127,7 @@ def forward_atomic(
]

# slice rr to get (nframes, nloc, nnei)
rr = torch.gather(pairwise_rr[:, :nloc, :], 2, masked_nlist)
rr = torch.gather(pairwise_rr[:, :self.nloc, :], 2, masked_nlist)

raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr)

Expand Down Expand Up @@ -189,19 +194,102 @@ def _pair_tabulated_inter(

uu -= idx

final_coef = self._extract_spline_coefficient(i_type, j_type, idx)

# here we need to do postprocess to overwrite coefficients when table values are not zero at rcut in linear extrapolation.
if self.tab.rmax < self.rcut:
post_mask = rr >= self.rcut
final_coef[post_mask] = torch.zeros(4)

a3, a2, a1, a0 = torch.unbind(final_coef, dim=-1) # 4 * (nframes, nloc, nnei)

etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations.
ener = etmp * uu + a0
table_coef = self._extract_spline_coefficient(i_type, j_type, idx, self.tab_data, self.nspline)
table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4)
ener = self._calcualte_ener(table_coef, uu)

# here we need to do postprocess to overwrite energy to zero beyond rcut.
if self.tab.rmax <= self.rcut:
mask_beyond_rcut = rr > self.rcut
ener[mask_beyond_rcut] = 0

# here we use smooth extrapolation to replace linear extrapolation.
extrapolation = self._extrapolate_rmax_rcut()
if extrapolation is not None:
uu_extrapolate = (rr - self.tab.rmax) / (self.rcut - self.tab.rmax)
clipped_uu = torch.clamp(uu_extrapolate, 0, 1) # clip rr within rmax.
extrapolate_coef = self._extract_spline_coefficient(i_type, j_type, torch.zeros_like(idx), extrapolation, 1)
extrapolate_coef = extrapolate_coef.reshape(self.nframes, self.nloc, self.nnei, 4)
ener_extrpolate = self._calcualte_ener(extrapolate_coef, clipped_uu)
mask_rmax_to_rcut = (self.tab.rmax < rr) & (rr <= self.rcut)
ener[mask_rmax_to_rcut] = ener_extrpolate[mask_rmax_to_rcut]
return ener

def _extrapolate_rmax_rcut(self) -> torch.Tensor:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
"""Soomth extrapolation between table upper boundary and rcut.
This method should only be used when the table upper boundary `rmax` is smaller than `rcut`, and
the table upper boundary values are not zeros. To simplify the problem, we use a single
cubic spline between `rmax` and `rcut` for each pair of atom types. One can substitute this extrapolation
to higher order polynomials if needed.
There are two scenarios:
1. `ruct` - `rmax` >= hh:
Set values at the grid point right before `rcut` to 0, and perform exterapolation between
the grid point and `rmax`, this allows smooth decay to 0 at `rcut`.
2. `rcut` - `rmax` < hh:
Set values at `rmax + hh` to 0, and perform extrapolation between `rmax` and `rmax + hh`,
the enery beyond `rcut` will be overwritten to `0` latter.
Returns
-------
torch.Tensor
The cubic spline coefficients for each pair of atom types. (ntype, ntype, 1, 4)
"""
#check if decays to `0` at rmax, if yes, no extrapolation is needed.
rmax_val = torch.from_numpy(self.tab.vdata[self.tab.vdata[:,0] == self.tab.rmax])
pre_rmax_val = torch.from_numpy(self.tab.vdata[self.tab.vdata[:,0] == self.tab.rmax - self.tab.hh])

if torch.all(rmax_val[:,1:] == 0):
return

Check warning on line 244 in deepmd/pt/model/model/pair_tab.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/pair_tab.py#L244

Added line #L244 was not covered by tests
else:
if self.rcut - self.tab.rmax >= self.tab.hh:
rcut_idx = int(self.rcut/self.tab.hh - self.tab.rmin/self.tab.hh)
rcut_val = torch.tensor(self.tab.vdata[rcut_idx,:]).reshape(1,-1)
grid = torch.concatenate([rmax_val, rcut_val],axis=0)
else:
# the last two rows will be the rmax, and rmax+hh
grid = torch.from_numpy(self.tab.vdata[-2:,:])
passin_slope = ((rmax_val - pre_rmax_val)/self.tab.hh)[:,1:].squeeze(0) if ~np.all(pre_rmax_val == None) else 0 # the slope at the end of table for each ntype pairs (ntypes,ntypes,1)

Check notice

Code scanning / CodeQL

Testing equality to None Note

Testing for None should use the 'is' operator.
extrapolate_coef = torch.from_numpy(self._calculate_spline_coef(grid, passin_slope)).reshape(self.ntypes,self.ntypes,4)
return extrapolate_coef.unsqueeze(2)

# might be able to refactor this, combine with PairTab
def _calculate_spline_coef(self, grid, passin_slope):
data = np.zeros([self.ntypes * self.ntypes * 4])
stride = 4
idx_iter = 0

xx = grid[:, 0]
for t0 in range(self.ntypes):
for t1 in range(t0, self.ntypes):
vv = grid[:, 1 + idx_iter]
slope_idx = [t0 * (2 * self.ntypes - t0 - 1)//2 + t1]

print(f"slope: {passin_slope[slope_idx]}")
cs = CubicSpline(xx, vv, bc_type=((1,passin_slope[slope_idx][0]),(1,0)))
dd = cs(xx, 1)
dd *= self.tab.hh
dtmp = np.zeros(stride)
dtmp[0] = (
2 * vv[0] - 2 * vv[1] + dd[0] + dd[1]
)
dtmp[1] = (
(-3 * vv[0] + 3 * vv[1] - 2 * dd[0] - dd[1])
)
dtmp[2] = dd[0]
dtmp[3] = vv[0]
data[
(t0 * self.ntypes + t1) * stride : (t0 * self.ntypes + t1) * stride
+ stride
] = dtmp
data[
(t1 * self.ntypes + t0) * stride : (t1 * self.ntypes + t0) * stride
+ stride
] = dtmp
idx_iter += 1
return data

@staticmethod
def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:
"""Get pairwise distance `dr`.
Expand Down Expand Up @@ -240,8 +328,8 @@ def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:
"""
return coords.unsqueeze(2) - coords.unsqueeze(1)

def _extract_spline_coefficient(
self, i_type: torch.Tensor, j_type: torch.Tensor, idx: torch.Tensor
@staticmethod
def _extract_spline_coefficient(i_type: torch.Tensor, j_type: torch.Tensor, idx: torch.Tensor, tab_data: torch.Tensor, nspline: int
) -> torch.Tensor:
"""Extract the spline coefficient from the table.
Expand All @@ -253,29 +341,40 @@ def _extract_spline_coefficient(
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
idx : torch.Tensor
The index of the spline coefficient. (nframes, nloc, nnei)
tab_data : torch.Tensor
The table storing all the spline coefficient. (ntype, ntype, nspline, 4)
nspline : int
The number of splines in the table.
Returns
-------
torch.Tensor
The spline coefficient. (nframes, nloc, nnei, 4)
The spline coefficient. (nframes, nloc, nnei, 4), shape maybe squeezed.
"""
# (nframes, nloc, nnei)
expanded_i_type = i_type.unsqueeze(-1).expand(-1, -1, j_type.shape[-1])

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = self.tab_data[expanded_i_type, j_type]
expanded_tab_data = tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, -1, 4)

# handle the case where idx is beyond the number of splines
clipped_indices = torch.clamp(expanded_idx, 0, self.nspline - 1).to(torch.int64)
clipped_indices = torch.clamp(expanded_idx, 0, nspline - 1).to(torch.int64)

# (nframes, nloc, nnei, 4)
final_coef = torch.gather(expanded_tab_data, 3, clipped_indices).squeeze()

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() >= self.nspline] = 0
final_coef[expanded_idx.squeeze() >= nspline] = 0

return final_coef

@staticmethod
def _calcualte_ener(coef, uu):
a3, a2, a1, a0 = torch.unbind(coef, dim=-1) # 4 * (nframes, nloc, nnei)
etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations.
ener = etmp * uu + a0 # this energy has the linear extrapolated value when rcut > rmax
return ener
11 changes: 5 additions & 6 deletions deepmd/utils/pair_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def reinit(self, filename: str, rcut: Optional[float] = None) -> None:
# check table data against rcut and update tab_file if needed, table upper boundary is used as rcut if not provided.
self.rcut = rcut if rcut is not None else self.rmax
self._check_table_upper_boundary()
self.nspline = self.vdata.shape[0] - 1
self.nspline = self.vdata.shape[0] - 1 # this nspline is updated based on the expanded table.
self.tab_info = np.array([self.rmin, self.hh, self.nspline, self.ntypes])
self.tab_data = self._make_data()

Expand Down Expand Up @@ -106,7 +106,7 @@ def _check_table_upper_boundary(self) -> None:
"""
upper_val = self.vdata[-1][1:]
upper_idx = self.vdata.shape[0] - 1

ncol = self.vdata.shape[1]
# the index of table for the grid point right after rcut
rcut_idx = int(self.rcut / self.hh)

Expand All @@ -120,7 +120,7 @@ def _check_table_upper_boundary(self) -> None:

# if table values decay to `0` before rcut, pad table with `0`s.
elif self.rcut > self.rmax:
pad_zero = np.zeros((rcut_idx - upper_idx, 4))
pad_zero = np.zeros((rcut_idx - upper_idx, ncol))
pad_zero[:, 0] = np.linspace(
self.rmax + self.hh, self.hh * (rcut_idx + 1), rcut_idx - upper_idx
)
Expand All @@ -134,9 +134,9 @@ def _check_table_upper_boundary(self) -> None:
# if rcut goes beyond table upper bond, need extrapolation, ensure values decay to `0` before rcut.
else:
logging.warning(
"The rcut goes beyond table upper boundary, performing linear extrapolation."
"The rcut goes beyond table upper boundary, performing extrapolation."
)
pad_linear = np.zeros((rcut_idx - upper_idx + 1, 4))
pad_linear = np.zeros((rcut_idx - upper_idx + 1, ncol))
pad_linear[:, 0] = np.linspace(
self.rmax, self.hh * (rcut_idx + 1), rcut_idx - upper_idx + 1
)
Expand All @@ -150,7 +150,6 @@ def get(self) -> Tuple[np.array, np.array]:
return self.tab_info, self.tab_data

def _make_data(self):
# here we need to do postprocess, to overwrite coefficients when padding zeros resulting in negative energies.
data = np.zeros([self.ntypes * self.ntypes * 4 * self.nspline])
stride = 4 * self.nspline
idx_iter = 0
Expand Down
63 changes: 63 additions & 0 deletions source/tests/pt/test_pairtab.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,68 @@ def test_with_mask(self):
def test_jit(self):
model = torch.jit.script(self.model)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable model is not used.


class TestPairTabTwoAtoms(unittest.TestCase):
@patch("numpy.loadtxt")
def test_extrapolation_nonzero_rmax(self, mock_loadtxt) -> None:
"""Scenarios to test.
rcut < rmax:
rr < rcut: use table values, or interpolate.
rr == rcut: use table values, or interpolate.
rr > rcut: should be 0
rcut == rmax:
rr < rcut: use table values, or interpolate.
rr == rcut: use table values, or interpolate.
rr > rcut: should be 0
rcut > rmax:
rr < rmax: use table values, or interpolate.
rr == rmax: use table values, or interpolate.
rmax < rr < rcut: extrapolate
rr >= rcut: should be 0
"""
file_path = "dummy_path"
mock_loadtxt.return_value = np.array(
[
[0.005, 1.0],
[0.01, 0.8],
[0.015, 0.5],
[0.02, 0.25],
]
)

# nframes=1, nall=2
extended_atype = torch.tensor([[0, 0]])

# nframes=1, nloc=2, nnei=1
nlist = torch.tensor([[[1], [-1]]])

results = []

for dist, rcut in zip(
[0.010, 0.015, 0.020, 0.015, 0.020, 0.021, 0.015, 0.020, 0.021, 0.025, 0.026, 0.025, 0.02999],
[0.015, 0.015, 0.015, 0.02, 0.020, 0.020, 0.022, 0.022, 0.022, 0.025, 0.025, 0.030, 0.030]):
extended_coord = torch.tensor(
[
[
[0.0, 0.0, 0.0],
[0.0, dist, 0.0],
],
]
)

model = PairTabModel(tab_file=file_path, rcut=rcut, sel=2)
results.append(model.forward_atomic(
extended_coord, extended_atype, nlist
)["energy"])


expected_result = torch.stack([torch.tensor([[[0.4, 0], [0.25,0.], [0.,0], [0.25,0], [0.125,0], [0.,0], [0.25,0] ,[0.125,0], [0.0469,0], [0.,0], [0.,0],[0.0469,0],[0,0]]])]).reshape(13,2)
results = torch.stack(results).reshape(13,2)

torch.testing.assert_allclose(results, expected_result, 0.0001, 0.0001)


if __name__ == "__main__":
unittest.main()

0 comments on commit 1c4ee0d

Please sign in to comment.