diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 9c4f67c567..0d369671e6 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -28,6 +28,7 @@ ) from deepmd.utils.out_stat import ( compute_stats_from_redu, + compute_stats_from_atomic, ) from deepmd.utils.path import ( DPPath, @@ -204,26 +205,29 @@ def compute_output_stats( sys_constant_matrix = [] for sys in range(len(sampled)): nframs = sampled[sys]["type"].shape[0] - sys_type_count = np.zeros((nframs, self.ntypes)) - for itype in range(self.ntypes): - type_mask = sampled[sys]["type"] == itype - sys_type_count[:, itype] = type_mask.sum(dim=1) if sampled[sys]["find_atomic_polarizability"] > 0.0: - sys_bias_redu = sampled[sys]["atomic_polarizability"].sum(dim=1) + sys_atom_polar = compute_stats_from_atomic( + sampled[sys]["atomic_polarizability"], + sampled[sys]["type"] + )[0] else: if not sampled[sys]["find_polarizability"] > 0.0: continue - + sys_type_count = np.zeros((nframs, self.ntypes)) + for itype in range(self.ntypes): + type_mask = sampled[sys]["type"] == itype + sys_type_count[:, itype] = type_mask.sum(dim=1) + sys_bias_redu = sampled[sys]["polarizability"] - sys_atom_polar = compute_stats_from_redu( - sys_type_count, sys_bias_redu - )[0] + sys_atom_polar = compute_stats_from_redu( + sys_bias_redu, sys_type_count + )[0] cur_constant_matrix = np.zeros(self.ntypes) for itype in range(self.ntypes): cur_constant_matrix[itype] = torch.mean( - torch.diagonal(sys_atom_polar.T[itype].reshape(3, 3)) + torch.diagonal(sys_atom_polar[itype].reshape(3, 3)) ) sys_constant_matrix.append(cur_constant_matrix) constant_matrix = np.stack(sys_constant_matrix).mean(axis=0) diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 7ac31809f3..57070633ba 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -193,7 +193,7 @@ def compute_output_stats(self, all_stat): index_lis = [ index for index, w in enumerate(atom_has_polar) - if atom_has_polar[index] == self.sel_type[itype] + if w == self.sel_type[itype] ] # select index in this type sys_matrix.append(np.zeros((1, len(self.sel_type)))) @@ -201,10 +201,8 @@ def compute_output_stats(self, all_stat): polar_bias.append( np.sum( - all_stat["atomic_polarizability"][ss].reshape((-1, 9))[ - index_lis - ], - axis=0, + all_stat["atomic_polarizability"][ss][:, index_lis, :], + axis=(0,1), ).reshape((1, 9)) ) else: # No atomic polar in this system, so it should have global polar