Skip to content

Commit

Permalink
fix(jax): fix several serialization and jit issues
Browse files Browse the repository at this point in the history
- `deepmd/jax/descriptor/__init__.py` imports SeT and DPA-2 to let them found by the plugin;
- `deepmd/dpmodel/descriptor/dpa1.py` fixes the jit issue regarding to the shape generated by `jnp.prod`. The shape should be static by using `math.prod`.
- `deepmd/jax/model/ener_model.py` and `deepmd/jax/model/dp_zbl_model.py` stop the graident of coordinates when rebuilding the neighbor list. The graient of sort causes an error due to jax-ml/jax#24730.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 5, 2024
1 parent dabedd2 commit 54dc410
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 2 deletions.
5 changes: 3 additions & 2 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import math
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -852,7 +853,7 @@ def cal_g(
):
xp = array_api_compat.array_namespace(ss)
nfnl, nnei = ss.shape[0:2]
shape2 = xp.prod(xp.asarray(ss.shape[2:]))
shape2 = math.prod(ss.shape[2:])
ss = xp.reshape(ss, (nfnl, nnei, shape2))
# nfnl x nnei x ng
gg = self.embeddings[embedding_idx].call(ss)
Expand All @@ -866,7 +867,7 @@ def cal_g_strip(
assert self.embeddings_strip is not None
xp = array_api_compat.array_namespace(ss)
nfnl, nnei = ss.shape[0:2]
shape2 = xp.prod(xp.asarray(ss.shape[2:]))
shape2 = math.prod(ss.shape[2:])
ss = xp.reshape(ss, (nfnl, nnei, shape2))
# nfnl x nnei x ng
gg = self.embeddings_strip[embedding_idx].call(ss)
Expand Down
8 changes: 8 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from deepmd.jax.descriptor.dpa1 import (
DescrptDPA1,
)
from deepmd.jax.descriptor.dpa2 import (
DescrptDPA2,
)
from deepmd.jax.descriptor.hybrid import (
DescrptHybrid,
)
Expand All @@ -14,11 +17,16 @@
from deepmd.jax.descriptor.se_t import (
DescrptSeT,
)
from deepmd.jax.descriptor.se_t_tebd import (
DescrptSeTTebd,
)

__all__ = [
"DescrptSeA",
"DescrptSeR",
"DescrptSeT",
"DescrptSeTTebd",
"DescrptDPA1",
"DescrptDPA2",
"DescrptHybrid",
]
16 changes: 16 additions & 0 deletions deepmd/jax/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
flax_module,
)
from deepmd.jax.env import (
jax,
jnp,
)
from deepmd.jax.model.base_model import (
Expand Down Expand Up @@ -48,3 +49,18 @@ def forward_common_atomic(
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

def format_nlist(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
extra_nlist_sort: bool = False,
):
return DPZBLModelDP.format_nlist(
self,
jax.lax.stop_gradient(extended_coord),
extended_atype,
nlist,
extra_nlist_sort=extra_nlist_sort,
)
16 changes: 16 additions & 0 deletions deepmd/jax/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
flax_module,
)
from deepmd.jax.env import (
jax,
jnp,
)
from deepmd.jax.model.base_model import (
Expand Down Expand Up @@ -48,3 +49,18 @@ def forward_common_atomic(
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

def format_nlist(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
extra_nlist_sort: bool = False,
):
return EnergyModelDP.format_nlist(
self,
jax.lax.stop_gradient(extended_coord),
extended_atype,
nlist,
extra_nlist_sort=extra_nlist_sort,
)

0 comments on commit 54dc410

Please sign in to comment.