Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 11, 2023
1 parent 0e8876d commit 0b9c32f
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions deepmd/infer/model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,26 @@
DeepPot,
)


@overload
def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
atomic: Literal[False] = False,
)-> Tuple[np.ndarray, np.ndarray, np.ndarray]: ...
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
...


@overload
def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
*
atomic: Literal[True],
)-> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ...
*atomic: Literal[True],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
...


def calc_model_devi_f(
fs: np.ndarray,
Expand All @@ -59,7 +63,7 @@ def calc_model_devi_f(
value is the level parameter for computing the relative model
deviation of the force.
atomic : bool, default: False
Whether return deviation of force in all atoms
Whether return deviation of force in all atoms
Returns
-------
Expand Down Expand Up @@ -328,7 +332,11 @@ def calc_model_devi(
devi += list(
calc_model_devi_v(virials, real_data["virial"], relative=relative_v)
)
devi_f = list(calc_model_devi_f(forces, real_data["force"], relative=relative, atomic=atomic))
devi_f = list(
calc_model_devi_f(
forces, real_data["force"], relative=relative, atomic=atomic
)
)
devi += devi_f[:3]
devi.append(calc_model_devi_e(energies, real_data["energy"]))
devi = np.vstack(devi).T
Expand Down

0 comments on commit 0b9c32f

Please sign in to comment.