forked from deepmodeling/deepmd-kit
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(jax): force & virial (deepmodeling#4251)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced new methods `forward_common_atomic` in multiple classes to enhance atomic model predictions and derivative calculations. - Added a new function `get_leading_dims` for better handling of output dimensions. - Added a new function `scatter_sum` for performing reduction operations on tensors. - Updated test methods to include flexible handling of results with the new `SKIP_FLAG` variable. - **Bug Fixes** - Improved numerical stability in calculations by ensuring small values are handled appropriately. - **Tests** - Expanded test outputs to include additional data like forces and virials for more comprehensive testing. - Enhanced backend handling in tests to accommodate new return values based on backend availability. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
10 changed files
with
284 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,107 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from typing import ( | ||
Optional, | ||
) | ||
|
||
from deepmd.dpmodel.model.base_model import ( | ||
make_base_model, | ||
) | ||
from deepmd.dpmodel.output_def import ( | ||
get_deriv_name, | ||
get_reduce_name, | ||
) | ||
from deepmd.jax.env import ( | ||
jax, | ||
jnp, | ||
) | ||
|
||
BaseModel = make_base_model() | ||
|
||
|
||
def forward_common_atomic( | ||
self, | ||
extended_coord: jnp.ndarray, | ||
extended_atype: jnp.ndarray, | ||
nlist: jnp.ndarray, | ||
mapping: Optional[jnp.ndarray] = None, | ||
fparam: Optional[jnp.ndarray] = None, | ||
aparam: Optional[jnp.ndarray] = None, | ||
do_atomic_virial: bool = False, | ||
): | ||
atomic_ret = self.atomic_model.forward_common_atomic( | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping=mapping, | ||
fparam=fparam, | ||
aparam=aparam, | ||
) | ||
atomic_output_def = self.atomic_output_def() | ||
model_predict = {} | ||
for kk, vv in atomic_ret.items(): | ||
model_predict[kk] = vv | ||
vdef = atomic_output_def[kk] | ||
shap = vdef.shape | ||
atom_axis = -(len(shap) + 1) | ||
if vdef.reducible: | ||
kk_redu = get_reduce_name(kk) | ||
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis) | ||
kk_derv_r, kk_derv_c = get_deriv_name(kk) | ||
if vdef.c_differentiable: | ||
|
||
def eval_output( | ||
cc_ext, | ||
extended_atype, | ||
nlist, | ||
mapping, | ||
fparam, | ||
aparam, | ||
*, | ||
_kk=kk, | ||
_atom_axis=atom_axis, | ||
): | ||
atomic_ret = self.atomic_model.forward_common_atomic( | ||
cc_ext[None, ...], | ||
extended_atype[None, ...], | ||
nlist[None, ...], | ||
mapping=mapping[None, ...] if mapping is not None else None, | ||
fparam=fparam[None, ...] if fparam is not None else None, | ||
aparam=aparam[None, ...] if aparam is not None else None, | ||
) | ||
return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis) | ||
|
||
# extended_coord: [nf, nall, 3] | ||
# ff: [nf, *def, nall, 3] | ||
ff = -jax.vmap(jax.jacrev(eval_output, argnums=0))( | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping, | ||
fparam, | ||
aparam, | ||
) | ||
# extended_force: [nf, nall, *def, 3] | ||
def_ndim = len(vdef.shape) | ||
extended_force = jnp.transpose( | ||
ff, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] | ||
) | ||
|
||
model_predict[kk_derv_r] = extended_force | ||
if vdef.c_differentiable: | ||
assert vdef.r_differentiable | ||
# avr: [nf, *def, nall, 3, 3] | ||
avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord) | ||
# avr: [nf, *def, nall, 9] | ||
avr = jnp.reshape(avr, [*ff.shape[:-1], 9]) | ||
# extended_virial: [nf, nall, *def, 9] | ||
extended_virial = jnp.transpose( | ||
avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] | ||
) | ||
|
||
# the correction sums to zero, which does not contribute to global virial | ||
# cannot jit | ||
# if do_atomic_virial: | ||
# raise NotImplementedError("Atomic virial is not implemented yet.") | ||
# to [...,3,3] -> [...,9] | ||
model_predict[kk_derv_c] = extended_virial | ||
return model_predict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.