Skip to content

Commit

Permalink
fix(pt): fix precision
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 12, 2024
1 parent c4a973a commit f0cdbe4
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 49 deletions.
10 changes: 9 additions & 1 deletion deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei
"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
Expand All @@ -693,7 +695,13 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)

return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None,
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down
15 changes: 14 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
)
from deepmd.pt.utils.nlist import (
build_multiple_neighbor_list,
get_multiple_nlist_key,
Expand Down Expand Up @@ -268,6 +271,7 @@ def init_subclass_params(sub_data, sub_class):
)
self.concat_output_tebd = concat_output_tebd
self.precision = precision
self.prec = PRECISION_DICT[self.precision]
self.smooth = smooth
self.exclude_types = exclude_types
self.env_protection = env_protection
Expand Down Expand Up @@ -745,6 +749,9 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei
"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)

use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
Expand Down Expand Up @@ -810,7 +817,13 @@ def forward(
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down
12 changes: 6 additions & 6 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
)
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
Expand Down Expand Up @@ -237,6 +240,7 @@ def __init__(
self.reinit_exclude(exclude_types)
self.env_protection = env_protection
self.precision = precision
self.prec = PRECISION_DICT[precision]
self.trainable_ln = trainable_ln
self.ln_eps = ln_eps
self.epsilon = 1e-4
Expand Down Expand Up @@ -286,12 +290,8 @@ def __init__(
self.layers = torch.nn.ModuleList(layers)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.stats = None
Expand Down
18 changes: 14 additions & 4 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,18 @@ def forward(
The smooth switch function.
"""
return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping)
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
g1, rot_mat, g2, h2, sw = self.sea.forward(
nlist, coord_ext, atype_ext, None, mapping
)
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -742,7 +753,6 @@ def forward(
)

dmatrix = dmatrix.view(-1, self.nnei, 4)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
Expand Down Expand Up @@ -811,8 +821,8 @@ def forward(
result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
rot_mat,
None,
None,
sw,
Expand Down
10 changes: 2 additions & 8 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,8 @@ def __init__(
)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2
Expand Down Expand Up @@ -568,8 +564,6 @@ def forward(
# nfnl x nnei x ng
# gg = gg_s * gg_t + gg_s
gg_t = gg_t.reshape(-1, gg_t.size(-1))
# Convert all tensors to the required precision at once
ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t))
xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten(
self.compress_data[0].contiguous(),
self.compress_info[0].cpu().contiguous(),
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ def forward(
The smooth switch function.
"""
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
del mapping, comm_dict
nf = nlist.shape[0]
nloc = nlist.shape[1]
Expand All @@ -474,7 +476,6 @@ def forward(

assert self.filter_layers is not None
dmatrix = dmatrix.view(-1, self.nnei, 1)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
Expand Down Expand Up @@ -519,7 +520,7 @@ def forward(
None,
None,
None,
sw,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
Expand Down
18 changes: 13 additions & 5 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,18 @@ def forward(
The smooth switch function.
"""
return self.seat.forward(nlist, coord_ext, atype_ext, None, mapping)
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
g1, rot_mat, g2, h2, sw = self.seat.forward(
nlist, coord_ext, atype_ext, None, mapping
)
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -801,7 +812,6 @@ def forward(
protection=self.env_protection,
)
dmatrix = dmatrix.view(-1, self.nnei, 4)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
result = torch.zeros(
Expand Down Expand Up @@ -832,8 +842,6 @@ def forward(
env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j)
if self.compress:
ebd_env_ij = env_ij.view(-1, 1)
ebd_env_ij = ebd_env_ij.to(dtype=self.prec)
env_ij = env_ij.to(dtype=self.prec)
res_ij = torch.ops.deepmd.tabulate_fusion_se_t(
compress_data_ii.contiguous(),
compress_info_ii.cpu().contiguous(),
Expand All @@ -853,7 +861,7 @@ def forward(
# xyz_scatter /= (self.nnei * self.nnei)
result = result.view(nf, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
None,
None,
None,
Expand Down
22 changes: 13 additions & 9 deletions deepmd/pt/model/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,14 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei
"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
g1, g2, h2, rot_mat, sw = self.se_ttebd(
g1, _, _, _, sw = self.se_ttebd(
nlist,
extended_coord,
extended_atype,
Expand All @@ -456,7 +458,13 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)

return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down Expand Up @@ -540,12 +548,8 @@ def __init__(
self.reinit_exclude(exclude_types)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.tebd_dim_input = self.tebd_dim * 2
Expand Down Expand Up @@ -849,7 +853,7 @@ def forward(
# nf x nl x ng
result = res_ij.view(nframes, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
None,
None,
None,
Expand Down
2 changes: 0 additions & 2 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def forward(
The output.
"""
ori_prec = xx.dtype
xx = xx.to(self.prec)
yy = (
torch.matmul(xx, self.matrix) + self.bias
if self.bias is not None
Expand All @@ -215,7 +214,6 @@ def forward(
yy += torch.concat([xx, xx], dim=-1)
else:
yy = yy
yy = yy.to(ori_prec)
return yy

def serialize(self) -> dict:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def forward(
]
# (nframes * nloc, 1, m1)
out = out.view(-1, 1, self.embedding_width)
# cast from global to gr precision again
out = out.to(dtype=gr.dtype)
# (nframes * nloc, m1, 3)
gr = gr.view(nframes * nloc, self.embedding_width, 3)
# (nframes, nloc, 3)
Expand Down
18 changes: 12 additions & 6 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,11 @@ def _forward_common(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):
xx = descriptor
# cast the input to internal precsion
xx = descriptor.to(self.prec)
fparam = fparam.to(self.prec) if fparam is not None else None
aparam = aparam.to(self.prec) if aparam is not None else None

if self.remove_vaccum_contribution is not None:
# TODO: compute the input for vaccm when remove_vaccum_contribution is set
# Ideally, the input for vacuum should be computed;
Expand Down Expand Up @@ -477,10 +481,12 @@ def _forward_common(
device=descriptor.device,
) # jit assertion
if self.mixed_types:
atom_property = self.filter_layers.networks[0](xx) + self.bias_atom_e[atype]
atom_property = self.filter_layers.networks[0](xx)
if xx_zeros is not None:
atom_property -= self.filter_layers.networks[0](xx_zeros)
outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out]
outs = (
outs + atom_property + self.bias_atom_e[atype]
) # Shape is [nframes, natoms[0], net_dim_out]
else:
for type_i, ll in enumerate(self.filter_layers.networks):
mask = (atype == type_i).unsqueeze(-1)
Expand All @@ -495,12 +501,12 @@ def _forward_common(
):
atom_property -= ll(xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask
atom_property = torch.where(mask, atom_property, 0.0)
outs = (
outs + atom_property
) # Shape is [nframes, natoms[0], net_dim_out]
# nf x nloc
mask = self.emask(atype)
mask = self.emask(atype).bool()
# nf x nloc x nod
outs = outs * mask[:, :, None]
outs = torch.where(mask[:, :, None], outs, 0.0)
return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)}
3 changes: 3 additions & 0 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ def forward(
self.var_name
]
out = out * (self.scale.to(atype.device))[atype]
# cast from global to gr precision again
out = out.to(dtype=gr.dtype)

gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3)

if self.fit_diag:
Expand Down
5 changes: 0 additions & 5 deletions source/tests/pt/model/test_compressed_descriptor_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,6 @@ def eval_pt_descriptor(
class TestDescriptorDPA2(unittest.TestCase):
def setUp(self):
(self.dtype, self.type_one_side) = self.param
if self.dtype == "float32":
self.skipTest("FP32 has bugs:")
# ../../../../deepmd/pt/model/descriptor/repformer_layer.py:521: in forward
# torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni)
# E RuntimeError: expected scalar type Float but found Double
if self.dtype == "float32":
self.atol = 1e-5
elif self.dtype == "float64":
Expand Down

0 comments on commit f0cdbe4

Please sign in to comment.