diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index ce1a6708e6..e3d6d8bcd1 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from abc import ( ABC, - abstractclassmethod, abstractmethod, ) from typing import ( @@ -13,6 +12,10 @@ from deepmd.dpmodel.output_def import ( FittingOutputDef, ) +from deepmd.utils.plugin import ( + PluginVariant, + make_plugin_registry, +) def make_base_atomic_model( @@ -31,7 +34,7 @@ def make_base_atomic_model( """ - class BAM(ABC): + class BAM(ABC, PluginVariant, make_plugin_registry("atomic model")): """Base Atomic Model provides the interfaces of an atomic model.""" @abstractmethod @@ -128,8 +131,9 @@ def fwd( def serialize(self) -> dict: pass - @abstractclassmethod - def deserialize(cls): + @classmethod + @abstractmethod + def deserialize(cls, data: dict): pass def do_grad_r( diff --git a/deepmd/dpmodel/model/__init__.py b/deepmd/dpmodel/model/__init__.py index cb796e6d35..c1ff15ab0d 100644 --- a/deepmd/dpmodel/model/__init__.py +++ b/deepmd/dpmodel/model/__init__.py @@ -8,6 +8,8 @@ according to output variable definition `deepmd.dpmodel.OutputVariableDef`. +All models should be inherited from :class:`deepmd.dpmodel.model.base_model.BaseModel`. +Models generated by `make_model` have already done it. """ from .dp_model import ( diff --git a/deepmd/dpmodel/model/dp_model.py b/deepmd/dpmodel/model/dp_model.py index 15f9027d4c..8d84c435b4 100644 --- a/deepmd/dpmodel/model/dp_model.py +++ b/deepmd/dpmodel/model/dp_model.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + from deepmd.dpmodel.atomic_model import ( DPAtomicModel, ) @@ -17,7 +18,7 @@ # use "class" to resolve "Variable not allowed in type expression" @BaseModel.register("standard") -class DPModel(make_model(DPAtomicModel), BaseModel): +class DPModel(make_model(DPAtomicModel)): @classmethod def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index d1f671c8de..6022fd3e73 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -4,10 +4,14 @@ List, Optional, Tuple, + Type, ) import numpy as np +from deepmd.dpmodel.atomic_model.base_atomic_model import ( + BaseAtomicModel, +) from deepmd.dpmodel.common import ( GLOBAL_ENER_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION, @@ -15,7 +19,11 @@ RESERVED_PRECISON_DICT, NativeOP, ) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) from deepmd.dpmodel.output_def import ( + FittingOutputDef, ModelOutputDef, OutputVariableCategory, OutputVariableOperation, @@ -34,7 +42,7 @@ ) -def make_model(T_AtomicModel): +def make_model(T_AtomicModel: Type[BaseAtomicModel]): """Make a model as a derived class of an atomic model. The model provide two interfaces. @@ -57,16 +65,18 @@ def make_model(T_AtomicModel): """ - class CM(T_AtomicModel, NativeOP): + class CM(NativeOP, BaseModel): def __init__( self, *args, + # underscore to prevent conflict with normal inputs + atomic_model_: Optional[T_AtomicModel] = None, **kwargs, ): - super().__init__( - *args, - **kwargs, - ) + if atomic_model_ is not None: + self.atomic_model: T_AtomicModel = atomic_model_ + else: + self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs) self.precision_dict = PRECISION_DICT self.reverse_precision_dict = RESERVED_PRECISON_DICT self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION @@ -208,7 +218,7 @@ def call_lower( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam - atomic_ret = self.forward_common_atomic( + atomic_ret = self.atomic_model.forward_common_atomic( cc_ext, extended_atype, nlist, @@ -377,4 +387,93 @@ def _format_nlist( assert ret.shape[-1] == nnei return ret + def do_grad_r( + self, + var_name: Optional[str] = None, + ) -> bool: + """Tell if the output variable `var_name` is r_differentiable. + if var_name is None, returns if any of the variable is r_differentiable. + """ + return self.atomic_model.do_grad_r(var_name) + + def do_grad_c( + self, + var_name: Optional[str] = None, + ) -> bool: + """Tell if the output variable `var_name` is c_differentiable. + if var_name is None, returns if any of the variable is c_differentiable. + """ + return self.atomic_model.do_grad_c(var_name) + + def serialize(self) -> dict: + return self.atomic_model.serialize() + + @classmethod + def deserialize(cls, data) -> "CM": + return cls(atomic_model_=T_AtomicModel.deserialize(data)) + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.atomic_model.get_dim_fparam() + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.atomic_model.get_dim_aparam() + + def get_sel_type(self) -> List[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.atomic_model.get_sel_type() + + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return self.atomic_model.is_aparam_nall() + + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return self.atomic_model.get_rcut() + + def get_type_map(self) -> List[str]: + """Get the type map.""" + return self.atomic_model.get_type_map() + + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.atomic_model.get_nsel() + + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.atomic_model.get_nnei() + + def get_model_def_script(self) -> str: + """Get the model definition script.""" + return self.atomic_model.get_model_def_script() + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.atomic_model.get_sel() + + def mixed_types(self) -> bool: + """If true, the model + 1. assumes total number of atoms aligned across frames; + 2. uses a neighbor list that does not distinguish different atomic types. + + If false, the model + 1. assumes total number of atoms of each atom type aligned across frames; + 2. uses a neighbor list that distinguishes different atomic types. + + """ + return self.atomic_model.mixed_types() + + def atomic_output_def(self) -> FittingOutputDef: + """Get the output def of the atomic model.""" + return self.atomic_model.atomic_output_def() + return CM diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 8180c48c81..d045220b6e 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -21,6 +21,9 @@ AtomExcludeMask, PairExcludeMask, ) +from deepmd.utils.path import ( + DPPath, +) BaseAtomicModel_ = make_base_atomic_model(torch.Tensor) @@ -55,12 +58,6 @@ def reinit_pair_exclude( else: self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types) - # export public methods that are not abstract - get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel) - get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei) - get_ntypes = torch.jit.export(BaseAtomicModel_.get_ntypes) - - @torch.jit.export def get_model_def_script(self) -> str: return self.model_def_script @@ -126,3 +123,25 @@ def serialize(self) -> dict: "atom_exclude_types": self.atom_exclude_types, "pair_exclude_types": self.pair_exclude_types, } + + def compute_or_load_stat( + self, + sampled_func, + stat_file_path: Optional[DPPath] = None, + ): + """ + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. + + Parameters + ---------- + sampled_func + The sampled data frames from different data systems. + stat_file_path + The path to the statistics files. + """ + raise NotImplementedError diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index cad1e1cc88..ec08850524 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -223,17 +223,14 @@ def wrapped_sampler(): if self.fitting_net is not None: self.fitting_net.compute_output_stats(wrapped_sampler, stat_file_path) - @torch.jit.export def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.fitting_net.get_dim_fparam() - @torch.jit.export def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return self.fitting_net.get_dim_aparam() - @torch.jit.export def get_sel_type(self) -> List[int]: """Get the selected atom types of this model. @@ -243,7 +240,6 @@ def get_sel_type(self) -> List[int]: """ return self.fitting_net.get_sel_type() - @torch.jit.export def is_aparam_nall(self) -> bool: """Check whether the shape of atomic parameters is (nframes, nall, ndim). diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 66d19c0a02..3fb3ee90dd 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -96,12 +96,10 @@ def mixed_types(self) -> bool: """ return True - @torch.jit.export def get_rcut(self) -> float: """Get the cut-off radius.""" return max(self.get_model_rcuts()) - @torch.jit.export def get_type_map(self) -> List[str]: """Get the type map.""" return self.type_map @@ -292,18 +290,15 @@ def _compute_weight( """This should be a list of user defined weights that matches the number of models to be combined.""" raise NotImplementedError - @torch.jit.export def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" # tricky... return max([model.get_dim_fparam() for model in self.models]) - @torch.jit.export def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return max([model.get_dim_aparam() for model in self.models]) - @torch.jit.export def get_sel_type(self) -> List[int]: """Get the selected atom types of this model. @@ -324,7 +319,6 @@ def get_sel_type(self) -> List[int]: ) ).tolist() - @torch.jit.export def is_aparam_nall(self) -> bool: """Check whether the shape of atomic parameters is (nframes, nall, ndim). diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 19a67fc8ff..db0a2efa4a 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -139,11 +139,9 @@ def fitting_output_def(self) -> FittingOutputDef: ] ) - @torch.jit.export def get_rcut(self) -> float: return self.rcut - @torch.jit.export def get_type_map(self) -> List[str]: return self.type_map @@ -454,17 +452,14 @@ def _calculate_ener(coef: torch.Tensor, uu: torch.Tensor) -> torch.Tensor: ener = etmp * uu + a0 # this energy has the extrapolated value when rcut > rmax return ener - @torch.jit.export def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return 0 - @torch.jit.export def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return 0 - @torch.jit.export def get_sel_type(self) -> List[int]: """Get the selected atom types of this model. @@ -474,7 +469,6 @@ def get_sel_type(self) -> List[int]: """ return [] - @torch.jit.export def is_aparam_nall(self) -> bool: """Check whether the shape of atomic parameters is (nframes, nall, ndim). diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 8e4352e60c..3098dc7677 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -7,6 +7,8 @@ communication of the atomic properties according to output variable definition `deepmd.dpmodel.OutputVariableDef`. +All models should be inherited from :class:`deepmd.pt.model.model.model.BaseModel`. +Models generated by `make_model` have already done it. """ import copy @@ -147,8 +149,8 @@ def get_standard_model(model_params): pair_exclude_types = model_params.get("pair_exclude_types", []) model = DPModel( - descriptor, - fitting, + descriptor=descriptor, + fitting=fitting, type_map=model_params["type_map"], atom_exclude_types=atom_exclude_types, pair_exclude_types=pair_exclude_types, diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index 8b6f2c47c1..45b120771b 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -38,7 +38,7 @@ def forward( aparam=aparam, do_atomic_virial=do_atomic_virial, ) - if self.fitting_net is not None: + if self.get_fitting_net() is not None: model_predict = {} model_predict["dipole"] = model_ret["dipole"] model_predict["global_dipole"] = model_ret["dipole_redu"] @@ -77,7 +77,7 @@ def forward_lower( aparam=aparam, do_atomic_virial=do_atomic_virial, ) - if self.fitting_net is not None: + if self.get_fitting_net() is not None: model_predict = {} model_predict["dipole"] = model_ret["dipole"] model_predict["global_dipole"] = model_ret["dipole_redu"] diff --git a/deepmd/pt/model/model/dp_model.py b/deepmd/pt/model/model/dp_model.py index 0df45d4f84..138398539a 100644 --- a/deepmd/pt/model/model/dp_model.py +++ b/deepmd/pt/model/model/dp_model.py @@ -1,4 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + Optional, +) + +import torch + from deepmd.pt.model.atomic_model import ( DPAtomicModel, ) @@ -25,8 +32,16 @@ @BaseModel.register("standard") -class DPModel(make_model(DPAtomicModel), BaseModel): - def __new__(cls, descriptor, fitting, *args, **kwargs): +class DPModel(make_model(DPAtomicModel)): + def __new__( + cls, + descriptor=None, + fitting=None, + *args, + # disallow positional atomic_model_ + atomic_model_: Optional[DPAtomicModel] = None, + **kwargs, + ): from deepmd.pt.model.model.dipole_model import ( DipoleModel, ) @@ -37,6 +52,11 @@ def __new__(cls, descriptor, fitting, *args, **kwargs): PolarModel, ) + if atomic_model_ is not None: + fitting = atomic_model_.fitting_net + else: + assert fitting is not None, "fitting network is not provided" + # according to the fitting network to decide the type of the model if cls is DPModel: # map fitting to model @@ -67,3 +87,30 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): global_jdata, local_jdata["descriptor"] ) return local_jdata_cpy + + def get_fitting_net(self): + """Get the fitting network.""" + return self.atomic_model.fitting_net + + def get_descriptor(self): + """Get the descriptor.""" + return self.atomic_model.descriptor + + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + # directly call the forward_common method when no specific transform rule + return self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index fdf9334119..bbc82b8d77 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -24,7 +24,7 @@ @BaseModel.register("zbl") -class DPZBLModel(DPZBLModel_, BaseModel): +class DPZBLModel(DPZBLModel_): model_type = "ener" def __init__( diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index cd4f78a2e2..5217293623 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -38,7 +38,7 @@ def forward( aparam=aparam, do_atomic_virial=do_atomic_virial, ) - if self.fitting_net is not None: + if self.get_fitting_net() is not None: model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] @@ -79,7 +79,7 @@ def forward_lower( aparam=aparam, do_atomic_virial=do_atomic_virial, ) - if self.fitting_net is not None: + if self.get_fitting_net() is not None: model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index f9daa916a8..0a5f286040 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -4,6 +4,7 @@ List, Optional, Tuple, + Type, ) import torch @@ -12,10 +13,17 @@ ModelOutputDef, ) from deepmd.dpmodel.output_def import ( + FittingOutputDef, OutputVariableCategory, OutputVariableOperation, check_operation_applied, ) +from deepmd.pt.model.atomic_model.base_atomic_model import ( + BaseAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) from deepmd.pt.model.model.transform_output import ( communicate_extended_output, fit_output_to_model_output, @@ -30,9 +38,12 @@ extend_input_and_build_neighbor_list, nlist_distinguish_types, ) +from deepmd.utils.path import ( + DPPath, +) -def make_model(T_AtomicModel): +def make_model(T_AtomicModel: Type[BaseAtomicModel]): """Make a model as a derived class of an atomic model. The model provide two interfaces. @@ -55,16 +66,19 @@ def make_model(T_AtomicModel): """ - class CM(T_AtomicModel): + class CM(BaseModel): def __init__( self, *args, + # underscore to prevent conflict with normal inputs + atomic_model_: Optional[T_AtomicModel] = None, **kwargs, ): - super().__init__( - *args, - **kwargs, - ) + super().__init__(*args, **kwargs) + if atomic_model_ is not None: + self.atomic_model: T_AtomicModel = atomic_model_ + else: + self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs) self.precision_dict = PRECISION_DICT self.reverse_precision_dict = RESERVED_PRECISON_DICT self.global_pt_float_precision = GLOBAL_PT_FLOAT_PRECISION @@ -203,7 +217,7 @@ def forward_common_lower( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam - atomic_ret = self.forward_common_atomic( + atomic_ret = self.atomic_model.forward_common_atomic( cc_ext, extended_atype, nlist, @@ -382,4 +396,110 @@ def _format_nlist( assert nlist.shape[-1] == nnei return nlist + def do_grad_r( + self, + var_name: Optional[str] = None, + ) -> bool: + """Tell if the output variable `var_name` is r_differentiable. + if var_name is None, returns if any of the variable is r_differentiable. + """ + return self.atomic_model.do_grad_r(var_name) + + def do_grad_c( + self, + var_name: Optional[str] = None, + ) -> bool: + """Tell if the output variable `var_name` is c_differentiable. + if var_name is None, returns if any of the variable is c_differentiable. + """ + return self.atomic_model.do_grad_c(var_name) + + def serialize(self) -> dict: + return self.atomic_model.serialize() + + @classmethod + def deserialize(cls, data) -> "CM": + return cls(atomic_model_=T_AtomicModel.deserialize(data)) + + @torch.jit.export + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.atomic_model.get_dim_fparam() + + @torch.jit.export + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.atomic_model.get_dim_aparam() + + @torch.jit.export + def get_sel_type(self) -> List[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.atomic_model.get_sel_type() + + @torch.jit.export + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return self.atomic_model.is_aparam_nall() + + @torch.jit.export + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return self.atomic_model.get_rcut() + + @torch.jit.export + def get_type_map(self) -> List[str]: + """Get the type map.""" + return self.atomic_model.get_type_map() + + @torch.jit.export + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.atomic_model.get_nsel() + + @torch.jit.export + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.atomic_model.get_nnei() + + @torch.jit.export + def get_model_def_script(self) -> str: + """Get the model definition script.""" + return self.atomic_model.get_model_def_script() + + def atomic_output_def(self) -> FittingOutputDef: + """Get the output def of the atomic model.""" + return self.atomic_model.atomic_output_def() + + def compute_or_load_stat( + self, + sampled_func, + stat_file_path: Optional[DPPath] = None, + ): + """Compute or load the statistics.""" + return self.atomic_model.compute_or_load_stat(sampled_func, stat_file_path) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.atomic_model.get_sel() + + def mixed_types(self) -> bool: + """If true, the model + 1. assumes total number of atoms aligned across frames; + 2. uses a neighbor list that does not distinguish different atomic types. + + If false, the model + 1. assumes total number of atoms of each atom type aligned across frames; + 2. uses a neighbor list that distinguishes different atomic types. + + """ + return self.atomic_model.mixed_types() + return CM diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index e32d2f307d..3d4618449a 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -3,6 +3,8 @@ Optional, ) +import torch + from deepmd.dpmodel.model.base_model import ( make_base_model, ) @@ -11,61 +13,14 @@ ) -# trick: torch.nn.Module should not be inherbited here, otherwise, -# the abstract method will override the method from the atomic model -# as Python resolves method lookups using the C3 linearisation. -# See https://stackoverflow.com/a/47117600/9567349 -# Take an example, this is the situation for only inheriting make_model(): -# torch.nn.Module BaseAtomicModel make_model() -# | | | -# ------------------------- | -# | | -# DPAtomicModel BaseModel -# | | -# make_model(DPAtomicModel) | -# | | -# ---------------------------------- -# | -# DPModel -# -# The order is: DPModel -> make_model(DPAtomicModel) -> DPAtomicModel -> -# torch.nn.Module -> BaseAtomicModel -> BaseModel -> make_model() -# -# However, if BaseModel also inherbits from torch.nn.Module: -# torch.nn.Module make_model() -# | | -# |--------------------------- | -# | | | -# | BaseAtomicModel | | -# | | | | -# |------------- ---------- -# | | -# DPAtomicModel BaseModel -# | | -# | | -# make_model(DPAtomicModel) | -# | | -# | | -# -------------------------------- -# | -# | -# DPModel -# -# The order is DPModel -> make_model(DPAtomicModel) -> DPAtomicModel -> -# BaseModel -> torch.nn.Module -> BaseAtomicModel -> make_model() -# BaseModel has higher proirity than BaseAtomicModel, which is not what -# we want. -# Alternatively, we can also make BaseAtomicModel in front of torch.nn.Module -# in DPAtomicModel (and other classes), but this requires the developer aware -# of it when developing it... -class BaseModel(make_base_model()): +class BaseModel(torch.nn.Module, make_base_model()): def __init__(self, *args, **kwargs): """Construct a basic model for different tasks.""" - super().__init__(*args, **kwargs) + torch.nn.Module.__init__(self) def compute_or_load_stat( self, - sampled, + sampled_func, stat_file_path: Optional[DPPath] = None, ): """ @@ -78,7 +33,7 @@ def compute_or_load_stat( Parameters ---------- - sampled + sampled_func The sampled data frames from different data systems. stat_file_path The path to the statistics files. diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index bf430c6706..403058aa47 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -38,7 +38,7 @@ def forward( aparam=aparam, do_atomic_virial=do_atomic_virial, ) - if self.fitting_net is not None: + if self.get_fitting_net() is not None: model_predict = {} model_predict["polar"] = model_ret["polar"] model_predict["global_polar"] = model_ret["polar_redu"] @@ -69,7 +69,7 @@ def forward_lower( aparam=aparam, do_atomic_virial=do_atomic_virial, ) - if self.fitting_net is not None: + if self.get_fitting_net() is not None: model_predict = {} model_predict["polar"] = model_ret["polar"] model_predict["global_polar"] = model_ret["polar_redu"] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index fb28f0c4f2..b20d80c629 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -520,8 +520,11 @@ def get_loss(loss_params, start_lr, _ntypes, _model): model_params["type_map"], model_params["new_type_map"], ) - if hasattr(self.model, "fitting_net"): - self.model.fitting_net.change_energy_bias( + # TODO: need an interface instead of fetching fitting_net!!!!!!!!! + if hasattr(self.model, "atomic_model") and hasattr( + self.model.atomic_model, "fitting_net" + ): + self.model.atomic_model.fitting_net.change_energy_bias( config, self.model, old_type_map, @@ -531,7 +534,7 @@ def get_loss(loss_params, start_lr, _ntypes, _model): ) elif isinstance(self.model, DPZBLModel): # need to updated - self.model.change_energy_bias() + self.model.atomic_model.change_energy_bias() else: raise NotImplementedError if init_frz_model is not None: diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index c1040fb9e3..061cd777db 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -75,12 +75,12 @@ def share_params(self, shared_links, resume=False): shared_level_base = shared_base["shared_level"] if "descriptor" in class_type_base: if class_type_base == "descriptor": - base_class = self.model[model_key_base].__getattr__("descriptor") + base_class = self.model[model_key_base].get_descriptor() elif "hybrid" in class_type_base: hybrid_index = int(class_type_base.split("_")[-1]) base_class = ( self.model[model_key_base] - .__getattr__("descriptor") + .get_descriptor() .descriptor_list[hybrid_index] ) else: @@ -96,14 +96,12 @@ def share_params(self, shared_links, resume=False): "descriptor" in class_type_link ), f"Class type mismatched: {class_type_base} vs {class_type_link}!" if class_type_link == "descriptor": - link_class = self.model[model_key_link].__getattr__( - "descriptor" - ) + link_class = self.model[model_key_link].get_descriptor() elif "hybrid" in class_type_link: hybrid_index = int(class_type_link.split("_")[-1]) link_class = ( self.model[model_key_link] - .__getattr__("descriptor") + .get_descriptor() .descriptor_list[hybrid_index] ) else: diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py index adc682a41f..7f24ffdc53 100644 --- a/source/tests/pt/model/test_linear_atomic_model.py +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -178,11 +178,13 @@ def test_self_consistency(self): def test_jit(self): md1 = torch.jit.script(self.md1) - self.assertEqual(md1.get_rcut(), self.rcut) - self.assertEqual(md1.get_type_map(), ["foo", "bar"]) + # atomic model no more export methods + # self.assertEqual(md1.get_rcut(), self.rcut) + # self.assertEqual(md1.get_type_map(), ["foo", "bar"]) md3 = torch.jit.script(self.md3) - self.assertEqual(md3.get_rcut(), self.rcut) - self.assertEqual(md3.get_type_map(), ["foo", "bar"]) + # atomic model no more export methods + # self.assertEqual(md3.get_rcut(), self.rcut) + # self.assertEqual(md3.get_type_map(), ["foo", "bar"]) class TestRemmapMethod(unittest.TestCase): diff --git a/source/tests/pt/model/test_model.py b/source/tests/pt/model/test_model.py index f42c11aa4c..5a30de7ac8 100644 --- a/source/tests/pt/model/test_model.py +++ b/source/tests/pt/model/test_model.py @@ -60,13 +60,13 @@ def torch2tf(torch_name, last_layer_id=None): fields = torch_name.split(".") - offset = int(fields[2] == "networks") + offset = int(fields[3] == "networks") + 1 element_id = int(fields[2 + offset]) - if fields[0] == "descriptor": + if fields[1] == "descriptor": layer_id = int(fields[4 + offset]) + 1 weight_type = fields[5 + offset] ret = "filter_type_all/%s_%d_%d:0" % (weight_type, layer_id, element_id) - elif fields[0] == "fitting_net": + elif fields[1] == "fitting_net": layer_id = int(fields[4 + offset]) weight_type = fields[5 + offset] if layer_id != last_layer_id: @@ -301,7 +301,7 @@ def test_consistency(self): ) # Keep statistics consistency between 2 implentations - my_em = my_model.descriptor + my_em = my_model.get_descriptor() mean = stat_dict["descriptor.mean"].reshape([self.ntypes, my_em.get_nsel(), 4]) stddev = stat_dict["descriptor.stddev"].reshape( [self.ntypes, my_em.get_nsel(), 4] @@ -310,7 +310,7 @@ def test_consistency(self): torch.tensor(mean, device=DEVICE), torch.tensor(stddev, device=DEVICE), ) - my_model.fitting_net.bias_atom_e = torch.tensor( + my_model.get_fitting_net().bias_atom_e = torch.tensor( stat_dict["fitting_net.bias_atom_e"], device=DEVICE ) diff --git a/source/tests/pt/model/test_pairtab_atomic_model.py b/source/tests/pt/model/test_pairtab_atomic_model.py index 322de51a2c..165e3dead7 100644 --- a/source/tests/pt/model/test_pairtab_atomic_model.py +++ b/source/tests/pt/model/test_pairtab_atomic_model.py @@ -98,8 +98,9 @@ def test_with_mask(self): def test_jit(self): model = torch.jit.script(self.model) - self.assertEqual(model.get_rcut(), 0.02) - self.assertEqual(model.get_type_map(), ["H", "O"]) + # atomic model no more export methods + # self.assertEqual(model.get_rcut(), 0.02) + # self.assertEqual(model.get_type_map(), ["H", "O"]) def test_deserialize(self): model1 = PairTabAtomicModel.deserialize(self.model.serialize()) @@ -121,8 +122,9 @@ def test_deserialize(self): ) model1 = torch.jit.script(model1) - self.assertEqual(model1.get_rcut(), 0.02) - self.assertEqual(model1.get_type_map(), ["H", "O"]) + # atomic model no more export methods + # self.assertEqual(model1.get_rcut(), 0.02) + # self.assertEqual(model1.get_type_map(), ["H", "O"]) def test_cross_deserialize(self): model_dict = self.model.serialize() # pytorch model to dict diff --git a/source/tests/pt/test_finetune.py b/source/tests/pt/test_finetune.py index d21a44acc7..dd72eb4718 100644 --- a/source/tests/pt/test_finetune.py +++ b/source/tests/pt/test_finetune.py @@ -44,27 +44,29 @@ def test_finetune_change_energy_bias(self): else: model = get_model(self.model_config) if isinstance(model, EnergyModel): - model.fitting_net.bias_atom_e = torch.rand_like( - model.fitting_net.bias_atom_e + model.get_fitting_net().bias_atom_e = torch.rand_like( + model.get_fitting_net().bias_atom_e ) energy_bias_before = deepcopy( - model.fitting_net.bias_atom_e.detach().cpu().numpy().reshape(-1) + model.get_fitting_net().bias_atom_e.detach().cpu().numpy().reshape(-1) ) bias_atom_e_input = deepcopy( - model.fitting_net.bias_atom_e.detach().cpu().numpy().reshape(-1) + model.get_fitting_net().bias_atom_e.detach().cpu().numpy().reshape(-1) ) elif isinstance(model, DPZBLModel): - model.dp_model.fitting_net.bias_atom_e = torch.rand_like( - model.dp_model.fitting_net.bias_atom_e + model.dp_model.get_fitting_net().bias_atom_e = torch.rand_like( + model.dp_model.get_fitting_net().bias_atom_e ) energy_bias_before = deepcopy( - model.dp_model.fitting_net.bias_atom_e.detach() + model.dp_model.get_fitting_net() + .bias_atom_e.detach() .cpu() .numpy() .reshape(-1) ) bias_atom_e_input = deepcopy( - model.dp_model.fitting_net.bias_atom_e.detach() + model.dp_model.get_fitting_net() + .bias_atom_e.detach() .cpu() .numpy() .reshape(-1) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index d3b6bd67b5..76055c6f4a 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -52,11 +52,11 @@ def test_trainable(self): fix_params["model"]["descriptor"]["trainable"] = True trainer_fix = get_trainer(fix_params) model_dict_before_training = deepcopy( - trainer_fix.model.fitting_net.state_dict() + trainer_fix.model.get_fitting_net().state_dict() ) trainer_fix.run() model_dict_after_training = deepcopy( - trainer_fix.model.fitting_net.state_dict() + trainer_fix.model.get_fitting_net().state_dict() ) else: trainer_fix = get_trainer(fix_params)