diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index cbd4c0af4a..b89ada98a1 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -208,7 +208,7 @@ def compute_output_stats( if sampled[sys]["find_atomic_polarizability"] > 0.0: sys_atom_polar = compute_stats_from_atomic( - sampled[sys]["atomic_polarizability"], sampled[sys]["type"] + sampled[sys]["atomic_polarizability"].numpy(force=True), sampled[sys]["type"].numpy(force=True) )[0] else: if not sampled[sys]["find_polarizability"] > 0.0: @@ -216,9 +216,9 @@ def compute_output_stats( sys_type_count = np.zeros((nframs, self.ntypes)) for itype in range(self.ntypes): type_mask = sampled[sys]["type"] == itype - sys_type_count[:, itype] = type_mask.sum(dim=1) + sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(force=True) - sys_bias_redu = sampled[sys]["polarizability"] + sys_bias_redu = sampled[sys]["polarizability"].numpy(force=True) sys_atom_polar = compute_stats_from_redu( sys_bias_redu, sys_type_count diff --git a/source/tests/pt/model/test_polar_stat.py b/source/tests/pt/model/test_polar_stat.py index 02faf8ed55..3503d8e64e 100644 --- a/source/tests/pt/model/test_polar_stat.py +++ b/source/tests/pt/model/test_polar_stat.py @@ -36,7 +36,7 @@ def setUp(self) -> None: "find_polarizability": find_polarizability, } ] - self.all_stat = {k: [v.numpy()] for d in self.sampled for k, v in d.items()} + self.all_stat = {k: [v.numpy(force=True)] for d in self.sampled for k, v in d.items()} self.tfpolar = PolarFittingSeA( ntypes=ntypes, dim_descrpt=1,