From 0e8876da6c7de9ec5bf2b81ada0678982bc925eb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 11 Sep 2023 04:12:25 +0000 Subject: [PATCH] add atomic=False option to calc_model_devi_f --- deepmd/infer/model_devi.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/deepmd/infer/model_devi.py b/deepmd/infer/model_devi.py index dbf62085c7..305e8c8e90 100644 --- a/deepmd/infer/model_devi.py +++ b/deepmd/infer/model_devi.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Literal, Optional, Tuple, + overload, ) import numpy as np @@ -20,12 +22,29 @@ DeepPot, ) +@overload +def calc_model_devi_f( + fs: np.ndarray, + real_f: Optional[np.ndarray] = None, + relative: Optional[float] = None, + atomic: Literal[False] = False, +)-> Tuple[np.ndarray, np.ndarray, np.ndarray]: ... + +@overload +def calc_model_devi_f( + fs: np.ndarray, + real_f: Optional[np.ndarray] = None, + relative: Optional[float] = None, + * + atomic: Literal[True], +)-> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ... def calc_model_devi_f( fs: np.ndarray, real_f: Optional[np.ndarray] = None, relative: Optional[float] = None, -) -> Tuple[np.ndarray]: + atomic: bool = False, +) -> Tuple[np.ndarray, ...]: """Calculate model deviation of force. Parameters @@ -39,6 +58,8 @@ def calc_model_devi_f( If given, calculate the relative model deviation of force. The value is the level parameter for computing the relative model deviation of the force. + atomic : bool, default: False + Whether return deviation of force in all atoms Returns ------- @@ -49,7 +70,7 @@ def calc_model_devi_f( avg_devi_f : numpy.ndarray average deviation of force in all atoms fs_devi : numpy.ndarray - deviation of force in all atoms + deviation of force in all atoms, returned if atomic=True """ if real_f is None: fs_devi = np.linalg.norm(np.std(fs, axis=0), axis=-1) @@ -70,7 +91,9 @@ def calc_model_devi_f( max_devi_f = np.max(fs_devi, axis=-1) min_devi_f = np.min(fs_devi, axis=-1) avg_devi_f = np.mean(fs_devi, axis=-1) - return max_devi_f, min_devi_f, avg_devi_f, fs_devi + if atomic: + return max_devi_f, min_devi_f, avg_devi_f, fs_devi + return max_devi_f, min_devi_f, avg_devi_f def calc_model_devi_e( @@ -107,7 +130,7 @@ def calc_model_devi_v( vs: np.ndarray, real_v: Optional[np.ndarray] = None, relative: Optional[float] = None, -) -> Tuple[np.ndarray]: +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Calculate model deviation of virial. Parameters @@ -298,14 +321,14 @@ def calc_model_devi( devi = [np.arange(coord.shape[0]) * frequency] if real_data is None: devi += list(calc_model_devi_v(virials, relative=relative_v)) - devi_f = list(calc_model_devi_f(forces, relative=relative)) + devi_f = list(calc_model_devi_f(forces, relative=relative, atomic=atomic)) devi += devi_f[:3] devi.append(calc_model_devi_e(energies)) else: devi += list( calc_model_devi_v(virials, real_data["virial"], relative=relative_v) ) - devi_f = list(calc_model_devi_f(forces, real_data["force"], relative=relative)) + devi_f = list(calc_model_devi_f(forces, real_data["force"], relative=relative, atomic=atomic)) devi += devi_f[:3] devi.append(calc_model_devi_e(energies, real_data["energy"])) devi = np.vstack(devi).T