diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index c29a76b3f1..6307b19f41 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import copy +import math from typing import ( Optional, ) +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( NativeOP, + to_numpy_array, ) from deepmd.dpmodel.output_def import ( FittingOutputDef, @@ -172,17 +174,18 @@ def forward_common_atomic( ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual. """ + xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) _, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] if self.pair_excl is not None: pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype) # exclude neighbors in the nlist - nlist = np.where(pair_mask == 1, nlist, -1) + nlist = xp.where(pair_mask == 1, nlist, -1) ext_atom_mask = self.make_atom_mask(extended_atype) ret_dict = self.forward_atomic( extended_coord, - np.where(ext_atom_mask, extended_atype, 0), + xp.where(ext_atom_mask, extended_atype, 0), nlist, mapping=mapping, fparam=fparam, @@ -191,13 +194,13 @@ def forward_common_atomic( ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc - atom_mask = ext_atom_mask[:, :nloc].astype(np.int32) + atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32) if self.atom_excl is not None: atom_mask *= self.atom_excl.build_type_exclude_mask(atype) for kk in ret_dict.keys(): out_shape = ret_dict[kk].shape - out_shape2 = np.prod(out_shape[2:]) + out_shape2 = math.prod(out_shape[2:]) ret_dict[kk] = ( ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) * atom_mask[:, :, None] @@ -232,14 +235,15 @@ def serialize(self) -> dict: "rcond": self.rcond, "preset_out_bias": self.preset_out_bias, "@variables": { - "out_bias": self.out_bias, - "out_std": self.out_std, + "out_bias": to_numpy_array(self.out_bias), + "out_std": to_numpy_array(self.out_std), }, } @classmethod def deserialize(cls, data: dict) -> "BaseAtomicModel": - data = copy.deepcopy(data) + # do not deep copy Descriptor and Fitting class + data = data.copy() variables = data.pop("@variables") obj = cls(**data) for kk in variables.keys(): diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 7e576eb484..fe049021fe 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -169,14 +169,20 @@ def serialize(self) -> dict: ) return dd + # for subclass overriden + base_descriptor_cls = BaseDescriptor + """The base descriptor class.""" + base_fitting_cls = BaseFitting + """The base fitting class.""" + @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 2) data.pop("@class") data.pop("type") - descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) - fitting_obj = BaseFitting.deserialize(data.pop("fitting")) + descriptor_obj = cls.base_descriptor_cls.deserialize(data.pop("descriptor")) + fitting_obj = cls.base_fitting_cls.deserialize(data.pop("fitting")) data["descriptor"] = descriptor_obj data["fitting"] = fitting_obj obj = super().deserialize(data) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 8cdb7e1f25..dc90f10da7 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -3,6 +3,7 @@ Optional, ) +import array_api_compat import numpy as np from deepmd.dpmodel.atomic_model.base_atomic_model import ( @@ -75,7 +76,8 @@ def __init__( else: self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs) self.precision_dict = PRECISION_DICT - self.reverse_precision_dict = RESERVED_PRECISON_DICT + # not supported by flax + # self.reverse_precision_dict = RESERVED_PRECISON_DICT self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION @@ -253,9 +255,7 @@ def input_type_cast( str, ]: """Cast the input data to global float type.""" - input_prec = self.reverse_precision_dict[ - self.precision_dict[coord.dtype.name] - ] + input_prec = RESERVED_PRECISON_DICT[self.precision_dict[coord.dtype.name]] ### ### type checking would not pass jit, convert to coord prec anyway ### @@ -264,10 +264,7 @@ def input_type_cast( for vv in [box, fparam, aparam] ] box, fparam, aparam = _lst - if ( - input_prec - == self.reverse_precision_dict[self.global_np_float_precision] - ): + if input_prec == RESERVED_PRECISON_DICT[self.global_np_float_precision]: return coord, box, fparam, aparam, input_prec else: pp = self.global_np_float_precision @@ -286,8 +283,7 @@ def output_type_cast( ) -> dict[str, np.ndarray]: """Convert the model output to the input prec.""" do_cast = ( - input_prec - != self.reverse_precision_dict[self.global_np_float_precision] + input_prec != RESERVED_PRECISON_DICT[self.global_np_float_precision] ) pp = self.precision_dict[input_prec] odef = self.model_output_def() @@ -366,6 +362,7 @@ def _format_nlist( nnei: int, extra_nlist_sort: bool = False, ): + xp = array_api_compat.array_namespace(extended_coord, nlist) n_nf, n_nloc, n_nnei = nlist.shape extended_coord = extended_coord.reshape([n_nf, -1, 3]) nall = extended_coord.shape[1] @@ -373,10 +370,10 @@ def _format_nlist( if n_nnei < nnei: # make a copy before revise - ret = np.concatenate( + ret = xp.concat( [ nlist, - -1 * np.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), + -1 * xp.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), ], axis=-1, ) @@ -385,16 +382,16 @@ def _format_nlist( n_nf, n_nloc, n_nnei = nlist.shape # make a copy before revise m_real_nei = nlist >= 0 - ret = np.where(m_real_nei, nlist, 0) + ret = xp.where(m_real_nei, nlist, 0) coord0 = extended_coord[:, :n_nloc, :] index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2) - coord1 = np.take_along_axis(extended_coord, index, axis=1) + coord1 = xp.take_along_axis(extended_coord, index, axis=1) coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3) - rr = np.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) - rr = np.where(m_real_nei, rr, float("inf")) - rr, ret_mapping = np.sort(rr, axis=-1), np.argsort(rr, axis=-1) - ret = np.take_along_axis(ret, ret_mapping, axis=2) - ret = np.where(rr > rcut, -1, ret) + rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) + rr = xp.where(m_real_nei, rr, float("inf")) + rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1) + ret = xp.take_along_axis(ret, ret_mapping, axis=2) + ret = xp.where(rr > rcut, -1, ret) ret = ret[..., :nnei] # not extra_nlist_sort and n_nnei <= nnei: elif n_nnei == nnei: diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 43c275b1be..928c33f3bd 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( @@ -23,6 +24,7 @@ def fit_output_to_model_output( the model output. """ + xp = array_api_compat.get_namespace(coord_ext) model_ret = dict(fit_ret.items()) for kk, vv in fit_ret.items(): vdef = fit_output_def[kk] @@ -31,7 +33,7 @@ def fit_output_to_model_output( if vdef.reducible: kk_redu = get_reduce_name(kk) # cast to energy prec brefore reduction - model_ret[kk_redu] = np.sum( + model_ret[kk_redu] = xp.sum( vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis ) if vdef.r_differentiable: diff --git a/deepmd/jax/atomic_model/__init__.py b/deepmd/jax/atomic_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/atomic_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/atomic_model/base_atomic_model.py b/deepmd/jax/atomic_model/base_atomic_model.py new file mode 100644 index 0000000000..90920879c2 --- /dev/null +++ b/deepmd/jax/atomic_model/base_atomic_model.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + + +def base_atomic_model_set_attr(name, value): + if name in {"out_bias", "out_std"}: + value = to_jax_array(value) + elif name == "pair_excl" and value is not None: + value = PairExcludeMask(value.ntypes, value.exclude_types) + elif name == "atom_excl" and value is not None: + value = AtomExcludeMask(value.ntypes, value.exclude_types) + return value diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py new file mode 100644 index 0000000000..077209e29a --- /dev/null +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP +from deepmd.jax.atomic_model.base_atomic_model import ( + base_atomic_model_set_attr, +) +from deepmd.jax.common import ( + flax_module, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) + + +@flax_module +class DPAtomicModel(DPAtomicModelDP): + base_descriptor_cls = BaseDescriptor + """The base descriptor class.""" + base_fitting_cls = BaseFitting + """The base fitting class.""" + + def __setattr__(self, name: str, value: Any) -> None: + value = base_atomic_model_set_attr(name, value) + return super().__setattr__(name, value) diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index 6ceb116d85..ed59493268 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -1 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, +) + +__all__ = [ + "DescrptSeA", + "DescrptDPA1", +] diff --git a/deepmd/jax/descriptor/base_descriptor.py b/deepmd/jax/descriptor/base_descriptor.py new file mode 100644 index 0000000000..7dec3cd6d4 --- /dev/null +++ b/deepmd/jax/descriptor/base_descriptor.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.descriptor.make_base_descriptor import ( + make_base_descriptor, +) +from deepmd.jax.env import ( + jnp, +) + +BaseDescriptor = make_base_descriptor(jnp.ndarray) diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py index a9b0404970..0528e4bb93 100644 --- a/deepmd/jax/descriptor/dpa1.py +++ b/deepmd/jax/descriptor/dpa1.py @@ -16,6 +16,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.jax.utils.exclude_mask import ( PairExcludeMask, ) @@ -76,6 +79,8 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseDescriptor.register("dpa1") +@BaseDescriptor.register("se_atten") @flax_module class DescrptDPA1(DescrptDPA1DP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/descriptor/se_e2_a.py b/deepmd/jax/descriptor/se_e2_a.py index a60a4e9af1..d1a6e9a8d9 100644 --- a/deepmd/jax/descriptor/se_e2_a.py +++ b/deepmd/jax/descriptor/se_e2_a.py @@ -8,6 +8,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.jax.utils.exclude_mask import ( PairExcludeMask, ) @@ -16,6 +19,8 @@ ) +@BaseDescriptor.register("se_e2_a") +@BaseDescriptor.register("se_a") @flax_module class DescrptSeA(DescrptSeADP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/fitting/__init__.py b/deepmd/jax/fitting/__init__.py index 6ceb116d85..e72314dcab 100644 --- a/deepmd/jax/fitting/__init__.py +++ b/deepmd/jax/fitting/__init__.py @@ -1 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.fitting.fitting import ( + DOSFittingNet, + EnergyFittingNet, +) + +__all__ = [ + "EnergyFittingNet", + "DOSFittingNet", +] diff --git a/deepmd/jax/fitting/base_fitting.py b/deepmd/jax/fitting/base_fitting.py new file mode 100644 index 0000000000..fd9f3a416d --- /dev/null +++ b/deepmd/jax/fitting/base_fitting.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.fitting.make_base_fitting import ( + make_base_fitting, +) +from deepmd.jax.env import ( + jnp, +) + +BaseFitting = make_base_fitting(jnp.ndarray) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index 284213c70a..f979db4d41 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -9,6 +9,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) from deepmd.jax.utils.exclude_mask import ( AtomExcludeMask, ) @@ -33,6 +36,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any: return value +@BaseFitting.register("ener") @flax_module class EnergyFittingNet(EnergyFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: @@ -40,6 +44,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseFitting.register("dos") @flax_module class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/model/__init__.py b/deepmd/jax/model/__init__.py new file mode 100644 index 0000000000..05a60c4ffe --- /dev/null +++ b/deepmd/jax/model/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .ener_model import ( + EnergyModel, +) + +__all__ = ["EnergyModel"] diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py new file mode 100644 index 0000000000..fee4855da3 --- /dev/null +++ b/deepmd/jax/model/base_model.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.model.base_model import ( + make_base_model, +) + +BaseModel = make_base_model() diff --git a/deepmd/jax/model/ener_model.py b/deepmd/jax/model/ener_model.py new file mode 100644 index 0000000000..79c5a29e88 --- /dev/null +++ b/deepmd/jax/model/ener_model.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.model import EnergyModel as EnergyModelDP +from deepmd.jax.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.jax.common import ( + flax_module, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +@BaseModel.register("ener") +@flax_module +class EnergyModel(EnergyModelDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "atomic_model": + value = DPAtomicModel.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py new file mode 100644 index 0000000000..7fa3efda6e --- /dev/null +++ b/deepmd/jax/model/model.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) + +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +def get_standard_model(data: dict): + """Get a Model from a dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + data = deepcopy(data) + descriptor_type = data["descriptor"].pop("type") + data["descriptor"]["type_map"] = data["type_map"] + fitting_type = data["fitting_net"].pop("type") + data["fitting_net"]["type_map"] = data["type_map"] + descriptor = BaseDescriptor.get_class_by_type(descriptor_type)( + **data["descriptor"], + ) + fitting = BaseFitting.get_class_by_type(fitting_type)( + ntypes=descriptor.get_ntypes(), + dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), + **data["fitting_net"], + ) + return BaseModel.get_class_by_type(fitting_type)( + descriptor=descriptor, + fitting=fitting, + type_map=data["type_map"], + atom_exclude_types=data.get("atom_exclude_types", []), + pair_exclude_types=data.get("pair_exclude_types", []), + ) + + +def get_model(data: dict): + """Get a model from a dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + model_type = data.get("type", "standard") + if model_type == "standard": + if "spin" in data: + raise NotImplementedError("Spin model is not implemented yet.") + else: + return get_standard_model(data) + else: + return BaseModel.get_class_by_type(model_type).get_model(data) diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 294edec1d6..4112e09cff 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -6,8 +6,12 @@ from deepmd.common import ( make_default_mesh, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, ) @@ -20,6 +24,11 @@ GLOBAL_TF_FLOAT_PRECISION, tf, ) +if INSTALLED_JAX: + from deepmd.jax.common import to_jax_array as numpy_to_jax + from deepmd.jax.env import ( + jnp, + ) class ModelTest: @@ -62,3 +71,17 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any: box=numpy_to_torch(box), ).items() } + + def eval_jax_model(self, jax_obj: Any, natoms, coords, atype, box) -> Any: + def assert_jax_array(arr): + assert isinstance(arr, jnp.ndarray) or arr is None + return arr + + return { + kk: to_numpy_array(assert_jax_array(vv)) + for kk, vv in jax_obj( + numpy_to_jax(coords), + numpy_to_jax(atype), + box=numpy_to_jax(box), + ).items() + } diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 692e1287dc..78a2aac703 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -13,6 +13,7 @@ ) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -36,6 +37,12 @@ model_args, ) +if INSTALLED_JAX: + from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + EnergyModelJAX = None + @parameterized( ( @@ -84,14 +91,20 @@ def data(self) -> dict: tf_class = EnergyModelTF dp_class = EnergyModelDP pt_class = EnergyModelPT + jax_class = EnergyModelJAX args = model_args() + @property def skip_tf(self): return ( self.data["pair_exclude_types"] != [] or self.data["atom_exclude_types"] != [] ) + @property + def skip_jax(self): + return not INSTALLED_JAX + def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" data = data.copy() @@ -99,6 +112,8 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is EnergyModelPT: return get_model_pt(data) + elif cls is EnergyModelJAX: + return get_model_jax(data) return cls(**data, **self.addtional_data) def setUp(self): @@ -168,6 +183,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: @@ -176,4 +200,6 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret["energy"].ravel(), ret["atom_energy"].ravel()) elif backend is self.RefBackend.TF: return (ret[0].ravel(), ret[1].ravel()) + elif backend is self.RefBackend.JAX: + return (ret["energy_redu"].ravel(), ret["energy"].ravel()) raise ValueError(f"Unknown backend: {backend}")