diff --git a/deepmd/model_format/se_e2_a.py b/deepmd/model_format/se_e2_a.py index fe516c8620..28751cad8d 100644 --- a/deepmd/model_format/se_e2_a.py +++ b/deepmd/model_format/se_e2_a.py @@ -223,7 +223,18 @@ def call( Returns ------- descriptor - The descriptor. shape: nf x nloc x ng x axis_neuron + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. """ # nf x nloc x nnei x 4 rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd) @@ -238,15 +249,17 @@ def call( gg = self.cal_g(ss, tt) # nf x nloc x ng x 4 gr += np.einsum("flni,flnj->flij", gg, tr) + # nf x nloc x ng x 4 gr /= self.nnei gr1 = gr[:, :, : self.axis_neuron, :] # nf x nloc x ng x ng1 grrg = np.einsum("flid,fljd->flij", gr, gr1) # nf x nloc x (ng x ng1) grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron) - return grrg + return grrg, gr[..., 1:], None, None, ww def serialize(self) -> dict: + """Serialize the descriptor to dict.""" return { "rcut": self.rcut, "rcut_smth": self.rcut_smth, @@ -271,6 +284,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "DescrptSeA": + """Deserialize from dict.""" data = copy.deepcopy(data) variables = data.pop("@variables") embeddings = data.pop("embeddings") diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index dd34b815c9..23f521b6d8 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -135,12 +135,42 @@ def forward( nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, not required by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 g1_ext = self.type_embedding(extended_atype) g1_inp = g1_ext[:, :nloc, :] - g1, env_mat, diff, rot_mat, sw = self.se_atten( + g1, g2, h2, rot_mat, sw = self.se_atten( nlist, extended_coord, extended_atype, @@ -149,4 +179,5 @@ def forward( ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, env_mat, diff, rot_mat, sw + + return g1, rot_mat, g2, h2, sw diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index fbdbc91dd9..409b999262 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -329,6 +329,36 @@ def forward( nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, mapps extended region index to local region. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 # nlists @@ -372,4 +402,4 @@ def forward( ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, g2, h2, rot_mat, sw + return g1, rot_mat, g2, h2, sw diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 26887b1b75..141b5dc745 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -256,7 +256,7 @@ def forward( # (nb x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) - return g1, g2, h2, rot_mat.view(-1, self.dim_emb, 3), sw + return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw def compute_input_stats(self, merged): """Update mean and stddev for descriptor elements.""" diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 10aa66311e..3f42736dca 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -115,12 +115,42 @@ def get_data_process_key(cls, config): def forward( self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, + coord_ext: torch.Tensor, + atype_ext: torch.Tensor, nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, ): - return self.sea.forward(nlist, extended_coord, extended_atype, None, mapping) + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, not required by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + + """ + return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping) def set_stat_mean_and_stddev( self, @@ -389,7 +419,7 @@ def forward( del extended_atype_embd, mapping nloc = nlist.shape[1] atype = extended_atype[:, :nloc] - dmatrix, diff, _ = prod_env_mat_se_a( + dmatrix, diff, sw = prod_env_mat_se_a( extended_coord, nlist, atype, @@ -438,12 +468,14 @@ def forward( result = torch.matmul( xyz_scatter_1, xyz_scatter_2 ) # shape is [nframes*nall, self.filter_neuron[-1], self.axis_neuron] + result = result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron) + rot_mat = rot_mat.view([-1, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005 return ( - result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), - None, - None, + result, + rot_mat, None, None, + sw, ) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 0c932f42f2..78cba59da7 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -281,9 +281,8 @@ def forward( self.rcut, self.rcut_smth, ) - dmatrix = dmatrix.view( - -1, self.ndescrpt - ) # shape is [nframes*nall, self.ndescrpt] + # [nfxnlocxnnei, self.ndescrpt] + dmatrix = dmatrix.view(-1, self.ndescrpt) nlist_mask = nlist != -1 nlist[nlist == -1] = 0 sw = torch.squeeze(sw, -1) @@ -328,7 +327,7 @@ def forward( return ( result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), ret.view(-1, nloc, self.nnei, self.filter_neuron[-1]), - diff, + dmatrix.view(-1, nloc, self.nnei, 4)[..., 1:], rot_mat.view(-1, self.filter_neuron[-1], 3), sw, ) diff --git a/deepmd/pt/model/model/dp_atomic_model.py b/deepmd/pt/model/model/dp_atomic_model.py index a0f9b25765..853eacb875 100644 --- a/deepmd/pt/model/model/dp_atomic_model.py +++ b/deepmd/pt/model/model/dp_atomic_model.py @@ -14,7 +14,6 @@ Descriptor, ) from deepmd.pt.model.task import ( - DenoiseNet, Fitting, ) @@ -93,40 +92,20 @@ def __init__( sampled=sampled, ) - # Fitting - if fitting_net: - fitting_net["type"] = fitting_net.get("type", "ener") - if self.descriptor_type not in ["se_e2_a"]: - fitting_net["ntypes"] = 1 - else: - fitting_net["ntypes"] = self.descriptor.get_ntype() - fitting_net["use_tebd"] = False - fitting_net["embedding_width"] = self.descriptor.dim_out - - self.grad_force = "direct" not in fitting_net["type"] - if not self.grad_force: - fitting_net["out_dim"] = self.descriptor.dim_emb - if "ener" in fitting_net["type"]: - fitting_net["return_energy"] = True - self.fitting_net = Fitting(**fitting_net) + fitting_net["type"] = fitting_net.get("type", "ener") + if self.descriptor_type not in ["se_e2_a"]: + fitting_net["ntypes"] = 1 else: - self.fitting_net = None - self.grad_force = False - if not self.split_nlist: - self.coord_denoise_net = DenoiseNet( - self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb - ) - elif self.combination: - self.coord_denoise_net = DenoiseNet( - self.descriptor.dim_out, - self.ntypes - 1, - self.descriptor.dim_emb_list, - self.prefactor, - ) - else: - self.coord_denoise_net = DenoiseNet( - self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb - ) + fitting_net["ntypes"] = self.descriptor.get_ntype() + fitting_net["use_tebd"] = False + fitting_net["embedding_width"] = self.descriptor.dim_out + + self.grad_force = "direct" not in fitting_net["type"] + if not self.grad_force: + fitting_net["out_dim"] = self.descriptor.dim_emb + if "ener" in fitting_net["type"]: + fitting_net["return_energy"] = True + self.fitting_net = Fitting(**fitting_net) def get_fitting_output_def(self) -> FittingOutputDef: """Get the output def of the fitting net.""" @@ -178,7 +157,7 @@ def forward_atomic( atype = extended_atype[:, :nloc] if self.do_grad(): extended_coord.requires_grad_(True) - descriptor, env_mat, diff, rot_mat, sw = self.descriptor( + descriptor, rot_mat, g2, h2, sw = self.descriptor( extended_coord, extended_atype, nlist, @@ -186,23 +165,5 @@ def forward_atomic( ) assert descriptor is not None # energy, force - if self.fitting_net is not None: - fit_ret = self.fitting_net( - descriptor, atype, atype_tebd=None, rot_mat=rot_mat - ) - # denoise - else: - nlist_list = [nlist] - if not self.split_nlist: - nnei_mask = nlist != -1 - elif self.combination: - nnei_mask = [] - for item in nlist_list: - nnei_mask_item = item != -1 - nnei_mask.append(nnei_mask_item) - else: - env_mat = env_mat[-1] - diff = diff[-1] - nnei_mask = nlist_list[-1] != -1 - fit_ret = self.coord_denoise_net(env_mat, diff, nnei_mask, descriptor, sw) + fit_ret = self.fitting_net(descriptor, atype, atype_tebd=None, rot_mat=rot_mat) return fit_ret diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 7ddcbd5c54..03043e2fcb 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -207,8 +207,11 @@ def forward( inputs ) # Shape is [nframes, nloc, m1] assert list(vec_out.size()) == [nframes, nloc, self.out_dim] + # (nf x nloc) x 1 x od vec_out = vec_out.view(-1, 1, self.out_dim) assert rot_mat is not None + # (nf x nloc) x od x 3 + rot_mat = rot_mat.view(-1, self.out_dim, 3) vec_out = ( torch.bmm(vec_out, rot_mat).squeeze(-2).view(nframes, nloc, 3) ) # Shape is [nframes, nloc, 3] diff --git a/source/tests/common/test_model_format_utils.py b/source/tests/common/test_model_format_utils.py index 22393515ec..da76c53ed9 100644 --- a/source/tests/common/test_model_format_utils.py +++ b/source/tests/common/test_model_format_utils.py @@ -367,4 +367,5 @@ def test_self_consistency( em1 = DescrptSeA.deserialize(em0.serialize()) mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist) mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist) - np.testing.assert_allclose(mm0, mm1) + for ii in [0, 1, 4]: + np.testing.assert_allclose(mm0[ii], mm1[ii]) diff --git a/source/tests/pt/test_permutation_denoise.py b/source/tests/pt/test_permutation_denoise.py index 47bd0360f2..6dd61ab7e4 100644 --- a/source/tests/pt/test_permutation_denoise.py +++ b/source/tests/pt/test_permutation_denoise.py @@ -66,6 +66,7 @@ def test( ) +@unittest.skip("support of the denoise is temporally disabled") class TestDenoiseModelDPA1(unittest.TestCase, PermutationDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) @@ -74,6 +75,7 @@ def setUp(self): self.model = get_model(model_params, sampled).to(env.DEVICE) +@unittest.skip("support of the denoise is temporally disabled") class TestDenoiseModelDPA2(unittest.TestCase, PermutationDenoiseTest): def setUp(self): model_params_sample = copy.deepcopy(model_dpa2) diff --git a/source/tests/pt/test_rot_denoise.py b/source/tests/pt/test_rot_denoise.py index cab8de7bec..2cbfd8fd38 100644 --- a/source/tests/pt/test_rot_denoise.py +++ b/source/tests/pt/test_rot_denoise.py @@ -97,6 +97,7 @@ def test( ) +@unittest.skip("support of the denoise is temporally disabled") class TestDenoiseModelDPA1(unittest.TestCase, RotDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) @@ -105,6 +106,7 @@ def setUp(self): self.model = get_model(model_params, sampled).to(env.DEVICE) +@unittest.skip("support of the denoise is temporally disabled") class TestDenoiseModelDPA2(unittest.TestCase, RotDenoiseTest): def setUp(self): model_params_sample = copy.deepcopy(model_dpa2) diff --git a/source/tests/pt/test_se_e2_a.py b/source/tests/pt/test_se_e2_a.py index 96a17c2bad..c0a106cb16 100644 --- a/source/tests/pt/test_se_e2_a.py +++ b/source/tests/pt/test_se_e2_a.py @@ -102,7 +102,7 @@ def test_consistency( ) # serialization dd1 = DescrptSeA.deserialize(dd0.serialize()) - rd1, _, _, _, _ = dd1( + rd1, gr1, _, _, sw1 = dd1( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), torch.tensor(self.nlist, dtype=int, device=env.DEVICE), @@ -116,18 +116,19 @@ def test_consistency( ) # dp impl dd2 = DPDescrptSeA.deserialize(dd0.serialize()) - rd2 = dd2.call( + rd2, gr2, _, _, sw2 = dd2.call( self.coord_ext, self.atype_ext, self.nlist, ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd2, - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) + for aa, bb in zip([rd1, gr1, sw1], [rd2, gr2, sw2]): + np.testing.assert_allclose( + aa.detach().cpu().numpy(), + bb, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) # old impl if idt is False and prec == "float64": dd3 = DescrptSeA( @@ -154,18 +155,19 @@ def test_consistency( dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0) dd3.sea.load_state_dict(dd3_state_dict) - rd3, _, _, _, _ = dd3( + rd3, gr3, _, _, sw3 = dd3( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), torch.tensor(self.nlist, dtype=int, device=env.DEVICE), ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd3.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) + for aa, bb in zip([rd1, gr1, sw1], [rd3, gr3, sw3]): + np.testing.assert_allclose( + aa.detach().cpu().numpy(), + bb.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) def test_jit( self, diff --git a/source/tests/pt/test_smooth_denoise.py b/source/tests/pt/test_smooth_denoise.py index a66e5df957..de89f8dccc 100644 --- a/source/tests/pt/test_smooth_denoise.py +++ b/source/tests/pt/test_smooth_denoise.py @@ -96,6 +96,7 @@ def compare(ret0, ret1): compare(ret0, ret3) +@unittest.skip("support of the denoise is temporally disabled") class TestDenoiseModelDPA2(unittest.TestCase, SmoothDenoiseTest): def setUp(self): model_params_sample = copy.deepcopy(model_dpa2) @@ -116,6 +117,7 @@ def setUp(self): self.aprec = 1e-5 +@unittest.skip("support of the denoise is temporally disabled") class TestDenoiseModelDPA2_1(unittest.TestCase, SmoothDenoiseTest): def setUp(self): model_params_sample = copy.deepcopy(model_dpa2) diff --git a/source/tests/pt/test_trans_denoise.py b/source/tests/pt/test_trans_denoise.py index 360633278c..88b926a3ae 100644 --- a/source/tests/pt/test_trans_denoise.py +++ b/source/tests/pt/test_trans_denoise.py @@ -56,6 +56,7 @@ def test( torch.testing.assert_close(ret0["logits"], ret1["logits"], rtol=prec, atol=prec) +@unittest.skip("support of the denoise is temporally disabled") class TestDenoiseModelDPA1(unittest.TestCase, TransDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) @@ -64,6 +65,7 @@ def setUp(self): self.model = get_model(model_params, sampled).to(env.DEVICE) +@unittest.skip("support of the denoise is temporally disabled") class TestDenoiseModelDPA2(unittest.TestCase, TransDenoiseTest): def setUp(self): model_params_sample = copy.deepcopy(model_dpa2)