From a2931b4d43ac8127fe35b4dae32796286789986b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 1 Nov 2024 17:23:04 -0400 Subject: [PATCH 1/5] feat(jax): zbl Signed-off-by: Jinzhe Zeng --- .../atomic_model/linear_atomic_model.py | 47 ++++++------ .../atomic_model/pairtab_atomic_model.py | 74 +++++++++++-------- deepmd/dpmodel/model/make_model.py | 11 ++- deepmd/dpmodel/utils/nlist.py | 28 ++++--- .../jax/atomic_model/linear_atomic_model.py | 38 ++++++++++ .../jax/atomic_model/pairtab_atomic_model.py | 27 +++++++ deepmd/jax/model/__init__.py | 8 +- deepmd/jax/model/dp_zbl_model.py | 50 +++++++++++++ deepmd/jax/model/model.py | 53 +++++++++++++ .../tests/consistent/model/test_zbl_ener.py | 11 ++- 10 files changed, 279 insertions(+), 68 deletions(-) create mode 100644 deepmd/jax/atomic_model/linear_atomic_model.py create mode 100644 deepmd/jax/atomic_model/pairtab_atomic_model.py create mode 100644 deepmd/jax/model/dp_zbl_model.py diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 224fdd145c..08bb249812 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -5,6 +5,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel.utils.nlist import ( @@ -69,7 +70,7 @@ def __init__( self.models = models sub_model_type_maps = [md.get_type_map() for md in models] err_msg = [] - self.mapping_list = [] + mapping_list = [] common_type_map = set(type_map) self.type_map = type_map for tpmp in sub_model_type_maps: @@ -77,7 +78,8 @@ def __init__( err_msg.append( f"type_map {tpmp} is not a subset of type_map {type_map}" ) - self.mapping_list.append(self.remap_atype(tpmp, self.type_map)) + mapping_list.append(self.remap_atype(tpmp, self.type_map)) + self.mapping_list = mapping_list assert len(err_msg) == 0, "\n".join(err_msg) self.mixed_types_list = [model.mixed_types() for model in self.models] @@ -180,8 +182,9 @@ def forward_atomic( result_dict the result dict, defined by the fitting net output def. """ + xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) nframes, nloc, nnei = nlist.shape - extended_coord = extended_coord.reshape(nframes, -1, 3) + extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) sorted_rcuts, sorted_sels = self._sort_rcuts_sels() nlists = build_multiple_neighbor_list( extended_coord, @@ -212,10 +215,10 @@ def forward_atomic( aparam, )["energy"] ) - self.weights = self._compute_weight(extended_coord, extended_atype, nlists_) + weights = self._compute_weight(extended_coord, extended_atype, nlists_) fit_ret = { - "energy": np.sum(np.stack(ener_list) * np.stack(self.weights), axis=0), + "energy": xp.sum(xp.stack(ener_list) * xp.stack(weights), axis=0), } # (nframes, nloc, 1) return fit_ret @@ -288,11 +291,12 @@ def _compute_weight( nlists_: list[np.ndarray], ) -> list[np.ndarray]: """This should be a list of user defined weights that matches the number of models to be combined.""" + xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlists_) nmodels = len(self.models) nframes, nloc, _ = nlists_[0].shape # the dtype of weights is the interface data type. return [ - np.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels + xp.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels for _ in range(nmodels) ] @@ -410,6 +414,7 @@ def _compute_weight( self.sw_rmax > self.sw_rmin ), "The upper boundary `sw_rmax` must be greater than the lower boundary `sw_rmin`." + xp = array_api_compat.array_namespace(extended_coord, extended_atype) dp_nlist = nlists_[0] zbl_nlist = nlists_[1] @@ -418,40 +423,40 @@ def _compute_weight( # use the larger rr based on nlist nlist_larger = zbl_nlist if zbl_nnei >= dp_nnei else dp_nlist - masked_nlist = np.clip(nlist_larger, 0, None) + masked_nlist = xp.clip(nlist_larger, 0, None) pairwise_rr = PairTabAtomicModel._get_pairwise_dist( extended_coord, masked_nlist ) - numerator = np.sum( - np.where( + numerator = xp.sum( + xp.where( nlist_larger != -1, - pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha), - np.zeros_like(nlist_larger), + pairwise_rr * xp.exp(-pairwise_rr / self.smin_alpha), + xp.zeros_like(nlist_larger), ), axis=-1, ) # masked nnei will be zero, no need to handle - denominator = np.sum( - np.where( + denominator = xp.sum( + xp.where( nlist_larger != -1, - np.exp(-pairwise_rr / self.smin_alpha), - np.zeros_like(nlist_larger), + xp.exp(-pairwise_rr / self.smin_alpha), + xp.zeros_like(nlist_larger), ), axis=-1, ) # handle masked nnei. with np.errstate(divide="ignore", invalid="ignore"): sigma = numerator / denominator u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin) - coef = np.zeros_like(u) + coef = xp.zeros_like(u) left_mask = sigma < self.sw_rmin mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax) right_mask = sigma >= self.sw_rmax - coef[left_mask] = 1 + coef = xp.where(left_mask, xp.ones_like(coef), coef) with np.errstate(invalid="ignore"): smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1 - coef[mid_mask] = smooth[mid_mask] - coef[right_mask] = 0 + coef = xp.where(mid_mask, smooth, coef) + coef = xp.where(right_mask, xp.zeros_like(coef), coef) # to handle masked atoms - coef = np.where(sigma != 0, coef, np.zeros_like(coef)) + coef = xp.where(sigma != 0, coef, xp.zeros_like(coef)) self.zbl_weight = coef - return [1 - np.expand_dims(coef, -1), np.expand_dims(coef, -1)] + return [1 - xp.expand_dims(coef, -1), xp.expand_dims(coef, -1)] diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 2899f106bc..7def899c38 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -5,8 +5,12 @@ Union, ) +import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, @@ -74,9 +78,10 @@ def __init__( self.atom_ener = atom_ener if self.tab_file is not None: - self.tab_info, self.tab_data = self.tab.get() - nspline, ntypes_tab = self.tab_info[-2:].astype(int) - self.tab_data = self.tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4) + tab_info, tab_data = self.tab.get() + nspline, ntypes_tab = tab_info[-2:].astype(int) + self.tab_info = tab_info + self.tab_data = tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4) if self.ntypes != ntypes_tab: raise ValueError( "The `type_map` provided does not match the number of columns in the table." @@ -189,8 +194,9 @@ def forward_atomic( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, ) -> dict[str, np.ndarray]: + xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) nframes, nloc, nnei = nlist.shape - extended_coord = extended_coord.reshape(nframes, -1, 3) + extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) # this will mask all -1 in the nlist mask = nlist >= 0 @@ -200,23 +206,21 @@ def forward_atomic( pairwise_rr = self._get_pairwise_dist( extended_coord, masked_nlist ) # (nframes, nloc, nnei) - self.tab_data = self.tab_data.reshape( - self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 - ) # (nframes, nloc, nnei), index type is int64. j_type = extended_atype[ - np.arange(extended_atype.shape[0], dtype=np.int64)[:, None, None], + xp.arange(extended_atype.shape[0], dtype=xp.int64)[:, None, None], masked_nlist, ] raw_atomic_energy = self._pair_tabulated_inter( nlist, atype, j_type, pairwise_rr ) - atomic_energy = 0.5 * np.sum( - np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)), + atomic_energy = 0.5 * xp.sum( + xp.where(nlist != -1, raw_atomic_energy, xp.zeros_like(raw_atomic_energy)), axis=-1, - ).reshape(nframes, nloc, 1) + ) + atomic_energy = xp.reshape(atomic_energy, (nframes, nloc, 1)) return {"energy": atomic_energy} @@ -255,36 +259,42 @@ def _pair_tabulated_inter( This function is used to calculate the pairwise energy between two atoms. It uses a table containing cubic spline coefficients calculated in PairTab. """ + xp = array_api_compat.array_namespace(nlist, i_type, j_type, rr) nframes, nloc, nnei = nlist.shape rmin = self.tab_info[0] hh = self.tab_info[1] hi = 1.0 / hh - nspline = int(self.tab_info[2] + 0.1) + # jax jit does not support convert to a Python int, so we need to convert to xp.int64. + nspline = (self.tab_info[2] + 0.1).astype(xp.int64) uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei) # if nnei of atom 0 has -1 in the nlist, uu would be 0. # this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms. - uu = np.where(nlist != -1, uu, nspline + 1) + uu = xp.where(nlist != -1, uu, nspline + 1) - if np.any(uu < 0): - raise Exception("coord go beyond table lower boundary") + # unsupported by jax + # if xp.any(uu < 0): + # raise Exception("coord go beyond table lower boundary") - idx = uu.astype(int) + idx = xp.astype(uu, xp.int64) uu -= idx table_coef = self._extract_spline_coefficient( i_type, j_type, idx, self.tab_data, nspline ) - table_coef = table_coef.reshape(nframes, nloc, nnei, 4) + table_coef = xp.reshape(table_coef, (nframes, nloc, nnei, 4)) ener = self._calculate_ener(table_coef, uu) # here we need to overwrite energy to zero at rcut and beyond. mask_beyond_rcut = rr >= self.rcut # also overwrite values beyond extrapolation to zero extrapolation_mask = rr >= self.tab.rmin + nspline * self.tab.hh - ener[mask_beyond_rcut] = 0 - ener[extrapolation_mask] = 0 + ener = xp.where( + xp.logical_or(mask_beyond_rcut, extrapolation_mask), + xp.zeros_like(ener), + ener, + ) return ener @@ -304,12 +314,13 @@ def _get_pairwise_dist(coords: np.ndarray, nlist: np.ndarray) -> np.ndarray: np.ndarray The pairwise distance between the atoms (nframes, nloc, nnei). """ + xp = array_api_compat.array_namespace(coords, nlist) # index type is int64 - batch_indices = np.arange(nlist.shape[0], dtype=np.int64)[:, None, None] + batch_indices = xp.arange(nlist.shape[0], dtype=xp.int64)[:, None, None] neighbor_atoms = coords[batch_indices, nlist] loc_atoms = coords[:, : nlist.shape[1], :] pairwise_dr = loc_atoms[:, :, None, :] - neighbor_atoms - pairwise_rr = np.sqrt(np.sum(np.power(pairwise_dr, 2), axis=-1)) + pairwise_rr = xp.sqrt(xp.sum(xp.power(pairwise_dr, 2), axis=-1)) return pairwise_rr @@ -319,7 +330,7 @@ def _extract_spline_coefficient( j_type: np.ndarray, idx: np.ndarray, tab_data: np.ndarray, - nspline: int, + nspline: np.int64, ) -> np.ndarray: """Extract the spline coefficient from the table. @@ -341,9 +352,10 @@ def _extract_spline_coefficient( np.ndarray The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed. """ + xp = array_api_compat.array_namespace(i_type, j_type, idx, tab_data) # (nframes, nloc, nnei) - expanded_i_type = np.broadcast_to( - i_type[:, :, np.newaxis], + expanded_i_type = xp.broadcast_to( + i_type[:, :, xp.newaxis], (i_type.shape[0], i_type.shape[1], j_type.shape[-1]), ) @@ -351,18 +363,20 @@ def _extract_spline_coefficient( expanded_tab_data = tab_data[expanded_i_type, j_type] # (nframes, nloc, nnei, 1, 4) - expanded_idx = np.broadcast_to( - idx[..., np.newaxis, np.newaxis], (*idx.shape, 1, 4) + expanded_idx = xp.broadcast_to( + idx[..., xp.newaxis, xp.newaxis], (*idx.shape, 1, 4) ) - clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int) + clipped_indices = xp.clip(expanded_idx, 0, nspline - 1).astype(int) # (nframes, nloc, nnei, 4) - final_coef = np.squeeze( - np.take_along_axis(expanded_tab_data, clipped_indices, 3) + final_coef = xp.squeeze( + xp_take_along_axis(expanded_tab_data, clipped_indices, 3) ) # when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`. - final_coef[expanded_idx.squeeze() > nspline] = 0 + final_coef = xp.where( + expanded_idx.squeeze() > nspline, xp.zeros_like(final_coef), final_coef + ) return final_coef @staticmethod diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index b6379573e1..9803b224af 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -475,11 +475,14 @@ def _format_nlist( index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2) coord1 = xp.take_along_axis(extended_coord, index, axis=1) coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3) - rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) - rr = xp.where(m_real_nei, rr, float("inf")) - rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1) + # jax raises NaN error using norm + # but note: we don't actually need to sqrt here; the squared value is enough + # rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) + rr2 = xp.sum(xp.square(coord0[:, :, None, :] - coord1), axis=-1) + rr2 = xp.where(m_real_nei, rr2, float("inf")) + rr2, ret_mapping = xp.sort(rr2, axis=-1), xp.argsort(rr2, axis=-1) ret = xp.take_along_axis(ret, ret_mapping, axis=2) - ret = xp.where(rr > rcut, -1, ret) + ret = xp.where(rr2 > rcut * rcut, -1, ret) ret = ret[..., :nnei] # not extra_nlist_sort and n_nnei <= nnei: elif n_nnei == nnei: diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index b827032588..239260ff1f 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -215,30 +215,36 @@ def build_multiple_neighbor_list( value being the corresponding nlist. """ + xp = array_api_compat.array_namespace(coord, nlist) assert len(rcuts) == len(nsels) if len(rcuts) == 0: return {} nb, nloc, nsel = nlist.shape if nsel < nsels[-1]: - pad = -1 * np.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype) - nlist = np.concatenate([nlist, pad], axis=-1) + pad = -1 * xp.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype) + nlist = xp.concat([nlist, pad], axis=-1) nsel = nsels[-1] - coord1 = coord.reshape(nb, -1, 3) + coord1 = xp.reshape(coord, (nb, -1, 3)) nall = coord1.shape[1] coord0 = coord1[:, :nloc, :] nlist_mask = nlist == -1 - tnlist_0 = nlist.copy() - tnlist_0[nlist_mask] = 0 - index = np.tile(tnlist_0.reshape(nb, nloc * nsel, 1), [1, 1, 3]) - coord2 = np.take_along_axis(coord1, index, axis=1).reshape(nb, nloc, nsel, 3) + tnlist_0 = xp.where(nlist_mask, xp.zeros_like(nlist), nlist) + index = xp.tile(xp.reshape(tnlist_0, (nb, nloc * nsel, 1)), (1, 1, 3)) + coord2 = xp_take_along_axis(coord1, index, axis=1) + coord2 = xp.reshape(coord2, (nb, nloc, nsel, 3)) diff = coord2 - coord0[:, :, None, :] - rr = np.linalg.norm(diff, axis=-1) - rr = np.where(nlist_mask, float("inf"), rr) + # jax raises NaN error using norm + # but note: we don't actually need to sqrt here; the squared value is enough + # rr = xp.linalg.vector_norm(diff, axis=-1) + rr2 = xp.sum(xp.square(diff), axis=-1) + rr2 = xp.where(nlist_mask, xp.full_like(rr2, float("inf")), rr2) nlist0 = nlist ret = {} for rc, ns in zip(rcuts[::-1], nsels[::-1]): - tnlist_1 = np.copy(nlist0[:, :, :ns]) - tnlist_1[rr[:, :, :ns] > rc] = -1 + tnlist_1 = nlist0[:, :, :ns] + tnlist_1 = xp.where( + rr2[:, :, :ns] > rc * rc, xp.full_like(tnlist_1, -1), tnlist_1 + ) ret[get_multiple_nlist_key(rc, ns)] = tnlist_1 return ret diff --git a/deepmd/jax/atomic_model/linear_atomic_model.py b/deepmd/jax/atomic_model/linear_atomic_model.py new file mode 100644 index 0000000000..c86e3ef02c --- /dev/null +++ b/deepmd/jax/atomic_model/linear_atomic_model.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel as DPZBLLinearEnergyAtomicModelDP, +) +from deepmd.jax.atomic_model.base_atomic_model import ( + base_atomic_model_set_attr, +) +from deepmd.jax.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.jax.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, +) +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) + + +@flax_module +class DPZBLLinearEnergyAtomicModel(DPZBLLinearEnergyAtomicModelDP): + def __setattr__(self, name: str, value: Any) -> None: + value = base_atomic_model_set_attr(name, value) + if name == "mapping_list": + value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value] + elif name == "zbl_weight": + value = ArrayAPIVariable(to_jax_array(value)) + elif name == "models": + value = [ + DPAtomicModel.deserialize(value[0].serialize()), + PairTabAtomicModel.deserialize(value[1].serialize()), + ] + return super().__setattr__(name, value) diff --git a/deepmd/jax/atomic_model/pairtab_atomic_model.py b/deepmd/jax/atomic_model/pairtab_atomic_model.py new file mode 100644 index 0000000000..2401362b63 --- /dev/null +++ b/deepmd/jax/atomic_model/pairtab_atomic_model.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel as PairTabAtomicModelDP, +) +from deepmd.jax.atomic_model.base_atomic_model import ( + base_atomic_model_set_attr, +) +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) + + +@flax_module +class PairTabAtomicModel(PairTabAtomicModelDP): + def __setattr__(self, name: str, value: Any) -> None: + value = base_atomic_model_set_attr(name, value) + if name in {"tab_info", "tab_data"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + return super().__setattr__(name, value) diff --git a/deepmd/jax/model/__init__.py b/deepmd/jax/model/__init__.py index 05a60c4ffe..bba5bc766a 100644 --- a/deepmd/jax/model/__init__.py +++ b/deepmd/jax/model/__init__.py @@ -1,6 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .dp_zbl_model import ( + DPZBLLinearEnergyAtomicModel, +) from .ener_model import ( EnergyModel, ) -__all__ = ["EnergyModel"] +__all__ = [ + "EnergyModel", + "DPZBLLinearEnergyAtomicModel", +] diff --git a/deepmd/jax/model/dp_zbl_model.py b/deepmd/jax/model/dp_zbl_model.py new file mode 100644 index 0000000000..028fa8593b --- /dev/null +++ b/deepmd/jax/model/dp_zbl_model.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP +from deepmd.jax.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel, +) +from deepmd.jax.common import ( + flax_module, +) +from deepmd.jax.env import ( + jnp, +) +from deepmd.jax.model.base_model import ( + BaseModel, + forward_common_atomic, +) + + +@BaseModel.register("zbl") +@flax_module +class DPZBLModel(DPZBLModelDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "atomic_model": + value = DPZBLLinearEnergyAtomicModel.deserialize(value.serialize()) + return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + return forward_common_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py index 7fa3efda6e..e636eba4c6 100644 --- a/deepmd/jax/model/model.py +++ b/deepmd/jax/model/model.py @@ -3,15 +3,27 @@ deepcopy, ) +from deepmd.jax.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.jax.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, +) from deepmd.jax.descriptor.base_descriptor import ( BaseDescriptor, ) from deepmd.jax.fitting.base_fitting import ( BaseFitting, ) +from deepmd.jax.fitting.fitting import ( + EnergyFittingNet, +) from deepmd.jax.model.base_model import ( BaseModel, ) +from deepmd.jax.model.dp_zbl_model import ( + DPZBLModel, +) def get_standard_model(data: dict): @@ -45,6 +57,45 @@ def get_standard_model(data: dict): ) +def get_zbl_model(data: dict) -> DPZBLModel: + data["descriptor"]["ntypes"] = len(data["type_map"]) + descriptor_type = data["descriptor"].pop("type") + descriptor = BaseDescriptor.get_class_by_type(descriptor_type)(**data["descriptor"]) + fitting_type = data["fitting_net"].pop("type") + if fitting_type == "ener": + fitting = EnergyFittingNet( + ntypes=descriptor.get_ntypes(), + dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), + **data["fitting_net"], + ) + else: + raise ValueError(f"Unknown fitting type {fitting_type}") + + dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"]) + # pairtab + filepath = data["use_srtab"] + pt_model = PairTabAtomicModel( + filepath, + data["descriptor"]["rcut"], + data["descriptor"]["sel"], + type_map=data["type_map"], + ) + rmin = data["sw_rmin"] + rmax = data["sw_rmax"] + atom_exclude_types = data.get("atom_exclude_types", []) + pair_exclude_types = data.get("pair_exclude_types", []) + return DPZBLModel( + dp_model, + pt_model, + rmin, + rmax, + type_map=data["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) + + def get_model(data: dict): """Get a model from a dictionary. @@ -57,6 +108,8 @@ def get_model(data: dict): if model_type == "standard": if "spin" in data: raise NotImplementedError("Spin model is not implemented yet.") + elif "use_srtab" in data: + return get_zbl_model(data) else: return get_standard_model(data) else: diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index f37bee0c90..a63543ab74 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -13,6 +13,7 @@ ) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, SKIP_FLAG, CommonTest, @@ -27,6 +28,11 @@ from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelPT else: DPZBLModelPT = None +if INSTALLED_JAX: + from deepmd.jax.model.dp_zbl_model import DPZBLModel as DPZBLModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + DPZBLModelJAX = None import os from deepmd.utils.argcheck import ( @@ -86,6 +92,7 @@ def data(self) -> dict: dp_class = DPZBLModelDP pt_class = DPZBLModelPT + jax_class = DPZBLModelJAX args = model_args() def get_reference_backend(self): @@ -109,7 +116,7 @@ def skip_tf(self): @property def skip_jax(self): - return True + return not INSTALLED_JAX def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" @@ -118,6 +125,8 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is DPZBLModelPT: return get_model_pt(data) + elif cls is DPZBLModelJAX: + return get_model_jax(data) return cls(**data, **self.additional_data) def setUp(self): From 1b3ea6b4ac4ce0d94ea6bb936bd34814ec318cda Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 1 Nov 2024 17:47:37 -0400 Subject: [PATCH 2/5] fix grad issue Signed-off-by: Jinzhe Zeng --- .../atomic_model/pairtab_atomic_model.py | 5 ++- deepmd/dpmodel/descriptor/dpa1.py | 5 ++- deepmd/dpmodel/utils/safe_gradient.py | 32 +++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 deepmd/dpmodel/utils/safe_gradient.py diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 7def899c38..c927089daf 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -15,6 +15,9 @@ FittingOutputDef, OutputVariableDef, ) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_sqrt, +) from deepmd.utils.pair_tab import ( PairTab, ) @@ -320,7 +323,7 @@ def _get_pairwise_dist(coords: np.ndarray, nlist: np.ndarray) -> np.ndarray: neighbor_atoms = coords[batch_indices, nlist] loc_atoms = coords[:, : nlist.shape[1], :] pairwise_dr = loc_atoms[:, :, None, :] - neighbor_atoms - pairwise_rr = xp.sqrt(xp.sum(xp.power(pairwise_dr, 2), axis=-1)) + pairwise_rr = safe_for_sqrt(xp.sum(xp.power(pairwise_dr, 2), axis=-1)) return pairwise_rr diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 2f2b12e03c..b033811507 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -27,6 +27,9 @@ LayerNorm, NativeLayer, ) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_vector_norm, +) from deepmd.dpmodel.utils.seed import ( child_seed, ) @@ -943,7 +946,7 @@ def call( else: raise NotImplementedError - normed = xp.linalg.vector_norm( + normed = safe_for_vector_norm( xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True ) input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum( diff --git a/deepmd/dpmodel/utils/safe_gradient.py b/deepmd/dpmodel/utils/safe_gradient.py new file mode 100644 index 0000000000..a7783b318b --- /dev/null +++ b/deepmd/dpmodel/utils/safe_gradient.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Safe versions of some functions that have problematic gradients. + +Check https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where +for more information. +""" + +import array_api_compat + + +def safe_for_sqrt(x): + """Safe version of sqrt that has a gradient of 0 at x = 0.""" + xp = array_api_compat.array_namespace(x) + mask = x > 0.0 + return xp.where(mask, xp.sqrt(xp.where(mask, x, xp.ones_like(x))), xp.zeros_like(x)) + + +def safe_for_vector_norm(x, /, *, axis=None, keepdims=False, ord=2): + """Safe version of sqrt that has a gradient of 0 at x = 0.""" + xp = array_api_compat.array_namespace(x) + mask = xp.sum(xp.square(x), axis=axis, keepdims=True) > 0 + if keepdims: + mask_squeezed = mask + else: + mask_squeezed = xp.squeeze(mask, axis=axis) + return xp.where( + mask_squeezed, + xp.linalg.vector_norm( + xp.where(mask, x, xp.ones_like(x)), axis=axis, keepdims=keepdims, ord=ord + ), + xp.zeros_like(mask_squeezed), + ) From 41863c9b074924625fe3844bb28607e11abe15b1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 1 Nov 2024 19:23:50 -0400 Subject: [PATCH 3/5] set dtype Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/utils/safe_gradient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/utils/safe_gradient.py b/deepmd/dpmodel/utils/safe_gradient.py index a7783b318b..2baf530c08 100644 --- a/deepmd/dpmodel/utils/safe_gradient.py +++ b/deepmd/dpmodel/utils/safe_gradient.py @@ -28,5 +28,5 @@ def safe_for_vector_norm(x, /, *, axis=None, keepdims=False, ord=2): xp.linalg.vector_norm( xp.where(mask, x, xp.ones_like(x)), axis=axis, keepdims=keepdims, ord=ord ), - xp.zeros_like(mask_squeezed), + xp.zeros_like(mask_squeezed, dtype=x.dtype), ) From 84803b847e641cf5a688bf2b6b9c16f3632b25c1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 2 Nov 2024 01:25:48 -0400 Subject: [PATCH 4/5] stop_gradient for nlist Signed-off-by: Jinzhe Zeng --- deepmd/jax/atomic_model/dp_atomic_model.py | 23 +++++++++++++++++++ .../jax/atomic_model/linear_atomic_model.py | 23 +++++++++++++++++++ .../jax/atomic_model/pairtab_atomic_model.py | 23 +++++++++++++++++++ 3 files changed, 69 insertions(+) diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py index 077209e29a..5898fd3ff8 100644 --- a/deepmd/jax/atomic_model/dp_atomic_model.py +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + Optional, ) from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP @@ -13,6 +14,10 @@ from deepmd.jax.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.jax.env import ( + jax, + jnp, +) from deepmd.jax.fitting.base_fitting import ( BaseFitting, ) @@ -28,3 +33,21 @@ class DPAtomicModel(DPAtomicModelDP): def __setattr__(self, name: str, value: Any) -> None: value = base_atomic_model_set_attr(name, value) return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + ) -> dict[str, jnp.ndarray]: + return super().forward_common_atomic( + extended_coord, + extended_atype, + jax.lax.stop_gradient(nlist), + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) diff --git a/deepmd/jax/atomic_model/linear_atomic_model.py b/deepmd/jax/atomic_model/linear_atomic_model.py index c86e3ef02c..6ce82fa07c 100644 --- a/deepmd/jax/atomic_model/linear_atomic_model.py +++ b/deepmd/jax/atomic_model/linear_atomic_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + Optional, ) from deepmd.dpmodel.atomic_model.linear_atomic_model import ( @@ -20,6 +21,10 @@ flax_module, to_jax_array, ) +from deepmd.jax.env import ( + jax, + jnp, +) @flax_module @@ -36,3 +41,21 @@ def __setattr__(self, name: str, value: Any) -> None: PairTabAtomicModel.deserialize(value[1].serialize()), ] return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + ) -> dict[str, jnp.ndarray]: + return super().forward_common_atomic( + extended_coord, + extended_atype, + jax.lax.stop_gradient(nlist), + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) diff --git a/deepmd/jax/atomic_model/pairtab_atomic_model.py b/deepmd/jax/atomic_model/pairtab_atomic_model.py index 2401362b63..023f4e886a 100644 --- a/deepmd/jax/atomic_model/pairtab_atomic_model.py +++ b/deepmd/jax/atomic_model/pairtab_atomic_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + Optional, ) from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( @@ -14,6 +15,10 @@ flax_module, to_jax_array, ) +from deepmd.jax.env import ( + jax, + jnp, +) @flax_module @@ -25,3 +30,21 @@ def __setattr__(self, name: str, value: Any) -> None: if value is not None: value = ArrayAPIVariable(value) return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + ) -> dict[str, jnp.ndarray]: + return super().forward_common_atomic( + extended_coord, + extended_atype, + jax.lax.stop_gradient(nlist), + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) From 0d7c740807d641f7be68d942fd8dd87bd2cf260a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 04:55:02 -0500 Subject: [PATCH 5/5] revert make_model.py Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/make_model.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 98a93c7500..95d97262df 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -503,14 +503,11 @@ def _format_nlist( index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2) coord1 = xp.take_along_axis(extended_coord, index, axis=1) coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3) - # jax raises NaN error using norm - # but note: we don't actually need to sqrt here; the squared value is enough - # rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) - rr2 = xp.sum(xp.square(coord0[:, :, None, :] - coord1), axis=-1) - rr2 = xp.where(m_real_nei, rr2, float("inf")) - rr2, ret_mapping = xp.sort(rr2, axis=-1), xp.argsort(rr2, axis=-1) + rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) + rr = xp.where(m_real_nei, rr, float("inf")) + rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1) ret = xp.take_along_axis(ret, ret_mapping, axis=2) - ret = xp.where(rr2 > rcut * rcut, -1, ret) + ret = xp.where(rr > rcut, -1, ret) ret = ret[..., :nnei] # not extra_nlist_sort and n_nnei <= nnei: elif n_nnei == nnei: