Skip to content

Commit

Permalink
fix: stat
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Apr 2, 2024
1 parent 164898c commit f07533b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
8 changes: 4 additions & 4 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,16 @@ def change_out_bias(
delta_bias = compute_output_stats(
merged,
self.get_ntypes(),
keys=["energy"],
keys=list(self.fitting_output_def().keys()),
model_forward=self.get_forward_wrapper_func(),
)["energy"]
)[list(self.fitting_output_def().keys())[0]]
self.set_out_bias(delta_bias, add=True)
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
merged,
self.get_ntypes(),
keys=["energy"],
)["energy"]
keys=list(self.fitting_output_def().keys()),
)[list(self.fitting_output_def().keys())[0]]
self.set_out_bias(bias_atom)
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
Expand Down
6 changes: 2 additions & 4 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def compute_output_stats(
model_forward=model_forward,
)
elif (

Check warning on line 140 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L140

Added line #L140 was not covered by tests
["dos", "polar"] in keys
"dos" in keys or "polar" in keys
): # this is the polar fitting or dos fitting which may have keys ['polar'] or ['dos']
return compute_output_stats_with_atomic(

Check warning on line 143 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L143

Added line #L143 was not covered by tests
merged=merged,
Expand Down Expand Up @@ -192,7 +192,6 @@ def compute_output_stats_global_only(
The difference will then be used to calculate the delta complement energy bias for each type.
"""
bias_atom_e = restore_from_file(stat_file_path, keys)

if bias_atom_e is None:
if callable(merged):
# only get data for once
Expand Down Expand Up @@ -394,7 +393,7 @@ def model_forward_auto_batch_size(*args, **kwargs):
nframs = system["atype"].shape[0]

Check warning on line 393 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L391-L393

Added lines #L391 - L393 were not covered by tests

for kk in keys:
if "find_atom_" + kk > 0.0:
if system["find_atom_" + kk] > 0.0:
sys_property = system["atom_" + kk].numpy(force=True)
if kk not in model_predict:
sys_bias = compute_stats_from_atomic(

Check warning on line 399 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L395-L399

Added lines #L395 - L399 were not covered by tests
Expand Down Expand Up @@ -437,5 +436,4 @@ def model_forward_auto_batch_size(*args, **kwargs):
if stat_file_path is not None:
save_to_file(stat_file_path, atom_bias)
ret = {kk: to_torch_tensor(atom_bias[kk]) for kk in keys}

return ret

Check warning on line 439 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L436-L439

Added lines #L436 - L439 were not covered by tests

0 comments on commit f07533b

Please sign in to comment.