Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 26, 2024
1 parent c34d8b4 commit 19e0970
Showing 1 changed file with 40 additions and 57 deletions.
97 changes: 40 additions & 57 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,73 +48,56 @@ def forward_common_atomic(
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
if vdef.c_differentiable:
size = 1
for ii in vdef.shape:
size *= ii

split_ff = []
split_vv = []
for ss in range(size):

def eval_output(
cc_ext,
extended_atype,
nlist,
mapping,
fparam,
aparam,
*,
_kk=kk,
_ss=ss,
_atom_axis=atom_axis,
):
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext[None, ...],
extended_atype[None, ...],
nlist[None, ...],
mapping=mapping[None, ...] if mapping is not None else None,
fparam=fparam[None, ...] if fparam is not None else None,
aparam=aparam[None, ...] if aparam is not None else None,
)
return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)[_ss]

# extended_coord: [nf, nall, 3]
# ffi: [nf, nall, 3]
ffi = -jax.vmap(jax.grad(eval_output, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
def eval_output(
cc_ext,
extended_atype,
nlist,
mapping,
fparam,
aparam,
*,
_kk=kk,
_atom_axis=atom_axis,
):
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext[None, ...],
extended_atype[None, ...],
nlist[None, ...],
mapping=mapping[None, ...] if mapping is not None else None,
fparam=fparam[None, ...] if fparam is not None else None,
aparam=aparam[None, ...] if aparam is not None else None,
)
# ffi[..., None]: [nf, nall, 3, 1]
# extended_coord[..., None, :]: [nf, nall, 1, 3]
# aviri: [nf, nall, 3, 3]
aviri = ffi[..., None] @ extended_coord[..., None, :]
# aviri: [nf, nall, 9]
aviri = aviri.reshape(*aviri.shape[:-2], 9)
# ffi: [nf, nall, 1, 3]
ffi = ffi[..., None, :]
split_ff.append(ffi)
# aviri: [nf, nall, 1, 9]
aviri = aviri[..., None, :]
split_vv.append(aviri)
out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape
# extended_force: [nf, nall, def_size, 3]
return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)

# extended_coord: [nf, nall, 3]
# ff: [nf, *def, nall, 3]
ff = -jax.vmap(jax.jacrev(eval_output, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
# extended_force: [nf, nall, *def, 3]
extended_force = jnp.concat(split_ff, axis=-2).reshape(
*out_lead_shape, 3
def_ndim = len(vdef.shape)
extended_force = jnp.transpose(
ff, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
)

model_predict[kk_derv_r] = extended_force
if vdef.c_differentiable:
assert vdef.r_differentiable
# extended_virial: [nf, nall, def_size, 9]
# avr: [nf, *def, nall, 3, 3]
avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord)
# avr: [nf, *def, nall, 9]
avr = jnp.reshape(avr, [*ff.shape[:-1], 9])
# extended_virial: [nf, nall, *def, 9]
extended_virial = jnp.concat(split_vv, axis=-2).reshape(
*out_lead_shape, 9
extended_virial = jnp.transpose(
avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
)

# the correction sums to zero, which does not contribute to global virial
# cannot jit
# if do_atomic_virial:
Expand Down

0 comments on commit 19e0970

Please sign in to comment.