diff --git a/deepmd/dpmodel/fitting/__init__.py b/deepmd/dpmodel/fitting/__init__.py index 929a63fda7..866a710a3b 100644 --- a/deepmd/dpmodel/fitting/__init__.py +++ b/deepmd/dpmodel/fitting/__init__.py @@ -2,6 +2,9 @@ from .dipole_fitting import ( DipoleFitting, ) +from .dos_fitting import ( + DOSFittingNet, +) from .ener_fitting import ( EnergyFittingNet, ) @@ -21,4 +24,5 @@ "DipoleFitting", "EnergyFittingNet", "PolarFitting", + "DOSFittingNet", ] diff --git a/deepmd/dpmodel/fitting/dos_fitting.py b/deepmd/dpmodel/fitting/dos_fitting.py new file mode 100644 index 0000000000..7c86d392b0 --- /dev/null +++ b/deepmd/dpmodel/fitting/dos_fitting.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +from typing import ( + TYPE_CHECKING, + List, + Optional, + Union, +) + +import numpy as np + +from deepmd.dpmodel.common import ( + DEFAULT_PRECISION, +) +from deepmd.dpmodel.fitting.invar_fitting import ( + InvarFitting, +) + +if TYPE_CHECKING: + from deepmd.dpmodel.fitting.general_fitting import ( + GeneralFitting, + ) + +from deepmd.utils.version import ( + check_version_compatibility, +) + + +@InvarFitting.register("dos") +class DOSFittingNet(InvarFitting): + def __init__( + self, + ntypes: int, + dim_descrpt: int, + numb_dos: int = 300, + neuron: List[int] = [120, 120, 120], + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + bias_dos: Optional[np.ndarray] = None, + rcond: Optional[float] = None, + trainable: Union[bool, List[bool]] = True, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = False, + exclude_types: List[int] = [], + # not used + seed: Optional[int] = None, + ): + if bias_dos is not None: + self.bias_dos = bias_dos + else: + self.bias_dos = np.zeros((ntypes, numb_dos), dtype=DEFAULT_PRECISION) + super().__init__( + var_name="dos", + ntypes=ntypes, + dim_descrpt=dim_descrpt, + dim_out=numb_dos, + neuron=neuron, + resnet_dt=resnet_dt, + bias_atom=bias_dos, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + rcond=rcond, + trainable=trainable, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + exclude_types=exclude_types, + ) + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + data["numb_dos"] = data.pop("dim_out") + data.pop("tot_ener_zero", None) + data.pop("var_name", None) + data.pop("layer_name", None) + data.pop("use_aparam_as_mask", None) + data.pop("spin", None) + data.pop("atom_ener", None) + return super().deserialize(data) + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + dd = { + **super().serialize(), + "type": "dos", + } + dd["@variables"]["bias_atom_e"] = self.bias_atom_e + + return dd diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index e9dddae2de..3b0d022562 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -40,6 +40,8 @@ class GeneralFitting(NativeOP, BaseFitting): The dimension of the input descriptor. neuron Number of neurons :math:`N` in each hidden layer of the fitting net + bias_atom_e + Average enery per atom for each element. resnet_dt Time-step `dt` in the resnet construction: :math:`y = x + dt * \phi (Wx + b)` @@ -85,6 +87,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + bias_atom_e: Optional[np.ndarray] = None, rcond: Optional[float] = None, tot_ener_zero: bool = False, trainable: Optional[List[bool]] = None, @@ -125,7 +128,11 @@ def __init__( net_dim_out = self._net_out_dim() # init constants - self.bias_atom_e = np.zeros([self.ntypes, net_dim_out]) + if bias_atom_e is None: + self.bias_atom_e = np.zeros([self.ntypes, net_dim_out]) + else: + assert bias_atom_e.shape == (self.ntypes, net_dim_out) + self.bias_atom_e = bias_atom_e if self.numb_fparam > 0: self.fparam_avg = np.zeros(self.numb_fparam) self.fparam_inv_std = np.ones(self.numb_fparam) diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index f7c091843b..9bf1731830 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -82,6 +82,8 @@ class InvarFitting(GeneralFitting): Number of atomic parameter rcond The condition number for the regression of atomic energy. + bias_atom + Bias for each element. tot_ener_zero Force the total energy to zero. Useful for the charge fitting. trainable @@ -117,10 +119,11 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + bias_atom: Optional[np.ndarray] = None, rcond: Optional[float] = None, tot_ener_zero: bool = False, trainable: Optional[List[bool]] = None, - atom_ener: Optional[List[float]] = [], + atom_ener: Optional[List[float]] = None, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, layer_name: Optional[List[Optional[str]]] = None, @@ -152,6 +155,7 @@ def __init__( numb_fparam=numb_fparam, numb_aparam=numb_aparam, rcond=rcond, + bias_atom_e=bias_atom, tot_ener_zero=tot_ener_zero, trainable=trainable, activation_function=activation_function, diff --git a/deepmd/infer/deep_dos.py b/deepmd/infer/deep_dos.py index d95d2a119f..7823f02999 100644 --- a/deepmd/infer/deep_dos.py +++ b/deepmd/infer/deep_dos.py @@ -56,6 +56,11 @@ def output_def(self) -> ModelOutputDef: ) ) + @property + def numb_dos(self) -> int: + """Get the number of DOS.""" + return self.get_numb_dos() + def eval( self, coords: np.ndarray, diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index f93ec88bde..7a2070e476 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -140,7 +140,8 @@ def get_standard_model(model_params): fitting_net["type"] = fitting_net.get("type", "ener") fitting_net["ntypes"] = descriptor.get_ntypes() fitting_net["mixed_types"] = descriptor.mixed_types() - fitting_net["embedding_width"] = descriptor.get_dim_emb() + if fitting_net["type"] in ["dipole", "polar"]: + fitting_net["embedding_width"] = descriptor.get_dim_emb() fitting_net["dim_descrpt"] = descriptor.get_dim_out() grad_force = "direct" not in fitting_net["type"] if not grad_force: diff --git a/deepmd/pt/model/model/dos_model.py b/deepmd/pt/model/model/dos_model.py new file mode 100644 index 0000000000..680eac41f5 --- /dev/null +++ b/deepmd/pt/model/model/dos_model.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + Optional, +) + +import torch + +from .dp_model import ( + DPModel, +) + + +class DOSModel(DPModel): + model_type = "dos" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_dos"] = model_ret["dos"] + model_predict["dos"] = model_ret["dos_redu"] + + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_dos"] = model_ret["dos"] + model_predict["dos"] = model_ret["dos_redu"] + + else: + model_predict = model_ret + return model_predict diff --git a/deepmd/pt/model/model/dp_model.py b/deepmd/pt/model/model/dp_model.py index 138398539a..d7b3c4f4e2 100644 --- a/deepmd/pt/model/model/dp_model.py +++ b/deepmd/pt/model/model/dp_model.py @@ -18,6 +18,9 @@ from deepmd.pt.model.task.dipole import ( DipoleFittingNet, ) +from deepmd.pt.model.task.dos import ( + DOSFittingNet, +) from deepmd.pt.model.task.ener import ( EnergyFittingNet, EnergyFittingNetDirect, @@ -45,6 +48,9 @@ def __new__( from deepmd.pt.model.model.dipole_model import ( DipoleModel, ) + from deepmd.pt.model.model.dos_model import ( + DOSModel, + ) from deepmd.pt.model.model.ener_model import ( EnergyModel, ) @@ -68,6 +74,8 @@ def __new__( cls = DipoleModel elif isinstance(fitting, PolarFittingNet): cls = PolarModel + elif isinstance(fitting, DOSFittingNet): + cls = DOSModel # else: unknown fitting type, fall back to DPModel return super().__new__(cls) diff --git a/deepmd/pt/model/task/dos.py b/deepmd/pt/model/task/dos.py new file mode 100644 index 0000000000..c37b05277a --- /dev/null +++ b/deepmd/pt/model/task/dos.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import logging +from typing import ( + List, + Optional, + Union, +) + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.task.ener import ( + InvarFitting, +) +from deepmd.pt.model.task.fitting import ( + Fitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + +log = logging.getLogger(__name__) + + +@Fitting.register("dos") +class DOSFittingNet(InvarFitting): + def __init__( + self, + ntypes: int, + dim_descrpt: int, + numb_dos: int = 300, + neuron: List[int] = [128, 128, 128], + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + rcond: Optional[float] = None, + bias_dos: Optional[torch.Tensor] = None, + trainable: Union[bool, List[bool]] = True, + seed: Optional[int] = None, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + exclude_types: List[int] = [], + mixed_types: bool = True, + ): + if bias_dos is not None: + self.bias_dos = bias_dos + else: + self.bias_dos = torch.zeros( + (ntypes, numb_dos), dtype=dtype, device=env.DEVICE + ) + super().__init__( + var_name="dos", + ntypes=ntypes, + dim_descrpt=dim_descrpt, + dim_out=numb_dos, + neuron=neuron, + bias_atom_e=bias_dos, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + rcond=rcond, + seed=seed, + exclude_types=exclude_types, + trainable=trainable, + ) + + def output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + self.var_name, + [self.dim_out], + reduciable=True, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + + @classmethod + def deserialize(cls, data: dict) -> "DOSFittingNet": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("var_name", None) + data.pop("tot_ener_zero", None) + data.pop("layer_name", None) + data.pop("use_aparam_as_mask", None) + data.pop("spin", None) + data.pop("atom_ener", None) + data["numb_dos"] = data.pop("dim_out") + obj = super().deserialize(data) + + return obj + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + # dd = super(InvarFitting, self).serialize() + dd = { + **InvarFitting.serialize(self), + "type": "dos", + "dim_out": self.dim_out, + } + dd["@variables"]["bias_atom_e"] = to_numpy_array(self.bias_atom_e) + + return dd + + # make jit happy with torch 2.0.0 + exclude_types: List[int] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index b20d80c629..fc293f70ec 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -274,6 +274,9 @@ def get_loss(loss_params, start_lr, _ntypes, _model): if loss_type == "ener": loss_params["starter_learning_rate"] = start_lr return EnergyStdLoss(**loss_params) + elif loss_type == "dos": + loss_params["starter_learning_rate"] = start_lr + raise NotImplementedError() elif loss_type == "ener_spin": loss_params["starter_learning_rate"] = start_lr return EnergySpinLoss(**loss_params) diff --git a/deepmd/tf/fit/dos.py b/deepmd/tf/fit/dos.py index 0cc5a7df62..aef134da92 100644 --- a/deepmd/tf/fit/dos.py +++ b/deepmd/tf/fit/dos.py @@ -46,6 +46,9 @@ from deepmd.utils.out_stat import ( compute_stats_from_redu, ) +from deepmd.utils.version import ( + check_version_compatibility, +) log = logging.getLogger(__name__) @@ -57,8 +60,10 @@ class DOSFitting(Fitting): Parameters ---------- - descrpt - The descrptor :math:`\mathcal{D}` + ntypes + The ntypes of the descrptor :math:`\mathcal{D}` + dim_descrpt + The dimension of the descrptor :math:`\mathcal{D}` neuron Number of neurons :math:`N` in each hidden layer of the fitting net resnet_dt @@ -94,7 +99,8 @@ class DOSFitting(Fitting): def __init__( self, - descrpt: tf.Tensor, + ntypes: int, + dim_descrpt: int, neuron: List[int] = [120, 120, 120], resnet_dt: bool = True, numb_fparam: int = 0, @@ -112,8 +118,8 @@ def __init__( ) -> None: """Constructor.""" # model param - self.ntypes = descrpt.get_ntypes() - self.dim_descrpt = descrpt.get_dim_out() + self.ntypes = ntypes + self.dim_descrpt = dim_descrpt self.use_aparam_as_mask = use_aparam_as_mask self.numb_fparam = numb_fparam @@ -127,6 +133,7 @@ def __init__( self.seed = seed self.uniform_seed = uniform_seed self.seed_shift = one_layer_rand_seed_shift() + self.activation_function = activation_function self.fitting_activation_fn = get_activation_func(activation_function) self.fitting_precision = get_precision(precision) self.trainable = trainable @@ -145,16 +152,16 @@ def __init__( add_data_requirement( "fparam", self.numb_fparam, atomic=False, must=True, high_prec=False ) - self.fparam_avg = None - self.fparam_std = None - self.fparam_inv_std = None + self.fparam_avg = None + self.fparam_std = None + self.fparam_inv_std = None if self.numb_aparam > 0: add_data_requirement( "aparam", self.numb_aparam, atomic=True, must=True, high_prec=False ) - self.aparam_avg = None - self.aparam_std = None - self.aparam_inv_std = None + self.aparam_avg = None + self.aparam_std = None + self.aparam_inv_std = None self.fitting_net_variables = None self.mixed_prec = None @@ -521,7 +528,11 @@ def build( final_layer = tf.reshape( final_layer, - [tf.shape(inputs)[0] * self.numb_dos, natoms[2 + type_i]], + [ + tf.shape(inputs)[0], + natoms[2 + type_i], + self.numb_dos, + ], ) outs_list.append(final_layer) start_index += natoms[2 + type_i] @@ -550,7 +561,8 @@ def build( ) outs = tf.reshape( - final_layer, [tf.shape(inputs)[0] * self.numb_dos, natoms[0]] + final_layer, + [tf.shape(inputs)[0], natoms[0], self.numb_dos], ) # add bias # self.atom_ener_before = outs @@ -562,7 +574,7 @@ def build( # self.atom_ener_after = outs tf.summary.histogram("fitting_net_output", outs) - return tf.reshape(outs, [-1]) + return outs def init_variables( self, @@ -641,3 +653,84 @@ def get_loss(self, loss: dict, lr) -> Loss: return DOSLoss( **loss, starter_learning_rate=lr.start_lr(), numb_dos=self.get_numb_dos() ) + + @classmethod + def deserialize(cls, data: dict, suffix: str = ""): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data["numb_dos"] = data.pop("dim_out") + fitting = cls(**data) + fitting.fitting_net_variables = cls.deserialize_network( + data["nets"], + suffix=suffix, + ) + fitting.bias_dos = data["@variables"]["bias_atom_e"] + if fitting.numb_fparam > 0: + fitting.fparam_avg = data["@variables"]["fparam_avg"] + fitting.fparam_inv_std = data["@variables"]["fparam_inv_std"] + if fitting.numb_aparam > 0: + fitting.aparam_avg = data["@variables"]["aparam_avg"] + fitting.aparam_inv_std = data["@variables"]["aparam_inv_std"] + return fitting + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + data = { + "@class": "Fitting", + "type": "dos", + "@version": 1, + "var_name": "dos", + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + # very bad design: type embedding is not passed to the class + # TODO: refactor the class + "mixed_types": False, + "dim_out": self.numb_dos, + "neuron": self.n_neuron, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "rcond": self.rcond, + "trainable": self.trainable, + "activation_function": self.activation_function, + "precision": self.fitting_precision.name, + "exclude_types": [], + "nets": self.serialize_network( + ntypes=self.ntypes, + # TODO: consider type embeddings + ndim=1, + in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam, + out_dim=self.numb_dos, + neuron=self.n_neuron, + activation_function=self.activation_function, + resnet_dt=self.resnet_dt, + variables=self.fitting_net_variables, + suffix=suffix, + ), + "@variables": { + "bias_atom_e": self.bias_dos, + "fparam_avg": self.fparam_avg, + "fparam_inv_std": self.fparam_inv_std, + "aparam_avg": self.aparam_avg, + "aparam_inv_std": self.aparam_inv_std, + }, + } + return data diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 5a35ced0a1..cbcb987c89 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -252,7 +252,7 @@ def test_tf_consistent_with_ref(self): tf_obj = self.tf_class.deserialize(data1, suffix=self.unique_id) ret2, data2 = self.get_tf_ret_serialization_from_cls(tf_obj) ret2 = self.extract_ret(ret2, self.RefBackend.TF) - if tf_obj.__class__.__name__.startswith(("Polar", "Dipole")): + if tf_obj.__class__.__name__.startswith(("Polar", "Dipole", "DOS")): # tf, pt serialization mismatch common_keys = set(data1.keys()) & set(data2.keys()) data1 = {k: data1[k] for k in common_keys} @@ -331,7 +331,7 @@ def test_pt_consistent_with_ref(self): ret2 = self.eval_pt(obj) ret2 = self.extract_ret(ret2, self.RefBackend.PT) data2 = obj.serialize() - if obj.__class__.__name__.startswith(("Polar", "Dipole")): + if obj.__class__.__name__.startswith(("Polar", "Dipole", "DOS")): # tf, pt serialization mismatch common_keys = set(data1.keys()) & set(data2.keys()) data1 = {k: data1[k] for k in common_keys} diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py new file mode 100644 index 0000000000..2832d67641 --- /dev/null +++ b/source/tests/consistent/fitting/test_dos.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Tuple, +) + +import numpy as np + +from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingDP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_TF, + CommonTest, + parameterized, +) +from .common import ( + FittingTest, +) + +if INSTALLED_PT: + import torch + + from deepmd.pt.model.task.dos import DOSFittingNet as DOSFittingPT + from deepmd.pt.utils.env import DEVICE as PT_DEVICE +else: + DOSFittingPT = object +if INSTALLED_TF: + from deepmd.tf.fit.dos import DOSFitting as DOSFittingTF +else: + DOSFittingTF = object +from deepmd.utils.argcheck import ( + fitting_dos, +) + + +@parameterized( + (True, False), # resnet_dt + ("float64", "float32"), # precision + (True, False), # mixed_types + (0, 1), # numb_fparam + (10, 20), # numb_dos +) +class TestDOS(CommonTest, FittingTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + return { + "neuron": [5, 5, 5], + "resnet_dt": resnet_dt, + "precision": precision, + "numb_fparam": numb_fparam, + "seed": 20240217, + "numb_dos": numb_dos, + } + + @property + def skip_tf(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + # TODO: mixed_types + return mixed_types or CommonTest.skip_pt + + @property + def skip_pt(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + return CommonTest.skip_pt + + tf_class = DOSFittingTF + dp_class = DOSFittingDP + pt_class = DOSFittingPT + args = fitting_dos() + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + self.inputs = np.ones((1, 6, 20), dtype=GLOBAL_NP_FLOAT_PRECISION) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + # inconsistent if not sorted + self.atype.sort() + self.fparam = -np.ones((1,), dtype=GLOBAL_NP_FLOAT_PRECISION) + + @property + def addtional_data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + return { + "ntypes": self.ntypes, + "dim_descrpt": self.inputs.shape[-1], + "mixed_types": mixed_types, + } + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + return self.build_tf_fitting( + obj, + self.inputs.ravel(), + self.natoms, + self.atype, + self.fparam if numb_fparam else None, + suffix, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + return ( + pt_obj( + torch.from_numpy(self.inputs).to(device=PT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE), + fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE) + if numb_fparam + else None, + )["dos"] + .detach() + .cpu() + .numpy() + ) + + def eval_dp(self, dp_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + return dp_obj( + self.inputs, + self.atype.reshape(1, -1), + fparam=self.fparam if numb_fparam else None, + )["dos"] + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + if backend == self.RefBackend.TF: + # shape is not same + ret = ret[0].reshape(-1, self.natoms[0], 1) + return (ret,) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_dos, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt/dos/data/set.000/atom_dos.npy b/source/tests/pt/dos/data/set.000/atom_dos.npy new file mode 100644 index 0000000000..22809c1068 Binary files /dev/null and b/source/tests/pt/dos/data/set.000/atom_dos.npy differ diff --git a/source/tests/pt/dos/data/set.000/box.npy b/source/tests/pt/dos/data/set.000/box.npy new file mode 100644 index 0000000000..6265bf150e Binary files /dev/null and b/source/tests/pt/dos/data/set.000/box.npy differ diff --git a/source/tests/pt/dos/data/set.000/coord.npy b/source/tests/pt/dos/data/set.000/coord.npy new file mode 100644 index 0000000000..f33ce430bf Binary files /dev/null and b/source/tests/pt/dos/data/set.000/coord.npy differ diff --git a/source/tests/pt/dos/data/set.000/dos.npy b/source/tests/pt/dos/data/set.000/dos.npy new file mode 100644 index 0000000000..904b23e709 Binary files /dev/null and b/source/tests/pt/dos/data/set.000/dos.npy differ diff --git a/source/tests/pt/dos/data/type.raw b/source/tests/pt/dos/data/type.raw new file mode 100644 index 0000000000..de3c26ec4e --- /dev/null +++ b/source/tests/pt/dos/data/type.raw @@ -0,0 +1,32 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 diff --git a/source/tests/pt/dos/data/type_map.raw b/source/tests/pt/dos/data/type_map.raw new file mode 100644 index 0000000000..a9edc74f38 --- /dev/null +++ b/source/tests/pt/dos/data/type_map.raw @@ -0,0 +1 @@ +H diff --git a/source/tests/pt/dos/input.json b/source/tests/pt/dos/input.json new file mode 100644 index 0000000000..f9330003be --- /dev/null +++ b/source/tests/pt/dos/input.json @@ -0,0 +1,80 @@ +{ + "model": { + "type_map": [ + "H" + ], + "descriptor": { + "type": "se_e2_a", + "sel": [ + 90 + ], + "rcut_smth": 1.8, + "rcut": 6.0, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 8, + "precision": "float64", + "seed": 1 + }, + "fitting_net": { + "type": "dos", + "numb_dos": 250, + "neuron": [ + 120, + 120, + 120 + ], + "resnet_dt": true, + "numb_fparam": 0, + "precision": "float64", + "seed": 1 + } + }, + "loss": { + "type": "dos", + "start_pref_dos": 0.0, + "limit_pref_dos": 0.0, + "start_pref_cdf": 0.0, + "limit_pref_cdf": 0.0, + "start_pref_ados": 1.0, + "limit_pref_ados": 1.0, + "start_pref_acdf": 0.0, + "limit_pref_acdf": 0.0 + }, + "learning_rate": { + "type": "exp", + "start_lr": 0.001, + "stop_lr": 1e-08 + }, + "training": { + "stop_batch": 100000, + "seed": 1, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000, + "save_ckpt": "model.ckpt", + "disp_training": true, + "time_training": true, + "profiling": false, + "profiling_file": "timeline.json", + "training_data": { + "systems": [ + "pt/dos/data/" + ], + "set_prefix": "set", + "batch_size": 1 + }, + "validation_data": { + "systems": [ + "pt/dos/data/" + ], + "set_prefix": "set", + "batch_size": 1 + } + }, + "_comment1": "that's all" +} diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index 8ec5c375fd..3d9a4df11e 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -36,6 +36,28 @@ "data_stat_nbatch": 20, } +model_dos = { + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "se_e2_a", + "sel": [46, 92, 4], + "rcut_smth": 0.50, + "rcut": 4.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + "type": "dos", + "numb_dos": 5, + }, + "data_stat_nbatch": 20, +} + model_zbl = { "type_map": ["O", "H", "B"], "use_srtab": "source/tests/pt/model/water/data/zbl_tab_potential/H2O_tab_potential.txt", @@ -278,6 +300,13 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) +class TestDOSModelSeA(unittest.TestCase, PermutationTest): + def setUp(self): + model_params = copy.deepcopy(model_dos) + self.type_split = False + self.model = get_model(model_params).to(env.DEVICE) + + class TestEnergyModelDPA1(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) diff --git a/source/tests/pt/model/test_rot.py b/source/tests/pt/model/test_rot.py index a12bd063b4..cbf09ecf40 100644 --- a/source/tests/pt/model/test_rot.py +++ b/source/tests/pt/model/test_rot.py @@ -15,6 +15,7 @@ ) from .test_permutation import ( # model_dpau, + model_dos, model_dpa1, model_dpa2, model_hybrid, @@ -139,6 +140,13 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) +class TestDOSModelSeA(unittest.TestCase, RotTest): + def setUp(self): + model_params = copy.deepcopy(model_dos) + self.type_split = False + self.model = get_model(model_params).to(env.DEVICE) + + class TestEnergyModelDPA1(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index 86e9ed94d7..4f5be912cf 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -15,6 +15,7 @@ ) from .test_permutation import ( # model_dpau, + model_dos, model_dpa1, model_dpa2, model_hybrid, @@ -139,6 +140,14 @@ def setUp(self): self.epsilon, self.aprec = None, None +class TestDOSModelSeA(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dos) + self.type_split = False + self.model = get_model(model_params).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + # @unittest.skip("dpa-1 not smooth at the moment") class TestEnergyModelDPA1(unittest.TestCase, SmoothTest): def setUp(self): diff --git a/source/tests/pt/model/test_trans.py b/source/tests/pt/model/test_trans.py index 359e91d8c8..a0aeefd6b3 100644 --- a/source/tests/pt/model/test_trans.py +++ b/source/tests/pt/model/test_trans.py @@ -15,6 +15,7 @@ ) from .test_permutation import ( # model_dpau, + model_dos, model_dpa1, model_dpa2, model_hybrid, @@ -83,6 +84,13 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) +class TestDOSModelSeA(unittest.TestCase, TransTest): + def setUp(self): + model_params = copy.deepcopy(model_dos) + self.type_split = False + self.model = get_model(model_params).to(env.DEVICE) + + class TestEnergyModelDPA1(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 76055c6f4a..a9ba2fd720 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -17,6 +17,7 @@ ) from .model.test_permutation import ( + model_dos, model_dpa1, model_dpa2, model_hybrid, @@ -96,6 +97,23 @@ def tearDown(self) -> None: DPTrainTest.tearDown(self) +@unittest.skip("loss not implemented") +class TestDOSModelSeA(unittest.TestCase, DPTrainTest): + def setUp(self): + input_json = str(Path(__file__).parent / "dos/input.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "dos/data/")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dos) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) + + class TestEnergyZBLModelSeA(unittest.TestCase, DPTrainTest): def setUp(self): input_json = str(Path(__file__).parent / "water/zbl.json") diff --git a/source/tests/tf/test_fitting_dos.py b/source/tests/tf/test_fitting_dos.py index a2a54d6287..f9df5fc126 100644 --- a/source/tests/tf/test_fitting_dos.py +++ b/source/tests/tf/test_fitting_dos.py @@ -59,7 +59,8 @@ def test_fitting(self): descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True) jdata["model"]["fitting_net"].pop("type", None) - jdata["model"]["fitting_net"]["descrpt"] = descrpt + jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() + jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out() fitting = DOSFitting(**jdata["model"]["fitting_net"], uniform_seed=True) # model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']]) @@ -189,21 +190,20 @@ def test_fitting(self): ref_atom_dos_1 = [ -0.32495014, - -0.87979356, - -0.26630668, -0.32495882, - -0.87979767, - -0.2663072, + -0.32496842, + -0.32495892, + -0.32495469, + -0.32496075, ] ref_atom_dos_2 = [ - -0.26630917, 0.21549911, - -0.87979638, - -0.26630564, 0.21550413, - -0.87979585, + 0.21551077, + 0.21550547, + 0.21550303, + 0.21550645, ] places = 4 - np.testing.assert_almost_equal(pred_atom_dos[:, 0], ref_atom_dos_1, places) np.testing.assert_almost_equal(pred_atom_dos[:, 50], ref_atom_dos_2, places) diff --git a/source/tests/tf/test_model_dos.py b/source/tests/tf/test_model_dos.py index d88c81c332..9c01b14e32 100644 --- a/source/tests/tf/test_model_dos.py +++ b/source/tests/tf/test_model_dos.py @@ -66,7 +66,8 @@ def test_model(self): descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True) jdata["model"]["fitting_net"].pop("type", None) - jdata["model"]["fitting_net"]["descrpt"] = descrpt + jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() + jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out() fitting = DOSFitting(**jdata["model"]["fitting_net"], uniform_seed=True) model = DOSModel(descrpt, fitting) @@ -123,106 +124,106 @@ def test_model(self): ref_dos = np.array( [ - -2.98834333, - -0.63166985, - -3.37199568, - -1.88397887, - 0.87560992, - 4.85426159, - -1.22677731, - -0.60918118, - 8.80472675, - -1.12006829, - -3.72653765, - -3.03698828, - 3.50906891, - 5.55140795, - -3.34920924, - -4.43507641, - -6.1729281, - -8.34865917, - 0.14371788, - -4.38078479, - -6.43141133, - 4.07791938, - 7.14102837, - -0.52347718, - 0.82663796, - -1.64225631, - -4.63088421, - 3.3910594, - -9.09682274, - 1.61104204, - 4.45900773, - -2.44688559, - -2.83298183, - -2.00733658, - 7.33444256, - 7.09187373, - -1.97065392, - 0.01623084, - -7.48861264, - -1.17790161, - 2.77126775, - -2.55552037, - 3.3518257, - -0.09316856, - -1.94521413, - 0.50089251, - -2.75763233, - -1.94382637, - 1.30562041, - 5.08351043, - -1.90604837, - -0.80030045, - -4.87093267, - 4.18009666, - -2.9011435, - 2.58497143, - 4.47495176, - -0.9639419, - 8.15692179, - 0.48758731, - -0.62264663, - -1.70677258, - -5.51641378, - 3.98621565, - 0.57749944, - 2.9658081, - -4.10467591, - -7.14827888, - 0.02838605, - -2.48630333, - -4.82178216, - -0.7444178, - 2.48224802, - -1.54683936, - 0.46969412, - -0.0960347, - -2.08290541, - 6.357031, - -3.49716615, - 3.28959028, - 7.83932727, - 1.51457023, - -4.14575033, - 0.02007839, - 4.20953773, - 3.66456664, - -4.67441496, - -0.13296372, - -3.77145766, - 1.49368976, - -2.53627817, - -3.14188618, - 0.24991722, - 0.8770123, - 0.16635733, - -3.15391098, - -3.7733242, - -2.25134676, - 1.00975552, - 1.38717682, + -1.98049388, + -4.58033899, + -6.95508968, + -0.79619016, + 15.58478599, + 2.7636959, + -2.99147438, + -6.94430794, + -1.77877141, + -4.5000298, + -3.12026893, + -8.42191319, + 3.8991195, + 4.85271854, + 8.30541908, + -1.0435944, + -4.42713079, + 19.70011955, + -6.53945284, + 0.85064846, + 4.36868488, + 4.77303801, + 3.00829128, + 0.70043584, + -7.69047143, + -0.0647043, + 4.56830405, + -8.67154404, + -4.64015279, + -7.62202078, + -8.97078455, + -5.19685985, + -1.66080276, + -6.03225716, + -4.06780949, + -0.53046979, + 8.3543131, + -1.84893576, + 2.42669245, + -4.26357086, + -11.33995527, + 10.98529887, + -10.70000829, + -4.50179402, + -1.34978505, + -8.83091676, + -11.85324773, + -3.6305035, + 2.89933807, + 4.65750153, + 1.25464578, + -5.06196944, + 10.05305042, + -1.83868447, + -11.57017913, + -2.03900316, + -3.37235187, + -1.37010554, + -2.93769471, + 0.11905709, + 6.99367431, + 3.48640865, + -4.16242817, + 4.44778342, + -0.98405367, + 1.81581506, + -5.31481686, + 8.72426364, + 4.78954098, + 7.67879332, + -5.00417706, + 0.79717914, + -3.20581567, + -2.96034568, + 6.31165294, + 2.9891188, + -12.2013139, + -13.67496037, + 4.77102881, + 2.71353286, + 6.83849229, + -3.50400312, + 1.3839428, + -5.07550528, + -8.5623218, + 17.64081151, + 6.46051807, + 2.89067584, + 14.23057359, + 17.85941763, + -6.46129295, + -3.43602528, + -3.13520203, + 4.45313732, + -5.23012576, + -2.65929557, + -0.66191939, + 4.47530191, + 9.33992973, + -6.29808733, ] ) @@ -230,104 +231,104 @@ def test_model(self): [ -0.33019322, -0.76332506, - -0.32665648, - -0.76601747, - -1.16441856, - -0.13627609, -1.15916671, -0.13280604, - 2.60139518, - 0.44470952, - -0.48316771, - -1.15926141, 2.59680457, 0.46049936, - -0.29459777, - -0.76433726, - -0.52091744, - -1.39903065, -0.49890317, -1.15747878, - 0.66585524, - 0.81804842, - 1.38592217, - -0.18025826, -0.2964021, -0.74953328, - -0.7427461, - 3.27935087, - -1.09340192, - 0.1462458, -0.51982728, -1.40236941, - 0.73902497, - 0.79969456, - 0.50726592, - 0.11403234, 0.64964525, 0.8084967, - -1.27543102, - -0.00571457, - 0.7748912, - -1.42492251, 1.38371838, -0.17366078, - -0.76119888, - -1.26083707, - -1.48263244, - -0.85698727, -0.7374573, 3.28274006, - -0.27029769, - -1.00478711, - -0.67481511, - -0.07978058, -1.09001574, 0.14173437, - 1.4092343, - -0.31785424, - 0.40551362, - -0.71900495, 0.7269307, 0.79545851, - -1.88407155, - 1.83983772, - -1.78413438, - -0.74852344, 0.50059876, 0.1165872, - -0.2139368, - -1.44989426, - -1.96651281, - -0.6031689, -1.28106632, -0.01107711, - 0.48796663, - 0.76500912, - 0.21308153, - -0.85297893, 0.76139868, -1.44547292, - 1.68105021, - -0.30655702, - -1.93123, - -0.34294737, -0.77352498, -1.26982082, - -0.5562998, - -0.22048683, - -0.48641512, - 0.01124872, -1.49597963, -0.86647985, - 1.17310075, - 0.59402879, - -0.705076, - 0.72991794, -0.27728806, -1.00542829, - -0.16289102, - 0.29464248, + -0.67794229, + -0.08898442, + 1.39205396, + -0.30789099, + 0.40393006, + -0.70982912, + -1.88961087, + 1.830906, + -1.78326071, + -0.75013615, + -0.22537904, + -1.47257916, + -1.9756803, + -0.60493323, + 0.48350014, + 0.77676571, + 0.20885468, + -0.84351691, + 1.67501205, + -0.30662021, + -1.92884376, + -0.34021625, + -0.56212664, + -0.22884438, + -0.4891038, + 0.0199886, + 1.16506594, + 0.58068956, + -0.69376438, + 0.74156043, + -0.16360848, + 0.30303168, + -0.88639571, + 1.453683, + 0.79818052, + 1.2796414, + -0.8335433, + 0.13359098, + -0.53425462, + -0.4939294, + 1.05247266, + 0.49770575, + -2.03320073, + -2.27918678, + 0.79462598, + 0.45187804, + 1.13925239, + -0.58410808, + 0.23092918, + -0.84611213, + -1.42726499, + 2.93985879, + 1.07635712, + 0.48092082, + 2.37197063, + 2.97647126, + -1.07670667, + -0.57300341, + -0.52316403, + 0.74274268, + -0.87188274, + -0.44279998, + -0.11060956, + 0.74619435, + 1.55646754, + -1.05043903, ] )