diff --git a/backend/read_env.py b/backend/read_env.py index c3fe2d5127..ae82778f4e 100644 --- a/backend/read_env.py +++ b/backend/read_env.py @@ -60,6 +60,8 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str, str]: cmake_minimum_required_version = "3.21" cmake_args.append("-DUSE_ROCM_TOOLKIT:BOOL=TRUE") rocm_root = os.environ.get("ROCM_ROOT") + if not rocm_root: + rocm_root = os.environ.get("ROCM_PATH") if rocm_root: cmake_args.append(f"-DCMAKE_HIP_COMPILER_ROCM_ROOT:STRING={rocm_root}") hipcc_flags = os.environ.get("HIP_HIPCC_FLAGS") diff --git a/deepmd/pt/__init__.py b/deepmd/pt/__init__.py index ab61736198..daf3d406e9 100644 --- a/deepmd/pt/__init__.py +++ b/deepmd/pt/__init__.py @@ -4,6 +4,11 @@ from deepmd.pt.cxx_op import ( ENABLE_CUSTOMIZED_OP, ) +from deepmd.utils.entry_point import ( + load_entry_point, +) + +load_entry_point("deepmd.pt") __all__ = [ "ENABLE_CUSTOMIZED_OP", diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index ab404cd8a5..4e5b77f02b 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -604,230 +604,3 @@ def eval_typeebd(self) -> np.ndarray: def get_model_def_script(self) -> str: """Get model defination script.""" return self.model_def_script - - -# For tests only -def eval_model( - model, - coords: Union[np.ndarray, torch.Tensor], - cells: Optional[Union[np.ndarray, torch.Tensor]], - atom_types: Union[np.ndarray, torch.Tensor, List[int]], - spins: Optional[Union[np.ndarray, torch.Tensor]] = None, - atomic: bool = False, - infer_batch_size: int = 2, - denoise: bool = False, -): - model = model.to(DEVICE) - energy_out = [] - atomic_energy_out = [] - force_out = [] - force_mag_out = [] - virial_out = [] - atomic_virial_out = [] - updated_coord_out = [] - logits_out = [] - err_msg = ( - f"All inputs should be the same format, " - f"but found {type(coords)}, {type(cells)}, {type(atom_types)} instead! " - ) - return_tensor = True - if isinstance(coords, torch.Tensor): - if cells is not None: - assert isinstance(cells, torch.Tensor), err_msg - if spins is not None: - assert isinstance(spins, torch.Tensor), err_msg - assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) - atom_types = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) - elif isinstance(coords, np.ndarray): - if cells is not None: - assert isinstance(cells, np.ndarray), err_msg - if spins is not None: - assert isinstance(spins, np.ndarray), err_msg - assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list) - atom_types = np.array(atom_types, dtype=np.int32) - return_tensor = False - - nframes = coords.shape[0] - if len(atom_types.shape) == 1: - natoms = len(atom_types) - if isinstance(atom_types, torch.Tensor): - atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape( - nframes, -1 - ) - else: - atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) - else: - natoms = len(atom_types[0]) - - coord_input = torch.tensor( - coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - spin_input = None - if spins is not None: - spin_input = torch.tensor( - spins.reshape([-1, natoms, 3]), - dtype=GLOBAL_PT_FLOAT_PRECISION, - device=DEVICE, - ) - has_spin = getattr(model, "has_spin", False) - if callable(has_spin): - has_spin = has_spin() - type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) - box_input = None - if cells is None: - pbc = False - else: - pbc = True - box_input = torch.tensor( - cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size) - - for ii in range(num_iter): - batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - batch_box = None - batch_spin = None - if spin_input is not None: - batch_spin = spin_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - if pbc: - batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] - input_dict = { - "coord": batch_coord, - "atype": batch_atype, - "box": batch_box, - "do_atomic_virial": atomic, - } - if has_spin: - input_dict["spin"] = batch_spin - batch_output = model(**input_dict) - if isinstance(batch_output, tuple): - batch_output = batch_output[0] - if not return_tensor: - if "energy" in batch_output: - energy_out.append(batch_output["energy"].detach().cpu().numpy()) - if "atom_energy" in batch_output: - atomic_energy_out.append( - batch_output["atom_energy"].detach().cpu().numpy() - ) - if "force" in batch_output: - force_out.append(batch_output["force"].detach().cpu().numpy()) - if "force_mag" in batch_output: - force_mag_out.append(batch_output["force_mag"].detach().cpu().numpy()) - if "virial" in batch_output: - virial_out.append(batch_output["virial"].detach().cpu().numpy()) - if "atom_virial" in batch_output: - atomic_virial_out.append( - batch_output["atom_virial"].detach().cpu().numpy() - ) - if "updated_coord" in batch_output: - updated_coord_out.append( - batch_output["updated_coord"].detach().cpu().numpy() - ) - if "logits" in batch_output: - logits_out.append(batch_output["logits"].detach().cpu().numpy()) - else: - if "energy" in batch_output: - energy_out.append(batch_output["energy"]) - if "atom_energy" in batch_output: - atomic_energy_out.append(batch_output["atom_energy"]) - if "force" in batch_output: - force_out.append(batch_output["force"]) - if "force_mag" in batch_output: - force_mag_out.append(batch_output["force_mag"]) - if "virial" in batch_output: - virial_out.append(batch_output["virial"]) - if "atom_virial" in batch_output: - atomic_virial_out.append(batch_output["atom_virial"]) - if "updated_coord" in batch_output: - updated_coord_out.append(batch_output["updated_coord"]) - if "logits" in batch_output: - logits_out.append(batch_output["logits"]) - if not return_tensor: - energy_out = ( - np.concatenate(energy_out) if energy_out else np.zeros([nframes, 1]) # pylint: disable=no-explicit-dtype - ) - atomic_energy_out = ( - np.concatenate(atomic_energy_out) - if atomic_energy_out - else np.zeros([nframes, natoms, 1]) # pylint: disable=no-explicit-dtype - ) - force_out = ( - np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype - ) - force_mag_out = ( - np.concatenate(force_mag_out) - if force_mag_out - else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype - ) - virial_out = ( - np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3]) # pylint: disable=no-explicit-dtype - ) - atomic_virial_out = ( - np.concatenate(atomic_virial_out) - if atomic_virial_out - else np.zeros([nframes, natoms, 3, 3]) # pylint: disable=no-explicit-dtype - ) - updated_coord_out = ( - np.concatenate(updated_coord_out) if updated_coord_out else None - ) - logits_out = np.concatenate(logits_out) if logits_out else None - else: - energy_out = ( - torch.cat(energy_out) - if energy_out - else torch.zeros( - [nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - atomic_energy_out = ( - torch.cat(atomic_energy_out) - if atomic_energy_out - else torch.zeros( - [nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - force_out = ( - torch.cat(force_out) - if force_out - else torch.zeros( - [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - force_mag_out = ( - torch.cat(force_mag_out) - if force_mag_out - else torch.zeros( - [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - virial_out = ( - torch.cat(virial_out) - if virial_out - else torch.zeros( - [nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - atomic_virial_out = ( - torch.cat(atomic_virial_out) - if atomic_virial_out - else torch.zeros( - [nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ) - updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None - logits_out = torch.cat(logits_out) if logits_out else None - if denoise: - return updated_coord_out, logits_out - else: - results_dict = { - "energy": energy_out, - "force": force_out, - "virial": virial_out, - } - if has_spin: - results_dict["force_mag"] = force_mag_out - if atomic: - results_dict["atom_energy"] = atomic_energy_out - results_dict["atom_virial"] = atomic_virial_out - return results_dict diff --git a/deepmd/pt/model/__init__.py b/deepmd/pt/model/__init__.py index 8422ac3802..6ceb116d85 100644 --- a/deepmd/pt/model/__init__.py +++ b/deepmd/pt/model/__init__.py @@ -1,6 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from deepmd.utils.entry_point import ( - load_entry_point, -) - -load_entry_point("deepmd.pt") diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index d32b03bd8f..ae160d966c 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1032,10 +1032,13 @@ def save_model(self, save_path, lr=0.0, step=0): if dist.is_available() and dist.is_initialized() else self.wrapper ) - module.train_infos["lr"] = lr + module.train_infos["lr"] = float(lr) module.train_infos["step"] = step + optim_state_dict = deepcopy(self.optimizer.state_dict()) + for item in optim_state_dict["param_groups"]: + item["lr"] = float(item["lr"]) torch.save( - {"model": module.state_dict(), "optimizer": self.optimizer.state_dict()}, + {"model": module.state_dict(), "optimizer": optim_state_dict}, save_path, ) checkpoint_dir = save_path.parent diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 377953cc35..e794a36cab 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import itertools import os from abc import ( ABC, @@ -373,7 +374,11 @@ def glob(self, pattern: str) -> List["DPPath"]: list of paths """ # got paths starts with current path first, which is faster - subpaths = [ii for ii in self._keys if ii.startswith(self._name)] + subpaths = [ + ii + for ii in itertools.chain(self._keys, self._new_keys) + if ii.startswith(self._name) + ] return [ type(self)(f"{self.root_path}#{pp}", mode=self.mode) for pp in globfilter(subpaths, self._connect_path(pattern)) diff --git a/doc/install/install-from-source.md b/doc/install/install-from-source.md index 6f17a272c6..a725be0133 100644 --- a/doc/install/install-from-source.md +++ b/doc/install/install-from-source.md @@ -155,7 +155,8 @@ The path to the CUDA toolkit directory. CUDA 9.0 or later is supported. NVCC is **Type**: Path; **Default**: Detected automatically -The path to the ROCM toolkit directory. +The path to the ROCM toolkit directory. If `ROCM_ROOT` is not set, it will look for `ROCM_PATH`; if `ROCM_PATH` is also not set, it will be detected using `hipconfig --rocmpath`. + ::: :::{envvar} DP_ENABLE_TENSORFLOW diff --git a/doc/model/dpa2.md b/doc/model/dpa2.md index 3dd97df6ef..5de30ee6b2 100644 --- a/doc/model/dpa2.md +++ b/doc/model/dpa2.md @@ -6,7 +6,7 @@ The DPA-2 model implementation. See https://arxiv.org/abs/2312.15492 for more details. -Training example: `examples/water/dpa2/input_torch.json`. +Training example: `examples/water/dpa2/input_torch_medium.json`, see [README](../../examples/water/dpa2/README.md) for inputs in different levels. ## Data format diff --git a/examples/water/dpa2/README.md b/examples/water/dpa2/README.md new file mode 100644 index 0000000000..aa37d410a8 --- /dev/null +++ b/examples/water/dpa2/README.md @@ -0,0 +1,15 @@ +## Inputs for DPA-2 model + +This directory contains the input files for training the DPA-2 model (currently supporting PyTorch backend only). Depending on your precision/efficiency requirements, we provide three different levels of model complexity: + +- `input_torch_small.json`: Our smallest DPA-2 model, optimized for speed. +- `input_torch_medium.json` (Recommended): Our well-performing DPA-2 model, balancing efficiency and precision. This is a good starting point for most users. +- `input_torch_large.json`: Our most complex model with the highest precision, suitable for very intricate data structures. + +For detailed differences in their configurations, please refer to the table below: + +| Input | Repformer layers | Three-body embedding in Repinit | Pair-wise attention in Repformer | Tuned sub-structures in [#4089](https://github.com/deepmodeling/deepmd-kit/pull/4089) | Description | +| ------------------------- | ---------------- | ------------------------------- | -------------------------------- | ------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | +| `input_torch_small.json` | 3 | ✓ | ✗ | ✓ | Smallest DPA-2 model, optimized for speed. | +| `input_torch_medium.json` | 6 | ✓ | ✓ | ✓ | Recommended well-performing DPA-2 model, balancing efficiency and precision. | +| `input_torch_large.json` | 12 | ✓ | ✓ | ✓ | Most complex model with the highest precision. | diff --git a/examples/water/dpa2/input_torch.json b/examples/water/dpa2/input_torch_large.json similarity index 78% rename from examples/water/dpa2/input_torch.json rename to examples/water/dpa2/input_torch_large.json index ba8f2e5967..568cbc1a94 100644 --- a/examples/water/dpa2/input_torch.json +++ b/examples/water/dpa2/input_torch_large.json @@ -9,8 +9,8 @@ "type": "dpa2", "repinit": { "tebd_dim": 8, - "rcut": 9.0, - "rcut_smth": 8.0, + "rcut": 6.0, + "rcut_smth": 0.5, "nsel": 120, "neuron": [ 25, @@ -18,7 +18,11 @@ 100 ], "axis_neuron": 12, - "activation_function": "tanh" + "activation_function": "tanh", + "three_body_sel": 40, + "three_body_rcut": 4.0, + "three_body_rcut_smth": 3.5, + "use_three_body": true }, "repformer": { "rcut": 4.0, @@ -36,10 +40,16 @@ "update_g1_has_conv": true, "update_g1_has_grrg": true, "update_g1_has_drrd": true, - "update_g1_has_attn": true, - "update_g2_has_g1g1": true, + "update_g1_has_attn": false, + "update_g2_has_g1g1": false, "update_g2_has_attn": true, - "attn2_has_gate": true + "update_style": "res_residual", + "update_residual": 0.01, + "update_residual_init": "norm", + "attn2_has_gate": true, + "use_sqrt_nnei": true, + "g1_out_conv": true, + "g1_out_mlp": true }, "add_tebd_to_repinit_out": false }, @@ -58,7 +68,7 @@ "learning_rate": { "type": "exp", "decay_steps": 5000, - "start_lr": 0.0002, + "start_lr": 0.001, "stop_lr": 3.51e-08, "_comment": "that's all" }, diff --git a/examples/water/dpa2/input_torch_medium.json b/examples/water/dpa2/input_torch_medium.json new file mode 100644 index 0000000000..5b739e6f27 --- /dev/null +++ b/examples/water/dpa2/input_torch_medium.json @@ -0,0 +1,112 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa2", + "repinit": { + "tebd_dim": 8, + "rcut": 6.0, + "rcut_smth": 0.5, + "nsel": 120, + "neuron": [ + 25, + 50, + 100 + ], + "axis_neuron": 12, + "activation_function": "tanh", + "three_body_sel": 40, + "three_body_rcut": 4.0, + "three_body_rcut_smth": 3.5, + "use_three_body": true + }, + "repformer": { + "rcut": 4.0, + "rcut_smth": 3.5, + "nsel": 40, + "nlayers": 6, + "g1_dim": 128, + "g2_dim": 32, + "attn2_hidden": 32, + "attn2_nhead": 4, + "attn1_hidden": 128, + "attn1_nhead": 4, + "axis_neuron": 4, + "update_h2": false, + "update_g1_has_conv": true, + "update_g1_has_grrg": true, + "update_g1_has_drrd": true, + "update_g1_has_attn": false, + "update_g2_has_g1g1": false, + "update_g2_has_attn": true, + "update_style": "res_residual", + "update_residual": 0.01, + "update_residual_init": "norm", + "attn2_has_gate": true, + "use_sqrt_nnei": true, + "g1_out_conv": true, + "g1_out_mlp": true + }, + "add_tebd_to_repinit_out": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment": " that's all" + }, + "training": { + "stat_file": "./dpa2.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000, + "_comment": "that's all" + } +} diff --git a/examples/water/dpa2/input_torch_small.json b/examples/water/dpa2/input_torch_small.json new file mode 100644 index 0000000000..98147030b6 --- /dev/null +++ b/examples/water/dpa2/input_torch_small.json @@ -0,0 +1,112 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa2", + "repinit": { + "tebd_dim": 8, + "rcut": 6.0, + "rcut_smth": 0.5, + "nsel": 120, + "neuron": [ + 25, + 50, + 100 + ], + "axis_neuron": 12, + "activation_function": "tanh", + "three_body_sel": 40, + "three_body_rcut": 4.0, + "three_body_rcut_smth": 3.5, + "use_three_body": true + }, + "repformer": { + "rcut": 4.0, + "rcut_smth": 3.5, + "nsel": 40, + "nlayers": 3, + "g1_dim": 128, + "g2_dim": 32, + "attn2_hidden": 32, + "attn2_nhead": 4, + "attn1_hidden": 128, + "attn1_nhead": 4, + "axis_neuron": 4, + "update_h2": false, + "update_g1_has_conv": true, + "update_g1_has_grrg": true, + "update_g1_has_drrd": true, + "update_g1_has_attn": false, + "update_g2_has_g1g1": false, + "update_g2_has_attn": false, + "update_style": "res_residual", + "update_residual": 0.01, + "update_residual_init": "norm", + "attn2_has_gate": true, + "use_sqrt_nnei": true, + "g1_out_conv": true, + "g1_out_mlp": true + }, + "add_tebd_to_repinit_out": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment": " that's all" + }, + "training": { + "stat_file": "./dpa2.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000, + "_comment": "that's all" + } +} diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index 03fe112995..6abb482824 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -52,7 +52,9 @@ p_examples / "dprc" / "generalized_force" / "input.json", p_examples / "water" / "se_e2_a" / "input_torch.json", p_examples / "water" / "se_atten" / "input_torch.json", - p_examples / "water" / "dpa2" / "input_torch.json", + p_examples / "water" / "dpa2" / "input_torch_small.json", + p_examples / "water" / "dpa2" / "input_torch_medium.json", + p_examples / "water" / "dpa2" / "input_torch_large.json", p_examples / "property" / "train" / "input_torch.json", p_examples / "water" / "se_e3_tebd" / "input_torch.json", ) diff --git a/source/tests/pt/common.py b/source/tests/pt/common.py index 8886522360..16b343be8a 100644 --- a/source/tests/pt/common.py +++ b/source/tests/pt/common.py @@ -1,7 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Optional, + Union, +) + +import numpy as np +import torch + from deepmd.main import ( main, ) +from deepmd.pt.utils.env import ( + DEVICE, + GLOBAL_PT_FLOAT_PRECISION, +) def run_dp(cmd: str) -> int: @@ -27,3 +40,229 @@ def run_dp(cmd: str) -> int: main(cmds) return 0 + + +def eval_model( + model, + coords: Union[np.ndarray, torch.Tensor], + cells: Optional[Union[np.ndarray, torch.Tensor]], + atom_types: Union[np.ndarray, torch.Tensor, List[int]], + spins: Optional[Union[np.ndarray, torch.Tensor]] = None, + atomic: bool = False, + infer_batch_size: int = 2, + denoise: bool = False, +): + model = model.to(DEVICE) + energy_out = [] + atomic_energy_out = [] + force_out = [] + force_mag_out = [] + virial_out = [] + atomic_virial_out = [] + updated_coord_out = [] + logits_out = [] + err_msg = ( + f"All inputs should be the same format, " + f"but found {type(coords)}, {type(cells)}, {type(atom_types)} instead! " + ) + return_tensor = True + if isinstance(coords, torch.Tensor): + if cells is not None: + assert isinstance(cells, torch.Tensor), err_msg + if spins is not None: + assert isinstance(spins, torch.Tensor), err_msg + assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) + atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE) + elif isinstance(coords, np.ndarray): + if cells is not None: + assert isinstance(cells, np.ndarray), err_msg + if spins is not None: + assert isinstance(spins, np.ndarray), err_msg + assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list) + atom_types = np.array(atom_types, dtype=np.int32) + return_tensor = False + + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + if isinstance(atom_types, torch.Tensor): + atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape( + nframes, -1 + ) + else: + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + coord_input = torch.tensor( + coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + spin_input = None + if spins is not None: + spin_input = torch.tensor( + spins.reshape([-1, natoms, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + has_spin = getattr(model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) + box_input = None + if cells is None: + pbc = False + else: + pbc = True + box_input = torch.tensor( + cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size) + + for ii in range(num_iter): + batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + batch_box = None + batch_spin = None + if spin_input is not None: + batch_spin = spin_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + if pbc: + batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + input_dict = { + "coord": batch_coord, + "atype": batch_atype, + "box": batch_box, + "do_atomic_virial": atomic, + } + if has_spin: + input_dict["spin"] = batch_spin + batch_output = model(**input_dict) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + if not return_tensor: + if "energy" in batch_output: + energy_out.append(batch_output["energy"].detach().cpu().numpy()) + if "atom_energy" in batch_output: + atomic_energy_out.append( + batch_output["atom_energy"].detach().cpu().numpy() + ) + if "force" in batch_output: + force_out.append(batch_output["force"].detach().cpu().numpy()) + if "force_mag" in batch_output: + force_mag_out.append(batch_output["force_mag"].detach().cpu().numpy()) + if "virial" in batch_output: + virial_out.append(batch_output["virial"].detach().cpu().numpy()) + if "atom_virial" in batch_output: + atomic_virial_out.append( + batch_output["atom_virial"].detach().cpu().numpy() + ) + if "updated_coord" in batch_output: + updated_coord_out.append( + batch_output["updated_coord"].detach().cpu().numpy() + ) + if "logits" in batch_output: + logits_out.append(batch_output["logits"].detach().cpu().numpy()) + else: + if "energy" in batch_output: + energy_out.append(batch_output["energy"]) + if "atom_energy" in batch_output: + atomic_energy_out.append(batch_output["atom_energy"]) + if "force" in batch_output: + force_out.append(batch_output["force"]) + if "force_mag" in batch_output: + force_mag_out.append(batch_output["force_mag"]) + if "virial" in batch_output: + virial_out.append(batch_output["virial"]) + if "atom_virial" in batch_output: + atomic_virial_out.append(batch_output["atom_virial"]) + if "updated_coord" in batch_output: + updated_coord_out.append(batch_output["updated_coord"]) + if "logits" in batch_output: + logits_out.append(batch_output["logits"]) + if not return_tensor: + energy_out = ( + np.concatenate(energy_out) if energy_out else np.zeros([nframes, 1]) # pylint: disable=no-explicit-dtype + ) + atomic_energy_out = ( + np.concatenate(atomic_energy_out) + if atomic_energy_out + else np.zeros([nframes, natoms, 1]) # pylint: disable=no-explicit-dtype + ) + force_out = ( + np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype + ) + force_mag_out = ( + np.concatenate(force_mag_out) + if force_mag_out + else np.zeros([nframes, natoms, 3]) # pylint: disable=no-explicit-dtype + ) + virial_out = ( + np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3]) # pylint: disable=no-explicit-dtype + ) + atomic_virial_out = ( + np.concatenate(atomic_virial_out) + if atomic_virial_out + else np.zeros([nframes, natoms, 3, 3]) # pylint: disable=no-explicit-dtype + ) + updated_coord_out = ( + np.concatenate(updated_coord_out) if updated_coord_out else None + ) + logits_out = np.concatenate(logits_out) if logits_out else None + else: + energy_out = ( + torch.cat(energy_out) + if energy_out + else torch.zeros( + [nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + atomic_energy_out = ( + torch.cat(atomic_energy_out) + if atomic_energy_out + else torch.zeros( + [nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + force_out = ( + torch.cat(force_out) + if force_out + else torch.zeros( + [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + force_mag_out = ( + torch.cat(force_mag_out) + if force_mag_out + else torch.zeros( + [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + virial_out = ( + torch.cat(virial_out) + if virial_out + else torch.zeros( + [nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + atomic_virial_out = ( + torch.cat(atomic_virial_out) + if atomic_virial_out + else torch.zeros( + [nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None + logits_out = torch.cat(logits_out) if logits_out else None + if denoise: + return updated_coord_out, logits_out + else: + results_dict = { + "energy": energy_out, + "force": force_out, + "virial": virial_out, + } + if has_spin: + results_dict["force_mag"] = force_mag_out + if atomic: + results_dict["atom_energy"] = atomic_energy_out + results_dict["atom_virial"] = atomic_virial_out + return results_dict diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index d891583491..1adcff55fc 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -21,8 +21,10 @@ dtype = torch.float64 -from .test_permutation import ( +from ..common import ( eval_model, +) +from .test_permutation import ( model_dpa1, model_dpa2, model_hybrid, diff --git a/source/tests/pt/model/test_forward_lower.py b/source/tests/pt/model/test_forward_lower.py index c9857a6343..87a3f5b06e 100644 --- a/source/tests/pt/model/test_forward_lower.py +++ b/source/tests/pt/model/test_forward_lower.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -20,6 +17,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_null_input.py b/source/tests/pt/model/test_null_input.py index 1dca7ee119..a2e0fa66db 100644 --- a/source/tests/pt/model/test_null_input.py +++ b/source/tests/pt/model/test_null_input.py @@ -5,9 +5,6 @@ import numpy as np import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, get_zbl_model, @@ -22,6 +19,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index f5edc6ef64..2fbc5fde3c 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -5,9 +5,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -18,6 +15,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) CUR_DIR = os.path.dirname(__file__) diff --git a/source/tests/pt/model/test_permutation_denoise.py b/source/tests/pt/model/test_permutation_denoise.py index 133c48f551..53bf55fb0f 100644 --- a/source/tests/pt/model/test_permutation_denoise.py +++ b/source/tests/pt/model/test_permutation_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_rot.py b/source/tests/pt/model/test_rot.py index 23bdede923..ca6a6375c8 100644 --- a/source/tests/pt/model/test_rot.py +++ b/source/tests/pt/model/test_rot.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dos, model_dpa1, diff --git a/source/tests/pt/model/test_rot_denoise.py b/source/tests/pt/model/test_rot_denoise.py index 5fe99a0d7a..9828ba5225 100644 --- a/source/tests/pt/model/test_rot_denoise.py +++ b/source/tests/pt/model/test_rot_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation_denoise import ( model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index c33dddfab5..9a7040f9cc 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dos, model_dpa1, diff --git a/source/tests/pt/model/test_smooth_denoise.py b/source/tests/pt/model/test_smooth_denoise.py index 069c578d52..faa892c5d0 100644 --- a/source/tests/pt/model/test_smooth_denoise.py +++ b/source/tests/pt/model/test_smooth_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation_denoise import ( model_dpa2, ) diff --git a/source/tests/pt/model/test_trans.py b/source/tests/pt/model/test_trans.py index afd70f8995..b62fac1312 100644 --- a/source/tests/pt/model/test_trans.py +++ b/source/tests/pt/model/test_trans.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( # model_dpau, model_dos, model_dpa1, diff --git a/source/tests/pt/model/test_trans_denoise.py b/source/tests/pt/model/test_trans_denoise.py index 2d31d5de50..84ec21929c 100644 --- a/source/tests/pt/model/test_trans_denoise.py +++ b/source/tests/pt/model/test_trans_denoise.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation_denoise import ( model_dpa1, model_dpa2, diff --git a/source/tests/pt/model/test_unused_params.py b/source/tests/pt/model/test_unused_params.py index e225719e7f..3f068d5e5b 100644 --- a/source/tests/pt/model/test_unused_params.py +++ b/source/tests/pt/model/test_unused_params.py @@ -4,9 +4,6 @@ import torch -from deepmd.pt.infer.deep_eval import ( - eval_model, -) from deepmd.pt.model.model import ( get_model, ) @@ -17,6 +14,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..common import ( + eval_model, +) from .test_permutation import ( model_dpa2, )