Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 2, 2024
1 parent afde28b commit e5e8bac
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def compute_output_stats(
"dipole": "dipole",
}

if "energy" in keys: # this is the energy fitting which may have keys ['energy', 'dforce']
if (

Check warning on line 128 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L128

Added line #L128 was not covered by tests
"energy" in keys
): # this is the energy fitting which may have keys ['energy', 'dforce']
return compute_output_stats_global_only(

Check warning on line 131 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L131

Added line #L131 was not covered by tests
merged=merged,
ntypes=ntypes,
Expand All @@ -135,7 +137,9 @@ def compute_output_stats(
atom_ener=atom_ener,
model_forward=model_forward,
)
elif ["dos", "polar"] in keys: # this is the polar fitting or dos fitting which may have keys ['polar'] or ['dos']
elif (

Check warning on line 140 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L140

Added line #L140 was not covered by tests
["dos", "polar"] in keys
): # this is the polar fitting or dos fitting which may have keys ['polar'] or ['dos']
return compute_output_stats_with_atomic(

Check warning on line 143 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L143

Added line #L143 was not covered by tests
merged=merged,
ntypes=ntypes,
Expand Down Expand Up @@ -174,7 +178,7 @@ def compute_output_stats_global_only(
the lazy function helps by only sampling once.
ntypes : int
The number of atom types.
keys : List[int]
keys : List[int]
The output variable names of a given atomic model, can be found in `fitting_output_def` of the model.
stat_file_path : DPPath, optional
The path to the stat file.
Expand Down Expand Up @@ -323,7 +327,7 @@ def compute_output_stats_with_atomic(
the lazy function helps by only sampling once.
ntypes : int
The number of atom types.
key : str
keys : str
The var_name of the fitting net.
stat_file_path : DPPath, optional
The path to the stat file.
Expand All @@ -337,7 +341,6 @@ def compute_output_stats_with_atomic(
which will be subtracted from the energy label of the data.
The difference will then be used to calculate the delta complement energy bias for each type.
"""

atom_bias = restore_from_file(stat_file_path, keys)

Check warning on line 344 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L344

Added line #L344 was not covered by tests

if atom_bias is None:
Expand Down Expand Up @@ -377,13 +380,13 @@ def model_forward_auto_batch_size(*args, **kwargs):
for kk in keys:
model_predict[kk].append(

Check warning on line 381 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L380-L381

Added lines #L380 - L381 were not covered by tests
to_numpy_array(
sample_predict[kk] # nf x nloc x odims
sample_predict[kk] # nf x nloc x odims
)
)

# this stores all atomic predictions.
model_predict = {kk: np.concatenate(model_predict[kk]) for kk in keys}

Check warning on line 388 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L388

Added line #L388 was not covered by tests

else:
model_predict = {}

Check warning on line 391 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L391

Added line #L391 was not covered by tests

Expand All @@ -394,7 +397,7 @@ def model_forward_auto_batch_size(*args, **kwargs):
for kk in keys:
if "find_atom_" + kk > 0.0:
sys_property = system["atom_" + kk].numpy(force=True)
if not kk in model_predict:
if kk not in model_predict:
sys_bias = compute_stats_from_atomic(

Check warning on line 401 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L397-L401

Added lines #L397 - L401 were not covered by tests
sys_property,
system["atype"].numpy(force=True),
Expand All @@ -413,9 +416,11 @@ def model_forward_auto_batch_size(*args, **kwargs):
)
for itype in range(ntypes):
type_mask = system["atype"] == itype
sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(force=True)
sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(

Check warning on line 419 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L417-L419

Added lines #L417 - L419 were not covered by tests
force=True
)
sys_bias_redu = system[kk].numpy(force=True)
if not kk in model_predict:
if kk not in model_predict:
sys_bias = compute_stats_from_redu(

Check warning on line 424 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L422-L424

Added lines #L422 - L424 were not covered by tests
sys_bias_redu, sys_type_count, rcond=rcond
)[0]
Expand All @@ -427,7 +432,9 @@ def model_forward_auto_batch_size(*args, **kwargs):

atom_bias[kk].append(sys_bias)

Check warning on line 433 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L433

Added line #L433 was not covered by tests
# need to take care shift diag and add atom_ener
atom_bias = {kk: np.nan_to_num(np.stack(vv).mean(axis=0)) for kk, vv in atom_bias.items()}
atom_bias = {

Check warning on line 435 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L435

Added line #L435 was not covered by tests
kk: np.nan_to_num(np.stack(vv).mean(axis=0)) for kk, vv in atom_bias.items()
}
if stat_file_path is not None:
save_to_file(stat_file_path, atom_bias)
ret = {kk: to_torch_tensor(atom_bias[kk]) for kk in keys}

Check warning on line 440 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L438-L440

Added lines #L438 - L440 were not covered by tests
Expand Down

0 comments on commit e5e8bac

Please sign in to comment.