diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 485f82cb72..2c7e029d53 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -5,6 +5,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel.utils.nlist import ( @@ -69,7 +70,7 @@ def __init__( self.models = models sub_model_type_maps = [md.get_type_map() for md in models] err_msg = [] - self.mapping_list = [] + mapping_list = [] common_type_map = set(type_map) self.type_map = type_map for tpmp in sub_model_type_maps: @@ -77,7 +78,8 @@ def __init__( err_msg.append( f"type_map {tpmp} is not a subset of type_map {type_map}" ) - self.mapping_list.append(self.remap_atype(tpmp, self.type_map)) + mapping_list.append(self.remap_atype(tpmp, self.type_map)) + self.mapping_list = mapping_list assert len(err_msg) == 0, "\n".join(err_msg) self.mixed_types_list = [model.mixed_types() for model in self.models] @@ -212,8 +214,9 @@ def forward_atomic( result_dict the result dict, defined by the fitting net output def. """ + xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) nframes, nloc, nnei = nlist.shape - extended_coord = extended_coord.reshape(nframes, -1, 3) + extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) sorted_rcuts, sorted_sels = self._sort_rcuts_sels() nlists = build_multiple_neighbor_list( extended_coord, @@ -244,10 +247,10 @@ def forward_atomic( aparam, )["energy"] ) - self.weights = self._compute_weight(extended_coord, extended_atype, nlists_) + weights = self._compute_weight(extended_coord, extended_atype, nlists_) fit_ret = { - "energy": np.sum(np.stack(ener_list) * np.stack(self.weights), axis=0), + "energy": xp.sum(xp.stack(ener_list) * xp.stack(weights), axis=0), } # (nframes, nloc, 1) return fit_ret @@ -320,11 +323,12 @@ def _compute_weight( nlists_: list[np.ndarray], ) -> list[np.ndarray]: """This should be a list of user defined weights that matches the number of models to be combined.""" + xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlists_) nmodels = len(self.models) nframes, nloc, _ = nlists_[0].shape # the dtype of weights is the interface data type. return [ - np.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels + xp.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels for _ in range(nmodels) ] @@ -442,6 +446,7 @@ def _compute_weight( self.sw_rmax > self.sw_rmin ), "The upper boundary `sw_rmax` must be greater than the lower boundary `sw_rmin`." + xp = array_api_compat.array_namespace(extended_coord, extended_atype) dp_nlist = nlists_[0] zbl_nlist = nlists_[1] @@ -450,40 +455,40 @@ def _compute_weight( # use the larger rr based on nlist nlist_larger = zbl_nlist if zbl_nnei >= dp_nnei else dp_nlist - masked_nlist = np.clip(nlist_larger, 0, None) + masked_nlist = xp.clip(nlist_larger, 0, None) pairwise_rr = PairTabAtomicModel._get_pairwise_dist( extended_coord, masked_nlist ) - numerator = np.sum( - np.where( + numerator = xp.sum( + xp.where( nlist_larger != -1, - pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha), - np.zeros_like(nlist_larger), + pairwise_rr * xp.exp(-pairwise_rr / self.smin_alpha), + xp.zeros_like(nlist_larger), ), axis=-1, ) # masked nnei will be zero, no need to handle - denominator = np.sum( - np.where( + denominator = xp.sum( + xp.where( nlist_larger != -1, - np.exp(-pairwise_rr / self.smin_alpha), - np.zeros_like(nlist_larger), + xp.exp(-pairwise_rr / self.smin_alpha), + xp.zeros_like(nlist_larger), ), axis=-1, ) # handle masked nnei. with np.errstate(divide="ignore", invalid="ignore"): sigma = numerator / denominator u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin) - coef = np.zeros_like(u) + coef = xp.zeros_like(u) left_mask = sigma < self.sw_rmin mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax) right_mask = sigma >= self.sw_rmax - coef[left_mask] = 1 + coef = xp.where(left_mask, xp.ones_like(coef), coef) with np.errstate(invalid="ignore"): smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1 - coef[mid_mask] = smooth[mid_mask] - coef[right_mask] = 0 + coef = xp.where(mid_mask, smooth, coef) + coef = xp.where(right_mask, xp.zeros_like(coef), coef) # to handle masked atoms - coef = np.where(sigma != 0, coef, np.zeros_like(coef)) + coef = xp.where(sigma != 0, coef, xp.zeros_like(coef)) self.zbl_weight = coef - return [1 - np.expand_dims(coef, -1), np.expand_dims(coef, -1)] + return [1 - xp.expand_dims(coef, -1), xp.expand_dims(coef, -1)] diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 2899f106bc..c927089daf 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -5,12 +5,19 @@ Union, ) +import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, ) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_sqrt, +) from deepmd.utils.pair_tab import ( PairTab, ) @@ -74,9 +81,10 @@ def __init__( self.atom_ener = atom_ener if self.tab_file is not None: - self.tab_info, self.tab_data = self.tab.get() - nspline, ntypes_tab = self.tab_info[-2:].astype(int) - self.tab_data = self.tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4) + tab_info, tab_data = self.tab.get() + nspline, ntypes_tab = tab_info[-2:].astype(int) + self.tab_info = tab_info + self.tab_data = tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4) if self.ntypes != ntypes_tab: raise ValueError( "The `type_map` provided does not match the number of columns in the table." @@ -189,8 +197,9 @@ def forward_atomic( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, ) -> dict[str, np.ndarray]: + xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) nframes, nloc, nnei = nlist.shape - extended_coord = extended_coord.reshape(nframes, -1, 3) + extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) # this will mask all -1 in the nlist mask = nlist >= 0 @@ -200,23 +209,21 @@ def forward_atomic( pairwise_rr = self._get_pairwise_dist( extended_coord, masked_nlist ) # (nframes, nloc, nnei) - self.tab_data = self.tab_data.reshape( - self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 - ) # (nframes, nloc, nnei), index type is int64. j_type = extended_atype[ - np.arange(extended_atype.shape[0], dtype=np.int64)[:, None, None], + xp.arange(extended_atype.shape[0], dtype=xp.int64)[:, None, None], masked_nlist, ] raw_atomic_energy = self._pair_tabulated_inter( nlist, atype, j_type, pairwise_rr ) - atomic_energy = 0.5 * np.sum( - np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)), + atomic_energy = 0.5 * xp.sum( + xp.where(nlist != -1, raw_atomic_energy, xp.zeros_like(raw_atomic_energy)), axis=-1, - ).reshape(nframes, nloc, 1) + ) + atomic_energy = xp.reshape(atomic_energy, (nframes, nloc, 1)) return {"energy": atomic_energy} @@ -255,36 +262,42 @@ def _pair_tabulated_inter( This function is used to calculate the pairwise energy between two atoms. It uses a table containing cubic spline coefficients calculated in PairTab. """ + xp = array_api_compat.array_namespace(nlist, i_type, j_type, rr) nframes, nloc, nnei = nlist.shape rmin = self.tab_info[0] hh = self.tab_info[1] hi = 1.0 / hh - nspline = int(self.tab_info[2] + 0.1) + # jax jit does not support convert to a Python int, so we need to convert to xp.int64. + nspline = (self.tab_info[2] + 0.1).astype(xp.int64) uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei) # if nnei of atom 0 has -1 in the nlist, uu would be 0. # this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms. - uu = np.where(nlist != -1, uu, nspline + 1) + uu = xp.where(nlist != -1, uu, nspline + 1) - if np.any(uu < 0): - raise Exception("coord go beyond table lower boundary") + # unsupported by jax + # if xp.any(uu < 0): + # raise Exception("coord go beyond table lower boundary") - idx = uu.astype(int) + idx = xp.astype(uu, xp.int64) uu -= idx table_coef = self._extract_spline_coefficient( i_type, j_type, idx, self.tab_data, nspline ) - table_coef = table_coef.reshape(nframes, nloc, nnei, 4) + table_coef = xp.reshape(table_coef, (nframes, nloc, nnei, 4)) ener = self._calculate_ener(table_coef, uu) # here we need to overwrite energy to zero at rcut and beyond. mask_beyond_rcut = rr >= self.rcut # also overwrite values beyond extrapolation to zero extrapolation_mask = rr >= self.tab.rmin + nspline * self.tab.hh - ener[mask_beyond_rcut] = 0 - ener[extrapolation_mask] = 0 + ener = xp.where( + xp.logical_or(mask_beyond_rcut, extrapolation_mask), + xp.zeros_like(ener), + ener, + ) return ener @@ -304,12 +317,13 @@ def _get_pairwise_dist(coords: np.ndarray, nlist: np.ndarray) -> np.ndarray: np.ndarray The pairwise distance between the atoms (nframes, nloc, nnei). """ + xp = array_api_compat.array_namespace(coords, nlist) # index type is int64 - batch_indices = np.arange(nlist.shape[0], dtype=np.int64)[:, None, None] + batch_indices = xp.arange(nlist.shape[0], dtype=xp.int64)[:, None, None] neighbor_atoms = coords[batch_indices, nlist] loc_atoms = coords[:, : nlist.shape[1], :] pairwise_dr = loc_atoms[:, :, None, :] - neighbor_atoms - pairwise_rr = np.sqrt(np.sum(np.power(pairwise_dr, 2), axis=-1)) + pairwise_rr = safe_for_sqrt(xp.sum(xp.power(pairwise_dr, 2), axis=-1)) return pairwise_rr @@ -319,7 +333,7 @@ def _extract_spline_coefficient( j_type: np.ndarray, idx: np.ndarray, tab_data: np.ndarray, - nspline: int, + nspline: np.int64, ) -> np.ndarray: """Extract the spline coefficient from the table. @@ -341,9 +355,10 @@ def _extract_spline_coefficient( np.ndarray The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed. """ + xp = array_api_compat.array_namespace(i_type, j_type, idx, tab_data) # (nframes, nloc, nnei) - expanded_i_type = np.broadcast_to( - i_type[:, :, np.newaxis], + expanded_i_type = xp.broadcast_to( + i_type[:, :, xp.newaxis], (i_type.shape[0], i_type.shape[1], j_type.shape[-1]), ) @@ -351,18 +366,20 @@ def _extract_spline_coefficient( expanded_tab_data = tab_data[expanded_i_type, j_type] # (nframes, nloc, nnei, 1, 4) - expanded_idx = np.broadcast_to( - idx[..., np.newaxis, np.newaxis], (*idx.shape, 1, 4) + expanded_idx = xp.broadcast_to( + idx[..., xp.newaxis, xp.newaxis], (*idx.shape, 1, 4) ) - clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int) + clipped_indices = xp.clip(expanded_idx, 0, nspline - 1).astype(int) # (nframes, nloc, nnei, 4) - final_coef = np.squeeze( - np.take_along_axis(expanded_tab_data, clipped_indices, 3) + final_coef = xp.squeeze( + xp_take_along_axis(expanded_tab_data, clipped_indices, 3) ) # when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`. - final_coef[expanded_idx.squeeze() > nspline] = 0 + final_coef = xp.where( + expanded_idx.squeeze() > nspline, xp.zeros_like(final_coef), final_coef + ) return final_coef @staticmethod diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 2f2b12e03c..b033811507 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -27,6 +27,9 @@ LayerNorm, NativeLayer, ) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_vector_norm, +) from deepmd.dpmodel.utils.seed import ( child_seed, ) @@ -943,7 +946,7 @@ def call( else: raise NotImplementedError - normed = xp.linalg.vector_norm( + normed = safe_for_vector_norm( xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True ) input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum( diff --git a/deepmd/dpmodel/utils/safe_gradient.py b/deepmd/dpmodel/utils/safe_gradient.py new file mode 100644 index 0000000000..2baf530c08 --- /dev/null +++ b/deepmd/dpmodel/utils/safe_gradient.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Safe versions of some functions that have problematic gradients. + +Check https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where +for more information. +""" + +import array_api_compat + + +def safe_for_sqrt(x): + """Safe version of sqrt that has a gradient of 0 at x = 0.""" + xp = array_api_compat.array_namespace(x) + mask = x > 0.0 + return xp.where(mask, xp.sqrt(xp.where(mask, x, xp.ones_like(x))), xp.zeros_like(x)) + + +def safe_for_vector_norm(x, /, *, axis=None, keepdims=False, ord=2): + """Safe version of sqrt that has a gradient of 0 at x = 0.""" + xp = array_api_compat.array_namespace(x) + mask = xp.sum(xp.square(x), axis=axis, keepdims=True) > 0 + if keepdims: + mask_squeezed = mask + else: + mask_squeezed = xp.squeeze(mask, axis=axis) + return xp.where( + mask_squeezed, + xp.linalg.vector_norm( + xp.where(mask, x, xp.ones_like(x)), axis=axis, keepdims=keepdims, ord=ord + ), + xp.zeros_like(mask_squeezed, dtype=x.dtype), + ) diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py index 077209e29a..5898fd3ff8 100644 --- a/deepmd/jax/atomic_model/dp_atomic_model.py +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + Optional, ) from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP @@ -13,6 +14,10 @@ from deepmd.jax.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.jax.env import ( + jax, + jnp, +) from deepmd.jax.fitting.base_fitting import ( BaseFitting, ) @@ -28,3 +33,21 @@ class DPAtomicModel(DPAtomicModelDP): def __setattr__(self, name: str, value: Any) -> None: value = base_atomic_model_set_attr(name, value) return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + ) -> dict[str, jnp.ndarray]: + return super().forward_common_atomic( + extended_coord, + extended_atype, + jax.lax.stop_gradient(nlist), + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) diff --git a/deepmd/jax/atomic_model/linear_atomic_model.py b/deepmd/jax/atomic_model/linear_atomic_model.py new file mode 100644 index 0000000000..6ce82fa07c --- /dev/null +++ b/deepmd/jax/atomic_model/linear_atomic_model.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +from deepmd.dpmodel.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel as DPZBLLinearEnergyAtomicModelDP, +) +from deepmd.jax.atomic_model.base_atomic_model import ( + base_atomic_model_set_attr, +) +from deepmd.jax.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.jax.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, +) +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.env import ( + jax, + jnp, +) + + +@flax_module +class DPZBLLinearEnergyAtomicModel(DPZBLLinearEnergyAtomicModelDP): + def __setattr__(self, name: str, value: Any) -> None: + value = base_atomic_model_set_attr(name, value) + if name == "mapping_list": + value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value] + elif name == "zbl_weight": + value = ArrayAPIVariable(to_jax_array(value)) + elif name == "models": + value = [ + DPAtomicModel.deserialize(value[0].serialize()), + PairTabAtomicModel.deserialize(value[1].serialize()), + ] + return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + ) -> dict[str, jnp.ndarray]: + return super().forward_common_atomic( + extended_coord, + extended_atype, + jax.lax.stop_gradient(nlist), + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) diff --git a/deepmd/jax/atomic_model/pairtab_atomic_model.py b/deepmd/jax/atomic_model/pairtab_atomic_model.py new file mode 100644 index 0000000000..023f4e886a --- /dev/null +++ b/deepmd/jax/atomic_model/pairtab_atomic_model.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel as PairTabAtomicModelDP, +) +from deepmd.jax.atomic_model.base_atomic_model import ( + base_atomic_model_set_attr, +) +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.env import ( + jax, + jnp, +) + + +@flax_module +class PairTabAtomicModel(PairTabAtomicModelDP): + def __setattr__(self, name: str, value: Any) -> None: + value = base_atomic_model_set_attr(name, value) + if name in {"tab_info", "tab_data"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + ) -> dict[str, jnp.ndarray]: + return super().forward_common_atomic( + extended_coord, + extended_atype, + jax.lax.stop_gradient(nlist), + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) diff --git a/deepmd/jax/model/__init__.py b/deepmd/jax/model/__init__.py index 05a60c4ffe..bba5bc766a 100644 --- a/deepmd/jax/model/__init__.py +++ b/deepmd/jax/model/__init__.py @@ -1,6 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .dp_zbl_model import ( + DPZBLLinearEnergyAtomicModel, +) from .ener_model import ( EnergyModel, ) -__all__ = ["EnergyModel"] +__all__ = [ + "EnergyModel", + "DPZBLLinearEnergyAtomicModel", +] diff --git a/deepmd/jax/model/dp_zbl_model.py b/deepmd/jax/model/dp_zbl_model.py new file mode 100644 index 0000000000..028fa8593b --- /dev/null +++ b/deepmd/jax/model/dp_zbl_model.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP +from deepmd.jax.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel, +) +from deepmd.jax.common import ( + flax_module, +) +from deepmd.jax.env import ( + jnp, +) +from deepmd.jax.model.base_model import ( + BaseModel, + forward_common_atomic, +) + + +@BaseModel.register("zbl") +@flax_module +class DPZBLModel(DPZBLModelDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "atomic_model": + value = DPZBLLinearEnergyAtomicModel.deserialize(value.serialize()) + return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + return forward_common_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py index 7fa3efda6e..e636eba4c6 100644 --- a/deepmd/jax/model/model.py +++ b/deepmd/jax/model/model.py @@ -3,15 +3,27 @@ deepcopy, ) +from deepmd.jax.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.jax.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, +) from deepmd.jax.descriptor.base_descriptor import ( BaseDescriptor, ) from deepmd.jax.fitting.base_fitting import ( BaseFitting, ) +from deepmd.jax.fitting.fitting import ( + EnergyFittingNet, +) from deepmd.jax.model.base_model import ( BaseModel, ) +from deepmd.jax.model.dp_zbl_model import ( + DPZBLModel, +) def get_standard_model(data: dict): @@ -45,6 +57,45 @@ def get_standard_model(data: dict): ) +def get_zbl_model(data: dict) -> DPZBLModel: + data["descriptor"]["ntypes"] = len(data["type_map"]) + descriptor_type = data["descriptor"].pop("type") + descriptor = BaseDescriptor.get_class_by_type(descriptor_type)(**data["descriptor"]) + fitting_type = data["fitting_net"].pop("type") + if fitting_type == "ener": + fitting = EnergyFittingNet( + ntypes=descriptor.get_ntypes(), + dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), + **data["fitting_net"], + ) + else: + raise ValueError(f"Unknown fitting type {fitting_type}") + + dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"]) + # pairtab + filepath = data["use_srtab"] + pt_model = PairTabAtomicModel( + filepath, + data["descriptor"]["rcut"], + data["descriptor"]["sel"], + type_map=data["type_map"], + ) + rmin = data["sw_rmin"] + rmax = data["sw_rmax"] + atom_exclude_types = data.get("atom_exclude_types", []) + pair_exclude_types = data.get("pair_exclude_types", []) + return DPZBLModel( + dp_model, + pt_model, + rmin, + rmax, + type_map=data["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) + + def get_model(data: dict): """Get a model from a dictionary. @@ -57,6 +108,8 @@ def get_model(data: dict): if model_type == "standard": if "spin" in data: raise NotImplementedError("Spin model is not implemented yet.") + elif "use_srtab" in data: + return get_zbl_model(data) else: return get_standard_model(data) else: diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index f37bee0c90..a63543ab74 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -13,6 +13,7 @@ ) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, SKIP_FLAG, CommonTest, @@ -27,6 +28,11 @@ from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelPT else: DPZBLModelPT = None +if INSTALLED_JAX: + from deepmd.jax.model.dp_zbl_model import DPZBLModel as DPZBLModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + DPZBLModelJAX = None import os from deepmd.utils.argcheck import ( @@ -86,6 +92,7 @@ def data(self) -> dict: dp_class = DPZBLModelDP pt_class = DPZBLModelPT + jax_class = DPZBLModelJAX args = model_args() def get_reference_backend(self): @@ -109,7 +116,7 @@ def skip_tf(self): @property def skip_jax(self): - return True + return not INSTALLED_JAX def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" @@ -118,6 +125,8 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is DPZBLModelPT: return get_model_pt(data) + elif cls is DPZBLModelJAX: + return get_model_jax(data) return cls(**data, **self.additional_data) def setUp(self):