diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 76115b2810..b541e665ab 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -22,6 +22,7 @@ env, ) from deepmd.pt.utils.env import ( + PRECISION_DICT, RESERVED_PRECISON_DICT, ) from deepmd.pt.utils.tabulate import ( @@ -311,6 +312,7 @@ def __init__( use_tebd_bias=use_tebd_bias, type_map=type_map, ) + self.prec = PRECISION_DICT[precision] self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable @@ -678,6 +680,8 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 @@ -693,7 +697,13 @@ def forward( if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None, + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 77e9f1d936..3c1db49568 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -27,6 +27,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) from deepmd.pt.utils.nlist import ( build_multiple_neighbor_list, get_multiple_nlist_key, @@ -268,6 +271,7 @@ def init_subclass_params(sub_data, sub_class): ) self.concat_output_tebd = concat_output_tebd self.precision = precision + self.prec = PRECISION_DICT[self.precision] self.smooth = smooth self.exclude_types = exclude_types self.env_protection = env_protection @@ -745,6 +749,9 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + use_three_body = self.use_three_body nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 @@ -810,7 +817,13 @@ def forward( ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index bd1109e8b7..d646179fa2 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -22,6 +22,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSe, ) @@ -237,6 +240,7 @@ def __init__( self.reinit_exclude(exclude_types) self.env_protection = env_protection self.precision = precision + self.prec = PRECISION_DICT[precision] self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.epsilon = 1e-4 @@ -286,12 +290,8 @@ def __init__( self.layers = torch.nn.ModuleList(layers) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.stats = None diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index eadce86963..0554c474f0 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -118,6 +118,7 @@ def __init__( super().__init__() self.type_map = type_map self.compress = False + self.prec = PRECISION_DICT[precision] self.sea = DescrptBlockSeA( rcut, rcut_smth, @@ -337,7 +338,18 @@ def forward( The smooth switch function. """ - return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping) + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) + g1, rot_mat, g2, h2, sw = self.sea.forward( + nlist, coord_ext, atype_ext, None, mapping + ) + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) def set_stat_mean_and_stddev( self, @@ -742,7 +754,6 @@ def forward( ) dmatrix = dmatrix.view(-1, self.nnei, 4) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit xyz_scatter = torch.zeros( @@ -811,8 +822,8 @@ def forward( result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron) rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005 return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), - rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, + rot_mat, None, None, sw, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 6ec02de514..caf1bed6f7 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -227,12 +227,8 @@ def __init__( ) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 @@ -568,8 +564,6 @@ def forward( # nfnl x nnei x ng # gg = gg_s * gg_t + gg_s gg_t = gg_t.reshape(-1, gg_t.size(-1)) - # Convert all tensors to the required precision at once - ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t)) xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten( self.compress_data[0].contiguous(), self.compress_info[0].cpu().contiguous(), diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index f70fdfa9f1..ab64eab1a7 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -456,6 +456,8 @@ def forward( The smooth switch function. """ + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) del mapping, comm_dict nf = nlist.shape[0] nloc = nlist.shape[1] @@ -474,7 +476,6 @@ def forward( assert self.filter_layers is not None dmatrix = dmatrix.view(-1, self.nnei, 1) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit xyz_scatter = torch.zeros( @@ -519,7 +520,7 @@ def forward( None, None, None, - sw, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), ) def set_stat_mean_and_stddev( diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 0eec78fd2f..8da6be8094 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -154,6 +154,7 @@ def __init__( super().__init__() self.type_map = type_map self.compress = False + self.prec = PRECISION_DICT[precision] self.seat = DescrptBlockSeT( rcut, rcut_smth, @@ -373,7 +374,18 @@ def forward( The smooth switch function. """ - return self.seat.forward(nlist, coord_ext, atype_ext, None, mapping) + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) + g1, rot_mat, g2, h2, sw = self.seat.forward( + nlist, coord_ext, atype_ext, None, mapping + ) + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) def set_stat_mean_and_stddev( self, @@ -801,7 +813,6 @@ def forward( protection=self.env_protection, ) dmatrix = dmatrix.view(-1, self.nnei, 4) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit result = torch.zeros( @@ -832,8 +843,6 @@ def forward( env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j) if self.compress: ebd_env_ij = env_ij.view(-1, 1) - ebd_env_ij = ebd_env_ij.to(dtype=self.prec) - env_ij = env_ij.to(dtype=self.prec) res_ij = torch.ops.deepmd.tabulate_fusion_se_t( compress_data_ii.contiguous(), compress_info_ii.cpu().contiguous(), @@ -853,7 +862,7 @@ def forward( # xyz_scatter /= (self.nnei * self.nnei) result = result.view(nf, nloc, self.filter_neuron[-1]) return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, None, None, None, diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index 82ccb06f32..e7b4827724 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -161,6 +161,7 @@ def __init__( smooth=smooth, seed=child_seed(seed, 1), ) + self.prec = PRECISION_DICT[precision] self.use_econf_tebd = use_econf_tebd self.type_map = type_map self.smooth = smooth @@ -441,12 +442,14 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 g1_ext = self.type_embedding(extended_atype) g1_inp = g1_ext[:, :nloc, :] - g1, g2, h2, rot_mat, sw = self.se_ttebd( + g1, _, _, _, sw = self.se_ttebd( nlist, extended_coord, extended_atype, @@ -456,7 +459,13 @@ def forward( if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( @@ -540,12 +549,8 @@ def __init__( self.reinit_exclude(exclude_types) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.tebd_dim_input = self.tebd_dim * 2 @@ -849,7 +854,7 @@ def forward( # nf x nl x ng result = res_ij.view(nframes, nloc, self.filter_neuron[-1]) return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, None, None, None, diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index f2137bd004..a5ac086770 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -200,7 +200,8 @@ def forward( The output. """ ori_prec = xx.dtype - xx = xx.to(self.prec) + if not env.DP_DTYPE_PROMOTION_STRICT: + xx = xx.to(self.prec) yy = ( torch.matmul(xx, self.matrix) + self.bias if self.bias is not None @@ -215,7 +216,8 @@ def forward( yy += torch.concat([xx, xx], dim=-1) else: yy = yy - yy = yy.to(ori_prec) + if not env.DP_DTYPE_PROMOTION_STRICT: + yy = yy.to(ori_prec) return yy def serialize(self) -> dict: diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index bc09fa4d0f..af81e54d13 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -180,6 +180,8 @@ def forward( ): nframes, nloc, _ = descriptor.shape assert gr is not None, "Must provide the rotation matrix for dipole fitting." + # cast the input to internal precsion + gr = gr.to(self.prec) # (nframes, nloc, m1) out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 470b420c89..798e271c8f 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -403,7 +403,11 @@ def _forward_common( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, ): - xx = descriptor + # cast the input to internal precsion + xx = descriptor.to(self.prec) + fparam = fparam.to(self.prec) if fparam is not None else None + aparam = aparam.to(self.prec) if aparam is not None else None + if self.remove_vaccum_contribution is not None: # TODO: compute the input for vaccm when remove_vaccum_contribution is set # Ideally, the input for vacuum should be computed; @@ -473,14 +477,16 @@ def _forward_common( outs = torch.zeros( (nf, nloc, net_dim_out), - dtype=env.GLOBAL_PT_FLOAT_PRECISION, + dtype=self.prec, device=descriptor.device, ) # jit assertion if self.mixed_types: - atom_property = self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] + atom_property = self.filter_layers.networks[0](xx) if xx_zeros is not None: atom_property -= self.filter_layers.networks[0](xx_zeros) - outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out] + outs = ( + outs + atom_property + self.bias_atom_e[atype].to(self.prec) + ) # Shape is [nframes, natoms[0], net_dim_out] else: for type_i, ll in enumerate(self.filter_layers.networks): mask = (atype == type_i).unsqueeze(-1) @@ -494,13 +500,13 @@ def _forward_common( and not self.remove_vaccum_contribution[type_i] ): atom_property -= ll(xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i] - atom_property = atom_property * mask + atom_property = atom_property + self.bias_atom_e[type_i].to(self.prec) + atom_property = torch.where(mask, atom_property, 0.0) outs = ( outs + atom_property ) # Shape is [nframes, natoms[0], net_dim_out] # nf x nloc - mask = self.emask(atype) + mask = self.emask(atype).to(torch.bool) # nf x nloc x nod - outs = outs * mask[:, :, None] - return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + outs = torch.where(mask[:, :, None], outs, 0.0) + return {self.var_name: outs} diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index acdd5b0fda..c339f4690d 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -177,7 +177,10 @@ def forward( ------- - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ - return self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam) + out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ + self.var_name + ] + return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} # make jit happy with torch 2.0.0 exclude_types: list[int] diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 7b7f92c3af..40a569fa80 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -233,11 +233,14 @@ def forward( assert ( gr is not None ), "Must provide the rotation matrix for polarizability fitting." + # cast the input to internal precsion + gr = gr.to(self.prec) # (nframes, nloc, _net_out_dim) out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] - out = out * (self.scale.to(atype.device))[atype] + out = out * (self.scale.to(atype.device).to(self.prec))[atype] + gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3) if self.fit_diag: diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 3ee0b7b54d..81dce669ff 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -15,6 +15,7 @@ ) SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) +DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1" try: # only linux ncpus = len(os.sched_getaffinity(0)) diff --git a/source/tests/pt/model/test_compressed_descriptor_dpa2.py b/source/tests/pt/model/test_compressed_descriptor_dpa2.py index 05b1143eb1..c7b3deba7e 100644 --- a/source/tests/pt/model/test_compressed_descriptor_dpa2.py +++ b/source/tests/pt/model/test_compressed_descriptor_dpa2.py @@ -53,11 +53,6 @@ def eval_pt_descriptor( class TestDescriptorDPA2(unittest.TestCase): def setUp(self): (self.dtype, self.type_one_side) = self.param - if self.dtype == "float32": - self.skipTest("FP32 has bugs:") - # ../../../../deepmd/pt/model/descriptor/repformer_layer.py:521: in forward - # torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) - # E RuntimeError: expected scalar type Float but found Double if self.dtype == "float32": self.atol = 1e-5 elif self.dtype == "float64": diff --git a/source/tests/pt/model/test_compressed_descriptor_se_atten.py b/source/tests/pt/model/test_compressed_descriptor_se_atten.py index a439255396..edb682b27b 100644 --- a/source/tests/pt/model/test_compressed_descriptor_se_atten.py +++ b/source/tests/pt/model/test_compressed_descriptor_se_atten.py @@ -115,11 +115,6 @@ def test_compressed_forward(self): self.box, ) - if self.dtype == "float32": - result_pt = result_pt.to(torch.float32) - elif self.dtype == "float64": - result_pt = result_pt.to(torch.float64) - self.se_atten.enable_compression(0.5) result_pt_compressed = eval_pt_descriptor( self.se_atten, diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 71da2781ac..0c4121f457 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -262,7 +262,7 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["dipole"]) np.testing.assert_allclose( @@ -303,7 +303,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["dipole"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 1ca563a8c2..4e63145741 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -326,7 +326,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["polarizability"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) diff --git a/source/tests/pt/model/test_property_fitting.py b/source/tests/pt/model/test_property_fitting.py index dfe2725f3b..ad5f3687e9 100644 --- a/source/tests/pt/model/test_property_fitting.py +++ b/source/tests/pt/model/test_property_fitting.py @@ -228,7 +228,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["property"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -399,7 +399,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["property"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))