From f5f73703d0cefb2ace2dc226f5b3c325d234ac5c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Nov 2024 10:38:01 +0800 Subject: [PATCH] fix uts --- deepmd/pt/model/task/dipole.py | 2 -- deepmd/pt/model/task/fitting.py | 8 ++++---- deepmd/pt/model/task/invar_fitting.py | 5 ++++- deepmd/pt/model/task/polarizability.py | 4 +--- source/tests/pt/model/test_dipole_fitting.py | 4 ++-- source/tests/pt/model/test_polarizability_fitting.py | 2 +- source/tests/pt/model/test_property_fitting.py | 4 ++-- 7 files changed, 14 insertions(+), 15 deletions(-) diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 1642937652..bc09fa4d0f 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -186,8 +186,6 @@ def forward( ] # (nframes * nloc, 1, m1) out = out.view(-1, 1, self.embedding_width) - # cast from global to gr precision again - out = out.to(dtype=gr.dtype) # (nframes * nloc, m1, 3) gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes, nloc, 3) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 7897c7dce9..798e271c8f 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -477,7 +477,7 @@ def _forward_common( outs = torch.zeros( (nf, nloc, net_dim_out), - dtype=env.GLOBAL_PT_FLOAT_PRECISION, + dtype=self.prec, device=descriptor.device, ) # jit assertion if self.mixed_types: @@ -485,7 +485,7 @@ def _forward_common( if xx_zeros is not None: atom_property -= self.filter_layers.networks[0](xx_zeros) outs = ( - outs + atom_property + self.bias_atom_e[atype] + outs + atom_property + self.bias_atom_e[atype].to(self.prec) ) # Shape is [nframes, natoms[0], net_dim_out] else: for type_i, ll in enumerate(self.filter_layers.networks): @@ -500,7 +500,7 @@ def _forward_common( and not self.remove_vaccum_contribution[type_i] ): atom_property -= ll(xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i] + atom_property = atom_property + self.bias_atom_e[type_i].to(self.prec) atom_property = torch.where(mask, atom_property, 0.0) outs = ( outs + atom_property @@ -509,4 +509,4 @@ def _forward_common( mask = self.emask(atype).to(torch.bool) # nf x nloc x nod outs = torch.where(mask[:, :, None], outs, 0.0) - return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + return {self.var_name: outs} diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index acdd5b0fda..c339f4690d 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -177,7 +177,10 @@ def forward( ------- - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ - return self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam) + out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ + self.var_name + ] + return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} # make jit happy with torch 2.0.0 exclude_types: list[int] diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 36a205b637..6ec7635377 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -237,9 +237,7 @@ def forward( out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] - out = out * (self.scale.to(atype.device))[atype] - # cast from global to gr precision again - out = out.to(dtype=gr.dtype) + out = out * (self.scale.to(atype.device).to(self.prec))[atype] gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 71da2781ac..0c4121f457 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -262,7 +262,7 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["dipole"]) np.testing.assert_allclose( @@ -303,7 +303,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["dipole"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 1ca563a8c2..4e63145741 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -326,7 +326,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["polarizability"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) diff --git a/source/tests/pt/model/test_property_fitting.py b/source/tests/pt/model/test_property_fitting.py index dfe2725f3b..ad5f3687e9 100644 --- a/source/tests/pt/model/test_property_fitting.py +++ b/source/tests/pt/model/test_property_fitting.py @@ -228,7 +228,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["property"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -399,7 +399,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["property"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))