Skip to content

Commit

Permalink
store reference std
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Feb 28, 2024
1 parent 9f0b9b7 commit 7c12ad0
Showing 1 changed file with 49 additions and 6 deletions.
55 changes: 49 additions & 6 deletions source/tests/common/test_out_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,19 @@ def setUp(self) -> None:
def test_compute_stats_from_redu(self):
bias, std = compute_stats_from_redu(self.output_redu, self.natoms)
np.testing.assert_allclose(bias, self.mean, rtol=1e-7)
reference_std = np.array(
[
0.01700638138272794,
0.01954897296228177,
0.020281857747683162,
0.010741237959989648,
0.020258211828681347,
]
)
np.testing.assert_allclose(
std,
np.sqrt(self.natoms.mean(axis=0) @ np.square(self.std)),
rtol=1e-1,
reference_std,
rtol=1e-7,
)
# ensure the sum is close
np.testing.assert_allclose(
Expand All @@ -59,10 +68,19 @@ def test_compute_stats_from_redu_with_assigned_bias(self):
)
np.testing.assert_allclose(bias, self.mean, rtol=1e-7)
np.testing.assert_allclose(bias[0], self.mean[0], rtol=1e-14)
reference_std = np.array(
[
0.017015794087883902,
0.019549011723239484,
0.020285565914828625,
0.01074124012073672,
0.020283557003416414,
]
)
np.testing.assert_allclose(
std,
np.sqrt(self.natoms.mean(axis=0) @ np.square(self.std)),
rtol=1e-1,
reference_std,
rtol=1e-7,
)
# ensure the sum is close
np.testing.assert_allclose(
Expand All @@ -74,8 +92,33 @@ def test_compute_stats_from_redu_with_assigned_bias(self):
def test_compute_stats_from_atomic(self):
bias, std = compute_stats_from_atomic(self.output, self.atype)
np.testing.assert_allclose(bias, self.mean)
reference_std = np.array(
[
[
0.0005452949516910239,
0.000686732800598535,
0.00089423457667224,
7.818017989121455e-05,
0.0004758637035637342,
],
[
2.0610161678825724e-05,
0.0007728218734771541,
0.0004754659308165858,
0.0001809007655290948,
0.0008187364708029638,
],
[
0.0007935836092665254,
0.00031176505013516624,
0.0005469653430009186,
0.0005652240916389281,
0.0006087722080071852,
],
]
)
np.testing.assert_allclose(
std,
self.std,
rtol=1e-2,
reference_std,
rtol=1e-7,
)

0 comments on commit 7c12ad0

Please sign in to comment.