From 0cee8872dd3f99d0c3b438bfba09c9cca4917e4a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 09:05:33 -0400 Subject: [PATCH 01/10] fix(lmp): update print_summary (#4271) 1. Remove out-of-date float prec; 2. Include PyTorch libraries and include directories. ## Summary by CodeRabbit - **New Features** - Updated output messages for build information to enhance clarity, transitioning from TensorFlow-specific references to backend-oriented configurations. - **Bug Fixes** - Improved handling of backend include directories and library paths for better compatibility. - **Documentation** - Enhanced clarity in build information outputs related to backend configurations. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- source/lmp/deepmd_version.h.in | 8 ++++---- source/lmp/pair_deepmd.cpp | 6 ++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/source/lmp/deepmd_version.h.in b/source/lmp/deepmd_version.h.in index 4b99bc7c33..0b74491778 100644 --- a/source/lmp/deepmd_version.h.in +++ b/source/lmp/deepmd_version.h.in @@ -3,8 +3,8 @@ #define GIT_BRANCH @GIT_BRANCH@ #define GIT_DATE @GIT_DATE@ #define DEEPMD_ROOT @CMAKE_INSTALL_PREFIX@ -#define TensorFlow_INCLUDE_DIRS @TensorFlow_INCLUDE_DIRS@ -#define TensorFlow_LIBRARY @TensorFlow_LIBRARY@ +#define BACKEND_INCLUDE_DIRS @BACKEND_INCLUDE_DIRS@ +#define BACKEND_LIBRARY_PATH @BACKEND_LIBRARY_PATH@ #define DPMD_CVT_STR(...) #__VA_ARGS__ #define DPMD_CVT_ASSTR(X) DPMD_CVT_STR(X) #define STR_GIT_SUMM DPMD_CVT_ASSTR(GIT_SUMM) @@ -13,5 +13,5 @@ #define STR_GIT_DATE DPMD_CVT_ASSTR(GIT_DATE) #define STR_FLOAT_PREC DPMD_CVT_ASSTR(FLOAT_PREC) #define STR_DEEPMD_ROOT DPMD_CVT_ASSTR(DEEPMD_ROOT) -#define STR_TensorFlow_INCLUDE_DIRS DPMD_CVT_ASSTR(TensorFlow_INCLUDE_DIRS) -#define STR_TensorFlow_LIBRARY DPMD_CVT_ASSTR(TensorFlow_LIBRARY) +#define STR_BACKEND_INCLUDE_DIRS DPMD_CVT_ASSTR(BACKEND_INCLUDE_DIRS) +#define STR_BACKEND_LIBRARY_PATH DPMD_CVT_ASSTR(BACKEND_LIBRARY_PATH) diff --git a/source/lmp/pair_deepmd.cpp b/source/lmp/pair_deepmd.cpp index 09d97fe460..d741814aa5 100644 --- a/source/lmp/pair_deepmd.cpp +++ b/source/lmp/pair_deepmd.cpp @@ -437,10 +437,8 @@ void PairDeepMD::print_summary(const string pre) const { cout << pre << "source branch: " << STR_GIT_BRANCH << endl; cout << pre << "source commit: " << STR_GIT_HASH << endl; cout << pre << "source commit at: " << STR_GIT_DATE << endl; - cout << pre << "build float prec: " << STR_FLOAT_PREC << endl; - cout << pre << "build with tf inc: " << STR_TensorFlow_INCLUDE_DIRS - << endl; - cout << pre << "build with tf lib: " << STR_TensorFlow_LIBRARY << endl; + cout << pre << "build with inc: " << STR_BACKEND_INCLUDE_DIRS << endl; + cout << pre << "build with lib: " << STR_BACKEND_LIBRARY_PATH << endl; std::cout.rdbuf(sbuf); utils::logmesg(lmp, buffer.str()); From 737f7c8bb77a1a32f76f1f2d72d7099a95db2fc7 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 09:11:43 -0400 Subject: [PATCH 02/10] feat(jax/array-api): hybrid descriptor (#4275) ## Summary by CodeRabbit - **New Features** - Introduced support for the JAX backend in the hybrid descriptor framework. - Added a new `DescrptHybrid` class with specialized attribute handling. - Enhanced testing framework to support additional backends, including JAX and strict array API. - **Bug Fixes** - Improved attribute handling in multiple descriptor classes to ensure proper deserialization and registration. - **Documentation** - Updated documentation to reflect the addition of JAX as a supported backend for hybrid descriptors. --------- Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/hybrid.py | 13 ++++--- deepmd/jax/descriptor/__init__.py | 4 +++ deepmd/jax/descriptor/hybrid.py | 26 ++++++++++++++ doc/model/train-hybrid.md | 4 +-- .../array_api_strict/descriptor/__init__.py | 19 ++++++++++ .../descriptor/base_descriptor.py | 11 ++++++ .../tests/array_api_strict/descriptor/dpa1.py | 5 +++ .../array_api_strict/descriptor/hybrid.py | 24 +++++++++++++ .../array_api_strict/descriptor/se_e2_a.py | 5 +++ .../array_api_strict/descriptor/se_e2_r.py | 5 +++ .../consistent/descriptor/test_hybrid.py | 35 +++++++++++++++++++ 11 files changed, 144 insertions(+), 7 deletions(-) create mode 100644 deepmd/jax/descriptor/hybrid.py create mode 100644 source/tests/array_api_strict/descriptor/base_descriptor.py create mode 100644 source/tests/array_api_strict/descriptor/hybrid.py diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index 4eb14f29cf..0d89902e4a 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( @@ -66,7 +67,7 @@ def __init__( ), f"number of atom types in {ii}th descriptor {self.descrpt_list[0].__class__.__name__} does not match others" # if hybrid sel is larger than sub sel, the nlist needs to be cut for each type hybrid_sel = self.get_sel() - self.nlist_cut_idx: list[np.ndarray] = [] + nlist_cut_idx: list[np.ndarray] = [] if self.mixed_types() and not all( descrpt.mixed_types() for descrpt in self.descrpt_list ): @@ -92,7 +93,8 @@ def __init__( cut_idx = np.concatenate( [range(ss, ee) for ss, ee in zip(start_idx, end_idx)] ) - self.nlist_cut_idx.append(cut_idx) + nlist_cut_idx.append(cut_idx) + self.nlist_cut_idx = nlist_cut_idx def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -242,6 +244,7 @@ def call( sw The smooth switch function. """ + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) out_descriptor = [] out_gr = [] out_g2 = None @@ -258,7 +261,7 @@ def call( for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx): # cut the nlist to the correct length if self.mixed_types() == descrpt.mixed_types(): - nl = nlist[:, :, nci] + nl = xp.take(nlist, nci, axis=2) else: # mixed_types is True, but descrpt.mixed_types is False assert nl_distinguish_types is not None @@ -268,8 +271,8 @@ def call( if gr is not None: out_gr.append(gr) - out_descriptor = np.concatenate(out_descriptor, axis=-1) - out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None + out_descriptor = xp.concat(out_descriptor, axis=-1) + out_gr = xp.concat(out_gr, axis=-2) if out_gr else None return out_descriptor, out_gr, out_g2, out_h2, out_sw @classmethod diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index 3ed096f9c1..cabee5a189 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -2,6 +2,9 @@ from deepmd.jax.descriptor.dpa1 import ( DescrptDPA1, ) +from deepmd.jax.descriptor.hybrid import ( + DescrptHybrid, +) from deepmd.jax.descriptor.se_e2_a import ( DescrptSeA, ) @@ -13,4 +16,5 @@ "DescrptSeA", "DescrptSeR", "DescrptDPA1", + "DescrptHybrid", ] diff --git a/deepmd/jax/descriptor/hybrid.py b/deepmd/jax/descriptor/hybrid.py new file mode 100644 index 0000000000..20fc5f838b --- /dev/null +++ b/deepmd/jax/descriptor/hybrid.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("hybrid") +@flax_module +class DescrptHybrid(DescrptHybridDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"nlist_cut_idx"}: + value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value] + elif name in {"descrpt_list"}: + value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value] + + return super().__setattr__(name, value) diff --git a/doc/model/train-hybrid.md b/doc/model/train-hybrid.md index 1219d208a7..da3b40487b 100644 --- a/doc/model/train-hybrid.md +++ b/doc/model/train-hybrid.md @@ -1,7 +1,7 @@ -# Descriptor `"hybrid"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Descriptor `"hybrid"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: This descriptor hybridizes multiple descriptors to form a new descriptor. For example, we have a list of descriptors denoted by $\mathcal D_1$, $\mathcal D_2$, ..., $\mathcal D_N$, the hybrid descriptor this the concatenation of the list, i.e. $\mathcal D = (\mathcal D_1, \mathcal D_2, \cdots, \mathcal D_N)$. diff --git a/source/tests/array_api_strict/descriptor/__init__.py b/source/tests/array_api_strict/descriptor/__init__.py index 6ceb116d85..5667fed858 100644 --- a/source/tests/array_api_strict/descriptor/__init__.py +++ b/source/tests/array_api_strict/descriptor/__init__.py @@ -1 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .dpa1 import ( + DescrptDPA1, +) +from .hybrid import ( + DescrptHybrid, +) +from .se_e2_a import ( + DescrptSeA, +) +from .se_e2_r import ( + DescrptSeR, +) + +__all__ = [ + "DescrptSeA", + "DescrptSeR", + "DescrptDPA1", + "DescrptHybrid", +] diff --git a/source/tests/array_api_strict/descriptor/base_descriptor.py b/source/tests/array_api_strict/descriptor/base_descriptor.py new file mode 100644 index 0000000000..2a31895f55 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/base_descriptor.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.make_base_descriptor import ( + make_base_descriptor, +) + +# no type annotations standard in array api +BaseDescriptor = make_base_descriptor(Any) diff --git a/source/tests/array_api_strict/descriptor/dpa1.py b/source/tests/array_api_strict/descriptor/dpa1.py index ebd688e303..d14444f269 100644 --- a/source/tests/array_api_strict/descriptor/dpa1.py +++ b/source/tests/array_api_strict/descriptor/dpa1.py @@ -27,6 +27,9 @@ from ..utils.type_embed import ( TypeEmbedNet, ) +from .base_descriptor import ( + BaseDescriptor, +) class GatedAttentionLayer(GatedAttentionLayerDP): @@ -72,6 +75,8 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseDescriptor.register("dpa1") +@BaseDescriptor.register("se_atten") class DescrptDPA1(DescrptDPA1DP): def __setattr__(self, name: str, value: Any) -> None: if name == "se_atten": diff --git a/source/tests/array_api_strict/descriptor/hybrid.py b/source/tests/array_api_strict/descriptor/hybrid.py new file mode 100644 index 0000000000..aaaa24ed6b --- /dev/null +++ b/source/tests/array_api_strict/descriptor/hybrid.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP + +from ..common import ( + to_array_api_strict_array, +) +from .base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("hybrid") +class DescrptHybrid(DescrptHybridDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"nlist_cut_idx"}: + value = [to_array_api_strict_array(vv) for vv in value] + elif name in {"descrpt_list"}: + value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value] + + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/descriptor/se_e2_a.py b/source/tests/array_api_strict/descriptor/se_e2_a.py index 654b9f8925..17da2aafbf 100644 --- a/source/tests/array_api_strict/descriptor/se_e2_a.py +++ b/source/tests/array_api_strict/descriptor/se_e2_a.py @@ -14,8 +14,13 @@ from ..utils.network import ( NetworkCollection, ) +from .base_descriptor import ( + BaseDescriptor, +) +@BaseDescriptor.register("se_e2_a") +@BaseDescriptor.register("se_a") class DescrptSeA(DescrptSeADP): def __setattr__(self, name: str, value: Any) -> None: if name in {"dstd", "davg"}: diff --git a/source/tests/array_api_strict/descriptor/se_e2_r.py b/source/tests/array_api_strict/descriptor/se_e2_r.py index 839e536cea..b499f4c4c9 100644 --- a/source/tests/array_api_strict/descriptor/se_e2_r.py +++ b/source/tests/array_api_strict/descriptor/se_e2_r.py @@ -14,8 +14,13 @@ from ..utils.network import ( NetworkCollection, ) +from .base_descriptor import ( + BaseDescriptor, +) +@BaseDescriptor.register("se_e2_r") +@BaseDescriptor.register("se_r") class DescrptSeR(DescrptSeRDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"dstd", "davg"}: diff --git a/source/tests/consistent/descriptor/test_hybrid.py b/source/tests/consistent/descriptor/test_hybrid.py index cd52eea5be..c43652b498 100644 --- a/source/tests/consistent/descriptor/test_hybrid.py +++ b/source/tests/consistent/descriptor/test_hybrid.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -28,6 +30,16 @@ from deepmd.tf.descriptor.hybrid import DescrptHybrid as DescrptHybridTF else: DescrptHybridTF = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.hybrid import DescrptHybrid as DescrptHybridJAX +else: + DescrptHybridJAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.hybrid import ( + DescrptHybrid as DescrptHybridStrict, + ) +else: + DescrptHybridStrict = None from deepmd.utils.argcheck import ( descrpt_hybrid_args, ) @@ -68,8 +80,13 @@ def data(self) -> dict: tf_class = DescrptHybridTF dp_class = DescrptHybridDP pt_class = DescrptHybridPT + jax_class = DescrptHybridJAX + array_api_strict_class = DescrptHybridStrict args = descrpt_hybrid_args() + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + def setUp(self): CommonTest.setUp(self) @@ -132,5 +149,23 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],) From 8e27d2f65a4b4064fc9b8b34c603324e1eee1872 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 09:13:15 -0400 Subject: [PATCH 03/10] feat(jax/array-api): dipole/polarizability fitting (#4278) ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced `DipoleFittingNet` and `PolarFittingNet` classes for enhanced fitting functionality. - Expanded support for JAX as a backend for fitting tensors, alongside existing TensorFlow and PyTorch support. - **Bug Fixes** - Improved error handling and parameter validation in the `DipoleFitting` and `PolarFitting` classes. - **Documentation** - Updated documentation to reflect JAX as a supported backend for fitting tensors. - **Tests** - Enhanced testing framework to support evaluations with JAX and Array API Strict, including new test methods and properties. Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/dipole_fitting.py | 10 ++- .../dpmodel/fitting/polarizability_fitting.py | 71 +++++++++++-------- deepmd/jax/fitting/__init__.py | 4 ++ deepmd/jax/fitting/fitting.py | 27 +++++++ doc/model/train-fitting-tensor.md | 4 +- .../tests/array_api_strict/fitting/fitting.py | 21 ++++++ .../tests/consistent/fitting/test_dipole.py | 41 +++++++++++ source/tests/consistent/fitting/test_polar.py | 41 +++++++++++ 8 files changed, 184 insertions(+), 35 deletions(-) diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 01bd60c777..cecba865d0 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( @@ -207,6 +208,7 @@ def call( The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` """ + xp = array_api_compat.array_namespace(descriptor, atype) nframes, nloc, _ = descriptor.shape assert gr is not None, "Must provide the rotation matrix for dipole fitting." # (nframes, nloc, m1) @@ -214,9 +216,11 @@ def call( self.var_name ] # (nframes * nloc, 1, m1) - out = out.reshape(-1, 1, self.embedding_width) + out = xp.reshape(out, (-1, 1, self.embedding_width)) # (nframes * nloc, m1, 3) - gr = gr.reshape(nframes * nloc, -1, 3) + gr = xp.reshape(gr, (nframes * nloc, -1, 3)) # (nframes, nloc, 3) - out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3) + # out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3) + out = out @ gr + out = xp.reshape(out, (nframes, nloc, 3)) return {self.var_name: out} diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 2d96eec580..b972b45971 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.common import ( @@ -14,6 +15,9 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.fitting.base_fitting import ( BaseFitting, ) @@ -124,23 +128,18 @@ def __init__( self.embedding_width = embedding_width self.fit_diag = fit_diag - self.scale = scale - if self.scale is None: - self.scale = [1.0 for _ in range(ntypes)] + if scale is None: + scale = [1.0 for _ in range(ntypes)] else: - if isinstance(self.scale, list): - assert ( - len(self.scale) == ntypes - ), "Scale should be a list of length ntypes." - elif isinstance(self.scale, float): - self.scale = [self.scale for _ in range(ntypes)] + if isinstance(scale, list): + assert len(scale) == ntypes, "Scale should be a list of length ntypes." + elif isinstance(scale, float): + scale = [scale for _ in range(ntypes)] else: raise ValueError( "Scale must be a list of float of length ntypes or a float." ) - self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape( - ntypes, 1 - ) + self.scale = np.array(scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(ntypes, 1) self.shift_diag = shift_diag self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION) super().__init__( @@ -192,8 +191,8 @@ def serialize(self) -> dict: data["embedding_width"] = self.embedding_width data["fit_diag"] = self.fit_diag data["shift_diag"] = self.shift_diag - data["@variables"]["scale"] = self.scale - data["@variables"]["constant_matrix"] = self.constant_matrix + data["@variables"]["scale"] = to_numpy_array(self.scale) + data["@variables"]["constant_matrix"] = to_numpy_array(self.constant_matrix) return data @classmethod @@ -276,6 +275,7 @@ def call( The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` """ + xp = array_api_compat.array_namespace(descriptor, atype) nframes, nloc, _ = descriptor.shape assert ( gr is not None @@ -284,28 +284,39 @@ def call( out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] - out = out * self.scale[atype] + # out = out * self.scale[atype, ...] + scale_atype = xp.reshape( + xp.take(self.scale, xp.reshape(atype, [-1]), axis=0), (*atype.shape, 1) + ) + out = out * scale_atype # (nframes * nloc, m1, 3) - gr = gr.reshape(nframes * nloc, -1, 3) + gr = xp.reshape(gr, (nframes * nloc, -1, 3)) if self.fit_diag: - out = out.reshape(-1, self.embedding_width) - out = np.einsum("ij,ijk->ijk", out, gr) + out = xp.reshape(out, (-1, self.embedding_width)) + # out = np.einsum("ij,ijk->ijk", out, gr) + out = out[:, :, None] * gr else: - out = out.reshape(-1, self.embedding_width, self.embedding_width) - out = (out + np.transpose(out, axes=(0, 2, 1))) / 2 - out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) - out = np.einsum( - "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out - ) # (nframes * nloc, 3, 3) - out = out.reshape(nframes, nloc, 3, 3) + out = xp.reshape(out, (-1, self.embedding_width, self.embedding_width)) + out = (out + xp.matrix_transpose(out)) / 2 + # out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3) + out = out @ gr + # out = np.einsum( + # "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out + # ) # (nframes * nloc, 3, 3) + out = xp.matrix_transpose(gr) @ out + out = xp.reshape(out, (nframes, nloc, 3, 3)) if self.shift_diag: - bias = self.constant_matrix[atype] + # bias = self.constant_matrix[atype] + bias = xp.reshape( + xp.take(self.constant_matrix, xp.reshape(atype, [-1]), axis=0), + (nframes, nloc), + ) # (nframes, nloc, 1) - bias = np.expand_dims(bias, axis=-1) * self.scale[atype] - eye = np.eye(3, dtype=descriptor.dtype) - eye = np.tile(eye, (nframes, nloc, 1, 1)) + bias = bias[..., None] * scale_atype + eye = xp.eye(3, dtype=descriptor.dtype) + eye = xp.tile(eye, (nframes, nloc, 1, 1)) # (nframes, nloc, 3, 3) - bias = np.expand_dims(bias, axis=-1) * eye + bias = bias[..., None] * eye out = out + bias return {"polarizability": out} diff --git a/deepmd/jax/fitting/__init__.py b/deepmd/jax/fitting/__init__.py index e72314dcab..226a6d5b43 100644 --- a/deepmd/jax/fitting/__init__.py +++ b/deepmd/jax/fitting/__init__.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from deepmd.jax.fitting.fitting import ( + DipoleFittingNet, DOSFittingNet, EnergyFittingNet, + PolarFittingNet, ) __all__ = [ "EnergyFittingNet", "DOSFittingNet", + "DipoleFittingNet", + "PolarFittingNet", ] diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index cef1f667b3..2a6186ac46 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -3,8 +3,12 @@ Any, ) +from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.dpmodel.fitting.polarizability_fitting import ( + PolarFitting as PolarFittingNetDP, +) from deepmd.jax.common import ( ArrayAPIVariable, flax_module, @@ -53,3 +57,26 @@ class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) return super().__setattr__(name, value) + + +@BaseFitting.register("dipole") +@flax_module +class DipoleFittingNet(DipoleFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + +@BaseFitting.register("polar") +@flax_module +class PolarFittingNet(PolarFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + if name in { + "scale", + "constant_matrix", + }: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + return super().__setattr__(name, value) diff --git a/doc/model/train-fitting-tensor.md b/doc/model/train-fitting-tensor.md index c6b54c69ef..d4d546eccf 100644 --- a/doc/model/train-fitting-tensor.md +++ b/doc/model/train-fitting-tensor.md @@ -1,7 +1,7 @@ -# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: Unlike `energy`, which is a scalar, one may want to fit some high dimensional physical quantity, like `dipole` (vector) and `polarizability` (matrix, shorted as `polar`). Deep Potential has provided different APIs to do this. In this example, we will show you how to train a model to fit a water system. A complete training input script of the examples can be found in diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py index 8b65320203..5a2bd9c58f 100644 --- a/source/tests/array_api_strict/fitting/fitting.py +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -3,8 +3,12 @@ Any, ) +from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.dpmodel.fitting.polarizability_fitting import ( + PolarFitting as PolarFittingNetDP, +) from ..common import ( to_array_api_strict_array, @@ -43,3 +47,20 @@ class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) return super().__setattr__(name, value) + + +class DipoleFittingNet(DipoleFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + +class PolarFittingNet(PolarFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + if name in { + "scale", + "constant_matrix", + }: + value = to_array_api_strict_array(value) + return super().__setattr__(name, value) diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 5d7be1b0e5..55d6c44c34 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -32,6 +34,21 @@ from deepmd.tf.fit.dipole import DipoleFittingSeA as DipoleFittingTF else: DipoleFittingTF = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import DipoleFittingNet as DipoleFittingJAX +else: + DipoleFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + DipoleFittingNet as DipoleFittingArrayAPIStrict, + ) +else: + DipoleFittingArrayAPIStrict = object from deepmd.utils.argcheck import ( fitting_dipole, ) @@ -69,7 +86,11 @@ def skip_pt(self) -> bool: tf_class = DipoleFittingTF dp_class = DipoleFittingDP pt_class = DipoleFittingPT + jax_class = DipoleFittingJAX + array_api_strict_class = DipoleFittingArrayAPIStrict args = fitting_dipole() + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) @@ -143,6 +164,26 @@ def eval_dp(self, dp_obj: Any) -> Any: None, )["dipole"] + def eval_jax(self, jax_obj: Any) -> Any: + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + jnp.asarray(self.gr), + None, + )["dipole"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + array_api_strict.asarray(self.gr), + None, + )["dipole"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 6a3465ba24..895974baf9 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -32,6 +34,21 @@ from deepmd.tf.fit.polar import PolarFittingSeA as PolarFittingTF else: PolarFittingTF = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import PolarFittingNet as PolarFittingJAX +else: + PolarFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + PolarFittingNet as PolarFittingArrayAPIStrict, + ) +else: + PolarFittingArrayAPIStrict = object from deepmd.utils.argcheck import ( fitting_polar, ) @@ -69,7 +86,11 @@ def skip_pt(self) -> bool: tf_class = PolarFittingTF dp_class = PolarFittingDP pt_class = PolarFittingPT + jax_class = PolarFittingJAX + array_api_strict_class = PolarFittingArrayAPIStrict args = fitting_polar() + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) @@ -143,6 +164,26 @@ def eval_dp(self, dp_obj: Any) -> Any: None, )["polarizability"] + def eval_jax(self, jax_obj: Any) -> Any: + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + jnp.asarray(self.gr), + None, + )["polarizability"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + array_api_strict.asarray(self.gr), + None, + )["polarizability"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same From cdad3122d3a872fd811dc19a5007a68f6a34d0bc Mon Sep 17 00:00:00 2001 From: Chenqqian Zhang <100290172+Chengqian-Zhang@users.noreply.github.com> Date: Thu, 31 Oct 2024 21:15:27 +0800 Subject: [PATCH 04/10] fix(dptest): Wrong dptest results except for energy head (#4280) Solve issue #4249 In `/deepmd/entrypoints/test.py` line 127 `tmap = dp.get_type_map() if isinstance(dp, DeepPot) else None`. If we use `DeepProperty` or `DeepPolar` or `DeepDOS`..... `tmap` is None. So `type_map` is not modified again according to `type_map` in `input.json`. So `atype` is wrong in the model forward process. The model prediction value is wrong. According to @njzjz , It seems that in the `r2` branch, `get_type_map()` is only available in `DeepPot`. After we refactor `DeepEval` in v3, this should not be a problem anymore. I also change the order of `type_map` in UT to ensure that when `type_map` of `input.json` doesn't match the `type_map ` of data, the result of the dptest is still correct. ## Summary by CodeRabbit - **New Features** - Introduced warning logs for unsupported `DeepGlobalPolar` model usage, recommending the `DeepPolar` model instead. - **Bug Fixes** - Simplified logic for obtaining the type map, ensuring consistent retrieval from the updated source. - Adjusted model configuration in tests to influence type interpretation. - **Documentation** - Improved clarity of comments and logging statements for better understanding. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/entrypoints/test.py | 2 +- source/tests/pt/test_dp_test.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index d9ccf392f5..fd0393c914 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -124,7 +124,7 @@ def test( log.info(f"# testing system : {system}") # create data class - tmap = dp.get_type_map() if isinstance(dp, DeepPot) else None + tmap = dp.get_type_map() data = DeepmdData( system, set_prefix="set", diff --git a/source/tests/pt/test_dp_test.py b/source/tests/pt/test_dp_test.py index c18c3286f6..0427f2b14a 100644 --- a/source/tests/pt/test_dp_test.py +++ b/source/tests/pt/test_dp_test.py @@ -152,6 +152,9 @@ def setUp(self): self.config["training"]["training_data"]["systems"] = data_file self.config["training"]["validation_data"]["systems"] = data_file self.config["model"] = deepcopy(model_property) + self.config["model"]["type_map"] = [ + self.config["model"]["type_map"][i] for i in [1, 0, 3, 2] + ] self.input_json = "test_dp_test_property.json" with open(self.input_json, "w") as fp: json.dump(self.config, fp, indent=4) From 0d13911bbcc36bab42d18d62c7a62bd3fa4a8004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yifan=20Li=E6=9D=8E=E4=B8=80=E5=B8=86?= Date: Thu, 31 Oct 2024 09:15:49 -0400 Subject: [PATCH 05/10] Print the reminder for the illegal memory error in the AutoBatchSize under tf (#4283) #3822 added a reminder for the illegal memory error. However, this reminder is only needed for tf. This PR moves the illegal memory reminder from base class AutoBatchSize to the inherited class under tf. ## Summary by CodeRabbit - **New Features** - Enhanced `AutoBatchSize` class to initialize batch size from an environment variable, improving user guidance on memory management with TensorFlow. - **Bug Fixes** - Removed redundant logging during initialization to streamline the process when GPU resources are available. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/tf/utils/batch_size.py | 16 ++++++++++++++++ deepmd/utils/batch_size.py | 5 ----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/deepmd/tf/utils/batch_size.py b/deepmd/tf/utils/batch_size.py index 33f1ec0da0..438bf36703 100644 --- a/deepmd/tf/utils/batch_size.py +++ b/deepmd/tf/utils/batch_size.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import os + from packaging.version import ( Version, ) @@ -11,9 +13,23 @@ OutOfMemoryError, ) from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase +from deepmd.utils.batch_size import ( + log, +) class AutoBatchSize(AutoBatchSizeBase): + def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: + super().__init__(initial_batch_size, factor) + DP_INFER_BATCH_SIZE = int(os.environ.get("DP_INFER_BATCH_SIZE", 0)) + if not DP_INFER_BATCH_SIZE > 0: + if self.is_gpu_available(): + log.info( + "If you encounter the error 'an illegal memory access was encountered', this may be due to a TensorFlow issue. " + "To avoid this, set the environment variable DP_INFER_BATCH_SIZE to a smaller value than the last adjusted batch size. " + "The environment variable DP_INFER_BATCH_SIZE controls the inference batch size (nframes * natoms). " + ) + def is_gpu_available(self) -> bool: """Check if GPU is available. diff --git a/deepmd/utils/batch_size.py b/deepmd/utils/batch_size.py index 259fe93bdb..5ab06e55e2 100644 --- a/deepmd/utils/batch_size.py +++ b/deepmd/utils/batch_size.py @@ -61,11 +61,6 @@ def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: self.maximum_working_batch_size = initial_batch_size if self.is_gpu_available(): self.minimal_not_working_batch_size = 2**31 - log.info( - "If you encounter the error 'an illegal memory access was encountered', this may be due to a TensorFlow issue. " - "To avoid this, set the environment variable DP_INFER_BATCH_SIZE to a smaller value than the last adjusted batch size. " - "The environment variable DP_INFER_BATCH_SIZE controls the inference batch size (nframes * natoms). " - ) else: self.minimal_not_working_batch_size = ( self.maximum_working_batch_size + 1 From 9c767adfba8a56c92501591199db60af11d53f94 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 09:19:12 -0400 Subject: [PATCH 06/10] feat(dpmodel/jax): add entry point for dpmodel and jax backend (#4284) ## Summary by CodeRabbit - **New Features** - Introduced entry point loading functionality for enhanced module initialization in both `dpmodel` and `jax` components of the DeepMD framework. These changes improve the framework's functionality and streamline backend configuration. Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/__init__.py | 7 +++++++ deepmd/jax/__init__.py | 6 ++++++ 2 files changed, 13 insertions(+) diff --git a/deepmd/dpmodel/__init__.py b/deepmd/dpmodel/__init__.py index 6f83f849a3..111c2d6ced 100644 --- a/deepmd/dpmodel/__init__.py +++ b/deepmd/dpmodel/__init__.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.utils.entry_point import ( + load_entry_point, +) + from .common import ( DEFAULT_PRECISION, PRECISION_DICT, @@ -32,3 +36,6 @@ "get_deriv_name", "get_hessian_name", ] + + +load_entry_point("deepmd.dpmodel") diff --git a/deepmd/jax/__init__.py b/deepmd/jax/__init__.py index 2ff078e797..bb5c0a5206 100644 --- a/deepmd/jax/__init__.py +++ b/deepmd/jax/__init__.py @@ -1,2 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """JAX backend.""" + +from deepmd.utils.entry_point import ( + load_entry_point, +) + +load_entry_point("deepmd.jax") From ff04d8bdafa0985c83a9c2418b91466db63f0bc5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 14:56:28 -0400 Subject: [PATCH 07/10] fix(dpmodel/jax): fix fparam and aparam support in DeepEval (#4285) ## Summary by CodeRabbit - **New Features** - Enhanced error messages for improved clarity when input dimensions are incorrect. - Added support for optional fitting and atomic parameters in model evaluations. - **Bug Fixes** - Removed restrictions on providing fitting and atomic parameters, allowing for more flexible evaluations. - **Tests** - Introduced a new test class to validate the handling of fitting and atomic parameters in model evaluations. --------- Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/general_fitting.py | 8 ++-- deepmd/dpmodel/infer/deep_eval.py | 21 +++++++-- deepmd/jax/infer/deep_eval.py | 16 +++++-- deepmd/jax/utils/serialization.py | 8 ++-- source/tests/consistent/io/test_io.py | 56 +++++++++++++++++++++++ 5 files changed, 93 insertions(+), 16 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index e55f57c774..a027e1e59d 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -388,8 +388,8 @@ def _call_common( assert fparam is not None, "fparam should not be None" if fparam.shape[-1] != self.numb_fparam: raise ValueError( - "get an input fparam of dim {fparam.shape[-1]}, ", - "which is not consistent with {self.numb_fparam}.", + f"get an input fparam of dim {fparam.shape[-1]}, " + f"which is not consistent with {self.numb_fparam}." ) fparam = (fparam - self.fparam_avg) * self.fparam_inv_std fparam = xp.tile( @@ -409,8 +409,8 @@ def _call_common( assert aparam is not None, "aparam should not be None" if aparam.shape[-1] != self.numb_aparam: raise ValueError( - "get an input aparam of dim {aparam.shape[-1]}, ", - "which is not consistent with {self.numb_aparam}.", + f"get an input aparam of dim {aparam.shape[-1]}, " + f"which is not consistent with {self.numb_aparam}." ) aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam]) aparam = (aparam - self.aparam_avg) * self.aparam_inv_std diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index c1f3e4630b..5463743ada 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -204,8 +204,6 @@ def eval( The output of the evaluation. The keys are the names of the output variables, and the values are the corresponding output arrays. """ - if fparam is not None or aparam is not None: - raise NotImplementedError # convert all of the input to numpy array atom_types = np.array(atom_types, dtype=np.int32) coords = np.array(coords) @@ -216,7 +214,7 @@ def eval( ) request_defs = self._get_request_defs(atomic) out = self._eval_func(self._eval_model, numb_test, natoms)( - coords, cells, atom_types, request_defs + coords, cells, atom_types, fparam, aparam, request_defs ) return dict( zip( @@ -306,6 +304,8 @@ def _eval_model( coords: np.ndarray, cells: Optional[np.ndarray], atom_types: np.ndarray, + fparam: Optional[np.ndarray], + aparam: Optional[np.ndarray], request_defs: list[OutputVariableDef], ): model = self.dp @@ -323,12 +323,25 @@ def _eval_model( box_input = cells.reshape([-1, 3, 3]) else: box_input = None + if fparam is not None: + fparam_input = fparam.reshape(nframes, self.get_dim_fparam()) + else: + fparam_input = None + if aparam is not None: + aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam()) + else: + aparam_input = None do_atomic_virial = any( x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs ) batch_output = model( - coord_input, type_input, box=box_input, do_atomic_virial=do_atomic_virial + coord_input, + type_input, + box=box_input, + fparam=fparam_input, + aparam=aparam_input, + do_atomic_virial=do_atomic_virial, ) if isinstance(batch_output, tuple): batch_output = batch_output[0] diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 76f044a327..c1967fb0da 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -214,8 +214,6 @@ def eval( The output of the evaluation. The keys are the names of the output variables, and the values are the corresponding output arrays. """ - if fparam is not None or aparam is not None: - raise NotImplementedError # convert all of the input to numpy array atom_types = np.array(atom_types, dtype=np.int32) coords = np.array(coords) @@ -226,7 +224,7 @@ def eval( ) request_defs = self._get_request_defs(atomic) out = self._eval_func(self._eval_model, numb_test, natoms)( - coords, cells, atom_types, request_defs + coords, cells, atom_types, fparam, aparam, request_defs ) return dict( zip( @@ -316,6 +314,8 @@ def _eval_model( coords: np.ndarray, cells: Optional[np.ndarray], atom_types: np.ndarray, + fparam: Optional[np.ndarray], + aparam: Optional[np.ndarray], request_defs: list[OutputVariableDef], ): model = self.dp @@ -333,6 +333,14 @@ def _eval_model( box_input = cells.reshape([-1, 3, 3]) else: box_input = None + if fparam is not None: + fparam_input = fparam.reshape(nframes, self.get_dim_fparam()) + else: + fparam_input = None + if aparam is not None: + aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam()) + else: + aparam_input = None do_atomic_virial = any( x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs @@ -341,6 +349,8 @@ def _eval_model( to_jax_array(coord_input), to_jax_array(type_input), box=to_jax_array(box_input), + fparam=to_jax_array(fparam_input), + aparam=to_jax_array(aparam_input), do_atomic_virial=do_atomic_virial, ) if isinstance(batch_output, tuple): diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index fcfcc8a610..a7d57523e2 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -51,18 +51,16 @@ def deserialize_to_file(model_file: str, data: dict) -> None: model_def_script = data["model_def_script"] call_lower = model.call_lower - nf, nloc, nghost, nfp, nap = jax_export.symbolic_shape( - "nf, nloc, nghost, nfp, nap" - ) + nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") exported = jax_export.export(jax.jit(call_lower))( jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping - jax.ShapeDtypeStruct((nf, nfp), jnp.float64) + jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64) if model.get_dim_fparam() else None, # fparam - jax.ShapeDtypeStruct((nf, nap), jnp.float64) + jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64) if model.get_dim_aparam() else None, # aparam False, # do_atomic_virial diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index dc0f280d56..af26c41694 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -136,6 +136,8 @@ def test_deep_eval(self): [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], dtype=GLOBAL_NP_FLOAT_PRECISION, ).reshape(1, 9) + natoms = self.atype.shape[1] + nframes = self.atype.shape[0] prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): @@ -145,10 +147,20 @@ def test_deep_eval(self): reference_data = copy.deepcopy(self.data) self.save_data_to_model(prefix + backend.suffixes[0], reference_data) deep_eval = DeepEval(prefix + backend.suffixes[0]) + if deep_eval.get_dim_fparam() > 0: + fparam = np.ones((nframes, deep_eval.get_dim_fparam())) + else: + fparam = None + if deep_eval.get_dim_aparam() > 0: + aparam = np.ones((nframes, natoms, deep_eval.get_dim_aparam())) + else: + aparam = None ret = deep_eval.eval( self.coords, self.box, self.atype, + fparam=fparam, + aparam=aparam, ) rets.append(ret) for ret in rets[1:]: @@ -199,3 +211,47 @@ def setUp(self): def tearDown(self): IOTest.tearDown(self) + + +class TestDeepPotFparamAparam(unittest.TestCase, IOTest): + def setUp(self): + model_def_script = { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 3, + 6, + ], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "type": "ener", + "neuron": [ + 5, + 5, + ], + "resnet_dt": True, + "precision": "float64", + "atom_ener": [], + "seed": 1, + "numb_fparam": 2, + "numb_aparam": 2, + }, + } + model = get_model(copy.deepcopy(model_def_script)) + self.data = { + "model": model.serialize(), + "backend": "test", + "model_def_script": model_def_script, + } + + def tearDown(self): + IOTest.tearDown(self) From 704db2ff84188424dac86d274e473935854b8524 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 20:20:50 -0400 Subject: [PATCH 08/10] fix(lmp): add pair_deepmd_index arg to fix dplr for multiple deepmd pairs (#4274) Fix #4273. ## Summary by CodeRabbit - **New Features** - Introduced a new optional keyword `pair_deepmd_index` in the `fix dplr` command for enhanced control in simulations. - Updated documentation with clearer instructions and examples for the DPLR model, including training process and simulation setup. - **Bug Fixes** - Improved error handling related to the new `pair_deepmd_index` parameter to ensure proper usage. - **Documentation** - Enhanced descriptions and usability of the DPLR model documentation. Signed-off-by: Jinzhe Zeng --- doc/model/dplr.md | 6 +++++- source/lmp/fix_dplr.cpp | 9 ++++++++- source/lmp/fix_dplr.h | 3 +++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/doc/model/dplr.md b/doc/model/dplr.md index 91c2251346..cf071d4029 100644 --- a/doc/model/dplr.md +++ b/doc/model/dplr.md @@ -198,7 +198,7 @@ fix ID group-ID style_name keyword value ... - three or more keyword/value pairs may be appended ``` -keyword = *model* or *type_associate* or *bond_type* or *efield* +keyword = *model* or *type_associate* or *bond_type* or *efield* or *pair_deepmd_index* *model* value = name name = name of DPLR model file (e.g. frozen_model.pb) (not DW model) *type_associate* values = NR1 NW1 NR2 NW2 ... @@ -208,6 +208,8 @@ keyword = *model* or *type_associate* or *bond_type* or *efield* NBi = bond type of i-th (real atom, Wannier centroid) pair *efield* (optional) values = Ex Ey Ez Ex/Ey/Ez = electric field along x/y/z direction + *pair_deepmd_index* (optional) values = idx + idx = The index of pair_style deepmd, starting from 1, if more than one is used ``` **Examples** @@ -223,6 +225,8 @@ fix_modify 0 virial yes ``` The fix command `dplr` calculates the position of WCs by the DW model and back-propagates the long-range interaction on virtual atoms to real toms. +The fix command must be used after [pair_style `deepmd`](../third-party/lammps-command.md#pair_style-deepmd). +If there are more than 1 pair_style `deepmd`, `pair_deepmd_index` (starting from 1) must be set to assign the index of the pair_style `deepmd`. The atom names specified in [pair_style `deepmd`](../third-party/lammps-command.md#pair_style-deepmd) will be used to determine elements. If it is not set, the training parameter {ref}`type_map ` will be mapped to LAMMPS atom types. diff --git a/source/lmp/fix_dplr.cpp b/source/lmp/fix_dplr.cpp index 8a6be7d840..34fd2515ed 100644 --- a/source/lmp/fix_dplr.cpp +++ b/source/lmp/fix_dplr.cpp @@ -62,6 +62,7 @@ FixDPLR::FixDPLR(LAMMPS *lmp, int narg, char **arg) size_vector = 3; qe2f = force->qe2f; xstyle = ystyle = zstyle = NONE; + pair_deepmd_index = 0; if (strcmp(update->unit_style, "lj") == 0) { error->all(FLERR, @@ -125,6 +126,12 @@ FixDPLR::FixDPLR(LAMMPS *lmp, int narg, char **arg) } sort(bond_type.begin(), bond_type.end()); iarg = iend; + } else if (string(arg[iarg]) == string("pair_deepmd_index")) { + if (iarg + 1 >= narg) { + error->all(FLERR, "Illegal pair_deepmd_index, not provided"); + } + pair_deepmd_index = atoi(arg[iarg + 1]); + iarg += 2; } else { break; } @@ -141,7 +148,7 @@ FixDPLR::FixDPLR(LAMMPS *lmp, int narg, char **arg) error->one(FLERR, e.what()); } - pair_deepmd = (PairDeepMD *)force->pair_match("deepmd", 1); + pair_deepmd = (PairDeepMD *)force->pair_match("deepmd", 1, pair_deepmd_index); if (!pair_deepmd) { error->all(FLERR, "pair_style deepmd should be set before this fix\n"); } diff --git a/source/lmp/fix_dplr.h b/source/lmp/fix_dplr.h index a6822fe4fe..c43296e611 100644 --- a/source/lmp/fix_dplr.h +++ b/source/lmp/fix_dplr.h @@ -80,6 +80,9 @@ class FixDPLR : public Fix { void update_efield_variables(); enum { NONE, CONSTANT, EQUAL }; std::vector type_idx_map; + /* The index of deepmd pair index, which starts from 1. By default 0, which + * works only when there is one deepmd pair. */ + int pair_deepmd_index; }; } // namespace LAMMPS_NS From a4688194e9c42ff285df877fe5fcbff384f5a470 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 22:09:42 -0400 Subject: [PATCH 09/10] feat(jax/array-api): property fitting (#4287) ## Summary by CodeRabbit - **New Features** - Introduced the `PropertyFittingNet` class for enhanced property-specific fitting operations. - Enhanced testing framework to support additional computational backends (JAX and Array API Strict). - **Bug Fixes** - Improved handling of attribute assignments in property fitting. - **Tests** - Added new methods and properties to the testing suite for evaluating property fitting with JAX and Array API Strict. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/jax/fitting/fitting.py | 11 ++++ .../tests/array_api_strict/fitting/fitting.py | 9 +++ .../tests/consistent/fitting/test_property.py | 62 +++++++++++++++++++ 3 files changed, 82 insertions(+) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index 2a6186ac46..d62681490c 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -9,6 +9,9 @@ from deepmd.dpmodel.fitting.polarizability_fitting import ( PolarFitting as PolarFittingNetDP, ) +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from deepmd.jax.common import ( ArrayAPIVariable, flax_module, @@ -51,6 +54,14 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseFitting.register("property") +@flax_module +class PropertyFittingNet(PropertyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + @BaseFitting.register("dos") @flax_module class DOSFittingNet(DOSFittingNetDP): diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py index 5a2bd9c58f..323a49cfe8 100644 --- a/source/tests/array_api_strict/fitting/fitting.py +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -9,6 +9,9 @@ from deepmd.dpmodel.fitting.polarizability_fitting import ( PolarFitting as PolarFittingNetDP, ) +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from ..common import ( to_array_api_strict_array, @@ -43,6 +46,12 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +class PropertyFittingNet(PropertyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index beb21d9c04..4e0fe04f9f 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -17,6 +17,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, CommonTest, parameterized, @@ -32,6 +34,22 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: PropertyFittingPT = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import PropertyFittingNet as PropertyFittingJAX +else: + PropertyFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + PropertyFittingNet as PropertyFittingStrict, + ) +else: + PropertyFittingStrict = object + PropertyFittingTF = object @@ -84,9 +102,14 @@ def skip_pt(self) -> bool: def skip_tf(self) -> bool: return True + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + tf_class = PropertyFittingTF dp_class = PropertyFittingDP pt_class = PropertyFittingPT + jax_class = PropertyFittingJAX + array_api_strict_class = PropertyFittingStrict args = fitting_property() def setUp(self): @@ -183,6 +206,45 @@ def eval_dp(self, dp_obj: Any) -> Any: aparam=self.aparam if numb_aparam else None, )["property"] + def eval_jax(self, jax_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + fparam=jnp.asarray(self.fparam) if numb_fparam else None, + aparam=jnp.asarray(self.aparam) if numb_aparam else None, + )["property"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + array_api_strict.set_array_api_strict_flags(api_version="2023.12") + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None, + aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None, + )["property"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same From 5c321470794245f8d235841dee689107b2c7b593 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Fri, 1 Nov 2024 12:26:13 +0800 Subject: [PATCH 10/10] Feat: Add consistency test for ZBL between dp and pt (#4292) ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced `DPZBLModel`, enhancing energy modeling capabilities. - Added `get_zbl_model` function for creating `DPZBLModel` from input data. - New `DPZBLLinearEnergyAtomicModel` class allows for complex interactions between atomic models. - **Bug Fixes** - Corrected typographical errors in multiple test classes to improve code clarity and consistency in method names. - Updated model type attributes for `DPZBLModel` and `LinearEnergyModel` to reflect accurate classifications. - **Tests** - Added comprehensive unit tests for energy models to ensure functionality across various backends. - Enhanced existing test classes with corrected method names for improved accuracy. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../atomic_model/linear_atomic_model.py | 2 + deepmd/dpmodel/model/dp_zbl_model.py | 66 ++++++ deepmd/dpmodel/model/model.py | 53 +++++ deepmd/pt/model/model/dp_linear_model.py | 2 +- deepmd/pt/model/model/dp_zbl_model.py | 2 +- source/tests/consistent/common.py | 4 +- .../tests/consistent/fitting/test_dipole.py | 2 +- source/tests/consistent/fitting/test_dos.py | 2 +- source/tests/consistent/fitting/test_ener.py | 2 +- source/tests/consistent/fitting/test_polar.py | 2 +- .../tests/consistent/fitting/test_property.py | 2 +- source/tests/consistent/model/test_ener.py | 2 +- .../tests/consistent/model/test_zbl_ener.py | 224 ++++++++++++++++++ .../tests/consistent/test_type_embedding.py | 2 +- 14 files changed, 356 insertions(+), 11 deletions(-) create mode 100644 deepmd/dpmodel/model/dp_zbl_model.py create mode 100644 source/tests/consistent/model/test_zbl_ener.py diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 5d86472674..224fdd145c 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -34,6 +34,7 @@ ) +@BaseAtomicModel.register("linear") class LinearEnergyAtomicModel(BaseAtomicModel): """Linear model make linear combinations of several existing models. @@ -324,6 +325,7 @@ def is_aparam_nall(self) -> bool: return False +@BaseAtomicModel.register("zbl") class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel): """Model linearly combine a list of AtomicModels. diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py new file mode 100644 index 0000000000..ba19785235 --- /dev/null +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +from deepmd.dpmodel.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) +from deepmd.dpmodel.model.dp_model import ( + DPModelCommon, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + +from .make_model import ( + make_model, +) + +DPZBLModel_ = make_model(DPZBLLinearEnergyAtomicModel) + + +@BaseModel.register("zbl") +class DPZBLModel(DPZBLModel_): + model_type = "zbl" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel( + train_data, type_map, local_jdata["dpmodel"] + ) + return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index cccd0732cd..c29240214c 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -1,4 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, +) +from deepmd.dpmodel.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.dpmodel.descriptor.se_e2_a import ( DescrptSeA, ) @@ -8,6 +17,9 @@ from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.dpmodel.model.dp_zbl_model import ( + DPZBLModel, +) from deepmd.dpmodel.model.ener_model import ( EnergyModel, ) @@ -55,6 +67,45 @@ def get_standard_model(data: dict) -> EnergyModel: ) +def get_zbl_model(data: dict) -> DPZBLModel: + data["descriptor"]["ntypes"] = len(data["type_map"]) + descriptor = BaseDescriptor(**data["descriptor"]) + fitting_type = data["fitting_net"].pop("type") + if fitting_type == "ener": + fitting = EnergyFittingNet( + ntypes=descriptor.get_ntypes(), + dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), + **data["fitting_net"], + ) + else: + raise ValueError(f"Unknown fitting type {fitting_type}") + + dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"]) + # pairtab + filepath = data["use_srtab"] + pt_model = PairTabAtomicModel( + filepath, + data["descriptor"]["rcut"], + data["descriptor"]["sel"], + type_map=data["type_map"], + ) + + rmin = data["sw_rmin"] + rmax = data["sw_rmax"] + atom_exclude_types = data.get("atom_exclude_types", []) + pair_exclude_types = data.get("pair_exclude_types", []) + return DPZBLModel( + dp_model, + pt_model, + rmin, + rmax, + type_map=data["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) + + def get_spin_model(data: dict) -> SpinModel: """Get a spin model from a dictionary. @@ -100,6 +151,8 @@ def get_model(data: dict): if model_type == "standard": if "spin" in data: return get_spin_model(data) + elif "use_srtab" in data: + return get_zbl_model(data) else: return get_standard_model(data) else: diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py index d19070fc5b..4028d77228 100644 --- a/deepmd/pt/model/model/dp_linear_model.py +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -30,7 +30,7 @@ @BaseModel.register("linear_ener") class LinearEnergyModel(DPLinearModel_): - model_type = "ener" + model_type = "linear_ener" def __init__( self, diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index e1ef00f5fe..0f05e3e56d 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -30,7 +30,7 @@ @BaseModel.register("zbl") class DPZBLModel(DPZBLModel_): - model_type = "ener" + model_type = "zbl" def __init__( self, diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index bcad7c4502..734486becb 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -75,7 +75,7 @@ class CommonTest(ABC): data: ClassVar[dict] """Arguments data.""" - addtional_data: ClassVar[dict] = {} + additional_data: ClassVar[dict] = {} """Additional data that will not be checked.""" tf_class: ClassVar[Optional[type]] """TensorFlow model class.""" @@ -128,7 +128,7 @@ def init_backend_cls(self, cls) -> Any: def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" - return cls(**data, **self.addtional_data) + return cls(**data, **self.additional_data) @abstractmethod def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 55d6c44c34..60ee7322c1 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -104,7 +104,7 @@ def setUp(self): self.atype.sort() @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index 774e3f655e..d3de3ef151 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -124,7 +124,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index e32410a0ec..f4e78ce966 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -134,7 +134,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 895974baf9..bd9d013b8d 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -104,7 +104,7 @@ def setUp(self): self.atype.sort() @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index 4e0fe04f9f..a096d4dd68 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -127,7 +127,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 2a358ba7e0..98330ba849 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -130,7 +130,7 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_pt(data) elif cls is EnergyModelJAX: return get_model_jax(data) - return cls(**data, **self.addtional_data) + return cls(**data, **self.additional_data) def setUp(self): CommonTest.setUp(self) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py new file mode 100644 index 0000000000..f37bee0c90 --- /dev/null +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + SKIP_FLAG, + CommonTest, + parameterized, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelPT +else: + DPZBLModelPT = None +import os + +from deepmd.utils.argcheck import ( + model_args, +) + +TESTS_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + + +@parameterized( + ( + [], + [[0, 1]], + ), + ( + [], + [1], + ), +) +class TestEner(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + pair_exclude_types, atom_exclude_types = self.param + return { + "type_map": ["O", "H", "B"], + "use_srtab": f"{TESTS_DIR}/pt/water/data/zbl_tab_potential/H2O_tab_potential.txt", + "smin_alpha": 0.1, + "sw_rmin": 0.2, + "sw_rmax": 4.0, + "pair_exclude_types": pair_exclude_types, + "atom_exclude_types": atom_exclude_types, + "descriptor": { + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [3, 6], + "axis_neuron": 2, + "attn": 8, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "set_davg_zero": True, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "seed": 1, + }, + } + + dp_class = DPZBLModelDP + pt_class = DPZBLModelPT + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_jax: + return self.RefBackend.JAX + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_tf(self): + return True + + @property + def skip_jax(self): + return True + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is DPZBLModelDP: + return get_model_dp(data) + elif cls is DPZBLModelPT: + return get_model_pt(data) + return cls(**data, **self.additional_data) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + return self.build_tf_model( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend is self.RefBackend.DP: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + ) + elif backend is self.RefBackend.PT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ) + elif backend is self.RefBackend.TF: + return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) + elif backend is self.RefBackend.JAX: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + ret["energy_derv_r"].ravel(), + ret["energy_derv_c_redu"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index a4b516ef16..0dd17c841e 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -82,7 +82,7 @@ def data(self) -> dict: skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision,