From b740b2de8b82f5c6ab0c409c4d944e26fd5ace71 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:50:55 +0800 Subject: [PATCH] Improve the model inherit --- deepmd/pt/model/atomic_model/__init__.py | 6 - .../atomic_model/wrapper_atomic_model.py | 180 ------------------ deepmd/pt/model/model/__init__.py | 11 +- deepmd/pt/model/model/dp_spin_model.py | 123 +++++++++--- deepmd/pt/model/model/dp_zbl_model.py | 1 - deepmd/pt/model/model/ener_model.py | 1 - deepmd/pt/train/training.py | 2 +- deepmd/pt/train/wrapper.py | 14 +- 8 files changed, 110 insertions(+), 228 deletions(-) delete mode 100644 deepmd/pt/model/atomic_model/wrapper_atomic_model.py diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index f41d631565..75c1ce3c2e 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -27,10 +27,6 @@ from .pairtab_atomic_model import ( PairTabAtomicModel, ) -from .wrapper_atomic_model import ( - DPSpinWrapperAtomicModel, - WrapperAtomicModel, -) __all__ = [ "BaseAtomicModel", @@ -38,6 +34,4 @@ "PairTabAtomicModel", "LinearAtomicModel", "DPZBLLinearAtomicModel", - "WrapperAtomicModel", - "DPSpinWrapperAtomicModel", ] diff --git a/deepmd/pt/model/atomic_model/wrapper_atomic_model.py b/deepmd/pt/model/atomic_model/wrapper_atomic_model.py deleted file mode 100644 index ea2b991d2f..0000000000 --- a/deepmd/pt/model/atomic_model/wrapper_atomic_model.py +++ /dev/null @@ -1,180 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import sys -from typing import ( - Dict, - List, - Optional, -) - -import torch - -from deepmd.dpmodel import ( - FittingOutputDef, -) -from deepmd.dpmodel.utils import ( - Spin, -) - -from .base_atomic_model import ( - BaseAtomicModel, -) -from .dp_atomic_model import ( - DPAtomicModel, -) - - -class WrapperAtomicModel(torch.nn.Module, BaseAtomicModel): - """Wrapper model that has an existing model inside - with additionally transformation on the input and output for the model. - - Parameters - ---------- - model : BaseAtomicModel - A model to be wrapped inside. - """ - - def __init__( - self, - model: BaseAtomicModel, - **kwargs, - ): - super().__init__() - self.model = model - - def distinguish_types(self) -> bool: - """If distinguish different types by sorting.""" - return self.model.distinguish_types() - - @torch.jit.export - def get_rcut(self) -> float: - """Get the cut-off radius.""" - return self.model.get_rcut() - - @torch.jit.export - def get_type_map(self) -> List[str]: - """Get the type map.""" - raise self.model.get_type_map() - - def get_sel(self) -> List[int]: - return self.model.get_sel() - - def forward_atomic( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, - ) -> Dict[str, torch.Tensor]: - """Return atomic prediction. - - Parameters - ---------- - extended_coord - coodinates in extended region, (nframes, nall * 3) - extended_atype - atomic type in extended region, (nframes, nall) - nlist - neighbor list, (nframes, nloc, nsel). - mapping - mapps the extended indices to local indices. - fparam - frame parameter. (nframes, ndf) - aparam - atomic parameter. (nframes, nloc, nda) - - Returns - ------- - result_dict - the result dict, defined by the fitting net output def. - """ - return self.model.forward_atomic( - extended_coord, - extended_atype, - nlist, - mapping, - fparam, - aparam, - ) - - def fitting_output_def(self) -> FittingOutputDef: - return self.model.fitting_output_def() - - @staticmethod - def serialize(model) -> dict: - return { - "model": model.serialize(), - "model_name": model.__class__.__name__, - } - - @staticmethod - def deserialize(data) -> "BaseAtomicModel": - model = getattr(sys.modules[__name__], data["model_name"]).deserialize( - data["model"] - ) - return model - - @torch.jit.export - def get_dim_fparam(self) -> int: - """Get the number (dimension) of frame parameters of this atomic model.""" - return self.model.get_dim_fparam() - - @torch.jit.export - def get_dim_aparam(self) -> int: - """Get the number (dimension) of atomic parameters of this atomic model.""" - return self.model.get_dim_aparam() - - @torch.jit.export - def get_sel_type(self) -> List[int]: - """Get the selected atom types of this model. - - Only atoms with selected atom types have atomic contribution - to the result of the model. - If returning an empty list, all atom types are selected. - """ - return self.model.get_sel_type() - - @torch.jit.export - def is_aparam_nall(self) -> bool: - """Check whether the shape of atomic parameters is (nframes, nall, ndim). - - If False, the shape is (nframes, nloc, ndim). - """ - return False - - -class DPSpinWrapperAtomicModel(WrapperAtomicModel): - """Spin model wrapper with an AtomicModel. - - Parameters - ---------- - backbone_model - The backbone model wrapped inside. - spin - The object containing spin settings. - """ - - def __init__( - self, - backbone_model: DPAtomicModel, - spin: Spin, - **kwargs, - ): - super().__init__(backbone_model, **kwargs) - self.spin = spin - - def serialize(self) -> dict: - return { - "wrapper_model": WrapperAtomicModel.serialize(self.model), - "spin": self.spin.serialize(), - } - - @classmethod - def deserialize(cls, data) -> "DPSpinWrapperAtomicModel": - spin = Spin.deserialize(data["spin"]) - backbone_model = WrapperAtomicModel.deserialize(data["wrapper_model"]) - return cls( - backbone_model=backbone_model, - spin=spin, - ) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index acdb0a055d..ec0fe34cea 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -29,6 +29,7 @@ DPModel, ) from .dp_spin_model import ( + SpinEnergyModel, SpinModel, ) from .dp_zbl_model import ( @@ -136,18 +137,17 @@ def get_spin_model(model_params): fitting_net = model_params.get("fitting_net", None) fitting_net["type"] = fitting_net.get("type", "ener") fitting_net["ntypes"] = descriptor.get_ntypes() - fitting_net["distinguish_types"] = descriptor.distinguish_types() + fitting_net["mixed_types"] = descriptor.mixed_types() fitting_net["embedding_width"] = descriptor.get_dim_out() + fitting_net["dim_descrpt"] = descriptor.get_dim_out() grad_force = "direct" not in fitting_net["type"] if not grad_force: fitting_net["out_dim"] = descriptor.get_dim_emb() if "ener" in fitting_net["type"]: fitting_net["return_energy"] = True fitting = Fitting(**fitting_net) - backbone_model = DPAtomicModel( - descriptor, fitting, type_map=model_params["type_map"] - ) - return SpinModel(backbone_model=backbone_model, spin=spin) + backbone_model = DPModel(descriptor, fitting, type_map=model_params["type_map"]) + return SpinEnergyModel(backbone_model=backbone_model, spin=spin) def get_model(model_params): @@ -166,6 +166,7 @@ def get_model(model_params): "DPModel", "EnergyModel", "SpinModel", + "SpinEnergyModel", "DPZBLModel", "make_model", "make_hessian_model", diff --git a/deepmd/pt/model/model/dp_spin_model.py b/deepmd/pt/model/model/dp_spin_model.py index b38384e1ef..1f32ad45d5 100644 --- a/deepmd/pt/model/model/dp_spin_model.py +++ b/deepmd/pt/model/model/dp_spin_model.py @@ -6,28 +6,20 @@ import torch -from deepmd.pt.model.atomic_model import ( - DPSpinWrapperAtomicModel, -) - -from .make_model import ( - make_model, -) -DPSpinModel_ = make_model(DPSpinWrapperAtomicModel) - - -class SpinModel(DPSpinModel_): +class SpinModel(torch.nn.Module): """A spin model wrapper, with spin input preprocess and output split.""" - model_type = "ener" + __USE_SPIN_INPUT__ = True def __init__( self, - *args, - **kwargs, + backbone_model, + spin, ): - super().__init__(*args, **kwargs) + super().__init__() + self.backbone_model = backbone_model + self.spin = spin def preprocess_spin_input(self, coord, atype, spin): nframes, nloc = coord.shape[:-1] @@ -40,29 +32,41 @@ def preprocess_spin_input(self, coord, atype, spin): return coord_spin, atype_spin def preprocess_spin_output(self, atype, force): - nframes, nloc_double = force.shape[:-1] + nframes, nloc_double = force.shape[:2] nloc = nloc_double // 2 virtual_scale_mask = self.spin.get_virtual_scale_mask() atmoic_mask = torch.gather( virtual_scale_mask, -1, index=atype.view(-1) ).reshape([nframes, nloc, 1]) - force_real, force_mag = torch.split(force, [nloc, nloc], dim=-2) - force_mag = force_mag * atmoic_mask + force_real, force_mag = torch.split(force, [nloc, nloc], dim=1) + force_mag = (force_mag.view([nframes, nloc, -1]) * atmoic_mask).view( + force_mag.shape + ) return force_real, force_mag, atmoic_mask > 0.0 - def forward( + def __getattr__(self, name): + """Get attribute from the wrapped model.""" + if ( + name == "backbone_model" + ): # torch.nn.Module will exclude modules to self.__dict__["_modules"] + return self.__dict__["_modules"]["backbone_model"] + elif name in self.__dict__: + return self.__dict__[name] + else: + return getattr(self.backbone_model, name) + + def forward_common( self, coord, atype, + spin, box: Optional[torch.Tensor] = None, - spin: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, - **kwargs, ) -> Dict[str, torch.Tensor]: coord_updated, atype_updated = self.preprocess_spin_input(coord, atype, spin) - model_ret = self.forward_common( + model_ret = self.backbone_model.forward_common( coord_updated, atype_updated, box, @@ -70,23 +74,82 @@ def forward( aparam=aparam, do_atomic_virial=do_atomic_virial, ) + if self.fitting_net is not None: + var_name = self.fitting_net.var_name + if self.do_grad(var_name): + force_all = model_ret[f"{var_name}_derv_r"] + ( + model_ret[f"{var_name}_derv_r_real"], + model_ret[f"{var_name}_derv_r_mag"], + model_ret["atmoic_mask"], + ) = self.preprocess_spin_output(atype, force_all) + else: + force_all = model_ret["dforce"] + ( + model_ret["dforce_real"], + model_ret["dforce_mag"], + model_ret["atmoic_mask"], + ) = self.preprocess_spin_output(atype, force_all) + return model_ret + + def forward_common_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + ## TODO preprocess + raise NotImplementedError("Not implemented forward_common_lower for spin") + +class SpinEnergyModel(SpinModel): + """A spin model for energy.""" + + model_type = "ener" + + def __init__( + self, + backbone_model, + spin, + ): + super().__init__(backbone_model, spin) + + def forward( + self, + coord, + atype, + spin, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + spin, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] - + model_predict["atmoic_mask"] = model_ret["atmoic_mask"] if self.do_grad("energy"): - force_all = model_ret["energy_derv_r"].squeeze(-2) + model_predict["force_real"] = model_ret["energy_derv_r_real"].squeeze(-2) + model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) else: - force_all = model_ret["dforce"] - ( - model_predict["force_real"], - model_predict["force_mag"], - model_predict["atmoic_mask"], - ) = self.preprocess_spin_output(atype, force_all) + model_predict["force_real"] = model_ret["dforce_real"] + model_predict["force_mag"] = model_ret["dforce_mag"] return model_predict def forward_lower( diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index 259a2a4838..8d71157b60 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -35,7 +35,6 @@ def forward( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, - **kwargs, ) -> Dict[str, torch.Tensor]: model_ret = self.forward_common( coord, diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index a41c051cff..2afeb2762b 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -29,7 +29,6 @@ def forward( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, - **kwargs, ) -> Dict[str, torch.Tensor]: model_ret = self.forward_common( coord, diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index bdfe338e92..353b4fe7c7 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -803,8 +803,8 @@ def get_data(self, is_train=True, task_key="Default"): for item in [ "coord", "atype", - "box", "spin", + "box", ]: if item in batch_data: input_dict[item] = batch_data[item] diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index ae5970649b..602c54b5be 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -158,8 +158,8 @@ def forward( self, coord, atype, - box: Optional[torch.Tensor] = None, spin: Optional[torch.Tensor] = None, + box: Optional[torch.Tensor] = None, cur_lr: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, task_key: Optional[torch.Tensor] = None, @@ -172,9 +172,15 @@ def forward( assert ( task_key is not None ), f"Multitask model must specify the inference task! Supported tasks are {list(self.model.keys())}." - model_pred = self.model[task_key]( - coord, atype, box=box, spin=spin, do_atomic_virial=do_atomic_virial - ) + input_dict = { + "coord": coord, + "atype": atype, + "box": box, + "do_atomic_virial": do_atomic_virial, + } + if getattr(self.model[task_key], "__USE_SPIN_INPUT__", False): + input_dict["spin"] = spin + model_pred = self.model[task_key](**input_dict) natoms = atype.shape[-1] if not self.inference_only and not inference_only: loss, more_loss = self.loss[task_key](