diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index afe2eaffb6..e36182e712 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -222,22 +222,42 @@ def call_lower( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam - atomic_ret = self.atomic_model.forward_common_atomic( + model_predict = self.forward_common_atomic( cc_ext, extended_atype, nlist, mapping=mapping, fparam=fp, aparam=ap, + do_atomic_virial=do_atomic_virial, + ) + model_predict = self.output_type_cast(model_predict, input_prec) + return model_predict + + def forward_common_atomic( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + do_atomic_virial: bool = False, + ): + atomic_ret = self.atomic_model.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, ) - model_predict = fit_output_to_model_output( + return fit_output_to_model_output( atomic_ret, self.atomic_output_def(), - cc_ext, + extended_coord, do_atomic_virial=do_atomic_virial, ) - model_predict = self.output_type_cast(model_predict, input_prec) - return model_predict forward_lower = call_lower diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 107455a6d5..af1429ce25 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -9,6 +9,7 @@ from deepmd.dpmodel.output_def import ( FittingOutputDef, ModelOutputDef, + OutputVariableDef, get_deriv_name, get_reduce_name, ) @@ -47,6 +48,28 @@ def fit_output_to_model_output( return model_ret +def get_leading_dims( + vv: np.ndarray, + vdef: OutputVariableDef, +): + """Get the dimensions of nf x nloc. + + Parameters + ---------- + vv : np.ndarray + The input array from which to compute the leading dimensions. + vdef : OutputVariableDef + The output variable definition containing the shape to exclude from `vv`. + + Returns + ------- + list + A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions. + """ + vshape = vv.shape + return list(vshape[: (len(vshape) - len(vdef.shape))]) + + def communicate_extended_output( model_ret: dict[str, np.ndarray], model_output_def: ModelOutputDef, @@ -57,6 +80,7 @@ def communicate_extended_output( local and ghost (extended) atoms to local atoms. """ + xp = array_api_compat.get_namespace(mapping) new_ret = {} for kk in model_output_def.keys_outp(): vv = model_ret[kk] @@ -65,15 +89,63 @@ def communicate_extended_output( if vdef.reducible: kk_redu = get_reduce_name(kk) new_ret[kk_redu] = model_ret[kk_redu] + kk_derv_r, kk_derv_c = get_deriv_name(kk) + mldims = list(mapping.shape) + vldims = get_leading_dims(vv, vdef) if vdef.r_differentiable: - kk_derv_r, kk_derv_c = get_deriv_name(kk) - # name holders - new_ret[kk_derv_r] = None + if model_ret[kk_derv_r] is not None: + derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005 + mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims))) + mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims) + force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype) + # jax only + if array_api_compat.is_jax_array(force): + from deepmd.jax.common import ( + scatter_sum, + ) + + force = scatter_sum( + force, + 1, + mapping, + model_ret[kk_derv_r], + ) + else: + raise NotImplementedError("Only JAX arrays are supported.") + new_ret[kk_derv_r] = force + else: + # name holders + new_ret[kk_derv_r] = None if vdef.c_differentiable: assert vdef.r_differentiable - kk_derv_r, kk_derv_c = get_deriv_name(kk) - new_ret[kk_derv_c] = None - new_ret[kk_derv_c + "_redu"] = None + if model_ret[kk_derv_c] is not None: + derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005 + mapping = xp.tile( + mapping, [1] * (len(mldims) + len(vdef.shape)) + [3] + ) + virial = xp.zeros( + vldims + derv_c_ext_dims, + dtype=vv.dtype, + ) + # jax only + if array_api_compat.is_jax_array(virial): + from deepmd.jax.common import ( + scatter_sum, + ) + + virial = scatter_sum( + virial, + 1, + mapping, + model_ret[kk_derv_c], + ) + else: + raise NotImplementedError("Only JAX arrays are supported.") + new_ret[kk_derv_c] = virial + new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1) + else: + new_ret[kk_derv_c] = None + new_ret[kk_derv_c + "_redu"] = None if not do_atomic_virial: # pop atomic virial, because it is not correctly calculated. new_ret.pop(kk_derv_c) diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index f4bc333a03..aa8520202e 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -61,7 +61,9 @@ def _make_env_mat( # nf x nloc x nnei x 3 diff = coord_r - coord_l # nf x nloc x nnei - length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True) + # the grad of JAX vector_norm is NaN at x=0 + diff_ = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff) + length = xp.linalg.vector_norm(diff_, axis=-1, keepdims=True) # for index 0 nloc atom length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype) t0 = 1 / (length + protection) diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index f372e97eb5..59f36d11ad 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -95,3 +95,13 @@ def __dlpack__(self, *args, **kwargs): def __dlpack_device__(self, *args, **kwargs): return self.value.__dlpack_device__(*args, **kwargs) + + +def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray: + """Reduces all values from the src tensor to the indices specified in the index tensor.""" + idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape) + new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel() + shape = input.shape + input = input.ravel() + input = input.at[new_idx].add(src.ravel()) + return input.reshape(shape) diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 5a5a7f6bf0..ee11e17125 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -10,6 +10,7 @@ ) jax.config.update("jax_enable_x64", True) +# jax.config.update("jax_debug_nans", True) __all__ = [ "jax", diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index fee4855da3..8631c85d16 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -1,6 +1,107 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + from deepmd.dpmodel.model.base_model import ( make_base_model, ) +from deepmd.dpmodel.output_def import ( + get_deriv_name, + get_reduce_name, +) +from deepmd.jax.env import ( + jax, + jnp, +) BaseModel = make_base_model() + + +def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, +): + atomic_ret = self.atomic_model.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + atomic_output_def = self.atomic_output_def() + model_predict = {} + for kk, vv in atomic_ret.items(): + model_predict[kk] = vv + vdef = atomic_output_def[kk] + shap = vdef.shape + atom_axis = -(len(shap) + 1) + if vdef.reducible: + kk_redu = get_reduce_name(kk) + model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis) + kk_derv_r, kk_derv_c = get_deriv_name(kk) + if vdef.c_differentiable: + + def eval_output( + cc_ext, + extended_atype, + nlist, + mapping, + fparam, + aparam, + *, + _kk=kk, + _atom_axis=atom_axis, + ): + atomic_ret = self.atomic_model.forward_common_atomic( + cc_ext[None, ...], + extended_atype[None, ...], + nlist[None, ...], + mapping=mapping[None, ...] if mapping is not None else None, + fparam=fparam[None, ...] if fparam is not None else None, + aparam=aparam[None, ...] if aparam is not None else None, + ) + return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis) + + # extended_coord: [nf, nall, 3] + # ff: [nf, *def, nall, 3] + ff = -jax.vmap(jax.jacrev(eval_output, argnums=0))( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + # extended_force: [nf, nall, *def, 3] + def_ndim = len(vdef.shape) + extended_force = jnp.transpose( + ff, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] + ) + + model_predict[kk_derv_r] = extended_force + if vdef.c_differentiable: + assert vdef.r_differentiable + # avr: [nf, *def, nall, 3, 3] + avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord) + # avr: [nf, *def, nall, 9] + avr = jnp.reshape(avr, [*ff.shape[:-1], 9]) + # extended_virial: [nf, nall, *def, 9] + extended_virial = jnp.transpose( + avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] + ) + + # the correction sums to zero, which does not contribute to global virial + # cannot jit + # if do_atomic_virial: + # raise NotImplementedError("Atomic virial is not implemented yet.") + # to [...,3,3] -> [...,9] + model_predict[kk_derv_c] = extended_virial + return model_predict diff --git a/deepmd/jax/model/ener_model.py b/deepmd/jax/model/ener_model.py index 79c5a29e88..b1bf568544 100644 --- a/deepmd/jax/model/ener_model.py +++ b/deepmd/jax/model/ener_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + Optional, ) from deepmd.dpmodel.model import EnergyModel as EnergyModelDP @@ -10,8 +11,12 @@ from deepmd.jax.common import ( flax_module, ) +from deepmd.jax.env import ( + jnp, +) from deepmd.jax.model.base_model import ( BaseModel, + forward_common_atomic, ) @@ -22,3 +27,24 @@ def __setattr__(self, name: str, value: Any) -> None: if name == "atomic_model": value = DPAtomicModel.deserialize(value.serialize()) return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + return forward_common_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 885662c766..bcad7c4502 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -69,6 +69,8 @@ "INSTALLED_ARRAY_API_STRICT", ] +SKIP_FLAG = object() + class CommonTest(ABC): data: ClassVar[dict] @@ -362,6 +364,8 @@ def test_dp_consistent_with_ref(self): data2 = dp_obj.serialize() np.testing.assert_equal(data1, data2) for rr1, rr2 in zip(ret1, ret2): + if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG: + continue np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 4112e09cff..11940d9bdf 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix): {}, suffix=suffix, ) - return [ret["energy"], ret["atom_ener"]], { + return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], { t_coord: coords, t_type: atype, t_natoms: natoms, diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 78a2aac703..2a358ba7e0 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -16,6 +16,7 @@ INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, + SKIP_FLAG, CommonTest, parameterized, ) @@ -94,6 +95,21 @@ def data(self) -> dict: jax_class = EnergyModelJAX args = model_args() + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_jax: + return self.RefBackend.JAX + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + @property def skip_tf(self): return ( @@ -195,11 +211,26 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: - return (ret["energy_redu"].ravel(), ret["energy"].ravel()) + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + ) elif backend is self.RefBackend.PT: - return (ret["energy"].ravel(), ret["atom_energy"].ravel()) + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ) elif backend is self.RefBackend.TF: - return (ret[0].ravel(), ret[1].ravel()) + return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) elif backend is self.RefBackend.JAX: - return (ret["energy_redu"].ravel(), ret["energy"].ravel()) + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + ret["energy_derv_r"].ravel(), + ret["energy_derv_c_redu"].ravel(), + ) raise ValueError(f"Unknown backend: {backend}")