diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9058decc21..bd36fd6e63 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: exclude: ^source/3rdparty - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.8.3 + rev: v0.8.4 hooks: - id: ruff args: ["--fix"] diff --git a/CITATIONS.bib b/CITATIONS.bib index d5524a14f6..52c8045bf3 100644 --- a/CITATIONS.bib +++ b/CITATIONS.bib @@ -128,26 +128,26 @@ @article{Zhang_NpjComputMater_2024_v10_p94 doi = {10.1038/s41524-024-01278-7}, } -@misc{Zhang_2023_DPA2, +@article{Zhang_npjComputMater_2024_v10_p293, annote = {DPA-2}, author = { Duo Zhang and Xinzijian Liu and Xiangyu Zhang and Chengqian Zhang and Chun - Cai and Hangrui Bi and Yiming Du and Xuejian Qin and Jiameng Huang and - Bowen Li and Yifan Shan and Jinzhe Zeng and Yuzhi Zhang and Siyuan Liu and - Yifan Li and Junhan Chang and Xinyan Wang and Shuo Zhou and Jianchuan Liu - and Xiaoshan Luo and Zhenyu Wang and Wanrun Jiang and Jing Wu and Yudi Yang - and Jiyuan Yang and Manyi Yang and Fu-Qiang Gong and Linshuang Zhang and - Mengchao Shi and Fu-Zhi Dai and Darrin M. York and Shi Liu and Tong Zhu and - Zhicheng Zhong and Jian Lv and Jun Cheng and Weile Jia and Mohan Chen and - Guolin Ke and Weinan E and Linfeng Zhang and Han Wang + Cai and Hangrui Bi and Yiming Du and Xuejian Qin and Anyang Peng and + Jiameng Huang and Bowen Li and Yifan Shan and Jinzhe Zeng and Yuzhi Zhang + and Siyuan Liu and Yifan Li and Junhan Chang and Xinyan Wang and Shuo Zhou + and Jianchuan Liu and Xiaoshan Luo and Zhenyu Wang and Wanrun Jiang and + Jing Wu and Yudi Yang and Jiyuan Yang and Manyi Yang and Fu-Qiang Gong and + Linshuang Zhang and Mengchao Shi and Fu-Zhi Dai and Darrin M. York and Shi + Liu and Tong Zhu and Zhicheng Zhong and Jian Lv and Jun Cheng and Weile Jia + and Mohan Chen and Guolin Ke and Weinan E and Linfeng Zhang and Han Wang }, - title = { - {DPA-2: Towards a universal large atomic model for molecular and material - simulation} - }, - publisher = {arXiv}, - year = 2023, - doi = {10.48550/arXiv.2312.15492}, + title = {{DPA-2: a large atomic model as a multi-task learner}}, + journal = {npj Comput. Mater}, + year = 2024, + volume = 10, + number = 1, + pages = 293, + doi = {10.1038/s41524-024-01493-2}, } @article{Zhang_PhysPlasmas_2020_v27_p122704, diff --git a/deepmd/dpmodel/atomic_model/__init__.py b/deepmd/dpmodel/atomic_model/__init__.py index 3d90c738ae..4d882d5e4b 100644 --- a/deepmd/dpmodel/atomic_model/__init__.py +++ b/deepmd/dpmodel/atomic_model/__init__.py @@ -42,6 +42,9 @@ from .polar_atomic_model import ( DPPolarAtomicModel, ) +from .property_atomic_model import ( + DPPropertyAtomicModel, +) __all__ = [ "BaseAtomicModel", @@ -50,6 +53,7 @@ "DPDipoleAtomicModel", "DPEnergyAtomicModel", "DPPolarAtomicModel", + "DPPropertyAtomicModel", "DPZBLLinearEnergyAtomicModel", "LinearEnergyAtomicModel", "PairTabAtomicModel", diff --git a/deepmd/dpmodel/atomic_model/property_atomic_model.py b/deepmd/dpmodel/atomic_model/property_atomic_model.py index 6f69f8dfb6..e3c038e695 100644 --- a/deepmd/dpmodel/atomic_model/property_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/property_atomic_model.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + from deepmd.dpmodel.fitting.property_fitting import ( PropertyFittingNet, ) @@ -15,3 +17,25 @@ def __init__(self, descriptor, fitting, type_map, **kwargs): "fitting must be an instance of PropertyFittingNet for DPPropertyAtomicModel" ) super().__init__(descriptor, fitting, type_map, **kwargs) + + def apply_out_stat( + self, + ret: dict[str, np.ndarray], + atype: np.ndarray, + ): + """Apply the stat to each atomic output. + + In property fitting, each output will be multiplied by label std and then plus the label average value. + + Parameters + ---------- + ret + The returned dict by the forward_atomic method + atype + The atom types. nf x nloc. It is useless in property fitting. + + """ + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + ret[kk] = ret[kk] * out_std[kk][0] + out_bias[kk][0] + return ret diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index e4cadb7b36..55ae331593 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -387,7 +387,7 @@ def __init__( use_tebd_bias: bool = False, type_map: Optional[list[str]] = None, ) -> None: - r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492. + r"""The DPA-2 descriptor[1]_. Parameters ---------- @@ -434,6 +434,11 @@ def __init__( sw: torch.Tensor The switch function for decaying inverse distance. + References + ---------- + .. [1] Zhang, D., Liu, X., Zhang, X. et al. DPA-2: a + large atomic model as a multi-task learner. npj + Comput Mater 10, 293 (2024). https://doi.org/10.1038/s41524-024-01493-2 """ def init_subclass_params(sub_data, sub_class): diff --git a/deepmd/dpmodel/fitting/property_fitting.py b/deepmd/dpmodel/fitting/property_fitting.py index 8b903af00e..6d0aa3546f 100644 --- a/deepmd/dpmodel/fitting/property_fitting.py +++ b/deepmd/dpmodel/fitting/property_fitting.py @@ -41,10 +41,9 @@ class PropertyFittingNet(InvarFitting): this list is of length :math:`N_l + 1`, specifying if the hidden layers and the output layer are trainable. intensive Whether the fitting property is intensive. - bias_method - The method of applying the bias to each atomic output, user can select 'normal' or 'no_bias'. - If 'normal' is used, the computed bias will be added to the atomic output. - If 'no_bias' is used, no bias will be added to the atomic output. + property_name: + The name of fitting property, which should be consistent with the property name in the dataset. + If the data file is named `humo.npy`, this parameter should be "humo". resnet_dt Time-step `dt` in the resnet construction: :math:`y = x + dt * \phi (Wx + b)` @@ -74,7 +73,7 @@ def __init__( rcond: Optional[float] = None, trainable: Union[bool, list[bool]] = True, intensive: bool = False, - bias_method: str = "normal", + property_name: str = "property", resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, @@ -89,9 +88,8 @@ def __init__( ) -> None: self.task_dim = task_dim self.intensive = intensive - self.bias_method = bias_method super().__init__( - var_name="property", + var_name=property_name, ntypes=ntypes, dim_descrpt=dim_descrpt, dim_out=task_dim, @@ -113,9 +111,9 @@ def __init__( @classmethod def deserialize(cls, data: dict) -> "PropertyFittingNet": data = data.copy() - check_version_compatibility(data.pop("@version"), 3, 1) + check_version_compatibility(data.pop("@version"), 4, 1) data.pop("dim_out") - data.pop("var_name") + data["property_name"] = data.pop("var_name") data.pop("tot_ener_zero") data.pop("layer_name") data.pop("use_aparam_as_mask", None) @@ -131,6 +129,8 @@ def serialize(self) -> dict: **InvarFitting.serialize(self), "type": "property", "task_dim": self.task_dim, + "intensive": self.intensive, } + dd["@version"] = 4 return dd diff --git a/deepmd/dpmodel/model/property_model.py b/deepmd/dpmodel/model/property_model.py index 16fdedd36e..9bd07bd349 100644 --- a/deepmd/dpmodel/model/property_model.py +++ b/deepmd/dpmodel/model/property_model.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from deepmd.dpmodel.atomic_model.dp_atomic_model import ( - DPAtomicModel, +from deepmd.dpmodel.atomic_model import ( + DPPropertyAtomicModel, ) from deepmd.dpmodel.model.base_model import ( BaseModel, @@ -13,7 +13,7 @@ make_model, ) -DPPropertyModel_ = make_model(DPAtomicModel) +DPPropertyModel_ = make_model(DPPropertyAtomicModel) @BaseModel.register("property") diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index d9744246d7..5aeb84468d 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -779,9 +779,17 @@ def test_property( tuple[list[np.ndarray], list[int]] arrays with results and their shapes """ - data.add("property", dp.task_dim, atomic=False, must=True, high_prec=True) + var_name = dp.get_var_name() + assert isinstance(var_name, str) + data.add(var_name, dp.task_dim, atomic=False, must=True, high_prec=True) if has_atom_property: - data.add("atom_property", dp.task_dim, atomic=True, must=False, high_prec=True) + data.add( + f"atom_{var_name}", + dp.task_dim, + atomic=True, + must=False, + high_prec=True, + ) if dp.get_dim_fparam() > 0: data.add( @@ -832,12 +840,12 @@ def test_property( aproperty = ret[1] aproperty = aproperty.reshape([numb_test, natoms * dp.task_dim]) - diff_property = property - test_data["property"][:numb_test] + diff_property = property - test_data[var_name][:numb_test] mae_property = mae(diff_property) rmse_property = rmse(diff_property) if has_atom_property: - diff_aproperty = aproperty - test_data["atom_property"][:numb_test] + diff_aproperty = aproperty - test_data[f"atom_{var_name}"][:numb_test] mae_aproperty = mae(diff_aproperty) rmse_aproperty = rmse(diff_aproperty) @@ -854,7 +862,7 @@ def test_property( detail_path = Path(detail_file) for ii in range(numb_test): - test_out = test_data["property"][ii].reshape(-1, 1) + test_out = test_data[var_name][ii].reshape(-1, 1) pred_out = property[ii].reshape(-1, 1) frame_output = np.hstack((test_out, pred_out)) @@ -868,7 +876,7 @@ def test_property( if has_atom_property: for ii in range(numb_test): - test_out = test_data["atom_property"][ii].reshape(-1, 1) + test_out = test_data[f"atom_{var_name}"][ii].reshape(-1, 1) pred_out = aproperty[ii].reshape(-1, 1) frame_output = np.hstack((test_out, pred_out)) diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 159f9bdf60..15e4a56280 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -70,8 +70,6 @@ class DeepEvalBackend(ABC): "dipole_derv_c_redu": "virial", "dos": "atom_dos", "dos_redu": "dos", - "property": "atom_property", - "property_redu": "property", "mask_mag": "mask_mag", "mask": "mask", # old models in v1 @@ -276,6 +274,10 @@ def get_has_spin(self) -> bool: """Check if the model has spin atom types.""" return False + def get_var_name(self) -> str: + """Get the name of the fitting property.""" + raise NotImplementedError + @abstractmethod def get_ntypes_spin(self) -> int: """Get the number of spin atom types of this model. Only used in old implement.""" diff --git a/deepmd/infer/deep_property.py b/deepmd/infer/deep_property.py index 389a0e8512..5944491cc0 100644 --- a/deepmd/infer/deep_property.py +++ b/deepmd/infer/deep_property.py @@ -37,25 +37,41 @@ class DeepProperty(DeepEval): Keyword arguments. """ - @property def output_def(self) -> ModelOutputDef: - """Get the output definition of this model.""" - return ModelOutputDef( + """ + Get the output definition of this model. + But in property_fitting, the output definition is not known until the model is loaded. + So we need to rewrite the output definition after the model is loaded. + See detail in change_output_def. + """ + pass + + def change_output_def(self) -> None: + """ + Change the output definition of this model. + In property_fitting, the output definition is known after the model is loaded. + We need to rewrite the output definition and related information. + """ + self.output_def = ModelOutputDef( FittingOutputDef( [ OutputVariableDef( - "property", - shape=[-1], + self.get_var_name(), + shape=[self.get_task_dim()], reducible=True, atomic=True, + intensive=self.get_intensive(), ), ] ) ) - - def change_output_def(self) -> None: - self.output_def["property"].shape = self.task_dim - self.output_def["property"].intensive = self.get_intensive() + self.deep_eval.output_def = self.output_def + self.deep_eval._OUTDEF_DP2BACKEND[self.get_var_name()] = ( + f"atom_{self.get_var_name()}" + ) + self.deep_eval._OUTDEF_DP2BACKEND[f"{self.get_var_name()}_redu"] = ( + self.get_var_name() + ) @property def task_dim(self) -> int: @@ -120,10 +136,12 @@ def eval( aparam=aparam, **kwargs, ) - atomic_property = results["property"].reshape( + atomic_property = results[self.get_var_name()].reshape( nframes, natoms, self.get_task_dim() ) - property = results["property_redu"].reshape(nframes, self.get_task_dim()) + property = results[f"{self.get_var_name()}_redu"].reshape( + nframes, self.get_task_dim() + ) if atomic: return ( @@ -141,5 +159,9 @@ def get_intensive(self) -> bool: """Get whether the property is intensive.""" return self.deep_eval.get_intensive() + def get_var_name(self) -> str: + """Get the name of the fitting property.""" + return self.deep_eval.get_var_name() + __all__ = ["DeepProperty"] diff --git a/deepmd/pd/infer/deep_eval.py b/deepmd/pd/infer/deep_eval.py index a2f8510f28..c31170ad71 100644 --- a/deepmd/pd/infer/deep_eval.py +++ b/deepmd/pd/infer/deep_eval.py @@ -113,6 +113,7 @@ def __init__( else: # self.dp = paddle.jit.load(self.model_path.split(".json")[0]) raise ValueError(f"Unknown model file format: {self.model_path}!") + self.dp.eval() self.rcut = self.dp.model["Default"].get_rcut() self.type_map = self.dp.model["Default"].get_type_map() if isinstance(auto_batch_size, bool): diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 65e35a1c4b..0f3c7a9732 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -588,15 +588,14 @@ def warm_up_linear(step, warmup_steps): if self.opt_type == "Adam": self.scheduler = paddle.optimizer.lr.LambdaDecay( learning_rate=self.lr_exp.start_lr, - lr_lambda=lambda step: warm_up_linear( - step + self.start_step, self.warmup_steps - ), + lr_lambda=lambda step: warm_up_linear(step, self.warmup_steps), ) self.optimizer = paddle.optimizer.Adam( learning_rate=self.scheduler, parameters=self.wrapper.parameters() ) if optimizer_state_dict is not None and self.restart_training: self.optimizer.set_state_dict(optimizer_state_dict) + self.scheduler.last_epoch -= 1 else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") diff --git a/deepmd/pd/utils/auto_batch_size.py b/deepmd/pd/utils/auto_batch_size.py index 8cdb5ddea2..0eb5e46d5f 100644 --- a/deepmd/pd/utils/auto_batch_size.py +++ b/deepmd/pd/utils/auto_batch_size.py @@ -49,12 +49,8 @@ def is_oom_error(self, e: Exception) -> bool: # several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error, # such as https://github.com/JuliaGPU/CUDA.jl/issues/1924 # (the meaningless error message should be considered as a bug in cusolver) - if isinstance(e, RuntimeError) and ( - "CUDA out of memory." in e.args[0] - or "CUDA driver error: out of memory" in e.args[0] - or "cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR" in e.args[0] - ): + if isinstance(e, MemoryError) and ("ResourceExhaustedError" in e.args[0]): # Release all unoccupied cached memory - # paddle.device.cuda.empty_cache() + paddle.device.cuda.empty_cache() return True return False diff --git a/deepmd/pd/utils/region.py b/deepmd/pd/utils/region.py index f3e3eaa52d..d2600ef16e 100644 --- a/deepmd/pd/utils/region.py +++ b/deepmd/pd/utils/region.py @@ -108,5 +108,5 @@ def normalize_coord( """ icoord = phys2inter(coord, cell) - icoord = paddle.remainder(icoord, paddle.full([], 1.0)) + icoord = paddle.remainder(icoord, paddle.full([], 1.0, dtype=icoord.dtype)) return inter2phys(icoord, cell) diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 59b833d34c..facead838e 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -184,6 +184,15 @@ def get_dim_aparam(self) -> int: def get_intensive(self) -> bool: return self.dp.model["Default"].get_intensive() + def get_var_name(self) -> str: + """Get the name of the property.""" + if hasattr(self.dp.model["Default"], "get_var_name") and callable( + getattr(self.dp.model["Default"], "get_var_name") + ): + return self.dp.model["Default"].get_var_name() + else: + raise NotImplementedError + @property def model_type(self) -> type["DeepEvalWrapper"]: """The the evaluator of the model type.""" @@ -200,7 +209,7 @@ def model_type(self) -> type["DeepEvalWrapper"]: return DeepGlobalPolar elif "wfc" in model_output_type: return DeepWFC - elif "property" in model_output_type: + elif self.get_var_name() in model_output_type: return DeepProperty else: raise RuntimeError("Unknown model type") diff --git a/deepmd/pt/loss/property.py b/deepmd/pt/loss/property.py index 07e394650a..9d42c81b45 100644 --- a/deepmd/pt/loss/property.py +++ b/deepmd/pt/loss/property.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from typing import ( + Union, +) import torch import torch.nn.functional as F @@ -21,9 +24,13 @@ class PropertyLoss(TaskLoss): def __init__( self, task_dim, + var_name: str, loss_func: str = "smooth_mae", metric: list = ["mae"], beta: float = 1.00, + out_bias: Union[list, None] = None, + out_std: Union[list, None] = None, + intensive: bool = False, **kwargs, ) -> None: r"""Construct a layer to compute loss on property. @@ -32,18 +39,32 @@ def __init__( ---------- task_dim : float The output dimension of property fitting net. + var_name : str + The atomic property to fit, 'energy', 'dipole', and 'polar'. loss_func : str The loss function, such as "smooth_mae", "mae", "rmse". metric : list The metric such as mae, rmse which will be printed. - beta: + beta : float The 'beta' parameter in 'smooth_mae' loss. + out_bias : Union[list, None] + It is the average value of the label. The shape is nkeys * ntypes * task_dim. + In property fitting, nkeys = 1, so the shape is 1 * ntypes * task_dim. + out_std : Union[list, None] + It is the standard deviation of the label. The shape is nkeys * ntypes * task_dim. + In property fitting, nkeys = 1, so the shape is 1 * ntypes * task_dim. + intensive : bool + Whether the property is intensive. """ super().__init__() self.task_dim = task_dim self.loss_func = loss_func self.metric = metric self.beta = beta + self.out_bias = out_bias + self.out_std = out_std + self.intensive = intensive + self.var_name = var_name def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False): """Return loss on properties . @@ -69,34 +90,64 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False Other losses for display. """ model_pred = model(**input_dict) - assert label["property"].shape[-1] == self.task_dim - assert model_pred["property"].shape[-1] == self.task_dim + var_name = self.var_name + nbz = model_pred[var_name].shape[0] + assert model_pred[var_name].shape == (nbz, self.task_dim) + assert label[var_name].shape == (nbz, self.task_dim) + if not self.intensive: + model_pred[var_name] = model_pred[var_name] / natoms + label[var_name] = label[var_name] / natoms + + if self.out_std is None: + out_std = model.atomic_model.out_std[0][0] + else: + out_std = torch.tensor( + self.out_std, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + if out_std.shape != (self.task_dim,): + raise ValueError( + f"Expected out_std to have shape ({self.task_dim},), but got {out_std.shape}" + ) + + if self.out_bias is None: + out_bias = model.atomic_model.out_bias[0][0] + else: + out_bias = torch.tensor( + self.out_bias, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + if out_bias.shape != (self.task_dim,): + raise ValueError( + f"Expected out_bias to have shape ({self.task_dim},), but got {out_bias.shape}" + ) + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] more_loss = {} # loss if self.loss_func == "smooth_mae": loss += F.smooth_l1_loss( - label["property"], - model_pred["property"], + (label[var_name] - out_bias) / out_std, + (model_pred[var_name] - out_bias) / out_std, reduction="sum", beta=self.beta, ) elif self.loss_func == "mae": loss += F.l1_loss( - label["property"], model_pred["property"], reduction="sum" + (label[var_name] - out_bias) / out_std, + (model_pred[var_name] - out_bias) / out_std, + reduction="sum", ) elif self.loss_func == "mse": loss += F.mse_loss( - label["property"], - model_pred["property"], + (label[var_name] - out_bias) / out_std, + (model_pred[var_name] - out_bias) / out_std, reduction="sum", ) elif self.loss_func == "rmse": loss += torch.sqrt( F.mse_loss( - label["property"], - model_pred["property"], + (label[var_name] - out_bias) / out_std, + (model_pred[var_name] - out_bias) / out_std, reduction="mean", ) ) @@ -106,28 +157,28 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False # more loss if "smooth_mae" in self.metric: more_loss["smooth_mae"] = F.smooth_l1_loss( - label["property"], - model_pred["property"], + label[var_name], + model_pred[var_name], reduction="mean", beta=self.beta, ).detach() if "mae" in self.metric: more_loss["mae"] = F.l1_loss( - label["property"], - model_pred["property"], + label[var_name], + model_pred[var_name], reduction="mean", ).detach() if "mse" in self.metric: more_loss["mse"] = F.mse_loss( - label["property"], - model_pred["property"], + label[var_name], + model_pred[var_name], reduction="mean", ).detach() if "rmse" in self.metric: more_loss["rmse"] = torch.sqrt( F.mse_loss( - label["property"], - model_pred["property"], + label[var_name], + model_pred[var_name], reduction="mean", ) ).detach() @@ -140,10 +191,10 @@ def label_requirement(self) -> list[DataRequirementItem]: label_requirement = [] label_requirement.append( DataRequirementItem( - "property", + self.var_name, ndof=self.task_dim, atomic=False, - must=False, + must=True, high_prec=True, ) ) diff --git a/deepmd/pt/loss/tensor.py b/deepmd/pt/loss/tensor.py index 8f2f937a07..69b133de58 100644 --- a/deepmd/pt/loss/tensor.py +++ b/deepmd/pt/loss/tensor.py @@ -22,6 +22,7 @@ def __init__( pref_atomic: float = 0.0, pref: float = 0.0, inference=False, + enable_atomic_weight: bool = False, **kwargs, ) -> None: r"""Construct a loss for local and global tensors. @@ -40,6 +41,8 @@ def __init__( The prefactor of the weight of global loss. It should be larger than or equal to 0. inference : bool If true, it will output all losses found in output, ignoring the pre-factors. + enable_atomic_weight : bool + If true, atomic weight will be used in the loss calculation. **kwargs Other keyword arguments. """ @@ -50,6 +53,7 @@ def __init__( self.local_weight = pref_atomic self.global_weight = pref self.inference = inference + self.enable_atomic_weight = enable_atomic_weight assert ( self.local_weight >= 0.0 and self.global_weight >= 0.0 @@ -85,6 +89,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False """ model_pred = model(**input_dict) del learning_rate, mae + + if self.enable_atomic_weight: + atomic_weight = label["atom_weight"].reshape([-1, 1]) + else: + atomic_weight = 1.0 + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] more_loss = {} if ( @@ -103,6 +113,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False diff = (local_tensor_pred - local_tensor_label).reshape( [-1, self.tensor_size] ) + diff = diff * atomic_weight if "mask" in model_pred: diff = diff[model_pred["mask"].reshape([-1]).bool()] l2_local_loss = torch.mean(torch.square(diff)) @@ -171,4 +182,15 @@ def label_requirement(self) -> list[DataRequirementItem]: high_prec=False, ) ) + if self.enable_atomic_weight: + label_requirement.append( + DataRequirementItem( + "atomic_weight", + ndof=1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + ) + ) return label_requirement diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index a64eca0fe9..c83e35dab3 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -125,6 +125,14 @@ def get_type_map(self) -> list[str]: """Get the type map.""" return self.type_map + def get_compute_stats_distinguish_types(self) -> bool: + """Get whether the fitting net computes stats which are not distinguished between different types of atoms.""" + return True + + def get_intensive(self) -> bool: + """Whether the fitting property is intensive.""" + return False + def reinit_atom_exclude( self, exclude_types: list[int] = [], @@ -456,7 +464,6 @@ def change_out_bias( model_forward=self._get_forward_wrapper_func(), rcond=self.rcond, preset_bias=self.preset_out_bias, - atomic_output=self.atomic_output_def(), ) self._store_out_stat(delta_bias, out_std, add=True) elif bias_adjust_mode == "set-by-statistic": @@ -467,7 +474,8 @@ def change_out_bias( stat_file_path=stat_file_path, rcond=self.rcond, preset_bias=self.preset_out_bias, - atomic_output=self.atomic_output_def(), + stats_distinguish_types=self.get_compute_stats_distinguish_types(), + intensive=self.get_intensive(), ) self._store_out_stat(bias_out, std_out) else: diff --git a/deepmd/pt/model/atomic_model/property_atomic_model.py b/deepmd/pt/model/atomic_model/property_atomic_model.py index 1fdc72b2b6..3622c9f476 100644 --- a/deepmd/pt/model/atomic_model/property_atomic_model.py +++ b/deepmd/pt/model/atomic_model/property_atomic_model.py @@ -19,31 +19,31 @@ def __init__(self, descriptor, fitting, type_map, **kwargs): ) super().__init__(descriptor, fitting, type_map, **kwargs) + def get_compute_stats_distinguish_types(self) -> bool: + """Get whether the fitting net computes stats which are not distinguished between different types of atoms.""" + return False + + def get_intensive(self) -> bool: + """Whether the fitting property is intensive.""" + return self.fitting_net.get_intensive() + def apply_out_stat( self, ret: dict[str, torch.Tensor], atype: torch.Tensor, ): """Apply the stat to each atomic output. - This function defines how the bias is applied to the atomic output of the model. + In property fitting, each output will be multiplied by label std and then plus the label average value. Parameters ---------- ret The returned dict by the forward_atomic method atype - The atom types. nf x nloc + The atom types. nf x nloc. It is useless in property fitting. """ - if self.fitting_net.get_bias_method() == "normal": - out_bias, out_std = self._fetch_out_stat(self.bias_keys) - for kk in self.bias_keys: - # nf x nloc x odims, out_bias: ntypes x odims - ret[kk] = ret[kk] + out_bias[kk][atype] - return ret - elif self.fitting_net.get_bias_method() == "no_bias": - return ret - else: - raise NotImplementedError( - "Only 'normal' and 'no_bias' is supported for parameter 'bias_method'." - ) + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + ret[kk] = ret[kk] * out_std[kk][0] + out_bias[kk][0] + return ret diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index c8e430960b..f086a346b6 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -100,7 +100,7 @@ def __init__( use_tebd_bias: bool = False, type_map: Optional[list[str]] = None, ) -> None: - r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492. + r"""The DPA-2 descriptor[1]_. Parameters ---------- @@ -147,6 +147,11 @@ def __init__( sw: torch.Tensor The switch function for decaying inverse distance. + References + ---------- + .. [1] Zhang, D., Liu, X., Zhang, X. et al. DPA-2: a + large atomic model as a multi-task learner. npj + Comput Mater 10, 293 (2024). https://doi.org/10.1038/s41524-024-01493-2 """ super().__init__() diff --git a/deepmd/pt/model/model/property_model.py b/deepmd/pt/model/model/property_model.py index 4581a2bc3e..7c50c75ff1 100644 --- a/deepmd/pt/model/model/property_model.py +++ b/deepmd/pt/model/model/property_model.py @@ -37,8 +37,8 @@ def __init__( def translated_output_def(self): out_def_data = self.model_output_def().get_data() output_def = { - "atom_property": out_def_data["property"], - "property": out_def_data["property_redu"], + f"atom_{self.get_var_name()}": out_def_data[self.get_var_name()], + self.get_var_name(): out_def_data[f"{self.get_var_name()}_redu"], } if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] @@ -62,8 +62,8 @@ def forward( do_atomic_virial=do_atomic_virial, ) model_predict = {} - model_predict["atom_property"] = model_ret["property"] - model_predict["property"] = model_ret["property_redu"] + model_predict[f"atom_{self.get_var_name()}"] = model_ret[self.get_var_name()] + model_predict[self.get_var_name()] = model_ret[f"{self.get_var_name()}_redu"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] return model_predict @@ -76,7 +76,12 @@ def get_task_dim(self) -> int: @torch.jit.export def get_intensive(self) -> bool: """Get whether the property is intensive.""" - return self.model_output_def()["property"].intensive + return self.model_output_def()[self.get_var_name()].intensive + + @torch.jit.export + def get_var_name(self) -> str: + """Get the name of the property.""" + return self.get_fitting_net().var_name @torch.jit.export def forward_lower( @@ -102,8 +107,8 @@ def forward_lower( extra_nlist_sort=self.need_sorted_nlist_for_lower(), ) model_predict = {} - model_predict["atom_property"] = model_ret["property"] - model_predict["property"] = model_ret["property_redu"] + model_predict[f"atom_{self.get_var_name()}"] = model_ret[self.get_var_name()] + model_predict[self.get_var_name()] = model_ret[f"{self.get_var_name()}_redu"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] return model_predict diff --git a/deepmd/pt/model/task/property.py b/deepmd/pt/model/task/property.py index dec0f1447b..c15e60fe04 100644 --- a/deepmd/pt/model/task/property.py +++ b/deepmd/pt/model/task/property.py @@ -43,17 +43,16 @@ class PropertyFittingNet(InvarFitting): dim_descrpt : int Embedding width per atom. task_dim : int - The dimension of outputs of fitting net. + The dimension of outputs of fitting net. + property_name: + The name of fitting property, which should be consistent with the property name in the dataset. + If the data file is named `humo.npy`, this parameter should be "humo". neuron : list[int] Number of neurons in each hidden layers of the fitting net. bias_atom_p : torch.Tensor, optional Average property per atom for each element. intensive : bool, optional Whether the fitting property is intensive. - bias_method : str, optional - The method of applying the bias to each atomic output, user can select 'normal' or 'no_bias'. - If 'normal' is used, the computed bias will be added to the atomic output. - If 'no_bias' is used, no bias will be added to the atomic output. resnet_dt : bool Using time-step in the ResNet construction. numb_fparam : int @@ -77,11 +76,11 @@ def __init__( self, ntypes: int, dim_descrpt: int, + property_name: str, task_dim: int = 1, neuron: list[int] = [128, 128, 128], bias_atom_p: Optional[torch.Tensor] = None, intensive: bool = False, - bias_method: str = "normal", resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, @@ -94,9 +93,8 @@ def __init__( ) -> None: self.task_dim = task_dim self.intensive = intensive - self.bias_method = bias_method super().__init__( - var_name="property", + var_name=property_name, ntypes=ntypes, dim_descrpt=dim_descrpt, dim_out=task_dim, @@ -113,9 +111,6 @@ def __init__( **kwargs, ) - def get_bias_method(self) -> str: - return self.bias_method - def output_def(self) -> FittingOutputDef: return FittingOutputDef( [ @@ -130,12 +125,16 @@ def output_def(self) -> FittingOutputDef: ] ) + def get_intensive(self) -> bool: + """Whether the fitting property is intensive.""" + return self.intensive + @classmethod def deserialize(cls, data: dict) -> "PropertyFittingNet": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 3, 1) + check_version_compatibility(data.pop("@version", 1), 4, 1) data.pop("dim_out") - data.pop("var_name") + data["property_name"] = data.pop("var_name") obj = super().deserialize(data) return obj @@ -146,7 +145,9 @@ def serialize(self) -> dict: **InvarFitting.serialize(self), "type": "property", "task_dim": self.task_dim, + "intensive": self.intensive, } + dd["@version"] = 4 return dd diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 8ca510492c..eca952d7f8 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1240,7 +1240,11 @@ def get_loss(loss_params, start_lr, _ntypes, _model): return TensorLoss(**loss_params) elif loss_type == "property": task_dim = _model.get_task_dim() + var_name = _model.get_var_name() + intensive = _model.get_intensive() loss_params["task_dim"] = task_dim + loss_params["var_name"] = var_name + loss_params["intensive"] = intensive return PropertyLoss(**loss_params) else: loss_params["starter_learning_rate"] = start_lr diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 1c5e3f1c52..710d392ac3 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -12,9 +12,6 @@ import numpy as np import torch -from deepmd.dpmodel.output_def import ( - FittingOutputDef, -) from deepmd.pt.utils import ( AtomExcludeMask, ) @@ -27,6 +24,7 @@ to_torch_tensor, ) from deepmd.utils.out_stat import ( + compute_stats_do_not_distinguish_types, compute_stats_from_atomic, compute_stats_from_redu, ) @@ -136,11 +134,16 @@ def _post_process_stat( For global statistics, we do not have the std for each type of atoms, thus fake the output std by ones for all the types. + If the shape of out_std is already the same as out_bias, + we do not need to do anything. """ new_std = {} for kk, vv in out_bias.items(): - new_std[kk] = np.ones_like(vv) + if vv.shape == out_std[kk].shape: + new_std[kk] = out_std[kk] + else: + new_std[kk] = np.ones_like(vv) return out_bias, new_std @@ -242,7 +245,8 @@ def compute_output_stats( rcond: Optional[float] = None, preset_bias: Optional[dict[str, list[Optional[np.ndarray]]]] = None, model_forward: Optional[Callable[..., torch.Tensor]] = None, - atomic_output: Optional[FittingOutputDef] = None, + stats_distinguish_types: bool = True, + intensive: bool = False, ): """ Compute the output statistics (e.g. energy bias) for the fitting net from packed data. @@ -272,8 +276,10 @@ def compute_output_stats( If not None, the model will be utilized to generate the original energy prediction, which will be subtracted from the energy label of the data. The difference will then be used to calculate the delta complement energy bias for each type. - atomic_output : FittingOutputDef, optional - The output of atomic model. + stats_distinguish_types : bool, optional + Whether to distinguish different element types in the statistics. + intensive : bool, optional + Whether the fitting target is intensive. """ # try to restore the bias from stat file bias_atom_e, std_atom_e = _restore_from_file(stat_file_path, keys) @@ -362,7 +368,8 @@ def compute_output_stats( rcond, preset_bias, model_pred_g, - atomic_output, + stats_distinguish_types, + intensive, ) bias_atom_a, std_atom_a = compute_output_stats_atomic( sampled, @@ -405,7 +412,8 @@ def compute_output_stats_global( rcond: Optional[float] = None, preset_bias: Optional[dict[str, list[Optional[np.ndarray]]]] = None, model_pred: Optional[dict[str, np.ndarray]] = None, - atomic_output: Optional[FittingOutputDef] = None, + stats_distinguish_types: bool = True, + intensive: bool = False, ): """This function only handle stat computation from reduced global labels.""" # return directly if model predict is empty for global @@ -476,19 +484,22 @@ def compute_output_stats_global( std_atom_e = {} for kk in keys: if kk in stats_input: - if atomic_output is not None and atomic_output.get_data()[kk].intensive: - task_dim = stats_input[kk].shape[1] - assert merged_natoms[kk].shape == (nf[kk], ntypes) - stats_input[kk] = ( - merged_natoms[kk].sum(axis=1).reshape(-1, 1) * stats_input[kk] + if not stats_distinguish_types: + bias_atom_e[kk], std_atom_e[kk] = ( + compute_stats_do_not_distinguish_types( + stats_input[kk], + merged_natoms[kk], + assigned_bias=assigned_atom_ener[kk], + intensive=intensive, + ) + ) + else: + bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_redu( + stats_input[kk], + merged_natoms[kk], + assigned_bias=assigned_atom_ener[kk], + rcond=rcond, ) - assert stats_input[kk].shape == (nf[kk], task_dim) - bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_redu( - stats_input[kk], - merged_natoms[kk], - assigned_bias=assigned_atom_ener[kk], - rcond=rcond, - ) else: # this key does not have global labels, skip it. continue diff --git a/deepmd/tf/loss/tensor.py b/deepmd/tf/loss/tensor.py index aca9182ff6..d7f879b4b4 100644 --- a/deepmd/tf/loss/tensor.py +++ b/deepmd/tf/loss/tensor.py @@ -40,6 +40,7 @@ def __init__(self, jdata, **kwarg) -> None: # YWolfeee: modify, use pref / pref_atomic, instead of pref_weight / pref_atomic_weight self.local_weight = jdata.get("pref_atomic", None) self.global_weight = jdata.get("pref", None) + self.enable_atomic_weight = jdata.get("enable_atomic_weight", False) assert ( self.local_weight is not None and self.global_weight is not None @@ -66,9 +67,18 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix): "global_loss": global_cvt_2_tf_float(0.0), } + if self.enable_atomic_weight: + atomic_weight = tf.reshape(label_dict["atom_weight"], [-1, 1]) + else: + atomic_weight = global_cvt_2_tf_float(1.0) + if self.local_weight > 0.0: + diff = tf.reshape(polar, [-1, self.tensor_size]) - tf.reshape( + atomic_polar_hat, [-1, self.tensor_size] + ) + diff = diff * atomic_weight local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean( - tf.square(self.scale * (polar - atomic_polar_hat)), name="l2_" + suffix + tf.square(self.scale * diff), name="l2_" + suffix ) more_loss["local_loss"] = self.display_if_exist(local_loss, find_atomic) l2_loss += self.local_weight * local_loss @@ -163,4 +173,16 @@ def label_requirement(self) -> list[DataRequirementItem]: type_sel=self.type_sel, ) ) + if self.enable_atomic_weight: + data_requirements.append( + DataRequirementItem( + "atom_weight", + 1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + type_sel=self.type_sel, + ) + ) return data_requirements diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0a8d61fe7e..b076952831 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1458,6 +1458,10 @@ def dpa3_repflow_args(): "The head number of multiple edge messages to update node feature. " "Default is 1, indicating one head edge message." ) + doc_n_multi_edge_message = ( + "The head number of multiple edge messages to update node feature. " + "Default is 1, indicating one head edge message." + ) doc_axis_neuron = "The number of dimension of submatrix in the symmetrization ops." doc_update_angle = ( "Where to update the angle rep. If not, only node and edge rep will be used." @@ -1786,7 +1790,7 @@ def fitting_property(): doc_seed = "Random seed for parameter initialization of the fitting net" doc_task_dim = "The dimension of outputs of fitting net" doc_intensive = "Whether the fitting property is intensive" - doc_bias_method = "The method of applying the bias to each atomic output, user can select 'normal' or 'no_bias'. If 'no_bias' is used, no bias will be added to the atomic output." + doc_property_name = "The names of fitting property, which should be consistent with the property name in the dataset." return [ Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), @@ -1818,7 +1822,10 @@ def fitting_property(): Argument("task_dim", int, optional=True, default=1, doc=doc_task_dim), Argument("intensive", bool, optional=True, default=False, doc=doc_intensive), Argument( - "bias_method", str, optional=True, default="normal", doc=doc_bias_method + "property_name", + str, + optional=False, + doc=doc_property_name, ), ] @@ -2723,8 +2730,9 @@ def loss_property(): def loss_tensor(): # doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If only `pref` is provided or both are not provided, training will be global mode, i.e. the shape of 'polarizability.npy` or `dipole.npy` should be #frams x [9 or 3]." # doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If only `pref_atomic` is provided, training will be atomic mode, i.e. the shape of `polarizability.npy` or `dipole.npy` should be #frames x ([9 or 3] x #selected atoms). If both `pref` and `pref_atomic` are provided, training will be combined mode, and atomic label should be provided as well." - doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included." - doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #selected atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0." + doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included." + doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0." + doc_enable_atomic_weight = "If true, the atomic loss will be reweighted." return [ Argument( "pref", [float, int], optional=False, default=None, doc=doc_global_weight @@ -2736,6 +2744,13 @@ def loss_tensor(): default=None, doc=doc_local_weight, ), + Argument( + "enable_atomic_weight", + bool, + optional=True, + default=False, + doc=doc_enable_atomic_weight, + ), ] diff --git a/deepmd/utils/out_stat.py b/deepmd/utils/out_stat.py index 4d0d788f8b..ecbd379e2d 100644 --- a/deepmd/utils/out_stat.py +++ b/deepmd/utils/out_stat.py @@ -130,3 +130,64 @@ def compute_stats_from_atomic( output[mask].std(axis=0) if output[mask].size > 0 else np.nan ) return output_bias, output_std + + +def compute_stats_do_not_distinguish_types( + output_redu: np.ndarray, + natoms: np.ndarray, + assigned_bias: Optional[np.ndarray] = None, + intensive: bool = False, +) -> tuple[np.ndarray, np.ndarray]: + """Compute element-independent statistics for property fitting. + + Computes mean and standard deviation of the output, treating all elements equally. + For extensive properties, the output is normalized by the total number of atoms + before computing statistics. + + Parameters + ---------- + output_redu + The reduced output value, shape is [nframes, *(odim0, odim1, ...)]. + natoms + The number of atoms for each atom, shape is [nframes, ntypes]. + Used for normalization of extensive properties and generating uniform bias. + assigned_bias + The assigned output bias, shape is [ntypes, *(odim0, odim1, ...)]. + Set to a tensor of shape (odim0, odim1, ...) filled with nan if the bias + of the type is not assigned. + intensive + Whether the output is intensive or extensive. + If False, the output will be normalized by the total number of atoms before computing statistics. + + Returns + ------- + np.ndarray + The computed output mean(fake bias), shape is [ntypes, *(odim0, odim1, ...)]. + The same bias is used for all atom types. + np.ndarray + The computed output standard deviation, shape is [ntypes, *(odim0, odim1, ...)]. + The same standard deviation is used for all atom types. + """ + natoms = np.array(natoms) # [nf, ntypes] + nf, ntypes = natoms.shape + output_redu = np.array(output_redu) + var_shape = list(output_redu.shape[1:]) + output_redu = output_redu.reshape(nf, -1) + if not intensive: + total_atoms = natoms.sum(axis=1) + output_redu = output_redu / total_atoms[:, np.newaxis] + # check shape + assert output_redu.ndim == 2 + assert natoms.ndim == 2 + assert output_redu.shape[0] == natoms.shape[0] # [nf,1] + + computed_output_bias = np.repeat( + np.mean(output_redu, axis=0)[np.newaxis, :], ntypes, axis=0 + ) + output_std = np.std(output_redu, axis=0) + + computed_output_bias = computed_output_bias.reshape([natoms.shape[1]] + var_shape) # noqa: RUF005 + output_std = output_std.reshape(var_shape) + output_std = np.tile(output_std, (computed_output_bias.shape[0], 1)) + + return computed_output_bias, output_std diff --git a/doc/credits.rst b/doc/credits.rst index 1b39dc1e0e..059746ee0b 100644 --- a/doc/credits.rst +++ b/doc/credits.rst @@ -54,7 +54,7 @@ Cite DeePMD-kit and methods .. bibliography:: :filter: False - Zhang_2023_DPA2 + Zhang_npjComputMater_2024_v10_p293 - If frame-specific parameters (`fparam`, e.g. electronic temperature) is used, diff --git a/doc/development/create-a-model-pt.md b/doc/development/create-a-model-pt.md index 08528cc5f6..7eb75b7026 100644 --- a/doc/development/create-a-model-pt.md +++ b/doc/development/create-a-model-pt.md @@ -180,7 +180,7 @@ The arguments here should be consistent with the class arguments of your new com ## Package new codes You may package new codes into a new Python package if you don't want to contribute it to the main DeePMD-kit repository. -A good example is [DeePMD-GNN](https://github.com/njzjz/deepmd-gnn). +A good example is [DeePMD-GNN](https://gitlab.com/RutgersLBSR/deepmd-gnn). It's crucial to add your new component to `project.entry-points."deepmd.pt"` in `pyproject.toml`: ```toml diff --git a/doc/install/install-from-c-library.md b/doc/install/install-from-c-library.md index d408fb1b67..806be51ca9 100644 --- a/doc/install/install-from-c-library.md +++ b/doc/install/install-from-c-library.md @@ -1,4 +1,4 @@ -# Install from pre-compiled C library {{ tensorflow_icon }}, JAX {{ jax_icon }} +# Install from pre-compiled C library {{ tensorflow_icon }} {{ jax_icon }} :::{note} **Supported backends**: TensorFlow {{ tensorflow_icon }}, JAX {{ jax_icon }} diff --git a/doc/model/dpa2.md b/doc/model/dpa2.md index eb641d6b01..300876bf05 100644 --- a/doc/model/dpa2.md +++ b/doc/model/dpa2.md @@ -4,7 +4,7 @@ **Supported backends**: PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: -The DPA-2 model implementation. See https://arxiv.org/abs/2312.15492 for more details. +The DPA-2 model implementation. See https://doi.org/10.1038/s41524-024-01493-2 for more details. Training example: `examples/water/dpa2/input_torch_medium.json`, see [README](../../examples/water/dpa2/README.md) for inputs in different levels. diff --git a/doc/model/index.rst b/doc/model/index.rst index c067ea4207..5e7ba32486 100644 --- a/doc/model/index.rst +++ b/doc/model/index.rst @@ -16,6 +16,7 @@ Model train-energy-spin train-fitting-tensor train-fitting-dos + train-fitting-property train-se-e2-a-tebd train-se-a-mask train-se-e3-tebd diff --git a/doc/model/train-fitting-property.md b/doc/model/train-fitting-property.md new file mode 100644 index 0000000000..d624a2fa3d --- /dev/null +++ b/doc/model/train-fitting-property.md @@ -0,0 +1,195 @@ +# Fit other properties {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} + +:::{note} +**Supported backends**: PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} +::: + +Here we present an API to DeepProperty model, which can be used to fit other properties like band gap, bulk modulus, critical temperature, etc. + +In this example, we will show you how to train a model to fit properties of `humo`, `lumo` and `band gap`. A complete training input script of the examples can be found in + +```bash +$deepmd_source_dir/examples/property/train +``` + +The training and validation data are also provided our examples. But note that **the data provided along with the examples are of limited amount, and should not be used to train a production model.** + +Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found [here](train-se-atten.md). To fit the `property`, one needs to modify {ref}`model[standard]/fitting_net ` and {ref}`loss `. + +## The fitting Network + +The {ref}`fitting_net ` section tells DP which fitting net to use. + +The JSON of `property` type should be provided like + +```json + "fitting_net" : { + "type": "property", + "intensive": true, + "property_name": "band_prop", + "task_dim": 3, + "neuron": [240,240,240], + "resnet_dt": true, + "fparam": 0, + "seed": 1, + }, +``` + +- `type` specifies which type of fitting net should be used. It should be `property`. +- `intensive` indicates whether the fitting property is intensive. If `intensive` is `true`, the model output is the average of the property contribution of each atom. If `intensive` is `false`, the model output is the sum of the property contribution of each atom. +- `property_name` is the name of the property to be predicted. It should be consistent with the property name in the dataset. In each system, code will read `set.*/{property_name}.npy` file as prediction label if you use NumPy format data. +- `fitting_net/task_dim` is the dimension of model output. It should be consistent with the property dimension in the dataset, which means if the shape of data stored in `set.*/{property_name}.npy` is `batch size * 3`, `fitting_net/task_dim` should be set to 3. +- The rest arguments have the same meaning as they do in `ener` mode. + +## Loss + +DeepProperty supports trainings of the global system (one or more global labels are provided in a frame). For example, when fitting `property`, each frame will provide a `1 x task_dim` vector which gives the fitting properties. + +The loss section should be provided like + +```json + "loss" : { + "type": "property", + "metric": ["mae"], + "loss_func": "smooth_mae" + }, +``` + +- {ref}`type ` should be written as `property` as a distinction from `ener` mode. +- `metric`: The metric for display, which will be printed in `lcurve.out`. This list can include 'smooth_mae', 'mae', 'mse' and 'rmse'. +- `loss_func`: The loss function to minimize, you can use 'mae','smooth_mae', 'mse' and 'rmse'. + +## Training Data Preparation + +The label should be named `{property_name}.npy/raw`, `property_name` is defined by `fitting_net/property_name` in `input.json`. + +To prepare the data, you can use `dpdata` tools, for example: + +```py +import dpdata +import numpy as np +from dpdata.data_type import ( + Axis, + DataType, +) + +property_name = "band_prop" # fittng_net/property_name +task_dim = 3 # fitting_net/task_dim + +# register datatype +datatypes = [ + DataType( + property_name, + np.ndarray, + shape=(Axis.NFRAMES, task_dim), + required=False, + ), +] +datatypes.extend( + [ + DataType( + "energies", + np.ndarray, + shape=(Axis.NFRAMES, 1), + required=False, + ), + DataType( + "forces", + np.ndarray, + shape=(Axis.NFRAMES, Axis.NATOMS, 1), + required=False, + ), + ] +) + +for datatype in datatypes: + dpdata.System.register_data_type(datatype) + dpdata.LabeledSystem.register_data_type(datatype) + +ls = dpdata.MultiSystems() +frame = dpdata.System("POSCAR", fmt="vasp/poscar") +labelframe = dpdata.LabeledSystem() +labelframe.append(frame) +labelframe.data[property_name] = np.array([[-0.236, 0.056, 0.292]], dtype=np.float32) +ls.append(labelframe) +ls.to_deepmd_npy_mixed("deepmd") +``` + +## Train the Model + +The training command is the same as `ener` mode, i.e. + +::::{tab-set} + +:::{tab-item} PyTorch {{ pytorch_icon }} + +```bash +dp --pt train input.json +``` + +::: + +:::: + +The detailed loss can be found in `lcurve.out`: + +``` +# step mae_val mae_trn lr +# If there is no available reference data, rmse_*_{val,trn} will print nan + 1 2.72e-02 2.40e-02 2.0e-04 + 100 1.79e-02 1.34e-02 2.0e-04 + 200 1.45e-02 1.86e-02 2.0e-04 + 300 1.61e-02 4.90e-03 2.0e-04 + 400 2.04e-02 1.05e-02 2.0e-04 + 500 9.09e-03 1.85e-02 2.0e-04 + 600 1.01e-02 5.63e-03 2.0e-04 + 700 1.10e-02 1.76e-02 2.0e-04 + 800 1.14e-02 1.50e-02 2.0e-04 + 900 9.54e-03 2.70e-02 2.0e-04 + 1000 1.00e-02 2.73e-02 2.0e-04 +``` + +## Test the Model + +We can use `dp test` to infer the properties for given frames. + +::::{tab-set} + +:::{tab-item} PyTorch {{ pytorch_icon }} + +```bash + +dp --pt freeze -o frozen_model.pth + +dp --pt test -m frozen_model.pth -s ../data/data_0/ -d ${output_prefix} -n 100 +``` + +::: + +:::: + +if `dp test -d ${output_prefix}` is specified, the predicted properties for each frame are output in the working directory + +``` +${output_prefix}.property.out.0 ${output_prefix}.property.out.1 ${output_prefix}.property.out.2 ${output_prefix}.property.out.3 +``` + +for `*.property.out.*`, it contains matrix with shape of `(2, task_dim)`, + +``` +# ../data/data_0 - 0: data_property pred_property +-2.449000030755996704e-01 -2.315840660495154801e-01 +6.400000303983688354e-02 5.810663314446311983e-02 +3.088999986648559570e-01 2.917143316092784544e-01 +``` + +## Data Normalization + +When `fitting_net/type` is `ener`, the energy bias layer “$e_{bias}$” adds a constant bias to the atomic energy contribution according to the atomic number.i.e., +$$e_{bias} (Z_i) (MLP(D_i))= MLP(D_i) + e_{bias} (Z_i)$$ + +But when `fitting_net/type` is `property`. The property bias layer is used to normalize the property output of the model.i.e., +$$p_{bias} (MLP(D_i))= MLP(D_i) * std+ mean$$ + +1. `std`: The standard deviation of the property label +2. `mean`: The average value of the property label diff --git a/doc/third-party/out-of-deepmd-kit.md b/doc/third-party/out-of-deepmd-kit.md index 12ae5842c7..a04ba9741b 100644 --- a/doc/third-party/out-of-deepmd-kit.md +++ b/doc/third-party/out-of-deepmd-kit.md @@ -6,7 +6,7 @@ The codes of the following interfaces are not a part of the DeePMD-kit package a ### External GNN models (MACE/NequIP) -[DeePMD-GNN](https://github.com/njzjz/deepmd-gnn) is DeePMD-kit plugin for various graph neural network (GNN) models. +[DeePMD-GNN](https://gitlab.com/RutgersLBSR/deepmd-gnn) is DeePMD-kit plugin for various graph neural network (GNN) models. It has interfaced with [MACE](https://github.com/ACEsuit/mace) (PyTorch version) and [NequIP](https://github.com/mir-group/nequip) (PyTorch version). It is also the first example to the DeePMD-kit [plugin mechanism](../development/create-a-model-pt.md#package-new-codes). diff --git a/doc/train/finetuning.md b/doc/train/finetuning.md index cf2f5fde4f..04d86cfc98 100644 --- a/doc/train/finetuning.md +++ b/doc/train/finetuning.md @@ -94,7 +94,7 @@ The model section will be overwritten (except the `type_map` subsection) by that #### Fine-tuning from a multi-task pre-trained model -Additionally, within the PyTorch implementation and leveraging the flexibility offered by the framework and the multi-task training process proposed in DPA2 [paper](https://arxiv.org/abs/2312.15492), +Additionally, within the PyTorch implementation and leveraging the flexibility offered by the framework and the multi-task training process proposed in DPA2 [paper](https://doi.org/10.1038/s41524-024-01493-2), we also support more general multitask pre-trained models, which includes multiple datasets for pre-training. These pre-training datasets share a common descriptor while maintaining their individual fitting nets, as detailed in the paper above. diff --git a/doc/train/multi-task-training.md b/doc/train/multi-task-training.md index 51dffcc5f5..16f6c0e05c 100644 --- a/doc/train/multi-task-training.md +++ b/doc/train/multi-task-training.md @@ -26,7 +26,7 @@ and the Adam optimizer is executed to minimize $L^{(t)}$ for one step to update In the case of multi-GPU parallel training, different GPUs will independently select their tasks. In the DPA-2 model, this multi-task training framework is adopted.[^1] -[^1]: Duo Zhang, Xinzijian Liu, Xiangyu Zhang, Chengqian Zhang, Chun Cai, Hangrui Bi, Yiming Du, Xuejian Qin, Jiameng Huang, Bowen Li, Yifan Shan, Jinzhe Zeng, Yuzhi Zhang, Siyuan Liu, Yifan Li, Junhan Chang, Xinyan Wang, Shuo Zhou, Jianchuan Liu, Xiaoshan Luo, Zhenyu Wang, Wanrun Jiang, Jing Wu, Yudi Yang, Jiyuan Yang, Manyi Yang, Fu-Qiang Gong, Linshuang Zhang, Mengchao Shi, Fu-Zhi Dai, Darrin M. York, Shi Liu, Tong Zhu, Zhicheng Zhong, Jian Lv, Jun Cheng, Weile Jia, Mohan Chen, Guolin Ke, Weinan E, Linfeng Zhang, Han Wang, [arXiv preprint arXiv:2312.15492 (2023)](https://arxiv.org/abs/2312.15492) licensed under a [Creative Commons Attribution (CC BY) license](http://creativecommons.org/licenses/by/4.0/). +[^1]: Duo Zhang, Xinzijian Liu, Xiangyu Zhang, Chengqian Zhang, Chun Cai, Hangrui Bi, Yiming Du, Xuejian Qin, Anyang Peng, Jiameng Huang, Bowen Li, Yifan Shan, Jinzhe Zeng, Yuzhi Zhang, Siyuan Liu, Yifan Li, Junhan Chang, Xinyan Wang, Shuo Zhou, Jianchuan Liu, Xiaoshan Luo, Zhenyu Wang, Wanrun Jiang, Jing Wu, Yudi Yang, Jiyuan Yang, Manyi Yang, Fu-Qiang Gong, Linshuang Zhang, Mengchao Shi, Fu-Zhi Dai, Darrin M. York, Shi Liu, Tong Zhu, Zhicheng Zhong, Jian Lv, Jun Cheng, Weile Jia, Mohan Chen, Guolin Ke, Weinan E, Linfeng Zhang, Han Wang, DPA-2: a large atomic model as a multi-task learner. npj Comput Mater 10, 293 (2024). [DOI: 10.1038/s41524-024-01493-2](https://doi.org/10.1038/s41524-024-01493-2) licensed under a [Creative Commons Attribution (CC BY) license](http://creativecommons.org/licenses/by/4.0/). Compared with the previous TensorFlow implementation, the new support in PyTorch is more flexible and efficient. In particular, it makes multi-GPU parallel training and even tasks beyond DFT possible, diff --git a/examples/property/data/data_0/set.000000/property.npy b/examples/property/data/data_0/set.000000/band_prop.npy similarity index 100% rename from examples/property/data/data_0/set.000000/property.npy rename to examples/property/data/data_0/set.000000/band_prop.npy diff --git a/examples/property/data/data_1/set.000000/property.npy b/examples/property/data/data_1/set.000000/band_prop.npy similarity index 100% rename from examples/property/data/data_1/set.000000/property.npy rename to examples/property/data/data_1/set.000000/band_prop.npy diff --git a/examples/property/data/data_2/set.000000/property.npy b/examples/property/data/data_2/set.000000/band_prop.npy similarity index 100% rename from examples/property/data/data_2/set.000000/property.npy rename to examples/property/data/data_2/set.000000/band_prop.npy diff --git a/examples/property/train/README.md b/examples/property/train/README.md new file mode 100644 index 0000000000..e4dc9ed704 --- /dev/null +++ b/examples/property/train/README.md @@ -0,0 +1,5 @@ +Some explanations of the parameters in `input.json`: + +1. `fitting_net/property_name` is the name of the property to be predicted. It should be consistent with the property name in the dataset. In each system, code will read `set.*/{property_name}.npy` file as prediction label if you use NumPy format data. +2. `fitting_net/task_dim` is the dimension of model output. It should be consistent with the property dimension in the dataset, which means if the shape of data stored in `set.*/{property_name}.npy` is `batch size * 3`, `fitting_net/task_dim` should be set to 3. +3. `fitting/intensive` indicates whether the fitting property is intensive. If `intensive` is `true`, the model output is the average of the property contribution of each atom. If `intensive` is `false`, the model output is the sum of the property contribution of each atom. diff --git a/examples/property/train/input_torch.json b/examples/property/train/input_torch.json index 33eaa28a07..1e6ce00048 100644 --- a/examples/property/train/input_torch.json +++ b/examples/property/train/input_torch.json @@ -33,6 +33,7 @@ "type": "property", "intensive": true, "task_dim": 3, + "property_name": "band_prop", "neuron": [ 240, 240, @@ -53,6 +54,11 @@ }, "loss": { "type": "property", + "metric": [ + "mae" + ], + "loss_func": "smooth_mae", + "beta": 1.0, "_comment": " that's all" }, "training": { diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index c51ae9a8b4..d3cad083bd 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -390,7 +390,13 @@ static inline void _load_library_path(std::string dso_path) { if (!dso_handle) { throw deepmd::deepmd_exception( dso_path + - " is not found! You can add the library directory to LD_LIBRARY_PATH"); + " is not found or fails to load! You can add the library directory to " + "LD_LIBRARY_PATH." +#ifndef _WIN32 + " Error message: " + + std::string(dlerror()) +#endif + ); } } diff --git a/source/lib/src/gpu/cudart/cudart_stub.cc b/source/lib/src/gpu/cudart/cudart_stub.cc index 8083a0a89d..cfbabd6f5e 100644 --- a/source/lib/src/gpu/cudart/cudart_stub.cc +++ b/source/lib/src/gpu/cudart/cudart_stub.cc @@ -25,6 +25,10 @@ void *DP_cudart_dlopen(char *libname) { #endif if (!dso_handle) { std::cerr << "DeePMD-kit: Cannot find " << libname << std::endl; +#ifndef _WIN32 + std::cerr << "DeePMD-kit: Error message: " << std::string(dlerror()) + << std::endl; +#endif return nullptr; } std::cerr << "DeePMD-kit: Successfully load " << libname << std::endl; diff --git a/source/tests/common/test_out_stat.py b/source/tests/common/test_out_stat.py index c175d7c643..0236c39f22 100644 --- a/source/tests/common/test_out_stat.py +++ b/source/tests/common/test_out_stat.py @@ -4,6 +4,7 @@ import numpy as np from deepmd.utils.out_stat import ( + compute_stats_do_not_distinguish_types, compute_stats_from_atomic, compute_stats_from_redu, ) @@ -89,6 +90,58 @@ def test_compute_stats_from_redu_with_assigned_bias(self) -> None: rtol=1e-7, ) + def test_compute_stats_do_not_distinguish_types_intensive(self) -> None: + """Test compute_stats_property function with intensive scenario.""" + bias, std = compute_stats_do_not_distinguish_types( + self.output_redu, self.natoms, intensive=True + ) + # Test shapes + assert bias.shape == (len(self.mean), self.output_redu.shape[1]) + assert std.shape == (len(self.mean), self.output_redu.shape[1]) + + # Test values + for fake_atom_bias in bias: + np.testing.assert_allclose( + fake_atom_bias, np.mean(self.output_redu, axis=0), rtol=1e-7 + ) + for fake_atom_std in std: + np.testing.assert_allclose( + fake_atom_std, np.std(self.output_redu, axis=0), rtol=1e-7 + ) + + def test_compute_stats_do_not_distinguish_types_extensive(self) -> None: + """Test compute_stats_property function with extensive scenario.""" + bias, std = compute_stats_do_not_distinguish_types( + self.output_redu, self.natoms + ) + # Test shapes + assert bias.shape == (len(self.mean), self.output_redu.shape[1]) + assert std.shape == (len(self.mean), self.output_redu.shape[1]) + + # Test values + for fake_atom_bias in bias: + np.testing.assert_allclose( + fake_atom_bias, + np.array( + [ + 6218.91610282, + 7183.82275736, + 4445.23155934, + 5748.23644722, + 5362.8519454, + ] + ), + rtol=1e-7, + ) + for fake_atom_std in std: + np.testing.assert_allclose( + fake_atom_std, + np.array( + [128.78691576, 36.53743668, 105.82372405, 96.43642486, 33.68885327] + ), + rtol=1e-7, + ) + def test_compute_stats_from_atomic(self) -> None: bias, std = compute_stats_from_atomic(self.output, self.atype) np.testing.assert_allclose(bias, self.mean) diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index 3abd672c88..4c359026c7 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -86,6 +86,7 @@ def data(self) -> dict: "seed": 20240217, "task_dim": task_dim, "intensive": intensive, + "property_name": "foo", } @property @@ -186,7 +187,7 @@ def eval_pt(self, pt_obj: Any) -> Any: aparam=torch.from_numpy(self.aparam).to(device=PT_DEVICE) if numb_aparam else None, - )["property"] + )[pt_obj.var_name] .detach() .cpu() .numpy() @@ -207,7 +208,7 @@ def eval_dp(self, dp_obj: Any) -> Any: self.atype.reshape(1, -1), fparam=self.fparam if numb_fparam else None, aparam=self.aparam if numb_aparam else None, - )["property"] + )[dp_obj.var_name] def eval_jax(self, jax_obj: Any) -> Any: ( @@ -225,7 +226,7 @@ def eval_jax(self, jax_obj: Any) -> Any: jnp.asarray(self.atype.reshape(1, -1)), fparam=jnp.asarray(self.fparam) if numb_fparam else None, aparam=jnp.asarray(self.aparam) if numb_aparam else None, - )["property"] + )[jax_obj.var_name] ) def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: @@ -244,7 +245,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: array_api_strict.asarray(self.atype.reshape(1, -1)), fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None, aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None, - )["property"] + )[array_api_strict_obj.var_name] ) def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: diff --git a/source/tests/consistent/model/test_property.py b/source/tests/consistent/model/test_property.py index 29786fb247..75aded98fd 100644 --- a/source/tests/consistent/model/test_property.py +++ b/source/tests/consistent/model/test_property.py @@ -56,6 +56,7 @@ def data(self) -> dict: "fitting_net": { "type": "property", "neuron": [4, 4, 4], + "property_name": "foo", "resnet_dt": True, "numb_fparam": 0, "precision": "float64", @@ -182,14 +183,15 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... + property_name = self.data["fitting_net"]["property_name"] if backend in {self.RefBackend.DP, self.RefBackend.JAX}: return ( - ret["property_redu"].ravel(), - ret["property"].ravel(), + ret[f"{property_name}_redu"].ravel(), + ret[property_name].ravel(), ) elif backend is self.RefBackend.PT: return ( - ret["property"].ravel(), - ret["atom_property"].ravel(), + ret[property_name].ravel(), + ret[f"atom_{property_name}"].ravel(), ) raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/pd/model/test_permutation.py b/source/tests/pd/model/test_permutation.py index 4543348d3b..135c5ea819 100644 --- a/source/tests/pd/model/test_permutation.py +++ b/source/tests/pd/model/test_permutation.py @@ -331,10 +331,10 @@ }, "fitting_net": { "type": "property", + "property_name": "band_property", "task_dim": 3, "neuron": [24, 24, 24], "resnet_dt": True, - "bias_method": "normal", "intensive": True, "seed": 1, }, diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index 5c7b8db9a4..e4eb47a540 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -331,9 +331,9 @@ "fitting_net": { "type": "property", "task_dim": 3, + "property_name": "band_property", "neuron": [24, 24, 24], "resnet_dt": True, - "bias_method": "normal", "intensive": True, "seed": 1, }, diff --git a/source/tests/pt/model/test_property_fitting.py b/source/tests/pt/model/test_property_fitting.py index 305d1be951..6825924bc1 100644 --- a/source/tests/pt/model/test_property_fitting.py +++ b/source/tests/pt/model/test_property_fitting.py @@ -61,7 +61,7 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for nfp, nap, bias_atom_p, intensive, bias_method in itertools.product( + for nfp, nap, bias_atom_p, intensive in itertools.product( [0, 3], [0, 4], [ @@ -69,18 +69,17 @@ def test_consistency( np.array([[11, 12, 13, 4, 15], [16, 17, 18, 9, 20]]), ], [True, False], - ["normal", "no_bias"], ): ft0 = PropertyFittingNet( self.nt, self.dd0.dim_out, task_dim=5, + property_name="foo", numb_fparam=nfp, numb_aparam=nap, mixed_types=self.dd0.mixed_types(), bias_atom_p=bias_atom_p, intensive=intensive, - bias_method=bias_method, seed=GLOBAL_SEED, ).to(env.DEVICE) @@ -120,36 +119,35 @@ def test_consistency( aparam=to_numpy_array(iap), ) np.testing.assert_allclose( - to_numpy_array(ret0["property"]), - ret1["property"], + to_numpy_array(ret0[ft0.var_name]), + ret1[ft1.var_name], ) np.testing.assert_allclose( - to_numpy_array(ret0["property"]), - to_numpy_array(ret2["property"]), + to_numpy_array(ret0[ft0.var_name]), + to_numpy_array(ret2[ft2.var_name]), ) np.testing.assert_allclose( - to_numpy_array(ret0["property"]), - ret3["property"], + to_numpy_array(ret0[ft0.var_name]), + ret3[ft3.var_name], ) def test_jit( self, ) -> None: - for nfp, nap, intensive, bias_method in itertools.product( + for nfp, nap, intensive in itertools.product( [0, 3], [0, 4], [True, False], - ["normal", "no_bias"], ): ft0 = PropertyFittingNet( self.nt, self.dd0.dim_out, task_dim=5, + property_name="foo", numb_fparam=nfp, numb_aparam=nap, mixed_types=self.dd0.mixed_types(), intensive=intensive, - bias_method=bias_method, seed=GLOBAL_SEED, ).to(env.DEVICE) torch.jit.script(ft0) @@ -201,6 +199,7 @@ def test_trans(self) -> None: self.nt, self.dd0.dim_out, task_dim=11, + property_name="bar", numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), @@ -229,7 +228,7 @@ def test_trans(self) -> None: ) ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) - res.append(ret0["property"]) + res.append(ret0[ft0.var_name]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -257,21 +256,20 @@ def test_rot(self) -> None: # use larger cell to rotate only coord and shift to the center of cell cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) - for nfp, nap, intensive, bias_method in itertools.product( + for nfp, nap, intensive in itertools.product( [0, 3], [0, 4], [True, False], - ["normal", "no_bias"], ): ft0 = PropertyFittingNet( self.nt, self.dd0.dim_out, # dim_descrpt - task_dim=9, + task_dim=5, + property_name="bar", numb_fparam=nfp, numb_aparam=nap, mixed_types=self.dd0.mixed_types(), intensive=intensive, - bias_method=bias_method, seed=GLOBAL_SEED, ).to(env.DEVICE) if nfp > 0: @@ -312,7 +310,7 @@ def test_rot(self) -> None: ) ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap) - res.append(ret0["property"]) + res.append(ret0[ft0.var_name]) np.testing.assert_allclose( to_numpy_array(res[1]), to_numpy_array(res[0]), @@ -324,6 +322,7 @@ def test_permu(self) -> None: self.nt, self.dd0.dim_out, task_dim=8, + property_name="abc", numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), @@ -353,7 +352,7 @@ def test_permu(self) -> None: ) ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) - res.append(ret0["property"]) + res.append(ret0[ft0.var_name]) np.testing.assert_allclose( to_numpy_array(res[0][:, idx_perm]), @@ -372,6 +371,7 @@ def test_trans(self) -> None: self.nt, self.dd0.dim_out, task_dim=11, + property_name="foo", numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), @@ -400,7 +400,7 @@ def test_trans(self) -> None: ) ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) - res.append(ret0["property"]) + res.append(ret0[ft0.var_name]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -422,6 +422,7 @@ def setUp(self) -> None: self.nt, self.dd0.dim_out, task_dim=3, + property_name="bar", numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), diff --git a/source/tests/pt/property/double/nopbc b/source/tests/pt/property/double/nopbc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/source/tests/pt/property/double/set.000000/band_property.npy b/source/tests/pt/property/double/set.000000/band_property.npy new file mode 100644 index 0000000000..042c1a8b0d Binary files /dev/null and b/source/tests/pt/property/double/set.000000/band_property.npy differ diff --git a/source/tests/pt/property/double/set.000000/coord.npy b/source/tests/pt/property/double/set.000000/coord.npy new file mode 100644 index 0000000000..9c781a81f3 Binary files /dev/null and b/source/tests/pt/property/double/set.000000/coord.npy differ diff --git a/source/tests/pt/property/double/set.000000/real_atom_types.npy b/source/tests/pt/property/double/set.000000/real_atom_types.npy new file mode 100644 index 0000000000..3bfe0abd94 Binary files /dev/null and b/source/tests/pt/property/double/set.000000/real_atom_types.npy differ diff --git a/source/tests/pt/property/double/type.raw b/source/tests/pt/property/double/type.raw new file mode 100644 index 0000000000..d677b495ec --- /dev/null +++ b/source/tests/pt/property/double/type.raw @@ -0,0 +1,20 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 diff --git a/source/tests/pt/property/double/type_map.raw b/source/tests/pt/property/double/type_map.raw new file mode 100644 index 0000000000..c8a39f3a9e --- /dev/null +++ b/source/tests/pt/property/double/type_map.raw @@ -0,0 +1,4 @@ +H +C +N +O diff --git a/source/tests/pt/property/input.json b/source/tests/pt/property/input.json index 4e005f8277..44bc1e6005 100644 --- a/source/tests/pt/property/input.json +++ b/source/tests/pt/property/input.json @@ -27,6 +27,7 @@ "fitting_net": { "type": "property", "intensive": true, + "property_name": "band_property", "task_dim": 3, "neuron": [ 100, diff --git a/source/tests/pt/property/single/set.000000/property.npy b/source/tests/pt/property/single/set.000000/band_property.npy similarity index 100% rename from source/tests/pt/property/single/set.000000/property.npy rename to source/tests/pt/property/single/set.000000/band_property.npy diff --git a/source/tests/pt/test_dp_test.py b/source/tests/pt/test_dp_test.py index dbec472cc0..c2915c7ee7 100644 --- a/source/tests/pt/test_dp_test.py +++ b/source/tests/pt/test_dp_test.py @@ -183,7 +183,7 @@ def test_dp_test_1_frame(self) -> None: pred_property = np.loadtxt(self.detail_file + ".property.out.0")[:, 1] np.testing.assert_almost_equal( pred_property, - to_numpy_array(result["property"])[0], + to_numpy_array(result[model.get_var_name()])[0], ) def tearDown(self) -> None: diff --git a/source/tests/pt/test_loss_tensor.py b/source/tests/pt/test_loss_tensor.py new file mode 100644 index 0000000000..5802c0b775 --- /dev/null +++ b/source/tests/pt/test_loss_tensor.py @@ -0,0 +1,464 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +import unittest + +import numpy as np +import tensorflow.compat.v1 as tf +import torch + +tf.disable_eager_execution() +from pathlib import ( + Path, +) + +from deepmd.pt.loss import TensorLoss as PTTensorLoss +from deepmd.pt.utils import ( + dp_random, + env, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSetForLoader, +) +from deepmd.tf.loss.tensor import TensorLoss as TFTensorLoss +from deepmd.utils.data import ( + DataRequirementItem, +) + +from ..seed import ( + GLOBAL_SEED, +) + +CUR_DIR = os.path.dirname(__file__) + + +def get_batch(system, type_map, data_requirement): + dataset = DeepmdDataSetForLoader(system, type_map) + dataset.add_data_requirement(data_requirement) + np_batch, pt_batch = get_single_batch(dataset) + return np_batch, pt_batch + + +def get_single_batch(dataset, index=None): + if index is None: + index = dp_random.choice(np.arange(len(dataset))) + np_batch = dataset[index] + pt_batch = {} + + for key in [ + "coord", + "box", + "atom_dipole", + "dipole", + "atom_polarizability", + "polarizability", + "atype", + "natoms", + ]: + if key in np_batch.keys(): + np_batch[key] = np.expand_dims(np_batch[key], axis=0) + pt_batch[key] = torch.as_tensor(np_batch[key], device=env.DEVICE) + if key in ["coord", "atom_dipole"]: + np_batch[key] = np_batch[key].reshape(1, -1) + np_batch["natoms"] = np_batch["natoms"][0] + return np_batch, pt_batch + + +class LossCommonTest(unittest.TestCase): + def setUp(self) -> None: + self.cur_lr = 1.2 + self.type_map = ["H", "O"] + + # data + tensor_data_requirement = [ + DataRequirementItem( + "atomic_" + self.label_name, + ndof=self.tensor_size, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + self.label_name, + ndof=self.tensor_size, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atomic_weight", + ndof=1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + ), + ] + np_batch, pt_batch = get_batch( + self.system, self.type_map, tensor_data_requirement + ) + natoms = np_batch["natoms"] + self.nloc = natoms[0] + self.nframes = np_batch["atom_" + self.label_name].shape[0] + rng = np.random.default_rng(GLOBAL_SEED) + + l_atomic_tensor, l_global_tensor = ( + np_batch["atom_" + self.label_name], + np_batch[self.label_name], + ) + p_atomic_tensor, p_global_tensor = ( + np.ones_like(l_atomic_tensor), + np.ones_like(l_global_tensor), + ) + + batch_size = pt_batch["coord"].shape[0] + + # atom_pref = rng.random(size=[batch_size, nloc * 3]) + # drdq = rng.random(size=[batch_size, nloc * 2 * 3]) + atom_weight = rng.random(size=[batch_size, self.nloc]) + + # tf + self.g = tf.Graph() + with self.g.as_default(): + t_cur_lr = tf.placeholder(shape=[], dtype=tf.float64) + t_natoms = tf.placeholder(shape=[None], dtype=tf.int32) + t_patomic_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_pglobal_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_latomic_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_lglobal_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_atom_weight = tf.placeholder(shape=[None, None], dtype=tf.float64) + find_atomic = tf.constant(1.0, dtype=tf.float64) + find_global = tf.constant(1.0, dtype=tf.float64) + find_atom_weight = tf.constant(1.0, dtype=tf.float64) + model_dict = { + self.tensor_name: t_patomic_tensor, + } + label_dict = { + "atom_" + self.label_name: t_latomic_tensor, + "find_atom_" + self.label_name: find_atomic, + self.label_name: t_lglobal_tensor, + "find_" + self.label_name: find_global, + "atom_weight": t_atom_weight, + "find_atom_weight": find_atom_weight, + } + self.tf_loss_sess = self.tf_loss.build( + t_cur_lr, t_natoms, model_dict, label_dict, "" + ) + + self.feed_dict = { + t_cur_lr: self.cur_lr, + t_natoms: natoms, + t_patomic_tensor: p_atomic_tensor, + t_pglobal_tensor: p_global_tensor, + t_latomic_tensor: l_atomic_tensor, + t_lglobal_tensor: l_global_tensor, + t_atom_weight: atom_weight, + } + # pt + self.model_pred = { + self.tensor_name: torch.from_numpy(p_atomic_tensor), + "global_" + self.tensor_name: torch.from_numpy(p_global_tensor), + } + self.label = { + "atom_" + self.label_name: torch.from_numpy(l_atomic_tensor), + "find_" + "atom_" + self.label_name: 1.0, + self.label_name: torch.from_numpy(l_global_tensor), + "find_" + self.label_name: 1.0, + "atom_weight": torch.from_numpy(atom_weight), + "find_atom_weight": 1.0, + } + self.label_absent = { + "atom_" + self.label_name: torch.from_numpy(l_atomic_tensor), + self.label_name: torch.from_numpy(l_global_tensor), + "atom_weight": torch.from_numpy(atom_weight), + } + self.natoms = pt_batch["natoms"] + + def tearDown(self) -> None: + tf.reset_default_graph() + return super().tearDown() + + +class TestAtomicDipoleLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "dipole" + self.tensor_size = 3 + self.label_name = "dipole" + self.system = str(Path(__file__).parent / "water_tensor/dipole/O78H156") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +class TestAtomicDipoleAWeightLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "dipole" + self.tensor_size = 3 + self.label_name = "dipole" + self.system = str(Path(__file__).parent / "water_tensor/dipole/O78H156") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + "enable_atomic_weight": True, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + enable_atomic_weight=True, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +class TestAtomicPolarLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "polar" + self.tensor_size = 9 + self.label_name = "polarizability" + + self.system = str(Path(__file__).parent / "water_tensor/polar/atomic_system") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +class TestAtomicPolarAWeightLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "polar" + self.tensor_size = 9 + self.label_name = "polarizability" + + self.system = str(Path(__file__).parent / "water_tensor/polar/atomic_system") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + "enable_atomic_weight": True, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + enable_atomic_weight=True, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 1fbd01c39f..ad52c5db16 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -464,7 +464,7 @@ def setUp(self) -> None: property_input = str(Path(__file__).parent / "property/input.json") with open(property_input) as f: self.config_property = json.load(f) - prop_data_file = [str(Path(__file__).parent / "property/single")] + prop_data_file = [str(Path(__file__).parent / "property/double")] self.config_property["training"]["training_data"]["systems"] = prop_data_file self.config_property["training"]["validation_data"]["systems"] = prop_data_file self.config_property["model"]["descriptor"] = deepcopy(model_dpa1["descriptor"]) diff --git a/source/tests/universal/common/cases/model/model.py b/source/tests/universal/common/cases/model/model.py index cee69d9d6c..06ddd90970 100644 --- a/source/tests/universal/common/cases/model/model.py +++ b/source/tests/universal/common/cases/model/model.py @@ -165,7 +165,7 @@ def setUpClass(cls) -> None: cls.expected_dim_aparam = 0 cls.expected_sel_type = [0, 1] cls.expected_aparam_nall = False - cls.expected_model_output_type = ["property", "mask"] + cls.expected_model_output_type = ["band_prop", "mask"] cls.model_output_equivariant = [] cls.expected_sel = [46, 92] cls.expected_sel_mix = sum(cls.expected_sel) diff --git a/source/tests/universal/dpmodel/fitting/test_fitting.py b/source/tests/universal/dpmodel/fitting/test_fitting.py index db199c02a3..2fe0060003 100644 --- a/source/tests/universal/dpmodel/fitting/test_fitting.py +++ b/source/tests/universal/dpmodel/fitting/test_fitting.py @@ -208,6 +208,8 @@ def FittingParamProperty( "dim_descrpt": dim_descrpt, "mixed_types": mixed_types, "type_map": type_map, + "task_dim": 3, + "property_name": "band_prop", "exclude_types": exclude_types, "seed": GLOBAL_SEED, "precision": precision, diff --git a/source/tests/universal/dpmodel/loss/test_loss.py b/source/tests/universal/dpmodel/loss/test_loss.py index 6473c159da..79c67cdba4 100644 --- a/source/tests/universal/dpmodel/loss/test_loss.py +++ b/source/tests/universal/dpmodel/loss/test_loss.py @@ -189,11 +189,14 @@ def LossParamTensor( def LossParamProperty(): key_to_pref_map = { - "property": 1.0, + "foo": 1.0, } input_dict = { "key_to_pref_map": key_to_pref_map, - "task_dim": 2, + "var_name": "foo", + "out_bias": [0.1, 0.5, 1.2, -0.1, -10], + "out_std": [8, 10, 0.001, -0.2, -10], + "task_dim": 5, } return input_dict