Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Pt Polar Stat keyword #3584

Merged
merged 12 commits into from
Mar 26, 2024
6 changes: 3 additions & 3 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,12 @@

sys_constant_matrix = []
for sys in range(len(sampled)):
nframs = sampled[sys]["type"].shape[0]
nframs = sampled[sys]["atype"].shape[0]

Check warning on line 233 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L233

Added line #L233 was not covered by tests

if sampled[sys]["find_atomic_polarizability"] > 0.0:
sys_atom_polar = compute_stats_from_atomic(
sampled[sys]["atomic_polarizability"].numpy(force=True),
sampled[sys]["type"].numpy(force=True),
sampled[sys]["atype"].numpy(force=True),
)[0]
else:
if not sampled[sys]["find_polarizability"] > 0.0:
Expand All @@ -244,7 +244,7 @@
(nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION
)
for itype in range(self.ntypes):
type_mask = sampled[sys]["type"] == itype
type_mask = sampled[sys]["atype"] == itype

Check warning on line 247 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L247

Added line #L247 was not covered by tests
sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(
force=True
)
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,17 @@
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(stat_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat[dd] = stat_data[dd]

Check warning on line 65 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L64-L65

Added lines #L64 - L65 were not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved
else:
pass

for key in sys_stat:
if sys_stat[key] is None or sys_stat[key][0] is None:
if isinstance(sys_stat[key], np.float32):
pass
elif sys_stat[key] is None or sys_stat[key][0] is None:

Check warning on line 72 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L70-L72

Added lines #L70 - L72 were not covered by tests
sys_stat[key] = None
else:
elif isinstance(stat_data[dd], torch.Tensor):

Check warning on line 74 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L74

Added line #L74 was not covered by tests
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
lst.append(sys_stat)
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@
for key in sample_dict:
if isinstance(sample_dict[key], list):
sample_dict[key] = [item.to(DEVICE) for item in sample_dict[key]]
if isinstance(sample_dict[key], np.float32):
sample_dict[key] = (

Check warning on line 120 in deepmd/pt/utils/utils.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/utils.py#L119-L120

Added lines #L119 - L120 were not covered by tests
torch.ones(1, dtype=torch.float32, device=DEVICE) * sample_dict[key]
)
else:
if sample_dict[key] is not None:
sample_dict[key] = sample_dict[key].to(DEVICE)
3 changes: 2 additions & 1 deletion source/tests/pt/model/test_polar_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setUp(self) -> None:
find_atomic_polarizability = torch.rand(1, device=env.DEVICE)
self.sampled = [
{
"type": types,
"atype": types,
"find_atomic_polarizability": find_atomic_polarizability,
"atomic_polarizability": atomic_polarizability,
"polarizability": polarizability,
Expand All @@ -40,6 +40,7 @@ def setUp(self) -> None:
self.all_stat = {
k: [v.numpy(force=True)] for d in self.sampled for k, v in d.items()
}
self.all_stat["type"] = self.all_stat.pop("atype")
self.tfpolar = PolarFittingSeA(
ntypes=ntypes,
dim_descrpt=1,
Expand Down