diff --git a/deepmd/dpmodel/__init__.py b/deepmd/dpmodel/__init__.py index 6f83f849a3..111c2d6ced 100644 --- a/deepmd/dpmodel/__init__.py +++ b/deepmd/dpmodel/__init__.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.utils.entry_point import ( + load_entry_point, +) + from .common import ( DEFAULT_PRECISION, PRECISION_DICT, @@ -32,3 +36,6 @@ "get_deriv_name", "get_hessian_name", ] + + +load_entry_point("deepmd.dpmodel") diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 5d86472674..224fdd145c 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -34,6 +34,7 @@ ) +@BaseAtomicModel.register("linear") class LinearEnergyAtomicModel(BaseAtomicModel): """Linear model make linear combinations of several existing models. @@ -324,6 +325,7 @@ def is_aparam_nall(self) -> bool: return False +@BaseAtomicModel.register("zbl") class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel): """Model linearly combine a list of AtomicModels. diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index 4eb14f29cf..0d89902e4a 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( @@ -66,7 +67,7 @@ def __init__( ), f"number of atom types in {ii}th descriptor {self.descrpt_list[0].__class__.__name__} does not match others" # if hybrid sel is larger than sub sel, the nlist needs to be cut for each type hybrid_sel = self.get_sel() - self.nlist_cut_idx: list[np.ndarray] = [] + nlist_cut_idx: list[np.ndarray] = [] if self.mixed_types() and not all( descrpt.mixed_types() for descrpt in self.descrpt_list ): @@ -92,7 +93,8 @@ def __init__( cut_idx = np.concatenate( [range(ss, ee) for ss, ee in zip(start_idx, end_idx)] ) - self.nlist_cut_idx.append(cut_idx) + nlist_cut_idx.append(cut_idx) + self.nlist_cut_idx = nlist_cut_idx def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -242,6 +244,7 @@ def call( sw The smooth switch function. """ + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) out_descriptor = [] out_gr = [] out_g2 = None @@ -258,7 +261,7 @@ def call( for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx): # cut the nlist to the correct length if self.mixed_types() == descrpt.mixed_types(): - nl = nlist[:, :, nci] + nl = xp.take(nlist, nci, axis=2) else: # mixed_types is True, but descrpt.mixed_types is False assert nl_distinguish_types is not None @@ -268,8 +271,8 @@ def call( if gr is not None: out_gr.append(gr) - out_descriptor = np.concatenate(out_descriptor, axis=-1) - out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None + out_descriptor = xp.concat(out_descriptor, axis=-1) + out_gr = xp.concat(out_gr, axis=-2) if out_gr else None return out_descriptor, out_gr, out_g2, out_h2, out_sw @classmethod diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 01bd60c777..cecba865d0 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( @@ -207,6 +208,7 @@ def call( The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` """ + xp = array_api_compat.array_namespace(descriptor, atype) nframes, nloc, _ = descriptor.shape assert gr is not None, "Must provide the rotation matrix for dipole fitting." # (nframes, nloc, m1) @@ -214,9 +216,11 @@ def call( self.var_name ] # (nframes * nloc, 1, m1) - out = out.reshape(-1, 1, self.embedding_width) + out = xp.reshape(out, (-1, 1, self.embedding_width)) # (nframes * nloc, m1, 3) - gr = gr.reshape(nframes * nloc, -1, 3) + gr = xp.reshape(gr, (nframes * nloc, -1, 3)) # (nframes, nloc, 3) - out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3) + # out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3) + out = out @ gr + out = xp.reshape(out, (nframes, nloc, 3)) return {self.var_name: out} diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index e55f57c774..a027e1e59d 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -388,8 +388,8 @@ def _call_common( assert fparam is not None, "fparam should not be None" if fparam.shape[-1] != self.numb_fparam: raise ValueError( - "get an input fparam of dim {fparam.shape[-1]}, ", - "which is not consistent with {self.numb_fparam}.", + f"get an input fparam of dim {fparam.shape[-1]}, " + f"which is not consistent with {self.numb_fparam}." ) fparam = (fparam - self.fparam_avg) * self.fparam_inv_std fparam = xp.tile( @@ -409,8 +409,8 @@ def _call_common( assert aparam is not None, "aparam should not be None" if aparam.shape[-1] != self.numb_aparam: raise ValueError( - "get an input aparam of dim {aparam.shape[-1]}, ", - "which is not consistent with {self.numb_aparam}.", + f"get an input aparam of dim {aparam.shape[-1]}, " + f"which is not consistent with {self.numb_aparam}." ) aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam]) aparam = (aparam - self.aparam_avg) * self.aparam_inv_std diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 2d96eec580..b972b45971 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.common import ( @@ -14,6 +15,9 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.base_fitting import ( BaseFitting, ) @@ -124,23 +128,18 @@ def __init__( self.embedding_width = embedding_width self.fit_diag = fit_diag - self.scale = scale - if self.scale is None: - self.scale = [1.0 for _ in range(ntypes)] + if scale is None: + scale = [1.0 for _ in range(ntypes)] else: - if isinstance(self.scale, list): - assert ( - len(self.scale) == ntypes - ), "Scale should be a list of length ntypes." - elif isinstance(self.scale, float): - self.scale = [self.scale for _ in range(ntypes)] + if isinstance(scale, list): + assert len(scale) == ntypes, "Scale should be a list of length ntypes." + elif isinstance(scale, float): + scale = [scale for _ in range(ntypes)] else: raise ValueError( "Scale must be a list of float of length ntypes or a float." ) - self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape( - ntypes, 1 - ) + self.scale = np.array(scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(ntypes, 1) self.shift_diag = shift_diag self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION) super().__init__( @@ -192,8 +191,8 @@ def serialize(self) -> dict: data["embedding_width"] = self.embedding_width data["fit_diag"] = self.fit_diag data["shift_diag"] = self.shift_diag - data["@variables"]["scale"] = self.scale - data["@variables"]["constant_matrix"] = self.constant_matrix + data["@variables"]["scale"] = to_numpy_array(self.scale) + data["@variables"]["constant_matrix"] = to_numpy_array(self.constant_matrix) return data @classmethod @@ -276,6 +275,7 @@ def call( The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` """ + xp = array_api_compat.array_namespace(descriptor, atype) nframes, nloc, _ = descriptor.shape assert ( gr is not None @@ -284,28 +284,39 @@ def call( out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] - out = out * self.scale[atype] + # out = out * self.scale[atype, ...] + scale_atype = xp.reshape( + xp.take(self.scale, xp.reshape(atype, [-1]), axis=0), (*atype.shape, 1) + ) + out = out * scale_atype # (nframes * nloc, m1, 3) - gr = gr.reshape(nframes * nloc, -1, 3) + gr = xp.reshape(gr, (nframes * nloc, -1, 3)) if self.fit_diag: - out = out.reshape(-1, self.embedding_width) - out = np.einsum("ij,ijk->ijk", out, gr) + out = xp.reshape(out, (-1, self.embedding_width)) + # out = np.einsum("ij,ijk->ijk", out, gr) + out = out[:, :, None] * gr else: - out = out.reshape(-1, self.embedding_width, self.embedding_width) - out = (out + np.transpose(out, axes=(0, 2, 1))) / 2 - out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) - out = np.einsum( - "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out - ) # (nframes * nloc, 3, 3) - out = out.reshape(nframes, nloc, 3, 3) + out = xp.reshape(out, (-1, self.embedding_width, self.embedding_width)) + out = (out + xp.matrix_transpose(out)) / 2 + # out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) + out = out @ gr + # out = np.einsum( + # "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out + # ) # (nframes * nloc, 3, 3) + out = xp.matrix_transpose(gr) @ out + out = xp.reshape(out, (nframes, nloc, 3, 3)) if self.shift_diag: - bias = self.constant_matrix[atype] + # bias = self.constant_matrix[atype] + bias = xp.reshape( + xp.take(self.constant_matrix, xp.reshape(atype, [-1]), axis=0), + (nframes, nloc), + ) # (nframes, nloc, 1) - bias = np.expand_dims(bias, axis=-1) * self.scale[atype] - eye = np.eye(3, dtype=descriptor.dtype) - eye = np.tile(eye, (nframes, nloc, 1, 1)) + bias = bias[..., None] * scale_atype + eye = xp.eye(3, dtype=descriptor.dtype) + eye = xp.tile(eye, (nframes, nloc, 1, 1)) # (nframes, nloc, 3, 3) - bias = np.expand_dims(bias, axis=-1) * eye + bias = bias[..., None] * eye out = out + bias return {"polarizability": out} diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index c1f3e4630b..5463743ada 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -204,8 +204,6 @@ def eval( The output of the evaluation. The keys are the names of the output variables, and the values are the corresponding output arrays. """ - if fparam is not None or aparam is not None: - raise NotImplementedError # convert all of the input to numpy array atom_types = np.array(atom_types, dtype=np.int32) coords = np.array(coords) @@ -216,7 +214,7 @@ def eval( ) request_defs = self._get_request_defs(atomic) out = self._eval_func(self._eval_model, numb_test, natoms)( - coords, cells, atom_types, request_defs + coords, cells, atom_types, fparam, aparam, request_defs ) return dict( zip( @@ -306,6 +304,8 @@ def _eval_model( coords: np.ndarray, cells: Optional[np.ndarray], atom_types: np.ndarray, + fparam: Optional[np.ndarray], + aparam: Optional[np.ndarray], request_defs: list[OutputVariableDef], ): model = self.dp @@ -323,12 +323,25 @@ def _eval_model( box_input = cells.reshape([-1, 3, 3]) else: box_input = None + if fparam is not None: + fparam_input = fparam.reshape(nframes, self.get_dim_fparam()) + else: + fparam_input = None + if aparam is not None: + aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam()) + else: + aparam_input = None do_atomic_virial = any( x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs ) batch_output = model( - coord_input, type_input, box=box_input, do_atomic_virial=do_atomic_virial + coord_input, + type_input, + box=box_input, + fparam=fparam_input, + aparam=aparam_input, + do_atomic_virial=do_atomic_virial, ) if isinstance(batch_output, tuple): batch_output = batch_output[0] diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py new file mode 100644 index 0000000000..ba19785235 --- /dev/null +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +from deepmd.dpmodel.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) +from deepmd.dpmodel.model.dp_model import ( + DPModelCommon, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + +from .make_model import ( + make_model, +) + +DPZBLModel_ = make_model(DPZBLLinearEnergyAtomicModel) + + +@BaseModel.register("zbl") +class DPZBLModel(DPZBLModel_): + model_type = "zbl" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel( + train_data, type_map, local_jdata["dpmodel"] + ) + return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index cccd0732cd..c29240214c 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -1,4 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, +) +from deepmd.dpmodel.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.dpmodel.descriptor.se_e2_a import ( DescrptSeA, ) @@ -8,6 +17,9 @@ from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.dpmodel.model.dp_zbl_model import ( + DPZBLModel, +) from deepmd.dpmodel.model.ener_model import ( EnergyModel, ) @@ -55,6 +67,45 @@ def get_standard_model(data: dict) -> EnergyModel: ) +def get_zbl_model(data: dict) -> DPZBLModel: + data["descriptor"]["ntypes"] = len(data["type_map"]) + descriptor = BaseDescriptor(**data["descriptor"]) + fitting_type = data["fitting_net"].pop("type") + if fitting_type == "ener": + fitting = EnergyFittingNet( + ntypes=descriptor.get_ntypes(), + dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), + **data["fitting_net"], + ) + else: + raise ValueError(f"Unknown fitting type {fitting_type}") + + dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"]) + # pairtab + filepath = data["use_srtab"] + pt_model = PairTabAtomicModel( + filepath, + data["descriptor"]["rcut"], + data["descriptor"]["sel"], + type_map=data["type_map"], + ) + + rmin = data["sw_rmin"] + rmax = data["sw_rmax"] + atom_exclude_types = data.get("atom_exclude_types", []) + pair_exclude_types = data.get("pair_exclude_types", []) + return DPZBLModel( + dp_model, + pt_model, + rmin, + rmax, + type_map=data["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) + + def get_spin_model(data: dict) -> SpinModel: """Get a spin model from a dictionary. @@ -100,6 +151,8 @@ def get_model(data: dict): if model_type == "standard": if "spin" in data: return get_spin_model(data) + elif "use_srtab" in data: + return get_zbl_model(data) else: return get_standard_model(data) else: diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index d9ccf392f5..fd0393c914 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -124,7 +124,7 @@ def test( log.info(f"# testing system : {system}") # create data class - tmap = dp.get_type_map() if isinstance(dp, DeepPot) else None + tmap = dp.get_type_map() data = DeepmdData( system, set_prefix="set", diff --git a/deepmd/jax/__init__.py b/deepmd/jax/__init__.py index 2ff078e797..bb5c0a5206 100644 --- a/deepmd/jax/__init__.py +++ b/deepmd/jax/__init__.py @@ -1,2 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """JAX backend.""" + +from deepmd.utils.entry_point import ( + load_entry_point, +) + +load_entry_point("deepmd.jax") diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index 3ed096f9c1..cabee5a189 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -2,6 +2,9 @@ from deepmd.jax.descriptor.dpa1 import ( DescrptDPA1, ) +from deepmd.jax.descriptor.hybrid import ( + DescrptHybrid, +) from deepmd.jax.descriptor.se_e2_a import ( DescrptSeA, ) @@ -13,4 +16,5 @@ "DescrptSeA", "DescrptSeR", "DescrptDPA1", + "DescrptHybrid", ] diff --git a/deepmd/jax/descriptor/hybrid.py b/deepmd/jax/descriptor/hybrid.py new file mode 100644 index 0000000000..20fc5f838b --- /dev/null +++ b/deepmd/jax/descriptor/hybrid.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("hybrid") +@flax_module +class DescrptHybrid(DescrptHybridDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"nlist_cut_idx"}: + value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value] + elif name in {"descrpt_list"}: + value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value] + + return super().__setattr__(name, value) diff --git a/deepmd/jax/fitting/__init__.py b/deepmd/jax/fitting/__init__.py index e72314dcab..226a6d5b43 100644 --- a/deepmd/jax/fitting/__init__.py +++ b/deepmd/jax/fitting/__init__.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from deepmd.jax.fitting.fitting import ( + DipoleFittingNet, DOSFittingNet, EnergyFittingNet, + PolarFittingNet, ) __all__ = [ "EnergyFittingNet", "DOSFittingNet", + "DipoleFittingNet", + "PolarFittingNet", ] diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index cef1f667b3..d62681490c 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -3,8 +3,15 @@ Any, ) +from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.dpmodel.fitting.polarizability_fitting import ( + PolarFitting as PolarFittingNetDP, +) +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from deepmd.jax.common import ( ArrayAPIVariable, flax_module, @@ -47,9 +54,40 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseFitting.register("property") +@flax_module +class PropertyFittingNet(PropertyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + @BaseFitting.register("dos") @flax_module class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) return super().__setattr__(name, value) + + +@BaseFitting.register("dipole") +@flax_module +class DipoleFittingNet(DipoleFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + +@BaseFitting.register("polar") +@flax_module +class PolarFittingNet(PolarFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + if name in { + "scale", + "constant_matrix", + }: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + return super().__setattr__(name, value) diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 76f044a327..c1967fb0da 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -214,8 +214,6 @@ def eval( The output of the evaluation. The keys are the names of the output variables, and the values are the corresponding output arrays. """ - if fparam is not None or aparam is not None: - raise NotImplementedError # convert all of the input to numpy array atom_types = np.array(atom_types, dtype=np.int32) coords = np.array(coords) @@ -226,7 +224,7 @@ def eval( ) request_defs = self._get_request_defs(atomic) out = self._eval_func(self._eval_model, numb_test, natoms)( - coords, cells, atom_types, request_defs + coords, cells, atom_types, fparam, aparam, request_defs ) return dict( zip( @@ -316,6 +314,8 @@ def _eval_model( coords: np.ndarray, cells: Optional[np.ndarray], atom_types: np.ndarray, + fparam: Optional[np.ndarray], + aparam: Optional[np.ndarray], request_defs: list[OutputVariableDef], ): model = self.dp @@ -333,6 +333,14 @@ def _eval_model( box_input = cells.reshape([-1, 3, 3]) else: box_input = None + if fparam is not None: + fparam_input = fparam.reshape(nframes, self.get_dim_fparam()) + else: + fparam_input = None + if aparam is not None: + aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam()) + else: + aparam_input = None do_atomic_virial = any( x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs @@ -341,6 +349,8 @@ def _eval_model( to_jax_array(coord_input), to_jax_array(type_input), box=to_jax_array(box_input), + fparam=to_jax_array(fparam_input), + aparam=to_jax_array(aparam_input), do_atomic_virial=do_atomic_virial, ) if isinstance(batch_output, tuple): diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index fcfcc8a610..a7d57523e2 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -51,18 +51,16 @@ def deserialize_to_file(model_file: str, data: dict) -> None: model_def_script = data["model_def_script"] call_lower = model.call_lower - nf, nloc, nghost, nfp, nap = jax_export.symbolic_shape( - "nf, nloc, nghost, nfp, nap" - ) + nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") exported = jax_export.export(jax.jit(call_lower))( jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping - jax.ShapeDtypeStruct((nf, nfp), jnp.float64) + jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64) if model.get_dim_fparam() else None, # fparam - jax.ShapeDtypeStruct((nf, nap), jnp.float64) + jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64) if model.get_dim_aparam() else None, # aparam False, # do_atomic_virial diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py index d19070fc5b..4028d77228 100644 --- a/deepmd/pt/model/model/dp_linear_model.py +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -30,7 +30,7 @@ @BaseModel.register("linear_ener") class LinearEnergyModel(DPLinearModel_): - model_type = "ener" + model_type = "linear_ener" def __init__( self, diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index e1ef00f5fe..0f05e3e56d 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -30,7 +30,7 @@ @BaseModel.register("zbl") class DPZBLModel(DPZBLModel_): - model_type = "ener" + model_type = "zbl" def __init__( self, diff --git a/deepmd/tf/utils/batch_size.py b/deepmd/tf/utils/batch_size.py index 33f1ec0da0..438bf36703 100644 --- a/deepmd/tf/utils/batch_size.py +++ b/deepmd/tf/utils/batch_size.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import os + from packaging.version import ( Version, ) @@ -11,9 +13,23 @@ OutOfMemoryError, ) from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase +from deepmd.utils.batch_size import ( + log, +) class AutoBatchSize(AutoBatchSizeBase): + def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: + super().__init__(initial_batch_size, factor) + DP_INFER_BATCH_SIZE = int(os.environ.get("DP_INFER_BATCH_SIZE", 0)) + if not DP_INFER_BATCH_SIZE > 0: + if self.is_gpu_available(): + log.info( + "If you encounter the error 'an illegal memory access was encountered', this may be due to a TensorFlow issue. " + "To avoid this, set the environment variable DP_INFER_BATCH_SIZE to a smaller value than the last adjusted batch size. " + "The environment variable DP_INFER_BATCH_SIZE controls the inference batch size (nframes * natoms). " + ) + def is_gpu_available(self) -> bool: """Check if GPU is available. diff --git a/deepmd/utils/batch_size.py b/deepmd/utils/batch_size.py index 259fe93bdb..5ab06e55e2 100644 --- a/deepmd/utils/batch_size.py +++ b/deepmd/utils/batch_size.py @@ -61,11 +61,6 @@ def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: self.maximum_working_batch_size = initial_batch_size if self.is_gpu_available(): self.minimal_not_working_batch_size = 2**31 - log.info( - "If you encounter the error 'an illegal memory access was encountered', this may be due to a TensorFlow issue. " - "To avoid this, set the environment variable DP_INFER_BATCH_SIZE to a smaller value than the last adjusted batch size. " - "The environment variable DP_INFER_BATCH_SIZE controls the inference batch size (nframes * natoms). " - ) else: self.minimal_not_working_batch_size = ( self.maximum_working_batch_size + 1 diff --git a/doc/model/dplr.md b/doc/model/dplr.md index 91c2251346..cf071d4029 100644 --- a/doc/model/dplr.md +++ b/doc/model/dplr.md @@ -198,7 +198,7 @@ fix ID group-ID style_name keyword value ... - three or more keyword/value pairs may be appended ``` -keyword = *model* or *type_associate* or *bond_type* or *efield* +keyword = *model* or *type_associate* or *bond_type* or *efield* or *pair_deepmd_index* *model* value = name name = name of DPLR model file (e.g. frozen_model.pb) (not DW model) *type_associate* values = NR1 NW1 NR2 NW2 ... @@ -208,6 +208,8 @@ keyword = *model* or *type_associate* or *bond_type* or *efield* NBi = bond type of i-th (real atom, Wannier centroid) pair *efield* (optional) values = Ex Ey Ez Ex/Ey/Ez = electric field along x/y/z direction + *pair_deepmd_index* (optional) values = idx + idx = The index of pair_style deepmd, starting from 1, if more than one is used ``` **Examples** @@ -223,6 +225,8 @@ fix_modify 0 virial yes ``` The fix command `dplr` calculates the position of WCs by the DW model and back-propagates the long-range interaction on virtual atoms to real toms. +The fix command must be used after [pair_style `deepmd`](../third-party/lammps-command.md#pair_style-deepmd). +If there are more than 1 pair_style `deepmd`, `pair_deepmd_index` (starting from 1) must be set to assign the index of the pair_style `deepmd`. The atom names specified in [pair_style `deepmd`](../third-party/lammps-command.md#pair_style-deepmd) will be used to determine elements. If it is not set, the training parameter {ref}`type_map ` will be mapped to LAMMPS atom types. diff --git a/doc/model/train-fitting-tensor.md b/doc/model/train-fitting-tensor.md index c6b54c69ef..d4d546eccf 100644 --- a/doc/model/train-fitting-tensor.md +++ b/doc/model/train-fitting-tensor.md @@ -1,7 +1,7 @@ -# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: Unlike `energy`, which is a scalar, one may want to fit some high dimensional physical quantity, like `dipole` (vector) and `polarizability` (matrix, shorted as `polar`). Deep Potential has provided different APIs to do this. In this example, we will show you how to train a model to fit a water system. A complete training input script of the examples can be found in diff --git a/doc/model/train-hybrid.md b/doc/model/train-hybrid.md index 1219d208a7..da3b40487b 100644 --- a/doc/model/train-hybrid.md +++ b/doc/model/train-hybrid.md @@ -1,7 +1,7 @@ -# Descriptor `"hybrid"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Descriptor `"hybrid"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: This descriptor hybridizes multiple descriptors to form a new descriptor. For example, we have a list of descriptors denoted by $\mathcal D_1$, $\mathcal D_2$, ..., $\mathcal D_N$, the hybrid descriptor this the concatenation of the list, i.e. $\mathcal D = (\mathcal D_1, \mathcal D_2, \cdots, \mathcal D_N)$. diff --git a/source/lmp/deepmd_version.h.in b/source/lmp/deepmd_version.h.in index 4b99bc7c33..0b74491778 100644 --- a/source/lmp/deepmd_version.h.in +++ b/source/lmp/deepmd_version.h.in @@ -3,8 +3,8 @@ #define GIT_BRANCH @GIT_BRANCH@ #define GIT_DATE @GIT_DATE@ #define DEEPMD_ROOT @CMAKE_INSTALL_PREFIX@ -#define TensorFlow_INCLUDE_DIRS @TensorFlow_INCLUDE_DIRS@ -#define TensorFlow_LIBRARY @TensorFlow_LIBRARY@ +#define BACKEND_INCLUDE_DIRS @BACKEND_INCLUDE_DIRS@ +#define BACKEND_LIBRARY_PATH @BACKEND_LIBRARY_PATH@ #define DPMD_CVT_STR(...) #__VA_ARGS__ #define DPMD_CVT_ASSTR(X) DPMD_CVT_STR(X) #define STR_GIT_SUMM DPMD_CVT_ASSTR(GIT_SUMM) @@ -13,5 +13,5 @@ #define STR_GIT_DATE DPMD_CVT_ASSTR(GIT_DATE) #define STR_FLOAT_PREC DPMD_CVT_ASSTR(FLOAT_PREC) #define STR_DEEPMD_ROOT DPMD_CVT_ASSTR(DEEPMD_ROOT) -#define STR_TensorFlow_INCLUDE_DIRS DPMD_CVT_ASSTR(TensorFlow_INCLUDE_DIRS) -#define STR_TensorFlow_LIBRARY DPMD_CVT_ASSTR(TensorFlow_LIBRARY) +#define STR_BACKEND_INCLUDE_DIRS DPMD_CVT_ASSTR(BACKEND_INCLUDE_DIRS) +#define STR_BACKEND_LIBRARY_PATH DPMD_CVT_ASSTR(BACKEND_LIBRARY_PATH) diff --git a/source/lmp/fix_dplr.cpp b/source/lmp/fix_dplr.cpp index 8a6be7d840..34fd2515ed 100644 --- a/source/lmp/fix_dplr.cpp +++ b/source/lmp/fix_dplr.cpp @@ -62,6 +62,7 @@ FixDPLR::FixDPLR(LAMMPS *lmp, int narg, char **arg) size_vector = 3; qe2f = force->qe2f; xstyle = ystyle = zstyle = NONE; + pair_deepmd_index = 0; if (strcmp(update->unit_style, "lj") == 0) { error->all(FLERR, @@ -125,6 +126,12 @@ FixDPLR::FixDPLR(LAMMPS *lmp, int narg, char **arg) } sort(bond_type.begin(), bond_type.end()); iarg = iend; + } else if (string(arg[iarg]) == string("pair_deepmd_index")) { + if (iarg + 1 >= narg) { + error->all(FLERR, "Illegal pair_deepmd_index, not provided"); + } + pair_deepmd_index = atoi(arg[iarg + 1]); + iarg += 2; } else { break; } @@ -141,7 +148,7 @@ FixDPLR::FixDPLR(LAMMPS *lmp, int narg, char **arg) error->one(FLERR, e.what()); } - pair_deepmd = (PairDeepMD *)force->pair_match("deepmd", 1); + pair_deepmd = (PairDeepMD *)force->pair_match("deepmd", 1, pair_deepmd_index); if (!pair_deepmd) { error->all(FLERR, "pair_style deepmd should be set before this fix\n"); } diff --git a/source/lmp/fix_dplr.h b/source/lmp/fix_dplr.h index a6822fe4fe..c43296e611 100644 --- a/source/lmp/fix_dplr.h +++ b/source/lmp/fix_dplr.h @@ -80,6 +80,9 @@ class FixDPLR : public Fix { void update_efield_variables(); enum { NONE, CONSTANT, EQUAL }; std::vector type_idx_map; + /* The index of deepmd pair index, which starts from 1. By default 0, which + * works only when there is one deepmd pair. */ + int pair_deepmd_index; }; } // namespace LAMMPS_NS diff --git a/source/lmp/pair_deepmd.cpp b/source/lmp/pair_deepmd.cpp index 09d97fe460..d741814aa5 100644 --- a/source/lmp/pair_deepmd.cpp +++ b/source/lmp/pair_deepmd.cpp @@ -437,10 +437,8 @@ void PairDeepMD::print_summary(const string pre) const { cout << pre << "source branch: " << STR_GIT_BRANCH << endl; cout << pre << "source commit: " << STR_GIT_HASH << endl; cout << pre << "source commit at: " << STR_GIT_DATE << endl; - cout << pre << "build float prec: " << STR_FLOAT_PREC << endl; - cout << pre << "build with tf inc: " << STR_TensorFlow_INCLUDE_DIRS - << endl; - cout << pre << "build with tf lib: " << STR_TensorFlow_LIBRARY << endl; + cout << pre << "build with inc: " << STR_BACKEND_INCLUDE_DIRS << endl; + cout << pre << "build with lib: " << STR_BACKEND_LIBRARY_PATH << endl; std::cout.rdbuf(sbuf); utils::logmesg(lmp, buffer.str()); diff --git a/source/tests/array_api_strict/descriptor/__init__.py b/source/tests/array_api_strict/descriptor/__init__.py index 6ceb116d85..5667fed858 100644 --- a/source/tests/array_api_strict/descriptor/__init__.py +++ b/source/tests/array_api_strict/descriptor/__init__.py @@ -1 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .dpa1 import ( + DescrptDPA1, +) +from .hybrid import ( + DescrptHybrid, +) +from .se_e2_a import ( + DescrptSeA, +) +from .se_e2_r import ( + DescrptSeR, +) + +__all__ = [ + "DescrptSeA", + "DescrptSeR", + "DescrptDPA1", + "DescrptHybrid", +] diff --git a/source/tests/array_api_strict/descriptor/base_descriptor.py b/source/tests/array_api_strict/descriptor/base_descriptor.py new file mode 100644 index 0000000000..2a31895f55 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/base_descriptor.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.make_base_descriptor import ( + make_base_descriptor, +) + +# no type annotations standard in array api +BaseDescriptor = make_base_descriptor(Any) diff --git a/source/tests/array_api_strict/descriptor/dpa1.py b/source/tests/array_api_strict/descriptor/dpa1.py index ebd688e303..d14444f269 100644 --- a/source/tests/array_api_strict/descriptor/dpa1.py +++ b/source/tests/array_api_strict/descriptor/dpa1.py @@ -27,6 +27,9 @@ from ..utils.type_embed import ( TypeEmbedNet, ) +from .base_descriptor import ( + BaseDescriptor, +) class GatedAttentionLayer(GatedAttentionLayerDP): @@ -72,6 +75,8 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseDescriptor.register("dpa1") +@BaseDescriptor.register("se_atten") class DescrptDPA1(DescrptDPA1DP): def __setattr__(self, name: str, value: Any) -> None: if name == "se_atten": diff --git a/source/tests/array_api_strict/descriptor/hybrid.py b/source/tests/array_api_strict/descriptor/hybrid.py new file mode 100644 index 0000000000..aaaa24ed6b --- /dev/null +++ b/source/tests/array_api_strict/descriptor/hybrid.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP + +from ..common import ( + to_array_api_strict_array, +) +from .base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("hybrid") +class DescrptHybrid(DescrptHybridDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"nlist_cut_idx"}: + value = [to_array_api_strict_array(vv) for vv in value] + elif name in {"descrpt_list"}: + value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value] + + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/descriptor/se_e2_a.py b/source/tests/array_api_strict/descriptor/se_e2_a.py index 654b9f8925..17da2aafbf 100644 --- a/source/tests/array_api_strict/descriptor/se_e2_a.py +++ b/source/tests/array_api_strict/descriptor/se_e2_a.py @@ -14,8 +14,13 @@ from ..utils.network import ( NetworkCollection, ) +from .base_descriptor import ( + BaseDescriptor, +) +@BaseDescriptor.register("se_e2_a") +@BaseDescriptor.register("se_a") class DescrptSeA(DescrptSeADP): def __setattr__(self, name: str, value: Any) -> None: if name in {"dstd", "davg"}: diff --git a/source/tests/array_api_strict/descriptor/se_e2_r.py b/source/tests/array_api_strict/descriptor/se_e2_r.py index 839e536cea..b499f4c4c9 100644 --- a/source/tests/array_api_strict/descriptor/se_e2_r.py +++ b/source/tests/array_api_strict/descriptor/se_e2_r.py @@ -14,8 +14,13 @@ from ..utils.network import ( NetworkCollection, ) +from .base_descriptor import ( + BaseDescriptor, +) +@BaseDescriptor.register("se_e2_r") +@BaseDescriptor.register("se_r") class DescrptSeR(DescrptSeRDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"dstd", "davg"}: diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py index 8b65320203..323a49cfe8 100644 --- a/source/tests/array_api_strict/fitting/fitting.py +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -3,8 +3,15 @@ Any, ) +from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.dpmodel.fitting.polarizability_fitting import ( + PolarFitting as PolarFittingNetDP, +) +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from ..common import ( to_array_api_strict_array, @@ -39,7 +46,30 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +class PropertyFittingNet(PropertyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) return super().__setattr__(name, value) + + +class DipoleFittingNet(DipoleFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + +class PolarFittingNet(PolarFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + if name in { + "scale", + "constant_matrix", + }: + value = to_array_api_strict_array(value) + return super().__setattr__(name, value) diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index bcad7c4502..734486becb 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -75,7 +75,7 @@ class CommonTest(ABC): data: ClassVar[dict] """Arguments data.""" - addtional_data: ClassVar[dict] = {} + additional_data: ClassVar[dict] = {} """Additional data that will not be checked.""" tf_class: ClassVar[Optional[type]] """TensorFlow model class.""" @@ -128,7 +128,7 @@ def init_backend_cls(self, cls) -> Any: def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" - return cls(**data, **self.addtional_data) + return cls(**data, **self.additional_data) @abstractmethod def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: diff --git a/source/tests/consistent/descriptor/test_hybrid.py b/source/tests/consistent/descriptor/test_hybrid.py index cd52eea5be..c43652b498 100644 --- a/source/tests/consistent/descriptor/test_hybrid.py +++ b/source/tests/consistent/descriptor/test_hybrid.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -28,6 +30,16 @@ from deepmd.tf.descriptor.hybrid import DescrptHybrid as DescrptHybridTF else: DescrptHybridTF = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.hybrid import DescrptHybrid as DescrptHybridJAX +else: + DescrptHybridJAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.hybrid import ( + DescrptHybrid as DescrptHybridStrict, + ) +else: + DescrptHybridStrict = None from deepmd.utils.argcheck import ( descrpt_hybrid_args, ) @@ -68,8 +80,13 @@ def data(self) -> dict: tf_class = DescrptHybridTF dp_class = DescrptHybridDP pt_class = DescrptHybridPT + jax_class = DescrptHybridJAX + array_api_strict_class = DescrptHybridStrict args = descrpt_hybrid_args() + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + def setUp(self): CommonTest.setUp(self) @@ -132,5 +149,23 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],) diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 5d7be1b0e5..60ee7322c1 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -32,6 +34,21 @@ from deepmd.tf.fit.dipole import DipoleFittingSeA as DipoleFittingTF else: DipoleFittingTF = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import DipoleFittingNet as DipoleFittingJAX +else: + DipoleFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + DipoleFittingNet as DipoleFittingArrayAPIStrict, + ) +else: + DipoleFittingArrayAPIStrict = object from deepmd.utils.argcheck import ( fitting_dipole, ) @@ -69,7 +86,11 @@ def skip_pt(self) -> bool: tf_class = DipoleFittingTF dp_class = DipoleFittingDP pt_class = DipoleFittingPT + jax_class = DipoleFittingJAX + array_api_strict_class = DipoleFittingArrayAPIStrict args = fitting_dipole() + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) @@ -83,7 +104,7 @@ def setUp(self): self.atype.sort() @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, @@ -143,6 +164,26 @@ def eval_dp(self, dp_obj: Any) -> Any: None, )["dipole"] + def eval_jax(self, jax_obj: Any) -> Any: + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + jnp.asarray(self.gr), + None, + )["dipole"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + array_api_strict.asarray(self.gr), + None, + )["dipole"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index 774e3f655e..d3de3ef151 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -124,7 +124,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index e32410a0ec..f4e78ce966 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -134,7 +134,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 6a3465ba24..bd9d013b8d 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -32,6 +34,21 @@ from deepmd.tf.fit.polar import PolarFittingSeA as PolarFittingTF else: PolarFittingTF = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import PolarFittingNet as PolarFittingJAX +else: + PolarFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + PolarFittingNet as PolarFittingArrayAPIStrict, + ) +else: + PolarFittingArrayAPIStrict = object from deepmd.utils.argcheck import ( fitting_polar, ) @@ -69,7 +86,11 @@ def skip_pt(self) -> bool: tf_class = PolarFittingTF dp_class = PolarFittingDP pt_class = PolarFittingPT + jax_class = PolarFittingJAX + array_api_strict_class = PolarFittingArrayAPIStrict args = fitting_polar() + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) @@ -83,7 +104,7 @@ def setUp(self): self.atype.sort() @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, @@ -143,6 +164,26 @@ def eval_dp(self, dp_obj: Any) -> Any: None, )["polarizability"] + def eval_jax(self, jax_obj: Any) -> Any: + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + jnp.asarray(self.gr), + None, + )["polarizability"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + array_api_strict.asarray(self.gr), + None, + )["polarizability"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index beb21d9c04..a096d4dd68 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -17,6 +17,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, CommonTest, parameterized, @@ -32,6 +34,22 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: PropertyFittingPT = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import PropertyFittingNet as PropertyFittingJAX +else: + PropertyFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + PropertyFittingNet as PropertyFittingStrict, + ) +else: + PropertyFittingStrict = object + PropertyFittingTF = object @@ -84,9 +102,14 @@ def skip_pt(self) -> bool: def skip_tf(self) -> bool: return True + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + tf_class = PropertyFittingTF dp_class = PropertyFittingDP pt_class = PropertyFittingPT + jax_class = PropertyFittingJAX + array_api_strict_class = PropertyFittingStrict args = fitting_property() def setUp(self): @@ -104,7 +127,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, @@ -183,6 +206,45 @@ def eval_dp(self, dp_obj: Any) -> Any: aparam=self.aparam if numb_aparam else None, )["property"] + def eval_jax(self, jax_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + fparam=jnp.asarray(self.fparam) if numb_fparam else None, + aparam=jnp.asarray(self.aparam) if numb_aparam else None, + )["property"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + array_api_strict.set_array_api_strict_flags(api_version="2023.12") + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None, + aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None, + )["property"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index dc0f280d56..af26c41694 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -136,6 +136,8 @@ def test_deep_eval(self): [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], dtype=GLOBAL_NP_FLOAT_PRECISION, ).reshape(1, 9) + natoms = self.atype.shape[1] + nframes = self.atype.shape[0] prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): @@ -145,10 +147,20 @@ def test_deep_eval(self): reference_data = copy.deepcopy(self.data) self.save_data_to_model(prefix + backend.suffixes[0], reference_data) deep_eval = DeepEval(prefix + backend.suffixes[0]) + if deep_eval.get_dim_fparam() > 0: + fparam = np.ones((nframes, deep_eval.get_dim_fparam())) + else: + fparam = None + if deep_eval.get_dim_aparam() > 0: + aparam = np.ones((nframes, natoms, deep_eval.get_dim_aparam())) + else: + aparam = None ret = deep_eval.eval( self.coords, self.box, self.atype, + fparam=fparam, + aparam=aparam, ) rets.append(ret) for ret in rets[1:]: @@ -199,3 +211,47 @@ def setUp(self): def tearDown(self): IOTest.tearDown(self) + + +class TestDeepPotFparamAparam(unittest.TestCase, IOTest): + def setUp(self): + model_def_script = { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 3, + 6, + ], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "type": "ener", + "neuron": [ + 5, + 5, + ], + "resnet_dt": True, + "precision": "float64", + "atom_ener": [], + "seed": 1, + "numb_fparam": 2, + "numb_aparam": 2, + }, + } + model = get_model(copy.deepcopy(model_def_script)) + self.data = { + "model": model.serialize(), + "backend": "test", + "model_def_script": model_def_script, + } + + def tearDown(self): + IOTest.tearDown(self) diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 2a358ba7e0..98330ba849 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -130,7 +130,7 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_pt(data) elif cls is EnergyModelJAX: return get_model_jax(data) - return cls(**data, **self.addtional_data) + return cls(**data, **self.additional_data) def setUp(self): CommonTest.setUp(self) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py new file mode 100644 index 0000000000..f37bee0c90 --- /dev/null +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + SKIP_FLAG, + CommonTest, + parameterized, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelPT +else: + DPZBLModelPT = None +import os + +from deepmd.utils.argcheck import ( + model_args, +) + +TESTS_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + + +@parameterized( + ( + [], + [[0, 1]], + ), + ( + [], + [1], + ), +) +class TestEner(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + pair_exclude_types, atom_exclude_types = self.param + return { + "type_map": ["O", "H", "B"], + "use_srtab": f"{TESTS_DIR}/pt/water/data/zbl_tab_potential/H2O_tab_potential.txt", + "smin_alpha": 0.1, + "sw_rmin": 0.2, + "sw_rmax": 4.0, + "pair_exclude_types": pair_exclude_types, + "atom_exclude_types": atom_exclude_types, + "descriptor": { + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [3, 6], + "axis_neuron": 2, + "attn": 8, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "set_davg_zero": True, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "seed": 1, + }, + } + + dp_class = DPZBLModelDP + pt_class = DPZBLModelPT + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_jax: + return self.RefBackend.JAX + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_tf(self): + return True + + @property + def skip_jax(self): + return True + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is DPZBLModelDP: + return get_model_dp(data) + elif cls is DPZBLModelPT: + return get_model_pt(data) + return cls(**data, **self.additional_data) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + return self.build_tf_model( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend is self.RefBackend.DP: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + ) + elif backend is self.RefBackend.PT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ) + elif backend is self.RefBackend.TF: + return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) + elif backend is self.RefBackend.JAX: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + ret["energy_derv_r"].ravel(), + ret["energy_derv_c_redu"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index a4b516ef16..0dd17c841e 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -82,7 +82,7 @@ def data(self) -> dict: skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/pt/test_dp_test.py b/source/tests/pt/test_dp_test.py index c18c3286f6..0427f2b14a 100644 --- a/source/tests/pt/test_dp_test.py +++ b/source/tests/pt/test_dp_test.py @@ -152,6 +152,9 @@ def setUp(self): self.config["training"]["training_data"]["systems"] = data_file self.config["training"]["validation_data"]["systems"] = data_file self.config["model"] = deepcopy(model_property) + self.config["model"]["type_map"] = [ + self.config["model"]["type_map"][i] for i in [1, 0, 3, 2] + ] self.input_json = "test_dp_test_property.json" with open(self.input_json, "w") as fp: json.dump(self.config, fp, indent=4)