Skip to content

Commit

Permalink
add atomic=False option to calc_model_devi_f
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz committed Sep 11, 2023
1 parent 2046592 commit 0e8876d
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions deepmd/infer/model_devi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Literal,
Optional,
Tuple,
overload,
)

import numpy as np
Expand All @@ -20,12 +22,29 @@
DeepPot,
)

@overload
def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
atomic: Literal[False] = False,
)-> Tuple[np.ndarray, np.ndarray, np.ndarray]: ...

@overload
def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
*
atomic: Literal[True],
)-> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ...

def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
) -> Tuple[np.ndarray]:
atomic: bool = False,
) -> Tuple[np.ndarray, ...]:
"""Calculate model deviation of force.
Parameters
Expand All @@ -39,6 +58,8 @@ def calc_model_devi_f(
If given, calculate the relative model deviation of force. The
value is the level parameter for computing the relative model
deviation of the force.
atomic : bool, default: False
Whether return deviation of force in all atoms
Returns
-------
Expand All @@ -49,7 +70,7 @@ def calc_model_devi_f(
avg_devi_f : numpy.ndarray
average deviation of force in all atoms
fs_devi : numpy.ndarray
deviation of force in all atoms
deviation of force in all atoms, returned if atomic=True
"""
if real_f is None:
fs_devi = np.linalg.norm(np.std(fs, axis=0), axis=-1)
Expand All @@ -70,7 +91,9 @@ def calc_model_devi_f(
max_devi_f = np.max(fs_devi, axis=-1)
min_devi_f = np.min(fs_devi, axis=-1)
avg_devi_f = np.mean(fs_devi, axis=-1)
return max_devi_f, min_devi_f, avg_devi_f, fs_devi
if atomic:
return max_devi_f, min_devi_f, avg_devi_f, fs_devi

Check warning on line 95 in deepmd/infer/model_devi.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/model_devi.py#L94-L95

Added lines #L94 - L95 were not covered by tests
return max_devi_f, min_devi_f, avg_devi_f


def calc_model_devi_e(
Expand Down Expand Up @@ -107,7 +130,7 @@ def calc_model_devi_v(
vs: np.ndarray,
real_v: Optional[np.ndarray] = None,
relative: Optional[float] = None,
) -> Tuple[np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Calculate model deviation of virial.
Parameters
Expand Down Expand Up @@ -298,14 +321,14 @@ def calc_model_devi(
devi = [np.arange(coord.shape[0]) * frequency]
if real_data is None:
devi += list(calc_model_devi_v(virials, relative=relative_v))
devi_f = list(calc_model_devi_f(forces, relative=relative))
devi_f = list(calc_model_devi_f(forces, relative=relative, atomic=atomic))
devi += devi_f[:3]

Check warning on line 325 in deepmd/infer/model_devi.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/model_devi.py#L323-L325

Added lines #L323 - L325 were not covered by tests
devi.append(calc_model_devi_e(energies))
else:
devi += list(

Check warning on line 328 in deepmd/infer/model_devi.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/model_devi.py#L328

Added line #L328 was not covered by tests
calc_model_devi_v(virials, real_data["virial"], relative=relative_v)
)
devi_f = list(calc_model_devi_f(forces, real_data["force"], relative=relative))
devi_f = list(calc_model_devi_f(forces, real_data["force"], relative=relative, atomic=atomic))
devi += devi_f[:3]

Check warning on line 332 in deepmd/infer/model_devi.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/model_devi.py#L331-L332

Added lines #L331 - L332 were not covered by tests
devi.append(calc_model_devi_e(energies, real_data["energy"]))
devi = np.vstack(devi).T
Expand Down

0 comments on commit 0e8876d

Please sign in to comment.