diff --git a/deepmd/dpmodel/atomic_model/__init__.py b/deepmd/dpmodel/atomic_model/__init__.py index 37f6b8bf28..4f4ef32e03 100644 --- a/deepmd/dpmodel/atomic_model/__init__.py +++ b/deepmd/dpmodel/atomic_model/__init__.py @@ -17,9 +17,18 @@ from .base_atomic_model import ( BaseAtomicModel, ) +from .dipole_atomic_model import ( + DPDipoleAtomicModel, +) +from .dos_atomic_model import ( + DPDOSAtomicModel, +) from .dp_atomic_model import ( DPAtomicModel, ) +from .energy_atomic_model import ( + DPEnergyAtomicModel, +) from .linear_atomic_model import ( DPZBLLinearEnergyAtomicModel, LinearEnergyAtomicModel, @@ -30,12 +39,19 @@ from .pairtab_atomic_model import ( PairTabAtomicModel, ) +from .polar_atomic_model import ( + DPPolarAtomicModel, +) __all__ = [ "make_base_atomic_model", "BaseAtomicModel", "DPAtomicModel", + "DPEnergyAtomicModel", "PairTabAtomicModel", "LinearEnergyAtomicModel", "DPZBLLinearEnergyAtomicModel", + "DPDOSAtomicModel", + "DPPolarAtomicModel", + "DPDipoleAtomicModel", ] diff --git a/deepmd/dpmodel/atomic_model/dipole_atomic_model.py b/deepmd/dpmodel/atomic_model/dipole_atomic_model.py new file mode 100644 index 0000000000..00428f4e95 --- /dev/null +++ b/deepmd/dpmodel/atomic_model/dipole_atomic_model.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from deepmd.dpmodel.fitting.dipole_fitting import ( + DipoleFitting, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPDipoleAtomicModel(DPAtomicModel): + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not isinstance(fitting, DipoleFitting): + raise TypeError( + "fitting must be an instance of DipoleFitting for DPDipoleAtomicModel" + ) + super().__init__(descriptor, fitting, type_map, **kwargs) + + def apply_out_stat( + self, + ret: dict[str, np.ndarray], + atype: np.ndarray, + ): + # dipole not applying bias + return ret diff --git a/deepmd/dpmodel/atomic_model/dos_atomic_model.py b/deepmd/dpmodel/atomic_model/dos_atomic_model.py new file mode 100644 index 0000000000..7ef6d10ebf --- /dev/null +++ b/deepmd/dpmodel/atomic_model/dos_atomic_model.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.fitting.dos_fitting import ( + DOSFittingNet, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPDOSAtomicModel(DPAtomicModel): + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not isinstance(fitting, DOSFittingNet): + raise TypeError( + "fitting must be an instance of DOSFittingNet for DPDOSAtomicModel" + ) + super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/dpmodel/atomic_model/energy_atomic_model.py b/deepmd/dpmodel/atomic_model/energy_atomic_model.py new file mode 100644 index 0000000000..4f9f8ec005 --- /dev/null +++ b/deepmd/dpmodel/atomic_model/energy_atomic_model.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.fitting.ener_fitting import ( + EnergyFittingNet, + InvarFitting, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPEnergyAtomicModel(DPAtomicModel): + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not ( + isinstance(fitting, EnergyFittingNet) or isinstance(fitting, InvarFitting) + ): + raise TypeError( + "fitting must be an instance of EnergyFittingNet or InvarFitting for DPEnergyAtomicModel" + ) + super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/dpmodel/atomic_model/polar_atomic_model.py b/deepmd/dpmodel/atomic_model/polar_atomic_model.py new file mode 100644 index 0000000000..6e1d32ff35 --- /dev/null +++ b/deepmd/dpmodel/atomic_model/polar_atomic_model.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import numpy as np + +from deepmd.dpmodel.fitting.polarizability_fitting import ( + PolarFitting, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPPolarAtomicModel(DPAtomicModel): + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not isinstance(fitting, PolarFitting): + raise TypeError( + "fitting must be an instance of PolarFitting for DPPolarAtomicModel" + ) + super().__init__(descriptor, fitting, type_map, **kwargs) + + def apply_out_stat( + self, + ret: dict[str, np.ndarray], + atype: np.ndarray, + ): + """Apply the stat to each atomic output. + + Parameters + ---------- + ret + The returned dict by the forward_atomic method + atype + The atom types. nf x nloc + + """ + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + + if self.fitting_net.shift_diag: + nframes, nloc = atype.shape + dtype = out_bias[self.bias_keys[0]].dtype + for kk in self.bias_keys: + ntypes = out_bias[kk].shape[0] + temp = np.zeros(ntypes, dtype=dtype) + temp = np.mean( + np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2), + axis=1, + ) + modified_bias = temp[atype] + + # (nframes, nloc, 1) + modified_bias = ( + modified_bias[..., np.newaxis] * (self.fitting_net.scale[atype]) + ) + + eye = np.eye(3, dtype=dtype) + eye = np.tile(eye, (nframes, nloc, 1, 1)) + # (nframes, nloc, 3, 3) + modified_bias = modified_bias[..., np.newaxis] * eye + + # nf x nloc x odims, out_bias: ntypes x odims + ret[kk] = ret[kk] + modified_bias + return ret diff --git a/deepmd/dpmodel/atomic_model/property_atomic_model.py b/deepmd/dpmodel/atomic_model/property_atomic_model.py index 1c5d2d1900..6f69f8dfb6 100644 --- a/deepmd/dpmodel/atomic_model/property_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/property_atomic_model.py @@ -9,6 +9,9 @@ class DPPropertyAtomicModel(DPAtomicModel): - def __init__(self, descriptor, fitting, type_map, **kwargs) -> None: - assert isinstance(fitting, PropertyFittingNet) + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not isinstance(fitting, PropertyFittingNet): + raise TypeError( + "fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel" + ) super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/dpmodel/model/dipole_model.py b/deepmd/dpmodel/model/dipole_model.py new file mode 100644 index 0000000000..4ca523f79b --- /dev/null +++ b/deepmd/dpmodel/model/dipole_model.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +from deepmd.dpmodel.atomic_model import ( + DPDipoleAtomicModel, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_model import ( + make_model, +) + +DPDipoleModel_ = make_model(DPDipoleAtomicModel) + + +@BaseModel.register("dipole") +class DipoleModel(DPModelCommon, DPDipoleModel_): + model_type = "dipole" + + def __init__( + self, + *args, + **kwargs, + ): + DPModelCommon.__init__(self) + DPDipoleModel_.__init__(self, *args, **kwargs) diff --git a/deepmd/dpmodel/model/dos_model.py b/deepmd/dpmodel/model/dos_model.py new file mode 100644 index 0000000000..3df887b460 --- /dev/null +++ b/deepmd/dpmodel/model/dos_model.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.dpmodel.atomic_model import ( + DPDOSAtomicModel, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_model import ( + make_model, +) + +DPDOSModel_ = make_model(DPDOSAtomicModel) + + +@BaseModel.register("dos") +class DOSModel(DPModelCommon, DPDOSModel_): + model_type = "dos" + + def __init__( + self, + *args, + **kwargs, + ): + DPModelCommon.__init__(self) + DPDOSModel_.__init__(self, *args, **kwargs) diff --git a/deepmd/dpmodel/model/ener_model.py b/deepmd/dpmodel/model/ener_model.py index 643f260bff..e4233eb397 100644 --- a/deepmd/dpmodel/model/ener_model.py +++ b/deepmd/dpmodel/model/ener_model.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from deepmd.dpmodel.atomic_model.dp_atomic_model import ( - DPAtomicModel, +from deepmd.dpmodel.atomic_model import ( + DPEnergyAtomicModel, ) from deepmd.dpmodel.model.base_model import ( BaseModel, @@ -13,7 +13,7 @@ make_model, ) -DPEnergyModel_ = make_model(DPAtomicModel) +DPEnergyModel_ = make_model(DPEnergyAtomicModel) @BaseModel.register("ener") diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 19408e58c4..1d18b70e8e 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( DPAtomicModel, ) @@ -8,18 +10,33 @@ from deepmd.dpmodel.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.dpmodel.fitting.base_fitting import ( + BaseFitting, +) from deepmd.dpmodel.fitting.ener_fitting import ( EnergyFittingNet, ) from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.dpmodel.model.dipole_model import ( + DipoleModel, +) +from deepmd.dpmodel.model.dos_model import ( + DOSModel, +) from deepmd.dpmodel.model.dp_zbl_model import ( DPZBLModel, ) from deepmd.dpmodel.model.ener_model import ( EnergyModel, ) +from deepmd.dpmodel.model.polar_model import ( + PolarModel, +) +from deepmd.dpmodel.model.property_model import ( + PropertyModel, +) from deepmd.dpmodel.model.spin_model import ( SpinModel, ) @@ -28,6 +45,29 @@ ) +def _get_standard_model_components(data, ntypes): + # descriptor + data["descriptor"]["ntypes"] = ntypes + data["descriptor"]["type_map"] = copy.deepcopy(data["type_map"]) + descriptor = BaseDescriptor(**data["descriptor"]) + # fitting + fitting_net = data.get("fitting_net", {}) + fitting_net["type"] = fitting_net.get("type", "ener") + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["type_map"] = copy.deepcopy(data["type_map"]) + fitting_net["mixed_types"] = descriptor.mixed_types() + if fitting_net["type"] in ["dipole", "polar"]: + fitting_net["embedding_width"] = descriptor.get_dim_emb() + fitting_net["dim_descrpt"] = descriptor.get_dim_out() + grad_force = "direct" not in fitting_net["type"] + if not grad_force: + fitting_net["out_dim"] = descriptor.get_dim_emb() + if "ener" in fitting_net["type"]: + fitting_net["return_energy"] = True + fitting = BaseFitting(**fitting_net) + return descriptor, fitting, fitting_net["type"] + + def get_standard_model(data: dict) -> EnergyModel: """Get a EnergyModel from a dictionary. @@ -40,29 +80,33 @@ def get_standard_model(data: dict) -> EnergyModel: raise ValueError( "In the DP backend, type_embedding is not at the model level, but within the descriptor. See type embedding documentation for details." ) - data["descriptor"]["type_map"] = data["type_map"] - data["descriptor"]["ntypes"] = len(data["type_map"]) - fitting_type = data["fitting_net"].pop("type") - data["fitting_net"]["type_map"] = data["type_map"] - descriptor = BaseDescriptor( - **data["descriptor"], - ) - if fitting_type == "ener": - fitting = EnergyFittingNet( - ntypes=descriptor.get_ntypes(), - dim_descrpt=descriptor.get_dim_out(), - mixed_types=descriptor.mixed_types(), - **data["fitting_net"], - ) + data = copy.deepcopy(data) + ntypes = len(data["type_map"]) + descriptor, fitting, fitting_net_type = _get_standard_model_components(data, ntypes) + atom_exclude_types = data.get("atom_exclude_types", []) + pair_exclude_types = data.get("pair_exclude_types", []) + + if fitting_net_type == "dipole": + modelcls = DipoleModel + elif fitting_net_type == "polar": + modelcls = PolarModel + elif fitting_net_type == "dos": + modelcls = DOSModel + elif fitting_net_type in ["ener", "direct_force_ener"]: + modelcls = EnergyModel + elif fitting_net_type == "property": + modelcls = PropertyModel else: - raise ValueError(f"Unknown fitting type {fitting_type}") - return EnergyModel( + raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") + + model = modelcls( descriptor=descriptor, fitting=fitting, type_map=data["type_map"], - atom_exclude_types=data.get("atom_exclude_types", []), - pair_exclude_types=data.get("pair_exclude_types", []), + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, ) + return model def get_zbl_model(data: dict) -> DPZBLModel: diff --git a/deepmd/dpmodel/model/polar_model.py b/deepmd/dpmodel/model/polar_model.py new file mode 100644 index 0000000000..994b3556c2 --- /dev/null +++ b/deepmd/dpmodel/model/polar_model.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.dpmodel.atomic_model import ( + DPPolarAtomicModel, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_model import ( + make_model, +) + +DPPolarModel_ = make_model(DPPolarAtomicModel) + + +@BaseModel.register("polar") +class PolarModel(DPModelCommon, DPPolarModel_): + model_type = "polar" + + def __init__( + self, + *args, + **kwargs, + ): + DPModelCommon.__init__(self) + DPPolarModel_.__init__(self, *args, **kwargs) diff --git a/deepmd/pt/model/atomic_model/dipole_atomic_model.py b/deepmd/pt/model/atomic_model/dipole_atomic_model.py index fd0879e707..3796aa2e83 100644 --- a/deepmd/pt/model/atomic_model/dipole_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dipole_atomic_model.py @@ -12,8 +12,11 @@ class DPDipoleAtomicModel(DPAtomicModel): - def __init__(self, descriptor, fitting, type_map, **kwargs) -> None: - assert isinstance(fitting, DipoleFittingNet) + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not isinstance(fitting, DipoleFittingNet): + raise TypeError( + "fitting must be an instance of DipoleFittingNet for DPDipoleAtomicModel" + ) super().__init__(descriptor, fitting, type_map, **kwargs) def apply_out_stat( diff --git a/deepmd/pt/model/atomic_model/dos_atomic_model.py b/deepmd/pt/model/atomic_model/dos_atomic_model.py index 1f7d7a9917..2af1a4e052 100644 --- a/deepmd/pt/model/atomic_model/dos_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dos_atomic_model.py @@ -9,6 +9,9 @@ class DPDOSAtomicModel(DPAtomicModel): - def __init__(self, descriptor, fitting, type_map, **kwargs) -> None: - assert isinstance(fitting, DOSFittingNet) + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not isinstance(fitting, DOSFittingNet): + raise TypeError( + "fitting must be an instance of DOSFittingNet for DPDOSAtomicModel" + ) super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/pt/model/atomic_model/energy_atomic_model.py b/deepmd/pt/model/atomic_model/energy_atomic_model.py index 855c1213ec..6d894b4aab 100644 --- a/deepmd/pt/model/atomic_model/energy_atomic_model.py +++ b/deepmd/pt/model/atomic_model/energy_atomic_model.py @@ -11,10 +11,13 @@ class DPEnergyAtomicModel(DPAtomicModel): - def __init__(self, descriptor, fitting, type_map, **kwargs) -> None: - assert ( + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not ( isinstance(fitting, EnergyFittingNet) or isinstance(fitting, EnergyFittingNetDirect) or isinstance(fitting, InvarFitting) - ) + ): + raise TypeError( + "fitting must be an instance of EnergyFittingNet, EnergyFittingNetDirect or InvarFitting for DPEnergyAtomicModel" + ) super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/pt/model/atomic_model/polar_atomic_model.py b/deepmd/pt/model/atomic_model/polar_atomic_model.py index 8ec27b5762..6bd063591f 100644 --- a/deepmd/pt/model/atomic_model/polar_atomic_model.py +++ b/deepmd/pt/model/atomic_model/polar_atomic_model.py @@ -12,8 +12,11 @@ class DPPolarAtomicModel(DPAtomicModel): - def __init__(self, descriptor, fitting, type_map, **kwargs) -> None: - assert isinstance(fitting, PolarFittingNet) + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not isinstance(fitting, PolarFittingNet): + raise TypeError( + "fitting must be an instance of PolarFittingNet for DPPolarAtomicModel" + ) super().__init__(descriptor, fitting, type_map, **kwargs) def apply_out_stat( @@ -40,8 +43,12 @@ def apply_out_stat( for kk in self.bias_keys: ntypes = out_bias[kk].shape[0] temp = torch.zeros(ntypes, dtype=dtype, device=device) - for i in range(ntypes): - temp[i] = torch.mean(torch.diagonal(out_bias[kk][i].reshape(3, 3))) + temp = torch.mean( + torch.diagonal( + out_bias[kk].reshape(ntypes, 3, 3), dim1=-2, dim2=-1 + ), + dim=-1, + ) modified_bias = temp[atype] # (nframes, nloc, 1) diff --git a/deepmd/pt/model/atomic_model/property_atomic_model.py b/deepmd/pt/model/atomic_model/property_atomic_model.py index 6bacaf5d72..1fdc72b2b6 100644 --- a/deepmd/pt/model/atomic_model/property_atomic_model.py +++ b/deepmd/pt/model/atomic_model/property_atomic_model.py @@ -12,8 +12,11 @@ class DPPropertyAtomicModel(DPAtomicModel): - def __init__(self, descriptor, fitting, type_map, **kwargs) -> None: - assert isinstance(fitting, PropertyFittingNet) + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not isinstance(fitting, PropertyFittingNet): + raise TypeError( + "fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel" + ) super().__init__(descriptor, fitting, type_map, **kwargs) def apply_out_stat( diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index 79490bc20c..a24820b74a 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -19,11 +19,11 @@ make_model, ) -DPDOSModel_ = make_model(DPDipoleAtomicModel) +DPDipoleModel_ = make_model(DPDipoleAtomicModel) @BaseModel.register("dipole") -class DipoleModel(DPModelCommon, DPDOSModel_): +class DipoleModel(DPModelCommon, DPDipoleModel_): model_type = "dipole" def __init__( @@ -32,7 +32,7 @@ def __init__( **kwargs, ) -> None: DPModelCommon.__init__(self) - DPDOSModel_.__init__(self, *args, **kwargs) + DPDipoleModel_.__init__(self, *args, **kwargs) def translated_output_def(self): out_def_data = self.model_output_def().get_data() diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index ea6316dc91..cb72532366 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -19,11 +19,11 @@ make_model, ) -DPDOSModel_ = make_model(DPPolarAtomicModel) +DPPolarModel_ = make_model(DPPolarAtomicModel) @BaseModel.register("polar") -class PolarModel(DPModelCommon, DPDOSModel_): +class PolarModel(DPModelCommon, DPPolarModel_): model_type = "polar" def __init__( @@ -32,7 +32,7 @@ def __init__( **kwargs, ) -> None: DPModelCommon.__init__(self) - DPDOSModel_.__init__(self, *args, **kwargs) + DPPolarModel_.__init__(self, *args, **kwargs) def translated_output_def(self): out_def_data = self.model_output_def().get_data() diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 4eeb19b1f0..bb38abc5b6 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -34,7 +34,9 @@ class ModelTest: """Useful utilities for model tests.""" - def build_tf_model(self, obj, natoms, coords, atype, box, suffix): + def build_tf_model( + self, obj, natoms, coords, atype, box, suffix, ret_key: str = "energy" + ): t_coord = tf.placeholder( GLOBAL_TF_FLOAT_PRECISION, [None, None, None], name="i_coord" ) @@ -51,13 +53,32 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix): {}, suffix=suffix, ) - return [ - ret["energy"], - ret["atom_ener"], - ret["force"], - ret["virial"], - ret["atom_virial"], - ], { + if ret_key == "energy": + ret_list = [ + ret["energy"], + ret["atom_ener"], + ret["force"], + ret["virial"], + ret["atom_virial"], + ] + elif ret_key == "dos": + ret_list = [ + ret["dos"], + ret["atom_dos"], + ] + elif ret_key == "dipole": + ret_list = [ + ret["global_dipole"], + ret["dipole"], + ] + elif ret_key == "polar": + ret_list = [ + ret["polar"], + ret["global_polar"], + ] + else: + raise NotImplementedError + return ret_list, { t_coord: coords, t_type: atype, t_natoms: natoms, diff --git a/source/tests/consistent/model/test_dos.py b/source/tests/consistent/model/test_dos.py new file mode 100644 index 0000000000..8f0b0309cc --- /dev/null +++ b/source/tests/consistent/model/test_dos.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.dos_model import DOSModel as DOSModelDP +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_TF, + CommonTest, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.dos_model import DOSModel as DOSModelPT +else: + DOSModelPT = None +if INSTALLED_TF: + from deepmd.tf.model.dos import DOSModel as DOSModelTF +else: + DOSModelTF = None +from deepmd.utils.argcheck import ( + model_args, +) + + +class TestDOS(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + return { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 1.8, + "rcut": 6.0, + "neuron": [2, 4, 8], + "resnet_dt": False, + "axis_neuron": 8, + "precision": "float64", + "seed": 1, + }, + "fitting_net": { + "type": "dos", + "numb_dos": 2, + "neuron": [4, 4, 4], + "resnet_dt": True, + "numb_fparam": 0, + "precision": "float64", + "seed": 1, + }, + } + + tf_class = DOSModelTF + dp_class = DOSModelDP + pt_class = DOSModelPT + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_tf(self): + return True # need to fix tf consistency + + @property + def skip_jax(self) -> bool: + return True + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is DOSModelDP: + return get_model_dp(data) + elif cls is DOSModelPT: + model = get_model_pt(data) + model.atomic_model.out_bias.uniform_() + return model + return cls(**data, **self.additional_data) + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + return self.build_tf_model( + obj, self.natoms, self.coords, self.atype, self.box, suffix, ret_key="dos" + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend is self.RefBackend.DP: + return ( + ret["dos_redu"].ravel(), + ret["dos"].ravel(), + ) + elif backend is self.RefBackend.PT: + return ( + ret["dos"].ravel(), + ret["atom_dos"].ravel(), + ) + elif backend is self.RefBackend.TF: + return ( + ret[0].ravel(), + ret[1].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}")