diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index fa4f6d7f37..a35fae90bf 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -114,6 +114,7 @@ def __init__( self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE ).view(ntypes, 1) self.shift_diag = shift_diag + self.constant_matrix = torch.zero(self.ntypes) super().__init__( var_name=kwargs.pop("var_name", "polar"), ntypes=ntypes, @@ -184,7 +185,60 @@ def compute_output_stats( The path to the stat file. """ - pass + if self.shift_diag: + if stat_file_path is not None: + stat_file_path = stat_file_path / "constant_matrix" + if stat_file_path is not None and stat_file_path.is_file(): + constant_matrix = stat_file_path.load_numpy() + else: + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + + sys_matrix, polar_bias = [], [] + for sys in len(sampled): + if sampled[sys]["find_atomic_polarizability"] > 0.0: + for itype in range(self.ntypes): + # this is a tensor of shape nframes, nall + type_mask = sampled[sys]["type"] == itype + sys_matrix.append(torch.zeros((1, self.ntypes))) + # this gives the number of atoms of type itype in the system + sys_matrix[-1][0, itype] = type_mask.sum().item() + expanded_mask = type_mask.unsqueeze(-1).expand((*type_mask.shape, 9)) + polar_bias.append( + torch.sum( + (sampled[sys]["atomic_polarizability"]*expanded_mask).reshape(-1,9), + dim=0, + ).reshape((1, 9)) + ) + else: + if (not sampled[sys]["find_polarizability"] > 0.0): + continue + sys_matrix.append(torch.zeros((1, self.ntypes))) + for itype in range(self.ntypes): + type_mask = sampled[sys]["type"] == itype + sys_matrix[-1][0, itype] = type_mask.sum().item() + polar_bias.append(sampled[sys]["polarizability"].reshape((1, 9))) + matrix, bias = ( + torch.cat(sys_matrix, dim=0), + torch.cat(polar_bias, dim=0), + ) + atom_polar, _, _, _ = torch.linalg.lstsq(matrix, bias, rcond=None) + constant_matrix = [] + for itype in range(self.ntypes): + constant_matrix.append( + torch.mean(torch.diagonal(atom_polar[itype].reshape(3,3))) + ) + + self.constant_matrix = torch.tensor(constant_matrix, device=env.DEVICE) + if stat_file_path is not None: + stat_file_path.save_numpy(self.constant_matrix.detach().cpu().numpy()) + + + + def forward( self,