Skip to content

Commit

Permalink
set dtype
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 1, 2024
1 parent 1b3ea6b commit 41863c9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepmd/dpmodel/utils/safe_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def safe_for_vector_norm(x, /, *, axis=None, keepdims=False, ord=2):
xp.linalg.vector_norm(
xp.where(mask, x, xp.ones_like(x)), axis=axis, keepdims=keepdims, ord=ord
),
xp.zeros_like(mask_squeezed),
xp.zeros_like(mask_squeezed, dtype=x.dtype),
)

0 comments on commit 41863c9

Please sign in to comment.