Skip to content

Commit

Permalink
fix(pt): fix precision (deepmodeling#4344)
Browse files Browse the repository at this point in the history
Tried to implement the decorator as in deepmodeling#4343, but encountered JIT
errors.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Enhanced precision handling across various descriptor classes and
methods, ensuring consistent tensor operations.
- Updated output formats in several classes to improve clarity and
usability.
- Introduced a new environment variable for stricter control over tensor
precision handling.
- Added a new parameter to the `DipoleFittingNet` class for excluding
specific types.

- **Bug Fixes**
- Removed conditions that skipped tests for "float32" data type,
allowing all tests to run consistently.

- **Documentation**
- Improved error messages for dimension mismatches and unsupported
parameters, enhancing user understanding.

- **Tests**
- Adjusted test parameters for consistency in handling `fparam` and
`aparam` across multiple test cases.
- Simplified tensor handling in tests by removing unnecessary type
conversions before compression.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored Nov 13, 2024
1 parent 698b08d commit 47b76c8
Show file tree
Hide file tree
Showing 19 changed files with 114 additions and 64 deletions.
12 changes: 11 additions & 1 deletion deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.tabulate import (
Expand Down Expand Up @@ -311,6 +312,7 @@ def __init__(
use_tebd_bias=use_tebd_bias,
type_map=type_map,
)
self.prec = PRECISION_DICT[precision]
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
Expand Down Expand Up @@ -678,6 +680,8 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei
"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
Expand All @@ -693,7 +697,13 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)

return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None,
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down
15 changes: 14 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
)
from deepmd.pt.utils.nlist import (
build_multiple_neighbor_list,
get_multiple_nlist_key,
Expand Down Expand Up @@ -268,6 +271,7 @@ def init_subclass_params(sub_data, sub_class):
)
self.concat_output_tebd = concat_output_tebd
self.precision = precision
self.prec = PRECISION_DICT[self.precision]
self.smooth = smooth
self.exclude_types = exclude_types
self.env_protection = env_protection
Expand Down Expand Up @@ -745,6 +749,9 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei
"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)

use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
Expand Down Expand Up @@ -810,7 +817,13 @@ def forward(
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down
12 changes: 6 additions & 6 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
)
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
Expand Down Expand Up @@ -237,6 +240,7 @@ def __init__(
self.reinit_exclude(exclude_types)
self.env_protection = env_protection
self.precision = precision
self.prec = PRECISION_DICT[precision]
self.trainable_ln = trainable_ln
self.ln_eps = ln_eps
self.epsilon = 1e-4
Expand Down Expand Up @@ -286,12 +290,8 @@ def __init__(
self.layers = torch.nn.ModuleList(layers)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.stats = None
Expand Down
19 changes: 15 additions & 4 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
super().__init__()
self.type_map = type_map
self.compress = False
self.prec = PRECISION_DICT[precision]
self.sea = DescrptBlockSeA(
rcut,
rcut_smth,
Expand Down Expand Up @@ -337,7 +338,18 @@ def forward(
The smooth switch function.
"""
return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping)
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
g1, rot_mat, g2, h2, sw = self.sea.forward(
nlist, coord_ext, atype_ext, None, mapping
)
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -742,7 +754,6 @@ def forward(
)

dmatrix = dmatrix.view(-1, self.nnei, 4)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
Expand Down Expand Up @@ -811,8 +822,8 @@ def forward(
result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
rot_mat,
None,
None,
sw,
Expand Down
10 changes: 2 additions & 8 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,8 @@ def __init__(
)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2
Expand Down Expand Up @@ -568,8 +564,6 @@ def forward(
# nfnl x nnei x ng
# gg = gg_s * gg_t + gg_s
gg_t = gg_t.reshape(-1, gg_t.size(-1))
# Convert all tensors to the required precision at once
ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t))
xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten(
self.compress_data[0].contiguous(),
self.compress_info[0].cpu().contiguous(),
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ def forward(
The smooth switch function.
"""
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
del mapping, comm_dict
nf = nlist.shape[0]
nloc = nlist.shape[1]
Expand All @@ -474,7 +476,6 @@ def forward(

assert self.filter_layers is not None
dmatrix = dmatrix.view(-1, self.nnei, 1)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
Expand Down Expand Up @@ -519,7 +520,7 @@ def forward(
None,
None,
None,
sw,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
Expand Down
19 changes: 14 additions & 5 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
super().__init__()
self.type_map = type_map
self.compress = False
self.prec = PRECISION_DICT[precision]
self.seat = DescrptBlockSeT(
rcut,
rcut_smth,
Expand Down Expand Up @@ -373,7 +374,18 @@ def forward(
The smooth switch function.
"""
return self.seat.forward(nlist, coord_ext, atype_ext, None, mapping)
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
g1, rot_mat, g2, h2, sw = self.seat.forward(
nlist, coord_ext, atype_ext, None, mapping
)
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -801,7 +813,6 @@ def forward(
protection=self.env_protection,
)
dmatrix = dmatrix.view(-1, self.nnei, 4)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
result = torch.zeros(
Expand Down Expand Up @@ -832,8 +843,6 @@ def forward(
env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j)
if self.compress:
ebd_env_ij = env_ij.view(-1, 1)
ebd_env_ij = ebd_env_ij.to(dtype=self.prec)
env_ij = env_ij.to(dtype=self.prec)
res_ij = torch.ops.deepmd.tabulate_fusion_se_t(
compress_data_ii.contiguous(),
compress_info_ii.cpu().contiguous(),
Expand All @@ -853,7 +862,7 @@ def forward(
# xyz_scatter /= (self.nnei * self.nnei)
result = result.view(nf, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
None,
None,
None,
Expand Down
23 changes: 14 additions & 9 deletions deepmd/pt/model/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
smooth=smooth,
seed=child_seed(seed, 1),
)
self.prec = PRECISION_DICT[precision]
self.use_econf_tebd = use_econf_tebd
self.type_map = type_map
self.smooth = smooth
Expand Down Expand Up @@ -441,12 +442,14 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei
"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
g1, g2, h2, rot_mat, sw = self.se_ttebd(
g1, _, _, _, sw = self.se_ttebd(
nlist,
extended_coord,
extended_atype,
Expand All @@ -456,7 +459,13 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)

return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down Expand Up @@ -540,12 +549,8 @@ def __init__(
self.reinit_exclude(exclude_types)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.tebd_dim_input = self.tebd_dim * 2
Expand Down Expand Up @@ -849,7 +854,7 @@ def forward(
# nf x nl x ng
result = res_ij.view(nframes, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
None,
None,
None,
Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def forward(
The output.
"""
ori_prec = xx.dtype
xx = xx.to(self.prec)
if not env.DP_DTYPE_PROMOTION_STRICT:
xx = xx.to(self.prec)
yy = (
torch.matmul(xx, self.matrix) + self.bias
if self.bias is not None
Expand All @@ -215,7 +216,8 @@ def forward(
yy += torch.concat([xx, xx], dim=-1)
else:
yy = yy
yy = yy.to(ori_prec)
if not env.DP_DTYPE_PROMOTION_STRICT:
yy = yy.to(ori_prec)
return yy

def serialize(self) -> dict:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def forward(
):
nframes, nloc, _ = descriptor.shape
assert gr is not None, "Must provide the rotation matrix for dipole fitting."
# cast the input to internal precsion
gr = gr.to(self.prec)
# (nframes, nloc, m1)
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
Expand Down
Loading

0 comments on commit 47b76c8

Please sign in to comment.