diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index ba8858d6b9..5f48b14131 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -62,6 +62,7 @@ jobs: env: NUM_WORKERS: 0 DP_TEST_TF2_ONLY: 1 + DP_DTYPE_PROMOTION_STRICT: 1 if: matrix.group == 1 - run: mv .test_durations .test_durations_${{ matrix.group }} - name: Upload partial durations diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 6e6113b494..63f3a34105 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -3,8 +3,12 @@ ABC, abstractmethod, ) +from functools import ( + wraps, +) from typing import ( Any, + Callable, Optional, ) @@ -116,6 +120,94 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]: return np.from_dlpack(x) +def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]: + """A decorator that casts and casts back the input + and output tensor of a method. + + The decorator should be used in a classmethod. + + The decorator will do the following thing: + (1) It casts input arrays from the global precision + to precision defined by property `precision`. + (2) It casts output arrays from `precision` to + the global precision. + (3) It checks inputs and outputs and only casts when + input or output is an array and its dtype matches + the global precision and `precision`, respectively. + If it does not match (e.g. it is an integer), the decorator + will do nothing on it. + + The decorator supports the array API. + + Returns + ------- + Callable + a decorator that casts and casts back the input and + output array of a method + + Examples + -------- + >>> class A: + ... def __init__(self): + ... self.precision = "float32" + ... + ... @cast_precision + ... def f(x: Array, y: Array) -> Array: + ... return x**2 + y + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + # only convert tensors + returned_tensor = func( + self, + *[safe_cast_array(vv, "global", self.precision) for vv in args], + **{ + kk: safe_cast_array(vv, "global", self.precision) + for kk, vv in kwargs.items() + }, + ) + if isinstance(returned_tensor, tuple): + return tuple( + safe_cast_array(vv, self.precision, "global") for vv in returned_tensor + ) + else: + return safe_cast_array(returned_tensor, self.precision, "global") + + return wrapper + + +def safe_cast_array( + input: np.ndarray, from_precision: str, to_precision: str +) -> np.ndarray: + """Convert an array from a precision to another precision. + + If input is not an array or without the specific precision, the method will not + cast it. + + Array API is supported. + + Parameters + ---------- + input : tf.Tensor + Input tensor + from_precision : str + Array data type that is casted from + to_precision : str + Array data type that casts to + + Returns + ------- + tf.Tensor + casted Tensor + """ + if array_api_compat.is_array_api_obj(input): + xp = array_api_compat.array_namespace(input) + if input.dtype == get_xp_precision(xp, from_precision): + return xp.astype(input, get_xp_precision(xp, to_precision)) + return input + + __all__ = [ "GLOBAL_NP_FLOAT_PRECISION", "GLOBAL_ENER_FLOAT_PRECISION", diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 259593e731..d21fc492c3 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -19,6 +19,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -329,6 +330,7 @@ def __init__( self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -448,6 +450,7 @@ def change_type_map( obj["davg"] = obj["davg"][remap_index] obj["dstd"] = obj["dstd"][remap_index] + @cast_precision def call( self, coord_ext, diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 097be2ef09..d82f136a9f 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -14,6 +14,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -594,6 +595,7 @@ def init_subclass_params(sub_data, sub_class): self.rcut = self.repinit.get_rcut() self.ntypes = ntypes self.sel = self.repinit.sel + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -757,6 +759,7 @@ def get_stat_mean_and_stddev(self) -> tuple[list[np.ndarray], list[np.ndarray]]: stddev_list.append(self.repinit_three_body.stddev) return mean_list, stddev_list + @cast_precision def call( self, coord_ext: np.ndarray, diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 63402b6f84..4ffdd025e8 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -15,6 +15,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -29,9 +30,6 @@ from deepmd.dpmodel.utils.update_sel import ( UpdateSel, ) -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -340,6 +338,7 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + @cast_precision def call( self, coord_ext, @@ -415,9 +414,7 @@ def call( # nf x nloc x ng x ng1 grrg = np.einsum("flid,fljd->flij", gr, gr1) # nf x nloc x (ng x ng1) - grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( - GLOBAL_NP_FLOAT_PRECISION - ) + grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron) return grrg, gr[..., 1:], None, None, ww def serialize(self) -> dict: @@ -506,6 +503,7 @@ def update_sel( class DescrptSeAArrayAPI(DescrptSeA): + @cast_precision def call( self, coord_ext, @@ -585,7 +583,5 @@ def call( # grrg = xp.einsum("flid,fljd->flij", gr, gr1) grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4) # nf x nloc x (ng x ng1) - grrg = xp.astype( - xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype - ) + grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)) return grrg, gr[..., 1:], None, None, ww diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index d652eb1420..45757c68ec 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -14,6 +14,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, get_xp_precision, to_numpy_array, ) @@ -289,6 +290,7 @@ def cal_g( gg = self.embeddings[(ll,)].call(ss) return gg + @cast_precision def call( self, coord_ext, @@ -352,7 +354,6 @@ def call( res_rescale = 1.0 / 5.0 res = xyz_scatter * res_rescale res = xp.reshape(res, (nf, nloc, ng)) - res = xp.astype(res, get_xp_precision(xp, "global")) return res, None, None, None, ww def serialize(self) -> dict: diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index be587c77da..38bd660af2 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -14,6 +14,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, get_xp_precision, to_numpy_array, ) @@ -264,6 +265,7 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + @cast_precision def call( self, coord_ext, @@ -317,7 +319,6 @@ def call( # we don't require atype is the same in all frames exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) rr = xp.reshape(rr, (nf * nloc, nnei, 4)) - rr = xp.astype(rr, get_xp_precision(xp, self.precision)) for embedding_idx in itertools.product( range(self.ntypes), repeat=self.embeddings.ndim @@ -349,7 +350,6 @@ def call( result += res_ij # nf x nloc x ng result = xp.reshape(result, (nf, nloc, ng)) - result = xp.astype(result, get_xp_precision(xp, "global")) return result, None, None, None, ww def serialize(self) -> dict: diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 298f823690..b1b7cfa930 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -16,7 +16,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( - get_xp_precision, + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -168,6 +168,7 @@ def __init__( self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -287,6 +288,7 @@ def change_type_map( obj["davg"] = obj["davg"][remap_index] obj["dstd"] = obj["dstd"][remap_index] + @cast_precision def call( self, coord_ext, @@ -741,7 +743,6 @@ def call( res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei)) # nf x nl x ng result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1])) - result = xp.astype(result, get_xp_precision(xp, "global")) return ( result, None, diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 344dab7ff1..b4691bf8a3 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -364,6 +364,11 @@ def _call_common( """ xp = array_api_compat.array_namespace(descriptor, atype) + descriptor = xp.astype(descriptor, get_xp_precision(xp, self.precision)) + if fparam is not None: + fparam = xp.astype(fparam, get_xp_precision(xp, self.precision)) + if aparam is not None: + aparam = xp.astype(aparam, get_xp_precision(xp, self.precision)) nf, nloc, nd = descriptor.shape net_dim_out = self._net_out_dim() # check input dim @@ -439,18 +444,24 @@ def _call_common( ): assert xx_zeros is not None atom_property -= self.nets[(type_i,)](xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i, ...] - atom_property = atom_property * xp.astype(mask, atom_property.dtype) + atom_property = xp.where( + mask, atom_property, xp.zeros_like(atom_property) + ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + outs = xp.astype(outs, get_xp_precision(xp, "global")) + for type_i in range(self.ntypes): + outs = outs + self.bias_atom_e[type_i, ...] else: - outs = self.nets[()](xx) + xp.reshape( + outs = self.nets[()](xx) + if xx_zeros is not None: + outs -= self.nets[()](xx_zeros) + outs = xp.astype(outs, get_xp_precision(xp, "global")) + outs += xp.reshape( xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), [nf, nloc, net_dim_out], ) - if xx_zeros is not None: - outs -= self.nets[()](xx_zeros) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod - outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype) + outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))} diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 1b90433b00..02e31ae66e 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -13,6 +13,9 @@ jax.config.update("jax_enable_x64", True) # jax.config.update("jax_debug_nans", True) +if os.environ.get("DP_DTYPE_PROMOTION_STRICT") == "1": + jax.config.update("jax_numpy_dtype_promotion", "strict") + __all__ = [ "jax", "jnp",