Skip to content

Commit

Permalink
feat: add UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Mar 7, 2024
1 parent b81ec36 commit ee28334
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 9 deletions.
23 changes: 17 additions & 6 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
).view(ntypes, 1)
self.shift_diag = shift_diag
self.constant_matrix = torch.zeros(ntypes)
self.constant_matrix = torch.zeros(ntypes, device=env.DEVICE)
super().__init__(
var_name=kwargs.pop("var_name", "polar"),
ntypes=ntypes,
Expand Down Expand Up @@ -224,13 +224,17 @@ def compute_output_stats(
sys_bias_redu, sys_type_count
)[0]
cur_constant_matrix = np.zeros(self.ntypes)
if sys_atom_polar.shape[0] < self.ntypes:
# pad zeros in case atype.max() + 1 != ntypes
sys_atom_polar = np.concatenate([sys_atom_polar, np.zeros((self.ntypes - sys_atom_polar.shape[0], 9))])
for itype in range(self.ntypes):
cur_constant_matrix[itype] = torch.mean(
torch.diagonal(sys_atom_polar[itype].reshape(3, 3))
cur_constant_matrix[itype] = np.mean(
np.diagonal(sys_atom_polar[itype].reshape(3, 3))
)
sys_constant_matrix.append(cur_constant_matrix)
constant_matrix = np.stack(sys_constant_matrix).mean(axis=0)

# handle nan values.
constant_matrix = np.nan_to_num(constant_matrix)
self.constant_matrix = torch.tensor(constant_matrix, device=env.DEVICE)
if stat_file_path is not None:
stat_file_path.save_numpy(self.constant_matrix.detach().cpu().numpy())
Expand Down Expand Up @@ -268,10 +272,17 @@ def forward(
) # (nframes * nloc, 3, 3)
out = out.view(nframes, nloc, 3, 3)
if self.shift_diag:
# to handle nan in constant_matrix

# (nframes, nloc, 1, 1)
bias = self.constant_matrix[atype].unsqueeze(-1).unsqueeze(-1)
eye = torch.eye(3, device=env.DEVICE)
eye = eye.repeat(nframes, nloc, 1, 1)
# (nframes, nloc, 3, 3)
bias = bias * eye
out = (
out
+ self.constant_matrix[atype]
* torch.eye(3, device=env.DEVICE)
+ bias
* self.scale[atype]
)

Expand Down
7 changes: 4 additions & 3 deletions deepmd/tf/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def compute_output_stats(self, all_stat):
mean_polar = np.zeros([len(self.sel_type), 9])
sys_matrix, polar_bias = [], []
for ss in range(len(all_stat["type"])):
nframes = all_stat["type"][ss].shape[0]
atom_has_polar = [
w for w in all_stat["type"][ss][0] if (w in self.sel_type)
] # select atom with polar
Expand All @@ -201,8 +202,8 @@ def compute_output_stats(self, all_stat):

polar_bias.append(
np.sum(
all_stat["atomic_polarizability"][ss][:, index_lis, :],
axis=(0, 1),
all_stat["atomic_polarizability"][ss][:, index_lis, :]/nframes,
axis=(0,1),
).reshape((1, 9))
)
else: # No atomic polar in this system, so it should have global polar
Expand All @@ -226,7 +227,7 @@ def compute_output_stats(self, all_stat):
sys_matrix[-1][0, itype] = len(index_lis)

# add polar_bias
polar_bias.append(all_stat["polarizability"][ss].reshape((1, 9)))
polar_bias.append(np.mean(all_stat["polarizability"][ss],axis=0).reshape((1, 9)))

matrix, bias = (
np.concatenate(sys_matrix, axis=0),
Expand Down
59 changes: 59 additions & 0 deletions source/tests/consistent/model/test_polar_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import unittest
import numpy as np
import torch
from deepmd.pt.utils import (
env,
)
from deepmd.pt.model.task.polarizability import PolarFittingNet
from deepmd.tf.fit.polar import PolarFittingSeA
from deepmd.pt.utils.utils import (
to_numpy_array,
)

class TestConsistency(unittest.TestCase):
def setUp(self) -> None:
types = torch.randint(0,4,(1,5), device=env.DEVICE)
types = torch.cat((types,types,types), dim=0)
ntypes = 4
atomic_polarizability = torch.rand(3,5,9)
polarizability = torch.rand(3,9)
find_polarizability = torch.rand(1)
find_atomic_polarizability = torch.rand(1)
self.sampled = [{"type": types,
"find_atomic_polarizability": find_atomic_polarizability,
"atomic_polarizability": atomic_polarizability,
"polarizability": polarizability,
"find_polarizability": find_polarizability}]
self.all_stat = {k: [v.numpy()] for d in self.sampled for k, v in d.items()}
self.tfpolar = PolarFittingSeA(
ntypes=ntypes,
dim_descrpt=1,
embedding_width=1,
sel_type=[i for i in range(ntypes)],
)
self.ptpolar = PolarFittingNet(
ntypes=ntypes,
dim_descrpt=1,
embedding_width=1,
)


def test_atomic_consistency(self):
self.tfpolar.compute_output_stats(self.all_stat)
tfbias = self.tfpolar.constant_matrix
self.ptpolar.compute_output_stats(self.sampled)
ptbias = self.ptpolar.constant_matrix
print(tfbias, to_numpy_array(ptbias))
np.testing.assert_allclose(tfbias, to_numpy_array(ptbias))

def test_global_consistency(self):
self.sampled[0]["find_atomic_polarizability"] = -1
self.sampled[0]["polarizability"] = self.sampled[0]["atomic_polarizability"].sum(dim=1)
self.all_stat["find_atomic_polarizability"] = [-1]
self.all_stat["polarizability"] = [(self.all_stat["atomic_polarizability"][0]).sum(axis=1)]
self.tfpolar.compute_output_stats(self.all_stat)
tfbias = self.tfpolar.constant_matrix
self.ptpolar.compute_output_stats(self.sampled)
ptbias = self.ptpolar.constant_matrix
print(tfbias, to_numpy_array(ptbias))
np.testing.assert_allclose(tfbias, to_numpy_array(ptbias))

0 comments on commit ee28334

Please sign in to comment.