Skip to content

Commit

Permalink
feat(jax): energy model (no grad support) (#4226)
Browse files Browse the repository at this point in the history
Add JAX energy model without grad support. The grad support needs
discussion.
Array API is not supported in this PR as it needs more effort. (JAX has
more APIs than Array API)
This PR also fixes a `skip_tf` bug introduced in #3357. When no
`@property` was added, `xx.skip_tf` is always cast to `True`.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Enhanced `BaseAtomicModel` and `DPAtomicModel` classes with improved
array compatibility and new output definitions.
- Introduced new classes and attributes for better model flexibility and
customization.
- Added `EnergyFittingNet` and `DOSFittingNet` for advanced fitting
capabilities.
- New functions `get_standard_model` and `get_model` for flexible model
creation based on input data.
- Added `BaseDescriptor` and `BaseFitting` classes to streamline
descriptor and fitting processes.
	- Introduced `EnergyModel` class for improved atomic model handling.

- **Bug Fixes**
	- Updated serialization logic for consistency across models.

- **Tests**
- Enhanced testing framework to support JAX operations and added methods
for JAX model evaluation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 23, 2024
1 parent b4701da commit c2515ed
Show file tree
Hide file tree
Showing 20 changed files with 289 additions and 30 deletions.
20 changes: 12 additions & 8 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import math
from typing import (
Optional,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
NativeOP,
to_numpy_array,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
Expand Down Expand Up @@ -172,17 +174,18 @@ def forward_common_atomic(
ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual.
"""
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
if self.pair_excl is not None:
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = np.where(pair_mask == 1, nlist, -1)
nlist = xp.where(pair_mask == 1, nlist, -1)

ext_atom_mask = self.make_atom_mask(extended_atype)
ret_dict = self.forward_atomic(
extended_coord,
np.where(ext_atom_mask, extended_atype, 0),
xp.where(ext_atom_mask, extended_atype, 0),
nlist,
mapping=mapping,
fparam=fparam,
Expand All @@ -191,13 +194,13 @@ def forward_common_atomic(
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].astype(np.int32)
atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32)
if self.atom_excl is not None:
atom_mask *= self.atom_excl.build_type_exclude_mask(atype)

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = np.prod(out_shape[2:])
out_shape2 = math.prod(out_shape[2:])
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
* atom_mask[:, :, None]
Expand Down Expand Up @@ -232,14 +235,15 @@ def serialize(self) -> dict:
"rcond": self.rcond,
"preset_out_bias": self.preset_out_bias,
"@variables": {
"out_bias": self.out_bias,
"out_std": self.out_std,
"out_bias": to_numpy_array(self.out_bias),
"out_std": to_numpy_array(self.out_std),
},
}

@classmethod
def deserialize(cls, data: dict) -> "BaseAtomicModel":
data = copy.deepcopy(data)
# do not deep copy Descriptor and Fitting class
data = data.copy()
variables = data.pop("@variables")
obj = cls(**data)
for kk in variables.keys():
Expand Down
10 changes: 8 additions & 2 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,20 @@ def serialize(self) -> dict:
)
return dd

# for subclass overriden
base_descriptor_cls = BaseDescriptor
"""The base descriptor class."""
base_fitting_cls = BaseFitting
"""The base fitting class."""

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 2)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))
fitting_obj = BaseFitting.deserialize(data.pop("fitting"))
descriptor_obj = cls.base_descriptor_cls.deserialize(data.pop("descriptor"))
fitting_obj = cls.base_fitting_cls.deserialize(data.pop("fitting"))
data["descriptor"] = descriptor_obj
data["fitting"] = fitting_obj
obj = super().deserialize(data)
Expand Down
35 changes: 16 additions & 19 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Optional,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.atomic_model.base_atomic_model import (
Expand Down Expand Up @@ -75,7 +76,8 @@ def __init__(
else:
self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs)
self.precision_dict = PRECISION_DICT
self.reverse_precision_dict = RESERVED_PRECISON_DICT
# not supported by flax
# self.reverse_precision_dict = RESERVED_PRECISON_DICT
self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION
self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION

Expand Down Expand Up @@ -253,9 +255,7 @@ def input_type_cast(
str,
]:
"""Cast the input data to global float type."""
input_prec = self.reverse_precision_dict[
self.precision_dict[coord.dtype.name]
]
input_prec = RESERVED_PRECISON_DICT[self.precision_dict[coord.dtype.name]]
###
### type checking would not pass jit, convert to coord prec anyway
###
Expand All @@ -264,10 +264,7 @@ def input_type_cast(
for vv in [box, fparam, aparam]
]
box, fparam, aparam = _lst
if (
input_prec
== self.reverse_precision_dict[self.global_np_float_precision]
):
if input_prec == RESERVED_PRECISON_DICT[self.global_np_float_precision]:
return coord, box, fparam, aparam, input_prec
else:
pp = self.global_np_float_precision
Expand All @@ -286,8 +283,7 @@ def output_type_cast(
) -> dict[str, np.ndarray]:
"""Convert the model output to the input prec."""
do_cast = (
input_prec
!= self.reverse_precision_dict[self.global_np_float_precision]
input_prec != RESERVED_PRECISON_DICT[self.global_np_float_precision]
)
pp = self.precision_dict[input_prec]
odef = self.model_output_def()
Expand Down Expand Up @@ -366,17 +362,18 @@ def _format_nlist(
nnei: int,
extra_nlist_sort: bool = False,
):
xp = array_api_compat.array_namespace(extended_coord, nlist)
n_nf, n_nloc, n_nnei = nlist.shape
extended_coord = extended_coord.reshape([n_nf, -1, 3])
nall = extended_coord.shape[1]
rcut = self.get_rcut()

if n_nnei < nnei:
# make a copy before revise
ret = np.concatenate(
ret = xp.concat(
[
nlist,
-1 * np.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype),
-1 * xp.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype),
],
axis=-1,
)
Expand All @@ -385,16 +382,16 @@ def _format_nlist(
n_nf, n_nloc, n_nnei = nlist.shape
# make a copy before revise
m_real_nei = nlist >= 0
ret = np.where(m_real_nei, nlist, 0)
ret = xp.where(m_real_nei, nlist, 0)
coord0 = extended_coord[:, :n_nloc, :]
index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2)
coord1 = np.take_along_axis(extended_coord, index, axis=1)
coord1 = xp.take_along_axis(extended_coord, index, axis=1)
coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3)
rr = np.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1)
rr = np.where(m_real_nei, rr, float("inf"))
rr, ret_mapping = np.sort(rr, axis=-1), np.argsort(rr, axis=-1)
ret = np.take_along_axis(ret, ret_mapping, axis=2)
ret = np.where(rr > rcut, -1, ret)
rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1)
rr = xp.where(m_real_nei, rr, float("inf"))
rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1)
ret = xp.take_along_axis(ret, ret_mapping, axis=2)
ret = xp.where(rr > rcut, -1, ret)
ret = ret[..., :nnei]
# not extra_nlist_sort and n_nnei <= nnei:
elif n_nnei == nnei:
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
Expand All @@ -23,6 +24,7 @@ def fit_output_to_model_output(
the model output.
"""
xp = array_api_compat.get_namespace(coord_ext)
model_ret = dict(fit_ret.items())
for kk, vv in fit_ret.items():
vdef = fit_output_def[kk]
Expand All @@ -31,7 +33,7 @@ def fit_output_to_model_output(
if vdef.reducible:
kk_redu = get_reduce_name(kk)
# cast to energy prec brefore reduction
model_ret[kk_redu] = np.sum(
model_ret[kk_redu] = xp.sum(
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
)
if vdef.r_differentiable:
Expand Down
1 change: 1 addition & 0 deletions deepmd/jax/atomic_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
18 changes: 18 additions & 0 deletions deepmd/jax/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.common import (
to_jax_array,
)
from deepmd.jax.utils.exclude_mask import (
AtomExcludeMask,
PairExcludeMask,
)


def base_atomic_model_set_attr(name, value):
if name in {"out_bias", "out_std"}:
value = to_jax_array(value)
elif name == "pair_excl" and value is not None:
value = PairExcludeMask(value.ntypes, value.exclude_types)
elif name == "atom_excl" and value is not None:
value = AtomExcludeMask(value.ntypes, value.exclude_types)
return value
30 changes: 30 additions & 0 deletions deepmd/jax/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP
from deepmd.jax.atomic_model.base_atomic_model import (
base_atomic_model_set_attr,
)
from deepmd.jax.common import (
flax_module,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.fitting.base_fitting import (
BaseFitting,
)


@flax_module
class DPAtomicModel(DPAtomicModelDP):
base_descriptor_cls = BaseDescriptor
"""The base descriptor class."""
base_fitting_cls = BaseFitting
"""The base fitting class."""

def __setattr__(self, name: str, value: Any) -> None:
value = base_atomic_model_set_attr(name, value)
return super().__setattr__(name, value)
11 changes: 11 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.descriptor.dpa1 import (
DescrptDPA1,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)

__all__ = [
"DescrptSeA",
"DescrptDPA1",
]
9 changes: 9 additions & 0 deletions deepmd/jax/descriptor/base_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.descriptor.make_base_descriptor import (
make_base_descriptor,
)
from deepmd.jax.env import (
jnp,
)

BaseDescriptor = make_base_descriptor(jnp.ndarray)
5 changes: 5 additions & 0 deletions deepmd/jax/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
Expand Down Expand Up @@ -76,6 +79,8 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


@BaseDescriptor.register("dpa1")
@BaseDescriptor.register("se_atten")
@flax_module
class DescrptDPA1(DescrptDPA1DP):
def __setattr__(self, name: str, value: Any) -> None:
Expand Down
5 changes: 5 additions & 0 deletions deepmd/jax/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
Expand All @@ -16,6 +19,8 @@
)


@BaseDescriptor.register("se_e2_a")
@BaseDescriptor.register("se_a")
@flax_module
class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
Expand Down
9 changes: 9 additions & 0 deletions deepmd/jax/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.fitting.fitting import (
DOSFittingNet,
EnergyFittingNet,
)

__all__ = [
"EnergyFittingNet",
"DOSFittingNet",
]
9 changes: 9 additions & 0 deletions deepmd/jax/fitting/base_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.fitting.make_base_fitting import (
make_base_fitting,
)
from deepmd.jax.env import (
jnp,
)

BaseFitting = make_base_fitting(jnp.ndarray)
5 changes: 5 additions & 0 deletions deepmd/jax/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
flax_module,
to_jax_array,
)
from deepmd.jax.fitting.base_fitting import (
BaseFitting,
)
from deepmd.jax.utils.exclude_mask import (
AtomExcludeMask,
)
Expand All @@ -33,13 +36,15 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
return value


@BaseFitting.register("ener")
@flax_module
class EnergyFittingNet(EnergyFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


@BaseFitting.register("dos")
@flax_module
class DOSFittingNet(DOSFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
Expand Down
6 changes: 6 additions & 0 deletions deepmd/jax/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .ener_model import (
EnergyModel,
)

__all__ = ["EnergyModel"]
6 changes: 6 additions & 0 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.model.base_model import (
make_base_model,
)

BaseModel = make_base_model()
Loading

0 comments on commit c2515ed

Please sign in to comment.