diff --git a/source/tests/consistent/model/test_dos.py b/source/tests/consistent/model/test_dos.py index c3734cee09..87a83b7f9a 100644 --- a/source/tests/consistent/model/test_dos.py +++ b/source/tests/consistent/model/test_dos.py @@ -6,18 +6,8 @@ import numpy as np -from deepmd.dpmodel.common import ( - to_numpy_array, -) from deepmd.dpmodel.model.dos_model import DOSModel as DOSModelDP from deepmd.dpmodel.model.model import get_model as get_model_dp -from deepmd.dpmodel.utils.nlist import ( - build_neighbor_list, - extend_coord_with_ghosts, -) -from deepmd.dpmodel.utils.region import ( - normalize_coord, -) from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, ) @@ -36,8 +26,6 @@ if INSTALLED_PT: from deepmd.pt.model.model import get_model as get_model_pt from deepmd.pt.model.model.dos_model import DOSModel as DOSModelPT - from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy - from deepmd.pt.utils.utils import to_torch_tensor as numpy_to_torch else: DOSModelPT = None if INSTALLED_TF: @@ -64,39 +52,27 @@ class TestDOS(CommonTest, ModelTest, unittest.TestCase): def data(self) -> dict: pair_exclude_types, atom_exclude_types = self.param return { - "type_map": [ - "H" - ], + "type_map": ["H"], "descriptor": { - "type": "se_e2_a", - "sel": [ - 90 - ], - "rcut_smth": 1.8, - "rcut": 6.0, - "neuron": [ - 25, - 50, - 100 - ], - "resnet_dt": False, - "axis_neuron": 8, - "precision": "float64", - "seed": 1 + "type": "se_e2_a", + "sel": [90], + "rcut_smth": 1.8, + "rcut": 6.0, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 8, + "precision": "float64", + "seed": 1, }, "fitting_net": { - "type": "dos", - "numb_dos": 250, - "neuron": [ - 120, - 120, - 120 - ], - "resnet_dt": True, - "numb_fparam": 0, - "precision": "float64", - "seed": 1 - } + "type": "dos", + "numb_dos": 250, + "neuron": [120, 120, 120], + "resnet_dt": True, + "numb_fparam": 0, + "precision": "float64", + "seed": 1, + }, } tf_class = DOSModelTF @@ -242,4 +218,3 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret[4].ravel(), ) raise ValueError(f"Unknown backend: {backend}") -