diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 0e72d2239f..e7752c47ca 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -209,23 +209,26 @@ def _make_preset_out_bias( def _fill_stat_with_global( atomic_stat: Union[np.ndarray, None], global_stat: np.ndarray, - ): +): """This function is used to fill atomic stat with global stat. - + Parameters ---------- atomic_stat : Union[np.ndarray, None] The atomic stat. global_stat : np.ndarray The global stat. - - if the atomic stat is None, use global stat. + if the atomic stat is None, use global stat. if the atomic stat is not None, but has nan values (missing atypes), fill with global stat. """ if atomic_stat is None: return global_stat else: - return np.nan_to_num(np.where(np.isnan(atomic_stat) & ~np.isnan(global_stat), global_stat, atomic_stat)) + return np.nan_to_num( + np.where( + np.isnan(atomic_stat) & ~np.isnan(global_stat), global_stat, atomic_stat + ) + ) def compute_output_stats( @@ -383,11 +386,13 @@ def compute_output_stats( std_atom_e[kk] = None # use global bias to fill missing atomic bias if kk in bias_atom_g: - bias_atom_e[kk] = _fill_stat_with_global(bias_atom_e[kk], bias_atom_g[kk]) + bias_atom_e[kk] = _fill_stat_with_global( + bias_atom_e[kk], bias_atom_g[kk] + ) std_atom_e[kk] = _fill_stat_with_global(std_atom_e[kk], std_atom_g[kk]) else: raise RuntimeError("Fail to compute stat.") - + if stat_file_path is not None: _save_to_file(stat_file_path, bias_atom_e, std_atom_e) @@ -566,8 +571,8 @@ def compute_output_stats_atomic( if missing_types > 0: nan_padding = np.empty((missing_types, bias_atom_e[kk].shape[1])) nan_padding.fill(np.nan) - bias_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding],axis=0) - std_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding],axis=0) + bias_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0) + std_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0) else: # this key does not have atomic labels, skip it. continue diff --git a/source/tests/pt/model/test_atomic_bias.py b/source/tests/pt/model/test_atomic_bias.py index bd4f2d5cd6..dc0c55eb53 100644 --- a/source/tests/pt/model/test_atomic_bias.py +++ b/source/tests/pt/model/test_atomic_bias.py @@ -255,7 +255,11 @@ def raise_error(): for kk in ["foo"]: np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) assert False -class TestAtomicModelStatMergeGlobalAtomic(unittest.TestCase, TestCaseSingleFrameWithNlist): + + +class TestAtomicModelStatMergeGlobalAtomic( + unittest.TestCase, TestCaseSingleFrameWithNlist +): def tearDown(self): self.tempdir.cleanup() @@ -398,8 +402,6 @@ def raise_error(): ret3 = cvt_ret(ret3) expected_ret3 = {} # new bias [2, -5] - expected_ret3["foo"] = np.array( - [[3, 4, -2], [6, 0, 1]] - ).reshape(2, 3, 1) + expected_ret3["foo"] = np.array([[3, 4, -2], [6, 0, 1]]).reshape(2, 3, 1) for kk in ["foo"]: np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4)