From 85bd386b8f43e9fb912da978ae6209fb24ffb3ea Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Wed, 11 Sep 2024 13:12:28 +0800 Subject: [PATCH] chore: support preset bias of atomic model output (#4116) ## Summary by CodeRabbit - **New Features** - Introduced a new `preset_out_bias` parameter for enhanced model configuration, allowing users to define biases. - Added documentation for the `preset_out_bias` parameter in the model arguments for improved clarity. - **Bug Fixes** - Implemented validation to ensure the `preset_out_bias` length matches the model's type map, preventing runtime errors. - **Tests** - Added unit tests for the `get_model` function to validate model attributes and ensure proper error handling for the new bias parameter. --------- Signed-off-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Signed-off-by: Jinzhe Zeng Co-authored-by: Han Wang Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng --- .../model/atomic_model/base_atomic_model.py | 7 +- deepmd/pt/model/model/__init__.py | 18 ++++ deepmd/pt/utils/stat.py | 12 +-- deepmd/utils/argcheck.py | 9 ++ source/tests/pt/model/test_get_model.py | 82 +++++++++++++++++++ 5 files changed, 119 insertions(+), 9 deletions(-) create mode 100644 source/tests/pt/model/test_get_model.py diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index d73c794c73..4742fe66a3 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -11,6 +11,7 @@ Union, ) +import numpy as np import torch from deepmd.dpmodel.atomic_model import ( @@ -66,9 +67,9 @@ class BaseAtomicModel(torch.nn.Module, BaseAtomicModel_): of the atomic model. Implemented by removing the pairs from the nlist. rcond : float, optional The condition number for the regression of atomic energy. - preset_out_bias : Dict[str, List[Optional[torch.Tensor]]], optional + preset_out_bias : Dict[str, List[Optional[np.ndarray]]], optional Specifying atomic energy contribution in vacuum. Given by key:value pairs. - The value is a list specifying the bias. the elements can be None or np.array of output shape. + The value is a list specifying the bias. the elements can be None or np.ndarray of output shape. For example: [None, [2.]] means type 0 is not set, type 1 is set to [2.] The `set_davg_zero` key in the descrptor should be set. @@ -80,7 +81,7 @@ def __init__( atom_exclude_types: List[int] = [], pair_exclude_types: List[Tuple[int, int]] = [], rcond: Optional[float] = None, - preset_out_bias: Optional[Dict[str, torch.Tensor]] = None, + preset_out_bias: Optional[Dict[str, np.ndarray]] = None, ): torch.nn.Module.__init__(self) BaseAtomicModel_.__init__(self) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index ed9cfcd7c2..9e69b6841e 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -151,6 +151,19 @@ def get_zbl_model(model_params): ) +def _convert_preset_out_bias_to_array(preset_out_bias, type_map): + if preset_out_bias is not None: + for kk in preset_out_bias: + if len(preset_out_bias[kk]) != len(type_map): + raise ValueError( + "length of the preset_out_bias should be the same as the type_map" + ) + for jj in range(len(preset_out_bias[kk])): + if preset_out_bias[kk][jj] is not None: + preset_out_bias[kk][jj] = np.array(preset_out_bias[kk][jj]) + return preset_out_bias + + def get_standard_model(model_params): model_params_old = model_params model_params = copy.deepcopy(model_params) @@ -176,6 +189,10 @@ def get_standard_model(model_params): fitting = BaseFitting(**fitting_net) atom_exclude_types = model_params.get("atom_exclude_types", []) pair_exclude_types = model_params.get("pair_exclude_types", []) + preset_out_bias = model_params.get("preset_out_bias") + preset_out_bias = _convert_preset_out_bias_to_array( + preset_out_bias, model_params["type_map"] + ) if fitting_net["type"] == "dipole": modelcls = DipoleModel @@ -196,6 +213,7 @@ def get_standard_model(model_params): type_map=model_params["type_map"], atom_exclude_types=atom_exclude_types, pair_exclude_types=pair_exclude_types, + preset_out_bias=preset_out_bias, ) model.model_def_script = json.dumps(model_params_old) return model diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 8adf21c127..6de70eb175 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -187,8 +187,8 @@ def model_forward_auto_batch_size(*args, **kwargs): def _make_preset_out_bias( ntypes: int, - ibias: List[Optional[np.array]], -) -> Optional[np.array]: + ibias: List[Optional[np.ndarray]], +) -> Optional[np.ndarray]: """Make preset out bias. output: @@ -242,7 +242,7 @@ def compute_output_stats( keys: Union[str, List[str]] = ["energy"], stat_file_path: Optional[DPPath] = None, rcond: Optional[float] = None, - preset_bias: Optional[Dict[str, List[Optional[torch.Tensor]]]] = None, + preset_bias: Optional[Dict[str, List[Optional[np.ndarray]]]] = None, model_forward: Optional[Callable[..., torch.Tensor]] = None, atomic_output: Optional[FittingOutputDef] = None, ): @@ -264,9 +264,9 @@ def compute_output_stats( The path to the stat file. rcond : float, optional The condition number for the regression of atomic energy. - preset_bias : Dict[str, List[Optional[torch.Tensor]]], optional + preset_bias : Dict[str, List[Optional[np.ndarray]]], optional Specifying atomic energy contribution in vacuum. Given by key:value pairs. - The value is a list specifying the bias. the elements can be None or np.array of output shape. + The value is a list specifying the bias. the elements can be None or np.ndarray of output shape. For example: [None, [2.]] means type 0 is not set, type 1 is set to [2.] The `set_davg_zero` key in the descrptor should be set. model_forward : Callable[..., torch.Tensor], optional @@ -405,7 +405,7 @@ def compute_output_stats_global( ntypes: int, keys: List[str], rcond: Optional[float] = None, - preset_bias: Optional[Dict[str, List[Optional[torch.Tensor]]]] = None, + preset_bias: Optional[Dict[str, List[Optional[np.ndarray]]]] = None, model_pred: Optional[Dict[str, np.ndarray]] = None, atomic_output: Optional[FittingOutputDef] = None, ): diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index c2f483e715..4eab9d87df 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -4,6 +4,7 @@ import warnings from typing import ( Callable, + Dict, List, Optional, Union, @@ -1771,6 +1772,7 @@ def model_args(exclude_hybrid=False): doc_spin = "The settings for systems with spin." doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types" doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other." + doc_preset_out_bias = "The preset bias of the atomic output. Is provided as a dict. Taking the energy model that has three atom types for example, the preset_out_bias may be given as `{ 'energy': [null, 0., 1.] }`. In this case the bias of type 1 and 2 are set to 0. and 1., respectively.The set_davg_zero should be set to true." doc_finetune_head = ( "The chosen fitting net to fine-tune on, when doing multi-task fine-tuning. " "If not set or set to 'RANDOM', the fitting net will be randomly initialized." @@ -1833,6 +1835,13 @@ def model_args(exclude_hybrid=False): default=[], doc=doc_only_pt_supported + doc_atom_exclude_types, ), + Argument( + "preset_out_bias", + Dict[str, Optional[float]], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_preset_out_bias, + ), Argument( "srtab_add_bias", bool, diff --git a/source/tests/pt/model/test_get_model.py b/source/tests/pt/model/test_get_model.py new file mode 100644 index 0000000000..c433597d5a --- /dev/null +++ b/source/tests/pt/model/test_get_model.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import numpy as np +import torch + +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) + +dtype = torch.float64 + +model_se_e2_a = { + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "se_e2_a", + "sel": [46, 92, 4], + "rcut_smth": 0.50, + "rcut": 4.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 20, + "atom_exclude_types": [1], + "pair_exclude_types": [[1, 2]], + "preset_out_bias": { + "energy": [ + None, + [1.0], + [3.0], + ] + }, +} + + +class TestGetModel(unittest.TestCase): + def test_model_attr(self): + model_params = copy.deepcopy(model_se_e2_a) + self.model = get_model(model_params).to(env.DEVICE) + atomic_model = self.model.atomic_model + self.assertEqual(atomic_model.type_map, ["O", "H", "B"]) + self.assertEqual( + atomic_model.preset_out_bias, + { + "energy": [ + None, + np.array([1.0]), + np.array([3.0]), + ] + }, + ) + self.assertEqual(atomic_model.atom_exclude_types, [1]) + self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]]) + + def test_notset_model_attr(self): + model_params = copy.deepcopy(model_se_e2_a) + model_params.pop("atom_exclude_types") + model_params.pop("pair_exclude_types") + model_params.pop("preset_out_bias") + self.model = get_model(model_params).to(env.DEVICE) + atomic_model = self.model.atomic_model + self.assertEqual(atomic_model.type_map, ["O", "H", "B"]) + self.assertEqual(atomic_model.preset_out_bias, None) + self.assertEqual(atomic_model.atom_exclude_types, []) + self.assertEqual(atomic_model.pair_exclude_types, []) + + def test_preset_wrong_len(self): + model_params = copy.deepcopy(model_se_e2_a) + model_params["preset_out_bias"] = {"energy": [None]} + with self.assertRaises(ValueError): + self.model = get_model(model_params).to(env.DEVICE)