From 5db7883457df7d7cdc7523a7354d6fb4a5432cfb Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 29 Feb 2024 16:13:17 +0800 Subject: [PATCH] Add DataRequirementItem --- deepmd/dpmodel/model/base_model.py | 5 +- deepmd/dpmodel/model/dp_model.py | 9 +++- deepmd/pt/model/model/dipole_model.py | 41 ++++++++------ deepmd/pt/model/model/dp_zbl_model.py | 77 ++++++++++++++------------ deepmd/pt/model/model/ener_model.py | 78 +++++++++++++++------------ deepmd/pt/model/model/model.py | 6 ++- deepmd/pt/model/model/polar_model.py | 41 ++++++++------ deepmd/pt/utils/dataloader.py | 7 ++- deepmd/pt/utils/dataset.py | 27 ++++++---- deepmd/utils/data.py | 70 ++++++++++++++++++++++++ 10 files changed, 243 insertions(+), 118 deletions(-) diff --git a/deepmd/dpmodel/model/base_model.py b/deepmd/dpmodel/model/base_model.py index c4b998d763..ee22dec132 100644 --- a/deepmd/dpmodel/model/base_model.py +++ b/deepmd/dpmodel/model/base_model.py @@ -10,6 +10,9 @@ Type, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from deepmd.utils.plugin import ( PluginVariant, make_plugin_registry, @@ -93,7 +96,7 @@ def model_output_type(self) -> str: """Get the output type for the model.""" @abstractmethod - def data_requirement(self) -> dict: + def data_requirement(self) -> List[DataRequirementItem]: """Get the data requirement for the model.""" @abstractmethod diff --git a/deepmd/dpmodel/model/dp_model.py b/deepmd/dpmodel/model/dp_model.py index 705750414b..88243c8742 100644 --- a/deepmd/dpmodel/model/dp_model.py +++ b/deepmd/dpmodel/model/dp_model.py @@ -1,10 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, +) + from deepmd.dpmodel.atomic_model import ( DPAtomicModel, ) from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from .make_model import ( make_model, @@ -14,6 +21,6 @@ # use "class" to resolve "Variable not allowed in type expression" @BaseModel.register("standard") class DPModel(make_model(DPAtomicModel), BaseModel): - def data_requirement(self) -> dict: + def data_requirement(self) -> List[DataRequirementItem]: """Get the data requirement for the model.""" raise NotImplementedError diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index f6d896b5d8..106202d00c 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -1,11 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Dict, + List, Optional, ) import torch +from deepmd.utils.data import ( + DataRequirementItem, +) + from .dp_model import ( DPModel, ) @@ -92,21 +97,23 @@ def forward_lower( return model_predict @property - def data_requirement(self): - data_requirement = { - "dipole": { - "ndof": 3, - "atomic": False, - "must": False, - "high_prec": False, - "type_sel": self.get_sel_type(), - }, - "atomic_dipole": { - "ndof": 3, - "atomic": True, - "must": False, - "high_prec": False, - "type_sel": self.get_sel_type(), - }, - } + def data_requirement(self) -> List[DataRequirementItem]: + data_requirement = [ + DataRequirementItem( + "dipole", + ndof=3, + atomic=False, + must=False, + high_prec=False, + type_sel=self.get_sel_type(), + ), + DataRequirementItem( + "atomic_dipole", + ndof=3, + atomic=True, + must=False, + high_prec=False, + type_sel=self.get_sel_type(), + ), + ] return data_requirement diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index fd47b4368d..fed9d89bf5 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Dict, + List, Optional, ) @@ -12,6 +13,9 @@ from deepmd.pt.model.model.model import ( BaseModel, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from .make_model import ( make_model, @@ -99,38 +103,43 @@ def forward_lower( return model_predict @property - def data_requirement(self): - data_requirement = { - "energy": { - "ndof": 1, - "atomic": False, - "must": False, - "high_prec": True, - }, - "force": { - "ndof": 3, - "atomic": True, - "must": False, - "high_prec": False, - }, - "virial": { - "ndof": 9, - "atomic": False, - "must": False, - "high_prec": False, - }, - "atom_ener": { - "ndof": 1, - "atomic": True, - "must": False, - "high_prec": False, - }, - "atom_pref": { - "ndof": 1, - "atomic": True, - "must": False, - "high_prec": False, - "repeat": 3, - }, - } + def data_requirement(self) -> List[DataRequirementItem]: + data_requirement = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ), + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ), + ] return data_requirement diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 1497cbade4..92b2b95e34 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -1,11 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Dict, + List, Optional, ) import torch +from deepmd.utils.data import ( + DataRequirementItem, +) + from .dp_model import ( DPModel, ) @@ -97,38 +102,43 @@ def forward_lower( return model_predict @property - def data_requirement(self): - data_requirement = { - "energy": { - "ndof": 1, - "atomic": False, - "must": False, - "high_prec": True, - }, - "force": { - "ndof": 3, - "atomic": True, - "must": False, - "high_prec": False, - }, - "virial": { - "ndof": 9, - "atomic": False, - "must": False, - "high_prec": False, - }, - "atom_ener": { - "ndof": 1, - "atomic": True, - "must": False, - "high_prec": False, - }, - "atom_pref": { - "ndof": 1, - "atomic": True, - "must": False, - "high_prec": False, - "repeat": 3, - }, - } + def data_requirement(self) -> List[DataRequirementItem]: + data_requirement = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ), + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ), + ] return data_requirement diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index 0e2afadd14..1b82402747 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + List, Optional, ) from deepmd.dpmodel.model.base_model import ( make_base_model, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from deepmd.utils.path import ( DPPath, ) @@ -85,6 +89,6 @@ def compute_or_load_stat( """ raise NotImplementedError - def data_requirement(self) -> dict: + def data_requirement(self) -> List[DataRequirementItem]: """Get the data requirement for the model.""" raise NotImplementedError diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index 450f5f2fb5..c23e26afac 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -1,11 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Dict, + List, Optional, ) import torch +from deepmd.utils.data import ( + DataRequirementItem, +) + from .dp_model import ( DPModel, ) @@ -76,21 +81,23 @@ def forward_lower( return model_predict @property - def get_data_requirement(self): - data_requirement = { - "polar": { - "ndof": 9, - "atomic": False, - "must": False, - "high_prec": False, - "type_sel": self.get_sel_type(), - }, - "atomic_polar": { - "ndof": 9, - "atomic": True, - "must": False, - "high_prec": False, - "type_sel": self.get_sel_type(), - }, - } + def get_data_requirement(self) -> List[DataRequirementItem]: + data_requirement = [ + DataRequirementItem( + "polar", + ndof=9, + atomic=False, + must=False, + high_prec=False, + type_sel=self.get_sel_type(), + ), + DataRequirementItem( + "atomic_polar", + ndof=9, + atomic=True, + must=False, + high_prec=False, + type_sel=self.get_sel_type(), + ), + ] return data_requirement diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 70993c21a0..b197f46124 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -35,6 +35,9 @@ from deepmd.pt.utils.dataset import ( DeepmdDataSetForLoader, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from deepmd.utils.data_system import ( prob_sys_size_ext, process_sys_probs, @@ -147,10 +150,10 @@ def __getitem__(self, idx): batch["sid"] = idx return batch - def add_data_requirement(self, dict_of_keys): + def add_data_requirement(self, data_requirement: List[DataRequirementItem]): """Add data requirement for each system in multiple systems.""" for system in self.systems: - system.add_data_requirement(dict_of_keys) + system.add_data_requirement(data_requirement) _sentinel = object() diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 9de82778dc..40a513acdf 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -1,11 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, +) + from torch.utils.data import ( Dataset, ) from deepmd.utils.data import ( + DataRequirementItem, DeepmdData, ) @@ -42,17 +47,17 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data - def add_data_requirement(self, dict_of_keys): + def add_data_requirement(self, data_requirement: List[DataRequirementItem]): """Add data requirement for this data system.""" - for data_key in dict_of_keys: + for data_item in data_requirement: self._data_system.add( - data_key, - dict_of_keys[data_key]["ndof"], - atomic=dict_of_keys[data_key].get("atomic", False), - must=dict_of_keys[data_key].get("must", False), - high_prec=dict_of_keys[data_key].get("high_prec", False), - type_sel=dict_of_keys[data_key].get("type_sel", None), - repeat=dict_of_keys[data_key].get("repeat", 1), - default=dict_of_keys[data_key].get("default", 0.0), - dtype=dict_of_keys[data_key].get("dtype", None), + data_item["key"], + data_item["ndof"], + atomic=data_item["atomic"], + must=data_item["must"], + high_prec=data_item["high_prec"], + type_sel=data_item["type_sel"], + repeat=data_item["repeat"], + default=data_item["default"], + dtype=data_item["dtype"], ) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 9e726fbe19..03e39e1f21 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -666,3 +666,73 @@ def _check_pbc(self, sys_path: DPPath): def _check_mode(self, set_path: DPPath): return (set_path / "real_atom_types.npy").is_file() + + +class DataRequirementItem: + """A class to store the data requirement for data systems. + + Parameters + ---------- + key + The key of the item. The corresponding data is stored in `sys_path/set.*/key.npy` + ndof + The number of dof + atomic + The item is an atomic property. + If False, the size of the data should be nframes x ndof + If True, the size of data should be nframes x natoms x ndof + must + The data file `sys_path/set.*/key.npy` must exist. + If must is False and the data file does not exist, the `data_dict[find_key]` is set to 0.0 + high_prec + Load the data and store in float64, otherwise in float32 + type_sel + Select certain type of atoms + repeat + The data will be repeated `repeat` times. + default : float, default=0. + default value of data + dtype : np.dtype, optional + the dtype of data, overwrites `high_prec` if provided + """ + + def __init__( + self, + key: str, + ndof: int, + atomic: bool = False, + must: bool = False, + high_prec: bool = False, + type_sel: Optional[List[int]] = None, + repeat: int = 1, + default: float = 0.0, + dtype: Optional[np.dtype] = None, + ) -> None: + self.key = key + self.ndof = ndof + self.atomic = atomic + self.must = must + self.high_prec = high_prec + self.type_sel = type_sel + self.repeat = repeat + self.default = default + self.dtype = dtype + self.dict = self.to_dict() + + def to_dict(self) -> dict: + return { + "key": self.key, + "ndof": self.ndof, + "atomic": self.atomic, + "must": self.must, + "high_prec": self.high_prec, + "type_sel": self.type_sel, + "repeat": self.repeat, + "default": self.default, + "dtype": self.dtype, + } + + def __getitem__(self, key: str): + if key not in self.dict: + raise KeyError(key) + return self.dict[key]