From 83abc7b36665d49a0553c37f45491a6c10ac661d Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 20 Sep 2024 09:19:17 +0800 Subject: [PATCH 1/6] docs(pt): examples for new dpa2 model (#4138) small: 3 layers; w three-body; wo g2 attn; medium: 6 layers; w three-body; w g2 attn; large: 12 layers; w three-body; w g2 attn; ## Summary by CodeRabbit - **New Features** - Introduced comprehensive JSON configuration files for the DPA2 model, enhancing setup for molecular simulations. - Added detailed README documentation outlining model configurations and input files, aiding user selection based on precision and efficiency needs. - Added parameters for three-body interactions to improve model accuracy. - Configured learning rate settings and loss function preferences for better training dynamics. - **Bug Fixes** - Expanded test coverage by including multiple input file variations for the DPA2 example, ensuring more robust testing. - **Documentation** - Updated training example reference for clarity and included links to README for input variations. --- doc/model/dpa2.md | 2 +- examples/water/dpa2/README.md | 15 +++ ...nput_torch.json => input_torch_large.json} | 24 ++-- examples/water/dpa2/input_torch_medium.json | 112 ++++++++++++++++++ examples/water/dpa2/input_torch_small.json | 112 ++++++++++++++++++ source/tests/common/test_examples.py | 4 +- 6 files changed, 260 insertions(+), 9 deletions(-) create mode 100644 examples/water/dpa2/README.md rename examples/water/dpa2/{input_torch.json => input_torch_large.json} (78%) create mode 100644 examples/water/dpa2/input_torch_medium.json create mode 100644 examples/water/dpa2/input_torch_small.json diff --git a/doc/model/dpa2.md b/doc/model/dpa2.md index 3dd97df6ef..5de30ee6b2 100644 --- a/doc/model/dpa2.md +++ b/doc/model/dpa2.md @@ -6,7 +6,7 @@ The DPA-2 model implementation. See https://arxiv.org/abs/2312.15492 for more details. -Training example: `examples/water/dpa2/input_torch.json`. +Training example: `examples/water/dpa2/input_torch_medium.json`, see [README](../../examples/water/dpa2/README.md) for inputs in different levels. ## Data format diff --git a/examples/water/dpa2/README.md b/examples/water/dpa2/README.md new file mode 100644 index 0000000000..aa37d410a8 --- /dev/null +++ b/examples/water/dpa2/README.md @@ -0,0 +1,15 @@ +## Inputs for DPA-2 model + +This directory contains the input files for training the DPA-2 model (currently supporting PyTorch backend only). Depending on your precision/efficiency requirements, we provide three different levels of model complexity: + +- `input_torch_small.json`: Our smallest DPA-2 model, optimized for speed. +- `input_torch_medium.json` (Recommended): Our well-performing DPA-2 model, balancing efficiency and precision. This is a good starting point for most users. +- `input_torch_large.json`: Our most complex model with the highest precision, suitable for very intricate data structures. + +For detailed differences in their configurations, please refer to the table below: + +| Input | Repformer layers | Three-body embedding in Repinit | Pair-wise attention in Repformer | Tuned sub-structures in [#4089](https://github.com/deepmodeling/deepmd-kit/pull/4089) | Description | +| ------------------------- | ---------------- | ------------------------------- | -------------------------------- | ------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | +| `input_torch_small.json` | 3 | ✓ | ✗ | ✓ | Smallest DPA-2 model, optimized for speed. | +| `input_torch_medium.json` | 6 | ✓ | ✓ | ✓ | Recommended well-performing DPA-2 model, balancing efficiency and precision. | +| `input_torch_large.json` | 12 | ✓ | ✓ | ✓ | Most complex model with the highest precision. | diff --git a/examples/water/dpa2/input_torch.json b/examples/water/dpa2/input_torch_large.json similarity index 78% rename from examples/water/dpa2/input_torch.json rename to examples/water/dpa2/input_torch_large.json index ba8f2e5967..568cbc1a94 100644 --- a/examples/water/dpa2/input_torch.json +++ b/examples/water/dpa2/input_torch_large.json @@ -9,8 +9,8 @@ "type": "dpa2", "repinit": { "tebd_dim": 8, - "rcut": 9.0, - "rcut_smth": 8.0, + "rcut": 6.0, + "rcut_smth": 0.5, "nsel": 120, "neuron": [ 25, @@ -18,7 +18,11 @@ 100 ], "axis_neuron": 12, - "activation_function": "tanh" + "activation_function": "tanh", + "three_body_sel": 40, + "three_body_rcut": 4.0, + "three_body_rcut_smth": 3.5, + "use_three_body": true }, "repformer": { "rcut": 4.0, @@ -36,10 +40,16 @@ "update_g1_has_conv": true, "update_g1_has_grrg": true, "update_g1_has_drrd": true, - "update_g1_has_attn": true, - "update_g2_has_g1g1": true, + "update_g1_has_attn": false, + "update_g2_has_g1g1": false, "update_g2_has_attn": true, - "attn2_has_gate": true + "update_style": "res_residual", + "update_residual": 0.01, + "update_residual_init": "norm", + "attn2_has_gate": true, + "use_sqrt_nnei": true, + "g1_out_conv": true, + "g1_out_mlp": true }, "add_tebd_to_repinit_out": false }, @@ -58,7 +68,7 @@ "learning_rate": { "type": "exp", "decay_steps": 5000, - "start_lr": 0.0002, + "start_lr": 0.001, "stop_lr": 3.51e-08, "_comment": "that's all" }, diff --git a/examples/water/dpa2/input_torch_medium.json b/examples/water/dpa2/input_torch_medium.json new file mode 100644 index 0000000000..5b739e6f27 --- /dev/null +++ b/examples/water/dpa2/input_torch_medium.json @@ -0,0 +1,112 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa2", + "repinit": { + "tebd_dim": 8, + "rcut": 6.0, + "rcut_smth": 0.5, + "nsel": 120, + "neuron": [ + 25, + 50, + 100 + ], + "axis_neuron": 12, + "activation_function": "tanh", + "three_body_sel": 40, + "three_body_rcut": 4.0, + "three_body_rcut_smth": 3.5, + "use_three_body": true + }, + "repformer": { + "rcut": 4.0, + "rcut_smth": 3.5, + "nsel": 40, + "nlayers": 6, + "g1_dim": 128, + "g2_dim": 32, + "attn2_hidden": 32, + "attn2_nhead": 4, + "attn1_hidden": 128, + "attn1_nhead": 4, + "axis_neuron": 4, + "update_h2": false, + "update_g1_has_conv": true, + "update_g1_has_grrg": true, + "update_g1_has_drrd": true, + "update_g1_has_attn": false, + "update_g2_has_g1g1": false, + "update_g2_has_attn": true, + "update_style": "res_residual", + "update_residual": 0.01, + "update_residual_init": "norm", + "attn2_has_gate": true, + "use_sqrt_nnei": true, + "g1_out_conv": true, + "g1_out_mlp": true + }, + "add_tebd_to_repinit_out": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment": " that's all" + }, + "training": { + "stat_file": "./dpa2.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000, + "_comment": "that's all" + } +} diff --git a/examples/water/dpa2/input_torch_small.json b/examples/water/dpa2/input_torch_small.json new file mode 100644 index 0000000000..98147030b6 --- /dev/null +++ b/examples/water/dpa2/input_torch_small.json @@ -0,0 +1,112 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa2", + "repinit": { + "tebd_dim": 8, + "rcut": 6.0, + "rcut_smth": 0.5, + "nsel": 120, + "neuron": [ + 25, + 50, + 100 + ], + "axis_neuron": 12, + "activation_function": "tanh", + "three_body_sel": 40, + "three_body_rcut": 4.0, + "three_body_rcut_smth": 3.5, + "use_three_body": true + }, + "repformer": { + "rcut": 4.0, + "rcut_smth": 3.5, + "nsel": 40, + "nlayers": 3, + "g1_dim": 128, + "g2_dim": 32, + "attn2_hidden": 32, + "attn2_nhead": 4, + "attn1_hidden": 128, + "attn1_nhead": 4, + "axis_neuron": 4, + "update_h2": false, + "update_g1_has_conv": true, + "update_g1_has_grrg": true, + "update_g1_has_drrd": true, + "update_g1_has_attn": false, + "update_g2_has_g1g1": false, + "update_g2_has_attn": false, + "update_style": "res_residual", + "update_residual": 0.01, + "update_residual_init": "norm", + "attn2_has_gate": true, + "use_sqrt_nnei": true, + "g1_out_conv": true, + "g1_out_mlp": true + }, + "add_tebd_to_repinit_out": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment": " that's all" + }, + "training": { + "stat_file": "./dpa2.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000, + "_comment": "that's all" + } +} diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index 03fe112995..6abb482824 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -52,7 +52,9 @@ p_examples / "dprc" / "generalized_force" / "input.json", p_examples / "water" / "se_e2_a" / "input_torch.json", p_examples / "water" / "se_atten" / "input_torch.json", - p_examples / "water" / "dpa2" / "input_torch.json", + p_examples / "water" / "dpa2" / "input_torch_small.json", + p_examples / "water" / "dpa2" / "input_torch_medium.json", + p_examples / "water" / "dpa2" / "input_torch_large.json", p_examples / "property" / "train" / "input_torch.json", p_examples / "water" / "se_e3_tebd" / "input_torch.json", ) From c084b2015a978b954cfce8d31c1a26cdef00da6c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 19 Sep 2024 21:26:57 -0400 Subject: [PATCH 2/6] fix(pt): move entry point from `deepmd.pt.model` to `deepmd.pt` (#4146) `deepmd.pt` should be a more reasonable one to load the plugins. ## Summary by CodeRabbit - **New Features** - Enhanced initialization process for the `deepmd.pt` module, allowing for improved dynamic loading and configuration. - **Refactor** - Removed unnecessary imports and function calls related to entry point loading in the `deepmd.pt.model` module, streamlining the module's initialization. Signed-off-by: Jinzhe Zeng --- deepmd/pt/__init__.py | 5 +++++ deepmd/pt/model/__init__.py | 5 ----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/__init__.py b/deepmd/pt/__init__.py index ab61736198..daf3d406e9 100644 --- a/deepmd/pt/__init__.py +++ b/deepmd/pt/__init__.py @@ -4,6 +4,11 @@ from deepmd.pt.cxx_op import ( ENABLE_CUSTOMIZED_OP, ) +from deepmd.utils.entry_point import ( + load_entry_point, +) + +load_entry_point("deepmd.pt") __all__ = [ "ENABLE_CUSTOMIZED_OP", diff --git a/deepmd/pt/model/__init__.py b/deepmd/pt/model/__init__.py index 8422ac3802..6ceb116d85 100644 --- a/deepmd/pt/model/__init__.py +++ b/deepmd/pt/model/__init__.py @@ -1,6 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from deepmd.utils.entry_point import ( - load_entry_point, -) - -load_entry_point("deepmd.pt") From 81b9d206163aa4f24e7249419673fd5464374869 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20L=C3=B8land=20Bore?= <31211370+sigbjobo@users.noreply.github.com> Date: Fri, 20 Sep 2024 22:43:57 +0200 Subject: [PATCH 3/6] Set ROCM_ROOT to ROCM_PATH when it exist (#4150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This small commit closes issue #4149 by checking if `rocm_root` is not set, and then inferring it from `rocm_path`, if it exists. ## Summary by CodeRabbit - **New Features** - Enhanced flexibility in retrieving the ROCM root directory by checking both `ROCM_ROOT` and `ROCM_PATH` environment variables. - Updated installation documentation to clarify the fallback mechanism for locating the ROCM toolkit, ensuring users have clear guidance on configuration options. --------- Signed-off-by: Sigbjørn Løland Bore <31211370+sigbjobo@users.noreply.github.com> --- backend/read_env.py | 2 ++ doc/install/install-from-source.md | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/backend/read_env.py b/backend/read_env.py index c3fe2d5127..ae82778f4e 100644 --- a/backend/read_env.py +++ b/backend/read_env.py @@ -60,6 +60,8 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str, str]: cmake_minimum_required_version = "3.21" cmake_args.append("-DUSE_ROCM_TOOLKIT:BOOL=TRUE") rocm_root = os.environ.get("ROCM_ROOT") + if not rocm_root: + rocm_root = os.environ.get("ROCM_PATH") if rocm_root: cmake_args.append(f"-DCMAKE_HIP_COMPILER_ROCM_ROOT:STRING={rocm_root}") hipcc_flags = os.environ.get("HIP_HIPCC_FLAGS") diff --git a/doc/install/install-from-source.md b/doc/install/install-from-source.md index 6f17a272c6..a725be0133 100644 --- a/doc/install/install-from-source.md +++ b/doc/install/install-from-source.md @@ -155,7 +155,8 @@ The path to the CUDA toolkit directory. CUDA 9.0 or later is supported. NVCC is **Type**: Path; **Default**: Detected automatically -The path to the ROCM toolkit directory. +The path to the ROCM toolkit directory. If `ROCM_ROOT` is not set, it will look for `ROCM_PATH`; if `ROCM_PATH` is also not set, it will be detected using `hipconfig --rocmpath`. + ::: :::{envvar} DP_ENABLE_TENSORFLOW From d224fdd5a44c8227d09feeee88083c3b122f02a1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 20 Sep 2024 23:34:45 -0400 Subject: [PATCH 4/6] fix: fix `DPH5Path.glob` for new keys (#4152) Fix #4151. ## Summary by CodeRabbit - **New Features** - Enhanced path filtering logic to include a broader range of keys when generating subpaths. - **Bug Fixes** - Improved the accuracy of path results returned by the `glob` method. Signed-off-by: Jinzhe Zeng --- deepmd/utils/path.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 377953cc35..e794a36cab 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import itertools import os from abc import ( ABC, @@ -373,7 +374,11 @@ def glob(self, pattern: str) -> List["DPPath"]: list of paths """ # got paths starts with current path first, which is faster - subpaths = [ii for ii in self._keys if ii.startswith(self._name)] + subpaths = [ + ii + for ii in itertools.chain(self._keys, self._new_keys) + if ii.startswith(self._name) + ] return [ type(self)(f"{self.root_path}#{pp}", mode=self.mode) for pp in globfilter(subpaths, self._connect_path(pattern)) From 532e30952311a64bc830f3db65836c3481b72886 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sat, 21 Sep 2024 11:59:31 +0800 Subject: [PATCH 5/6] fix(pt): make `state_dict` safe for `weights_only` (#4148) See #4147 and #4143. We can first make `state_dict` safe for `weights_only`, then make a breaking change when loading `state_dict` in the future. ## Summary by CodeRabbit - **New Features** - Enhanced model saving functionality by ensuring learning rates are consistently stored as floats, improving type consistency. - **Bug Fixes** - Updated model loading behavior in tests to focus solely on model weights, which may resolve issues related to state dictionary loading. --- deepmd/pt/train/training.py | 7 +++++-- source/tests/pt/test_change_bias.py | 10 +++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index a7b9e25b4e..c3d603dadd 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1030,10 +1030,13 @@ def save_model(self, save_path, lr=0.0, step=0): if dist.is_available() and dist.is_initialized() else self.wrapper ) - module.train_infos["lr"] = lr + module.train_infos["lr"] = float(lr) module.train_infos["step"] = step + optim_state_dict = deepcopy(self.optimizer.state_dict()) + for item in optim_state_dict["param_groups"]: + item["lr"] = float(item["lr"]) torch.save( - {"model": module.state_dict(), "optimizer": self.optimizer.state_dict()}, + {"model": module.state_dict(), "optimizer": optim_state_dict}, save_path, ) checkpoint_dir = save_path.parent diff --git a/source/tests/pt/test_change_bias.py b/source/tests/pt/test_change_bias.py index f76be40b3f..febc439f50 100644 --- a/source/tests/pt/test_change_bias.py +++ b/source/tests/pt/test_change_bias.py @@ -92,7 +92,9 @@ def test_change_bias_with_data(self): run_dp( f"dp --pt change-bias {self.model_path!s} -s {self.data_file[0]} -o {self.model_path_data_bias!s}" ) - state_dict = torch.load(str(self.model_path_data_bias), map_location=DEVICE) + state_dict = torch.load( + str(self.model_path_data_bias), map_location=DEVICE, weights_only=True + ) model_params = state_dict["model"]["_extra_state"]["model_params"] model_for_wrapper = get_model_for_wrapper(model_params) wrapper = ModelWrapper(model_for_wrapper) @@ -114,7 +116,7 @@ def test_change_bias_with_data_sys_file(self): f"dp --pt change-bias {self.model_path!s} -f {tmp_file.name} -o {self.model_path_data_file_bias!s}" ) state_dict = torch.load( - str(self.model_path_data_file_bias), map_location=DEVICE + str(self.model_path_data_file_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] model_for_wrapper = get_model_for_wrapper(model_params) @@ -134,7 +136,9 @@ def test_change_bias_with_user_defined(self): run_dp( f"dp --pt change-bias {self.model_path!s} -b {' '.join([str(_) for _ in user_bias])} -o {self.model_path_user_bias!s}" ) - state_dict = torch.load(str(self.model_path_user_bias), map_location=DEVICE) + state_dict = torch.load( + str(self.model_path_user_bias), map_location=DEVICE, weights_only=True + ) model_params = state_dict["model"]["_extra_state"]["model_params"] model_for_wrapper = get_model_for_wrapper(model_params) wrapper = ModelWrapper(model_for_wrapper) From 6010c7305c551f78c7e9b6ab55984b699578a920 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 21 Sep 2024 10:44:01 -0400 Subject: [PATCH 6/6] chore(pt): move `deepmd.pt.infer.deep_eval.eval_model` to tests (#4153) Per discussion in https://github.com/deepmodeling/deepmd-kit/pull/4142#issuecomment-2359848991. It should not be a public API as it lacks maintainance. ## Summary by CodeRabbit - **New Features** - Introduced a new `eval_model` function in the testing module to enhance model evaluation capabilities with various input configurations. - **Bug Fixes** - Removed the old `eval_model` function from the main module to streamline functionality and improve code organization. - **Refactor** - Consolidated the import of `eval_model` to a common module across multiple test files for better organization and reduced dependencies. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> --- deepmd/pt/infer/deep_eval.py | 227 ----------------- source/tests/pt/common.py | 239 ++++++++++++++++++ source/tests/pt/model/test_autodiff.py | 4 +- source/tests/pt/model/test_forward_lower.py | 6 +- source/tests/pt/model/test_null_input.py | 6 +- source/tests/pt/model/test_permutation.py | 6 +- .../pt/model/test_permutation_denoise.py | 6 +- source/tests/pt/model/test_rot.py | 6 +- source/tests/pt/model/test_rot_denoise.py | 6 +- source/tests/pt/model/test_smooth.py | 6 +- source/tests/pt/model/test_smooth_denoise.py | 6 +- source/tests/pt/model/test_trans.py | 6 +- source/tests/pt/model/test_trans_denoise.py | 6 +- source/tests/pt/model/test_unused_params.py | 6 +- 14 files changed, 275 insertions(+), 261 deletions(-) diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 353109d650..d5eae71731 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -602,230 +602,3 @@ def eval_typeebd(self) -> np.ndarray: def get_model_def_script(self) -> str: """Get model defination script.""" return self.model_def_script - - -# For tests only -def eval_model( - model, - coords: Union[np.ndarray, torch.Tensor], - cells: Optional[Union[np.ndarray, torch.Tensor]], - atom_types: Union[np.ndarray, torch.Tensor, List[int]], - spins: Optional[Union[np.ndarray, torch.Tensor]] = None, - atomic: bool = False, - infer_batch_size: int = 2, - denoise: bool = False, -): - model = model.to(DEVICE) - energy_out = [] - atomic_energy_out = [] - force_out = [] - force_mag_out = [] - virial_out = [] - atomic_virial_out = [] - updated_coord_out = [] - logits_out = [] - err_msg = ( - f"All inputs should be the same format, " - f"but found {type(coords)}, {type(cells)}, {type(atom_types)} instead! " - ) - return_tensor = True - if isinstance(coords, torch.Tensor): - if cells is not None: - assert isinstance(cells, torch.Tensor), err_msg - if spins is not None: - assert isinstance(spins, torch.Tensor), err_msg - assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) - atom_types = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) - elif isinstance(coords, np.ndarray): - if cells is not None: - assert isinstance(cells, np.ndarray), err_msg - if spins is not None: - assert isinstance(spins, np.ndarray), err_msg - assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list) - atom_types = np.array(atom_types, dtype=np.int32) - return_tensor = False - - nframes = coords.shape[0] - if len(atom_types.shape) == 1: - natoms = len(atom_types) - if isinstance(atom_types, torch.Tensor): - atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape( - nframes, -1 - ) - else: - atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) - else: - natoms = len(atom_types[0]) - - coord_input = torch.tensor( - coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - spin_input = None - if spins is not None: - spin_input = torch.tensor( - spins.reshape([-1, natoms, 3]), - dtype=GLOBAL_PT_FLOAT_PRECISION, - device=DEVICE, - ) - has_spin = getattr(model, "has_spin", False) - if callable(has_spin): - has_spin = has_spin() - type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) - box_input = None - if cells is None: - pbc = False - else: - pbc = True - box_input = torch.tensor( - cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size) - - for ii in range(num_iter): - batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - batch_box = None - batch_spin = None - if spin_input is not None: - batch_spin = spin_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - if pbc: - batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - input_dict = { - "coord": batch_coord, - "atype": batch_atype, - "box": batch_box, - "do_atomic_virial": atomic, - } - if has_spin: - input_dict["spin"] = batch_spin - batch_output = model(**input_dict) - if isinstance(batch_output, tuple): - batch_output = batch_output[0] - if not return_tensor: - if "energy" in batch_output: - energy_out.append(batch_output["energy"].detach().cpu().numpy()) - if "atom_energy" in batch_output: - atomic_energy_out.append( - batch_output["atom_energy"].detach().cpu().numpy() - ) - if "force" in batch_output: - force_out.append(batch_output["force"].detach().cpu().numpy()) - if "force_mag" in batch_output: - force_mag_out.append(batch_output["force_mag"].detach().cpu().numpy()) - if "virial" in batch_output: - virial_out.append(batch_output["virial"].detach().cpu().numpy()) - if "atom_virial" in batch_output: - atomic_virial_out.append( - batch_output["atom_virial"].detach().cpu().numpy() - ) - if "updated_coord" in batch_output: - updated_coord_out.append( - batch_output["updated_coord"].detach().cpu().numpy() - ) - if "logits" in batch_output: - logits_out.append(batch_output["logits"].detach().cpu().numpy()) - else: - if "energy" in batch_output: - energy_out.append(batch_output["energy"]) - if "atom_energy" in batch_output: - atomic_energy_out.append(batch_output["atom_energy"]) - if "force" in batch_output: - force_out.append(batch_output["force"]) - if "force_mag" in batch_output: - force_mag_out.append(batch_output["force_mag"]) - if "virial" in batch_output: - virial_out.append(batch_output["virial"]) - if "atom_virial" in batch_output: - atomic_virial_out.append(batch_output["atom_virial"]) - if "updated_coord" in batch_output: - updated_coord_out.append(batch_output["updated_coord"]) - if "logits" in batch_output: - logits_out.append(batch_output["logits"]) - if not return_tensor: - energy_out = ( - np.concatenate(energy_out) if energy_out else np.zeros([nframes, 1]) # pylint: disable=no-explicit-dtype - ) - atomic_energy_out = ( - np.concatenate(atomic_energy_out) - if atomic_energy_out - else np.zeros([nframes, natoms, 1]) # pylint: disable=no-explicit-dtype - ) - force_out = ( - np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype - ) - force_mag_out = ( - np.concatenate(force_mag_out) - if force_mag_out - else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype - ) - virial_out = ( - np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3]) # pylint: disable=no-explicit-dtype - ) - atomic_virial_out = ( - np.concatenate(atomic_virial_out) - if atomic_virial_out - else np.zeros([nframes, natoms, 3, 3]) # pylint: disable=no-explicit-dtype - ) - updated_coord_out = ( - np.concatenate(updated_coord_out) if updated_coord_out else None - ) - logits_out = np.concatenate(logits_out) if logits_out else None - else: - energy_out = ( - torch.cat(energy_out) - if energy_out - else torch.zeros( - [nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - atomic_energy_out = ( - torch.cat(atomic_energy_out) - if atomic_energy_out - else torch.zeros( - [nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - force_out = ( - torch.cat(force_out) - if force_out - else torch.zeros( - [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - force_mag_out = ( - torch.cat(force_mag_out) - if force_mag_out - else torch.zeros( - [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - virial_out = ( - torch.cat(virial_out) - if virial_out - else torch.zeros( - [nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - atomic_virial_out = ( - torch.cat(atomic_virial_out) - if atomic_virial_out - else torch.zeros( - [nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None - logits_out = torch.cat(logits_out) if logits_out else None - if denoise: - return updated_coord_out, logits_out - else: - results_dict = { - "energy": energy_out, - "force": force_out, - "virial": virial_out, - } - if has_spin: - results_dict["force_mag"] = force_mag_out - if atomic: - results_dict["atom_energy"] = atomic_energy_out - results_dict["atom_virial"] = atomic_virial_out - return results_dict diff --git a/source/tests/pt/common.py b/source/tests/pt/common.py index 8886522360..16b343be8a 100644 --- a/source/tests/pt/common.py +++ b/source/tests/pt/common.py @@ -1,7 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Optional, + Union, +) + +import numpy as np +import torch + from deepmd.main import ( main, ) +from deepmd.pt.utils.env import ( + DEVICE, + GLOBAL_PT_FLOAT_PRECISION, +) def run_dp(cmd: str) -> int: @@ -27,3 +40,229 @@ def run_dp(cmd: str) -> int: main(cmds) return 0 + + +def eval_model( + model, + coords: Union[np.ndarray, torch.Tensor], + cells: Optional[Union[np.ndarray, torch.Tensor]], + atom_types: Union[np.ndarray, torch.Tensor, List[int]], + spins: Optional[Union[np.ndarray, torch.Tensor]] = None, + atomic: bool = False, + infer_batch_size: int = 2, + denoise: bool = False, +): + model = model.to(DEVICE) + energy_out = [] + atomic_energy_out = [] + force_out = [] + force_mag_out = [] + virial_out = [] + atomic_virial_out = [] + updated_coord_out = [] + logits_out = [] + err_msg = ( + f"All inputs should be the same format, " + f"but found {type(coords)}, {type(cells)}, {type(atom_types)} instead! " + ) + return_tensor = True + if isinstance(coords, torch.Tensor): + if cells is not None: + assert isinstance(cells, torch.Tensor), err_msg + if spins is not None: + assert isinstance(spins, torch.Tensor), err_msg + assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) + atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE) + elif isinstance(coords, np.ndarray): + if cells is not None: + assert isinstance(cells, np.ndarray), err_msg + if spins is not None: + assert isinstance(spins, np.ndarray), err_msg + assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list) + atom_types = np.array(atom_types, dtype=np.int32) + return_tensor = False + + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + if isinstance(atom_types, torch.Tensor): + atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape( + nframes, -1 + ) + else: + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + coord_input = torch.tensor( + coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + spin_input = None + if spins is not None: + spin_input = torch.tensor( + spins.reshape([-1, natoms, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + has_spin = getattr(model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) + box_input = None + if cells is None: + pbc = False + else: + pbc = True + box_input = torch.tensor( + cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size) + + for ii in range(num_iter): + batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + batch_box = None + batch_spin = None + if spin_input is not None: + batch_spin = spin_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + if pbc: + batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + input_dict = { + "coord": batch_coord, + "atype": batch_atype, + "box": batch_box, + "do_atomic_virial": atomic, + } + if has_spin: + input_dict["spin"] = batch_spin + batch_output = model(**input_dict) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + if not return_tensor: + if "energy" in batch_output: + energy_out.append(batch_output["energy"].detach().cpu().numpy()) + if "atom_energy" in batch_output: + atomic_energy_out.append( + batch_output["atom_energy"].detach().cpu().numpy() + ) + if "force" in batch_output: + force_out.append(batch_output["force"].detach().cpu().numpy()) + if "force_mag" in batch_output: + force_mag_out.append(batch_output["force_mag"].detach().cpu().numpy()) + if "virial" in batch_output: + virial_out.append(batch_output["virial"].detach().cpu().numpy()) + if "atom_virial" in batch_output: + atomic_virial_out.append( + batch_output["atom_virial"].detach().cpu().numpy() + ) + if "updated_coord" in batch_output: + updated_coord_out.append( + batch_output["updated_coord"].detach().cpu().numpy() + ) + if "logits" in batch_output: + logits_out.append(batch_output["logits"].detach().cpu().numpy()) + else: + if "energy" in batch_output: + energy_out.append(batch_output["energy"]) + if "atom_energy" in batch_output: + atomic_energy_out.append(batch_output["atom_energy"]) + if "force" in batch_output: + force_out.append(batch_output["force"]) + if "force_mag" in batch_output: + force_mag_out.append(batch_output["force_mag"]) + if "virial" in batch_output: + virial_out.append(batch_output["virial"]) + if "atom_virial" in batch_output: + atomic_virial_out.append(batch_output["atom_virial"]) + if "updated_coord" in batch_output: + updated_coord_out.append(batch_output["updated_coord"]) + if "logits" in batch_output: + logits_out.append(batch_output["logits"]) + if not return_tensor: + energy_out = ( + np.concatenate(energy_out) if energy_out else np.zeros([nframes, 1]) # pylint: disable=no-explicit-dtype + ) + atomic_energy_out = ( + np.concatenate(atomic_energy_out) + if atomic_energy_out + else np.zeros([nframes, natoms, 1]) # pylint: disable=no-explicit-dtype + ) + force_out = ( + np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype + ) + force_mag_out = ( + np.concatenate(force_mag_out) + if force_mag_out + else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype + ) + virial_out = ( + np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3]) # pylint: disable=no-explicit-dtype + ) + atomic_virial_out = ( + np.concatenate(atomic_virial_out) + if atomic_virial_out + else np.zeros([nframes, natoms, 3, 3]) # pylint: disable=no-explicit-dtype + ) + updated_coord_out = ( + np.concatenate(updated_coord_out) if updated_coord_out else None + ) + logits_out = np.concatenate(logits_out) if logits_out else None + else: + energy_out = ( + torch.cat(energy_out) + if energy_out + else torch.zeros( + [nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + atomic_energy_out = ( + torch.cat(atomic_energy_out) + if atomic_energy_out + else torch.zeros( + [nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + force_out = ( + torch.cat(force_out) + if force_out + else torch.zeros( + [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + force_mag_out = ( + torch.cat(force_mag_out) + if force_mag_out + else torch.zeros( + [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + virial_out = ( + torch.cat(virial_out) + if virial_out + else torch.zeros( + [nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + atomic_virial_out = ( + torch.cat(atomic_virial_out) + if atomic_virial_out + else torch.zeros( + [nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None + logits_out = torch.cat(logits_out) if logits_out else None + if denoise: + return updated_coord_out, logits_out + else: + results_dict = { + "energy": energy_out, + "force": force_out, + "virial": virial_out, + } + if has_spin: + results_dict["force_mag"] = force_mag_out + if atomic: + results_dict["atom_energy"] = atomic_energy_out + results_dict["atom_virial"] = atomic_virial_out + return results_dict diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index d891583491..1adcff55fc 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -21,8 +21,10 @@ dtype = torch.float64 -from .test_permutation import ( +from ..common import ( eval_model, +) +from .test_permutation import ( model_dpa1, model_dpa2, model_hybrid, diff --git a/source/tests/pt/model/test_forward_lower.py b/source/tests/pt/model/test_forward_lower.py index c9857a6343..87a3f5b06e 100644 --- a/source/tests/pt/model/test_forward_lower.py +++ b/source/tests/pt/model/test_forward_lower.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -20,6 +17,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_null_input.py b/source/tests/pt/model/test_null_input.py index 1dca7ee119..a2e0fa66db 100644 --- a/source/tests/pt/model/test_null_input.py +++ b/source/tests/pt/model/test_null_input.py @@ -5,9 +5,6 @@ import numpy as np import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, get_zbl_model, @@ -22,6 +19,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index f5edc6ef64..2fbc5fde3c 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -5,9 +5,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -18,6 +15,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) CUR_DIR = os.path.dirname(__file__) diff --git a/source/tests/pt/model/test_permutation_denoise.py b/source/tests/pt/model/test_permutation_denoise.py index 133c48f551..53bf55fb0f 100644 --- a/source/tests/pt/model/test_permutation_denoise.py +++ b/source/tests/pt/model/test_permutation_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_rot.py b/source/tests/pt/model/test_rot.py index 23bdede923..ca6a6375c8 100644 --- a/source/tests/pt/model/test_rot.py +++ b/source/tests/pt/model/test_rot.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dos, model_dpa1, diff --git a/source/tests/pt/model/test_rot_denoise.py b/source/tests/pt/model/test_rot_denoise.py index 5fe99a0d7a..9828ba5225 100644 --- a/source/tests/pt/model/test_rot_denoise.py +++ b/source/tests/pt/model/test_rot_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation_denoise import ( model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index c33dddfab5..9a7040f9cc 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dos, model_dpa1, diff --git a/source/tests/pt/model/test_smooth_denoise.py b/source/tests/pt/model/test_smooth_denoise.py index 069c578d52..faa892c5d0 100644 --- a/source/tests/pt/model/test_smooth_denoise.py +++ b/source/tests/pt/model/test_smooth_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation_denoise import ( model_dpa2, ) diff --git a/source/tests/pt/model/test_trans.py b/source/tests/pt/model/test_trans.py index afd70f8995..b62fac1312 100644 --- a/source/tests/pt/model/test_trans.py +++ b/source/tests/pt/model/test_trans.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dos, model_dpa1, diff --git a/source/tests/pt/model/test_trans_denoise.py b/source/tests/pt/model/test_trans_denoise.py index 2d31d5de50..84ec21929c 100644 --- a/source/tests/pt/model/test_trans_denoise.py +++ b/source/tests/pt/model/test_trans_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation_denoise import ( model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_unused_params.py b/source/tests/pt/model/test_unused_params.py index e225719e7f..3f068d5e5b 100644 --- a/source/tests/pt/model/test_unused_params.py +++ b/source/tests/pt/model/test_unused_params.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( model_dpa2, )