Skip to content

Commit

Permalink
feat(jax): zbl (#4301)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced new classes: `DPZBLLinearEnergyAtomicModel` and
`PairTabAtomicModel`, enhancing atomic model functionalities.
- Added `get_zbl_model` function for constructing `DPZBLModel` from
input data.
- Improved error handling in vector normalization with
`safe_for_vector_norm` and `safe_for_sqrt`.

- **Bug Fixes**
- Enhanced distance calculations in `format_nlist` to prevent NaN
errors.

- **Documentation**
	- Updated comments and docstrings for clarity on recent changes.

- **Tests**
	- Enhanced test support for JAX backend in `test_zbl_ener.py`.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 4, 2024
1 parent bfbe2ed commit 7aaf284
Show file tree
Hide file tree
Showing 11 changed files with 363 additions and 54 deletions.
47 changes: 26 additions & 21 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.utils.nlist import (
Expand Down Expand Up @@ -69,15 +70,16 @@ def __init__(
self.models = models
sub_model_type_maps = [md.get_type_map() for md in models]
err_msg = []
self.mapping_list = []
mapping_list = []
common_type_map = set(type_map)
self.type_map = type_map
for tpmp in sub_model_type_maps:
if not common_type_map.issubset(set(tpmp)):
err_msg.append(
f"type_map {tpmp} is not a subset of type_map {type_map}"
)
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))
mapping_list.append(self.remap_atype(tpmp, self.type_map))
self.mapping_list = mapping_list
assert len(err_msg) == 0, "\n".join(err_msg)
self.mixed_types_list = [model.mixed_types() for model in self.models]

Expand Down Expand Up @@ -212,8 +214,9 @@ def forward_atomic(
result_dict
the result dict, defined by the fitting net output def.
"""
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.reshape(nframes, -1, 3)
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
nlists = build_multiple_neighbor_list(
extended_coord,
Expand Down Expand Up @@ -244,10 +247,10 @@ def forward_atomic(
aparam,
)["energy"]
)
self.weights = self._compute_weight(extended_coord, extended_atype, nlists_)
weights = self._compute_weight(extended_coord, extended_atype, nlists_)

fit_ret = {
"energy": np.sum(np.stack(ener_list) * np.stack(self.weights), axis=0),
"energy": xp.sum(xp.stack(ener_list) * xp.stack(weights), axis=0),
} # (nframes, nloc, 1)
return fit_ret

Expand Down Expand Up @@ -320,11 +323,12 @@ def _compute_weight(
nlists_: list[np.ndarray],
) -> list[np.ndarray]:
"""This should be a list of user defined weights that matches the number of models to be combined."""
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlists_)
nmodels = len(self.models)
nframes, nloc, _ = nlists_[0].shape
# the dtype of weights is the interface data type.
return [
np.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels
xp.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels
for _ in range(nmodels)
]

Expand Down Expand Up @@ -442,6 +446,7 @@ def _compute_weight(
self.sw_rmax > self.sw_rmin
), "The upper boundary `sw_rmax` must be greater than the lower boundary `sw_rmin`."

xp = array_api_compat.array_namespace(extended_coord, extended_atype)
dp_nlist = nlists_[0]
zbl_nlist = nlists_[1]

Expand All @@ -450,40 +455,40 @@ def _compute_weight(

# use the larger rr based on nlist
nlist_larger = zbl_nlist if zbl_nnei >= dp_nnei else dp_nlist
masked_nlist = np.clip(nlist_larger, 0, None)
masked_nlist = xp.clip(nlist_larger, 0, None)
pairwise_rr = PairTabAtomicModel._get_pairwise_dist(
extended_coord, masked_nlist
)

numerator = np.sum(
np.where(
numerator = xp.sum(
xp.where(
nlist_larger != -1,
pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha),
np.zeros_like(nlist_larger),
pairwise_rr * xp.exp(-pairwise_rr / self.smin_alpha),
xp.zeros_like(nlist_larger),
),
axis=-1,
) # masked nnei will be zero, no need to handle
denominator = np.sum(
np.where(
denominator = xp.sum(
xp.where(
nlist_larger != -1,
np.exp(-pairwise_rr / self.smin_alpha),
np.zeros_like(nlist_larger),
xp.exp(-pairwise_rr / self.smin_alpha),
xp.zeros_like(nlist_larger),
),
axis=-1,
) # handle masked nnei.
with np.errstate(divide="ignore", invalid="ignore"):
sigma = numerator / denominator
u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin)
coef = np.zeros_like(u)
coef = xp.zeros_like(u)
left_mask = sigma < self.sw_rmin
mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax)
right_mask = sigma >= self.sw_rmax
coef[left_mask] = 1
coef = xp.where(left_mask, xp.ones_like(coef), coef)
with np.errstate(invalid="ignore"):
smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1
coef[mid_mask] = smooth[mid_mask]
coef[right_mask] = 0
coef = xp.where(mid_mask, smooth, coef)
coef = xp.where(right_mask, xp.zeros_like(coef), coef)
# to handle masked atoms
coef = np.where(sigma != 0, coef, np.zeros_like(coef))
coef = xp.where(sigma != 0, coef, xp.zeros_like(coef))
self.zbl_weight = coef
return [1 - np.expand_dims(coef, -1), np.expand_dims(coef, -1)]
return [1 - xp.expand_dims(coef, -1), xp.expand_dims(coef, -1)]
77 changes: 47 additions & 30 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
)
from deepmd.dpmodel.utils.safe_gradient import (
safe_for_sqrt,
)
from deepmd.utils.pair_tab import (
PairTab,
)
Expand Down Expand Up @@ -74,9 +81,10 @@ def __init__(
self.atom_ener = atom_ener

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
nspline, ntypes_tab = self.tab_info[-2:].astype(int)
self.tab_data = self.tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
tab_info, tab_data = self.tab.get()
nspline, ntypes_tab = tab_info[-2:].astype(int)
self.tab_info = tab_info
self.tab_data = tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
if self.ntypes != ntypes_tab:
raise ValueError(
"The `type_map` provided does not match the number of columns in the table."
Expand Down Expand Up @@ -189,8 +197,9 @@ def forward_atomic(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> dict[str, np.ndarray]:
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.reshape(nframes, -1, 3)
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))

# this will mask all -1 in the nlist
mask = nlist >= 0
Expand All @@ -200,23 +209,21 @@ def forward_atomic(
pairwise_rr = self._get_pairwise_dist(
extended_coord, masked_nlist
) # (nframes, nloc, nnei)
self.tab_data = self.tab_data.reshape(
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

# (nframes, nloc, nnei), index type is int64.
j_type = extended_atype[
np.arange(extended_atype.shape[0], dtype=np.int64)[:, None, None],
xp.arange(extended_atype.shape[0], dtype=xp.int64)[:, None, None],
masked_nlist,
]

raw_atomic_energy = self._pair_tabulated_inter(
nlist, atype, j_type, pairwise_rr
)
atomic_energy = 0.5 * np.sum(
np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)),
atomic_energy = 0.5 * xp.sum(
xp.where(nlist != -1, raw_atomic_energy, xp.zeros_like(raw_atomic_energy)),
axis=-1,
).reshape(nframes, nloc, 1)
)
atomic_energy = xp.reshape(atomic_energy, (nframes, nloc, 1))

return {"energy": atomic_energy}

Expand Down Expand Up @@ -255,36 +262,42 @@ def _pair_tabulated_inter(
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""
xp = array_api_compat.array_namespace(nlist, i_type, j_type, rr)
nframes, nloc, nnei = nlist.shape
rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1.0 / hh

nspline = int(self.tab_info[2] + 0.1)
# jax jit does not support convert to a Python int, so we need to convert to xp.int64.
nspline = (self.tab_info[2] + 0.1).astype(xp.int64)

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
uu = np.where(nlist != -1, uu, nspline + 1)
uu = xp.where(nlist != -1, uu, nspline + 1)

if np.any(uu < 0):
raise Exception("coord go beyond table lower boundary")
# unsupported by jax
# if xp.any(uu < 0):
# raise Exception("coord go beyond table lower boundary")

idx = uu.astype(int)
idx = xp.astype(uu, xp.int64)

uu -= idx
table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, nspline
)
table_coef = table_coef.reshape(nframes, nloc, nnei, 4)
table_coef = xp.reshape(table_coef, (nframes, nloc, nnei, 4))
ener = self._calculate_ener(table_coef, uu)
# here we need to overwrite energy to zero at rcut and beyond.
mask_beyond_rcut = rr >= self.rcut
# also overwrite values beyond extrapolation to zero
extrapolation_mask = rr >= self.tab.rmin + nspline * self.tab.hh
ener[mask_beyond_rcut] = 0
ener[extrapolation_mask] = 0
ener = xp.where(
xp.logical_or(mask_beyond_rcut, extrapolation_mask),
xp.zeros_like(ener),
ener,
)

return ener

Expand All @@ -304,12 +317,13 @@ def _get_pairwise_dist(coords: np.ndarray, nlist: np.ndarray) -> np.ndarray:
np.ndarray
The pairwise distance between the atoms (nframes, nloc, nnei).
"""
xp = array_api_compat.array_namespace(coords, nlist)
# index type is int64
batch_indices = np.arange(nlist.shape[0], dtype=np.int64)[:, None, None]
batch_indices = xp.arange(nlist.shape[0], dtype=xp.int64)[:, None, None]
neighbor_atoms = coords[batch_indices, nlist]
loc_atoms = coords[:, : nlist.shape[1], :]
pairwise_dr = loc_atoms[:, :, None, :] - neighbor_atoms
pairwise_rr = np.sqrt(np.sum(np.power(pairwise_dr, 2), axis=-1))
pairwise_rr = safe_for_sqrt(xp.sum(xp.power(pairwise_dr, 2), axis=-1))

return pairwise_rr

Expand All @@ -319,7 +333,7 @@ def _extract_spline_coefficient(
j_type: np.ndarray,
idx: np.ndarray,
tab_data: np.ndarray,
nspline: int,
nspline: np.int64,
) -> np.ndarray:
"""Extract the spline coefficient from the table.
Expand All @@ -341,28 +355,31 @@ def _extract_spline_coefficient(
np.ndarray
The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed.
"""
xp = array_api_compat.array_namespace(i_type, j_type, idx, tab_data)
# (nframes, nloc, nnei)
expanded_i_type = np.broadcast_to(
i_type[:, :, np.newaxis],
expanded_i_type = xp.broadcast_to(
i_type[:, :, xp.newaxis],
(i_type.shape[0], i_type.shape[1], j_type.shape[-1]),
)

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = np.broadcast_to(
idx[..., np.newaxis, np.newaxis], (*idx.shape, 1, 4)
expanded_idx = xp.broadcast_to(
idx[..., xp.newaxis, xp.newaxis], (*idx.shape, 1, 4)
)
clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int)
clipped_indices = xp.clip(expanded_idx, 0, nspline - 1).astype(int)

# (nframes, nloc, nnei, 4)
final_coef = np.squeeze(
np.take_along_axis(expanded_tab_data, clipped_indices, 3)
final_coef = xp.squeeze(
xp_take_along_axis(expanded_tab_data, clipped_indices, 3)
)

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() > nspline] = 0
final_coef = xp.where(
expanded_idx.squeeze() > nspline, xp.zeros_like(final_coef), final_coef
)
return final_coef

@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
LayerNorm,
NativeLayer,
)
from deepmd.dpmodel.utils.safe_gradient import (
safe_for_vector_norm,
)
from deepmd.dpmodel.utils.seed import (
child_seed,
)
Expand Down Expand Up @@ -943,7 +946,7 @@ def call(
else:
raise NotImplementedError

normed = xp.linalg.vector_norm(
normed = safe_for_vector_norm(
xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True
)
input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum(
Expand Down
32 changes: 32 additions & 0 deletions deepmd/dpmodel/utils/safe_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Safe versions of some functions that have problematic gradients.
Check https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
for more information.
"""

import array_api_compat


def safe_for_sqrt(x):
"""Safe version of sqrt that has a gradient of 0 at x = 0."""
xp = array_api_compat.array_namespace(x)
mask = x > 0.0
return xp.where(mask, xp.sqrt(xp.where(mask, x, xp.ones_like(x))), xp.zeros_like(x))


def safe_for_vector_norm(x, /, *, axis=None, keepdims=False, ord=2):
"""Safe version of sqrt that has a gradient of 0 at x = 0."""
xp = array_api_compat.array_namespace(x)
mask = xp.sum(xp.square(x), axis=axis, keepdims=True) > 0
if keepdims:
mask_squeezed = mask
else:
mask_squeezed = xp.squeeze(mask, axis=axis)
return xp.where(
mask_squeezed,
xp.linalg.vector_norm(
xp.where(mask, x, xp.ones_like(x)), axis=axis, keepdims=keepdims, ord=ord
),
xp.zeros_like(mask_squeezed, dtype=x.dtype),
)
Loading

0 comments on commit 7aaf284

Please sign in to comment.