Skip to content

Commit

Permalink
fix(dpmodel): fix precision (#4343)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced a new environment variable `DP_DTYPE_PROMOTION_STRICT` to
enhance precision handling in TensorFlow tests.
- Added a decorator `@cast_precision` to several descriptor classes,
improving precision management during computations.
- Updated JAX configuration to enable strict dtype promotion based on
the new environment variable.
- Enhanced serialization and deserialization processes to include
precision attributes across multiple classes.

- **Bug Fixes**
- Enhanced type handling and input processing in the `GeneralFitting`
class for better output predictions.
- Improved handling of atomic contributions and exclusions in the
`BaseAtomicModel` class.
- Addressed potential type mismatches during matrix operations in the
`NativeLayer` class.

- **Chores**
- Updated caching mechanisms in the testing workflow to ensure unique
keys based on run parameters.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 14, 2024
1 parent 058e066 commit 6e815a2
Show file tree
Hide file tree
Showing 16 changed files with 175 additions and 34 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
env:
NUM_WORKERS: 0
DP_TEST_TF2_ONLY: 1
DP_DTYPE_PROMOTION_STRICT: 1
if: matrix.group == 1
- run: mv .test_durations .test_durations_${{ matrix.group }}
- name: Upload partial durations
Expand Down
15 changes: 8 additions & 7 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,19 @@ def forward_common_atomic(
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32)
atom_mask = ext_atom_mask[:, :nloc]
if self.atom_excl is not None:
atom_mask *= self.atom_excl.build_type_exclude_mask(atype)
atom_mask = xp.logical_and(
atom_mask, self.atom_excl.build_type_exclude_mask(atype)
)

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = math.prod(out_shape[2:])
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask
tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr))
ret_dict[kk] = xp.reshape(tmp_arr, out_shape)
ret_dict["mask"] = xp.astype(atom_mask, xp.int32)

return ret_dict

Expand Down
104 changes: 104 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
ABC,
abstractmethod,
)
from functools import (
wraps,
)
from typing import (
Any,
Callable,
Optional,
overload,
)

import array_api_compat
Expand Down Expand Up @@ -116,6 +121,105 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]:
return np.from_dlpack(x)


def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that casts and casts back the input
and output tensor of a method.
The decorator should be used on an instance method.
The decorator will do the following thing:
(1) It casts input arrays from the global precision
to precision defined by property `precision`.
(2) It casts output arrays from `precision` to
the global precision.
(3) It checks inputs and outputs and only casts when
input or output is an array and its dtype matches
the global precision and `precision`, respectively.
If it does not match (e.g. it is an integer), the decorator
will do nothing on it.
The decorator supports the array API.
Returns
-------
Callable
a decorator that casts and casts back the input and
output array of a method
Examples
--------
>>> class A:
... def __init__(self):
... self.precision = "float32"
...
... @cast_precision
... def f(x: Array, y: Array) -> Array:
... return x**2 + y
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# only convert tensors
returned_tensor = func(
self,
*[safe_cast_array(vv, "global", self.precision) for vv in args],
**{
kk: safe_cast_array(vv, "global", self.precision)
for kk, vv in kwargs.items()
},
)
if isinstance(returned_tensor, tuple):
return tuple(
safe_cast_array(vv, self.precision, "global") for vv in returned_tensor
)
elif isinstance(returned_tensor, dict):
return {
kk: safe_cast_array(vv, self.precision, "global")
for kk, vv in returned_tensor.items()
}
else:
return safe_cast_array(returned_tensor, self.precision, "global")

return wrapper


@overload
def safe_cast_array(
input: np.ndarray, from_precision: str, to_precision: str
) -> np.ndarray: ...
@overload
def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ...
def safe_cast_array(
input: Optional[np.ndarray], from_precision: str, to_precision: str
) -> Optional[np.ndarray]:
"""Convert an array from a precision to another precision.
If input is not an array or without the specific precision, the method will not
cast it.
Array API is supported.
Parameters
----------
input : np.ndarray or None
Input array
from_precision : str
Array data type that is casted from
to_precision : str
Array data type that casts to
Returns
-------
np.ndarray or None
casted array
"""
if array_api_compat.is_array_api_obj(input):
xp = array_api_compat.array_namespace(input)
if input.dtype == get_xp_precision(xp, from_precision):
return xp.astype(input, get_xp_precision(xp, to_precision))
return input


__all__ = [
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -330,6 +331,7 @@ def __init__(
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -451,6 +453,7 @@ def change_type_map(
obj["davg"] = obj["davg"][remap_index]
obj["dstd"] = obj["dstd"][remap_index]

@cast_precision
def call(
self,
coord_ext,
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -595,6 +596,7 @@ def init_subclass_params(sub_data, sub_class):
self.rcut = self.repinit.get_rcut()
self.ntypes = ntypes
self.sel = self.repinit.sel
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -760,6 +762,7 @@ def get_stat_mean_and_stddev(self) -> tuple[list[np.ndarray], list[np.ndarray]]:
stddev_list.append(self.repinit_three_body.stddev)
return mean_list, stddev_list

@cast_precision
def call(
self,
coord_ext: np.ndarray,
Expand Down
14 changes: 5 additions & 9 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand All @@ -30,9 +31,6 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -343,6 +341,7 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -418,9 +417,7 @@ def call(
# nf x nloc x ng x ng1
grrg = np.einsum("flid,fljd->flij", gr, gr1)
# nf x nloc x (ng x ng1)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype(
GLOBAL_NP_FLOAT_PRECISION
)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron)
return grrg, gr[..., 1:], None, None, ww

def serialize(self) -> dict:
Expand Down Expand Up @@ -509,6 +506,7 @@ def update_sel(


class DescrptSeAArrayAPI(DescrptSeA):
@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -588,7 +586,5 @@ def call(
# grrg = xp.einsum("flid,fljd->flij", gr, gr1)
grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)
# nf x nloc x (ng x ng1)
grrg = xp.astype(
xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype
)
grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron))
return grrg, gr[..., 1:], None, None, ww
3 changes: 2 additions & 1 deletion deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
get_xp_precision,
to_numpy_array,
)
Expand Down Expand Up @@ -292,6 +293,7 @@ def cal_g(
gg = self.embeddings[(ll,)].call(ss)
return gg

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -355,7 +357,6 @@ def call(
res_rescale = 1.0 / 5.0
res = xyz_scatter * res_rescale
res = xp.reshape(res, (nf, nloc, ng))
res = xp.astype(res, get_xp_precision(xp, "global"))
return res, None, None, None, ww

def serialize(self) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
get_xp_precision,
to_numpy_array,
)
Expand Down Expand Up @@ -267,6 +268,7 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -320,7 +322,6 @@ def call(
# we don't require atype is the same in all frames
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
rr = xp.astype(rr, get_xp_precision(xp, self.precision))

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
Expand Down Expand Up @@ -352,7 +353,6 @@ def call(
result += res_ij
# nf x nloc x ng
result = xp.reshape(result, (nf, nloc, ng))
result = xp.astype(result, get_xp_precision(xp, "global"))
return result, None, None, None, ww

def serialize(self) -> dict:
Expand Down
5 changes: 3 additions & 2 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
get_xp_precision,
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -169,6 +169,7 @@ def __init__(
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -290,6 +291,7 @@ def change_type_map(
obj["davg"] = obj["davg"][remap_index]
obj["dstd"] = obj["dstd"][remap_index]

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -744,7 +746,6 @@ def call(
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
# nf x nl x ng
result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1]))
result = xp.astype(result, get_xp_precision(xp, "global"))
return (
result,
None,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.common import (
cast_precision,
)
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
Expand Down Expand Up @@ -174,6 +177,7 @@ def output_def(self):
]
)

@cast_precision
def call(
self,
descriptor: np.ndarray,
Expand Down
20 changes: 12 additions & 8 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,18 +439,22 @@ def _call_common(
):
assert xx_zeros is not None
atom_property -= self.nets[(type_i,)](xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i, ...]
atom_property = atom_property * xp.astype(mask, atom_property.dtype)
atom_property = xp.where(
mask, atom_property, xp.zeros_like(atom_property)
)
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
outs = self.nets[()](xx) + xp.reshape(
xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
[nf, nloc, net_dim_out],
)
outs = self.nets[()](xx)
if xx_zeros is not None:
outs -= self.nets[()](xx_zeros)
outs += xp.reshape(
xp.take(
xp.astype(self.bias_atom_e, outs.dtype), xp.reshape(atype, [-1]), axis=0
),
[nf, nloc, net_dim_out],
)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
# nf x nloc x nod
outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype)
return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))}
outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
return {self.var_name: outs}
Loading

0 comments on commit 6e815a2

Please sign in to comment.