From f0cdbe48b8afd675857debac1f3a7d54f2348192 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 12 Nov 2024 22:18:54 +0800 Subject: [PATCH] fix(pt): fix precision --- deepmd/pt/model/descriptor/dpa1.py | 10 ++++++++- deepmd/pt/model/descriptor/dpa2.py | 15 ++++++++++++- deepmd/pt/model/descriptor/repformers.py | 12 +++++----- deepmd/pt/model/descriptor/se_a.py | 18 +++++++++++---- deepmd/pt/model/descriptor/se_atten.py | 10 ++------- deepmd/pt/model/descriptor/se_r.py | 5 +++-- deepmd/pt/model/descriptor/se_t.py | 18 ++++++++++----- deepmd/pt/model/descriptor/se_t_tebd.py | 22 +++++++++++-------- deepmd/pt/model/network/mlp.py | 2 -- deepmd/pt/model/task/dipole.py | 2 ++ deepmd/pt/model/task/fitting.py | 18 ++++++++++----- deepmd/pt/model/task/polarizability.py | 3 +++ .../model/test_compressed_descriptor_dpa2.py | 5 ----- 13 files changed, 91 insertions(+), 49 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 76115b2810..b5355a5be7 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -678,6 +678,8 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 @@ -693,7 +695,13 @@ def forward( if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None, + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 77e9f1d936..3c1db49568 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -27,6 +27,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) from deepmd.pt.utils.nlist import ( build_multiple_neighbor_list, get_multiple_nlist_key, @@ -268,6 +271,7 @@ def init_subclass_params(sub_data, sub_class): ) self.concat_output_tebd = concat_output_tebd self.precision = precision + self.prec = PRECISION_DICT[self.precision] self.smooth = smooth self.exclude_types = exclude_types self.env_protection = env_protection @@ -745,6 +749,9 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + use_three_body = self.use_three_body nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 @@ -810,7 +817,13 @@ def forward( ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index bd1109e8b7..d646179fa2 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -22,6 +22,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSe, ) @@ -237,6 +240,7 @@ def __init__( self.reinit_exclude(exclude_types) self.env_protection = env_protection self.precision = precision + self.prec = PRECISION_DICT[precision] self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.epsilon = 1e-4 @@ -286,12 +290,8 @@ def __init__( self.layers = torch.nn.ModuleList(layers) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.stats = None diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index eadce86963..821b6f3d5a 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -337,7 +337,18 @@ def forward( The smooth switch function. """ - return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping) + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) + g1, rot_mat, g2, h2, sw = self.sea.forward( + nlist, coord_ext, atype_ext, None, mapping + ) + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) def set_stat_mean_and_stddev( self, @@ -742,7 +753,6 @@ def forward( ) dmatrix = dmatrix.view(-1, self.nnei, 4) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit xyz_scatter = torch.zeros( @@ -811,8 +821,8 @@ def forward( result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron) rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005 return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), - rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, + rot_mat, None, None, sw, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 6ec02de514..caf1bed6f7 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -227,12 +227,8 @@ def __init__( ) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 @@ -568,8 +564,6 @@ def forward( # nfnl x nnei x ng # gg = gg_s * gg_t + gg_s gg_t = gg_t.reshape(-1, gg_t.size(-1)) - # Convert all tensors to the required precision at once - ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t)) xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten( self.compress_data[0].contiguous(), self.compress_info[0].cpu().contiguous(), diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index f70fdfa9f1..ab64eab1a7 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -456,6 +456,8 @@ def forward( The smooth switch function. """ + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) del mapping, comm_dict nf = nlist.shape[0] nloc = nlist.shape[1] @@ -474,7 +476,6 @@ def forward( assert self.filter_layers is not None dmatrix = dmatrix.view(-1, self.nnei, 1) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit xyz_scatter = torch.zeros( @@ -519,7 +520,7 @@ def forward( None, None, None, - sw, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), ) def set_stat_mean_and_stddev( diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 0eec78fd2f..001853a699 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -373,7 +373,18 @@ def forward( The smooth switch function. """ - return self.seat.forward(nlist, coord_ext, atype_ext, None, mapping) + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) + g1, rot_mat, g2, h2, sw = self.seat.forward( + nlist, coord_ext, atype_ext, None, mapping + ) + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) def set_stat_mean_and_stddev( self, @@ -801,7 +812,6 @@ def forward( protection=self.env_protection, ) dmatrix = dmatrix.view(-1, self.nnei, 4) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit result = torch.zeros( @@ -832,8 +842,6 @@ def forward( env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j) if self.compress: ebd_env_ij = env_ij.view(-1, 1) - ebd_env_ij = ebd_env_ij.to(dtype=self.prec) - env_ij = env_ij.to(dtype=self.prec) res_ij = torch.ops.deepmd.tabulate_fusion_se_t( compress_data_ii.contiguous(), compress_info_ii.cpu().contiguous(), @@ -853,7 +861,7 @@ def forward( # xyz_scatter /= (self.nnei * self.nnei) result = result.view(nf, nloc, self.filter_neuron[-1]) return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, None, None, None, diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index 82ccb06f32..42f1222dff 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -441,12 +441,14 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 g1_ext = self.type_embedding(extended_atype) g1_inp = g1_ext[:, :nloc, :] - g1, g2, h2, rot_mat, sw = self.se_ttebd( + g1, _, _, _, sw = self.se_ttebd( nlist, extended_coord, extended_atype, @@ -456,7 +458,13 @@ def forward( if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( @@ -540,12 +548,8 @@ def __init__( self.reinit_exclude(exclude_types) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.tebd_dim_input = self.tebd_dim * 2 @@ -849,7 +853,7 @@ def forward( # nf x nl x ng result = res_ij.view(nframes, nloc, self.filter_neuron[-1]) return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, None, None, None, diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index f2137bd004..2b8383806b 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -200,7 +200,6 @@ def forward( The output. """ ori_prec = xx.dtype - xx = xx.to(self.prec) yy = ( torch.matmul(xx, self.matrix) + self.bias if self.bias is not None @@ -215,7 +214,6 @@ def forward( yy += torch.concat([xx, xx], dim=-1) else: yy = yy - yy = yy.to(ori_prec) return yy def serialize(self) -> dict: diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index bc09fa4d0f..1642937652 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -186,6 +186,8 @@ def forward( ] # (nframes * nloc, 1, m1) out = out.view(-1, 1, self.embedding_width) + # cast from global to gr precision again + out = out.to(dtype=gr.dtype) # (nframes * nloc, m1, 3) gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes, nloc, 3) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 470b420c89..aef3b6d21a 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -403,7 +403,11 @@ def _forward_common( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, ): - xx = descriptor + # cast the input to internal precsion + xx = descriptor.to(self.prec) + fparam = fparam.to(self.prec) if fparam is not None else None + aparam = aparam.to(self.prec) if aparam is not None else None + if self.remove_vaccum_contribution is not None: # TODO: compute the input for vaccm when remove_vaccum_contribution is set # Ideally, the input for vacuum should be computed; @@ -477,10 +481,12 @@ def _forward_common( device=descriptor.device, ) # jit assertion if self.mixed_types: - atom_property = self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] + atom_property = self.filter_layers.networks[0](xx) if xx_zeros is not None: atom_property -= self.filter_layers.networks[0](xx_zeros) - outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out] + outs = ( + outs + atom_property + self.bias_atom_e[atype] + ) # Shape is [nframes, natoms[0], net_dim_out] else: for type_i, ll in enumerate(self.filter_layers.networks): mask = (atype == type_i).unsqueeze(-1) @@ -495,12 +501,12 @@ def _forward_common( ): atom_property -= ll(xx_zeros) atom_property = atom_property + self.bias_atom_e[type_i] - atom_property = atom_property * mask + atom_property = torch.where(mask, atom_property, 0.0) outs = ( outs + atom_property ) # Shape is [nframes, natoms[0], net_dim_out] # nf x nloc - mask = self.emask(atype) + mask = self.emask(atype).bool() # nf x nloc x nod - outs = outs * mask[:, :, None] + outs = torch.where(mask[:, :, None], outs, 0.0) return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 7b7f92c3af..36a205b637 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -238,6 +238,9 @@ def forward( self.var_name ] out = out * (self.scale.to(atype.device))[atype] + # cast from global to gr precision again + out = out.to(dtype=gr.dtype) + gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3) if self.fit_diag: diff --git a/source/tests/pt/model/test_compressed_descriptor_dpa2.py b/source/tests/pt/model/test_compressed_descriptor_dpa2.py index 05b1143eb1..c7b3deba7e 100644 --- a/source/tests/pt/model/test_compressed_descriptor_dpa2.py +++ b/source/tests/pt/model/test_compressed_descriptor_dpa2.py @@ -53,11 +53,6 @@ def eval_pt_descriptor( class TestDescriptorDPA2(unittest.TestCase): def setUp(self): (self.dtype, self.type_one_side) = self.param - if self.dtype == "float32": - self.skipTest("FP32 has bugs:") - # ../../../../deepmd/pt/model/descriptor/repformer_layer.py:521: in forward - # torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) - # E RuntimeError: expected scalar type Float but found Double if self.dtype == "float32": self.atol = 1e-5 elif self.dtype == "float64":