From f0cdbe48b8afd675857debac1f3a7d54f2348192 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 12 Nov 2024 22:18:54 +0800 Subject: [PATCH 1/8] fix(pt): fix precision --- deepmd/pt/model/descriptor/dpa1.py | 10 ++++++++- deepmd/pt/model/descriptor/dpa2.py | 15 ++++++++++++- deepmd/pt/model/descriptor/repformers.py | 12 +++++----- deepmd/pt/model/descriptor/se_a.py | 18 +++++++++++---- deepmd/pt/model/descriptor/se_atten.py | 10 ++------- deepmd/pt/model/descriptor/se_r.py | 5 +++-- deepmd/pt/model/descriptor/se_t.py | 18 ++++++++++----- deepmd/pt/model/descriptor/se_t_tebd.py | 22 +++++++++++-------- deepmd/pt/model/network/mlp.py | 2 -- deepmd/pt/model/task/dipole.py | 2 ++ deepmd/pt/model/task/fitting.py | 18 ++++++++++----- deepmd/pt/model/task/polarizability.py | 3 +++ .../model/test_compressed_descriptor_dpa2.py | 5 ----- 13 files changed, 91 insertions(+), 49 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 76115b2810..b5355a5be7 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -678,6 +678,8 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 @@ -693,7 +695,13 @@ def forward( if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None, + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 77e9f1d936..3c1db49568 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -27,6 +27,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) from deepmd.pt.utils.nlist import ( build_multiple_neighbor_list, get_multiple_nlist_key, @@ -268,6 +271,7 @@ def init_subclass_params(sub_data, sub_class): ) self.concat_output_tebd = concat_output_tebd self.precision = precision + self.prec = PRECISION_DICT[self.precision] self.smooth = smooth self.exclude_types = exclude_types self.env_protection = env_protection @@ -745,6 +749,9 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + use_three_body = self.use_three_body nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 @@ -810,7 +817,13 @@ def forward( ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index bd1109e8b7..d646179fa2 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -22,6 +22,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSe, ) @@ -237,6 +240,7 @@ def __init__( self.reinit_exclude(exclude_types) self.env_protection = env_protection self.precision = precision + self.prec = PRECISION_DICT[precision] self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.epsilon = 1e-4 @@ -286,12 +290,8 @@ def __init__( self.layers = torch.nn.ModuleList(layers) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.stats = None diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index eadce86963..821b6f3d5a 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -337,7 +337,18 @@ def forward( The smooth switch function. """ - return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping) + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) + g1, rot_mat, g2, h2, sw = self.sea.forward( + nlist, coord_ext, atype_ext, None, mapping + ) + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) def set_stat_mean_and_stddev( self, @@ -742,7 +753,6 @@ def forward( ) dmatrix = dmatrix.view(-1, self.nnei, 4) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit xyz_scatter = torch.zeros( @@ -811,8 +821,8 @@ def forward( result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron) rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005 return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), - rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, + rot_mat, None, None, sw, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 6ec02de514..caf1bed6f7 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -227,12 +227,8 @@ def __init__( ) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 @@ -568,8 +564,6 @@ def forward( # nfnl x nnei x ng # gg = gg_s * gg_t + gg_s gg_t = gg_t.reshape(-1, gg_t.size(-1)) - # Convert all tensors to the required precision at once - ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t)) xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten( self.compress_data[0].contiguous(), self.compress_info[0].cpu().contiguous(), diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index f70fdfa9f1..ab64eab1a7 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -456,6 +456,8 @@ def forward( The smooth switch function. """ + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) del mapping, comm_dict nf = nlist.shape[0] nloc = nlist.shape[1] @@ -474,7 +476,6 @@ def forward( assert self.filter_layers is not None dmatrix = dmatrix.view(-1, self.nnei, 1) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit xyz_scatter = torch.zeros( @@ -519,7 +520,7 @@ def forward( None, None, None, - sw, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), ) def set_stat_mean_and_stddev( diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 0eec78fd2f..001853a699 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -373,7 +373,18 @@ def forward( The smooth switch function. """ - return self.seat.forward(nlist, coord_ext, atype_ext, None, mapping) + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) + g1, rot_mat, g2, h2, sw = self.seat.forward( + nlist, coord_ext, atype_ext, None, mapping + ) + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) def set_stat_mean_and_stddev( self, @@ -801,7 +812,6 @@ def forward( protection=self.env_protection, ) dmatrix = dmatrix.view(-1, self.nnei, 4) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit result = torch.zeros( @@ -832,8 +842,6 @@ def forward( env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j) if self.compress: ebd_env_ij = env_ij.view(-1, 1) - ebd_env_ij = ebd_env_ij.to(dtype=self.prec) - env_ij = env_ij.to(dtype=self.prec) res_ij = torch.ops.deepmd.tabulate_fusion_se_t( compress_data_ii.contiguous(), compress_info_ii.cpu().contiguous(), @@ -853,7 +861,7 @@ def forward( # xyz_scatter /= (self.nnei * self.nnei) result = result.view(nf, nloc, self.filter_neuron[-1]) return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, None, None, None, diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index 82ccb06f32..42f1222dff 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -441,12 +441,14 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 g1_ext = self.type_embedding(extended_atype) g1_inp = g1_ext[:, :nloc, :] - g1, g2, h2, rot_mat, sw = self.se_ttebd( + g1, _, _, _, sw = self.se_ttebd( nlist, extended_coord, extended_atype, @@ -456,7 +458,13 @@ def forward( if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( @@ -540,12 +548,8 @@ def __init__( self.reinit_exclude(exclude_types) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.tebd_dim_input = self.tebd_dim * 2 @@ -849,7 +853,7 @@ def forward( # nf x nl x ng result = res_ij.view(nframes, nloc, self.filter_neuron[-1]) return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, None, None, None, diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index f2137bd004..2b8383806b 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -200,7 +200,6 @@ def forward( The output. """ ori_prec = xx.dtype - xx = xx.to(self.prec) yy = ( torch.matmul(xx, self.matrix) + self.bias if self.bias is not None @@ -215,7 +214,6 @@ def forward( yy += torch.concat([xx, xx], dim=-1) else: yy = yy - yy = yy.to(ori_prec) return yy def serialize(self) -> dict: diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index bc09fa4d0f..1642937652 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -186,6 +186,8 @@ def forward( ] # (nframes * nloc, 1, m1) out = out.view(-1, 1, self.embedding_width) + # cast from global to gr precision again + out = out.to(dtype=gr.dtype) # (nframes * nloc, m1, 3) gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes, nloc, 3) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 470b420c89..aef3b6d21a 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -403,7 +403,11 @@ def _forward_common( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, ): - xx = descriptor + # cast the input to internal precsion + xx = descriptor.to(self.prec) + fparam = fparam.to(self.prec) if fparam is not None else None + aparam = aparam.to(self.prec) if aparam is not None else None + if self.remove_vaccum_contribution is not None: # TODO: compute the input for vaccm when remove_vaccum_contribution is set # Ideally, the input for vacuum should be computed; @@ -477,10 +481,12 @@ def _forward_common( device=descriptor.device, ) # jit assertion if self.mixed_types: - atom_property = self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] + atom_property = self.filter_layers.networks[0](xx) if xx_zeros is not None: atom_property -= self.filter_layers.networks[0](xx_zeros) - outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out] + outs = ( + outs + atom_property + self.bias_atom_e[atype] + ) # Shape is [nframes, natoms[0], net_dim_out] else: for type_i, ll in enumerate(self.filter_layers.networks): mask = (atype == type_i).unsqueeze(-1) @@ -495,12 +501,12 @@ def _forward_common( ): atom_property -= ll(xx_zeros) atom_property = atom_property + self.bias_atom_e[type_i] - atom_property = atom_property * mask + atom_property = torch.where(mask, atom_property, 0.0) outs = ( outs + atom_property ) # Shape is [nframes, natoms[0], net_dim_out] # nf x nloc - mask = self.emask(atype) + mask = self.emask(atype).bool() # nf x nloc x nod - outs = outs * mask[:, :, None] + outs = torch.where(mask[:, :, None], outs, 0.0) return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 7b7f92c3af..36a205b637 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -238,6 +238,9 @@ def forward( self.var_name ] out = out * (self.scale.to(atype.device))[atype] + # cast from global to gr precision again + out = out.to(dtype=gr.dtype) + gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3) if self.fit_diag: diff --git a/source/tests/pt/model/test_compressed_descriptor_dpa2.py b/source/tests/pt/model/test_compressed_descriptor_dpa2.py index 05b1143eb1..c7b3deba7e 100644 --- a/source/tests/pt/model/test_compressed_descriptor_dpa2.py +++ b/source/tests/pt/model/test_compressed_descriptor_dpa2.py @@ -53,11 +53,6 @@ def eval_pt_descriptor( class TestDescriptorDPA2(unittest.TestCase): def setUp(self): (self.dtype, self.type_one_side) = self.param - if self.dtype == "float32": - self.skipTest("FP32 has bugs:") - # ../../../../deepmd/pt/model/descriptor/repformer_layer.py:521: in forward - # torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) - # E RuntimeError: expected scalar type Float but found Double if self.dtype == "float32": self.atol = 1e-5 elif self.dtype == "float64": From bb9a5808eb7afa4c19b72452d8c7d1affc4ca477 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 12 Nov 2024 22:56:11 +0800 Subject: [PATCH 2/8] fix uts --- deepmd/pt/model/descriptor/dpa1.py | 2 ++ deepmd/pt/model/descriptor/se_a.py | 1 + deepmd/pt/model/descriptor/se_t.py | 1 + deepmd/pt/model/descriptor/se_t_tebd.py | 1 + 4 files changed, 5 insertions(+) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index b5355a5be7..b541e665ab 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -22,6 +22,7 @@ env, ) from deepmd.pt.utils.env import ( + PRECISION_DICT, RESERVED_PRECISON_DICT, ) from deepmd.pt.utils.tabulate import ( @@ -311,6 +312,7 @@ def __init__( use_tebd_bias=use_tebd_bias, type_map=type_map, ) + self.prec = PRECISION_DICT[precision] self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 821b6f3d5a..0554c474f0 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -118,6 +118,7 @@ def __init__( super().__init__() self.type_map = type_map self.compress = False + self.prec = PRECISION_DICT[precision] self.sea = DescrptBlockSeA( rcut, rcut_smth, diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 001853a699..8da6be8094 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -154,6 +154,7 @@ def __init__( super().__init__() self.type_map = type_map self.compress = False + self.prec = PRECISION_DICT[precision] self.seat = DescrptBlockSeT( rcut, rcut_smth, diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index 42f1222dff..e7b4827724 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -161,6 +161,7 @@ def __init__( smooth=smooth, seed=child_seed(seed, 1), ) + self.prec = PRECISION_DICT[precision] self.use_econf_tebd = use_econf_tebd self.type_map = type_map self.smooth = smooth From 2aa2f735046078feda3813e4674b47d70c49714a Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 12 Nov 2024 22:59:38 +0800 Subject: [PATCH 3/8] Update fitting.py --- deepmd/pt/model/task/fitting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index aef3b6d21a..7897c7dce9 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -506,7 +506,7 @@ def _forward_common( outs + atom_property ) # Shape is [nframes, natoms[0], net_dim_out] # nf x nloc - mask = self.emask(atype).bool() + mask = self.emask(atype).to(torch.bool) # nf x nloc x nod outs = torch.where(mask[:, :, None], outs, 0.0) return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} From f5f73703d0cefb2ace2dc226f5b3c325d234ac5c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Nov 2024 10:38:01 +0800 Subject: [PATCH 4/8] fix uts --- deepmd/pt/model/task/dipole.py | 2 -- deepmd/pt/model/task/fitting.py | 8 ++++---- deepmd/pt/model/task/invar_fitting.py | 5 ++++- deepmd/pt/model/task/polarizability.py | 4 +--- source/tests/pt/model/test_dipole_fitting.py | 4 ++-- source/tests/pt/model/test_polarizability_fitting.py | 2 +- source/tests/pt/model/test_property_fitting.py | 4 ++-- 7 files changed, 14 insertions(+), 15 deletions(-) diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 1642937652..bc09fa4d0f 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -186,8 +186,6 @@ def forward( ] # (nframes * nloc, 1, m1) out = out.view(-1, 1, self.embedding_width) - # cast from global to gr precision again - out = out.to(dtype=gr.dtype) # (nframes * nloc, m1, 3) gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes, nloc, 3) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 7897c7dce9..798e271c8f 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -477,7 +477,7 @@ def _forward_common( outs = torch.zeros( (nf, nloc, net_dim_out), - dtype=env.GLOBAL_PT_FLOAT_PRECISION, + dtype=self.prec, device=descriptor.device, ) # jit assertion if self.mixed_types: @@ -485,7 +485,7 @@ def _forward_common( if xx_zeros is not None: atom_property -= self.filter_layers.networks[0](xx_zeros) outs = ( - outs + atom_property + self.bias_atom_e[atype] + outs + atom_property + self.bias_atom_e[atype].to(self.prec) ) # Shape is [nframes, natoms[0], net_dim_out] else: for type_i, ll in enumerate(self.filter_layers.networks): @@ -500,7 +500,7 @@ def _forward_common( and not self.remove_vaccum_contribution[type_i] ): atom_property -= ll(xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i] + atom_property = atom_property + self.bias_atom_e[type_i].to(self.prec) atom_property = torch.where(mask, atom_property, 0.0) outs = ( outs + atom_property @@ -509,4 +509,4 @@ def _forward_common( mask = self.emask(atype).to(torch.bool) # nf x nloc x nod outs = torch.where(mask[:, :, None], outs, 0.0) - return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + return {self.var_name: outs} diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index acdd5b0fda..c339f4690d 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -177,7 +177,10 @@ def forward( ------- - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ - return self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam) + out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ + self.var_name + ] + return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} # make jit happy with torch 2.0.0 exclude_types: list[int] diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 36a205b637..6ec7635377 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -237,9 +237,7 @@ def forward( out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] - out = out * (self.scale.to(atype.device))[atype] - # cast from global to gr precision again - out = out.to(dtype=gr.dtype) + out = out * (self.scale.to(atype.device).to(self.prec))[atype] gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 71da2781ac..0c4121f457 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -262,7 +262,7 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["dipole"]) np.testing.assert_allclose( @@ -303,7 +303,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["dipole"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 1ca563a8c2..4e63145741 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -326,7 +326,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["polarizability"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) diff --git a/source/tests/pt/model/test_property_fitting.py b/source/tests/pt/model/test_property_fitting.py index dfe2725f3b..ad5f3687e9 100644 --- a/source/tests/pt/model/test_property_fitting.py +++ b/source/tests/pt/model/test_property_fitting.py @@ -228,7 +228,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["property"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -399,7 +399,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["property"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) From b1faf56ba22727f1e97eeb6e4362dac02e8a1654 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:15:18 +0800 Subject: [PATCH 5/8] add optional cast --- deepmd/pt/model/network/mlp.py | 5 +++++ deepmd/pt/utils/env.py | 1 + 2 files changed, 6 insertions(+) diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 2b8383806b..582abf4d69 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -32,6 +32,7 @@ ) from deepmd.pt.utils.env import ( DEFAULT_PRECISION, + DP_DTYPE_PROMOTION_STRICT, PRECISION_DICT, ) from deepmd.pt.utils.utils import ( @@ -200,6 +201,8 @@ def forward( The output. """ ori_prec = xx.dtype + if not DP_DTYPE_PROMOTION_STRICT: + xx = xx.to(self.prec) yy = ( torch.matmul(xx, self.matrix) + self.bias if self.bias is not None @@ -214,6 +217,8 @@ def forward( yy += torch.concat([xx, xx], dim=-1) else: yy = yy + if not DP_DTYPE_PROMOTION_STRICT: + yy = yy.to(ori_prec) return yy def serialize(self) -> dict: diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 3ee0b7b54d..81dce669ff 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -15,6 +15,7 @@ ) SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) +DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1" try: # only linux ncpus = len(os.sched_getaffinity(0)) From abfddd895b5399fa49597ee35ce8cfae787b2f8f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Nov 2024 13:11:28 +0800 Subject: [PATCH 6/8] Update mlp.py --- deepmd/pt/model/network/mlp.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 582abf4d69..a5ac086770 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -32,7 +32,6 @@ ) from deepmd.pt.utils.env import ( DEFAULT_PRECISION, - DP_DTYPE_PROMOTION_STRICT, PRECISION_DICT, ) from deepmd.pt.utils.utils import ( @@ -201,7 +200,7 @@ def forward( The output. """ ori_prec = xx.dtype - if not DP_DTYPE_PROMOTION_STRICT: + if not env.DP_DTYPE_PROMOTION_STRICT: xx = xx.to(self.prec) yy = ( torch.matmul(xx, self.matrix) + self.bias @@ -217,7 +216,7 @@ def forward( yy += torch.concat([xx, xx], dim=-1) else: yy = yy - if not DP_DTYPE_PROMOTION_STRICT: + if not env.DP_DTYPE_PROMOTION_STRICT: yy = yy.to(ori_prec) return yy From 02b75b3d2298dc163d9756f471da18c38cca320f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:23:12 +0800 Subject: [PATCH 7/8] add extra cast for dipole/polar --- deepmd/pt/model/task/dipole.py | 2 ++ deepmd/pt/model/task/polarizability.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index bc09fa4d0f..af81e54d13 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -180,6 +180,8 @@ def forward( ): nframes, nloc, _ = descriptor.shape assert gr is not None, "Must provide the rotation matrix for dipole fitting." + # cast the input to internal precsion + gr = gr.to(self.prec) # (nframes, nloc, m1) out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 6ec7635377..40a569fa80 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -233,6 +233,8 @@ def forward( assert ( gr is not None ), "Must provide the rotation matrix for polarizability fitting." + # cast the input to internal precsion + gr = gr.to(self.prec) # (nframes, nloc, _net_out_dim) out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name From 9003b32e1144ad157c922adf8b941fbeef81ae10 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:37:00 +0800 Subject: [PATCH 8/8] Update test_compressed_descriptor_se_atten.py --- source/tests/pt/model/test_compressed_descriptor_se_atten.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/source/tests/pt/model/test_compressed_descriptor_se_atten.py b/source/tests/pt/model/test_compressed_descriptor_se_atten.py index a439255396..edb682b27b 100644 --- a/source/tests/pt/model/test_compressed_descriptor_se_atten.py +++ b/source/tests/pt/model/test_compressed_descriptor_se_atten.py @@ -115,11 +115,6 @@ def test_compressed_forward(self): self.box, ) - if self.dtype == "float32": - result_pt = result_pt.to(torch.float32) - elif self.dtype == "float64": - result_pt = result_pt.to(torch.float64) - self.se_atten.enable_compression(0.5) result_pt_compressed = eval_pt_descriptor( self.se_atten,