From 201cf80a22e497075c608fd8d6364f6abd89defc Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 18:10:07 -0500 Subject: [PATCH] fix docstring Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index efeeabaea1..2bef086726 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -125,7 +125,7 @@ def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]: """A decorator that casts and casts back the input and output tensor of a method. - The decorator should be used in a classmethod. + The decorator should be used on an instance method. The decorator will do the following thing: (1) It casts input arrays from the global precision @@ -201,8 +201,8 @@ def safe_cast_array( Parameters ---------- - input : tf.Tensor - Input tensor + input : np.ndarray or None + Input array from_precision : str Array data type that is casted from to_precision : str @@ -210,8 +210,8 @@ def safe_cast_array( Returns ------- - tf.Tensor - casted Tensor + np.ndarray or None + casted array """ if array_api_compat.is_array_api_obj(input): xp = array_api_compat.array_namespace(input)