Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add polar stat constant matrix calculation to PT #3426

Merged
merged 35 commits into from
Mar 11, 2024
Merged
Changes from 2 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
97548d3
feat: add constant_matrix calc
anyangml Mar 6, 2024
1ae79a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
2da5a35
feat: add output
anyangml Mar 6, 2024
fa2dc20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
abb3ca4
fix: typo
anyangml Mar 6, 2024
2503147
fix UTs
anyangml Mar 7, 2024
38969ff
fix: reuse out_stat
anyangml Mar 7, 2024
1734875
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
b261abc
fix: atomic stat
anyangml Mar 7, 2024
b81ec36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
ee28334
feat: add UTs
anyangml Mar 7, 2024
a742f49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
590aeba
fix: precommit
anyangml Mar 7, 2024
7748723
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
254c198
fix: precommit
anyangml Mar 7, 2024
8699c28
fix: UTs
anyangml Mar 7, 2024
4d22a5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
62e0386
fix UTs
anyangml Mar 7, 2024
78f1483
fix: CUDA
anyangml Mar 7, 2024
b9090c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
cd3ac47
fix: serialize
anyangml Mar 8, 2024
f416c99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
71ae5e3
Merge branch 'devel' into feat/polar_stat
anyangml Mar 8, 2024
cacbfb6
chore: bump version
anyangml Mar 8, 2024
4e8f8fe
fix: refacotr version check
anyangml Mar 8, 2024
3de4709
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
831158c
chore: refactor
anyangml Mar 8, 2024
543fb6e
Merge branch 'devel' into feat/polar_stat
anyangml Mar 8, 2024
5b3cb4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
2ffca0b
fix: precommit
anyangml Mar 8, 2024
4f63e95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
e4086fb
fix: precommit
anyangml Mar 8, 2024
edee7e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
042c86b
Merge branch 'devel' into feat/polar_stat
anyangml Mar 11, 2024
9ad2c97
fix: typo
anyangml Mar 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
).view(ntypes, 1)
self.shift_diag = shift_diag
self.constant_matrix = torch.zero(self.ntypes)
super().__init__(
var_name=kwargs.pop("var_name", "polar"),
ntypes=ntypes,
Expand Down Expand Up @@ -184,7 +185,63 @@
The path to the stat file.

"""
pass
if self.shift_diag:
if stat_file_path is not None:
stat_file_path = stat_file_path / "constant_matrix"
if stat_file_path is not None and stat_file_path.is_file():
constant_matrix = stat_file_path.load_numpy()
else:
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged

sys_matrix, polar_bias = [], []
for sys in len(sampled):
Fixed Show fixed Hide fixed
if sampled[sys]["find_atomic_polarizability"] > 0.0:
for itype in range(self.ntypes):
# this is a tensor of shape nframes, nall
type_mask = sampled[sys]["type"] == itype
sys_matrix.append(torch.zeros((1, self.ntypes)))
# this gives the number of atoms of type itype in the system
sys_matrix[-1][0, itype] = type_mask.sum().item()
expanded_mask = type_mask.unsqueeze(-1).expand(
(*type_mask.shape, 9)
)
polar_bias.append(
torch.sum(
(
sampled[sys]["atomic_polarizability"]
* expanded_mask
).reshape(-1, 9),
dim=0,
).reshape((1, 9))
)
else:
if not sampled[sys]["find_polarizability"] > 0.0:
continue
sys_matrix.append(torch.zeros((1, self.ntypes)))
for itype in range(self.ntypes):
type_mask = sampled[sys]["type"] == itype
sys_matrix[-1][0, itype] = type_mask.sum().item()
polar_bias.append(
sampled[sys]["polarizability"].reshape((1, 9))
)
matrix, bias = (
torch.cat(sys_matrix, dim=0),
torch.cat(polar_bias, dim=0),
)
atom_polar, _, _, _ = torch.linalg.lstsq(matrix, bias, rcond=None)
constant_matrix = []
for itype in range(self.ntypes):
constant_matrix.append(
torch.mean(torch.diagonal(atom_polar[itype].reshape(3, 3)))
)

self.constant_matrix = torch.tensor(constant_matrix, device=env.DEVICE)
if stat_file_path is not None:
stat_file_path.save_numpy(self.constant_matrix.detach().cpu().numpy())
anyangml marked this conversation as resolved.
Show resolved Hide resolved

def forward(
self,
Expand Down
Loading