diff --git a/dpdata/stat.py b/dpdata/stat.py index 5ec39570..ed74c258 100644 --- a/dpdata/stat.py +++ b/dpdata/stat.py @@ -2,13 +2,14 @@ from abc import ABCMeta, abstractmethod from functools import lru_cache +from typing import Any import numpy as np from dpdata.system import LabeledSystem, MultiSystems -def mae(errors: np.ndarray) -> np.float64: +def mae(errors: np.ndarray) -> np.floating[Any]: """Compute the mean absolute error (MAE). Parameters @@ -18,13 +19,13 @@ def mae(errors: np.ndarray) -> np.float64: Returns ------- - np.float64 + floating[Any] mean absolute error (MAE) """ return np.mean(np.abs(errors)) -def rmse(errors: np.ndarray) -> np.float64: +def rmse(errors: np.ndarray) -> np.floating[Any]: """Compute the root mean squared error (RMSE). Parameters @@ -34,7 +35,7 @@ def rmse(errors: np.ndarray) -> np.float64: Returns ------- - np.float64 + floating[Any] root mean squared error (RMSE) """ return np.sqrt(np.mean(np.square(errors))) @@ -74,22 +75,22 @@ def f_errors(self) -> np.ndarray: """Force errors.""" @property - def e_mae(self) -> np.float64: + def e_mae(self) -> np.floating[Any]: """Energy MAE.""" return mae(self.e_errors) @property - def e_rmse(self) -> np.float64: + def e_rmse(self) -> np.floating[Any]: """Energy RMSE.""" return rmse(self.e_errors) @property - def f_mae(self) -> np.float64: + def f_mae(self) -> np.floating[Any]: """Force MAE.""" return mae(self.f_errors) @property - def f_rmse(self) -> np.float64: + def f_rmse(self) -> np.floating[Any]: """Force RMSE.""" return rmse(self.f_errors)