diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py index 077209e29a..5898fd3ff8 100644 --- a/deepmd/jax/atomic_model/dp_atomic_model.py +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + Optional, ) from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP @@ -13,6 +14,10 @@ from deepmd.jax.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.jax.env import ( + jax, + jnp, +) from deepmd.jax.fitting.base_fitting import ( BaseFitting, ) @@ -28,3 +33,21 @@ class DPAtomicModel(DPAtomicModelDP): def __setattr__(self, name: str, value: Any) -> None: value = base_atomic_model_set_attr(name, value) return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + ) -> dict[str, jnp.ndarray]: + return super().forward_common_atomic( + extended_coord, + extended_atype, + jax.lax.stop_gradient(nlist), + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) diff --git a/deepmd/jax/atomic_model/linear_atomic_model.py b/deepmd/jax/atomic_model/linear_atomic_model.py index c86e3ef02c..6ce82fa07c 100644 --- a/deepmd/jax/atomic_model/linear_atomic_model.py +++ b/deepmd/jax/atomic_model/linear_atomic_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + Optional, ) from deepmd.dpmodel.atomic_model.linear_atomic_model import ( @@ -20,6 +21,10 @@ flax_module, to_jax_array, ) +from deepmd.jax.env import ( + jax, + jnp, +) @flax_module @@ -36,3 +41,21 @@ def __setattr__(self, name: str, value: Any) -> None: PairTabAtomicModel.deserialize(value[1].serialize()), ] return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + ) -> dict[str, jnp.ndarray]: + return super().forward_common_atomic( + extended_coord, + extended_atype, + jax.lax.stop_gradient(nlist), + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) diff --git a/deepmd/jax/atomic_model/pairtab_atomic_model.py b/deepmd/jax/atomic_model/pairtab_atomic_model.py index 2401362b63..023f4e886a 100644 --- a/deepmd/jax/atomic_model/pairtab_atomic_model.py +++ b/deepmd/jax/atomic_model/pairtab_atomic_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + Optional, ) from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( @@ -14,6 +15,10 @@ flax_module, to_jax_array, ) +from deepmd.jax.env import ( + jax, + jnp, +) @flax_module @@ -25,3 +30,21 @@ def __setattr__(self, name: str, value: Any) -> None: if value is not None: value = ArrayAPIVariable(value) return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + ) -> dict[str, jnp.ndarray]: + return super().forward_common_atomic( + extended_coord, + extended_atype, + jax.lax.stop_gradient(nlist), + mapping=mapping, + fparam=fparam, + aparam=aparam, + )