From b7d2b324249a21a1d02d620ba32e4986d3e2ccac Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 17:32:26 -0500 Subject: [PATCH] support dict Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 42a3136d2c..efeeabaea1 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -172,6 +172,11 @@ def wrapper(self, *args, **kwargs): return tuple( safe_cast_array(vv, self.precision, "global") for vv in returned_tensor ) + elif isinstance(returned_tensor, dict): + return { + kk: safe_cast_array(vv, self.precision, "global") + for kk, vv in returned_tensor.items() + } else: return safe_cast_array(returned_tensor, self.precision, "global")