Skip to content

Commit

Permalink
feat(jax): freeze to StableXLO & DeepEval (#4256)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced support for `.hlo` file extensions in model loading and
saving functionalities.
- Added a `DeepEval` class for enhanced deep learning model evaluation
in molecular simulations.
- Implemented a new `HLO` class for managing model predictions within a
deep learning framework.

- **Bug Fixes**
- Improved handling of suffixes and backend names in test cases for
better consistency.

- **Documentation**
	- Added SPDX license identifier to relevant files.

- **Chores**
	- Refactored internal methods to streamline model prediction processes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and coderabbitai[bot] authored Oct 30, 2024
1 parent 159361d commit d165fee
Show file tree
Hide file tree
Showing 10 changed files with 875 additions and 41 deletions.
10 changes: 7 additions & 3 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ class JAXBackend(Backend):
features: ClassVar[Backend.Feature] = (
Backend.Feature.IO
| Backend.Feature.ENTRY_POINT
# | Backend.Feature.DEEP_EVAL
| Backend.Feature.DEEP_EVAL
| Backend.Feature.NEIGHBOR_STAT
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".jax"]
suffixes: ClassVar[list[str]] = [".hlo", ".jax"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down Expand Up @@ -71,7 +71,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]:
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError
from deepmd.jax.infer.deep_eval import (
DeepEval,
)

return DeepEval

@property
def neighbor_stat(self) -> type["NeighborStat"]:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def call(
coord_ext, atype_ext, nlist, self.davg, self.dstd
)
nf, nloc, nnei, _ = rr.shape
sec = xp.asarray(self.sel_cumsum)
sec = self.sel_cumsum

ng = self.neuron[-1]
gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype)
Expand Down
130 changes: 99 additions & 31 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Optional,
)

Expand Down Expand Up @@ -39,6 +40,95 @@
)


def model_call_from_call_lower(
*, # enforce keyword-only arguments
call_lower: Callable[
[
np.ndarray,
np.ndarray,
np.ndarray,
Optional[np.ndarray],
Optional[np.ndarray],
bool,
],
dict[str, np.ndarray],
],
rcut: float,
sel: list[int],
mixed_types: bool,
model_output_def: ModelOutputDef,
coord: np.ndarray,
atype: np.ndarray,
box: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
):
"""Return model prediction from lower interface.
Parameters
----------
coord
The coordinates of the atoms.
shape: nf x (nloc x 3)
atype
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.
Returns
-------
ret_dict
The result dict of type dict[str,np.ndarray].
The keys are defined by the `ModelOutputDef`.
"""
nframes, nloc = atype.shape[:2]
cc, bb, fp, ap = coord, box, fparam, aparam
del coord, box, fparam, aparam
if bb is not None:
coord_normalized = normalize_coord(
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = cc.copy()
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, bb, rcut
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
rcut,
sel,
distinguish_types=not mixed_types,
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
model_predict_lower = call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = communicate_extended_output(
model_predict_lower,
model_output_def,
mapping,
do_atomic_virial=do_atomic_virial,
)
return model_predict


def make_model(T_AtomicModel: type[BaseAtomicModel]):
"""Make a model as a derived class of an atomic model.
Expand Down Expand Up @@ -130,45 +220,23 @@ def call(
The keys are defined by the `ModelOutputDef`.
"""
nframes, nloc = atype.shape[:2]
cc, bb, fp, ap, input_prec = self.input_type_cast(
coord, box=box, fparam=fparam, aparam=aparam
)
del coord, box, fparam, aparam
if bb is not None:
coord_normalized = normalize_coord(
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = cc.copy()
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, bb, self.get_rcut()
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
self.get_rcut(),
self.get_sel(),
distinguish_types=not self.mixed_types(),
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
model_predict_lower = self.call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
model_predict = model_call_from_call_lower(
call_lower=self.call_lower,
rcut=self.get_rcut(),
sel=self.get_sel(),
mixed_types=self.mixed_types(),
model_output_def=self.model_output_def(),
coord=cc,
atype=atype,
box=bb,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = communicate_extended_output(
model_predict_lower,
self.model_output_def(),
mapping,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def save_dp_model(filename: str, model_dict: dict) -> None:
# use UTC+0 time
"time": str(datetime.datetime.now(tz=datetime.timezone.utc)),
}
if filename_extension == ".dp":
if filename_extension in (".dp", ".hlo"):
variable_counter = Counter()
with h5py.File(filename, "w") as f:
model_dict = traverse_model_dict(
Expand Down Expand Up @@ -141,7 +141,7 @@ def load_dp_model(filename: str) -> dict:
The loaded model dict, including meta information.
"""
filename_extension = Path(filename).suffix
if filename_extension == ".dp":
if filename_extension in {".dp", ".hlo"}:
with h5py.File(filename, "r") as f:
model_dict = json.loads(f.attrs["json"])
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
Expand Down
2 changes: 2 additions & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flax import (
nnx,
)
from jax import export as jax_export

jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True)
Expand All @@ -16,4 +17,5 @@
"jax",
"jnp",
"nnx",
"jax_export",
]
1 change: 1 addition & 0 deletions deepmd/jax/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
Loading

0 comments on commit d165fee

Please sign in to comment.