diff --git a/deepmd/dpmodel/__init__.py b/deepmd/dpmodel/__init__.py new file mode 100644 index 0000000000..5a83bb7bd4 --- /dev/null +++ b/deepmd/dpmodel/__init__.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .common import ( + DEFAULT_PRECISION, + PRECISION_DICT, + NativeOP, +) +from .model import ( + DPAtomicModel, + DPModel, +) +from .output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, + fitting_check_output, + get_deriv_name, + get_reduce_name, + model_check_output, +) + +__all__ = [ + "DPModel", + "DPAtomicModel", + "PRECISION_DICT", + "DEFAULT_PRECISION", + "NativeOP", + "ModelOutputDef", + "FittingOutputDef", + "OutputVariableDef", + "model_check_output", + "fitting_check_output", + "get_reduce_name", + "get_deriv_name", +] diff --git a/deepmd/model_format/common.py b/deepmd/dpmodel/common.py similarity index 85% rename from deepmd/model_format/common.py rename to deepmd/dpmodel/common.py index d032e5d5df..1e35bd4d49 100644 --- a/deepmd/model_format/common.py +++ b/deepmd/dpmodel/common.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from abc import ( ABC, + abstractmethod, ) import numpy as np @@ -12,6 +13,8 @@ "half": np.float16, "single": np.float32, "double": np.float64, + "int32": np.int32, + "int64": np.int64, } DEFAULT_PRECISION = "float64" @@ -19,9 +22,10 @@ class NativeOP(ABC): """The unit operation of a native model.""" + @abstractmethod def call(self, *args, **kwargs): """Forward pass in NumPy implementation.""" - raise NotImplementedError + pass def __call__(self, *args, **kwargs): """Forward pass in NumPy implementation.""" diff --git a/deepmd/dpmodel/descriptor/__init__.py b/deepmd/dpmodel/descriptor/__init__.py new file mode 100644 index 0000000000..5eca26acc5 --- /dev/null +++ b/deepmd/dpmodel/descriptor/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .make_base_descriptor import ( + make_base_descriptor, +) +from .se_e2_a import ( + DescrptSeA, +) + +__all__ = [ + "DescrptSeA", + "make_base_descriptor", +] diff --git a/deepmd/dpmodel/descriptor/base_descriptor.py b/deepmd/dpmodel/descriptor/base_descriptor.py new file mode 100644 index 0000000000..ca403d7f8e --- /dev/null +++ b/deepmd/dpmodel/descriptor/base_descriptor.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from .make_base_descriptor import ( + make_base_descriptor, +) + +BaseDescriptor = make_base_descriptor(np.ndarray, "call") diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py new file mode 100644 index 0000000000..2b0025af07 --- /dev/null +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractclassmethod, + abstractmethod, +) +from typing import ( + List, + Optional, +) + + +def make_base_descriptor( + t_tensor, + fwd_method_name: str = "forward", +): + """Make the base class for the descriptor. + + Parameters + ---------- + t_tensor + The type of the tensor. used in the type hint. + fwd_method_name + Name of the forward method. For dpmodels, it should be "call". + For torch models, it should be "forward". + + """ + + class BD(ABC): + """Base descriptor provides the interfaces of descriptor.""" + + @abstractmethod + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + pass + + @abstractmethod + def get_sel(self) -> List[int]: + """Returns the number of selected neighboring atoms for each type.""" + pass + + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return sum(self.get_sel()) + + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.get_nsel() + + @abstractmethod + def get_ntypes(self) -> int: + """Returns the number of element types.""" + pass + + @abstractmethod + def get_dim_out(self) -> int: + """Returns the output descriptor dimension.""" + pass + + @abstractmethod + def get_dim_emb(self) -> int: + """Returns the embedding dimension of g2.""" + pass + + @abstractmethod + def distinguish_types(self) -> bool: + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + pass + + @abstractmethod + def compute_input_stats(self, merged): + """Update mean and stddev for descriptor elements.""" + pass + + @abstractmethod + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + """Initialize the model bias by the statistics.""" + pass + + @abstractmethod + def fwd( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[t_tensor] = None, + ): + """Calculate descriptor.""" + pass + + @abstractmethod + def serialize(self) -> dict: + """Serialize the obj to dict.""" + pass + + @abstractclassmethod + def deserialize(cls): + """Deserialize from a dict.""" + pass + + setattr(BD, fwd_method_name, BD.fwd) + delattr(BD, "fwd") + + return BD diff --git a/deepmd/model_format/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py similarity index 88% rename from deepmd/model_format/se_e2_a.py rename to deepmd/dpmodel/descriptor/se_e2_a.py index f179b10ac3..1cbaf69c49 100644 --- a/deepmd/model_format/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -13,20 +13,22 @@ Optional, ) -from .common import ( +from deepmd.dpmodel import ( DEFAULT_PRECISION, NativeOP, ) -from .env_mat import ( - EnvMat, -) -from .network import ( +from deepmd.dpmodel.utils import ( EmbeddingNet, + EnvMat, NetworkCollection, ) +from .base_descriptor import ( + BaseDescriptor, +) + -class DescrptSeA(NativeOP): +class DescrptSeA(NativeOP, BaseDescriptor): r"""DeepPot-SE constructed from all information (both angular and radial) of atomic configurations. The embedding takes the distance between atoms as input. @@ -193,9 +195,43 @@ def __getitem__(self, key): @property def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.get_dim_out() + + def get_dim_out(self): """Returns the output dimension of this descriptor.""" return self.neuron[-1] * self.axis_neuron + def get_dim_emb(self): + """Returns the embedding (g2) dimension of this descriptor.""" + return self.neuron[-1] + + def get_rcut(self): + """Returns cutoff radius.""" + return self.rcut + + def get_sel(self): + """Returns cutoff radius.""" + return self.sel + + def distinguish_types(self): + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return True + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def compute_input_stats(self, merged): + """Update mean and stddev for descriptor elements.""" + raise NotImplementedError + + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + """Initialize the model bias by the statistics.""" + raise NotImplementedError + def cal_g( self, ss, @@ -212,6 +248,7 @@ def call( coord_ext, atype_ext, nlist, + mapping: Optional[np.ndarray] = None, ): """Compute the descriptor. @@ -223,6 +260,8 @@ def call( The extended aotm types. shape: nf x nall nlist The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping from extended to lcoal region. not used by this descriptor. Returns ------- @@ -240,6 +279,7 @@ def call( sw The smooth switch function. """ + del mapping # nf x nloc x nnei x 4 rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd) nf, nloc, nnei, _ = rr.shape diff --git a/deepmd/dpmodel/fitting/__init__.py b/deepmd/dpmodel/fitting/__init__.py new file mode 100644 index 0000000000..2bd5e23f5b --- /dev/null +++ b/deepmd/dpmodel/fitting/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .invar_fitting import ( + InvarFitting, +) +from .make_base_fitting import ( + make_base_fitting, +) + +__all__ = [ + "InvarFitting", + "make_base_fitting", +] diff --git a/deepmd/dpmodel/fitting/base_fitting.py b/deepmd/dpmodel/fitting/base_fitting.py new file mode 100644 index 0000000000..bb1853a4a0 --- /dev/null +++ b/deepmd/dpmodel/fitting/base_fitting.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from .make_base_fitting import ( + make_base_fitting, +) + +BaseFitting = make_base_fitting(np.ndarray, fwd_method_name="call") diff --git a/deepmd/model_format/fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py similarity index 96% rename from deepmd/model_format/fitting.py rename to deepmd/dpmodel/fitting/invar_fitting.py index 904fb42b76..efe2771323 100644 --- a/deepmd/model_format/fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -2,30 +2,35 @@ import copy from typing import ( Any, + Dict, List, Optional, ) import numpy as np -from .common import ( +from deepmd.dpmodel import ( DEFAULT_PRECISION, NativeOP, ) -from .network import ( - FittingNet, - NetworkCollection, -) -from .output_def import ( +from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, fitting_check_output, ) +from deepmd.dpmodel.utils import ( + FittingNet, + NetworkCollection, +) + +from .base_fitting import ( + BaseFitting, +) @fitting_check_output -class InvarFitting(NativeOP): - r"""Fitting the energy (or a porperty of `dim_out`) of the system. The force and the virial can also be trained. +class InvarFitting(NativeOP, BaseFitting): + r"""Fitting the energy (or a rotationally invariant porperty of `dim_out`) of the system. The force and the virial can also be trained. Lets take the energy fitting task as an example. The potential energy :math:`E` is a fitting network function of the descriptor :math:`\mathcal{D}`: @@ -279,7 +284,7 @@ def call( h2: Optional[np.array] = None, fparam: Optional[np.array] = None, aparam: Optional[np.array] = None, - ): + ) -> Dict[str, np.array]: """Calculate the fitting. Parameters @@ -320,7 +325,7 @@ def call( "which is not consistent with {self.numb_fparam}.", ) fparam = (fparam - self.fparam_avg) * self.fparam_inv_std - fparam = np.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) + fparam = np.tile(fparam.reshape([nf, 1, self.numb_fparam]), [1, nloc, 1]) xx = np.concatenate( [xx, fparam], axis=-1, @@ -333,6 +338,7 @@ def call( "get an input aparam of dim {aparam.shape[-1]}, ", "which is not consistent with {self.numb_aparam}.", ) + aparam = aparam.reshape([nf, nloc, self.numb_aparam]) aparam = (aparam - self.aparam_avg) * self.aparam_inv_std xx = np.concatenate( [xx, aparam], diff --git a/deepmd/dpmodel/fitting/make_base_fitting.py b/deepmd/dpmodel/fitting/make_base_fitting.py new file mode 100644 index 0000000000..719ac6169e --- /dev/null +++ b/deepmd/dpmodel/fitting/make_base_fitting.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractclassmethod, + abstractmethod, +) +from typing import ( + Dict, + Optional, +) + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, +) + + +def make_base_fitting( + t_tensor, + fwd_method_name: str = "forward", +): + """Make the base class for the fitting. + + Parameters + ---------- + t_tensor + The type of the tensor. used in the type hint. + fwd_method_name + Name of the forward method. For dpmodels, it should be "call". + For torch models, it should be "forward". + + """ + + class BF(ABC): + """Base fitting provides the interfaces of fitting net.""" + + @abstractmethod + def output_def(self) -> FittingOutputDef: + """Returns the output def of the fitting net.""" + pass + + @abstractmethod + def fwd( + self, + descriptor: t_tensor, + atype: t_tensor, + gr: Optional[t_tensor] = None, + g2: Optional[t_tensor] = None, + h2: Optional[t_tensor] = None, + fparam: Optional[t_tensor] = None, + aparam: Optional[t_tensor] = None, + ) -> Dict[str, t_tensor]: + """Calculate fitting.""" + pass + + @abstractmethod + def serialize(self) -> dict: + """Serialize the obj to dict.""" + pass + + @abstractclassmethod + def deserialize(cls): + """Deserialize from a dict.""" + pass + + setattr(BF, fwd_method_name, BF.fwd) + delattr(BF, "fwd") + + return BF diff --git a/deepmd/dpmodel/model/__init__.py b/deepmd/dpmodel/model/__init__.py new file mode 100644 index 0000000000..5c0a32673d --- /dev/null +++ b/deepmd/dpmodel/model/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .dp_atomic_model import ( + DPAtomicModel, +) +from .dp_model import ( + DPModel, +) +from .make_base_atomic_model import ( + make_base_atomic_model, +) + +__all__ = [ + "DPModel", + "DPAtomicModel", + "make_base_atomic_model", +] diff --git a/deepmd/dpmodel/model/base_atomic_model.py b/deepmd/dpmodel/model/base_atomic_model.py new file mode 100644 index 0000000000..b9521cde8e --- /dev/null +++ b/deepmd/dpmodel/model/base_atomic_model.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from .make_base_atomic_model import ( + make_base_atomic_model, +) + +BaseAtomicModel = make_base_atomic_model(np.ndarray) diff --git a/deepmd/dpmodel/model/dp_atomic_model.py b/deepmd/dpmodel/model/dp_atomic_model.py new file mode 100644 index 0000000000..63c44aa1f8 --- /dev/null +++ b/deepmd/dpmodel/model/dp_atomic_model.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import sys +from typing import ( + Dict, + List, + Optional, +) + +import numpy as np + +from deepmd.dpmodel.descriptor import ( # noqa # TODO: should import all descriptors! + DescrptSeA, +) +from deepmd.dpmodel.fitting import ( # noqa # TODO: should import all fittings! + InvarFitting, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) + + +class DPAtomicModel(BaseAtomicModel): + """Model give atomic prediction of some physical property. + + Parameters + ---------- + descriptor + Descriptor + fitting_net + Fitting net + type_map + Mapping atom type to the name (str) of the type. + For example `type_map[1]` gives the name of the type 1. + + """ + + def __init__( + self, + descriptor, + fitting, + type_map: Optional[List[str]] = None, + ): + super().__init__() + self.type_map = type_map + self.descriptor = descriptor + self.fitting = fitting + + def fitting_output_def(self) -> FittingOutputDef: + """Get the output def of the fitting net.""" + return self.fitting.output_def() + + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return self.descriptor.get_rcut() + + def get_sel(self) -> List[int]: + """Get the neighbor selection.""" + return self.descriptor.get_sel() + + def distinguish_types(self) -> bool: + """Returns if model requires a neighbor list that distinguish different + atomic types or not. + """ + return self.descriptor.distinguish_types() + + def forward_atomic( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + ) -> Dict[str, np.ndarray]: + """Models' atomic predictions. + + Parameters + ---------- + extended_coord + coodinates in extended region + extended_atype + atomic type in extended region + nlist + neighbor list. nf x nloc x nsel + mapping + mapps the extended indices to local indices. nf x nall + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + + Returns + ------- + result_dict + the result dict, defined by the `FittingOutputDef`. + + """ + nframes, nloc, nnei = nlist.shape + atype = extended_atype[:, :nloc] + descriptor, rot_mat, g2, h2, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + ret = self.fitting( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + return ret + + def serialize(self) -> dict: + return { + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "fitting": self.fitting.serialize(), + "descriptor_name": self.descriptor.__class__.__name__, + "fitting_name": self.fitting.__class__.__name__, + } + + @classmethod + def deserialize(cls, data) -> "DPAtomicModel": + data = copy.deepcopy(data) + descriptor_obj = getattr( + sys.modules[__name__], data["descriptor_name"] + ).deserialize(data["descriptor"]) + fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize( + data["fitting"] + ) + obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"]) + return obj diff --git a/deepmd/dpmodel/model/dp_model.py b/deepmd/dpmodel/model/dp_model.py new file mode 100644 index 0000000000..819d46450e --- /dev/null +++ b/deepmd/dpmodel/model/dp_model.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .dp_atomic_model import ( + DPAtomicModel, +) +from .make_model import ( + make_model, +) + +DPModel = make_model(DPAtomicModel) diff --git a/deepmd/dpmodel/model/make_base_atomic_model.py b/deepmd/dpmodel/model/make_base_atomic_model.py new file mode 100644 index 0000000000..c057cd25f1 --- /dev/null +++ b/deepmd/dpmodel/model/make_base_atomic_model.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractclassmethod, + abstractmethod, +) +from typing import ( + Dict, + List, + Optional, +) + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, +) + + +def make_base_atomic_model( + t_tensor, + fwd_method_name: str = "forward_atomic", +): + """Make the base class for the atomic model. + + Parameters + ---------- + t_tensor + The type of the tensor. used in the type hint. + fwd_method_name + Name of the forward method. For dpmodels, it should be "call". + For torch models, it should be "forward". + + """ + + class BAM(ABC): + """Base Atomic Model provides the interfaces of an atomic model.""" + + @abstractmethod + def fitting_output_def(self) -> FittingOutputDef: + """Get the fitting output def.""" + pass + + @abstractmethod + def get_rcut(self) -> float: + """Get the cut-off radius.""" + pass + + @abstractmethod + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + pass + + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return sum(self.get_sel()) + + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.get_nsel() + + @abstractmethod + def distinguish_types(self) -> bool: + """Returns if the model requires a neighbor list that distinguish different + atomic types or not. + """ + pass + + @abstractmethod + def fwd( + self, + extended_coord: t_tensor, + extended_atype: t_tensor, + nlist: t_tensor, + mapping: Optional[t_tensor] = None, + fparam: Optional[t_tensor] = None, + aparam: Optional[t_tensor] = None, + ) -> Dict[str, t_tensor]: + pass + + @abstractmethod + def serialize(self) -> dict: + pass + + @abstractclassmethod + def deserialize(cls): + pass + + def do_grad( + self, + var_name: Optional[str] = None, + ) -> bool: + """Tell if the output variable `var_name` is differentiable. + if var_name is None, returns if any of the variable is differentiable. + + """ + odef = self.fitting_output_def() + if var_name is None: + require: List[bool] = [] + for vv in odef.keys(): + require.append(self.do_grad_(vv)) + return any(require) + else: + return self.do_grad_(var_name) + + def do_grad_( + self, + var_name: str, + ) -> bool: + """Tell if the output variable `var_name` is differentiable.""" + assert var_name is not None + return self.fitting_output_def()[var_name].differentiable + + setattr(BAM, fwd_method_name, BAM.fwd) + delattr(BAM, "fwd") + + return BAM diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py new file mode 100644 index 0000000000..fec04255fa --- /dev/null +++ b/deepmd/dpmodel/model/make_model.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + Optional, +) + +import numpy as np + +from deepmd.dpmodel.output_def import ( + ModelOutputDef, +) +from deepmd.dpmodel.utils import ( + build_neighbor_list, + extend_coord_with_ghosts, + nlist_distinguish_types, + normalize_coord, +) + +from .transform_output import ( + communicate_extended_output, + fit_output_to_model_output, +) + + +def make_model(T_AtomicModel): + """Make a model as a derived class of an atomic model. + + The model provide two interfaces. + + 1. the `call_lower`, that takes extended coordinates, atyps and neighbor list, + and outputs the atomic and property and derivatives (if required) on the extended region. + + 2. the `call`, that takes coordinates, atypes and cell and predicts + the atomic and reduced property, and derivatives (if required) on the local region. + + Parameters + ---------- + T_AtomicModel + The atomic model. + + Returns + ------- + CM + The model. + + """ + + class CM(T_AtomicModel): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__( + *args, + **kwargs, + ) + + def model_output_def(self): + """Get the output def for the model.""" + return ModelOutputDef(self.fitting_output_def()) + + def call( + self, + coord, + atype, + box: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, np.ndarray]: + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type Dict[str,np.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + nframes, nloc = atype.shape[:2] + if box is not None: + coord_normalized = normalize_coord( + coord.reshape(nframes, nloc, 3), + box.reshape(nframes, 3, 3), + ) + else: + coord_normalized = coord.copy() + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, self.get_rcut() + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + self.get_rcut(), + self.get_sel(), + distinguish_types=self.distinguish_types(), + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + model_predict_lower = self.call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = communicate_extended_output( + model_predict_lower, + self.model_output_def(), + mapping, + do_atomic_virial=do_atomic_virial, + ) + return model_predict + + def call_lower( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + do_atomic_virial: bool = False, + ): + """Return model prediction. Lower interface that takes + extended atomic coordinates and types, nlist, and mapping + as input, and returns the predictions on the extended region. + The predictions are not reduced. + + Parameters + ---------- + extended_coord + coodinates in extended region + extended_atype + atomic type in extended region + nlist + neighbor list. nf x nloc x nsel + mapping + mapps the extended indices to local indices + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + whether calculate atomic virial + + Returns + ------- + result_dict + the result dict, defined by the `FittingOutputDef`. + + """ + nframes, nall = extended_atype.shape[:2] + extended_coord = extended_coord.reshape(nframes, -1, 3) + nlist = self.format_nlist(extended_coord, extended_atype, nlist) + atomic_ret = self.forward_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + model_predict = fit_output_to_model_output( + atomic_ret, + self.fitting_output_def(), + extended_coord, + do_atomic_virial=do_atomic_virial, + ) + return model_predict + + def format_nlist( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlist: np.ndarray, + ): + """Format the neighbor list. + + 1. If the number of neighbors in the `nlist` is equal to sum(self.sel), + it does nothong + + 2. If the number of neighbors in the `nlist` is smaller than sum(self.sel), + the `nlist` is pad with -1. + + 3. If the number of neighbors in the `nlist` is larger than sum(self.sel), + the nearest sum(sel) neighbors will be preseved. + + Known limitations: + + In the case of self.distinguish_types, the nlist is always formatted. + May have side effact on the efficiency. + + Parameters + ---------- + extended_coord + coodinates in extended region. nf x nall x 3 + extended_atype + atomic type in extended region. nf x nall + nlist + neighbor list. nf x nloc x nsel + + Returns + ------- + formated_nlist + the formated nlist. + + """ + n_nf, n_nloc, n_nnei = nlist.shape + distinguish_types = self.distinguish_types() + ret = self._format_nlist(extended_coord, nlist, sum(self.get_sel())) + if distinguish_types: + ret = nlist_distinguish_types(ret, extended_atype, self.get_sel()) + return ret + + def _format_nlist( + self, + extended_coord: np.ndarray, + nlist: np.ndarray, + nnei: int, + ): + n_nf, n_nloc, n_nnei = nlist.shape + extended_coord = extended_coord.reshape([n_nf, -1, 3]) + nall = extended_coord.shape[1] + rcut = self.get_rcut() + + if n_nnei < nnei: + # make a copy before revise + ret = np.concatenate( + [ + nlist, + -1 * np.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), + ], + axis=-1, + ) + elif n_nnei > nnei: + # make a copy before revise + m_real_nei = nlist >= 0 + ret = np.where(m_real_nei, nlist, 0) + coord0 = extended_coord[:, :n_nloc, :] + index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2) + coord1 = np.take_along_axis(extended_coord, index, axis=1) + coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3) + rr = np.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) + rr = np.where(m_real_nei, rr, float("inf")) + rr, ret_mapping = np.sort(rr, axis=-1), np.argsort(rr, axis=-1) + ret = np.take_along_axis(ret, ret_mapping, axis=2) + ret = np.where(rr > rcut, -1, ret) + ret = ret[..., :nnei] + else: # n_nnei == nnei: + # copy anyway... + ret = nlist + assert ret.shape[-1] == nnei + return ret + + return CM diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py new file mode 100644 index 0000000000..3c7917d847 --- /dev/null +++ b/deepmd/dpmodel/model/transform_output.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, +) + +import numpy as np + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + get_deriv_name, + get_reduce_name, +) + + +def fit_output_to_model_output( + fit_ret: Dict[str, np.ndarray], + fit_output_def: FittingOutputDef, + coord_ext: np.ndarray, + do_atomic_virial: bool = False, +) -> Dict[str, np.ndarray]: + """Transform the output of the fitting network to + the model output. + + """ + model_ret = dict(fit_ret.items()) + for kk, vv in fit_ret.items(): + vdef = fit_output_def[kk] + shap = vdef.shape + atom_axis = -(len(shap) + 1) + if vdef.reduciable: + kk_redu = get_reduce_name(kk) + model_ret[kk_redu] = np.sum(vv, axis=atom_axis) + if vdef.differentiable: + kk_derv_r, kk_derv_c = get_deriv_name(kk) + # name-holders + model_ret[kk_derv_r] = None + model_ret[kk_derv_c] = None + return model_ret + + +def communicate_extended_output( + model_ret: Dict[str, np.ndarray], + model_output_def: ModelOutputDef, + mapping: np.ndarray, # nf x nloc + do_atomic_virial: bool = False, +) -> Dict[str, np.ndarray]: + """Transform the output of the model network defined on + local and ghost (extended) atoms to local atoms. + + """ + new_ret = {} + for kk in model_output_def.keys_outp(): + vv = model_ret[kk] + vdef = model_output_def[kk] + new_ret[kk] = vv + if vdef.reduciable: + kk_redu = get_reduce_name(kk) + new_ret[kk_redu] = model_ret[kk_redu] + if vdef.differentiable: + kk_derv_r, kk_derv_c = get_deriv_name(kk) + # name holders + new_ret[kk_derv_r] = None + new_ret[kk_derv_c] = None + new_ret[kk_derv_c + "_redu"] = None + if not do_atomic_virial: + # pop atomic virial, because it is not correctly calculated. + new_ret.pop(kk_derv_c) + return new_ret diff --git a/deepmd/model_format/output_def.py b/deepmd/dpmodel/output_def.py similarity index 98% rename from deepmd/model_format/output_def.py rename to deepmd/dpmodel/output_def.py index 268dc21ea6..583f88491e 100644 --- a/deepmd/model_format/output_def.py +++ b/deepmd/dpmodel/output_def.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import functools from typing import ( Dict, List, @@ -42,6 +43,7 @@ def model_check_output(cls): """ + @functools.wraps(cls, updated=()) class wrapper(cls): def __init__( self, @@ -81,6 +83,7 @@ def fitting_check_output(cls): """ + @functools.wraps(cls, updated=()) class wrapper(cls): def __init__( self, diff --git a/deepmd/model_format/__init__.py b/deepmd/dpmodel/utils/__init__.py similarity index 54% rename from deepmd/model_format/__init__.py rename to deepmd/dpmodel/utils/__init__.py index e15f73758e..d3c31ae246 100644 --- a/deepmd/model_format/__init__.py +++ b/deepmd/dpmodel/utils/__init__.py @@ -1,15 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from .common import ( - DEFAULT_PRECISION, - PRECISION_DICT, - NativeOP, -) from .env_mat import ( EnvMat, ) -from .fitting import ( - InvarFitting, -) from .network import ( EmbeddingNet, FittingNet, @@ -23,22 +15,21 @@ save_dp_model, traverse_model_dict, ) -from .output_def import ( - FittingOutputDef, - ModelOutputDef, - OutputVariableDef, - fitting_check_output, - get_deriv_name, - get_reduce_name, - model_check_output, +from .nlist import ( + build_multiple_neighbor_list, + build_neighbor_list, + extend_coord_with_ghosts, + get_multiple_nlist_key, + nlist_distinguish_types, ) -from .se_e2_a import ( - DescrptSeA, +from .region import ( + inter2phys, + normalize_coord, + phys2inter, + to_face_distance, ) __all__ = [ - "InvarFitting", - "DescrptSeA", "EnvMat", "make_multilayer_network", "make_embedding_network", @@ -48,17 +39,18 @@ "NativeLayer", "NativeNet", "NetworkCollection", - "NativeOP", "load_dp_model", "save_dp_model", "traverse_model_dict", "PRECISION_DICT", "DEFAULT_PRECISION", - "ModelOutputDef", - "FittingOutputDef", - "OutputVariableDef", - "model_check_output", - "fitting_check_output", - "get_reduce_name", - "get_deriv_name", + "build_neighbor_list", + "nlist_distinguish_types", + "get_multiple_nlist_key", + "build_multiple_neighbor_list", + "extend_coord_with_ghosts", + "normalize_coord", + "inter2phys", + "phys2inter", + "to_face_distance", ] diff --git a/deepmd/model_format/env_mat.py b/deepmd/dpmodel/utils/env_mat.py similarity index 99% rename from deepmd/model_format/env_mat.py rename to deepmd/dpmodel/utils/env_mat.py index 7822bd7d0c..739b06208c 100644 --- a/deepmd/model_format/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -6,7 +6,7 @@ import numpy as np -from .common import ( +from deepmd.dpmodel import ( NativeOP, ) diff --git a/deepmd/model_format/network.py b/deepmd/dpmodel/utils/network.py similarity index 99% rename from deepmd/model_format/network.py rename to deepmd/dpmodel/utils/network.py index f2056c0b95..17b3043612 100644 --- a/deepmd/model_format/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -22,7 +22,7 @@ except ImportError: __version__ = "unknown" -from .common import ( +from deepmd.dpmodel import ( DEFAULT_PRECISION, PRECISION_DICT, NativeOP, diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py new file mode 100644 index 0000000000..bc6592d52b --- /dev/null +++ b/deepmd/dpmodel/utils/nlist.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Union, +) + +import numpy as np + +from .region import ( + to_face_distance, +) + + +## translated from torch implemantation by chatgpt +def build_neighbor_list( + coord1: np.ndarray, + atype: np.ndarray, + nloc: int, + rcut: float, + sel: Union[int, List[int]], + distinguish_types: bool = True, +) -> np.ndarray: + """Build neightbor list for a single frame. keeps nsel neighbors. + + Parameters + ---------- + coord1 : np.ndarray + exptended coordinates of shape [batch_size, nall x 3] + atype : np.ndarray + extended atomic types of shape [batch_size, nall] + nloc : int + number of local atoms. + rcut : float + cut-off radius + sel : int or List[int] + maximal number of neighbors (of each type). + if distinguish_types==True, nsel should be list and + the length of nsel should be equal to number of + types. + distinguish_types : bool + distinguish different types. + + Returns + ------- + neighbor_list : np.ndarray + Neighbor list of shape [batch_size, nloc, nsel], the neighbors + are stored in an ascending order. If the number of + neighbors is less than nsel, the positions are masked + with -1. The neighbor list of an atom looks like + |------ nsel ------| + xx xx xx xx -1 -1 -1 + if distinguish_types==True and we have two types + |---- nsel[0] -----| |---- nsel[1] -----| + xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1 + + """ + batch_size = coord1.shape[0] + coord1 = coord1.reshape(batch_size, -1) + nall = coord1.shape[1] // 3 + if isinstance(sel, int): + sel = [sel] + nsel = sum(sel) + coord0 = coord1[:, : nloc * 3] + diff = ( + coord1.reshape([batch_size, -1, 3])[:, None, :, :] + - coord0.reshape([batch_size, -1, 3])[:, :, None, :] + ) + assert list(diff.shape) == [batch_size, nloc, nall, 3] + rr = np.linalg.norm(diff, axis=-1) + nlist = np.argsort(rr, axis=-1) + rr = np.sort(rr, axis=-1) + rr = rr[:, :, 1:] + nlist = nlist[:, :, 1:] + nnei = rr.shape[2] + if nsel <= nnei: + rr = rr[:, :, :nsel] + nlist = nlist[:, :, :nsel] + else: + rr = np.concatenate( + [rr, np.ones([batch_size, nloc, nsel - nnei]) + rcut], axis=-1 + ) + nlist = np.concatenate( + [nlist, np.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)], + axis=-1, + ) + assert list(nlist.shape) == [batch_size, nloc, nsel] + nlist = np.where((rr > rcut), -1, nlist) + + if distinguish_types: + return nlist_distinguish_types(nlist, atype, sel) + else: + return nlist + + +def nlist_distinguish_types( + nlist: np.ndarray, + atype: np.ndarray, + sel: List[int], +): + """Given a nlist that does not distinguish atom types, return a nlist that + distinguish atom types. + + """ + nf, nloc, _ = nlist.shape + ret_nlist = [] + tmp_atype = np.tile(atype[:, None], [1, nloc, 1]) + mask = nlist == -1 + tnlist_0 = nlist.copy() + tnlist_0[mask] = 0 + tnlist = np.take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze() + tnlist = np.where(mask, -1, tnlist) + snsel = tnlist.shape[2] + for ii, ss in enumerate(sel): + pick_mask = (tnlist == ii).astype(np.int32) + sorted_indices = np.argsort(-pick_mask, kind="stable", axis=-1) + pick_mask_sorted = -np.sort(-pick_mask, axis=-1) + inlist = np.take_along_axis(nlist, sorted_indices, axis=2) + inlist = np.where(~pick_mask_sorted.astype(bool), -1, inlist) + ret_nlist.append(np.split(inlist, [ss, snsel - ss], axis=-1)[0]) + ret = np.concatenate(ret_nlist, axis=-1) + return ret + + +def get_multiple_nlist_key(rcut: float, nsel: int) -> str: + return str(rcut) + "_" + str(nsel) + + +## translated from torch implemantation by chatgpt +def build_multiple_neighbor_list( + coord: np.ndarray, + nlist: np.ndarray, + rcuts: List[float], + nsels: List[int], +) -> Dict[str, np.ndarray]: + """Input one neighbor list, and produce multiple neighbor lists with + different cutoff radius and numbers of selection out of it. The + required rcuts and nsels should be smaller or equal to the input nlist. + + Parameters + ---------- + coord : np.ndarray + exptended coordinates of shape [batch_size, nall x 3] + nlist : np.ndarray + Neighbor list of shape [batch_size, nloc, nsel], the neighbors + should be stored in an ascending order. + rcuts : List[float] + list of cut-off radius in ascending order. + nsels : List[int] + maximal number of neighbors in ascending order. + + Returns + ------- + nlist_dict : Dict[str, np.ndarray] + A dict of nlists, key given by get_multiple_nlist_key(rc, nsel) + value being the corresponding nlist. + + """ + assert len(rcuts) == len(nsels) + if len(rcuts) == 0: + return {} + nb, nloc, nsel = nlist.shape + if nsel < nsels[-1]: + pad = -1 * np.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype) + nlist = np.concatenate([nlist, pad], axis=-1) + nsel = nsels[-1] + coord1 = coord.reshape(nb, -1, 3) + nall = coord1.shape[1] + coord0 = coord1[:, :nloc, :] + nlist_mask = nlist == -1 + tnlist_0 = nlist + tnlist_0[nlist_mask] = 0 + index = np.tile(tnlist_0.reshape(nb, nloc * nsel, 1), [1, 1, 3]) + coord2 = np.take_along_axis(coord1, index, axis=1).reshape(nb, nloc, nsel, 3) + diff = coord2 - coord0[:, :, None, :] + rr = np.linalg.norm(diff, axis=-1) + rr = np.where(nlist_mask, float("inf"), rr) + nlist0 = nlist + ret = {} + for rc, ns in zip(rcuts[::-1], nsels[::-1]): + tnlist_1 = np.copy(nlist0[:, :, :ns]) + tnlist_1[rr[:, :, :ns] > rc] = int(-1) + ret[get_multiple_nlist_key(rc, ns)] = tnlist_1 + return ret + + +## translated from torch implemantation by chatgpt +def extend_coord_with_ghosts( + coord: np.ndarray, + atype: np.ndarray, + cell: Optional[np.ndarray], + rcut: float, +): + """Extend the coordinates of the atoms by appending peridoc images. + The number of images is large enough to ensure all the neighbors + within rcut are appended. + + Parameters + ---------- + coord : np.ndarray + original coordinates of shape [-1, nloc*3]. + atype : np.ndarray + atom type of shape [-1, nloc]. + cell : np.ndarray + simulation cell tensor of shape [-1, 9]. + rcut : float + the cutoff radius + + Returns + ------- + extended_coord: np.ndarray + extended coordinates of shape [-1, nall*3]. + extended_atype: np.ndarray + extended atom type of shape [-1, nall]. + index_mapping: np.ndarray + maping extended index to the local index + + """ + nf, nloc = atype.shape + aidx = np.tile(np.arange(nloc)[np.newaxis, :], (nf, 1)) + if cell is None: + nall = nloc + extend_coord = coord.copy() + extend_atype = atype.copy() + extend_aidx = aidx.copy() + else: + coord = coord.reshape((nf, nloc, 3)) + cell = cell.reshape((nf, 3, 3)) + to_face = to_face_distance(cell) + nbuff = np.ceil(rcut / to_face).astype(int) + nbuff = np.max(nbuff, axis=0) + xi = np.arange(-nbuff[0], nbuff[0] + 1, 1) + yi = np.arange(-nbuff[1], nbuff[1] + 1, 1) + zi = np.arange(-nbuff[2], nbuff[2] + 1, 1) + xyz = np.outer(xi, np.array([1, 0, 0]))[:, np.newaxis, np.newaxis, :] + xyz = xyz + np.outer(yi, np.array([0, 1, 0]))[np.newaxis, :, np.newaxis, :] + xyz = xyz + np.outer(zi, np.array([0, 0, 1]))[np.newaxis, np.newaxis, :, :] + xyz = xyz.reshape(-1, 3) + shift_idx = xyz[np.argsort(np.linalg.norm(xyz, axis=1))] + ns, _ = shift_idx.shape + nall = ns * nloc + shift_vec = np.einsum("sd,fdk->fsk", shift_idx, cell) + extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :] + extend_atype = np.tile(atype[:, :, np.newaxis], (1, ns, 1)) + extend_aidx = np.tile(aidx[:, :, np.newaxis], (1, ns, 1)) + + return ( + extend_coord.reshape((nf, nall * 3)), + extend_atype.reshape((nf, nall)), + extend_aidx.reshape((nf, nall)), + ) diff --git a/deepmd/dpmodel/utils/region.py b/deepmd/dpmodel/utils/region.py new file mode 100644 index 0000000000..ddbc4b29b8 --- /dev/null +++ b/deepmd/dpmodel/utils/region.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + + +def phys2inter( + coord: np.ndarray, + cell: np.ndarray, +) -> np.ndarray: + """Convert physical coordinates to internal(direct) coordinates. + + Parameters + ---------- + coord : np.ndarray + physical coordinates of shape [*, na, 3]. + cell : np.ndarray + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + inter_coord: np.ndarray + the internal coordinates + + """ + rec_cell = np.linalg.inv(cell) + return np.matmul(coord, rec_cell) + + +def inter2phys( + coord: np.ndarray, + cell: np.ndarray, +) -> np.ndarray: + """Convert internal(direct) coordinates to physical coordinates. + + Parameters + ---------- + coord : np.ndarray + internal coordinates of shape [*, na, 3]. + cell : np.ndarray + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + phys_coord: np.ndarray + the physical coordinates + + """ + return np.matmul(coord, cell) + + +def normalize_coord( + coord: np.ndarray, + cell: np.ndarray, +) -> np.ndarray: + """Apply PBC according to the atomic coordinates. + + Parameters + ---------- + coord : np.ndarray + orignal coordinates of shape [*, na, 3]. + cell : np.ndarray + simulation cell shape [*, 3, 3]. + + Returns + ------- + wrapped_coord: np.ndarray + wrapped coordinates of shape [*, na, 3]. + + """ + icoord = phys2inter(coord, cell) + icoord = np.remainder(icoord, 1.0) + return inter2phys(icoord, cell) + + +def to_face_distance( + cell: np.ndarray, +) -> np.ndarray: + """Compute the to-face-distance of the simulation cell. + + Parameters + ---------- + cell : np.ndarray + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + dist: np.ndarray + the to face distances of shape [*, 3] + + """ + cshape = cell.shape + dist = b_to_face_distance(cell.reshape([-1, 3, 3])) + return dist.reshape(list(cshape[:-2]) + [3]) # noqa:RUF005 + + +def b_to_face_distance(cell): + volume = np.linalg.det(cell) + c_yz = np.cross(cell[:, 1], cell[:, 2], axis=-1) + _h2yz = volume / np.linalg.norm(c_yz, axis=-1) + c_zx = np.cross(cell[:, 2], cell[:, 0], axis=-1) + _h2zx = volume / np.linalg.norm(c_zx, axis=-1) + c_xy = np.cross(cell[:, 0], cell[:, 1], axis=-1) + _h2xy = volume / np.linalg.norm(c_xy, axis=-1) + return np.stack([_h2yz, _h2zx, _h2xy], axis=1) diff --git a/deepmd/pt/model/descriptor/base_descriptor.py b/deepmd/pt/model/descriptor/base_descriptor.py new file mode 100644 index 0000000000..aa142b3acb --- /dev/null +++ b/deepmd/pt/model/descriptor/base_descriptor.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + +from deepmd.dpmodel.descriptor import ( + make_base_descriptor, +) + +BaseDescriptor = make_base_descriptor(torch.Tensor, "forward") diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index bb98e8dc15..b4e866bb11 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -19,8 +19,12 @@ Plugin, ) +from .base_descriptor import ( + BaseDescriptor, +) + -class Descriptor(torch.nn.Module, ABC): +class Descriptor(torch.nn.Module, BaseDescriptor): """The descriptor. Given the atomic coordinates, atomic types and neighbor list, calculate the descriptor. @@ -29,52 +33,6 @@ class Descriptor(torch.nn.Module, ABC): __plugins = Plugin() local_cluster = False - @abstractmethod - def get_rcut(self) -> float: - """Returns the cut-off radius.""" - raise NotImplementedError - - @abstractmethod - def get_nsel(self) -> int: - """Returns the number of selected atoms in the cut-off radius.""" - raise NotImplementedError - - @abstractmethod - def get_sel(self) -> List[int]: - """Returns the number of selected atoms for each type.""" - raise NotImplementedError - - @abstractmethod - def get_ntype(self) -> int: - """Returns the number of element types.""" - raise NotImplementedError - - @abstractmethod - def get_dim_out(self) -> int: - """Returns the output dimension.""" - raise NotImplementedError - - @abstractmethod - def compute_input_stats(self, merged): - """Update mean and stddev for descriptor elements.""" - raise NotImplementedError - - @abstractmethod - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): - """Initialize the model bias by the statistics.""" - raise NotImplementedError - - @abstractmethod - def forward( - self, - extended_coord, - extended_atype, - nlist, - mapping: Optional[torch.Tensor] = None, - ): - """Calculate descriptor.""" - raise NotImplementedError - @staticmethod def register(key: str) -> Callable: """Register a descriptor plugin. @@ -166,42 +124,47 @@ def __new__(cls, *args, **kwargs): @abstractmethod def get_rcut(self) -> float: """Returns the cut-off radius.""" - raise NotImplementedError + pass @abstractmethod def get_nsel(self) -> int: """Returns the number of selected atoms in the cut-off radius.""" - raise NotImplementedError + pass @abstractmethod def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" - raise NotImplementedError + pass @abstractmethod - def get_ntype(self) -> int: + def get_ntypes(self) -> int: """Returns the number of element types.""" - raise NotImplementedError + pass @abstractmethod def get_dim_out(self) -> int: """Returns the output dimension.""" - raise NotImplementedError + pass @abstractmethod def get_dim_in(self) -> int: """Returns the output dimension.""" - raise NotImplementedError + pass + + @abstractmethod + def get_dim_emb(self) -> int: + """Returns the embedding dimension.""" + pass @abstractmethod def compute_input_stats(self, merged): """Update mean and stddev for DescriptorBlock elements.""" - raise NotImplementedError + pass @abstractmethod def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): """Initialize the model bias by the statistics.""" - raise NotImplementedError + pass def share_params(self, base_class, shared_level, resume=False): assert ( diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 23f521b6d8..914c37ed51 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -91,9 +91,9 @@ def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" return self.se_atten.get_sel() - def get_ntype(self) -> int: + def get_ntypes(self) -> int: """Returns the number of element types.""" - return self.se_atten.get_ntype() + return self.se_atten.get_ntypes() def get_dim_out(self) -> int: """Returns the output dimension.""" @@ -102,13 +102,22 @@ def get_dim_out(self) -> int: ret += self.tebd_dim return ret + def get_dim_emb(self) -> int: + return self.se_atten.dim_emb + + def distinguish_types(self) -> bool: + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return False + @property def dim_out(self): return self.get_dim_out() @property def dim_emb(self): - return self.se_atten.dim_emb + return self.get_dim_emb() def compute_input_stats(self, merged): return self.se_atten.compute_input_stats(merged) @@ -128,6 +137,15 @@ def get_data_process_key(cls, config): assert descrpt_type in ["dpa1", "se_atten"] return {"sel": config["sel"], "rcut": config["rcut"]} + def serialize(self) -> dict: + """Serialize the obj to dict.""" + raise NotImplementedError + + @classmethod + def deserialize(cls) -> "DescrptDPA1": + """Deserialize from a dict.""" + raise NotImplementedError + def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 409b999262..b40e466ed4 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -256,7 +256,7 @@ def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" return self.sel - def get_ntype(self) -> int: + def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes @@ -267,6 +267,16 @@ def get_dim_out(self) -> int: ret += self.tebd_dim return ret + def get_dim_emb(self) -> int: + """Returns the embedding dimension of this descriptor.""" + return self.repformers.dim_emb + + def distinguish_types(self) -> bool: + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return False + @property def dim_out(self): return self.get_dim_out() @@ -274,7 +284,7 @@ def dim_out(self): @property def dim_emb(self): """Returns the embedding dimension g2.""" - return self.repformers.dim_emb + return self.get_dim_emb() def compute_input_stats(self, merged): sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] @@ -322,6 +332,15 @@ def get_data_process_key(cls, config): "rcut": [config["repinit_rcut"], config["repformer_rcut"]], } + def serialize(self) -> dict: + """Serialize the obj to dict.""" + raise NotImplementedError + + @classmethod + def deserialize(cls) -> "DescrptDPA2": + """Deserialize from a dict.""" + raise NotImplementedError + def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index 11bbc80729..0698992659 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -88,7 +88,7 @@ def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" return self.sel - def get_ntype(self) -> int: + def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes @@ -100,6 +100,9 @@ def get_dim_in(self) -> int: """Returns the input dimension.""" return self.dim_in + def get_dim_emb(self): + return self.dim_emb + @property def dim_out(self): """Returns the output dimension of this descriptor.""" diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 141b5dc745..853962de69 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -162,7 +162,7 @@ def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" return self.sel - def get_ntype(self) -> int: + def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes @@ -174,6 +174,10 @@ def get_dim_in(self) -> int: """Returns the input dimension.""" return self.dim_in + def get_dim_emb(self) -> int: + """Returns the embedding dimension g2.""" + return self.g2_dim + @property def dim_out(self): """Returns the output dimension of this descriptor.""" @@ -187,7 +191,7 @@ def dim_in(self): @property def dim_emb(self): """Returns the embedding dimension g2.""" - return self.g2_dim + return self.get_dim_emb() def forward( self, diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 3f42736dca..23b78dcf34 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -28,7 +28,7 @@ except ImportError: from torch.jit import Final -from deepmd.model_format import EnvMat as DPEnvMat +from deepmd.dpmodel.utils import EnvMat as DPEnvMat from deepmd.pt.model.network.mlp import ( EmbeddingNet, NetworkCollection, @@ -81,14 +81,24 @@ def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" return self.sea.get_sel() - def get_ntype(self) -> int: + def get_ntypes(self) -> int: """Returns the number of element types.""" - return self.sea.get_ntype() + return self.sea.get_ntypes() def get_dim_out(self) -> int: """Returns the output dimension.""" return self.sea.get_dim_out() + def get_dim_emb(self) -> int: + """Returns the output dimension.""" + return self.sea.get_dim_emb() + + def distinguish_types(self): + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return True + @property def dim_out(self): """Returns the output dimension of this descriptor.""" @@ -295,7 +305,7 @@ def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" return self.sel - def get_ntype(self) -> int: + def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes @@ -303,6 +313,10 @@ def get_dim_out(self) -> int: """Returns the output dimension.""" return self.dim_out + def get_dim_emb(self) -> int: + """Returns the output dimension.""" + return self.neuron[-1] + def get_dim_in(self) -> int: """Returns the input dimension.""" return self.dim_in diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 78cba59da7..5d6e16fb96 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -145,7 +145,7 @@ def get_sel(self) -> List[int]: """Returns the number of selected atoms for each type.""" return self.sel - def get_ntype(self) -> int: + def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes @@ -157,6 +157,10 @@ def get_dim_out(self) -> int: """Returns the output dimension.""" return self.dim_out + def get_dim_emb(self) -> int: + """Returns the output dimension of embedding.""" + return self.filter_neuron[-1] + @property def dim_out(self): """Returns the output dimension of this descriptor.""" @@ -170,7 +174,7 @@ def dim_in(self): @property def dim_emb(self): """Returns the output dimension of embedding.""" - return self.filter_neuron[-1] + return self.get_dim_emb() def compute_input_stats(self, merged): """Update mean and stddev for descriptor elements.""" diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index a3db3dbdec..c4de02ed20 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -1,4 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy + +from deepmd.pt.model.descriptor.descriptor import ( + Descriptor, +) +from deepmd.pt.model.task import ( + Fitting, +) + from .ener import ( EnergyModel, ) @@ -8,9 +17,27 @@ def get_model(model_params, sampled=None): + model_params = copy.deepcopy(model_params) + ntypes = len(model_params["type_map"]) + # descriptor + model_params["descriptor"]["ntypes"] = ntypes + descriptor = Descriptor(**model_params["descriptor"]) + # fitting + fitting_net = model_params.get("fitting_net", None) + fitting_net["type"] = fitting_net.get("type", "ener") + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["distinguish_types"] = descriptor.distinguish_types() + fitting_net["embedding_width"] = descriptor.get_dim_out() + grad_force = "direct" not in fitting_net["type"] + if not grad_force: + fitting_net["out_dim"] = descriptor.get_dim_emb() + if "ener" in fitting_net["type"]: + fitting_net["return_energy"] = True + fitting = Fitting(**fitting_net) + return EnergyModel( - descriptor=model_params["descriptor"], - fitting_net=model_params.get("fitting_net", None), + descriptor, + fitting, type_map=model_params["type_map"], type_embedding=model_params.get("type_embedding", None), resuming=model_params.get("resuming", False), diff --git a/deepmd/pt/model/model/atomic_model.py b/deepmd/pt/model/model/atomic_model.py deleted file mode 100644 index 9720bfa57d..0000000000 --- a/deepmd/pt/model/model/atomic_model.py +++ /dev/null @@ -1,70 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from abc import ( - ABC, - abstractmethod, -) -from typing import ( - Dict, - List, - Optional, -) - -import torch - -from deepmd.model_format import ( - FittingOutputDef, -) - - -class AtomicModel(ABC): - @abstractmethod - def get_fitting_output_def(self) -> FittingOutputDef: - raise NotImplementedError - - @abstractmethod - def get_rcut(self) -> float: - raise NotImplementedError - - @abstractmethod - def get_sel(self) -> List[int]: - raise NotImplementedError - - @abstractmethod - def distinguish_types(self) -> bool: - raise NotImplementedError - - @abstractmethod - def forward_atomic( - self, - extended_coord, - extended_atype, - nlist, - mapping: Optional[torch.Tensor] = None, - do_atomic_virial: bool = False, - ) -> Dict[str, torch.Tensor]: - raise NotImplementedError - - def do_grad( - self, - var_name: Optional[str] = None, - ) -> bool: - """Tell if the output variable `var_name` is differentiable. - if var_name is None, returns if any of the variable is differentiable. - - """ - odef = self.get_fitting_output_def() - if var_name is None: - require: List[bool] = [] - for vv in odef.keys(): - require.append(self.do_grad_(vv)) - return any(require) - else: - return self.do_grad_(var_name) - - def do_grad_( - self, - var_name: str, - ) -> bool: - """Tell if the output variable `var_name` is differentiable.""" - assert var_name is not None - return self.get_fitting_output_def()[var_name].differentiable diff --git a/deepmd/pt/model/model/base_atomic_model.py b/deepmd/pt/model/model/base_atomic_model.py new file mode 100644 index 0000000000..3f3e14257b --- /dev/null +++ b/deepmd/pt/model/model/base_atomic_model.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import torch + +from deepmd.dpmodel.model import ( + make_base_atomic_model, +) + +BaseAtomicModel = make_base_atomic_model(torch.Tensor) diff --git a/deepmd/pt/model/model/dp_atomic_model.py b/deepmd/pt/model/model/dp_atomic_model.py index a222c8e6f6..b2ae48628b 100644 --- a/deepmd/pt/model/model/dp_atomic_model.py +++ b/deepmd/pt/model/model/dp_atomic_model.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import sys from typing import ( Dict, List, @@ -7,25 +9,25 @@ import torch -from deepmd.model_format import ( +from deepmd.dpmodel import ( FittingOutputDef, ) -from deepmd.pt.model.descriptor.descriptor import ( - Descriptor, +from deepmd.pt.model.descriptor.se_a import ( # noqa # TODO: should import all descriptors!!! + DescrptSeA, ) -from deepmd.pt.model.task import ( - Fitting, +from deepmd.pt.model.task.ener import ( # noqa # TODO: should import all fittings! + InvarFitting, ) -from .atomic_model import ( - AtomicModel, +from .base_atomic_model import ( + BaseAtomicModel, ) from .model import ( BaseModel, ) -class DPAtomicModel(BaseModel, AtomicModel): +class DPAtomicModel(BaseModel, BaseAtomicModel): """Model give atomic prediction of some physical property. Parameters @@ -49,10 +51,11 @@ class DPAtomicModel(BaseModel, AtomicModel): Sampled frames to compute the statistics. """ + # I am enough with the shit interface! def __init__( self, - descriptor: dict, - fitting_net: dict, + descriptor, + fitting, type_map: Optional[List[str]], type_embedding: Optional[dict] = None, resuming: bool = False, @@ -62,26 +65,15 @@ def __init__( **kwargs, ): super().__init__() - # Descriptor + Type Embedding Net (Optional) ntypes = len(type_map) self.type_map = type_map self.ntypes = ntypes - descriptor["ntypes"] = ntypes - self.combination = descriptor.get("combination", False) - if self.combination: - self.prefactor = descriptor.get("prefactor", [0.5, 0.5]) - self.descriptor_type = descriptor["type"] - - self.type_split = True - if self.descriptor_type not in ["se_e2_a"]: - self.type_split = False - - self.descriptor = Descriptor(**descriptor) + self.descriptor = descriptor self.rcut = self.descriptor.get_rcut() self.sel = self.descriptor.get_sel() - self.split_nlist = False - + self.fitting_net = fitting # Statistics + fitting_net = None # TODO: hack!!! not sure if it is correct. self.compute_or_load_stat( fitting_net, ntypes, @@ -92,22 +84,7 @@ def __init__( sampled=sampled, ) - fitting_net["type"] = fitting_net.get("type", "ener") - fitting_net["ntypes"] = self.descriptor.get_ntype() - if self.descriptor_type in ["se_e2_a"]: - fitting_net["distinguish_types"] = True - else: - fitting_net["distinguish_types"] = False - fitting_net["embedding_width"] = self.descriptor.dim_out - - self.grad_force = "direct" not in fitting_net["type"] - if not self.grad_force: - fitting_net["out_dim"] = self.descriptor.dim_emb - if "ener" in fitting_net["type"]: - fitting_net["return_energy"] = True - self.fitting_net = Fitting(**fitting_net) - - def get_fitting_output_def(self) -> FittingOutputDef: + def fitting_output_def(self) -> FittingOutputDef: """Get the output def of the fitting net.""" return ( self.fitting_net.output_def() @@ -125,7 +102,34 @@ def get_sel(self) -> List[int]: def distinguish_types(self) -> bool: """If distinguish different types by sorting.""" - return self.type_split + return self.descriptor.distinguish_types() + + def serialize(self) -> dict: + return { + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "fitting": self.fitting_net.serialize(), + "descriptor_name": self.descriptor.__class__.__name__, + "fitting_name": self.fitting_net.__class__.__name__, + } + + @classmethod + def deserialize(cls, data) -> "DPAtomicModel": + data = copy.deepcopy(data) + descriptor_obj = getattr( + sys.modules[__name__], data["descriptor_name"] + ).deserialize(data["descriptor"]) + fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize( + data["fitting"] + ) + # TODO: dirty hack to provide type_map and avoid data stat!!! + obj = cls( + descriptor_obj, + fitting_obj, + type_map=data["type_map"], + resuming=True, + ) + return obj def forward_atomic( self, @@ -133,6 +137,8 @@ def forward_atomic( extended_atype, nlist, mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """Return atomic prediction. @@ -146,11 +152,15 @@ def forward_atomic( neighbor list. nf x nloc x nsel mapping mapps the extended indices to local indices + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda Returns ------- result_dict - the result dict, defined by the fitting net output def. + the result dict, defined by the `FittingOutputDef`. """ nframes, nloc, nnei = nlist.shape @@ -165,5 +175,13 @@ def forward_atomic( ) assert descriptor is not None # energy, force - fit_ret = self.fitting_net(descriptor, atype, gr=rot_mat) + fit_ret = self.fitting_net( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) return fit_ret diff --git a/deepmd/pt/model/model/ener.py b/deepmd/pt/model/model/ener.py index c316c99a86..a408689d8d 100644 --- a/deepmd/pt/model/model/ener.py +++ b/deepmd/pt/model/model/ener.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Dict, - List, Optional, ) @@ -32,6 +31,8 @@ def forward( coord, atype, box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ) -> Dict[str, torch.Tensor]: model_ret = self.forward_common( @@ -86,66 +87,3 @@ def forward_lower( else: model_predict = model_ret return model_predict - - -# should be a stand-alone function!!!! -def process_nlist( - nlist, - extended_atype, - mapping: Optional[torch.Tensor] = None, -): - # process the nlist_type and nlist_loc - nframes, nloc = nlist.shape[:2] - nmask = nlist == -1 - nlist[nmask] = 0 - if mapping is not None: - nlist_loc = torch.gather( - mapping, - dim=1, - index=nlist.reshape(nframes, -1), - ).reshape(nframes, nloc, -1) - nlist_loc[nmask] = -1 - else: - nlist_loc = None - nlist_type = torch.gather( - extended_atype, - dim=1, - index=nlist.reshape(nframes, -1), - ).reshape(nframes, nloc, -1) - nlist_type[nmask] = -1 - nlist[nmask] = -1 - return nlist_loc, nlist_type, nframes, nloc - - -def process_nlist_gathered( - nlist, - extended_atype, - split_sel: List[int], - mapping: Optional[torch.Tensor] = None, -): - nlist_list = list(torch.split(nlist, split_sel, -1)) - nframes, nloc = nlist_list[0].shape[:2] - nlist_type_list = [] - nlist_loc_list = [] - for nlist_item in nlist_list: - nmask = nlist_item == -1 - nlist_item[nmask] = 0 - if mapping is not None: - nlist_loc_item = torch.gather( - mapping, dim=1, index=nlist_item.reshape(nframes, -1) - ).reshape(nframes, nloc, -1) - nlist_loc_item[nmask] = -1 - nlist_loc_list.append(nlist_loc_item) - nlist_type_item = torch.gather( - extended_atype, dim=1, index=nlist_item.reshape(nframes, -1) - ).reshape(nframes, nloc, -1) - nlist_type_item[nmask] = -1 - nlist_type_list.append(nlist_type_item) - nlist_item[nmask] = -1 - - if mapping is not None: - nlist_loc = torch.cat(nlist_loc_list, -1) - else: - nlist_loc = None - nlist_type = torch.cat(nlist_type_list, -1) - return nlist_loc, nlist_type, nframes, nloc diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 3ddd21fbb8..c8c1e9450b 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -6,7 +6,7 @@ import torch -from deepmd.model_format import ( +from deepmd.dpmodel import ( ModelOutputDef, ) from deepmd.pt.model.model.transform_output import ( @@ -16,6 +16,7 @@ from deepmd.pt.utils.nlist import ( build_neighbor_list, extend_coord_with_ghosts, + nlist_distinguish_types, ) from deepmd.pt.utils.region import ( normalize_coord, @@ -23,6 +24,28 @@ def make_model(T_AtomicModel): + """Make a model as a derived class of an atomic model. + + The model provide two interfaces. + + 1. the `forward_common_lower`, that takes extended coordinates, atyps and neighbor list, + and outputs the atomic and property and derivatives (if required) on the extended region. + + 2. the `forward_common`, that takes coordinates, atypes and cell and predicts + the atomic and reduced property, and derivatives (if required) on the local region. + + Parameters + ---------- + T_AtomicModel + The atomic model. + + Returns + ------- + CM + The model. + + """ + class CM(T_AtomicModel): def __init__( self, @@ -34,8 +57,9 @@ def __init__( **kwargs, ) - def get_model_output_def(self): - return ModelOutputDef(self.get_fitting_output_def()) + def model_output_def(self): + """Get the output def for the model.""" + return ModelOutputDef(self.fitting_output_def()) # cannot use the name forward. torch script does not work def forward_common( @@ -43,24 +67,37 @@ def forward_common( coord, atype, box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ) -> Dict[str, torch.Tensor]: - """Return total energy of the system. - Args: - - coord: Atom coordinates with shape [nframes, natoms[1]*3]. - - atype: Atom types with shape [nframes, natoms[1]]. - - natoms: Atom statisics with shape [self.ntypes+2]. - - box: Simulation box with shape [nframes, 9]. - - atomic_virial: Whether or not compoute the atomic virial. + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + do_atomic_virial + If calculate the atomic virial. Returns ------- - - energy: Energy per atom. - - force: XYZ force per atom. + ret_dict + The result dict of type Dict[str,torch.Tensor]. + The keys are defined by the `ModelOutputDef`. + """ nframes, nloc = atype.shape[:2] if box is not None: - coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) + coord_normalized = normalize_coord( + coord.view(nframes, nloc, 3), + box.reshape(nframes, 3, 3), + ) else: coord_normalized = coord.clone() extended_coord, extended_atype, mapping = extend_coord_with_ghosts( @@ -74,17 +111,19 @@ def forward_common( self.get_sel(), distinguish_types=self.distinguish_types(), ) - extended_coord = extended_coord.reshape(nframes, -1, 3) + extended_coord = extended_coord.view(nframes, -1, 3) model_predict_lower = self.forward_common_lower( extended_coord, extended_atype, nlist, mapping, do_atomic_virial=do_atomic_virial, + fparam=fparam, + aparam=aparam, ) model_predict = communicate_extended_output( model_predict_lower, - self.get_model_output_def(), + self.model_output_def(), mapping, do_atomic_virial=do_atomic_virial, ) @@ -96,9 +135,14 @@ def forward_common_lower( extended_atype, nlist, mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ): - """Return model prediction. + """Return model prediction. Lower interface that takes + extended atomic coordinates and types, nlist, and mapping + as input, and returns the predictions on the extended region. + The predictions are not reduced. Parameters ---------- @@ -111,26 +155,118 @@ def forward_common_lower( mapping mapps the extended indices to local indices do_atomic_virial - whether do atomic virial + whether calculate atomic virial Returns ------- result_dict - the result dict, defined by the fitting net output def. + the result dict, defined by the `FittingOutputDef`. """ + nframes, nall = extended_atype.shape[:2] + extended_coord = extended_coord.view(nframes, -1, 3) + nlist = self.format_nlist(extended_coord, extended_atype, nlist) atomic_ret = self.forward_atomic( extended_coord, extended_atype, nlist, mapping=mapping, + fparam=fparam, + aparam=aparam, ) model_predict = fit_output_to_model_output( atomic_ret, - self.get_fitting_output_def(), + self.fitting_output_def(), extended_coord, do_atomic_virial=do_atomic_virial, ) return model_predict + def format_nlist( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + ): + """Format the neighbor list. + + 1. If the number of neighbors in the `nlist` is equal to sum(self.sel), + it does nothong + + 2. If the number of neighbors in the `nlist` is smaller than sum(self.sel), + the `nlist` is pad with -1. + + 3. If the number of neighbors in the `nlist` is larger than sum(self.sel), + the nearest sum(sel) neighbors will be preseved. + + Known limitations: + + In the case of self.distinguish_types, the nlist is always formatted. + May have side effact on the efficiency. + + Parameters + ---------- + extended_coord + coodinates in extended region. nf x nall x 3 + extended_atype + atomic type in extended region. nf x nall + nlist + neighbor list. nf x nloc x nsel + + Returns + ------- + formated_nlist + the formated nlist. + + """ + distinguish_types = self.distinguish_types() + nlist = self._format_nlist(extended_coord, nlist, sum(self.get_sel())) + if distinguish_types: + nlist = nlist_distinguish_types(nlist, extended_atype, self.get_sel()) + return nlist + + def _format_nlist( + self, + extended_coord: torch.Tensor, + nlist: torch.Tensor, + nnei: int, + ): + n_nf, n_nloc, n_nnei = nlist.shape + # nf x nall x 3 + extended_coord = extended_coord.view([n_nf, -1, 3]) + rcut = self.get_rcut() + + if n_nnei < nnei: + nlist = torch.cat( + [ + nlist, + -1 + * torch.ones( + [n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype + ).to(nlist.device), + ], + dim=-1, + ) + elif n_nnei > nnei: + m_real_nei = nlist >= 0 + nlist = torch.where(m_real_nei, nlist, 0) + # nf x nloc x 3 + coord0 = extended_coord[:, :n_nloc, :] + # nf x (nloc x nnei) x 3 + index = nlist.view(n_nf, n_nloc * n_nnei, 1).expand(-1, -1, 3) + coord1 = torch.gather(extended_coord, 1, index) + # nf x nloc x nnei x 3 + coord1 = coord1.view(n_nf, n_nloc, n_nnei, 3) + # nf x nloc x nnei + rr = torch.linalg.norm(coord0[:, :, None, :] - coord1, dim=-1) + rr = torch.where(m_real_nei, rr, float("inf")) + rr, nlist_mapping = torch.sort(rr, dim=-1) + nlist = torch.gather(nlist, 2, nlist_mapping) + nlist = torch.where(rr > rcut, -1, nlist) + nlist = nlist[..., :nnei] + else: # n_nnei == nnei: + pass # great! + assert nlist.shape[-1] == nnei + return nlist + return CM diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index 139744c1e9..01c2d7b9d6 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -18,10 +18,6 @@ def __init__(self): """Construct a basic model for different tasks.""" super().__init__() - def forward(self, *args, **kwargs): - """Model output.""" - raise NotImplementedError - def compute_or_load_stat( self, fitting_param, diff --git a/deepmd/pt/model/model/pair_tab.py b/deepmd/pt/model/model/pair_tab.py index 6f0782289a..430d090eb0 100644 --- a/deepmd/pt/model/model/pair_tab.py +++ b/deepmd/pt/model/model/pair_tab.py @@ -11,7 +11,7 @@ nn, ) -from deepmd.model_format import ( +from deepmd.dpmodel import ( FittingOutputDef, OutputVariableDef, ) @@ -19,12 +19,12 @@ PairTab, ) -from .atomic_model import ( - AtomicModel, +from .base_atomic_model import ( + BaseAtomicModel, ) -class PairTabModel(nn.Module, AtomicModel): +class PairTabModel(nn.Module, BaseAtomicModel): """Pairwise tabulation energy model. This model can be used to tabulate the pairwise energy between atoms for either @@ -72,7 +72,7 @@ def __init__( else: raise TypeError("sel must be int or list[int]") - def get_fitting_output_def(self) -> FittingOutputDef: + def fitting_output_def(self) -> FittingOutputDef: return FittingOutputDef( [ OutputVariableDef( @@ -91,6 +91,14 @@ def distinguish_types(self) -> bool: # to match DPA1 and DPA2. return False + def serialize(self) -> dict: + # place holder, implemantated in future PR + raise NotImplementedError + + def deserialize(cls): + # place holder, implemantated in future PR + raise NotImplementedError + def forward_atomic( self, extended_coord, diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index a14518e8a0..d942ed3ae8 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -7,7 +7,7 @@ import torch -from deepmd.model_format import ( +from deepmd.dpmodel import ( FittingOutputDef, ModelOutputDef, OutputVariableDef, @@ -152,6 +152,7 @@ def fit_output_to_model_output( ) model_ret[kk_derv_r] = dr model_ret[kk_derv_c] = dc + model_ret[kk_derv_c + "_redu"] = torch.sum(model_ret[kk_derv_c], dim=1) return model_ret diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index d76abd82f9..251150f945 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -15,11 +15,11 @@ device = env.DEVICE -from deepmd.model_format import ( +from deepmd.dpmodel.utils import ( NativeLayer, ) -from deepmd.model_format import NetworkCollection as DPNetworkCollection -from deepmd.model_format import ( +from deepmd.dpmodel.utils import NetworkCollection as DPNetworkCollection +from deepmd.dpmodel.utils import ( make_embedding_network, make_fitting_network, make_multilayer_network, diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index fcf46632f3..0b21033d31 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -2,6 +2,9 @@ from .atten_lcc import ( FittingNetAttenLcc, ) +from .base_fitting import ( + BaseFitting, +) from .denoise import ( DenoiseNet, ) @@ -15,9 +18,6 @@ from .fitting import ( Fitting, ) -from .task import ( - TaskBaseMethod, -) from .type_predict import ( TypePredictNet, ) @@ -29,6 +29,6 @@ "EnergyFittingNet", "EnergyFittingNetDirect", "Fitting", - "TaskBaseMethod", + "BaseFitting", "TypePredictNet", ] diff --git a/deepmd/pt/model/task/atten_lcc.py b/deepmd/pt/model/task/atten_lcc.py index 41ccf99330..e5961335ec 100644 --- a/deepmd/pt/model/task/atten_lcc.py +++ b/deepmd/pt/model/task/atten_lcc.py @@ -6,15 +6,15 @@ EnergyHead, NodeTaskHead, ) -from deepmd.pt.model.task.task import ( - TaskBaseMethod, +from deepmd.pt.model.task.fitting import ( + Fitting, ) from deepmd.pt.utils import ( env, ) -class FittingNetAttenLcc(TaskBaseMethod): +class FittingNetAttenLcc(Fitting): def __init__( self, embedding_width, bias_atom_e, pair_embed_dim, attention_heads, **kwargs ): diff --git a/deepmd/pt/model/task/base_fitting.py b/deepmd/pt/model/task/base_fitting.py new file mode 100644 index 0000000000..884a1bfe57 --- /dev/null +++ b/deepmd/pt/model/task/base_fitting.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + +from deepmd.dpmodel.fitting import ( + make_base_fitting, +) + +BaseFitting = make_base_fitting(torch.Tensor, fwd_method_name="forward") diff --git a/deepmd/pt/model/task/denoise.py b/deepmd/pt/model/task/denoise.py index 7e6b6dcdb6..35846ed231 100644 --- a/deepmd/pt/model/task/denoise.py +++ b/deepmd/pt/model/task/denoise.py @@ -5,7 +5,7 @@ import torch -from deepmd.model_format import ( +from deepmd.dpmodel import ( FittingOutputDef, OutputVariableDef, fitting_check_output, @@ -14,8 +14,8 @@ MaskLMHead, NonLinearHead, ) -from deepmd.pt.model.task.task import ( - TaskBaseMethod, +from deepmd.pt.model.task.fitting import ( + Fitting, ) from deepmd.pt.utils import ( env, @@ -23,7 +23,7 @@ @fitting_check_output -class DenoiseNet(TaskBaseMethod): +class DenoiseNet(Fitting): def __init__( self, feature_dim, diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 8511c7dc29..4906987bf8 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -6,12 +6,12 @@ from deepmd.pt.model.network.network import ( ResidualDeep, ) -from deepmd.pt.model.task.task import ( - TaskBaseMethod, +from deepmd.pt.model.task.fitting import ( + Fitting, ) -class DipoleFittingNetType(TaskBaseMethod): +class DipoleFittingNetType(Fitting): def __init__( self, ntypes, embedding_width, neuron, out_dim, resnet_dt=True, **kwargs ): diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index e40a6bda44..484e477b6a 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -10,7 +10,7 @@ import numpy as np import torch -from deepmd.model_format import ( +from deepmd.dpmodel import ( FittingOutputDef, OutputVariableDef, fitting_check_output, @@ -292,6 +292,7 @@ def forward( "get an input fparam of dim {fparam.shape[-1]}, ", "which is not consistent with {self.numb_fparam}.", ) + fparam = fparam.view([nf, self.numb_fparam]) nb, _ = fparam.shape t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) @@ -311,6 +312,7 @@ def forward( "get an input aparam of dim {aparam.shape[-1]}, ", "which is not consistent with {self.numb_aparam}.", ) + aparam = aparam.view([nf, nloc, self.numb_aparam]) nb, nloc, _ = aparam.shape t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) @@ -396,7 +398,7 @@ def __init__( ntypes, embedding_width, neuron, - bias_atom_e, + bias_atom_e=None, out_dim=1, resnet_dt=True, use_tebd=True, @@ -417,6 +419,8 @@ def __init__( self.dim_descrpt = embedding_width self.use_tebd = use_tebd self.out_dim = out_dim + if bias_atom_e is None: + bias_atom_e = np.zeros([self.ntypes]) if not use_tebd: assert self.ntypes == len(bias_atom_e), "Element count mismatches!" bias_atom_e = torch.tensor(bias_atom_e) @@ -460,11 +464,21 @@ def output_def(self): ] ) + def serialize(self) -> dict: + raise NotImplementedError + + def deserialize(cls) -> "EnergyFittingNetDirect": + raise NotImplementedError + def forward( self, inputs: torch.Tensor, atype: torch.Tensor, gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, None]: """Based on embedding net output, alculate total energy. diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index c6fb6b27e1..551fb9640b 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -7,8 +7,8 @@ import numpy as np import torch -from deepmd.pt.model.task.task import ( - TaskBaseMethod, +from deepmd.pt.model.task.base_fitting import ( + BaseFitting, ) from deepmd.pt.utils.dataloader import ( DpLoaderSet, @@ -24,7 +24,7 @@ ) -class Fitting(TaskBaseMethod): +class Fitting(torch.nn.Module, BaseFitting): __plugins = Plugin() @staticmethod diff --git a/deepmd/pt/model/task/task.py b/deepmd/pt/model/task/task.py index b2dc03e4bd..6ceb116d85 100644 --- a/deepmd/pt/model/task/task.py +++ b/deepmd/pt/model/task/task.py @@ -1,18 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from abc import ( - ABC, - abstractmethod, -) - -import torch - -from deepmd.model_format import ( - FittingOutputDef, -) - - -class TaskBaseMethod(torch.nn.Module, ABC): - @abstractmethod - def output_def(self) -> FittingOutputDef: - """Definition for the task Output.""" - raise NotImplementedError diff --git a/deepmd/pt/model/task/type_predict.py b/deepmd/pt/model/task/type_predict.py index 57227004d0..c696590043 100644 --- a/deepmd/pt/model/task/type_predict.py +++ b/deepmd/pt/model/task/type_predict.py @@ -9,11 +9,11 @@ MaskLMHead, ) from deepmd.pt.model.task import ( - TaskBaseMethod, + Fitting, ) -class TypePredictNet(TaskBaseMethod): +class TypePredictNet(Fitting): def __init__(self, feature_dim, ntypes, activation_function="gelu", **kwargs): """Construct a type predict net. diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index b51b03fdc2..a679ccf1fa 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -40,6 +40,8 @@ "half": torch.float16, "single": torch.float32, "double": torch.float64, + "int32": torch.int32, + "int64": torch.int64, } DEFAULT_PRECISION = "float64" diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index 23a11684a5..fdb2627f04 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -16,143 +16,6 @@ ) -def _build_neighbor_list( - coord1: torch.Tensor, - nloc: int, - rcut: float, - nsel: int, - rmin: float = 1e-10, - cut_nearest: bool = True, -) -> torch.Tensor: - """Build neightbor list for a single frame. keeps nsel neighbors. - coord1 : [nall x 3]. - - ret: [nloc x nsel] stores indexes of coord1. - """ - nall = coord1.shape[-1] // 3 - coord0 = torch.split(coord1, [nloc * 3, (nall - nloc) * 3])[0] - # nloc x nall x 3 - diff = coord1.view([-1, 3])[None, :, :] - coord0.view([-1, 3])[:, None, :] - assert list(diff.shape) == [nloc, nall, 3] - # nloc x nall - rr = torch.linalg.norm(diff, dim=-1) - rr, nlist = torch.sort(rr, dim=-1) - if cut_nearest: - # nloc x (nall-1) - rr = torch.split(rr, [1, nall - 1], dim=-1)[-1] - nlist = torch.split(nlist, [1, nall - 1], dim=-1)[-1] - # nloc x nsel - nnei = rr.shape[1] - rr = torch.split(rr, [nsel, nnei - nsel], dim=-1)[0] - nlist = torch.split(nlist, [nsel, nnei - nsel], dim=-1)[0] - nlist = nlist.masked_fill((rr > rcut), -1) - return nlist - - -def build_neighbor_list_lower( - coord1: torch.Tensor, - atype: torch.Tensor, - nloc: int, - rcut: float, - sel: Union[int, List[int]], - distinguish_types: bool = True, -) -> torch.Tensor: - """Build neightbor list for a single frame. keeps nsel neighbors. - - Parameters - ---------- - coord1 : torch.Tensor - exptended coordinates of shape [nall x 3] - atype : torch.Tensor - extended atomic types of shape [nall] - nloc : int - number of local atoms. - rcut : float - cut-off radius - sel : int or List[int] - maximal number of neighbors (of each type). - if distinguish_types==True, nsel should be list and - the length of nsel should be equal to number of - types. - distinguish_types : bool - distinguish different types. - - Returns - ------- - neighbor_list : torch.Tensor - Neighbor list of shape [nloc, nsel], the neighbors - are stored in an ascending order. If the number of - neighbors is less than nsel, the positions are masked - with -1. The neighbor list of an atom looks like - |------ nsel ------| - xx xx xx xx -1 -1 -1 - if distinguish_types==True and we have two types - |---- nsel[0] -----| |---- nsel[1] -----| - xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1 - - """ - nall = coord1.shape[0] // 3 - if isinstance(sel, int): - sel = [sel] - nsel = sum(sel) - # nloc x 3 - coord0 = coord1[: nloc * 3] - # nloc x nall x 3 - diff = coord1.view([-1, 3]).unsqueeze(0) - coord0.view([-1, 3]).unsqueeze(1) - assert list(diff.shape) == [nloc, nall, 3] - # nloc x nall - rr = torch.linalg.norm(diff, dim=-1) - rr, nlist = torch.sort(rr, dim=-1) - # nloc x (nall-1) - rr = rr[:, 1:] - nlist = nlist[:, 1:] - # nloc x nsel - nnei = rr.shape[1] - if nsel <= nnei: - rr = rr[:, :nsel] - nlist = nlist[:, :nsel] - else: - rr = torch.cat( - [rr, torch.ones([nloc, nsel - nnei]).to(rr.device) + rcut], dim=-1 - ) - nlist = torch.cat( - [nlist, torch.ones([nloc, nsel - nnei], dtype=torch.long).to(rr.device)], - dim=-1, - ) - assert list(nlist.shape) == [nloc, nsel] - nlist = nlist.masked_fill((rr > rcut), -1) - - if not distinguish_types: - return nlist - else: - ret_nlist = [] - # nloc x nall - tmp_atype = torch.tile(atype.unsqueeze(0), [nloc, 1]) - mask = nlist == -1 - # nloc x s(nsel) - tnlist = torch.gather( - tmp_atype, - 1, - nlist.masked_fill(mask, 0), - ) - tnlist = tnlist.masked_fill(mask, -1) - snsel = tnlist.shape[1] - for ii, ss in enumerate(sel): - # nloc x s(nsel) - # to int because bool cannot be sort on GPU - pick_mask = (tnlist == ii).to(torch.int32) - # nloc x s(nsel), stable sort, nearer neighbors first - pick_mask, imap = torch.sort( - pick_mask, dim=-1, descending=True, stable=True - ) - # nloc x s(nsel) - inlist = torch.gather(nlist, 1, imap) - inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1) - # nloc x nsel[ii] - ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0]) - return torch.concat(ret_nlist, dim=-1) - - def build_neighbor_list( coord1: torch.Tensor, atype: torch.Tensor, @@ -227,7 +90,7 @@ def build_neighbor_list( nlist = torch.cat( [ nlist, - torch.ones([batch_size, nloc, nsel - nnei], dtype=torch.long).to( + torch.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype).to( rr.device ), ], @@ -236,35 +99,46 @@ def build_neighbor_list( assert list(nlist.shape) == [batch_size, nloc, nsel] nlist = nlist.masked_fill((rr > rcut), -1) - if not distinguish_types: - return nlist + if distinguish_types: + return nlist_distinguish_types(nlist, atype, sel) else: - ret_nlist = [] - # nloc x nall - tmp_atype = torch.tile(atype.unsqueeze(1), [1, nloc, 1]) - mask = nlist == -1 + return nlist + + +def nlist_distinguish_types( + nlist: torch.Tensor, + atype: torch.Tensor, + sel: List[int], +): + """Given a nlist that does not distinguish atom types, return a nlist that + distinguish atom types. + + """ + nf, nloc, nnei = nlist.shape + ret_nlist = [] + # nloc x nall + tmp_atype = torch.tile(atype.unsqueeze(1), [1, nloc, 1]) + mask = nlist == -1 + # nloc x s(nsel) + tnlist = torch.gather( + tmp_atype, + 2, + nlist.masked_fill(mask, 0), + ) + tnlist = tnlist.masked_fill(mask, -1) + snsel = tnlist.shape[2] + for ii, ss in enumerate(sel): # nloc x s(nsel) - tnlist = torch.gather( - tmp_atype, - 2, - nlist.masked_fill(mask, 0), - ) - tnlist = tnlist.masked_fill(mask, -1) - snsel = tnlist.shape[2] - for ii, ss in enumerate(sel): - # nloc x s(nsel) - # to int because bool cannot be sort on GPU - pick_mask = (tnlist == ii).to(torch.int32) - # nloc x s(nsel), stable sort, nearer neighbors first - pick_mask, imap = torch.sort( - pick_mask, dim=-1, descending=True, stable=True - ) - # nloc x s(nsel) - inlist = torch.gather(nlist, 2, imap) - inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1) - # nloc x nsel[ii] - ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0]) - return torch.concat(ret_nlist, dim=-1) + # to int because bool cannot be sort on GPU + pick_mask = (tnlist == ii).to(torch.int32) + # nloc x s(nsel), stable sort, nearer neighbors first + pick_mask, imap = torch.sort(pick_mask, dim=-1, descending=True, stable=True) + # nloc x s(nsel) + inlist = torch.gather(nlist, 2, imap) + inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1) + # nloc x nsel[ii] + ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0]) + return torch.concat(ret_nlist, dim=-1) # build_neighbor_list = torch.vmap( @@ -369,6 +243,8 @@ def extend_coord_with_ghosts( atom type of shape [-1, nloc]. cell : torch.Tensor simulation cell tensor of shape [-1, 9]. + rcut : float + the cutoff radius Returns ------- diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index e83e12f608..2b96925a51 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F -from deepmd.model_format.common import PRECISION_DICT as NP_PRECISION_DICT +from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT from .env import ( DEVICE, diff --git a/source/tests/common/test_model_format_utils.py b/source/tests/common/test_model_format_utils.py index cb85fd2bb2..18a40ffdd9 100644 --- a/source/tests/common/test_model_format_utils.py +++ b/source/tests/common/test_model_format_utils.py @@ -8,17 +8,31 @@ import numpy as np -from deepmd.model_format import ( +from deepmd.dpmodel.descriptor import ( DescrptSeA, +) +from deepmd.dpmodel.fitting import ( + InvarFitting, +) +from deepmd.dpmodel.model import ( + DPAtomicModel, + DPModel, +) +from deepmd.dpmodel.utils import ( EmbeddingNet, EnvMat, FittingNet, - InvarFitting, NativeLayer, NativeNet, NetworkCollection, + build_multiple_neighbor_list, + build_neighbor_list, + extend_coord_with_ghosts, + get_multiple_nlist_key, + inter2phys, load_dp_model, save_dp_model, + to_face_distance, ) @@ -266,7 +280,7 @@ def test_zero_dim(self): ) -class TestDPModel(unittest.TestCase): +class TestSaveLoadDPModel(unittest.TestCase): def setUp(self) -> None: self.w = np.full((3, 2), 3.0) self.b = np.full((3,), 4.0) @@ -285,7 +299,7 @@ def setUp(self) -> None: }, ], } - self.filename = "test_dp_model_format.dp" + self.filename = "test_dp_dpmodel.dp" def test_save_load_model(self): save_dp_model(self.filename, deepcopy(self.model_dict)) @@ -321,7 +335,7 @@ def setUp(self): [ [1, 3, -1, -1, -1, 2, -1], [0, -1, -1, -1, -1, 2, -1], - [0, 1, -1, -1, -1, 0, -1], + [0, 1, -1, -1, -1, -1, -1], ], dtype=int, ).reshape([1, self.nloc, sum(self.sel)]) @@ -490,3 +504,386 @@ def test_get_set(self): ]: ifn0[ii] = foo np.testing.assert_allclose(foo, ifn0[ii]) + + +class TestDPAtomicModel(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ) + type_map = ["foo", "bar"] + md0 = DPAtomicModel(ds, ft, type_map=type_map) + md1 = DPAtomicModel.deserialize(md0.serialize()) + + ret0 = md0.forward_atomic(self.coord_ext, self.atype_ext, self.nlist) + ret1 = md1.forward_atomic(self.coord_ext, self.atype_ext, self.nlist) + + np.testing.assert_allclose(ret0["energy"], ret1["energy"]) + + +class TestDPModel(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ) + type_map = ["foo", "bar"] + md0 = DPModel(ds, ft, type_map=type_map) + md1 = DPModel.deserialize(md0.serialize()) + + ret0 = md0.call_lower(self.coord_ext, self.atype_ext, self.nlist) + ret1 = md1.call_lower(self.coord_ext, self.atype_ext, self.nlist) + + np.testing.assert_allclose(ret0["energy"], ret1["energy"]) + np.testing.assert_allclose(ret0["energy_redu"], ret1["energy_redu"]) + + +class TestDPModelFormatNlist(unittest.TestCase): + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 3 + self.nall = 5 + self.nf, self.nt = 1, 2 + self.coord_ext = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, -2, 0], + [2.3, 0, 0], + ], + dtype=np.float64, + ).reshape([1, self.nall * 3]) + # sel = [5, 2] + self.sel = [5, 2] + self.expected_nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, -1, -1], + ], + dtype=int, + ).reshape([1, self.nloc, sum(self.sel)]) + self.atype_ext = np.array([0, 0, 1, 0, 1], dtype=int).reshape([1, self.nall]) + self.rcut_smth = 0.4 + self.rcut = 2.1 + + nf, nloc, nnei = self.expected_nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ) + type_map = ["foo", "bar"] + self.md = DPModel(ds, ft, type_map=type_map) + + def test_nlist_eq(self): + # n_nnei == nnei + nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, -1, -1], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + self.coord_ext, + self.atype_ext, + nlist, + ) + np.testing.assert_allclose(self.expected_nlist, nlist1) + + def test_nlist_st(self): + # n_nnei < nnei + nlist = np.array( + [ + [1, 3, -1, 2], + [0, -1, -1, 2], + [0, 1, -1, -1], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + self.coord_ext, + self.atype_ext, + nlist, + ) + np.testing.assert_allclose(self.expected_nlist, nlist1) + + def test_nlist_lt(self): + # n_nnei > nnei + nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1, -1, 4], + [0, -1, 4, -1, -1, 2, -1, 3, -1], + [0, 1, -1, -1, -1, 4, -1, -1, 3], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + self.coord_ext, + self.atype_ext, + nlist, + ) + np.testing.assert_allclose(self.expected_nlist, nlist1) + + +class TestRegion(unittest.TestCase): + def setUp(self): + self.cell = np.array( + [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], + ) + self.cell = np.reshape(self.cell, [1, 1, -1, 3]) + self.cell = np.tile(self.cell, [4, 5, 1, 1]) + self.prec = 1e-8 + + def test_inter_to_phys(self): + rng = np.random.default_rng() + inter = rng.normal(size=[4, 5, 3, 3]) + phys = inter2phys(inter, self.cell) + for ii in range(4): + for jj in range(5): + expected_phys = np.matmul(inter[ii, jj], self.cell[ii, jj]) + np.testing.assert_allclose( + phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec + ) + + def test_to_face_dist(self): + cell0 = self.cell[0][0] + vol = np.linalg.det(cell0) + # area of surfaces xy, xz, yz + sxy = np.linalg.norm(np.cross(cell0[0], cell0[1])) + sxz = np.linalg.norm(np.cross(cell0[0], cell0[2])) + syz = np.linalg.norm(np.cross(cell0[1], cell0[2])) + # vol / area gives distance + dz = vol / sxy + dy = vol / sxz + dx = vol / syz + expected = np.array([dx, dy, dz]) + dists = to_face_distance(self.cell) + for ii in range(4): + for jj in range(5): + np.testing.assert_allclose( + dists[ii][jj], expected, rtol=self.prec, atol=self.prec + ) + + +dtype = np.float64 + + +class TestNeighList(unittest.TestCase): + def setUp(self): + self.nf = 3 + self.nloc = 2 + self.ns = 5 * 5 * 3 + self.nall = self.ns * self.nloc + self.cell = np.array([[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype) + self.icoord = np.array([[0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype) + self.atype = np.array([0, 1], dtype=np.int32) + [self.cell, self.icoord, self.atype] = [ + np.expand_dims(ii, 0) for ii in [self.cell, self.icoord, self.atype] + ] + self.coord = inter2phys(self.icoord, self.cell).reshape([-1, self.nloc * 3]) + self.cell = self.cell.reshape([-1, 9]) + [self.cell, self.coord, self.atype] = [ + np.tile(ii, [self.nf, 1]) for ii in [self.cell, self.coord, self.atype] + ] + self.rcut = 1.01 + self.prec = 1e-10 + self.nsel = [10, 10] + self.ref_nlist = np.array( + [ + [0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1], + ] + ) + + def test_build_notype(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + nlist = build_neighbor_list( + ecoord, + eatype, + self.nloc, + self.rcut, + sum(self.nsel), + distinguish_types=False, + ) + np.testing.assert_allclose(nlist[0], nlist[1]) + nlist_mask = nlist[0] == -1 + nlist_loc = mapping[0][nlist[0]] + nlist_loc[nlist_mask] = -1 + np.testing.assert_allclose( + np.sort(nlist_loc, axis=-1), + np.sort(self.ref_nlist, axis=-1), + ) + + def test_build_type(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + nlist = build_neighbor_list( + ecoord, + eatype, + self.nloc, + self.rcut, + self.nsel, + distinguish_types=True, + ) + np.testing.assert_allclose(nlist[0], nlist[1]) + nlist_mask = nlist[0] == -1 + nlist_loc = mapping[0][nlist[0]] + nlist_loc[nlist_mask] = -1 + for ii in range(2): + np.testing.assert_allclose( + np.sort(np.split(nlist_loc, self.nsel, axis=-1)[ii], axis=-1), + np.sort(np.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1), + ) + + def test_build_multiple_nlist(self): + rcuts = [1.01, 2.01] + nsels = [20, 80] + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, max(rcuts) + ) + nlist1 = build_neighbor_list( + ecoord, + eatype, + self.nloc, + rcuts[1], + nsels[1] - 1, + distinguish_types=False, + ) + pad = -1 * np.ones([self.nf, self.nloc, 1], dtype=nlist1.dtype) + nlist2 = np.concatenate([nlist1, pad], axis=-1) + nlist0 = build_neighbor_list( + ecoord, + eatype, + self.nloc, + rcuts[0], + nsels[0], + distinguish_types=False, + ) + nlists = build_multiple_neighbor_list(ecoord, nlist1, rcuts, nsels) + for dd in range(2): + self.assertEqual( + nlists[get_multiple_nlist_key(rcuts[dd], nsels[dd])].shape[-1], + nsels[dd], + ) + np.testing.assert_allclose( + nlists[get_multiple_nlist_key(rcuts[0], nsels[0])], + nlist0, + ) + np.testing.assert_allclose( + nlists[get_multiple_nlist_key(rcuts[1], nsels[1])], + nlist2, + ) + + def test_extend_coord(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + # expected ncopy x nloc + self.assertEqual(list(ecoord.shape), [self.nf, self.nall * 3]) + self.assertEqual(list(eatype.shape), [self.nf, self.nall]) + self.assertEqual(list(mapping.shape), [self.nf, self.nall]) + # check the nloc part is identical with original coord + np.testing.assert_allclose( + ecoord[:, : self.nloc * 3], self.coord, rtol=self.prec, atol=self.prec + ) + # check the shift vectors are aligned with grid + shift_vec = ( + ecoord.reshape([-1, self.ns, self.nloc, 3]) + - self.coord.reshape([-1, self.nloc, 3])[:, None, :, :] + ) + shift_vec = shift_vec.reshape([-1, self.nall, 3]) + # hack!!! assumes identical cell across frames + shift_vec = np.matmul( + shift_vec, np.linalg.inv(self.cell.reshape([self.nf, 3, 3])[0]) + ) + # nf x nall x 3 + shift_vec = np.round(shift_vec) + # check: identical shift vecs + np.testing.assert_allclose( + shift_vec[0], shift_vec[1], rtol=self.prec, atol=self.prec + ) + # check: shift idx aligned with grid + mm, cc = np.unique(shift_vec[0][:, 0], return_counts=True) + np.testing.assert_allclose( + mm, + np.array([-2, -1, 0, 1, 2], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + np.testing.assert_allclose( + cc, + np.array([30, 30, 30, 30, 30], dtype=np.int32), + rtol=self.prec, + atol=self.prec, + ) + mm, cc = np.unique(shift_vec[1][:, 1], return_counts=True) + np.testing.assert_allclose( + mm, + np.array([-2, -1, 0, 1, 2], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + np.testing.assert_allclose( + cc, + np.array([30, 30, 30, 30, 30], dtype=np.int32), + rtol=self.prec, + atol=self.prec, + ) + mm, cc = np.unique(shift_vec[1][:, 2], return_counts=True) + np.testing.assert_allclose( + mm, + np.array([-1, 0, 1], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + np.testing.assert_allclose( + cc, + np.array([50, 50, 50], dtype=np.int32), + rtol=self.prec, + atol=self.prec, + ) diff --git a/source/tests/common/test_output_def.py b/source/tests/common/test_output_def.py index 4316fa5982..d0cf822247 100644 --- a/source/tests/common/test_output_def.py +++ b/source/tests/common/test_output_def.py @@ -6,7 +6,7 @@ import numpy as np -from deepmd.model_format import ( +from deepmd.dpmodel import ( FittingOutputDef, ModelOutputDef, NativeOP, @@ -14,7 +14,7 @@ fitting_check_output, model_check_output, ) -from deepmd.model_format.output_def import ( +from deepmd.dpmodel.output_def import ( check_var, ) diff --git a/source/tests/pt/test_descriptor_dpa1.py b/source/tests/pt/test_descriptor_dpa1.py index 725369d68d..21a43803c9 100644 --- a/source/tests/pt/test_descriptor_dpa1.py +++ b/source/tests/pt/test_descriptor_dpa1.py @@ -277,7 +277,7 @@ def test_descriptor_block(self): self.assertEqual(descriptor.shape[-1], des.get_dim_out()) self.assertAlmostEqual(6.0, des.get_rcut()) self.assertEqual(30, des.get_nsel()) - self.assertEqual(2, des.get_ntype()) + self.assertEqual(2, des.get_ntypes()) torch.testing.assert_close( descriptor.view(-1), self.ref_d, atol=1e-10, rtol=1e-10 ) @@ -329,7 +329,7 @@ def test_descriptor(self): self.assertEqual(descriptor.shape[-1], des.get_dim_out()) self.assertAlmostEqual(6.0, des.get_rcut()) self.assertEqual(30, des.get_nsel()) - self.assertEqual(2, des.get_ntype()) + self.assertEqual(2, des.get_ntypes()) torch.testing.assert_close( descriptor.view(-1), self.ref_d, atol=1e-10, rtol=1e-10 ) diff --git a/source/tests/pt/test_descriptor_dpa2.py b/source/tests/pt/test_descriptor_dpa2.py index aa6b16964e..e614e64c2f 100644 --- a/source/tests/pt/test_descriptor_dpa2.py +++ b/source/tests/pt/test_descriptor_dpa2.py @@ -224,7 +224,7 @@ def test_descriptor(self): self.assertEqual(descriptor.shape[-1], des.get_dim_out()) self.assertAlmostEqual(6.0, des.get_rcut()) self.assertEqual(30, des.get_nsel()) - self.assertEqual(2, des.get_ntype()) + self.assertEqual(2, des.get_ntypes()) torch.testing.assert_close( descriptor.view(-1), self.ref_d, atol=1e-10, rtol=1e-10 ) diff --git a/source/tests/pt/test_dp_atomic_model.py b/source/tests/pt/test_dp_atomic_model.py new file mode 100644 index 0000000000..2960cb97cc --- /dev/null +++ b/source/tests/pt/test_dp_atomic_model.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel import DPAtomicModel as DPDPAtomicModel +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.pt.model.task.ener import ( + InvarFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDPAtomicModel(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + # TODO: dirty hack to avoid data stat!!! + md0 = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = md0.forward_atomic(*args) + ret1 = md1.forward_atomic(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) + + def test_dp_consistency(self): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + ds = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ) + type_map = ["foo", "bar"] + md0 = DPDPAtomicModel(ds, ft, type_map=type_map) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE) + args0 = [self.coord_ext, self.atype_ext, self.nlist] + args1 = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = md0.forward_atomic(*args0) + ret1 = md1.forward_atomic(*args1) + np.testing.assert_allclose( + ret0["energy"], + to_numpy_array(ret1["energy"]), + ) + + def test_jit(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + # TODO: dirty hack to avoid data stat!!! + md0 = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + torch.jit.script(md0) diff --git a/source/tests/pt/test_dp_model.py b/source/tests/pt/test_dp_model.py new file mode 100644 index 0000000000..79f65d26d6 --- /dev/null +++ b/source/tests/pt/test_dp_model.py @@ -0,0 +1,388 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel import DPModel as DPDPModel +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model.ener import ( + DPModel, +) +from deepmd.pt.model.task.ener import ( + InvarFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, + TestCaseSingleFrameWithoutNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDPModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist): + def setUp(self): + TestCaseSingleFrameWithoutNlist.setUp(self) + + def test_self_consistency(self): + nf, nloc = self.atype.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + # TODO: dirty hack to avoid data stat!!! + md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] + ret0 = md0.forward_common(*args) + ret1 = md1.forward_common(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_redu"]), + to_numpy_array(ret1["energy_redu"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_r"]), + to_numpy_array(ret1["energy_derv_r"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c_redu"]), + to_numpy_array(ret1["energy_derv_c_redu"]), + ) + ret0 = md0.forward_common(*args, do_atomic_virial=True) + ret1 = md1.forward_common(*args, do_atomic_virial=True) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c"]), + to_numpy_array(ret1["energy_derv_c"]), + ) + + coord_ext, atype_ext, mapping = extend_coord_with_ghosts( + to_torch_tensor(self.coord), + to_torch_tensor(self.atype), + to_torch_tensor(self.cell), + self.rcut, + ) + nlist = build_neighbor_list( + coord_ext, + atype_ext, + self.nloc, + self.rcut, + self.sel, + distinguish_types=md0.distinguish_types(), + ) + args = [coord_ext, atype_ext, nlist] + ret2 = md0.forward_common_lower(*args, do_atomic_virial=True) + # check the consistency between the reduced virial from + # forward_common and forward_common_lower + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c_redu"]), + to_numpy_array(ret2["energy_derv_c_redu"]), + ) + + def test_dp_consistency(self): + nf, nloc = self.atype.shape + nfp, nap = 2, 3 + ds = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + numb_fparam=nfp, + numb_aparam=nap, + ) + type_map = ["foo", "bar"] + md0 = DPDPModel(ds, ft, type_map=type_map) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + + rng = np.random.default_rng() + fparam = rng.normal(size=[self.nf, nfp]) + aparam = rng.normal(size=[self.nf, nloc, nap]) + args0 = [self.coord, self.atype, self.cell] + args1 = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] + kwargs0 = {"fparam": fparam, "aparam": aparam} + kwargs1 = {kk: to_torch_tensor(vv) for kk, vv in kwargs0.items()} + ret0 = md0.call(*args0, **kwargs0) + ret1 = md1.forward_common(*args1, **kwargs1) + np.testing.assert_allclose( + ret0["energy"], + to_numpy_array(ret1["energy"]), + ) + np.testing.assert_allclose( + ret0["energy_redu"], + to_numpy_array(ret1["energy_redu"]), + ) + + def test_dp_consistency_nopbc(self): + nf, nloc = self.atype.shape + nfp, nap = 2, 3 + ds = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + numb_fparam=nfp, + numb_aparam=nap, + ) + type_map = ["foo", "bar"] + md0 = DPDPModel(ds, ft, type_map=type_map) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + + rng = np.random.default_rng() + fparam = rng.normal(size=[self.nf, nfp]) + aparam = rng.normal(size=[self.nf, self.nloc, nap]) + args0 = [self.coord, self.atype] + args1 = [to_torch_tensor(ii) for ii in args0] + kwargs0 = {"fparam": fparam, "aparam": aparam} + kwargs1 = {kk: to_torch_tensor(vv) for kk, vv in kwargs0.items()} + ret0 = md0.call(*args0, **kwargs0) + ret1 = md1.forward_common(*args1, **kwargs1) + np.testing.assert_allclose( + ret0["energy"], + to_numpy_array(ret1["energy"]), + ) + np.testing.assert_allclose( + ret0["energy_redu"], + to_numpy_array(ret1["energy_redu"]), + ) + + +class TestDPModelLower(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + # TODO: dirty hack to avoid data stat!!! + md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = md0.forward_common_lower(*args) + ret1 = md1.forward_common_lower(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_redu"]), + to_numpy_array(ret1["energy_redu"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_r"]), + to_numpy_array(ret1["energy_derv_r"]), + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c_redu"]), + to_numpy_array(ret1["energy_derv_c_redu"]), + ) + ret0 = md0.forward_common_lower(*args, do_atomic_virial=True) + ret1 = md1.forward_common_lower(*args, do_atomic_virial=True) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c"]), + to_numpy_array(ret1["energy_derv_c"]), + ) + + def test_dp_consistency(self): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + ds = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ) + type_map = ["foo", "bar"] + md0 = DPDPModel(ds, ft, type_map=type_map) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + args0 = [self.coord_ext, self.atype_ext, self.nlist] + args1 = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + ret0 = md0.call_lower(*args0) + ret1 = md1.forward_common_lower(*args1) + np.testing.assert_allclose( + ret0["energy"], + to_numpy_array(ret1["energy"]), + ) + np.testing.assert_allclose( + ret0["energy_redu"], + to_numpy_array(ret1["energy_redu"]), + ) + + def test_jit(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + # TODO: dirty hack to avoid data stat!!! + md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + torch.jit.script(md0) + + +class TestDPModelFormatNlist(unittest.TestCase): + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 3 + self.nall = 5 + self.nf, self.nt = 1, 2 + self.coord_ext = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, -2, 0], + [2.3, 0, 0], + ], + dtype=np.float64, + ).reshape([1, self.nall * 3]) + # sel = [5, 2] + self.sel = [5, 2] + self.expected_nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, -1, -1], + ], + dtype=int, + ).reshape([1, self.nloc, sum(self.sel)]) + self.atype_ext = np.array([0, 0, 1, 0, 1], dtype=int).reshape([1, self.nall]) + self.rcut_smth = 0.4 + self.rcut = 2.0 + + nf, nloc, nnei = self.expected_nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + # TODO: dirty hack to avoid data stat!!! + self.md = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + + def test_nlist_eq(self): + # n_nnei == nnei + nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, -1, -1], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + to_torch_tensor(self.coord_ext), + to_torch_tensor(self.atype_ext), + to_torch_tensor(nlist), + ) + np.testing.assert_allclose(self.expected_nlist, to_numpy_array(nlist1)) + + def test_nlist_st(self): + # n_nnei < nnei + nlist = np.array( + [ + [1, 3, -1, 2], + [0, -1, -1, 2], + [0, 1, -1, -1], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + to_torch_tensor(self.coord_ext), + to_torch_tensor(self.atype_ext), + to_torch_tensor(nlist), + ) + np.testing.assert_allclose(self.expected_nlist, to_numpy_array(nlist1)) + + def test_nlist_lt(self): + # n_nnei > nnei + nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1, -1, 4], + [0, -1, 4, -1, -1, 2, -1, 3, -1], + [0, 1, -1, -1, -1, 4, -1, -1, 3], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + to_torch_tensor(self.coord_ext), + to_torch_tensor(self.atype_ext), + to_torch_tensor(nlist), + ) + np.testing.assert_allclose(self.expected_nlist, to_numpy_array(nlist1)) diff --git a/source/tests/pt/test_ener_fitting.py b/source/tests/pt/test_ener_fitting.py index eece8447df..cbddf34dd6 100644 --- a/source/tests/pt/test_ener_fitting.py +++ b/source/tests/pt/test_ener_fitting.py @@ -5,7 +5,7 @@ import numpy as np import torch -from deepmd.model_format import InvarFitting as DPInvarFitting +from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting from deepmd.pt.model.descriptor.se_a import ( DescrptSeA, ) diff --git a/source/tests/pt/test_env_mat.py b/source/tests/pt/test_env_mat.py index f4931e9ecc..b9f0ff1981 100644 --- a/source/tests/pt/test_env_mat.py +++ b/source/tests/pt/test_env_mat.py @@ -5,7 +5,7 @@ import torch try: - from deepmd.model_format import ( + from deepmd.dpmodel import ( EnvMat, ) @@ -47,7 +47,7 @@ def setUp(self): [ [1, 3, -1, -1, -1, 2, -1], [0, -1, -1, -1, -1, 2, -1], - [0, 1, -1, -1, -1, 0, -1], + [0, 1, -1, -1, -1, -1, -1], ], dtype=int, ).reshape([1, self.nloc, sum(self.sel)]) @@ -55,6 +55,27 @@ def setUp(self): self.rcut_smth = 2.2 +class TestCaseSingleFrameWithoutNlist: + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 3 + self.nf, self.nt = 1, 2 + self.coord = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + dtype=np.float64, + ).reshape([1, self.nloc * 3]) + self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc]) + self.cell = 2.0 * np.eye(3).reshape([1, 9]) + # sel = [5, 2] + self.sel = [5, 2] + self.rcut = 0.4 + self.rcut_smth = 2.2 + + # to be merged with the tf test case @unittest.skipIf(not support_env_mat, "EnvMat not supported") class TestEnvMat(unittest.TestCase, TestCaseSingleFrameWithNlist): diff --git a/source/tests/pt/test_mlp.py b/source/tests/pt/test_mlp.py index 26f0041bf9..3a78b8294d 100644 --- a/source/tests/pt/test_mlp.py +++ b/source/tests/pt/test_mlp.py @@ -42,7 +42,7 @@ try: - from deepmd.model_format import ( + from deepmd.dpmodel import ( NativeLayer, NativeNet, ) @@ -54,7 +54,7 @@ support_native_net = False try: - from deepmd.model_format import EmbeddingNet as DPEmbeddingNet + from deepmd.dpmodel import EmbeddingNet as DPEmbeddingNet support_embedding_net = True except ModuleNotFoundError: @@ -63,7 +63,7 @@ support_embedding_net = False try: - from deepmd.model_format import FittingNet as DPFittingNet + from deepmd.dpmodel import FittingNet as DPFittingNet support_fitting_net = True except ModuleNotFoundError: diff --git a/source/tests/pt/test_rotation.py b/source/tests/pt/test_rotation.py index 58ec80e0d6..a62e04eb89 100644 --- a/source/tests/pt/test_rotation.py +++ b/source/tests/pt/test_rotation.py @@ -111,22 +111,18 @@ def test_rotation(self): result1 = self.model(**get_data(self.origin_batch)) result2 = self.model(**get_data(self.rotated_batch)) rotation = torch.from_numpy(self.rotation).to(env.DEVICE) - self.assertTrue(result1["energy"] == result2["energy"]) + torch.testing.assert_close(result1["energy"], result2["energy"]) if "force" in result1: - self.assertTrue( - torch.allclose( - result2["force"][0], torch.matmul(rotation, result1["force"][0].T).T - ) + torch.testing.assert_close( + result2["force"][0], torch.matmul(rotation, result1["force"][0].T).T ) if "virial" in result1: - self.assertTrue( - torch.allclose( - result2["virial"][0].view([3, 3]), - torch.matmul( - torch.matmul(rotation, result1["virial"][0].view([3, 3]).T), - rotation.T, - ), - ) + torch.testing.assert_close( + result2["virial"][0].view([3, 3]), + torch.matmul( + torch.matmul(rotation, result1["virial"][0].view([3, 3]).T), + rotation.T, + ), ) diff --git a/source/tests/pt/test_se_e2_a.py b/source/tests/pt/test_se_e2_a.py index 0da80ea1ea..ec49725929 100644 --- a/source/tests/pt/test_se_e2_a.py +++ b/source/tests/pt/test_se_e2_a.py @@ -6,8 +6,8 @@ import torch try: - # from deepmd.model_format import PRECISION_DICT as DP_PRECISION_DICT - from deepmd.model_format import DescrptSeA as DPDescrptSeA + # from deepmd.dpmodel import PRECISION_DICT as DP_PRECISION_DICT + from deepmd.dpmodel import DescrptSeA as DPDescrptSeA support_se_e2_a = True except ModuleNotFoundError: diff --git a/source/tests/pt/test_utils.py b/source/tests/pt/test_utils.py index 9c9a9479ad..145fe6c510 100644 --- a/source/tests/pt/test_utils.py +++ b/source/tests/pt/test_utils.py @@ -24,7 +24,7 @@ def test_to_numpy(self): onk = to_numpy_array(bar) self.assertEqual(onk.dtype, npp) with self.assertRaises(ValueError) as ee: - foo = foo.astype(np.int32) + foo = foo.astype(np.int8) bar = to_torch_tensor(foo) with self.assertRaises(ValueError) as ee: bar = to_torch_tensor(foo)