Skip to content

Commit

Permalink
support dict
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 12, 2024
1 parent d5e8b49 commit b7d2b32
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ def wrapper(self, *args, **kwargs):
return tuple(
safe_cast_array(vv, self.precision, "global") for vv in returned_tensor
)
elif isinstance(returned_tensor, dict):
return {
kk: safe_cast_array(vv, self.precision, "global")
for kk, vv in returned_tensor.items()
}
else:
return safe_cast_array(returned_tensor, self.precision, "global")

Expand Down

0 comments on commit b7d2b32

Please sign in to comment.