Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): zbl #4301

Merged
merged 6 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Union,
)

import array_api_compat
njzjz marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np

from deepmd.dpmodel.utils.nlist import (
Expand Down Expand Up @@ -69,15 +70,16 @@
self.models = models
sub_model_type_maps = [md.get_type_map() for md in models]
err_msg = []
self.mapping_list = []
mapping_list = []
common_type_map = set(type_map)
self.type_map = type_map
for tpmp in sub_model_type_maps:
if not common_type_map.issubset(set(tpmp)):
err_msg.append(
f"type_map {tpmp} is not a subset of type_map {type_map}"
)
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))
mapping_list.append(self.remap_atype(tpmp, self.type_map))
self.mapping_list = mapping_list
assert len(err_msg) == 0, "\n".join(err_msg)
self.mixed_types_list = [model.mixed_types() for model in self.models]

Expand Down Expand Up @@ -212,8 +214,9 @@
result_dict
the result dict, defined by the fitting net output def.
"""
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.reshape(nframes, -1, 3)
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
nlists = build_multiple_neighbor_list(
extended_coord,
Expand Down Expand Up @@ -244,10 +247,10 @@
aparam,
)["energy"]
)
self.weights = self._compute_weight(extended_coord, extended_atype, nlists_)
weights = self._compute_weight(extended_coord, extended_atype, nlists_)

fit_ret = {
"energy": np.sum(np.stack(ener_list) * np.stack(self.weights), axis=0),
"energy": xp.sum(xp.stack(ener_list) * xp.stack(weights), axis=0),
} # (nframes, nloc, 1)
return fit_ret

Expand Down Expand Up @@ -320,11 +323,12 @@
nlists_: list[np.ndarray],
) -> list[np.ndarray]:
"""This should be a list of user defined weights that matches the number of models to be combined."""
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlists_)

Check warning on line 326 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L326

Added line #L326 was not covered by tests
nmodels = len(self.models)
nframes, nloc, _ = nlists_[0].shape
# the dtype of weights is the interface data type.
return [
np.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels
xp.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels
njzjz marked this conversation as resolved.
Show resolved Hide resolved
for _ in range(nmodels)
]

Expand Down Expand Up @@ -442,6 +446,7 @@
self.sw_rmax > self.sw_rmin
), "The upper boundary `sw_rmax` must be greater than the lower boundary `sw_rmin`."

xp = array_api_compat.array_namespace(extended_coord, extended_atype)
dp_nlist = nlists_[0]
zbl_nlist = nlists_[1]

Expand All @@ -450,40 +455,40 @@

# use the larger rr based on nlist
nlist_larger = zbl_nlist if zbl_nnei >= dp_nnei else dp_nlist
masked_nlist = np.clip(nlist_larger, 0, None)
masked_nlist = xp.clip(nlist_larger, 0, None)
pairwise_rr = PairTabAtomicModel._get_pairwise_dist(
extended_coord, masked_nlist
)

numerator = np.sum(
np.where(
numerator = xp.sum(
xp.where(
nlist_larger != -1,
pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha),
np.zeros_like(nlist_larger),
pairwise_rr * xp.exp(-pairwise_rr / self.smin_alpha),
xp.zeros_like(nlist_larger),
),
axis=-1,
) # masked nnei will be zero, no need to handle
denominator = np.sum(
np.where(
denominator = xp.sum(
xp.where(
nlist_larger != -1,
np.exp(-pairwise_rr / self.smin_alpha),
np.zeros_like(nlist_larger),
xp.exp(-pairwise_rr / self.smin_alpha),
xp.zeros_like(nlist_larger),
),
axis=-1,
njzjz marked this conversation as resolved.
Show resolved Hide resolved
) # handle masked nnei.
with np.errstate(divide="ignore", invalid="ignore"):
sigma = numerator / denominator
u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin)
coef = np.zeros_like(u)
coef = xp.zeros_like(u)
left_mask = sigma < self.sw_rmin
mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax)
right_mask = sigma >= self.sw_rmax
coef[left_mask] = 1
coef = xp.where(left_mask, xp.ones_like(coef), coef)
with np.errstate(invalid="ignore"):
smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1
coef[mid_mask] = smooth[mid_mask]
coef[right_mask] = 0
coef = xp.where(mid_mask, smooth, coef)
coef = xp.where(right_mask, xp.zeros_like(coef), coef)
# to handle masked atoms
coef = np.where(sigma != 0, coef, np.zeros_like(coef))
coef = xp.where(sigma != 0, coef, xp.zeros_like(coef))
self.zbl_weight = coef
return [1 - np.expand_dims(coef, -1), np.expand_dims(coef, -1)]
return [1 - xp.expand_dims(coef, -1), xp.expand_dims(coef, -1)]
77 changes: 47 additions & 30 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
)
from deepmd.dpmodel.utils.safe_gradient import (
safe_for_sqrt,
)
from deepmd.utils.pair_tab import (
PairTab,
)
Expand Down Expand Up @@ -74,9 +81,10 @@
self.atom_ener = atom_ener

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
nspline, ntypes_tab = self.tab_info[-2:].astype(int)
self.tab_data = self.tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
tab_info, tab_data = self.tab.get()
nspline, ntypes_tab = tab_info[-2:].astype(int)
self.tab_info = tab_info
self.tab_data = tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
if self.ntypes != ntypes_tab:
raise ValueError(
"The `type_map` provided does not match the number of columns in the table."
Expand Down Expand Up @@ -189,8 +197,9 @@
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> dict[str, np.ndarray]:
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.reshape(nframes, -1, 3)
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))

# this will mask all -1 in the nlist
mask = nlist >= 0
Expand All @@ -200,23 +209,21 @@
pairwise_rr = self._get_pairwise_dist(
extended_coord, masked_nlist
) # (nframes, nloc, nnei)
self.tab_data = self.tab_data.reshape(
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

# (nframes, nloc, nnei), index type is int64.
j_type = extended_atype[
np.arange(extended_atype.shape[0], dtype=np.int64)[:, None, None],
xp.arange(extended_atype.shape[0], dtype=xp.int64)[:, None, None],
masked_nlist,
]

raw_atomic_energy = self._pair_tabulated_inter(
nlist, atype, j_type, pairwise_rr
)
atomic_energy = 0.5 * np.sum(
np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)),
atomic_energy = 0.5 * xp.sum(
xp.where(nlist != -1, raw_atomic_energy, xp.zeros_like(raw_atomic_energy)),
axis=-1,
).reshape(nframes, nloc, 1)
)
atomic_energy = xp.reshape(atomic_energy, (nframes, nloc, 1))

return {"energy": atomic_energy}

Expand Down Expand Up @@ -255,36 +262,42 @@
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""
xp = array_api_compat.array_namespace(nlist, i_type, j_type, rr)
nframes, nloc, nnei = nlist.shape
rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1.0 / hh

nspline = int(self.tab_info[2] + 0.1)
# jax jit does not support convert to a Python int, so we need to convert to xp.int64.
nspline = (self.tab_info[2] + 0.1).astype(xp.int64)
njzjz marked this conversation as resolved.
Show resolved Hide resolved

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
uu = np.where(nlist != -1, uu, nspline + 1)
uu = xp.where(nlist != -1, uu, nspline + 1)

if np.any(uu < 0):
raise Exception("coord go beyond table lower boundary")
# unsupported by jax
# if xp.any(uu < 0):
# raise Exception("coord go beyond table lower boundary")
Comment on lines +281 to +282

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
njzjz marked this conversation as resolved.
Show resolved Hide resolved

idx = uu.astype(int)
idx = xp.astype(uu, xp.int64)

uu -= idx
table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, nspline
)
table_coef = table_coef.reshape(nframes, nloc, nnei, 4)
table_coef = xp.reshape(table_coef, (nframes, nloc, nnei, 4))
ener = self._calculate_ener(table_coef, uu)
# here we need to overwrite energy to zero at rcut and beyond.
mask_beyond_rcut = rr >= self.rcut
# also overwrite values beyond extrapolation to zero
extrapolation_mask = rr >= self.tab.rmin + nspline * self.tab.hh
ener[mask_beyond_rcut] = 0
ener[extrapolation_mask] = 0
ener = xp.where(
xp.logical_or(mask_beyond_rcut, extrapolation_mask),
xp.zeros_like(ener),
ener,
)

return ener

Expand All @@ -304,12 +317,13 @@
np.ndarray
The pairwise distance between the atoms (nframes, nloc, nnei).
"""
xp = array_api_compat.array_namespace(coords, nlist)
# index type is int64
batch_indices = np.arange(nlist.shape[0], dtype=np.int64)[:, None, None]
batch_indices = xp.arange(nlist.shape[0], dtype=xp.int64)[:, None, None]
neighbor_atoms = coords[batch_indices, nlist]
loc_atoms = coords[:, : nlist.shape[1], :]
pairwise_dr = loc_atoms[:, :, None, :] - neighbor_atoms
pairwise_rr = np.sqrt(np.sum(np.power(pairwise_dr, 2), axis=-1))
pairwise_rr = safe_for_sqrt(xp.sum(xp.power(pairwise_dr, 2), axis=-1))

return pairwise_rr

Expand All @@ -319,7 +333,7 @@
j_type: np.ndarray,
idx: np.ndarray,
tab_data: np.ndarray,
nspline: int,
nspline: np.int64,
) -> np.ndarray:
"""Extract the spline coefficient from the table.

Expand All @@ -341,28 +355,31 @@
np.ndarray
The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed.
"""
xp = array_api_compat.array_namespace(i_type, j_type, idx, tab_data)
# (nframes, nloc, nnei)
expanded_i_type = np.broadcast_to(
i_type[:, :, np.newaxis],
expanded_i_type = xp.broadcast_to(
i_type[:, :, xp.newaxis],
(i_type.shape[0], i_type.shape[1], j_type.shape[-1]),
)

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = np.broadcast_to(
idx[..., np.newaxis, np.newaxis], (*idx.shape, 1, 4)
expanded_idx = xp.broadcast_to(
idx[..., xp.newaxis, xp.newaxis], (*idx.shape, 1, 4)
)
clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int)
clipped_indices = xp.clip(expanded_idx, 0, nspline - 1).astype(int)

# (nframes, nloc, nnei, 4)
final_coef = np.squeeze(
np.take_along_axis(expanded_tab_data, clipped_indices, 3)
final_coef = xp.squeeze(
xp_take_along_axis(expanded_tab_data, clipped_indices, 3)
)

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() > nspline] = 0
final_coef = xp.where(
expanded_idx.squeeze() > nspline, xp.zeros_like(final_coef), final_coef
)
return final_coef

@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
LayerNorm,
NativeLayer,
)
from deepmd.dpmodel.utils.safe_gradient import (
safe_for_vector_norm,
)
from deepmd.dpmodel.utils.seed import (
child_seed,
)
Expand Down Expand Up @@ -943,7 +946,7 @@ def call(
else:
raise NotImplementedError

normed = xp.linalg.vector_norm(
normed = safe_for_vector_norm(
njzjz marked this conversation as resolved.
Show resolved Hide resolved
xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True
)
input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum(
Expand Down
32 changes: 32 additions & 0 deletions deepmd/dpmodel/utils/safe_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Safe versions of some functions that have problematic gradients.
Check https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
for more information.
"""

import array_api_compat


def safe_for_sqrt(x):
"""Safe version of sqrt that has a gradient of 0 at x = 0."""
xp = array_api_compat.array_namespace(x)
mask = x > 0.0
return xp.where(mask, xp.sqrt(xp.where(mask, x, xp.ones_like(x))), xp.zeros_like(x))
njzjz marked this conversation as resolved.
Show resolved Hide resolved


def safe_for_vector_norm(x, /, *, axis=None, keepdims=False, ord=2):
"""Safe version of sqrt that has a gradient of 0 at x = 0."""
njzjz marked this conversation as resolved.
Show resolved Hide resolved
xp = array_api_compat.array_namespace(x)
mask = xp.sum(xp.square(x), axis=axis, keepdims=True) > 0
if keepdims:
mask_squeezed = mask
else:
mask_squeezed = xp.squeeze(mask, axis=axis)

Check warning on line 25 in deepmd/dpmodel/utils/safe_gradient.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/safe_gradient.py#L25

Added line #L25 was not covered by tests
return xp.where(
mask_squeezed,
xp.linalg.vector_norm(
xp.where(mask, x, xp.ones_like(x)), axis=axis, keepdims=keepdims, ord=ord
),
xp.zeros_like(mask_squeezed, dtype=x.dtype),
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Loading