Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 18, 2024
1 parent e48eb8b commit 5036545
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def test_property(

if has_atom_property:
for ii in range(numb_test):
test_out = test_data[f"atom_property"][ii].reshape(-1, 1)
test_out = test_data["atom_property"][ii].reshape(-1, 1)
pred_out = aproperty[ii].reshape(-1, 1)

frame_output = np.hstack((test_out, pred_out))
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,11 @@ def change_out_bias(
)
self._store_out_stat(delta_bias, out_std, add=True)
elif bias_adjust_mode == "set-by-statistic":
property_name = self.fitting_net.property_name if "property_name" in vars(self.fitting_net) else None
property_name = (
self.fitting_net.property_name
if "property_name" in vars(self.fitting_net)
else None
)
bias_out, std_out = compute_output_stats(
sample_merged,
self.get_ntypes(),
Expand Down
12 changes: 7 additions & 5 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,12 @@ def compute_output_stats_global(
if kk in stats_input:
if property_name is not None:
assert len(keys) == 1
bias_atom_e["property"], std_atom_e["property"] = compute_stats_property(
stats_input[kk],
merged_natoms[kk],
assigned_bias=assigned_atom_ener[kk],
bias_atom_e["property"], std_atom_e["property"] = (
compute_stats_property(
stats_input[kk],
merged_natoms[kk],
assigned_bias=assigned_atom_ener[kk],
)
)
return bias_atom_e, std_atom_e
else:
Expand All @@ -498,7 +500,7 @@ def compute_output_stats_global(
else:
# this key does not have global labels, skip it.
continue

bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e)

# unbias_e is only used for print rmse
Expand Down

0 comments on commit 5036545

Please sign in to comment.