From 02b3a6d93ef4b835c1509e2fe3bd0d70278027a0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 1 Feb 2024 16:13:22 -0500 Subject: [PATCH] fix PT tests Signed-off-by: Jinzhe Zeng --- deepmd/infer/deep_eval.py | 4 ++-- deepmd/pt/infer/deep_eval.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 266b3b2d4d..c8cc1186e7 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -322,10 +322,10 @@ def get_dim_aparam(self) -> int: def _get_natoms_and_nframes( self, coords: np.ndarray, - atom_types: Union[List[int], np.ndarray], + atom_types: np.ndarray, mixed_type: bool = False, ) -> Tuple[int, int]: - if mixed_type: + if mixed_type or atom_types.ndim > 1: natoms = len(atom_types[0]) else: natoms = len(atom_types) diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 8eb7e06868..41bf8a4bce 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -126,7 +126,8 @@ def get_ntypes_spin(self): "energy": "atom_energy", "energy_redu": "energy", "energy_derv_r": "force", - "energy_derv_c": "atom_virial", + # not same as TF... + "energy_derv_c": "atomic_virial", "energy_derv_c_redu": "virial", }