Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pt): fix precision #4344

Merged
merged 8 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.tabulate import (
Expand Down Expand Up @@ -311,6 +312,7 @@ def __init__(
use_tebd_bias=use_tebd_bias,
type_map=type_map,
)
self.prec = PRECISION_DICT[precision]
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
Expand Down Expand Up @@ -678,6 +680,8 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei

"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
Expand All @@ -693,7 +697,13 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)

return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None,
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def update_sel(
Expand Down
15 changes: 14 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
)
from deepmd.pt.utils.nlist import (
build_multiple_neighbor_list,
get_multiple_nlist_key,
Expand Down Expand Up @@ -268,6 +271,7 @@ def init_subclass_params(sub_data, sub_class):
)
self.concat_output_tebd = concat_output_tebd
self.precision = precision
self.prec = PRECISION_DICT[self.precision]
self.smooth = smooth
self.exclude_types = exclude_types
self.env_protection = env_protection
Expand Down Expand Up @@ -745,6 +749,9 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei

"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)

use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
Expand Down Expand Up @@ -810,7 +817,13 @@ def forward(
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down
12 changes: 6 additions & 6 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
)
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
Expand Down Expand Up @@ -237,6 +240,7 @@ def __init__(
self.reinit_exclude(exclude_types)
self.env_protection = env_protection
self.precision = precision
self.prec = PRECISION_DICT[precision]
iProzd marked this conversation as resolved.
Show resolved Hide resolved
self.trainable_ln = trainable_ln
self.ln_eps = ln_eps
self.epsilon = 1e-4
Expand Down Expand Up @@ -286,12 +290,8 @@ def __init__(
self.layers = torch.nn.ModuleList(layers)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.stats = None
Expand Down
19 changes: 15 additions & 4 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
super().__init__()
self.type_map = type_map
self.compress = False
self.prec = PRECISION_DICT[precision]
self.sea = DescrptBlockSeA(
rcut,
rcut_smth,
Expand Down Expand Up @@ -337,7 +338,18 @@ def forward(
The smooth switch function.

"""
return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping)
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
g1, rot_mat, g2, h2, sw = self.sea.forward(
nlist, coord_ext, atype_ext, None, mapping
)
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -742,7 +754,6 @@ def forward(
)

dmatrix = dmatrix.view(-1, self.nnei, 4)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
Expand Down Expand Up @@ -811,8 +822,8 @@ def forward(
result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
rot_mat,
None,
None,
sw,
Expand Down
10 changes: 2 additions & 8 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,8 @@ def __init__(
)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2
Expand Down Expand Up @@ -568,8 +564,6 @@ def forward(
# nfnl x nnei x ng
# gg = gg_s * gg_t + gg_s
gg_t = gg_t.reshape(-1, gg_t.size(-1))
# Convert all tensors to the required precision at once
ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t))
xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten(
self.compress_data[0].contiguous(),
self.compress_info[0].cpu().contiguous(),
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ def forward(
The smooth switch function.

"""
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
del mapping, comm_dict
nf = nlist.shape[0]
nloc = nlist.shape[1]
Expand All @@ -474,7 +476,6 @@ def forward(

assert self.filter_layers is not None
dmatrix = dmatrix.view(-1, self.nnei, 1)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
Expand Down Expand Up @@ -519,7 +520,7 @@ def forward(
None,
None,
None,
sw,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
Expand Down
19 changes: 14 additions & 5 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
super().__init__()
self.type_map = type_map
self.compress = False
self.prec = PRECISION_DICT[precision]
self.seat = DescrptBlockSeT(
rcut,
rcut_smth,
Expand Down Expand Up @@ -373,7 +374,18 @@ def forward(
The smooth switch function.

"""
return self.seat.forward(nlist, coord_ext, atype_ext, None, mapping)
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
g1, rot_mat, g2, h2, sw = self.seat.forward(
nlist, coord_ext, atype_ext, None, mapping
)
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -801,7 +813,6 @@ def forward(
protection=self.env_protection,
)
dmatrix = dmatrix.view(-1, self.nnei, 4)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
result = torch.zeros(
Expand Down Expand Up @@ -832,8 +843,6 @@ def forward(
env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j)
if self.compress:
ebd_env_ij = env_ij.view(-1, 1)
ebd_env_ij = ebd_env_ij.to(dtype=self.prec)
env_ij = env_ij.to(dtype=self.prec)
res_ij = torch.ops.deepmd.tabulate_fusion_se_t(
compress_data_ii.contiguous(),
compress_info_ii.cpu().contiguous(),
Expand All @@ -853,7 +862,7 @@ def forward(
# xyz_scatter /= (self.nnei * self.nnei)
result = result.view(nf, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
None,
None,
None,
Expand Down
23 changes: 14 additions & 9 deletions deepmd/pt/model/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
smooth=smooth,
seed=child_seed(seed, 1),
)
self.prec = PRECISION_DICT[precision]
self.use_econf_tebd = use_econf_tebd
self.type_map = type_map
self.smooth = smooth
Expand Down Expand Up @@ -441,12 +442,14 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei

"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
iProzd marked this conversation as resolved.
Show resolved Hide resolved
iProzd marked this conversation as resolved.
Show resolved Hide resolved
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
g1, g2, h2, rot_mat, sw = self.se_ttebd(
g1, _, _, _, sw = self.se_ttebd(
nlist,
extended_coord,
extended_atype,
Expand All @@ -456,7 +459,13 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)

return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down Expand Up @@ -540,12 +549,8 @@ def __init__(
self.reinit_exclude(exclude_types)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.tebd_dim_input = self.tebd_dim * 2
Expand Down Expand Up @@ -849,7 +854,7 @@ def forward(
# nf x nl x ng
result = res_ij.view(nframes, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
None,
None,
None,
Expand Down
2 changes: 0 additions & 2 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def forward(
The output.
"""
ori_prec = xx.dtype
xx = xx.to(self.prec)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
yy = (
torch.matmul(xx, self.matrix) + self.bias
if self.bias is not None
Expand All @@ -215,7 +214,6 @@ def forward(
yy += torch.concat([xx, xx], dim=-1)
else:
yy = yy
yy = yy.to(ori_prec)
return yy

def serialize(self) -> dict:
Expand Down
24 changes: 15 additions & 9 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,11 @@ def _forward_common(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):
xx = descriptor
# cast the input to internal precsion
xx = descriptor.to(self.prec)
fparam = fparam.to(self.prec) if fparam is not None else None
aparam = aparam.to(self.prec) if aparam is not None else None

if self.remove_vaccum_contribution is not None:
# TODO: compute the input for vaccm when remove_vaccum_contribution is set
# Ideally, the input for vacuum should be computed;
Expand Down Expand Up @@ -473,14 +477,16 @@ def _forward_common(

outs = torch.zeros(
(nf, nloc, net_dim_out),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
dtype=self.prec,
device=descriptor.device,
) # jit assertion
if self.mixed_types:
atom_property = self.filter_layers.networks[0](xx) + self.bias_atom_e[atype]
atom_property = self.filter_layers.networks[0](xx)
if xx_zeros is not None:
atom_property -= self.filter_layers.networks[0](xx_zeros)
outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out]
outs = (
outs + atom_property + self.bias_atom_e[atype].to(self.prec)
) # Shape is [nframes, natoms[0], net_dim_out]
else:
for type_i, ll in enumerate(self.filter_layers.networks):
mask = (atype == type_i).unsqueeze(-1)
Expand All @@ -494,13 +500,13 @@ def _forward_common(
and not self.remove_vaccum_contribution[type_i]
):
atom_property -= ll(xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask
atom_property = atom_property + self.bias_atom_e[type_i].to(self.prec)
atom_property = torch.where(mask, atom_property, 0.0)
outs = (
outs + atom_property
) # Shape is [nframes, natoms[0], net_dim_out]
# nf x nloc
mask = self.emask(atype)
mask = self.emask(atype).to(torch.bool)
# nf x nloc x nod
outs = outs * mask[:, :, None]
return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)}
outs = torch.where(mask[:, :, None], outs, 0.0)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
return {self.var_name: outs}
Loading
Loading