diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index c807dbf2ef..970019028c 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -213,7 +213,7 @@ def forward_common_atomic( tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr)) ret_dict[kk] = xp.reshape(tmp_arr, out_shape) - ret_dict["mask"] = atom_mask + ret_dict["mask"] = xp.astype(atom_mask, xp.int32) return ret_dict