Skip to content

Commit

Permalink
fix: UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Apr 9, 2024
1 parent bb47541 commit f504f07
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 299 deletions.
11 changes: 5 additions & 6 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def compute_output_stats(
bias_atom_e[kk], bias_atom_g[kk]
)
std_atom_e[kk] = _fill_stat_with_global(std_atom_e[kk], std_atom_g[kk])
else:
if (bias_atom_e[kk] is None) or (std_atom_e[kk] is None):
raise RuntimeError("Fail to compute stat.")

if stat_file_path is not None:
Expand Down Expand Up @@ -492,20 +492,20 @@ def compute_output_stats_global(

if model_pred is None:
unbias_e = {
kk: merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1) for kk in keys
kk: merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1) for kk in bias_atom_e.keys()
}
else:
unbias_e = {
kk: model_pred[kk].reshape(nf[kk], -1)
+ merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1)
for kk in keys
for kk in bias_atom_e.keys()
}
atom_numbs = {kk: merged_natoms[kk].sum(-1) for kk in keys}
atom_numbs = {kk: merged_natoms[kk].sum(-1) for kk in bias_atom_e.keys()}

def rmse(x):
return np.sqrt(np.mean(np.square(x)))

for kk in keys:
for kk in bias_atom_e.keys():
rmse_ae = rmse(
(unbias_e[kk].reshape(nf[kk], -1) - merged_output[kk].reshape(nf[kk], -1))
/ atom_numbs[kk][:, None]
Expand Down Expand Up @@ -576,6 +576,5 @@ def compute_output_stats_atomic(
else:
# this key does not have atomic labels, skip it.
continue

bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e)
return bias_atom_e, std_atom_e
1 change: 0 additions & 1 deletion source/tests/pt/model/test_atomic_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def raise_error():
).reshape(2, 3, 1)
for kk in ["foo"]:
np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4)
assert False


class TestAtomicModelStatMergeGlobalAtomic(
Expand Down
Loading

0 comments on commit f504f07

Please sign in to comment.