Skip to content

Commit

Permalink
feat(jax): force & virial (#4251)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced new methods `forward_common_atomic` in multiple classes to
enhance atomic model predictions and derivative calculations.
- Added a new function `get_leading_dims` for better handling of output
dimensions.
- Added a new function `scatter_sum` for performing reduction operations
on tensors.
- Updated test methods to include flexible handling of results with the
new `SKIP_FLAG` variable.

- **Bug Fixes**
- Improved numerical stability in calculations by ensuring small values
are handled appropriately.

- **Tests**
- Expanded test outputs to include additional data like forces and
virials for more comprehensive testing.
- Enhanced backend handling in tests to accommodate new return values
based on backend availability.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 29, 2024
1 parent dd36e6c commit 159361d
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 17 deletions.
30 changes: 25 additions & 5 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,22 +222,42 @@ def call_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
atomic_ret = self.atomic_model.forward_common_atomic(
model_predict = self.forward_common_atomic(
cc_ext,
extended_atype,
nlist,
mapping=mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

def forward_common_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
model_predict = fit_output_to_model_output(
return fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
extended_coord,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

forward_lower = call_lower

Expand Down
84 changes: 78 additions & 6 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
get_deriv_name,
get_reduce_name,
)
Expand Down Expand Up @@ -47,6 +48,28 @@ def fit_output_to_model_output(
return model_ret


def get_leading_dims(
vv: np.ndarray,
vdef: OutputVariableDef,
):
"""Get the dimensions of nf x nloc.
Parameters
----------
vv : np.ndarray
The input array from which to compute the leading dimensions.
vdef : OutputVariableDef
The output variable definition containing the shape to exclude from `vv`.
Returns
-------
list
A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions.
"""
vshape = vv.shape
return list(vshape[: (len(vshape) - len(vdef.shape))])


def communicate_extended_output(
model_ret: dict[str, np.ndarray],
model_output_def: ModelOutputDef,
Expand All @@ -57,6 +80,7 @@ def communicate_extended_output(
local and ghost (extended) atoms to local atoms.
"""
xp = array_api_compat.get_namespace(mapping)
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
Expand All @@ -65,15 +89,63 @@ def communicate_extended_output(
if vdef.reducible:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
kk_derv_r, kk_derv_c = get_deriv_name(kk)
mldims = list(mapping.shape)
vldims = get_leading_dims(vv, vdef)
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name holders
new_ret[kk_derv_r] = None
if model_ret[kk_derv_r] is not None:
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.common import (
scatter_sum,
)

force = scatter_sum(
force,
1,
mapping,
model_ret[kk_derv_r],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_derv_r] = force
else:
# name holders
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if model_ret[kk_derv_c] is not None:
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
mapping = xp.tile(
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
)
virial = xp.zeros(
vldims + derv_c_ext_dims,
dtype=vv.dtype,
)
# jax only
if array_api_compat.is_jax_array(virial):
from deepmd.jax.common import (
scatter_sum,
)

virial = scatter_sum(
virial,
1,
mapping,
model_ret[kk_derv_c],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_derv_c] = virial
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
else:
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if not do_atomic_virial:
# pop atomic virial, because it is not correctly calculated.
new_ret.pop(kk_derv_c)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _make_env_mat(
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True)
# the grad of JAX vector_norm is NaN at x=0
diff_ = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff)
length = xp.linalg.vector_norm(diff_, axis=-1, keepdims=True)
# for index 0 nloc atom
length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype)
t0 = 1 / (length + protection)
Expand Down
10 changes: 10 additions & 0 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,13 @@ def __dlpack__(self, *args, **kwargs):

def __dlpack_device__(self, *args, **kwargs):
return self.value.__dlpack_device__(*args, **kwargs)


def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray:
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape)
new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel()
shape = input.shape
input = input.ravel()
input = input.at[new_idx].add(src.ravel())
return input.reshape(shape)
1 change: 1 addition & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True)

__all__ = [
"jax",
Expand Down
101 changes: 101 additions & 0 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,107 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

from deepmd.dpmodel.model.base_model import (
make_base_model,
)
from deepmd.dpmodel.output_def import (
get_deriv_name,
get_reduce_name,
)
from deepmd.jax.env import (
jax,
jnp,
)

BaseModel = make_base_model()


def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
atomic_output_def = self.atomic_output_def()
model_predict = {}
for kk, vv in atomic_ret.items():
model_predict[kk] = vv
vdef = atomic_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
if vdef.reducible:
kk_redu = get_reduce_name(kk)
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
if vdef.c_differentiable:

def eval_output(
cc_ext,
extended_atype,
nlist,
mapping,
fparam,
aparam,
*,
_kk=kk,
_atom_axis=atom_axis,
):
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext[None, ...],
extended_atype[None, ...],
nlist[None, ...],
mapping=mapping[None, ...] if mapping is not None else None,
fparam=fparam[None, ...] if fparam is not None else None,
aparam=aparam[None, ...] if aparam is not None else None,
)
return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)

# extended_coord: [nf, nall, 3]
# ff: [nf, *def, nall, 3]
ff = -jax.vmap(jax.jacrev(eval_output, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
# extended_force: [nf, nall, *def, 3]
def_ndim = len(vdef.shape)
extended_force = jnp.transpose(
ff, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
)

model_predict[kk_derv_r] = extended_force
if vdef.c_differentiable:
assert vdef.r_differentiable
# avr: [nf, *def, nall, 3, 3]
avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord)
# avr: [nf, *def, nall, 9]
avr = jnp.reshape(avr, [*ff.shape[:-1], 9])
# extended_virial: [nf, nall, *def, 9]
extended_virial = jnp.transpose(
avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
)

# the correction sums to zero, which does not contribute to global virial
# cannot jit
# if do_atomic_virial:
# raise NotImplementedError("Atomic virial is not implemented yet.")
# to [...,3,3] -> [...,9]
model_predict[kk_derv_c] = extended_virial
return model_predict
26 changes: 26 additions & 0 deletions deepmd/jax/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Optional,
)

from deepmd.dpmodel.model import EnergyModel as EnergyModelDP
Expand All @@ -10,8 +11,12 @@
from deepmd.jax.common import (
flax_module,
)
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.model.base_model import (
BaseModel,
forward_common_atomic,
)


Expand All @@ -22,3 +27,24 @@ def __setattr__(self, name: str, value: Any) -> None:
if name == "atomic_model":
value = DPAtomicModel.deserialize(value.serialize())
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
return forward_common_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
4 changes: 4 additions & 0 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
"INSTALLED_ARRAY_API_STRICT",
]

SKIP_FLAG = object()


class CommonTest(ABC):
data: ClassVar[dict]
Expand Down Expand Up @@ -362,6 +364,8 @@ def test_dp_consistent_with_ref(self):
data2 = dp_obj.serialize()
np.testing.assert_equal(data1, data2)
for rr1, rr2 in zip(ret1, ret2):
if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG:
continue
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
{},
suffix=suffix,
)
return [ret["energy"], ret["atom_ener"]], {
return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], {
t_coord: coords,
t_type: atype,
t_natoms: natoms,
Expand Down
Loading

0 comments on commit 159361d

Please sign in to comment.