Skip to content

Commit

Permalink
feat(jax): build nlist in the SavedModel & fix nopbc for StableHLO an…
Browse files Browse the repository at this point in the history
…d SavedModel (deepmodeling#4318)

Per our discussion, use TF to build the neighbor list in the SavedModel
format.

Also, fix a bug when the number of ghost atoms is zero. The
polymorphic_shape needs to be larger than 1, and `nghost == 0` triggered
the error. Previously, I also tried `nall` or `nghost - 1` but none of
them worked. Finally, I export two different functions... So now four
functions are stored in the model: calculate virial or not, x nghost is
zero or not. The tests for nopbc are added.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Enhanced model initialization with additional parameters for improved
functionality.
- Introduced functions for neighbor list management and region
transformations in molecular simulations.
- Added new methods for handling atomic virial calculations in model
predictions.
- New functions for transforming model outputs to accommodate local and
ghost atoms.

- **Bug Fixes**
- Improved error handling in model serialization and evaluation
processes.

- **Tests**
- Added comprehensive unit tests for new functionalities, ensuring
consistent behavior across different scenarios, including tests for
neighbor list construction and region transformations.

- **Chores**
	- Updated testing workflow for better organization and efficiency.
- Modified dependency management and linting configurations in
`pyproject.toml`.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 11, 2024
1 parent 023bb9c commit dcbf607
Show file tree
Hide file tree
Showing 17 changed files with 997 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
env:
NUM_WORKERS: 0
- name: Test TF2 eager mode
run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0
run: pytest --cov=deepmd --cov-append source/tests/consistent/io/test_io.py source/jax2tf_tests --durations=0
env:
NUM_WORKERS: 0
DP_TEST_TF2_ONLY: 1
Expand Down
6 changes: 6 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def __init__(
stablehlo_atomic_virial=model_data["@variables"][
"stablehlo_atomic_virial"
].tobytes(),
stablehlo_no_ghost=model_data["@variables"][
"stablehlo_no_ghost"
].tobytes(),
stablehlo_atomic_virial_no_ghost=model_data["@variables"][
"stablehlo_atomic_virial_no_ghost"
].tobytes(),
model_def_script=model_data["model_def_script"],
**model_data["constants"],
)
Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/jax2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

if not tf.executing_eagerly():
# TF disallow temporary eager execution
Expand All @@ -9,3 +10,5 @@
"If you are converting a model between different backends, "
"considering converting to the `.dp` format first."
)

tnp.experimental_enable_numpy_behavior()
110 changes: 110 additions & 0 deletions deepmd/jax/jax2tf/make_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
)

import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from deepmd.dpmodel.output_def import (
ModelOutputDef,
)
from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)
from deepmd.jax.jax2tf.region import (
normalize_coord,
)
from deepmd.jax.jax2tf.transform_output import (
communicate_extended_output,
)


def model_call_from_call_lower(
*, # enforce keyword-only arguments
call_lower: Callable[
[
tnp.ndarray,
tnp.ndarray,
tnp.ndarray,
tnp.ndarray,
tnp.ndarray,
bool,
],
dict[str, tnp.ndarray],
],
rcut: float,
sel: list[int],
mixed_types: bool,
model_output_def: ModelOutputDef,
coord: tnp.ndarray,
atype: tnp.ndarray,
box: tnp.ndarray,
fparam: tnp.ndarray,
aparam: tnp.ndarray,
do_atomic_virial: bool = False,
):
"""Return model prediction from lower interface.
Parameters
----------
coord
The coordinates of the atoms.
shape: nf x (nloc x 3)
atype
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.
Returns
-------
ret_dict
The result dict of type dict[str,tnp.ndarray].
The keys are defined by the `ModelOutputDef`.
"""
atype_shape = tf.shape(atype)
nframes, nloc = atype_shape[0], atype_shape[1]
cc, bb, fp, ap = coord, box, fparam, aparam
del coord, box, fparam, aparam
if tf.shape(bb)[-1] != 0:
coord_normalized = normalize_coord(
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = cc
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, bb, rcut
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
rcut,
sel,
distinguish_types=not mixed_types,
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
model_predict_lower = call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fp,
aparam=ap,
)
model_predict = communicate_extended_output(
model_predict_lower,
model_output_def,
mapping,
do_atomic_virial=do_atomic_virial,
)
return model_predict
217 changes: 217 additions & 0 deletions deepmd/jax/jax2tf/nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Union,
)

import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from .region import (
to_face_distance,
)


## translated from torch implementation by chatgpt
def build_neighbor_list(
coord: tnp.ndarray,
atype: tnp.ndarray,
nloc: int,
rcut: float,
sel: Union[int, list[int]],
distinguish_types: bool = True,
) -> tnp.ndarray:
"""Build neighbor list for a single frame. keeps nsel neighbors.
Parameters
----------
coord : tnp.ndarray
exptended coordinates of shape [batch_size, nall x 3]
atype : tnp.ndarray
extended atomic types of shape [batch_size, nall]
type < 0 the atom is treat as virtual atoms.
nloc : int
number of local atoms.
rcut : float
cut-off radius
sel : int or list[int]
maximal number of neighbors (of each type).
if distinguish_types==True, nsel should be list and
the length of nsel should be equal to number of
types.
distinguish_types : bool
distinguish different types.
Returns
-------
neighbor_list : tnp.ndarray
Neighbor list of shape [batch_size, nloc, nsel], the neighbors
are stored in an ascending order. If the number of
neighbors is less than nsel, the positions are masked
with -1. The neighbor list of an atom looks like
|------ nsel ------|
xx xx xx xx -1 -1 -1
if distinguish_types==True and we have two types
|---- nsel[0] -----| |---- nsel[1] -----|
xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1
For virtual atoms all neighboring positions are filled with -1.
"""
batch_size = tf.shape(coord)[0]
coord = tnp.reshape(coord, (batch_size, -1))
nall = tf.shape(coord)[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
if tf.size(coord) > 0:
xmax = tnp.max(coord) + 2.0 * rcut
else:
xmax = tf.cast(2.0 * rcut, coord.dtype)
# nf x nall
is_vir = atype < 0
coord1 = tnp.where(
is_vir[:, :, None], xmax, tnp.reshape(coord, (batch_size, nall, 3))
)
coord1 = tnp.reshape(coord1, (batch_size, nall * 3))
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
coord0 = coord1[:, : nloc * 3]
diff = (
tnp.reshape(coord1, [batch_size, -1, 3])[:, None, :, :]
- tnp.reshape(coord0, [batch_size, -1, 3])[:, :, None, :]
)
rr = tf.linalg.norm(diff, axis=-1)
# if central atom has two zero distances, sorting sometimes can not exclude itself
rr -= tf.eye(nloc, nall, dtype=diff.dtype)[tnp.newaxis, :, :]
nlist = tnp.argsort(rr, axis=-1)
rr = tnp.sort(rr, axis=-1)
rr = rr[:, :, 1:]
nlist = nlist[:, :, 1:]
nnei = tf.shape(rr)[2]
if nsel <= nnei:
rr = rr[:, :, :nsel]
nlist = nlist[:, :, :nsel]
else:
rr = tnp.concatenate(
[rr, tnp.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + rcut],
axis=-1,
)
nlist = tnp.concatenate(
[nlist, tnp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)],
axis=-1,
)
nlist = tnp.where(
tnp.logical_or((rr > rcut), is_vir[:, :nloc, None]),
tnp.full_like(nlist, -1),
nlist,
)

if distinguish_types:
return nlist_distinguish_types(nlist, atype, sel)
else:
return nlist


def nlist_distinguish_types(
nlist: tnp.ndarray,
atype: tnp.ndarray,
sel: list[int],
):
"""Given a nlist that does not distinguish atom types, return a nlist that
distinguish atom types.
"""
nloc = tf.shape(nlist)[1]
ret_nlist = []
tmp_atype = tnp.tile(atype[:, None, :], (1, nloc, 1))
mask = nlist == -1
tnlist_0 = tnp.where(mask, tnp.zeros_like(nlist), nlist)
tnlist = tnp.take_along_axis(tmp_atype, tnlist_0, axis=2)
tnlist = tnp.where(mask, tnp.full_like(tnlist, -1), tnlist)
for ii, ss in enumerate(sel):
pick_mask = tf.cast(tnlist == ii, tnp.int32)
sorted_indices = tnp.argsort(-pick_mask, kind="stable", axis=-1)
pick_mask_sorted = -tnp.sort(-pick_mask, axis=-1)
inlist = tnp.take_along_axis(nlist, sorted_indices, axis=2)
inlist = tnp.where(
~tf.cast(pick_mask_sorted, tf.bool), tnp.full_like(inlist, -1), inlist
)
ret_nlist.append(inlist[..., :ss])
ret = tf.concat(ret_nlist, axis=-1)
return ret


def tf_outer(a, b):
return tf.einsum("i,j->ij", a, b)


## translated from torch implementation by chatgpt
def extend_coord_with_ghosts(
coord: tnp.ndarray,
atype: tnp.ndarray,
cell: tnp.ndarray,
rcut: float,
):
"""Extend the coordinates of the atoms by appending peridoc images.
The number of images is large enough to ensure all the neighbors
within rcut are appended.
Parameters
----------
coord : tnp.ndarray
original coordinates of shape [-1, nloc*3].
atype : tnp.ndarray
atom type of shape [-1, nloc].
cell : tnp.ndarray
simulation cell tensor of shape [-1, 9].
rcut : float
the cutoff radius
Returns
-------
extended_coord: tnp.ndarray
extended coordinates of shape [-1, nall*3].
extended_atype: tnp.ndarray
extended atom type of shape [-1, nall].
index_mapping: tnp.ndarray
mapping extended index to the local index
"""
atype_shape = tf.shape(atype)
nf, nloc = atype_shape[0], atype_shape[1]
# int64 for index
aidx = tf.range(nloc, dtype=tnp.int64)
aidx = tnp.tile(aidx[tnp.newaxis, :], (nf, 1))
if tf.shape(cell)[-1] == 0:
nall = nloc
extend_coord = coord
extend_atype = atype
extend_aidx = aidx
else:
coord = tnp.reshape(coord, (nf, nloc, 3))
cell = tnp.reshape(cell, (nf, 3, 3))
to_face = to_face_distance(cell)
nbuff = tf.cast(tnp.ceil(rcut / to_face), tnp.int64)
nbuff = tnp.max(nbuff, axis=0)
xi = tf.range(-nbuff[0], nbuff[0] + 1, 1, dtype=tnp.int64)
yi = tf.range(-nbuff[1], nbuff[1] + 1, 1, dtype=tnp.int64)
zi = tf.range(-nbuff[2], nbuff[2] + 1, 1, dtype=tnp.int64)
xyz = tf_outer(xi, tnp.asarray([1, 0, 0]))[:, tnp.newaxis, tnp.newaxis, :]
xyz = xyz + tf_outer(yi, tnp.asarray([0, 1, 0]))[tnp.newaxis, :, tnp.newaxis, :]
xyz = xyz + tf_outer(zi, tnp.asarray([0, 0, 1]))[tnp.newaxis, tnp.newaxis, :, :]
xyz = tnp.reshape(xyz, (-1, 3))
xyz = tf.cast(xyz, coord.dtype)
shift_idx = tnp.take(xyz, tnp.argsort(tf.linalg.norm(xyz, axis=1)), axis=0)
ns = tf.shape(shift_idx)[0]
nall = ns * nloc
shift_vec = tnp.einsum("sd,fdk->fsk", shift_idx, cell)
# shift_vec = tnp.tensordot(shift_idx, cell, axes=([1], [1]))
# shift_vec = tnp.transpose(shift_vec, (1, 0, 2))
extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :]
extend_atype = tnp.tile(atype[:, :, tnp.newaxis], (1, ns, 1))
extend_aidx = tnp.tile(aidx[:, :, tnp.newaxis], (1, ns, 1))

return (
tnp.reshape(extend_coord, (nf, nall * 3)),
tnp.reshape(extend_atype, (nf, nall)),
tnp.reshape(extend_aidx, (nf, nall)),
)
Loading

0 comments on commit dcbf607

Please sign in to comment.