Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add polar stat constant matrix calculation to PT #3426

Merged
merged 35 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
97548d3
feat: add constant_matrix calc
anyangml Mar 6, 2024
1ae79a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
2da5a35
feat: add output
anyangml Mar 6, 2024
fa2dc20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
abb3ca4
fix: typo
anyangml Mar 6, 2024
2503147
fix UTs
anyangml Mar 7, 2024
38969ff
fix: reuse out_stat
anyangml Mar 7, 2024
1734875
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
b261abc
fix: atomic stat
anyangml Mar 7, 2024
b81ec36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
ee28334
feat: add UTs
anyangml Mar 7, 2024
a742f49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
590aeba
fix: precommit
anyangml Mar 7, 2024
7748723
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
254c198
fix: precommit
anyangml Mar 7, 2024
8699c28
fix: UTs
anyangml Mar 7, 2024
4d22a5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
62e0386
fix UTs
anyangml Mar 7, 2024
78f1483
fix: CUDA
anyangml Mar 7, 2024
b9090c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
cd3ac47
fix: serialize
anyangml Mar 8, 2024
f416c99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
71ae5e3
Merge branch 'devel' into feat/polar_stat
anyangml Mar 8, 2024
cacbfb6
chore: bump version
anyangml Mar 8, 2024
4e8f8fe
fix: refacotr version check
anyangml Mar 8, 2024
3de4709
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
831158c
chore: refactor
anyangml Mar 8, 2024
543fb6e
Merge branch 'devel' into feat/polar_stat
anyangml Mar 8, 2024
5b3cb4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
2ffca0b
fix: precommit
anyangml Mar 8, 2024
4f63e95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
e4086fb
fix: precommit
anyangml Mar 8, 2024
edee7e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
042c86b
Merge branch 'devel' into feat/polar_stat
anyangml Mar 11, 2024
9ad2c97
fix: typo
anyangml Mar 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Union,
)

import numpy as np

Check warning on line 10 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L10

Added line #L10 was not covered by tests
import torch

from deepmd.dpmodel import (
Expand All @@ -25,6 +26,10 @@
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.out_stat import (

Check warning on line 29 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L29

Added line #L29 was not covered by tests
compute_stats_from_atomic,
compute_stats_from_redu,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -114,6 +119,7 @@
self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
).view(ntypes, 1)
self.shift_diag = shift_diag
self.constant_matrix = torch.zeros(ntypes)

Check warning on line 122 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L122

Added line #L122 was not covered by tests
super().__init__(
var_name=kwargs.pop("var_name", "polar"),
ntypes=ntypes,
Expand Down Expand Up @@ -184,7 +190,50 @@
The path to the stat file.

"""
pass
if self.shift_diag:
if stat_file_path is not None:
stat_file_path = stat_file_path / "constant_matrix"
if stat_file_path is not None and stat_file_path.is_file():
constant_matrix = stat_file_path.load_numpy()

Check warning on line 197 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L193-L197

Added lines #L193 - L197 were not covered by tests
else:
if callable(merged):

Check warning on line 199 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L199

Added line #L199 was not covered by tests
# only get data for once
sampled = merged()

Check warning on line 201 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L201

Added line #L201 was not covered by tests
else:
sampled = merged

Check warning on line 203 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L203

Added line #L203 was not covered by tests

sys_constant_matrix = []
for sys in range(len(sampled)):
nframs = sampled[sys]["type"].shape[0]

Check warning on line 207 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L205-L207

Added lines #L205 - L207 were not covered by tests

if sampled[sys]["find_atomic_polarizability"] > 0.0:
sys_atom_polar = compute_stats_from_atomic(

Check warning on line 210 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L209-L210

Added lines #L209 - L210 were not covered by tests
sampled[sys]["atomic_polarizability"], sampled[sys]["type"]
anyangml marked this conversation as resolved.
Show resolved Hide resolved
)[0]
else:
if not sampled[sys]["find_polarizability"] > 0.0:
continue
sys_type_count = np.zeros((nframs, self.ntypes))
anyangml marked this conversation as resolved.
Show resolved Hide resolved
for itype in range(self.ntypes):
type_mask = sampled[sys]["type"] == itype
sys_type_count[:, itype] = type_mask.sum(dim=1)

Check warning on line 219 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L214-L219

Added lines #L214 - L219 were not covered by tests

sys_bias_redu = sampled[sys]["polarizability"]

Check warning on line 221 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L221

Added line #L221 was not covered by tests

sys_atom_polar = compute_stats_from_redu(

Check warning on line 223 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L223

Added line #L223 was not covered by tests
sys_bias_redu, sys_type_count
)[0]
anyangml marked this conversation as resolved.
Show resolved Hide resolved
cur_constant_matrix = np.zeros(self.ntypes)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
for itype in range(self.ntypes):
cur_constant_matrix[itype] = torch.mean(

Check warning on line 228 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L226-L228

Added lines #L226 - L228 were not covered by tests
torch.diagonal(sys_atom_polar[itype].reshape(3, 3))
)
sys_constant_matrix.append(cur_constant_matrix)
constant_matrix = np.stack(sys_constant_matrix).mean(axis=0)

Check warning on line 232 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L231-L232

Added lines #L231 - L232 were not covered by tests

self.constant_matrix = torch.tensor(constant_matrix, device=env.DEVICE)
if stat_file_path is not None:
stat_file_path.save_numpy(self.constant_matrix.detach().cpu().numpy())

Check warning on line 236 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L234-L236

Added lines #L234 - L236 were not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved

def forward(
self,
Expand Down Expand Up @@ -218,5 +267,12 @@
"bim,bmj->bij", gr.transpose(1, 2), out
) # (nframes * nloc, 3, 3)
out = out.view(nframes, nloc, 3, 3)
if self.shift_diag:
out = (

Check warning on line 271 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L270-L271

Added lines #L270 - L271 were not covered by tests
out
+ self.constant_matrix[atype]
* torch.eye(3, device=env.DEVICE)
* self.scale[atype]
)

return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
8 changes: 3 additions & 5 deletions deepmd/tf/fit/polar.py
anyangml marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,16 @@ def compute_output_stats(self, all_stat):
index_lis = [
index
for index, w in enumerate(atom_has_polar)
if atom_has_polar[index] == self.sel_type[itype]
if w == self.sel_type[itype]
] # select index in this type

sys_matrix.append(np.zeros((1, len(self.sel_type))))
sys_matrix[-1][0, itype] = len(index_lis)

polar_bias.append(
np.sum(
all_stat["atomic_polarizability"][ss].reshape((-1, 9))[
index_lis
],
axis=0,
all_stat["atomic_polarizability"][ss][:, index_lis, :],
axis=(0, 1),
anyangml marked this conversation as resolved.
Show resolved Hide resolved
).reshape((1, 9))
)
else: # No atomic polar in this system, so it should have global polar
Expand Down
Loading