Skip to content

Commit

Permalink
feat: add constant_matrix calc
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Mar 6, 2024
1 parent fa8e645 commit 97548d3
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
).view(ntypes, 1)
self.shift_diag = shift_diag
self.constant_matrix = torch.zero(self.ntypes)
super().__init__(
var_name=kwargs.pop("var_name", "polar"),
ntypes=ntypes,
Expand Down Expand Up @@ -184,7 +185,60 @@ def compute_output_stats(
The path to the stat file.
"""
pass
if self.shift_diag:
if stat_file_path is not None:
stat_file_path = stat_file_path / "constant_matrix"
if stat_file_path is not None and stat_file_path.is_file():
constant_matrix = stat_file_path.load_numpy()
else:
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged

sys_matrix, polar_bias = [], []
for sys in len(sampled):
if sampled[sys]["find_atomic_polarizability"] > 0.0:
for itype in range(self.ntypes):
# this is a tensor of shape nframes, nall
type_mask = sampled[sys]["type"] == itype
sys_matrix.append(torch.zeros((1, self.ntypes)))
# this gives the number of atoms of type itype in the system
sys_matrix[-1][0, itype] = type_mask.sum().item()
expanded_mask = type_mask.unsqueeze(-1).expand((*type_mask.shape, 9))
polar_bias.append(
torch.sum(
(sampled[sys]["atomic_polarizability"]*expanded_mask).reshape(-1,9),
dim=0,
).reshape((1, 9))
)
else:
if (not sampled[sys]["find_polarizability"] > 0.0):
continue
sys_matrix.append(torch.zeros((1, self.ntypes)))
for itype in range(self.ntypes):
type_mask = sampled[sys]["type"] == itype
sys_matrix[-1][0, itype] = type_mask.sum().item()
polar_bias.append(sampled[sys]["polarizability"].reshape((1, 9)))
matrix, bias = (
torch.cat(sys_matrix, dim=0),
torch.cat(polar_bias, dim=0),
)
atom_polar, _, _, _ = torch.linalg.lstsq(matrix, bias, rcond=None)
constant_matrix = []
for itype in range(self.ntypes):
constant_matrix.append(
torch.mean(torch.diagonal(atom_polar[itype].reshape(3,3)))
)

self.constant_matrix = torch.tensor(constant_matrix, device=env.DEVICE)
if stat_file_path is not None:
stat_file_path.save_numpy(self.constant_matrix.detach().cpu().numpy())





def forward(
self,
Expand Down

0 comments on commit 97548d3

Please sign in to comment.