From f9b0b06e2b04f8b5409e96065761c59bc3fc0d91 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 12 Mar 2024 22:19:11 +0800 Subject: [PATCH] feat: add dp --- deepmd/dpmodel/fitting/dos_fitting.py | 81 ++++++++ deepmd/pt/model/model/dos_model.py | 80 ++++++++ deepmd/pt/model/model/dp_model.py | 4 +- deepmd/pt/model/task/dos.py | 6 +- deepmd/tf/fit/dos.py | 108 +++++++++-- source/tests/consistent/fitting/test_dos.py | 200 ++++++++++++++++++++ 6 files changed, 466 insertions(+), 13 deletions(-) create mode 100644 deepmd/dpmodel/fitting/dos_fitting.py create mode 100644 deepmd/pt/model/model/dos_model.py create mode 100644 source/tests/consistent/fitting/test_dos.py diff --git a/deepmd/dpmodel/fitting/dos_fitting.py b/deepmd/dpmodel/fitting/dos_fitting.py new file mode 100644 index 0000000000..2cf2882074 --- /dev/null +++ b/deepmd/dpmodel/fitting/dos_fitting.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +from typing import ( + TYPE_CHECKING, + Any, + List, + Optional, +) + +from deepmd.dpmodel.common import ( + DEFAULT_PRECISION, +) +from deepmd.dpmodel.fitting.invar_fitting import ( + InvarFitting, +) + +if TYPE_CHECKING: + from deepmd.dpmodel.fitting.general_fitting import ( + GeneralFitting, + ) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +@InvarFitting.register("dos") +class DOSFittingNet(InvarFitting): + def __init__( + self, + ntypes: int, + dim_descrpt: int, + neuron: List[int] = [120, 120, 120], + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + numb_dos: int = 300, + rcond: Optional[float] = None, + trainable: Optional[List[bool]] = None, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = False, + exclude_types: List[int] = [], + # not used + seed: Optional[int] = None, + ): + super().__init__( + var_name="dos", + ntypes=ntypes, + dim_descrpt=dim_descrpt, + dim_out=numb_dos, + neuron=neuron, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + rcond=rcond, + trainable=trainable, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + exclude_types=exclude_types, + ) + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("var_name") + data.pop("dim_out") + data.pop("tot_ener_zero") + data.pop("layer_name") + data.pop("use_aparam_as_mask") + data.pop("spin") + data.pop("atom_ener") + return super().deserialize(data) + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + **super().serialize(), + "type": "dos", + } diff --git a/deepmd/pt/model/model/dos_model.py b/deepmd/pt/model/model/dos_model.py new file mode 100644 index 0000000000..a8ba826851 --- /dev/null +++ b/deepmd/pt/model/model/dos_model.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + Optional, +) + +import torch + +from .dp_model import ( + DPModel, +) + + +class DOSModel(DPModel): + model_type = "dos" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_dos"] = model_ret["dos"] + model_predict["dos"] = model_ret["dos_redu"] + + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_dos"] = model_ret["dos"] + model_predict["dos"] = model_ret["energy_redu"] + + else: + model_predict = model_ret + return model_predict diff --git a/deepmd/pt/model/model/dp_model.py b/deepmd/pt/model/model/dp_model.py index 93f10a5819..38a7b47dbf 100644 --- a/deepmd/pt/model/model/dp_model.py +++ b/deepmd/pt/model/model/dp_model.py @@ -54,7 +54,9 @@ def __new__( from deepmd.pt.model.model.polar_model import ( PolarModel, ) - from deepmd.pt.model.model.dos_model + from deepmd.pt.model.model.dos_model import ( + DOSModel, + ) if atomic_model_ is not None: fitting = atomic_model_.fitting_net diff --git a/deepmd/pt/model/task/dos.py b/deepmd/pt/model/task/dos.py index 49c5a4b351..fe38f2755a 100644 --- a/deepmd/pt/model/task/dos.py +++ b/deepmd/pt/model/task/dos.py @@ -9,6 +9,9 @@ import torch from deepmd.pt.model.task.fitting import ( + Fitting, +) +from deepmd.pt.model.task.ener import ( InvarFitting, ) from deepmd.pt.utils import ( @@ -64,11 +67,12 @@ def __init__( rcond=rcond, seed=seed, exclude_types=exclude_types, + trainable=trainable, **kwargs, ) @classmethod - def deserialize(cls, data: dict) -> "DOSFittingNet": + def deserialize(cls, data: dict) -> "InvarFitting": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 1, 1) data.pop("var_name") diff --git a/deepmd/tf/fit/dos.py b/deepmd/tf/fit/dos.py index 0cc5a7df62..31f1c7b019 100644 --- a/deepmd/tf/fit/dos.py +++ b/deepmd/tf/fit/dos.py @@ -46,6 +46,9 @@ from deepmd.utils.out_stat import ( compute_stats_from_redu, ) +from deepmd.utils.version import ( + check_version_compatibility, +) log = logging.getLogger(__name__) @@ -57,8 +60,10 @@ class DOSFitting(Fitting): Parameters ---------- - descrpt - The descrptor :math:`\mathcal{D}` + ntypes + The ntypes of the descrptor :math:`\mathcal{D}` + dim_descrpt + The dimension of the descrptor :math:`\mathcal{D}` neuron Number of neurons :math:`N` in each hidden layer of the fitting net resnet_dt @@ -94,7 +99,8 @@ class DOSFitting(Fitting): def __init__( self, - descrpt: tf.Tensor, + ntypes: int, + dim_descrpt: int, neuron: List[int] = [120, 120, 120], resnet_dt: bool = True, numb_fparam: int = 0, @@ -112,8 +118,8 @@ def __init__( ) -> None: """Constructor.""" # model param - self.ntypes = descrpt.get_ntypes() - self.dim_descrpt = descrpt.get_dim_out() + self.ntypes = ntypes + self.dim_descrpt = dim_descrpt self.use_aparam_as_mask = use_aparam_as_mask self.numb_fparam = numb_fparam @@ -127,6 +133,7 @@ def __init__( self.seed = seed self.uniform_seed = uniform_seed self.seed_shift = one_layer_rand_seed_shift() + self.activation_function = activation_function self.fitting_activation_fn = get_activation_func(activation_function) self.fitting_precision = get_precision(precision) self.trainable = trainable @@ -145,16 +152,16 @@ def __init__( add_data_requirement( "fparam", self.numb_fparam, atomic=False, must=True, high_prec=False ) - self.fparam_avg = None - self.fparam_std = None - self.fparam_inv_std = None + self.fparam_avg = None + self.fparam_std = None + self.fparam_inv_std = None if self.numb_aparam > 0: add_data_requirement( "aparam", self.numb_aparam, atomic=True, must=True, high_prec=False ) - self.aparam_avg = None - self.aparam_std = None - self.aparam_inv_std = None + self.aparam_avg = None + self.aparam_std = None + self.aparam_inv_std = None self.fitting_net_variables = None self.mixed_prec = None @@ -641,3 +648,82 @@ def get_loss(self, loss: dict, lr) -> Loss: return DOSLoss( **loss, starter_learning_rate=lr.start_lr(), numb_dos=self.get_numb_dos() ) + + @classmethod + def deserialize(cls, data: dict, suffix: str = ""): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + fitting = cls(**data) + fitting.fitting_net_variables = cls.deserialize_network( + data["nets"], + suffix=suffix, + ) + fitting.bias_dos = data["@variables"]["bias_dos"] + if fitting.numb_fparam > 0: + fitting.fparam_avg = data["@variables"]["fparam_avg"] + fitting.fparam_inv_std = data["@variables"]["fparam_inv_std"] + if fitting.numb_aparam > 0: + fitting.aparam_avg = data["@variables"]["aparam_avg"] + fitting.aparam_inv_std = data["@variables"]["aparam_inv_std"] + return fitting + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + data = { + "@class": "Fitting", + "type": "dos", + "@version": 1, + "var_name": "dos", + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + # very bad design: type embedding is not passed to the class + # TODO: refactor the class + "mixed_types": False, + "dim_out": 1, + "neuron": self.n_neuron, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "rcond": self.rcond, + "trainable": self.trainable, + "activation_function": self.activation_function, + "precision": self.fitting_precision.name, + "exclude_types": [], + "nets": self.serialize_network( + ntypes=self.ntypes, + # TODO: consider type embeddings + ndim=1, + in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam, + neuron=self.n_neuron, + activation_function=self.activation_function, + resnet_dt=self.resnet_dt, + variables=self.fitting_net_variables, + suffix=suffix, + ), + "@variables": { + "bias_dos": self.bias_dos, + "fparam_avg": self.fparam_avg, + "fparam_inv_std": self.fparam_inv_std, + "aparam_avg": self.aparam_avg, + "aparam_inv_std": self.aparam_inv_std, + }, + } + return data \ No newline at end of file diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py new file mode 100644 index 0000000000..9af832c550 --- /dev/null +++ b/source/tests/consistent/fitting/test_dos.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Tuple, +) + +import numpy as np + +from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingDP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_TF, + CommonTest, + parameterized, +) +from .common import ( + FittingTest, +) + +if INSTALLED_PT: + import torch + + from deepmd.pt.model.task.dos import DOSFittingNet as DOSFittingPT + from deepmd.pt.utils.env import DEVICE as PT_DEVICE +else: + DOSFittingPT = object +if INSTALLED_TF: + from deepmd.tf.fit.dos import DOSFitting as DOSFittingTF +else: + DOSFittingTF = object +from deepmd.utils.argcheck import ( + fitting_dos, +) + + +@parameterized( + (True, False), # resnet_dt + ("float64", "float32"), # precision + (True, False), # mixed_types + (0, 1), # numb_fparam +) +class TestDOS(CommonTest, FittingTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + ) = self.param + return { + "neuron": [5, 5, 5], + "resnet_dt": resnet_dt, + "precision": precision, + "numb_fparam": numb_fparam, + "seed": 20240217, + } + + @property + def skip_tf(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + ) = self.param + # TODO: mixed_types + return mixed_types or CommonTest.skip_pt + + @property + def skip_pt(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + ) = self.param + return CommonTest.skip_pt + + tf_class = DOSFittingTF + dp_class = DOSFittingDP + pt_class = DOSFittingPT + args = fitting_dos() + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + self.inputs = np.ones((1, 6, 20), dtype=GLOBAL_NP_FLOAT_PRECISION) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + # inconsistent if not sorted + self.atype.sort() + self.fparam = -np.ones((1,), dtype=GLOBAL_NP_FLOAT_PRECISION) + + @property + def addtional_data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + ) = self.param + return { + "ntypes": self.ntypes, + "dim_descrpt": self.inputs.shape[-1], + "mixed_types": mixed_types, + } + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + ) = self.param + return self.build_tf_fitting( + obj, + self.inputs.ravel(), + self.natoms, + self.atype, + self.fparam if numb_fparam else None, + suffix, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + ) = self.param + return ( + pt_obj( + torch.from_numpy(self.inputs).to(device=PT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE), + fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE) + if numb_fparam + else None, + )["dos"] + .detach() + .cpu() + .numpy() + ) + + def eval_dp(self, dp_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + ) = self.param + return dp_obj( + self.inputs, + self.atype.reshape(1, -1), + fparam=self.fparam if numb_fparam else None, + )["dos"] + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + if backend == self.RefBackend.TF: + # shape is not same + ret = ret[0].reshape(-1, self.natoms[0], 1) + return (ret,) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}")