Skip to content

Commit

Permalink
feat(jax): zbl
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 1, 2024
1 parent 5c32147 commit a2931b4
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 68 deletions.
47 changes: 26 additions & 21 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.utils.nlist import (
Expand Down Expand Up @@ -69,15 +70,16 @@ def __init__(
self.models = models
sub_model_type_maps = [md.get_type_map() for md in models]
err_msg = []
self.mapping_list = []
mapping_list = []
common_type_map = set(type_map)
self.type_map = type_map
for tpmp in sub_model_type_maps:
if not common_type_map.issubset(set(tpmp)):
err_msg.append(
f"type_map {tpmp} is not a subset of type_map {type_map}"
)
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))
mapping_list.append(self.remap_atype(tpmp, self.type_map))
self.mapping_list = mapping_list
assert len(err_msg) == 0, "\n".join(err_msg)
self.mixed_types_list = [model.mixed_types() for model in self.models]

Expand Down Expand Up @@ -180,8 +182,9 @@ def forward_atomic(
result_dict
the result dict, defined by the fitting net output def.
"""
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.reshape(nframes, -1, 3)
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
nlists = build_multiple_neighbor_list(
extended_coord,
Expand Down Expand Up @@ -212,10 +215,10 @@ def forward_atomic(
aparam,
)["energy"]
)
self.weights = self._compute_weight(extended_coord, extended_atype, nlists_)
weights = self._compute_weight(extended_coord, extended_atype, nlists_)

fit_ret = {
"energy": np.sum(np.stack(ener_list) * np.stack(self.weights), axis=0),
"energy": xp.sum(xp.stack(ener_list) * xp.stack(weights), axis=0),
} # (nframes, nloc, 1)
return fit_ret

Expand Down Expand Up @@ -288,11 +291,12 @@ def _compute_weight(
nlists_: list[np.ndarray],
) -> list[np.ndarray]:
"""This should be a list of user defined weights that matches the number of models to be combined."""
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlists_)
nmodels = len(self.models)
nframes, nloc, _ = nlists_[0].shape
# the dtype of weights is the interface data type.
return [
np.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels
xp.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels
for _ in range(nmodels)
]

Expand Down Expand Up @@ -410,6 +414,7 @@ def _compute_weight(
self.sw_rmax > self.sw_rmin
), "The upper boundary `sw_rmax` must be greater than the lower boundary `sw_rmin`."

xp = array_api_compat.array_namespace(extended_coord, extended_atype)
dp_nlist = nlists_[0]
zbl_nlist = nlists_[1]

Expand All @@ -418,40 +423,40 @@ def _compute_weight(

# use the larger rr based on nlist
nlist_larger = zbl_nlist if zbl_nnei >= dp_nnei else dp_nlist
masked_nlist = np.clip(nlist_larger, 0, None)
masked_nlist = xp.clip(nlist_larger, 0, None)
pairwise_rr = PairTabAtomicModel._get_pairwise_dist(
extended_coord, masked_nlist
)

numerator = np.sum(
np.where(
numerator = xp.sum(
xp.where(
nlist_larger != -1,
pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha),
np.zeros_like(nlist_larger),
pairwise_rr * xp.exp(-pairwise_rr / self.smin_alpha),
xp.zeros_like(nlist_larger),
),
axis=-1,
) # masked nnei will be zero, no need to handle
denominator = np.sum(
np.where(
denominator = xp.sum(
xp.where(
nlist_larger != -1,
np.exp(-pairwise_rr / self.smin_alpha),
np.zeros_like(nlist_larger),
xp.exp(-pairwise_rr / self.smin_alpha),
xp.zeros_like(nlist_larger),
),
axis=-1,
) # handle masked nnei.
with np.errstate(divide="ignore", invalid="ignore"):
sigma = numerator / denominator
u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin)
coef = np.zeros_like(u)
coef = xp.zeros_like(u)
left_mask = sigma < self.sw_rmin
mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax)
right_mask = sigma >= self.sw_rmax
coef[left_mask] = 1
coef = xp.where(left_mask, xp.ones_like(coef), coef)
with np.errstate(invalid="ignore"):
smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1
coef[mid_mask] = smooth[mid_mask]
coef[right_mask] = 0
coef = xp.where(mid_mask, smooth, coef)
coef = xp.where(right_mask, xp.zeros_like(coef), coef)
# to handle masked atoms
coef = np.where(sigma != 0, coef, np.zeros_like(coef))
coef = xp.where(sigma != 0, coef, xp.zeros_like(coef))
self.zbl_weight = coef
return [1 - np.expand_dims(coef, -1), np.expand_dims(coef, -1)]
return [1 - xp.expand_dims(coef, -1), xp.expand_dims(coef, -1)]
74 changes: 44 additions & 30 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
Expand Down Expand Up @@ -74,9 +78,10 @@ def __init__(
self.atom_ener = atom_ener

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
nspline, ntypes_tab = self.tab_info[-2:].astype(int)
self.tab_data = self.tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
tab_info, tab_data = self.tab.get()
nspline, ntypes_tab = tab_info[-2:].astype(int)
self.tab_info = tab_info
self.tab_data = tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
if self.ntypes != ntypes_tab:
raise ValueError(
"The `type_map` provided does not match the number of columns in the table."
Expand Down Expand Up @@ -189,8 +194,9 @@ def forward_atomic(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> dict[str, np.ndarray]:
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.reshape(nframes, -1, 3)
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))

# this will mask all -1 in the nlist
mask = nlist >= 0
Expand All @@ -200,23 +206,21 @@ def forward_atomic(
pairwise_rr = self._get_pairwise_dist(
extended_coord, masked_nlist
) # (nframes, nloc, nnei)
self.tab_data = self.tab_data.reshape(
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

# (nframes, nloc, nnei), index type is int64.
j_type = extended_atype[
np.arange(extended_atype.shape[0], dtype=np.int64)[:, None, None],
xp.arange(extended_atype.shape[0], dtype=xp.int64)[:, None, None],
masked_nlist,
]

raw_atomic_energy = self._pair_tabulated_inter(
nlist, atype, j_type, pairwise_rr
)
atomic_energy = 0.5 * np.sum(
np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)),
atomic_energy = 0.5 * xp.sum(
xp.where(nlist != -1, raw_atomic_energy, xp.zeros_like(raw_atomic_energy)),
axis=-1,
).reshape(nframes, nloc, 1)
)
atomic_energy = xp.reshape(atomic_energy, (nframes, nloc, 1))

return {"energy": atomic_energy}

Expand Down Expand Up @@ -255,36 +259,42 @@ def _pair_tabulated_inter(
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""
xp = array_api_compat.array_namespace(nlist, i_type, j_type, rr)
nframes, nloc, nnei = nlist.shape
rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1.0 / hh

nspline = int(self.tab_info[2] + 0.1)
# jax jit does not support convert to a Python int, so we need to convert to xp.int64.
nspline = (self.tab_info[2] + 0.1).astype(xp.int64)

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
uu = np.where(nlist != -1, uu, nspline + 1)
uu = xp.where(nlist != -1, uu, nspline + 1)

if np.any(uu < 0):
raise Exception("coord go beyond table lower boundary")
# unsupported by jax
# if xp.any(uu < 0):
# raise Exception("coord go beyond table lower boundary")

idx = uu.astype(int)
idx = xp.astype(uu, xp.int64)

uu -= idx
table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, nspline
)
table_coef = table_coef.reshape(nframes, nloc, nnei, 4)
table_coef = xp.reshape(table_coef, (nframes, nloc, nnei, 4))
ener = self._calculate_ener(table_coef, uu)
# here we need to overwrite energy to zero at rcut and beyond.
mask_beyond_rcut = rr >= self.rcut
# also overwrite values beyond extrapolation to zero
extrapolation_mask = rr >= self.tab.rmin + nspline * self.tab.hh
ener[mask_beyond_rcut] = 0
ener[extrapolation_mask] = 0
ener = xp.where(
xp.logical_or(mask_beyond_rcut, extrapolation_mask),
xp.zeros_like(ener),
ener,
)

return ener

Expand All @@ -304,12 +314,13 @@ def _get_pairwise_dist(coords: np.ndarray, nlist: np.ndarray) -> np.ndarray:
np.ndarray
The pairwise distance between the atoms (nframes, nloc, nnei).
"""
xp = array_api_compat.array_namespace(coords, nlist)
# index type is int64
batch_indices = np.arange(nlist.shape[0], dtype=np.int64)[:, None, None]
batch_indices = xp.arange(nlist.shape[0], dtype=xp.int64)[:, None, None]
neighbor_atoms = coords[batch_indices, nlist]
loc_atoms = coords[:, : nlist.shape[1], :]
pairwise_dr = loc_atoms[:, :, None, :] - neighbor_atoms
pairwise_rr = np.sqrt(np.sum(np.power(pairwise_dr, 2), axis=-1))
pairwise_rr = xp.sqrt(xp.sum(xp.power(pairwise_dr, 2), axis=-1))

return pairwise_rr

Expand All @@ -319,7 +330,7 @@ def _extract_spline_coefficient(
j_type: np.ndarray,
idx: np.ndarray,
tab_data: np.ndarray,
nspline: int,
nspline: np.int64,
) -> np.ndarray:
"""Extract the spline coefficient from the table.
Expand All @@ -341,28 +352,31 @@ def _extract_spline_coefficient(
np.ndarray
The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed.
"""
xp = array_api_compat.array_namespace(i_type, j_type, idx, tab_data)
# (nframes, nloc, nnei)
expanded_i_type = np.broadcast_to(
i_type[:, :, np.newaxis],
expanded_i_type = xp.broadcast_to(
i_type[:, :, xp.newaxis],
(i_type.shape[0], i_type.shape[1], j_type.shape[-1]),
)

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = np.broadcast_to(
idx[..., np.newaxis, np.newaxis], (*idx.shape, 1, 4)
expanded_idx = xp.broadcast_to(
idx[..., xp.newaxis, xp.newaxis], (*idx.shape, 1, 4)
)
clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int)
clipped_indices = xp.clip(expanded_idx, 0, nspline - 1).astype(int)

# (nframes, nloc, nnei, 4)
final_coef = np.squeeze(
np.take_along_axis(expanded_tab_data, clipped_indices, 3)
final_coef = xp.squeeze(
xp_take_along_axis(expanded_tab_data, clipped_indices, 3)
)

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() > nspline] = 0
final_coef = xp.where(
expanded_idx.squeeze() > nspline, xp.zeros_like(final_coef), final_coef
)
return final_coef

@staticmethod
Expand Down
11 changes: 7 additions & 4 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,14 @@ def _format_nlist(
index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2)
coord1 = xp.take_along_axis(extended_coord, index, axis=1)
coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3)
rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1)
rr = xp.where(m_real_nei, rr, float("inf"))
rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1)
# jax raises NaN error using norm
# but note: we don't actually need to sqrt here; the squared value is enough
# rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1)
rr2 = xp.sum(xp.square(coord0[:, :, None, :] - coord1), axis=-1)
rr2 = xp.where(m_real_nei, rr2, float("inf"))
rr2, ret_mapping = xp.sort(rr2, axis=-1), xp.argsort(rr2, axis=-1)
ret = xp.take_along_axis(ret, ret_mapping, axis=2)
ret = xp.where(rr > rcut, -1, ret)
ret = xp.where(rr2 > rcut * rcut, -1, ret)
ret = ret[..., :nnei]
# not extra_nlist_sort and n_nnei <= nnei:
elif n_nnei == nnei:
Expand Down
28 changes: 17 additions & 11 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,30 +215,36 @@ def build_multiple_neighbor_list(
value being the corresponding nlist.
"""
xp = array_api_compat.array_namespace(coord, nlist)
assert len(rcuts) == len(nsels)
if len(rcuts) == 0:
return {}
nb, nloc, nsel = nlist.shape
if nsel < nsels[-1]:
pad = -1 * np.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype)
nlist = np.concatenate([nlist, pad], axis=-1)
pad = -1 * xp.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype)
nlist = xp.concat([nlist, pad], axis=-1)
nsel = nsels[-1]
coord1 = coord.reshape(nb, -1, 3)
coord1 = xp.reshape(coord, (nb, -1, 3))
nall = coord1.shape[1]
coord0 = coord1[:, :nloc, :]
nlist_mask = nlist == -1
tnlist_0 = nlist.copy()
tnlist_0[nlist_mask] = 0
index = np.tile(tnlist_0.reshape(nb, nloc * nsel, 1), [1, 1, 3])
coord2 = np.take_along_axis(coord1, index, axis=1).reshape(nb, nloc, nsel, 3)
tnlist_0 = xp.where(nlist_mask, xp.zeros_like(nlist), nlist)
index = xp.tile(xp.reshape(tnlist_0, (nb, nloc * nsel, 1)), (1, 1, 3))
coord2 = xp_take_along_axis(coord1, index, axis=1)
coord2 = xp.reshape(coord2, (nb, nloc, nsel, 3))
diff = coord2 - coord0[:, :, None, :]
rr = np.linalg.norm(diff, axis=-1)
rr = np.where(nlist_mask, float("inf"), rr)
# jax raises NaN error using norm
# but note: we don't actually need to sqrt here; the squared value is enough
# rr = xp.linalg.vector_norm(diff, axis=-1)
rr2 = xp.sum(xp.square(diff), axis=-1)
rr2 = xp.where(nlist_mask, xp.full_like(rr2, float("inf")), rr2)
nlist0 = nlist
ret = {}
for rc, ns in zip(rcuts[::-1], nsels[::-1]):
tnlist_1 = np.copy(nlist0[:, :, :ns])
tnlist_1[rr[:, :, :ns] > rc] = -1
tnlist_1 = nlist0[:, :, :ns]
tnlist_1 = xp.where(
rr2[:, :, :ns] > rc * rc, xp.full_like(tnlist_1, -1), tnlist_1
)
ret[get_multiple_nlist_key(rc, ns)] = tnlist_1
return ret

Expand Down
Loading

0 comments on commit a2931b4

Please sign in to comment.